├── model ├── __init__.py ├── nn.py ├── gumbel_masks.py └── ilcm_vae.py ├── scripts ├── __init__.py ├── produce_all_logs.py ├── compute_udr_npy.py └── compute_udr_npy_rand_graphs.py ├── baseline_models ├── __init__.py ├── beta-tcvae │ ├── __init__.py │ ├── metric_helpers │ │ ├── mi_metric.py │ │ └── loader.py │ ├── LICENSE │ ├── README.md │ ├── elbo_decomposition.py │ └── disentanglement_metrics.py ├── icebeem │ ├── __init__.py │ ├── metrics │ │ └── __init__.py │ ├── models │ │ ├── nflib │ │ │ └── __init__.py │ │ ├── tcl │ │ │ ├── __init__.py │ │ │ ├── tcl_preprocessing.py │ │ │ ├── tcl_eval.py │ │ │ └── tcl_wrapper_gpu.py │ │ ├── __init__.py │ │ ├── ivae │ │ │ ├── __init__.py │ │ │ └── ivae_wrapper.py │ │ ├── icebeem_wrapper.py │ │ ├── ebm.py │ │ └── nets.py │ ├── data │ │ ├── __init__.py │ │ └── utils.py │ ├── losses │ │ ├── __init__.py │ │ └── dsm.py │ ├── configs │ │ ├── imca.yaml │ │ ├── mnist_baseline.yaml │ │ ├── cifar10.yaml │ │ ├── fashionmnist_baseline.yaml │ │ ├── cifar100.yaml │ │ ├── cifar10_baseline.yaml │ │ ├── mnist.yaml │ │ ├── fashionmnist.yaml │ │ └── cifar100_baseline.yaml │ ├── LICENSE │ └── train.py └── slowvae_pcl │ ├── scripts │ ├── __init__.py │ ├── evaluate_disentanglement.py │ ├── model.py │ └── data_analysis_utils.py │ ├── mcc_metric │ ├── __init__.py │ └── metric.py │ ├── data_generation │ ├── coco │ │ ├── __init__.py │ │ └── mask.py │ ├── gen_youtube_csv.py │ └── gen_kitti_masks.py │ ├── latent_factors.gif │ ├── requirements.txt │ ├── metric_configs │ ├── dci.gin │ ├── mcc.gin │ ├── mig.gin │ ├── sap_score.gin │ ├── beta_vae_sklearn.gin │ ├── modularity_explicitness.gin │ └── factor_vae_metric.gin │ ├── LICENSE │ ├── README.md │ └── train.py ├── universal_logger ├── __init__.py └── logger.py ├── results ├── CLeaR2022 │ └── all_logs (8).npy ├── JMLR │ ├── penalty │ │ ├── all_logs_28jul2023_udr.npy │ │ └── all_logs_12sep2023_rand_graphs_udr.npy │ └── constraint │ │ ├── graphVisualization │ │ ├── TimeNonDiag │ │ │ ├── fig.png │ │ │ ├── fig.xcf │ │ │ ├── gt_gz_0.png │ │ │ ├── G^z_300000.png │ │ │ ├── C_pattern_300000.png │ │ │ ├── estimated_C_300000.png │ │ │ ├── permuted_G_mask_300000.png │ │ │ └── correlation_matrix_300000.png │ │ ├── ActionNonDiag │ │ │ ├── fig.png │ │ │ ├── fig.xcf │ │ │ ├── gt_ga_0.png │ │ │ ├── G^a_300000.png │ │ │ ├── C_pattern_300000.png │ │ │ ├── estimated_C_300000.png │ │ │ ├── correlation_matrix_300000.png │ │ │ └── permuted_GC_mask_300000.png │ │ ├── TimeBlockNonDiag │ │ │ ├── fig.png │ │ │ ├── fig.xcf │ │ │ ├── gt_gz_0.png │ │ │ ├── G^z_300000.png │ │ │ ├── C_pattern_300000.png │ │ │ ├── estimated_C_300000.png │ │ │ ├── permuted_G_mask_300000.png │ │ │ └── correlation_matrix_300000.png │ │ └── ActionBlockNonDiag │ │ │ ├── fig.png │ │ │ ├── fig.xcf │ │ │ ├── gt_ga_0.png │ │ │ ├── G^a_300000.png │ │ │ ├── C_pattern_300000.png │ │ │ ├── estimated_C_300000.png │ │ │ ├── correlation_matrix_300000.png │ │ │ └── permuted_GC_mask_300000.png │ │ ├── all_logs_16sep2023_JMLR_const_ws_udr.npy │ │ ├── all_logs_23sep2023_JMLR_const_rand_g.npy │ │ ├── all_logs_18sep2023_JMLR_const_clear_udr.npy │ │ └── all_logs_19sep2023_JMLR_const_violation.npy └── summary.txt ├── plot.py ├── requirements.txt ├── .gitignore ├── README.md └── optimization.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baseline_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /universal_logger/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baseline_models/beta-tcvae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baseline_models/icebeem/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/mcc_metric/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/data_generation/coco/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baseline_models/icebeem/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcc import mean_corr_coef 2 | __all__ = ["mcc"] -------------------------------------------------------------------------------- /baseline_models/icebeem/models/nflib/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["conditional_flows", "flows", "spline_flows"] -------------------------------------------------------------------------------- /baseline_models/icebeem/models/tcl/__init__.py: -------------------------------------------------------------------------------- 1 | from .tcl_wrapper_gpu import train 2 | __all__ = ["tcl_wrapper_gpu", "tcl_core"] -------------------------------------------------------------------------------- /baseline_models/icebeem/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .icebeem_wrapper import ICEBEEM_wrapper 2 | __all__ = ["icebeem_wrapper", "nets"] 3 | -------------------------------------------------------------------------------- /baseline_models/icebeem/models/ivae/__init__.py: -------------------------------------------------------------------------------- 1 | from .ivae_wrapper import IVAE_wrapper 2 | __all__ = ["ivae_wrapper", "ivae_core"] -------------------------------------------------------------------------------- /baseline_models/icebeem/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .imca import generate_synthetic_data 2 | from .utils import to_one_hot 3 | 4 | __all__ = ["imca", "utils"] -------------------------------------------------------------------------------- /results/CLeaR2022/all_logs (8).npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/CLeaR2022/all_logs (8).npy -------------------------------------------------------------------------------- /baseline_models/icebeem/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .dsm import conditional_dsm, dsm_score_estimation, dsm, cdsm 2 | from .fce import ConditionalFCE 3 | __all__ = ["dsm", "fce"] -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/latent_factors.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/baseline_models/slowvae_pcl/latent_factors.gif -------------------------------------------------------------------------------- /results/JMLR/penalty/all_logs_28jul2023_udr.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/penalty/all_logs_28jul2023_udr.npy -------------------------------------------------------------------------------- /results/JMLR/penalty/all_logs_12sep2023_rand_graphs_udr.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/penalty/all_logs_12sep2023_rand_graphs_udr.npy -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeNonDiag/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/fig.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeNonDiag/fig.xcf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/fig.xcf -------------------------------------------------------------------------------- /results/JMLR/constraint/all_logs_16sep2023_JMLR_const_ws_udr.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/all_logs_16sep2023_JMLR_const_ws_udr.npy -------------------------------------------------------------------------------- /results/JMLR/constraint/all_logs_23sep2023_JMLR_const_rand_g.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/all_logs_23sep2023_JMLR_const_rand_g.npy -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionNonDiag/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/fig.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionNonDiag/fig.xcf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/fig.xcf -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeNonDiag/gt_gz_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/gt_gz_0.png -------------------------------------------------------------------------------- /results/JMLR/constraint/all_logs_18sep2023_JMLR_const_clear_udr.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/all_logs_18sep2023_JMLR_const_clear_udr.npy -------------------------------------------------------------------------------- /results/JMLR/constraint/all_logs_19sep2023_JMLR_const_violation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/all_logs_19sep2023_JMLR_const_violation.npy -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionNonDiag/gt_ga_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/gt_ga_0.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/fig.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/fig.xcf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/fig.xcf -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/fig.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/fig.xcf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/fig.xcf -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionNonDiag/G^a_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/G^a_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/gt_gz_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/gt_gz_0.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeNonDiag/G^z_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/G^z_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/gt_ga_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/gt_ga_0.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/G^a_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/G^a_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/G^z_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/G^z_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeNonDiag/C_pattern_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/C_pattern_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionNonDiag/C_pattern_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/C_pattern_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeNonDiag/estimated_C_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/estimated_C_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionNonDiag/estimated_C_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/estimated_C_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/C_pattern_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/C_pattern_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeNonDiag/permuted_G_mask_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/permuted_G_mask_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/C_pattern_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/C_pattern_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/estimated_C_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/estimated_C_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/estimated_C_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/estimated_C_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionNonDiag/correlation_matrix_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/correlation_matrix_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionNonDiag/permuted_GC_mask_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionNonDiag/permuted_GC_mask_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/permuted_G_mask_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/permuted_G_mask_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeNonDiag/correlation_matrix_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeNonDiag/correlation_matrix_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/correlation_matrix_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/correlation_matrix_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/permuted_GC_mask_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/ActionBlockNonDiag/permuted_GC_mask_300000.png -------------------------------------------------------------------------------- /results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/correlation_matrix_300000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slachapelle/disentanglement_via_mechanism_sparsity/HEAD/results/JMLR/constraint/graphVisualization/TimeBlockNonDiag/correlation_matrix_300000.png -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.0 2 | torch==1.4.0 3 | torchvision==0.5.0 4 | imageio==2.8.0 5 | matplotlib==3.2.2 6 | scipy==1.1.0 7 | scikit-learn==0.23.1 8 | spriteworld==1.0.2 9 | pandas==1.0.5 10 | disentanglement-lib==1.4 11 | tensorflow==1.14 12 | tqdm==4.31.1 -------------------------------------------------------------------------------- /baseline_models/icebeem/configs/imca.yaml: -------------------------------------------------------------------------------- 1 | data_dim: 5 2 | n_segments: 20 3 | n_layers: [ 2,4 ] # [2,4] 4 | n_obs_per_seg: [100, 200, 500, 1000, 2000] 5 | data_seed: 1 6 | 7 | ivae: 8 | max_iter: 70000 9 | lr: 0.001 10 | cuda: False 11 | 12 | icebeem: 13 | lr_flow: 0.00001 14 | lr_ebm: 0.0003 15 | n_layers_flow: 10 # make 5 for L=2 IMCA best perf 16 | ebm_hidden_size: 32 # make 16 for L=2 IMCA best perf 17 | -------------------------------------------------------------------------------- /baseline_models/icebeem/configs/mnist_baseline.yaml: -------------------------------------------------------------------------------- 1 | # config file for baseline (of no transfer) example on MNIST 2 | 3 | training: 4 | batch_size: 63 5 | n_epochs: 5 6 | n_iters: 200001 7 | ngpu: 1 8 | snapshot_freq: 50 9 | algo: 'dsm' 10 | anneal_power: 2.0 11 | 12 | data: 13 | dataset: "MNIST_transferBaseline" 14 | image_size: 28 15 | channels: 1 16 | logit_transform: false 17 | random_flip: false 18 | store_loss: true 19 | 20 | model: 21 | sigma_begin: 1 22 | sigma_end: 0.01 23 | num_classes: 10 24 | ngf: 64 25 | final_layer: true 26 | feature_size: 90 27 | augment: false 28 | positive: false 29 | architecture: 'ConvMLP' 30 | 31 | optim: 32 | weight_decay: 0.000 33 | optimizer: "Adam" 34 | lr: 0.001 35 | beta1: 0.9 36 | amsgrad: false 37 | 38 | n_labels: 8 39 | -------------------------------------------------------------------------------- /baseline_models/icebeem/configs/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # config file for transfer learning example on CIFAR10 2 | 3 | training: 4 | batch_size: 63 5 | n_epochs: 500000 6 | n_iters: 5001 7 | ngpu: 1 8 | snapshot_freq: 250 9 | algo: 'dsm' 10 | anneal_power: 2.0 11 | 12 | data: 13 | dataset: "CIFAR10" 14 | image_size: 32 15 | channels: 3 16 | logit_transform: false 17 | random_flip: true 18 | random_state: 0 19 | split_size: .33 20 | 21 | model: 22 | sigma_begin: 1 23 | sigma_end: 0.01 24 | num_classes: 10 25 | batch_norm: false 26 | ngf: 64 27 | final_layer: true 28 | feature_size: 200 29 | augment: false 30 | positive: false 31 | architecture: 'ConvMLP' 32 | 33 | optim: 34 | weight_decay: 0.000 35 | optimizer: "Adam" 36 | lr: 0.001 37 | beta1: 0.9 38 | amsgrad: false 39 | 40 | n_labels: 8 41 | -------------------------------------------------------------------------------- /baseline_models/icebeem/configs/fashionmnist_baseline.yaml: -------------------------------------------------------------------------------- 1 | # config file for transfer learning example on FashionMNIST 2 | 3 | training: 4 | batch_size: 63 5 | n_epochs: 15 6 | n_iters: 200001 7 | ngpu: 1 8 | snapshot_freq: 50 9 | algo: 'dsm' 10 | anneal_power: 2.0 11 | 12 | data: 13 | dataset: "FashionMNIST_transferBaseline" 14 | image_size: 28 15 | channels: 1 16 | logit_transform: false 17 | random_flip: false 18 | 19 | model: 20 | sigma_begin: 1 21 | sigma_end: 0.01 22 | num_classes: 10 23 | batch_norm: false 24 | ngf: 64 25 | final_layer: true 26 | feature_size: 200 27 | augment: false 28 | positive: false 29 | architecture: 'ConvMLP' 30 | 31 | optim: 32 | weight_decay: 0.000 33 | optimizer: "Adam" 34 | lr: 0.001 35 | beta1: 0.9 36 | amsgrad: false 37 | 38 | n_labels: 8 39 | -------------------------------------------------------------------------------- /baseline_models/icebeem/configs/cifar100.yaml: -------------------------------------------------------------------------------- 1 | # config file for transfer learning example on CIFAR100 2 | 3 | training: 4 | batch_size: 63 5 | n_epochs: 500000 6 | n_iters: 5001 7 | ngpu: 1 8 | snapshot_freq: 250 9 | algo: 'dsm' 10 | anneal_power: 2.0 11 | 12 | data: 13 | dataset: "CIFAR100" 14 | image_size: 32 15 | channels: 3 16 | logit_transform: false 17 | random_flip: true 18 | random_state: 0 19 | split_size: .33 20 | 21 | model: 22 | sigma_begin: 1 23 | sigma_end: 0.01 24 | num_classes: 10 25 | batch_norm: false 26 | ngf: 64 27 | final_layer: true 28 | feature_size: 200 29 | augment: false 30 | positive: false 31 | architecture: 'ConvMLP' 32 | 33 | optim: 34 | weight_decay: 0.000 35 | optimizer: "Adam" 36 | lr: 0.001 37 | beta1: 0.9 38 | amsgrad: false 39 | 40 | n_labels: 85 41 | -------------------------------------------------------------------------------- /baseline_models/icebeem/configs/cifar10_baseline.yaml: -------------------------------------------------------------------------------- 1 | # config file for baseline (of no transfer) example on CIFAR10 2 | 3 | training: 4 | batch_size: 63 5 | n_epochs: 15 6 | n_iters: 200001 7 | ngpu: 1 8 | snapshot_freq: 50 9 | algo: 'dsm' 10 | anneal_power: 2.0 11 | 12 | data: 13 | dataset: "CIFAR10_transferBaseline" 14 | image_size: 32 15 | channels: 3 16 | logit_transform: false 17 | random_flip: true 18 | store_loss: true 19 | 20 | model: 21 | sigma_begin: 1 22 | sigma_end: 0.01 23 | num_classes: 10 24 | batch_norm: false 25 | ngf: 64 26 | final_layer: true 27 | feature_size: 200 28 | augment: false 29 | positive: false 30 | architecture: 'ConvMLP' 31 | 32 | optim: 33 | weight_decay: 0.000 34 | optimizer: "Adam" 35 | lr: 0.001 36 | beta1: 0.9 37 | amsgrad: false 38 | 39 | n_labels: 8 40 | -------------------------------------------------------------------------------- /baseline_models/icebeem/configs/mnist.yaml: -------------------------------------------------------------------------------- 1 | # config file for transfer learning example on MNIST 2 | 3 | training: 4 | batch_size: 63 5 | n_epochs: 500000 6 | n_iters: 5001 7 | ngpu: 1 8 | snapshot_freq: 250 9 | algo: 'dsm' 10 | anneal_power: 2.0 11 | 12 | data: 13 | dataset: "MNIST" 14 | image_size: 28 15 | channels: 1 16 | logit_transform: false 17 | random_flip: false 18 | random_state: 0 19 | split_size: .15 20 | 21 | 22 | model: 23 | sigma_begin: 1 24 | sigma_end: 0.01 25 | num_classes: 10 26 | batch_norm: false 27 | ngf: 64 28 | final_layer: true 29 | feature_size: 90 30 | augment: false 31 | positive: false 32 | architecture: 'ConvMLP' 33 | 34 | optim: 35 | weight_decay: 0.000 36 | optimizer: "Adam" 37 | lr: 0.001 38 | beta1: 0.9 39 | amsgrad: false 40 | 41 | n_labels: 8 42 | -------------------------------------------------------------------------------- /results/summary.txt: -------------------------------------------------------------------------------- 1 | CLeaR 2022: 2 | /CLeaR2022/all_logs (8).npy 3 | to generate figures: https://colab.research.google.com/drive/1-gUajJMcGEEcpQA9nm9XcYB4ZLpR-wVc#scrollTo=thick-range 4 | 5 | JMLR (penalty): 6 | JMLR/penalty/all_logs_28jul2023_udr.npy (workshop data) 7 | JMLR/penalty/all_logs_12sep2023_rand_graphs_udr.npy 8 | to generate figures: https://colab.research.google.com/drive/1godG3MbAlLhB-FK95HJCHVQu2IjE8mM7#scrollTo=V0oKAFtzMXmw 9 | 10 | JMLR (constraint): 11 | JMLR/constraint/all_logs_18sep2023_JMLR_const_clear_udr.npy 12 | JMLR/constraint/all_logs_16sep2023_JMLR_const_ws_udr.npy 13 | JMLR/constraint/all_logs_19sep2023_JMLR_const_violation.npy 14 | JMLR/constraint/all_logs_23sep2023_JMLR_const_rand_g.npy 15 | to generate figures: https://colab.research.google.com/drive/1lftjqLUeoIQrNlWyKRrtPyXRXwyQz3Y5 16 | 17 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/metric_configs/dci.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @dci 17 | evaluation.random_seed = 0 18 | dci.num_train=10000 19 | dci.num_test=5000 20 | -------------------------------------------------------------------------------- /baseline_models/icebeem/configs/fashionmnist.yaml: -------------------------------------------------------------------------------- 1 | # config file for transfer learning example on FashionMNIST 2 | 3 | training: 4 | batch_size: 63 5 | n_epochs: 500000 6 | n_iters: 5001 7 | ngpu: 1 8 | snapshot_freq: 250 9 | algo: 'dsm' 10 | anneal_power: 2.0 11 | 12 | data: 13 | dataset: "FashionMNIST" 14 | image_size: 28 15 | channels: 1 16 | logit_transform: false 17 | random_flip: false 18 | random_state: 0 19 | split_size : .15 20 | 21 | model: 22 | sigma_begin: 1 23 | sigma_end: 0.01 24 | num_classes: 10 25 | batch_norm: false 26 | ngf: 64 27 | final_layer: true 28 | feature_size: 200 29 | augment: false 30 | positive: false 31 | architecture: 'ConvMLP' 32 | 33 | optim: 34 | weight_decay: 0.000 35 | optimizer: "Adam" 36 | lr: 0.001 37 | beta1: 0.9 38 | amsgrad: false 39 | 40 | n_labels: 8 41 | -------------------------------------------------------------------------------- /baseline_models/icebeem/configs/cifar100_baseline.yaml: -------------------------------------------------------------------------------- 1 | # config file for transfer learning example on CIFAR100 2 | 3 | training: 4 | batch_size: 63 5 | n_epochs: 15 6 | n_iters: 200001 7 | ngpu: 1 8 | snapshot_freq: 50 9 | algo: 'dsm' 10 | anneal_power: 2.0 11 | 12 | data: 13 | dataset: "CIFAR100_transferBaseline" 14 | image_size: 32 15 | channels: 3 16 | logit_transform: false 17 | random_flip: true 18 | random_state: 0 19 | split_size: .33 20 | 21 | model: 22 | sigma_begin: 1 23 | sigma_end: 0.01 24 | num_classes: 10 25 | batch_norm: false 26 | ngf: 64 27 | final_layer: true 28 | feature_size: 200 29 | augment: false 30 | positive: false 31 | architecture: 'ConvMLP' 32 | 33 | optim: 34 | weight_decay: 0.000 35 | optimizer: "Adam" 36 | lr: 0.001 37 | beta1: 0.9 38 | amsgrad: false 39 | 40 | n_labels: 85 41 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/metric_configs/mcc.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @mcc 17 | evaluation.random_seed = 0 18 | mcc.num_train=10000 19 | mcc.correlation_fn = "Spearman" 20 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/metric_configs/mig.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @mig 17 | evaluation.random_seed = 0 18 | mig.num_train=100000 19 | discretizer.discretizer_fn = @histogram_discretizer 20 | discretizer.num_bins = 20 21 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/metric_configs/sap_score.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @sap_score 17 | evaluation.random_seed = 0 18 | sap_score.num_train=10000 19 | sap_score.num_test=5000 20 | sap_score.continuous_factors=False 21 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/metric_configs/beta_vae_sklearn.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @beta_vae_sklearn 17 | evaluation.random_seed = 0 18 | beta_vae_sklearn.batch_size=64 19 | beta_vae_sklearn.num_train=10000 20 | beta_vae_sklearn.num_eval=5000 21 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/metric_configs/modularity_explicitness.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @modularity_explicitness 17 | evaluation.random_seed = 0 18 | modularity_explicitness.num_train=100000 19 | modularity_explicitness.num_test=5000 20 | discretizer.discretizer_fn = @histogram_discretizer 21 | discretizer.num_bins = 20 22 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/metric_configs/factor_vae_metric.gin: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The DisentanglementLib Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | evaluation.evaluation_fn = @factor_vae_score 17 | evaluation.random_seed = 0 18 | factor_vae_score.num_variance_estimate=10000 19 | factor_vae_score.num_train=10000 20 | factor_vae_score.num_eval=5000 21 | factor_vae_score.batch_size=64 22 | prune_dims.threshold = 0.05 23 | 24 | -------------------------------------------------------------------------------- /baseline_models/beta-tcvae/metric_helpers/mi_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | metric_name = 'MIG' 4 | 5 | 6 | def MIG(mi_normed): 7 | return torch.mean(mi_normed[:, 0] - mi_normed[:, 1]) 8 | 9 | 10 | def compute_metric_shapes(marginal_entropies, cond_entropies): 11 | factor_entropies = [6, 40, 32, 32] 12 | mutual_infos = marginal_entropies[None] - cond_entropies 13 | mutual_infos = torch.sort(mutual_infos, dim=1, descending=True)[0].clamp(min=0) 14 | mi_normed = mutual_infos / torch.Tensor(factor_entropies).log()[:, None] 15 | metric = eval(metric_name)(mi_normed) 16 | return metric 17 | 18 | 19 | def compute_metric_faces(marginal_entropies, cond_entropies): 20 | factor_entropies = [21, 11, 11] 21 | mutual_infos = marginal_entropies[None] - cond_entropies 22 | mutual_infos = torch.sort(mutual_infos, dim=1, descending=True)[0].clamp(min=0) 23 | mi_normed = mutual_infos / torch.Tensor(factor_entropies).log()[:, None] 24 | metric = eval(metric_name)(mi_normed) 25 | return metric 26 | 27 | -------------------------------------------------------------------------------- /baseline_models/icebeem/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ilyes Khemakhem 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Bethge Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /baseline_models/beta-tcvae/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ricky Tian Qi Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /baseline_models/beta-tcvae/README.md: -------------------------------------------------------------------------------- 1 | # beta-TCVAE 2 | 3 | This repository contains cleaned-up code for reproducing the quantitative experiments in Isolating Sources of Disentanglement in Variational Autoencoders \[[arxiv](https://arxiv.org/abs/1802.04942)\]. 4 | 5 | ## Usage 6 | 7 | To train a model: 8 | 9 | ``` 10 | python vae_quant.py --dataset [shapes/faces] --beta 6 --tcvae 11 | ``` 12 | Specify `--conv` to use the convolutional VAE. We used a mlp for dSprites and conv for 3d faces. To see all options, use the `-h` flag. 13 | 14 | The main computational difference between beta-VAE and beta-TCVAE is summarized in [these lines](vae_quant.py#L220-L228). 15 | 16 | To evaluate the MIG of a model: 17 | ``` 18 | python disentanglement_metrics.py --checkpt [checkpt] 19 | ``` 20 | To see all options, use the `-h` flag. 21 | 22 | ## Datasets 23 | 24 | ### dSprites 25 | Download the npz file from [here](https://github.com/deepmind/dsprites-dataset) and place it into `data/`. 26 | 27 | ### 3D faces 28 | We cannot publicly distribute this due to the [license](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-2&id=downloads). Please contact me for the data. 29 | 30 | ## Contact 31 | Email rtqichen@cs.toronto.edu if you have questions about the code/data. 32 | 33 | ## Bibtex 34 | ``` 35 | @inproceedings{chen2018isolating, 36 | title={Isolating Sources of Disentanglement in Variational Autoencoders}, 37 | author={Chen, Ricky T. Q. and Li, Xuechen and Grosse, Roger and Duvenaud, David}, 38 | booktitle = {Advances in Neural Information Processing Systems}, 39 | year={2018} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /baseline_models/beta-tcvae/metric_helpers/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lib.dist as dist 3 | import lib.flows as flows 4 | import vae_quant 5 | 6 | 7 | def load_model_and_dataset(checkpt_filename): 8 | print('Loading model and dataset.') 9 | checkpt = torch.load(checkpt_filename, map_location=lambda storage, loc: storage) 10 | args = checkpt['args'] 11 | state_dict = checkpt['state_dict'] 12 | 13 | # backwards compatibility 14 | if not hasattr(args, 'conv'): 15 | args.conv = False 16 | 17 | if not hasattr(args, 'dist') or args.dist == 'normal': 18 | prior_dist = dist.Normal() 19 | q_dist = dist.Normal() 20 | elif args.dist == 'laplace': 21 | prior_dist = dist.Laplace() 22 | q_dist = dist.Laplace() 23 | elif args.dist == 'flow': 24 | prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32) 25 | q_dist = dist.Normal() 26 | 27 | # model 28 | if hasattr(args, 'ncon'): 29 | # InfoGAN 30 | model = infogan.Model( 31 | args.latent_dim, n_con=args.ncon, n_cat=args.ncat, cat_dim=args.cat_dim, use_cuda=True, conv=args.conv) 32 | model.load_state_dict(state_dict, strict=False) 33 | vae = vae_quant.VAE( 34 | z_dim=args.ncon, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv) 35 | vae.encoder = model.encoder 36 | vae.decoder = model.decoder 37 | else: 38 | vae = vae_quant.VAE( 39 | z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv) 40 | vae.load_state_dict(state_dict, strict=False) 41 | 42 | # dataset loader 43 | loader = vae_quant.setup_data_loaders(args) 44 | return vae, loader.dataset, args 45 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def plot_matrix(matrix, title="", row_label="", col_label="", col_to_mark=[], row_to_mark=[], kl_values=None, row_names=[], col_names=[], vmin=0., vmax=1.): 10 | fig = plt.figure() 11 | if kl_values is not None: 12 | ax = fig.add_subplot(2, 1, 1) 13 | ax_kl = fig.add_subplot(2, 1, 2) 14 | else: 15 | ax = fig.add_subplot(1, 1, 1) 16 | ax.matshow(matrix, vmin=vmin, vmax=vmax) 17 | ax.set_title(title) 18 | ax.set_xlabel(col_label) 19 | ax.set_ylabel(row_label) 20 | ax.set_xticklabels([''] + col_names) 21 | ax.set_yticklabels([''] + row_names) 22 | 23 | # Loop over data dimensions and create text annotations. 24 | for i in range(matrix.shape[0]): 25 | for j in range(matrix.shape[1]): 26 | text = "{:.2f}".format(matrix[i, j]) 27 | if i in row_to_mark: 28 | text += "*" 29 | if j in col_to_mark: 30 | text += "*" 31 | ax.text(j, i, text, ha="center", va="center", color="w") 32 | 33 | if kl_values is not None: 34 | ax_kl.matshow(kl_values.reshape(1, -1), vmin=0) 35 | ax_kl.set_xlabel("KL values") 36 | 37 | return fig 38 | 39 | def plot_weighted_adjacency_vs_steps(weighted_adjacency, gt_adjacency, iterations=None): 40 | num_row = weighted_adjacency.shape[1] 41 | num_col = weighted_adjacency.shape[2] 42 | if iterations is None: 43 | iterations = range(weighted_adjacency.shape[0]) 44 | assert weighted_adjacency.shape[0] == len(iterations) 45 | 46 | fig, ax1 = plt.subplots() 47 | 48 | # Plot weight of incorrect edges 49 | for i in range(num_row): 50 | for j in range(num_col): 51 | if gt_adjacency[i, j] == 1: 52 | color = 'g' # correct edge 53 | else: 54 | color = 'r' # incorrect edge 55 | y = weighted_adjacency[:, i, j] 56 | ax1.plot(iterations, y, color, linewidth=1) 57 | 58 | fig.tight_layout() 59 | return fig 60 | 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argon2-cffi==20.1.0 2 | async-generator==1.10 3 | attrs==20.3.0 4 | backcall==0.2.0 5 | bleach==3.3.0 6 | cached-property==1.5.2 7 | certifi==2020.12.5 8 | cffi==1.14.4 9 | chardet==4.0.0 10 | comet-ml==3.3.2 11 | configobj==5.0.6 12 | cycler==0.10.0 13 | decorator==4.4.2 14 | defusedxml==0.6.0 15 | dulwich==0.20.18 16 | entrypoints==0.3 17 | everett==1.0.3 18 | h5py==3.2.1 19 | idna==2.10 20 | importlib-metadata==3.4.0 21 | ipdb==0.13.4 22 | ipykernel==5.4.3 23 | ipython==7.20.0 24 | ipython-genutils==0.2.0 25 | ipywidgets==7.6.3 26 | jedi==0.18.0 27 | Jinja2==2.11.3 28 | joblib==1.0.1 29 | jsonschema==3.2.0 30 | jupyter==1.0.0 31 | jupyter-client==6.1.11 32 | jupyter-console==6.2.0 33 | jupyter-core==4.7.1 34 | jupyter-http-over-ws==0.0.8 35 | jupyterlab-pygments==0.1.2 36 | jupyterlab-widgets==1.0.0 37 | kiwisolver==1.3.1 38 | MarkupSafe==1.1.1 39 | matplotlib==3.3.4 40 | mistune==0.8.4 41 | nbclient==0.5.2 42 | nbconvert==6.0.7 43 | nbformat==5.1.2 44 | nest-asyncio==1.5.1 45 | netifaces==0.10.9 46 | notebook==6.2.0 47 | numpy==1.20.1 48 | nvidia-ml-py3==7.352.0 49 | packaging==20.9 50 | pandas==1.2.2 51 | pandocfilters==1.4.3 52 | parso==0.8.1 53 | pexpect==4.8.0 54 | pickleshare==0.7.5 55 | Pillow==8.1.0 56 | prometheus-client==0.9.0 57 | prompt-toolkit==3.0.15 58 | ptyprocess==0.7.0 59 | pycparser==2.20 60 | Pygments==2.7.4 61 | pynvml==8.0.4 62 | pyparsing==2.4.7 63 | pyrsistent==0.17.3 64 | python-dateutil==2.8.1 65 | pytorch-ignite==0.4.3 66 | pytz==2021.1 67 | pyzmq==22.0.2 68 | qtconsole==5.0.2 69 | QtPy==1.9.0 70 | requests==2.25.1 71 | requests-toolbelt==0.9.1 72 | scikit-learn==0.24.1 73 | scipy==1.6.0 74 | seaborn==0.11.1 75 | Send2Trash==1.5.0 76 | six==1.15.0 77 | synbols==1.0.1 78 | terminado==0.9.2 79 | testpath==0.4.4 80 | threadpoolctl==2.1.0 81 | torch==1.7.1 82 | torchvision==0.8.2 83 | tornado==6.1 84 | tqdm==4.56.2 85 | traitlets==5.0.5 86 | typing-extensions==3.7.4.3 87 | urllib3==1.26.3 88 | wcwidth==0.2.5 89 | webencodings==0.5.1 90 | websocket-client==0.57.0 91 | widgetsnbextension==3.5.1 92 | wrapt==1.12.1 93 | wurlitzer==2.0.1 94 | zipp==3.4.0 95 | qj==0.2.0 96 | git+https://github.com/cooper-org/cooper.git 97 | -------------------------------------------------------------------------------- /baseline_models/icebeem/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | def to_one_hot(x, m=None): 7 | "batch one hot" 8 | if type(x) is not list: 9 | x = [x] 10 | if m is None: 11 | ml = [] 12 | for xi in x: 13 | ml += [xi.max() + 1] 14 | m = max(ml) 15 | dtp = x[0].dtype 16 | xoh = [] 17 | for i, xi in enumerate(x): 18 | xoh += [np.zeros((xi.size, int(m)), dtype=dtp)] 19 | xoh[i][np.arange(xi.size), xi.astype(np.int)] = 1 20 | return xoh 21 | 22 | 23 | def one_hot_encode(labels, n_labels=10): 24 | """ 25 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1. 26 | """ 27 | 28 | assert np.min(labels) >= 0 and np.max(labels) < n_labels 29 | 30 | y = np.zeros([labels.size, n_labels]).astype(np.float32) 31 | y[range(labels.size), labels] = 1 32 | 33 | return y 34 | 35 | 36 | def single_one_hot_encode(label, n_labels=10): 37 | """ 38 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1. 39 | """ 40 | 41 | assert label >= 0 and label < n_labels 42 | 43 | y = np.zeros([n_labels]).astype(np.float32) 44 | y[label] = 1 45 | 46 | return y 47 | 48 | 49 | def single_one_hot_encode_rev(label, n_labels=10, start_label=0): 50 | """ 51 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1. 52 | """ 53 | assert label >= start_label and label < n_labels 54 | y = np.zeros([n_labels - start_label]).astype(np.float32) 55 | y[label - start_label] = 1 56 | return y 57 | 58 | 59 | mnist_one_hot_transform = lambda label: single_one_hot_encode(label, n_labels=10) 60 | contrastive_one_hot_transform = lambda label: single_one_hot_encode(label, n_labels=2) 61 | 62 | 63 | def make_dir(dir_name): 64 | if dir_name[-1] != '/': 65 | dir_name += '/' 66 | if not os.path.exists(dir_name): 67 | os.makedirs(dir_name) 68 | return dir_name 69 | 70 | 71 | def make_file(file_name): 72 | if not os.path.exists(file_name): 73 | open(file_name, 'a').close() 74 | return file_name 75 | -------------------------------------------------------------------------------- /baseline_models/icebeem/models/tcl/tcl_preprocessing.py: -------------------------------------------------------------------------------- 1 | """Preprocessing""" 2 | 3 | import numpy as np 4 | 5 | 6 | # ============================================================ 7 | # ============================================================ 8 | def pca(x, num_comp=None, params=None, zerotolerance=1e-7): 9 | """Apply PCA whitening to data. 10 | Args: 11 | x: data. 2D ndarray [num_comp, num_data] 12 | num_comp: number of components 13 | params: (option) dictionary of PCA parameters {'mean':?, 'W':?, 'A':?}. If given, apply this to the data 14 | zerotolerance: (option) 15 | Returns: 16 | x: whitened data 17 | parms: parameters of PCA 18 | mean: subtracted mean 19 | W: whitening matrix 20 | A: mixing matrix 21 | """ 22 | # print("PCA...") 23 | 24 | # Dimension 25 | if num_comp is None: 26 | num_comp = x.shape[0] 27 | # print(" num_comp={0:d}".format(num_comp)) 28 | 29 | # From learned parameters -------------------------------- 30 | if params is not None: 31 | # Use previously-trained model 32 | print(" use learned value") 33 | data_pca = x - params['mean'] 34 | x = np.dot(params['W'], data_pca) 35 | 36 | # Learn from data ---------------------------------------- 37 | else: 38 | # Zero mean 39 | xmean = np.mean(x, 1).reshape([-1, 1]) 40 | x = x - xmean 41 | 42 | # Eigenvalue decomposition 43 | xcov = np.cov(x) 44 | d, V = np.linalg.eigh(xcov) # Ascending order 45 | # Convert to descending order 46 | d = d[::-1] 47 | V = V[:, ::-1] 48 | 49 | zeroeigval = np.sum((d[:num_comp] / d[0]) < zerotolerance) 50 | if zeroeigval > 0: # Do not allow zero eigenval 51 | raise ValueError 52 | 53 | # Calculate contribution ratio 54 | contratio = np.sum(d[:num_comp]) / np.sum(d) 55 | # print(" contribution ratio={0:f}".format(contratio)) 56 | 57 | # Construct whitening and dewhitening matrices 58 | dsqrt = np.sqrt(d[:num_comp]) 59 | dsqrtinv = 1 / dsqrt 60 | V = V[:, :num_comp] 61 | # Whitening 62 | W = np.dot(np.diag(dsqrtinv), V.transpose()) # whitening matrix 63 | A = np.dot(V, np.diag(dsqrt)) # de-whitening matrix 64 | x = np.dot(W, x) 65 | 66 | params = {'mean': xmean, 'W': W, 'A': A} 67 | 68 | # Check 69 | datacov = np.cov(x) 70 | 71 | return x, params 72 | -------------------------------------------------------------------------------- /baseline_models/icebeem/losses/dsm.py: -------------------------------------------------------------------------------- 1 | ### conditional dsm objective 2 | # 3 | # this code is adapted from: https://github.com/ermongroup/ncsn/ 4 | # 5 | 6 | import torch 7 | import torch.autograd as autograd 8 | 9 | 10 | def dsm(energy_net, samples, sigma=1): 11 | samples.requires_grad_(True) 12 | vector = torch.randn_like(samples) * sigma 13 | perturbed_inputs = samples + vector 14 | logp = -energy_net(perturbed_inputs) 15 | dlogp = sigma ** 2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0] 16 | kernel = vector 17 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2 18 | loss = loss.mean() / 2. 19 | 20 | return loss 21 | 22 | 23 | def cdsm(energy_net, samples, conditions, sigma=1.): 24 | """ 25 | Conditional denoising score matching 26 | :param energy_net: an energy network that takes x and y as input and outputs energy of shape (batch_size,) 27 | :param samples: values of dependent variable x 28 | :param conditions: values of conditioning variable y 29 | :param sigma: noise level for dsm 30 | :return: cdsm loss of shape (batch_size,) 31 | """ 32 | samples.requires_grad_(True) 33 | vector = torch.randn_like(samples) * sigma 34 | perturbed_inputs = samples + vector 35 | logp = -energy_net(perturbed_inputs, conditions) 36 | assert logp.ndim == 1 37 | dlogp = sigma ** 2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0] 38 | kernel = vector 39 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2 40 | loss = loss.mean() / 2. 41 | return loss 42 | 43 | 44 | def conditional_dsm(energy_net, samples, segLabels, energy_net_final_layer, sigma=1): 45 | samples.requires_grad_(True) 46 | vector = torch.randn_like(samples) * sigma 47 | perturbed_inputs = samples + vector 48 | 49 | d = samples.shape[-1] 50 | 51 | # apply conditioning 52 | logp = -energy_net(perturbed_inputs).view(-1, d * d) 53 | logp = torch.mm(logp, energy_net_final_layer) 54 | # take only relevant segment energy 55 | logp = logp[segLabels] 56 | 57 | dlogp = sigma ** 2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0] 58 | kernel = vector 59 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2 60 | loss = loss.mean() / 2. 61 | 62 | return loss 63 | 64 | 65 | def dsm_score_estimation(scorenet, samples, sigma=0.01): 66 | perturbed_samples = samples + torch.randn_like(samples) * sigma 67 | target = - 1 / (sigma ** 2) * (perturbed_samples - samples) 68 | scores = scorenet(perturbed_samples) 69 | target = target.view(target.shape[0], -1) 70 | scores = scores.view(scores.shape[0], -1) 71 | loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1).mean(dim=0) 72 | 73 | return loss 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE files 2 | .idea 3 | 4 | # Byte-compiled / optimized / DLL files 5 | *__pycache__ 6 | 7 | # Docker 8 | Dockerfile 9 | 10 | # roms 11 | roms/* 12 | Roms.rar 13 | 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | cover/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | .pybuilder/ 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | # For a library or package, you might want to ignore these files since the code is 99 | # intended to run in multiple environments; otherwise, check them in: 100 | .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 110 | __pypackages__/ 111 | 112 | # Celery stuff 113 | celerybeat-schedule 114 | celerybeat.pid 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | 146 | # pytype static type analyzer 147 | .pytype/ 148 | 149 | # Cython debug symbols 150 | cython_debug/ 151 | 152 | # pytorch files 153 | *.pt 154 | -------------------------------------------------------------------------------- /baseline_models/icebeem/models/ivae/ivae_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | from torch import optim 7 | from torch.utils.data import DataLoader 8 | 9 | from data.imca import ConditionalDataset 10 | from .ivae_core import iVAE 11 | 12 | 13 | 14 | 15 | def IVAE_wrapper(X, U, latent_dim, batch_size=256, max_iter=7e4, seed=0, n_layers=3, hidden_dim=20, lr=1e-3, cuda=True, 16 | ckpt_folder='ivae.pt', architecture="ivae", logger=None, time_limit=None, learn_decoder_var=False): 17 | " args are the arguments from the main.py file" 18 | torch.manual_seed(seed) 19 | np.random.seed(seed) 20 | 21 | device = torch.device('cuda:0' if cuda else 'cpu') 22 | print('training on {}'.format(torch.cuda.get_device_name(device) if cuda else 'cpu')) 23 | 24 | # load data 25 | # print('Creating shuffled dataset..') 26 | dset = ConditionalDataset(X.astype(np.float32), U.astype(np.float32), device) 27 | loader_params = {'num_workers': 1, 'pin_memory': True, "generator": torch.Generator(device=device)} if cuda else {} 28 | train_loader = DataLoader(dset, shuffle=True, batch_size=batch_size, **loader_params) 29 | data_dim, _, aux_dim = dset.get_dims() 30 | N = len(dset) 31 | max_epochs = int(max_iter // len(train_loader) + 1) 32 | 33 | # define model and optimizer 34 | # print('Defining model and optimizer..') 35 | model = iVAE(latent_dim, data_dim, aux_dim, activation='lrelu', device=device, 36 | n_layers=n_layers, hidden_dim=hidden_dim, architecture=architecture, learn_decoder_var=learn_decoder_var) 37 | optimizer = optim.Adam(model.parameters(), lr=lr) 38 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=20, verbose=True) 39 | # training loop 40 | print("Training..") 41 | 42 | # timer 43 | if time_limit is not None: 44 | time_limit = time_limit * 60 * 60 # convert to seconds 45 | t0 = time.time() 46 | else: 47 | time_limit = np.inf 48 | t0 = 0 49 | 50 | it = 0 51 | model.train() 52 | while it < max_iter and time.time() - t0 < time_limit : 53 | elbo_train = 0 54 | epoch = it // len(train_loader) + 1 55 | for _, (x, u) in enumerate(train_loader): 56 | it += 1 57 | optimizer.zero_grad() 58 | x, u = x.to(device), u.to(device) 59 | elbo, z_est = model.elbo(x, u) 60 | elbo.mul(-1).backward() 61 | optimizer.step() 62 | elbo_train += -elbo.item() 63 | if logger is not None and it%100 == 0: 64 | metrics = {"loss_train": -elbo.item()} 65 | logger.log_metrics(step=it, metrics=metrics) 66 | if it > max_iter or time.time() - t0 > time_limit : 67 | break 68 | elbo_train /= len(train_loader) 69 | scheduler.step(elbo_train) 70 | 71 | #print('epoch {}/{} \tloss: {}'.format(epoch, max_epochs, elbo_train)) 72 | # save model checkpoint after training 73 | torch.save(model.state_dict(), os.path.join(ckpt_folder, 'ivae.pt')) 74 | 75 | return model 76 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/README.md: -------------------------------------------------------------------------------- 1 | # Towards Nonlinear Disentanglement in Natural Data with Temporal Sparse Coding 2 | 3 | This repository contains the code release for: 4 | 5 | **Towards Nonlinear Disentanglement in Natural Data with Temporal Sparse Coding.** 6 | David Klindt*, Lukas Schott*, Yash Sharma*, Ivan Ustyuzhaninov, Wieland Brendel, Matthias Bethge†, Dylan Paiton† 7 | https://arxiv.org/abs/2007.10930 8 | 9 | An example latent traversal using our learned model:
10 | ![Sample traversal](https://github.com/bethgelab/slow_disentanglement/blob/master/latent_factors.gif?raw=true) 11 | 12 | 13 | **Abstract:** We construct an unsupervised learning model that achieves nonlinear disentanglement of underlying factors of variation in naturalistic videos. Previous work suggests that representations can be disentangled if all but a few factors in the environment stay constant at any point in time. As a result, algorithms proposed for this problem have only been tested on carefully constructed datasets with this exact property, leaving it unclear whether they will transfer to natural scenes. Here we provide evidence that objects in segmented natural movies undergo transitions that are typically small in magnitude with occasional large jumps, which is characteristic of a temporally sparse distribution. We leverage this finding and present SlowVAE, a model for unsupervised representation learning that uses a sparse prior on temporally adjacent observations to disentangle generative factors without any assumptions on the number of changing factors. We provide a proof of identifiability and show that the model reliably learns disentangled representations on several established benchmark datasets, often surpassing the current state-of-the-art. We additionally demonstrate transferability towards video datasets with natural dynamics, Natural Sprites and KITTI Masks, which we contribute as benchmarks for guiding disentanglement research towards more natural data domains. 14 | 15 | ### Cite 16 | If you make use of this code in your own work, please cite our paper: 17 | ``` 18 | @inproceedings{klindt2021towards, 19 | title={Towards Nonlinear Disentanglement in Natural Data with Temporal Sparse Coding}, 20 | author={David A. Klindt and Lukas Schott and Yash Sharma and Ivan Ustyuzhaninov and Wieland Brendel and Matthias Bethge and Dylan Paiton}, 21 | booktitle={International Conference on Learning Representations}, 22 | year={2021}, 23 | url={https://openreview.net/forum?id=EbIDjBynYJ8} 24 | } 25 | ``` 26 | 27 | ### Datasets 28 | Our work also contributes two new datasets.
29 | The Natural Sprites dataset can be downloaded here: https://zenodo.org/record/3948069
30 | The KITTI Masks dataset can be downloaded here: https://zenodo.org/record/3931823 31 | 32 | 33 | ### Acknowledgements 34 | 35 | The repository is based on the following [Beta-VAE reproduction](https://github.com/1Konny/Beta-VAE). The MCC metric was adopted from the [Time-Contrastive Learning release](https://github.com/hirosm/TCL). 36 | 37 | ### Contact 38 | 39 | - Maintainers: [David Klindt](https://github.com/david-klindt) & [Lukas Schott](https://github.com/lukas-schott) & [Yash Sharma](https://github.com/ysharma1126) 40 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/data_generation/gen_youtube_csv.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from coco.coco import COCO 4 | import coco.mask as maskUtils 5 | import numpy as np 6 | from scipy import ndimage 7 | from scipy.misc import imresize 8 | from tqdm import tqdm 9 | 10 | def myRange(start,end,step): 11 | i = start 12 | while i < end: 13 | yield i 14 | i += step 15 | yield end 16 | 17 | def main(args, data_dir): 18 | downscale = args.downscale 19 | keepaspectratio = args.keepaspect 20 | stride = args.stride 21 | annIds = coco.getAnnIds() 22 | anns = coco.loadAnns(annIds) 23 | max_len = np.max([len(x['segmentations']) for x in anns]) 24 | file_name = '{}{}{}'.format('downscale_' if args.downscale else '', 25 | 'keepaspect_' if args.keepaspect else '', 26 | 'stride_{}_'.format(args.stride) if args.stride != 32 else '') 27 | with open('{}.csv'.format(file_name.rstrip('_')), mode='w') as movie_file: 28 | movie_writer = csv.writer(movie_file, delimiter=',') 29 | movie_writer.writerow(['id','cat_id'] + ['t_{}'.format(i) for i in range(max_len)]) 30 | for ann in tqdm(anns) if args.verbose else anns: 31 | vals = [] 32 | for seg in ann['segmentations']: 33 | if seg: 34 | rle = maskUtils.frPyObjects([seg], ann['height'], ann['width']) 35 | mask = np.squeeze(maskUtils.decode(rle)) 36 | if mask.any(): 37 | if downscale: 38 | if keepaspectratio: 39 | tr_mask = imresize(mask, (64,128)) 40 | tr_masks = [] 41 | window_idxes = list(myRange(0,64,stride)) 42 | for i in range(len(window_idxes)): 43 | tr_masks.append(tr_mask[:,window_idxes[i]:window_idxes[i]+64]) 44 | else: 45 | tr_masks=[imresize(mask, (64,64))] 46 | else: 47 | tr_masks=[mask] 48 | temp = [] 49 | for tr_mask in tr_masks: 50 | if tr_mask.any(): 51 | com_val = np.array(ndimage.measurements.center_of_mass(tr_mask)).astype(np.float).tolist() 52 | y_val = com_val[0] 53 | x_val = com_val[1] 54 | rle = maskUtils.encode(np.asfortranarray(tr_mask)) 55 | area_val = maskUtils.area(rle).astype(np.int) 56 | temp.append((y_val,x_val,area_val)) 57 | else: 58 | temp.append(None) 59 | vals.append(temp) 60 | else: 61 | vals.append(None) 62 | else: 63 | vals.append(None) 64 | movie_writer.writerow([ann['id'],ann['category_id']] + vals) 65 | 66 | 67 | if __name__ == "__main__": 68 | import argparse 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('--verbose', action='store_true') 71 | parser.add_argument('--stride', type=int, default=32) 72 | parser.add_argument('--downscale', action='store_true') 73 | parser.add_argument('--keepaspect', action='store_true') 74 | args = parser.parse_args() 75 | #download data from https://competitions.codalab.org/competitions/20127#participate-get-data 76 | data_dir = './data/youtube_voc' 77 | annFile= os.path.join(data_dir, 'train.json') 78 | # initialize COCO api for instance annotations 79 | coco=COCO(annFile) 80 | main(args, data_dir) -------------------------------------------------------------------------------- /scripts/produce_all_logs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import math 5 | import os 6 | import random 7 | import shutil 8 | import sys 9 | # sys.path.insert(0, os.path.abspath('..')) 10 | from copy import deepcopy 11 | 12 | import numpy as np 13 | 14 | import time 15 | 16 | def load_hparams(folder_path): 17 | with open(os.path.join(folder_path, 'hparams.json'), 'r') as infile: 18 | opt = json.load(infile) 19 | return opt 20 | # class Bunch: 21 | # def __init__(self, opt): 22 | # self.__dict__.update(opt) 23 | #return Bunch(opt) 24 | 25 | def load_metrics(folder_path): 26 | with open(os.path.join(folder_path, 'log.ndjson'), 'r') as infile: 27 | metrics = json.loads(infile.readlines()[-1]) # reading last line 28 | return metrics 29 | 30 | def main(args=None): 31 | ROOT = '/network/scratch/l/lachaseb/identifiable_latent_causal_model/exp/' 32 | #META_EXPS = {"temporal_ilcm": "e0b63788-0eeb-11ee-8f41-424954050100", 33 | # "temporal_random": "0b4e90b8-0f81-11ee-b8a0-424954050300", 34 | # "temporal_supervised": "e3baec86-0f80-11ee-9b91-424954050300", 35 | # "temporal_tcvae": "25e9291a-0f95-11ee-9cf4-424954050100", 36 | # "temporal_pcl": "81168a28-106f-11ee-bb87-424954050200", 37 | # "temporal_slowvae": "ee3dd866-1071-11ee-ab0d-424954050200", 38 | # "action_ilcm": "4472ac22-1124-11ee-80d1-48df37d42c20", 39 | # "action_random": "f668798e-1124-11ee-86e1-48df37d42c20", 40 | # "action_supervised": "f81f68f4-1125-11ee-9e63-48df37d42c20", 41 | # "action_tcvae": "fc4313ee-1504-11ee-bca8-d8c497b83240", 42 | # "action_ivae": "5d35b610-1119-11ee-bc8f-424954050100", 43 | # } 44 | #META_EXPS = {"temporal_ilcm": "71fee648-4e81-11ee-9556-424954050200", 45 | # "action_ilcm": "3645f1aa-4e81-11ee-b87d-424954050200", 46 | # } 47 | #META_EXPS = {"temporal_ilcm": "462a70b4-5318-11ee-8099-424954050100", 48 | # "action_ilcm": "66143ee0-5332-11ee-b033-424954050100", 49 | # } 50 | #META_EXPS = {"temporal_ilcm": "95b2eb58-53ee-11ee-929b-424954050300", 51 | # "action_ilcm": "3d959e64-5404-11ee-8e54-424954050300", 52 | # } 53 | #META_EXPS = {"temporal_ilcm": "4d5d4c0a-54b0-11ee-be29-424954050100", 54 | # "action_ilcm": "f854193c-54af-11ee-be2a-424954050100", 55 | # } 56 | META_EXPS = {"temporal_ilcm": "2cdcae28-58b9-11ee-befb-424954050100", 57 | "action_ilcm": "77382470-58b9-11ee-a8ff-424954050100", 58 | } 59 | 60 | 61 | 62 | all_logs = [] 63 | for name, meta_exp in META_EXPS.items(): 64 | meta_exp_path = os.path.join(ROOT, meta_exp) 65 | num_exps = 0 66 | for exp in os.listdir(meta_exp_path): 67 | exp_path = os.path.join(meta_exp_path, exp) 68 | if not os.path.isdir(exp_path): 69 | continue 70 | 71 | # verify if run is completed 72 | if not os.path.exists(os.path.join(exp_path, "z_hat_final.npy")): 73 | print(f"{name}: Run {meta_exp}/{exp} is not completed thus excluded from all_logs file") 74 | #log = load_hparams(exp_path) 75 | #print("beta", log["beta"]) 76 | continue 77 | 78 | num_exps += 1 79 | 80 | log = load_hparams(exp_path) 81 | metrics = load_metrics(exp_path) 82 | log.update(metrics) 83 | 84 | all_logs.append(log) 85 | print(f"{name}: Done with {num_exps} experiments.") 86 | 87 | save_path = os.path.join(ROOT, "all_logs_23sep2023_JMLR_const_rand_g.npy") 88 | np.save(save_path, all_logs) 89 | print("Saved to:", save_path) 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/mcc_metric/metric.py: -------------------------------------------------------------------------------- 1 | """Mean Correlation Coefficient from Hyvarinen & Morioka 2 | """ 3 | from absl import logging 4 | from disentanglement_lib.evaluation.metrics import utils 5 | import numpy as np 6 | import gin.tf 7 | import scipy as sp 8 | from mcc_metric.munkres import Munkres 9 | 10 | def correlation(x, y, method='Pearson'): 11 | """Evaluate correlation 12 | Args: 13 | x: data to be sorted 14 | y: target data 15 | Returns: 16 | corr_sort: correlation matrix between x and y (after sorting) 17 | sort_idx: sorting index 18 | x_sort: x after sorting 19 | method: correlation method ('Pearson' or 'Spearman') 20 | """ 21 | 22 | print("Calculating correlation...") 23 | 24 | x = x.copy() 25 | y = y.copy() 26 | dim = x.shape[0] 27 | 28 | # Calculate correlation ----------------------------------- 29 | if method=='Pearson': 30 | corr = np.corrcoef(y, x) 31 | corr = corr[0:dim,dim:] 32 | elif method=='Spearman': 33 | corr, pvalue = sp.stats.spearmanr(y.T, x.T) 34 | corr = corr[0:dim, dim:] 35 | 36 | # Sort ---------------------------------------------------- 37 | munk = Munkres() 38 | indexes = munk.compute(-np.absolute(corr)) 39 | 40 | sort_idx = np.zeros(dim) 41 | x_sort = np.zeros(x.shape) 42 | for i in range(dim): 43 | sort_idx[i] = indexes[i][1] 44 | x_sort[i,:] = x[indexes[i][1],:] 45 | 46 | # Re-calculate correlation -------------------------------- 47 | if method=='Pearson': 48 | corr_sort = np.corrcoef(y, x_sort) 49 | corr_sort = corr_sort[0:dim,dim:] 50 | elif method=='Spearman': 51 | corr_sort, pvalue = sp.stats.spearmanr(y.T, x_sort.T) 52 | corr_sort = corr_sort[0:dim, dim:] 53 | 54 | return corr_sort, sort_idx, x_sort 55 | 56 | 57 | @gin.configurable( 58 | "mcc", 59 | blacklist=["ground_truth_data", "representation_function", "random_state", 60 | "artifact_dir"]) 61 | def compute_mcc(ground_truth_data, 62 | representation_function, 63 | random_state, 64 | artifact_dir=None, 65 | num_train=gin.REQUIRED, 66 | correlation_fn=gin.REQUIRED, 67 | batch_size=16): 68 | """Computes the mean correlation coefficient. 69 | 70 | Args: 71 | ground_truth_data: GroundTruthData to be sampled from. 72 | representation_function: Function that takes observations as input and 73 | outputs a dim_representation sized representation for each observation. 74 | random_state: Numpy random state used for randomness. 75 | artifact_dir: Optional path to directory where artifacts can be saved. 76 | num_train: Number of points used for training. 77 | batch_size: Batch size for sampling. 78 | 79 | Returns: 80 | Dict with mcc stats 81 | """ 82 | del artifact_dir 83 | logging.info("Generating training set.") 84 | mus_train, ys_train = utils.generate_batch_factor_code( 85 | ground_truth_data, representation_function, num_train, 86 | random_state, batch_size) 87 | assert mus_train.shape[1] == num_train 88 | return _compute_mcc(mus_train, ys_train, correlation_fn, random_state) 89 | 90 | 91 | def _compute_mcc(mus_train, ys_train, correlation_fn, random_state): 92 | """Computes score based on both training and testing codes and factors.""" 93 | score_dict = {} 94 | result = np.zeros(mus_train.shape) 95 | result[:ys_train.shape[0],:ys_train.shape[1]] = ys_train 96 | for i in range(len(mus_train) - len(ys_train)): 97 | result[ys_train.shape[0] + i, :] = random_state.normal(size=ys_train.shape[1]) 98 | corr_sorted, sort_idx, mu_sorted = correlation(mus_train, result, method=correlation_fn) 99 | score_dict["meanabscorr"] = np.mean(np.abs(np.diag(corr_sorted)[:len(ys_train)])) 100 | for i in range(len(corr_sorted)): 101 | for j in range(len(corr_sorted[0])): 102 | score_dict["corr_sorted_{}{}".format(i,j)] = corr_sorted[i][j] 103 | for i in range(len(sort_idx)): 104 | score_dict["sort_idx_{}".format(i)] = sort_idx[i] 105 | return score_dict 106 | 107 | -------------------------------------------------------------------------------- /model/nn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torch.nn.utils import spectral_norm as sn 7 | 8 | 9 | class MLP(torch.nn.Module): 10 | def __init__(self, ni, no, nhidden, nlayers, spectral_norm=False, batch_norm=True): 11 | super().__init__() 12 | self.nlayers = nlayers 13 | self.batch_norm = batch_norm 14 | for i in range(nlayers): 15 | if i == 0: 16 | if spectral_norm: 17 | setattr(self, "linear%d" % i, sn(torch.nn.Linear(ni, nhidden, bias=(not batch_norm)))) 18 | else: 19 | setattr(self, "linear%d" % i, torch.nn.Linear(ni, nhidden, bias=(not batch_norm))) 20 | 21 | else: 22 | if spectral_norm: 23 | setattr(self, "linear%d" % i, sn(torch.nn.Linear(nhidden, nhidden, bias=(not batch_norm)))) 24 | else: 25 | setattr(self, "linear%d" % i, torch.nn.Linear(nhidden, nhidden, bias=(not batch_norm))) 26 | if batch_norm: 27 | setattr(self, "bn%d" % i, torch.nn.BatchNorm1d(nhidden)) 28 | if nlayers == 0: 29 | nhidden = ni 30 | self.linear_out = torch.nn.Linear(nhidden, no) 31 | 32 | def forward(self, x): 33 | for i in range(self.nlayers): 34 | linear = getattr(self, "linear%d" % i) 35 | x = linear(x) 36 | if self.batch_norm: 37 | bn = getattr(self, "bn%d" % i) 38 | x = bn(x) 39 | x = F.leaky_relu(x, 0.2, True) 40 | return self.linear_out(x) 41 | 42 | 43 | class ParallelMLP(torch.nn.Module): 44 | def __init__(self, ni, no, nhidden, nlayers, nMLPs, bn=True): 45 | super().__init__() 46 | self.nlayers = nlayers 47 | self.nMLPs = nMLPs 48 | self.bn = bn 49 | for i in range(nlayers): 50 | if i == 0: 51 | setattr(self, "linear%d" % i, ParallelLinear(ni, nhidden, nMLPs, bias=False)) 52 | else: 53 | setattr(self, "linear%d" % i, ParallelLinear(nhidden, nhidden, nMLPs, bias=False)) 54 | if self.bn: 55 | setattr(self, "bn%d" % i, torch.nn.BatchNorm1d(nhidden * nMLPs)) 56 | if nlayers == 0: 57 | nhidden = ni 58 | self.linear_out = ParallelLinear(nhidden, no, nMLPs) 59 | 60 | def forward(self, x): 61 | assert self.nMLPs == x.shape[1] 62 | bs, nMLPs, ni = x.shape 63 | for i in range(self.nlayers): 64 | linear = getattr(self, "linear%d" % i) 65 | x = linear(x) 66 | # this `reshape` instead of `view` call is necessary since x is not contiguous. 67 | if self.bn: 68 | bn = getattr(self, "bn%d" % i) 69 | x = bn(x.reshape(bs, -1)).view(bs, nMLPs, -1) # TODO: should I worry about the copy made by reshape? 70 | x = F.leaky_relu(x, 0.2, True) 71 | return self.linear_out(x) 72 | 73 | 74 | class ParallelLinear(torch.nn.Module): 75 | def __init__(self, in_features, out_features, num_linears, bias=True): 76 | super(ParallelLinear, self).__init__() 77 | self.in_features = in_features 78 | self.out_features = out_features 79 | self.num_linears = num_linears 80 | self.weight = torch.nn.Parameter(torch.Tensor(num_linears, out_features, in_features)) 81 | if bias: 82 | self.bias = torch.nn.Parameter(torch.Tensor(num_linears, out_features)) 83 | else: 84 | self.register_parameter('bias', None) 85 | self.reset_parameters() 86 | 87 | def reset_parameters(self): 88 | for linear in range(self.num_linears): 89 | torch.nn.init.kaiming_uniform_(self.weight[linear], a=math.sqrt(5)) 90 | if self.bias is not None: 91 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight[linear]) 92 | bound = 1 / math.sqrt(fan_in) 93 | torch.nn.init.uniform_(self.bias, -bound, bound) 94 | 95 | def forward(self, input): 96 | # input shape: (bs, num_linears, in_features) 97 | x = torch.einsum("bli,lji->blj", input, self.weight) 98 | if self.bias is not None: 99 | x = x + self.bias 100 | return x # (bs, num_linears, out_features) 101 | 102 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/scripts/evaluate_disentanglement.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import sys 4 | import numpy as np 5 | import gin.tf 6 | gin.enter_interactive_mode() 7 | import time 8 | import os 9 | import json 10 | import traceback 11 | from scripts.model import reparametrize 12 | from scripts.model import BetaVAE_H as BetaVAE 13 | from disentanglement_lib.utils import results 14 | 15 | # needed later: 16 | from disentanglement_lib.evaluation.metrics import beta_vae # pylint: disable=unused-import 17 | from disentanglement_lib.evaluation.metrics import dci # pylint: disable=unused-import 18 | from disentanglement_lib.evaluation.metrics import downstream_task # pylint: disable=unused-import 19 | from disentanglement_lib.evaluation.metrics import factor_vae # pylint: disable=unused-import 20 | from disentanglement_lib.evaluation.metrics import fairness # pylint: disable=unused-import 21 | from disentanglement_lib.evaluation.metrics import irs # pylint: disable=unused-import 22 | from disentanglement_lib.evaluation.metrics import mig # pylint: disable=unused-import 23 | from disentanglement_lib.evaluation.metrics import modularity_explicitness # pylint: disable=unused-import 24 | from disentanglement_lib.evaluation.metrics import reduced_downstream_task # pylint: disable=unused-import 25 | from disentanglement_lib.evaluation.metrics import sap_score # pylint: disable=unused-import 26 | from disentanglement_lib.evaluation.metrics import unsupervised_metrics # pylint: disable=unused-import 27 | 28 | import mcc_metric.metric as mcc # pylint: disable=unused-import 29 | 30 | 31 | 32 | def main(args, dataset): 33 | device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu') 34 | net = BetaVAE 35 | net = net(args.z_dim, args.num_channel, args.pcl).to(device) 36 | file_path = os.path.join(args.ckpt_dir, 'last') 37 | checkpoint = torch.load(file_path) 38 | net.load_state_dict(checkpoint['model_states']['net']) 39 | 40 | def mean_rep(x): 41 | distributions = net._encode( 42 | torch.from_numpy(x).float().to(device)) 43 | mu = distributions[:, :net.z_dim] 44 | logvar = distributions[:, net.z_dim:] 45 | return np.array(mu.detach().cpu()) 46 | 47 | def sample_rep(x): 48 | distributions = net._encode( 49 | torch.from_numpy(x).float().to(device)) 50 | mu = distributions[:, :net.z_dim] 51 | logvar = distributions[:, net.z_dim:] 52 | return np.array(reparametrize(mu, logvar).detach().cpu()) 53 | 54 | @gin.configurable("evaluation") 55 | def evaluate(post, 56 | output_dir, 57 | evaluation_fn=gin.REQUIRED, 58 | random_seed=gin.REQUIRED, 59 | name=""): 60 | experiment_timer = time.time() 61 | assert post == 'mean' or post == 'sampled' 62 | results_dict = evaluation_fn( 63 | dataset, 64 | mean_rep if post == 'mean' else sample_rep, 65 | random_state=np.random.RandomState(random_seed)) 66 | results_dict["elapsed_time"] = time.time() - experiment_timer 67 | results.update_result_directory(output_dir, "evaluation", results_dict) 68 | 69 | random_state = np.random.RandomState(0) 70 | config_dir = 'metric_configs' 71 | eval_config_files = [f for f in os.listdir(config_dir) if not (f.startswith('.') or 'others' in f)] 72 | t0 = time.time() 73 | posts = ['mean'] 74 | for post in posts: 75 | for eval_config in eval_config_files: 76 | metric_name = os.path.basename(eval_config).replace(".gin", "") 77 | continuous = False 78 | if args.dataset == 'kittimasks' or ( 79 | args.dataset == 'natural' and not args.natural_discrete): 80 | continuous = True 81 | if continuous: 82 | if metric_name != 'mcc': 83 | continue 84 | contains = True 85 | if args.specify: 86 | contains = False 87 | for specific in args.specify.split('_'): 88 | if specific in metric_name: 89 | contains = True 90 | break 91 | if contains: 92 | if args.verbose: 93 | print("Computing metric '{}' on '{}'...".format(metric_name, post)) 94 | eval_bindings = [ 95 | "evaluation.random_seed = {}".format(random_state.randint(2 ** 32)), 96 | "evaluation.name = '{}'".format(metric_name) 97 | ] 98 | gin.parse_config_files_and_bindings( 99 | [os.path.join(config_dir, eval_config)], eval_bindings) 100 | output_dir = os.path.join( 101 | args.output_dir, 'evaluation', args.ckpt_name, post, metric_name) 102 | evaluate(post, output_dir) 103 | gin.clear_config() 104 | if args.verbose: 105 | print('took', time.time() - t0, 's') 106 | t0 = time.time() 107 | -------------------------------------------------------------------------------- /baseline_models/icebeem/models/icebeem_wrapper.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import sys 4 | import pathlib 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from sklearn.decomposition import FastICA 10 | from torch.distributions import Uniform, TransformedDistribution, SigmoidTransform 11 | 12 | from losses.fce import ConditionalFCE 13 | from .nets import MLP 14 | from .nflib.flows import NormalizingFlowModel, Invertible1x1Conv, ActNorm 15 | from .nflib.spline_flows import NSF_AR 16 | 17 | sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent.parent.parent)) 18 | from disentanglement_via_mechanism_sparsity.model.nn import MLP as MLP_ilcm 19 | 20 | 21 | def ICEBEEM_wrapper(X, Y, ebm_hidden_size, n_layers_ebm, n_layers_flow, lr_flow, lr_ebm, seed, 22 | ckpt_file='icebeem.pt', test=False): 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | data_dim = X.shape[1] 26 | 27 | model_ebm = MLP(input_size=data_dim, hidden_size=[ebm_hidden_size] * n_layers_ebm, 28 | n_layers=n_layers_ebm, output_size=data_dim, use_bn=True, 29 | activation_function=F.leaky_relu) 30 | 31 | prior = TransformedDistribution(Uniform(torch.zeros(data_dim), torch.ones(data_dim)), 32 | SigmoidTransform().inv) 33 | nfs_flow = NSF_AR 34 | flows = [nfs_flow(dim=data_dim, K=8, B=3, hidden_dim=16) for _ in range(n_layers_flow)] 35 | convs = [Invertible1x1Conv(dim=data_dim) for _ in flows] 36 | norms = [ActNorm(dim=data_dim) for _ in flows] 37 | flows = list(itertools.chain(*zip(norms, convs, flows))) 38 | # construct the model 39 | model_flow = NormalizingFlowModel(prior, flows) 40 | 41 | pretrain_flow = True 42 | augment_ebm = True 43 | 44 | # instantiate ebmFCE object 45 | fce_ = ConditionalFCE(data=X.astype(np.float32), segments=Y.astype(np.float32), 46 | energy_MLP=model_ebm, flow_model=model_flow, verbose=False) 47 | 48 | init_ckpt_file = os.path.splitext(ckpt_file)[0] + '_0' + os.path.splitext(ckpt_file)[1] 49 | if not test: 50 | if pretrain_flow: 51 | # print('pretraining flow model..') 52 | fce_.pretrain_flow_model(epochs=1, lr=1e-4) 53 | # print('pretraining done.') 54 | 55 | # first we pretrain the final layer of EBM model (this is g(y) as it depends on segments) 56 | fce_.train_ebm_fce(epochs=15, augment=augment_ebm, finalLayerOnly=True, cutoff=.5) 57 | 58 | # then train full EBM via NCE with flow contrastive noise: 59 | fce_.train_ebm_fce(epochs=50, augment=augment_ebm, cutoff=.5, useVAT=False) 60 | 61 | torch.save({'ebm_mlp': fce_.energy_MLP.state_dict(), 62 | 'ebm_finalLayer': fce_.ebm_finalLayer, 63 | 'flow': fce_.flow_model.state_dict()}, init_ckpt_file) 64 | else: 65 | state = torch.load(init_ckpt_file, map_location=fce_.device) 66 | fce_.energy_MLP.load_state_dict(state['ebm_mlp']) 67 | fce_.ebm_finalLayer = state['ebm_finalLayer'] 68 | fce_.flow_model.load_stat_dict(state['flow']) 69 | 70 | # evaluate recovery of latents 71 | recov = fce_.unmixSamples(X, modelChoice='ebm') 72 | source_est_ica = FastICA().fit_transform((recov)) 73 | recov_sources = [source_est_ica] 74 | 75 | # iterate between updating noise and tuning the EBM 76 | eps = .025 77 | for iter_ in range(3): 78 | mid_ckpt_file = os.path.splitext(ckpt_file)[0] + '_' + str(iter_ + 1) + os.path.splitext(ckpt_file)[1] 79 | if not test: 80 | # update flow model: 81 | fce_.train_flow_fce(epochs=5, objConstant=-1., cutoff=.5 - eps, lr=lr_flow) 82 | # update energy based model: 83 | fce_.train_ebm_fce(epochs=50, augment=augment_ebm, cutoff=.5 + eps, lr=lr_ebm, useVAT=False) 84 | 85 | torch.save({'ebm_mlp': fce_.energy_MLP.state_dict(), 86 | 'ebm_finalLayer': fce_.ebm_finalLayer, 87 | 'flow': fce_.flow_model.state_dict()}, mid_ckpt_file) 88 | else: 89 | state = torch.load(mid_ckpt_file, map_location=fce_.device) 90 | fce_.energy_MLP.load_state_dict(state['ebm_mlp']) 91 | fce_.ebm_finalLayer = state['ebm_finalLayer'] 92 | fce_.flow_model.load_stat_dict(state['flow']) 93 | 94 | # evaluate recovery of latents 95 | recov = fce_.unmixSamples(X, modelChoice='ebm') 96 | source_est_ica = FastICA().fit_transform((recov)) 97 | recov_sources.append(source_est_ica) 98 | 99 | return recov_sources 100 | -------------------------------------------------------------------------------- /baseline_models/icebeem/models/tcl/tcl_eval.py: -------------------------------------------------------------------------------- 1 | """ Fuctions for evaluation 2 | This software includes the work that is distributed in the Apache License 2.0 3 | """ 4 | 5 | import sys 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from sklearn.metrics import confusion_matrix 10 | 11 | 12 | # ============================================================= 13 | # ============================================================= 14 | def get_tensor(x, vars, sess, data_holder, batch=256): 15 | """Get tensor data . 16 | Args: 17 | x: input data [Ndim, Ndata] 18 | vars: tensors (list) 19 | sess: session 20 | data_holder: data holder 21 | batch: batch size 22 | Returns: 23 | y: value of tensors 24 | """ 25 | 26 | Ndata = x.shape[1] 27 | if batch is None: 28 | Nbatch = Ndata 29 | else: 30 | Nbatch = batch 31 | Niter = int(np.ceil(Ndata / Nbatch)) 32 | 33 | if not isinstance(vars, list): 34 | vars = [vars] 35 | 36 | # Convert names to tensors (if necessary) ----------------- 37 | for i in range(len(vars)): 38 | if not tf.is_numeric_tensor(vars[i]) and isinstance(vars[i], str): 39 | vars[i] = tf.get_default_graph().get_tensor_by_name(vars[i]) 40 | 41 | # Start batch-inputs -------------------------------------- 42 | y = {} 43 | for iter in range(Niter): 44 | 45 | sys.stdout.write('\r>> Getting tensors... %d/%d' % (iter + 1, Niter)) 46 | sys.stdout.flush() 47 | 48 | # Get batch ------------------------------------------- 49 | batchidx = np.arange(Nbatch * iter, np.minimum(Nbatch * (iter + 1), Ndata)) 50 | xbatch = x[:, batchidx].T 51 | 52 | # Get tensor data ------------------------------------- 53 | feed_dict = {data_holder: xbatch} 54 | ybatch = sess.run(vars, feed_dict=feed_dict) 55 | 56 | # Storage 57 | for tn in range(len(ybatch)): 58 | # Initialize 59 | if iter == 0: 60 | y[tn] = np.zeros([Ndata] + list(ybatch[tn].shape[1:]), dtype=np.float32) 61 | # Store 62 | y[tn][batchidx,] = ybatch[tn] 63 | 64 | sys.stdout.write('\r\n') 65 | 66 | return y 67 | 68 | 69 | # ============================================================= 70 | # ============================================================= 71 | def calc_accuracy(pred, label, normalize_confmat=True): 72 | """ Calculate accuracy and confusion matrix 73 | Args: 74 | pred: [Ndata x Nlabel] 75 | label: [Ndata x Nlabel] 76 | Returns: 77 | accuracy: accuracy 78 | conf: confusion matrix 79 | """ 80 | 81 | # print("Calculating accuracy...") 82 | 83 | # Accuracy ------------------------------------------------ 84 | correctflag = pred.reshape(-1) == label.reshape(-1) 85 | accuracy = np.mean(correctflag) 86 | 87 | # Confusion matrix ---------------------------------------- 88 | conf = confusion_matrix(label[:], pred[:]).astype(np.float32) 89 | # Normalization 90 | if normalize_confmat: 91 | for i in range(conf.shape[0]): 92 | conf[i, :] = conf[i, :] / np.sum(conf[i, :]) 93 | 94 | return accuracy, conf 95 | 96 | # ============================================================= 97 | # ============================================================= 98 | # def correlation(x, y, method='Pearson'): 99 | # """Evaluate correlation 100 | # Args: 101 | # x: data to be sorted 102 | # y: target data 103 | # Returns: 104 | # corr_sort: correlation matrix between x and y (after sorting) 105 | # sort_idx: sorting index 106 | # x_sort: x after sorting 107 | # """ 108 | # 109 | # print("Calculating correlation...") 110 | # 111 | # x = x.copy() 112 | # y = y.copy() 113 | # dim = x.shape[0] 114 | # 115 | # # Calculate correlation ----------------------------------- 116 | # if method=='Pearson': 117 | # corr = np.corrcoef(y, x) 118 | # corr = corr[0:dim,dim:] 119 | # elif method=='Spearman': 120 | # corr, pvalue = sp.stats.spearmanr(y.T, x.T) 121 | # corr = corr[0:dim, dim:] 122 | # 123 | # # Sort ---------------------------------------------------- 124 | # munk = Munkres() 125 | # indexes = munk.compute(-np.absolute(corr)) 126 | # 127 | # sort_idx = np.zeros(dim) 128 | # x_sort = np.zeros(x.shape) 129 | # for i in range(dim): 130 | # sort_idx[i] = indexes[i][1] 131 | # x_sort[i,:] = x[indexes[i][1],:] 132 | # 133 | # # Re-calculate correlation -------------------------------- 134 | # if method=='Pearson': 135 | # corr_sort = np.corrcoef(y, x_sort) 136 | # corr_sort = corr_sort[0:dim,dim:] 137 | # elif method=='Spearman': 138 | # corr_sort, pvalue = sp.stats.spearmanr(y.T, x_sort.T) 139 | # corr_sort = corr_sort[0:dim, dim:] 140 | # 141 | # return corr_sort, sort_idx, x_sort 142 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/data_generation/coco/mask.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tsungyi' 2 | 3 | import _mask as _mask 4 | 5 | # Interface for manipulating masks stored in RLE format. 6 | # 7 | # RLE is a simple yet efficient format for storing binary masks. RLE 8 | # first divides a vector (or vectorized image) into a series of piecewise 9 | # constant regions and then for each piece simply stores the length of 10 | # that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would 11 | # be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1] 12 | # (note that the odd counts are always the numbers of zeros). Instead of 13 | # storing the counts directly, additional compression is achieved with a 14 | # variable bitrate representation based on a common scheme called LEB128. 15 | # 16 | # Compression is greatest given large piecewise constant regions. 17 | # Specifically, the size of the RLE is proportional to the number of 18 | # *boundaries* in M (or for an image the number of boundaries in the y 19 | # direction). Assuming fairly simple shapes, the RLE representation is 20 | # O(sqrt(n)) where n is number of pixels in the object. Hence space usage 21 | # is substantially lower, especially for large simple objects (large n). 22 | # 23 | # Many common operations on masks can be computed directly using the RLE 24 | # (without need for decoding). This includes computations such as area, 25 | # union, intersection, etc. All of these operations are linear in the 26 | # size of the RLE, in other words they are O(sqrt(n)) where n is the area 27 | # of the object. Computing these operations on the original mask is O(n). 28 | # Thus, using the RLE can result in substantial computational savings. 29 | # 30 | # The following API functions are defined: 31 | # encode - Encode binary masks using RLE. 32 | # decode - Decode binary masks encoded via RLE. 33 | # merge - Compute union or intersection of encoded masks. 34 | # iou - Compute intersection over union between masks. 35 | # area - Compute area of encoded masks. 36 | # toBbox - Get bounding boxes surrounding encoded masks. 37 | # frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask. 38 | # 39 | # Usage: 40 | # Rs = encode( masks ) 41 | # masks = decode( Rs ) 42 | # R = merge( Rs, intersect=false ) 43 | # o = iou( dt, gt, iscrowd ) 44 | # a = area( Rs ) 45 | # bbs = toBbox( Rs ) 46 | # Rs = frPyObjects( [pyObjects], h, w ) 47 | # 48 | # In the API the following formats are used: 49 | # Rs - [dict] Run-length encoding of binary masks 50 | # R - dict Run-length encoding of binary mask 51 | # masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order) 52 | # iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore 53 | # bbs - [nx4] Bounding box(es) stored as [x y w h] 54 | # poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list) 55 | # dt,gt - May be either bounding boxes or encoded masks 56 | # Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 57 | # 58 | # Finally, a note about the intersection over union (iou) computation. 59 | # The standard iou of a ground truth (gt) and detected (dt) object is 60 | # iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt)) 61 | # For "crowd" regions, we use a modified criteria. If a gt object is 62 | # marked as "iscrowd", we allow a dt to match any subregion of the gt. 63 | # Choosing gt' in the crowd gt that best matches the dt can be done using 64 | # gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing 65 | # iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt) 66 | # For crowd gt regions we use this modified criteria above for the iou. 67 | # 68 | # To compile run "python setup.py build_ext --inplace" 69 | # Please do not contact us for help with compiling. 70 | # 71 | # Microsoft COCO Toolbox. version 2.0 72 | # Data, paper, and tutorials available at: http://mscoco.org/ 73 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 74 | # Licensed under the Simplified BSD License [see coco/license.txt] 75 | 76 | iou = _mask.iou 77 | merge = _mask.merge 78 | frPyObjects = _mask.frPyObjects 79 | 80 | def encode(bimask): 81 | if len(bimask.shape) == 3: 82 | return _mask.encode(bimask) 83 | elif len(bimask.shape) == 2: 84 | h, w = bimask.shape 85 | return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0] 86 | 87 | def decode(rleObjs): 88 | if type(rleObjs) == list: 89 | return _mask.decode(rleObjs) 90 | else: 91 | return _mask.decode([rleObjs])[:,:,0] 92 | 93 | def area(rleObjs): 94 | if type(rleObjs) == list: 95 | return _mask.area(rleObjs) 96 | else: 97 | return _mask.area([rleObjs])[0] 98 | 99 | def toBbox(rleObjs): 100 | if type(rleObjs) == list: 101 | return _mask.toBbox(rleObjs) 102 | else: 103 | return _mask.toBbox([rleObjs])[0] -------------------------------------------------------------------------------- /baseline_models/icebeem/models/ebm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from .nets import CleanMLP 6 | 7 | 8 | class UnnormalizedConditialEBM(nn.Module): 9 | def __init__(self, input_size, hidden_size, n_hidden, output_size, condition_size, activation='lrelu', 10 | augment=False, positive=False): 11 | super().__init__() 12 | 13 | self.input_size = input_size 14 | self.output_size = output_size 15 | self.hidden_size = hidden_size 16 | self.cond_size = condition_size 17 | self.n_hidden = n_hidden 18 | self.activation = activation 19 | self.augment = augment 20 | self.positive = positive 21 | 22 | self.f = CleanMLP(input_size, hidden_size, n_hidden, output_size, activation=activation) 23 | self.g = nn.Linear(condition_size, output_size, bias=False) 24 | 25 | def forward(self, x, y): 26 | fx, gy = self.f(x).view(-1, self.output_size), self.g(y) 27 | 28 | if self.positive: 29 | fx = F.relu(fx) 30 | gy = F.relu(gy) 31 | 32 | if self.augment: 33 | return torch.einsum('bi,bi->b', [fx, gy]) + torch.einsum('bi,bi->b', [fx.pow(2), gy.pow(2)]) 34 | 35 | else: 36 | return torch.einsum('bi,bi->b', [fx, gy]) 37 | 38 | 39 | class ModularUnnormalizedConditionalEBM(nn.Module): 40 | def __init__(self, f_net, g_net, augment=False, positive=False): 41 | super().__init__() 42 | 43 | assert f_net.output_size == g_net.output_size 44 | 45 | self.input_size = f_net.input_size 46 | self.output_size = f_net.output_size 47 | self.cond_size = g_net.input_size 48 | self.augment = augment 49 | self.positive = positive 50 | 51 | self.f = f_net 52 | self.g = g_net 53 | 54 | def forward(self, x, y): 55 | fx, gy = self.f(x).view(-1, self.output_size), self.g(y) 56 | 57 | if self.positive: 58 | fx = F.relu(fx) 59 | gy = F.relu(gy) 60 | 61 | if self.augment: 62 | return torch.einsum('bi,bi->b', [fx, gy]) + torch.einsum('bi,bi->b', [fx.pow(2), gy.pow(2)]) 63 | 64 | else: 65 | return torch.einsum('bi,bi->b', [fx, gy]) 66 | 67 | 68 | class ConditionalEBM(UnnormalizedConditialEBM): 69 | def __init__(self, input_size, hidden_size, n_hidden, output_size, condition_size, activation='lrelu'): 70 | super().__init__(input_size, hidden_size, n_hidden, output_size, condition_size, activation) 71 | 72 | self.log_norm = nn.Parameter(torch.randn(1) - 5, requires_grad=True) 73 | 74 | def forward(self, x, y, augment=True, positive=False): 75 | return super().forward(x, y, augment, positive) + self.log_norm 76 | 77 | 78 | class ModularConditionalEBM(ModularUnnormalizedConditionalEBM): 79 | def __init__(self, f_net, g_net): 80 | super().__init__(f_net, g_net) 81 | 82 | self.log_norm = nn.Parameter(torch.randn(1) - 5, requires_grad=True) 83 | 84 | def forward(self, x, y, augment=True, positive=False): 85 | return super().forward(x, y, augment, positive) + self.log_norm 86 | 87 | 88 | class UnnormalizedEBM(nn.Module): 89 | def __init__(self, input_size, hidden_size, n_hidden, output_size, activation='lrelu'): 90 | super().__init__() 91 | 92 | self.input_size = input_size 93 | self.output_size = output_size 94 | self.hidden_size = hidden_size 95 | self.n_hidden = n_hidden 96 | self.activation = activation 97 | 98 | self.f = CleanMLP(input_size, hidden_size, n_hidden, output_size, activation=activation) 99 | self.g = torch.ones(output_size) 100 | 101 | def forward(self, x, y=None): 102 | fx = self.f(x).view(-1, self.output_size) 103 | return torch.einsum('bi,i->b', [fx, self.g]) 104 | 105 | 106 | class ModularUnnormalizedEBM(nn.Module): 107 | def __init__(self, f_net): 108 | super().__init__() 109 | 110 | self.input_size = f_net.input_size 111 | self.output_size = f_net.output_size 112 | 113 | self.f = f_net 114 | self.g = torch.ones(self.output_size) 115 | 116 | def forward(self, x, y=None): 117 | fx = self.f(x).view(-1, self.output_size) 118 | return torch.einsum('bi,i->b', [fx, self.g]) 119 | 120 | 121 | class EBM(UnnormalizedEBM): 122 | def __init__(self, input_size, hidden_size, n_hidden, output_size, activation='lrelu'): 123 | super().__init__(input_size, hidden_size, n_hidden, output_size, activation) 124 | 125 | self.log_norm = nn.Parameter(torch.randn(1) - 5, requires_grad=True) 126 | 127 | def forward(self, x, y=None): 128 | return super().forward(x, y) + self.log_norm 129 | 130 | 131 | class ModularEBM(ModularUnnormalizedEBM): 132 | def __init__(self, f_net): 133 | super().__init__(f_net) 134 | 135 | self.log_norm = nn.Parameter(torch.randn(1) - 5, requires_grad=True) 136 | 137 | def forward(self, x, y=None): 138 | return super().forward(x, y) + self.log_norm 139 | -------------------------------------------------------------------------------- /baseline_models/icebeem/models/tcl/tcl_wrapper_gpu.py: -------------------------------------------------------------------------------- 1 | ### Wrapper function for TCL 2 | # 3 | # this code is adapted from: https://github.com/hirosm/TCL 4 | # 5 | # 6 | import os 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from sklearn.decomposition import FastICA 11 | 12 | from .tcl_core import inference 13 | from .tcl_core import train_gpu as train 14 | from .tcl_eval import get_tensor, calc_accuracy 15 | from .tcl_preprocessing import pca 16 | 17 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 18 | 19 | 20 | def TCL_wrapper(sensor, label, list_hidden_nodes, random_seed=0, max_steps=int(7e4), max_steps_init=int(7e4), 21 | ckpt_dir='./', test=False): 22 | # Training ---------------------------------------------------- 23 | initial_learning_rate = 0.01 # initial learning rate 24 | momentum = 0.9 # momentum parameter of SGD 25 | # max_steps = int(7e4) # number of iterations (mini-batches) 26 | decay_steps = int(5e4) # decay steps (tf.train.exponential_decay) 27 | decay_factor = 0.1 # decay factor (tf.train.exponential_decay) 28 | batch_size = 512 # mini-batch size 29 | moving_average_decay = 0.9999 # moving average decay of variables to be saved 30 | checkpoint_steps = 1e5 # interval to save checkpoint 31 | num_comp = sensor.shape[0] 32 | 33 | # for MLR initialization 34 | decay_steps_init = int(5e4) # decay steps for initializing only MLR 35 | 36 | # Other ------------------------------------------------------- 37 | train_dir = ckpt_dir # save directory 38 | 39 | num_segment = len(np.unique(label)) 40 | 41 | # Preprocessing ----------------------------------------------- 42 | sensor, pca_parm = pca(sensor, num_comp=num_comp) 43 | 44 | if not test: 45 | # Train model (only MLR) -------------------------------------- 46 | train(sensor, 47 | label, 48 | num_class=len(np.unique(label)), # num_segment, 49 | list_hidden_nodes=list_hidden_nodes, 50 | initial_learning_rate=initial_learning_rate, 51 | momentum=momentum, 52 | max_steps=max_steps_init, # For init 53 | decay_steps=decay_steps_init, # For init 54 | decay_factor=decay_factor, 55 | batch_size=batch_size, 56 | train_dir=train_dir, 57 | checkpoint_steps=checkpoint_steps, 58 | moving_average_decay=moving_average_decay, 59 | MLP_trainable=False, # For init 60 | save_file='model_init.ckpt', # For init 61 | random_seed=random_seed) 62 | 63 | init_model_path = os.path.join(train_dir, 'model_init.ckpt') 64 | 65 | # Train model ------------------------------------------------- 66 | train(sensor, 67 | label, 68 | num_class=len(np.unique(label)), # num_segment, 69 | list_hidden_nodes=list_hidden_nodes, 70 | initial_learning_rate=initial_learning_rate, 71 | momentum=momentum, 72 | max_steps=max_steps, 73 | decay_steps=decay_steps, 74 | decay_factor=decay_factor, 75 | batch_size=batch_size, 76 | train_dir=train_dir, 77 | checkpoint_steps=checkpoint_steps, 78 | moving_average_decay=moving_average_decay, 79 | load_file=init_model_path, 80 | random_seed=random_seed) 81 | 82 | # now that we have trained everything, we can evaluate results: 83 | eval_dir = ckpt_dir 84 | ckpt = tf.train.get_checkpoint_state(eval_dir) 85 | 86 | with tf.Graph().as_default(): 87 | data_holder = tf.placeholder(tf.float32, shape=[None, sensor.shape[0]], name='data') 88 | 89 | # Build a Graph that computes the logits predictions from the 90 | # inference model. 91 | logits, feats = inference(data_holder, list_hidden_nodes, num_class=num_segment) 92 | 93 | # Calculate predictions. 94 | top_value, preds = tf.nn.top_k(logits, k=1, name='preds') 95 | 96 | # Restore the moving averaged version of the learned variables for eval. 97 | variable_averages = tf.train.ExponentialMovingAverage(moving_average_decay) 98 | variables_to_restore = variable_averages.variables_to_restore() 99 | saver = tf.train.Saver(variables_to_restore) 100 | 101 | with tf.Session() as sess: 102 | saver.restore(sess, ckpt.model_checkpoint_path) 103 | 104 | tensor_val = get_tensor(sensor, [preds, feats], sess, data_holder, batch=256) 105 | pred_val = tensor_val[0].reshape(-1) 106 | feat_val = tensor_val[1] 107 | 108 | # Calculate accuracy ------------------------------------------ 109 | accuracy, confmat = calc_accuracy(pred_val, label) 110 | 111 | # Apply fastICA ----------------------------------------------- 112 | ica = FastICA(random_state=random_seed) 113 | feat_val_ica = ica.fit_transform(feat_val) 114 | 115 | feat_val_ica = feat_val_ica.T # Estimated feature 116 | feat_val = feat_val.T 117 | 118 | return feat_val, feat_val_ica, accuracy 119 | -------------------------------------------------------------------------------- /model/gumbel_masks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | class GumbelSigmoid(torch.nn.Module): 6 | def __init__(self, shape, freeze=False, drawhard=True, tau=1, one_gumbel_sample=False): 7 | super(GumbelSigmoid, self).__init__() 8 | self.shape = shape 9 | self.freeze = freeze 10 | self.drawhard = drawhard 11 | self.log_alpha = torch.nn.Parameter(torch.zeros(self.shape)) 12 | self.tau = tau 13 | self.one_sample_per_batch = one_gumbel_sample 14 | # useful to make sure these parameters will be pushed to the GPU 15 | self.uniform = torch.distributions.uniform.Uniform(0, 1) 16 | self.register_buffer("fixed_mask", torch.ones(shape)) 17 | self.reset_parameters() 18 | 19 | def forward(self, bs): 20 | if self.freeze: 21 | y = self.fixed_mask.unsqueeze(0).expand((bs,) + self.shape) 22 | return y 23 | else: 24 | shape = tuple([bs] + list(self.shape)) 25 | logistic_noise = self.sample_logistic(shape).type(self.log_alpha.type()).to(self.log_alpha.device) 26 | y_soft = torch.sigmoid((self.log_alpha + logistic_noise) / self.tau) 27 | 28 | if self.drawhard: 29 | y_hard = (y_soft > 0.5).type(y_soft.type()) 30 | 31 | # This weird line does two things: 32 | # 1) at forward, we get a hard sample. 33 | # 2) at backward, we differentiate the gumbel sigmoid 34 | y = y_hard.detach() - y_soft.detach() + y_soft 35 | 36 | else: 37 | y = y_soft 38 | 39 | return y 40 | 41 | def get_proba(self): 42 | """Returns probability of getting one""" 43 | if self.freeze: 44 | return self.fixed_mask 45 | else: 46 | return torch.sigmoid(self.log_alpha) 47 | 48 | def reset_parameters(self): 49 | torch.nn.init.constant_(self.log_alpha, 5) # 5) # will yield a probability ~0.99. Inspired by DCDI 50 | 51 | def sample_logistic(self, shape): 52 | if self.one_sample_per_batch: 53 | bs = shape[0] 54 | u = self.uniform.sample([1] + list(self.shape)).expand((bs,) + self.shape) 55 | return torch.log(u) - torch.log(1 - u) 56 | else: 57 | u = self.uniform.sample(shape) 58 | return torch.log(u) - torch.log(1 - u) 59 | 60 | def threshold(self): 61 | proba = self.get_proba() 62 | self.fixed_mask.copy_((proba > 0.5).type(proba.type())) 63 | self.freeze = True 64 | 65 | 66 | class LouizosGumbelSigmoid(torch.nn.Module): 67 | """My implementation of https://openreview.net/pdf?id=H1Y8hhg0b""" 68 | def __init__(self, shape, freeze=False, tau=1, gamma=-0.1, zeta=1.1): 69 | super(LouizosGumbelSigmoid, self).__init__() 70 | self.shape = shape 71 | self.freeze=freeze 72 | self.log_alpha = torch.nn.Parameter(torch.zeros(self.shape)) 73 | self.tau = tau 74 | assert gamma < 0 and zeta > 1 75 | self.gamma = gamma 76 | self.zeta = zeta 77 | # useful to make sure these parameters will be pushed to the GPU 78 | self.uniform = torch.distributions.uniform.Uniform(0, 1) 79 | self.register_buffer("fixed_mask", torch.ones(shape)) 80 | self.reset_parameters() 81 | 82 | def forward(self, bs): 83 | if self.freeze: 84 | y = self.fixed_mask.unsqueeze(0).expand((bs,) + self.shape) 85 | return y 86 | else: 87 | shape = tuple([bs] + list(self.shape)) 88 | logistic_noise = self.sample_logistic(shape).type(self.log_alpha.type()).to(self.log_alpha.device) 89 | y_soft = torch.sigmoid((self.log_alpha + logistic_noise) / self.tau) 90 | y_soft = y_soft * (self.zeta - self.gamma) + self.gamma 91 | one = torch.ones((1,)).to(self.log_alpha.device) 92 | zero = torch.zeros((1,)).to(self.log_alpha.device) 93 | y_soft = torch.minimum(one , torch.maximum(zero, y_soft)) 94 | 95 | return y_soft 96 | 97 | def get_proba(self): 98 | """Returns probability of mask being > 0""" 99 | if self.freeze: 100 | return self.fixed_mask 101 | else: 102 | return torch.sigmoid(self.log_alpha - self.tau * (math.log(-self.gamma) - math.log(self.zeta))) 103 | 104 | def reset_parameters(self): 105 | #torch.nn.init.constant_(self.log_alpha, 5) # 5) # will yield a probability ~0.99. Inspired by DCDI 106 | torch.nn.init.constant_(self.log_alpha, self.tau * (math.log(1 - self.gamma) - math.log(self.zeta - 1))) # at init, half the samples will be exactly one. 107 | # general formula is p(M = 1) = sigmoid(log_alpha - beta (log(1-gamma) - log(zeta - 1)) 108 | print(f"initialized so that P(mask != 0) = {self.get_proba().view(-1)[0]}") 109 | 110 | def sample_logistic(self, shape): 111 | u = self.uniform.sample(shape) 112 | return torch.log(u) - torch.log(1 - u) 113 | 114 | def threshold(self): 115 | proba = self.get_proba() 116 | self.fixed_mask.copy_((proba > 0.5).type(proba.type())) 117 | self.freeze = True 118 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/scripts/model.py: -------------------------------------------------------------------------------- 1 | """model.py""" 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | from torch.autograd import Variable 10 | import json 11 | import numpy as np 12 | 13 | sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent.parent)) 14 | from model.nn import MLP as MLP_ilcm 15 | 16 | 17 | def reparametrize(mu, logvar): 18 | std = logvar.div(2).exp() + 1e-6 19 | eps = Variable(std.data.new(std.size()).normal_()) 20 | return mu + std*eps 21 | 22 | def compute_kl(z_1, z_2, logvar_1, logvar_2): 23 | var_1 = logvar_1.exp() + 1e-6 24 | var_2 = logvar_2.exp() + 1e-6 25 | return var_1/var_2 + ((z_2-z_1)**2)/var_2 - 1 + logvar_2 - logvar_1 26 | 27 | 28 | class View(nn.Module): 29 | def __init__(self, size): 30 | super(View, self).__init__() 31 | self.size = size 32 | 33 | def forward(self, tensor): 34 | return tensor.view(self.size) 35 | 36 | 37 | class BetaVAE_H(nn.Module): 38 | """Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017).""" 39 | 40 | def __init__(self, z_dim=10, nc=3, pcl=False, architecture="standard_conv", image_shape=None): 41 | super(BetaVAE_H, self).__init__() 42 | self.pcl = pcl 43 | self.z_dim = z_dim 44 | self.nc = nc 45 | self.architecture = architecture 46 | self.image_shape = image_shape 47 | if self.architecture == "standard_conv": 48 | assert self.image_shape == (nc, 64, 64) 49 | self.encoder = nn.Sequential( 50 | nn.Conv2d(nc, 32, 4, 2, 1), # B, 32, 32, 32 51 | nn.ReLU(True), 52 | nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 16, 16 53 | nn.ReLU(True), 54 | nn.Conv2d(32, 64, 4, 2, 1), # B, 64, 8, 8 55 | nn.ReLU(True), 56 | nn.Conv2d(64, 64, 4, 2, 1), # B, 64, 4, 4 57 | nn.ReLU(True), 58 | nn.Conv2d(64, 256, 4, 1), # B, 256, 1, 1 59 | nn.ReLU(True), 60 | View((-1, 256*1*1)), # B, 256 61 | nn.Linear(256, z_dim if pcl else z_dim*2), # B, z_dim*2 62 | ) 63 | self.decoder = nn.Sequential( 64 | nn.Linear(z_dim, 256), # B, 256 65 | View((-1, 256, 1, 1)), # B, 256, 1, 1 66 | nn.ReLU(True), 67 | nn.ConvTranspose2d(256, 64, 4), # B, 64, 4, 4 68 | nn.ReLU(True), 69 | nn.ConvTranspose2d(64, 64, 4, 2, 1), # B, 64, 8, 8 70 | nn.ReLU(True), 71 | nn.ConvTranspose2d(64, 32, 4, 2, 1), # B, 32, 16, 16 72 | nn.ReLU(True), 73 | nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 32, 32 74 | nn.ReLU(True), 75 | nn.ConvTranspose2d(32, nc, 4, 2, 1), # B, nc, 64, 64 76 | ) 77 | 78 | self.weight_init() 79 | elif self.architecture == "ilcm_tabular": 80 | assert len(self.image_shape) == 1 81 | self.encoder = MLP_ilcm(image_shape[0], z_dim if pcl else 2 * z_dim, 512, 6, spectral_norm=False, batch_norm=False) 82 | self.decoder = MLP_ilcm(z_dim, image_shape[0], 512, 6, spectral_norm=False, batch_norm=False) 83 | self.x_logsigma = torch.nn.Parameter(-5 * torch.ones((1,))) 84 | 85 | def weight_init(self): 86 | for block in self._modules: 87 | for m in self._modules[block]: 88 | kaiming_init(m) 89 | 90 | def forward(self, x, return_z=False): 91 | distributions = self._encode(x) 92 | if self.pcl: 93 | return None, distributions, None 94 | else: 95 | mu = distributions[:, :self.z_dim] 96 | logvar = distributions[:, self.z_dim:] 97 | z = reparametrize(mu, logvar) 98 | x_recon = self._decode(z) 99 | 100 | if len(self.image_shape) == 1: 101 | x_recon = (x_recon, torch.nn.functional.softplus(self.x_logsigma) + 1e-6) 102 | 103 | if return_z: 104 | return x_recon, mu, logvar, z 105 | else: 106 | return x_recon, mu, logvar 107 | 108 | def _encode(self, x): 109 | return self.encoder(x) 110 | 111 | def _decode(self, z): 112 | return self.decoder(z) 113 | 114 | def kaiming_init(m): 115 | if isinstance(m, (nn.Linear, nn.Conv2d)): 116 | init.kaiming_normal(m.weight) 117 | if m.bias is not None: 118 | m.bias.data.fill_(0) 119 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 120 | m.weight.data.fill_(1) 121 | if m.bias is not None: 122 | m.bias.data.fill_(0) 123 | 124 | 125 | def normal_init(m, mean, std): 126 | if isinstance(m, (nn.Linear, nn.Conv2d)): 127 | m.weight.data.normal_(mean, std) 128 | if m.bias.data is not None: 129 | m.bias.data.zero_() 130 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 131 | m.weight.data.fill_(1) 132 | if m.bias.data is not None: 133 | m.bias.data.zero_() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Disentanglement via Mechanism Sparsity 2 | 3 | This repository contains the code used to run the experiments in the papers:
4 | [1] [Disentanglement via Mechanism Sparsity Regularization: A New Principle for Nonlinear ICA](https://arxiv.org/abs/2107.10098) (CLeaR2022)
5 | [2] [Nonparametric Partial Disentanglement via Mechanism Sparsity: Sparse Actions, Interventions and Sparse Temporal Dependencies](https://arxiv.org/abs/2401.04890) (Preprint)
6 | By Sébastien Lachapelle, Pau Rodríguez López, Yash Sharma, Katie Everett, Rémi Le Priol, Alexandre Lacoste, Simon Lacoste-Julien 7 | 8 | ### Environment: 9 | 10 | Tested on python 3.7. 11 | 12 | See `requirements.txt`. 13 | 14 | ### Action-sparsity experiment 15 | ``` 16 | OUTPUT_DIR= 17 | DATAROOT= 18 | DATASET= 19 | python disentanglement_via_mechanism_sparsity/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --mode vae --dataset toy-nn/action_sparsity_non_trivial --freeze_g --freeze_gc --z_dim 10 --gt_z_dim 10 --gt_x_dim 20 --n_lag 0 --time_limit 3 20 | ``` 21 | 22 | ### Time-sparsity experiment 23 | ``` 24 | OUTPUT_DIR= 25 | DATAROOT= 26 | python disentanglement_via_mechanism_sparsity/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --mode vae --dataset toy-nn/temporal_sparsity_non_trivial --freeze_g --freeze_gc --z_dim 10 --gt_z_dim 10 --gt_x_dim 20 --n_lag 1 --full_seq --time_limit 3 27 | ``` 28 | 29 | ### Adding penalty regularization 30 | In the minimal commands provided above, all regularizations are deactivated (via the `--freeze_g` and `--freeze_gc` flags). 31 | To activate the penalty regularization for say the temporal mask G^z, replace `--freeze_g` by `--g_reg_coeff COEFF_VALUE`. 32 | Same syntax works also for the action mask G^a (named `gc` in the code). Here's the correspondence between the mask names in the code (left) and in the paper (right): 33 | 34 | `g` = Mask G^z (Time sparsity) 35 | 36 | `gc` = Mask G^a (Action sparsity) 37 | 38 | ### Adding constraint regularization 39 | To activate the constraint regularization for say the temporal mask G^z, replace `--freeze_g` by `--g_constraint UPPER_BOUND`. 40 | Same syntax works also for the action mask G^a (named `gc` in the code). 41 | The experiments were all performed with `--constraint_scedule 150000` and `--dual_restarts`. 42 | The option `--set_constraint_to_gt` will automatically set the upper bound of the constraint to the optimal value for the ground-truth graph. 43 | 44 | ### Synthetic datasets (referencing to sections of [2]) 45 | Here's a list of the synthetic datasets used. The data is generated before training, so no need to download anything. 46 | 47 | #### Section 8.1 (Used in both [1] and [2]) 48 | 49 | - `--dataset toy-nn/action_sparsity_trivial` 50 | - `--dataset toy-nn/action_sparsity_non_trivial` 51 | - `--dataset toy-nn/action_sparsity_non_trivial_no_suff_var` 52 | - `--dataset toy-nn/action_sparsity_non_trivial_k=2` 53 | - `--dataset toy-nn/temporal_sparsity_trivial` 54 | - `--dataset toy-nn/temporal_sparsity_non_trivial` 55 | - `--dataset toy-nn/temporal_sparsity_non_trivial_no_suff_var` 56 | - `--dataset toy-nn/temporal_sparsity_non_trivial_k=2` 57 | 58 | #### Section 8.2 (Used only in [2]) 59 | - `--dataset toy-nn/action_sparsity_non_trivial --graph_name graph_action_3_easy` 60 | - `--dataset toy-nn/action_sparsity_non_trivial --graph_name graph_action_3_hard` 61 | - `--dataset toy-nn/temporal_sparsity_non_trivial --graph_name graph_temporal_3_easy` 62 | - `--dataset toy-nn/temporal_sparsity_non_trivial --graph_name graph_temporal_3_hard` 63 | - `--dataset toy-nn/action_sparsity_non_trivial --rand_g_density PROBA_OF_EDGE` 64 | - `--dataset toy-nn/temporal_sparsity_non_trivial --rand_g_density PROBA_OF_EDGE` 65 | 66 | 67 | #### Datasets Used only in [1] 68 | - `--dataset toy-nn/action_sparsity_non_trivial_no_graph_crit` 69 | - `--dataset toy-nn/temporal_sparsity_non_trivial_no_graph_crit` 70 | 71 | ### Baselines 72 | #### TCVAE 73 | Code adapted from: https://github.com/rtqichen/beta-tcvae 74 | ``` 75 | OUTPUT_DIR= 76 | DATAROOT= 77 | python disentanglement_via_mechanism_sparsity/baseline_models/beta-tcvae/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --dataset toy-nn/action_sparsity_non_trivial --tcvae --beta 1 --gt_z_dim 10 --gt_x_dim 20 --time_limit 3 78 | ``` 79 | 80 | #### iVAE 81 | Code adapted from: https://github.com/ilkhem/icebeem 82 | ``` 83 | OUTPUT_DIR= 84 | DATAROOT= 85 | python disentanglement_via_mechanism_sparsity/baseline_models/icebeem/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --dataset toy-nn/action_sparsity_non_trivial --method ivae --gt_z_dim 10 --gt_x_dim 20 --time_limit 3 86 | ``` 87 | 88 | #### SlowVAE 89 | Code adapted from: https://github.com/bethgelab/slow_disentanglement 90 | ``` 91 | OUTPUT_DIR= 92 | DATAROOT= 93 | python disentanglement_via_mechanism_sparsity/baseline_models/slowvae_pcl/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --dataset toy-nn/temporal_sparsity_non_trivial --gt_z_dim 10 --gt_x_dim 20 --time_limit 3 94 | ``` 95 | 96 | #### PCL 97 | Code adapted from: https://github.com/bethgelab/slow_disentanglement/tree/baselines 98 | ``` 99 | OUTPUT_DIR= 100 | DATAROOT= 101 | python disentanglement_via_mechanism_sparsity/baseline_models/slowvae_pcl/train.py --dataroot $DATAROOT --output_dir $OUTPUT_DIR --dataset toy-nn/temporal_sparsity_non_trivial --pcl --r_func mlp --gt_z_dim 10 --gt_x_dim 20 --time_limit 3 102 | ``` 103 | 104 | 105 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cooper 3 | 4 | 5 | def compute_nll(log_likelihood, valid, opt): 6 | valid = valid.type(log_likelihood.type()) 7 | if opt.include_invalid: 8 | nll = - torch.mean(log_likelihood) 9 | else: 10 | nll = - torch.dot(log_likelihood, valid) / torch.sum(valid) 11 | 12 | return nll 13 | 14 | 15 | class CustomCMP(cooper.ConstrainedMinimizationProblem): 16 | def __init__(self, g_reg_coeff=0., gc_reg_coeff=0., g_constraint=0, gc_constraint=0., g_scaling=1., gc_scaling=1., 17 | schedule=False, max_g=0, max_gc=0): 18 | self.is_constrained = (g_constraint > 0. or gc_constraint > 0.) 19 | self.g_reg_coeff = g_reg_coeff 20 | self.gc_reg_coeff = gc_reg_coeff 21 | self.g_constraint = g_constraint 22 | self.gc_constraint = gc_constraint 23 | 24 | if schedule: 25 | self.current_g_constraint = max_g 26 | self.current_gc_constraint = max_gc 27 | else: 28 | self.current_g_constraint = g_constraint 29 | self.current_gc_constraint = gc_constraint 30 | 31 | self.max_g, self.max_gc = max_g, max_gc 32 | 33 | self.g_scaling = g_scaling 34 | self.gc_scaling = gc_scaling 35 | super().__init__(is_constrained=self.is_constrained) 36 | 37 | def closure(self, model, obs, cont_c, disc_c, valid, other, opt): 38 | misc = {} 39 | if opt.mode == "vae": 40 | elbo, reconstruction_loss, kl, _ = model.elbo(obs, cont_c, disc_c) 41 | loss = compute_nll(elbo, valid, opt) 42 | elif opt.mode == "supervised_vae": 43 | obs = obs[:, -1] 44 | other = other[:, -1] 45 | z_hat = model.latent_model.mean(model.latent_model.transform_q_params(model.encode(obs))) 46 | loss = torch.mean((z_hat.view(z_hat.shape[0], -1) - other) ** 2) 47 | reconstruction_loss, kl = 0, 0 48 | elif opt.mode == "latent_transition_only": 49 | ll = model.log_likelihood(other, cont_c, disc_c) 50 | loss = compute_nll(ll, valid, opt) 51 | reconstruction_loss, kl = 0, 0 52 | else: 53 | raise NotImplementedError(f"--mode {opt.mode} is not implemented.") 54 | 55 | misc["nll"] = loss.item() 56 | misc["reconstruction_loss"] = reconstruction_loss 57 | misc["kl"] = kl 58 | 59 | # regularization/constraint 60 | g_reg = model.latent_model.g_regularizer() 61 | gc_reg = model.latent_model.gc_regularizer() 62 | 63 | misc["g_reg"], misc["gc_reg"] = g_reg.item(), gc_reg.item() 64 | 65 | if not self.is_constrained: 66 | if self.g_reg_coeff > 0: 67 | loss += opt.g_reg_coeff * g_reg * self.g_scaling 68 | if self.gc_reg_coeff > 0: 69 | loss += opt.gc_reg_coeff * gc_reg * self.gc_scaling 70 | 71 | return cooper.CMPState(loss=loss, ineq_defect=None, eq_defect=None, misc=misc) 72 | else: 73 | defects = [] 74 | if self.g_constraint > 0: 75 | defects.append(g_reg - self.current_g_constraint) 76 | if self.gc_constraint > 0: 77 | defects.append(gc_reg - self.current_gc_constraint) 78 | 79 | defects = torch.stack(defects) 80 | 81 | return cooper.CMPState(loss=loss, ineq_defect=defects, eq_defect=None, misc=misc) 82 | 83 | def update_constraint(self, iter, total_iter, no_update_period=0): 84 | if iter <= no_update_period: 85 | if self.g_constraint > 0: 86 | self.current_g_constraint = self.max_g 87 | if self.gc_constraint > 0: 88 | self.current_gc_constraint = self.max_gc 89 | elif no_update_period < iter <= no_update_period + total_iter: 90 | if self.g_constraint > 0: 91 | self.current_g_constraint = (self.max_g - self.g_constraint) * (1 - iter / total_iter) + self.g_constraint 92 | if self.gc_constraint > 0: 93 | self.current_gc_constraint = (self.max_gc - self.gc_constraint) * (1 - iter / total_iter) + self.gc_constraint 94 | else: 95 | if self.g_constraint > 0: 96 | self.current_g_constraint = self.g_constraint 97 | if self.gc_constraint > 0: 98 | self.current_gc_constraint = self.gc_constraint 99 | 100 | return self.current_g_constraint, self.current_gc_constraint 101 | 102 | def update_constraint_adaptive(self, iter, decrease_rate=0.0005, no_update_period=0): 103 | if iter <= no_update_period: 104 | if self.g_constraint > 0: 105 | self.current_g_constraint = self.max_g 106 | if self.gc_constraint > 0: 107 | self.current_gc_constraint = self.max_gc 108 | else: 109 | # decrease constraint only when defect is smaller than 0.1, otherwise do not change constraint. 110 | if self.g_constraint > 0 and self.state.ineq_defect.sum() <= 0.1: 111 | self.current_g_constraint = max(self.current_g_constraint - decrease_rate, self.g_constraint) 112 | if self.gc_constraint > 0 and self.state.ineq_defect.sum() <= 0.1: 113 | self.current_gc_constraint = max(self.current_gc_constraint - decrease_rate, self.gc_constraint) 114 | 115 | return self.current_g_constraint, self.current_gc_constraint 116 | 117 | #if self.g_constraint > 0 and self.state.ineq_defect.sum() <= 0: 118 | # self.current_g_constraint = max(self.current_g_constraint - 1, self.g_constraint) 119 | #if self.gc_constraint > 0 and self.state.ineq_defect.sum() <= 0: 120 | # self.current_gc_constraint = max(self.current_gc_constraint - 1, self.gc_constraint) 121 | 122 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/data_generation/gen_kitti_masks.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import os 4 | import pickle 5 | import scipy.ndimage as ndi 6 | import torch 7 | import PIL.Image as Image 8 | from scipy.ndimage.measurements import center_of_mass 9 | 10 | # kitti mots 11 | def get_data(path='./data/kitti/instances/', 12 | folder=range(0, 21), 13 | n_hor_windows=6, 14 | img_size=64): 15 | all_imgs = [] 16 | # get each sequence 17 | for j in folder: # 21 sequences 18 | img_folder = f'{j:04d}' 19 | files = sorted(glob.glob(path + f'{img_folder}/*.png')) 20 | imgs = np.zeros((len(files), n_hor_windows, img_size, img_size), dtype=np.uint32) 21 | # load sequence_i 22 | for i, f in enumerate(files): 23 | img = np.array(Image.open(f)) 24 | if i == 0: 25 | print('folder', img_folder, 'first image shape ', img.shape) 26 | shape = img.shape 27 | x_size = int(shape[1] // (shape[0] / img_size)) 28 | img_t = torch.tensor(img.astype(np.float32)) 29 | img = torch.nn.functional.interpolate(img_t[None, None], size=(img_size, x_size)).numpy().astype(np.uint32)[0, 0] 30 | # tile into windows 31 | hor_stride = (x_size - img_size) // (n_hor_windows - 1) 32 | for k in range(n_hor_windows): 33 | offset = k*hor_stride 34 | imgs[i, k, :, :] = img[:, offset:offset+img_size] 35 | all_imgs.append(imgs) 36 | return all_imgs 37 | 38 | def get_individual_sequences(all_imgs, n_hor_windows=6, mask_threshold=30): 39 | sequences = [] 40 | for f_id, imgs_windows in enumerate(all_imgs): # all videos 41 | print('sequence orig', f_id) 42 | for window_i in range(n_hor_windows): # one window 43 | imgs = imgs_windows[:, window_i] 44 | ids = np.where(np.bincount(imgs.ravel()) != 0)[0][1:-1] # first is bg, last is non-category 45 | for id_i in ids: # indiv 46 | if id_i // 1000 == 1: 47 | continue 48 | imgs_id_i = np.zeros(imgs.shape, dtype=np.bool) 49 | imgs_id_i[imgs == id_i] = 1 50 | t_inds = np.where(imgs_id_i != 0)[0] 51 | t_inds = np.arange(t_inds[0], t_inds[-1]+1) # make dense 52 | sequence_id_i = [] 53 | for t_ind in t_inds: # augean stables 54 | frame = imgs_id_i[t_ind] 55 | if np.sum(frame) < mask_threshold: # mask too small 56 | if len(sequence_id_i) > 2: # min sequ len 2 57 | sequences.append(np.stack(sequence_id_i)) # add to sequences 58 | sequence_id_i = [] # hole in sequence, start new one 59 | continue 60 | else: 61 | sequence_id_i.append(frame) # add to sequence 62 | if len(sequence_id_i) > 1: 63 | sequences.append(np.stack(sequence_id_i)) # add to sequences 64 | return sequences 65 | 66 | # get center of mass, and area 67 | def get_latents(sequence): 68 | all_latents = [] 69 | for seq in sequence: 70 | latents = np.zeros((len(seq), 3), dtype=np.float32) 71 | for i, img in enumerate(seq): 72 | com = center_of_mass(img) 73 | latents[i] = np.array([com[0], com[1], np.sum(img)]) # y pos, x pos, area 74 | all_latents.append(latents) 75 | return all_latents 76 | 77 | def main(args): 78 | # raw data from https://www.vision.rwth-aachen.de/page/mots 79 | all_imgs_c = get_data(path='./data/kitti/instances/', folder=range(0, 21), 80 | n_hor_windows=args.n_hor_windows, img_size=args.img_size) # mostly cars 81 | all_imgs_p = get_data(path='./data/kitti/mots/instances/', folder=[2, 5, 9, 11], 82 | n_hor_windows=args.n_hor_windows, img_size=args.img_size) # pedestrians 83 | print('number folders mostly cars', len(all_imgs_c)) 84 | print('number folders pedestrians', len(all_imgs_p)) 85 | sequences_p = get_individual_sequences(all_imgs_c, n_hor_windows=args.n_hor_windows) 86 | print() 87 | print('pedestrians') 88 | sequences_c = get_individual_sequences(all_imgs_p, n_hor_windows=args.n_hor_windows) 89 | all_sequences = sequences_p + sequences_c 90 | all_latents = get_latents(all_sequences) 91 | # save data 92 | with open(os.path.join('./data/kitti_peds_v2.pickle'), 'wb') as f: # v0, v1 are only internal and were not released 93 | pickle.dump({'pedestrians':all_sequences, 'pedestrians_latents': all_latents}, f) 94 | # this is to do the data analysis 95 | dd = {} 96 | dd['id'] = [] 97 | dd['category_id'] = [] 98 | dd['category'] = [] 99 | dd['x'] = [] 100 | dd['x_diff'] = [] 101 | dd['y'] = [] 102 | dd['y_diff'] = [] 103 | dd['area'] = [] 104 | dd['area_diff'] = [] 105 | dd['masks'] = [] 106 | rotate = args.rotate 107 | for id_i, (seq, lat) in enumerate(zip(all_sequences, all_latents)): 108 | for (start_img, next_img), (start_latent, next_latent) in zip(zip(seq[:-1], seq[1:]), zip(lat[:-1], lat[1:])): 109 | if rotate: 110 | start_img = ndi.rotate(start_img, 45) 111 | next_img = ndi.rotate(next_img, 45) 112 | start_latent, next_latent 113 | 114 | start_com = center_of_mass(start_img) 115 | next_com = center_of_mass(next_img) 116 | start_latent = [start_com[0], start_com[1], np.sum(start_img)] # y pos, x pos, area 117 | next_latent = [next_com[0], next_com[1], np.sum(next_img)] # y pos, x pos, area 118 | 119 | dd['id'].append(id_i) 120 | dd['category_id'].append(1) 121 | dd['category'].append('pedestrian') 122 | dd['x'].append([start_latent[1], next_latent[1]]) 123 | dd['x_diff'].append(next_latent[1] - start_latent[1]) 124 | 125 | dd['y'].append([start_latent[0], next_latent[0]]) 126 | dd['y_diff'].append(next_latent[0] - start_latent[0]) 127 | 128 | dd['area'].append([np.sum(start_img), np.sum(next_img)]) 129 | dd['area_diff'].append(np.sum(next_img) - np.sum(start_img)) 130 | 131 | dd['masks'].append([start_img.astype(np.uint8), next_img.astype(np.uint8)]) 132 | prefix = '' 133 | if rotate: 134 | prefix += '_rotate' 135 | with open(f'./data/kitti_dict_p_v2{prefix}.pkl', 'wb') as f: 136 | pickle.dump(dd, f) 137 | 138 | 139 | if __name__ == "__main__": 140 | import argparse 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument('--img-size', type=int, default=64) 143 | parser.add_argument('--n-hor-windows', type=int, default=6) 144 | parser.add_argument('--rotate', action='store_true') 145 | args = parser.parse_args() 146 | main(args) 147 | 148 | 149 | -------------------------------------------------------------------------------- /model/ilcm_vae.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from .nn import MLP 7 | 8 | 9 | class ILCM_VAE(torch.nn.Module): 10 | def __init__(self, latent_model, image_shape, cont_c_dim, disc_c_dim, disc_n_values, opt): 11 | super().__init__() 12 | self.latent_model = latent_model 13 | self.z_dim = opt.z_max_dim 14 | self.n_lag = opt.n_lag 15 | self.cont_c_dim = cont_c_dim 16 | self.disc_c_dim = disc_c_dim 17 | self.disc_c_n_values = disc_n_values 18 | self.beta = opt.beta 19 | self.learn_var = opt.learn_decoder_var 20 | self.opt = opt 21 | 22 | if self.opt.full_seq: 23 | assert self.n_lag == 1, "--full_seq is supported only for --n_lag 1" 24 | 25 | # choosing encoder 26 | if opt.encoder == "tabular": 27 | assert len(image_shape) == 1, f"The encoder {opt.encoder} works only on tabular data." 28 | self.encoder = MLP(image_shape[0], 2 * self.z_dim, 512, 3 * opt.encoder_depth_multiplier, spectral_norm=False, batch_norm=opt.bn_enc_dec) 29 | else: 30 | raise NotImplementedError(f"The encoder {opt.encoder} is not implemented.") 31 | 32 | # choosing decoder 33 | if opt.decoder == "tabular": 34 | assert len(image_shape) == 1, f"The decoder {opt.decoder} works only on tabular data." 35 | self.decoder = MLP(self.z_dim, image_shape[0], 512, 3 * opt.decoder_depth_multiplier, spectral_norm=False, batch_norm=opt.bn_enc_dec) 36 | else: 37 | raise NotImplementedError(f"The decoder {opt.decoder} is not implemented.") 38 | 39 | # dummy parameter which replaces masked out z_i's in the decoder input 40 | self.eta = torch.nn.Parameter(torch.zeros((self.z_dim,))) 41 | if opt.freeze_dummies: 42 | self.eta.requires_grad = False 43 | 44 | if len(image_shape) == 3: 45 | self.image_shape = image_shape[2:3] + image_shape[0:2] 46 | else: 47 | self.image_shape = image_shape 48 | 49 | # that's a bit messy, keep it like this to keep previous default behavior. 50 | if self.learn_var: 51 | if self.opt.init_decoder_var is None: 52 | self.decoder_logvar = torch.nn.Parameter(-10 * torch.ones((1,))) 53 | else: 54 | self.decoder_logvar = torch.nn.Parameter(math.log(opt.init_decoder_var) * torch.ones((1,))) 55 | else: 56 | if self.opt.init_decoder_var is None: 57 | self.decoder_logvar = torch.nn.Parameter(torch.zeros((1,))) 58 | else: 59 | self.decoder_logvar = torch.nn.Parameter(math.log(opt.init_decoder_var) * torch.ones((1,))) 60 | self.decoder_logvar.requires_grad = False 61 | 62 | def encode(self, x): 63 | b = x.size(0) 64 | z_params = self.encoder(x) 65 | return z_params 66 | 67 | def decode(self, input, logit=False): 68 | if len(self.image_shape) == 1: 69 | return self.decoder(input) 70 | elif len(self.image_shape) == 3: 71 | if logit: 72 | return self.decoder(input) 73 | else: 74 | return torch.sigmoid(self.decoder(input)) 75 | 76 | def _mask_z_before_decoding(self, z, m): 77 | # z shape: (b, t, z_max_dim) 78 | b, t = z.shape[0:2] 79 | num_blocks = m.shape[1] 80 | m = m.view(b, 1, num_blocks, 1) 81 | z = z.view(b,t, num_blocks, -1) 82 | eta = self.eta.view(1, 1, num_blocks, -1) 83 | z_masked_eta = m * z + (1 - m) * eta 84 | return z_masked_eta.view(b * t, -1) 85 | 86 | def elbo(self, obs, cont_c, disc_c): 87 | b, t = obs.shape[0:2] 88 | 89 | q_params = self.encode(obs.view((b * t,) + self.image_shape)) 90 | z = self.latent_model.reparameterize(q_params) 91 | m = self.latent_model.m(b) # sample a mask 92 | 93 | ## --- Reconstruction --- ## 94 | z_masked_eta = self._mask_z_before_decoding(z.view(b, t, -1), m) 95 | # including the reconstruction term not only for x_t, but also for x_t-1, ..., x_t-k. 96 | 97 | reconstructions = self.decode(z_masked_eta) 98 | std = torch.exp(0.5 * self.decoder_logvar) + 1e-4 99 | # SL: This choice of reduction is picked to keep the relative importance of each terms in line with the original ELBO 100 | rec_loss = - torch.distributions.normal.Normal(reconstructions.view(b, t, -1), std).log_prob(obs.view(b, t, -1)).mean(dim=(1, 2)) 101 | 102 | ## --- KL divergence --- # 103 | # mask z and c 104 | if self.latent_model.n_lag > 0: 105 | g = self.latent_model.g(b) 106 | z_lag = z.view(b, t, -1)[:, :-1] 107 | z_masked_gamma_lag = self.latent_model._mask_z_lag(z_lag, m, g) # (bs, num_blocks, n_lag, z_dim) 108 | z_tm1 = z_lag[:, -1].view(b, -1, self.latent_model.z_block_size) 109 | else: 110 | z_masked_gamma_lag = None 111 | z_tm1 = None 112 | 113 | if self.latent_model.cont_c_dim > 0: 114 | gc = self.latent_model.gc(b) 115 | masked_cont_c = self.latent_model._mask_cont_c(cont_c, gc) 116 | else: 117 | masked_cont_c = None 118 | 119 | if self.disc_c_dim > 0: 120 | gc_disc = self.latent_model.gc_disc(b) # (batch_size, num_z_blocks, disc_c_dim) TODO: disc_c_dim == 1 for now... 121 | masked_disc_c_one_hot = self.latent_model._mask_disc_c_one_hot(disc_c, gc_disc) # (bs, num_blocks, cont_c_dim) 122 | else: 123 | masked_disc_c_one_hot = None 124 | 125 | p_params = self.latent_model.network(z_masked_gamma_lag, masked_cont_c, masked_disc_c_one_hot) 126 | q_params = q_params.view(b, t, -1) 127 | kl = self.latent_model.compute_kl(p_params, q_params[:, -1], z_tm1) 128 | 129 | if self.opt.full_seq: 130 | # we divide by two to preserve the relative importance between rec_loss and kl in line with original elbo 131 | init_params = self.latent_model.init_p_params.expand(b, -1, -1) 132 | kl = 0.5 * (kl + self.latent_model.compute_kl(init_params, q_params[:, 0], None, init=True)) 133 | 134 | kl_reduced = torch.sum(kl, 1) / int(np.product(self.image_shape)) 135 | 136 | elbo = -rec_loss - self.beta * kl_reduced 137 | return elbo, rec_loss.mean().item(), kl_reduced.mean().item(), kl 138 | 139 | def log_likelihood(self, gt_z, cont_c, disc_c): 140 | b, t = gt_z.shape[0:2] 141 | 142 | m = self.latent_model.m(b) # sample a mask 143 | 144 | # mask z and c 145 | if self.latent_model.n_lag > 0: 146 | g = self.latent_model.g(b) 147 | z_lag = gt_z[:, :-1] 148 | z_masked_gamma_lag = self.latent_model._mask_z_lag(z_lag, m, g) # (bs, num_blocks, n_lag, z_dim) 149 | z_tm1 = z_lag[:, -1].view(b, -1, self.latent_model.z_block_size) 150 | else: 151 | z_masked_gamma_lag = None 152 | z_tm1 = None 153 | 154 | if self.latent_model.cont_c_dim > 0: 155 | gc = self.latent_model.gc(b) 156 | masked_cont_c = self.latent_model._mask_cont_c(cont_c, gc) 157 | else: 158 | masked_cont_c = None 159 | 160 | p_params = self.latent_model.network(z_masked_gamma_lag, masked_cont_c, disc_c) 161 | ll = self.latent_model.log_likelihood(p_params, gt_z[:, -1], z_tm1) 162 | 163 | return ll 164 | 165 | -------------------------------------------------------------------------------- /scripts/compute_udr_npy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import math 5 | import os 6 | import random 7 | import shutil 8 | import sys 9 | import time 10 | import pathlib 11 | 12 | import numpy as np 13 | import pandas as pd 14 | 15 | #sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) 16 | sys.path.insert(0, os.path.abspath('..')) 17 | from metrics import mean_corr_coef_np 18 | 19 | def load_hparams(folder_path): 20 | with open(os.path.join(folder_path, 'hparams.json'), 'r') as infile: 21 | opt = json.load(infile) 22 | 23 | class Bunch: 24 | def __init__(self, opt): 25 | self.__dict__.update(opt) 26 | return Bunch(opt) 27 | 28 | def compute_udr_mcc_from_npy(inferred_model_reps): 29 | """Computes the UDR score using scikit-learn. 30 | 31 | Args: 32 | ground_truth_data: GroundTruthData to be sampled from. 33 | representation_functions: functions that takes observations as input and 34 | outputs a dim_representation sized representation for each observation. 35 | random_state: numpy random state used for randomness. 36 | batch_size: Number of datapoints to compute in a single batch. Useful for 37 | reducing memory overhead for larger models. 38 | num_data_points: total number of representation datapoints to generate for 39 | computing the correlation matrix. 40 | correlation_matrix: Type of correlation matrix to generate. Can be either 41 | "lasso" or "spearman". 42 | filter_low_kl: If True, filter out elements of the representation vector 43 | which have low computed KL divergence. 44 | include_raw_correlations: Whether or not to include the raw correlation 45 | matrices in the results. 46 | kl_filter_threshold: Threshold which latents with average KL divergence 47 | lower than the threshold will be ignored when computing disentanglement. 48 | 49 | Returns: 50 | scores_dict: a dictionary of the scores computed for UDR with the following 51 | keys: 52 | raw_correlations: (num_models, num_models, latent_dim, latent_dim) - The 53 | raw computed correlation matrices for all models. The pair of models is 54 | indexed by axis 0 and 1 and the matrix represents the computed 55 | correlation matrix between latents in axis 2 and 3. 56 | pairwise_disentanglement_scores: (num_models, num_models, 1) - The 57 | computed disentanglement scores representing the similarity of 58 | representation between pairs of models. 59 | model_scores: (num_models) - List of aggregated model scores corresponding 60 | to the median of the pairwise disentanglement scores for each model. 61 | """ 62 | 63 | num_models = len(inferred_model_reps) 64 | mcc_all = np.zeros((num_models, num_models)) 65 | 66 | for i in range(num_models): 67 | for j in range(num_models): 68 | if i == j: 69 | continue 70 | 71 | mcc = mean_corr_coef_np(inferred_model_reps[i], 72 | inferred_model_reps[j], 73 | method='pearson', indices=None)[0] 74 | 75 | mcc_all[i, j] = mcc 76 | off_diag = mcc_all[~np.eye(mcc_all.shape[0], dtype=bool)] 77 | return {'median': np.median(off_diag), 'mean': np.mean(off_diag)} 78 | 79 | def create_mode_entry(all_logs_pd): 80 | # for tcvae 81 | all_logs_pd.loc[all_logs_pd["tcvae"] == True, 'mode'] = "tcvae" 82 | 83 | # TODO: create for other methods if necessary. 84 | 85 | return all_logs_pd 86 | 87 | def main(args=None): 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument("--all_logs_file", type=str, 90 | help="Absolute path to all_logs.npy files") 91 | 92 | #GT_GRAPH_NAMES = ["graph_temporal_3_easy", "graph_temporal_3_hard", "graph_action_3_easy", "graph_action_3_hard"] 93 | GT_GRAPH_NAMES = ["toy-nn/temporal_sparsity_trivial", "toy-nn/temporal_sparsity_non_trivial", "toy-nn/action_sparsity_trivial", "toy-nn/action_sparsity_non_trivial"] 94 | MODES = ['vae'] #, 'pcl', 'slowvae', 'tcvae', 'ivae', 'random_vae', 'supervised_vae'] 95 | HPARAM_NAMES = {"vae": ['gc_constraint', 'g_constraint']} #, 96 | #"random_vae": [], 97 | #"supervised_vae": [], 98 | #"tcvae": ["beta"], 99 | #"ivae": [], 100 | #"pcl": [], 101 | #"slowvae": ['gamma', 'rate_prior']} 102 | 103 | opt = parser.parse_args() 104 | 105 | all_logs = np.load(opt.all_logs_file, allow_pickle=True).tolist() 106 | all_logs_pd = pd.DataFrame(all_logs) 107 | #all_logs_pd = create_mode_entry(all_logs_pd) 108 | 109 | # to be filled with udr values 110 | all_logs_pd_udr = all_logs_pd.copy() 111 | 112 | for gt_graph_name in GT_GRAPH_NAMES: 113 | for mode in MODES: 114 | print("########## gt_graph_name:", gt_graph_name, "mode:", mode, "##########") 115 | #condition_data_mode = (all_logs_pd['gt_graph_name'] == gt_graph_name) & (all_logs_pd['mode'] == mode) 116 | condition_data_mode = (all_logs_pd['dataset'] == gt_graph_name) & (all_logs_pd['mode'] == mode) 117 | logs = all_logs_pd[condition_data_mode] 118 | 119 | if len(logs) == 0: 120 | print("No logs found.") 121 | continue 122 | 123 | if len(HPARAM_NAMES[mode]) != 0: 124 | hparams_values = logs[HPARAM_NAMES[mode]].drop_duplicates() 125 | 126 | for i in range(len(hparams_values)): 127 | # selecting only the runs with this specific hparam value 128 | condition_hp = (hparams_values.iloc[i] == logs[HPARAM_NAMES[mode]]).all(axis=1) 129 | logs_specific_hp = logs[condition_hp] 130 | 131 | # logging number of seeds used to compute UDR 132 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "num_seeds"] = len(logs_specific_hp) 133 | 134 | # cannot compute UDR when there is only one seed 135 | if len(logs_specific_hp) == 1: 136 | print(f"Not computing UDR for {hparams_values.iloc[i].to_dict()} since only one seed.") 137 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_mean"] = -1 138 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_median"] = -1 139 | continue 140 | 141 | # load their z_hat_final.npy 142 | z_hat_list = [] 143 | for output_dir in logs_specific_hp["output_dir"]: 144 | z_hat_list.append(np.load(os.path.join(output_dir, "z_hat_final.npy"))) 145 | 146 | print(f"Computing UDR for {hparams_values.iloc[i].to_dict()} on {len(logs_specific_hp)} seeds.") 147 | udr = compute_udr_mcc_from_npy(z_hat_list) 148 | 149 | # Add udr values to table 150 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_mean"] = udr["mean"] 151 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_median"] = udr["median"] 152 | 153 | else: 154 | print(f"No hyperparameter search. Total of {len(logs)} seeds.") 155 | # number of sucessful seeds 156 | all_logs_pd_udr.loc[condition_data_mode, "num_seeds"] = len(logs) 157 | # if the method has no hparameter, just set UDR score to -1 158 | all_logs_pd_udr.loc[condition_data_mode, "udr_mean"] = -1 159 | all_logs_pd_udr.loc[condition_data_mode, "udr_median"] = -1 160 | 161 | 162 | np.save(opt.all_logs_file.replace(".npy", "_udr.npy"), all_logs_pd_udr.to_dict('records')) 163 | 164 | 165 | if __name__ == "__main__": 166 | main() 167 | -------------------------------------------------------------------------------- /scripts/compute_udr_npy_rand_graphs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import math 5 | import os 6 | import random 7 | import shutil 8 | import sys 9 | import time 10 | import pathlib 11 | 12 | import numpy as np 13 | import pandas as pd 14 | 15 | #sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) 16 | sys.path.insert(0, os.path.abspath('..')) 17 | from metrics import mean_corr_coef_np 18 | 19 | def load_hparams(folder_path): 20 | with open(os.path.join(folder_path, 'hparams.json'), 'r') as infile: 21 | opt = json.load(infile) 22 | 23 | class Bunch: 24 | def __init__(self, opt): 25 | self.__dict__.update(opt) 26 | return Bunch(opt) 27 | 28 | def compute_udr_mcc_from_npy(inferred_model_reps): 29 | """Computes the UDR score using scikit-learn. 30 | 31 | Args: 32 | ground_truth_data: GroundTruthData to be sampled from. 33 | representation_functions: functions that takes observations as input and 34 | outputs a dim_representation sized representation for each observation. 35 | random_state: numpy random state used for randomness. 36 | batch_size: Number of datapoints to compute in a single batch. Useful for 37 | reducing memory overhead for larger models. 38 | num_data_points: total number of representation datapoints to generate for 39 | computing the correlation matrix. 40 | correlation_matrix: Type of correlation matrix to generate. Can be either 41 | "lasso" or "spearman". 42 | filter_low_kl: If True, filter out elements of the representation vector 43 | which have low computed KL divergence. 44 | include_raw_correlations: Whether or not to include the raw correlation 45 | matrices in the results. 46 | kl_filter_threshold: Threshold which latents with average KL divergence 47 | lower than the threshold will be ignored when computing disentanglement. 48 | 49 | Returns: 50 | scores_dict: a dictionary of the scores computed for UDR with the following 51 | keys: 52 | raw_correlations: (num_models, num_models, latent_dim, latent_dim) - The 53 | raw computed correlation matrices for all models. The pair of models is 54 | indexed by axis 0 and 1 and the matrix represents the computed 55 | correlation matrix between latents in axis 2 and 3. 56 | pairwise_disentanglement_scores: (num_models, num_models, 1) - The 57 | computed disentanglement scores representing the similarity of 58 | representation between pairs of models. 59 | model_scores: (num_models) - List of aggregated model scores corresponding 60 | to the median of the pairwise disentanglement scores for each model. 61 | """ 62 | 63 | num_models = len(inferred_model_reps) 64 | mcc_all = np.zeros((num_models, num_models)) 65 | 66 | for i in range(num_models): 67 | for j in range(num_models): 68 | if i == j: 69 | continue 70 | 71 | mcc = mean_corr_coef_np(inferred_model_reps[i], 72 | inferred_model_reps[j], 73 | method='pearson', indices=None)[0] 74 | 75 | mcc_all[i, j] = mcc 76 | off_diag = mcc_all[~np.eye(mcc_all.shape[0], dtype=bool)] 77 | return {'median': np.median(off_diag), 'mean': np.mean(off_diag)} 78 | 79 | def create_mode_entry(all_logs_pd): 80 | # for tcvae 81 | all_logs_pd.loc[all_logs_pd["tcvae"] == True, 'mode'] = "tcvae" 82 | 83 | # TODO: create for other methods if necessary. 84 | 85 | return all_logs_pd 86 | 87 | def main(args=None): 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument("--all_logs_file", type=str, 90 | help="Absolute path to all_logs.npy files") 91 | 92 | GRAPH_DENSITIES = [0.25, 0.5, 0.75] 93 | DATASET_NAMES = ["toy-nn/action_sparsity_non_trivial", "toy-nn/temporal_sparsity_non_trivial"] 94 | MODES = ['vae'] #, 'pcl', 'slowvae', 'tcvae', 'ivae', 'random_vae', 'supervised_vae'] 95 | HPARAM_NAMES = {"vae": ['gc_reg_coeff', 'g_reg_coeff']} 96 | #"random_vae": [], 97 | #"supervised_vae": [], 98 | #"tcvae": ["beta"], 99 | #"ivae": [], 100 | #"pcl": [], 101 | #"slowvae": ['gamma', 'rate_prior']} 102 | 103 | opt = parser.parse_args() 104 | 105 | all_logs = np.load(opt.all_logs_file, allow_pickle=True).tolist() 106 | all_logs_pd = pd.DataFrame(all_logs) 107 | #all_logs_pd = create_mode_entry(all_logs_pd) 108 | 109 | # to be filled with udr values 110 | all_logs_pd_udr = all_logs_pd.copy() 111 | 112 | for dataset in DATASET_NAMES: 113 | for graph_density in GRAPH_DENSITIES: 114 | for mode in MODES: 115 | print("########## dataset:", dataset, "graph_density", graph_density, "mode:", mode, "##########") 116 | condition_data_mode = (all_logs_pd['dataset'] == dataset) & (all_logs_pd['rand_g_density'] == graph_density) & (all_logs_pd['mode'] == mode) 117 | #condition = (all_logs_pd['dataset'] == gt_graph_name) & (all_logs_pd['mode'] == mode) # only for debugging 118 | logs = all_logs_pd[condition_data_mode] 119 | 120 | if len(logs) == 0: 121 | print("No logs found.") 122 | continue 123 | 124 | if len(HPARAM_NAMES[mode]) != 0: 125 | hparams_values = logs[HPARAM_NAMES[mode]].drop_duplicates() 126 | 127 | for i in range(len(hparams_values)): 128 | # selecting only the runs with this specific hparam value 129 | condition_hp = (hparams_values.iloc[i] == logs[HPARAM_NAMES[mode]]).all(axis=1) 130 | logs_specific_hp = logs[condition_hp] 131 | 132 | # logging number of seeds used to compute UDR 133 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "num_seeds"] = len(logs_specific_hp) 134 | 135 | # cannot compute UDR when there is only one seed 136 | if len(logs_specific_hp) == 1: 137 | print(f"Not computing UDR for {hparams_values.iloc[i].to_dict()} since only one seed.") 138 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_mean"] = -1 139 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_median"] = -1 140 | continue 141 | 142 | # load their z_hat_final.npy 143 | z_hat_list = [] 144 | for output_dir in logs_specific_hp["output_dir"]: 145 | z_hat_list.append(np.load(os.path.join(output_dir, "z_hat_final.npy"))) 146 | 147 | print(f"Computing UDR for {hparams_values.iloc[i].to_dict()} on {len(logs_specific_hp)} seeds.") 148 | udr = compute_udr_mcc_from_npy(z_hat_list) 149 | 150 | # Add udr values to table 151 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_mean"] = udr["mean"] 152 | all_logs_pd_udr.loc[condition_data_mode & condition_hp, "udr_median"] = udr["median"] 153 | 154 | else: 155 | print(f"No hyperparameter search. Total of {len(logs)} seeds.") 156 | # number of sucessful seeds 157 | all_logs_pd_udr.loc[condition_data_mode, "num_seeds"] = len(logs) 158 | # if the method has no hparameter, just set UDR score to -1 159 | all_logs_pd_udr.loc[condition_data_mode, "udr_mean"] = -1 160 | all_logs_pd_udr.loc[condition_data_mode, "udr_median"] = -1 161 | 162 | 163 | np.save(opt.all_logs_file.replace(".npy", "_udr.npy"), all_logs_pd_udr.to_dict('records')) 164 | 165 | 166 | if __name__ == "__main__": 167 | main() 168 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/scripts/data_analysis_utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from collections import defaultdict 3 | import ast 4 | import pickle 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import pandas as pd 8 | import scipy.stats 9 | import collections 10 | import warnings 11 | from sklearn.feature_selection import mutual_info_regression 12 | import random 13 | #for youtube 14 | name_list = 'person giant_panda lizard parrot skateboard sedan ape dog snake monkey hand rabbit duck cat cow fish train horse turtle bear motorbike giraffe leopard fox deer owl surfboard airplane truck zebra tiger elephant snowboard boat shark mouse frog eagle earless_seal tennis_racket'.split(' ') 15 | 16 | def load_csv(csv_file, sequence=2): 17 | csv_reader = csv.reader(csv_file, delimiter=',') 18 | next(csv_reader) 19 | data = defaultdict(list) 20 | for i,row in enumerate(csv_reader): 21 | for j in range(2,len(row)): 22 | considered_columns = row[j:j+sequence] 23 | if all(considered_columns): 24 | temp = defaultdict(list) 25 | for column in considered_columns: 26 | val_list = ast.literal_eval(column) 27 | for i,val in enumerate(val_list): 28 | if val: 29 | temp['pos'].append(i) 30 | temp['y'].append(val[0]) 31 | temp['x'].append(val[1]) 32 | temp['area'].append(val[2]) 33 | for i in range(len(val_list)): 34 | if temp['pos'].count(i) == sequence: 35 | data['id'].append(int(row[0])) 36 | data['category_id'].append(int(row[1])) 37 | data['area'].append([temp['area'][j] for j in range(len(temp['pos'])) if temp['pos'][j]==i]) 38 | data['x'].append([temp['x'][j] for j in range(len(temp['pos'])) if temp['pos'][j]==i]) 39 | data['y'].append([temp['y'][j] for j in range(len(temp['pos'])) if temp['pos'][j]==i]) 40 | for k in range(1,sequence): 41 | data['area_diff{}'.format(k if k>1 else '')].append(data['area'][-1][k] - data['area'][-1][k-1]) 42 | data['x_diff{}'.format(k if k>1 else '')].append(data['x'][-1][k] - data['x'][-1][k-1]) 43 | data['y_diff{}'.format(k if k>1 else '')].append(data['y'][-1][k] - data['y'][-1][k-1]) 44 | else: 45 | assert temp['pos'].count(i) < sequence 46 | return data 47 | 48 | def load_data(path): 49 | with open(path + '.pkl', 'rb') as data: 50 | data_dict = pickle.load(data) 51 | return data_dict 52 | 53 | def plot_bar(data, key): 54 | bars = np.bincount(data[key]) 55 | _ = plt.bar(range(len(bars)), bars) 56 | 57 | def plot_type(data, x, type_='all', semilogy=False): 58 | if type_ == 'all': 59 | _ = plt.hist(x, bins=100) 60 | if semilogy: 61 | _ = plt.semilogy() 62 | else: 63 | fig, axes = plt.subplots(len(name_list) // 8, len(name_list) // 5, 64 | figsize=((len(name_list) // 5)*10,(len(name_list) // 8)*10)) 65 | data_per_cat = collections.defaultdict(list) 66 | for j in range(len(data['id'])): 67 | data_per_cat[data['category_id'][j]].append(x[j]) 68 | for i in range(len(name_list)): 69 | _ = axes[i % (len(name_list) // 8), i % (len(name_list) // 5)].hist(data_per_cat[i+1], bins=100) 70 | _ = axes[i % (len(name_list) // 8), i % (len(name_list) // 5)].set_title(name_list[i]) 71 | plt.show() 72 | 73 | 74 | def plot_val(data, key, low=0., high=1., type_='all'): 75 | x = pd.Series(np.array(data[key])[:,0]) 76 | print('orig_min',x.min()) 77 | print('orig_max',x.max()) 78 | x = x[x.between(x.quantile(low), x.quantile(high))] 79 | print('{}_min'.format(low),x.min()) 80 | print('{}_max'.format(high),x.max()) 81 | plot_type(data, x, type_) 82 | 83 | def plot_diff(data, key, semilogy=True, type_='all'): 84 | plot_type(data, data[key + '_diff'], type_, semilogy) 85 | 86 | def visualize_mask(mask): 87 | _ = plt.imshow(mask) 88 | 89 | def generate_dataframe(data, type_='all', mi=False, mi_samples=20000): 90 | stats = collections.defaultdict(list) 91 | warnings.filterwarnings(action='ignore') 92 | distributions = [scipy.stats.gennorm, scipy.stats.norm, scipy.stats.laplace] 93 | for i in range((0 if type_=='all' else np.max(data['category_id']))+1): 94 | if i == 0: 95 | stats['category'].append('all') 96 | stats['N'].append(len(data['id'])) 97 | area_val = data['area_diff'] 98 | x_val = data['x_diff'] 99 | y_val = data['y_diff'] 100 | else: 101 | stats['category'].append(name_list[i-1]) 102 | stats['N'].append(np.count_nonzero(np.array(data['category_id']).astype(int) == i)) 103 | area_val = [data['area_diff'][j] for j in range(len(data['area_diff'])) if i == data['category_id'][j]] 104 | x_val = [data['x_diff'][j] for j in range(len(data['x_diff'])) if i == data['category_id'][j]] 105 | y_val = [data['y_diff'][j] for j in range(len(data['y_diff'])) if i == data['category_id'][j]] 106 | vals = {'area': area_val, 'x': x_val, 'y': y_val} 107 | for x in vals.keys(): 108 | stats['kurtosis_' + x].append("%.2f" % scipy.stats.kurtosis(vals[x])) 109 | for distribution in distributions: 110 | params = distribution.fit(vals[x]) 111 | stats['{}_{}'.format(distribution.name, x)].append(['%.2e' % x for x in params]) 112 | arg = params[:-2] 113 | loc = params[-2] 114 | scale = params[-1] 115 | stats['ll_{}_{}'.format(distribution.name, x)].append("%.2e" % distribution.logpdf(vals[x], *params).sum()) 116 | stats['ks_{}_{}'.format(distribution.name, x)].append("%.2e" % scipy.stats.kstest(vals[x], 117 | lambda x: distribution.cdf(x, 118 | loc=loc, 119 | scale=scale, 120 | *arg))[1]) 121 | stats['pearson_area_x'].append(["%.2f" % x for x in scipy.stats.pearsonr(vals['area'], vals['x'])]) 122 | stats['pearson_area_y'].append(["%.2f" % x for x in scipy.stats.pearsonr(vals['area'], vals['y'])]) 123 | stats['pearson_x_y'].append(["%.2f" % x for x in scipy.stats.pearsonr(vals['x'], vals['y'])]) 124 | if mi: 125 | indices = random.sample(range(len(vals['area'])), min(mi_samples,stats['N'][-1])) 126 | stats['mi_area_x'].append("%.2f" % mutual_info_regression(np.array(vals['area']).reshape(-1,1)[indices], 127 | np.array(vals['x'])[indices])) 128 | stats['mi_area_y'].append("%.2f" % mutual_info_regression(np.array(vals['area']).reshape(-1,1)[indices], 129 | np.array(vals['y'])[indices])) 130 | stats['mi_x_y'].append("%.2f" % mutual_info_regression(np.array(vals['x']).reshape(-1,1)[indices], 131 | np.array(vals['y'])[indices])) 132 | return pd.DataFrame.from_dict(stats).sort_values('N', ascending=True) 133 | 134 | 135 | def find_best(df, criterion='ll'): 136 | best_df = pd.concat([df['category'], 137 | df['N'], 138 | df[[*[col for col in df.columns if ('area' in col and criterion in col)]]].astype(np.float64).idxmax(axis=1), 139 | df[[*[col for col in df.columns if ('x' in col and criterion in col)]]].astype(np.float64).idxmax(axis=1), 140 | df[[*[col for col in df.columns if ('y' in col and criterion in col)]]].astype(np.float64).idxmax(axis=1), 141 | ], axis=1) 142 | return best_df.sort_values('N',ascending=False) 143 | 144 | 145 | -------------------------------------------------------------------------------- /universal_logger/logger.py: -------------------------------------------------------------------------------- 1 | # Author: Alexandre Drouin 2 | # Adapted by: Sebastien Lachapelle 3 | 4 | import json 5 | import numpy as np 6 | import os 7 | import queue 8 | 9 | from time import strftime, time 10 | 11 | 12 | try: 13 | import comet_ml 14 | COMET_AVAIL = True 15 | except: 16 | COMET_AVAIL = False 17 | 18 | try: 19 | import tensorboardX 20 | TBX_AVAIL = True 21 | except: 22 | TBX_AVAIL = False 23 | 24 | 25 | def _check_randomstate(random_state): 26 | if isinstance(random_state, int): 27 | random_state = np.random.RandomState(random_state) 28 | if not isinstance(random_state, np.random.RandomState): 29 | raise ValueError("Random state must be numpy RandomState or int.") 30 | return random_state 31 | 32 | 33 | def _prefix_stage(metrics, stage): 34 | if stage is not None: 35 | metrics = {"%s/%s" % (stage, k): v for k, v in metrics.items()} 36 | return metrics 37 | 38 | 39 | class CometLogger(object): 40 | def __init__(self, experiment): 41 | self.experiment = experiment 42 | 43 | def log_figure(self, stage, step, name, figure): 44 | prefix = "" 45 | #if stage is not None: 46 | # prefix += stage + "/" 47 | #if step is not None: 48 | # prefix += str(step) + "/" 49 | try: 50 | self.experiment.log_figure(figure_name=prefix + name, figure=figure, step=step, overwrite=False) 51 | except: 52 | self.experiment.log_image(figure, name=prefix + name, step=step, overwrite=False) 53 | 54 | def log_metrics(self, stage, step, metrics): 55 | self.experiment.log_metrics(step=step, dic=_prefix_stage(metrics, stage)) 56 | 57 | 58 | class JsonLogger(object): 59 | def __init__(self, path, time=True, max_fig_save=None): 60 | self.path = path 61 | os.makedirs(os.path.dirname(path), exist_ok=True) 62 | self.time = time 63 | self.max_fig_save = max_fig_save # makes sure we don't save too many .png files. 64 | self.current_figs = {} 65 | 66 | def log_metrics(self, stage, step, metrics): 67 | metrics["stage"] = stage 68 | metrics["step"] = step 69 | if self.time: 70 | metrics["time"] = time() 71 | with open(os.path.join(self.path, "log.ndjson"), "a") as f: 72 | f.write(json.dumps(metrics) + "\n") 73 | 74 | def log_figure(self, stage, step, name, figure): 75 | # if stage is not None: 76 | # prefix += stage + "/" 77 | # if step is not None: 78 | # prefix += str(step) + "/" 79 | if name not in self.current_figs.keys(): 80 | self.current_figs[name] = queue.Queue() 81 | 82 | file_name = f"{name}_{step}.png" 83 | figure.savefig(os.path.join(self.path, file_name)) 84 | self.current_figs[name].put(file_name) 85 | 86 | # removing old figures 87 | if self.current_figs[name].qsize() > self.max_fig_save: 88 | file_to_delete = self.current_figs[name].get() 89 | os.remove(os.path.join(self.path, file_to_delete)) 90 | 91 | 92 | class StdoutLogger(object): 93 | def __init__(self, time=True): 94 | self.time = time 95 | 96 | def log_metrics(self, stage, step, metrics): 97 | prefix = "" if stage is None else (stage + " -- ") 98 | try: 99 | log = f"{prefix}{step}:\t" + "\t".join("%s: %.6f" % (m, v) for m, v in metrics.items()) 100 | except: 101 | log = f"{prefix}{step}:\t" + "\t".join("%s: %s" % (m, v) for m, v in metrics.items()) 102 | if self.time: 103 | log += "\t time: " + strftime('%X %x') 104 | print(log) 105 | 106 | 107 | class TensorboardXLogger(object): 108 | def __init__(self, summary_writer): 109 | self.summary_writer = summary_writer 110 | 111 | def log_metrics(self, stage, step, metrics): 112 | for m, v in _prefix_stage(metrics, stage).items(): 113 | self.summary_writer.add_scalar(m, v, step) 114 | 115 | 116 | # XXX: I'll eventually move this class to the shared code-base 117 | class UniversalLogger(object): 118 | """ 119 | A logger that simultaneously logs to multiple outputs. 120 | 121 | Parameters: 122 | ----------- 123 | comet: comet_ml.Experiment 124 | A comet experiment 125 | json: str 126 | The path to a json file 127 | stdout: bool 128 | Whether or not to print metrics to the standard output 129 | tensorboardx: tensorboardx.SummaryWriter 130 | A summary writer for tensorboardx 131 | time: bool 132 | Whether to log the time in some loggers (supported: json, stdout) 133 | throttle: int 134 | The minimum time between logs in seconds 135 | 136 | """ 137 | def __init__(self, comet=None, json=None, stdout=False, tensorboardx=None, time=True, throttle=None, max_fig_save=None): 138 | super().__init__() 139 | loggers = [] 140 | if comet is not None: 141 | if not COMET_AVAIL: 142 | raise RuntimeError("comet_ml is not available on this platform. Please install it.") 143 | loggers.append(CometLogger(experiment=comet)) 144 | if json is not None: 145 | loggers.append(JsonLogger(json, time=time, max_fig_save=max_fig_save)) 146 | if stdout: 147 | loggers.append(StdoutLogger(time=time)) 148 | if tensorboardx is not None: 149 | if not TBX_AVAIL: 150 | raise RuntimeError("TensorboardX is not available on this platform. Please install it.") 151 | loggers.append(TensorboardXLogger(tensorboardx)) 152 | assert len(loggers) >= 1 153 | self.loggers = loggers 154 | 155 | # Throttling 156 | self.throttle = throttle 157 | self.last_log_time = 0 158 | self.current_stage = None 159 | 160 | def _check_stage(self, stage): 161 | # This will force logging if we are entering a new stage 162 | if self.current_stage != stage: 163 | self.last_log_time = 0 164 | self.current_stage = stage 165 | 166 | def _check_throttle(self): 167 | return self.throttle is None or time() - self.last_log_time > self.throttle 168 | 169 | def log_metrics(self, step, metrics, stage=None, throttle=True): 170 | """ 171 | Log a real-valued metric 172 | 173 | Parameters: 174 | ----------- 175 | stage: str 176 | The current stage of execution (e.g., "train", "test", etc.) 177 | step: uint 178 | The current step number (e.g., epoch) 179 | metrics: dict 180 | A dictionnary with metric names as keys and metric values as values 181 | throttle: bool 182 | Whether to respect log throttling or not (default is True) 183 | 184 | """ 185 | self._check_stage(stage) 186 | if self._check_throttle() or not throttle: 187 | for log in self.loggers: 188 | if hasattr(log, "log_metrics"): 189 | log.log_metrics(stage, step, dict(metrics)) 190 | self.last_log_time = time() 191 | 192 | def log_figure(self, name, figure, stage=None, step=None, throttle=True): 193 | """ 194 | Log a matplotlib figure 195 | 196 | Parameters: 197 | ----------- 198 | name: str 199 | The name of the figure 200 | figure: mpl.Figure 201 | The matplotlib figure 202 | step: uint 203 | The current step number (default: None). If required by a logger and not provided, an exception will be 204 | raised. 205 | throttle: bool 206 | Whether to respect log throttling or not (default is True) 207 | 208 | """ 209 | self._check_stage(stage) 210 | if self._check_throttle() or not throttle: 211 | for log in self.loggers: 212 | if hasattr(log, "log_figure"): 213 | log.log_figure(stage, step, name, figure) 214 | self.last_log_time = time() -------------------------------------------------------------------------------- /baseline_models/icebeem/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import pathlib 5 | import pickle 6 | import json 7 | 8 | try: 9 | from comet_ml import Experiment 10 | COMET_AVAIL = True 11 | except: 12 | COMET_AVAIL = False 13 | import numpy as np 14 | import torch 15 | 16 | sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent)) 17 | from baseline_models.icebeem.models.ivae.ivae_wrapper import IVAE_wrapper 18 | from baseline_models.icebeem.models.icebeem_wrapper import ICEBEEM_wrapper 19 | from train import get_dataset, get_loader 20 | from universal_logger.logger import UniversalLogger 21 | from metrics import evaluate_disentanglement 22 | 23 | def parse_sim(): 24 | parser = argparse.ArgumentParser(description='') 25 | parser.add_argument("--output_dir", required=True, 26 | help="Directory to output logs and model checkpoints") 27 | parser.add_argument("--dataset", type=str, required=True, 28 | help="Type of the dataset to be used 'toy-MANIFOLD/TRANSITION_MODEL'") 29 | parser.add_argument('--method', type=str, default='icebeem', choices=['icebeem', 'ivae'], 30 | help='Method to employ. Should be TCL, iVAE or ICE-BeeM') 31 | parser.add_argument("--dataroot", type=str, default="./", 32 | help="path to dataset") 33 | parser.add_argument("--gt_z_dim", type=int, default=5, 34 | help="ground truth dimensionality of z (for TRANSITION_MODEL == 'linear_system')") 35 | parser.add_argument("--gt_x_dim", type=int, default=10, 36 | help="ground truth dimensionality of x (for MANIFOLD == 'nn')") 37 | parser.add_argument("--num_samples", type=float, default=int(1e6), 38 | help="Number of samples") 39 | parser.add_argument("--rand_g_density", type=float, default=None, 40 | help="Probability of sampling an edge. When None, the graph is set to a default (or to gt_graph_name).") 41 | parser.add_argument("--gt_graph_name", type=str, default=None, 42 | help="Name of the ground-truth graph to use in synthetic data.") 43 | parser.add_argument("--architecture", type=str, default='ilcm_tabular', choices=['ilcm_tabular'], 44 | help="encoder/decoder architecture.") 45 | parser.add_argument("--learn_decoder_var", action='store_true', 46 | help="Whether to learn the variance of the output decoder") 47 | parser.add_argument("--train_prop", type=float, default=None, 48 | help="proportion of all samples used in validation set") 49 | parser.add_argument("--valid_prop", type=float, default=0.10, 50 | help="proportion of all samples used in validation set") 51 | parser.add_argument("--test_prop", type=float, default=0.10, 52 | help="proportion of all samples used in test set") 53 | parser.add_argument("--batch_size", type=int, default=1024, 54 | help="batch size used during training") 55 | parser.add_argument("--eval_batch_size", type=int, default=1024, 56 | help="batch size used during evaluation") 57 | parser.add_argument("--time_limit", type=float, default=None, 58 | help="After this amount of time, terminate training. (in hours)") 59 | parser.add_argument("--max_iter", type=int, default=int(1e6), 60 | help="Maximal amount of iterations") 61 | parser.add_argument("--ivae_lr", type=float, default=1e-4, 62 | help="After this amount of time, terminate training. (in hours)") 63 | parser.add_argument("--seed", type=int, default=0, 64 | help="manual seed") 65 | parser.add_argument('--no_print', action="store_true", 66 | help='do not print') 67 | parser.add_argument('--comet_key', type=str, default=None, 68 | help="comet api-key") 69 | parser.add_argument('--comet_tag', type=str, default=None, 70 | help="comet tag, to ease comparison") 71 | parser.add_argument('--comet_workspace', type=str, default=None, 72 | help="comet workspace") 73 | parser.add_argument('--comet_project_name', type=str, default=None, 74 | help="comet project_name") 75 | parser.add_argument("--no_cuda", action="store_false", dest="cuda", 76 | help="Disables cuda") 77 | 78 | #parser.add_argument('--dataset', type=str, default='TCL', help='Dataset to run experiments. Should be TCL or IMCA') 79 | 80 | #parser.add_argument('--config', type=str, default='imca.yaml', help='Path to the config file') 81 | #parser.add_argument('--run', type=str, default='run/', help='Path for saving running related data.') 82 | #parser.add_argument('--nSims', type=int, default=10, help='Number of simulations to run') 83 | 84 | #parser.add_argument('--test', action='store_true', help='Whether to evaluate the models from checkpoints') 85 | #parser.add_argument('--plot', action='store_true') 86 | 87 | return parser.parse_args() 88 | 89 | def main(args): 90 | print('WARNING: this code do not support discrete auxiliary variable. See warning in mean_corr_coef function in metrics.py') 91 | print('Running {} experiments using {}'.format(args.dataset, args.method)) 92 | 93 | device = torch.device('cuda:0' if args.cuda else 'cpu') 94 | 95 | if args.method.lower() == 'ivae': 96 | args.mode = "ivae" 97 | else: 98 | raise ValueError('Unsupported method {}'.format(args.method)) 99 | 100 | ## ---- Save hparams ---- ## 101 | kwargs = vars(args) 102 | with open(os.path.join(args.output_dir, "hparams.json"), "w") as fp: 103 | json.dump(kwargs, fp, sort_keys=True, indent=4) 104 | 105 | ## ---- Data ---- ## 106 | args.n_lag = 0 107 | args.no_norm = False 108 | args.n_workers = 0 # can't put it to 4 since we get weird error msg... 109 | image_shape, cont_c_dim, disc_c_dim, disc_c_n_values, train_dataset, valid_dataset, test_dataset = get_dataset(args) 110 | _, _, test_loader = get_loader(args, train_dataset, valid_dataset, test_dataset) 111 | x, y, s = train_dataset.x.cpu().numpy(), train_dataset.c.cpu().numpy(), train_dataset.z.cpu().numpy() 112 | 113 | ## ---- Logger ---- ## 114 | if COMET_AVAIL and args.comet_key is not None: 115 | comet_exp = Experiment(api_key=args.comet_key, project_name=args.comet_project_name, 116 | workspace=args.comet_workspace, auto_metric_logging=False, auto_param_logging=False) 117 | comet_exp.log_parameters(vars(args)) 118 | if args.comet_tag is not None: 119 | comet_exp.add_tag(args.comet_tag) 120 | else: 121 | comet_exp = None 122 | logger = UniversalLogger(comet=comet_exp, 123 | stdout=(not args.no_print), 124 | json=args.output_dir, throttle=None) 125 | 126 | ## ---- Running ---- ## 127 | # the argument `architecture="ilcm_tabular"` will choose the same encoder/decoder as in ilcm for synthetic experiments. 128 | model = IVAE_wrapper(x, y, args.gt_z_dim, ckpt_folder=args.output_dir, batch_size=args.batch_size, max_iter=args.max_iter, #max_iter=100, 129 | seed=args.seed, n_layers=5, hidden_dim=512, lr=args.ivae_lr, 130 | architecture=args.architecture, logger=logger, time_limit=args.time_limit, learn_decoder_var=args.learn_decoder_var) 131 | 132 | ## ---- Evaluate performance ---- ## 133 | mcc, consistent_r, r, cc, C_hat, C_pattern, perm_mat, z, z_hat, transposed_consistent_r = evaluate_disentanglement(model, test_loader, device, args) 134 | 135 | ## ---- Save ---- ## 136 | # save scores 137 | metrics = {"mean_corr_coef_final": mcc, 138 | "consistent_r_final": consistent_r, 139 | "r_final": r, 140 | "transposed_consistent_r_final": transposed_consistent_r} 141 | logger.log_metrics(step=0, metrics=metrics) 142 | 143 | # save both ground_truth and learned latents 144 | np.save(os.path.join(args.output_dir, "z_hat_final.npy"), z_hat) 145 | np.save(os.path.join(args.output_dir, "z_gt_final.npy"), z) 146 | 147 | if __name__ == '__main__': 148 | args = parse_sim() 149 | main(args) 150 | -------------------------------------------------------------------------------- /baseline_models/slowvae_pcl/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | import os, json, sys, traceback, time 4 | import pathlib 5 | 6 | try: 7 | from comet_ml import Experiment 8 | COMET_AVAIL = True 9 | except: 10 | COMET_AVAIL = False 11 | import numpy as np 12 | import torch 13 | import datetime 14 | 15 | sys.path.insert(0, str(pathlib.Path(__file__).parent)) 16 | from scripts.solver import Solver 17 | 18 | sys.path.insert(0, str(pathlib.Path(__file__).parent.parent.parent)) 19 | from train import get_dataset, get_loader 20 | from universal_logger.logger import UniversalLogger 21 | from metrics import evaluate_disentanglement 22 | 23 | torch.backends.cudnn.enabled = True 24 | torch.backends.cudnn.benchmark = True 25 | 26 | 27 | def main(args, writer=None): 28 | torch.manual_seed(args.seed) 29 | torch.cuda.manual_seed(args.seed) 30 | np.random.seed(args.seed) 31 | 32 | device = torch.device('cuda:0' if args.cuda else 'cpu') 33 | 34 | ## ---- Data ---- ## 35 | args.no_norm = False 36 | args.n_lag = 0 # get dataset expects this argument, but no effect. 37 | args.num_workers = args.n_workers 38 | image_shape, cont_c_dim, disc_c_dim, disc_c_n_values, train_dataset, valid_dataset, test_dataset, = get_dataset(args) 39 | train_loader, valid_loader, test_loader = get_loader(args, train_dataset, valid_dataset, test_dataset) 40 | data_loader = train_loader 41 | if len(image_shape) == 3: 42 | args.num_channel = image_shape[-1] 43 | else: 44 | args.num_channel = None 45 | 46 | ## ---- Logging ---- ## 47 | if COMET_AVAIL and args.comet_key is not None and args.comet_workspace is not None and args.comet_project_name is not None: 48 | comet_exp = Experiment(api_key=args.comet_key, project_name=args.comet_project_name, 49 | workspace=args.comet_workspace, auto_metric_logging=False, auto_param_logging=False) 50 | comet_exp.log_parameters(vars(args)) 51 | if args.comet_tag is not None: 52 | comet_exp.add_tag(args.comet_tag) 53 | else: 54 | comet_exp = None 55 | logger = UniversalLogger(comet=comet_exp, 56 | stdout=(not args.no_print), 57 | json=args.output_dir, throttle=None) 58 | 59 | t0 = time.time() 60 | 61 | # saving hp 62 | ## ---- Save hparams ---- ## 63 | if args.pcl: 64 | args.mode = "pcl" 65 | else: 66 | args.mode = 'slowvae' 67 | kwargs = vars(args) 68 | with open(os.path.join(args.output_dir, "hparams.json"), "w") as fp: 69 | json.dump(kwargs, fp, sort_keys=True, indent=4) 70 | with open(os.path.join(args.output_dir, "args"), "w") as f: 71 | json.dump(args.__dict__, f) 72 | 73 | net = Solver(args, image_shape, data_loader=data_loader, logger=logger, z_dim=train_dataset.z_dim) 74 | failure = net.train(writer) 75 | if failure: 76 | print('failed in %.2fs' % (time.time() - t0)) 77 | #shutil.rmtree(args.output_dir) 78 | else: 79 | print('done in %.2fs' % (time.time() - t0)) 80 | 81 | ## ---- Evaluate performance ---- # 82 | mcc, consistent_r, r, cc, C_hat, C_pattern, perm_mat, z, z_hat, transposed_consistent_r = evaluate_disentanglement(net.net, test_loader, device, args) 83 | 84 | ## ---- Save ---- ## 85 | # save scores 86 | metrics = {"mean_corr_coef_final": mcc, 87 | "consistent_r_final": consistent_r, 88 | "r_final": r, 89 | "transposed_consistent_r_final": transposed_consistent_r} 90 | logger.log_metrics(step=0, metrics=metrics) 91 | 92 | # save both ground_truth and learned latents 93 | np.save(os.path.join(args.output_dir, "z_hat_final.npy"), z_hat) 94 | np.save(os.path.join(args.output_dir, "z_gt_final.npy"), z) 95 | 96 | 97 | ### For Random Search ### 98 | def randint(low, high): 99 | return np.int(np.random.randint(low, high, 1)[0]) 100 | 101 | def uniform(low, high): 102 | return np.random.uniform(low, high, 1)[0] 103 | 104 | def loguniform(low, high): 105 | return np.exp(np.random.uniform(np.log(low), np.log(high), 1))[0] 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser(description='slowVAE') 109 | parser.add_argument('--pcl', action='store_true') 110 | parser.add_argument('--r_func', type=str, default='default', choices=['default', 'mlp'], 111 | help='Type of regression function used for PCL') 112 | parser.add_argument("--output_dir", required=True, 113 | help="Directory to output logs and model checkpoints") 114 | parser.add_argument("--dataset", type=str, required=True, 115 | help="Type of the dataset to be used 'toy-MANIFOLD/TRANSITION_MODEL'") 116 | parser.add_argument("--dataroot", type=str, default="./", 117 | help="path to dataset") 118 | parser.add_argument("--gt_z_dim", type=int, default=10, 119 | help="ground truth dimensionality of z (for TRANSITION_MODEL == 'linear_system')") 120 | parser.add_argument("--gt_x_dim", type=int, default=20, 121 | help="ground truth dimensionality of x (for MANIFOLD == 'nn')") 122 | parser.add_argument("--num_samples", type=float, default=int(1e6), 123 | help="number of trajectories in toy dataset") 124 | parser.add_argument("--architecture", type=str, default='ilcm_tabular', choices=['ilcm_tabular', 'standard_conv'], 125 | help="VAE encoder/decoder architecture.") 126 | parser.add_argument("--train_prop", type=float, default=None, 127 | help="proportion of all samples used in validation set") 128 | parser.add_argument("--valid_prop", type=float, default=0.10, 129 | help="proportion of all samples used in validation set") 130 | parser.add_argument("--test_prop", type=float, default=0.10, 131 | help="proportion of all samples used in test set") 132 | parser.add_argument("--n_workers", type=int, default=4) 133 | parser.add_argument("--time_limit", type=float, default=None, 134 | help="After this amount of time, terminate training. (in hours)") 135 | parser.add_argument("--max_iter", type=int, default=int(1e6), 136 | help="Maximal amount of iterations") 137 | parser.add_argument("--seed", type=int, default=0, 138 | help="manual seed") 139 | parser.add_argument('--no_print', action="store_true", 140 | help='do not print') 141 | parser.add_argument('--comet_key', type=str, default=None, 142 | help="comet api-key") 143 | parser.add_argument('--comet_tag', type=str, default=None, 144 | help="comet tag, to ease comparison") 145 | parser.add_argument('--comet_workspace', type=str, default=None, 146 | help="comet workspace") 147 | parser.add_argument('--comet_project_name', type=str, default=None, 148 | help="comet project_name") 149 | parser.add_argument("--rand_g_density", type=float, default=None, 150 | help="Probability of sampling an edge. When None, the graph is set to a default (or to gt_graph_name).") 151 | parser.add_argument("--gt_graph_name", type=str, default=None, 152 | help="Name of the ground-truth graph to use in synthetic data.") 153 | parser.add_argument("--add_noise", type=float, default=0.0, 154 | help="Add normal noise sigma = add_noise on images (only training data)") 155 | parser.add_argument("--no_cuda", action="store_false", dest="cuda", 156 | help="Disables cuda") 157 | parser.add_argument("--batch_size", type=int, default=1024, 158 | help="batch size used during training") 159 | parser.add_argument("--eval_batch_size", type=int, default=1024, 160 | help="batch size used during evaluation") 161 | parser.add_argument('--beta', default=1, type=float, 162 | help='weight for kl to normal') 163 | parser.add_argument('--gamma', default=10, type=float, 164 | help='weight for kl to laplace') 165 | parser.add_argument('--rate_prior', default=6, type=float, 166 | help='rate (or inverse scale) for prior laplace (larger -> sparser).') 167 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 168 | parser.add_argument('--beta1', default=0.9, type=float, 169 | help='Adam optimizer beta1') 170 | parser.add_argument('--beta2', default=0.999, type=float, 171 | help='Adam optimizer beta2') 172 | parser.add_argument('--ckpt-name', default='last', type=str, 173 | help='load previous checkpoint. insert checkpoint filename') 174 | parser.add_argument('--log_step', default=100, type=int, 175 | help='numer of iterations after which data is logged') 176 | parser.add_argument('--save_step', default=10000, type=int, 177 | help='number of iterations after which a checkpoint is saved') 178 | args = parser.parse_args() 179 | 180 | args = main(args) 181 | -------------------------------------------------------------------------------- /baseline_models/beta-tcvae/elbo_decomposition.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from numbers import Number 4 | from tqdm import tqdm 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | import lib.dist as dist 9 | import lib.flows as flows 10 | 11 | 12 | def estimate_entropies(qz_samples, qz_params, q_dist): 13 | """Computes the term: 14 | E_{p(x)} E_{q(z|x)} [-log q(z)] 15 | and 16 | E_{p(x)} E_{q(z_j|x)} [-log q(z_j)] 17 | where q(z) = 1/N sum_n=1^N q(z|x_n). 18 | Assumes samples are from q(z|x) for *all* x in the dataset. 19 | Assumes that q(z|x) is factorial ie. q(z|x) = prod_j q(z_j|x). 20 | 21 | Computes numerically stable NLL: 22 | - log q(z) = log N - logsumexp_n=1^N log q(z|x_n) 23 | 24 | Inputs: 25 | ------- 26 | qz_samples (K, S) Variable 27 | qz_params (N, K, nparams) Variable 28 | """ 29 | 30 | # Only take a sample subset of the samples 31 | qz_samples = qz_samples.index_select(1, Variable(torch.randperm(qz_samples.size(1))[:10000].cuda())) 32 | 33 | K, S = qz_samples.size() 34 | N, _, nparams = qz_params.size() 35 | assert(nparams == q_dist.nparams) 36 | assert(K == qz_params.size(1)) 37 | 38 | marginal_entropies = torch.zeros(K).cuda() 39 | joint_entropy = torch.zeros(1).cuda() 40 | 41 | pbar = tqdm(total=S) 42 | k = 0 43 | while k < S: 44 | batch_size = min(10, S - k) 45 | logqz_i = q_dist.log_density( 46 | qz_samples.view(1, K, S).expand(N, K, S)[:, :, k:k + batch_size], 47 | qz_params.view(N, K, 1, nparams).expand(N, K, S, nparams)[:, :, k:k + batch_size]) 48 | k += batch_size 49 | 50 | # computes - log q(z_i) summed over minibatch 51 | marginal_entropies += (math.log(N) - logsumexp(logqz_i, dim=0, keepdim=False).data).sum(1) 52 | # computes - log q(z) summed over minibatch 53 | logqz = logqz_i.sum(1) # (N, S) 54 | joint_entropy += (math.log(N) - logsumexp(logqz, dim=0, keepdim=False).data).sum(0) 55 | pbar.update(batch_size) 56 | pbar.close() 57 | 58 | marginal_entropies /= S 59 | joint_entropy /= S 60 | 61 | return marginal_entropies, joint_entropy 62 | 63 | 64 | def logsumexp(value, dim=None, keepdim=False): 65 | """Numerically stable implementation of the operation 66 | 67 | value.exp().sum(dim, keepdim).log() 68 | """ 69 | if dim is not None: 70 | m, _ = torch.max(value, dim=dim, keepdim=True) 71 | value0 = value - m 72 | if keepdim is False: 73 | m = m.squeeze(dim) 74 | return m + torch.log(torch.sum(torch.exp(value0), 75 | dim=dim, keepdim=keepdim)) 76 | else: 77 | m = torch.max(value) 78 | sum_exp = torch.sum(torch.exp(value - m)) 79 | if isinstance(sum_exp, Number): 80 | return m + math.log(sum_exp) 81 | else: 82 | return m + torch.log(sum_exp) 83 | 84 | 85 | def analytical_NLL(qz_params, q_dist, prior_dist, qz_samples=None): 86 | """Computes the quantities 87 | 1/N sum_n=1^N E_{q(z|x)} [ - log q(z|x) ] 88 | and 89 | 1/N sum_n=1^N E_{q(z_j|x)} [ - log p(z_j) ] 90 | 91 | Inputs: 92 | ------- 93 | qz_params (N, K, nparams) Variable 94 | 95 | Returns: 96 | -------- 97 | nlogqz_condx (K,) Variable 98 | nlogpz (K,) Variable 99 | """ 100 | pz_params = Variable(torch.zeros(1).type_as(qz_params.data).expand(qz_params.size()), volatile=True) 101 | 102 | nlogqz_condx = q_dist.NLL(qz_params).mean(0) 103 | nlogpz = prior_dist.NLL(pz_params, qz_params).mean(0) 104 | return nlogqz_condx, nlogpz 105 | 106 | 107 | def elbo_decomposition(vae, dataset_loader): 108 | N = len(dataset_loader.dataset) # number of data samples 109 | K = vae.z_dim # number of latent variables 110 | S = 1 # number of latent variable samples 111 | nparams = vae.q_dist.nparams 112 | 113 | print('Computing q(z|x) distributions.') 114 | # compute the marginal q(z_j|x_n) distributions 115 | qz_params = torch.Tensor(N, K, nparams) 116 | n = 0 117 | logpx = 0 118 | for xs in dataset_loader: 119 | batch_size = xs.size(0) 120 | xs = Variable(xs.view(batch_size, -1, 64, 64).cuda(), volatile=True) 121 | z_params = vae.encoder.forward(xs).view(batch_size, K, nparams) 122 | qz_params[n:n + batch_size] = z_params.data 123 | n += batch_size 124 | 125 | # estimate reconstruction term 126 | for _ in range(S): 127 | z = vae.q_dist.sample(params=z_params) 128 | x_params = vae.decoder.forward(z) 129 | logpx += vae.x_dist.log_density(xs, params=x_params).view(batch_size, -1).data.sum() 130 | # Reconstruction term 131 | logpx = logpx / (N * S) 132 | 133 | qz_params = Variable(qz_params.cuda(), volatile=True) 134 | 135 | print('Sampling from q(z).') 136 | # sample S times from each marginal q(z_j|x_n) 137 | qz_params_expanded = qz_params.view(N, K, 1, nparams).expand(N, K, S, nparams) 138 | qz_samples = vae.q_dist.sample(params=qz_params_expanded) 139 | qz_samples = qz_samples.transpose(0, 1).contiguous().view(K, N * S) 140 | 141 | print('Estimating entropies.') 142 | marginal_entropies, joint_entropy = estimate_entropies(qz_samples, qz_params, vae.q_dist) 143 | 144 | if hasattr(vae.q_dist, 'NLL'): 145 | nlogqz_condx = vae.q_dist.NLL(qz_params).mean(0) 146 | else: 147 | nlogqz_condx = - vae.q_dist.log_density(qz_samples, 148 | qz_params_expanded.transpose(0, 1).contiguous().view(K, N * S)).mean(1) 149 | 150 | if hasattr(vae.prior_dist, 'NLL'): 151 | pz_params = vae._get_prior_params(N * K).contiguous().view(N, K, -1) 152 | nlogpz = vae.prior_dist.NLL(pz_params, qz_params).mean(0) 153 | else: 154 | nlogpz = - vae.prior_dist.log_density(qz_samples.transpose(0, 1)).mean(0) 155 | 156 | # nlogqz_condx, nlogpz = analytical_NLL(qz_params, vae.q_dist, vae.prior_dist) 157 | nlogqz_condx = nlogqz_condx.data 158 | nlogpz = nlogpz.data 159 | 160 | # Independence term 161 | # KL(q(z)||prod_j q(z_j)) = log q(z) - sum_j log q(z_j) 162 | dependence = (- joint_entropy + marginal_entropies.sum())[0] 163 | 164 | # Information term 165 | # KL(q(z|x)||q(z)) = log q(z|x) - log q(z) 166 | information = (- nlogqz_condx.sum() + joint_entropy)[0] 167 | 168 | # Dimension-wise KL term 169 | # sum_j KL(q(z_j)||p(z_j)) = sum_j (log q(z_j) - log p(z_j)) 170 | dimwise_kl = (- marginal_entropies + nlogpz).sum() 171 | 172 | # Compute sum of terms analytically 173 | # KL(q(z|x)||p(z)) = log q(z|x) - log p(z) 174 | analytical_cond_kl = (- nlogqz_condx + nlogpz).sum() 175 | 176 | print('Dependence: {}'.format(dependence)) 177 | print('Information: {}'.format(information)) 178 | print('Dimension-wise KL: {}'.format(dimwise_kl)) 179 | print('Analytical E_p(x)[ KL(q(z|x)||p(z)) ]: {}'.format(analytical_cond_kl)) 180 | print('Estimated ELBO: {}'.format(logpx - analytical_cond_kl)) 181 | 182 | return logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy 183 | 184 | 185 | if __name__ == '__main__': 186 | import argparse 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument('-checkpt', required=True) 189 | parser.add_argument('-save', type=str, default='.') 190 | parser.add_argument('-gpu', type=int, default=0) 191 | args = parser.parse_args() 192 | 193 | def load_model_and_dataset(checkpt_filename): 194 | checkpt = torch.load(checkpt_filename) 195 | args = checkpt['args'] 196 | state_dict = checkpt['state_dict'] 197 | 198 | # backwards compatibility 199 | if not hasattr(args, 'conv'): 200 | args.conv = False 201 | 202 | from vae_quant import VAE, setup_data_loaders 203 | 204 | # model 205 | if args.dist == 'normal': 206 | prior_dist = dist.Normal() 207 | q_dist = dist.Normal() 208 | elif args.dist == 'laplace': 209 | prior_dist = dist.Laplace() 210 | q_dist = dist.Laplace() 211 | elif args.dist == 'flow': 212 | prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32) 213 | q_dist = dist.Normal() 214 | vae = VAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv) 215 | vae.load_state_dict(state_dict, strict=False) 216 | vae.eval() 217 | 218 | # dataset loader 219 | loader = setup_data_loaders(args, use_cuda=True) 220 | return vae, loader 221 | 222 | torch.cuda.set_device(args.gpu) 223 | vae, dataset_loader = load_model_and_dataset(args.checkpt) 224 | logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy = \ 225 | elbo_decomposition(vae, dataset_loader) 226 | torch.save({ 227 | 'logpx': logpx, 228 | 'dependence': dependence, 229 | 'information': information, 230 | 'dimwise_kl': dimwise_kl, 231 | 'analytical_cond_kl': analytical_cond_kl, 232 | 'marginal_entropies': marginal_entropies, 233 | 'joint_entropy': joint_entropy 234 | }, os.path.join(args.save, 'elbo_decomposition.pth')) 235 | -------------------------------------------------------------------------------- /baseline_models/beta-tcvae/disentanglement_metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | from tqdm import tqdm 5 | from torch.utils.data import DataLoader 6 | from torch.autograd import Variable 7 | 8 | import lib.utils as utils 9 | from metric_helpers.loader import load_model_and_dataset 10 | from metric_helpers.mi_metric import compute_metric_shapes, compute_metric_faces 11 | 12 | 13 | def estimate_entropies(qz_samples, qz_params, q_dist, n_samples=10000, weights=None): 14 | """Computes the term: 15 | E_{p(x)} E_{q(z|x)} [-log q(z)] 16 | and 17 | E_{p(x)} E_{q(z_j|x)} [-log q(z_j)] 18 | where q(z) = 1/N sum_n=1^N q(z|x_n). 19 | Assumes samples are from q(z|x) for *all* x in the dataset. 20 | Assumes that q(z|x) is factorial ie. q(z|x) = prod_j q(z_j|x). 21 | 22 | Computes numerically stable NLL: 23 | - log q(z) = log N - logsumexp_n=1^N log q(z|x_n) 24 | 25 | Inputs: 26 | ------- 27 | qz_samples (K, N) Variable 28 | qz_params (N, K, nparams) Variable 29 | weights (N) Variable 30 | """ 31 | 32 | # Only take a sample subset of the samples 33 | if weights is None: 34 | qz_samples = qz_samples.index_select(1, Variable(torch.randperm(qz_samples.size(1))[:n_samples].cuda())) 35 | else: 36 | sample_inds = torch.multinomial(weights, n_samples, replacement=True) 37 | qz_samples = qz_samples.index_select(1, sample_inds) 38 | 39 | K, S = qz_samples.size() 40 | N, _, nparams = qz_params.size() 41 | assert(nparams == q_dist.nparams) 42 | assert(K == qz_params.size(1)) 43 | 44 | if weights is None: 45 | weights = -math.log(N) 46 | else: 47 | weights = torch.log(weights.view(N, 1, 1) / weights.sum()) 48 | 49 | entropies = torch.zeros(K).cuda() 50 | 51 | pbar = tqdm(total=S) 52 | k = 0 53 | while k < S: 54 | batch_size = min(10, S - k) 55 | logqz_i = q_dist.log_density( 56 | qz_samples.view(1, K, S).expand(N, K, S)[:, :, k:k + batch_size], 57 | qz_params.view(N, K, 1, nparams).expand(N, K, S, nparams)[:, :, k:k + batch_size]) 58 | k += batch_size 59 | 60 | # computes - log q(z_i) summed over minibatch 61 | entropies += - utils.logsumexp(logqz_i + weights, dim=0, keepdim=False).data.sum(1) 62 | pbar.update(batch_size) 63 | pbar.close() 64 | 65 | entropies /= S 66 | 67 | return entropies 68 | 69 | 70 | def mutual_info_metric_shapes(vae, shapes_dataset): 71 | dataset_loader = DataLoader(shapes_dataset, batch_size=1000, num_workers=1, shuffle=False) 72 | 73 | N = len(dataset_loader.dataset) # number of data samples 74 | K = vae.z_dim # number of latent variables 75 | nparams = vae.q_dist.nparams 76 | vae.eval() 77 | 78 | print('Computing q(z|x) distributions.') 79 | qz_params = torch.Tensor(N, K, nparams) 80 | 81 | n = 0 82 | for xs in dataset_loader: 83 | batch_size = xs.size(0) 84 | xs = Variable(xs.view(batch_size, 1, 64, 64).cuda(), volatile=True) 85 | qz_params[n:n + batch_size] = vae.encoder.forward(xs).view(batch_size, vae.z_dim, nparams).data 86 | n += batch_size 87 | 88 | qz_params = Variable(qz_params.view(3, 6, 40, 32, 32, K, nparams).cuda()) 89 | qz_samples = vae.q_dist.sample(params=qz_params) 90 | 91 | print('Estimating marginal entropies.') 92 | # marginal entropies 93 | marginal_entropies = estimate_entropies( 94 | qz_samples.view(N, K).transpose(0, 1), 95 | qz_params.view(N, K, nparams), 96 | vae.q_dist) 97 | 98 | marginal_entropies = marginal_entropies.cpu() 99 | cond_entropies = torch.zeros(4, K) 100 | 101 | print('Estimating conditional entropies for scale.') 102 | for i in range(6): 103 | qz_samples_scale = qz_samples[:, i, :, :, :, :].contiguous() 104 | qz_params_scale = qz_params[:, i, :, :, :, :].contiguous() 105 | 106 | cond_entropies_i = estimate_entropies( 107 | qz_samples_scale.view(N // 6, K).transpose(0, 1), 108 | qz_params_scale.view(N // 6, K, nparams), 109 | vae.q_dist) 110 | 111 | cond_entropies[0] += cond_entropies_i.cpu() / 6 112 | 113 | print('Estimating conditional entropies for orientation.') 114 | for i in range(40): 115 | qz_samples_scale = qz_samples[:, :, i, :, :, :].contiguous() 116 | qz_params_scale = qz_params[:, :, i, :, :, :].contiguous() 117 | 118 | cond_entropies_i = estimate_entropies( 119 | qz_samples_scale.view(N // 40, K).transpose(0, 1), 120 | qz_params_scale.view(N // 40, K, nparams), 121 | vae.q_dist) 122 | 123 | cond_entropies[1] += cond_entropies_i.cpu() / 40 124 | 125 | print('Estimating conditional entropies for pos x.') 126 | for i in range(32): 127 | qz_samples_scale = qz_samples[:, :, :, i, :, :].contiguous() 128 | qz_params_scale = qz_params[:, :, :, i, :, :].contiguous() 129 | 130 | cond_entropies_i = estimate_entropies( 131 | qz_samples_scale.view(N // 32, K).transpose(0, 1), 132 | qz_params_scale.view(N // 32, K, nparams), 133 | vae.q_dist) 134 | 135 | cond_entropies[2] += cond_entropies_i.cpu() / 32 136 | 137 | print('Estimating conditional entropies for pox y.') 138 | for i in range(32): 139 | qz_samples_scale = qz_samples[:, :, :, :, i, :].contiguous() 140 | qz_params_scale = qz_params[:, :, :, :, i, :].contiguous() 141 | 142 | cond_entropies_i = estimate_entropies( 143 | qz_samples_scale.view(N // 32, K).transpose(0, 1), 144 | qz_params_scale.view(N // 32, K, nparams), 145 | vae.q_dist) 146 | 147 | cond_entropies[3] += cond_entropies_i.cpu() / 32 148 | 149 | metric = compute_metric_shapes(marginal_entropies, cond_entropies) 150 | return metric, marginal_entropies, cond_entropies 151 | 152 | 153 | def mutual_info_metric_faces(vae, shapes_dataset): 154 | dataset_loader = DataLoader(shapes_dataset, batch_size=1000, num_workers=1, shuffle=False) 155 | 156 | N = len(dataset_loader.dataset) # number of data samples 157 | K = vae.z_dim # number of latent variables 158 | nparams = vae.q_dist.nparams 159 | vae.eval() 160 | 161 | print('Computing q(z|x) distributions.') 162 | qz_params = torch.Tensor(N, K, nparams) 163 | 164 | n = 0 165 | for xs in dataset_loader: 166 | batch_size = xs.size(0) 167 | xs = Variable(xs.view(batch_size, 1, 64, 64).cuda(), volatile=True) 168 | qz_params[n:n + batch_size] = vae.encoder.forward(xs).view(batch_size, vae.z_dim, nparams).data 169 | n += batch_size 170 | 171 | qz_params = Variable(qz_params.view(50, 21, 11, 11, K, nparams).cuda()) 172 | qz_samples = vae.q_dist.sample(params=qz_params) 173 | 174 | print('Estimating marginal entropies.') 175 | # marginal entropies 176 | marginal_entropies = estimate_entropies( 177 | qz_samples.view(N, K).transpose(0, 1), 178 | qz_params.view(N, K, nparams), 179 | vae.q_dist) 180 | 181 | marginal_entropies = marginal_entropies.cpu() 182 | cond_entropies = torch.zeros(3, K) 183 | 184 | print('Estimating conditional entropies for azimuth.') 185 | for i in range(21): 186 | qz_samples_pose_az = qz_samples[:, i, :, :, :].contiguous() 187 | qz_params_pose_az = qz_params[:, i, :, :, :].contiguous() 188 | 189 | cond_entropies_i = estimate_entropies( 190 | qz_samples_pose_az.view(N // 21, K).transpose(0, 1), 191 | qz_params_pose_az.view(N // 21, K, nparams), 192 | vae.q_dist) 193 | 194 | cond_entropies[0] += cond_entropies_i.cpu() / 21 195 | 196 | print('Estimating conditional entropies for elevation.') 197 | for i in range(11): 198 | qz_samples_pose_el = qz_samples[:, :, i, :, :].contiguous() 199 | qz_params_pose_el = qz_params[:, :, i, :, :].contiguous() 200 | 201 | cond_entropies_i = estimate_entropies( 202 | qz_samples_pose_el.view(N // 11, K).transpose(0, 1), 203 | qz_params_pose_el.view(N // 11, K, nparams), 204 | vae.q_dist) 205 | 206 | cond_entropies[1] += cond_entropies_i.cpu() / 11 207 | 208 | print('Estimating conditional entropies for lighting.') 209 | for i in range(11): 210 | qz_samples_lighting = qz_samples[:, :, :, i, :].contiguous() 211 | qz_params_lighting = qz_params[:, :, :, i, :].contiguous() 212 | 213 | cond_entropies_i = estimate_entropies( 214 | qz_samples_lighting.view(N // 11, K).transpose(0, 1), 215 | qz_params_lighting.view(N // 11, K, nparams), 216 | vae.q_dist) 217 | 218 | cond_entropies[2] += cond_entropies_i.cpu() / 11 219 | 220 | metric = compute_metric_faces(marginal_entropies, cond_entropies) 221 | return metric, marginal_entropies, cond_entropies 222 | 223 | 224 | if __name__ == '__main__': 225 | import argparse 226 | parser = argparse.ArgumentParser() 227 | parser.add_argument('--checkpt', required=True) 228 | parser.add_argument('--gpu', type=int, default=0) 229 | parser.add_argument('--save', type=str, default='.') 230 | args = parser.parse_args() 231 | 232 | if args.gpu != 0: 233 | torch.cuda.set_device(args.gpu) 234 | vae, dataset, cpargs = load_model_and_dataset(args.checkpt) 235 | metric, marginal_entropies, cond_entropies = eval('mutual_info_metric_' + cpargs.dataset)(vae, dataset) 236 | torch.save({ 237 | 'metric': metric, 238 | 'marginal_entropies': marginal_entropies, 239 | 'cond_entropies': cond_entropies, 240 | }, os.path.join(args.save, 'disentanglement_metric.pth')) 241 | print('MIG: {:.2f}'.format(metric)) 242 | -------------------------------------------------------------------------------- /baseline_models/icebeem/models/nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class smoothReLU(nn.Module): 7 | """ 8 | smooth ReLU activation function 9 | """ 10 | 11 | def __init__(self, beta=1): 12 | super().__init__() 13 | self.beta = 1 14 | 15 | def forward(self, x): 16 | return x / (1 + torch.exp(-self.beta * x)) 17 | 18 | 19 | class LeafParam(nn.Module): 20 | """ 21 | just ignores the input and outputs a parameter tensor 22 | """ 23 | 24 | def __init__(self, n): 25 | super().__init__() 26 | self.p = nn.Parameter(torch.zeros(1, n)) 27 | 28 | def forward(self, x): 29 | return self.p.expand(x.size(0), self.p.size(1)) 30 | 31 | 32 | class PositionalEncoder(nn.Module): 33 | """ 34 | Each dimension of the input gets expanded out with sins/coses 35 | to "carve" out the space. Useful in low-dimensional cases with 36 | tightly "curled up" data. 37 | """ 38 | 39 | def __init__(self, freqs=(.5, 1, 2, 4, 8)): 40 | super().__init__() 41 | self.freqs = freqs 42 | 43 | def forward(self, x): 44 | sines = [torch.sin(x * f) for f in self.freqs] 45 | coses = [torch.cos(x * f) for f in self.freqs] 46 | out = torch.cat(sines + coses, dim=1) 47 | return out 48 | 49 | 50 | class MLP4(nn.Module): 51 | """ a simple 4-layer MLP4 """ 52 | 53 | def __init__(self, nin, nout, nh): 54 | super().__init__() 55 | self.net = nn.Sequential( 56 | nn.Linear(nin, nh), 57 | nn.LeakyReLU(0.2), 58 | nn.Linear(nh, nh), 59 | nn.LeakyReLU(0.2), 60 | nn.Linear(nh, nh), 61 | nn.LeakyReLU(0.2), 62 | nn.Linear(nh, nout), 63 | ) 64 | 65 | def forward(self, x): 66 | return self.net(x) 67 | 68 | 69 | class PosEncMLP(nn.Module): 70 | """ 71 | Position Encoded MLP4, where the first layer performs position encoding. 72 | Each dimension of the input gets transformed to len(freqs)*2 dimensions 73 | using a fixed transformation of sin/cos of given frequencies. 74 | """ 75 | 76 | def __init__(self, nin, nout, nh, freqs=(.5, 1, 2, 4, 8)): 77 | super().__init__() 78 | self.net = nn.Sequential( 79 | PositionalEncoder(freqs), 80 | MLP4(nin * len(freqs) * 2, nout, nh), 81 | ) 82 | 83 | def forward(self, x): 84 | return self.net(x) 85 | 86 | 87 | class MLPlayer(nn.Module): 88 | """ 89 | implement basic module for MLP 90 | 91 | note that this module keeps the dimensions fixed! will implement a mapping from a 92 | vector of dimension input_size to another vector of dimension input_size 93 | """ 94 | 95 | def __init__(self, input_size, output_size=None, activation_function=nn.functional.relu, use_bn=False): 96 | super().__init__() 97 | if output_size is None: 98 | output_size = input_size 99 | self.activation_function = activation_function 100 | self.linear_layer = nn.Linear(input_size, output_size) 101 | self.use_bn = use_bn 102 | self.bn_layer = nn.BatchNorm1d(input_size) 103 | 104 | def forward(self, x): 105 | if self.use_bn: 106 | x = self.bn_layer(x) 107 | linear_act = self.linear_layer(x) 108 | H_x = self.activation_function(linear_act) 109 | return H_x 110 | 111 | 112 | class MLP(nn.Module): 113 | """ 114 | define a MLP network - this is a more general class than MLP4 above, allows for user to specify 115 | the dimensions at each layer of the network 116 | """ 117 | 118 | def __init__(self, input_size, hidden_size, n_layers, output_size=None, activation_function=F.relu, use_bn=False): 119 | """ 120 | Input: 121 | - input_size : dimension of input data (e.g., 784 for MNIST) 122 | - hidden_size : list of hidden representations, one entry per layer 123 | - n_layers : number of hidden layers 124 | """ 125 | super().__init__() 126 | 127 | if output_size is None: 128 | output_size = 1 # because we approximating a log density, output should be scalar! 129 | 130 | self.use_bn = use_bn 131 | self.activation_function = activation_function 132 | self.linear1st = nn.Linear(input_size, hidden_size[0]) # map from data dim to dimension of hidden units 133 | self.Layers = nn.ModuleList([MLPlayer(hidden_size[i - 1], hidden_size[i], 134 | activation_function=self.activation_function, use_bn=self.use_bn) for i in 135 | range(1, n_layers)]) 136 | self.linearLast = nn.Linear(hidden_size[-1], 137 | output_size) # map from dimension of hidden units to dimension of output 138 | 139 | def forward(self, x): 140 | """ 141 | forward pass through resnet 142 | """ 143 | x = self.linear1st(x) 144 | for current_layer in self.Layers: 145 | x = current_layer(x) 146 | x = self.linearLast(x) 147 | return x 148 | 149 | 150 | class CleanMLP(nn.Module): 151 | def __init__(self, input_size, hidden_size, n_hidden, output_size, activation='lrelu', batch_norm=False): 152 | super().__init__() 153 | 154 | self.input_size = input_size 155 | self.output_size = output_size 156 | self.hidden_size = hidden_size 157 | self.n_hidden = n_hidden 158 | self.activation = activation 159 | self.batch_norm = batch_norm 160 | 161 | if activation == 'lrelu': 162 | act = nn.LeakyReLU(0.2, inplace=True) 163 | elif activation == 'relu': 164 | act = nn.ReLU() 165 | else: 166 | raise ValueError('wrong activation') 167 | 168 | # construct model 169 | if n_hidden == 0: 170 | modules = [nn.Linear(input_size, output_size)] 171 | else: 172 | modules = [nn.Linear(input_size, hidden_size), act] + batch_norm * [nn.BatchNorm1d(hidden_size)] 173 | 174 | for i in range(n_hidden - 1): 175 | modules += [nn.Linear(hidden_size, hidden_size), act] + batch_norm * [nn.BatchNorm1d(hidden_size)] 176 | 177 | modules += [nn.Linear(hidden_size, output_size)] 178 | 179 | self.net = nn.Sequential(*modules) 180 | 181 | def forward(self, x, y=None): 182 | return self.net(x) 183 | 184 | 185 | class SimpleLinear(nn.Linear): 186 | """ 187 | a wrapper around nn.Linear that defines custom fields 188 | """ 189 | 190 | def __init__(self, nin, nout, bias=False): 191 | super().__init__(nin, nout, bias=bias) 192 | self.input_size = nin 193 | self.output_size = nout 194 | 195 | 196 | class FullMLP(nn.Module): 197 | def __init__(self, config): 198 | super().__init__() 199 | self.num_classes = config.model.num_classes 200 | self.image_size = config.data.image_size 201 | self.n_channels = config.data.channels 202 | self.ngf = ngf = config.model.ngf 203 | 204 | self.input_size = config.data.image_size ** 2 * config.data.channels 205 | self.output_size = self.input_size 206 | if config.model.final_layer: 207 | self.output_size = config.model.feature_size 208 | 209 | self.linear = nn.Sequential( 210 | nn.Linear(self.input_size, ngf * 8), 211 | nn.LeakyReLU(inplace=True, negative_slope=.1), 212 | nn.Linear(ngf * 8, ngf * 6), 213 | nn.LeakyReLU(inplace=True, negative_slope=.1), 214 | nn.Dropout(p=0.1), 215 | nn.Linear(ngf * 6, ngf * 4), 216 | nn.LeakyReLU(inplace=True, negative_slope=.1), 217 | nn.Linear(ngf * 4, ngf * 4), 218 | nn.LeakyReLU(inplace=True, negative_slope=.1), 219 | nn.Linear(ngf * 4, self.output_size) 220 | ) 221 | 222 | def forward(self, x): 223 | output = x.view(x.shape[0], -1) 224 | output = self.linear(output) 225 | return output 226 | 227 | 228 | class ConvMLP(nn.Module): 229 | def __init__(self, config): 230 | super().__init__() 231 | self.num_classes = config.model.num_classes 232 | self.image_size = im = config.data.image_size 233 | self.n_channels = nc = config.data.channels 234 | self.ngf = ngf = config.model.ngf 235 | 236 | self.input_size = config.data.image_size ** 2 * config.data.channels 237 | self.output_size = self.input_size 238 | if config.model.final_layer: 239 | self.output_size = config.model.feature_size 240 | 241 | # convolutional bit is [(conv, bn, relu, maxpool)*2, resize_conv) 242 | self.conv = nn.Sequential( 243 | # input is (nc, im, im) 244 | nn.Conv2d(nc, ngf // 2, 3, 1, 1), # (ngf/2, im, im) 245 | nn.BatchNorm2d(ngf // 2), # (ngf/2, im, im) 246 | nn.ReLU(inplace=True), # (ngf/2, im, im) 247 | nn.Conv2d(ngf // 2, ngf, 3, 1, 1), # (ngf, im, im) 248 | nn.BatchNorm2d(ngf), # (ngf, im, im) 249 | nn.ReLU(inplace=True), # (ngf, im, im) 250 | nn.MaxPool2d(kernel_size=2, stride=2), # (ngf, im/2, im/2) 251 | nn.Conv2d(ngf, ngf * 2, 3, 1, 1), # (ngf*2, im/2, im/2) 252 | nn.BatchNorm2d(ngf * 2), # (ngf*2, im/2, im/2) 253 | nn.ReLU(inplace=True), # (ngf*2, im/2, im/2) 254 | nn.Conv2d(ngf * 2, ngf * 4, 3, 1, 1), # (ngf*4, im/2, im/2) 255 | nn.BatchNorm2d(ngf * 4), # (ngf*4, im/2, im/2) 256 | nn.ReLU(inplace=True), # (ngf*4, im/2, im/2) 257 | nn.MaxPool2d(kernel_size=2, stride=2), # (ngf*4, im/4, im/4) 258 | nn.Conv2d(ngf * 4, ngf * 4, im // 4, 1, 0) # (ngf*4, 1, 1) 259 | ) 260 | # linear bit is [drop, (lin, lrelu)*2, lin] 261 | self.linear = nn.Sequential( 262 | nn.Dropout(p=0.1), 263 | nn.Linear(ngf * 4, ngf * 4), 264 | # nn.LeakyReLU(inplace=True, negative_slope=.1), 265 | # nn.Linear(ngf * 2, ngf * 2), 266 | nn.LeakyReLU(inplace=True, negative_slope=.1), 267 | nn.Linear(ngf * 4, self.output_size) 268 | ) 269 | 270 | def forward(self, x): 271 | h = self.conv(x).squeeze() 272 | output = self.linear(h) 273 | return output 274 | --------------------------------------------------------------------------------