├── .gitmodules ├── care_nl_ica ├── __init__.py ├── losses │ ├── __init__.py │ ├── dep_mat.py │ ├── utils.py │ └── dsm.py ├── models │ ├── __init__.py │ ├── nflib │ │ └── __init__.py │ ├── ivae │ │ ├── __init__.py │ │ └── ivae_wrapper.py │ ├── tcl │ │ ├── __init__.py │ │ ├── tcl_preprocessing.py │ │ └── tcl_eval.py │ ├── sparsity.py │ ├── model.py │ └── icebeem_wrapper.py ├── independence │ ├── __init__.py │ ├── indep_check.py │ └── hsic.py ├── cl_ica │ ├── datasets │ │ ├── __init__.py │ │ └── simple_image_dataset.py │ ├── kitti_masks │ │ ├── __init__.py │ │ ├── mcc_metric │ │ │ ├── __init__.py │ │ │ └── metric.py │ │ ├── README.md │ │ ├── LICENSE │ │ ├── evaluate_disentanglement.py │ │ ├── model.py │ │ └── solver.py │ ├── tools │ │ └── 3dident │ │ │ ├── data │ │ │ ├── materials_json │ │ │ │ ├── materials.json │ │ │ │ ├── materials_diff.json │ │ │ │ ├── materials_mix.json │ │ │ │ ├── materials_spec.json │ │ │ │ └── materials_cristal.json │ │ │ ├── colors_json │ │ │ │ ├── colors.json │ │ │ │ ├── colors_rgb.json │ │ │ │ ├── colors_value_green.json │ │ │ │ └── colors_saturation_green.json │ │ │ ├── textures │ │ │ │ └── grass.jpg │ │ │ ├── materials │ │ │ │ ├── Crystal.blend │ │ │ │ ├── MyMetal.blend │ │ │ │ └── Rubber.blend │ │ │ ├── scenes │ │ │ │ ├── base_scene.blend │ │ │ │ ├── base_scene_old.blend │ │ │ │ ├── base_scene_spot.blend │ │ │ │ ├── base_scene_simple.blend │ │ │ │ ├── base_scene_equal_xyz.blend │ │ │ │ └── base_scene_equal_xyz.blend1 │ │ │ ├── shapes │ │ │ │ ├── ShapeCube.blend │ │ │ │ ├── ShapeSphere.blend │ │ │ │ ├── ShapeTeapot.blend │ │ │ │ └── ShapeCylinder.blend │ │ │ ├── node_groups │ │ │ │ ├── NodeGroup.blend │ │ │ │ └── NodeGroupMulti4.blend │ │ │ ├── CoGenT_A.json │ │ │ ├── CoGenT_B.json │ │ │ └── properties.json │ │ │ ├── NOTICE │ │ │ ├── get_mean_std.py │ │ │ ├── ORIGINAL_LICENSE │ │ │ └── ORIGINAL_PATENTS │ ├── __init__.py │ ├── infinite_iterator.py │ ├── docker │ │ ├── tmux.conf │ │ ├── Dockerfile.blender │ │ ├── entrypoint.sh │ │ └── Dockerfile │ ├── LICENSE │ ├── latent_spaces.py │ ├── layers.py │ └── spaces_utils.py ├── metrics │ ├── __init__.py │ ├── metric_logger.py │ └── dep_mat.py ├── data │ ├── __init__.py │ └── utils.py ├── logger.py ├── dataset.py ├── ica.py ├── cli.py ├── graph_utils.py ├── prob_utils.py └── utils.py ├── scripts ├── fix_no_gpu.sh ├── cancel_all_jobs.sh ├── get_used_space.sh ├── sourcing_script.sh ├── sweeps.sh ├── start_interactive_job.sh ├── profile.sh ├── wandb_sweep.sh ├── start_preemptable_job.sh ├── start_preemptable_cli.sh ├── train_loop.sh ├── loss_compare_for_permutations.sh ├── run_singularity_server.sh ├── nv.def ├── interactive_job_inner.sh └── exclude_nodes.sh ├── configs ├── data │ ├── chain.yaml │ ├── permute.yaml │ ├── nl_sem.yaml │ └── uniform_weights.yaml ├── profile.yaml ├── cpu.yaml └── config.yaml ├── setup.py ├── notebooks ├── monti_sweep_3rqfxkyl.npz ├── monti_sweep_70zssxmx.npz ├── monti_sweep_77huh2ue.npz ├── sem_10d_sweep_7lsb5ud3.npz ├── sem_10d_sweep_i871qu61.npz ├── sem_3d_sweep_mmzkkmw4.npz ├── sem_3d_sweep_vfv1je0d.npz ├── sem_5d_sweep_f5nxtdxz.npz ├── sem_5d_sweep_h6y1gkvo.npz ├── sem_5d_sweep_kdau2seo.npz ├── sem_8d_sweep_7sscc3w1.npz ├── sem_8d_sweep_v3kd7kca.npz ├── sem_10d_sparse_sweep_t7rmrux1.npz ├── sem_3d_permute_sweep_2mgctqko.npz ├── sem_3d_sparse_sweep_lm8s890w.npz ├── sem_5d_munkres_sweep_kdau2seo.npz ├── sem_5d_permute_sweep_rv7yo1qy.npz ├── sem_5d_permute_sweep_x6chdc63.npz ├── sem_5d_sparse_sweep_7n108utd.npz ├── sem_8d_permute_sweep_05whlpmk.npz ├── sem_8d_permute_sweep_l49b2vhx.npz ├── sem_8d_sparse_sweep_2ykg2w21.npz ├── sem_10d_permute_sweep_291qhry5.npz ├── sem_10d_permute_sweep_at138q9q.npz ├── sem_5d_munkres_sweep_kdau2seo.csv ├── sem_3d_sparse_sweep_lm8s890w.csv ├── sem_3d_sweep_mmzkkmw4.csv ├── sem_5d_sweep_kdau2seo.csv ├── sem_8d_sweep_v3kd7kca.csv ├── sem_10d_sparse_sweep_t7rmrux1.csv ├── sem_5d_sparse_sweep_7n108utd.csv ├── sem_8d_sparse_sweep_2ykg2w21.csv ├── sem_10d_sweep_i871qu61.csv ├── monti_sweep_77huh2ue.csv ├── monti_sweep_3rqfxkyl.csv ├── sem_3d_sweep_vfv1je0d.csv ├── sem_10d_permute_sweep_at138q9q.csv ├── sem_8d_permute_sweep_05whlpmk.csv ├── sem_5d_sweep_f5nxtdxz.csv ├── sem_5d_permute_sweep_rv7yo1qy.csv └── sem_5d_permute_sweep_x6chdc63.csv ├── tests ├── requirements.txt ├── test_runner.py ├── test_dataset.py ├── test_datamodule.py ├── test_sinkhorn.py ├── test_metrics.py ├── test_dep_mat.py ├── conftest.py └── test_graph_utils.py ├── .gitignore ├── requirements.txt ├── setup.cfg ├── .pre-commit-config.yaml ├── sweeps ├── sem │ ├── mlp_sem3.yaml │ ├── mlp_sem5.yaml │ ├── mlp_sem3_no_permute.yaml │ ├── mlp_sem5_no_permute.yaml │ ├── mlp_sem8.yaml │ ├── mlp_sem8_no_permute.yaml │ ├── mlp_sem10_no_permute.yaml │ ├── mlp_sem15_no_permute.yaml │ ├── mlp_sem10.yaml │ ├── mlp_sem5_sparse.yaml │ ├── mlp_sem3_no_permute_sparse.yaml │ ├── mlp_sem5_no_permute_sparse.yaml │ ├── mlp_sem8_no_permute_sparse.yaml │ ├── mlp_sem8_sparse.yaml │ ├── mlp_sem10_no_permute_sparse.yaml │ └── mlp_sem10_sparse.yaml └── tcl │ ├── mlp_ar.yaml │ ├── mlp_ar_adaptive_offset.yaml │ └── mlp_ar_adaptive_offset_sparse.yaml ├── CITATION.cff ├── LICENSE ├── .github └── workflows │ └── python-package.yml └── README.md /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/independence/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/fix_no_gpu.sh: -------------------------------------------------------------------------------- 1 | unset SLURM_NTASKS_PER_NODE -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/mcc_metric/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/data/chain.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | force_chain: true 3 | -------------------------------------------------------------------------------- /configs/data/permute.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | permute: true 3 | -------------------------------------------------------------------------------- /configs/data/nl_sem.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | nonlin_sem: true 3 | 4 | -------------------------------------------------------------------------------- /configs/data/uniform_weights.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | force_uniform: true 3 | -------------------------------------------------------------------------------- /scripts/cancel_all_jobs.sh: -------------------------------------------------------------------------------- 1 | squeue --me | awk 'NR>1 {print $1}' | xargs scancel -------------------------------------------------------------------------------- /care_nl_ica/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcc import mean_corr_coef 2 | 3 | __all__ = ["mcc"] 4 | -------------------------------------------------------------------------------- /care_nl_ica/models/nflib/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["conditional_flows", "flows", "spline_flows"] 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(packages=find_packages()) 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials_json/materials.json: -------------------------------------------------------------------------------- 1 | { 2 | "mat": ["Rubber", "Rubber", "Rubber"] 3 | } -------------------------------------------------------------------------------- /scripts/get_used_space.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | awk '/[0-9]+,preizinger/{print}' /mnt/qb/work/bethge/bethge_user_disk_usage_MB.log -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/colors_json/colors.json: -------------------------------------------------------------------------------- 1 | { 2 | "rgb": [[0, 0, 255], [255, 255, 255], [255, 0, 0]] 3 | } -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials_json/materials_diff.json: -------------------------------------------------------------------------------- 1 | { 2 | "mat": ["Rubber", "Rubber", "Rubber"] 3 | } -------------------------------------------------------------------------------- /care_nl_ica/models/ivae/__init__.py: -------------------------------------------------------------------------------- 1 | from .ivae_wrapper import IVAE_wrapper 2 | 3 | __all__ = ["ivae_wrapper", "ivae_core"] 4 | -------------------------------------------------------------------------------- /care_nl_ica/models/tcl/__init__.py: -------------------------------------------------------------------------------- 1 | from .tcl_wrapper_gpu import train 2 | 3 | __all__ = ["tcl_wrapper_gpu", "tcl_core"] 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/colors_json/colors_rgb.json: -------------------------------------------------------------------------------- 1 | { 2 | "rgb": [[255, 0, 0], [0, 255, 0], [0, 0, 255]] 3 | } 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials_json/materials_mix.json: -------------------------------------------------------------------------------- 1 | { 2 | "mat": ["Crystal", "MyMetal", "Rubber"] 3 | } 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials_json/materials_spec.json: -------------------------------------------------------------------------------- 1 | { 2 | "mat": ["MyMetal", "MyMetal", "MyMetal"] 3 | } 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials_json/materials_cristal.json: -------------------------------------------------------------------------------- 1 | { 2 | "mat": ["Crystal", "Crystal", "Crystal"] 3 | } 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/colors_json/colors_value_green.json: -------------------------------------------------------------------------------- 1 | { 2 | "rgb": [[0, 255, 42], [0, 178, 29], [0, 76, 12]] 3 | } 4 | -------------------------------------------------------------------------------- /notebooks/monti_sweep_3rqfxkyl.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/monti_sweep_3rqfxkyl.npz -------------------------------------------------------------------------------- /notebooks/monti_sweep_70zssxmx.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/monti_sweep_70zssxmx.npz -------------------------------------------------------------------------------- /notebooks/monti_sweep_77huh2ue.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/monti_sweep_77huh2ue.npz -------------------------------------------------------------------------------- /notebooks/sem_10d_sweep_7lsb5ud3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_10d_sweep_7lsb5ud3.npz -------------------------------------------------------------------------------- /notebooks/sem_10d_sweep_i871qu61.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_10d_sweep_i871qu61.npz -------------------------------------------------------------------------------- /notebooks/sem_3d_sweep_mmzkkmw4.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_3d_sweep_mmzkkmw4.npz -------------------------------------------------------------------------------- /notebooks/sem_3d_sweep_vfv1je0d.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_3d_sweep_vfv1je0d.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_sweep_f5nxtdxz.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_sweep_f5nxtdxz.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_sweep_h6y1gkvo.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_sweep_h6y1gkvo.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_sweep_kdau2seo.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_sweep_kdau2seo.npz -------------------------------------------------------------------------------- /notebooks/sem_8d_sweep_7sscc3w1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_8d_sweep_7sscc3w1.npz -------------------------------------------------------------------------------- /notebooks/sem_8d_sweep_v3kd7kca.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_8d_sweep_v3kd7kca.npz -------------------------------------------------------------------------------- /scripts/sourcing_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for f in /.singularity.d/env/*; do echo "$f"; source "$f"; done 3 | 4 | export PS1="\u@\h \W:$" -------------------------------------------------------------------------------- /scripts/sweeps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for ((i = 1; i <= 1; i++)); 4 | do 5 | ./scripts/wandb_sweep.sh "$@" & 6 | sleep 1 7 | done 8 | -------------------------------------------------------------------------------- /care_nl_ica/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .imca import generate_synthetic_data 2 | from .utils import to_one_hot 3 | 4 | __all__ = ["imca", "utils"] 5 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/colors_json/colors_saturation_green.json: -------------------------------------------------------------------------------- 1 | { 2 | "rgb": [[0, 255, 42], [76, 255, 106], [178, 255, 191]] 3 | } 4 | -------------------------------------------------------------------------------- /notebooks/sem_10d_sparse_sweep_t7rmrux1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_10d_sparse_sweep_t7rmrux1.npz -------------------------------------------------------------------------------- /notebooks/sem_3d_permute_sweep_2mgctqko.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_3d_permute_sweep_2mgctqko.npz -------------------------------------------------------------------------------- /notebooks/sem_3d_sparse_sweep_lm8s890w.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_3d_sparse_sweep_lm8s890w.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_munkres_sweep_kdau2seo.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_munkres_sweep_kdau2seo.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_permute_sweep_rv7yo1qy.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_permute_sweep_rv7yo1qy.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_permute_sweep_x6chdc63.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_permute_sweep_x6chdc63.npz -------------------------------------------------------------------------------- /notebooks/sem_5d_sparse_sweep_7n108utd.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_5d_sparse_sweep_7n108utd.npz -------------------------------------------------------------------------------- /notebooks/sem_8d_permute_sweep_05whlpmk.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_8d_permute_sweep_05whlpmk.npz -------------------------------------------------------------------------------- /notebooks/sem_8d_permute_sweep_l49b2vhx.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_8d_permute_sweep_l49b2vhx.npz -------------------------------------------------------------------------------- /notebooks/sem_8d_sparse_sweep_2ykg2w21.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_8d_sparse_sweep_2ykg2w21.npz -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/NOTICE: -------------------------------------------------------------------------------- 1 | 3DIdent is derived from CLEVR. The original rendering pipeline is distributed under the terms listed in ORIGINAL_LICENSE. -------------------------------------------------------------------------------- /configs/profile.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | profiler: 3 | class_path: pytorch_lightning.profiler.PyTorchProfiler 4 | init_args: 5 | profile_memory: true 6 | -------------------------------------------------------------------------------- /notebooks/sem_10d_permute_sweep_291qhry5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_10d_permute_sweep_291qhry5.npz -------------------------------------------------------------------------------- /notebooks/sem_10d_permute_sweep_at138q9q.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/notebooks/sem_10d_permute_sweep_at138q9q.npz -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | coverage 2 | codecov 3 | pytest~=6.2.5 4 | pytest-cov 5 | pytest-flake8 6 | black~=23.1.0 7 | flake8 8 | check-manifest 9 | twine==1.13.0 -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/README.md: -------------------------------------------------------------------------------- 1 | Code from https://github.com/bethgelab/slow_disentanglement/blob/master/scripts 2 | to load and evaluate on the KITTI Masks dataset. -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/textures/grass.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/textures/grass.jpg -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials/Crystal.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/materials/Crystal.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials/MyMetal.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/materials/MyMetal.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/materials/Rubber.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/materials/Rubber.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeCube.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeCube.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeSphere.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeSphere.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeTeapot.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeTeapot.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/node_groups/NodeGroup.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/node_groups/NodeGroup.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_old.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_old.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_spot.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_spot.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeCylinder.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/shapes/ShapeCylinder.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_simple.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_simple.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/node_groups/NodeGroupMulti4.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/node_groups/NodeGroupMulti4.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_equal_xyz.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_equal_xyz.blend -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_equal_xyz.blend1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rpatrik96/nl-causal-representations/HEAD/care_nl_ica/cl_ica/tools/3dident/data/scenes/base_scene_equal_xyz.blend1 -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/__init__.py: -------------------------------------------------------------------------------- 1 | from . import latent_spaces 2 | from . import spaces 3 | from . import vmf 4 | from . import spaces_utils 5 | from . import layers 6 | 7 | 8 | __all__ = ["latent_spaces", "spaces", "vmf", "spaces_utils", "layers"] 9 | -------------------------------------------------------------------------------- /scripts/start_interactive_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PARTITION=gpu-2080ti-interactive 3 | FLAGS=--exclude=slurm-bm-06 4 | 5 | srun --job-name="$JOB_NAME" --partition=$PARTITION --pty --gres=gpu:1 "$FLAGS" -- ./scripts/interactive_job_inner.sh 6 | 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | venv 3 | run 4 | .ipynb_checkpoints 5 | */.ipynb_checkpoints/* 6 | __pycache__ 7 | */__pycache__/* 8 | */lightning_logs/* 9 | */wandb/* 10 | *.sif 11 | *.egg-info/* 12 | *.log 13 | *.ckpt 14 | outputs/* 15 | wandb/* 16 | */artifacts/* -------------------------------------------------------------------------------- /scripts/profile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 care_nl_ica/cli.py fit --config configs/config.yaml --config configs/model/ar_non_tri.yaml --config configs/data/permute.yaml --config configs/data/nl_sem.yaml --trainer.profiler=simple --trainer.max_epochs=2 3 | 4 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/CoGenT_A.json: -------------------------------------------------------------------------------- 1 | { 2 | "cube": [ 3 | "gray", "blue", "brown", "yellow" 4 | ], 5 | "cylinder": [ 6 | "red", "green", "purple", "cyan" 7 | ], 8 | "sphere": [ 9 | "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow" 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/CoGenT_B.json: -------------------------------------------------------------------------------- 1 | { 2 | "cube": [ 3 | "red", "green", "purple", "cyan" 4 | ], 5 | "cylinder": [ 6 | "gray", "blue", "brown", "yellow" 7 | ], 8 | "sphere": [ 9 | "gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow" 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /scripts/wandb_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PARTITION=gpu-2080ti-preemptable 3 | FLAGS=--exclude= #slurm-bm-62 4 | PYTHONPATH=. srun --time=600 --job-name="$JOB_NAME" --partition=$PARTITION $FLAGS --mem=8G --gpus=1 -- /mnt/qb/work/bethge/preizinger/nl-causal-representations//scripts/run_singularity_server.sh wandb agent --count 1 "$@" 5 | 6 | -------------------------------------------------------------------------------- /tests/test_runner.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.trainer import Trainer 2 | 3 | from care_nl_ica.data.datamodules import ContrastiveDataModule 4 | from care_nl_ica.runner import ContrastiveICAModule 5 | 6 | 7 | def test_runner(): 8 | trainer = Trainer(fast_dev_run=True) 9 | runner = ContrastiveICAModule() 10 | dm = ContrastiveDataModule() 11 | trainer.fit(runner, datamodule=dm) 12 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def test_contrastive_dataset(dataloader): 5 | batches = [] 6 | num_batches = 3 7 | for i in range(num_batches): 8 | batches.append(next(iter(dataloader))[1]) 9 | 10 | # calculates the variance accross the batch and n dimensions 11 | # to check that we do not get the same data 12 | torch.any(torch.stack(batches).var(0).sum([-1, -2]) > 1e-7) 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.19.5 2 | torch~=1.11.0 3 | scipy~=1.6.3 4 | scikit-learn~=0.24.2 5 | matplotlib~=3.4.2 6 | wandb~=0.12.17 7 | h5py~=3.1.0 8 | torchmetrics~=0.7.2 9 | pytest~=6.2.4 10 | setuptools~=59.5.0 11 | pandas~=1.3.1 12 | pip~=21.1.2 13 | pytorch-lightning==1.5.10 14 | jsonargparse[signatures]~=4.3.0 15 | hydra-core~=1.1.1 16 | pre-commit 17 | omegaconf~=2.1.1 18 | pynvml~=11.4.1 19 | functorch~=0.1.0 20 | tueplots -------------------------------------------------------------------------------- /scripts/start_preemptable_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PARTITION=gpu-2080ti-preemptable 3 | #FLAGS= 4 | PYTHONPATH=. srun --job-name="$JOB_NAME" --partition=$PARTITION --cpus-per-task=2 --mem=4G --gpus=1 -- /mnt/qb/work/bethge/preizinger/nl-causal-representations//scripts/run_singularity_server.sh python3 /mnt/qb/work/bethge/preizinger/nl-causal-representations/care_nl_ica/main.py --project mlp-test --use-batch-norm --use-dep-mat --use-wandb "$@" 5 | 6 | -------------------------------------------------------------------------------- /scripts/start_preemptable_cli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PARTITION=gpu-2080ti-preemptable 3 | #FLAGS= 4 | PYTHONPATH=. srun --time=360 --job-name="$JOB_NAME" --partition=$PARTITION --mem=8G --pty --gpus=1 -- scripts/run_singularity_server.sh python3 care_nl_ica/cli.py fit --config configs/config.yaml # --trainer.max_epochs=25000 --data.latent_dim=10 --model.offline=true --data.variant=4 --data.force_chain=true --data.nonlin_sem=true --data.use_sem=false --data.permute=false --data.data_gen_mode=rvs "$@" 5 | 6 | -------------------------------------------------------------------------------- /scripts/train_loop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for seed in 84645646 4 | do 5 | for n in 4 6 | do 7 | for l1 in 0 #1e1 1 1e-1 1e-2 1e-3 #1e-1 1 1e2 #0 1e-2 3e-3 1e-3 3e-4 #0 1e-6 1e-5 1e-4 8 | do 9 | ./scripts/start_preemptable_job.sh --use-ar-mlp --seed ${seed} --n ${n} --l1 ${l1} --note "sinkhorn" --tags normalization nonlinear sem residual entropy sinkhorn --use-sem --nonlin-sem --normalize-latents --verbose --permute & 10 | sleep 20 11 | done 12 | done 13 | done 14 | -------------------------------------------------------------------------------- /tests/test_datamodule.py: -------------------------------------------------------------------------------- 1 | from care_nl_ica.data.datamodules import ContrastiveDataModule 2 | import torch 3 | 4 | 5 | def test_contrastive_datamodule(datamodule: ContrastiveDataModule): 6 | batches = [] 7 | num_batches = 3 8 | for i in range(num_batches): 9 | batches.append(next(iter(datamodule.train_dataloader()))[0][0, :]) 10 | 11 | # calculates the variance accross the batch and n dimensions 12 | # to check that we do not get the same data 13 | torch.any(torch.stack(batches).var(0).sum([-1, -2]) > 1e-7) 14 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = causalnlica 3 | version = 1.0.0 4 | author = Patrik Reizinger, Yash Sharma 5 | author_email = prmedia17@gmail.com 6 | description = Causal Representation learning with Nonlinear ICA 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/rpatrik96/nl-causal-representations 10 | project_urls = 11 | Bug Tracker = https://github.com/rpatrik96/nl-causal-representations 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | License :: OSI Approved :: MIT License 15 | Operating System :: OS Independent 16 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/infinite_iterator.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | 4 | class InfiniteIterator: 5 | """Infinitely repeat the iterable.""" 6 | 7 | def __init__(self, iterable: Iterable): 8 | self._iterable = iterable 9 | self.iterator = iter(self._iterable) 10 | 11 | def __iter__(self): 12 | return self 13 | 14 | def __next__(self): 15 | for _ in range(2): 16 | try: 17 | return next(self.iterator) 18 | except StopIteration: 19 | # reset iterator 20 | del self.iterator 21 | self.iterator = iter(self._iterable) 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.1.0 4 | hooks: 5 | - id: black 6 | entry: black 7 | types: [python] 8 | # It is recommended to specify the latest version of Python 9 | # supported by your project here, or alternatively use 10 | # pre-commit's default_language_version, see 11 | # https://pre-commit.com/#top_level-default_language_version 12 | language_version: python3.8 13 | - repo: local 14 | hooks: 15 | - id: tests 16 | name: Unit tests 17 | language: system 18 | entry: bash -c "source activate venv"; pytest tests -------------------------------------------------------------------------------- /tests/test_sinkhorn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from care_nl_ica.models.sinkhorn import SinkhornNet 4 | from care_nl_ica.utils import setup_seed 5 | 6 | 7 | def test_sinkhorn_net(): 8 | num_vars = 3 9 | num_steps = 15 10 | temperature = 3e-3 11 | threshold = 0.95 12 | 13 | setup_seed(1) 14 | 15 | sinkhorn_net = SinkhornNet(num_vars, num_steps, temperature) 16 | sinkhorn_net.weight.data = torch.randn(sinkhorn_net.weight.shape) 17 | 18 | print(sinkhorn_net.doubly_stochastic_matrix.abs()) 19 | assert torch.all( 20 | (sinkhorn_net.doubly_stochastic_matrix.abs() > threshold).sum(dim=0) 21 | == torch.ones(1, num_vars) 22 | ) 23 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/data/properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "shapes": { 3 | "cube": "ShapeCube", 4 | "sphere": "ShapeSphere", 5 | "cylinder": "ShapeCylinder", 6 | "teapot": "ShapeTeapot" 7 | }, 8 | "colors": { 9 | "gray": [87, 87, 87], 10 | "red": [173, 35, 35], 11 | "blue": [42, 75, 215], 12 | "green": [29, 105, 20], 13 | "brown": [129, 74, 25], 14 | "purple": [129, 38, 192], 15 | "cyan": [41, 208, 208], 16 | "yellow": [255, 238, 51] 17 | }, 18 | "materials": { 19 | "rubber": "Rubber", 20 | "metal": "MyMetal", 21 | "crystal": "Crystal" 22 | }, 23 | "sizes": { 24 | "large": 0.7, 25 | "small": 0.35 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | from care_nl_ica.metrics.dep_mat import JacobianBinnedPrecisionRecall 2 | import torch 3 | 4 | 5 | def test_jacobian_prec_recall(): 6 | num_dim = 3 7 | num_thresholds = 15 8 | jac_pr = JacobianBinnedPrecisionRecall(num_thresholds=num_thresholds) 9 | target = ( 10 | ( 11 | torch.tril(torch.bernoulli(0.5 * torch.ones(num_dim, num_dim)), -1) 12 | + torch.eye(num_dim) 13 | ) 14 | .bool() 15 | .float() 16 | ) 17 | preds = torch.tril(torch.randn_like(target)) 18 | jac_pr(preds, target) 19 | precisions, recalls, thresholds = jac_pr.compute() 20 | print(precisions, recalls, thresholds) 21 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/docker/tmux.conf: -------------------------------------------------------------------------------- 1 | # change prefix key 2 | unbind C-b 3 | set -g prefix 'C-\' 4 | bind 'C-\' send-prefix 5 | 6 | # large history 7 | set -g history-limit 10000 8 | 9 | # shift arrow to switch windows 10 | bind -n F9 previous-window 11 | bind -n F10 next-window 12 | 13 | # THEME 14 | set -g status-bg black 15 | set -g status-fg white 16 | set -g window-status-current-bg white 17 | set -g window-status-current-fg black 18 | set -g window-status-current-attr bold 19 | set -g status-interval 60 20 | set -g status-left-length 30 21 | set -g status-left '#[fg=green](#S) #(whoami)' 22 | set -g status-right '#[fg=yellow]#(cut -d " " -f 1-3 /proc/loadavg)#[default] #[fg=white]%H:%M#[default]' -------------------------------------------------------------------------------- /care_nl_ica/losses/dep_mat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from care_nl_ica.metrics.ica_dis import frobenius_diagonality 4 | 5 | 6 | def permutation_loss(matrix: torch.Tensor, matrix_power: bool = False): 7 | if matrix_power is False: 8 | # rows and cols sum up to 1 9 | col_sum = matrix.abs().sum(0) 10 | row_sum = matrix.abs().sum(1) 11 | loss = (col_sum - torch.ones_like(col_sum)).pow(2).mean() + ( 12 | row_sum - torch.ones_like(row_sum) 13 | ).pow(2).mean() 14 | else: 15 | # diagonality (as Q^n = I for permutation matrices) 16 | loss = frobenius_diagonality(matrix.matrix_power(matrix.shape[0]).abs()) 17 | 18 | return loss 19 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem3.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,1,2,3,4,5] 22 | data.latent_dim: 23 | value: 3 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 0 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem5.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,5,31,75, 97, 119] 22 | data.latent_dim: 23 | value: 5 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 0 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem3_no_permute.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 3 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 1 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem5_no_permute.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 5 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 0 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem8.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,575, 8754, 12450, 21345, 36755] 22 | data.latent_dim: 23 | value: 8 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 0 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem8_no_permute.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 8 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 1 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem10_no_permute.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 25000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 10 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 1 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem15_no_permute.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 35000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 15 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 1 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem10.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 25000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,657894, 1257846, 1875424, 2645120, 3628600] 22 | data.latent_dim: 23 | value: 10 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 0 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /care_nl_ica/losses/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from typing import Dict 4 | 5 | 6 | @dataclass 7 | class ContrastiveLosses: 8 | cl_pos: float = 0.0 9 | cl_neg: float = 0.0 10 | cl_entropy: float = 0.0 11 | 12 | @property 13 | def total_loss(self): 14 | total_loss = 0.0 15 | for key, loss in self.__dict__.items(): 16 | if key != "cl_entropy": 17 | total_loss += loss 18 | 19 | return total_loss 20 | 21 | def log_dict(self) -> Dict[str, float]: 22 | return { 23 | "cl_pos": self.cl_pos, 24 | "cl_neg": self.cl_neg, 25 | "cl_entropy": self.cl_entropy, 26 | "total": self.total_loss, 27 | } 28 | -------------------------------------------------------------------------------- /tests/test_dep_mat.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | 5 | from care_nl_ica.dep_mat import calc_jacobian 6 | from care_nl_ica.models.model import ContrastiveLearningModel 7 | 8 | 9 | def test_calc_jacobian(args): 10 | m = ContrastiveLearningModel( 11 | Namespace( 12 | **{ 13 | **args.model, 14 | **args.data, 15 | "device": "cuda" if torch.cuda.is_available() else "cpu", 16 | } 17 | ) 18 | ) 19 | x = torch.randn(64, args.data.latent_dim) 20 | 21 | assert ( 22 | torch.allclose( 23 | calc_jacobian(m, x, vectorize=True), calc_jacobian(m, x, vectorize=False) 24 | ) 25 | == True 26 | ) 27 | -------------------------------------------------------------------------------- /sweeps/tcl/mlp_ar.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | seed_everything: 15 | values: [42, 64, 982, 5748, 23567] 16 | trainer.max_epochs: 17 | value: 20000 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 3 22 | data.latent_dim: 23 | value: 6 24 | data.use_sem: 25 | value: false 26 | data.n_mixing_layer: 27 | values: [1,2,3,4,5] 28 | data.data_gen_mode: 29 | value: offset 30 | data.mlp_sparsity: 31 | value: true 32 | data.offset: 33 | value: 3 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem5_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,5,31,75, 97, 119] 22 | data.latent_dim: 23 | value: 5 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 1 32 | data.mask_prob: 33 | value: 0.25 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/tcl/mlp_ar_adaptive_offset.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | seed_everything: 15 | values: [42, 64, 982, 5748, 23567] 16 | trainer.max_epochs: 17 | value: 20000 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 3 22 | data.latent_dim: 23 | value: 6 24 | data.use_sem: 25 | value: false 26 | data.n_mixing_layer: 27 | values: [1,2,3,4,5] 28 | data.data_gen_mode: 29 | value: offset 30 | data.mlp_sparsity: 31 | value: true 32 | data.offset: 33 | value: 0 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem3_no_permute_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 3 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 1 32 | data.mask_prob: 33 | value: 0.25 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem5_no_permute_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 5 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 0 32 | data.mask_prob: 33 | value: 0.25 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem8_no_permute_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 8 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 0 32 | data.mask_prob: 33 | value: 0.25 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem8_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 15000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,575, 8754, 12450, 21345, 36755] 22 | data.latent_dim: 23 | value: 8 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 1 32 | data.mask_prob: 33 | value: 0.5 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem10_no_permute_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 7500 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567, 8734, 564, 74452, 96, 26] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 0 22 | data.latent_dim: 23 | value: 10 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: false 30 | data.offset: 31 | value: 0 32 | data.mask_prob: 33 | value: 0.25 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/sem/mlp_sem10_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | trainer.max_epochs: 15 | value: 25000 16 | seed_everything: 17 | values: [42, 64, 982, 5748, 23567] 18 | model.offline: 19 | value: true 20 | data.variant: 21 | values: [0,657894, 1257846, 1875424, 2645120, 3628600] 22 | data.latent_dim: 23 | value: 10 24 | data.use_sem: 25 | value: true 26 | data.nonlin_sem: 27 | values: [false, true] 28 | data.permute: 29 | value: true 30 | data.offset: 31 | value: 1 32 | data.mask_prob: 33 | value: 0.5 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /sweeps/tcl/mlp_ar_adaptive_offset_sparse.yaml: -------------------------------------------------------------------------------- 1 | command: 2 | - ${env} 3 | - python3 4 | - care_nl_ica/cli.py 5 | - fit 6 | - "--config" 7 | - "configs/config.yaml" 8 | - ${args} 9 | method: grid 10 | metric: 11 | goal: minimize 12 | name: val_loss 13 | parameters: 14 | seed_everything: 15 | values: [42, 64, 982, 5748, 23567] 16 | trainer.max_epochs: 17 | value: 20000 18 | model.offline: 19 | value: true 20 | data.variant: 21 | value: 3 22 | data.latent_dim: 23 | value: 6 24 | data.use_sem: 25 | value: false 26 | data.n_mixing_layer: 27 | values: [1,2,3,4,5] 28 | data.data_gen_mode: 29 | value: offset 30 | data.mlp_sparsity: 31 | value: true 32 | data.offset: 33 | value: 0 34 | data.mask_prob: 35 | value: 0.25 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Reizinger" 5 | given-names: "Patrik" 6 | orcid: "https://orcid.org/0000-0001-9861-0293" 7 | - family-names: "Sharma" 8 | given-names: "Yash" 9 | - family-names: "Bethge" 10 | given-names: "Matthias" 11 | - family-names: "Schölkopf" 12 | given-names: "Bernhard" 13 | orcid: "https://orcid.org/0000-0002-8177-0925" 14 | - family-names: "Huszár" 15 | given-names: "Ferenc" 16 | orcid: "https://orcid.org/0000-0002-4988-1430" 17 | - family-names: "Brendel" 18 | given-names: "Wieland" 19 | orcid: "https://orcid.org/0000-0003-0982-552X" 20 | title: "nl-causal-representations" 21 | version: 1.0.0 22 | doi: 10.5281/zenodo.7002143 23 | date-released: 2022-08-17 24 | url: "https://github.com/rpatrik96/nl-causal-representations" 25 | -------------------------------------------------------------------------------- /scripts/loss_compare_for_permutations.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | N_STEPS=25001 4 | START_STEP=5000 5 | 6 | QR_LOSS=3e-1 7 | TRIANGULARITY_LOSS=0 8 | ENTROPY_COEFF=0 9 | LOSS_FLAGS="--qr-loss ${QR_LOSS} --triangularity-loss ${TRIANGULARITY_LOSS} --entropy-coeff ${ENTROPY_COEFF}" 10 | 11 | DATA_FLAGS="--use-sem --permute" 12 | LOG_FLAGS="--verbose --normalize-latents" 13 | MODEL_FLAGS="--use-ar-mlp" 14 | 15 | for seed in 84645646; do 16 | for n in 3; do 17 | fact=1 18 | for ((i = 1; i <= ${n}; i++)); do 19 | fact=$(($fact * $i)) 20 | done 21 | 22 | for ((variant = 0; variant < ${fact}; variant++)); do 23 | ./scripts/start_preemptable_job.sh --seed ${seed} --n ${n} --variant "${variant}" ${LOSS_FLAGS} --n-steps ${N_STEPS} --start-step ${START_STEP} ${DATA_FLAGS} ${LOG_FLAGS} "$@" & 24 | sleep 5 25 | done 26 | done 27 | done 28 | -------------------------------------------------------------------------------- /configs/cpu.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | auto_select_gpus: false 3 | overfit_batches: 0.0 4 | check_val_every_n_epoch: 200 5 | fast_dev_run: false 6 | max_epochs: 5500 7 | min_epochs: null 8 | max_steps: -1 9 | min_steps: null 10 | max_time: null 11 | gpus: 0 12 | limit_train_batches: 1.0 13 | limit_val_batches: 1.0 14 | limit_test_batches: 1.0 15 | limit_predict_batches: 1.0 16 | val_check_interval: 1.0 17 | log_every_n_steps: 250 18 | enable_model_summary: true 19 | weights_summary: top 20 | num_sanity_val_steps: 2 21 | profiler: null 22 | benchmark: false 23 | deterministic: true 24 | detect_anomaly: true 25 | plugins: null 26 | data: 27 | batch_size: 512 28 | latent_dim: 3 29 | diag_weight: 0.4 30 | permute: true 31 | variant: 5 32 | force_uniform: false 33 | force_chain: true 34 | model: 35 | start_step: 4500 36 | lr: 1e-3 -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/datasets/simple_image_dataset.py: -------------------------------------------------------------------------------- 1 | """Loads all images in a folder while ignoring class information etc.""" 2 | 3 | import torch.utils.data 4 | from typing import Optional, Callable 5 | import glob 6 | import os 7 | import torchvision 8 | 9 | 10 | class SimpleImageDataset(torch.utils.data.Dataset): 11 | def __init__(self, root: str, transform: Optional[Callable] = None): 12 | self.root = root 13 | self.image_paths = list(sorted(list(glob.glob(os.path.join(root, "*.*"))))) 14 | 15 | if transform is None: 16 | transform = lambda x: x 17 | self.transform = transform 18 | 19 | self.loader = torchvision.datasets.folder.pil_loader 20 | 21 | def __len__(self): 22 | return len(self.image_paths) 23 | 24 | def __getitem__(self, item): 25 | assert 0 <= item < len(self) 26 | return self.transform(self.loader(self.image_paths[item])) 27 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/docker/Dockerfile.blender: -------------------------------------------------------------------------------- 1 | # To build this Dockerfile you have to let it inherit from the other one 2 | 3 | # Set the time zone correctly 4 | ENV TZ=Europe/Berlin 5 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 6 | 7 | ENV SHELL /bin/bash 8 | 9 | RUN wget https://download.blender.org/release/Blender2.91/blender-2.91.0-linux64.tar.xz 10 | RUN tar -xf blender-2.91.0-linux64.tar.xz -C /usr/local/bin/ 11 | ENV PATH "$PATH:/usr/local/bin/blender-2.91.0-linux64" 12 | 13 | RUN apt-get update && apt-get install -y --no-install-recommends \ 14 | libxxf86vm1 libxfixes-dev libglu1-mesa-dev freeglut3-dev mesa-common-dev \ 15 | && apt-get clean \ 16 | && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 17 | 18 | COPY entrypoint.sh /usr/local/bin/ 19 | RUN chmod a+x /usr/local/bin/entrypoint.sh 20 | 21 | ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] 22 | 23 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /notebooks/sem_5d_munkres_sweep_kdau2seo.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,swept-sweep-13,5,False,0,1,True,True,False,False,0.9994050747388996,1.9483914375305176 3 | 1,skilled-sweep-14,5,False,0,1,True,True,False,False,0.9993582123422764,1.893308401107788 4 | 2,royal-sweep-9,5,False,0,1,True,False,False,False,0.9998361494785972,1.94585919380188 5 | 3,valiant-sweep-10,5,False,0,1,True,False,False,False,0.816690862169852,3.03879714012146 6 | 4,breezy-sweep-6,5,False,0,1,True,False,False,False,0.9997581404610733,1.9464752674102783 7 | 5,winter-sweep-7,5,False,0,1,True,False,False,False,0.8795997458112131,2.9836673736572266 8 | 6,glad-sweep-5,5,False,0,1,True,False,False,False,0.9995376654534928,1.968914031982422 9 | 7,honest-sweep-1,5,False,0,1,True,False,False,False,0.9998311645399612,1.9033708572387695 10 | 8,wandering-sweep-3,5,False,0,1,True,False,False,False,0.9999524627177776,1.9334661960601809 11 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/get_mean_std.py: -------------------------------------------------------------------------------- 1 | """Calculate mean and std of dataset.""" 2 | 3 | from datasets.simple_image_dataset import SimpleImageDataset 4 | import torch.utils.data 5 | import argparse 6 | import torchvision.transforms 7 | import os 8 | from tqdm import tqdm 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--root-folder", required=True) 12 | args = parser.parse_args() 13 | 14 | dataset = SimpleImageDataset( 15 | args.root_folder, transform=torchvision.transforms.ToTensor() 16 | ) 17 | 18 | full_loader = torch.utils.data.DataLoader( 19 | dataset, shuffle=True, num_workers=os.cpu_count(), batch_size=256 20 | ) 21 | 22 | mean = torch.zeros(3) 23 | std = torch.zeros(3) 24 | print("==> Computing mean and std..") 25 | for inputs in tqdm(full_loader): 26 | for i in range(3): 27 | mean[i] += inputs[:, i, :, :].mean(dim=(-1, -2)).sum(0) 28 | std[i] += inputs[:, i, :, :].std(dim=(-1, -2)).sum(0) 29 | mean.div_(len(dataset)) 30 | std.div_(len(dataset)) 31 | print(mean, std) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Patrik Reizinger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/kitti_masks/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Bethge Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Roland S. Zimmermann et al. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /care_nl_ica/models/sparsity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SparseBudgetNet(nn.Module): 6 | def __init__( 7 | self, 8 | num_dim: int, 9 | ): 10 | super(SparseBudgetNet, self).__init__() 11 | self.num_dim = num_dim 12 | self.budget: int = self.num_dim * (self.num_dim + 1) // 2 13 | self.weight = nn.Parameter( 14 | nn.Linear(self.num_dim, self.num_dim).weight, requires_grad=True 15 | ) 16 | 17 | def to(self, device): 18 | """ 19 | Move the model to the specified device. 20 | 21 | :param device: The device to move the model to. 22 | """ 23 | super().to(device) 24 | self.weight = self.weight.to(device) 25 | 26 | return self 27 | 28 | @property 29 | def mask(self): 30 | return torch.sigmoid(self.weight) 31 | 32 | @property 33 | def entropy(self): 34 | probs = torch.nn.functional.softmax(self.mask, -1).view( 35 | -1, 36 | ) 37 | 38 | return torch.distributions.Categorical(probs).entropy() 39 | 40 | @property 41 | def budget_loss(self): 42 | return torch.relu(self.budget - self.mask.sum()) 43 | -------------------------------------------------------------------------------- /scripts/run_singularity_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | IMAGE=./scripts/nv.sif 4 | 5 | userName=preizinger 6 | tmp_dir=$(mktemp -d -t singularity-home-XXXXXXXX) 7 | echo "$tmp_dir" 8 | 9 | # singularity doesn't like the image to be modified while it's running. 10 | # also, it's faster to have it on local storage 11 | echo "copy image" 12 | LOCAL_IMAGE=$tmp_dir/nv.sif 13 | rsync -av --progress $IMAGE "$LOCAL_IMAGE" 14 | 15 | 16 | mkdir -p "$tmp_dir"/.vscode-server 17 | mkdir -p "$tmp_dir"/.conda 18 | mkdir -p "$tmp_dir"/.ipython 19 | mkdir -p "$tmp_dir"/.jupyter 20 | mkdir -p "$tmp_dir"/.local 21 | mkdir -p "$tmp_dir"/.pylint.d 22 | mkdir -p "$tmp_dir"/.cache 23 | 24 | singularity exec -p --nv \ 25 | --bind "$tmp_dir"/.vscode-server:/mnt/qb/work/bethge/$userName/.vscode-server \ 26 | --bind "$tmp_dir"/.conda:/mnt/qb/work/bethge/$userName/.conda \ 27 | --bind "$tmp_dir"/.ipython:/mnt/qb/work/bethge/$userName/.ipython \ 28 | --bind "$tmp_dir"/.jupyter:/mnt/qb/work/bethge/$userName/.jupyter \ 29 | --bind "$tmp_dir"/.local:/mnt/qb/work/bethge/$userName/.local \ 30 | --bind "$tmp_dir"/.pylint.d:/mnt/qb/work/bethge/$userName/.pylint.d \ 31 | --bind "$tmp_dir"/.cache:/mnt/qb/work/bethge/$userName/.cache \ 32 | --bind /scratch_local \ 33 | --bind /home/bethge/$userName \ 34 | --bind /mnt/qb/work/bethge \ 35 | "$LOCAL_IMAGE" "$@" 36 | 37 | rm -rf "$tmp_dir" 38 | -------------------------------------------------------------------------------- /scripts/nv.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: nvidia/cuda:11.3.0-cudnn8-runtime-ubuntu20.04 3 | 4 | 5 | %environment 6 | # overwrite Singularity prompt with something more useful 7 | export PS1="\u@\h \W§ " 8 | 9 | %files 10 | # copy the dropbear keys into the container 11 | /mnt/qb/work/bethge/preizinger/dropbear /etc/dropbear 12 | 13 | # project requirement file 14 | ../requirements.txt 15 | 16 | %post 17 | # required to indicate that no interaction is required for installing packages with apt 18 | # otherwise, installing packages can fail 19 | export DEBIAN_FRONTEND=noninteractive 20 | 21 | # if the cuda container cannot be verified as nvidia did not update the key (GPG error) 22 | # https://github.com/NVIDIA/nvidia-docker/issues/619#issuecomment-359580806 23 | rm /etc/apt/sources.list.d/cuda.list 24 | 25 | # sometimes fetching the key explicitly can help 26 | # https://github.com/NVIDIA/nvidia-docker/issues/619#issuecomment-359695597 27 | apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub 28 | 29 | apt-get -qq -y update 30 | 31 | # apt-get install -y build-essential cmake git rsync 32 | 33 | apt-get install -y python3 python3-pip python3-wheel python3-yaml intel-mkl-full 34 | 35 | # ssh client for the vs code setup 36 | apt-get install -y dropbear 37 | chmod +r -R /etc/dropbear 38 | 39 | apt-get clean 40 | 41 | # project requirements install 42 | pip3 install -r requirements.txt 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /notebooks/sem_3d_sparse_sweep_lm8s890w.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,classic-sweep-20,3,False,0,1,True,True,False,False,0.9999596225772244,4.484454154968262 3 | 1,northern-sweep-19,3,False,0,1,True,True,False,False,0.9999133110494473,4.493100643157959 4 | 2,lucky-sweep-18,3,False,0,1,True,True,False,False,0.999934173215041,4.493666648864746 5 | 3,sunny-sweep-16,3,False,0,1,True,True,False,False,0.9999475003711812,4.4822492599487305 6 | 4,ethereal-sweep-15,3,False,0,1,True,True,False,False,0.9999135937156256,4.465629577636719 7 | 5,splendid-sweep-14,3,False,0,1,True,True,False,False,0.9999369829837246,4.470257759094238 8 | 6,cerulean-sweep-13,3,False,0,1,True,True,False,False,0.9998532325373448,4.4454545974731445 9 | 7,restful-sweep-11,3,False,0,1,True,True,False,False,0.999959709496154,4.461050510406494 10 | 8,bumbling-sweep-10,3,False,0,1,True,False,False,False,0.9999579945185338,4.493952751159668 11 | 9,morning-sweep-9,3,False,0,1,True,False,False,False,0.9999619495424686,4.491983890533447 12 | 10,classic-sweep-8,3,False,0,1,True,False,False,False,0.9999773318709866,4.492518424987793 13 | 11,vital-sweep-7,3,False,0,1,True,False,False,False,0.9999648924388136,4.451825141906738 14 | 12,winter-sweep-6,3,False,0,1,True,False,False,False,0.9999844737286218,4.4821624755859375 15 | 13,golden-sweep-5,3,False,0,1,True,False,False,False,0.9999239649393522,4.46468448638916 16 | 14,apricot-sweep-4,3,False,0,1,True,False,False,False,0.9999426637679648,4.468593120574951 17 | 15,stilted-sweep-1,3,False,0,1,True,False,False,False,0.9999451954775463,4.459792137145996 18 | -------------------------------------------------------------------------------- /care_nl_ica/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | 4 | from .utils import matrix_to_dict 5 | 6 | 7 | class Logger(object): 8 | def __init__(self, hparams, model) -> None: 9 | super().__init__() 10 | self.hparams = hparams 11 | 12 | self._setup_exp_management(model) 13 | 14 | self.total_loss_values = None 15 | 16 | def _setup_exp_management(self, model): 17 | if self.hparams.use_wandb is True: 18 | wandb.init( 19 | entity="causal-representation-learning", 20 | project=self.hparams.project, 21 | notes=self.hparams.notes, 22 | config=self.hparams, 23 | tags=self.hparams.tags, 24 | ) 25 | wandb.watch(model, log_freq=self.hparams.n_log_steps, log="all") 26 | 27 | # define metrics 28 | wandb.define_metric("total_loss", summary="min") 29 | wandb.define_metric("lin_dis_score", summary="max") 30 | wandb.define_metric("perm_dis_score", summary="max") 31 | 32 | def log_jacobian( 33 | self, dep_mat, name="gt_decoder", inv_name="gt_encoder", log_inverse=True 34 | ): 35 | jac = dep_mat.detach().cpu() 36 | cols = [f"a_{i}" for i in range(dep_mat.shape[1])] 37 | 38 | gt_jacobian_dec = wandb.Table(columns=cols, data=jac.tolist()) 39 | self.log_summary(**{f"{name}_jacobian": gt_jacobian_dec}) 40 | 41 | if log_inverse is True: 42 | gt_jacobian_enc = wandb.Table(columns=cols, data=jac.inverse().tolist()) 43 | self.log_summary(**{f"{inv_name}_jacobian": gt_jacobian_enc}) 44 | -------------------------------------------------------------------------------- /scripts/interactive_job_inner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -xe 3 | 4 | HOSTNAME=$(hostname | sed s/.novalocal//) 5 | echo "Running on $HOSTNAME" 6 | 7 | # update ssh target "interactive" 8 | # this requires you to add an entry 9 | # Include config.d/interactive 10 | # to your ~/.ssh/config 11 | mkdir -p ~/.ssh/config.d 12 | echo "Host interactive" > ~/.ssh/config.d/interactive 13 | echo " Hostname $HOSTNAME" >> ~/.ssh/config.d/interactive 14 | echo " User preizinger" >> ~/.ssh/config.d/interactive 15 | 16 | SESSION=preizinger-interactive 17 | 18 | # reset $TMPDIR. In SLURM-jobs by default it would be 19 | # TMPDIR="/scratch_local/$(whoami)-${SLURM_JOB_ID}/tmp" 20 | # but this setting is lost when sshing into the node. 21 | # Therefore, we unset $TMPDIR already here. 22 | unset TMPDIR 23 | #make sure we can run this script from within tmux 24 | TMUX=0 25 | tmux new-session -d -s $SESSION 26 | tmux send-keys 'echo session started' C-m 27 | # Build correct address parameters. 28 | # Unfortunately, dropbear seems to listen only on IPv6 29 | # since recently. I haven't found the reason yet, but it's 30 | # easiest to just specify the correct addresses. 31 | ADDRESSES='' 32 | for IP in $(hostname -I) 33 | do 34 | ADDRESSES="$ADDRESSES -p $IP:12345" 35 | done 36 | # -p is for making sure sshd is killed if singularity is stopped 37 | # - R is for generating key 38 | tmux send-keys "./scripts/run_singularity_server.sh /usr/sbin/dropbear -R -E -F $ADDRESSES -s" C-m 39 | 40 | while true; do 41 | # will fail if session ended 42 | tmux has-session -t $SESSION 43 | echo "session still running" 44 | sleep 10 45 | done -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/ORIGINAL_LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For clevr-dataset-gen software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import hydra.core.global_hydra 4 | import pytest 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | from care_nl_ica.data.datamodules import ContrastiveDataModule 9 | from care_nl_ica.dataset import ContrastiveDataset 10 | 11 | arg_matrix = namedtuple("arg_matrix", ["latent_dim"]) 12 | 13 | from hydra import compose, initialize 14 | from pytorch_lightning import seed_everything 15 | from argparse import Namespace 16 | 17 | 18 | @pytest.fixture( 19 | params=[ 20 | arg_matrix(latent_dim=3), 21 | ] 22 | ) 23 | def args(request): 24 | hydra.core.global_hydra.GlobalHydra.instance().clear() 25 | initialize(config_path="../configs", job_name="test_app") 26 | 27 | cfg = compose( 28 | config_name="config", 29 | overrides=[ 30 | f"data.latent_dim={request.param.latent_dim}", 31 | "data.use_sem=true", 32 | "data.nonlin_sem=true", 33 | ], 34 | ) 35 | 36 | seed_everything(cfg.seed_everything) 37 | 38 | return cfg 39 | 40 | 41 | @pytest.fixture() 42 | def dataloader(args): 43 | args = Namespace( 44 | **{**args.data, "device": "cuda" if torch.cuda.is_available() else "cpu"} 45 | ) 46 | ds = ContrastiveDataset( 47 | args, 48 | lambda x: x 49 | @ torch.tril(torch.ones(args.latent_dim, args.latent_dim, device=args.device)), 50 | ) 51 | dl = DataLoader(ds, args.batch_size) 52 | return dl 53 | 54 | 55 | @pytest.fixture() 56 | def datamodule(args): 57 | dm = ContrastiveDataModule.from_argparse_args(Namespace(**args.data)) 58 | dm.setup() 59 | return dm 60 | -------------------------------------------------------------------------------- /care_nl_ica/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | from care_nl_ica.cl_ica import latent_spaces, spaces 5 | from care_nl_ica.prob_utils import ( 6 | setup_marginal, 7 | setup_conditional, 8 | sample_marginal_and_conditional, 9 | ) 10 | 11 | 12 | class ContrastiveDataset(torch.utils.data.IterableDataset): 13 | def __init__(self, hparams, transform=None): 14 | super().__init__() 15 | self.hparams = hparams 16 | self.transform = transform 17 | 18 | self._setup_space() 19 | 20 | self.latent_space = latent_spaces.LatentSpace( 21 | space=self.space, 22 | sample_marginal=setup_marginal(self.hparams), 23 | sample_conditional=setup_conditional(self.hparams), 24 | ) 25 | torch.cuda.empty_cache() 26 | 27 | def _setup_space(self): 28 | if self.hparams.space_type == "box": 29 | self.space = spaces.NBoxSpace( 30 | self.hparams.latent_dim, self.hparams.box_min, self.hparams.box_max 31 | ) 32 | elif self.hparams.space_type == "sphere": 33 | self.space = spaces.NSphereSpace( 34 | self.hparams.latent_dim, self.hparams.sphere_r 35 | ) 36 | else: 37 | self.space = spaces.NRealSpace(self.hparams.latent_dim) 38 | 39 | def __iter__(self): 40 | sources = torch.stack( 41 | sample_marginal_and_conditional( 42 | self.latent_space, 43 | size=self.hparams.batch_size, 44 | device=self.hparams.device, 45 | ) 46 | ) 47 | 48 | mixtures = torch.stack(tuple(map(self.transform, sources))) 49 | 50 | return iter((sources, mixtures)) 51 | -------------------------------------------------------------------------------- /notebooks/sem_3d_sweep_mmzkkmw4.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,vibrant-sweep-19,3,False,0,1,True,True,False,False,0.9999834925346676,4.483795642852783 3 | 1,logical-sweep-17,3,False,0,1,True,True,False,False,0.9999510958654018,4.452263832092285 4 | 2,distinctive-sweep-18,3,False,0,1,True,True,False,False,0.9999742569135218,4.492233753204346 5 | 3,floral-sweep-16,3,False,0,1,True,True,False,False,0.999978368969236,4.482064723968506 6 | 4,laced-sweep-15,3,False,0,1,True,True,False,False,0.9999468021267304,4.464197158813477 7 | 5,dulcet-sweep-14,3,False,0,1,True,True,False,False,0.9998824498059984,4.4695892333984375 8 | 6,serene-sweep-13,3,False,0,1,True,True,False,False,0.9999359140889986,4.4987101554870605 9 | 7,polished-sweep-10,3,False,0,1,True,False,False,False,0.999966678058504,4.492259979248047 10 | 8,cosmic-sweep-7,3,False,0,1,True,False,False,False,0.9999527718070884,4.4820380210876465 11 | 9,solar-sweep-11,3,False,0,1,True,False,False,False,0.9999627495518776,4.49423885345459 12 | 10,azure-sweep-8,3,False,0,1,True,False,False,False,0.9999658911050588,4.451657295227051 13 | 11,lunar-sweep-12,3,False,0,1,True,True,False,False,0.999911378593971,4.443646430969238 14 | 12,lunar-sweep-9,3,False,0,1,True,False,False,False,0.99997813611064,4.492734909057617 15 | 13,deep-sweep-6,3,False,0,1,True,False,False,False,0.9999621276126484,4.463892936706543 16 | 14,noble-sweep-5,3,False,0,1,True,False,False,False,0.9999522251170536,4.491568088531494 17 | 15,spring-sweep-4,3,False,0,1,True,False,False,False,0.9999769496598908,4.497834205627441 18 | 16,dark-sweep-18,3,False,0,1,True,True,False,False,0.6902247615682603,5.701242446899414 19 | 17,eager-sweep-3,3,False,0,1,True,False,False,False,0.9999642529591248,4.443384170532227 20 | 18,laced-sweep-1,3,False,0,1,True,False,False,False,0.999976140927932,4.460049629211426 21 | -------------------------------------------------------------------------------- /notebooks/sem_5d_sweep_kdau2seo.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,earnest-sweep-20,5,False,0,1,True,True,False,False,0.8190186979735528,3.074489116668701 3 | 1,playful-sweep-18,5,False,0,1,True,True,False,False,0.9993895610460726,1.9437751770019531 4 | 2,glamorous-sweep-19,5,False,0,1,True,True,False,False,0.9998020879474088,1.9493937492370603 5 | 3,celestial-sweep-17,5,False,0,1,True,True,False,False,0.822084413101042,3.0058584213256836 6 | 4,effortless-sweep-15,5,False,0,1,True,True,False,False,0.9993695883932232,1.985156536102295 7 | 5,true-sweep-16,5,False,0,1,True,True,False,False,0.8185182855941202,3.023531675338745 8 | 6,swept-sweep-13,5,False,0,1,True,True,False,False,0.9994050747388996,1.9483914375305176 9 | 7,skilled-sweep-14,5,False,0,1,True,True,False,False,0.9993582123422764,1.893308401107788 10 | 8,worldly-sweep-11,5,False,0,1,True,True,False,False,0.9997244104382724,1.913358688354492 11 | 9,royal-sweep-9,5,False,0,1,True,False,False,False,0.9998361494785972,1.94585919380188 12 | 10,valiant-sweep-10,5,False,0,1,True,False,False,False,0.816690862169852,3.03879714012146 13 | 11,breezy-sweep-6,5,False,0,1,True,False,False,False,0.9997581404610733,1.9464752674102783 14 | 12,winter-sweep-7,5,False,0,1,True,False,False,False,0.8795997458112131,2.9836673736572266 15 | 13,lunar-sweep-4,5,False,0,1,True,False,False,False,0.9997457993188376,1.8816308975219729 16 | 14,glad-sweep-5,5,False,0,1,True,False,False,False,0.9995376654534928,1.968914031982422 17 | 15,honest-sweep-1,5,False,0,1,True,False,False,False,0.9998311645399612,1.9033708572387695 18 | 16,wandering-sweep-3,5,False,0,1,True,False,False,False,0.9999524627177776,1.9334661960601809 19 | 17,quiet-sweep-8,5,False,0,1,True,False,False,False,0.9998030919100152,1.9282464981079104 20 | 18,skilled-sweep-2,5,False,0,1,True,False,False,False,0.9998855164639356,1.918325901031494 21 | -------------------------------------------------------------------------------- /scripts/exclude_nodes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Generate an exclude list of nodes for SLURM 3 | # 4 | # Usage: 5 | # exclude_nodes.sh sbatch [...] 6 | # sbatch [...] --exclude=$(exclude_nodes.sh) [...] 7 | # 8 | # Example raw output: 9 | # λ ~/ bash exclude_nodes.sh 10 | # bg-slurmb-bm-3,slurm-bm-05,slurm-bm-08,slurm-bm-32,slurm-bm-42,slurm-bm-47,slurm-bm-55,slurm-bm-60,slurm-bm-62,slurm-bm-82% 11 | # 12 | # 13 | # Based on the MLCloud monitor at 14 | # http://134.2.168.207:3001/d/VbORxQJnk/slurm-monitoring-dashboard?orgId=1&from=now-3h&to=now&refresh=1m 15 | # 16 | # Up to date version of this script available at 17 | # https://gist.github.com/stes/52a139e260e25c72a97e2180d5be3bdb 18 | 19 | get_exclude_list() { 20 | curl -XPOST '134.2.168.207:8085/api/v2/query?orgID=1a87e58d4b097066' -sS \ 21 | -H 'Accept:application/csv' \ 22 | -H 'Content-type:application/vnd.flux' \ 23 | -H 'Authorization: Bearer 88JGdt2FvqcPXJwFQi5zGon6D0z7YP54' \ 24 | -d 'from(bucket: "tueslurm") 25 | |> range(start: -35s) 26 | |> filter(fn: (r) => r["_measurement"] == "node_state") 27 | |> filter(fn: (r) => r["_field"] == "responding") 28 | |> filter(fn: (r) => r["state"] == "down" or r["state"] == "drained") 29 | |> filter(fn: (r) => r["_value"] == 0) 30 | |> group(columns: ["_time"]) 31 | |> keep(columns: ["hostname"])' 32 | } 33 | 34 | formatted_exclude_list() { 35 | get_exclude_list \ 36 | | tail -n +2 \ 37 | | cut -f4 -d, \ 38 | | tr -d '\r' \ 39 | | sort | uniq \ 40 | | sed -r '/^\s*$/d' \ 41 | | tr '\n' ',' \ 42 | | sed -e 's/,$//g' 43 | } 44 | echo $(formatted_exclude_list) 45 | mode=$1 46 | shift 1 47 | case $mode in 48 | sbatch) 49 | sbatch --exclude=$(formatted_exclude_list) "$@" 50 | ;; 51 | 52 | *) 53 | formatted_exclude_list 54 | ;; 55 | esac -------------------------------------------------------------------------------- /notebooks/sem_8d_sweep_v3kd7kca.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,graceful-firebrand-766,8,False,0,1,True,True,False,False,0.9992295456708588,0.21403074264526367 3 | 1,clear-sunset-767,8,False,0,1,True,True,False,False,0.999486315004368,0.19814586639404297 4 | 2,copper-wood-768,8,False,0,1,True,True,False,False,0.9336650331449768,0.463017463684082 5 | 3,lucky-lion-769,8,False,0,1,True,True,False,False,0.9994647722676646,0.2104320526123047 6 | 4,faithful-night-770,8,False,0,1,True,True,False,False,0.9994833014056572,0.21943902969360352 7 | 5,decent-totem-771,8,False,0,1,True,True,False,False,0.9994218803476758,0.20617437362670896 8 | 6,giddy-moon-772,8,False,0,1,True,True,False,False,0.9546204229416516,0.39360618591308594 9 | 7,dashing-sky-773,8,False,0,1,True,True,False,False,0.9350217006543976,0.48402976989746094 10 | 8,efficient-sun-774,8,False,0,1,True,True,False,False,0.9994657360326052,0.2075634002685547 11 | 9,graceful-serenity-775,8,False,0,1,True,False,False,False,0.9996403885427748,0.20478010177612305 12 | 10,logical-silence-776,8,False,0,1,True,True,False,False,0.9994780825076024,0.2234654426574707 13 | 11,fanciful-vortex-777,8,False,0,1,True,False,False,False,0.9995520811012104,0.1971039772033691 14 | 12,curious-universe-778,8,False,0,1,True,False,False,False,0.9996713932249616,0.2232203483581543 15 | 13,jolly-shape-779,8,False,0,1,True,False,False,False,0.9996595352208192,0.20654773712158203 16 | 14,iconic-plant-780,8,False,0,1,True,False,False,False,0.9995704840859012,0.21356201171875 17 | 15,solar-haze-781,8,False,0,1,True,False,False,False,0.9995237113779494,0.21808767318725583 18 | 16,colorful-eon-782,8,False,0,1,True,False,False,False,0.9996085561632492,0.19951772689819336 19 | 17,autumn-flower-783,8,False,0,1,True,False,False,False,0.999546733108815,0.2105250358581543 20 | 18,still-brook-784,8,False,0,1,True,False,False,False,0.9995925756876288,0.1993727684020996 21 | -------------------------------------------------------------------------------- /notebooks/sem_10d_sparse_sweep_t7rmrux1.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,dandy-bush-896,10,False,0,1,True,True,False,False,0.8880686119855401,0.08689498901367188 3 | 1,vital-yogurt-897,10,False,0,1,True,True,False,False,0.8482571340203637,0.22718429565429688 4 | 2,revived-glitter-898,10,False,0,1,True,True,False,False,0.36292738693421894,4.941876411437988 5 | 3,dazzling-sea-899,10,False,0,1,True,True,False,False,0.7825259533014958,0.4925127029418946 6 | 4,hardy-terrain-900,10,False,0,1,True,True,False,False,0.7131718725702294,0.4598922729492187 7 | 5,absurd-galaxy-901,10,False,0,1,True,True,False,False,0.8253574629148996,0.24730396270751953 8 | 6,mild-firefly-902,10,False,0,1,True,True,False,False,0.4120730846524629,4.418057441711426 9 | 7,peach-sun-903,10,False,0,1,True,True,False,False,0.8316040574714311,0.205294132232666 10 | 8,hearty-monkey-904,10,False,0,1,True,True,False,False,0.9384207512809172,0.08607959747314453 11 | 9,icy-frost-905,10,False,0,1,True,False,False,False,0.8498906843335347,0.20059680938720703 12 | 10,glamorous-surf-906,10,False,0,1,True,False,False,False,0.998845625412927,0.034972190856933594 13 | 11,floral-silence-907,10,False,0,1,True,False,False,False,0.8205658296206438,0.2049999237060547 14 | 12,leafy-energy-908,10,False,0,1,True,False,False,False,0.3835599676275568,4.540416240692139 15 | 13,lively-paper-909,10,False,0,1,True,False,False,False,0.8501652502760267,0.22983646392822263 16 | 14,gentle-microwave-910,10,False,0,1,True,False,False,False,0.8872115948071577,0.08292198181152344 17 | 15,lively-lion-911,10,False,0,1,True,False,False,False,0.9424619803608764,0.0835275650024414 18 | 16,pretty-bird-912,10,False,0,1,True,False,False,False,0.4180905353243494,4.421955108642578 19 | 17,wise-fog-913,10,False,0,1,True,False,False,False,0.9329449987886742,0.08167648315429688 20 | 18,comfy-bush-914,10,False,0,1,True,False,False,False,0.9125699159725038,0.07466506958007812 21 | -------------------------------------------------------------------------------- /notebooks/sem_5d_sparse_sweep_7n108utd.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,dainty-sweep-20,5,False,0,1,True,True,False,False,0.8605566693153509,3.156348705291748 3 | 1,unique-sweep-19,5,False,0,1,True,True,False,False,0.9997611339139388,1.9037175178527832 4 | 2,brisk-sweep-18,5,False,0,1,True,True,False,False,0.9997199519680472,1.93845534324646 5 | 3,lemon-sweep-17,5,False,0,1,True,True,False,False,0.9984290010990629,1.939425468444824 6 | 4,vague-sweep-16,5,False,0,1,True,True,False,False,0.8597623994691783,3.0247762203216553 7 | 5,swift-sweep-15,5,False,0,1,True,True,False,False,0.9992409951087394,1.988208055496216 8 | 6,serene-sweep-14,5,False,0,1,True,True,False,False,0.9997871540199772,1.886852741241455 9 | 7,desert-sweep-13,5,False,0,1,True,True,False,False,0.9997916521885554,1.9361343383789065 10 | 8,olive-sweep-12,5,False,0,1,True,True,False,False,0.9998924572557067,1.921983242034912 11 | 9,efficient-sweep-11,5,False,0,1,True,True,False,False,0.9996489267749146,1.9088010787963867 12 | 10,dry-sweep-10,5,False,0,1,True,False,False,False,0.8660540846011828,3.0744245052337646 13 | 11,celestial-sweep-9,5,False,0,1,True,False,False,False,0.9997190417263034,1.949199914932251 14 | 12,vague-sweep-7,5,False,0,1,True,False,False,False,0.8887474432381002,2.9533934593200684 15 | 13,sage-sweep-8,5,False,0,1,True,False,False,False,0.999764614742212,1.9274096488952637 16 | 14,lemon-sweep-6,5,False,0,1,True,False,False,False,0.8597000177644654,3.019930839538574 17 | 15,expert-sweep-5,5,False,0,1,True,False,False,False,0.9994576253931562,1.973270893096924 18 | 16,confused-sweep-4,5,False,0,1,True,False,False,False,0.9996303757641034,1.88446044921875 19 | 17,fast-sweep-2,5,False,0,1,True,False,False,False,0.9998035295351776,1.9200544357299805 20 | 18,winter-sweep-3,5,False,0,1,True,False,False,False,0.9999002216320252,1.9328291416168213 21 | 19,gentle-sweep-1,5,False,0,1,True,False,False,False,0.9998635630074012,1.903132438659668 22 | -------------------------------------------------------------------------------- /care_nl_ica/ica.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ICAModel(nn.Module): 6 | """ 7 | Linear ICA class 8 | """ 9 | 10 | def __init__( 11 | self, dim: int, signal_model: torch.distributions.distribution.Distribution 12 | ): 13 | """ 14 | :param dim: an integer specifying the number of signals 15 | :param signal_model: class of the signal model distribution 16 | """ 17 | super().__init__() 18 | 19 | self.W = torch.nn.Parameter(torch.eye(dim)) 20 | self.signal_model = signal_model 21 | 22 | def forward(self, x: torch.Tensor): 23 | """ 24 | :param x : data tensor of (num_samples, signal_dim) 25 | """ 26 | 27 | # unmixing 28 | return torch.matmul(x, self.W) 29 | 30 | @staticmethod 31 | def _ml_objective( 32 | x: torch.Tensor, signal_model: torch.distributions.laplace.Laplace 33 | ): 34 | """ 35 | Implements the ML objective 36 | 37 | :param x: tensor to be transformed 38 | :param signal_model: Laplace distribution from torch.distributions. 39 | """ 40 | 41 | # transform with location and scale 42 | x_tr = (x - signal_model.mean).abs() / signal_model.scale 43 | 44 | return -x_tr - torch.log(2 * signal_model.scale) 45 | 46 | def loss(self, x: torch.Tensor): 47 | """ 48 | :param x : data tensor of (num_samples, signal_dim) 49 | """ 50 | 51 | # ML objective 52 | model_entropy = self._ml_objective(self(x), self.signal_model).mean() 53 | 54 | # log of the abs determinant of the unmixing matrix 55 | # _, log_abs_det = torch.linalg.slogdet(self.W) 56 | _, log_abs_det = torch.slogdet(self.W) 57 | 58 | # as we need to minimize, there is a minus sign here 59 | return -model_entropy, -log_abs_det 60 | 61 | def dependency(self, row_idx: int, col_idx: int): 62 | return self.W[row_idx, col_idx] 63 | -------------------------------------------------------------------------------- /notebooks/sem_8d_sparse_sweep_2ykg2w21.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,wild-snow-831,8,False,0,1,True,True,False,False,0.5665957865913724,2.985647439956665 3 | 1,chocolate-donkey-832,8,False,0,1,True,True,False,False,0.9251126760555228,0.4785151481628418 4 | 2,deft-surf-833,8,False,0,1,True,True,False,False,0.6397167261074248,2.1681180000305176 5 | 3,polished-dragon-834,8,False,0,1,True,True,False,False,0.9985103208167582,0.21297693252563477 6 | 4,floral-snow-835,8,False,0,1,True,True,False,False,0.8637280024515255,0.4936332702636719 7 | 5,curious-blaze-836,8,False,0,1,True,True,False,False,0.782662825797224,0.991201400756836 8 | 6,feasible-shape-837,8,False,0,1,True,True,False,False,0.6179069653714904,3.0603721141815186 9 | 7,fallen-dream-838,8,False,0,1,True,True,False,False,0.634959661278368,2.437565803527832 10 | 8,expert-serenity-839,8,False,0,1,True,True,False,False,0.8881299246044088,0.4492111206054687 11 | 9,true-moon-840,8,False,0,1,True,True,False,False,0.9111807302769804,0.5042295455932617 12 | 10,celestial-morning-841,8,False,0,1,True,False,False,False,0.6847524150842897,1.8819189071655271 13 | 11,glorious-shadow-842,8,False,0,1,True,False,False,False,0.9984860055248252,0.2131686210632324 14 | 12,good-jazz-843,8,False,0,1,True,False,False,False,0.9995145437982278,0.20080232620239255 15 | 13,elated-donkey-844,8,False,0,1,True,False,False,False,0.998075153819804,0.20594167709350583 16 | 14,neat-lion-845,8,False,0,1,True,False,False,False,0.682697764186909,1.733640432357788 17 | 15,denim-yogurt-846,8,False,0,1,True,False,False,False,0.8985175284775266,0.4691781997680664 18 | 16,gallant-jazz-847,8,False,0,1,True,False,False,False,0.5814280385222736,3.046862840652466 19 | 17,cerulean-disco-848,8,False,0,1,True,False,False,False,0.5956188631893446,3.068350315093994 20 | 18,eternal-morning-849,8,False,0,1,True,False,False,False,0.99880277420109,0.2215604782104492 21 | 19,ruby-morning-850,8,False,0,1,True,False,False,False,0.999422871852962,0.197904109954834 22 | -------------------------------------------------------------------------------- /notebooks/sem_10d_sweep_i871qu61.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,dry-firefly-1108,10,False,0,1,True,True,False,False,0.9462453311432756,0.08972740173339844 3 | 1,quiet-durian-1109,10,False,0,1,True,True,False,False,0.9461753828289016,0.0746774673461914 4 | 2,light-waterfall-1110,10,False,0,1,True,True,False,False,0.9982808458600884,0.03744983673095703 5 | 3,sparkling-armadillo-1111,10,False,0,1,True,True,False,False,0.9975028595230933,0.03430938720703125 6 | 4,northern-fog-1112,10,False,0,1,True,True,False,False,0.9473158082372166,0.08133888244628906 7 | 5,true-lion-1113,10,False,0,1,True,True,False,False,0.9980520715278508,0.03544425964355469 8 | 6,peachy-rain-1114,10,False,0,1,True,True,False,False,0.9481769773466752,0.09392070770263672 9 | 7,dutiful-sound-1115,10,False,0,1,True,True,False,False,0.9448784720406526,0.08184337615966797 10 | 8,lemon-night-1116,10,False,0,1,True,True,False,False,0.9468860650409092,0.0773916244506836 11 | 9,stellar-monkey-1117,10,False,0,1,True,True,False,False,0.946317823246206,0.076263427734375 12 | 10,wild-sun-1118,10,False,0,1,True,False,False,False,0.945418066209334,0.07754802703857422 13 | 11,sparkling-meadow-1119,10,False,0,1,True,False,False,False,0.9410652837876068,0.07428455352783203 14 | 12,woven-moon-1120,10,False,0,1,True,False,False,False,0.946525500169108,0.08149337768554688 15 | 13,brisk-deluge-1121,10,False,0,1,True,False,False,False,0.9990280483408992,0.03194618225097656 16 | 14,grateful-spaceship-1122,10,False,0,1,True,False,False,False,0.947555533628678,0.08596515655517578 17 | 15,rare-sound-1123,10,False,0,1,True,False,False,False,0.930128627881422,0.08655261993408203 18 | 16,swept-eon-1124,10,False,0,1,True,False,False,False,0.9987026645030392,0.034996986389160156 19 | 17,exalted-capybara-1125,10,False,0,1,True,False,False,False,0.9989967585241224,0.028470993041992188 20 | 18,bright-firefly-1126,10,False,0,1,True,False,False,False,0.9992242973706812,0.02747154235839844 21 | 19,smart-serenity-1127,10,False,0,1,True,False,False,False,0.9415826500938776,0.07753467559814453 22 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/tools/3dident/ORIGINAL_PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the clevr-dataset-gen software contributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: [3.8] 20 | os: [ubuntu-latest] 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | with: 25 | submodules: true 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v2 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | 31 | - name: Python ${{ matrix.python-version }} cache 32 | uses: actions/cache@v2 33 | with: 34 | path: ${{ env.pythonLocation }} 35 | key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('dev-requirements.txt') }} 36 | 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | if [ -f requirements.txt ]; then pip install -r requirements.txt; pip install -e .; fi 41 | pip install --requirement tests/requirements.txt --quiet 42 | - name: Lint with flake8 43 | run: | 44 | python -m pip install flake8 45 | # stop the build if there are Python syntax errors or undefined names 46 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude cl_ica,icebeem,pytorch-flows 47 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 48 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude cl_ica,icebeem,pytorch-flows 49 | - name: black 50 | run: | 51 | black --check --verbose ./care_nl_ica/ 52 | - name: Test with pytest 53 | run: | 54 | python -m pip install pytest 55 | python -m pytest tests 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /care_nl_ica/metrics/metric_logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | 4 | 5 | class MetricLogger(object): 6 | def __init__(self): 7 | """ 8 | Initialize the metrics. 9 | """ 10 | super().__init__() 11 | 12 | self.accuracy = torchmetrics.Accuracy() 13 | self.roc_auc = torchmetrics.AUROC(num_classes=2) 14 | self.auc_score = torchmetrics.AUC() 15 | self.precision = torchmetrics.Precision() 16 | self.recall = torchmetrics.Recall() 17 | self.f1 = torchmetrics.F1() 18 | self.hamming_distance = torchmetrics.HammingDistance() 19 | self.stat_scores = torchmetrics.StatScores() # FP, FN, TP, TN 20 | 21 | def update(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> None: 22 | """ 23 | Update the metrics given the predictions and the ground truth. 24 | """ 25 | self.accuracy.update(y_pred, y_true) 26 | self.roc_auc.update(y_pred, y_true) 27 | self.auc_score.update(y_pred, y_true) 28 | self.precision.update(y_pred, y_true) 29 | self.recall.update(y_pred, y_true) 30 | self.f1.update(y_pred, y_true) 31 | self.hamming_distance.update(y_pred, y_true) 32 | self.stat_scores.update(y_pred, y_true) 33 | 34 | def compute(self) -> dict: 35 | """ 36 | Compute the metrics. 37 | """ 38 | [tp, fp, tn, fn, sup] = self.stat_scores.compute() 39 | panel_name = "Metrics" 40 | return { 41 | f"{panel_name}/accuracy": self.accuracy.compute(), 42 | # f'{panel_name}/roc_auc': self.roc_auc.compute(), 43 | # f'{panel_name}/auc_score': self.auc_score.compute(), 44 | f"{panel_name}/precision": self.precision.compute(), 45 | f"{panel_name}/recall": self.recall.compute(), 46 | f"{panel_name}/f1": self.f1.compute(), 47 | f"{panel_name}/hamming_distance": self.hamming_distance.compute(), 48 | f"{panel_name}/tp": tp, 49 | f"{panel_name}/fp": fp, 50 | f"{panel_name}/tn": tn, 51 | f"{panel_name}/fn": fn, 52 | f"{panel_name}/support": sup, 53 | } 54 | -------------------------------------------------------------------------------- /care_nl_ica/independence/indep_check.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from care_nl_ica.independence.hsic import HSIC 4 | 5 | 6 | class IndependenceChecker(object): 7 | """ 8 | Class for encapsulating independence test-related methods 9 | """ 10 | 11 | def __init__(self, hparams) -> None: 12 | super().__init__() 13 | self.hparams = hparams 14 | 15 | self.test = HSIC(hparams.num_permutations) 16 | 17 | print("Using Bonferroni = 4") 18 | 19 | def check_bivariate_dependence(self, x1, x2): 20 | decisions = [] 21 | var_map = [1, 1, 2, 2] 22 | with torch.no_grad(): 23 | decisions.append( 24 | self.test.run_test(x1[:, 0], x2[:, 1], bonferroni=4).item() 25 | ) 26 | decisions.append( 27 | self.test.run_test(x1[:, 0], x2[:, 0], bonferroni=4).item() 28 | ) 29 | decisions.append( 30 | self.test.run_test(x1[:, 1], x2[:, 0], bonferroni=4).item() 31 | ) 32 | decisions.append( 33 | self.test.run_test(x1[:, 1], x2[:, 1], bonferroni=4).item() 34 | ) 35 | 36 | return decisions, var_map 37 | 38 | def check_multivariate_dependence( 39 | self, x1: torch.Tensor, x2: torch.Tensor 40 | ) -> torch.Tensor: 41 | """ 42 | Carries out HSIC for the multivariate case, all pairs are tested 43 | :param x1: tensor of the first batch of variables in the shape of (num_elem, num_dim) 44 | :param x2: tensor of the second batch of variables in the shape of (num_elem, num_dim) 45 | :return: the adjacency matrix 46 | """ 47 | num_dim = x1.shape[-1] 48 | max_edge_num = num_dim**2 49 | adjacency_matrix = torch.zeros(num_dim, num_dim).bool() 50 | 51 | print(max_edge_num) 52 | 53 | with torch.no_grad(): 54 | for i in range(num_dim): 55 | for j in range(num_dim): 56 | adjacency_matrix[i, j] = self.test.run_test( 57 | x1[:, i], x2[:, j], bonferroni=4 # max_edge_num 58 | ).item() 59 | 60 | return adjacency_matrix 61 | -------------------------------------------------------------------------------- /care_nl_ica/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | def to_one_hot(x, m=None): 7 | "batch one hot" 8 | if type(x) is not list: 9 | x = [x] 10 | if m is None: 11 | ml = [] 12 | for xi in x: 13 | ml += [xi.max() + 1] 14 | m = max(ml) 15 | dtp = x[0].dtype 16 | xoh = [] 17 | for i, xi in enumerate(x): 18 | xoh += [np.zeros((xi.size, int(m)), dtype=dtp)] 19 | xoh[i][np.arange(xi.size), xi.astype(np.int)] = 1 20 | return xoh 21 | 22 | 23 | def one_hot_encode(labels, n_labels=10): 24 | """ 25 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1. 26 | """ 27 | 28 | assert np.min(labels) >= 0 and np.max(labels) < n_labels 29 | 30 | y = np.zeros([labels.size, n_labels]).astype(np.float32) 31 | y[range(labels.size), labels] = 1 32 | 33 | return y 34 | 35 | 36 | def single_one_hot_encode(label, n_labels=10): 37 | """ 38 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1. 39 | """ 40 | 41 | assert label >= 0 and label < n_labels 42 | 43 | y = np.zeros([n_labels]).astype(np.float32) 44 | y[label] = 1 45 | 46 | return y 47 | 48 | 49 | def single_one_hot_encode_rev(label, n_labels=10, start_label=0): 50 | """ 51 | Transforms numeric labels to 1-hot encoded labels. Assumes numeric labels are in the range 0, 1, ..., n_labels-1. 52 | """ 53 | assert label >= start_label and label < n_labels 54 | y = np.zeros([n_labels - start_label]).astype(np.float32) 55 | y[label - start_label] = 1 56 | return y 57 | 58 | 59 | mnist_one_hot_transform = lambda label: single_one_hot_encode(label, n_labels=10) 60 | contrastive_one_hot_transform = lambda label: single_one_hot_encode(label, n_labels=2) 61 | 62 | 63 | def make_dir(dir_name): 64 | if dir_name[-1] != "/": 65 | dir_name += "/" 66 | if not os.path.exists(dir_name): 67 | os.makedirs(dir_name) 68 | return dir_name 69 | 70 | 71 | def make_file(file_name): 72 | if not os.path.exists(file_name): 73 | open(file_name, "a").close() 74 | return file_name 75 | -------------------------------------------------------------------------------- /tests/test_graph_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from care_nl_ica.graph_utils import ( 5 | indirect_causes, 6 | graph_paths, 7 | false_positive_paths, 8 | false_negative_paths, 9 | ) 10 | 11 | 12 | @pytest.fixture 13 | def three_dim_chain() -> torch.Tensor: 14 | dim = 3 15 | dep_mat = torch.tril(torch.ones(dim, dim)) 16 | zeros_in_chain = torch.tril(torch.ones_like(dep_mat), -2) 17 | dep_mat[zeros_in_chain == 1] = 0 18 | 19 | return dep_mat 20 | 21 | 22 | def test_graph_paths(three_dim_chain: torch.Tensor): 23 | gt_paths = { 24 | 0: torch.Tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), 25 | 1: torch.Tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), 26 | } 27 | 28 | [ 29 | torch.all(path, gt_path) 30 | for path, gt_path in zip(graph_paths(three_dim_chain).values(), gt_paths) 31 | ] 32 | 33 | 34 | def test_indirect_causes(three_dim_chain: torch.Tensor): 35 | # ground truth 36 | indirect_paths = torch.zeros_like(three_dim_chain) 37 | indirect_paths[-1, 0] = 1 38 | 39 | print(indirect_causes(three_dim_chain)[1]) 40 | 41 | torch.all(indirect_causes(three_dim_chain)[0] == indirect_paths) 42 | 43 | 44 | @pytest.mark.parametrize("weighted", [True, False]) 45 | def test_false_positive_paths(three_dim_chain: torch.Tensor, weighted): 46 | # false positive at [2,0] 47 | dep_mat = torch.tril(torch.rand_like(three_dim_chain)) + 1.3 48 | direct_causes = torch.tril((three_dim_chain.abs() > 1e-6).float(), -1) 49 | 50 | fp = torch.Tensor([1 if weighted is False else dep_mat[2, 0], 0]) 51 | torch.all( 52 | false_positive_paths(dep_mat, graph_paths(direct_causes), weighted=weighted) 53 | == fp 54 | ) 55 | 56 | 57 | @pytest.mark.parametrize("weighted", [True, False]) 58 | def test_false_negative_paths(three_dim_chain: torch.Tensor, weighted): 59 | gt_adjacency = torch.tril(torch.rand_like(three_dim_chain)) + 1.3 60 | direct_causes = torch.tril((gt_adjacency.abs() > 1e-6).float(), -1) 61 | 62 | # false negative at [2,0] - here three_dim_chain is the estimated value 63 | fn = torch.Tensor([1 if weighted is False else gt_adjacency[2, 0], 0]) 64 | torch.all( 65 | false_negative_paths( 66 | three_dim_chain, graph_paths(direct_causes), weighted=weighted 67 | ) 68 | == fn 69 | ) 70 | -------------------------------------------------------------------------------- /notebooks/monti_sweep_77huh2ue.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,charmed-sweep-25,6,False,3,5,False,False,False,False,0.8379107954836166,1.854844570159912 3 | 1,fiery-sweep-24,6,False,3,5,False,False,False,False,0.8497126701689969,1.8678297996520996 4 | 2,northern-sweep-23,6,False,3,5,False,False,False,False,0.840644201969512,1.8648576736450195 5 | 3,happy-sweep-21,6,False,3,5,False,False,False,False,0.880131429668987,2.045818567276001 6 | 4,clean-sweep-22,6,False,3,5,False,False,False,False,0.8293840628008139,1.9895012378692627 7 | 5,sweepy-sweep-20,6,False,3,4,False,False,False,False,0.8387286044061634,1.864729881286621 8 | 6,noble-sweep-17,6,False,3,4,False,False,False,False,0.8350038247753466,1.88857102394104 9 | 7,fanciful-sweep-16,6,False,3,4,False,False,False,False,0.8100973160695278,1.8514604568481443 10 | 8,golden-sweep-18,6,False,3,4,False,False,False,False,0.8343979017561596,1.8834435939788816 11 | 9,fanciful-sweep-19,6,False,3,4,False,False,False,False,0.847744241010156,1.8607234954833984 12 | 10,woven-sweep-15,6,False,3,3,False,False,False,False,0.8421843332186746,1.828907251358032 13 | 11,trim-sweep-13,6,False,3,3,False,False,False,False,0.8306750434765068,1.8489642143249512 14 | 12,lucky-sweep-14,6,False,3,3,False,False,False,False,0.9993110388035972,1.0367021560668943 15 | 13,zany-sweep-12,6,False,3,3,False,False,False,False,0.9873989134725712,1.1389002799987793 16 | 14,eager-sweep-9,6,False,3,2,False,False,False,False,0.9996324727750469,1.0297703742980957 17 | 15,wild-sweep-10,6,False,3,2,False,False,False,False,0.999326604292618,1.0480751991271973 18 | 16,vocal-sweep-11,6,False,3,3,False,False,False,False,0.9989177035541376,1.04758882522583 19 | 17,zesty-sweep-7,6,False,3,2,False,False,False,False,0.9991668196876312,1.0220303535461426 20 | 18,sage-sweep-8,6,False,3,2,False,False,False,False,0.9990669578066456,1.0225944519042969 21 | 19,good-sweep-5,6,False,3,1,False,False,False,False,0.9998565660776868,1.0176987648010254 22 | 20,cerulean-sweep-6,6,False,3,2,False,False,False,False,0.9998933114660916,1.0363197326660156 23 | 21,zany-sweep-4,6,False,3,1,False,False,False,False,0.999831198266047,1.0289840698242188 24 | 22,fast-sweep-3,6,False,3,1,False,False,False,False,0.9998441213652508,1.017876148223877 25 | 23,fresh-sweep-2,6,False,3,1,False,False,False,False,0.9997517473969434,1.056753158569336 26 | 24,genial-sweep-1,6,False,3,1,False,False,False,False,0.9999328830460192,1.008960247039795 27 | -------------------------------------------------------------------------------- /notebooks/monti_sweep_3rqfxkyl.csv: -------------------------------------------------------------------------------- 1 | ,name,dim,permute,variant,n_mixing_layer,use_sem,nonlin_sem,force_chain,force_uniform,mcc,val_loss 2 | 0,twilight-snowball-981,6,False,3,5,False,False,False,False,0.6164149991392273,2.9853672981262207 3 | 1,atomic-sun-982,6,False,3,5,False,False,False,False,0.6815620257946304,3.0047552585601807 4 | 2,warm-disco-983,6,False,3,5,False,False,False,False,0.6343687499663566,2.9120864868164062 5 | 3,deep-gorge-984,6,False,3,5,False,False,False,False,0.5551614959805392,3.3627450466156006 6 | 4,amber-cosmos-985,6,False,3,5,False,False,False,False,0.5284008640275427,5.690755844116211 7 | 5,scarlet-planet-986,6,False,3,4,False,False,False,False,0.9876972542156532,1.0956225395202637 8 | 6,fresh-pyramid-987,6,False,3,4,False,False,False,False,0.968844345402296,1.1442599296569824 9 | 7,morning-frog-988,6,False,3,4,False,False,False,False,0.9810430198144292,1.1612944602966309 10 | 8,worldly-frost-989,6,False,3,4,False,False,False,False,0.9958163626763924,1.0813336372375488 11 | 9,avid-brook-990,6,False,3,4,False,False,False,False,0.9548101426294048,1.3085899353027344 12 | 10,expert-frog-991,6,False,3,3,False,False,False,False,0.9984634171836134,1.0292787551879885 13 | 11,pleasant-meadow-992,6,False,3,3,False,False,False,False,0.9978086505101568,1.0431675910949707 14 | 12,upbeat-moon-993,6,False,3,3,False,False,False,False,0.9916541639921048,1.0663161277770996 15 | 13,stellar-flower-994,6,False,3,2,False,False,False,False,0.9999024314825214,1.0448369979858398 16 | 14,whole-bush-995,6,False,3,2,False,False,False,False,0.9997775690224052,1.030303955078125 17 | 15,whole-hill-996,6,False,3,1,False,False,False,False,0.9997894463259328,1.0464377403259275 18 | 16,lemon-wildflower-997,6,False,3,3,False,False,False,False,0.9980021135135936,1.086085319519043 19 | 17,flowing-dream-998,6,False,3,3,False,False,False,False,0.9971217896171256,1.066732406616211 20 | 18,swept-sun-999,6,False,3,2,False,False,False,False,0.9998097101599744,1.0740838050842283 21 | 19,driven-brook-1000,6,False,3,2,False,False,False,False,0.999663395469589,1.0601258277893066 22 | 20,fresh-tree-1001,6,False,3,2,False,False,False,False,0.9996586516190464,1.0395126342773438 23 | 21,vital-waterfall-1002,6,False,3,1,False,False,False,False,0.9996800321129156,1.0390777587890625 24 | 22,autumn-monkey-1003,6,False,3,1,False,False,False,False,0.9997013129758958,1.030467510223389 25 | 23,dry-snowflake-1004,6,False,3,1,False,False,False,False,0.9998626596067872,1.001848220825195 26 | 24,true-deluge-1005,6,False,3,1,False,False,False,False,0.999675473268248,1.0738143920898438 27 | -------------------------------------------------------------------------------- /care_nl_ica/cli.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.loggers.wandb import WandbLogger 2 | from pytorch_lightning.utilities.cli import LightningCLI 3 | 4 | from care_nl_ica.utils import add_tags, install_package 5 | from care_nl_ica.data.datamodules import ContrastiveDataModule 6 | from care_nl_ica.runner import ContrastiveICAModule 7 | 8 | 9 | class MyLightningCLI(LightningCLI): 10 | def add_arguments_to_parser(self, parser): 11 | parser.add_argument( 12 | "--notes", 13 | type=str, 14 | default=None, 15 | help="Notes for the run on Weights and Biases", 16 | ) 17 | # todo: process notes based on args in before_instantiate_classes 18 | parser.add_argument( 19 | "--tags", 20 | type=str, 21 | nargs="*", # 0 or more values expected => creates a list 22 | default=None, 23 | help="Tags for the run on Weights and Biases", 24 | ) 25 | 26 | parser.link_arguments("data.latent_dim", "model.latent_dim") 27 | parser.link_arguments("data.box_min", "model.box_min") 28 | parser.link_arguments("data.box_max", "model.box_max") 29 | parser.link_arguments("data.sphere_r", "model.sphere_r") 30 | parser.link_arguments("data.normalize_latents", "model.normalize_latents") 31 | 32 | def before_instantiate_classes(self) -> None: 33 | self.config[self.subcommand].trainer.logger.init_args.tags = add_tags( 34 | self.config[self.subcommand] 35 | ) 36 | 37 | def before_fit(self): 38 | if isinstance(self.trainer.logger, WandbLogger) is True: 39 | # required as the parser cannot parse the "-" symbol 40 | self.trainer.logger.__dict__["_wandb_init"][ 41 | "entity" 42 | ] = "causal-representation-learning" 43 | 44 | if self.config[self.subcommand].model.offline is True: 45 | self.trainer.logger.__dict__["_wandb_init"]["mode"] = "offline" 46 | else: 47 | self.trainer.logger.__dict__["_wandb_init"]["mode"] = "online" 48 | 49 | # todo: maybe set run in the CLI to false and call watch before? 50 | self.trainer.logger.watch(self.model, log="all", log_freq=250) 51 | 52 | 53 | if __name__ == "__main__": 54 | install_package() 55 | cli = MyLightningCLI( 56 | ContrastiveICAModule, 57 | ContrastiveDataModule, 58 | save_config_callback=None, 59 | run=True, 60 | parser_kwargs={"parse_as_dict": False}, 61 | ) 62 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/latent_spaces.py: -------------------------------------------------------------------------------- 1 | """Classes that combine spaces with specific probability densities.""" 2 | 3 | from typing import Callable, List 4 | from .spaces import Space 5 | import torch 6 | 7 | 8 | class LatentSpace: 9 | """Combines a topological space with a marginal and conditional density to sample from.""" 10 | 11 | def __init__( 12 | self, space: Space, sample_marginal: Callable, sample_conditional: Callable 13 | ): 14 | self.space = space 15 | self._sample_marginal = sample_marginal 16 | self._sample_conditional = sample_conditional 17 | 18 | @property 19 | def sample_conditional(self): 20 | if self._sample_conditional is None: 21 | raise RuntimeError("sample_conditional was not set") 22 | return lambda *args, **kwargs: self._sample_conditional( 23 | self.space, *args, **kwargs 24 | ) 25 | 26 | @sample_conditional.setter 27 | def sample_conditional(self, value: Callable): 28 | assert callable(value) 29 | self._sample_conditional = value 30 | 31 | @property 32 | def sample_marginal(self): 33 | if self._sample_marginal is None: 34 | raise RuntimeError("sample_marginal was not set") 35 | return lambda *args, **kwargs: self._sample_marginal( 36 | self.space, *args, **kwargs 37 | ) 38 | 39 | @sample_marginal.setter 40 | def sample_marginal(self, value: Callable): 41 | assert callable(value) 42 | self._sample_marginal = value 43 | 44 | @property 45 | def dim(self): 46 | return self.space.dim 47 | 48 | 49 | class ProductLatentSpace(LatentSpace): 50 | """A latent space which is the cartesian product of other latent spaces.""" 51 | 52 | def __init__(self, spaces: List[LatentSpace]): 53 | self.spaces = spaces 54 | 55 | def sample_conditional(self, z, size, **kwargs): 56 | x = [] 57 | n = 0 58 | for s in self.spaces: 59 | if len(z.shape) == 1: 60 | z_s = z[n : n + s.space.n] 61 | else: 62 | z_s = z[:, n : n + s.space.n] 63 | n += s.space.n 64 | x.append(s.sample_conditional(z=z_s, size=size, **kwargs)) 65 | 66 | return torch.cat(x, -1) 67 | 68 | def sample_marginal(self, size, **kwargs): 69 | x = [s.sample_marginal(size=size, **kwargs) for s in self.spaces] 70 | 71 | return torch.cat(x, -1) 72 | 73 | @property 74 | def dim(self): 75 | return sum([s.dim for s in self.spaces]) 76 | -------------------------------------------------------------------------------- /care_nl_ica/models/tcl/tcl_preprocessing.py: -------------------------------------------------------------------------------- 1 | """Preprocessing""" 2 | 3 | import numpy as np 4 | 5 | 6 | # ============================================================ 7 | # ============================================================ 8 | def pca(x, num_comp=None, params=None, zerotolerance=1e-7): 9 | """Apply PCA whitening to data. 10 | Args: 11 | x: data. 2D ndarray [num_comp, num_data] 12 | num_comp: number of components 13 | params: (option) dictionary of PCA parameters {'mean':?, 'W':?, 'A':?}. If given, apply this to the data 14 | zerotolerance: (option) 15 | Returns: 16 | x: whitened data 17 | parms: parameters of PCA 18 | mean: subtracted mean 19 | W: whitening matrix 20 | A: mixing matrix 21 | """ 22 | # print("PCA...") 23 | 24 | # Dimension 25 | if num_comp is None: 26 | num_comp = x.shape[0] 27 | # print(" num_comp={0:d}".format(num_comp)) 28 | 29 | # From learned parameters -------------------------------- 30 | if params is not None: 31 | # Use previously-trained model 32 | print(" use learned value") 33 | data_pca = x - params["mean"] 34 | x = np.dot(params["W"], data_pca) 35 | 36 | # Learn from data ---------------------------------------- 37 | else: 38 | # Zero mean 39 | xmean = np.mean(x, 1).reshape([-1, 1]) 40 | x = x - xmean 41 | 42 | # Eigenvalue decomposition 43 | xcov = np.cov(x) 44 | d, V = np.linalg.eigh(xcov) # Ascending order 45 | # Convert to descending order 46 | d = d[::-1] 47 | V = V[:, ::-1] 48 | 49 | zeroeigval = np.sum((d[:num_comp] / d[0]) < zerotolerance) 50 | if zeroeigval > 0: # Do not allow zero eigenval 51 | raise ValueError 52 | 53 | # Calculate contribution ratio 54 | contratio = np.sum(d[:num_comp]) / np.sum(d) 55 | # print(" contribution ratio={0:f}".format(contratio)) 56 | 57 | # Construct whitening and dewhitening matrices 58 | dsqrt = np.sqrt(d[:num_comp]) 59 | dsqrtinv = 1 / dsqrt 60 | V = V[:, :num_comp] 61 | # Whitening 62 | W = np.dot(np.diag(dsqrtinv), V.transpose()) # whitening matrix 63 | A = np.dot(V, np.diag(dsqrt)) # de-whitening matrix 64 | x = np.dot(W, x) 65 | 66 | params = {"mean": xmean, "W": W, "A": A} 67 | 68 | # Check 69 | datacov = np.cov(x) 70 | 71 | return x, params 72 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # start ssh 4 | service ssh restart 5 | 6 | # Create user account 7 | if [ -n "$USER" ]; then 8 | if [ -z "$USER_HOME" ]; then 9 | export USER_HOME=/home/$USER 10 | fi 11 | 12 | if [ -z "$USER_ID" ]; then 13 | export USER_ID=99 14 | fi 15 | 16 | if [ -n "$USER_ENCRYPTED_PASSWORD" ]; then 17 | useradd -M -d $USER_HOME -p $USER_ENCRYPTED_PASSWORD -u $USER_ID $USER > /dev/null 18 | else 19 | useradd -M -d $USER_HOME -u $USER_ID $USER > /dev/null 20 | fi 21 | 22 | # expects a comma-separated string of the form GROUP1:GROUP1ID,GROUP2,GROUP3:GROUP3ID,... 23 | # (the GROUPID is optional, but needs to be separated from the group name by a ':') 24 | for i in $(echo $USER_GROUPS | sed "s/,/ /g") 25 | do 26 | if [[ $i == *":"* ]] 27 | then 28 | addgroup ${i%:*} # > /dev/null 29 | groupmod -g ${i#*:} ${i%:*} #> /dev/null 30 | adduser $USER ${i%:*} #> /dev/null 31 | else 32 | addgroup $i > /dev/null 33 | adduser $USER $i > /dev/null 34 | fi 35 | done 36 | 37 | # add user to sudo group 38 | adduser $USER sudo 39 | 40 | # set correct primary group 41 | if [ -n "$USER_GROUPS" ]; then 42 | group="$( cut -d ',' -f 1 <<< "$USER_GROUPS" )" 43 | if [[ $group == *":"* ]] 44 | then 45 | usermod -g ${group%:*} $USER & 46 | else 47 | usermod -g $group $USER & 48 | fi 49 | fi 50 | 51 | # set shell 52 | if [ -z "$USER_SHELL" ] 53 | then 54 | usermod -s "/bin/bash" $USER 55 | else 56 | usermod -s $USER_SHELL $USER 57 | fi 58 | 59 | if [ -n $CWD ]; then cd $CWD; fi 60 | echo "Running as user $USER" 61 | 62 | # set environment such that gpus can be used even in ssh connections 63 | echo "export CUDNN_VERSION=$CUDNN_VERSION" >> /etc/profile 64 | echo "export NVIDIA_REQUIRE_CUDA='$NVIDIA_REQUIRE_CUDA'" >> /etc/profile 65 | echo "export LIBRARY_PATH=$LIBRARY_PATH" >> /etc/profile 66 | echo "export LD_PRELOAD=$LD_PRELOAD" >> /etc/profile 67 | echo "export NVIDIA_VISIBLE_DEVICES=$NVIDIA_VISIBLE_DEVICES" >> /etc/profile 68 | echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> /etc/profile 69 | echo "export NVIDIA_DRIVER_CAPABILITIES=$NVIDIA_DRIVER_CAPABILITIES" >> /etc/profile 70 | echo "export PATH=$PATH" >> /etc/profile 71 | echo "export CUDA_PKG_VERSION=$CUDA_PKG_VERSION" >> /etc/profile 72 | 73 | exec gosu $USER "$@" 74 | else 75 | if [ -n $CWD ]; then cd $CWD; fi 76 | echo "Running as default container user" 77 | exec "$@" 78 | fi 79 | -------------------------------------------------------------------------------- /care_nl_ica/graph_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def indirect_causes(gt_jacobian_encoder) -> torch.Tensor: 5 | """ 6 | Calculates all indirect paths in the encoder (SEM/SCM) 7 | :param gt_jacobian_encoder: 8 | :return: 9 | """ 10 | 11 | # calculate the indirect cause mask 12 | eps = 1e-6 13 | direct_causes = torch.tril((gt_jacobian_encoder.abs() > eps).float(), -1) 14 | 15 | # add together the matrix powers of the adjacency matrix 16 | # this yields all indirect paths 17 | paths = graph_paths(direct_causes) 18 | 19 | indirect_causes = torch.stack(list(paths.values())).sum(0) 20 | 21 | indirect_causes = ( 22 | indirect_causes.bool().float() 23 | ) # convert all non-1 value to 1 (for safety) 24 | # correct for causes where both the direct and indirect paths are present 25 | indirect_causes = indirect_causes * ((indirect_causes - direct_causes) > 0).float() 26 | 27 | return indirect_causes, paths 28 | 29 | 30 | def graph_paths(direct_causes: torch.Tensor) -> dict: 31 | paths = dict() 32 | matrix_power = direct_causes.clone() 33 | 34 | for i in range(direct_causes.shape[0]): 35 | if matrix_power.sum() == 0: 36 | break 37 | 38 | paths[str(i)] = matrix_power 39 | matrix_power = matrix_power @ direct_causes 40 | 41 | return paths 42 | 43 | 44 | def false_positive_paths( 45 | dep_mat, gt_paths: dict, threshold: float = 1e-2, weighted: bool = False 46 | ) -> torch.Tensor: 47 | direct_causes = torch.tril((dep_mat.abs() > threshold).float(), -1) 48 | dep_mat_paths = graph_paths(direct_causes) 49 | 50 | weighting = lambda gt_path, path: ( 51 | (1 - gt_path) 52 | * path 53 | * (dep_mat if weighted is True else torch.ones_like(dep_mat)) 54 | ).sum() 55 | 56 | return torch.Tensor( 57 | [ 58 | weighting(gt_path, path) 59 | for gt_path, path in zip(gt_paths.values(), dep_mat_paths.values()) 60 | ] 61 | ) 62 | 63 | 64 | def false_negative_paths( 65 | dep_mat, gt_paths: dict, threshold: float = 1e-2, weighted: bool = False 66 | ) -> torch.Tensor: 67 | direct_causes = torch.tril((dep_mat.abs() > threshold).float(), -1) 68 | dep_mat_paths = graph_paths(direct_causes) 69 | 70 | weighting = lambda gt_path, path: ( 71 | (1 - path[gt_path.bool()]) 72 | * (dep_mat if weighted is True else torch.ones_like(dep_mat)) 73 | ).sum() 74 | 75 | return torch.Tensor( 76 | [ 77 | weighting(gt_path, path) 78 | for gt_path, path in zip(gt_paths.values(), dep_mat_paths.values()) 79 | ] 80 | ) 81 | -------------------------------------------------------------------------------- /care_nl_ica/losses/dsm.py: -------------------------------------------------------------------------------- 1 | ### conditional dsm objective 2 | # 3 | # this code is adapted from: https://github.com/ermongroup/ncsn/ 4 | # 5 | 6 | import torch 7 | import torch.autograd as autograd 8 | 9 | 10 | def dsm(energy_net, samples, sigma=1): 11 | samples.requires_grad_(True) 12 | vector = torch.randn_like(samples) * sigma 13 | perturbed_inputs = samples + vector 14 | logp = -energy_net(perturbed_inputs) 15 | dlogp = ( 16 | sigma**2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0] 17 | ) 18 | kernel = vector 19 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2 20 | loss = loss.mean() / 2.0 21 | 22 | return loss 23 | 24 | 25 | def cdsm(energy_net, samples, conditions, sigma=1.0): 26 | """ 27 | Conditional denoising score matching 28 | :param energy_net: an energy network that takes x and y as input and outputs energy of shape (batch_size,) 29 | :param samples: values of dependent variable x 30 | :param conditions: values of conditioning variable y 31 | :param sigma: noise level for dsm 32 | :return: cdsm loss of shape (batch_size,) 33 | """ 34 | samples.requires_grad_(True) 35 | vector = torch.randn_like(samples) * sigma 36 | perturbed_inputs = samples + vector 37 | logp = -energy_net(perturbed_inputs, conditions) 38 | assert logp.ndim == 1 39 | dlogp = ( 40 | sigma**2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0] 41 | ) 42 | kernel = vector 43 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2 44 | loss = loss.mean() / 2.0 45 | return loss 46 | 47 | 48 | def conditional_dsm(energy_net, samples, segLabels, energy_net_final_layer, sigma=1): 49 | samples.requires_grad_(True) 50 | vector = torch.randn_like(samples) * sigma 51 | perturbed_inputs = samples + vector 52 | 53 | d = samples.shape[-1] 54 | 55 | # apply conditioning 56 | logp = -energy_net(perturbed_inputs).view(-1, d * d) 57 | logp = torch.mm(logp, energy_net_final_layer) 58 | # take only relevant segment energy 59 | logp = logp[segLabels] 60 | 61 | dlogp = ( 62 | sigma**2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0] 63 | ) 64 | kernel = vector 65 | loss = torch.norm(dlogp + kernel, dim=-1) ** 2 66 | loss = loss.mean() / 2.0 67 | 68 | return loss 69 | 70 | 71 | def dsm_score_estimation(scorenet, samples, sigma=0.01): 72 | perturbed_samples = samples + torch.randn_like(samples) * sigma 73 | target = -1 / (sigma**2) * (perturbed_samples - samples) 74 | scores = scorenet(perturbed_samples) 75 | target = target.view(target.shape[0], -1) 76 | scores = scores.view(scores.shape[0], -1) 77 | loss = 1 / 2.0 * ((scores - target) ** 2).sum(dim=-1).mean(dim=0) 78 | 79 | return loss 80 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | logger: 4 | class_path: pytorch_lightning.loggers.WandbLogger 5 | init_args: 6 | log_model: true 7 | project: experiment 8 | save_dir: 9 | callbacks: 10 | - class_path: pytorch_lightning.callbacks.EarlyStopping 11 | init_args: 12 | monitor: val_loss 13 | mode: min 14 | patience: 40 15 | auto_select_gpus: true 16 | checkpoint_callback: null 17 | enable_checkpointing: false 18 | default_root_dir: null 19 | gradient_clip_val: null 20 | gradient_clip_algorithm: null 21 | process_position: 0 22 | log_gpu_memory: null 23 | progress_bar_refresh_rate: null 24 | enable_progress_bar: true 25 | overfit_batches: 0.0 26 | track_grad_norm: -1 27 | check_val_every_n_epoch: 100 28 | fast_dev_run: false 29 | accumulate_grad_batches: null 30 | max_epochs: 5000 31 | min_epochs: null 32 | max_steps: -1 33 | min_steps: null 34 | max_time: null 35 | gpus: 1 36 | limit_train_batches: 400.0 37 | limit_val_batches: 5.0 38 | limit_test_batches: 1.0 39 | limit_predict_batches: 1.0 40 | val_check_interval: 1.0 41 | flush_logs_every_n_steps: null 42 | log_every_n_steps: 250 43 | accelerator: null 44 | strategy: null 45 | sync_batchnorm: false 46 | precision: 32 47 | enable_model_summary: true 48 | weights_summary: top 49 | weights_save_path: null 50 | num_sanity_val_steps: 2 51 | resume_from_checkpoint: null 52 | profiler: null 53 | benchmark: false 54 | deterministic: true 55 | reload_dataloaders_every_n_epochs: 0 56 | reload_dataloaders_every_epoch: false 57 | auto_lr_find: false 58 | replace_sampler_ddp: true 59 | auto_scale_batch_size: false 60 | prepare_data_per_node: null 61 | amp_backend: native 62 | amp_level: null 63 | move_metrics_to_cpu: false 64 | multiple_trainloader_mode: max_size_cycle 65 | stochastic_weight_avg: false 66 | detect_anomaly: true 67 | model: 68 | offline: true 69 | lr: 1e-4 70 | verbose: false 71 | p: 1.0 72 | tau: 1.0 73 | alpha: 0.5 74 | normalization: '' 75 | start_step: null 76 | use_bias: false 77 | data: 78 | n_mixing_layer: 1 79 | permute: false 80 | act_fct: leaky_relu 81 | use_sem: true 82 | data_gen_mode: rvs 83 | variant: 0 84 | nonlin_sem: false 85 | box_min: 0.0 86 | box_max: 1.0 87 | sphere_r: 1.0 88 | space_type: box 89 | m_p: 0 90 | c_p: 1 91 | m_param: 1.0 92 | c_param: 0.05 93 | batch_size: 6144 94 | latent_dim: 3 95 | normalize_latents: true 96 | use_dep_mat: true 97 | train_transforms: null 98 | val_transforms: null 99 | test_transforms: null 100 | dims: null 101 | diag_weight: 0.0 102 | notes: null 103 | tags: null 104 | ckpt_path: null 105 | -------------------------------------------------------------------------------- /care_nl_ica/prob_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def setup_marginal(args): 5 | device = args.device 6 | eta = torch.zeros(args.latent_dim) 7 | if args.space_type == "sphere": 8 | eta[0] = args.sphere_r 9 | 10 | if args.m_p == 1: 11 | sample_marginal = lambda space, size, device=device: space.laplace( 12 | eta, args.m_param, size, device 13 | ) 14 | elif args.m_p == 2: 15 | sample_marginal = lambda space, size, device=device: space.normal( 16 | eta, args.m_param, size, device 17 | ) 18 | elif args.m_p == 0: 19 | sample_marginal = lambda space, size, device=device: space.uniform( 20 | size, device=device 21 | ) 22 | else: 23 | sample_marginal = lambda space, size, device=device: space.generalized_normal( 24 | eta, args.m_param, p=args.m_p, size=size, device=device 25 | ) 26 | return sample_marginal 27 | 28 | 29 | def sample_marginal_and_conditional(latent_space, size, device): 30 | z = latent_space.sample_marginal(size=size, device=device) 31 | z_tilde = latent_space.sample_conditional(z, size=size, device=device) 32 | z_neg = latent_space.sample_marginal(size=size, device=device) 33 | 34 | return z, z_tilde, z_neg 35 | 36 | 37 | def setup_conditional(args): 38 | device = args.device 39 | if args.c_p == 1: 40 | sample_conditional = lambda space, z, size, device=device: space.laplace( 41 | z, args.c_param, size, device 42 | ) 43 | elif args.c_p == 2: 44 | sample_conditional = lambda space, z, size, device=device: space.normal( 45 | z, args.c_param, size, device 46 | ) 47 | elif args.c_p == 0: 48 | sample_conditional = ( 49 | lambda space, z, size, device=device: space.von_mises_fisher( 50 | z, 51 | args.c_param, 52 | size, 53 | device, 54 | ) 55 | ) 56 | else: 57 | sample_conditional = ( 58 | lambda space, z, size, device=device: space.generalized_normal( 59 | z, args.c_param, p=args.c_p, size=size, device=device 60 | ) 61 | ) 62 | return sample_conditional 63 | 64 | 65 | def laplace_log_cdf(x: torch.Tensor, signal_model: torch.distributions.laplace.Laplace): 66 | """ 67 | Log cdf of the Laplace distribution (numerically stable). 68 | Source: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/internal/special_math.py#L281 69 | 70 | :param x: tensor to be transformed 71 | :param signal_model: Laplace distribution from torch.distributions. 72 | """ 73 | 74 | # transform with location and scale 75 | x_tr = (x - signal_model.mean) / signal_model.scale 76 | 77 | # x < 0 78 | neg_res = torch.log(torch.tensor(0.5)) + x_tr 79 | 80 | # x >= 0 81 | pos_res = torch.log1p(-0.5 * (-x_tr.abs()).exp()) 82 | 83 | return torch.where(x < signal_model.mean, neg_res, pos_res) 84 | -------------------------------------------------------------------------------- /care_nl_ica/cl_ica/layers.py: -------------------------------------------------------------------------------- 1 | """Additional layers not included in PyTorch.""" 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.modules.conv as conv 6 | from typing import Optional 7 | from typing_extensions import Literal 8 | 9 | 10 | class PositionalEncoding(nn.Module): 11 | """Add a positional encoding as two additional channels to the data.""" 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, x): 17 | pos = torch.stack( 18 | torch.meshgrid( 19 | torch.arange(x.shape[-2], dtype=torch.float, device=x.device), 20 | torch.arange(x.shape[-1], dtype=torch.float, device=x.device), 21 | ), 22 | 0, 23 | ) 24 | pos /= torch.max(pos) + 1e-12 25 | pos = torch.repeat_interleave(pos.unsqueeze(0), len(x), 0) 26 | 27 | return torch.cat((pos, x), 1) 28 | 29 | 30 | class Lambda(nn.Module): 31 | """Apply a lambda function to the input.""" 32 | 33 | def __init__(self, f): 34 | super().__init__() 35 | self.f = f 36 | 37 | def forward(self, *args, **kwargs): 38 | return self.f(*args, **kwargs) 39 | 40 | 41 | class Flatten(Lambda): 42 | """Flatten the input data after the batch dimension.""" 43 | 44 | def __init__(self): 45 | super().__init__(lambda x: x.view(len(x), -1)) 46 | 47 | 48 | class RescaleLayer(nn.Module): 49 | """Normalize the data to a hypersphere with fixed/variable radius.""" 50 | 51 | def __init__( 52 | self, init_r=1.0, fixed_r=False, mode: Optional[Literal["eq", "leq"]] = "eq" 53 | ): 54 | super().__init__() 55 | self.fixed_r = fixed_r 56 | assert mode in ("leq", "eq") 57 | self.mode = mode 58 | if fixed_r: 59 | self.r = torch.ones(1, requires_grad=False) * init_r 60 | else: 61 | self.r = nn.Parameter(torch.ones(1, requires_grad=True) * init_r) 62 | 63 | def forward(self, x): 64 | if self.mode == "eq": 65 | x = x / torch.norm(x, dim=-1, keepdim=True) 66 | x = x * self.r.to(x.device) 67 | elif self.mode == "leq": 68 | norm = torch.norm(x, dim=-1, keepdim=True) 69 | x[norm > self.r] /= torch.norm(x, dim=-1, keepdim=True) / self.r 70 | 71 | return x 72 | 73 | 74 | class SoftclipLayer(nn.Module): 75 | """Normalize the data to a hyperrectangle with fixed/learnable size.""" 76 | 77 | def __init__(self, n, init_abs_bound=1.0, fixed_abs_bound=True): 78 | super().__init__() 79 | self.fixed_abs_bound = fixed_abs_bound 80 | if fixed_abs_bound: 81 | self.max_abs_bound = torch.ones(n, requires_grad=False) * init_abs_bound 82 | else: 83 | self.max_abs_bound = nn.Parameter( 84 | torch.ones(n, requires_grad=True) * init_abs_bound 85 | ) 86 | 87 | def forward(self, x): 88 | x = torch.sigmoid(x) 89 | x = x * self.max_abs_bound.to(x.device).unsqueeze(0) 90 | 91 | return x 92 | -------------------------------------------------------------------------------- /care_nl_ica/models/ivae/ivae_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import optim 4 | from torch.utils.data import DataLoader 5 | 6 | from care_nl_ica.data import ConditionalDataset 7 | from .ivae_core import iVAE 8 | 9 | 10 | def IVAE_wrapper( 11 | X, 12 | U, 13 | batch_size=256, 14 | max_iter=7e4, 15 | seed=0, 16 | n_layers=3, 17 | hidden_dim=20, 18 | lr=1e-3, 19 | cuda=True, 20 | ckpt_file="ivae.pt", 21 | test=False, 22 | ): 23 | "args are the arguments from the main.py file" 24 | torch.manual_seed(seed) 25 | np.random.seed(seed) 26 | 27 | device = torch.device("cuda:0" if cuda else "cpu") 28 | # print('training on {}'.format(torch.cuda.get_device_name(device) if cuda else 'cpu')) 29 | 30 | # load data 31 | # print('Creating shuffled dataset..') 32 | dset = ConditionalDataset(X.astype(np.float32), U.astype(np.float32), device) 33 | loader_params = {"num_workers": 1, "pin_memory": True} if cuda else {} 34 | train_loader = DataLoader( 35 | dset, shuffle=True, batch_size=batch_size, **loader_params 36 | ) 37 | data_dim, latent_dim, aux_dim = dset.get_dims() 38 | N = len(dset) 39 | max_epochs = int(max_iter // len(train_loader) + 1) 40 | 41 | # define model and optimizer 42 | # print('Defining model and optimizer..') 43 | model = iVAE( 44 | latent_dim, 45 | data_dim, 46 | aux_dim, 47 | activation="lrelu", 48 | device=device, 49 | n_layers=n_layers, 50 | hidden_dim=hidden_dim, 51 | ) 52 | optimizer = optim.Adam(model.parameters(), lr=lr) 53 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 54 | optimizer, factor=0.1, patience=20, verbose=True 55 | ) 56 | 57 | # training loop 58 | if not test: 59 | print("Training..") 60 | it = 0 61 | model.train() 62 | while it < max_iter: 63 | elbo_train = 0 64 | epoch = it // len(train_loader) + 1 65 | for _, (x, u) in enumerate(train_loader): 66 | it += 1 67 | optimizer.zero_grad() 68 | x, u = x.to(device), u.to(device) 69 | elbo, z_est = model.elbo(x, u) 70 | elbo.mul(-1).backward() 71 | optimizer.step() 72 | elbo_train += -elbo.item() 73 | elbo_train /= len(train_loader) 74 | scheduler.step(elbo_train) 75 | # print('epoch {}/{} \tloss: {}'.format(epoch, max_epochs, elbo_train)) 76 | # save model checkpoint after training 77 | torch.save(model.state_dict(), ckpt_file) 78 | else: 79 | model = torch.load(ckpt_file, map_location=device) 80 | 81 | Xt, Ut = dset.x, dset.y 82 | decoder_params, encoder_params, z, prior_params = model( 83 | Xt.to(device), Ut.to(device) 84 | ) 85 | params = { 86 | "decoder": decoder_params, 87 | "encoder": encoder_params, 88 | "prior": prior_params, 89 | } 90 | 91 | return z, model, params 92 | -------------------------------------------------------------------------------- /care_nl_ica/models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from care_nl_ica.cl_ica import encoders, losses 5 | 6 | 7 | class ContrastiveLearningModel(nn.Module): 8 | def __init__(self, hparams): 9 | super().__init__() 10 | 11 | self.hparams = hparams 12 | 13 | self._setup_unmixing() 14 | self._setup_loss() 15 | 16 | torch.cuda.empty_cache() 17 | 18 | def parameters(self, recurse: bool = True): 19 | parameters = list(self.unmixing.parameters(recurse)) 20 | 21 | return parameters 22 | 23 | def _setup_unmixing(self): 24 | hparams = self.hparams 25 | 26 | ( 27 | output_normalization, 28 | output_normalization_kwargs, 29 | ) = self._configure_output_normalization() 30 | 31 | self.unmixing = encoders.get_mlp( 32 | n_in=hparams.latent_dim, 33 | n_out=hparams.latent_dim, 34 | layers=[ 35 | hparams.latent_dim * 10, 36 | hparams.latent_dim * 50, 37 | hparams.latent_dim * 50, 38 | hparams.latent_dim * 50, 39 | hparams.latent_dim * 50, 40 | hparams.latent_dim * 10, 41 | ], 42 | output_normalization=output_normalization, 43 | output_normalization_kwargs=output_normalization_kwargs, 44 | ) 45 | 46 | if self.hparams.verbose is True: 47 | print(f"{self.unmixing.detach()=}") 48 | 49 | self.unmixing = self.unmixing.to(hparams.device) 50 | 51 | def _setup_loss(self): 52 | hparams = self.hparams 53 | 54 | if hparams.p: 55 | self.loss = losses.LpSimCLRLoss( 56 | p=hparams.p, tau=hparams.tau, simclr_compatibility_mode=True 57 | ) 58 | else: 59 | self.loss = losses.SimCLRLoss( 60 | normalize=False, tau=hparams.tau, alpha=hparams.alpha 61 | ) 62 | 63 | def _configure_output_normalization(self): 64 | hparams = self.hparams 65 | output_normalization_kwargs = None 66 | if hparams.normalization == "learnable_box": 67 | output_normalization = "learnable_box" 68 | elif hparams.normalization == "fixed_box": 69 | output_normalization = "fixed_box" 70 | output_normalization_kwargs = dict( 71 | init_abs_bound=hparams.box_max - hparams.box_min 72 | ) 73 | elif hparams.normalization == "learnable_sphere": 74 | output_normalization = "learnable_sphere" 75 | elif hparams.normalization == "fixed_sphere": 76 | output_normalization = "fixed_sphere" 77 | output_normalization_kwargs = dict(init_r=hparams.sphere_r) 78 | elif hparams.normalization == "": 79 | print("Using no output normalization") 80 | output_normalization = None 81 | else: 82 | raise ValueError("Invalid output normalization:", hparams.normalization) 83 | return output_normalization, output_normalization_kwargs 84 | 85 | def forward(self, x): 86 | if isinstance(x, list) or isinstance(x, tuple): 87 | return tuple(map(self.unmixing, x)) 88 | else: 89 | return self.unmixing(x) 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |