├── .gitmodules ├── care_nl_ica ├── __init__.py ├── losses │ ├── __init__.py │ ├── dep_mat.py │ ├── utils.py │ └── dsm.py ├── models │ ├── __init__.py │ ├── nflib │ │ └── __init__.py │ ├── ivae │ │ ├── __init__.py │ │ └── ivae_wrapper.py │ ├── tcl │ │ ├── __init__.py │ │ ├── tcl_preprocessing.py │ │ └── tcl_eval.py │ ├── sparsity.py │ ├── model.py │ └── icebeem_wrapper.py ├── independence │ ├── __init__.py │ ├── indep_check.py │ └── hsic.py ├── cl_ica │ ├── datasets │ │ ├── __init__.py │ │ └── simple_image_dataset.py │ ├── kitti_masks │ │ ├── __init__.py │ │ ├── mcc_metric │ │ │ ├── __init__.py │ │ │ └── metric.py │ │ ├── README.md │ │ ├── LICENSE │ │ ├── evaluate_disentanglement.py │ │ ├── model.py │ │ └── solver.py │ ├── tools │ │ └── 3dident │ │ │ ├── data │ │ │ ├── materials_json │ │ │ │ ├── materials.json │ │ │ │ ├── materials_diff.json │ │ │ │ ├── materials_mix.json │ │ │ │ ├── materials_spec.json │ │ │ │ └── materials_cristal.json │ │ │ ├── colors_json │ │ │ │ ├── colors.json │ │ │ │ ├── colors_rgb.json │ │ │ │ ├── colors_value_green.json │ │ │ │ └── colors_saturation_green.json │ │ │ ├── textures │ │ │ │ └── grass.jpg │ │ │ ├── materials │ │ │ │ ├── Crystal.blend │ │ │ │ ├── MyMetal.blend │ │ │ │ └── Rubber.blend │ │ │ ├── scenes │ │ │ │ ├── base_scene.blend │ │ │ │ ├── base_scene_old.blend │ │ │ │ ├── base_scene_spot.blend │ │ │ │ ├── base_scene_simple.blend │ │ │ │ ├── base_scene_equal_xyz.blend │ │ │ │ └── base_scene_equal_xyz.blend1 │ │ │ ├── shapes │ │ │ │ ├── ShapeCube.blend │ │ │ │ ├── ShapeSphere.blend │ │ │ │ ├── ShapeTeapot.blend │ │ │ │ └── ShapeCylinder.blend │ │ │ ├── node_groups │ │ │ │ ├── NodeGroup.blend │ │ │ │ └── NodeGroupMulti4.blend │ │ │ ├── CoGenT_A.json │ │ │ ├── CoGenT_B.json │ │ │ └── properties.json │ │ │ ├── NOTICE │ │ │ ├── get_mean_std.py │ │ │ ├── ORIGINAL_LICENSE │ │ │ └── ORIGINAL_PATENTS │ ├── __init__.py │ ├── infinite_iterator.py │ ├── docker │ │ ├── tmux.conf │ │ ├── Dockerfile.blender │ │ ├── entrypoint.sh │ │ └── Dockerfile │ ├── LICENSE │ ├── latent_spaces.py │ ├── layers.py │ └── spaces_utils.py ├── metrics │ ├── __init__.py │ ├── metric_logger.py │ └── dep_mat.py ├── data │ ├── __init__.py │ └── utils.py ├── logger.py ├── dataset.py ├── ica.py ├── cli.py ├── graph_utils.py ├── prob_utils.py └── utils.py ├── scripts ├── fix_no_gpu.sh ├── cancel_all_jobs.sh ├── get_used_space.sh ├── sourcing_script.sh ├── sweeps.sh ├── start_interactive_job.sh ├── profile.sh ├── wandb_sweep.sh ├── start_preemptable_job.sh ├── start_preemptable_cli.sh ├── train_loop.sh ├── loss_compare_for_permutations.sh ├── run_singularity_server.sh ├── nv.def ├── interactive_job_inner.sh └── exclude_nodes.sh ├── configs ├── data │ ├── chain.yaml │ ├── permute.yaml │ ├── nl_sem.yaml │ └── uniform_weights.yaml ├── profile.yaml ├── cpu.yaml └── config.yaml ├── setup.py ├── notebooks ├── monti_sweep_3rqfxkyl.npz ├── monti_sweep_70zssxmx.npz ├── monti_sweep_77huh2ue.npz ├── sem_10d_sweep_7lsb5ud3.npz ├── sem_10d_sweep_i871qu61.npz ├── sem_3d_sweep_mmzkkmw4.npz ├── sem_3d_sweep_vfv1je0d.npz ├── sem_5d_sweep_f5nxtdxz.npz ├── sem_5d_sweep_h6y1gkvo.npz ├── sem_5d_sweep_kdau2seo.npz ├── sem_8d_sweep_7sscc3w1.npz ├── sem_8d_sweep_v3kd7kca.npz ├── sem_10d_sparse_sweep_t7rmrux1.npz ├── sem_3d_permute_sweep_2mgctqko.npz ├── sem_3d_sparse_sweep_lm8s890w.npz ├── sem_5d_munkres_sweep_kdau2seo.npz ├── sem_5d_permute_sweep_rv7yo1qy.npz ├── sem_5d_permute_sweep_x6chdc63.npz ├── sem_5d_sparse_sweep_7n108utd.npz ├── sem_8d_permute_sweep_05whlpmk.npz ├── sem_8d_permute_sweep_l49b2vhx.npz ├── sem_8d_sparse_sweep_2ykg2w21.npz ├── sem_10d_permute_sweep_291qhry5.npz ├── sem_10d_permute_sweep_at138q9q.npz ├── sem_5d_munkres_sweep_kdau2seo.csv ├── sem_3d_sparse_sweep_lm8s890w.csv ├── sem_3d_sweep_mmzkkmw4.csv ├── sem_5d_sweep_kdau2seo.csv ├── sem_8d_sweep_v3kd7kca.csv ├── sem_10d_sparse_sweep_t7rmrux1.csv ├── sem_5d_sparse_sweep_7n108utd.csv ├── sem_8d_sparse_sweep_2ykg2w21.csv ├── sem_10d_sweep_i871qu61.csv ├── monti_sweep_77huh2ue.csv ├── monti_sweep_3rqfxkyl.csv ├── sem_3d_sweep_vfv1je0d.csv ├── sem_10d_permute_sweep_at138q9q.csv ├── sem_8d_permute_sweep_05whlpmk.csv ├── sem_5d_sweep_f5nxtdxz.csv ├── sem_5d_permute_sweep_rv7yo1qy.csv └── sem_5d_permute_sweep_x6chdc63.csv ├── tests ├── requirements.txt ├── test_runner.py ├── test_dataset.py ├── test_datamodule.py ├── test_sinkhorn.py ├── test_metrics.py ├── test_dep_mat.py ├── conftest.py └── test_graph_utils.py ├── .gitignore ├── requirements.txt ├── setup.cfg ├── .pre-commit-config.yaml ├── sweeps ├── sem │ ├── mlp_sem3.yaml │ ├── mlp_sem5.yaml │ ├── mlp_sem3_no_permute.yaml │ ├── mlp_sem5_no_permute.yaml │ ├── mlp_sem8.yaml │ ├── mlp_sem8_no_permute.yaml │ ├── mlp_sem10_no_permute.yaml │ ├── mlp_sem15_no_permute.yaml │ ├── mlp_sem10.yaml │ ├── mlp_sem5_sparse.yaml │ ├── mlp_sem3_no_permute_sparse.yaml │ ├── mlp_sem5_no_permute_sparse.yaml │ ├── mlp_sem8_no_permute_sparse.yaml │ ├── mlp_sem8_sparse.yaml │ ├── mlp_sem10_no_permute_sparse.yaml │ └── mlp_sem10_sparse.yaml └── tcl │ ├── mlp_ar.yaml │ ├── mlp_ar_adaptive_offset.yaml │ └── mlp_ar_adaptive_offset_sparse.yaml ├── CITATION.cff ├── LICENSE ├── .github └── workflows │ └── python-package.yml └── README.md /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/independence/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/fix_no_gpu.sh: -------------------------------------------------------------------------------- 1 | unset SLURM_NTASKS_PER_NODE -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/mcc_metric/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/data/chain.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | force_chain: true 3 | -------------------------------------------------------------------------------- /configs/data/permute.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | permute: true 3 | -------------------------------------------------------------------------------- /configs/data/nl_sem.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | nonlin_sem: true 3 | 4 | -------------------------------------------------------------------------------- /configs/data/uniform_weights.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | force_uniform: true 3 | -------------------------------------------------------------------------------- /scripts/cancel_all_jobs.sh: -------------------------------------------------------------------------------- 1 | squeue --me | awk 'NR>1 {print $1}' | xargs scancel -------------------------------------------------------------------------------- /care_nl_ica/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcc import mean_corr_coef 2 | 3 | __all__ = ["mcc"] 4 | -------------------------------------------------------------------------------- /care_nl_ica/models/nflib/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["conditional_flows", "flows", "spline_flows"] 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(packages=find_packages()) 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials_json/materials.json: -------------------------------------------------------------------------------- 1 | { 2 | "mat": ["Rubber", "Rubber", "Rubber"] 3 | } -------------------------------------------------------------------------------- /scripts/get_used_space.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | awk '/[0-9]+,preizinger/{print}' /mnt/qb/work/bethge/bethge_user_disk_usage_MB.log -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/colors_json/colors.json: -------------------------------------------------------------------------------- 1 | { 2 | "rgb": [[0, 0, 255], [255, 255, 255], [255, 0, 0]] 3 | } -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials_json/materials_diff.json: -------------------------------------------------------------------------------- 1 | { 2 | "mat": ["Rubber", "Rubber", "Rubber"] 3 | } -------------------------------------------------------------------------------- /care_nl_ica/models/ivae/__init__.py: -------------------------------------------------------------------------------- 1 | from .ivae_wrapper import IVAE_wrapper 2 | 3 | __all__ = ["ivae_wrapper", "ivae_core"] 4 | -------------------------------------------------------------------------------- /care_nl_ica/models/tcl/__init__.py: -------------------------------------------------------------------------------- 1 | from .tcl_wrapper_gpu import train 2 | 3 | __all__ = ["tcl_wrapper_gpu", "tcl_core"] 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/colors_json/colors_rgb.json: -------------------------------------------------------------------------------- 1 | { 2 | "rgb": [[255, 0, 0], [0, 255, 0], [0, 0, 255]] 3 | } 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials_json/materials_mix.json: -------------------------------------------------------------------------------- 1 | { 2 | "mat": ["Crystal", "MyMetal", "Rubber"] 3 | } 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials_json/materials_spec.json: -------------------------------------------------------------------------------- 1 | { 2 | "mat": ["MyMetal", "MyMetal", "MyMetal"] 3 | } 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials_json/materials_cristal.json: -------------------------------------------------------------------------------- 1 | { 2 | "mat": ["Crystal", "Crystal", "Crystal"] 3 | } 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/colors_json/colors_value_green.json: -------------------------------------------------------------------------------- 1 | { 2 | "rgb": [[0, 255, 42], [0, 178, 29], [0, 76, 12]] 3 | } 4 | -------------------------------------------------------------------------------- /notebooks/monti_sweep_3rqfxkyl.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/monti_sweep_3rqfxkyl.npz -------------------------------------------------------------------------------- /notebooks/monti_sweep_70zssxmx.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/monti_sweep_70zssxmx.npz -------------------------------------------------------------------------------- /notebooks/monti_sweep_77huh2ue.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/monti_sweep_77huh2ue.npz -------------------------------------------------------------------------------- /notebooks/sem_10d_sweep_7lsb5ud3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_10d_sweep_7lsb5ud3.npz -------------------------------------------------------------------------------- /notebooks/sem_10d_sweep_i871qu61.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_10d_sweep_i871qu61.npz -------------------------------------------------------------------------------- /notebooks/sem_3d_sweep_mmzkkmw4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_3d_sweep_mmzkkmw4.npz -------------------------------------------------------------------------------- /notebooks/sem_3d_sweep_vfv1je0d.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_3d_sweep_vfv1je0d.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_sweep_f5nxtdxz.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_sweep_f5nxtdxz.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_sweep_h6y1gkvo.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_sweep_h6y1gkvo.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_sweep_kdau2seo.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_sweep_kdau2seo.npz -------------------------------------------------------------------------------- /notebooks/sem_8d_sweep_7sscc3w1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_8d_sweep_7sscc3w1.npz -------------------------------------------------------------------------------- /notebooks/sem_8d_sweep_v3kd7kca.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_8d_sweep_v3kd7kca.npz -------------------------------------------------------------------------------- /scripts/sourcing_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for f in /.singularity.d/env/*; do echo "$f"; source "$f"; done 3 | 4 | export PS1="\u@\h \W:$" -------------------------------------------------------------------------------- /scripts/sweeps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for ((i = 1; i <= 1; i++)); 4 | do 5 | ./scripts/wandb_sweep.sh "$@" & 6 | sleep 1 7 | done 8 | -------------------------------------------------------------------------------- /care_nl_ica/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .imca import generate_synthetic_data 2 | from .utils import to_one_hot 3 | 4 | __all__ = ["imca", "utils"] 5 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/colors_json/colors_saturation_green.json: -------------------------------------------------------------------------------- 1 | { 2 | "rgb": [[0, 255, 42], [76, 255, 106], [178, 255, 191]] 3 | } 4 | -------------------------------------------------------------------------------- /notebooks/sem_10d_sparse_sweep_t7rmrux1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_10d_sparse_sweep_t7rmrux1.npz -------------------------------------------------------------------------------- /notebooks/sem_3d_permute_sweep_2mgctqko.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_3d_permute_sweep_2mgctqko.npz -------------------------------------------------------------------------------- /notebooks/sem_3d_sparse_sweep_lm8s890w.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_3d_sparse_sweep_lm8s890w.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_munkres_sweep_kdau2seo.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_munkres_sweep_kdau2seo.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_permute_sweep_rv7yo1qy.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_permute_sweep_rv7yo1qy.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_permute_sweep_x6chdc63.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_permute_sweep_x6chdc63.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_sparse_sweep_7n108utd.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_sparse_sweep_7n108utd.npz -------------------------------------------------------------------------------- /notebooks/sem_8d_permute_sweep_05whlpmk.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_8d_permute_sweep_05whlpmk.npz -------------------------------------------------------------------------------- /notebooks/sem_8d_permute_sweep_l49b2vhx.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_8d_permute_sweep_l49b2vhx.npz -------------------------------------------------------------------------------- /notebooks/sem_8d_sparse_sweep_2ykg2w21.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_8d_sparse_sweep_2ykg2w21.npz -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/NOTICE: -------------------------------------------------------------------------------- 1 | 3DIdent is derived from CLEVR. The original rendering pipeline is distributed under the terms listed in ORIGINAL_LICENSE. -------------------------------------------------------------------------------- /configs/profile.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | profiler: 3 | class_path: pytorch_lightning.profiler.PyTorchProfiler 4 | init_args: 5 | profile_memory: true 6 | -------------------------------------------------------------------------------- /notebooks/sem_10d_permute_sweep_291qhry5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_10d_permute_sweep_291qhry5.npz -------------------------------------------------------------------------------- /notebooks/sem_10d_permute_sweep_at138q9q.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_10d_permute_sweep_at138q9q.npz -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | coverage 2 | codecov 3 | pytest~=6.2.5 4 | pytest-cov 5 | pytest-flake8 6 | black~=23.1.0 7 | flake8 8 | check-manifest 9 | twine==1.13.0 -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/README.md: -------------------------------------------------------------------------------- 1 | Code from https://github.com/bethgelab/slow_disentanglement/blob/master/scripts 2 | to load and evaluate on the KITTI Masks dataset. -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/textures/grass.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/textures/grass.jpg -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials/Crystal.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/materials/Crystal.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials/MyMetal.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/materials/MyMetal.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials/Rubber.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/materials/Rubber.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeCube.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeCube.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeSphere.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeSphere.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeTeapot.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeTeapot.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/node_groups/NodeGroup.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/node_groups/NodeGroup.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_old.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_old.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_spot.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_spot.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeCylinder.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeCylinder.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_simple.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_simple.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/node_groups/NodeGroupMulti4.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/node_groups/NodeGroupMulti4.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_equal_xyz.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_equal_xyz.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_equal_xyz.blend1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_equal_xyz.blend1 -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/__init__.py: -------------------------------------------------------------------------------- 1 | from . import latent_spaces 2 | from . import spaces 3 | from . import vmf 4 | from . import spaces_utils 5 | from . import layers 6 | 7 | 8 | __all__ = ["latent_spaces", "spaces", "vmf", "spaces_utils", "layers"] 9 | -------------------------------------------------------------------------------- /scripts/start_interactive_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PARTITION=gpu-2080ti-interactive 3 | FLAGS=--exclude=slurm-bm-06 4 | 5 | srun --job-name="$JOB_NAME" --partition=$PARTITION --pty --gres=gpu:1 "$FLAGS" -- ./scripts/interactive_job_inner.sh 6 | 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | venv 3 | run 4 | .ipynb_checkpoints 5 | */.ipynb_checkpoints/* 6 | __pycache__ 7 | */__pycache__/* 8 | */lightning_logs/* 9 | */wandb/* 10 | *.sif 11 | *.egg-info/* 12 | *.log 13 | *.ckpt 14 | outputs/* 15 | wandb/* 16 | */artifacts/* -------------------------------------------------------------------------------- /scripts/profile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 care_nl_ica/cli.py fit --config configs/config.yaml --config configs/model/ar_non_tri.yaml --config configs/data/permute.yaml --config configs/data/nl_sem.yaml --trainer.profiler=simple --trainer.max_epochs=2 3 | 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/CoGenT_A.json: -------------------------------------------------------------------------------- 1 | { 2 | "cube": [ 3 | "gray", "blue", "brown", "yellow" 4 | ], 5 | "cylinder": [ 6 | "red", "green", "purple", "cyan" 7 | ], 8 | "sphere": [ 9 | "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow" 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/CoGenT_B.json: -------------------------------------------------------------------------------- 1 | { 2 | "cube": [ 3 | "red", "green", "purple", "cyan" 4 | ], 5 | "cylinder": [ 6 | "gray", "blue", "brown", "yellow" 7 | ], 8 | "sphere": [ 9 | "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow" 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /scripts/wandb_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PARTITION=gpu-2080ti-preemptable 3 | FLAGS=--exclude= #slurm-bm-62 4 | PYTHONPATH=. srun --time=600 --job-name="$JOB_NAME" --partition=$PARTITION $FLAGS --mem=8G --gpus=1 -- /mnt/qb/work/bethge/preizinger/nl-causal-representations//scripts/run_singularity_server.sh wandb agent --count 1 "$@" 5 | 6 | -------------------------------------------------------------------------------- /tests/test_runner.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.trainer import Trainer 2 | 3 | from care_nl_ica.data.datamodules import ContrastiveDataModule 4 | from care_nl_ica.runner import ContrastiveICAModule 5 | 6 | 7 | def test_runner(): 8 | trainer = Trainer(fast_dev_run=True) 9 | runner = ContrastiveICAModule() 10 | dm = ContrastiveDataModule() 11 | trainer.fit(runner, datamodule=dm) 12 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def test_contrastive_dataset(dataloader): 5 | batches = [] 6 | num_batches = 3 7 | for i in range(num_batches): 8 | batches.append(next(iter(dataloader))[1]) 9 | 10 | # calculates the variance accross the batch and n dimensions 11 | # to check that we do not get the same data 12 | torch.any(torch.stack(batches).var(0).sum([-1, -2]) > 1e-7) 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.19.5 2 | torch~=1.11.0 3 | scipy~=1.6.3 4 | scikit-learn~=0.24.2 5 | matplotlib~=3.4.2 6 | wandb~=0.12.17 7 | h5py~=3.1.0 8 | torchmetrics~=0.7.2 9 | pytest~=6.2.4 10 | setuptools~=59.5.0 11 | pandas~=1.3.1 12 | pip~=21.1.2 13 | pytorch-lightning==1.5.10 14 | jsonargparse[signatures]~=4.3.0 15 | hydra-core~=1.1.1 16 | pre-commit 17 | omegaconf~=2.1.1 18 | pynvml~=11.4.1 19 | functorch~=0.1.0 20 | tueplots -------------------------------------------------------------------------------- /scripts/start_preemptable_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PARTITION=gpu-2080ti-preemptable 3 | #FLAGS= 4 | PYTHONPATH=. srun --job-name="$JOB_NAME" --partition=$PARTITION --cpus-per-task=2 --mem=4G --gpus=1 -- /mnt/qb/work/bethge/preizinger/nl-causal-representations//scripts/run_singularity_server.sh python3 /mnt/qb/work/bethge/preizinger/nl-causal-representations/care_nl_ica/main.py --project mlp-test --use-batch-norm --use-dep-mat --use-wandb "$@" 5 | 6 | -------------------------------------------------------------------------------- /scripts/start_preemptable_cli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PARTITION=gpu-2080ti-preemptable 3 | #FLAGS= 4 | PYTHONPATH=. srun --time=360 --job-name="$JOB_NAME" --partition=$PARTITION --mem=8G --pty --gpus=1 -- scripts/run_singularity_server.sh python3 care_nl_ica/cli.py fit --config configs/config.yaml # --trainer.max_epochs=25000 --data.latent_dim=10 --model.offline=true --data.variant=4 --data.force_chain=true --data.nonlin_sem=true --data.use_sem=false --data.permute=false --data.data_gen_mode=rvs "$@" 5 | 6 | -------------------------------------------------------------------------------- /scripts/train_loop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for seed in 84645646 4 | do 5 | for n in 4 6 | do 7 | for l1 in 0 #1e1 1 1e-1 1e-2 1e-3 #1e-1 1 1e2 #0 1e-2 3e-3 1e-3 3e-4 #0 1e-6 1e-5 1e-4 8 | do 9 | ./scripts/start_preemptable_job.sh --use-ar-mlp --seed ${seed} --n ${n} --l1 ${l1} --note "sinkhorn" --tags normalization nonlinear sem residual entropy sinkhorn --use-sem --nonlin-sem --normalize-latents --verbose --permute & 10 | sleep 20 11 | done 12 | done 13 | done 14 | -------------------------------------------------------------------------------- /tests/test_datamodule.py: -------------------------------------------------------------------------------- 1 | from care_nl_ica.data.datamodules import ContrastiveDataModule 2 | import torch 3 | 4 | 5 | def test_contrastive_datamodule(datamodule: ContrastiveDataModule): 6 | batches = [] 7 | num_batches = 3 8 | for i in range(num_batches): 9 | batches.append(next(iter(datamodule.train_dataloader()))[0][0, :]) 10 | 11 | # calculates the variance accross the batch and n dimensions 12 | # to check that we do not get the same data 13 | torch.any(torch.stack(batches).var(0).sum([-1, -2]) > 1e-7) 14 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = causalnlica 3 | version = 1.0.0 4 | author = Patrik Reizinger, Yash Sharma 5 | author_email = prmedia17@gmail.com 6 | description = Causal Representation learning with Nonlinear ICA 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/rpatrik96/nl-causal-representations 10 | project_urls = 11 | Bug Tracker = https://github.com/rpatrik96/nl-causal-representations 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | License :: OSI Approved :: MIT License 15 | Operating System :: OS Independent 16 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/infinite_iterator.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | 4 | class InfiniteIterator: 5 | """Infinitely repeat the iterable.""" 6 | 7 | def __init__(self, iterable: Iterable): 8 | self._iterable = iterable 9 | self.iterator = iter(self._iterable) 10 | 11 | def __iter__(self): 12 | return self 13 | 14 | def __next__(self): 15 | for _ in range(2): 16 | try: 17 | return next(self.iterator) 18 | except StopIteration: 19 | # reset iterator 20 | del self.iterator 21 | self.iterator = iter(self._iterable) 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.1.0 4 | hooks: 5 | - id: black 6 | entry: black 7 | types: [python] 8 | # It is recommended to specify the latest version of Python 9 | # supported by your project here, or alternatively use 10 | # pre-commit's default_language_version, see 11 | # https://pre-commit.com/#top_level-default_language_version 12 | language_version: python3.8 13 | - repo: local 14 | hooks: 15 | - id: tests 16 | name: Unit tests 17 | language: system 18 | entry: bash -c "source activate venv"; pytest tests -------------------------------------------------------------------------------- /tests/test_sinkhorn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from care_nl_ica.models.sinkhorn import SinkhornNet 4 | from care_nl_ica.utils import setup_seed 5 | 6 | 7 | def test_sinkhorn_net(): 8 | num_vars = 3 9 | num_steps = 15 10 | temperature = 3e-3 11 | threshold = 0.95 12 | 13 | setup_seed(1) 14 | 15 | sinkhorn_net = SinkhornNet(num_vars, num_steps, temperature) 16 | sinkhorn_net.weight.data = torch.randn(sinkhorn_net.weight.shape) 17 | 18 | print(sinkhorn_net.doubly_stochastic_matrix.abs()) 19 | assert torch.all( 20 | (sinkhorn_net.doubly_stochastic_matrix.abs() > threshold).sum(dim=0) 21 | == torch.ones(1, num_vars) 22 | ) 23 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "shapes": { 3 | "cube": "ShapeCube", 4 | "sphere": "ShapeSphere", 5 | "cylinder": "ShapeCylinder", 6 | "teapot": "ShapeTeapot" 7 | }, 8 | "colors": { 9 | "gray": [87, 87, 87], 10 | "red": [173, 35, 35], 11 | "blue": [42, 75, 215], 12 | "green": [29, 105, 20], 13 | "brown": [129, 74, 25], 14 | "purple": [129, 38, 192], 15 | "cyan": [41, 208, 208], 16 | "yellow": [255, 238, 51] 17 | }, 18 | "materials": { 19 | "rubber": "Rubber", 20 | "metal": "MyMetal", 21 | "crystal": "Crystal" 22 | }, 23 | "sizes": { 24 | "large": 0.7, 25 | "small": 0.35 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | from care_nl_ica.metrics.dep_mat import JacobianBinnedPrecisionRecall 2 | import torch 3 | 4 | 5 | def test_jacobian_prec_recall(): 6 | num_dim = 3 7 | num_thresholds = 15 8 | jac_pr = JacobianBinnedPrecisionRecall(num_thresholds=num_thresholds) 9 | target = ( 10 | ( 11 | torch.tril(torch.bernoulli(0.5 * torch.ones(num_dim, num_dim)), -1) 12 | + torch.eye(num_dim) 13 | ) 14 | .bool() 15 | .float() 16 | ) 17 | preds = torch.tril(torch.randn_like(target)) 18 | jac_pr(preds, target) 19 | precisions, recalls, thresholds = jac_pr.compute() 20 | print(precisions, recalls, thresholds) 21 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/docker/tmux.conf: -------------------------------------------------------------------------------- 1 | # change prefix key 2 | unbind C-b 3 | set -g prefix 'C-\' 4 | bind 'C-\' send-prefix 5 | 6 | # large history 7 | set -g history-limit 10000 8 | 9 | # shift arrow to switch windows 10 | bind -n F9 previous-window 11 | bind -n F10 next-window 12 | 13 | # THEME 14 | set -g status-bg black 15 | set -g status-fg white 16 | set -g window-status-current-bg white 17 | set -g window-status-current-fg black 18 | set -g window-status-current-attr bold 19 | set -g status-interval 60 20 | set -g status-left-length 30 21 | set -g status-left '#[fg=green](#S) #(whoami)' 22 | set -g status-right '#[fg=yellow]#(cut -d " " -f 1-3 /proc/loadavg)#[default] #[fg=white]%H:%M#[default]' -------------------------------------------------------------------------------- /care_nl_ica/losses/dep_mat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from care_nl_ica.metrics.ica_dis import frobenius_diagonality 4 | 5 | 6 | def permutation_loss(matrix: torch.Tensor, matrix_power: bool = False): 7 | if matrix_power is False: 8 | # rows and cols sum up to 1 9 | col_sum = matrix.abs().sum(0) 10 | row_sum = matrix.abs().sum(1) 11 | loss = (col_sum - torch.ones_like(col_sum)).pow(2).mean() + ( 12 | row_sum - torch.ones_like(row_sum) 13 | ).pow(2).mean() 14 | else: 15 | # diagonality (as Q^n = I for permutation matrices) 16 | loss = frobenius_diagonality(matrix.matrix_power(matrix.shape[0]).abs()) 17 | 18 | return loss 19 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem3.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,1,2,3,4,5] 22 | data.latent_dim: 23 | value: 3 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 0 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem5.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,5,31,75, 97, 119] 22 | data.latent_dim: 23 | value: 5 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 0 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem3_no_permute.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 3 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 1 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem5_no_permute.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 5 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 0 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem8.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,575, 8754, 12450, 21345, 36755] 22 | data.latent_dim: 23 | value: 8 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 0 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem8_no_permute.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 8 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 1 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem10_no_permute.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 25000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 10 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 1 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem15_no_permute.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 35000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 15 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 1 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem10.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 25000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,657894, 1257846, 1875424, 2645120, 3628600] 22 | data.latent_dim: 23 | value: 10 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 0 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /care_nl_ica/losses/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from typing import Dict 4 | 5 | 6 | @dataclass 7 | class ContrastiveLosses: 8 | cl_pos: float = 0.0 9 | cl_neg: float = 0.0 10 | cl_entropy: float = 0.0 11 | 12 | @property 13 | def total_loss(self): 14 | total_loss = 0.0 15 | for key, loss in self.__dict__.items(): 16 | if key != "cl_entropy": 17 | total_loss += loss 18 | 19 | return total_loss 20 | 21 | def log_dict(self) -> Dict[str, float]: 22 | return { 23 | "cl_pos": self.cl_pos, 24 | "cl_neg": self.cl_neg, 25 | "cl_entropy": self.cl_entropy, 26 | "total": self.total_loss, 27 | } 28 | -------------------------------------------------------------------------------- /tests/test_dep_mat.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | 5 | from care_nl_ica.dep_mat import calc_jacobian 6 | from care_nl_ica.models.model import ContrastiveLearningModel 7 | 8 | 9 | def test_calc_jacobian(args): 10 | m = ContrastiveLearningModel( 11 | Namespace( 12 | **{ 13 | **args.model, 14 | **args.data, 15 | "device": "cuda" if torch.cuda.is_available() else "cpu", 16 | } 17 | ) 18 | ) 19 | x = torch.randn(64, args.data.latent_dim) 20 | 21 | assert ( 22 | torch.allclose( 23 | calc_jacobian(m, x, vectorize=True), calc_jacobian(m, x, vectorize=False) 24 | ) 25 | == True 26 | ) 27 | -------------------------------------------------------------------------------- /sweeps/tcl/mlp_ar.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | seed_everything: 15 | values: [42, 64, 982, 5748, 23567] 16 | trainer.max_epochs: 17 | value: 20000 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 3 22 | data.latent_dim: 23 | value: 6 24 | data.use_sem: 25 | value: false 26 | data.n_mixing_layer: 27 | values: [1,2,3,4,5] 28 | data.data_gen_mode: 29 | value: offset 30 | data.mlp_sparsity: 31 | value: true 32 | data.offset: 33 | value: 3 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem5_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,5,31,75, 97, 119] 22 | data.latent_dim: 23 | value: 5 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 1 32 | data.mask_prob: 33 | value: 0.25 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/tcl/mlp_ar_adaptive_offset.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | seed_everything: 15 | values: [42, 64, 982, 5748, 23567] 16 | trainer.max_epochs: 17 | value: 20000 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 3 22 | data.latent_dim: 23 | value: 6 24 | data.use_sem: 25 | value: false 26 | data.n_mixing_layer: 27 | values: [1,2,3,4,5] 28 | data.data_gen_mode: 29 | value: offset 30 | data.mlp_sparsity: 31 | value: true 32 | data.offset: 33 | value: 0 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem3_no_permute_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 3 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 1 32 | data.mask_prob: 33 | value: 0.25 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem5_no_permute_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 5 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 0 32 | data.mask_prob: 33 | value: 0.25 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem8_no_permute_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 8 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 0 32 | data.mask_prob: 33 | value: 0.25 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem8_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,575, 8754, 12450, 21345, 36755] 22 | data.latent_dim: 23 | value: 8 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 1 32 | data.mask_prob: 33 | value: 0.5 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem10_no_permute_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 10 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 0 32 | data.mask_prob: 33 | value: 0.25 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem10_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 25000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,657894, 1257846, 1875424, 2645120, 3628600] 22 | data.latent_dim: 23 | value: 10 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 1 32 | data.mask_prob: 33 | value: 0.5 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/tcl/mlp_ar_adaptive_offset_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | seed_everything: 15 | values: [42, 64, 982, 5748, 23567] 16 | trainer.max_epochs: 17 | value: 20000 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 3 22 | data.latent_dim: 23 | value: 6 24 | data.use_sem: 25 | value: false 26 | data.n_mixing_layer: 27 | values: [1,2,3,4,5] 28 | data.data_gen_mode: 29 | value: offset 30 | data.mlp_sparsity: 31 | value: true 32 | data.offset: 33 | value: 0 34 | data.mask_prob: 35 | value: 0.25 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Reizinger" 5 | given-names: "Patrik" 6 | orcid: "https://orcid.org/0000-0001-9861-0293" 7 | - family-names: "Sharma" 8 | given-names: "Yash" 9 | - family-names: "Bethge" 10 | given-names: "Matthias" 11 | - family-names: "Schölkopf" 12 | given-names: "Bernhard" 13 | orcid: "https://orcid.org/0000-0002-8177-0925" 14 | - family-names: "Huszár" 15 | given-names: "Ferenc" 16 | orcid: "https://orcid.org/0000-0002-4988-1430" 17 | - family-names: "Brendel" 18 | given-names: "Wieland" 19 | orcid: "https://orcid.org/0000-0003-0982-552X" 20 | title: "nl-causal-representations" 21 | version: 1.0.0 22 | doi: 10.5281/zenodo.7002143 23 | date-released: 2022-08-17 24 | url: "https://github.com/rpatrik96/nl-causal-representations" 25 | -------------------------------------------------------------------------------- /scripts/loss_compare_for_permutations.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | N_STEPS=25001 4 | START_STEP=5000 5 | 6 | QR_LOSS=3e-1 7 | TRIANGULARITY_LOSS=0 8 | ENTROPY_COEFF=0 9 | LOSS_FLAGS="--qr-loss ${QR_LOSS} --triangularity-loss ${TRIANGULARITY_LOSS} --entropy-coeff ${ENTROPY_COEFF}" 10 | 11 | DATA_FLAGS="--use-sem --permute" 12 | LOG_FLAGS="--verbose --normalize-latents" 13 | MODEL_FLAGS="--use-ar-mlp" 14 | 15 | for seed in 84645646; do 16 | for n in 3; do 17 | fact=1 18 | for ((i = 1; i <= ${n}; i++)); do 19 | fact=$(($fact * $i)) 20 | done 21 | 22 | for ((variant = 0; variant < ${fact}; variant++)); do 23 | ./scripts/start_preemptable_job.sh --seed ${seed} --n ${n} --variant "${variant}" ${LOSS_FLAGS} --n-steps ${N_STEPS} --start-step ${START_STEP} ${DATA_FLAGS} ${LOG_FLAGS} "$@" & 24 | sleep 5 25 | done 26 | done 27 | done 28 | -------------------------------------------------------------------------------- /configs/cpu.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | auto_select_gpus: false 3 | overfit_batches: 0.0 4 | check_val_every_n_epoch: 200 5 | fast_dev_run: false 6 | max_epochs: 5500 7 | min_epochs: null 8 | max_steps: -1 9 | min_steps: null 10 | max_time: null 11 | gpus: 0 12 | limit_train_batches: 1.0 13 | limit_val_batches: 1.0 14 | limit_test_batches: 1.0 15 | limit_predict_batches: 1.0 16 | val_check_interval: 1.0 17 | log_every_n_steps: 250 18 | enable_model_summary: true 19 | weights_summary: top 20 | num_sanity_val_steps: 2 21 | profiler: null 22 | benchmark: false 23 | deterministic: true 24 | detect_anomaly: true 25 | plugins: null 26 | data: 27 | batch_size: 512 28 | latent_dim: 3 29 | diag_weight: 0.4 30 | permute: true 31 | variant: 5 32 | force_uniform: false 33 | force_chain: true 34 | model: 35 | start_step: 4500 36 | lr: 1e-3 -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/datasets/simple_image_dataset.py: -------------------------------------------------------------------------------- 1 | """Loads all images in a folder while ignoring class information etc.""" 2 | 3 | import torch.utils.data 4 | from typing import Optional, Callable 5 | import glob 6 | import os 7 | import torchvision 8 | 9 | 10 | class SimpleImageDataset(torch.utils.data.Dataset): 11 | def __init__(self, root: str, transform: Optional[Callable] = None): 12 | self.root = root 13 | self.image_paths = list(sorted(list(glob.glob(os.path.join(root, "*.*"))))) 14 | 15 | if transform is None: 16 | transform = lambda x: x 17 | self.transform = transform 18 | 19 | self.loader = torchvision.datasets.folder.pil_loader 20 | 21 | def __len__(self): 22 | return len(self.image_paths) 23 | 24 | def __getitem__(self, item): 25 | assert 0 <= item < len(self) 26 | return self.transform(self.loader(self.image_paths[item])) 27 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/docker/Dockerfile.blender: -------------------------------------------------------------------------------- 1 | # To build this Dockerfile you have to let it inherit from the other one 2 | 3 | # Set the time zone correctly 4 | ENV TZ=Europe/Berlin 5 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 6 | 7 | ENV SHELL /bin/bash 8 | 9 | RUN wget https://download.blender.org/release/Blender2.91/blender-2.91.0-linux64.tar.xz 10 | RUN tar -xf blender-2.91.0-linux64.tar.xz -C /usr/local/bin/ 11 | ENV PATH "$PATH:/usr/local/bin/blender-2.91.0-linux64" 12 | 13 | RUN apt-get update && apt-get install -y --no-install-recommends \ 14 | libxxf86vm1 libxfixes-dev libglu1-mesa-dev freeglut3-dev mesa-common-dev \ 15 | && apt-get clean \ 16 | && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 17 | 18 | COPY entrypoint.sh /usr/local/bin/ 19 | RUN chmod a+x /usr/local/bin/entrypoint.sh 20 | 21 | ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] 22 | 23 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /notebooks/sem_5d_munkres_sweep_kdau2seo.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,swept-sweep-13,5,False,0,1,True,True,False,False,0.9994050747388996,1.9483914375305176 3 | 1,skilled-sweep-14,5,False,0,1,True,True,False,False,0.9993582123422764,1.893308401107788 4 | 2,royal-sweep-9,5,False,0,1,True,False,False,False,0.9998361494785972,1.94585919380188 5 | 3,valiant-sweep-10,5,False,0,1,True,False,False,False,0.816690862169852,3.03879714012146 6 | 4,breezy-sweep-6,5,False,0,1,True,False,False,False,0.9997581404610733,1.9464752674102783 7 | 5,winter-sweep-7,5,False,0,1,True,False,False,False,0.8795997458112131,2.9836673736572266 8 | 6,glad-sweep-5,5,False,0,1,True,False,False,False,0.9995376654534928,1.968914031982422 9 | 7,honest-sweep-1,5,False,0,1,True,False,False,False,0.9998311645399612,1.9033708572387695 10 | 8,wandering-sweep-3,5,False,0,1,True,False,False,False,0.9999524627177776,1.9334661960601809 11 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/get_mean_std.py: -------------------------------------------------------------------------------- 1 | """Calculate mean and std of dataset.""" 2 | 3 | from datasets.simple_image_dataset import SimpleImageDataset 4 | import torch.utils.data 5 | import argparse 6 | import torchvision.transforms 7 | import os 8 | from tqdm import tqdm 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--root-folder", required=True) 12 | args = parser.parse_args() 13 | 14 | dataset = SimpleImageDataset( 15 | args.root_folder, transform=torchvision.transforms.ToTensor() 16 | ) 17 | 18 | full_loader = torch.utils.data.DataLoader( 19 | dataset, shuffle=True, num_workers=os.cpu_count(), batch_size=256 20 | ) 21 | 22 | mean = torch.zeros(3) 23 | std = torch.zeros(3) 24 | print("==> Computing mean and std..") 25 | for inputs in tqdm(full_loader): 26 | for i in range(3): 27 | mean[i] += inputs[:, i, :, :].mean(dim=(-1, -2)).sum(0) 28 | std[i] += inputs[:, i, :, :].std(dim=(-1, -2)).sum(0) 29 | mean.div_(len(dataset)) 30 | std.div_(len(dataset)) 31 | print(mean, std) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Patrik Reizinger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Bethge Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Roland S. Zimmermann et al. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /care_nl_ica/models/sparsity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SparseBudgetNet(nn.Module): 6 | def __init__( 7 | self, 8 | num_dim: int, 9 | ): 10 | super(SparseBudgetNet, self).__init__() 11 | self.num_dim = num_dim 12 | self.budget: int = self.num_dim * (self.num_dim + 1) // 2 13 | self.weight = nn.Parameter( 14 | nn.Linear(self.num_dim, self.num_dim).weight, requires_grad=True 15 | ) 16 | 17 | def to(self, device): 18 | """ 19 | Move the model to the specified device. 20 | 21 | :param device: The device to move the model to. 22 | """ 23 | super().to(device) 24 | self.weight = self.weight.to(device) 25 | 26 | return self 27 | 28 | @property 29 | def mask(self): 30 | return torch.sigmoid(self.weight) 31 | 32 | @property 33 | def entropy(self): 34 | probs = torch.nn.functional.softmax(self.mask, -1).view( 35 | -1, 36 | ) 37 | 38 | return torch.distributions.Categorical(probs).entropy() 39 | 40 | @property 41 | def budget_loss(self): 42 | return torch.relu(self.budget - self.mask.sum()) 43 | -------------------------------------------------------------------------------- /scripts/run_singularity_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | IMAGE=./scripts/nv.sif 4 | 5 | userName=preizinger 6 | tmp_dir=$(mktemp -d -t singularity-home-XXXXXXXX) 7 | echo "$tmp_dir" 8 | 9 | # singularity doesn't like the image to be modified while it's running. 10 | # also, it's faster to have it on local storage 11 | echo "copy image" 12 | LOCAL_IMAGE=$tmp_dir/nv.sif 13 | rsync -av --progress $IMAGE "$LOCAL_IMAGE" 14 | 15 | 16 | mkdir -p "$tmp_dir"/.vscode-server 17 | mkdir -p "$tmp_dir"/.conda 18 | mkdir -p "$tmp_dir"/.ipython 19 | mkdir -p "$tmp_dir"/.jupyter 20 | mkdir -p "$tmp_dir"/.local 21 | mkdir -p "$tmp_dir"/.pylint.d 22 | mkdir -p "$tmp_dir"/.cache 23 | 24 | singularity exec -p --nv \ 25 | --bind "$tmp_dir"/.vscode-server:/mnt/qb/work/bethge/$userName/.vscode-server \ 26 | --bind "$tmp_dir"/.conda:/mnt/qb/work/bethge/$userName/.conda \ 27 | --bind "$tmp_dir"/.ipython:/mnt/qb/work/bethge/$userName/.ipython \ 28 | --bind "$tmp_dir"/.jupyter:/mnt/qb/work/bethge/$userName/.jupyter \ 29 | --bind "$tmp_dir"/.local:/mnt/qb/work/bethge/$userName/.local \ 30 | --bind "$tmp_dir"/.pylint.d:/mnt/qb/work/bethge/$userName/.pylint.d \ 31 | --bind "$tmp_dir"/.cache:/mnt/qb/work/bethge/$userName/.cache \ 32 | --bind /scratch_local \ 33 | --bind /home/bethge/$userName \ 34 | --bind /mnt/qb/work/bethge \ 35 | "$LOCAL_IMAGE" "$@" 36 | 37 | rm -rf "$tmp_dir" 38 | -------------------------------------------------------------------------------- /scripts/nv.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: nvidia/cuda:11.3.0-cudnn8-runtime-ubuntu20.04 3 | 4 | 5 | %environment 6 | # overwrite Singularity prompt with something more useful 7 | export PS1="\u@\h \W§ " 8 | 9 | %files 10 | # copy the dropbear keys into the container 11 | /mnt/qb/work/bethge/preizinger/dropbear /etc/dropbear 12 | 13 | # project requirement file 14 | ../requirements.txt 15 | 16 | %post 17 | # required to indicate that no interaction is required for installing packages with apt 18 | # otherwise, installing packages can fail 19 | export DEBIAN_FRONTEND=noninteractive 20 | 21 | # if the cuda container cannot be verified as nvidia did not update the key (GPG error) 22 | # https://github.com/NVIDIA/nvidia-docker/issues/619#issuecomment-359580806 23 | rm /etc/apt/sources.list.d/cuda.list 24 | 25 | # sometimes fetching the key explicitly can help 26 | # https://github.com/NVIDIA/nvidia-docker/issues/619#issuecomment-359695597 27 | apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub 28 | 29 | apt-get -qq -y update 30 | 31 | # apt-get install -y build-essential cmake git rsync 32 | 33 | apt-get install -y python3 python3-pip python3-wheel python3-yaml intel-mkl-full 34 | 35 | # ssh client for the vs code setup 36 | apt-get install -y dropbear 37 | chmod +r -R /etc/dropbear 38 | 39 | apt-get clean 40 | 41 | # project requirements install 42 | pip3 install -r requirements.txt 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /notebooks/sem_3d_sparse_sweep_lm8s890w.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,classic-sweep-20,3,False,0,1,True,True,False,False,0.9999596225772244,4.484454154968262 3 | 1,northern-sweep-19,3,False,0,1,True,True,False,False,0.9999133110494473,4.493100643157959 4 | 2,lucky-sweep-18,3,False,0,1,True,True,False,False,0.999934173215041,4.493666648864746 5 | 3,sunny-sweep-16,3,False,0,1,True,True,False,False,0.9999475003711812,4.4822492599487305 6 | 4,ethereal-sweep-15,3,False,0,1,True,True,False,False,0.9999135937156256,4.465629577636719 7 | 5,splendid-sweep-14,3,False,0,1,True,True,False,False,0.9999369829837246,4.470257759094238 8 | 6,cerulean-sweep-13,3,False,0,1,True,True,False,False,0.9998532325373448,4.4454545974731445 9 | 7,restful-sweep-11,3,False,0,1,True,True,False,False,0.999959709496154,4.461050510406494 10 | 8,bumbling-sweep-10,3,False,0,1,True,False,False,False,0.9999579945185338,4.493952751159668 11 | 9,morning-sweep-9,3,False,0,1,True,False,False,False,0.9999619495424686,4.491983890533447 12 | 10,classic-sweep-8,3,False,0,1,True,False,False,False,0.9999773318709866,4.492518424987793 13 | 11,vital-sweep-7,3,False,0,1,True,False,False,False,0.9999648924388136,4.451825141906738 14 | 12,winter-sweep-6,3,False,0,1,True,False,False,False,0.9999844737286218,4.4821624755859375 15 | 13,golden-sweep-5,3,False,0,1,True,False,False,False,0.9999239649393522,4.46468448638916 16 | 14,apricot-sweep-4,3,False,0,1,True,False,False,False,0.9999426637679648,4.468593120574951 17 | 15,stilted-sweep-1,3,False,0,1,True,False,False,False,0.9999451954775463,4.459792137145996 18 | -------------------------------------------------------------------------------- /care_nl_ica/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | 4 | from .utils import matrix_to_dict 5 | 6 | 7 | class Logger(object): 8 | def __init__(self, hparams, model) -> None: 9 | super().__init__() 10 | self.hparams = hparams 11 | 12 | self._setup_exp_management(model) 13 | 14 | self.total_loss_values = None 15 | 16 | def _setup_exp_management(self, model): 17 | if self.hparams.use_wandb is True: 18 | wandb.init( 19 | entity="causal-representation-learning", 20 | project=self.hparams.project, 21 | notes=self.hparams.notes, 22 | config=self.hparams, 23 | tags=self.hparams.tags, 24 | ) 25 | wandb.watch(model, log_freq=self.hparams.n_log_steps, log="all") 26 | 27 | # define metrics 28 | wandb.define_metric("total_loss", summary="min") 29 | wandb.define_metric("lin_dis_score", summary="max") 30 | wandb.define_metric("perm_dis_score", summary="max") 31 | 32 | def log_jacobian( 33 | self, dep_mat, name="gt_decoder", inv_name="gt_encoder", log_inverse=True 34 | ): 35 | jac = dep_mat.detach().cpu() 36 | cols = [f"a_{i}" for i in range(dep_mat.shape[1])] 37 | 38 | gt_jacobian_dec = wandb.Table(columns=cols, data=jac.tolist()) 39 | self.log_summary(**{f"{name}_jacobian": gt_jacobian_dec}) 40 | 41 | if log_inverse is True: 42 | gt_jacobian_enc = wandb.Table(columns=cols, data=jac.inverse().tolist()) 43 | self.log_summary(**{f"{inv_name}_jacobian": gt_jacobian_enc}) 44 | -------------------------------------------------------------------------------- /scripts/interactive_job_inner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xe 3 | 4 | HOSTNAME=$(hostname | sed s/.novalocal//) 5 | echo "Running on $HOSTNAME" 6 | 7 | # update ssh target "interactive" 8 | # this requires you to add an entry 9 | # Include config.d/interactive 10 | # to your ~/.ssh/config 11 | mkdir -p ~/.ssh/config.d 12 | echo "Host interactive" > ~/.ssh/config.d/interactive 13 | echo " Hostname $HOSTNAME" >> ~/.ssh/config.d/interactive 14 | echo " User preizinger" >> ~/.ssh/config.d/interactive 15 | 16 | SESSION=preizinger-interactive 17 | 18 | # reset $TMPDIR. In SLURM-jobs by default it would be 19 | # TMPDIR="/scratch_local/$(whoami)-${SLURM_JOB_ID}/tmp" 20 | # but this setting is lost when sshing into the node. 21 | # Therefore, we unset $TMPDIR already here. 22 | unset TMPDIR 23 | #make sure we can run this script from within tmux 24 | TMUX=0 25 | tmux new-session -d -s $SESSION 26 | tmux send-keys 'echo session started' C-m 27 | # Build correct address parameters. 28 | # Unfortunately, dropbear seems to listen only on IPv6 29 | # since recently. I haven't found the reason yet, but it's 30 | # easiest to just specify the correct addresses. 31 | ADDRESSES='' 32 | for IP in $(hostname -I) 33 | do 34 | ADDRESSES="$ADDRESSES -p $IP:12345" 35 | done 36 | # -p is for making sure sshd is killed if singularity is stopped 37 | # - R is for generating key 38 | tmux send-keys "./scripts/run_singularity_server.sh /usr/sbin/dropbear -R -E -F $ADDRESSES -s" C-m 39 | 40 | while true; do 41 | # will fail if session ended 42 | tmux has-session -t $SESSION 43 | echo "session still running" 44 | sleep 10 45 | done -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/ORIGINAL_LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For clevr-dataset-gen software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import hydra.core.global_hydra 4 | import pytest 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | from care_nl_ica.data.datamodules import ContrastiveDataModule 9 | from care_nl_ica.dataset import ContrastiveDataset 10 | 11 | arg_matrix = namedtuple("arg_matrix", ["latent_dim"]) 12 | 13 | from hydra import compose, initialize 14 | from pytorch_lightning import seed_everything 15 | from argparse import Namespace 16 | 17 | 18 | @pytest.fixture( 19 | params=[ 20 | arg_matrix(latent_dim=3), 21 | ] 22 | ) 23 | def args(request): 24 | hydra.core.global_hydra.GlobalHydra.instance().clear() 25 | initialize(config_path="../configs", job_name="test_app") 26 | 27 | cfg = compose( 28 | config_name="config", 29 | overrides=[ 30 | f"data.latent_dim={request.param.latent_dim}", 31 | "data.use_sem=true", 32 | "data.nonlin_sem=true", 33 | ], 34 | ) 35 | 36 | seed_everything(cfg.seed_everything) 37 | 38 | return cfg 39 | 40 | 41 | @pytest.fixture() 42 | def dataloader(args): 43 | args = Namespace( 44 | **{**args.data, "device": "cuda" if torch.cuda.is_available() else "cpu"} 45 | ) 46 | ds = ContrastiveDataset( 47 | args, 48 | lambda x: x 49 | @ torch.tril(torch.ones(args.latent_dim, args.latent_dim, device=args.device)), 50 | ) 51 | dl = DataLoader(ds, args.batch_size) 52 | return dl 53 | 54 | 55 | @pytest.fixture() 56 | def datamodule(args): 57 | dm = ContrastiveDataModule.from_argparse_args(Namespace(**args.data)) 58 | dm.setup() 59 | return dm 60 | -------------------------------------------------------------------------------- /care_nl_ica/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | from care_nl_ica.cl_ica import latent_spaces, spaces 5 | from care_nl_ica.prob_utils import ( 6 | setup_marginal, 7 | setup_conditional, 8 | sample_marginal_and_conditional, 9 | ) 10 | 11 | 12 | class ContrastiveDataset(torch.utils.data.IterableDataset): 13 | def __init__(self, hparams, transform=None): 14 | super().__init__() 15 | self.hparams = hparams 16 | self.transform = transform 17 | 18 | self._setup_space() 19 | 20 | self.latent_space = latent_spaces.LatentSpace( 21 | space=self.space, 22 | sample_marginal=setup_marginal(self.hparams), 23 | sample_conditional=setup_conditional(self.hparams), 24 | ) 25 | torch.cuda.empty_cache() 26 | 27 | def _setup_space(self): 28 | if self.hparams.space_type == "box": 29 | self.space = spaces.NBoxSpace( 30 | self.hparams.latent_dim, self.hparams.box_min, self.hparams.box_max 31 | ) 32 | elif self.hparams.space_type == "sphere": 33 | self.space = spaces.NSphereSpace( 34 | self.hparams.latent_dim, self.hparams.sphere_r 35 | ) 36 | else: 37 | self.space = spaces.NRealSpace(self.hparams.latent_dim) 38 | 39 | def __iter__(self): 40 | sources = torch.stack( 41 | sample_marginal_and_conditional( 42 | self.latent_space, 43 | size=self.hparams.batch_size, 44 | device=self.hparams.device, 45 | ) 46 | ) 47 | 48 | mixtures = torch.stack(tuple(map(self.transform, sources))) 49 | 50 | return iter((sources, mixtures)) 51 | -------------------------------------------------------------------------------- /notebooks/sem_3d_sweep_mmzkkmw4.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,vibrant-sweep-19,3,False,0,1,True,True,False,False,0.9999834925346676,4.483795642852783 3 | 1,logical-sweep-17,3,False,0,1,True,True,False,False,0.9999510958654018,4.452263832092285 4 | 2,distinctive-sweep-18,3,False,0,1,True,True,False,False,0.9999742569135218,4.492233753204346 5 | 3,floral-sweep-16,3,False,0,1,True,True,False,False,0.999978368969236,4.482064723968506 6 | 4,laced-sweep-15,3,False,0,1,True,True,False,False,0.9999468021267304,4.464197158813477 7 | 5,dulcet-sweep-14,3,False,0,1,True,True,False,False,0.9998824498059984,4.4695892333984375 8 | 6,serene-sweep-13,3,False,0,1,True,True,False,False,0.9999359140889986,4.4987101554870605 9 | 7,polished-sweep-10,3,False,0,1,True,False,False,False,0.999966678058504,4.492259979248047 10 | 8,cosmic-sweep-7,3,False,0,1,True,False,False,False,0.9999527718070884,4.4820380210876465 11 | 9,solar-sweep-11,3,False,0,1,True,False,False,False,0.9999627495518776,4.49423885345459 12 | 10,azure-sweep-8,3,False,0,1,True,False,False,False,0.9999658911050588,4.451657295227051 13 | 11,lunar-sweep-12,3,False,0,1,True,True,False,False,0.999911378593971,4.443646430969238 14 | 12,lunar-sweep-9,3,False,0,1,True,False,False,False,0.99997813611064,4.492734909057617 15 | 13,deep-sweep-6,3,False,0,1,True,False,False,False,0.9999621276126484,4.463892936706543 16 | 14,noble-sweep-5,3,False,0,1,True,False,False,False,0.9999522251170536,4.491568088531494 17 | 15,spring-sweep-4,3,False,0,1,True,False,False,False,0.9999769496598908,4.497834205627441 18 | 16,dark-sweep-18,3,False,0,1,True,True,False,False,0.6902247615682603,5.701242446899414 19 | 17,eager-sweep-3,3,False,0,1,True,False,False,False,0.9999642529591248,4.443384170532227 20 | 18,laced-sweep-1,3,False,0,1,True,False,False,False,0.999976140927932,4.460049629211426 21 | -------------------------------------------------------------------------------- /notebooks/sem_5d_sweep_kdau2seo.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,earnest-sweep-20,5,False,0,1,True,True,False,False,0.8190186979735528,3.074489116668701 3 | 1,playful-sweep-18,5,False,0,1,True,True,False,False,0.9993895610460726,1.9437751770019531 4 | 2,glamorous-sweep-19,5,False,0,1,True,True,False,False,0.9998020879474088,1.9493937492370603 5 | 3,celestial-sweep-17,5,False,0,1,True,True,False,False,0.822084413101042,3.0058584213256836 6 | 4,effortless-sweep-15,5,False,0,1,True,True,False,False,0.9993695883932232,1.985156536102295 7 | 5,true-sweep-16,5,False,0,1,True,True,False,False,0.8185182855941202,3.023531675338745 8 | 6,swept-sweep-13,5,False,0,1,True,True,False,False,0.9994050747388996,1.9483914375305176 9 | 7,skilled-sweep-14,5,False,0,1,True,True,False,False,0.9993582123422764,1.893308401107788 10 | 8,worldly-sweep-11,5,False,0,1,True,True,False,False,0.9997244104382724,1.913358688354492 11 | 9,royal-sweep-9,5,False,0,1,True,False,False,False,0.9998361494785972,1.94585919380188 12 | 10,valiant-sweep-10,5,False,0,1,True,False,False,False,0.816690862169852,3.03879714012146 13 | 11,breezy-sweep-6,5,False,0,1,True,False,False,False,0.9997581404610733,1.9464752674102783 14 | 12,winter-sweep-7,5,False,0,1,True,False,False,False,0.8795997458112131,2.9836673736572266 15 | 13,lunar-sweep-4,5,False,0,1,True,False,False,False,0.9997457993188376,1.8816308975219729 16 | 14,glad-sweep-5,5,False,0,1,True,False,False,False,0.9995376654534928,1.968914031982422 17 | 15,honest-sweep-1,5,False,0,1,True,False,False,False,0.9998311645399612,1.9033708572387695 18 | 16,wandering-sweep-3,5,False,0,1,True,False,False,False,0.9999524627177776,1.9334661960601809 19 | 17,quiet-sweep-8,5,False,0,1,True,False,False,False,0.9998030919100152,1.9282464981079104 20 | 18,skilled-sweep-2,5,False,0,1,True,False,False,False,0.9998855164639356,1.918325901031494 21 | -------------------------------------------------------------------------------- /scripts/exclude_nodes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Generate an exclude list of nodes for SLURM 3 | # 4 | # Usage: 5 | # exclude_nodes.sh sbatch [...] 6 | # sbatch [...] --exclude=$(exclude_nodes.sh) [...] 7 | # 8 | # Example raw output: 9 | # λ ~/ bash exclude_nodes.sh 10 | # bg-slurmb-bm-3,slurm-bm-05,slurm-bm-08,slurm-bm-32,slurm-bm-42,slurm-bm-47,slurm-bm-55,slurm-bm-60,slurm-bm-62,slurm-bm-82% 11 | # 12 | # 13 | # Based on the MLCloud monitor at 14 | # http://134.2.168.207:3001/d/VbORxQJnk/slurm-monitoring-dashboard?orgId=1&from=now-3h&to=now&refresh=1m 15 | # 16 | # Up to date version of this script available at 17 | # https://gist.github.com/stes/52a139e260e25c72a97e2180d5be3bdb 18 | 19 | get_exclude_list() { 20 | curl -XPOST '134.2.168.207:8085/api/v2/query?orgID=1a87e58d4b097066' -sS \ 21 | -H 'Accept:application/csv' \ 22 | -H 'Content-type:application/vnd.flux' \ 23 | -H 'Authorization: Bearer 88JGdt2FvqcPXJwFQi5zGon6D0z7YP54' \ 24 | -d 'from(bucket: "tueslurm") 25 | |> range(start: -35s) 26 | |> filter(fn: (r) => r["_measurement"] == "node_state") 27 | |> filter(fn: (r) => r["_field"] == "responding") 28 | |> filter(fn: (r) => r["state"] == "down" or r["state"] == "drained") 29 | |> filter(fn: (r) => r["_value"] == 0) 30 | |> group(columns: ["_time"]) 31 | |> keep(columns: ["hostname"])' 32 | } 33 | 34 | formatted_exclude_list() { 35 | get_exclude_list \ 36 | | tail -n +2 \ 37 | | cut -f4 -d, \ 38 | | tr -d '\r' \ 39 | | sort | uniq \ 40 | | sed -r '/^\s*$/d' \ 41 | | tr '\n' ',' \ 42 | | sed -e 's/,$//g' 43 | } 44 | echo $(formatted_exclude_list) 45 | mode=$1 46 | shift 1 47 | case $mode in 48 | sbatch) 49 | sbatch --exclude=$(formatted_exclude_list) "$@" 50 | ;; 51 | 52 | *) 53 | formatted_exclude_list 54 | ;; 55 | esac -------------------------------------------------------------------------------- /notebooks/sem_8d_sweep_v3kd7kca.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,graceful-firebrand-766,8,False,0,1,True,True,False,False,0.9992295456708588,0.21403074264526367 3 | 1,clear-sunset-767,8,False,0,1,True,True,False,False,0.999486315004368,0.19814586639404297 4 | 2,copper-wood-768,8,False,0,1,True,True,False,False,0.9336650331449768,0.463017463684082 5 | 3,lucky-lion-769,8,False,0,1,True,True,False,False,0.9994647722676646,0.2104320526123047 6 | 4,faithful-night-770,8,False,0,1,True,True,False,False,0.9994833014056572,0.21943902969360352 7 | 5,decent-totem-771,8,False,0,1,True,True,False,False,0.9994218803476758,0.20617437362670896 8 | 6,giddy-moon-772,8,False,0,1,True,True,False,False,0.9546204229416516,0.39360618591308594 9 | 7,dashing-sky-773,8,False,0,1,True,True,False,False,0.9350217006543976,0.48402976989746094 10 | 8,efficient-sun-774,8,False,0,1,True,True,False,False,0.9994657360326052,0.2075634002685547 11 | 9,graceful-serenity-775,8,False,0,1,True,False,False,False,0.9996403885427748,0.20478010177612305 12 | 10,logical-silence-776,8,False,0,1,True,True,False,False,0.9994780825076024,0.2234654426574707 13 | 11,fanciful-vortex-777,8,False,0,1,True,False,False,False,0.9995520811012104,0.1971039772033691 14 | 12,curious-universe-778,8,False,0,1,True,False,False,False,0.9996713932249616,0.2232203483581543 15 | 13,jolly-shape-779,8,False,0,1,True,False,False,False,0.9996595352208192,0.20654773712158203 16 | 14,iconic-plant-780,8,False,0,1,True,False,False,False,0.9995704840859012,0.21356201171875 17 | 15,solar-haze-781,8,False,0,1,True,False,False,False,0.9995237113779494,0.21808767318725583 18 | 16,colorful-eon-782,8,False,0,1,True,False,False,False,0.9996085561632492,0.19951772689819336 19 | 17,autumn-flower-783,8,False,0,1,True,False,False,False,0.999546733108815,0.2105250358581543 20 | 18,still-brook-784,8,False,0,1,True,False,False,False,0.9995925756876288,0.1993727684020996 21 | -------------------------------------------------------------------------------- /notebooks/sem_10d_sparse_sweep_t7rmrux1.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,dandy-bush-896,10,False,0,1,True,True,False,False,0.8880686119855401,0.08689498901367188 3 | 1,vital-yogurt-897,10,False,0,1,True,True,False,False,0.8482571340203637,0.22718429565429688 4 | 2,revived-glitter-898,10,False,0,1,True,True,False,False,0.36292738693421894,4.941876411437988 5 | 3,dazzling-sea-899,10,False,0,1,True,True,False,False,0.7825259533014958,0.4925127029418946 6 | 4,hardy-terrain-900,10,False,0,1,True,True,False,False,0.7131718725702294,0.4598922729492187 7 | 5,absurd-galaxy-901,10,False,0,1,True,True,False,False,0.8253574629148996,0.24730396270751953 8 | 6,mild-firefly-902,10,False,0,1,True,True,False,False,0.4120730846524629,4.418057441711426 9 | 7,peach-sun-903,10,False,0,1,True,True,False,False,0.8316040574714311,0.205294132232666 10 | 8,hearty-monkey-904,10,False,0,1,True,True,False,False,0.9384207512809172,0.08607959747314453 11 | 9,icy-frost-905,10,False,0,1,True,False,False,False,0.8498906843335347,0.20059680938720703 12 | 10,glamorous-surf-906,10,False,0,1,True,False,False,False,0.998845625412927,0.034972190856933594 13 | 11,floral-silence-907,10,False,0,1,True,False,False,False,0.8205658296206438,0.2049999237060547 14 | 12,leafy-energy-908,10,False,0,1,True,False,False,False,0.3835599676275568,4.540416240692139 15 | 13,lively-paper-909,10,False,0,1,True,False,False,False,0.8501652502760267,0.22983646392822263 16 | 14,gentle-microwave-910,10,False,0,1,True,False,False,False,0.8872115948071577,0.08292198181152344 17 | 15,lively-lion-911,10,False,0,1,True,False,False,False,0.9424619803608764,0.0835275650024414 18 | 16,pretty-bird-912,10,False,0,1,True,False,False,False,0.4180905353243494,4.421955108642578 19 | 17,wise-fog-913,10,False,0,1,True,False,False,False,0.9329449987886742,0.08167648315429688 20 | 18,comfy-bush-914,10,False,0,1,True,False,False,False,0.9125699159725038,0.07466506958007812 21 | -------------------------------------------------------------------------------- /notebooks/sem_5d_sparse_sweep_7n108utd.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,dainty-sweep-20,5,False,0,1,True,True,False,False,0.8605566693153509,3.156348705291748 3 | 1,unique-sweep-19,5,False,0,1,True,True,False,False,0.9997611339139388,1.9037175178527832 4 | 2,brisk-sweep-18,5,False,0,1,True,True,False,False,0.9997199519680472,1.93845534324646 5 | 3,lemon-sweep-17,5,False,0,1,True,True,False,False,0.9984290010990629,1.939425468444824 6 | 4,vague-sweep-16,5,False,0,1,True,True,False,False,0.8597623994691783,3.0247762203216553 7 | 5,swift-sweep-15,5,False,0,1,True,True,False,False,0.9992409951087394,1.988208055496216 8 | 6,serene-sweep-14,5,False,0,1,True,True,False,False,0.9997871540199772,1.886852741241455 9 | 7,desert-sweep-13,5,False,0,1,True,True,False,False,0.9997916521885554,1.9361343383789065 10 | 8,olive-sweep-12,5,False,0,1,True,True,False,False,0.9998924572557067,1.921983242034912 11 | 9,efficient-sweep-11,5,False,0,1,True,True,False,False,0.9996489267749146,1.9088010787963867 12 | 10,dry-sweep-10,5,False,0,1,True,False,False,False,0.8660540846011828,3.0744245052337646 13 | 11,celestial-sweep-9,5,False,0,1,True,False,False,False,0.9997190417263034,1.949199914932251 14 | 12,vague-sweep-7,5,False,0,1,True,False,False,False,0.8887474432381002,2.9533934593200684 15 | 13,sage-sweep-8,5,False,0,1,True,False,False,False,0.999764614742212,1.9274096488952637 16 | 14,lemon-sweep-6,5,False,0,1,True,False,False,False,0.8597000177644654,3.019930839538574 17 | 15,expert-sweep-5,5,False,0,1,True,False,False,False,0.9994576253931562,1.973270893096924 18 | 16,confused-sweep-4,5,False,0,1,True,False,False,False,0.9996303757641034,1.88446044921875 19 | 17,fast-sweep-2,5,False,0,1,True,False,False,False,0.9998035295351776,1.9200544357299805 20 | 18,winter-sweep-3,5,False,0,1,True,False,False,False,0.9999002216320252,1.9328291416168213 21 | 19,gentle-sweep-1,5,False,0,1,True,False,False,False,0.9998635630074012,1.903132438659668 22 | -------------------------------------------------------------------------------- /care_nl_ica/ica.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ICAModel(nn.Module): 6 | """ 7 | Linear ICA class 8 | """ 9 | 10 | def __init__( 11 | self, dim: int, signal_model: torch.distributions.distribution.Distribution 12 | ): 13 | """ 14 | :param dim: an integer specifying the number of signals 15 | :param signal_model: class of the signal model distribution 16 | """ 17 | super().__init__() 18 | 19 | self.W = torch.nn.Parameter(torch.eye(dim)) 20 | self.signal_model = signal_model 21 | 22 | def forward(self, x: torch.Tensor): 23 | """ 24 | :param x : data tensor of (num_samples, signal_dim) 25 | """ 26 | 27 | # unmixing 28 | return torch.matmul(x, self.W) 29 | 30 | @staticmethod 31 | def _ml_objective( 32 | x: torch.Tensor, signal_model: torch.distributions.laplace.Laplace 33 | ): 34 | """ 35 | Implements the ML objective 36 | 37 | :param x: tensor to be transformed 38 | :param signal_model: Laplace distribution from torch.distributions. 39 | """ 40 | 41 | # transform with location and scale 42 | x_tr = (x - signal_model.mean).abs() / signal_model.scale 43 | 44 | return -x_tr - torch.log(2 * signal_model.scale) 45 | 46 | def loss(self, x: torch.Tensor): 47 | """ 48 | :param x : data tensor of (num_samples, signal_dim) 49 | """ 50 | 51 | # ML objective 52 | model_entropy = self._ml_objective(self(x), self.signal_model).mean() 53 | 54 | # log of the abs determinant of the unmixing matrix 55 | # _, log_abs_det = torch.linalg.slogdet(self.W) 56 | _, log_abs_det = torch.slogdet(self.W) 57 | 58 | # as we need to minimize, there is a minus sign here 59 | return -model_entropy, -log_abs_det 60 | 61 | def dependency(self, row_idx: int, col_idx: int): 62 | return self.W[row_idx, col_idx] 63 | -------------------------------------------------------------------------------- /notebooks/sem_8d_sparse_sweep_2ykg2w21.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,wild-snow-831,8,False,0,1,True,True,False,False,0.5665957865913724,2.985647439956665 3 | 1,chocolate-donkey-832,8,False,0,1,True,True,False,False,0.9251126760555228,0.4785151481628418 4 | 2,deft-surf-833,8,False,0,1,True,True,False,False,0.6397167261074248,2.1681180000305176 5 | 3,polished-dragon-834,8,False,0,1,True,True,False,False,0.9985103208167582,0.21297693252563477 6 | 4,floral-snow-835,8,False,0,1,True,True,False,False,0.8637280024515255,0.4936332702636719 7 | 5,curious-blaze-836,8,False,0,1,True,True,False,False,0.782662825797224,0.991201400756836 8 | 6,feasible-shape-837,8,False,0,1,True,True,False,False,0.6179069653714904,3.0603721141815186 9 | 7,fallen-dream-838,8,False,0,1,True,True,False,False,0.634959661278368,2.437565803527832 10 | 8,expert-serenity-839,8,False,0,1,True,True,False,False,0.8881299246044088,0.4492111206054687 11 | 9,true-moon-840,8,False,0,1,True,True,False,False,0.9111807302769804,0.5042295455932617 12 | 10,celestial-morning-841,8,False,0,1,True,False,False,False,0.6847524150842897,1.8819189071655271 13 | 11,glorious-shadow-842,8,False,0,1,True,False,False,False,0.9984860055248252,0.2131686210632324 14 | 12,good-jazz-843,8,False,0,1,True,False,False,False,0.9995145437982278,0.20080232620239255 15 | 13,elated-donkey-844,8,False,0,1,True,False,False,False,0.998075153819804,0.20594167709350583 16 | 14,neat-lion-845,8,False,0,1,True,False,False,False,0.682697764186909,1.733640432357788 17 | 15,denim-yogurt-846,8,False,0,1,True,False,False,False,0.8985175284775266,0.4691781997680664 18 | 16,gallant-jazz-847,8,False,0,1,True,False,False,False,0.5814280385222736,3.046862840652466 19 | 17,cerulean-disco-848,8,False,0,1,True,False,False,False,0.5956188631893446,3.068350315093994 20 | 18,eternal-morning-849,8,False,0,1,True,False,False,False,0.99880277420109,0.2215604782104492 21 | 19,ruby-morning-850,8,False,0,1,True,False,False,False,0.999422871852962,0.197904109954834 22 | -------------------------------------------------------------------------------- /notebooks/sem_10d_sweep_i871qu61.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,dry-firefly-1108,10,False,0,1,True,True,False,False,0.9462453311432756,0.08972740173339844 3 | 1,quiet-durian-1109,10,False,0,1,True,True,False,False,0.9461753828289016,0.0746774673461914 4 | 2,light-waterfall-1110,10,False,0,1,True,True,False,False,0.9982808458600884,0.03744983673095703 5 | 3,sparkling-armadillo-1111,10,False,0,1,True,True,False,False,0.9975028595230933,0.03430938720703125 6 | 4,northern-fog-1112,10,False,0,1,True,True,False,False,0.9473158082372166,0.08133888244628906 7 | 5,true-lion-1113,10,False,0,1,True,True,False,False,0.9980520715278508,0.03544425964355469 8 | 6,peachy-rain-1114,10,False,0,1,True,True,False,False,0.9481769773466752,0.09392070770263672 9 | 7,dutiful-sound-1115,10,False,0,1,True,True,False,False,0.9448784720406526,0.08184337615966797 10 | 8,lemon-night-1116,10,False,0,1,True,True,False,False,0.9468860650409092,0.0773916244506836 11 | 9,stellar-monkey-1117,10,False,0,1,True,True,False,False,0.946317823246206,0.076263427734375 12 | 10,wild-sun-1118,10,False,0,1,True,False,False,False,0.945418066209334,0.07754802703857422 13 | 11,sparkling-meadow-1119,10,False,0,1,True,False,False,False,0.9410652837876068,0.07428455352783203 14 | 12,woven-moon-1120,10,False,0,1,True,False,False,False,0.946525500169108,0.08149337768554688 15 | 13,brisk-deluge-1121,10,False,0,1,True,False,False,False,0.9990280483408992,0.03194618225097656 16 | 14,grateful-spaceship-1122,10,False,0,1,True,False,False,False,0.947555533628678,0.08596515655517578 17 | 15,rare-sound-1123,10,False,0,1,True,False,False,False,0.930128627881422,0.08655261993408203 18 | 16,swept-eon-1124,10,False,0,1,True,False,False,False,0.9987026645030392,0.034996986389160156 19 | 17,exalted-capybara-1125,10,False,0,1,True,False,False,False,0.9989967585241224,0.028470993041992188 20 | 18,bright-firefly-1126,10,False,0,1,True,False,False,False,0.9992242973706812,0.02747154235839844 21 | 19,smart-serenity-1127,10,False,0,1,True,False,False,False,0.9415826500938776,0.07753467559814453 22 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/ORIGINAL_PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the clevr-dataset-gen software contributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: [3.8] 20 | os: [ubuntu-latest] 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | with: 25 | submodules: true 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v2 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | 31 | - name: Python ${{ matrix.python-version }} cache 32 | uses: actions/cache@v2 33 | with: 34 | path: ${{ env.pythonLocation }} 35 | key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('dev-requirements.txt') }} 36 | 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | if [ -f requirements.txt ]; then pip install -r requirements.txt; pip install -e .; fi 41 | pip install --requirement tests/requirements.txt --quiet 42 | - name: Lint with flake8 43 | run: | 44 | python -m pip install flake8 45 | # stop the build if there are Python syntax errors or undefined names 46 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude cl_ica,icebeem,pytorch-flows 47 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 48 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude cl_ica,icebeem,pytorch-flows 49 | - name: black 50 | run: | 51 | black --check --verbose ./care_nl_ica/ 52 | - name: Test with pytest 53 | run: | 54 | python -m pip install pytest 55 | python -m pytest tests 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /care_nl_ica/metrics/metric_logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | 4 | 5 | class MetricLogger(object): 6 | def __init__(self): 7 | """ 8 | Initialize the metrics. 9 | """ 10 | super().__init__() 11 | 12 | self.accuracy = torchmetrics.Accuracy() 13 | self.roc_auc = torchmetrics.AUROC(num_classes=2) 14 | self.auc_score = torchmetrics.AUC() 15 | self.precision = torchmetrics.Precision() 16 | self.recall = torchmetrics.Recall() 17 | self.f1 = torchmetrics.F1() 18 | self.hamming_distance = torchmetrics.HammingDistance() 19 | self.stat_scores = torchmetrics.StatScores() # FP, FN, TP, TN 20 | 21 | def update(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> None: 22 | """ 23 | Update the metrics given the predictions and the ground truth. 24 | """ 25 | self.accuracy.update(y_pred, y_true) 26 | self.roc_auc.update(y_pred, y_true) 27 | self.auc_score.update(y_pred, y_true) 28 | self.precision.update(y_pred, y_true) 29 | self.recall.update(y_pred, y_true) 30 | self.f1.update(y_pred, y_true) 31 | self.hamming_distance.update(y_pred, y_true) 32 | self.stat_scores.update(y_pred, y_true) 33 | 34 | def compute(self) -> dict: 35 | """ 36 | Compute the metrics. 37 | """ 38 | [tp, fp, tn, fn, sup] = self.stat_scores.compute() 39 | panel_name = "Metrics" 40 | return { 41 | f"{panel_name}/accuracy": self.accuracy.compute(), 42 | # f'{panel_name}/roc_auc': self.roc_auc.compute(), 43 | # f'{panel_name}/auc_score': self.auc_score.compute(), 44 | f"{panel_name}/precision": self.precision.compute(), 45 | f"{panel_name}/recall": self.recall.compute(), 46 | f"{panel_name}/f1": self.f1.compute(), 47 | f"{panel_name}/hamming_distance": self.hamming_distance.compute(), 48 | f"{panel_name}/tp": tp, 49 | f"{panel_name}/fp": fp, 50 | f"{panel_name}/tn": tn, 51 | f"{panel_name}/fn": fn, 52 | f"{panel_name}/support": sup, 53 | } 54 | -------------------------------------------------------------------------------- /care_nl_ica/independence/indep_check.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from care_nl_ica.independence.hsic import HSIC 4 | 5 | 6 | class IndependenceChecker(object): 7 | """ 8 | Class for encapsulating independence test-related methods 9 | """ 10 | 11 | def __init__(self, hparams) -> None: 12 | super().__init__() 13 | self.hparams = hparams 14 | 15 | self.test = HSIC(hparams.num_permutations) 16 | 17 | print("Using Bonferroni = 4") 18 | 19 | def check_bivariate_dependence(self, x1, x2): 20 | decisions = [] 21 | var_map = [1, 1, 2, 2] 22 | with torch.no_grad(): 23 | decisions.append( 24 | self.test.run_test(x1[:, 0], x2[:, 1], bonferroni=4).item() 25 | ) 26 | decisions.append( 27 | self.test.run_test(x1[:, 0], x2[:, 0], bonferroni=4).item() 28 | ) 29 | decisions.append( 30 | self.test.run_test(x1[:, 1], x2[:, 0], bonferroni=4).item() 31 | ) 32 | decisions.append( 33 | self.test.run_test(x1[:, 1], x2[:, 1], bonferroni=4).item() 34 | ) 35 | 36 | return decisions, var_map 37 | 38 | def check_multivariate_dependence( 39 | self, x1: torch.Tensor, x2: torch.Tensor 40 | ) -> torch.Tensor: 41 | """ 42 | Carries out HSIC for the multivariate case, all pairs are tested 43 | :param x1: tensor of the first batch of variables in the shape of (num_elem, num_dim) 44 | :param x2: tensor of the second batch of variables in the shape of (num_elem, num_dim) 45 | :return: the adjacency matrix 46 | """ 47 | num_dim = x1.shape[-1] 48 | max_edge_num = num_dim**2 49 | adjacency_matrix = torch.zeros(num_dim, num_dim).bool() 50 | 51 | print(max_edge_num) 52 | 53 | with torch.no_grad(): 54 | for i in range(num_dim): 55 | for j in range(num_dim): 56 | adjacency_matrix[i, j] = self.test.run_test( 57 | x1[:, i], x2[:, j], bonferroni=4 # max_edge_num 58 | ).item() 59 | 60 | return adjacency_matrix 61 | -------------------------------------------------------------------------------- /care_nl_ica/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | def to_one_hot(x, m=None): 7 | "batch one hot" 8 | if type(x) is not list: 9 | x = [x] 10 | if m is None: 11 | ml = [] 12 | for xi in x: 13 | ml += [xi.max() + 1] 14 | m = max(ml) 15 | dtp = x[0].dtype 16 | xoh = [] 17 | for i, xi in enumerate(x): 18 | xoh += [np.zeros((xi.size, int(m)), dtype=dtp)] 19 | xoh[i][np.arange(xi.size), xi.astype(np.int)] = 1 20 | return xoh 21 | 22 | 23 | def one_hot_encode(labels, n_labels=10): 24 | """ 25 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1. 26 | """ 27 | 28 | assert np.min(labels) >= 0 and np.max(labels) < n_labels 29 | 30 | y = np.zeros([labels.size, n_labels]).astype(np.float32) 31 | y[range(labels.size), labels] = 1 32 | 33 | return y 34 | 35 | 36 | def single_one_hot_encode(label, n_labels=10): 37 | """ 38 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1. 39 | """ 40 | 41 | assert label >= 0 and label < n_labels 42 | 43 | y = np.zeros([n_labels]).astype(np.float32) 44 | y[label] = 1 45 | 46 | return y 47 | 48 | 49 | def single_one_hot_encode_rev(label, n_labels=10, start_label=0): 50 | """ 51 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1. 52 | """ 53 | assert label >= start_label and label < n_labels 54 | y = np.zeros([n_labels - start_label]).astype(np.float32) 55 | y[label - start_label] = 1 56 | return y 57 | 58 | 59 | mnist_one_hot_transform = lambda label: single_one_hot_encode(label, n_labels=10) 60 | contrastive_one_hot_transform = lambda label: single_one_hot_encode(label, n_labels=2) 61 | 62 | 63 | def make_dir(dir_name): 64 | if dir_name[-1] != "/": 65 | dir_name += "/" 66 | if not os.path.exists(dir_name): 67 | os.makedirs(dir_name) 68 | return dir_name 69 | 70 | 71 | def make_file(file_name): 72 | if not os.path.exists(file_name): 73 | open(file_name, "a").close() 74 | return file_name 75 | -------------------------------------------------------------------------------- /tests/test_graph_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from care_nl_ica.graph_utils import ( 5 | indirect_causes, 6 | graph_paths, 7 | false_positive_paths, 8 | false_negative_paths, 9 | ) 10 | 11 | 12 | @pytest.fixture 13 | def three_dim_chain() -> torch.Tensor: 14 | dim = 3 15 | dep_mat = torch.tril(torch.ones(dim, dim)) 16 | zeros_in_chain = torch.tril(torch.ones_like(dep_mat), -2) 17 | dep_mat[zeros_in_chain == 1] = 0 18 | 19 | return dep_mat 20 | 21 | 22 | def test_graph_paths(three_dim_chain: torch.Tensor): 23 | gt_paths = { 24 | 0: torch.Tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), 25 | 1: torch.Tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), 26 | } 27 | 28 | [ 29 | torch.all(path, gt_path) 30 | for path, gt_path in zip(graph_paths(three_dim_chain).values(), gt_paths) 31 | ] 32 | 33 | 34 | def test_indirect_causes(three_dim_chain: torch.Tensor): 35 | # ground truth 36 | indirect_paths = torch.zeros_like(three_dim_chain) 37 | indirect_paths[-1, 0] = 1 38 | 39 | print(indirect_causes(three_dim_chain)[1]) 40 | 41 | torch.all(indirect_causes(three_dim_chain)[0] == indirect_paths) 42 | 43 | 44 | @pytest.mark.parametrize("weighted", [True, False]) 45 | def test_false_positive_paths(three_dim_chain: torch.Tensor, weighted): 46 | # false positive at [2,0] 47 | dep_mat = torch.tril(torch.rand_like(three_dim_chain)) + 1.3 48 | direct_causes = torch.tril((three_dim_chain.abs() > 1e-6).float(), -1) 49 | 50 | fp = torch.Tensor([1 if weighted is False else dep_mat[2, 0], 0]) 51 | torch.all( 52 | false_positive_paths(dep_mat, graph_paths(direct_causes), weighted=weighted) 53 | == fp 54 | ) 55 | 56 | 57 | @pytest.mark.parametrize("weighted", [True, False]) 58 | def test_false_negative_paths(three_dim_chain: torch.Tensor, weighted): 59 | gt_adjacency = torch.tril(torch.rand_like(three_dim_chain)) + 1.3 60 | direct_causes = torch.tril((gt_adjacency.abs() > 1e-6).float(), -1) 61 | 62 | # false negative at [2,0] - here three_dim_chain is the estimated value 63 | fn = torch.Tensor([1 if weighted is False else gt_adjacency[2, 0], 0]) 64 | torch.all( 65 | false_negative_paths( 66 | three_dim_chain, graph_paths(direct_causes), weighted=weighted 67 | ) 68 | == fn 69 | ) 70 | -------------------------------------------------------------------------------- /notebooks/monti_sweep_77huh2ue.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,charmed-sweep-25,6,False,3,5,False,False,False,False,0.8379107954836166,1.854844570159912 3 | 1,fiery-sweep-24,6,False,3,5,False,False,False,False,0.8497126701689969,1.8678297996520996 4 | 2,northern-sweep-23,6,False,3,5,False,False,False,False,0.840644201969512,1.8648576736450195 5 | 3,happy-sweep-21,6,False,3,5,False,False,False,False,0.880131429668987,2.045818567276001 6 | 4,clean-sweep-22,6,False,3,5,False,False,False,False,0.8293840628008139,1.9895012378692627 7 | 5,sweepy-sweep-20,6,False,3,4,False,False,False,False,0.8387286044061634,1.864729881286621 8 | 6,noble-sweep-17,6,False,3,4,False,False,False,False,0.8350038247753466,1.88857102394104 9 | 7,fanciful-sweep-16,6,False,3,4,False,False,False,False,0.8100973160695278,1.8514604568481443 10 | 8,golden-sweep-18,6,False,3,4,False,False,False,False,0.8343979017561596,1.8834435939788816 11 | 9,fanciful-sweep-19,6,False,3,4,False,False,False,False,0.847744241010156,1.8607234954833984 12 | 10,woven-sweep-15,6,False,3,3,False,False,False,False,0.8421843332186746,1.828907251358032 13 | 11,trim-sweep-13,6,False,3,3,False,False,False,False,0.8306750434765068,1.8489642143249512 14 | 12,lucky-sweep-14,6,False,3,3,False,False,False,False,0.9993110388035972,1.0367021560668943 15 | 13,zany-sweep-12,6,False,3,3,False,False,False,False,0.9873989134725712,1.1389002799987793 16 | 14,eager-sweep-9,6,False,3,2,False,False,False,False,0.9996324727750469,1.0297703742980957 17 | 15,wild-sweep-10,6,False,3,2,False,False,False,False,0.999326604292618,1.0480751991271973 18 | 16,vocal-sweep-11,6,False,3,3,False,False,False,False,0.9989177035541376,1.04758882522583 19 | 17,zesty-sweep-7,6,False,3,2,False,False,False,False,0.9991668196876312,1.0220303535461426 20 | 18,sage-sweep-8,6,False,3,2,False,False,False,False,0.9990669578066456,1.0225944519042969 21 | 19,good-sweep-5,6,False,3,1,False,False,False,False,0.9998565660776868,1.0176987648010254 22 | 20,cerulean-sweep-6,6,False,3,2,False,False,False,False,0.9998933114660916,1.0363197326660156 23 | 21,zany-sweep-4,6,False,3,1,False,False,False,False,0.999831198266047,1.0289840698242188 24 | 22,fast-sweep-3,6,False,3,1,False,False,False,False,0.9998441213652508,1.017876148223877 25 | 23,fresh-sweep-2,6,False,3,1,False,False,False,False,0.9997517473969434,1.056753158569336 26 | 24,genial-sweep-1,6,False,3,1,False,False,False,False,0.9999328830460192,1.008960247039795 27 | -------------------------------------------------------------------------------- /notebooks/monti_sweep_3rqfxkyl.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,twilight-snowball-981,6,False,3,5,False,False,False,False,0.6164149991392273,2.9853672981262207 3 | 1,atomic-sun-982,6,False,3,5,False,False,False,False,0.6815620257946304,3.0047552585601807 4 | 2,warm-disco-983,6,False,3,5,False,False,False,False,0.6343687499663566,2.9120864868164062 5 | 3,deep-gorge-984,6,False,3,5,False,False,False,False,0.5551614959805392,3.3627450466156006 6 | 4,amber-cosmos-985,6,False,3,5,False,False,False,False,0.5284008640275427,5.690755844116211 7 | 5,scarlet-planet-986,6,False,3,4,False,False,False,False,0.9876972542156532,1.0956225395202637 8 | 6,fresh-pyramid-987,6,False,3,4,False,False,False,False,0.968844345402296,1.1442599296569824 9 | 7,morning-frog-988,6,False,3,4,False,False,False,False,0.9810430198144292,1.1612944602966309 10 | 8,worldly-frost-989,6,False,3,4,False,False,False,False,0.9958163626763924,1.0813336372375488 11 | 9,avid-brook-990,6,False,3,4,False,False,False,False,0.9548101426294048,1.3085899353027344 12 | 10,expert-frog-991,6,False,3,3,False,False,False,False,0.9984634171836134,1.0292787551879885 13 | 11,pleasant-meadow-992,6,False,3,3,False,False,False,False,0.9978086505101568,1.0431675910949707 14 | 12,upbeat-moon-993,6,False,3,3,False,False,False,False,0.9916541639921048,1.0663161277770996 15 | 13,stellar-flower-994,6,False,3,2,False,False,False,False,0.9999024314825214,1.0448369979858398 16 | 14,whole-bush-995,6,False,3,2,False,False,False,False,0.9997775690224052,1.030303955078125 17 | 15,whole-hill-996,6,False,3,1,False,False,False,False,0.9997894463259328,1.0464377403259275 18 | 16,lemon-wildflower-997,6,False,3,3,False,False,False,False,0.9980021135135936,1.086085319519043 19 | 17,flowing-dream-998,6,False,3,3,False,False,False,False,0.9971217896171256,1.066732406616211 20 | 18,swept-sun-999,6,False,3,2,False,False,False,False,0.9998097101599744,1.0740838050842283 21 | 19,driven-brook-1000,6,False,3,2,False,False,False,False,0.999663395469589,1.0601258277893066 22 | 20,fresh-tree-1001,6,False,3,2,False,False,False,False,0.9996586516190464,1.0395126342773438 23 | 21,vital-waterfall-1002,6,False,3,1,False,False,False,False,0.9996800321129156,1.0390777587890625 24 | 22,autumn-monkey-1003,6,False,3,1,False,False,False,False,0.9997013129758958,1.030467510223389 25 | 23,dry-snowflake-1004,6,False,3,1,False,False,False,False,0.9998626596067872,1.001848220825195 26 | 24,true-deluge-1005,6,False,3,1,False,False,False,False,0.999675473268248,1.0738143920898438 27 | -------------------------------------------------------------------------------- /care_nl_ica/cli.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.loggers.wandb import WandbLogger 2 | from pytorch_lightning.utilities.cli import LightningCLI 3 | 4 | from care_nl_ica.utils import add_tags, install_package 5 | from care_nl_ica.data.datamodules import ContrastiveDataModule 6 | from care_nl_ica.runner import ContrastiveICAModule 7 | 8 | 9 | class MyLightningCLI(LightningCLI): 10 | def add_arguments_to_parser(self, parser): 11 | parser.add_argument( 12 | "--notes", 13 | type=str, 14 | default=None, 15 | help="Notes for the run on Weights and Biases", 16 | ) 17 | # todo: process notes based on args in before_instantiate_classes 18 | parser.add_argument( 19 | "--tags", 20 | type=str, 21 | nargs="*", # 0 or more values expected => creates a list 22 | default=None, 23 | help="Tags for the run on Weights and Biases", 24 | ) 25 | 26 | parser.link_arguments("data.latent_dim", "model.latent_dim") 27 | parser.link_arguments("data.box_min", "model.box_min") 28 | parser.link_arguments("data.box_max", "model.box_max") 29 | parser.link_arguments("data.sphere_r", "model.sphere_r") 30 | parser.link_arguments("data.normalize_latents", "model.normalize_latents") 31 | 32 | def before_instantiate_classes(self) -> None: 33 | self.config[self.subcommand].trainer.logger.init_args.tags = add_tags( 34 | self.config[self.subcommand] 35 | ) 36 | 37 | def before_fit(self): 38 | if isinstance(self.trainer.logger, WandbLogger) is True: 39 | # required as the parser cannot parse the "-" symbol 40 | self.trainer.logger.__dict__["_wandb_init"][ 41 | "entity" 42 | ] = "causal-representation-learning" 43 | 44 | if self.config[self.subcommand].model.offline is True: 45 | self.trainer.logger.__dict__["_wandb_init"]["mode"] = "offline" 46 | else: 47 | self.trainer.logger.__dict__["_wandb_init"]["mode"] = "online" 48 | 49 | # todo: maybe set run in the CLI to false and call watch before? 50 | self.trainer.logger.watch(self.model, log="all", log_freq=250) 51 | 52 | 53 | if __name__ == "__main__": 54 | install_package() 55 | cli = MyLightningCLI( 56 | ContrastiveICAModule, 57 | ContrastiveDataModule, 58 | save_config_callback=None, 59 | run=True, 60 | parser_kwargs={"parse_as_dict": False}, 61 | ) 62 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/latent_spaces.py: -------------------------------------------------------------------------------- 1 | """Classes that combine spaces with specific probability densities.""" 2 | 3 | from typing import Callable, List 4 | from .spaces import Space 5 | import torch 6 | 7 | 8 | class LatentSpace: 9 | """Combines a topological space with a marginal and conditional density to sample from.""" 10 | 11 | def __init__( 12 | self, space: Space, sample_marginal: Callable, sample_conditional: Callable 13 | ): 14 | self.space = space 15 | self._sample_marginal = sample_marginal 16 | self._sample_conditional = sample_conditional 17 | 18 | @property 19 | def sample_conditional(self): 20 | if self._sample_conditional is None: 21 | raise RuntimeError("sample_conditional was not set") 22 | return lambda *args, **kwargs: self._sample_conditional( 23 | self.space, *args, **kwargs 24 | ) 25 | 26 | @sample_conditional.setter 27 | def sample_conditional(self, value: Callable): 28 | assert callable(value) 29 | self._sample_conditional = value 30 | 31 | @property 32 | def sample_marginal(self): 33 | if self._sample_marginal is None: 34 | raise RuntimeError("sample_marginal was not set") 35 | return lambda *args, **kwargs: self._sample_marginal( 36 | self.space, *args, **kwargs 37 | ) 38 | 39 | @sample_marginal.setter 40 | def sample_marginal(self, value: Callable): 41 | assert callable(value) 42 | self._sample_marginal = value 43 | 44 | @property 45 | def dim(self): 46 | return self.space.dim 47 | 48 | 49 | class ProductLatentSpace(LatentSpace): 50 | """A latent space which is the cartesian product of other latent spaces.""" 51 | 52 | def __init__(self, spaces: List[LatentSpace]): 53 | self.spaces = spaces 54 | 55 | def sample_conditional(self, z, size, **kwargs): 56 | x = [] 57 | n = 0 58 | for s in self.spaces: 59 | if len(z.shape) == 1: 60 | z_s = z[n : n + s.space.n] 61 | else: 62 | z_s = z[:, n : n + s.space.n] 63 | n += s.space.n 64 | x.append(s.sample_conditional(z=z_s, size=size, **kwargs)) 65 | 66 | return torch.cat(x, -1) 67 | 68 | def sample_marginal(self, size, **kwargs): 69 | x = [s.sample_marginal(size=size, **kwargs) for s in self.spaces] 70 | 71 | return torch.cat(x, -1) 72 | 73 | @property 74 | def dim(self): 75 | return sum([s.dim for s in self.spaces]) 76 | -------------------------------------------------------------------------------- /care_nl_ica/models/tcl/tcl_preprocessing.py: -------------------------------------------------------------------------------- 1 | """Preprocessing""" 2 | 3 | import numpy as np 4 | 5 | 6 | # ============================================================ 7 | # ============================================================ 8 | def pca(x, num_comp=None, params=None, zerotolerance=1e-7): 9 | """Apply PCA whitening to data. 10 | Args: 11 | x: data. 2D ndarray [num_comp, num_data] 12 | num_comp: number of components 13 | params: (option) dictionary of PCA parameters {'mean':?, 'W':?, 'A':?}. If given, apply this to the data 14 | zerotolerance: (option) 15 | Returns: 16 | x: whitened data 17 | parms: parameters of PCA 18 | mean: subtracted mean 19 | W: whitening matrix 20 | A: mixing matrix 21 | """ 22 | # print("PCA...") 23 | 24 | # Dimension 25 | if num_comp is None: 26 | num_comp = x.shape[0] 27 | # print(" num_comp={0:d}".format(num_comp)) 28 | 29 | # From learned parameters -------------------------------- 30 | if params is not None: 31 | # Use previously-trained model 32 | print(" use learned value") 33 | data_pca = x - params["mean"] 34 | x = np.dot(params["W"], data_pca) 35 | 36 | # Learn from data ---------------------------------------- 37 | else: 38 | # Zero mean 39 | xmean = np.mean(x, 1).reshape([-1, 1]) 40 | x = x - xmean 41 | 42 | # Eigenvalue decomposition 43 | xcov = np.cov(x) 44 | d, V = np.linalg.eigh(xcov) # Ascending order 45 | # Convert to descending order 46 | d = d[::-1] 47 | V = V[:, ::-1] 48 | 49 | zeroeigval = np.sum((d[:num_comp] / d[0]) < zerotolerance) 50 | if zeroeigval > 0: # Do not allow zero eigenval 51 | raise ValueError 52 | 53 | # Calculate contribution ratio 54 | contratio = np.sum(d[:num_comp]) / np.sum(d) 55 | # print(" contribution ratio={0:f}".format(contratio)) 56 | 57 | # Construct whitening and dewhitening matrices 58 | dsqrt = np.sqrt(d[:num_comp]) 59 | dsqrtinv = 1 / dsqrt 60 | V = V[:, :num_comp] 61 | # Whitening 62 | W = np.dot(np.diag(dsqrtinv), V.transpose()) # whitening matrix 63 | A = np.dot(V, np.diag(dsqrt)) # de-whitening matrix 64 | x = np.dot(W, x) 65 | 66 | params = {"mean": xmean, "W": W, "A": A} 67 | 68 | # Check 69 | datacov = np.cov(x) 70 | 71 | return x, params 72 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # start ssh 4 | service ssh restart 5 | 6 | # Create user account 7 | if [ -n "$USER" ]; then 8 | if [ -z "$USER_HOME" ]; then 9 | export USER_HOME=/home/$USER 10 | fi 11 | 12 | if [ -z "$USER_ID" ]; then 13 | export USER_ID=99 14 | fi 15 | 16 | if [ -n "$USER_ENCRYPTED_PASSWORD" ]; then 17 | useradd -M -d $USER_HOME -p $USER_ENCRYPTED_PASSWORD -u $USER_ID $USER > /dev/null 18 | else 19 | useradd -M -d $USER_HOME -u $USER_ID $USER > /dev/null 20 | fi 21 | 22 | # expects a comma-separated string of the form GROUP1:GROUP1ID,GROUP2,GROUP3:GROUP3ID,... 23 | # (the GROUPID is optional, but needs to be separated from the group name by a ':') 24 | for i in $(echo $USER_GROUPS | sed "s/,/ /g") 25 | do 26 | if [[ $i == *":"* ]] 27 | then 28 | addgroup ${i%:*} # > /dev/null 29 | groupmod -g ${i#*:} ${i%:*} #> /dev/null 30 | adduser $USER ${i%:*} #> /dev/null 31 | else 32 | addgroup $i > /dev/null 33 | adduser $USER $i > /dev/null 34 | fi 35 | done 36 | 37 | # add user to sudo group 38 | adduser $USER sudo 39 | 40 | # set correct primary group 41 | if [ -n "$USER_GROUPS" ]; then 42 | group="$( cut -d ',' -f 1 <<< "$USER_GROUPS" )" 43 | if [[ $group == *":"* ]] 44 | then 45 | usermod -g ${group%:*} $USER & 46 | else 47 | usermod -g $group $USER & 48 | fi 49 | fi 50 | 51 | # set shell 52 | if [ -z "$USER_SHELL" ] 53 | then 54 | usermod -s "/bin/bash" $USER 55 | else 56 | usermod -s $USER_SHELL $USER 57 | fi 58 | 59 | if [ -n $CWD ]; then cd $CWD; fi 60 | echo "Running as user $USER" 61 | 62 | # set environment such that gpus can be used even in ssh connections 63 | echo "export CUDNN_VERSION=$CUDNN_VERSION" >> /etc/profile 64 | echo "export NVIDIA_REQUIRE_CUDA='$NVIDIA_REQUIRE_CUDA'" >> /etc/profile 65 | echo "export LIBRARY_PATH=$LIBRARY_PATH" >> /etc/profile 66 | echo "export LD_PRELOAD=$LD_PRELOAD" >> /etc/profile 67 | echo "export NVIDIA_VISIBLE_DEVICES=$NVIDIA_VISIBLE_DEVICES" >> /etc/profile 68 | echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> /etc/profile 69 | echo "export NVIDIA_DRIVER_CAPABILITIES=$NVIDIA_DRIVER_CAPABILITIES" >> /etc/profile 70 | echo "export PATH=$PATH" >> /etc/profile 71 | echo "export CUDA_PKG_VERSION=$CUDA_PKG_VERSION" >> /etc/profile 72 | 73 | exec gosu $USER "$@" 74 | else 75 | if [ -n $CWD ]; then cd $CWD; fi 76 | echo "Running as default container user" 77 | exec "$@" 78 | fi 79 | -------------------------------------------------------------------------------- /care_nl_ica/graph_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def indirect_causes(gt_jacobian_encoder) -> torch.Tensor: 5 | """ 6 | Calculates all indirect paths in the encoder (SEM/SCM) 7 | :param gt_jacobian_encoder: 8 | :return: 9 | """ 10 | 11 | # calculate the indirect cause mask 12 | eps = 1e-6 13 | direct_causes = torch.tril((gt_jacobian_encoder.abs() > eps).float(), -1) 14 | 15 | # add together the matrix powers of the adjacency matrix 16 | # this yields all indirect paths 17 | paths = graph_paths(direct_causes) 18 | 19 | indirect_causes = torch.stack(list(paths.values())).sum(0) 20 | 21 | indirect_causes = ( 22 | indirect_causes.bool().float() 23 | ) # convert all non-1 value to 1 (for safety) 24 | # correct for causes where both the direct and indirect paths are present 25 | indirect_causes = indirect_causes * ((indirect_causes - direct_causes) > 0).float() 26 | 27 | return indirect_causes, paths 28 | 29 | 30 | def graph_paths(direct_causes: torch.Tensor) -> dict: 31 | paths = dict() 32 | matrix_power = direct_causes.clone() 33 | 34 | for i in range(direct_causes.shape[0]): 35 | if matrix_power.sum() == 0: 36 | break 37 | 38 | paths[str(i)] = matrix_power 39 | matrix_power = matrix_power @ direct_causes 40 | 41 | return paths 42 | 43 | 44 | def false_positive_paths( 45 | dep_mat, gt_paths: dict, threshold: float = 1e-2, weighted: bool = False 46 | ) -> torch.Tensor: 47 | direct_causes = torch.tril((dep_mat.abs() > threshold).float(), -1) 48 | dep_mat_paths = graph_paths(direct_causes) 49 | 50 | weighting = lambda gt_path, path: ( 51 | (1 - gt_path) 52 | * path 53 | * (dep_mat if weighted is True else torch.ones_like(dep_mat)) 54 | ).sum() 55 | 56 | return torch.Tensor( 57 | [ 58 | weighting(gt_path, path) 59 | for gt_path, path in zip(gt_paths.values(), dep_mat_paths.values()) 60 | ] 61 | ) 62 | 63 | 64 | def false_negative_paths( 65 | dep_mat, gt_paths: dict, threshold: float = 1e-2, weighted: bool = False 66 | ) -> torch.Tensor: 67 | direct_causes = torch.tril((dep_mat.abs() > threshold).float(), -1) 68 | dep_mat_paths = graph_paths(direct_causes) 69 | 70 | weighting = lambda gt_path, path: ( 71 | (1 - path[gt_path.bool()]) 72 | * (dep_mat if weighted is True else torch.ones_like(dep_mat)) 73 | ).sum() 74 | 75 | return torch.Tensor( 76 | [ 77 | weighting(gt_path, path) 78 | for gt_path, path in zip(gt_paths.values(), dep_mat_paths.values()) 79 | ] 80 | ) 81 | -------------------------------------------------------------------------------- /care_nl_ica/losses/dsm.py: -------------------------------------------------------------------------------- 1 | ### conditional dsm objective 2 | # 3 | # this code is adapted from: https://github.com/ermongroup/ncsn/ 4 | # 5 | 6 | import torch 7 | import torch.autograd as autograd 8 | 9 | 10 | def dsm(energy_net, samples, sigma=1): 11 | samples.requires_grad_(True) 12 | vector = torch.randn_like(samples) * sigma 13 | perturbed_inputs = samples + vector 14 | logp = -energy_net(perturbed_inputs) 15 | dlogp = ( 16 | sigma**2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0] 17 | ) 18 | kernel = vector 19 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2 20 | loss = loss.mean() / 2.0 21 | 22 | return loss 23 | 24 | 25 | def cdsm(energy_net, samples, conditions, sigma=1.0): 26 | """ 27 | Conditional denoising score matching 28 | :param energy_net: an energy network that takes x and y as input and outputs energy of shape (batch_size,) 29 | :param samples: values of dependent variable x 30 | :param conditions: values of conditioning variable y 31 | :param sigma: noise level for dsm 32 | :return: cdsm loss of shape (batch_size,) 33 | """ 34 | samples.requires_grad_(True) 35 | vector = torch.randn_like(samples) * sigma 36 | perturbed_inputs = samples + vector 37 | logp = -energy_net(perturbed_inputs, conditions) 38 | assert logp.ndim == 1 39 | dlogp = ( 40 | sigma**2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0] 41 | ) 42 | kernel = vector 43 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2 44 | loss = loss.mean() / 2.0 45 | return loss 46 | 47 | 48 | def conditional_dsm(energy_net, samples, segLabels, energy_net_final_layer, sigma=1): 49 | samples.requires_grad_(True) 50 | vector = torch.randn_like(samples) * sigma 51 | perturbed_inputs = samples + vector 52 | 53 | d = samples.shape[-1] 54 | 55 | # apply conditioning 56 | logp = -energy_net(perturbed_inputs).view(-1, d * d) 57 | logp = torch.mm(logp, energy_net_final_layer) 58 | # take only relevant segment energy 59 | logp = logp[segLabels] 60 | 61 | dlogp = ( 62 | sigma**2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0] 63 | ) 64 | kernel = vector 65 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2 66 | loss = loss.mean() / 2.0 67 | 68 | return loss 69 | 70 | 71 | def dsm_score_estimation(scorenet, samples, sigma=0.01): 72 | perturbed_samples = samples + torch.randn_like(samples) * sigma 73 | target = -1 / (sigma**2) * (perturbed_samples - samples) 74 | scores = scorenet(perturbed_samples) 75 | target = target.view(target.shape[0], -1) 76 | scores = scores.view(scores.shape[0], -1) 77 | loss = 1 / 2.0 * ((scores - target) ** 2).sum(dim=-1).mean(dim=0) 78 | 79 | return loss 80 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | logger: 4 | class_path: pytorch_lightning.loggers.WandbLogger 5 | init_args: 6 | log_model: true 7 | project: experiment 8 | save_dir: 9 | callbacks: 10 | - class_path: pytorch_lightning.callbacks.EarlyStopping 11 | init_args: 12 | monitor: val_loss 13 | mode: min 14 | patience: 40 15 | auto_select_gpus: true 16 | checkpoint_callback: null 17 | enable_checkpointing: false 18 | default_root_dir: null 19 | gradient_clip_val: null 20 | gradient_clip_algorithm: null 21 | process_position: 0 22 | log_gpu_memory: null 23 | progress_bar_refresh_rate: null 24 | enable_progress_bar: true 25 | overfit_batches: 0.0 26 | track_grad_norm: -1 27 | check_val_every_n_epoch: 100 28 | fast_dev_run: false 29 | accumulate_grad_batches: null 30 | max_epochs: 5000 31 | min_epochs: null 32 | max_steps: -1 33 | min_steps: null 34 | max_time: null 35 | gpus: 1 36 | limit_train_batches: 400.0 37 | limit_val_batches: 5.0 38 | limit_test_batches: 1.0 39 | limit_predict_batches: 1.0 40 | val_check_interval: 1.0 41 | flush_logs_every_n_steps: null 42 | log_every_n_steps: 250 43 | accelerator: null 44 | strategy: null 45 | sync_batchnorm: false 46 | precision: 32 47 | enable_model_summary: true 48 | weights_summary: top 49 | weights_save_path: null 50 | num_sanity_val_steps: 2 51 | resume_from_checkpoint: null 52 | profiler: null 53 | benchmark: false 54 | deterministic: true 55 | reload_dataloaders_every_n_epochs: 0 56 | reload_dataloaders_every_epoch: false 57 | auto_lr_find: false 58 | replace_sampler_ddp: true 59 | auto_scale_batch_size: false 60 | prepare_data_per_node: null 61 | amp_backend: native 62 | amp_level: null 63 | move_metrics_to_cpu: false 64 | multiple_trainloader_mode: max_size_cycle 65 | stochastic_weight_avg: false 66 | detect_anomaly: true 67 | model: 68 | offline: true 69 | lr: 1e-4 70 | verbose: false 71 | p: 1.0 72 | tau: 1.0 73 | alpha: 0.5 74 | normalization: '' 75 | start_step: null 76 | use_bias: false 77 | data: 78 | n_mixing_layer: 1 79 | permute: false 80 | act_fct: leaky_relu 81 | use_sem: true 82 | data_gen_mode: rvs 83 | variant: 0 84 | nonlin_sem: false 85 | box_min: 0.0 86 | box_max: 1.0 87 | sphere_r: 1.0 88 | space_type: box 89 | m_p: 0 90 | c_p: 1 91 | m_param: 1.0 92 | c_param: 0.05 93 | batch_size: 6144 94 | latent_dim: 3 95 | normalize_latents: true 96 | use_dep_mat: true 97 | train_transforms: null 98 | val_transforms: null 99 | test_transforms: null 100 | dims: null 101 | diag_weight: 0.0 102 | notes: null 103 | tags: null 104 | ckpt_path: null 105 | -------------------------------------------------------------------------------- /care_nl_ica/prob_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def setup_marginal(args): 5 | device = args.device 6 | eta = torch.zeros(args.latent_dim) 7 | if args.space_type == "sphere": 8 | eta[0] = args.sphere_r 9 | 10 | if args.m_p == 1: 11 | sample_marginal = lambda space, size, device=device: space.laplace( 12 | eta, args.m_param, size, device 13 | ) 14 | elif args.m_p == 2: 15 | sample_marginal = lambda space, size, device=device: space.normal( 16 | eta, args.m_param, size, device 17 | ) 18 | elif args.m_p == 0: 19 | sample_marginal = lambda space, size, device=device: space.uniform( 20 | size, device=device 21 | ) 22 | else: 23 | sample_marginal = lambda space, size, device=device: space.generalized_normal( 24 | eta, args.m_param, p=args.m_p, size=size, device=device 25 | ) 26 | return sample_marginal 27 | 28 | 29 | def sample_marginal_and_conditional(latent_space, size, device): 30 | z = latent_space.sample_marginal(size=size, device=device) 31 | z_tilde = latent_space.sample_conditional(z, size=size, device=device) 32 | z_neg = latent_space.sample_marginal(size=size, device=device) 33 | 34 | return z, z_tilde, z_neg 35 | 36 | 37 | def setup_conditional(args): 38 | device = args.device 39 | if args.c_p == 1: 40 | sample_conditional = lambda space, z, size, device=device: space.laplace( 41 | z, args.c_param, size, device 42 | ) 43 | elif args.c_p == 2: 44 | sample_conditional = lambda space, z, size, device=device: space.normal( 45 | z, args.c_param, size, device 46 | ) 47 | elif args.c_p == 0: 48 | sample_conditional = ( 49 | lambda space, z, size, device=device: space.von_mises_fisher( 50 | z, 51 | args.c_param, 52 | size, 53 | device, 54 | ) 55 | ) 56 | else: 57 | sample_conditional = ( 58 | lambda space, z, size, device=device: space.generalized_normal( 59 | z, args.c_param, p=args.c_p, size=size, device=device 60 | ) 61 | ) 62 | return sample_conditional 63 | 64 | 65 | def laplace_log_cdf(x: torch.Tensor, signal_model: torch.distributions.laplace.Laplace): 66 | """ 67 | Log cdf of the Laplace distribution (numerically stable). 68 | Source: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/internal/special_math.py#L281 69 | 70 | :param x: tensor to be transformed 71 | :param signal_model: Laplace distribution from torch.distributions. 72 | """ 73 | 74 | # transform with location and scale 75 | x_tr = (x - signal_model.mean) / signal_model.scale 76 | 77 | # x < 0 78 | neg_res = torch.log(torch.tensor(0.5)) + x_tr 79 | 80 | # x >= 0 81 | pos_res = torch.log1p(-0.5 * (-x_tr.abs()).exp()) 82 | 83 | return torch.where(x < signal_model.mean, neg_res, pos_res) 84 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/layers.py: -------------------------------------------------------------------------------- 1 | """Additional layers not included in PyTorch.""" 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.modules.conv as conv 6 | from typing import Optional 7 | from typing_extensions import Literal 8 | 9 | 10 | class PositionalEncoding(nn.Module): 11 | """Add a positional encoding as two additional channels to the data.""" 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, x): 17 | pos = torch.stack( 18 | torch.meshgrid( 19 | torch.arange(x.shape[-2], dtype=torch.float, device=x.device), 20 | torch.arange(x.shape[-1], dtype=torch.float, device=x.device), 21 | ), 22 | 0, 23 | ) 24 | pos /= torch.max(pos) + 1e-12 25 | pos = torch.repeat_interleave(pos.unsqueeze(0), len(x), 0) 26 | 27 | return torch.cat((pos, x), 1) 28 | 29 | 30 | class Lambda(nn.Module): 31 | """Apply a lambda function to the input.""" 32 | 33 | def __init__(self, f): 34 | super().__init__() 35 | self.f = f 36 | 37 | def forward(self, *args, **kwargs): 38 | return self.f(*args, **kwargs) 39 | 40 | 41 | class Flatten(Lambda): 42 | """Flatten the input data after the batch dimension.""" 43 | 44 | def __init__(self): 45 | super().__init__(lambda x: x.view(len(x), -1)) 46 | 47 | 48 | class RescaleLayer(nn.Module): 49 | """Normalize the data to a hypersphere with fixed/variable radius.""" 50 | 51 | def __init__( 52 | self, init_r=1.0, fixed_r=False, mode: Optional[Literal["eq", "leq"]] = "eq" 53 | ): 54 | super().__init__() 55 | self.fixed_r = fixed_r 56 | assert mode in ("leq", "eq") 57 | self.mode = mode 58 | if fixed_r: 59 | self.r = torch.ones(1, requires_grad=False) * init_r 60 | else: 61 | self.r = nn.Parameter(torch.ones(1, requires_grad=True) * init_r) 62 | 63 | def forward(self, x): 64 | if self.mode == "eq": 65 | x = x / torch.norm(x, dim=-1, keepdim=True) 66 | x = x * self.r.to(x.device) 67 | elif self.mode == "leq": 68 | norm = torch.norm(x, dim=-1, keepdim=True) 69 | x[norm > self.r] /= torch.norm(x, dim=-1, keepdim=True) / self.r 70 | 71 | return x 72 | 73 | 74 | class SoftclipLayer(nn.Module): 75 | """Normalize the data to a hyperrectangle with fixed/learnable size.""" 76 | 77 | def __init__(self, n, init_abs_bound=1.0, fixed_abs_bound=True): 78 | super().__init__() 79 | self.fixed_abs_bound = fixed_abs_bound 80 | if fixed_abs_bound: 81 | self.max_abs_bound = torch.ones(n, requires_grad=False) * init_abs_bound 82 | else: 83 | self.max_abs_bound = nn.Parameter( 84 | torch.ones(n, requires_grad=True) * init_abs_bound 85 | ) 86 | 87 | def forward(self, x): 88 | x = torch.sigmoid(x) 89 | x = x * self.max_abs_bound.to(x.device).unsqueeze(0) 90 | 91 | return x 92 | -------------------------------------------------------------------------------- /care_nl_ica/models/ivae/ivae_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import optim 4 | from torch.utils.data import DataLoader 5 | 6 | from care_nl_ica.data import ConditionalDataset 7 | from .ivae_core import iVAE 8 | 9 | 10 | def IVAE_wrapper( 11 | X, 12 | U, 13 | batch_size=256, 14 | max_iter=7e4, 15 | seed=0, 16 | n_layers=3, 17 | hidden_dim=20, 18 | lr=1e-3, 19 | cuda=True, 20 | ckpt_file="ivae.pt", 21 | test=False, 22 | ): 23 | "args are the arguments from the main.py file" 24 | torch.manual_seed(seed) 25 | np.random.seed(seed) 26 | 27 | device = torch.device("cuda:0" if cuda else "cpu") 28 | # print('training on {}'.format(torch.cuda.get_device_name(device) if cuda else 'cpu')) 29 | 30 | # load data 31 | # print('Creating shuffled dataset..') 32 | dset = ConditionalDataset(X.astype(np.float32), U.astype(np.float32), device) 33 | loader_params = {"num_workers": 1, "pin_memory": True} if cuda else {} 34 | train_loader = DataLoader( 35 | dset, shuffle=True, batch_size=batch_size, **loader_params 36 | ) 37 | data_dim, latent_dim, aux_dim = dset.get_dims() 38 | N = len(dset) 39 | max_epochs = int(max_iter // len(train_loader) + 1) 40 | 41 | # define model and optimizer 42 | # print('Defining model and optimizer..') 43 | model = iVAE( 44 | latent_dim, 45 | data_dim, 46 | aux_dim, 47 | activation="lrelu", 48 | device=device, 49 | n_layers=n_layers, 50 | hidden_dim=hidden_dim, 51 | ) 52 | optimizer = optim.Adam(model.parameters(), lr=lr) 53 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 54 | optimizer, factor=0.1, patience=20, verbose=True 55 | ) 56 | 57 | # training loop 58 | if not test: 59 | print("Training..") 60 | it = 0 61 | model.train() 62 | while it < max_iter: 63 | elbo_train = 0 64 | epoch = it // len(train_loader) + 1 65 | for _, (x, u) in enumerate(train_loader): 66 | it += 1 67 | optimizer.zero_grad() 68 | x, u = x.to(device), u.to(device) 69 | elbo, z_est = model.elbo(x, u) 70 | elbo.mul(-1).backward() 71 | optimizer.step() 72 | elbo_train += -elbo.item() 73 | elbo_train /= len(train_loader) 74 | scheduler.step(elbo_train) 75 | # print('epoch {}/{} \tloss: {}'.format(epoch, max_epochs, elbo_train)) 76 | # save model checkpoint after training 77 | torch.save(model.state_dict(), ckpt_file) 78 | else: 79 | model = torch.load(ckpt_file, map_location=device) 80 | 81 | Xt, Ut = dset.x, dset.y 82 | decoder_params, encoder_params, z, prior_params = model( 83 | Xt.to(device), Ut.to(device) 84 | ) 85 | params = { 86 | "decoder": decoder_params, 87 | "encoder": encoder_params, 88 | "prior": prior_params, 89 | } 90 | 91 | return z, model, params 92 | -------------------------------------------------------------------------------- /care_nl_ica/models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from care_nl_ica.cl_ica import encoders, losses 5 | 6 | 7 | class ContrastiveLearningModel(nn.Module): 8 | def __init__(self, hparams): 9 | super().__init__() 10 | 11 | self.hparams = hparams 12 | 13 | self._setup_unmixing() 14 | self._setup_loss() 15 | 16 | torch.cuda.empty_cache() 17 | 18 | def parameters(self, recurse: bool = True): 19 | parameters = list(self.unmixing.parameters(recurse)) 20 | 21 | return parameters 22 | 23 | def _setup_unmixing(self): 24 | hparams = self.hparams 25 | 26 | ( 27 | output_normalization, 28 | output_normalization_kwargs, 29 | ) = self._configure_output_normalization() 30 | 31 | self.unmixing = encoders.get_mlp( 32 | n_in=hparams.latent_dim, 33 | n_out=hparams.latent_dim, 34 | layers=[ 35 | hparams.latent_dim * 10, 36 | hparams.latent_dim * 50, 37 | hparams.latent_dim * 50, 38 | hparams.latent_dim * 50, 39 | hparams.latent_dim * 50, 40 | hparams.latent_dim * 10, 41 | ], 42 | output_normalization=output_normalization, 43 | output_normalization_kwargs=output_normalization_kwargs, 44 | ) 45 | 46 | if self.hparams.verbose is True: 47 | print(f"{self.unmixing.detach()=}") 48 | 49 | self.unmixing = self.unmixing.to(hparams.device) 50 | 51 | def _setup_loss(self): 52 | hparams = self.hparams 53 | 54 | if hparams.p: 55 | self.loss = losses.LpSimCLRLoss( 56 | p=hparams.p, tau=hparams.tau, simclr_compatibility_mode=True 57 | ) 58 | else: 59 | self.loss = losses.SimCLRLoss( 60 | normalize=False, tau=hparams.tau, alpha=hparams.alpha 61 | ) 62 | 63 | def _configure_output_normalization(self): 64 | hparams = self.hparams 65 | output_normalization_kwargs = None 66 | if hparams.normalization == "learnable_box": 67 | output_normalization = "learnable_box" 68 | elif hparams.normalization == "fixed_box": 69 | output_normalization = "fixed_box" 70 | output_normalization_kwargs = dict( 71 | init_abs_bound=hparams.box_max - hparams.box_min 72 | ) 73 | elif hparams.normalization == "learnable_sphere": 74 | output_normalization = "learnable_sphere" 75 | elif hparams.normalization == "fixed_sphere": 76 | output_normalization = "fixed_sphere" 77 | output_normalization_kwargs = dict(init_r=hparams.sphere_r) 78 | elif hparams.normalization == "": 79 | print("Using no output normalization") 80 | output_normalization = None 81 | else: 82 | raise ValueError("Invalid output normalization:", hparams.normalization) 83 | return output_normalization, output_normalization_kwargs 84 | 85 | def forward(self, x): 86 | if isinstance(x, list) or isinstance(x, tuple): 87 | return tuple(map(self.unmixing, x)) 88 | else: 89 | return self.unmixing(x) 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7002143.svg)](https://doi.org/10.5281/zenodo.7002143) 4 | 5 | ![CI testing](https://github.com/rpatrik96/nl-causal-representations/workflows/Python%20package/badge.svg?branch=master&event=push) 6 | [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) 7 | 8 |
9 | 10 | # Jacobian-based Causal Discovery with Nonlinear ICA 11 | 12 | 13 | ## Description 14 | This is the code for the paper _Jacobian-based Causal Discovery with Nonlinear ICA_, demonstrating how identifiable representations (particularly, with Nonlinear ICA) can be used to extract the causal graph from an underlying structural equation model (SEM). 15 | 16 | ## Before running the code 17 | 18 | ### Singularity container build 19 | 20 | ```bash 21 | singularity build --fakeroot nv.sif nv.def 22 | ``` 23 | 24 | ### Logging 25 | 26 | 1. First, you need to log into `wandb` 27 | ```bash 28 | wandb login #you will find your API key at https://wandb.ai/authorize 29 | ``` 30 | 31 | 2. Second, you need to specify the project for logging, which you can in the `before_fit` method in [cli.py](https://github.com/rpatrik96/nl-causal-representations/blob/master/care_nl_ica/cli.py#L37) 32 | ```python 33 | def before_fit(self): 34 | if isinstance(self.trainer.logger, WandbLogger) is True: 35 | # required as the parser cannot parse the "-" symbol 36 | self.trainer.logger.__dict__["_wandb_init"][ 37 | "entity" 38 | ] = "causal-representation-learning" # <--- modify this line 39 | ``` 40 | 41 | 3. Then, you can create and run the sweep 42 | ```bash 43 | wandb sweep sweeps/sweep_file.yaml # returns sweep ID 44 | wandb agent --count= # when used on a cluster, set it to one and start multiple processes 45 | ``` 46 | 47 | 48 | ## Usage 49 | 50 | 1. Clone 51 | ```bash 52 | git clone https://github.com/rpatrik96/nl-causal-representations.git 53 | ``` 54 | 55 | 2. Install 56 | ```bash 57 | # install package 58 | pip3 install -e . 59 | 60 | # install requirements 61 | pip install -r requirements.txt 62 | 63 | # install pre-commit hooks 64 | pre-commit install 65 | ``` 66 | 67 | 3. Run: 68 | ```bash 69 | python3 care_nl_ica/cli.py fit --config configs/config.yaml 70 | ``` 71 | 72 | 73 | 74 | 75 | ### Code credits 76 | Our repo extensively relies on `cl-ica` [repo](https://github.com/brendel-group/cl-ica), so please consider citing the corresponding [paper](http://proceedings.mlr.press/v139/zimmermann21a/zimmermann21a.pdf) as well 77 | 78 | 79 | # Reference 80 | If you find our work useful, please consider citing our [TMLR paper](https://openreview.net/forum?id=2Yo9xqR6Ab) 81 | 82 | ```bibtex 83 | @article{reizinger2023jacobianbased, 84 | author = { 85 | Reizinger, Patrik and 86 | Sharma, Yash and 87 | Bethge, Matthias and 88 | Schölkopf, Bernhard and 89 | Huszár, Ferenc and 90 | Brendel, Wieland 91 | }, 92 | title = { 93 | Jacobian-based Causal Discovery with Nonlinear {ICA} 94 | }, 95 | journal={Transactions on Machine Learning Research}, 96 | issn={2835-8856}, 97 | year={2023}, 98 | url={https://openreview.net/forum?id=2Yo9xqR6Ab}, 99 | } 100 | ``` 101 | 102 | 103 | -------------------------------------------------------------------------------- /notebooks/sem_3d_sweep_vfv1je0d.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,fresh-sweep-40,3,True,3,1,True,False,False,False,0.9999342515022526,4.468616485595703 3 | 1,visionary-sweep-39,3,True,3,1,True,False,False,False,0.9999814500610972,4.498720169067383 4 | 2,gentle-sweep-37,3,True,2,1,True,False,False,False,0.9999192956769172,4.443269729614258 5 | 3,wandering-sweep-38,3,True,2,1,True,False,False,False,0.9999506773760876,4.468270778656006 6 | 4,hopeful-sweep-36,3,True,1,1,True,False,False,False,0.9999572169201444,4.464474678039551 7 | 5,blooming-sweep-35,3,True,0,1,True,False,False,False,0.9999769496598908,4.497834205627441 8 | 6,swept-sweep-57,3,True,4,1,True,True,False,False,0.999940480589484,4.435263633728027 9 | 7,breezy-sweep-53,3,True,3,1,True,True,False,False,0.999954985860554,4.445283889770508 10 | 8,dutiful-sweep-43,3,True,1,1,True,True,False,False,0.9999537439043902,4.445357799530029 11 | 9,northern-sweep-40,3,True,5,1,True,False,False,False,0.9999668243185336,4.4437408447265625 12 | 10,prime-sweep-37,3,True,5,1,True,False,False,False,0.9999650488552828,4.480136871337891 13 | 11,chocolate-sweep-35,3,True,4,1,True,False,False,False,0.9999592339829574,4.442595481872559 14 | 12,sweepy-sweep-34,3,True,4,1,True,False,False,False,0.999967825503246,4.498065948486328 15 | 13,revived-sweep-32,3,True,3,1,True,False,False,False,0.999943011614768,4.442612648010254 16 | 14,atomic-sweep-30,3,True,3,1,True,False,False,False,0.9999625010903596,4.460474967956543 17 | 15,fragrant-sweep-27,3,True,2,1,True,False,False,False,0.9999617712064278,4.498493194580078 18 | 16,peachy-sweep-26,3,True,2,1,True,False,False,False,0.999970478718227,4.460423946380615 19 | 17,denim-sweep-60,3,True,5,1,True,True,False,False,0.999945124117764,4.445318698883057 20 | 18,clean-sweep-59,3,True,5,1,True,True,False,False,0.9999522559673818,4.434645175933838 21 | 19,comfy-sweep-58,3,True,5,1,True,True,False,False,0.99995990874136,4.486793041229248 22 | 20,misty-sweep-47,3,True,3,1,True,True,False,False,0.9999712113982224,4.453330993652344 23 | 21,decent-sweep-46,3,True,3,1,True,True,False,False,0.9999685169207044,4.481067657470703 24 | 22,firm-sweep-45,3,True,2,1,True,True,False,False,0.9999079993539652,4.4449462890625 25 | 23,icy-sweep-36,3,True,1,1,True,True,False,False,0.999988256279146,4.465059757232666 26 | 24,jolly-sweep-33,3,True,0,1,True,True,False,False,0.9999406827692296,4.4866814613342285 27 | 25,silvery-sweep-35,3,True,0,1,True,True,False,False,0.9999402565917168,4.444202423095703 28 | 26,quiet-sweep-34,3,True,0,1,True,True,False,False,0.9999787443038824,4.4347991943359375 29 | 27,likely-sweep-27,3,True,5,1,True,False,False,False,0.9999564122987868,4.45283842086792 30 | 28,autumn-sweep-25,3,True,4,1,True,False,False,False,0.9999708795240304,4.443685531616211 31 | 29,still-sweep-21,3,True,4,1,True,False,False,False,0.9999677868259944,4.460321426391602 32 | 30,lyric-sweep-20,3,True,3,1,True,False,False,False,0.99995008620559,4.463925361633301 33 | 31,light-sweep-15,3,True,2,1,True,False,False,False,0.999959367579348,4.464412212371826 34 | 32,vivid-sweep-8,3,True,1,1,True,False,False,False,0.999920813637962,4.443802356719971 35 | 33,classic-sweep-9,3,True,1,1,True,False,False,False,0.999951786052966,4.468830108642578 36 | 34,gallant-sweep-6,3,True,1,1,True,False,False,False,0.999981342050152,4.464670181274414 37 | 35,whole-sweep-7,3,True,1,1,True,False,False,False,0.9999760623145736,4.49854850769043 38 | 36,lively-sweep-3,3,True,0,1,True,False,False,False,0.9999642529591248,4.443384170532227 39 | 37,zesty-sweep-5,3,True,0,1,True,False,False,False,0.9999621276126484,4.463892936706543 40 | 38,devout-sweep-4,3,True,0,1,True,False,False,False,0.9999522251170536,4.491568088531494 41 | 39,neat-sweep-1,3,True,0,1,True,False,False,False,0.999934881915256,4.455718994140625 42 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 2 | ARG PYTHON_VERSION=3.6 3 | 4 | # Set the time zone correctly 5 | ENV TZ=Europe/Berlin 6 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 7 | 8 | ENV SHELL /bin/bash 9 | 10 | RUN apt-get update && apt-get install -y --no-install-recommends \ 11 | build-essential \ 12 | cmake \ 13 | git \ 14 | curl \ 15 | vim \ 16 | ca-certificates \ 17 | libjpeg-dev \ 18 | tmux \ 19 | nano \ 20 | xterm \ 21 | rsync \ 22 | zip \ 23 | zsh \ 24 | htop \ 25 | screen \ 26 | zlib1g-dev \ 27 | libcurl3-dev \ 28 | libfreetype6-dev \ 29 | libpng12-dev \ 30 | libzmq3-dev \ 31 | libpng-dev \ 32 | libglib2.0-0 \ 33 | openssh-server \ 34 | sudo \ 35 | && apt-get clean \ 36 | && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 37 | 38 | 39 | RUN curl -o ~/miniconda.sh -L -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 40 | chmod +x ~/miniconda.sh && \ 41 | ~/miniconda.sh -b -p /opt/conda && \ 42 | rm ~/miniconda.sh && \ 43 | /opt/conda/bin/conda install -y python=$PYTHON_VERSION && \ 44 | /opt/conda/bin/conda install -y -c pytorch magma-cuda100 && \ 45 | /opt/conda/bin/conda clean -ya 46 | 47 | ENV PATH /opt/conda/bin:$PATH 48 | RUN pip install ninja tqdm 49 | RUN conda install pytorch torchvision torchaudio cudatoolkit=10.1 -c pytorch 50 | RUN conda install numpy pyyaml scipy ipython mkl mkl-include cython typing 51 | RUN pip install jupyterlab 52 | 53 | RUN conda create -n pytorch_cpu python=3.6 54 | RUN . /opt/conda/etc/profile.d/conda.sh && \ 55 | conda activate pytorch_cpu && \ 56 | conda install pytorch torchvision torchaudio cpuonly -c pytorch && \ 57 | pip install tqdm pillow requests jupyterlab && \ 58 | conda install numpy pyyaml scipy ipython mkl mkl-include cython typing 59 | 60 | RUN pip install matplotlib 61 | RUN . /opt/conda/etc/profile.d/conda.sh && \ 62 | conda activate pytorch_cpu && \ 63 | pip install matplotlib 64 | 65 | #RUN pip install tensorflow-gpu==1.15 lucid 66 | #RUN . /opt/conda/etc/profile.d/conda.sh && \ 67 | # conda activate pytorch_cpu && \ 68 | # pip install tensorflow==1.15 lucid 69 | 70 | RUN pip install git+https://github.com/VLL-HD/FrEIA.git typing_extensions sklearn -q 71 | RUN . /opt/conda/etc/profile.d/conda.sh && \ 72 | conda activate pytorch_cpu && \ 73 | pip install git+https://github.com/VLL-HD/FrEIA.git typing_extensions sklearn -q 74 | 75 | RUN conda install faiss-gpu -c pytorch 76 | RUN . /opt/conda/etc/profile.d/conda.sh && \ 77 | conda activate pytorch_cpu && \ 78 | conda install faiss-cpu -c pytorch 79 | 80 | 81 | ### add tmux config 82 | COPY tmux.conf /etc/ 83 | 84 | RUN conda clean --all -y 85 | RUN rm -rf ~/.cache/pip 86 | 87 | # Enable passwordless sudo for all users 88 | RUN echo '%sudo ALL=(ALL:ALL) NOPASSWD:ALL' >> /etc/sudoers 89 | 90 | RUN apt-get update && apt-get install -y --no-install-recommends \ 91 | wget \ 92 | && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 93 | 94 | # Setup gosu (https://github.com/tianon/gosu) 95 | # gosu is an improved version of su which behaves better inside docker 96 | # we use it to dynamically switch to the desired user in the entrypoint 97 | # (see below) 98 | ENV GOSU_VERSION 1.10 99 | RUN set -x \ 100 | && dpkgArch="$(dpkg --print-architecture | awk -F- '{ print $NF }')" \ 101 | && wget -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-$dpkgArch" \ 102 | && chmod +x /usr/local/bin/gosu \ 103 | && gosu nobody true 104 | 105 | COPY entrypoint.sh /usr/local/bin/ 106 | RUN chmod a+x /usr/local/bin/entrypoint.sh 107 | 108 | ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] 109 | 110 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/evaluate_disentanglement.py: -------------------------------------------------------------------------------- 1 | """Modified https://github.com/bethgelab/slow_disentanglement/blob/master/scripts/evaluate_disentanglement.py""" 2 | import torch 3 | import numpy as np 4 | import gin.tf 5 | 6 | gin.enter_interactive_mode() 7 | import time 8 | import os 9 | 10 | # from scripts.model import reparametrize 11 | from kitti_masks.model import reparametrize 12 | 13 | # from scripts.model import BetaVAE_H as BetaVAE 14 | from kitti_masks.model import BetaVAE_H as BetaVAE 15 | from disentanglement_lib.utils import results 16 | 17 | # needed later: 18 | 19 | 20 | def main(args, dataset): 21 | device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") 22 | net = BetaVAE 23 | net = net(args.z_dim, args.num_channel, args.box_norm).to(device) 24 | file_path = os.path.join(args.ckpt_dir, "last") 25 | checkpoint = torch.load(file_path) 26 | net.load_state_dict(checkpoint["model_states"]["net"]) 27 | 28 | def mean_rep(x): 29 | distributions = net._encode(torch.from_numpy(x).float().to(device)) 30 | # mu = distributions[:, :net.z_dim] 31 | # logvar = distributions[:, net.z_dim:] 32 | mu = distributions 33 | return np.array(mu.detach().cpu()) 34 | 35 | def sample_rep(x): 36 | distributions = net._encode(torch.from_numpy(x).float().to(device)) 37 | mu = distributions[:, : net.z_dim] 38 | logvar = distributions[:, net.z_dim :] 39 | return np.array(reparametrize(mu, logvar).detach().cpu()) 40 | 41 | @gin.configurable("evaluation") 42 | def evaluate( 43 | post, output_dir, evaluation_fn=gin.REQUIRED, random_seed=gin.REQUIRED, name="" 44 | ): 45 | experiment_timer = time.time() 46 | assert post == "mean" or post == "sampled" 47 | results_dict = evaluation_fn( 48 | dataset, 49 | mean_rep if post == "mean" else sample_rep, 50 | random_state=np.random.RandomState(random_seed), 51 | ) 52 | results_dict["elapsed_time"] = time.time() - experiment_timer 53 | results.update_result_directory(output_dir, "evaluation", results_dict) 54 | 55 | random_state = np.random.RandomState(0) 56 | config_dir = "metric_configs" 57 | eval_config_files = [ 58 | f for f in os.listdir(config_dir) if not (f.startswith(".") or "others" in f) 59 | ] 60 | t0 = time.time() 61 | posts = ["mean"] 62 | for post in posts: 63 | for eval_config in eval_config_files: 64 | metric_name = os.path.basename(eval_config).replace(".gin", "") 65 | continuous = False 66 | if args.dataset == "kittimasks" or ( 67 | args.dataset == "natural" and not args.natural_discrete 68 | ): 69 | continuous = True 70 | if continuous: 71 | if metric_name != "mcc": 72 | continue 73 | contains = True 74 | if args.specify: 75 | contains = False 76 | for specific in args.specify.split("_"): 77 | if specific in metric_name: 78 | contains = True 79 | break 80 | if contains: 81 | if args.verbose: 82 | print("Computing metric '{}' on '{}'...".format(metric_name, post)) 83 | eval_bindings = [ 84 | "evaluation.random_seed = {}".format(random_state.randint(2**32)), 85 | "evaluation.name = '{}'".format(metric_name), 86 | ] 87 | gin.parse_config_files_and_bindings( 88 | [os.path.join(config_dir, eval_config)], eval_bindings 89 | ) 90 | output_dir = os.path.join( 91 | args.output_dir, "evaluation", args.ckpt_name, post, metric_name 92 | ) 93 | evaluate(post, output_dir) 94 | gin.clear_config() 95 | if args.verbose: 96 | print("took", time.time() - t0, "s") 97 | t0 = time.time() 98 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/model.py: -------------------------------------------------------------------------------- 1 | """Modified https://github.com/bethgelab/slow_disentanglement/blob/master/scripts/model.py""" 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | from torch.autograd import Variable 6 | 7 | import layers 8 | 9 | 10 | def reparametrize(mu, logvar): 11 | std = logvar.div(2).exp() 12 | eps = Variable(std.data.new(std.size()).normal_()) 13 | return mu + std * eps 14 | 15 | 16 | class View(nn.Module): 17 | def __init__(self, size): 18 | super(View, self).__init__() 19 | self.size = size 20 | 21 | def forward(self, tensor): 22 | return tensor.view(self.size) 23 | 24 | 25 | class BetaVAE_H(nn.Module): 26 | """Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017).""" 27 | 28 | def __init__(self, z_dim=10, nc=3, box_norm=False): 29 | super(BetaVAE_H, self).__init__() 30 | self.z_dim = z_dim 31 | self.nc = nc 32 | if box_norm: 33 | non_periodic_rescale_layer = layers.SoftclipLayer( 34 | n=z_dim, init_abs_bound=1.0, fixed_abs_bound=False 35 | ) 36 | else: 37 | non_periodic_rescale_layer = layers.Lambda(lambda x: x) 38 | self.encoder = nn.Sequential( 39 | nn.Conv2d(nc, 32, 4, 2, 1), # B, 32, 32, 32 40 | nn.ReLU(True), 41 | nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 16, 16 42 | nn.ReLU(True), 43 | nn.Conv2d(32, 64, 4, 2, 1), # B, 64, 8, 8 44 | nn.ReLU(True), 45 | nn.Conv2d(64, 64, 4, 2, 1), # B, 64, 4, 4 46 | nn.ReLU(True), 47 | nn.Conv2d(64, 256, 4, 1), # B, 256, 1, 1 48 | nn.ReLU(True), 49 | View((-1, 256 * 1 * 1)), # B, 256 50 | # nn.Linear(256, z_dim*2), # B, z_dim*2 51 | nn.Linear(256, z_dim), 52 | non_periodic_rescale_layer, 53 | ) 54 | """ 55 | self.decoder = nn.Sequential( 56 | nn.Linear(z_dim, 256), # B, 256 57 | View((-1, 256, 1, 1)), # B, 256, 1, 1 58 | nn.ReLU(True), 59 | nn.ConvTranspose2d(256, 64, 4), # B, 64, 4, 4 60 | nn.ReLU(True), 61 | nn.ConvTranspose2d(64, 64, 4, 2, 1), # B, 64, 8, 8 62 | nn.ReLU(True), 63 | nn.ConvTranspose2d(64, 32, 4, 2, 1), # B, 32, 16, 16 64 | nn.ReLU(True), 65 | nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 32, 32 66 | nn.ReLU(True), 67 | nn.ConvTranspose2d(32, nc, 4, 2, 1), # B, nc, 64, 64 68 | ) 69 | """ 70 | self.weight_init() 71 | 72 | def weight_init(self): 73 | for block in self._modules: 74 | for m in self._modules[block]: 75 | kaiming_init(m) 76 | 77 | def forward(self, x, return_z=False): 78 | distributions = self._encode(x) 79 | return distributions 80 | """ 81 | mu = distributions[:, :self.z_dim] 82 | logvar = distributions[:, self.z_dim:] 83 | z = reparametrize(mu, logvar) 84 | x_recon = self._decode(z) 85 | 86 | if return_z: 87 | return x_recon, mu, logvar, z 88 | else: 89 | return x_recon, mu, logvar 90 | """ 91 | 92 | def _encode(self, x): 93 | return self.encoder(x) 94 | 95 | def _decode(self, z): 96 | return self.decoder(z) 97 | 98 | 99 | def kaiming_init(m): 100 | if isinstance(m, (nn.Linear, nn.Conv2d)): 101 | init.kaiming_normal(m.weight) 102 | if m.bias is not None: 103 | m.bias.data.fill_(0) 104 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 105 | m.weight.data.fill_(1) 106 | if m.bias is not None: 107 | m.bias.data.fill_(0) 108 | 109 | 110 | def normal_init(m, mean, std): 111 | if isinstance(m, (nn.Linear, nn.Conv2d)): 112 | m.weight.data.normal_(mean, std) 113 | if m.bias.data is not None: 114 | m.bias.data.zero_() 115 | elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): 116 | m.weight.data.fill_(1) 117 | if m.bias.data is not None: 118 | m.bias.data.zero_() 119 | -------------------------------------------------------------------------------- /care_nl_ica/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import Dict, Literal 4 | 5 | import numpy as np 6 | import pip 7 | import torch 8 | 9 | 10 | def unpack_item_list(lst): 11 | if isinstance(lst, tuple): 12 | lst = list(lst) 13 | result_list = [] 14 | for it in lst: 15 | if isinstance(it, (tuple, list)): 16 | result_list.append(unpack_item_list(it)) 17 | else: 18 | result_list.append(it.item()) 19 | return result_list 20 | 21 | 22 | def setup_seed(seed): 23 | if seed is not None: 24 | np.random.seed(seed) 25 | random.seed(seed) 26 | torch.manual_seed(seed) 27 | 28 | 29 | def save_state_dict(args, model, pth="g.pth"): 30 | if args.save_dir: 31 | if not os.path.exists(args.save_dir): 32 | os.makedirs(args.save_dir) 33 | torch.save(model.state_dict(), os.path.join(args.save_dir, pth)) 34 | 35 | 36 | def set_learning_mode(args): 37 | if args.mode == "unsupervised": 38 | learning_modes = [False] 39 | elif args.mode == "supervised": 40 | learning_modes = [True] 41 | else: 42 | learning_modes = [True, False] 43 | 44 | args.learning_modes = learning_modes 45 | 46 | 47 | def set_device(args) -> None: 48 | device = "cuda" 49 | if not torch.cuda.is_available() or args.no_cuda is True: 50 | device = "cpu" 51 | 52 | if args.verbose is True: 53 | print(f"{device=}") 54 | 55 | args.device = device 56 | 57 | 58 | def matrix_to_dict(matrix, name, panel_name=None, triangular=False) -> Dict[str, float]: 59 | if matrix is not None: 60 | if triangular is False: 61 | labels = [ 62 | f"{name}_{i}{j}" 63 | if panel_name is None 64 | else f"{panel_name}/{name}_{i}{j}" 65 | for i in range(matrix.shape[0]) 66 | for j in range(matrix.shape[1]) 67 | ] 68 | else: 69 | labels = [ 70 | f"{name}_{i}{j}" 71 | if panel_name is None 72 | else f"{panel_name}/{name}_{i}{j}" 73 | for i in range(matrix.shape[0]) 74 | for j in range(i + 1) 75 | ] 76 | data = ( 77 | matrix.detach() 78 | .cpu() 79 | .reshape( 80 | -1, 81 | ) 82 | .tolist() 83 | ) 84 | 85 | return {key: val for key, val in zip(labels, data)} 86 | 87 | 88 | OutputNormalizationType = Literal[ 89 | "", "fixed_box", "learnable_box", "fixed_sphere", "learnable_sphere" 90 | ] 91 | SpaceType = Literal["box", "sphere", "unbounded"] 92 | DataGenType = Literal["rvs", "pcl", "offset"] 93 | 94 | 95 | def add_tags(args): 96 | try: 97 | args.tags 98 | except: 99 | args.tags = [] 100 | 101 | if args.tags is None: 102 | args.tags = [] 103 | 104 | if args.data.use_sem is True: 105 | args.tags.append("sem") 106 | 107 | if args.data.nonlin_sem is True: 108 | args.tags.append("nonlinear") 109 | else: 110 | args.tags.append("linear") 111 | 112 | if args.data.permute is True: 113 | args.tags.append("permute") 114 | 115 | if args.model.normalize_latents is True: 116 | args.tags.append("normalization") 117 | 118 | return list(set(args.tags)) 119 | 120 | 121 | def get_cuda_stats(where): 122 | import pynvml 123 | 124 | pynvml.nvmlInit() 125 | handle = pynvml.nvmlDeviceGetHandleByIndex(int(0)) 126 | info = pynvml.nvmlDeviceGetMemoryInfo(handle) 127 | 128 | print(where) 129 | 130 | print(f"total : {info.total // 1024 ** 2}") 131 | print(f"free : {info.free // 1024 ** 2}") 132 | print(f"used : {info.used // 1024 ** 2}") 133 | 134 | 135 | def install_package(): 136 | """ 137 | Install the current package to ensure that imports work. 138 | """ 139 | try: 140 | import care_nl_ica 141 | except: 142 | print("Package not installed, installing...") 143 | pip.main( 144 | [ 145 | "install", 146 | f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}", 147 | "--upgrade", 148 | ] 149 | ) 150 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/mcc_metric/metric.py: -------------------------------------------------------------------------------- 1 | """Mean Correlation Coefficient from Hyvarinen & Morioka 2 | """ 3 | from absl import logging 4 | from disentanglement_lib.evaluation.metrics import utils 5 | import numpy as np 6 | import gin.tf 7 | import scipy as sp 8 | from kitti_masks.mcc_metric.munkres import Munkres 9 | 10 | 11 | def correlation(x, y, method="Pearson"): 12 | """Evaluate correlation 13 | Args: 14 | x: data to be sorted 15 | y: target data 16 | Returns: 17 | corr_sort: correlation matrix between x and y (after sorting) 18 | sort_idx: sorting index 19 | x_sort: x after sorting 20 | method: correlation method ('Pearson' or 'Spearman') 21 | """ 22 | 23 | print("Calculating correlation...") 24 | 25 | x = x.copy() 26 | y = y.copy() 27 | dim = x.shape[0] 28 | 29 | # Calculate correlation ----------------------------------- 30 | if method == "Pearson": 31 | corr = np.corrcoef(y, x) 32 | corr = corr[0:dim, dim:] 33 | elif method == "Spearman": 34 | corr, pvalue = sp.stats.spearmanr(y.T, x.T) 35 | corr = corr[0:dim, dim:] 36 | 37 | # Sort ---------------------------------------------------- 38 | munk = Munkres() 39 | indexes = munk.compute(-np.absolute(corr)) 40 | 41 | sort_idx = np.zeros(dim) 42 | x_sort = np.zeros(x.shape) 43 | for i in range(dim): 44 | sort_idx[i] = indexes[i][1] 45 | x_sort[i, :] = x[indexes[i][1], :] 46 | 47 | # Re-calculate correlation -------------------------------- 48 | if method == "Pearson": 49 | corr_sort = np.corrcoef(y, x_sort) 50 | corr_sort = corr_sort[0:dim, dim:] 51 | elif method == "Spearman": 52 | corr_sort, pvalue = sp.stats.spearmanr(y.T, x_sort.T) 53 | corr_sort = corr_sort[0:dim, dim:] 54 | 55 | return corr_sort, sort_idx, x_sort 56 | 57 | 58 | @gin.configurable( 59 | "mcc", 60 | blacklist=[ 61 | "ground_truth_data", 62 | "representation_function", 63 | "random_state", 64 | "artifact_dir", 65 | ], 66 | ) 67 | def compute_mcc( 68 | ground_truth_data, 69 | representation_function, 70 | random_state, 71 | artifact_dir=None, 72 | num_train=gin.REQUIRED, 73 | correlation_fn=gin.REQUIRED, 74 | batch_size=16, 75 | ): 76 | """Computes the mean correlation coefficient. 77 | 78 | Args: 79 | ground_truth_data: GroundTruthData to be sampled from. 80 | representation_function: Function that takes observations as input and 81 | outputs a dim_representation sized representation for each observation. 82 | random_state: Numpy random state used for randomness. 83 | artifact_dir: Optional path to directory where artifacts can be saved. 84 | num_train: Number of points used for training. 85 | batch_size: Batch size for sampling. 86 | 87 | Returns: 88 | Dict with mcc stats 89 | """ 90 | del artifact_dir 91 | logging.info("Generating training set.") 92 | mus_train, ys_train = utils.generate_batch_factor_code( 93 | ground_truth_data, representation_function, num_train, random_state, batch_size 94 | ) 95 | assert mus_train.shape[1] == num_train 96 | return _compute_mcc(mus_train, ys_train, correlation_fn, random_state) 97 | 98 | 99 | def _compute_mcc(mus_train, ys_train, correlation_fn, random_state): 100 | """Computes score based on both training and testing codes and factors.""" 101 | score_dict = {} 102 | result = np.zeros(mus_train.shape) 103 | result[: ys_train.shape[0], : ys_train.shape[1]] = ys_train 104 | 105 | for i in range(len(mus_train) - len(ys_train)): 106 | result[ys_train.shape[0] + i, :] = random_state.normal(size=ys_train.shape[1]) 107 | 108 | corr_sorted, sort_idx, mu_sorted = correlation( 109 | mus_train, result, method=correlation_fn 110 | ) 111 | score_dict["meanabscorr"] = np.mean(np.abs(np.diag(corr_sorted)[: len(ys_train)])) 112 | 113 | for i in range(len(corr_sorted)): 114 | for j in range(len(corr_sorted[0])): 115 | score_dict["corr_sorted_{}{}".format(i, j)] = corr_sorted[i][j] 116 | 117 | for i in range(len(sort_idx)): 118 | score_dict["sort_idx_{}".format(i)] = sort_idx[i] 119 | 120 | return score_dict 121 | -------------------------------------------------------------------------------- /notebooks/sem_10d_permute_sweep_at138q9q.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,avid-disco-1026,10,True,1875424,1,True,False,False,False,0.9983341568922408,0.03388404846191406 3 | 1,logical-sun-1027,10,True,1257846,1,True,False,False,False,0.9977168806977076,0.03010082244873047 4 | 2,faithful-sky-1028,10,True,657894,1,True,False,False,False,0.5665253662225751,1.9482862949371336 5 | 3,polar-capybara-1029,10,True,657894,1,True,False,False,False,0.998623243279074,0.027126312255859375 6 | 4,young-voice-1030,10,True,0,1,True,False,False,False,0.998879020993942,0.02951145172119141 7 | 5,fluent-bird-1031,10,True,3628600,1,True,True,False,False,0.9976653016722092,0.03337574005126953 8 | 6,celestial-donkey-1032,10,True,2645120,1,True,True,False,False,0.9983975127515916,0.031569480895996094 9 | 7,gentle-waterfall-1033,10,True,2645120,1,True,True,False,False,0.5016068996848807,3.0604326725006104 10 | 8,stilted-shadow-1034,10,True,2645120,1,True,True,False,False,0.99789983149808,0.03174877166748047 11 | 9,stilted-firefly-1035,10,True,1875424,1,True,True,False,False,0.998128787009014,0.03840923309326172 12 | 10,stoic-night-1036,10,True,1875424,1,True,True,False,False,0.578531677579783,2.0678510665893555 13 | 11,kind-thunder-1037,10,True,1875424,1,True,True,False,False,0.93353840620663,0.07629680633544922 14 | 12,zesty-night-1038,10,True,1875424,1,True,True,False,False,0.9977001427522548,0.042072296142578125 15 | 13,unique-resonance-1039,10,True,1257846,1,True,True,False,False,0.43464553624681407,4.450830459594727 16 | 14,dandy-lake-1040,10,True,1257846,1,True,True,False,False,0.8957182521453648,0.08549785614013672 17 | 15,exalted-sponge-1041,10,True,1257846,1,True,True,False,False,0.9971554320886372,0.03620433807373047 18 | 16,cosmic-microwave-1042,10,True,657894,1,True,True,False,False,0.9214585223938072,0.0818634033203125 19 | 17,snowy-hill-1043,10,True,657894,1,True,True,False,False,0.5111606211568637,3.0626487731933594 20 | 18,legendary-lion-1044,10,True,0,1,True,True,False,False,0.9452960430570314,0.08781814575195312 21 | 19,pious-pyramid-1045,10,True,0,1,True,True,False,False,0.9983308118998522,0.03724956512451172 22 | 20,easy-hill-1046,10,True,0,1,True,True,False,False,0.895622732550257,0.07729721069335938 23 | 21,pretty-moon-1047,10,True,0,1,True,True,False,False,0.9098679857577728,0.08410358428955078 24 | 22,daily-thunder-1048,10,True,3628600,1,True,False,False,False,0.9189409737955572,0.08632373809814453 25 | 23,wandering-frost-1049,10,True,3628600,1,True,False,False,False,0.9983997540017708,0.028326988220214844 26 | 24,sunny-mountain-1050,10,True,3628600,1,True,False,False,False,0.5223164111735606,3.0562312602996826 27 | 25,wild-dream-1051,10,True,3628600,1,True,False,False,False,0.9984776063394925,0.037484169006347656 28 | 26,balmy-feather-1052,10,True,3628600,1,True,False,False,False,0.9983932736811468,0.033855438232421875 29 | 27,dutiful-darkness-1053,10,True,2645120,1,True,False,False,False,0.9966642480196592,0.037670135498046875 30 | 28,morning-sun-1054,10,True,2645120,1,True,False,False,False,0.9986003568372032,0.02772045135498047 31 | 29,feasible-morning-1055,10,True,1875424,1,True,False,False,False,0.9987155621834108,0.035594940185546875 32 | 30,polished-pond-1056,10,True,1875424,1,True,False,False,False,0.5018518962215212,3.156581163406372 33 | 31,olive-monkey-1057,10,True,1875424,1,True,False,False,False,0.998320658033261,0.02749156951904297 34 | 32,autumn-bird-1058,10,True,1257846,1,True,False,False,False,0.9985764614778552,0.02850818634033203 35 | 33,lilac-meadow-1059,10,True,1257846,1,True,False,False,False,0.5059723457792604,3.066004514694214 36 | 34,dazzling-feather-1060,10,True,1257846,1,True,False,False,False,0.9988128967104164,0.02695751190185547 37 | 35,winter-hill-1061,10,True,1257846,1,True,False,False,False,0.997694246591539,0.035035133361816406 38 | 36,dashing-thunder-1062,10,True,657894,1,True,False,False,False,0.9982684824643238,0.029120445251464844 39 | 37,hopeful-water-1063,10,True,657894,1,True,False,False,False,0.9984400863462308,0.028257369995117188 40 | 38,volcanic-field-1064,10,True,657894,1,True,False,False,False,0.9983215020813604,0.03453254699707031 41 | 39,vital-wave-1065,10,True,0,1,True,False,False,False,0.9983303065696216,0.02803516387939453 42 | 40,solar-microwave-1066,10,True,0,1,True,False,False,False,0.9983405833411052,0.03751087188720703 43 | 41,blooming-salad-1067,10,True,0,1,True,False,False,False,0.9985522463979432,0.030382156372070312 44 | 42,golden-microwave-1068,10,True,0,1,True,False,False,False,0.5620545488919371,1.8969221115112305 45 | -------------------------------------------------------------------------------- /notebooks/sem_8d_permute_sweep_05whlpmk.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,golden-frost-785,8,True,36755,1,True,True,False,False,0.7870476626064031,1.0004196166992188 3 | 1,splendid-wildflower-786,8,True,36755,1,True,True,False,False,0.7360380641673412,1.9006996154785156 4 | 2,resilient-lion-787,8,True,36755,1,True,True,False,False,0.99912524360027,0.2155933380126953 5 | 3,polar-deluge-788,8,True,36755,1,True,True,False,False,0.9993331229853856,0.2247653007507324 6 | 4,lemon-glade-789,8,True,21345,1,True,True,False,False,0.7436701532816978,1.9569127559661863 7 | 5,stellar-tree-790,8,True,21345,1,True,True,False,False,0.8747439041034752,0.47955942153930664 8 | 6,faithful-universe-791,8,True,21345,1,True,True,False,False,0.9990683610089762,0.20560073852539065 9 | 7,balmy-surf-792,8,True,12450,1,True,True,False,False,0.7685320785413183,1.0127801895141602 10 | 8,jolly-pond-793,8,True,12450,1,True,True,False,False,0.6612866949212535,1.7968366146087646 11 | 9,pleasant-snowball-794,8,True,12450,1,True,True,False,False,0.725842104647481,1.955034255981445 12 | 10,graceful-puddle-795,8,True,12450,1,True,True,False,False,0.999190881981798,0.21725988388061523 13 | 11,fancy-universe-796,8,True,12450,1,True,True,False,False,0.9994212392635882,0.2027297019958496 14 | 12,logical-terrain-797,8,True,8754,1,True,True,False,False,0.8023882435290682,0.9988770484924316 15 | 13,radiant-blaze-798,8,True,8754,1,True,True,False,False,0.7120243058616726,1.3677153587341309 16 | 14,azure-wood-799,8,True,8754,1,True,True,False,False,0.9994784140167694,0.2152671813964844 17 | 15,toasty-haze-800,8,True,8754,1,True,True,False,False,0.9991351783593232,0.2244772911071777 18 | 16,quiet-lake-801,8,True,575,1,True,True,False,False,0.8013751352531858,0.9983444213867188 19 | 17,rose-terrain-802,8,True,575,1,True,True,False,False,0.7403898042056241,1.315999984741211 20 | 18,apricot-galaxy-803,8,True,575,1,True,True,False,False,0.9989420310715548,0.2329006195068359 21 | 19,visionary-dust-804,8,True,0,1,True,True,False,False,0.8001806797267774,1.0164227485656738 22 | 20,devoted-star-805,8,True,0,1,True,True,False,False,0.7131030555331175,1.9478073120117188 23 | 21,worthy-elevator-806,8,True,0,1,True,True,False,False,0.9986368165998738,0.2202730178833008 24 | 22,dutiful-butterfly-807,8,True,0,1,True,True,False,False,0.999502073345061,0.22344207763671875 25 | 23,lively-microwave-808,8,True,36755,1,True,False,False,False,0.999220260285246,0.21310853958129883 26 | 24,summer-yogurt-809,8,True,36755,1,True,False,False,False,0.7343398291368135,1.8402633666992188 27 | 25,northern-eon-810,8,True,36755,1,True,False,False,False,0.6745059154624304,1.4644536972045898 28 | 26,northern-yogurt-811,8,True,36755,1,True,False,False,False,0.9994202045137288,0.21187210083007812 29 | 27,proud-surf-812,8,True,36755,1,True,False,False,False,0.9996143923406624,0.19968605041503903 30 | 28,vocal-thunder-813,8,True,21345,1,True,False,False,False,0.9995160618581956,0.21488666534423828 31 | 29,light-vortex-814,8,True,21345,1,True,False,False,False,0.7589819893357959,1.1175751686096191 32 | 30,expert-armadillo-815,8,True,21345,1,True,False,False,False,0.999412151376227,0.2114906311035156 33 | 31,restful-dew-816,8,True,12450,1,True,False,False,False,0.9995392922072224,0.21223068237304688 34 | 32,lemon-eon-817,8,True,12450,1,True,False,False,False,0.7063054124179535,1.870588302612305 35 | 33,warm-moon-818,8,True,21345,1,True,False,False,False,0.9996033566210312,0.2002120018005371 36 | 34,gallant-breeze-819,8,True,12450,1,True,False,False,False,0.9989459497136968,0.21749353408813477 37 | 35,summer-snow-820,8,True,8754,1,True,False,False,False,0.9993328835265878,0.21512365341186523 38 | 36,flowing-bee-821,8,True,8754,1,True,False,False,False,0.6042372812790955,3.068920850753784 39 | 37,comic-salad-822,8,True,12450,1,True,False,False,False,0.99951846079956,0.1997842788696289 40 | 38,wild-sky-823,8,True,8754,1,True,False,False,False,0.999477441801906,0.20042896270751953 41 | 39,proud-capybara-824,8,True,575,1,True,False,False,False,0.9991039695317347,0.2176990509033203 42 | 40,confused-plant-825,8,True,575,1,True,False,False,False,0.691128721409549,1.8724606037139893 43 | 41,magic-shape-826,8,True,575,1,True,False,False,False,0.9994559180151016,0.21233606338500977 44 | 42,toasty-dust-827,8,True,575,1,True,False,False,False,0.9995909303714188,0.20017671585083008 45 | 43,worthy-spaceship-828,8,True,0,1,True,False,False,False,0.9992727598460192,0.2156882286071777 46 | 44,jolly-paper-829,8,True,0,1,True,False,False,False,0.999545418902575,0.2117609977722168 47 | 45,astral-aardvark-830,8,True,0,1,True,False,False,False,0.9996048160297684,0.20085525512695312 48 | -------------------------------------------------------------------------------- /notebooks/sem_5d_sweep_f5nxtdxz.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,snowy-sweep-56,5,True,119,1,True,True,False,False,0.9996499051994444,1.911044359207153 3 | 1,true-sweep-55,5,True,97,1,True,True,False,False,0.9998941951030564,1.962398052215576 4 | 2,breezy-sweep-53,5,True,97,1,True,True,False,False,0.999866300576302,1.93372106552124 5 | 3,upbeat-sweep-51,5,True,97,1,True,True,False,False,0.8702603415088127,2.94165563583374 6 | 4,trim-sweep-52,5,True,97,1,True,True,False,False,0.9996855466801964,1.9286811351776123 7 | 5,wise-sweep-48,5,True,75,1,True,True,False,False,0.994490484546619,2.0351738929748535 8 | 6,ruby-sweep-49,5,True,75,1,True,True,False,False,0.9999118592029588,1.883225440979004 9 | 7,rosy-sweep-50,5,True,75,1,True,True,False,False,0.9998608095768386,1.965829610824585 10 | 8,celestial-sweep-45,5,True,31,1,True,True,False,False,0.9996597244939468,1.9706602096557615 11 | 9,cool-sweep-43,5,True,31,1,True,True,False,False,0.823421892355371,3.0211195945739746 12 | 10,dark-sweep-44,5,True,31,1,True,True,False,False,0.9997995883818558,1.8853399753570557 13 | 11,vital-sweep-41,5,True,31,1,True,True,False,False,0.9997076818037208,1.9155378341674805 14 | 12,polar-sweep-42,5,True,31,1,True,True,False,False,0.9997808272892456,1.920363187789917 15 | 13,young-sweep-40,5,True,5,1,True,True,False,False,0.9935377136153132,2.044177293777466 16 | 14,dashing-sweep-37,5,True,5,1,True,True,False,False,0.8168110646412708,2.944557428359986 17 | 15,proud-sweep-39,5,True,5,1,True,True,False,False,0.9998844071870052,1.886756896972656 18 | 16,sunny-sweep-38,5,True,5,1,True,True,False,False,0.99976656848642,1.9395532608032229 19 | 17,icy-sweep-35,5,True,0,1,True,True,False,False,0.9995648418193144,1.978658676147461 20 | 18,sparkling-sweep-36,5,True,5,1,True,True,False,False,0.9996755789753784,1.914533615112305 21 | 19,deft-sweep-34,5,True,0,1,True,True,False,False,0.8174083118706899,2.9354665279388428 22 | 20,polar-sweep-33,5,True,0,1,True,True,False,False,0.9997882424833694,1.9382030963897705 23 | 21,breezy-sweep-30,5,True,119,1,True,False,False,False,0.9999355538294332,1.9596738815307615 24 | 22,classic-sweep-32,5,True,0,1,True,True,False,False,0.8198915003133844,2.966867685317993 25 | 23,copper-sweep-31,5,True,0,1,True,True,False,False,0.9997722409060924,1.9082589149475095 26 | 24,prime-sweep-28,5,True,119,1,True,False,False,False,0.9999188581026828,1.932438850402832 27 | 25,dashing-sweep-27,5,True,119,1,True,False,False,False,0.9999079590320092,1.9181571006774905 28 | 26,warm-sweep-25,5,True,97,1,True,False,False,False,0.999966780741623,1.959846019744873 29 | 27,logical-sweep-24,5,True,97,1,True,False,False,False,0.999955017868418,1.880798578262329 30 | 28,resilient-sweep-23,5,True,97,1,True,False,False,False,0.9999330353475632,1.9318504333496096 31 | 29,jumping-sweep-21,5,True,97,1,True,False,False,False,0.999820990883572,1.904160976409912 32 | 30,firm-sweep-22,5,True,97,1,True,False,False,False,0.9999171401411682,1.917438507080078 33 | 31,fearless-sweep-19,5,True,75,1,True,False,False,False,0.9999234708418284,1.8802132606506348 34 | 32,elated-sweep-18,5,True,75,1,True,False,False,False,0.9999196889782134,1.9317889213562012 35 | 33,dashing-sweep-20,5,True,75,1,True,False,False,False,0.9999449039595056,1.9598019123077393 36 | 34,silver-sweep-16,5,True,75,1,True,False,False,False,0.999834863136319,1.903186559677124 37 | 35,sage-sweep-17,5,True,75,1,True,False,False,False,0.999917115836632,1.91690993309021 38 | 36,morning-sweep-15,5,True,31,1,True,False,False,False,0.9999577963987886,1.959930419921875 39 | 37,polished-sweep-13,5,True,31,1,True,False,False,False,0.999894347665088,1.931706190109253 40 | 38,exalted-sweep-14,5,True,31,1,True,False,False,False,0.9999484014031094,1.8802127838134768 41 | 39,visionary-sweep-12,5,True,31,1,True,False,False,False,0.9999076880041662,1.917842388153076 42 | 40,earnest-sweep-11,5,True,31,1,True,False,False,False,0.9998738612434436,1.902634620666504 43 | 41,golden-sweep-9,5,True,5,1,True,False,False,False,0.999941533872309,1.8802099227905271 44 | 42,stilted-sweep-7,5,True,5,1,True,False,False,False,0.9999102913517128,1.91702938079834 45 | 43,polar-sweep-5,5,True,0,1,True,False,False,False,0.999927763907165,1.9610137939453125 46 | 44,wobbly-sweep-6,5,True,5,1,True,False,False,False,0.9998481879567616,1.9024906158447263 47 | 45,ancient-sweep-4,5,True,0,1,True,False,False,False,0.9999019257760446,1.880422592163086 48 | 46,desert-sweep-3,5,True,0,1,True,False,False,False,0.9998995823806212,1.9316697120666504 49 | 47,playful-sweep-2,5,True,0,1,True,False,False,False,0.9998983227571792,1.917553186416626 50 | 48,lively-sweep-1,5,True,0,1,True,False,False,False,0.9998883137472196,1.902991771697998 51 | -------------------------------------------------------------------------------- /notebooks/sem_5d_permute_sweep_rv7yo1qy.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,dazzling-sweep-60,5,True,119,1,True,True,False,False,0.8742166913675901,3.087957859039306 3 | 1,hardy-sweep-59,5,True,119,1,True,True,False,False,0.9992950206086946,1.8929150104522705 4 | 2,legendary-sweep-58,5,True,119,1,True,True,False,False,0.9998347448728876,1.93897008895874 5 | 3,vocal-sweep-53,5,True,97,1,True,True,False,False,0.9997956083325032,1.9444646835327148 6 | 4,chocolate-sweep-55,5,True,97,1,True,True,False,False,0.9991106807020784,1.9985361099243164 7 | 5,devout-sweep-54,5,True,97,1,True,True,False,False,0.9993266201765986,1.8953027725219729 8 | 6,peach-sweep-56,5,True,119,1,True,True,False,False,0.9998215461075076,1.908497095108032 9 | 7,effortless-sweep-52,5,True,97,1,True,True,False,False,0.9997651188213444,1.923126459121704 10 | 8,earnest-sweep-57,5,True,119,1,True,True,False,False,0.9998209840559452,1.9217238426208496 11 | 9,restful-sweep-51,5,True,97,1,True,True,False,False,0.9997012101032964,1.9148337841033936 12 | 10,serene-sweep-50,5,True,75,1,True,True,False,False,0.8501413339127124,3.0400595664978027 13 | 11,radiant-sweep-49,5,True,75,1,True,True,False,False,0.9991810946243016,1.8940644264221191 14 | 12,cerulean-sweep-44,5,True,31,1,True,True,False,False,0.9998914260608046,1.8876442909240725 15 | 13,clear-sweep-43,5,True,31,1,True,True,False,False,0.999800486546609,1.9419090747833252 16 | 14,trim-sweep-42,5,True,31,1,True,True,False,False,0.9998617421995636,1.9215517044067385 17 | 15,still-sweep-41,5,True,31,1,True,True,False,False,0.9996364255994936,1.9162847995758057 18 | 16,vital-sweep-40,5,True,5,1,True,True,False,False,0.9992987373865097,1.987121105194092 19 | 17,dutiful-sweep-39,5,True,5,1,True,True,False,False,0.9992363666605142,1.8910212516784668 20 | 18,neat-sweep-38,5,True,5,1,True,True,False,False,0.9998267638238868,1.9380648136138916 21 | 19,lyric-sweep-37,5,True,5,1,True,True,False,False,0.9998318013586658,1.9241962432861328 22 | 20,expert-sweep-34,5,True,0,1,True,True,False,False,0.9993582123422764,1.893308401107788 23 | 21,rare-sweep-33,5,True,0,1,True,True,False,False,0.9994050747388996,1.9483914375305176 24 | 22,wandering-sweep-32,5,True,0,1,True,True,False,False,0.9997525776611698,1.9267096519470217 25 | 23,wobbly-sweep-31,5,True,0,1,True,True,False,False,0.9997282468177572,1.9094431400299072 26 | 24,super-sweep-30,5,True,119,1,True,False,False,False,0.9997812256431584,1.9662752151489256 27 | 25,stilted-sweep-29,5,True,119,1,True,False,False,False,0.9998813875491512,1.8807268142700195 28 | 26,volcanic-sweep-28,5,True,119,1,True,False,False,False,0.9998618276142204,1.9327852725982664 29 | 27,toasty-sweep-27,5,True,119,1,True,False,False,False,0.9999008834176388,1.9184045791625977 30 | 28,vital-sweep-26,5,True,119,1,True,False,False,False,0.9997712024733764,1.9034337997436523 31 | 29,genial-sweep-25,5,True,97,1,True,False,False,False,0.8714708978597268,3.0373737812042236 32 | 30,cosmic-sweep-24,5,True,97,1,True,False,False,False,0.9998170503701528,1.881947040557861 33 | 31,fancy-sweep-23,5,True,97,1,True,False,False,False,0.9999062764670932,1.9328882694244385 34 | 32,rural-sweep-22,5,True,97,1,True,False,False,False,0.9998996156998048,1.9187803268432615 35 | 33,lemon-sweep-21,5,True,97,1,True,False,False,False,0.9997291840491233,1.9045186042785645 36 | 34,wild-sweep-20,5,True,75,1,True,False,False,False,0.9992671787716664,1.9818334579467771 37 | 35,graceful-sweep-19,5,True,75,1,True,False,False,False,0.9999342192477751,1.8819737434387207 38 | 36,electric-sweep-18,5,True,75,1,True,False,False,False,0.9999128942466676,1.9326307773590088 39 | 37,wobbly-sweep-16,5,True,75,1,True,False,False,False,0.9998074917828628,1.902494430541992 40 | 38,clear-sweep-15,5,True,31,1,True,False,False,False,0.849621685113642,3.0278584957122803 41 | 39,glad-sweep-12,5,True,31,1,True,False,False,False,0.9998587770204846,1.9181864261627195 42 | 40,kind-sweep-13,5,True,31,1,True,False,False,False,0.999914997138711,1.9331581592559817 43 | 41,visionary-sweep-14,5,True,31,1,True,False,False,False,0.9999006501626482,1.8816828727722168 44 | 42,unique-sweep-7,5,True,5,1,True,False,False,False,0.9998606652349916,1.9189696311950684 45 | 43,earthy-sweep-6,5,True,5,1,True,False,False,False,0.999806018149872,1.9040741920471191 46 | 44,curious-sweep-5,5,True,0,1,True,False,False,False,0.9995376654534928,1.968914031982422 47 | 45,splendid-sweep-1,5,True,0,1,True,False,False,False,0.9998311645399612,1.9033708572387695 48 | 46,trim-sweep-3,5,True,0,1,True,False,False,False,0.9999524627177776,1.9334661960601809 49 | 47,icy-sweep-2,5,True,0,1,True,False,False,False,0.9998855164639356,1.918325901031494 50 | 48,deft-sweep-4,5,True,0,1,True,False,False,False,0.9997457993188376,1.8816308975219729 51 | -------------------------------------------------------------------------------- /care_nl_ica/independence/hsic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class HSIC(object): 6 | def __init__(self, num_permutations: int, alpha: float = 0.05): 7 | """ 8 | 9 | :param num_permutations: number of index permutations 10 | :param alpha: type 1 error level 11 | 12 | """ 13 | 14 | self.num_permutations = num_permutations 15 | self.alpha = alpha 16 | 17 | @staticmethod 18 | def rbf(x: torch.Tensor, y: torch.Tensor, ls: float) -> torch.Tensor: 19 | """ 20 | Calculates the RBF kernel in a vectorized form 21 | 22 | :param x: tensor of the first sample in the form of (num_samples, num_dim) 23 | :param y: tensor of the first sample in the form of (num_samples, num_dim) 24 | :param ls: lenght scale of the RBF kernel 25 | """ 26 | 27 | # calc distances 28 | dists_sq = torch.cdist(x, y).pow(2) 29 | 30 | return torch.exp(-dists_sq / ls**2) 31 | 32 | def test_statistics( 33 | self, x: torch.Tensor, y: torch.Tensor, ls_x: float, ls_y: float 34 | ) -> torch.Tensor: 35 | """ 36 | Calculates the HSIC test statistics according to the code at 37 | http://www.gatsby.ucl.ac.uk/~gretton/indepTestFiles/indep.htm 38 | 39 | :param x: tensor of the first sample in the form of (num_samples, num_dim) 40 | :param y: tensor of the first sample in the form of (num_samples, num_dim) 41 | :param ls_x: lenght scale of the x RBF kernel 42 | :param ls_y: lenght scale of the y RBF kernel 43 | """ 44 | 45 | num_samples = x.shape[0] 46 | 47 | # calculate the RBF kernel values 48 | kernel_x = self.rbf(x, x, ls_x) 49 | kernel_y = self.rbf(y, y, ls_y) 50 | 51 | H = ( 52 | torch.eye(num_samples, device=x.device) 53 | - torch.ones(num_samples, num_samples, device=x.device) / num_samples 54 | ) 55 | 56 | return torch.trace(H @ kernel_x @ H @ kernel_y) / num_samples**2 57 | 58 | @staticmethod 59 | def calc_ls(x: torch.Tensor) -> torch.Tensor: 60 | """ 61 | Calculates the lemght scale based on the median distance between points 62 | 63 | :param x: tensor of the first sample in the form of (num_samples, num_dim) 64 | """ 65 | dists = F.pdist(x) 66 | 67 | return torch.median(dists) 68 | 69 | def run_test( 70 | self, 71 | x: torch.Tensor, 72 | y: torch.Tensor, 73 | ls_x: float = None, 74 | ls_y: float = None, 75 | bonferroni: int = 1, 76 | verbose=False, 77 | ) -> bool: 78 | """ 79 | Runs the HSIC test with randomly permuting the indices of y. 80 | 81 | :param verbose: whether to print to CLI 82 | :param x: tensor of the first sample in the form of (num_samples, num_dim) 83 | :param y: tensor of the second sample in the form of (num_samples, num_dim) 84 | :param ls_x: lenght scale of the x RBF kernel 85 | :param ls_y: lenght scale of the y RBF kernel 86 | :param bonferroni: Bonferroni correction coefficient (= #hypotheses) 87 | 88 | :return bool whether H0 (x and y are independent) holds 89 | """ 90 | if not torch.is_tensor(x): 91 | x = torch.from_numpy(x) 92 | if not torch.is_tensor(y): 93 | y = torch.from_numpy(y) 94 | 95 | if len(x.shape) == 1: 96 | x = x.unsqueeze(1) 97 | if len(y.shape) == 1: 98 | y = y.unsqueeze(1) 99 | 100 | x = x.float() 101 | y = y.float() 102 | 103 | if ls_x is None: 104 | ls_x = self.calc_ls(x) 105 | if ls_y is None: 106 | ls_y = self.calc_ls(y) 107 | 108 | alpha_corr = self.alpha / bonferroni 109 | 110 | stat_no_perm = self.test_statistics(x, y, ls_x, ls_y) 111 | 112 | num_samples = x.shape[0] 113 | 114 | # calculate test statistics for the permutations 115 | stats = [] 116 | for _ in range(self.num_permutations): 117 | idx = torch.randperm(num_samples) 118 | 119 | stats.append(self.test_statistics(x, y[idx], ls_x, ls_y)) 120 | 121 | stats = torch.tensor(stats, device=x.device) 122 | crit_val = torch.quantile(stats, 1 - alpha_corr) 123 | 124 | p = (stats > stat_no_perm).sum() / self.num_permutations 125 | 126 | if verbose is True: 127 | print(f"p={p:.3f}, critical value={crit_val:.3f}") 128 | print(f"The null hypothesis (x and y is independent) is {p > crit_val}") 129 | 130 | return p > crit_val 131 | -------------------------------------------------------------------------------- /care_nl_ica/metrics/dep_mat.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | from torchmetrics.utilities.data import METRIC_EPS 6 | 7 | 8 | def correct_jacobian_permutations( 9 | dep_mat: torch.Tensor, ica_permutation: torch.Tensor, sem_permutation: torch.Tensor 10 | ) -> torch.Tensor: 11 | return ica_permutation @ dep_mat @ sem_permutation 12 | 13 | 14 | def jacobian_edge_accuracy( 15 | dep_mat: torch.Tensor, gt_jacobian_unmixing: torch.Tensor 16 | ) -> torch.Tensor: 17 | """ 18 | Calculates the accuracy of detecting edges based on the GT Jacobian and the estimated one such that the smallest N 19 | absolute elements of the estimated `dep_mat` are set to 0, where N is the number of zeros in `gt_jacobian_unmixing` 20 | :param dep_mat: 21 | :param gt_jacobian_unmixing: 22 | :return: 23 | """ 24 | num_zeros = (gt_jacobian_unmixing == 0.0).sum() 25 | 26 | # query indices of smallest absolute values 27 | zero_idx = ( 28 | dep_mat.abs() 29 | .view( 30 | -1, 31 | ) 32 | .sort()[1][:num_zeros] 33 | ) 34 | 35 | # zero them out 36 | dep_mat.view( 37 | -1, 38 | )[zero_idx] = 0 39 | 40 | return (dep_mat.bool() == gt_jacobian_unmixing.bool()).float().mean() 41 | 42 | 43 | class JacobianBinnedPrecisionRecall(Metric): 44 | """ 45 | Based on https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/classification/binned_precision_recall.py#L45-L184 46 | """ 47 | 48 | TPs: torch.Tensor 49 | FPs: torch.Tensor 50 | FNs: torch.Tensor 51 | 52 | def __init__( 53 | self, 54 | num_thresholds: Optional[int] = None, 55 | thresholds: Union[int, torch.Tensor, List[float], None] = None, 56 | log_base: Optional[float] = 10.0, 57 | start=-4, 58 | compute_on_step: Optional[bool] = None, 59 | **kwargs: Dict[str, Any], 60 | ) -> None: 61 | super().__init__(compute_on_step=compute_on_step, **kwargs) 62 | 63 | if isinstance(num_thresholds, int): 64 | self.num_thresholds = num_thresholds 65 | if log_base > 1: 66 | thresholds = torch.logspace( 67 | start, 0, self.num_thresholds, base=log_base 68 | ) 69 | else: 70 | thresholds = torch.linspace(10**start, 1.0, num_thresholds) 71 | self.register_buffer("thresholds", thresholds) 72 | elif thresholds is not None: 73 | if not isinstance(thresholds, (list, torch.Tensor)): 74 | raise ValueError( 75 | "Expected argument `thresholds` to either be an integer, list of floats or a tensor" 76 | ) 77 | thresholds = ( 78 | torch.tensor(thresholds) if isinstance(thresholds, list) else thresholds 79 | ) 80 | self.num_thresholds = thresholds.numel() 81 | self.register_buffer("thresholds", thresholds) 82 | 83 | for name in ("TPs", "FPs", "FNs"): 84 | self.add_state( 85 | name=name, 86 | default=torch.zeros(self.num_thresholds, dtype=torch.float32), 87 | dist_reduce_fx="sum", 88 | ) 89 | 90 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore 91 | """ 92 | Args 93 | preds: (n_samples,) tensor 94 | target: (n_samples, ) tensor 95 | """ 96 | preds, target = ( 97 | preds.reshape( 98 | -1, 99 | ).abs(), 100 | target.reshape( 101 | -1, 102 | ).abs(), 103 | ) 104 | 105 | assert preds.shape == target.shape 106 | 107 | if (pred_max := preds.max()) != 1.0: 108 | preds /= pred_max 109 | 110 | target = target.bool() 111 | # Iterate one threshold at a time to conserve memory 112 | for i in range(self.num_thresholds): 113 | predictions = preds >= self.thresholds[i] 114 | self.TPs[i] += (target & predictions).sum(dim=0) 115 | self.FPs[i] += ((~target) & (predictions)).sum(dim=0) 116 | self.FNs[i] += (target & (~predictions)).sum(dim=0) 117 | 118 | def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 119 | """Returns float tensor of size n_classes.""" 120 | precisions = (self.TPs + METRIC_EPS) / (self.TPs + self.FPs + METRIC_EPS) 121 | recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) 122 | 123 | return precisions, recalls, self.thresholds 124 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/spaces_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for spaces.""" 2 | 3 | import torch 4 | import numpy as np 5 | from typing import Callable 6 | 7 | 8 | def spherical_to_cartesian(r, phi): 9 | """Convert spherical coordinates to cartesian coordinates.""" 10 | 11 | must_convert_to_torch = False 12 | if isinstance(phi, np.ndarray): 13 | must_convert_to_torch = True 14 | phi = torch.Tensor(phi) 15 | 16 | if isinstance(r, (int, float)): 17 | r = torch.ones((len(phi))) * r 18 | 19 | must_flatten = False 20 | if len(phi.shape) == 1: 21 | phi = phi.reshape(1, -1) 22 | must_flatten = True 23 | 24 | a = torch.cat((torch.ones((len(phi), 1), device=phi.device) * 2 * np.pi, phi), 1) 25 | si = torch.sin(a) 26 | si[:, 0] = 1 27 | si = torch.cumprod(si, dim=1) 28 | co = torch.cos(a) 29 | co = torch.roll(co, -1, dims=1) 30 | 31 | result = si * co * r.unsqueeze(-1) 32 | 33 | if must_flatten: 34 | result = result[0] 35 | 36 | if must_convert_to_torch: 37 | result = result.numpy() 38 | 39 | return result 40 | 41 | 42 | def cartesian_to_spherical(x): 43 | """Convert cartesian to spherical coordinates.""" 44 | 45 | must_convert_to_torch = False 46 | if isinstance(x, np.ndarray): 47 | must_convert_to_torch = True 48 | x = torch.Tensor(x) 49 | 50 | must_flatten = False 51 | if len(x.shape) == 1: 52 | x = x.reshape(1, -1, 1) 53 | must_flatten = True 54 | 55 | T = np.triu(np.ones((x.shape[1], x.shape[1]))) 56 | T = torch.Tensor(T).to(x.device) 57 | 58 | rs = torch.matmul(T, (x.unsqueeze(-1) ** 2)).reshape(x.shape) 59 | rs = torch.sqrt(rs) 60 | 61 | rs[rs == 0] = 1 62 | 63 | phi = torch.acos(torch.clamp(x / rs, -1, 1))[:, :-1] 64 | 65 | # if x.shape[-1] > 2: 66 | phi[:, -1] = phi[:, -1] + (2 * np.pi - 2 * phi[:, -1]) * (x[:, -1] <= 0).float() 67 | 68 | rs = rs[:, 0] 69 | 70 | if must_convert_to_torch: 71 | rs = rs.numpy() 72 | phi = phi.numpy() 73 | 74 | if must_flatten: 75 | result = rs[0], phi[0] 76 | else: 77 | result = rs, phi 78 | 79 | return result 80 | 81 | 82 | def sample_generalized_normal(mean: torch.Tensor, lbd: float, p: int, shape): 83 | """Sample from a generalized Normal distribution. 84 | Modified according to: 85 | https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/GeneralizedNormal 86 | 87 | Args: 88 | mean: Mean of the distribution. 89 | lbd: Parameter controlling the standard deviation of the distribution. 90 | p: Exponent of the distribution. 91 | shape: Shape of the samples to generate. 92 | """ 93 | 94 | assert isinstance(lbd, float) 95 | 96 | ipower = 1.0 / p 97 | gamma_dist = torch.distributions.Gamma(ipower, 1.0) 98 | gamma_sample = gamma_dist.rsample(shape) 99 | # could speed up operations, but doesnt.... 100 | # gamma_sample = torch._standard_gamma(torch.ones(shape) * ipower) 101 | binary_sample = torch.randint(low=0, high=2, size=shape, dtype=mean.dtype) * 2 - 1 102 | sampled = binary_sample * torch.pow(torch.abs(gamma_sample), ipower) 103 | return mean + lbd * sampled.to(mean.device) 104 | 105 | 106 | def truncated_rejection_resampling( 107 | sampler_fn: Callable, 108 | min_: float, 109 | max_: float, 110 | size: int, 111 | n: int, 112 | buffer_size_factor: int = 1, 113 | device: str = "cpu", 114 | ): 115 | """ 116 | Args: 117 | sampler_fn: 118 | min_: Min value of the support. 119 | max_: Max value of the support. 120 | size: Number of samples to generate. 121 | n: Dimensionality of the samples. 122 | buffer_size_factor: How many more samples to generate 123 | first to select the feasible ones and samples from them. 124 | device: Torch device. 125 | """ 126 | 127 | result = torch.ones((size, n), device=device) * np.nan 128 | finished_mask = ~torch.isnan(result) 129 | while torch.sum(finished_mask).item() < n * size: 130 | # get samples from sampler_fn w/o truncation 131 | buffer = sampler_fn(size * buffer_size_factor) 132 | buffer = buffer.view(buffer_size_factor, size, n) 133 | # check which samples are within the feasible set 134 | buffer_mask = (buffer >= min_) & (buffer <= max_) 135 | # calculate how many samples to use 136 | 137 | for i in range(buffer_size_factor): 138 | copy_mask = buffer_mask[i] & (~finished_mask) 139 | result[copy_mask] = buffer[i][copy_mask] 140 | finished_mask[copy_mask] = True 141 | 142 | return result 143 | -------------------------------------------------------------------------------- /notebooks/sem_5d_permute_sweep_x6chdc63.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,sweet-sweep-59,5,True,119,1,True,True,False,False,0.9996525329447454,1.8861751556396484 3 | 1,curious-sweep-58,5,True,119,1,True,True,False,False,0.9998784814143926,1.9373292922973633 4 | 2,cool-sweep-57,5,True,119,1,True,True,False,False,0.9998620282153284,1.921082019805908 5 | 3,skilled-sweep-56,5,True,119,1,True,True,False,False,0.999715478812013,1.9114770889282229 6 | 4,logical-sweep-55,5,True,97,1,True,True,False,False,0.9993620814231148,1.985690355300903 7 | 5,logical-sweep-54,5,True,97,1,True,True,False,False,0.9994954279485436,1.884931564331055 8 | 6,cerulean-sweep-53,5,True,97,1,True,True,False,False,0.9998204046660634,1.9368278980255127 9 | 7,daily-sweep-52,5,True,97,1,True,True,False,False,0.9998707701232512,1.920581102371216 10 | 8,hopeful-sweep-50,5,True,75,1,True,True,False,False,0.7951709323603757,3.0717906951904297 11 | 9,amber-sweep-49,5,True,75,1,True,True,False,False,0.9998710015129452,1.886138916015625 12 | 10,jolly-sweep-48,5,True,75,1,True,True,False,False,0.9998757116007576,1.9361495971679688 13 | 11,wild-sweep-47,5,True,75,1,True,True,False,False,0.9998625652600656,1.9206926822662351 14 | 12,serene-sweep-46,5,True,75,1,True,True,False,False,0.9997365402421534,1.9125754833221436 15 | 13,exalted-sweep-45,5,True,31,1,True,True,False,False,0.8795508543898165,3.042447805404663 16 | 14,iconic-sweep-44,5,True,31,1,True,True,False,False,0.9996180502543012,1.8884527683258057 17 | 15,eager-sweep-43,5,True,31,1,True,True,False,False,0.9997146460982862,1.938737392425537 18 | 16,autumn-sweep-42,5,True,31,1,True,True,False,False,0.9999004590624192,1.920802354812622 19 | 17,wandering-sweep-41,5,True,31,1,True,True,False,False,0.9997387904195564,1.9078187942504885 20 | 18,vital-sweep-40,5,True,5,1,True,True,False,False,0.9990543057090226,1.985670566558838 21 | 19,spring-sweep-39,5,True,5,1,True,True,False,False,0.9997678730756446,1.883044719696045 22 | 20,still-sweep-38,5,True,5,1,True,True,False,False,0.9998427780887252,1.935593605041504 23 | 21,gallant-sweep-37,5,True,5,1,True,True,False,False,0.9998776276779578,1.921824932098389 24 | 22,mild-sweep-36,5,True,5,1,True,True,False,False,0.9998156203425412,1.9052958488464355 25 | 23,colorful-sweep-34,5,True,0,1,True,True,False,False,0.999609707268512,1.886894702911377 26 | 24,genial-sweep-30,5,True,119,1,True,False,False,False,0.9993471253908446,1.9703075885772705 27 | 25,driven-sweep-28,5,True,119,1,True,False,False,False,0.9999208254898224,1.9328854084014893 28 | 26,soft-sweep-27,5,True,119,1,True,False,False,False,0.999839412812903,1.9183566570281985 29 | 27,legendary-sweep-25,5,True,97,1,True,False,False,False,0.9993398966411228,1.968295335769653 30 | 28,balmy-sweep-23,5,True,97,1,True,False,False,False,0.999934957239741,1.9327661991119385 31 | 29,twilight-sweep-22,5,True,97,1,True,False,False,False,0.9998536706903668,1.9173400402069092 32 | 30,frosty-sweep-21,5,True,97,1,True,False,False,False,0.999837929175316,1.9022855758666992 33 | 31,fresh-sweep-20,5,True,75,1,True,False,False,False,0.9997391358950568,1.9667389392852783 34 | 32,polar-sweep-19,5,True,75,1,True,False,False,False,0.9997047101272916,1.8817198276519775 35 | 33,swift-sweep-18,5,True,75,1,True,False,False,False,0.9998869775350364,1.9324567317962649 36 | 34,quiet-sweep-17,5,True,75,1,True,False,False,False,0.9998312031215908,1.918576717376709 37 | 35,crimson-sweep-16,5,True,75,1,True,False,False,False,0.9998233556649896,1.9022531509399416 38 | 36,silver-sweep-13,5,True,31,1,True,False,False,False,0.9998905965038332,1.9347615242004397 39 | 37,confused-sweep-14,5,True,31,1,True,False,False,False,0.999837668184696,1.881209373474121 40 | 38,mild-sweep-15,5,True,31,1,True,False,False,False,0.9997203499476942,1.9671955108642576 41 | 39,lyric-sweep-12,5,True,31,1,True,False,False,False,0.999881420176493,1.9189906120300293 42 | 40,legendary-sweep-10,5,True,5,1,True,False,False,False,0.999704717051795,1.9659061431884768 43 | 41,earnest-sweep-11,5,True,31,1,True,False,False,False,0.9998578024697564,1.903564453125 44 | 42,quiet-sweep-9,5,True,5,1,True,False,False,False,0.9997725934887132,1.8812062740325928 45 | 43,swift-sweep-8,5,True,5,1,True,False,False,False,0.9999091635224258,1.9331650733947752 46 | 44,usual-sweep-7,5,True,5,1,True,False,False,False,0.9998679143507296,1.9192843437194824 47 | 45,driven-sweep-6,5,True,5,1,True,False,False,False,0.999782312268404,1.9043476581573489 48 | 46,glowing-sweep-5,5,True,0,1,True,False,False,False,0.9997347264032164,1.968245029449463 49 | 47,hardy-sweep-4,5,True,0,1,True,False,False,False,0.9996069409732864,1.8831610679626465 50 | 48,playful-sweep-3,5,True,0,1,True,False,False,False,0.9999002216320252,1.9328291416168213 51 | 49,pious-sweep-2,5,True,0,1,True,False,False,False,0.9998913835217138,1.9182868003845217 52 | 50,toasty-sweep-1,5,True,0,1,True,False,False,False,0.9998414354492512,1.9037504196166992 53 | -------------------------------------------------------------------------------- /care_nl_ica/models/tcl/tcl_eval.py: -------------------------------------------------------------------------------- 1 | """ Fuctions for evaluation 2 | This software includes the work that is distributed in the Apache License 2.0 3 | """ 4 | 5 | import sys 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from sklearn.metrics import confusion_matrix 10 | 11 | 12 | # ============================================================= 13 | # ============================================================= 14 | def get_tensor(x, vars, sess, data_holder, batch=256): 15 | """Get tensor data . 16 | Args: 17 | x: input data [Ndim, Ndata] 18 | vars: tensors (list) 19 | sess: session 20 | data_holder: data holder 21 | batch: batch size 22 | Returns: 23 | y: value of tensors 24 | """ 25 | 26 | Ndata = x.shape[1] 27 | if batch is None: 28 | Nbatch = Ndata 29 | else: 30 | Nbatch = batch 31 | Niter = int(np.ceil(Ndata / Nbatch)) 32 | 33 | if not isinstance(vars, list): 34 | vars = [vars] 35 | 36 | # Convert names to tensors (if necessary) ----------------- 37 | for i in range(len(vars)): 38 | if not tf.is_numeric_tensor(vars[i]) and isinstance(vars[i], str): 39 | vars[i] = tf.get_default_graph().get_tensor_by_name(vars[i]) 40 | 41 | # Start batch-inputs -------------------------------------- 42 | y = {} 43 | for iter in range(Niter): 44 | sys.stdout.write("\r>> Getting tensors... %d/%d" % (iter + 1, Niter)) 45 | sys.stdout.flush() 46 | 47 | # Get batch ------------------------------------------- 48 | batchidx = np.arange(Nbatch * iter, np.minimum(Nbatch * (iter + 1), Ndata)) 49 | xbatch = x[:, batchidx].T 50 | 51 | # Get tensor data ------------------------------------- 52 | feed_dict = {data_holder: xbatch} 53 | ybatch = sess.run(vars, feed_dict=feed_dict) 54 | 55 | # Storage 56 | for tn in range(len(ybatch)): 57 | # Initialize 58 | if iter == 0: 59 | y[tn] = np.zeros([Ndata] + list(ybatch[tn].shape[1:]), dtype=np.float32) 60 | # Store 61 | y[tn][batchidx,] = ybatch[tn] 62 | 63 | sys.stdout.write("\r\n") 64 | 65 | return y 66 | 67 | 68 | # ============================================================= 69 | # ============================================================= 70 | def calc_accuracy(pred, label, normalize_confmat=True): 71 | """Calculate accuracy and confusion matrix 72 | Args: 73 | pred: [Ndata x Nlabel] 74 | label: [Ndata x Nlabel] 75 | Returns: 76 | accuracy: accuracy 77 | conf: confusion matrix 78 | """ 79 | 80 | # print("Calculating accuracy...") 81 | 82 | # Accuracy ------------------------------------------------ 83 | correctflag = pred.reshape(-1) == label.reshape(-1) 84 | accuracy = np.mean(correctflag) 85 | 86 | # Confusion matrix ---------------------------------------- 87 | conf = confusion_matrix(label[:], pred[:]).astype(np.float32) 88 | # Normalization 89 | if normalize_confmat: 90 | for i in range(conf.shape[0]): 91 | conf[i, :] = conf[i, :] / np.sum(conf[i, :]) 92 | 93 | return accuracy, conf 94 | 95 | 96 | # ============================================================= 97 | # ============================================================= 98 | # def correlation(x, y, method='Pearson'): 99 | # """Evaluate correlation 100 | # Args: 101 | # x: data to be sorted 102 | # y: target data 103 | # Returns: 104 | # corr_sort: correlation matrix between x and y (after sorting) 105 | # sort_idx: sorting index 106 | # x_sort: x after sorting 107 | # """ 108 | # 109 | # print("Calculating correlation...") 110 | # 111 | # x = x.copy() 112 | # y = y.copy() 113 | # dim = x.shape[0] 114 | # 115 | # # Calculate correlation ----------------------------------- 116 | # if method=='Pearson': 117 | # corr = np.corrcoef(y, x) 118 | # corr = corr[0:dim,dim:] 119 | # elif method=='Spearman': 120 | # corr, pvalue = sp.stats.spearmanr(y.T, x.T) 121 | # corr = corr[0:dim, dim:] 122 | # 123 | # # Sort ---------------------------------------------------- 124 | # munk = Munkres() 125 | # indexes = munk.compute(-np.absolute(corr)) 126 | # 127 | # sort_idx = np.zeros(dim) 128 | # x_sort = np.zeros(x.shape) 129 | # for i in range(dim): 130 | # sort_idx[i] = indexes[i][1] 131 | # x_sort[i,:] = x[indexes[i][1],:] 132 | # 133 | # # Re-calculate correlation -------------------------------- 134 | # if method=='Pearson': 135 | # corr_sort = np.corrcoef(y, x_sort) 136 | # corr_sort = corr_sort[0:dim,dim:] 137 | # elif method=='Spearman': 138 | # corr_sort, pvalue = sp.stats.spearmanr(y.T, x_sort.T) 139 | # corr_sort = corr_sort[0:dim, dim:] 140 | # 141 | # return corr_sort, sort_idx, x_sort 142 | -------------------------------------------------------------------------------- /care_nl_ica/models/icebeem_wrapper.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from sklearn.decomposition import FastICA 8 | from torch.distributions import Uniform, TransformedDistribution, SigmoidTransform 9 | 10 | from fce import ConditionalFCE 11 | from .nets import MLP 12 | from .nflib.flows import NormalizingFlowModel, Invertible1x1Conv, ActNorm 13 | from .nflib.spline_flows import NSF_AR 14 | 15 | 16 | def ICEBEEM_wrapper( 17 | X, 18 | Y, 19 | ebm_hidden_size, 20 | n_layers_ebm, 21 | n_layers_flow, 22 | lr_flow, 23 | lr_ebm, 24 | seed, 25 | ckpt_file="icebeem.pt", 26 | test=False, 27 | ): 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | data_dim = X.shape[1] 31 | 32 | model_ebm = MLP( 33 | input_size=data_dim, 34 | hidden_size=[ebm_hidden_size] * n_layers_ebm, 35 | n_layers=n_layers_ebm, 36 | output_size=data_dim, 37 | use_bn=True, 38 | activation_function=F.leaky_relu, 39 | ) 40 | 41 | prior = TransformedDistribution( 42 | Uniform(torch.zeros(data_dim), torch.ones(data_dim)), SigmoidTransform().inv 43 | ) 44 | nfs_flow = NSF_AR 45 | flows = [ 46 | nfs_flow(dim=data_dim, K=8, B=3, hidden_dim=16) for _ in range(n_layers_flow) 47 | ] 48 | convs = [Invertible1x1Conv(dim=data_dim) for _ in flows] 49 | norms = [ActNorm(dim=data_dim) for _ in flows] 50 | flows = list(itertools.chain(*zip(norms, convs, flows))) 51 | # construct the model 52 | model_flow = NormalizingFlowModel(prior, flows) 53 | 54 | pretrain_flow = True 55 | augment_ebm = True 56 | 57 | # instantiate ebmFCE object 58 | fce_ = ConditionalFCE( 59 | data=X.astype(np.float32), 60 | segments=Y.astype(np.float32), 61 | energy_MLP=model_ebm, 62 | flow_model=model_flow, 63 | verbose=False, 64 | ) 65 | 66 | init_ckpt_file = ( 67 | os.path.splitext(ckpt_file)[0] + "_0" + os.path.splitext(ckpt_file)[1] 68 | ) 69 | if not test: 70 | if pretrain_flow: 71 | # print('pretraining flow model..') 72 | fce_.pretrain_flow_model(epochs=1, lr=1e-4) 73 | # print('pretraining done.') 74 | 75 | # first we pretrain the final layer of EBM model (this is g(y) as it depends on segments) 76 | fce_.train_ebm_fce( 77 | epochs=15, augment=augment_ebm, finalLayerOnly=True, cutoff=0.5 78 | ) 79 | 80 | # then train full EBM via NCE with flow contrastive noise: 81 | fce_.train_ebm_fce(epochs=50, augment=augment_ebm, cutoff=0.5, useVAT=False) 82 | 83 | torch.save( 84 | { 85 | "ebm_mlp": fce_.energy_MLP.state_dict(), 86 | "ebm_finalLayer": fce_.ebm_finalLayer, 87 | "flow": fce_.flow_model.state_dict(), 88 | }, 89 | init_ckpt_file, 90 | ) 91 | else: 92 | state = torch.load(init_ckpt_file, map_location=fce_.device) 93 | fce_.energy_MLP.load_state_dict(state["ebm_mlp"]) 94 | fce_.ebm_finalLayer = state["ebm_finalLayer"] 95 | fce_.flow_model.load_stat_dict(state["flow"]) 96 | 97 | # evaluate recovery of latents 98 | recov = fce_.unmixSamples(X, modelChoice="ebm") 99 | source_est_ica = FastICA().fit_transform((recov)) 100 | recov_sources = [source_est_ica] 101 | 102 | # iterate between updating noise and tuning the EBM 103 | eps = 0.025 104 | for iter_ in range(3): 105 | mid_ckpt_file = ( 106 | os.path.splitext(ckpt_file)[0] 107 | + "_" 108 | + str(iter_ + 1) 109 | + os.path.splitext(ckpt_file)[1] 110 | ) 111 | if not test: 112 | # update flow model: 113 | fce_.train_flow_fce( 114 | epochs=5, objConstant=-1.0, cutoff=0.5 - eps, lr=lr_flow 115 | ) 116 | # update energy based model: 117 | fce_.train_ebm_fce( 118 | epochs=50, 119 | augment=augment_ebm, 120 | cutoff=0.5 + eps, 121 | lr=lr_ebm, 122 | useVAT=False, 123 | ) 124 | 125 | torch.save( 126 | { 127 | "ebm_mlp": fce_.energy_MLP.state_dict(), 128 | "ebm_finalLayer": fce_.ebm_finalLayer, 129 | "flow": fce_.flow_model.state_dict(), 130 | }, 131 | mid_ckpt_file, 132 | ) 133 | else: 134 | state = torch.load(mid_ckpt_file, map_location=fce_.device) 135 | fce_.energy_MLP.load_state_dict(state["ebm_mlp"]) 136 | fce_.ebm_finalLayer = state["ebm_finalLayer"] 137 | fce_.flow_model.load_stat_dict(state["flow"]) 138 | 139 | # evaluate recovery of latents 140 | recov = fce_.unmixSamples(X, modelChoice="ebm") 141 | source_est_ica = FastICA().fit_transform((recov)) 142 | recov_sources.append(source_est_ica) 143 | 144 | return recov_sources 145 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/solver.py: -------------------------------------------------------------------------------- 1 | """Modified https://github.com/bethgelab/slow_disentanglement/blob/master/scripts/solver.py 2 | and removed functions not needed for training with a contrastive loss.""" 3 | 4 | import os 5 | import shutil 6 | import torch 7 | import torch.optim as optim 8 | from torch.autograd import Variable 9 | from kitti_masks.model import BetaVAE_H as BetaVAE 10 | from care_nl_ica import losses 11 | 12 | 13 | class Solver(object): 14 | def __init__(self, args, data_loader=None): 15 | self.ckpt_dir = args.ckpt_dir 16 | self.output_dir = args.output_dir 17 | self.data_loader = data_loader 18 | self.dataset = args.dataset 19 | self.device = torch.device( 20 | "cuda" if torch.cuda.is_available() and args.cuda else "cpu" 21 | ) 22 | self.max_iter = args.max_iter 23 | self.global_iter = 0 24 | 25 | self.z_dim = args.z_dim 26 | self.nc = args.num_channel 27 | params = [] 28 | 29 | # for adam 30 | self.lr = args.lr 31 | self.beta1 = args.beta1 32 | self.beta2 = args.beta2 33 | 34 | self.net = BetaVAE(self.z_dim, self.nc, args.box_norm).to(self.device) 35 | self.optim = optim.Adam( 36 | params + list(self.net.parameters()), 37 | lr=self.lr, 38 | betas=(self.beta1, self.beta2), 39 | ) 40 | 41 | self.ckpt_name = args.ckpt_name 42 | if False and self.ckpt_name is not None: 43 | self.load_checkpoint(self.ckpt_name) 44 | 45 | self.log_step = args.log_step 46 | self.save_step = args.save_step 47 | 48 | self.loss = losses.LpSimCLRLoss( 49 | p=args.p, tau=1.0, simclr_compatibility_mode=True 50 | ) 51 | 52 | def train(self): 53 | self.net_mode(train=True) 54 | out = False # whether to exit training loop 55 | failure = False # whether training was stopped 56 | running_loss = 0 57 | log = open(os.path.join(self.output_dir, "log.csv"), "a", 1) 58 | log.write("Total Loss\n") 59 | 60 | while not out: 61 | for x, _ in self.data_loader: # don't use label 62 | x = Variable(x.to(self.device)) 63 | mu = self.net(x) 64 | z1_rec = mu[::2] 65 | z2_con_z1_rec = mu[1::2] 66 | z3_rec = torch.roll(z1_rec, 1, 0) 67 | vae_loss, _, _ = self.loss( 68 | None, None, None, z1_rec, z2_con_z1_rec, z3_rec 69 | ) 70 | running_loss += vae_loss.item() 71 | 72 | self.optim.zero_grad() 73 | vae_loss.backward() 74 | self.optim.step() 75 | 76 | self.global_iter += 1 77 | if self.global_iter % self.log_step == 0: 78 | running_loss /= self.log_step 79 | log.write("%.6f" % running_loss + "\n") 80 | 81 | running_loss = 0 82 | 83 | if self.global_iter % self.save_step == 0: 84 | self.save_checkpoint("last") 85 | 86 | if self.global_iter % 50000 == 0: 87 | self.save_checkpoint(str(self.global_iter)) 88 | 89 | if self.global_iter >= self.max_iter: 90 | out = True 91 | break 92 | 93 | if failure: 94 | shutil.rmtree(self.ckpt_dir) 95 | 96 | return failure 97 | 98 | def save_checkpoint(self, filename, silent=True): 99 | model_states = { 100 | "net": self.net.state_dict(), 101 | } 102 | optim_states = { 103 | "optim": self.optim.state_dict(), 104 | } 105 | states = { 106 | "iter": self.global_iter, 107 | "model_states": model_states, 108 | "optim_states": optim_states, 109 | } 110 | 111 | file_path = os.path.join(self.ckpt_dir, filename) 112 | with open(file_path, mode="wb+") as f: 113 | torch.save(states, f) 114 | if not silent: 115 | print( 116 | "=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter) 117 | ) 118 | 119 | def load_checkpoint(self, filename): 120 | file_path = os.path.join(self.ckpt_dir, filename) 121 | if os.path.isfile(file_path): 122 | checkpoint = torch.load(file_path) 123 | self.global_iter = checkpoint["iter"] 124 | self.net.load_state_dict(checkpoint["model_states"]["net"]) 125 | self.optim.load_state_dict(checkpoint["optim_states"]["optim"]) 126 | print( 127 | "=> loaded checkpoint '{} (iter {})'".format( 128 | file_path, self.global_iter 129 | ) 130 | ) 131 | else: 132 | print("=> no checkpoint found at '{}'".format(file_path)) 133 | 134 | def net_mode(self, train): 135 | if not isinstance(train, bool): 136 | raise ValueError("Only bool type is supported. True or False") 137 | 138 | if train: 139 | self.net.train() 140 | else: 141 | self.net.eval() 142 | --------------------------------------------------------------------------------