├── .gitignore ├── LICENSE ├── README.md ├── data └── cytometry │ ├── dataset_1.csv │ ├── dataset_10.csv │ ├── dataset_2.csv │ ├── dataset_3.csv │ ├── dataset_4.csv │ ├── dataset_5.csv │ ├── dataset_6.csv │ ├── dataset_7.csv │ ├── dataset_8.csv │ ├── dataset_9.csv │ ├── sachs.RData │ └── sachs_details.csv ├── experiments ├── cytometry │ ├── analyze_pvalues.ipynb │ ├── analyze_pvalues.py │ ├── figures │ │ └── learned_dag.pdf │ ├── results_paper │ │ ├── cytometry_mch_kci_pvalues_env=1.npy │ │ ├── cytometry_mch_kci_pvalues_env=10.npy │ │ ├── cytometry_mch_kci_pvalues_env=2.npy │ │ ├── cytometry_mch_kci_pvalues_env=3.npy │ │ ├── cytometry_mch_kci_pvalues_env=4.npy │ │ ├── cytometry_mch_kci_pvalues_env=5.npy │ │ ├── cytometry_mch_kci_pvalues_env=6.npy │ │ ├── cytometry_mch_kci_pvalues_env=7.npy │ │ ├── cytometry_mch_kci_pvalues_env=8.npy │ │ └── cytometry_mch_kci_pvalues_env=9.npy │ └── run_cytometry_experiment.py ├── figures │ └── pairwise_oracle_pc_simulation.pdf ├── main_simulations │ ├── exp_quick_settings.py │ ├── exp_settings.py │ ├── figures │ │ ├── bivariate_multiplic_power_plots.pdf │ │ ├── bivariate_power_plots.pdf │ │ ├── empirical_select_rates_er_others.pdf │ │ ├── empirical_select_rates_er_ours.pdf │ │ ├── oracle_rate_relplot.pdf │ │ ├── oracle_select_rates_all_models.pdf │ │ ├── oracle_select_rates_ba.pdf │ │ └── oracle_select_rates_er.pdf │ ├── plot_bivariate_identifiability.ipynb │ ├── plot_bivariate_identifiability.py │ ├── plot_empirical_power.ipynb │ ├── plot_empirical_power.py │ ├── plot_oracle_rates.ipynb │ ├── plot_oracle_rates.py │ ├── results_paper │ │ ├── DEBUG_results.csv │ │ ├── bivariate_multiplic_power_results.csv │ │ ├── bivariate_power_results_paper.csv │ │ ├── old_results │ │ │ ├── environment_convergence_results.csv │ │ │ ├── environment_convergence_results_paper.csv │ │ │ ├── oracle_rates_results_power_paper.csv │ │ │ ├── oracle_select_rates_results_old.csv │ │ │ ├── pairwise_power_results_100_samples.csv │ │ │ ├── pairwise_power_results_50_reps.csv │ │ │ ├── pairwise_power_results_envs.csv │ │ │ ├── pairwise_power_results_envs_supp.csv │ │ │ ├── pairwise_power_results_first_half.csv │ │ │ ├── pairwise_power_results_n_vars.csv │ │ │ ├── pairwise_power_results_n_vars_v2.csv │ │ │ ├── pairwise_power_results_paper_05_08.csv │ │ │ ├── pairwise_power_results_paper_subset.csv │ │ │ ├── pairwise_power_results_paper_subset_v2.csv │ │ │ ├── pairwise_power_results_samples.csv │ │ │ ├── pairwise_power_results_second_half.csv │ │ │ ├── pairwise_power_results_small.csv │ │ │ ├── pairwise_power_results_sparsity.csv │ │ │ └── soft_samples_results_paper.csv │ │ ├── oracle_rates_results_paper.csv │ │ ├── oracle_select_rates_results_paper.csv │ │ └── pairwise_power_results_paper.csv │ └── run_experiment.py ├── requirements.txt ├── teaser_sparse_oracle_pc.ipynb └── teaser_sparse_oracle_pc.py ├── requirements.txt ├── setup.py └── sparse_shift ├── __init__.py ├── causal_learn ├── GraphClass.py ├── PC.py ├── SkeletonDiscovery.py └── __init__.py ├── datasets ├── __init__.py ├── dags.py ├── simulations.py └── tests │ ├── test_dags.py │ └── test_simulations.py ├── independence_tests.py ├── kcd.py ├── methods.py ├── metrics.py ├── plotting.py ├── testing.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | .idea/ 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | docs/api/generated 94 | docs/_build 95 | docs/gallery 96 | docs/tutorials 97 | docs/sample_data 98 | docs/benchmarks 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | 113 | # other 114 | .DS_Store 115 | .vscode/ 116 | .Rhistory 117 | .Rdata 118 | 119 | # Project specific 120 | *.npy 121 | data/* 122 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ronan Perry 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | Conditional independence tools for causal learning under the sparse mechanism shift hypothesis. 4 | 5 | ## Local installation 6 | 7 | From a clean python environment (e.g. `conda create -n test python=3.9`), 8 | 9 | ```console 10 | git clone https://github.com/rflperry/sparse_shift.git 11 | cd sparse_shift 12 | pip install -e . 13 | ``` 14 | 15 | ## Running experiments and generating figures 16 | 17 | First navigate and install necessary packages 18 | 19 | ```console 20 | cd experiments 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ### Teaser figure 25 | 26 | ```console 27 | cd experiments 28 | 29 | python teaser_sparse_oracle_pc.py 30 | ``` 31 | 32 | Runs the teaser experiment and generates the figure. 33 | 34 | ### Simulations 35 | 36 | ```console 37 | cd experiments/main_simulations 38 | ``` 39 | 40 | Then run the following commands to generate results and the camera-ready figures. 41 | 42 | Bivariate power: 43 | 44 | ```console 45 | python run_experiment.py --experiment bivariate_power --quick 46 | python run_experiment.py --experiment bivariate_multiplic_power --quick 47 | 48 | python plot_bivariate_identifiability.py 49 | ``` 50 | 51 | Oracle rates: 52 | 53 | ```console 54 | python run_experiment.py --experiment oracle_rates --quick 55 | python run_experiment.py --experiment oracle_select_rates --quick 56 | 57 | python plot_oracle_rates.py 58 | 59 | ``` 60 | 61 | Empirical comparison simulations: 62 | 63 | ```console 64 | python run_experiment.py --experiment pairwise_power --quick 65 | 66 | python plot_empirical_power.py 67 | 68 | ``` 69 | 70 | Remove `--quick` and add `--n_jobs -2` to run the full paper experiments. 71 | Note that this can take a long time. 72 | 73 | ### Cytometry experiment 74 | 75 | ```console 76 | cd experiments/cytometry 77 | python run_cytometry_experiment.py --quick 78 | 79 | python analyze_pvalues.py 80 | ``` 81 | 82 | Remove `--quick` and add `--n_jobs -2` to run the full paper experiments. 83 | -------------------------------------------------------------------------------- /data/cytometry/sachs.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/data/cytometry/sachs.RData -------------------------------------------------------------------------------- /data/cytometry/sachs_details.csv: -------------------------------------------------------------------------------- 1 | "id","dataset","reagent","description" 2 | 1,"cd3cd28","Anti-CD3/CD28","T cell activation" 3 | 2,"cd3cd28_icam2","ICAM-2","LFA-1 signaling induction" 4 | 3,"cd3cd28_aktinhib","AKT_inhibitor","PKA activation" 5 | 4,"cd3cd28_g0076","G06976","AKT inhibition" 6 | 5,"cd3cd28_psitect","Psitectorigenin","MEK1/MEK2 inhibition" 7 | 6,"cd3cd28_u0126","U0126","PKC activation" 8 | 7,"cd3cd28_ly","LY294002","PKC inhibition" 9 | 8,"pma","PMA","PIP2 production inhibition" 10 | 9,"b2camp","beta2cAMP","AKT inhibition" 11 | 10,"cd3cd28icam2_aktinhib","simulated","simulated dataset" 12 | 11,"cd3cd28icam2_g0076","simulated","simulated dataset" 13 | 12,"cd3cd28icam2_psit","simulated","simulated dataset" 14 | 13,"cd3cd28icam2_u0126","simulated","simulated dataset" 15 | 14,"cd3cd28icam2_ly","simulated","simulated dataset" 16 | -------------------------------------------------------------------------------- /experiments/cytometry/analyze_pvalues.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[27]: 5 | 6 | 7 | import numpy as np 8 | from sparse_shift.utils import dag2cpdag, cpdag2dags 9 | import matplotlib.pyplot as plt 10 | from sparse_shift.plotting import plot_dag 11 | import pandas as pd 12 | 13 | 14 | # # Paper figure 15 | 16 | # In[28]: 17 | 18 | 19 | dag = np.zeros((11, 11)) 20 | dag[2, np.asarray([3, 4])] = 1 21 | dag[4, 3] = 1 22 | dag[8, np.asarray([10, 7, 0, 1, 9])] = 1 23 | dag[7, np.asarray([0, 1, 5, 6, 9, 10])] = 1 24 | dag[0, 1] = 1 25 | dag[1, 5] = 1 26 | dag[5, 6] = 1 27 | 28 | true_dag = dag 29 | 30 | 31 | # ## Compute pvalues 32 | 33 | # In[67]: 34 | 35 | 36 | # Shape: (graphs, variables, envs, envs) 37 | pvalues = np.load('./results_paper/cytometry_mch_kci_pvalues_env=9.npy') 38 | 39 | 40 | # In[30]: 41 | 42 | 43 | pvalues.shape 44 | 45 | 46 | # In[31]: 47 | 48 | 49 | alpha = 0.05 / pvalues.shape[1] 50 | n_changes = np.sum(pvalues <= alpha, axis=(1, 2, 3)) / 2 51 | 52 | 53 | # ## Find MEC true DAG and min changes DAG 54 | 55 | # In[32]: 56 | 57 | 58 | true_cpdag = dag2cpdag(true_dag) 59 | dags = cpdag2dags(true_cpdag) 60 | 61 | 62 | # In[33]: 63 | 64 | 65 | np.where((dags == true_dag).all(axis=(1,2)))[0] 66 | 67 | 68 | # In[34]: 69 | 70 | 71 | np.where(n_changes == np.min(n_changes))[0] 72 | 73 | 74 | # In[35]: 75 | 76 | 77 | n_changes[43] 78 | 79 | 80 | # ## Plot min changes DAG 81 | 82 | # In[36]: 83 | 84 | 85 | # Obtain labels 86 | df1 = pd.read_csv('../../data/cytometry/dataset_1.csv') 87 | labels = [l.split('.')[1] for l in df1.columns] 88 | 89 | 90 | # In[66]: 91 | 92 | 93 | idx = np.argmin(n_changes) 94 | dag = dags[idx] 95 | plot_dag( 96 | dag, 97 | highlight_edges=dag-true_dag, 98 | labels=labels, 99 | node_size=1000, 100 | figsize=(4, 5.5), 101 | ) 102 | plt.tight_layout() 103 | plt.savefig('./figures/learned_dag.pdf') 104 | plt.show() 105 | -------------------------------------------------------------------------------- /experiments/cytometry/figures/learned_dag.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/figures/learned_dag.pdf -------------------------------------------------------------------------------- /experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=1.npy -------------------------------------------------------------------------------- /experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=10.npy -------------------------------------------------------------------------------- /experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=2.npy -------------------------------------------------------------------------------- /experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=3.npy -------------------------------------------------------------------------------- /experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=4.npy -------------------------------------------------------------------------------- /experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=5.npy -------------------------------------------------------------------------------- /experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=6.npy -------------------------------------------------------------------------------- /experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=7.npy -------------------------------------------------------------------------------- /experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=8.npy -------------------------------------------------------------------------------- /experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/cytometry/results_paper/cytometry_mch_kci_pvalues_env=9.npy -------------------------------------------------------------------------------- /experiments/cytometry/run_cytometry_experiment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from sparse_shift.utils import dag2cpdag, cpdag2dags 4 | from sparse_shift.methods import MinChange, AugmentedPC, FullMinChanges, ParamChanges 5 | from sparse_shift.metrics import dag_true_orientations, dag_false_orientations, \ 6 | dag_precision, dag_recall, average_precision_score 7 | import argparse 8 | import logging 9 | import os 10 | 11 | # CPDAG from the Sachs et al. paper 12 | dag = np.zeros((11, 11)) 13 | dag[2, np.asarray([3, 4])] = 1 14 | dag[4, 3] = 1 15 | dag[8, np.asarray([10, 7, 0, 1, 9])] = 1 16 | dag[7, np.asarray([0, 1, 5, 6, 9, 10])] = 1 17 | dag[0, 1] = 1 18 | dag[1, 5] = 1 19 | dag[5, 6] = 1 20 | 21 | true_dag = dag 22 | true_cpdag = dag2cpdag(true_dag) 23 | mec_size = len(cpdag2dags(true_cpdag)) 24 | total_edges = np.sum(true_dag) 25 | unoriented_edges = np.sum((true_cpdag + true_cpdag.T) == 2) // 2 26 | 27 | METHODS = [ 28 | ( 29 | 'mch_kci', 30 | 'Min changes (KCI)', 31 | MinChange, 32 | { 33 | 'alpha': 0.05, 34 | 'scale_alpha': True, 35 | 'test': 'kci', 36 | 'test_kwargs': { 37 | "KernelX": "GaussianKernel", 38 | "KernelY": "GaussianKernel", 39 | "KernelZ": "GaussianKernel", 40 | }, 41 | } 42 | ), 43 | ] 44 | 45 | 46 | def main(args): 47 | # Compute empirical results 48 | save_name, method_name, mch, hyperparams = METHODS[0] 49 | mch = mch(cpdag=true_cpdag, **hyperparams) 50 | results = [] 51 | 52 | Xs = [ 53 | np.log( 54 | pd.read_csv(f'../../data/cytometry/dataset_{i}.csv') 55 | ) for i in range(1, 10) 56 | ] 57 | 58 | if args.quick: 59 | # Just two environments 60 | Xs = [X[:100] for X in Xs[:3]] 61 | 62 | for n_env, X in enumerate(Xs): 63 | mch.add_environment(X) 64 | 65 | min_cpdag = mch.get_min_cpdag(False) 66 | 67 | true_orients = np.round(dag_true_orientations(true_dag, min_cpdag), 4) 68 | false_orients = np.round(dag_false_orientations(true_dag, min_cpdag), 4) 69 | precision = np.round(dag_precision(true_dag, min_cpdag), 4) 70 | recall = np.round(dag_recall(true_dag, min_cpdag), 4) 71 | 72 | results += [true_orients, false_orients, precision, recall] 73 | print(n_env, ': ', np.round(precision, 4), ', ', np.round(recall, 4)) 74 | if not os.path.exists('./results/'): 75 | os.makedirs('./results/') 76 | np.save(f"./results/cytometry_{save_name}_pvalues_env={n_env+1}.npy", mch.pvalues_) 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument( 82 | "--quick", 83 | help="Enable to run a smaller, test version", 84 | default=False, 85 | action='store_true' 86 | ) 87 | args = parser.parse_args() 88 | 89 | main(args) 90 | -------------------------------------------------------------------------------- /experiments/figures/pairwise_oracle_pc_simulation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/figures/pairwise_oracle_pc_simulation.pdf -------------------------------------------------------------------------------- /experiments/main_simulations/exp_quick_settings.py: -------------------------------------------------------------------------------- 1 | from sparse_shift.methods import MinChange, AugmentedPC, FullMinChanges, ParamChanges 2 | 3 | PARAMS_DICT = { 4 | "pairwise_power": [ 5 | { 6 | "n_variables": [6], 7 | "n_total_environments": [3], 8 | "sparsity": [1, 3, 5], 9 | 'intervention_targets': [None], 10 | "sample_size": [100], 11 | "dag_density": [0.3], 12 | "reps": [2], 13 | "data_simulator": ['cdnod'], 14 | "dag_simulator": ["er"], 15 | }, 16 | { 17 | "n_variables": [6], 18 | "n_total_environments": [5], 19 | "sparsity": [1/3], 20 | 'intervention_targets': [None], 21 | "sample_size": [100], 22 | "dag_density": [0.3, 0.5, 0.7], 23 | "reps": [2], 24 | "data_simulator": ['cdnod'], 25 | "dag_simulator": ["er"], 26 | }, 27 | { 28 | "n_variables": [3, 6, 9], 29 | "n_total_environments": [3], 30 | "sparsity": [1/3], 31 | 'intervention_targets': [None], 32 | "sample_size": [100], 33 | "dag_density": [0.3], 34 | "reps": [2], 35 | "data_simulator": ['cdnod'], 36 | "dag_simulator": ["er"], 37 | }, 38 | { 39 | "n_variables": [6], 40 | "n_total_environments": [3], 41 | "sparsity": [1/3], 42 | 'intervention_targets': [None], 43 | "sample_size": [50, 100], 44 | "dag_density": [0.3], 45 | "reps": [2], 46 | "data_simulator": ['cdnod'], 47 | "dag_simulator": ["er"], 48 | }, 49 | { 50 | "n_variables": [6], 51 | "n_total_environments": [3], 52 | "sparsity": [1/3], 53 | 'intervention_targets': [None], 54 | "sample_size": [100], 55 | "dag_density": [0.3], 56 | "reps": [2], 57 | "data_simulator": ['cdnod'], 58 | "dag_simulator": ["er"], 59 | }, 60 | ], 61 | "oracle_rates": [{ 62 | "n_variables": [4, 6, 8], 63 | "n_total_environments": [5], 64 | "sparsity": [1/5, 1/2, 4/5], 65 | 'intervention_targets': [None], 66 | "sample_size": [None], 67 | "dag_density": [0.1, 0.5, 0.9], 68 | "reps": [2], 69 | "data_simulator": [None], 70 | "dag_simulator": ["er"], 71 | }], 72 | "oracle_select_rates": [ 73 | { 74 | "n_variables": [6, 8, 10], 75 | "n_total_environments": [5], 76 | "sparsity": [0.5], 77 | 'intervention_targets': [None], 78 | "sample_size": [None], 79 | "dag_density": [0.3], 80 | "reps": [2], 81 | "data_simulator": [None], 82 | "dag_simulator": ["er", 'ba'], 83 | }, 84 | { 85 | "n_variables": [8], 86 | "n_total_environments": [5], 87 | "sparsity": [1, 3, 5, 7], 88 | 'intervention_targets': [None], 89 | "sample_size": [None], 90 | "dag_density": [0.3], 91 | "reps": [2], 92 | "data_simulator": [None], 93 | "dag_simulator": ["er", 'ba'], 94 | }, 95 | { 96 | "n_variables": [8], 97 | "n_total_environments": [5], 98 | "sparsity": [0.5], 99 | 'intervention_targets': [None], 100 | "sample_size": [None], 101 | "dag_density": [0.3], 102 | "reps": [2], 103 | "data_simulator": [None], 104 | "dag_simulator": ["er", 'ba'], 105 | }, 106 | { 107 | # Since 'ba' can't handle all of the same settings as 'er' 108 | "n_variables": [8], 109 | "n_total_environments": [5], 110 | "sparsity": [0.5], 111 | 'intervention_targets': [None], 112 | "sample_size": [None], 113 | "dag_density": [0.5, 0.9], 114 | "reps": [2], 115 | "data_simulator": [None], 116 | "dag_simulator": ['ba'], 117 | }, 118 | { 119 | "n_variables": [8], 120 | "n_total_environments": [5], 121 | "sparsity": [0.5], 122 | 'intervention_targets': [None], 123 | "sample_size": [None], 124 | "dag_density": [0.5, 0.9, 0.1], 125 | "reps": [2], 126 | "data_simulator": [None], 127 | "dag_simulator": ["er"], 128 | }, 129 | ], 130 | "bivariate_power": [{ 131 | "n_variables": [2], 132 | "n_total_environments": [2], 133 | "sparsity": [None], 134 | 'intervention_targets': [ 135 | [[], [0]], 136 | [[], [1]], 137 | [[], []], 138 | [[], [0, 1]], 139 | ], 140 | "sample_size": [100], 141 | "dag_density": [None], 142 | "reps": [2], 143 | "data_simulator": ["cdnod"], 144 | "dag_simulator": ["complete"], 145 | }], 146 | "bivariate_multiplic_power": [{ 147 | "n_variables": [2], 148 | "n_total_environments": [2], 149 | "sparsity": [None], 150 | 'intervention_targets': [ 151 | [[], [0]], 152 | [[], [1]], 153 | [[], []], 154 | [[], [0, 1]], 155 | ], 156 | "sample_size": [100], 157 | "dag_density": [None], 158 | "reps": [2], 159 | "data_simulator": ["cdnod"], 160 | "dag_simulator": ["complete"], 161 | }] 162 | } 163 | 164 | 165 | # save name, method name, algo, hpyerparams 166 | ALL_METHODS = [ 167 | ( 168 | 'mch_kci', 169 | 'Min changes (KCI)', 170 | MinChange, 171 | { 172 | 'alpha': 0.05, 173 | 'scale_alpha': True, 174 | 'test': 'kci', 175 | 'test_kwargs': { 176 | "KernelX": "GaussianKernel", 177 | "KernelY": "GaussianKernel", 178 | "KernelZ": "GaussianKernel", 179 | }, 180 | } 181 | ), 182 | ( 183 | 'mch_lin', 184 | 'Min changes (Linear)', 185 | MinChange, 186 | { 187 | 'alpha': 0.05, 188 | 'scale_alpha': True, 189 | 'test': 'invariant_residuals', 190 | 'test_kwargs': {'method': 'linear', 'test': "whitney_levene"}, 191 | } 192 | ), 193 | ( 194 | 'mch_gam', 195 | 'Min changes (GAM)', 196 | MinChange, 197 | { 198 | 'alpha': 0.05, 199 | 'scale_alpha': True, 200 | 'test': 'invariant_residuals', 201 | 'test_kwargs': {'method': 'gam', 'test': "whitney_levene"}, 202 | } 203 | ), 204 | ( 205 | 'mch_fisherz', 206 | 'Min changes (FisherZ)', 207 | MinChange, 208 | { 209 | 'alpha': 0.05, 210 | 'scale_alpha': True, 211 | 'test': 'fisherz', 212 | 'test_kwargs': {}, 213 | } 214 | ), 215 | ( 216 | 'full_pc_kci', 217 | 'Full PC (KCI)', 218 | FullMinChanges, 219 | { 220 | 'alpha': 0.05, 221 | 'scale_alpha': True, 222 | 'test': 'kci', 223 | 'test_kwargs': {}, 224 | } 225 | ), 226 | ( 227 | 'mc', 228 | 'MC', 229 | ParamChanges, 230 | { 231 | 'alpha': 0.05, 232 | 'scale_alpha': True, 233 | } 234 | ) 235 | ] 236 | 237 | METHODS_DICT = { 238 | "DEBUG": ALL_METHODS, 239 | "pairwise_power": ALL_METHODS, 240 | "oracle_rates": [], 241 | "oracle_select_rates": [], 242 | "bivariate_power": ALL_METHODS, 243 | "bivariate_multiplic_power": ALL_METHODS, 244 | } 245 | 246 | 247 | def get_experiment_params(exp): 248 | return PARAMS_DICT[exp] 249 | 250 | 251 | def get_param_keys(exp): 252 | return list(PARAMS_DICT[exp][0].keys()) 253 | 254 | 255 | def get_experiments(): 256 | return list(PARAMS_DICT.keys()) 257 | 258 | 259 | def get_experiment_methods(exp): 260 | return METHODS_DICT[exp] 261 | -------------------------------------------------------------------------------- /experiments/main_simulations/exp_settings.py: -------------------------------------------------------------------------------- 1 | from sparse_shift.methods import MinChange, AugmentedPC, FullMinChanges, ParamChanges 2 | 3 | PARAMS_DICT = { 4 | "DEBUG": [{ 5 | "experiment": "debug", 6 | "n_variables": [3], 7 | "n_total_environments": [3], 8 | "sparsity": [1], 9 | 'intervention_targets': [None], 10 | "sample_size": [100], 11 | "dag_density": [0.3], 12 | "reps": [3], 13 | "data_simulator": ["cdnod"], 14 | "dag_simulator": ["er"], 15 | }], 16 | "pairwise_power": [ 17 | { 18 | "experiment": "sparsity", 19 | "n_variables": [6], 20 | "n_total_environments": [3], 21 | "sparsity": [1, 2, 3, 4, 5], 22 | 'intervention_targets': [None], 23 | "sample_size": [500], 24 | "dag_density": [0.3], 25 | "reps": [50], 26 | "data_simulator": ['cdnod'], 27 | "dag_simulator": ["er"], 28 | }, 29 | { 30 | "experiment": "dag_density", 31 | "n_variables": [6], 32 | "n_total_environments": [3], 33 | "sparsity": [1/3], 34 | 'intervention_targets': [None], 35 | "sample_size": [500], 36 | "dag_density": [0.3, 0.5, 0.7], 37 | "reps": [50], 38 | "data_simulator": ['cdnod'], 39 | "dag_simulator": ["er"], 40 | }, 41 | { 42 | "experiment": "n_variables", 43 | "n_variables": [3, 6, 9, 12], 44 | "n_total_environments": [3], 45 | "sparsity": [1/3], 46 | 'intervention_targets': [None], 47 | "sample_size": [500], 48 | "dag_density": [0.3], 49 | "reps": [50], 50 | "data_simulator": ['cdnod'], 51 | "dag_simulator": ["er"], 52 | }, 53 | { 54 | "experiment": "sample_size", 55 | "n_variables": [6], 56 | "n_total_environments": [3], 57 | "sparsity": [1/3], 58 | 'intervention_targets': [None], 59 | "sample_size": [50, 100, 200, 500, 1000, 2000], 60 | "dag_density": [0.3], 61 | "reps": [40], 62 | "data_simulator": ['cdnod'], 63 | "dag_simulator": ["er"], 64 | }, 65 | { 66 | "experiment": "n_total_environments", 67 | "n_variables": [6], 68 | "n_total_environments": [15], 69 | "sparsity": [1/3], 70 | 'intervention_targets': [None], 71 | "sample_size": [500], 72 | "dag_density": [0.3], 73 | "reps": [40], 74 | "data_simulator": ['cdnod'], 75 | "dag_simulator": ["er"], 76 | }, 77 | ], 78 | # "environment_convergence": [{ 79 | # "n_variables": [6], 80 | # "n_total_environments": [10], 81 | # "sparsity": [1, 2, 4], 82 | # 'intervention_targets': [None], 83 | # "sample_size": [500], 84 | # "dag_density": [0.3], 85 | # "reps": [20], 86 | # "data_simulator": ["cdnod"], 87 | # "dag_simulator": ["er"], 88 | # }], 89 | # "soft_samples": [{ 90 | # "n_variables": [6], 91 | # "n_total_environments": [5], 92 | # "sparsity": [1, 2, 3, 4, 5, 6], 93 | # 'intervention_targets': [None], 94 | # "sample_size": [50, 100, 200, 300, 500], 95 | # "dag_density": [0.3], 96 | # "reps": [20], 97 | # "data_simulator": ["cdnod"], 98 | # "dag_simulator": ["er"], 99 | # }], 100 | "oracle_rates": [{ 101 | "n_variables": [4, 6, 8, 10, 12], 102 | "n_total_environments": [5], 103 | "sparsity": [1/5, 1/3, 1/2, 2/3, 4/5], 104 | 'intervention_targets': [None], 105 | "sample_size": [None], 106 | "dag_density": [0.1, 0.3, 0.5, 0.7, 0.9], 107 | "reps": [20], 108 | "data_simulator": [None], 109 | "dag_simulator": ["er"], 110 | }], 111 | "oracle_select_rates": [ 112 | { 113 | "n_variables": [6, 8, 10, 12], 114 | "n_total_environments": [5], 115 | "sparsity": [0.5], 116 | 'intervention_targets': [None], 117 | "sample_size": [None], 118 | "dag_density": [0.3], 119 | "reps": [20], 120 | "data_simulator": [None], 121 | "dag_simulator": ["er", 'ba'], 122 | }, 123 | { 124 | "n_variables": [8], 125 | "n_total_environments": [5], 126 | "sparsity": [1, 2, 3, 4, 5, 6, 7], 127 | 'intervention_targets': [None], 128 | "sample_size": [None], 129 | "dag_density": [0.3], 130 | "reps": [20], 131 | "data_simulator": [None], 132 | "dag_simulator": ["er", 'ba'], 133 | }, 134 | { 135 | "n_variables": [8], 136 | "n_total_environments": [15], 137 | "sparsity": [0.5], 138 | 'intervention_targets': [None], 139 | "sample_size": [None], 140 | "dag_density": [0.3], 141 | "reps": [20], 142 | "data_simulator": [None], 143 | "dag_simulator": ["er", 'ba'], 144 | }, 145 | { 146 | "n_variables": [8], 147 | "n_total_environments": [5], 148 | "sparsity": [0.5], 149 | 'intervention_targets': [None], 150 | "sample_size": [None], 151 | "dag_density": [0.3, 0.5, 0.7, 0.9, 0.1], 152 | "reps": [20], 153 | "data_simulator": [None], 154 | "dag_simulator": ["er", 'ba'], 155 | }, 156 | { 157 | # Since 'ba' can't handle all of the same settings as 'er' 158 | "n_variables": [8], 159 | "n_total_environments": [5], 160 | "sparsity": [0.5], 161 | 'intervention_targets': [None], 162 | "sample_size": [None], 163 | "dag_density": [0.3, 0.5, 0.7, 0.9], 164 | "reps": [20], 165 | "data_simulator": [None], 166 | "dag_simulator": ["ba"], 167 | }, 168 | ], 169 | "bivariate_power": [{ 170 | "n_variables": [2], 171 | "n_total_environments": [2], 172 | "sparsity": [None], 173 | 'intervention_targets': [ 174 | [[], [0]], 175 | [[], [1]], 176 | [[], []], 177 | [[], [0, 1]], 178 | ], 179 | "sample_size": [500], 180 | "dag_density": [None], 181 | "reps": [50], 182 | "data_simulator": ["cdnod"], 183 | "dag_simulator": ["complete"], 184 | }], 185 | "bivariate_multiplic_power": [{ 186 | "n_variables": [2], 187 | "n_total_environments": [2], 188 | "sparsity": [None], 189 | 'intervention_targets': [ 190 | [[], [0]], 191 | [[], [1]], 192 | [[], []], 193 | [[], [0, 1]], 194 | ], 195 | "sample_size": [500], 196 | "dag_density": [None], 197 | "reps": [50], 198 | "data_simulator": ["cdnod"], 199 | "dag_simulator": ["complete"], 200 | }] 201 | } 202 | 203 | 204 | # save name, method name, algo, hpyerparams 205 | ALL_METHODS = [ 206 | ( 207 | 'mch_kci', 208 | 'Min changes (KCI)', 209 | MinChange, 210 | { 211 | 'alpha': 0.05, 212 | 'scale_alpha': True, 213 | 'test': 'kci', 214 | 'test_kwargs': { 215 | "KernelX": "GaussianKernel", 216 | "KernelY": "GaussianKernel", 217 | "KernelZ": "GaussianKernel", 218 | }, 219 | } 220 | ), 221 | ( 222 | 'mch_lin', 223 | 'Min changes (Linear)', 224 | MinChange, 225 | { 226 | 'alpha': 0.05, 227 | 'scale_alpha': True, 228 | 'test': 'invariant_residuals', 229 | 'test_kwargs': {'method': 'linear', 'test': "whitney_levene"}, 230 | } 231 | ), 232 | ( 233 | 'mch_gam', 234 | 'Min changes (GAM)', 235 | MinChange, 236 | { 237 | 'alpha': 0.05, 238 | 'scale_alpha': True, 239 | 'test': 'invariant_residuals', 240 | 'test_kwargs': {'method': 'gam', 'test': "whitney_levene"}, 241 | } 242 | ), 243 | ( 244 | 'mch_fisherz', 245 | 'Min changes (FisherZ)', 246 | MinChange, 247 | { 248 | 'alpha': 0.05, 249 | 'scale_alpha': True, 250 | 'test': 'fisherz', 251 | 'test_kwargs': {}, 252 | } 253 | ), 254 | ( 255 | 'full_pc_kci', 256 | 'Full PC (KCI)', 257 | FullMinChanges, 258 | { 259 | 'alpha': 0.05, 260 | 'scale_alpha': True, 261 | 'test': 'kci', 262 | 'test_kwargs': {}, 263 | } 264 | ), 265 | ( 266 | 'mc', 267 | 'MC', 268 | ParamChanges, 269 | { 270 | 'alpha': 0.05, 271 | 'scale_alpha': True, 272 | } 273 | ) 274 | ] 275 | 276 | METHODS_DICT = { 277 | "DEBUG": ALL_METHODS, 278 | "pairwise_power": ALL_METHODS, 279 | # "environment_convergence": ALL_METHODS, 280 | # "soft_samples": ALL_METHODS, 281 | "oracle_rates": [], 282 | "oracle_select_rates": [], 283 | "bivariate_power": ALL_METHODS, 284 | "bivariate_multiplic_power": ALL_METHODS, 285 | # ( 286 | # 'mch_kcd', 287 | # 'KCD', 288 | # MinChange, 289 | # { 290 | # 'alpha': 0.05, 291 | # 'scale_alpha': True, 292 | # 'test': 'fisherz', 293 | # 'test_kwargs': {'n_jobs': -2, 'n_reps': 100}, 294 | # } 295 | # ), 296 | # ], 297 | } 298 | 299 | 300 | def get_experiment_params(exp): 301 | return PARAMS_DICT[exp] 302 | 303 | 304 | def get_param_keys(exp): 305 | return list(PARAMS_DICT[exp][0].keys()) 306 | 307 | 308 | def get_experiments(): 309 | return list(PARAMS_DICT.keys()) 310 | 311 | 312 | def get_experiment_methods(exp): 313 | return METHODS_DICT[exp] 314 | -------------------------------------------------------------------------------- /experiments/main_simulations/figures/bivariate_multiplic_power_plots.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/main_simulations/figures/bivariate_multiplic_power_plots.pdf -------------------------------------------------------------------------------- /experiments/main_simulations/figures/bivariate_power_plots.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/main_simulations/figures/bivariate_power_plots.pdf -------------------------------------------------------------------------------- /experiments/main_simulations/figures/empirical_select_rates_er_others.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/main_simulations/figures/empirical_select_rates_er_others.pdf -------------------------------------------------------------------------------- /experiments/main_simulations/figures/empirical_select_rates_er_ours.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/main_simulations/figures/empirical_select_rates_er_ours.pdf -------------------------------------------------------------------------------- /experiments/main_simulations/figures/oracle_rate_relplot.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/main_simulations/figures/oracle_rate_relplot.pdf -------------------------------------------------------------------------------- /experiments/main_simulations/figures/oracle_select_rates_all_models.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/main_simulations/figures/oracle_select_rates_all_models.pdf -------------------------------------------------------------------------------- /experiments/main_simulations/figures/oracle_select_rates_ba.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/main_simulations/figures/oracle_select_rates_ba.pdf -------------------------------------------------------------------------------- /experiments/main_simulations/figures/oracle_select_rates_er.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/experiments/main_simulations/figures/oracle_select_rates_er.pdf -------------------------------------------------------------------------------- /experiments/main_simulations/plot_bivariate_identifiability.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | import pandas as pd 11 | 12 | 13 | # In[2]: 14 | 15 | 16 | SAVE_FIGURES = False 17 | RESULTS_DIR = './results_paper' 18 | 19 | 20 | # In[3]: 21 | 22 | 23 | EXPERIMENT = 'bivariate_power' 24 | TAG = '_paper' 25 | df1 = pd.read_csv(f'{RESULTS_DIR}/{EXPERIMENT}_results{TAG}.csv', sep=', ', engine='python') 26 | df1['Experiment'] = 'Additive' 27 | 28 | EXPERIMENT = 'bivariate_multiplic_power' 29 | TAG = '' 30 | df2 = pd.read_csv(f'{RESULTS_DIR}/{EXPERIMENT}_results{TAG}.csv', sep=', ', engine='python') 31 | df1['Experiment'] = 'Multiplic.' 32 | 33 | df = pd.concat([df1, df2]) 34 | 35 | 36 | # In[4]: 37 | 38 | 39 | plot_df = df 40 | 41 | x_var_rename_dict = { 42 | 'sample_size': '# Samples', 43 | 'Number of environments': '# Environments', 44 | 'Fraction of shifting mechanisms': 'Shift fraction', 45 | 'dag_density': 'Edge density', 46 | 'n_variables': '# Variables', 47 | } 48 | 49 | plot_df = df.rename( 50 | x_var_rename_dict, axis=1 51 | ).rename( 52 | {'Method': 'Test', 'Soft': 'Score'}, axis=1 53 | ).replace( 54 | { 55 | 'er': 'Erdos-Renyi', 56 | 'ba': 'Hub', 57 | 'PC (pool all)': 'Full PC (oracle)', 58 | 'Full PC (KCI)': r'Pooled PC (KCI) [25]', 59 | 'Min changes (oracle)': 'MSS (oracle)', 60 | 'Min changes (KCI)': 'MSS (KCI)', 61 | 'Min changes (GAM)': 'MSS (GAM)', 62 | 'Min changes (Linear)': 'MSS (Linear)', 63 | 'Min changes (FisherZ)': 'MSS (FisherZ)', 64 | 'MC': r'MC [11]', 65 | False: 'Hard', 66 | True: 'Soft', 67 | } 68 | ) 69 | 70 | plot_df = plot_df.loc[ 71 | (~plot_df['Test'].isin(['Full PC (oracle)', 'MSS (oracle)'])) & 72 | (plot_df['# Environments'] == 2) & 73 | (plot_df['Score'] == 'Hard') 74 | ] 75 | 76 | plot_df = plot_df.replace({ 77 | '[[];[0]]': 'P(X1)', 78 | '[[];[1]]': 'P(X2|X1)', 79 | '[[];[]]': 'Neither', 80 | '[[];[0;1]]': 'Both', 81 | }) 82 | 83 | 84 | # In[5]: 85 | 86 | 87 | sns.set_context('paper') 88 | fig, axes = plt.subplots(1, 4, sharey=True, sharex=True, figsize=(7.5, 2.5)) 89 | 90 | intv_targets = ['P(X1)', 'P(X2|X1)', 'Neither', 'Both'] 91 | ax_var = 'intervention_targets' 92 | x_var = 'Precision' # 'False orientation rate' # 93 | y_var = 'Recall' # 'True orientation rate'# 94 | hue = 'Test' 95 | 96 | for targets, ax in zip(intv_targets, axes.flatten()): 97 | mean_df = plot_df[plot_df[ax_var] == targets].groupby('Test').mean().reset_index() 98 | std_df = plot_df[plot_df[ax_var] == targets].groupby('Test')[['Precision', 'Recall']].std().reset_index() 99 | std_df.rename( 100 | {'Precision': 'Precision std', 'Recall': 'Recall std'}, axis=1 101 | ) 102 | 103 | g = sns.scatterplot( 104 | data=plot_df[plot_df[ax_var] == targets].groupby('Test').mean().reset_index(), 105 | x=x_var, 106 | y=y_var, 107 | hue=hue, 108 | ax=ax, 109 | palette=[ 110 | sns.color_palette("tab10")[i] 111 | for i in [2, 3, 4, 5, 7, 6] # 3, 4, 5, 112 | ], 113 | hue_order=[ 114 | 'MSS (KCI)', 115 | 'MSS (GAM)', 116 | 'MSS (FisherZ)', 117 | 'MSS (Linear)', 118 | 'Pooled PC (KCI) [25]', 119 | 'MC [11]', 120 | ], 121 | legend='full', 122 | s=100 123 | ) 124 | ax.set_title(f'Shift in {targets}') 125 | 126 | plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 127 | for ax in axes[:-1]: 128 | ax.get_legend().remove() 129 | plt.tight_layout() 130 | if SAVE_FIGURES: 131 | plt.savefig('./figures/bivariate_power_plots.pdf') 132 | plt.show() 133 | 134 | 135 | # ## multiplicative 136 | 137 | # In[7]: 138 | 139 | 140 | EXPERIMENT = 'bivariate_multiplic_power' 141 | TAG = '' 142 | df = pd.read_csv(f'{RESULTS_DIR}/{EXPERIMENT}_results{TAG}.csv', sep=', ', engine='python') 143 | 144 | 145 | # In[8]: 146 | 147 | 148 | plot_df = df 149 | 150 | x_var_rename_dict = { 151 | 'sample_size': '# Samples', 152 | 'Number of environments': '# Environments', 153 | 'Fraction of shifting mechanisms': 'Shift fraction', 154 | 'dag_density': 'Edge density', 155 | 'n_variables': '# Variables', 156 | } 157 | 158 | plot_df = df.rename( 159 | x_var_rename_dict, axis=1 160 | ).rename( 161 | {'Method': 'Test', 'Soft': 'Score'}, axis=1 162 | ).replace( 163 | { 164 | 'er': 'Erdos-Renyi', 165 | 'ba': 'Hub', 166 | 'PC (pool all)': 'Full PC (oracle)', 167 | 'Full PC (KCI)': r'Pooled PC (KCI) [25]', 168 | 'Min changes (oracle)': 'MSS (oracle)', 169 | 'Min changes (KCI)': 'MSS (KCI)', 170 | 'Min changes (GAM)': 'MSS (GAM)', 171 | 'Min changes (Linear)': 'MSS (Linear)', 172 | 'Min changes (FisherZ)': 'MSS (FisherZ)', 173 | 'MC': r'MC [11]', 174 | False: 'Hard', 175 | True: 'Soft', 176 | } 177 | ) 178 | 179 | plot_df = plot_df.loc[ 180 | (~plot_df['Test'].isin(['Full PC (oracle)', 'MSS (oracle)'])) & 181 | (plot_df['# Environments'] == 2) & 182 | (plot_df['Score'] == 'Hard') 183 | ] 184 | 185 | plot_df = plot_df.replace({ 186 | '[[];[0]]': 'P(X1)', 187 | '[[];[1]]': 'P(X2|X1)', 188 | '[[];[]]': 'Neither', 189 | '[[];[0;1]]': 'Both', 190 | }) 191 | 192 | 193 | # In[9]: 194 | 195 | 196 | sns.set_context('paper') 197 | fig, axes = plt.subplots(1, 4, sharey=True, sharex=True, figsize=(7.5, 2.5)) 198 | 199 | intv_targets = ['P(X1)', 'P(X2|X1)', 'Neither', 'Both'] 200 | ax_var = 'intervention_targets' 201 | x_var = 'Precision' # 'False orientation rate' # 202 | y_var = 'Recall' # 'True orientation rate'# 203 | hue = 'Test' 204 | 205 | for targets, ax in zip(intv_targets, axes.flatten()): 206 | mean_df = plot_df[plot_df[ax_var] == targets].groupby('Test').mean().reset_index() 207 | std_df = plot_df[plot_df[ax_var] == targets].groupby('Test')[['Precision', 'Recall']].std().reset_index() 208 | std_df.rename( 209 | {'Precision': 'Precision std', 'Recall': 'Recall std'}, axis=1 210 | ) 211 | 212 | g = sns.scatterplot( 213 | data=plot_df[plot_df[ax_var] == targets].groupby('Test').mean().reset_index(), 214 | x=x_var, 215 | y=y_var, 216 | hue=hue, 217 | ax=ax, 218 | palette=[ 219 | sns.color_palette("tab10")[i] 220 | for i in [2, 3, 4, 5, 7, 6] # 3, 4, 5, 221 | ], 222 | hue_order=[ 223 | 'MSS (KCI)', 224 | 'MSS (GAM)', 225 | 'MSS (FisherZ)', 226 | 'MSS (Linear)', 227 | 'Pooled PC (KCI) [25]', 228 | 'MC [11]', 229 | ], 230 | legend='full', 231 | s=100 232 | ) 233 | ax.set_title(f'Shift in {targets}') 234 | 235 | plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 236 | for ax in axes[:-1]: 237 | ax.get_legend().remove() 238 | plt.tight_layout() 239 | if SAVE_FIGURES: 240 | plt.savefig('./figures/bivariate_multiplic_power_plots.pdf') 241 | plt.show() 242 | 243 | -------------------------------------------------------------------------------- /experiments/main_simulations/plot_empirical_power.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | import pandas as pd 11 | 12 | 13 | # In[2]: 14 | 15 | 16 | SAVE_FIGURES = False 17 | RESULTS_DIR = './results_paper' 18 | 19 | 20 | # In[3]: 21 | 22 | 23 | EXPERIMENT = 'pairwise_power' 24 | 25 | tag = '_paper' 26 | 27 | df = pd.read_csv(f'{RESULTS_DIR}/{EXPERIMENT}_results{tag}.csv', sep=',', engine='python') 28 | 29 | 30 | # In[4]: 31 | 32 | 33 | df = df.loc[df['Precision'].notna(), :] 34 | 35 | df['Fraction of shifting mechanisms'] = df['sparsity'] / df['n_variables'] 36 | 37 | df['F1'] = 2 * df['Recall'] * df['Precision'] / (df['Recall'] + df['Precision']) 38 | 39 | x_var_rename_dict = { 40 | 'sample_size': '# Samples', 41 | 'Number of environments': '# Environments', 42 | 'Fraction of shifting mechanisms': 'Shift fraction', 43 | 'dag_density': 'Edge density', 44 | 'n_variables': '# Variables', 45 | } 46 | 47 | plot_df = df.rename( 48 | x_var_rename_dict, axis=1 49 | ).rename( 50 | {'Method': r'$\bf{Test}$', 'Soft': r'$\bf{Score}$'}, axis=1 51 | ).replace( 52 | { 53 | 'er': 'Erdos-Renyi', 54 | 'ba': 'Hub', 55 | 'PC (pool all)': 'Full PC (oracle)', 56 | 'Full PC (KCI)': r'Pooled PC (KCI) [25]', 57 | 'Min changes (oracle)': 'MSS (oracle)', 58 | 'Min changes (KCI)': 'MSS (KCI)', 59 | 'Min changes (GAM)': 'MSS (GAM)', 60 | 'Min changes (Linear)': 'MSS (Linear)', 61 | 'Min changes (FisherZ)': 'MSS (FisherZ)', 62 | 'MC': r'MC [11]', 63 | False: 'Hard', 64 | True: 'Soft', 65 | } 66 | ) 67 | 68 | 69 | # In[5]: 70 | 71 | 72 | sns.set_context('paper') 73 | 74 | grid_vars = list(x_var_rename_dict.values()) 75 | metrics = ['Recall', 'Precision', 'F1'] 76 | 77 | n_settings = [5, 3, 4, 6, 1] 78 | 79 | indices = [ 80 | (a, b) for a, b in zip(np.cumsum([0] + n_settings)[:-1], np.cumsum(n_settings)) 81 | ] 82 | indices = indices[-2:] + indices[:-2] 83 | 84 | fig, axes = plt.subplots( 85 | len(metrics), 86 | len(grid_vars), 87 | sharey='row', sharex='col', 88 | figsize=(1.5*len(grid_vars), 3) 89 | ) 90 | 91 | for row, metric in zip(axes, metrics): 92 | for g_var, (lb, ub), ax in zip(grid_vars, indices, row): 93 | plot_df_ax = plot_df[ 94 | (plot_df['params_index'] >= lb) 95 | & (plot_df['params_index'] < ub) 96 | & (~plot_df[r'$\bf{Test}$'].isin([ 97 | 'Full PC (oracle)', 98 | 'MSS (oracle)', 99 | 'MSS (GAM)', 100 | 'MSS (FisherZ)', 101 | 'MSS (Linear)' 102 | ])) 103 | ] 104 | plot_df_ax = pd.concat( 105 | ( 106 | plot_df_ax[plot_df_ax[r'$\bf{Score}$'] == 'Hard'], 107 | plot_df_ax[(plot_df_ax[r'$\bf{Score}$'] == 'Soft') & (plot_df_ax[r'$\bf{Test}$'] == 'MSS (KCI)')] 108 | ), 109 | ignore_index=True 110 | ) 111 | 112 | if g_var != '# Environments': 113 | plot_df_ax = plot_df_ax[ 114 | # IMPORTANT! otherwise average over all number of environments 115 | plot_df_ax['# Environments'] == plot_df_ax['# Environments'].max() 116 | ] 117 | 118 | sns.lineplot( 119 | data=plot_df_ax, 120 | x=g_var, 121 | y=metric, 122 | hue=r'$\bf{Test}$', 123 | style=r'$\bf{Score}$', 124 | ax=ax, 125 | palette=[ 126 | sns.color_palette("tab10")[i] 127 | for i in [2, 7, 6] 128 | ], 129 | legend='full', 130 | style_order=['Hard', 'Soft'], 131 | lw=2, 132 | ) 133 | 134 | xmin = plot_df_ax[g_var].min() 135 | xmax = plot_df_ax[g_var].max() 136 | if xmax > 1: 137 | ax.set_xticks([ 138 | xmin, 139 | int(xmin + (xmax - xmin) / 2), 140 | xmax, 141 | ]) 142 | else: 143 | ax.set_xticks([ 144 | np.round(xmin, 1), 145 | np.round(xmin + (xmax - xmin) / 2 , 1), 146 | np.round(xmax, 1), 147 | ]) 148 | 149 | leg_idx = 4 150 | 151 | axes = np.concatenate(axes) 152 | 153 | for i in range(len(axes)): 154 | if i == 0: 155 | axes[i].set_xscale('log') 156 | if i == leg_idx: 157 | axes[i].legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 158 | plt.setp(axes[i].get_legend().get_title(), fontsize=22) 159 | else: 160 | try: 161 | axes[i].get_legend().remove() 162 | except: 163 | pass 164 | 165 | plt.tight_layout() 166 | plt.subplots_adjust(hspace=0.15) 167 | if SAVE_FIGURES: 168 | plt.savefig(f'./figures/empirical_select_rates_er_others.pdf') 169 | plt.show() 170 | 171 | 172 | # ## Soft 173 | 174 | # In[6]: 175 | 176 | 177 | # sns.set_context('paper') 178 | 179 | fig, axes = plt.subplots( 180 | len(metrics), 181 | len(grid_vars), 182 | sharey='row', sharex='col', 183 | figsize=(1.5*len(grid_vars), 3.5) 184 | ) 185 | 186 | for row, metric in zip(axes, metrics): 187 | # for row, metric in zip(axes, ['Recall', 'Precision', 'Average precision']): 188 | for g_var, (lb, ub), ax in zip(grid_vars, indices, row): 189 | plot_df_ax = plot_df[ 190 | (plot_df['params_index'] >= lb) 191 | & (plot_df['params_index'] < ub) 192 | & (plot_df[r'$\bf{Test}$'].isin(['MSS (KCI)', 'MSS (GAM)', 'MSS (FisherZ)']))#, 'MC'])) 193 | ] 194 | 195 | if g_var != '# Environments': 196 | plot_df_ax = plot_df_ax[ 197 | # IMPORTANT! otherwise average over all number of environments 198 | plot_df_ax['# Environments'] == plot_df_ax['# Environments'].max() 199 | ] 200 | 201 | sns.lineplot( 202 | data=plot_df_ax, 203 | x=g_var, 204 | y=metric, 205 | hue=r'$\bf{Test}$', 206 | style=r'$\bf{Score}$', 207 | ax=ax, 208 | palette=[ 209 | sns.color_palette("tab10")[i] 210 | for i in [2, 3, 4] 211 | ], 212 | legend='full', 213 | style_order=['Hard', 'Soft'], 214 | ) 215 | 216 | xmin = plot_df_ax[g_var].min() 217 | xmax = plot_df_ax[g_var].max() 218 | if xmax > 1: 219 | ax.set_xticks([ 220 | xmin, 221 | int(xmin + (xmax - xmin) / 2), 222 | xmax, 223 | ]) 224 | else: 225 | ax.set_xticks([ 226 | np.round(xmin, 1), 227 | np.round(xmin + (xmax - xmin) / 2 , 1), 228 | np.round(xmax, 1), 229 | ]) 230 | 231 | leg_idx = 4 232 | axes = np.concatenate(axes) 233 | for i in range(len(axes)): 234 | if i == 0: 235 | axes[i].set_xscale('log') 236 | if i == leg_idx: 237 | axes[i].legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 238 | else: 239 | axes[i].get_legend().remove() 240 | 241 | plt.tight_layout() 242 | plt.subplots_adjust(hspace=0.15) 243 | if SAVE_FIGURES: 244 | plt.savefig(f'./figures/empirical_select_rates_er_ours.pdf') 245 | plt.show() 246 | 247 | 248 | # 249 | -------------------------------------------------------------------------------- /experiments/main_simulations/plot_oracle_rates.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | import pandas as pd 11 | 12 | 13 | # In[8]: 14 | 15 | 16 | SAVE_FIGURES = False 17 | RESULTS_DIR = './results_paper' 18 | 19 | 20 | # ## Full rates 21 | 22 | # In[9]: 23 | 24 | 25 | EXPERIMENT = 'oracle_rates' 26 | 27 | df = pd.read_csv(f'{RESULTS_DIR}/{EXPERIMENT}_results_paper.csv', sep=', ', engine='python') 28 | 29 | df = df.rename( 30 | { 31 | 'sample_size': '# Samples', 32 | 'Number of environments': '# Environments', 33 | 'sparsity': 'Shift fraction', 34 | 'dag_density': 'Density', 35 | 'n_variables': 'Vars', 36 | 'Method': 'Test', 37 | }, axis=1 38 | ).replace( 39 | { 40 | 'Full PC (oracle)': 'Pooled PC (oracle) [25]', 41 | 'Min changes (oracle)': 'MSS (oracle)', 42 | } 43 | ) 44 | 45 | 46 | # In[10]: 47 | 48 | 49 | sns.set_context('talk') 50 | 51 | for n_env in [5]: 52 | g = sns.relplot( 53 | data=df[ 54 | (df['# Environments'] == n_env) 55 | ], 56 | x='Shift fraction', 57 | y='Recall', 58 | hue='Test', 59 | row='Density', 60 | col='Vars', 61 | kind='line', 62 | height=2, 63 | aspect=2, 64 | legend=None,#'full', 65 | ) 66 | 67 | if SAVE_FIGURES: 68 | plt.savefig('./figures/oracle_rate_relplot.pdf') 69 | plt.show() 70 | 71 | 72 | # ## Marginal rates 73 | 74 | # In[11]: 75 | 76 | 77 | EXPERIMENT = 'oracle_select_rates' 78 | 79 | df = pd.read_csv(f'{RESULTS_DIR}/{EXPERIMENT}_results_paper.csv', sep=', ', engine='python') 80 | 81 | df['Fraction of shifting mechanisms'] = df['sparsity'] / df['n_variables'] 82 | 83 | df = df.rename( 84 | { 85 | 'sample_size': '# Samples', 86 | 'Number of environments': '# Environments', 87 | 'Fraction of shifting mechanisms': 'Shift fraction', 88 | 'dag_density': 'Edge density', 89 | 'n_variables': '# Variables', 90 | }, axis=1 91 | ).replace( 92 | { 93 | 'Full PC (oracle)': 'Pooled PC (oracle) [25]', 94 | 'Min changes (oracle)': 'MSS (oracle)', 95 | } 96 | ) 97 | 98 | 99 | # In[12]: 100 | 101 | 102 | sns.set_context('notebook') 103 | 104 | plot_df = df 105 | 106 | grid_vars = [ 107 | '# Environments', 'Shift fraction', 'Edge density', '# Variables' 108 | ] 109 | 110 | # Indices are based on the size of the parameter sets tested on 111 | indices = [ 112 | (22, 36), (8, 22), (36, 46), (0, 8), 113 | ] 114 | 115 | for graph_model in plot_df['dag_simulator'].unique(): 116 | fig, axes = plt.subplots(1, 4, sharey=True, figsize=(12, 2.5)) 117 | 118 | for g_var, (lb, ub), ax in zip(grid_vars, indices, axes.flatten()): 119 | plot_df_ax = plot_df[ 120 | (plot_df['params_index'] >= lb) 121 | & (plot_df['params_index'] < ub) 122 | & (plot_df['dag_simulator'] == graph_model) 123 | # IMPORTANT! otherwise average over all number of environments 124 | & ( 125 | (plot_df['# Environments'] == 3) 126 | if not (g_var == '# Environments') else True) 127 | ] 128 | 129 | sns.lineplot( 130 | data=plot_df_ax, 131 | x=g_var, 132 | y='Recall', 133 | hue='Method', 134 | ax=ax, 135 | hue_order=['MSS (oracle)', 'Pooled PC (oracle) [25]'], 136 | palette=[ 137 | sns.color_palette("tab10")[i] 138 | for i in [1, 0] # 3, 4, 5, 139 | ], 140 | legend='full', 141 | ) 142 | 143 | ax.set_xticks([ 144 | np.round(plot_df_ax[g_var].min(), 1), 145 | np.round(plot_df_ax[g_var].max(), 1) 146 | ]) 147 | 148 | plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 149 | plt.setp(axes[-1].get_legend().get_texts(), fontsize=12) 150 | for ax in axes[:-1]: 151 | ax.get_legend().remove() 152 | 153 | plt.ylim([0.4, 1.03]) 154 | plt.tight_layout() 155 | if SAVE_FIGURES: 156 | plt.savefig(f'./figures/oracle_select_rates_{graph_model}.pdf') 157 | plt.show() 158 | 159 | 160 | # In[13]: 161 | 162 | 163 | sns.set_context('notebook') 164 | 165 | plot_df = df 166 | 167 | grid_vars = [ 168 | '# Environments', 'Shift fraction', 'Edge density', '# Variables' 169 | ] 170 | 171 | indices = [ 172 | (22, 36), (8, 22), (36, 46), (0, 8), 173 | ] 174 | 175 | fig, axes = plt.subplots(1, 4, sharey=True, figsize=(12, 2.5)) 176 | 177 | for g_var, (lb, ub), ax in zip(grid_vars, indices, axes.flatten()): 178 | plot_df_ax = plot_df[ 179 | (plot_df['params_index'] >= lb) 180 | & (plot_df['params_index'] < ub) 181 | # IMPORTANT! otherwise average over all number of environments 182 | & ( 183 | (plot_df['# Environments'] == 3) 184 | if not (g_var == '# Environments') else True) 185 | ] 186 | 187 | sns.lineplot( 188 | data=plot_df_ax.replace( 189 | {'er': 'Erdos-Renyi', 'ba': 'Hub'}, 190 | ).rename( 191 | {'dag_simulator': r'$\bf{DAG\ Model}$', 'Method': r'$\bf{Test}$'}, 192 | axis=1), 193 | x=g_var, 194 | y='Recall', 195 | hue=r'$\bf{Test}$', 196 | # style="Test", 197 | ax=ax, 198 | hue_order=['MSS (oracle)', 'Pooled PC (oracle) [25]'], 199 | palette=[ 200 | sns.color_palette("tab10")[i] 201 | for i in [1, 0] # 3, 4, 5, 202 | ], 203 | style=r'$\bf{DAG\ Model}$', 204 | legend='full', 205 | ) 206 | 207 | ax.set_xticks([ 208 | np.round(plot_df_ax[g_var].min(), 1), 209 | np.round(plot_df_ax[g_var].max(), 1) 210 | ]) 211 | 212 | plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 213 | plt.setp(axes[-1].get_legend().get_texts(), fontsize=12) 214 | for ax in axes[:-1]: 215 | ax.get_legend().remove() 216 | 217 | plt.ylim([0.4, 1.03]) 218 | plt.tight_layout() 219 | if SAVE_FIGURES: 220 | plt.savefig(f'./figures/oracle_select_rates_all_models.pdf') 221 | plt.show() 222 | 223 | -------------------------------------------------------------------------------- /experiments/main_simulations/results_paper/DEBUG_results.csv: -------------------------------------------------------------------------------- 1 | params_index, n_variables, n_total_environments, sparsity, intervention_targets, sample_size, dag_density, reps, data_simulator, dag_simulator, Method, Soft, Number of environments, Rep, Number of possible DAGs, MEC size, MEC total edges, MEC unoriented edges, True orientation rate, False orientation rate, Precision, Recall, Average precision 2 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 3 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 4 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 2, 0, 2, 6, 3, 3, 0.6667, 0.0, 1.0, 0.6667, 0.6667 5 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 2, 0, 2, 6, 3, 3, 0.6667, 0.0, 1.0, 0.6667, 0.6667 6 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 3, 0, 1, 6, 3, 3, 1.0, 0.0, 1.0, 1.0, 1.0 7 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 3, 0, 1, 6, 3, 3, 1.0, 0.0, 1.0, 1.0, 1.0 8 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 9 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 10 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 2, 0, 1, 6, 3, 3, 0.3333, 0.6667, 0.3333, 0.3333, 0.8889 11 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 2, 0, 2, 6, 3, 3, 0.3333, 0.0, 1.0, 0.3333, 0.8889 12 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 3, 0, 1, 6, 3, 3, 0.3333, 0.6667, 0.3333, 0.3333, 0.8889 13 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 3, 0, 3, 6, 3, 3, 0.3333, 0.0, 1.0, 0.3333, 0.8889 14 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 15 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 16 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 2, 0, 1, 6, 3, 3, 1.0, 0.0, 1.0, 1.0, 0.7778 17 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 2, 0, 4, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.7778 18 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 3, 0, 1, 6, 3, 3, 0.6667, 0.3333, 0.6667, 0.6667, 0.8889 19 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 3, 0, 2, 6, 3, 3, 0.6667, 0.0, 1.0, 0.6667, 0.8889 20 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 21 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 22 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 2, 0, 1, 6, 3, 3, 1.0, 0.0, 1.0, 1.0, 0.7778 23 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 2, 0, 4, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.7778 24 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 3, 0, 1, 6, 3, 3, 1.0, 0.0, 1.0, 1.0, 0.6667 25 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 3, 0, 2, 6, 3, 3, 0.3333, 0.3333, 0.5, 0.3333, 0.6667 26 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), True, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 27 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), False, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 28 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), True, 2, 0, 1, 6, 3, 3, 0.3333, 0.6667, 0.3333, 0.3333, 0.3889 29 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), False, 2, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.3889 30 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), True, 3, 0, 1, 6, 3, 3, 0.6667, 0.3333, 0.6667, 0.6667, 0.5 31 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), False, 3, 0, 2, 6, 3, 3, 0.3333, 0.3333, 0.5, 0.3333, 0.5 32 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), True, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 33 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), False, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 34 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), True, 2, 0, 1, 6, 3, 3, 0.3333, 0.6667, 0.3333, 0.3333, 0.8889 35 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), False, 2, 0, 2, 6, 3, 3, 0.3333, 0.0, 1.0, 0.3333, 0.8889 36 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), True, 3, 0, 1, 6, 3, 3, 1.0, 0.0, 1.0, 1.0, 0.8889 37 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), False, 3, 0, 2, 6, 3, 3, 0.0, 0.6667, 0.0, 0.0, 0.8889 38 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, True, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 39 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, False, 1, 0, 6, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.0 40 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, True, 2, 0, 1, 6, 3, 3, 1.0, 0.0, 1.0, 1.0, 0.8333 41 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, False, 2, 0, 4, 6, 3, 3, 0.0, 0.0, 1, 0.0, 0.8333 42 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, True, 3, 0, 1, 6, 3, 3, 1.0, 0.0, 1.0, 1.0, 0.8889 43 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, False, 3, 0, 1, 6, 3, 3, 0.6667, 0.3333, 0.6667, 0.6667, 0.8889 44 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 45 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 46 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 2, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 47 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 2, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 48 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 3, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 49 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 3, 1, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 50 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 51 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 52 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 2, 1, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 1.0 53 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 2, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 54 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 3, 1, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 1.0 55 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 3, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 56 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 57 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 58 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 2, 1, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 59 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 2, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 60 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 3, 1, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 61 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 3, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 62 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 63 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 64 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 2, 1, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 1.0 65 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 2, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 66 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 3, 1, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 67 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 3, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 68 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), True, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 69 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), False, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 70 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), True, 2, 1, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 71 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), False, 2, 1, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 72 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), True, 3, 1, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 1.0 73 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), False, 3, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 74 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), True, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 75 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), False, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 76 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), True, 2, 1, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 77 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), False, 2, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 78 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), True, 3, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 79 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), False, 3, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 80 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, True, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 81 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, False, 1, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 82 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, True, 2, 1, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 83 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, False, 2, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 84 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, True, 3, 1, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 1.0 85 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, False, 3, 1, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 86 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 87 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 88 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 2, 2, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 89 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 2, 2, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 90 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 3, 2, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 91 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 3, 2, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 92 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 93 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 94 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 2, 2, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 1.0 95 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 2, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 96 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 3, 2, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 97 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 3, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 98 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 99 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 100 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 2, 2, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 101 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 2, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 102 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 3, 2, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 103 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 3, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 104 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 105 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 106 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 2, 2, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 107 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 2, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 108 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 3, 2, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 109 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 3, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 110 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), True, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 111 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), False, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 112 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), True, 2, 2, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 113 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), False, 2, 2, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 114 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), True, 3, 2, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 115 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (FisherZ), False, 3, 2, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 116 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), True, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 117 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), False, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 118 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), True, 2, 2, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 119 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), False, 2, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 120 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), True, 3, 2, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 121 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (KCI), False, 3, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 122 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, True, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 123 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, False, 1, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 124 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, True, 2, 2, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 125 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, False, 2, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 126 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, True, 3, 2, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 127 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, MC, False, 3, 2, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 128 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 1, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 129 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 1, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 130 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 2, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 131 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 2, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 132 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Full PC (oracle), False, 3, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 133 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (oracle), False, 3, 3, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 134 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 1, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 135 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 1, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 136 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 2, 3, 1, 2, 1, 1, 1.0, 0.0, 1.0, 1.0, 1.0 137 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 2, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 1.0 138 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), True, 3, 3, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 1.0 139 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (KCI), False, 3, 3, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 1.0 140 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 1, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 141 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 1, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 142 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 2, 3, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 143 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 2, 3, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 144 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), True, 3, 3, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 1.0 145 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (Linear), False, 3, 3, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 1.0 146 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 1, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 147 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 1, 3, 2, 2, 1, 1, 0.0, 0.0, 1, 0.0, 0.0 148 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), True, 2, 3, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 149 | 0, 3, 3, 1, None, 100, 0.3, 20, cdnod, er, Min changes (GAM), False, 2, 3, 1, 2, 1, 1, 0.0, 1.0, 0.0, 0.0, 0.0 150 | -------------------------------------------------------------------------------- /experiments/main_simulations/run_experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import logging 4 | import pickle 5 | import itertools 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from tqdm import tqdm 10 | from joblib import Parallel, delayed 11 | 12 | from sparse_shift.datasets import ( 13 | sample_cdnod_sim, 14 | sample_topological, 15 | erdos_renyi_dag, 16 | connected_erdos_renyi_dag, 17 | barabasi_albert_dag, 18 | complete_dag, 19 | ) 20 | from sparse_shift.plotting import plot_dag 21 | from sparse_shift.testing import test_mechanism_shifts, test_mechanism 22 | from sparse_shift.methods import FullPC, PairwisePC, MinChangeOracle, MinChange, FullMinChanges 23 | from sparse_shift.metrics import dag_true_orientations, dag_false_orientations, \ 24 | dag_precision, dag_recall, average_precision_score 25 | from sparse_shift.utils import dag2cpdag, cpdag2dags 26 | 27 | import os 28 | import warnings 29 | 30 | warnings.simplefilter("ignore") 31 | os.environ["PYTHONWARNINGS"] = "ignore" # Also affect subprocesses 32 | 33 | 34 | def _sample_dag(dag_simulator, n_variables, dag_density, seed=None): 35 | """ 36 | Samples a DAG from a specified distribution 37 | """ 38 | if dag_simulator == "er": 39 | dag = erdos_renyi_dag(n_variables, dag_density, seed=seed) 40 | elif dag_simulator == "ba": 41 | dag = barabasi_albert_dag(n_variables, dag_density, seed=seed) 42 | elif dag_simulator == 'complete': 43 | dag = complete_dag(n_variables) 44 | else: 45 | raise ValueError(f"DAG simulator {dag_simulator} not valid optoion") 46 | 47 | count = 0 48 | if len(cpdag2dags(dag2cpdag(dag))) == 1: 49 | # Don't sample already solved MECs 50 | np.random.seed(seed) 51 | new_seed = int(1000*np.random.uniform()) 52 | dag = _sample_dag(dag_simulator, n_variables, dag_density, new_seed) 53 | count += 1 54 | if count > 100: 55 | raise ValueError(f"Cannot sample a DAG in these settings with nontrivial MEC ({[dag_simulator, n_variables, dag_density]})") 56 | 57 | return dag 58 | 59 | 60 | def _sample_interventions(n_variables, n_total_environments, sparsity, seed=None): 61 | np.random.seed(seed) 62 | if isinstance(sparsity, float): 63 | sparsity = np.round(n_variables * sparsity).astype(int) 64 | sampled_targets = [ 65 | np.random.choice(n_variables, sparsity, replace=False) 66 | for _ in range(n_total_environments) 67 | ] 68 | return sampled_targets 69 | 70 | 71 | def _sample_datasets(data_simulator, sample_size, dag, intervention_targets, seed=None): 72 | """ 73 | Samples multi-environment data from a specified distribution 74 | """ 75 | if data_simulator == "cdnod": 76 | np.random.seed(seed) 77 | domain_seed = int(1000 * np.random.uniform()) 78 | Xs = [ 79 | sample_cdnod_sim( 80 | dag, 81 | sample_size, 82 | intervention_targets=targets, 83 | base_random_state=seed, 84 | domain_random_state=domain_seed + i, 85 | ) 86 | for i, targets in enumerate(intervention_targets) 87 | ] 88 | else: 89 | raise ValueError(f"Data simulator {data_simulator} not valid optoion") 90 | 91 | return Xs 92 | 93 | 94 | def main(args): 95 | # Determine experimental settings 96 | if args.quick: 97 | from exp_quick_settings import get_experiment_params, get_param_keys 98 | else: 99 | from exp_settings import get_experiment_params, get_param_key 100 | 101 | # Initialize og details 102 | logging.basicConfig( 103 | filename="./logging.log", 104 | format="%(asctime)s:%(levelname)s:%(message)s", 105 | level=logging.INFO, 106 | ) 107 | logging.info(f"NEW RUN:") 108 | logging.info(f"Args: {args}") 109 | logging.info(f"Experimental settings:") 110 | logging.info(get_experiment_params(args.experiment)) 111 | 112 | # Create results csv header 113 | header = np.hstack( 114 | [ 115 | ["params_index"], 116 | get_param_keys(args.experiment), 117 | ["Method", "Soft", "Number of environments", "Rep"], 118 | ["Number of possible DAGs", "MEC size", "MEC total edges", "MEC unoriented edges"], 119 | ["True orientation rate", "False orientation rate", "Precision", "Recall", 'Average precision'], 120 | ] 121 | ) 122 | if not os.path.exists('./results/'): 123 | os.makedirs('./results/') 124 | write_file = open(f"./results/{args.experiment}_results.csv", "w+") 125 | write_file.write(", ".join(header) + "\n") 126 | write_file.flush() 127 | 128 | # Construct parameter grids 129 | param_dicts = get_experiment_params(args.experiment) 130 | prior_indices = 0 131 | logging.info(f'{len(param_dicts)} total parameter dictionaries') 132 | for params_dict in param_dicts: 133 | param_keys, param_values = zip(*params_dict.items()) 134 | params_grid = [dict(zip(param_keys, v)) for v in itertools.product(*param_values)] 135 | 136 | # Iterate over 137 | logging.info(f'{len(params_grid)} total parameter combinations') 138 | 139 | for i, params in enumerate(params_grid): 140 | logging.info(f"Params {i} / {len(params_grid)}") 141 | run_experimental_setting( 142 | args=args, 143 | params_index=i + prior_indices, 144 | write_file=write_file, 145 | **params, 146 | ) 147 | 148 | prior_indices += len(params_grid) 149 | logging.info(f'Complete') 150 | 151 | 152 | def run_experimental_setting( 153 | args, 154 | params_index, 155 | write_file, 156 | n_variables, 157 | n_total_environments, 158 | sparsity, 159 | intervention_targets, 160 | sample_size, 161 | dag_density, 162 | reps, 163 | data_simulator, 164 | dag_simulator, 165 | ): 166 | 167 | # Determine experimental settings 168 | if args.quick: 169 | from exp_quick_settings import get_experiment_methods 170 | else: 171 | from exp_settings import get_experiment_methods 172 | 173 | name = args.experiment 174 | 175 | 176 | if sparsity is not None and sparsity > n_variables: 177 | logging.info(f"Skipping: sparsity {sparsity} greater than n_variables {n_variables}") 178 | return 179 | 180 | experimental_params = [ 181 | params_index, 182 | n_variables, 183 | n_total_environments, 184 | sparsity, 185 | intervention_targets, 186 | sample_size, 187 | dag_density, 188 | reps, 189 | data_simulator, 190 | dag_simulator, 191 | ] 192 | experimental_params = [str(val).replace(", ", ";") for val in experimental_params] 193 | 194 | def _run_rep(rep, write): 195 | results = [] 196 | # Get DAG 197 | true_dag = _sample_dag(dag_simulator, n_variables, dag_density, seed=rep) 198 | true_cpdag = dag2cpdag(true_dag) 199 | mec_size = len(cpdag2dags(true_cpdag)) 200 | total_edges = np.sum(true_dag) 201 | unoriented_edges = np.sum((true_cpdag + true_cpdag.T) == 2) // 2 202 | 203 | # Get interventions 204 | if intervention_targets is None: 205 | sampled_targets = _sample_interventions( 206 | n_variables, n_total_environments, sparsity, seed=rep 207 | ) 208 | else: 209 | sampled_targets = intervention_targets 210 | 211 | # Compute oracle results 212 | fpc_oracle = FullPC(true_dag) 213 | mch_oracle = MinChangeOracle(true_dag) 214 | 215 | for n_env, intv_targets in enumerate(sampled_targets): 216 | n_env += 1 217 | fpc_oracle.add_environment(intv_targets) 218 | mch_oracle.add_environment(intv_targets) 219 | 220 | cpdag = fpc_oracle.get_mec_cpdag() 221 | 222 | true_orients = np.round(dag_true_orientations(true_dag, cpdag), 4) 223 | false_orients = np.round(dag_false_orientations(true_dag, cpdag), 4) 224 | precision = np.round(dag_precision(true_dag, cpdag), 4) 225 | recall = np.round(dag_recall(true_dag, cpdag), 4) 226 | ap = recall 227 | 228 | result = ", ".join( 229 | map( 230 | str, 231 | experimental_params + [ 232 | "Full PC (oracle)", 233 | False, 234 | n_env, 235 | rep, 236 | len(fpc_oracle.get_mec_dags()), 237 | mec_size, 238 | total_edges, 239 | unoriented_edges, 240 | true_orients, 241 | false_orients, 242 | precision, 243 | recall, 244 | ap, 245 | ], 246 | ) 247 | ) + "\n" 248 | if write: 249 | write_file.write(result) 250 | write_file.flush() 251 | else: 252 | results.append(result) 253 | 254 | cpdag = mch_oracle.get_min_cpdag() 255 | 256 | true_orients = np.round(dag_true_orientations(true_dag, cpdag), 4) 257 | false_orients = np.round(dag_false_orientations(true_dag, cpdag), 4) 258 | precision = np.round(dag_precision(true_dag, cpdag), 4) 259 | recall = np.round(dag_recall(true_dag, cpdag), 4) 260 | ap = recall 261 | 262 | result = ", ".join( 263 | map( 264 | str, 265 | experimental_params + [ 266 | "Min changes (oracle)", 267 | False, 268 | n_env, 269 | rep, 270 | len(mch_oracle.get_min_dags()), 271 | mec_size, 272 | total_edges, 273 | unoriented_edges, 274 | true_orients, 275 | false_orients, 276 | precision, 277 | recall, 278 | ap, 279 | ], 280 | ) 281 | ) + "\n" 282 | if write: 283 | write_file.write(result) 284 | write_file.flush() 285 | else: 286 | results.append(result) 287 | 288 | del fpc_oracle, mch_oracle 289 | 290 | # Sample dataset 291 | if data_simulator is None: 292 | return results 293 | 294 | Xs = _sample_datasets( 295 | data_simulator, sample_size, true_dag, sampled_targets, seed=rep 296 | ) 297 | 298 | # Compute empirical results 299 | for save_name, method_name, mch, hyperparams in get_experiment_methods( 300 | args.experiment 301 | ): 302 | mch = mch(cpdag=true_cpdag, **hyperparams) 303 | 304 | for n_env, X in enumerate(Xs): 305 | n_env += 1 306 | mch.add_environment(X) 307 | 308 | for soft in [True, False]: 309 | min_cpdag = mch.get_min_cpdag(soft) 310 | 311 | true_orients = np.round(dag_true_orientations(true_dag, min_cpdag), 4) 312 | false_orients = np.round(dag_false_orientations(true_dag, min_cpdag), 4) 313 | precision = np.round(dag_precision(true_dag, min_cpdag), 4) 314 | recall = np.round(dag_recall(true_dag, min_cpdag), 4) 315 | 316 | if hasattr(mch, 'pvalues_'): 317 | ap = np.round(average_precision_score(true_dag, mch.pvalues_), 4) 318 | else: 319 | ap = None 320 | 321 | result = ", ".join( 322 | map( 323 | str, 324 | experimental_params + [ 325 | method_name, 326 | soft, 327 | n_env, 328 | rep, 329 | len(mch.get_min_dags(soft)), 330 | mec_size, 331 | total_edges, 332 | unoriented_edges, 333 | true_orients, 334 | false_orients, 335 | precision, 336 | recall, 337 | ap, 338 | ], 339 | ) 340 | ) + "\n" 341 | if write: 342 | write_file.write(result) 343 | write_file.flush() 344 | else: 345 | results.append(result) 346 | 347 | # Save pvalues 348 | if not os.path.exists(f'./results/pvalue_mats/{name}/'): 349 | os.makedirs(f'./results/pvalue_mats/{name}/') 350 | if hasattr(mch, 'pvalues_'): 351 | np.save( 352 | f"./results/pvalue_mats/{name}/{name}_{save_name}_pvalues_params={params_index}_rep={rep}.npy", 353 | mch.pvalues_, 354 | ) 355 | 356 | return results 357 | 358 | rep_shift = 0 359 | if args.jobs is not None: 360 | results = Parallel( 361 | n_jobs=args.jobs, 362 | )( 363 | delayed(_run_rep)(rep + rep_shift, False) for rep in range(reps) 364 | ) 365 | for result in np.concatenate(results): 366 | write_file.write(result) 367 | write_file.flush() 368 | else: 369 | for rep in tqdm(range(reps)): 370 | _run_rep(rep + rep_shift, write=True) 371 | 372 | 373 | if __name__ == "__main__": 374 | parser = argparse.ArgumentParser() 375 | parser.add_argument( 376 | "--experiment", 377 | help="experiment parameters to run", 378 | ) 379 | parser.add_argument( 380 | "--jobs", 381 | help="Number of jobs to run in parallel", 382 | default=None, 383 | type=int, 384 | ) 385 | parser.add_argument( 386 | "--quick", 387 | help="Enable to run a smaller, test version", 388 | default=False, 389 | action='store_true' 390 | ) 391 | args = parser.parse_args() 392 | 393 | main(args) 394 | -------------------------------------------------------------------------------- /experiments/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | matplotlib 3 | seaborn 4 | -------------------------------------------------------------------------------- /experiments/teaser_sparse_oracle_pc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import seaborn as sns 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | # # Experiment 14 | # 15 | # We asume the triangle DAG with edges 16 | # 1. X1 -> X2 17 | # 2. X2 -> X3 18 | # 3. X1 -> X3 19 | # 20 | # which we know up to its Markov equivalence class but obtain natural interventions on variables with some degree of sparsity. 21 | # 22 | # We begin with a domain of the original graph, then sample additional domans with some number of interventions. We either then pool all together, or consider pairwise pooling taking the union across pools. 23 | 24 | # In[2]: 25 | 26 | 27 | # Meeks rules giving the oritentations are commented above 28 | intervention_dict = { 29 | # 2 v-structures 30 | (1,): [1, 3], 31 | # 2 v-structure, acyclic 32 | (2,): [1, 2, 3], 33 | # 2 v-structures 34 | (3,): [2, 3], 35 | # 2 v-structures 36 | (1, 2): [2, 3], 37 | # 2 v-structures 38 | (2, 3): [1, 3], 39 | # 2 v-structures, acyclic 40 | (1, 3): [1, 2, 3], 41 | # Nothing 42 | (1, 2, 3): [], 43 | } 44 | 45 | 46 | # In[3]: 47 | 48 | 49 | class FullPC: 50 | def __init__(self): 51 | self.domains_ = [] 52 | self.interv_edges_ = set() 53 | 54 | def add_domain(self, interventions): 55 | self.interv_edges_.update(interventions) 56 | self.domains_.append(interventions) 57 | 58 | def get_learned_edges(self): 59 | if len(self.domains_) == 1: 60 | return [] 61 | else: 62 | return intervention_dict[tuple(sorted(self.interv_edges_))] 63 | 64 | class PairwisePC: 65 | def __init__(self): 66 | self.domains_ = [] 67 | self.learned_edges_ = [] 68 | 69 | def add_domain(self, interventions): 70 | for prior_domain in self.domains_: 71 | self.learned_edges_.append(intervention_dict[tuple( 72 | sorted(np.unique(np.hstack((prior_domain, interventions)))) 73 | )]) 74 | 75 | self.domains_.append(interventions) 76 | 77 | def get_learned_edges(self): 78 | if len(self.domains_) == 1: 79 | return [] 80 | else: 81 | return np.unique(np.hstack(self.learned_edges_)).astype(int) 82 | 83 | 84 | # ## Experiments 85 | 86 | # In[5]: 87 | 88 | 89 | results_mat = [] 90 | n_environments = 15 91 | n_reps = 200 92 | 93 | for rep in range(n_reps): 94 | for sparsity in [1, 2, 3]: 95 | fpc = FullPC() 96 | ppc = PairwisePC() 97 | for n_env in range(1, n_environments+1): 98 | interventions = tuple(np.random.choice([1, 2, 3], sparsity, replace=False)) 99 | fpc.add_domain(interventions) 100 | ppc.add_domain(interventions) 101 | results_mat.append([ 102 | rep, 'Full PC', sparsity, n_env, len(fpc.get_learned_edges()) 103 | ]) 104 | results_mat.append([ 105 | rep, 'Pairwise PC', sparsity, n_env, len(ppc.get_learned_edges()) 106 | ]) 107 | 108 | 109 | # In[6]: 110 | 111 | 112 | df = pd.DataFrame( 113 | results_mat, 114 | columns=['Rep', 'Method', 'Sparsity', 'Number of environments', 'Learned edges']) 115 | 116 | df['Fraction edges learned'] = df['Learned edges'] / 3 117 | 118 | 119 | # In[7]: 120 | 121 | 122 | plt.figure(figsize=(4, 3)) 123 | sns.lineplot( 124 | data=df, 125 | x='Number of environments', 126 | y='Fraction edges learned', 127 | hue='Method', 128 | style='Sparsity', 129 | ci=None, 130 | # palette='gist_heat' 131 | ) 132 | plt.title('Sparse shifts can provide identifiability') 133 | plt.xticks([1, 2, 5, 10, 15]) 134 | plt.tight_layout() 135 | # plt.savefig('./figures/pairwise_oracle_pc_simulation.pdf') 136 | plt.show() 137 | 138 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn 3 | matplotlib 4 | networkx 5 | joblib 6 | causal-learn 7 | causaldag 8 | hyppo 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | VERSION = 1.0 5 | PACKAGE_NAME = "sparse_shift" 6 | DESCRIPTION = "Conditional independence tools for causal learning under the sparse mechanism shift hypothesis." 7 | with open("README.md", "r") as f: 8 | LONG_DESCRIPTION = f.read() 9 | AUTHOR = ("Ronan Perry",) 10 | AUTHOR_EMAIL = "rflperry@gmail.com" 11 | with open('requirements.txt') as f: 12 | REQUIRED_PACKAGES = f.read().splitlines() 13 | 14 | setup( 15 | name=PACKAGE_NAME, 16 | version=VERSION, 17 | description=DESCRIPTION, 18 | long_description=LONG_DESCRIPTION, 19 | author=AUTHOR, 20 | author_email=AUTHOR_EMAIL, 21 | install_requires=REQUIRED_PACKAGES, 22 | license="MIT", 23 | packages=find_packages(), 24 | ) -------------------------------------------------------------------------------- /sparse_shift/__init__.py: -------------------------------------------------------------------------------- 1 | from .kcd import KCD, KCDCV 2 | from .utils import dags2mechanisms, dag2cpdag, cpdag2dags 3 | from .metrics import dag_true_orientations, dag_false_orientations, dag_precision, dag_recall 4 | from .methods import * 5 | -------------------------------------------------------------------------------- /sparse_shift/causal_learn/GraphClass.py: -------------------------------------------------------------------------------- 1 | import io 2 | import warnings 3 | from itertools import permutations 4 | from typing import List, Tuple 5 | 6 | import matplotlib.image as mpimg 7 | import matplotlib.pyplot as plt 8 | import networkx as nx 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from causallearn.graph.Edge import Edge 13 | from causallearn.graph.Endpoint import Endpoint 14 | from causallearn.graph.GeneralGraph import GeneralGraph 15 | from causallearn.graph.GraphNode import GraphNode 16 | from causallearn.graph.Node import Node 17 | from causallearn.utils.GraphUtils import GraphUtils 18 | from causallearn.utils.PCUtils.Helper import list_union, powerset 19 | 20 | 21 | class CausalGraph: 22 | def __init__(self, no_of_var: int=None, G: GeneralGraph=None): 23 | if G is None: 24 | node_names: List[str] = [("X%d" % (i + 1)) for i in range(no_of_var)] 25 | nodes: List[Node] = [] 26 | for name in node_names: 27 | node = GraphNode(name) 28 | nodes.append(node) 29 | self.G: GeneralGraph = GeneralGraph(nodes) 30 | 31 | for i in range(no_of_var): 32 | for j in range(i + 1, no_of_var): 33 | self.G.add_edge(Edge(nodes[i], nodes[j], Endpoint.TAIL, Endpoint.TAIL)) 34 | else: 35 | self.G = G 36 | nodes = G.get_nodes() 37 | no_of_var = len(nodes) 38 | 39 | self.data = None # store the data 40 | self.test = None # store the name of the conditional independence test 41 | self.corr_mat = None # store the correlation matrix of the data 42 | self.sepset = np.empty((no_of_var, no_of_var), object) # store the collection of sepsets 43 | self.definite_UC = [] # store the list of definite unshielded colliders 44 | self.definite_non_UC = [] # store the list of definite unshielded non-colliders 45 | self.PC_elapsed = -1 # store the elapsed time of running PC 46 | self.redundant_nodes = [] # store the list of redundant nodes (for subgraphs) 47 | self.nx_graph = nx.DiGraph() # store the directed graph 48 | self.nx_skel = nx.Graph() # store the undirected graph 49 | self.labels = {} 50 | self.prt_m = {} # store the parents of missingness indicators 51 | self.mvpc = False 52 | self.cardinalities = None # only works when self.data is discrete, i.e. self.test is chisq or gsq 53 | self.is_discrete = False 54 | self.citest_cache = dict() 55 | self.data_hash_key = None 56 | self.ci_test_hash_key = None 57 | 58 | def set_ind_test(self, indep_test, mvpc=False): 59 | """Set the conditional independence test that will be used""" 60 | # assert name_of_test in ["Fisher_Z", "Chi_sq", "G_sq"] 61 | self.mvpc = mvpc 62 | self.test = indep_test 63 | self.ci_test_hash_key = hash(indep_test) 64 | 65 | def ci_test(self, i: int, j: int, S) -> float: 66 | """Define the conditional independence test""" 67 | # assert i != j and not i in S and not j in S 68 | if self.mvpc: 69 | return self.test(self.data, self.nx_skel, self.prt_m, i, j, S, self.data.shape[0]) 70 | 71 | i, j = (i, j) if (i < j) else (j, i) 72 | ijS_key = (i, j, frozenset(S), self.data_hash_key, self.ci_test_hash_key) 73 | if ijS_key in self.citest_cache: 74 | return self.citest_cache[ijS_key] 75 | # if discrete, assert self.test is chisq or gsq, pass into cardinalities 76 | pValue = self.test(self.data, i, j, S, self.cardinalities) if self.is_discrete \ 77 | else self.test(self.data, i, j, S) 78 | self.citest_cache[ijS_key] = pValue 79 | return pValue 80 | 81 | def neighbors(self, i: int): 82 | """Find the neighbors of node i in adjmat""" 83 | return np.where(self.G.graph[i, :] != 0)[0] 84 | 85 | def max_degree(self) -> int: 86 | """Return the maximum number of edges connected to a node in adjmat""" 87 | return max(np.sum(self.G.graph != 0, axis=1)) 88 | 89 | def find_arrow_heads(self) -> List[Tuple[int, int]]: 90 | """Return the list of i o-> j in adjmat as (i, j)""" 91 | L = np.where(self.G.graph == 1) 92 | return list(zip(L[1], L[0])) 93 | 94 | def find_tails(self) -> List[Tuple[int, int]]: 95 | """Return the list of i --o j in adjmat as (j, i)""" 96 | L = np.where(self.G.graph == -1) 97 | return list(zip(L[1], L[0])) 98 | 99 | def find_undirected(self) -> List[Tuple[int, int]]: 100 | """Return the list of undirected edge i --- j in adjmat as (i, j) [with symmetry]""" 101 | return [(edge[0], edge[1]) for edge in self.find_tails() if self.G.graph[edge[0], edge[1]] == -1] 102 | 103 | def find_fully_directed(self) -> List[Tuple[int, int]]: 104 | """Return the list of directed edges i --> j in adjmat as (i, j)""" 105 | return [(edge[0], edge[1]) for edge in self.find_arrow_heads() if self.G.graph[edge[0], edge[1]] == -1] 106 | 107 | def find_bi_directed(self) -> List[Tuple[int, int]]: 108 | """Return the list of bidirected edges i <-> j in adjmat as (i, j) [with symmetry]""" 109 | return [(edge[1], edge[0]) for edge in self.find_arrow_heads() if ( 110 | self.G.graph[edge[1], edge[0]] == Endpoint.ARROW.value and self.G.graph[ 111 | edge[0], edge[1]] == Endpoint.ARROW.value)] 112 | 113 | def find_adj(self): 114 | """Return the list of adjacencies i --- j in adjmat as (i, j) [with symmetry]""" 115 | return list(self.find_tails() + self.find_arrow_heads()) 116 | 117 | def is_undirected(self, i, j) -> bool: 118 | """Return True if i --- j holds in adjmat and False otherwise""" 119 | return self.G.graph[i, j] == -1 and self.G.graph[j, i] == -1 120 | 121 | def is_fully_directed(self, i, j) -> bool: 122 | """Return True if i --> j holds in adjmat and False otherwise""" 123 | return self.G.graph[i, j] == -1 and self.G.graph[j, i] == 1 124 | 125 | def find_unshielded_triples(self) -> List[Tuple[int, int, int]]: 126 | """Return the list of unshielded triples i o-o j o-o k in adjmat as (i, j, k)""" 127 | return [(pair[0][0], pair[0][1], pair[1][1]) for pair in permutations(self.find_adj(), 2) 128 | if pair[0][1] == pair[1][0] and pair[0][0] != pair[1][1] and self.G.graph[pair[0][0], pair[1][1]] == 0] 129 | 130 | def find_triangles(self) -> List[Tuple[int, int, int]]: 131 | """Return the list of triangles i o-o j o-o k o-o i in adjmat as (i, j, k) [with symmetry]""" 132 | Adj = self.find_adj() 133 | return [(pair[0][0], pair[0][1], pair[1][1]) for pair in permutations(Adj, 2) 134 | if pair[0][1] == pair[1][0] and pair[0][0] != pair[1][1] and (pair[0][0], pair[1][1]) in Adj] 135 | 136 | def find_kites(self) -> List[Tuple[int, int, int, int]]: 137 | """Return the list of non-ambiguous kites i o-o j o-o l o-o k o-o i o-o l in adjmat \ 138 | (where j and k are non-adjacent) as (i, j, k, l) [with asymmetry j < k]""" 139 | return [(pair[0][0], pair[0][1], pair[1][1], pair[0][2]) for pair in permutations(self.find_triangles(), 2) 140 | if pair[0][0] == pair[1][0] and pair[0][2] == pair[1][2] 141 | and pair[0][1] < pair[1][1] and self.G.graph[pair[0][1], pair[1][1]] == 0] 142 | 143 | def find_cond_sets(self, i: int, j: int) -> List[Tuple[int]]: 144 | """return the list of conditioning sets of the neighbors of i or j in adjmat""" 145 | neigh_x = self.neighbors(i) 146 | neigh_y = self.neighbors(j) 147 | pow_neigh_x = powerset(neigh_x) 148 | pow_neigh_y = powerset(neigh_y) 149 | return list_union(pow_neigh_x, pow_neigh_y) 150 | 151 | def find_cond_sets_with_mid(self, i: int, j: int, k: int) -> List[Tuple[int]]: 152 | """return the list of conditioning sets of the neighbors of i or j in adjmat which contains k""" 153 | return [S for S in self.find_cond_sets(i, j) if k in S] 154 | 155 | def find_cond_sets_without_mid(self, i: int, j: int, k: int) -> List[Tuple[int]]: 156 | """return the list of conditioning sets of the neighbors of i or j which in adjmat does not contain k""" 157 | return [S for S in self.find_cond_sets(i, j) if k not in S] 158 | 159 | def rearrange(self, PATH): 160 | """Rearrange adjmat according to the data imported at PATH""" 161 | raw_col_names = list(pd.read_csv(PATH, sep='\t').columns) 162 | var_indices = [] 163 | for name in raw_col_names: 164 | var_indices.append(int(name.split('X')[1]) - 1) 165 | new_indices = np.zeros_like(var_indices) 166 | for i in range(1, len(new_indices)): 167 | new_indices[var_indices[i]] = range(len(new_indices))[i] 168 | output = self.adjmat[:, new_indices] 169 | output = output[new_indices, :] 170 | self.adjmat = output 171 | 172 | def to_nx_graph(self): 173 | """Convert adjmat into a networkx.Digraph object named nx_graph""" 174 | nodes = range(len(self.G.graph)) 175 | self.labels = {i: self.G.nodes[i].get_name() for i in nodes} 176 | self.nx_graph.add_nodes_from(nodes) 177 | undirected = self.find_undirected() 178 | directed = self.find_fully_directed() 179 | bidirected = self.find_bi_directed() 180 | for (i, j) in undirected: 181 | self.nx_graph.add_edge(i, j, color='g') # Green edge: undirected edge 182 | for (i, j) in directed: 183 | self.nx_graph.add_edge(i, j, color='b') # Blue edge: directed edge 184 | for (i, j) in bidirected: 185 | self.nx_graph.add_edge(i, j, color='r') # Red edge: bidirected edge 186 | 187 | def to_nx_skeleton(self): 188 | """Convert adjmat into its skeleton (a networkx.Graph object) named nx_skel""" 189 | nodes = range(len(self.G.graph)) 190 | self.nx_skel.add_nodes_from(nodes) 191 | adj = [(i, j) for (i, j) in self.find_adj() if i < j] 192 | for (i, j) in adj: 193 | self.nx_skel.add_edge(i, j, color='g') # Green edge: undirected edge 194 | 195 | def draw_nx_graph(self, skel=False): 196 | """Draw nx_graph if skel = False and draw nx_skel otherwise""" 197 | if not skel: 198 | print("Green: undirected; Blue: directed; Red: bi-directed\n") 199 | warnings.filterwarnings("ignore", category=UserWarning) 200 | g_to_be_drawn = self.nx_skel if skel else self.nx_graph 201 | edges = g_to_be_drawn.edges() 202 | colors = [g_to_be_drawn[u][v]['color'] for u, v in edges] 203 | pos = nx.circular_layout(g_to_be_drawn) 204 | nx.draw(g_to_be_drawn, pos=pos, with_labels=True, labels=self.labels, edge_color=colors) 205 | plt.draw() 206 | plt.show() 207 | 208 | def draw_pydot_graph(self): 209 | """Draw nx_graph if skel = False and draw nx_skel otherwise""" 210 | warnings.filterwarnings("ignore", category=UserWarning) 211 | pyd = GraphUtils.to_pydot(self.G) 212 | tmp_png = pyd.create_png(f="png") 213 | pyd.write_png("result.png") 214 | fp = io.BytesIO(tmp_png) 215 | img = mpimg.imread(fp, format='png') 216 | plt.axis('off') 217 | plt.imshow(img) 218 | plt.show() 219 | -------------------------------------------------------------------------------- /sparse_shift/causal_learn/PC.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code aquired and modified from the causal-learn github package 3 | https://github.com/cmu-phil/causal-learn 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | import time 9 | import warnings 10 | from itertools import combinations, permutations 11 | from typing import Dict, List, Tuple 12 | 13 | import networkx as nx 14 | from numpy import ndarray 15 | 16 | from causallearn.graph.GraphClass import CausalGraph 17 | from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge 18 | from causallearn.utils.cit import * 19 | from causallearn.utils.PCUtils import Helper, Meek, UCSepset 20 | from .SkeletonDiscovery import augmented_skeleton_discovery as skeleton_discovery 21 | from causallearn.utils.PCUtils.BackgroundKnowledgeOrientUtils import \ 22 | orient_by_background_knowledge 23 | 24 | 25 | def augmented_pc(data: ndarray, alpha=0.05, indep_test=fisherz, stable: bool = True, uc_rule: int = 0, uc_priority: int = 2, 26 | mvpc: bool = False, correction_name: str = 'MV_Crtn_Fisher_Z', 27 | background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, show_progress: bool = True, 28 | cg: CausalGraph = None): 29 | if data.shape[0] < data.shape[1]: 30 | warnings.warn("The number of features is much larger than the sample size!") 31 | 32 | if mvpc: # missing value PC 33 | if indep_test == fisherz: 34 | indep_test = mv_fisherz 35 | return mvpc_alg(data=data, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable, 36 | uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge, 37 | verbose=verbose, 38 | show_progress=show_progress) 39 | else: 40 | return augmented_pc_alg(data=data, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule, 41 | uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose, 42 | show_progress=show_progress, cg=cg) 43 | 44 | 45 | def augmented_pc_alg(data: ndarray, alpha: float, indep_test, stable: bool, uc_rule: int, uc_priority: int, 46 | background_knowledge: BackgroundKnowledge | None = None, 47 | verbose: bool = False, 48 | show_progress: bool = True, cg: CausalGraph = None) -> CausalGraph: 49 | """ 50 | Perform Peter-Clark (PC) algorithm for causal discovery 51 | 52 | Parameters 53 | ---------- 54 | data : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features. 55 | alpha : float, desired significance level of independence tests (p_value) in (0,1) 56 | indep_test : the function of the independence test being used 57 | [fisherz, chisq, gsq, kci] 58 | - fisherz: Fisher's Z conditional independence test 59 | - chisq: Chi-squared conditional independence test 60 | - gsq: G-squared conditional independence test 61 | - kci: Kernel-based conditional independence test 62 | stable : run stabilized skeleton discovery if True (default = True) 63 | uc_rule : how unshielded colliders are oriented 64 | 0: run uc_sepset 65 | 1: run maxP 66 | 2: run definiteMaxP 67 | uc_priority : rule of resolving conflicts between unshielded colliders 68 | -1: whatever is default in uc_rule 69 | 0: overwrite 70 | 1: orient bi-directed 71 | 2. prioritize existing colliders 72 | 3. prioritize stronger colliders 73 | 4. prioritize stronger* colliers 74 | background_knowledge : background knowledge 75 | verbose : True iff verbose output should be printed. 76 | show_progress : True iff the algorithm progress should be show in console. 77 | 78 | Returns 79 | ------- 80 | cg : a CausalGraph object, where cg.G.graph[j,i]=1 and cg.G.graph[i,j]=-1 indicates i --> j , 81 | cg.G.graph[i,j] = cg.G.graph[j,i] = -1 indicates i --- j, 82 | cg.G.graph[i,j] = cg.G.graph[j,i] = 1 indicates i <-> j. 83 | 84 | """ 85 | 86 | start = time.time() 87 | cg_1 = skeleton_discovery(data, alpha, indep_test, stable, 88 | background_knowledge=background_knowledge, verbose=verbose, 89 | show_progress=show_progress, cg=cg) 90 | 91 | if background_knowledge is not None: 92 | orient_by_background_knowledge(cg_1, background_knowledge) 93 | 94 | if uc_rule == 0: 95 | if uc_priority != -1: 96 | cg_2 = UCSepset.uc_sepset(cg_1, uc_priority, background_knowledge=background_knowledge) 97 | else: 98 | cg_2 = UCSepset.uc_sepset(cg_1, background_knowledge=background_knowledge) 99 | cg = Meek.meek(cg_2, background_knowledge=background_knowledge) 100 | 101 | elif uc_rule == 1: 102 | if uc_priority != -1: 103 | cg_2 = UCSepset.maxp(cg_1, uc_priority, background_knowledge=background_knowledge) 104 | else: 105 | cg_2 = UCSepset.maxp(cg_1, background_knowledge=background_knowledge) 106 | cg = Meek.meek(cg_2, background_knowledge=background_knowledge) 107 | 108 | elif uc_rule == 2: 109 | if uc_priority != -1: 110 | cg_2 = UCSepset.definite_maxp(cg_1, alpha, uc_priority, background_knowledge=background_knowledge) 111 | else: 112 | cg_2 = UCSepset.definite_maxp(cg_1, alpha, background_knowledge=background_knowledge) 113 | cg_before = Meek.definite_meek(cg_2, background_knowledge=background_knowledge) 114 | cg = Meek.meek(cg_before, background_knowledge=background_knowledge) 115 | else: 116 | raise ValueError("uc_rule should be in [0, 1, 2]") 117 | end = time.time() 118 | 119 | cg.PC_elapsed = end - start 120 | 121 | return cg 122 | 123 | 124 | def mvpc_alg(data: ndarray, alpha: float, indep_test, correction_name: str, stable: bool, uc_rule: int, 125 | uc_priority: int, background_knowledge: BackgroundKnowledge | None = None, 126 | verbose: bool = False, 127 | show_progress: bool = True) -> CausalGraph: 128 | """ 129 | Perform missing value Peter-Clark (PC) algorithm for causal discovery 130 | 131 | Parameters 132 | ---------- 133 | data : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features. 134 | alpha : float, desired significance level of independence tests (p_value) in (0,1) 135 | indep_test : name of the test-wise deletion independence test being used 136 | [mv_fisherz, mv_g_sq] 137 | - mv_fisherz: Fisher's Z conditional independence test 138 | - mv_g_sq: G-squared conditional independence test (TODO: under development) 139 | correction_name : correction_name: name of the missingness correction 140 | [MV_Crtn_Fisher_Z, MV_Crtn_G_sq, MV_DRW_Fisher_Z, MV_DRW_G_sq] 141 | - "MV_Crtn_Fisher_Z": Permutation based correction method 142 | - "MV_Crtn_G_sq": G-squared conditional independence test (TODO: under development) 143 | - "MV_DRW_Fisher_Z": density ratio weighting based correction method (TODO: under development) 144 | - "MV_DRW_G_sq": G-squared conditional independence test (TODO: under development) 145 | stable : run stabilized skeleton discovery if True (default = True) 146 | uc_rule : how unshielded colliders are oriented 147 | 0: run uc_sepset 148 | 1: run maxP 149 | 2: run definiteMaxP 150 | uc_priority : rule of resolving conflicts between unshielded colliders 151 | -1: whatever is default in uc_rule 152 | 0: overwrite 153 | 1: orient bi-directed 154 | 2. prioritize existing colliders 155 | 3. prioritize stronger colliders 156 | 4. prioritize stronger* colliers 157 | background_knowledge: background knowledge 158 | verbose : True iff verbose output should be printed. 159 | show_progress : True iff the algorithm progress should be show in console. 160 | 161 | Returns 162 | ------- 163 | cg : a CausalGraph object, where cg.G.graph[j,i]=1 and cg.G.graph[i,j]=-1 indicates i --> j , 164 | cg.G.graph[i,j] = cg.G.graph[j,i] = -1 indicates i --- j, 165 | cg.G.graph[i,j] = cg.G.graph[j,i] = 1 indicates i <-> j. 166 | 167 | """ 168 | 169 | start = time.time() 170 | 171 | ## Step 1: detect the direct causes of missingness indicators 172 | prt_m = get_prt_mpairs(data, alpha, indep_test, stable) 173 | # print('Finish detecting the parents of missingness indicators. ') 174 | 175 | ## Step 2: 176 | ## a) Run PC algorithm with the 1st step skeleton; 177 | cg_pre = skeleton_discovery(data, alpha, indep_test, stable, 178 | background_knowledge=background_knowledge, 179 | verbose=verbose, show_progress=show_progress) 180 | if background_knowledge is not None: 181 | orient_by_background_knowledge(cg_pre, background_knowledge) 182 | 183 | cg_pre.to_nx_skeleton() 184 | # print('Finish skeleton search with test-wise deletion.') 185 | 186 | ## b) Correction of the extra edges 187 | cg_corr = skeleton_correction(data, alpha, correction_name, cg_pre, prt_m, stable) 188 | # print('Finish missingness correction.') 189 | 190 | if background_knowledge is not None: 191 | orient_by_background_knowledge(cg_corr, background_knowledge) 192 | 193 | ## Step 3: Orient the edges 194 | if uc_rule == 0: 195 | if uc_priority != -1: 196 | cg_2 = UCSepset.uc_sepset(cg_corr, uc_priority, background_knowledge=background_knowledge) 197 | else: 198 | cg_2 = UCSepset.uc_sepset(cg_corr, background_knowledge=background_knowledge) 199 | cg = Meek.meek(cg_2, background_knowledge=background_knowledge) 200 | 201 | elif uc_rule == 1: 202 | if uc_priority != -1: 203 | cg_2 = UCSepset.maxp(cg_corr, uc_priority, background_knowledge=background_knowledge) 204 | else: 205 | cg_2 = UCSepset.maxp(cg_corr, background_knowledge=background_knowledge) 206 | cg = Meek.meek(cg_2, background_knowledge=background_knowledge) 207 | 208 | elif uc_rule == 2: 209 | if uc_priority != -1: 210 | cg_2 = UCSepset.definite_maxp(cg_corr, alpha, uc_priority, background_knowledge=background_knowledge) 211 | else: 212 | cg_2 = UCSepset.definite_maxp(cg_corr, alpha, background_knowledge=background_knowledge) 213 | cg_before = Meek.definite_meek(cg_2, background_knowledge=background_knowledge) 214 | cg = Meek.meek(cg_before, background_knowledge=background_knowledge) 215 | else: 216 | raise ValueError("uc_rule should be in [0, 1, 2]") 217 | end = time.time() 218 | 219 | cg.PC_elapsed = end - start 220 | 221 | return cg 222 | 223 | 224 | ####################################################################################################################### 225 | ## *********** Functions for Step 1 *********** 226 | def get_prt_mpairs(data: ndarray, alpha: float, indep_test, stable: bool = True) -> Dict[str, list]: 227 | """ 228 | Detect the parents of missingness indicators 229 | If a missingness indicator has no parent, it will not be included in the result 230 | :param data: data set (numpy ndarray) 231 | :param alpha: desired significance level in (0, 1) (float) 232 | :param indep_test: name of the test-wise deletion independence test being used 233 | - "MV_Fisher_Z": Fisher's Z conditional independence test 234 | - "MV_G_sq": G-squared conditional independence test (TODO: under development) 235 | :param stable: run stabilized skeleton discovery if True (default = True) 236 | :return: 237 | cg: a CausalGraph object 238 | """ 239 | prt_m = {'prt': [], 'm': []} 240 | 241 | ## Get the index of missingness indicators 242 | m_indx = get_mindx(data) 243 | 244 | ## Get the index of parents of missingness indicators 245 | # If the missingness indicator has no parent, then it will not be collected in prt_m 246 | for r in m_indx: 247 | prt_r = detect_parent(r, data, alpha, indep_test, stable) 248 | if isempty(prt_r): 249 | pass 250 | else: 251 | prt_m['prt'].append(prt_r) 252 | prt_m['m'].append(r) 253 | return prt_m 254 | 255 | 256 | def isempty(prt_r) -> bool: 257 | """Test whether the parent of a missingness indicator is empty""" 258 | return len(prt_r) == 0 259 | 260 | 261 | def get_mindx(data: ndarray) -> List[int]: 262 | """Detect the parents of missingness indicators 263 | :param data: data set (numpy ndarray) 264 | :return: 265 | m_indx: list, the index of missingness indicators 266 | """ 267 | 268 | m_indx = [] 269 | _, ncol = np.shape(data) 270 | for i in range(ncol): 271 | if np.isnan(data[:, i]).any(): 272 | m_indx.append(i) 273 | return m_indx 274 | 275 | 276 | def detect_parent(r: int, data_: ndarray, alpha: float, indep_test, stable: bool = True) -> ndarray: 277 | """Detect the parents of a missingness indicator 278 | :param r: the missingness indicator 279 | :param data_: data set (numpy ndarray) 280 | :param alpha: desired significance level in (0, 1) (float) 281 | :param indep_test: name of the test-wise deletion independence test being used 282 | - "MV_Fisher_Z": Fisher's Z conditional independence test 283 | - "MV_G_sq": G-squared conditional independence test (TODO: under development) 284 | :param stable: run stabilized skeleton discovery if True (default = True) 285 | : return: 286 | prt: parent of the missingness indicator, r 287 | """ 288 | ## TODO: in the test-wise deletion CI test, if test between a binary and a continuous variable, 289 | # there can be the case where the binary variable only take one value after deletion. 290 | # It is because the assumption is violated. 291 | 292 | ## *********** Adaptation 0 *********** 293 | # For avoid changing the original data 294 | data = data_.copy() 295 | ## *********** End *********** 296 | 297 | assert type(data) == np.ndarray 298 | assert 0 < alpha < 1 299 | 300 | ## *********** Adaptation 1 *********** 301 | # data 302 | ## Replace the variable r with its missingness indicator 303 | ## If r is not a missingness indicator, return []. 304 | data[:, r] = np.isnan(data[:, r]).astype(float) # True is missing; false is not missing 305 | if sum(data[:, r]) == 0 or sum(data[:, r]) == len(data[:, r]): 306 | return np.empty(0) 307 | ## *********** End *********** 308 | 309 | no_of_var = data.shape[1] 310 | cg = CausalGraph(no_of_var) 311 | cg.data = data 312 | cg.set_ind_test(indep_test) 313 | cg.corr_mat = np.corrcoef(data, rowvar=False) if indep_test == fisherz else [] 314 | 315 | node_ids = range(no_of_var) 316 | pair_of_variables = list(permutations(node_ids, 2)) 317 | 318 | depth = -1 319 | while cg.max_degree() - 1 > depth: 320 | depth += 1 321 | edge_removal = [] 322 | for (x, y) in pair_of_variables: 323 | 324 | ## *********** Adaptation 2 *********** 325 | # the skeleton search 326 | ## Only test which variable is the neighbor of r 327 | if x != r: 328 | continue 329 | ## *********** End *********** 330 | 331 | Neigh_x = cg.neighbors(x) 332 | if y not in Neigh_x: 333 | continue 334 | else: 335 | Neigh_x = np.delete(Neigh_x, np.where(Neigh_x == y)) 336 | 337 | if len(Neigh_x) >= depth: 338 | for S in combinations(Neigh_x, depth): 339 | p = cg.ci_test(x, y, S) 340 | if p > alpha: 341 | if not stable: # Unstable: Remove x---y right away 342 | edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) 343 | if edge1 is not None: 344 | cg.G.remove_edge(edge1) 345 | edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x]) 346 | if edge2 is not None: 347 | cg.G.remove_edge(edge2) 348 | else: # Stable: x---y will be removed only 349 | edge_removal.append((x, y)) # after all conditioning sets at 350 | edge_removal.append((y, x)) # depth l have been considered 351 | Helper.append_value(cg.sepset, x, y, S) 352 | Helper.append_value(cg.sepset, y, x, S) 353 | break 354 | 355 | for (x, y) in list(set(edge_removal)): 356 | edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) 357 | if edge1 is not None: 358 | cg.G.remove_edge(edge1) 359 | 360 | ## *********** Adaptation 3 *********** 361 | ## extract the parent of r from the graph 362 | cg.to_nx_skeleton() 363 | cg_skel_adj = nx.to_numpy_array(cg.nx_skel).astype(int) 364 | prt = get_parent(r, cg_skel_adj) 365 | ## *********** End *********** 366 | 367 | return prt 368 | 369 | 370 | def get_parent(r: int, cg_skel_adj: ndarray) -> ndarray: 371 | """Get the neighbors of missingness indicators which are the parents 372 | :param r: the missingness indicator index 373 | :param cg_skel_adj: adjacancy matrix of a causal skeleton 374 | :return: 375 | prt: list, parents of the missingness indicator r 376 | """ 377 | num_var = len(cg_skel_adj[0, :]) 378 | indx = np.array([i for i in range(num_var)]) 379 | prt = indx[cg_skel_adj[r, :] == 1] 380 | return prt 381 | 382 | 383 | ## *********** END *********** 384 | ####################################################################################################################### 385 | 386 | def skeleton_correction(data: ndarray, alpha: float, test_with_correction_name: str, init_cg: CausalGraph, prt_m: dict, 387 | stable: bool = True) -> CausalGraph: 388 | """Perform skeleton discovery 389 | :param data: data set (numpy ndarray) 390 | :param alpha: desired significance level in (0, 1) (float) 391 | :param test_with_correction_name: name of the independence test being used 392 | - "MV_Crtn_Fisher_Z": Fisher's Z conditional independence test 393 | - "MV_Crtn_G_sq": G-squared conditional independence test 394 | :param stable: run stabilized skeleton discovery if True (default = True) 395 | :return: 396 | cg: a CausalGraph object 397 | """ 398 | 399 | assert type(data) == np.ndarray 400 | assert 0 < alpha < 1 401 | assert test_with_correction_name in ["MV_Crtn_Fisher_Z", "MV_Crtn_G_sq"] 402 | 403 | ## *********** Adaption 1 *********** 404 | no_of_var = data.shape[1] 405 | 406 | ## Initialize the graph with the result of test-wise deletion skeletion search 407 | cg = init_cg 408 | 409 | cg.data = data 410 | if test_with_correction_name in ["MV_Crtn_Fisher_Z", "MV_Crtn_G_sq"]: 411 | cg.set_ind_test(mc_fisherz, True) 412 | # No need of the correlation matrix if using test-wise deletion test 413 | cg.corr_mat = np.corrcoef(data, rowvar=False) if test_with_correction_name == "MV_Crtn_Fisher_Z" else [] 414 | cg.prt_m = prt_m 415 | ## *********** Adaption 1 *********** 416 | 417 | node_ids = range(no_of_var) 418 | pair_of_variables = list(permutations(node_ids, 2)) 419 | 420 | depth = -1 421 | while cg.max_degree() - 1 > depth: 422 | depth += 1 423 | edge_removal = [] 424 | for (x, y) in pair_of_variables: 425 | Neigh_x = cg.neighbors(x) 426 | if y not in Neigh_x: 427 | continue 428 | else: 429 | Neigh_x = np.delete(Neigh_x, np.where(Neigh_x == y)) 430 | 431 | if len(Neigh_x) >= depth: 432 | for S in combinations(Neigh_x, depth): 433 | p = cg.ci_test(x, y, S) 434 | if p > alpha: 435 | if not stable: # Unstable: Remove x---y right away 436 | edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) 437 | if edge1 is not None: 438 | cg.G.remove_edge(edge1) 439 | edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x]) 440 | if edge2 is not None: 441 | cg.G.remove_edge(edge2) 442 | else: # Stable: x---y will be removed only 443 | edge_removal.append((x, y)) # after all conditioning sets at 444 | edge_removal.append((y, x)) # depth l have been considered 445 | Helper.append_value(cg.sepset, x, y, S) 446 | Helper.append_value(cg.sepset, y, x, S) 447 | break 448 | 449 | for (x, y) in list(set(edge_removal)): 450 | edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) 451 | if edge1 is not None: 452 | cg.G.remove_edge(edge1) 453 | 454 | return cg 455 | 456 | 457 | ####################################################################################################################### 458 | 459 | # *********** Evaluation util *********** 460 | 461 | def get_adjacancy_matrix(g: CausalGraph) -> ndarray: 462 | return nx.to_numpy_array(g.nx_graph).astype(int) 463 | 464 | 465 | def matrix_diff(cg1: CausalGraph, cg2: CausalGraph) -> (float, List[Tuple[int, int]]): 466 | adj1 = get_adjacancy_matrix(cg1) 467 | adj2 = get_adjacancy_matrix(cg2) 468 | count = 0 469 | diff_ls = [] 470 | for i in range(len(adj1[:, ])): 471 | for j in range(len(adj2[:, ])): 472 | if adj1[i, j] != adj2[i, j]: 473 | diff_ls.append((i, j)) 474 | count += 1 475 | return count / 2, diff_ls 476 | -------------------------------------------------------------------------------- /sparse_shift/causal_learn/SkeletonDiscovery.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code aquired and modified from the causal-learn github package 3 | https://github.com/cmu-phil/causal-learn 4 | """ 5 | 6 | 7 | from __future__ import annotations 8 | 9 | from itertools import combinations 10 | 11 | import numpy as np 12 | from numpy import ndarray 13 | from tqdm.auto import tqdm 14 | 15 | from causallearn.graph.GraphClass import CausalGraph 16 | from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge 17 | from causallearn.utils.cit import chisq, gsq 18 | from causallearn.utils.PCUtils.Helper import append_value 19 | 20 | 21 | def augmented_skeleton_discovery(data: ndarray, alpha: float, indep_test, stable: bool = True, 22 | background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, 23 | show_progress: bool = True, cg: CausalGraph = None) -> CausalGraph: 24 | """ 25 | Perform skeleton discovery 26 | Parameters 27 | ---------- 28 | data : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of 29 | samples and n_features is the number of features. 30 | alpha: float, desired significance level of independence tests (p_value) in (0,1) 31 | indep_test : the function of the independence test being used 32 | [fisherz, chisq, gsq, mv_fisherz, kci] 33 | - fisherz: Fisher's Z conditional independence test 34 | - chisq: Chi-squared conditional independence test 35 | - gsq: G-squared conditional independence test 36 | - mv_fisherz: Missing-value Fishers'Z conditional independence test 37 | - kci: Kernel-based conditional independence test 38 | stable : run stabilized skeleton discovery if True (default = True) 39 | background_knowledge : background knowledge 40 | verbose : True iff verbose output should be printed. 41 | show_progress : True iff the algorithm progress should be show in console. 42 | Returns 43 | ------- 44 | cg : a CausalGraph object. Where cg.G.graph[j,i]=0 and cg.G.graph[i,j]=1 indicates i -> j , 45 | cg.G.graph[i,j] = cg.G.graph[j,i] = -1 indicates i -- j, 46 | cg.G.graph[i,j] = cg.G.graph[j,i] = 1 indicates i <-> j. 47 | """ 48 | 49 | assert type(data) == np.ndarray 50 | assert 0 < alpha < 1 51 | 52 | import copy 53 | cg = copy.deepcopy(cg) 54 | 55 | no_of_var = data.shape[1] 56 | nodes = cg.G.get_nodes() 57 | 58 | # no_aug_skel = cg 59 | # if cg is None: 60 | # cg = CausalGraph(no_of_var) 61 | cg.set_ind_test(indep_test) 62 | cg.data_hash_key = hash(str(data)) 63 | if indep_test == chisq or indep_test == gsq: 64 | # if dealing with discrete data, data is numpy.ndarray with n rows m columns, 65 | # for each column, translate the discrete values to int indexs starting from 0, 66 | # e.g. [45, 45, 6, 7, 6, 7] -> [2, 2, 0, 1, 0, 1] 67 | # ['apple', 'apple', 'pear', 'peach', 'pear'] -> [0, 0, 2, 1, 2] 68 | # in old code, its presumed that discrete `data` is already indexed, 69 | # but here we make sure it's in indexed form, so allow more user input e.g. 'apple' .. 70 | def _unique(column): 71 | return np.unique(column, return_inverse=True)[1] 72 | 73 | cg.is_discrete = True 74 | cg.data = np.apply_along_axis(_unique, 0, data).astype(np.int64) 75 | cg.cardinalities = np.max(cg.data, axis=0) + 1 76 | else: 77 | cg.data = data 78 | 79 | depth = -1 80 | pbar = tqdm(total=no_of_var) if show_progress else None 81 | while cg.max_degree() - 1 > depth: 82 | depth += 1 83 | edge_removal = [] 84 | if show_progress: 85 | pbar.reset() 86 | 87 | # Just test the last variable (augmented) 88 | for x in range(no_of_var): 89 | 90 | if show_progress: 91 | pbar.update() 92 | if show_progress: 93 | pbar.set_description(f'Depth={depth}, working on node {x}') 94 | Neigh_x = cg.neighbors(x) 95 | if len(Neigh_x) < depth - 1: 96 | continue 97 | for y in [no_of_var - 1]: # range(no_of_var): 98 | knowledge_ban_edge = False 99 | sepsets = set() 100 | if background_knowledge is not None and ( 101 | background_knowledge.is_forbidden(cg.G.nodes[x], cg.G.nodes[y]) 102 | and background_knowledge.is_forbidden(cg.G.nodes[y], cg.G.nodes[x])): 103 | knowledge_ban_edge = True 104 | if knowledge_ban_edge: 105 | if not stable: 106 | edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) 107 | if edge1 is not None: 108 | cg.G.remove_edge(edge1) 109 | edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x]) 110 | if edge2 is not None: 111 | cg.G.remove_edge(edge2) 112 | append_value(cg.sepset, x, y, ()) 113 | append_value(cg.sepset, y, x, ()) 114 | break 115 | else: 116 | edge_removal.append((x, y)) # after all conditioning sets at 117 | edge_removal.append((y, x)) # depth l have been considered 118 | 119 | Neigh_x_noy = np.delete(Neigh_x, np.where(Neigh_x == y)) 120 | 121 | for S in combinations(Neigh_x_noy, depth): 122 | # if x != no_of_var-1 and y != no_of_var-1: 123 | # p = cg.G.is_dseparated_from(nodes[x], nodes[y], [nodes[s] for s in S] + [nodes[-1]]) 124 | # p = int(p) 125 | # else: 126 | p = cg.ci_test(x, y, S) 127 | if p > alpha: 128 | if verbose: 129 | print('%d ind %d | %s with p-value %f\n' % (x, y, S, p)) 130 | if not stable: 131 | edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) 132 | if edge1 is not None: 133 | cg.G.remove_edge(edge1) 134 | edge2 = cg.G.get_edge(cg.G.nodes[y], cg.G.nodes[x]) 135 | print(x, y, S, edge1, edge2) 136 | if edge2 is not None: 137 | cg.G.remove_edge(edge2) 138 | append_value(cg.sepset, x, y, S) 139 | append_value(cg.sepset, y, x, S) 140 | break 141 | else: 142 | edge_removal.append((x, y)) # after all conditioning sets at 143 | edge_removal.append((y, x)) # depth l have been considered 144 | for s in S: 145 | sepsets.add(s) 146 | else: 147 | if verbose: 148 | print('%d dep %d | %s with p-value %f\n' % (x, y, S, p)) 149 | append_value(cg.sepset, x, y, tuple(sepsets)) 150 | append_value(cg.sepset, y, x, tuple(sepsets)) 151 | 152 | if show_progress: 153 | pbar.refresh() 154 | 155 | for (x, y) in list(set(edge_removal)): 156 | edge1 = cg.G.get_edge(cg.G.nodes[x], cg.G.nodes[y]) 157 | if edge1 is not None: 158 | cg.G.remove_edge(edge1) 159 | 160 | if show_progress: 161 | pbar.close() 162 | 163 | return cg 164 | -------------------------------------------------------------------------------- /sparse_shift/causal_learn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rflperry/sparse_shift/8620c5ccf7b7b28b2d2946d3f84c1b16b8fcfd39/sparse_shift/causal_learn/__init__.py -------------------------------------------------------------------------------- /sparse_shift/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .simulations import sample_topological, sample_nonlinear_icp_sim, \ 2 | sample_cdnod_sim 3 | from .dags import erdos_renyi_dag, connected_erdos_renyi_dag, barabasi_albert_dag, complete_dag 4 | -------------------------------------------------------------------------------- /sparse_shift/datasets/dags.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | 4 | 5 | def _graph2dag(graph): 6 | """Converts nx.Graph to an directed, acyclic form. Returns the adjancency matrix""" 7 | adj = nx.adj_matrix(graph).todense() 8 | adj = adj + adj.T 9 | adj = (adj != 0).astype(int) 10 | adj = np.tril(adj) 11 | 12 | assert nx.is_directed_acyclic_graph(nx.from_numpy_matrix(adj, create_using=nx.DiGraph)) 13 | 14 | return adj 15 | 16 | 17 | def erdos_renyi_dag(n, p, seed=None): 18 | """ 19 | Simulates an Erdos Renyi random DAG on n vertices 20 | with expected degree p. Each node has the same expected 21 | degree. 22 | 23 | If p is an integer, it is the expected 24 | number of connected edges. Else, it is the expected degree 25 | fraction relative to n. 26 | """ 27 | if p > 1 or isinstance(p, int): 28 | p = p / (n - 1) 29 | G = nx.erdos_renyi_graph(n, p, seed, directed=False) 30 | return _graph2dag(G) 31 | 32 | 33 | def connected_erdos_renyi_dag(n, p, seed=None): 34 | """ 35 | Simulates an Erdos Renyi random DAG on n vertices 36 | with expected degree p. Each node has the same expected 37 | degree and the graph is gauranteed connected, with 38 | a deterministic number of edges. 39 | 40 | If p is an integer, it is the expected 41 | number of connected edges. Else, it is the expected degree 42 | fraction relative to n. 43 | """ 44 | if p <= 1 and isinstance(p, float): 45 | p = p * n 46 | if int(p) != p: 47 | import warnings 48 | warnings.warn(f'Number of neighbors {p:.1f} will be rounded') 49 | 50 | G = nx.connected_watts_strogatz_graph( 51 | n, k=round(p), p=1 - 1/n, seed=seed 52 | ) 53 | return _graph2dag(G) 54 | 55 | 56 | def barabasi_albert_dag(n, p, seed=None): 57 | """ 58 | Simulates an Barabasi Albert DAG on n vertices 59 | with expected degree p. The degree distribution follows 60 | a power law, and the graph is guaranteed to be connected. 61 | 62 | If p is an integer, it is the expected 63 | number of connected edges. Else, it is the expected degree 64 | fraction relative to n. Important, p must be <= 0.5 65 | or the integer equivalent to be guaranteed to succeed on all graphs. 66 | """ 67 | if p > 1 or isinstance(p, int): 68 | p = p / (n - 1) 69 | 70 | # BA model input m leads to K=(1+...+m) + m*(n-m) total edges 71 | # p = K 72 | m = 0.5*(2*n - 1 - np.sqrt(4*n**2 - 4*n + 1 - 4*p*n**2 + 4*p*n)) 73 | if int(m) != m: 74 | import warnings 75 | warnings.warn(f'Number of neighbors {m:.1f} will be rounded') 76 | 77 | G = nx.barabasi_albert_graph(n, round(m), seed) 78 | return _graph2dag(G) 79 | 80 | 81 | def complete_dag(n, p=None, seed=None): 82 | """ 83 | Returns a complete DAG over n variables 84 | """ 85 | G = np.ones((n, n)) - np.eye(n) 86 | return np.tril(G) 87 | -------------------------------------------------------------------------------- /sparse_shift/datasets/simulations.py: -------------------------------------------------------------------------------- 1 | """Simulated causal datasets""" 2 | 3 | import numpy as np 4 | import networkx as nx 5 | from functools import partial 6 | 7 | 8 | def sample_topological(n, equations, dag, noise, random_state=None): 9 | """ 10 | Samples from a Structural Causal Model (SCM) in topological order 11 | 12 | Parameters 13 | ---------- 14 | n : int 15 | Number of observations to sample 16 | 17 | equations : list of callables 18 | List of SCM equations, each a function accepting two parameters 19 | - All variables, of course only parents will be used. 20 | - An exogenous noise variable. 21 | 22 | noise : callable or list of callables 23 | Exogenous noise for each structural equation. If a single callable, 24 | then the same function will be used for all equations. 25 | 26 | random_state : int, optional 27 | Seed for reproducible randomness. 28 | 29 | Returns 30 | ------- 31 | np.ndarray, shape (n, len(equations)) 32 | Sampled observational data 33 | """ 34 | np.random.seed(random_state) 35 | n_vars = len(equations) 36 | X = np.zeros((n_vars, n)) 37 | 38 | topological_order = list(nx.topological_sort(nx.DiGraph(dag))) 39 | 40 | if not callable(noise): 41 | assert len(equations) == len( 42 | noise 43 | ), f"Must provide the same number of structural \ 44 | equations as noise variables. Provided {len(equations)} and \ 45 | {len(noise)}" 46 | 47 | for i in topological_order: 48 | f = equations[i] 49 | if not callable(noise): 50 | u = np.asarray([noise[i]() for _ in range(n)]) 51 | else: 52 | u = np.asarray([noise() for _ in range(n)]) 53 | X[i] = f(X, u) 54 | 55 | return X.T 56 | 57 | 58 | def _icp_base_func( 59 | X, u, parents, function, f_join, intervened, intervention_func, pre_intervention, 60 | ): 61 | """Helper function for icp simulations""" 62 | # X shape (m_features, n_samples) 63 | X = X * parents[:, np.newaxis] 64 | X = X[parents != 0] 65 | X = function(X) 66 | if intervened: 67 | if pre_intervention: 68 | return f_join(intervention_func(X), axis=0) + u 69 | else: 70 | return intervention_func(f_join(X, axis=0) + u) 71 | else: 72 | return f_join(X, axis=0) + u 73 | 74 | 75 | def sample_nonlinear_icp_sim( 76 | dag, 77 | n_samples, 78 | nonlinearity="id", 79 | noise_df=2, 80 | combination="additive", 81 | intervention_targets=None, 82 | intervention="soft", 83 | intervention_shift=0, 84 | intervention_scale=1, 85 | intervention_pct=None, 86 | random_state=None, 87 | pre_intervention=False, 88 | lambda_noise=None, 89 | ): 90 | """ 91 | Simulates data from a given dag according to the simulation design 92 | in Heinz-Deml et al. 2018 93 | 94 | Parameters 95 | ---------- 96 | dag : numpy.ndarray, shape (m, m) 97 | Weighted adjacency matrix. 98 | dag[i, j] != 0 if there is an edge from Xi -> Xj. The edge weight 99 | dag[i, j] will weight Xi in the computation of Xj, 100 | 101 | n_samples : int 102 | Number of training samples 103 | 104 | nonlinearity : {'id', 'relu', 'sqrt', 'sin', 'cubic'} or callable 105 | Nonlinear function of parent value. 106 | 107 | noise_df : int, nonnegative, default=100 108 | The degrees of freedom of the t-distribution from which 109 | the noise variable is sampled. Larger values are more 110 | similar to a Gaussian distribution. 111 | 112 | combination : {'additive', 'multiplicative'} 113 | How the functions of the variable's parents are combined. 114 | 115 | intervention_targets : list of features, optional 116 | Variables to intervene on. 117 | 118 | intervention : {'soft', 'hard'} 119 | Type of intervention. A 'soft' intervention adds noise. 120 | A 'hard' intervention is only the noise. The noise is a 121 | t-distribution with `noise_df` degrees of freedom, and 122 | shifted and scaled as specified. 123 | 124 | intervention_shift : float, default=0 125 | Shifted mean applied to noise data. 126 | 127 | intervention_scale : float, default=1 128 | Scale applied to noise data, pre shift 129 | 130 | intervention_pct : float or int, optional 131 | If `float`, the likelihood any given variable is intervened on. 132 | If `int`, the number of targets to intervene on. 133 | 134 | random_state : int, optional 135 | Allows reproducibility of randomness. 136 | 137 | Returns 138 | ------- 139 | numpy.ndarray : shape (n_samples, dag.shape[0]) 140 | Simulated data 141 | 142 | Notes 143 | ----- 144 | Heinz-Deml et al. 2018 considers the following settings: 145 | 146 | n_samples : {100, 200, 500, 2000, 5000} 147 | noise_df : {2, 3, 5, 10, 20, 50, 100} 148 | intervention_shift : {0, 0.1, 0.2, 0.5, 1, 2, 5, 10} 149 | intervention_scale : {0, 0.1, 0.2, 0.5, 1, 2, 5, 10} 150 | 151 | """ 152 | m = dag.shape[0] 153 | np.random.seed(random_state) 154 | 155 | if combination == "additive": 156 | f_join = np.sum 157 | elif combination == "multiplicative": 158 | f_join = np.prod 159 | 160 | # Choose nonlinearity 161 | if nonlinearity == "id": 162 | nonlinearity_func = lambda X: X 163 | elif nonlinearity == "relu": 164 | nonlinearity_func = lambda X: np.maximum(X, 0) 165 | elif nonlinearity == "sqrt": 166 | nonlinearity_func = lambda X: np.sin(X) * np.sqrt(np.abs(X)) 167 | elif nonlinearity == "sin": 168 | nonlinearity_func = lambda X: np.sin(2 * np.pi * X) 169 | elif nonlinearity == "cubic": 170 | nonlinearity_func = lambda X: X ** 3 171 | elif callable(nonlinearity): 172 | nonlinearity_func = nonlinearity 173 | else: 174 | raise ValueError(f"Nonlinearity invalid: {nonlinearity}") 175 | 176 | # Choose intervention targets 177 | if intervention_targets is None: 178 | if isinstance(intervention_pct, float): 179 | intervention_targets = [ 180 | i for i in range(m) if np.random.uniform() < intervention_pct 181 | ] 182 | elif isinstance(intervention_pct, int): 183 | intervention_targets = np.random.choice( 184 | m, size=(intervention_pct), replace=False 185 | ) 186 | else: 187 | intervention_targets = [] 188 | elif isinstance(intervention_targets, int): 189 | intervention_targets = [intervention_targets] 190 | 191 | # Intervention function 192 | if intervention == "hard": 193 | intervention_func = ( 194 | lambda X: intervention_scale 195 | * (np.random.standard_t(df=noise_df, size=X.shape)) 196 | + intervention_shift 197 | ) 198 | elif intervention == "soft": 199 | intervention_func = ( 200 | lambda X: X 201 | + intervention_scale * (np.random.standard_t(df=noise_df, size=X.shape)) 202 | + intervention_shift 203 | ) 204 | elif callable(intervention): 205 | intervention_func = intervention 206 | else: 207 | raise ValueError(f"Invalid intervention: {intervention}") 208 | 209 | if lambda_noise is None: 210 | noise = lambda: np.random.standard_t(df=noise_df) 211 | else: 212 | noise = lambda: lambda_noise(np.random.standard_t(df=noise_df)) 213 | 214 | equations = [ 215 | partial( 216 | _icp_base_func, 217 | parents=parents, 218 | function=nonlinearity_func, 219 | f_join=f_join, 220 | intervened=(i in intervention_targets), 221 | intervention_func=intervention_func, 222 | pre_intervention=pre_intervention, 223 | ) 224 | for i, parents in enumerate(dag.T) 225 | ] 226 | 227 | X = sample_topological(n_samples, equations, dag, noise, random_state) 228 | 229 | return X 230 | 231 | 232 | def _cdnod_base_func(X, u, parents, coefs, functions, noise_scale, noise_shift, additive): 233 | """Helper function for icp simulations""" 234 | # X shape (m_features, n_samples) 235 | # X = X * parents[:, np.newaxis] 236 | n_samples = X.shape[1] 237 | X = X[parents != 0] 238 | # X shape (m_parents, n_samples) 239 | X = np.asarray([b * f(x) for b, f, x in zip(coefs, functions, X)]) 240 | if additive: 241 | return np.sum(X, axis=0) + noise_scale * (u + noise_shift) 242 | else: 243 | if sum(parents) == 0: 244 | return np.random.normal(0, 1, (n_samples,)) 245 | return np.sum(X, axis=0) * np.abs(noise_scale * (u + noise_shift)) 246 | 247 | 248 | def sample_cdnod_sim( 249 | dag, 250 | n_samples, 251 | functions=[ 252 | np.tanh, 253 | np.sinc, 254 | lambda x: x ** 2, 255 | lambda x: x ** 3, 256 | ], 257 | intervention_targets=None, 258 | intervention_pct=None, 259 | base_random_state=None, 260 | domain_random_state=None, 261 | ): 262 | """ 263 | Simulates data from a given dag according to the simulation design 264 | in Huang et al. 2020 265 | 266 | Parameters 267 | ---------- 268 | dag : numpy.ndarray, shape (m, m) 269 | Weighted adjacency matrix. 270 | dag[i, j] != 0 if there is an edge from Xi -> Xj. The edge weight 271 | dag[i, j] will weight Xi in the computation of Xj, 272 | 273 | n_samples : int 274 | Number of training samples 275 | 276 | functions : list of callables 277 | Possible functions of parent variables to be sampled and summed 278 | when simulating the SCM. 279 | 280 | intervention_targets : list of features, optional 281 | Variables to intervene on. 282 | 283 | intervention_pct : float or int, optional 284 | If `float`, the likelihood any given variable is intervened on. 285 | If `int`, the number of targets to intervene on. 286 | 287 | base_random_state : int, optional 288 | Allows reproducibility of randomness of underlying SCM 289 | 290 | domain_random_state : int, optional 291 | Allows reproducibility of randomness of intervention 292 | 293 | Returns 294 | ------- 295 | numpy.ndarray : shape (n_samples, dag.shape[0]) 296 | Simulated data 297 | 298 | Notes 299 | ----- 300 | 301 | """ 302 | m = dag.shape[0] 303 | dag = (dag != 0).astype(int) 304 | base_seed = np.random.RandomState(base_random_state) 305 | domain_seed = np.random.RandomState(domain_random_state) 306 | 307 | # Choose intervention targets 308 | if intervention_targets is None: 309 | if isinstance(intervention_pct, float): 310 | intervention_targets = [ 311 | i for i in range(m) if domain_seed.uniform() < intervention_pct 312 | ] 313 | elif isinstance(intervention_pct, int): 314 | intervention_targets = domain_seed.choice( 315 | m, size=(intervention_pct), replace=False 316 | ) 317 | else: 318 | intervention_targets = [] 319 | elif isinstance(intervention_targets, int): 320 | intervention_targets = [intervention_targets] 321 | 322 | # If intervention on variable, use domain specific seed. Otherwise shared base seed 323 | functions = [ 324 | base_seed.choice(functions, size=(np.sum(parents))) 325 | for i, parents in enumerate(dag.T) 326 | ] 327 | additives = [ 328 | False 329 | # True #base_seed.choice([True, False]) 330 | for i, parents in enumerate(dag.T) 331 | ] 332 | 333 | base_equations = [ 334 | partial( 335 | _cdnod_base_func, 336 | parents=parents, 337 | coefs=[1]*sum(parents), 338 | # coefs=base_seed.uniform(0.5, 2.5, size=(np.sum(parents))), 339 | functions=functions[i], 340 | noise_scale=1, 341 | noise_shift=0, 342 | additive=additives[i], 343 | ) 344 | for i, parents in enumerate(dag.T) 345 | ] 346 | 347 | domain_equations = [ 348 | partial( 349 | _cdnod_base_func, 350 | parents=parents, 351 | coefs=[1]*sum(parents), 352 | # coefs=domain_seed.uniform(0.5, 2.5, size=(np.sum(parents))), 353 | functions=functions[i], 354 | # functions=domain_seed.choice(functions, size=(np.sum(parents))), 355 | # noise_scale=domain_seed.uniform(1, 3), 356 | noise_scale=1, 357 | noise_shift=0, 358 | additive=additives[i], 359 | ) 360 | for i, parents in enumerate(dag.T) 361 | ] 362 | 363 | base_noises = [base_seed.choice( 364 | [ 365 | # lambda: np.random.normal(0, 1), 366 | # lambda: np.random.uniform(-0.5, 0.5), 367 | lambda: 1 / np.random.uniform(1, 3), 368 | lambda: np.random.uniform(1, 3), 369 | ]) 370 | for i in range(m)] 371 | 372 | domain_noises = [domain_seed.choice( 373 | [ 374 | # lambda: np.random.normal(0, 1), 375 | # lambda: np.random.uniform(-0.5, 0.5), 376 | lambda: 1 / np.random.uniform(1, 3), 377 | lambda: np.random.uniform(1, 3), 378 | ]) 379 | for i in range(m)] 380 | 381 | equations = [ 382 | domain if i in intervention_targets else base 383 | for i, (base, domain) in enumerate(zip( 384 | base_equations, domain_equations)) 385 | ] 386 | 387 | noises = [ 388 | domain if i in intervention_targets else base 389 | for i, (base, domain) in enumerate(zip( 390 | base_noises, domain_noises)) 391 | ] 392 | 393 | X = sample_topological(n_samples, equations, dag, noises, domain_random_state) 394 | 395 | return X 396 | -------------------------------------------------------------------------------- /sparse_shift/datasets/tests/test_dags.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import networkx as nx 4 | from sparse_shift.datasets import erdos_renyi_dag, \ 5 | connected_erdos_renyi_dag, barabasi_albert_dag 6 | 7 | 8 | def test_are_dags(): 9 | n = 12 10 | p = 0.5 11 | G = erdos_renyi_dag(n, p) 12 | np.testing.assert_array_equal(G, np.tril(G)) 13 | 14 | G = connected_erdos_renyi_dag(n, p) 15 | np.testing.assert_array_equal(G, np.tril(G)) 16 | 17 | G = barabasi_albert_dag(n, p) 18 | np.testing.assert_array_equal(G, np.tril(G)) 19 | 20 | 21 | def test_connected_er_constant(): 22 | n = 12 23 | p = 0.5 24 | m = np.sum(connected_erdos_renyi_dag(n, p)) 25 | 26 | for _ in range(3): 27 | assert np.sum(connected_erdos_renyi_dag(n, p)) == m 28 | 29 | 30 | def test_ba_degree(): 31 | n = 12 32 | p = 0.5 33 | 34 | G1 = connected_erdos_renyi_dag(n, p) 35 | G2 = barabasi_albert_dag(n, p) 36 | 37 | assert np.sum(G1.shape[0]) == np.sum(G2.shape[0]) 38 | -------------------------------------------------------------------------------- /sparse_shift/datasets/tests/test_simulations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sparse_shift.datasets import sample_nonlinear_icp_sim 3 | import pytest 4 | 5 | 6 | @pytest.mark.parametrize("nonlinearity", ["id", "relu", "sqrt", "sin"]) 7 | @pytest.mark.parametrize("noise_df", [2]) 8 | @pytest.mark.parametrize("combination", ["additive", "multiplicative"]) 9 | @pytest.mark.parametrize("intervention", ["soft", "hard"]) 10 | @pytest.mark.parametrize("intervention_shift", [0, 1]) 11 | @pytest.mark.parametrize("intervention_scale", [0, 1]) 12 | @pytest.mark.parametrize("intervention_targets", [None, 1, [2, 3]]) 13 | @pytest.mark.parametrize("intervention_pct", [None, 0, 0.1]) 14 | def test_icp_sim_params_work( 15 | nonlinearity, 16 | noise_df, 17 | combination, 18 | intervention_targets, 19 | intervention, 20 | intervention_shift, 21 | intervention_scale, 22 | intervention_pct, 23 | ): 24 | 25 | dag = np.asarray([[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [1, 1, 1, 0]]) 26 | X = sample_nonlinear_icp_sim( 27 | dag=dag, 28 | n_samples=100, 29 | nonlinearity=nonlinearity, 30 | noise_df=noise_df, 31 | combination=combination, 32 | intervention_targets=intervention_targets, 33 | intervention=intervention, 34 | intervention_shift=intervention_shift, 35 | intervention_scale=intervention_scale, 36 | intervention_pct=intervention_pct, 37 | random_state=None, 38 | ) 39 | 40 | assert isinstance(X, np.ndarray) 41 | assert X.shape == (100, dag.shape[0]) 42 | -------------------------------------------------------------------------------- /sparse_shift/independence_tests.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import r2_score 3 | 4 | 5 | def invariant_residual_test( 6 | X, 7 | Y, 8 | z, 9 | method="gam", 10 | test="ks", 11 | method_kwargs={}, 12 | return_model=False, 13 | combine_pvalues=True, 14 | ): 15 | r""" 16 | Calulates the 2-sample test statistic. 17 | 18 | Parameters 19 | ---------- 20 | X : ndarray, shape (n, p) 21 | Features to condition on 22 | Y : ndarray, shape (n,) 23 | Target or outcome features 24 | z : list or ndarray, shape (n,) 25 | List of zeros and ones indicating which samples belong to 26 | which groups. 27 | method : {"forest", "gam", "linear"}, default="gam" 28 | Method to predict the target given the covariates 29 | test : {"whitney_levene", "ks"}, default="ks" 30 | Test of the residuals between the groups 31 | method_kwargs : dict 32 | Named arguments to pass to the prediction method. 33 | return_model : boolean, default=False 34 | If true, returns the fitted model 35 | combine_pvalues: bool, default=True 36 | If True, returns hte minimum of the corrected pvalues. 37 | 38 | Returns 39 | ------- 40 | pvalue : float 41 | The computed *k*-sample p-value. 42 | r2 : float 43 | r2 score of the regression fit 44 | model : object 45 | Fitted regresion model, if return_model is True 46 | """ 47 | 48 | if method == "forest": 49 | from sklearn.ensemble import RandomForestRegressor 50 | 51 | predictor = RandomForestRegressor(max_features="sqrt", **method_kwargs) 52 | elif method == "gam": 53 | from sklearn.linear_model import LinearRegression 54 | from sklearn.preprocessing import SplineTransformer 55 | from sklearn.pipeline import Pipeline 56 | from sklearn.model_selection import GridSearchCV 57 | 58 | pipe = Pipeline( 59 | steps=[ 60 | ("spline", SplineTransformer(n_knots=4, degree=3)), 61 | ("linear", LinearRegression(**method_kwargs)), 62 | ] 63 | ) 64 | param_grid = { 65 | "spline__n_knots": [3, 5, 7, 9], 66 | } 67 | predictor = GridSearchCV( 68 | pipe, param_grid, n_jobs=-2, refit=True, 69 | scoring="neg_mean_squared_error" 70 | ) 71 | elif method == "linear": 72 | from sklearn.linear_model import LinearRegression 73 | 74 | predictor = LinearRegression(**method_kwargs) 75 | else: 76 | raise ValueError(f"Method {method} not a valid option.") 77 | 78 | predictor = predictor.fit(X, Y) 79 | Y_pred = predictor.predict(X) 80 | residuals = Y - Y_pred 81 | r2 = r2_score(Y, Y_pred) 82 | 83 | if test == "whitney_levene": 84 | from scipy.stats import mannwhitneyu 85 | from scipy.stats import levene 86 | 87 | _, mean_pval = mannwhitneyu( 88 | residuals[np.asarray(z, dtype=bool)], 89 | residuals[np.asarray(1 - z, dtype=bool)], 90 | ) 91 | _, var_pval = levene( 92 | residuals[np.asarray(z, dtype=bool)], 93 | residuals[np.asarray(1 - z, dtype=bool)], 94 | ) 95 | # Correct for multiple tests 96 | if combine_pvalues: 97 | pval = min(mean_pval * 2, var_pval * 2, 1) 98 | else: 99 | pval = (min(mean_pval * 2, 1), min(var_pval * 2, 1)) 100 | elif test == "ks": 101 | from scipy.stats import kstest 102 | 103 | _, pval = kstest( 104 | residuals[np.asarray(z, dtype=bool)], 105 | residuals[np.asarray(1 - z, dtype=bool)], 106 | ) 107 | else: 108 | raise ValueError(f"Test {test} not a valid option.") 109 | 110 | if return_model: 111 | return pval, r2, predictor 112 | else: 113 | return pval, r2 114 | -------------------------------------------------------------------------------- /sparse_shift/kcd.py: -------------------------------------------------------------------------------- 1 | """Kernel Conditional Discrepancy test""" 2 | 3 | # Author: Ronan Pery 4 | 5 | import numpy as np 6 | from scipy.stats import norm 7 | from scipy.optimize import minimize_scalar 8 | from sklearn.linear_model import LogisticRegression 9 | from sklearn.metrics import pairwise_distances 10 | from sklearn.metrics.pairwise import pairwise_kernels 11 | from sklearn.model_selection import StratifiedShuffleSplit 12 | from joblib import Parallel, delayed 13 | from .utils import check_2d 14 | 15 | 16 | def _compute_kern(X, Y=None, metric="rbf", n_jobs=None, sigma=None): 17 | """Computes an RBF kernel matrix using median l2 distance as bandwidth""" 18 | X = check_2d(X) 19 | Y = check_2d(Y) 20 | if sigma is None: 21 | l2 = pairwise_distances(X, metric="l2", n_jobs=n_jobs) 22 | n = l2.shape[0] 23 | # compute median of off diagonal elements 24 | med = np.median( 25 | np.lib.stride_tricks.as_strided( 26 | l2, (n - 1, n + 1), (l2.itemsize * (n + 1), l2.itemsize) 27 | )[:, 1:] 28 | ) 29 | # prevents division by zero when used on label vectors 30 | med = med if med else 1 31 | else: 32 | med = sigma 33 | gamma = 1.0 / (2 * (med ** 2)) 34 | return pairwise_kernels(X, Y=Y, metric=metric, n_jobs=n_jobs, gamma=gamma), med 35 | 36 | 37 | def _compute_reg_bound(K): 38 | n = K.shape[0] 39 | evals = np.linalg.svd(K, compute_uv=False, hermitian=True) 40 | res = minimize_scalar( 41 | lambda reg: np.sum(evals ** 2 / (evals + reg) ** 2) / n + reg, 42 | bounds=(0.0001, 1000), 43 | method="bounded", 44 | ) 45 | return res.x 46 | 47 | 48 | class KCD: 49 | """ 50 | Kernel Conditional Discrepancy test. 51 | 52 | Parameters 53 | ---------- 54 | compute_distance : str, callable, or None, default: "euclidean" or "gaussian" 55 | A function that computes the distance among the samples within each 56 | data matrix. 57 | Valid strings for ``compute_distance`` are, as defined in 58 | :func:`sklearn.metrics.pairwise_distances`, 59 | 60 | - From scikit-learn: [``"euclidean"``, ``"cityblock"``, ``"cosine"``, 61 | ``"l1"``, ``"l2"``, ``"manhattan"``] See the documentation for 62 | :mod:`scipy.spatial.distance` for details 63 | on these metrics. 64 | - From scipy.spatial.distance: [``"braycurtis"``, ``"canberra"``, 65 | ``"chebyshev"``, ``"correlation"``, ``"dice"``, ``"hamming"``, 66 | ``"jaccard"``, ``"kulsinski"``, ``"mahalanobis"``, ``"minkowski"``, 67 | ``"rogerstanimoto"``, ``"russellrao"``, ``"seuclidean"``, 68 | ``"sokalmichener"``, ``"sokalsneath"``, ``"sqeuclidean"``, 69 | ``"yule"``] See the documentation for :mod:`scipy.spatial.distance` for 70 | details on these metrics. 71 | 72 | Alternatively, this function computes the kernel similarity among the 73 | samples within each data matrix. 74 | Valid strings for ``compute_kernel`` are, as defined in 75 | :func:`sklearn.metrics.pairwise.pairwise_kernels`, 76 | 77 | [``"additive_chi2"``, ``"chi2"``, ``"linear"``, ``"poly"``, 78 | ``"polynomial"``, ``"rbf"``, 79 | ``"laplacian"``, ``"sigmoid"``, ``"cosine"``] 80 | 81 | Note ``"rbf"`` and ``"gaussian"`` are the same metric. 82 | regs : (float, float), default=None 83 | Amount of regularization for inverting the kernel matrices 84 | of the two classes. 85 | If None, chooses the value that minimizes the upper bound 86 | on the mean squared prediction error. 87 | n_jobs : int, optional 88 | Number of jobs to run computations in parallel. 89 | **kwargs 90 | Arbitrary keyword arguments for ``compute_distance``. 91 | 92 | Attributes 93 | ---------- 94 | sigma_x_ : float 95 | median l2 distance between conditional features X 96 | sigma_y_ : float 97 | median l2 distance between target features Y 98 | regs_ : (float, float) 99 | Regularization used to invert kernel matrices in X 100 | propensity_reg_ : float 101 | Regularization parameter computed for propensity scores 102 | null_dist_ : list 103 | Null distribution of test statistics after calling test. 104 | e_hat_ : list 105 | Propensity score probabilities of samples being in group 1. 106 | 107 | Notes 108 | ----- 109 | Per [1], the regularization level should scale with n_samples**b, 110 | where b is in (0, 0.5) 111 | 112 | References 113 | ---------- 114 | [1] J. Park, U. Shalit, B. Schölkopf, and K. Muandet, “Conditional Distributional 115 | Treatment Effect with Kernel Conditional Mean Embeddings and U-Statistic 116 | Regression,” arXiv:2102.08208, Jun. 2021. 117 | """ 118 | 119 | def __init__(self, compute_distance=None, regs=None, n_jobs=None, **kwargs): 120 | self.compute_distance = compute_distance 121 | self.regs = regs 122 | self.kwargs = kwargs 123 | self.n_jobs = n_jobs 124 | 125 | def statistic(self, X, Y, z): 126 | r""" 127 | Calulates the 2-sample test statistic. 128 | 129 | Parameters 130 | ---------- 131 | X : ndarray, shape (n, p) 132 | Features to condition on 133 | Y : ndarray, shape (n, q) 134 | Target or outcome features 135 | z : list or ndarray, length n 136 | List of zeros and ones indicating which samples belong to 137 | which groups. 138 | 139 | Returns 140 | ------- 141 | stat : float 142 | The computed statistic 143 | """ 144 | K, sigma_x = _compute_kern(X, n_jobs=self.n_jobs) 145 | L, sigma_y = _compute_kern(Y, n_jobs=self.n_jobs) 146 | self.sigma_x_ = sigma_x 147 | self.sigma_y_ = sigma_y 148 | 149 | return self._statistic(K, L, z) 150 | 151 | def _get_inverse_kernels(self, K, z): 152 | """Helper function to compute W matrices""" 153 | # Compute W matrices from z 154 | K0 = K[np.array(1 - z, dtype=bool)][:, np.array(1 - z, dtype=bool)] 155 | K1 = K[np.array(z, dtype=bool)][:, np.array(z, dtype=bool)] 156 | 157 | if not hasattr(self, "regs_"): 158 | self.regs_ = self.regs 159 | if self.regs_ is None: 160 | self.regs_ = (_compute_reg_bound(K0), _compute_reg_bound(K1)) 161 | 162 | W0 = np.linalg.inv(K0 + self.regs_[0] * np.identity(int(np.sum(1 - z)))) 163 | W1 = np.linalg.inv(K1 + self.regs_[1] * np.identity(int(np.sum(z)))) 164 | 165 | return W0, W1 166 | 167 | def _statistic(self, K, L, z): 168 | """Helper function for efficient permutation calculations""" 169 | # Compute W matrices from z 170 | W0, W1 = self._get_inverse_kernels(K, z) 171 | 172 | # Compute L kernels 173 | L0 = L[np.array(1 - z, dtype=bool)][:, np.array(1 - z, dtype=bool)] 174 | L1 = L[np.array(z, dtype=bool)][:, np.array(z, dtype=bool)] 175 | L01 = L[np.array(1 - z, dtype=bool)][:, np.array(z, dtype=bool)] 176 | 177 | # Compute test statistic using traces 178 | # Simplified to avoid repeat computations. W symmetric 179 | KW0 = K[:, np.array(1 - z, dtype=bool)] @ W0 180 | KW1 = K[:, np.array(z, dtype=bool)] @ W1 181 | first = np.trace(KW0.T @ KW0 @ L0) 182 | second = np.trace(KW1.T @ KW0 @ L01) 183 | third = np.trace(KW1.T @ KW1 @ L1) 184 | 185 | return (first - 2 * second + third) / K.shape[0] 186 | 187 | def witness(self, X, Y, z, X_wit, Y_wit): 188 | r""" 189 | Calulates the witness function on a set of points 190 | 191 | Parameters 192 | ---------- 193 | X : ndarray, shape (n, p) 194 | Features to condition on 195 | Y : ndarray, shape (n, q) 196 | Target or outcome features 197 | z : list or ndarray, length n 198 | List of zeros and ones indicating which samples belong to 199 | which groups. 200 | X_wit : ndarray, shape (m, p) 201 | Features to compute the witness distance to 202 | Y_wit : ndarray, shape (l, q) 203 | Target or outcome features for the witness distance 204 | 205 | Returns 206 | ------- 207 | dists : ndarray, shape (l, m) 208 | The computed distances for all X_wit, Y_wit points 209 | """ 210 | K, sigma_x = _compute_kern(X, n_jobs=self.n_jobs) 211 | self.sigma_x_ = sigma_x 212 | 213 | # Compute W matrices from z 214 | W0, W1 = self._get_inverse_kernels(K, z) 215 | del K 216 | 217 | # Witness distances 218 | K, _ = _compute_kern(X, X_wit, n_jobs=self.n_jobs, sigma=sigma_x) 219 | L, sigma_y = _compute_kern(Y, Y_wit, n_jobs=self.n_jobs) 220 | self.sigma_y_ = sigma_y 221 | 222 | K0 = K[np.array(1 - z, dtype=bool)] 223 | K1 = K[np.array(z, dtype=bool)] 224 | L0 = L[np.array(1 - z, dtype=bool)] 225 | L1 = L[np.array(z, dtype=bool)] 226 | 227 | return (K1.T @ W1 @ L1 - K0.T @ W0 @ L0).T 228 | 229 | # def conditional_dmat(self, X, Y): 230 | # """Pairwise distances between conditional features, w.r.t. conditional dist""" 231 | # K, _ = _compute_kern(X, n_jobs=self.n_jobs) 232 | # L, _ = _compute_kern(Y, n_jobs=self.n_jobs) 233 | 234 | # W = np.linalg.inv(K + self.reg * np.identity(K.shape[0])) 235 | 236 | # XY = K.T @ W @ L 237 | # return pairwise_distances(XY, metric="l2", n_jobs=self.n_jobs) 238 | 239 | # def U_regress(self, X, Y, X_predict, alpha=1): 240 | # """ 241 | # Generalized kernel ridge regression for U-statistic regression at X_predict 242 | # """ 243 | # X = check_2d(X) 244 | # Y = check_2d(Y) 245 | # X_predict = check_2d(X_predict) 246 | # X_predict = np.tile(X_predict, 2) 247 | # n = X.shape[0] 248 | # p = X.shape[1] 249 | # x_pairs = np.zeros((n ** 2, 2 * p)) 250 | # x_pairs[:, range(p)] = np.repeat(X, n, axis=0) 251 | # x_pairs[:, range(p, 2 * p)] = np.tile(X, (n, 1)) 252 | # h = np.reshape(0.5 * (np.power(Y, 2) + np.power(Y, 2).T - 253 | # 2 * np.matmul(Y, Y.T)), -1) 254 | 255 | # var = KernelRidge( 256 | # alpha=alpha, kernel='rbf', gamma=1.0 / (2 * (self.sigma_x_ ** 2)) 257 | # ).fit(x_pairs, h).predict(X_predict) 258 | 259 | # return var 260 | 261 | def test(self, X, Y, z, reps=1000, random_state=None, fast_pvalue=None): 262 | r""" 263 | Calculates the *k*-sample test statistic and p-value. 264 | 265 | Parameters 266 | ---------- 267 | X : ndarray, shape (n, p) 268 | Features to condition on 269 | Y : ndarray, shape (n, q) 270 | Target or outcome features 271 | z : list or ndarray, length n 272 | List of zeros and ones indicating which samples belong to 273 | which groups. 274 | reps : int, default: 1000 275 | The number of replications used to estimate the null distribution 276 | when using the permutation test used to calculate the p-value. 277 | random_state : int, RandomState instance or None, default=None 278 | Controls randomness in propensity score estimation. 279 | fast_pvalue : Ignored 280 | Ignored 281 | 282 | Returns 283 | ------- 284 | stat : float 285 | The computed *k*-sample statistic. 286 | pvalue : float 287 | The computed *k*-sample p-value. 288 | """ 289 | # Construct kernel matrices for X, Y TODO efficient storage 290 | K, sigma_x = _compute_kern(X, n_jobs=self.n_jobs) 291 | L, sigma_y = _compute_kern(Y, n_jobs=self.n_jobs) 292 | self.sigma_x_ = sigma_x 293 | self.sigma_y_ = sigma_y 294 | 295 | # Compute W matrices from z 296 | stat = self._statistic(K, L, z) 297 | 298 | # Compute proensity scores 299 | self.propensity_reg_ = _compute_reg_bound(K) 300 | # Note: for stability should maybe exclude samples w/ prob < 1/reps 301 | self.e_hat_ = ( 302 | LogisticRegression( 303 | n_jobs=self.n_jobs, 304 | penalty="l2", 305 | warm_start=True, 306 | solver="lbfgs", 307 | random_state=random_state, 308 | C=1 / (2 * self.propensity_reg_), 309 | ) 310 | .fit(K, z) 311 | .predict_proba(K)[:, 1] 312 | ) 313 | 314 | # Parallelization storage cost of kernel matrices 315 | self.null_dist_ = np.array( 316 | Parallel(n_jobs=self.n_jobs)( 317 | [ 318 | delayed(self._statistic)(K, L, np.random.binomial(1, self.e_hat_)) 319 | for _ in range(reps) 320 | ] 321 | ) 322 | ) 323 | pvalue = (1 + np.sum(self.null_dist_ >= stat)) / (1 + reps) 324 | 325 | return stat, pvalue 326 | 327 | 328 | class KCDCV: 329 | """ 330 | Kernel Conditional Discrepancy test with cross validation for hyperparamter 331 | selection. 332 | 333 | Parameters 334 | ---------- 335 | compute_distance : str, callable, or None, default: "euclidean" or "gaussian" 336 | A function that computes the distance among the samples within each 337 | data matrix. 338 | Valid strings for ``compute_distance`` are, as defined in 339 | :func:`sklearn.metrics.pairwise_distances`, 340 | 341 | - From scikit-learn: [``"euclidean"``, ``"cityblock"``, ``"cosine"``, 342 | ``"l1"``, ``"l2"``, ``"manhattan"``] See the documentation for 343 | :mod:`scipy.spatial.distance` for details 344 | on these metrics. 345 | - From scipy.spatial.distance: [``"braycurtis"``, ``"canberra"``, 346 | ``"chebyshev"``, ``"correlation"``, ``"dice"``, ``"hamming"``, 347 | ``"jaccard"``, ``"kulsinski"``, ``"mahalanobis"``, ``"minkowski"``, 348 | ``"rogerstanimoto"``, ``"russellrao"``, ``"seuclidean"``, 349 | ``"sokalmichener"``, ``"sokalsneath"``, ``"sqeuclidean"``, 350 | ``"yule"``] See the documentation for :mod:`scipy.spatial.distance` for 351 | details on these metrics. 352 | 353 | Alternatively, this function computes the kernel similarity among the 354 | samples within each data matrix. 355 | Valid strings for ``compute_kernel`` are, as defined in 356 | :func:`sklearn.metrics.pairwise.pairwise_kernels`, 357 | 358 | [``"additive_chi2"``, ``"chi2"``, ``"linear"``, ``"poly"``, 359 | ``"polynomial"``, ``"rbf"``, 360 | ``"laplacian"``, ``"sigmoid"``, ``"cosine"``] 361 | 362 | Note ``"rbf"`` and ``"gaussian"`` are the same metric. 363 | regs : list, shape (n_regs,), default=(0.01, 0.1, 1.0, 10, 100) 364 | List of kernel regularization values to try. Larger values correspond 365 | to larger regularization. 366 | n_jobs : int, optional 367 | Number of jobs to run computations in parallel. 368 | **kwargs 369 | Arbitrary keyword arguments for ``compute_distance``. 370 | 371 | Attributes 372 | ---------- 373 | sigma_X_ : float 374 | median l2 distance between training features X 375 | sigma_Y_ : float 376 | median l2 distance between all training targets Y 377 | train_idx_ : numpy.ndarray 378 | Indices of training subset 379 | test_idx_ : numpy.ndarray 380 | Indices of test subset 381 | reg_snrs_ : list 382 | Training test statistic SNRs for each regularization value 383 | reg_opt_ : float 384 | Regularization value with maximum SNR on the training data 385 | stat_ : float 386 | Test data test statistic 387 | stat_snr_ : float 388 | Test data SNR 389 | power_alphas_ : numpy.ndarray 390 | False positive rate values for power calculation 391 | analytic_powers_ : numpy.ndarray 392 | Analytic true positive (power) for varying alpha levels 393 | analytic_pvalue_ : float 394 | Analytic pvalue based on a normal distribution approximation 395 | e_hat_ : list 396 | Propensity score probabilities, P(Z=1 | X) 397 | null_stats_ : list 398 | Null distribution of test statistics on permuted test Y 399 | null_vars_ : numpy.ndarray 400 | Estimated variance of null distribution test statistics 401 | perm_pvalue : float 402 | Pvalue of the test data statistic based on a permutation test 403 | 404 | References 405 | ---------- 406 | [1] J. Park, U. Shalit, B. Schölkopf, and K. Muandet, “Conditional 407 | Distributional Treatment Effect with Kernel Conditional Mean 408 | Embeddings and U-Statistic Regression,” arXiv:2102.08208, Jun. 2021. 409 | [2] J. M. Kübler, W. Jitkrittum, B. Schölkopf, and K. Muandet, “A Witness 410 | Two-Sample Test,” arXiv:2102.05573, Feb. 2022. 411 | """ 412 | 413 | def __init__( 414 | self, 415 | compute_distance=None, 416 | regs=(0.01, 0.1, 1.0, 10, 100), 417 | n_jobs=None, 418 | **kwargs 419 | ): 420 | self.compute_distance = compute_distance 421 | self.regs = regs 422 | self.kwargs = kwargs 423 | self.n_jobs = n_jobs 424 | 425 | def _optimize_params(self, X, Y, z, test_fraction=None, random_state=None): 426 | """ 427 | Optimizes the regularization amount for the kernel matrix inversion 428 | as detailed in [2]. Splits the data based on labels z and computes 429 | the witness function using the training data. The optimal amount of 430 | regularization is that which maximizes the signal to noise ratio of 431 | the difference between the two distributions. 432 | """ 433 | X = check_2d(X) 434 | Y = check_2d(Y) 435 | self.n_samples_, self.n_features_ = X.shape 436 | 437 | # Split data into train/test 438 | self.train_idx_, self.test_idx_ = next( 439 | StratifiedShuffleSplit( 440 | n_splits=1, 441 | test_size=0.5 if test_fraction is None else test_fraction, 442 | random_state=random_state, 443 | ).split(np.zeros(self.n_samples_), z) 444 | ) 445 | self.train_idx_ = np.sort(self.train_idx_) 446 | self.test_idx_ = np.sort(self.test_idx_) 447 | 448 | self.train_idx_ = np.arange(self.n_samples_) 449 | self.test_idx_ = np.arange(self.test_idx_) 450 | 451 | # compute K0, K1 on all data. split so as to compute separate sigmas 452 | _, self.sigma_X_ = _compute_kern(X[self.train_idx_], n_jobs=self.n_jobs) 453 | 454 | K00, _ = _compute_kern( 455 | X[self.train_idx_][np.array(1 - z[self.train_idx_], dtype=bool)], 456 | n_jobs=self.n_jobs, 457 | sigma=self.sigma_X_, 458 | ) 459 | K01, _ = _compute_kern( 460 | X[self.train_idx_][np.array(1 - z[self.train_idx_], dtype=bool)], 461 | X[self.train_idx_][np.array(z[self.train_idx_], dtype=bool)], 462 | sigma=self.sigma_X_, 463 | n_jobs=self.n_jobs, 464 | ) 465 | K0 = np.hstack((K00, K01)) 466 | 467 | K11, _ = _compute_kern( 468 | X[self.train_idx_][np.array(z[self.train_idx_], dtype=bool)], 469 | n_jobs=self.n_jobs, 470 | sigma=self.sigma_X_, 471 | ) 472 | K10, _ = _compute_kern( 473 | X[self.train_idx_][np.array(z[self.train_idx_], dtype=bool)], 474 | X[self.train_idx_][np.array(1 - z[self.train_idx_], dtype=bool)], 475 | sigma=self.sigma_X_, 476 | n_jobs=self.n_jobs, 477 | ) 478 | K1 = np.hstack((K11, K10)) 479 | 480 | # compute l0, l1 on training data 481 | _, self.sigma_Y_ = _compute_kern( 482 | Y[self.train_idx_], # [np.array(1 - z[self.train_idx_], dtype=bool)], 483 | n_jobs=self.n_jobs, 484 | ) 485 | L0, _ = _compute_kern( 486 | Y[self.train_idx_][np.array(1 - z[self.train_idx_], dtype=bool)], 487 | Y[self.train_idx_], # [np.array(1 - z[self.train_idx_], dtype=bool)], 488 | n_jobs=self.n_jobs, 489 | sigma=self.sigma_Y_, 490 | ) 491 | 492 | L1, _ = _compute_kern( 493 | Y[self.train_idx_][np.array(z[self.train_idx_], dtype=bool)], 494 | Y[self.train_idx_], # [np.array(z[self.train_idx_], dtype=bool)], 495 | n_jobs=self.n_jobs, 496 | sigma=self.sigma_Y_, 497 | ) 498 | 499 | # indices of groups 0 and 1 in L0, L1 500 | idx = z[self.train_idx_] 501 | 502 | # Iterate over reg 503 | self.reg_snrs_ = [] 504 | for reg in self.regs: # could consider separate reg parameters 505 | W0 = np.linalg.inv(K00 + reg * np.identity(K00.shape[0])) 506 | K0W0 = K0.T @ W0 507 | W1 = np.linalg.inv(K11 + reg * np.identity(K11.shape[0])) 508 | K1W1 = K1.T @ W1 509 | 510 | stat, pooled_var = self._statistic(K0W0, L0, K1W1, L1, idx) 511 | 512 | self.reg_snrs_.append(stat / np.sqrt(pooled_var)) 513 | 514 | # identify optimal reg 515 | self.reg_opt_ = self.regs[np.argmax(np.abs(self.reg_snrs_))] 516 | h_sign = np.sign(self.reg_snrs_[np.argmax(np.abs(self.reg_snrs_))]) 517 | 518 | W0 = np.linalg.inv(K00 + self.reg_opt_ * np.identity(K00.shape[0])) 519 | W1 = np.linalg.inv(K11 + self.reg_opt_ * np.identity(K11.shape[0])) 520 | 521 | return W0, W1, h_sign 522 | 523 | def test(self, X, Y, z, reps=1000, random_state=None, fast_pvalue=False): 524 | r""" 525 | Calculates the *k*-sample test statistic and p-value using the optimal 526 | regularization value which maximizes the test statistic SNR. 527 | 528 | Parameters 529 | ---------- 530 | X : ndarray, shape (n, p) 531 | Features to condition on 532 | Y : ndarray, shape (n, q) 533 | Target or outcome features 534 | z : list or ndarray, length n 535 | List of zeros and ones indicating which samples belong to 536 | which groups. 537 | reps : int, default: 1000 538 | The number of replications used to estimate the null distribution 539 | when using the permutation test used to calculate the p-value. 540 | random_state : int, RandomState instance or None, default=None 541 | Controls the randomness of the data splitting. 542 | fast_pvalue : boolean, default=False 543 | If True, the analytic form of the pvalue is computed using a 544 | normal distribution approximation. Valid with larger sample sizes. 545 | 546 | Returns 547 | ------- 548 | stat : float 549 | The computed test statistic. 550 | pvalue : float 551 | The computed test p-value. 552 | """ 553 | W0, W1, h_sign = self._optimize_params(X, Y, z, random_state=random_state) 554 | 555 | K, _ = _compute_kern(X, sigma=self.sigma_X_, n_jobs=self.n_jobs,) 556 | 557 | K0W0 = ( 558 | K[self.train_idx_][np.array(1 - z[self.train_idx_], dtype=bool)][ 559 | :, self.test_idx_ 560 | ].T 561 | @ W0 562 | ) 563 | K1W1 = ( 564 | K[self.train_idx_][np.array(z[self.train_idx_], dtype=bool)][ 565 | :, self.test_idx_ 566 | ].T 567 | @ W1 568 | ) 569 | 570 | # compute l0, l1 from test to train 571 | L0, _ = _compute_kern( # shape (tr0, te) 572 | Y[self.train_idx_][np.array(1 - z[self.train_idx_], dtype=bool)], 573 | Y[self.test_idx_], 574 | sigma=self.sigma_Y_, 575 | n_jobs=self.n_jobs, 576 | ) 577 | L1, _ = _compute_kern( # shape (tr1, te) 578 | Y[self.train_idx_][np.array(z[self.train_idx_], dtype=bool)], 579 | Y[self.test_idx_], 580 | sigma=self.sigma_Y_, 581 | n_jobs=self.n_jobs, 582 | ) 583 | 584 | # binary indices for test stat 585 | idx = z[self.test_idx_] 586 | n0 = np.array(1 - z[self.test_idx_], dtype=bool).sum() 587 | n1 = np.array(z[self.test_idx_], dtype=bool).sum() 588 | 589 | # h_sign ensures learned kernel expectation for Y1 > Y0 590 | self.stat_, pooled_var = self._statistic(K0W0, L0, K1W1, L1, idx, sign=h_sign) 591 | self.stat_snr_ = self.stat_ / np.sqrt(pooled_var) 592 | 593 | # Compute analytic power, per [2] 594 | self.power_alphas_ = np.linspace(0.05, 1, 21) 595 | self.analytic_powers_ = 1 - np.asarray( 596 | [ 597 | norm.cdf(norm.ppf(1 - alpha) - self.stat_snr_ * np.sqrt(n0 + n1)) 598 | for alpha in self.power_alphas_ 599 | ] 600 | ) 601 | 602 | # permutation test via permute l0, l1 per propensity scores 603 | # Note: for stability should maybe exclude samples w/ prob < 1/reps 604 | # Is trained on entire dataset, but then subsetted in the test step 605 | if fast_pvalue: # per [2] 606 | self.analytic_pvalue_ = 1 - norm.cdf(self.stat_snr_ * np.sqrt(n0 + n1)) 607 | return self.stat_, self.analytic_pvalue_ 608 | else: 609 | self.e_hat_ = ( 610 | LogisticRegression( 611 | n_jobs=self.n_jobs, 612 | penalty="l2", 613 | warm_start=True, 614 | solver="lbfgs", 615 | C=1 / (2 * self.reg_opt_), 616 | random_state=random_state, 617 | ) 618 | .fit(K, z) 619 | .predict_proba(K)[:, 1] 620 | ) 621 | 622 | M00 = K0W0.T @ K0W0 623 | M01 = K0W0.T @ K1W1 624 | M11 = K1W1.T @ K1W1 625 | 626 | h0_list = M00 @ L0 - M01 @ L1 627 | h1_list = M01.T @ L0 - M11 @ L1 628 | # Parallelization, storage cost of kernel matrices 629 | self.null_stats_, self.null_vars_ = np.array( 630 | list( 631 | zip( 632 | *Parallel(n_jobs=self.n_jobs)( 633 | [ 634 | delayed(self._permute_statistic)( 635 | h0_list, 636 | h1_list, 637 | # K0W0, L0, K1W1, L1, 638 | np.random.binomial(1, self.e_hat_[self.test_idx_]), 639 | sign=h_sign, 640 | ) 641 | for _ in range(reps) 642 | ] 643 | ) 644 | ) 645 | ) 646 | ) 647 | self.perm_pvalue_ = (1 + np.sum(self.null_stats_ >= self.stat_)) / ( 648 | 1 + reps 649 | ) 650 | 651 | return self.stat_, self.perm_pvalue_ 652 | 653 | def _permute_statistic(self, h0_list, h1_list, idx, sign=1): 654 | h0_list = h0_list[:, np.array(1 - idx, dtype=bool)] 655 | h1_list = h1_list[:, np.array(idx, dtype=bool)] 656 | 657 | c = len(h0_list) / (len(h0_list) + len(h1_list)) 658 | pooled_var = np.var(h0_list) / c + np.var(h1_list) / (1 - c) 659 | stat = (np.mean(h1_list) - np.mean(h0_list)) * sign 660 | return stat, pooled_var 661 | 662 | def _statistic(self, K0W0, L0, K1W1, L1, idx, sign=1): 663 | # test statistic and variance 664 | idx0 = np.array(1 - idx, dtype=bool) 665 | idx1 = np.array(idx, dtype=bool) 666 | 667 | M00 = K0W0.T @ K0W0 668 | M01 = K0W0.T @ K1W1 669 | M11 = K1W1.T @ K1W1 670 | 671 | h0_list = M00 @ L0[:, idx0] - M01 @ L1[:, idx0] 672 | h1_list = M01.T @ L0[:, idx1] - M11 @ L1[:, idx1] 673 | 674 | c = len(h0_list) / (len(h0_list) + len(h1_list)) 675 | pooled_var = np.var(h0_list) / c + np.var(h1_list) / (1 - c) 676 | stat = (np.mean(h1_list) - np.mean(h0_list)) * sign 677 | return stat, pooled_var 678 | -------------------------------------------------------------------------------- /sparse_shift/methods.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from causaldag import DAG 3 | from sparse_shift.utils import dag2cpdag, cpdag2dags 4 | from sparse_shift.testing import test_dag_shifts 5 | 6 | 7 | class FullPC: 8 | """ 9 | Pools all the data and computes the oracle PC algorithm result. 10 | """ 11 | def __init__(self, dag): 12 | self.domains_ = [] 13 | self.interv_targets_ = set() 14 | self.dag = dag # adj matrix 15 | 16 | def add_environment(self, interventions): 17 | self.interv_targets_.update(interventions) 18 | self.domains_.append(interventions) 19 | 20 | def get_mec_dags(self): 21 | if len(self.domains_) == 1: 22 | return cpdag2dags(dag2cpdag(self.dag)) 23 | else: 24 | intv_cpdag = dag2cpdag(self.dag, list(self.interv_targets_)) 25 | return cpdag2dags(intv_cpdag) 26 | 27 | def get_mec_cpdag(self): 28 | if len(self.domains_) == 1: 29 | return dag2cpdag(self.dag) 30 | else: 31 | intv_cpdag = dag2cpdag(self.dag, list(self.interv_targets_)) 32 | return intv_cpdag 33 | 34 | 35 | class PairwisePC: 36 | """ 37 | Oracle evaluation of the PC algorithm on all pairs of environments, orienting edges 38 | in the final answer if any pair orients an edges. 39 | """ 40 | def __init__(self, dag): 41 | self.interv_targets_ = [] 42 | self.dag = dag # adj matrix 43 | self.union_cpdag_ = np.zeros(dag.shape) 44 | 45 | def add_environment(self, interventions): 46 | for prior_targets in self.interv_targets_: 47 | pairwise_targets = np.unique(np.hstack((prior_targets, interventions))).astype(int) 48 | intv_cpdag = dag2cpdag(self.dag, pairwise_targets) 49 | self.union_cpdag_ += intv_cpdag 50 | 51 | self.interv_targets_.append(interventions) 52 | 53 | def get_mec_dags(self): 54 | if len(self.interv_targets_) == 1: 55 | return cpdag2dags(dag2cpdag(self.dag)) 56 | else: 57 | cpdag = (self.union_cpdag_ >= np.max(self.union_cpdag_)).astype(int) 58 | return cpdag2dags(cpdag) 59 | 60 | def get_mec_cpdag(self): 61 | if len(self.interv_targets_) == 1: 62 | return dag2cpdag(self.dag) 63 | else: 64 | cpdag = (self.union_cpdag_ >= np.max(self.union_cpdag_)).astype(int) 65 | return cpdag 66 | 67 | 68 | class MinChangeOracle: 69 | """ 70 | Oracle test of the number of mechanism changes each DAG in a Markov equivalence 71 | class experiences. 72 | """ 73 | def __init__(self, dag): 74 | self.interv_targets_ = [] 75 | self.dag = dag # adj matrix 76 | self.min_dags_ = np.asarray(cpdag2dags(dag2cpdag(dag))) 77 | 78 | def add_environment(self, interventions): 79 | for prior_targets in self.interv_targets_: 80 | n_changes = np.zeros(len(self.min_dags_)) 81 | pairwise_targets = np.unique(np.hstack((prior_targets, interventions))).astype(int) 82 | 83 | n_vars = self.dag.shape[0] 84 | aug_dag_adj = np.zeros((n_vars+1, n_vars+1)) 85 | aug_dag_adj[:-1, :-1] = self.dag 86 | aug_dag_adj[-1][pairwise_targets] = 1 87 | aug_dag = DAG().from_amat(aug_dag_adj) 88 | 89 | for i, dag in enumerate(self.min_dags_): 90 | n_changes[i] += self._num_changes(aug_dag, dag, pairwise_targets) 91 | min_idx = np.where(n_changes == min(n_changes))[0] 92 | self.min_dags_ = self.min_dags_[np.asarray(min_idx)] 93 | 94 | self.interv_targets_.append(interventions) 95 | 96 | def get_min_dags(self): 97 | return self.min_dags_ 98 | 99 | def get_min_cpdag(self): 100 | cpdag = (np.sum(self.min_dags_, axis=0) > 0).astype(int) 101 | return cpdag 102 | 103 | def _num_changes(self, true_aug_dag, dag_adj, targets): 104 | n_vars = dag_adj.shape[0] 105 | d_seps = [ 106 | true_aug_dag.dsep(n_vars, i, np.where(dag_adj.T[i] != 0)[0]) 107 | for i in range(n_vars) 108 | ] 109 | num_changes = n_vars - np.sum(d_seps) 110 | return num_changes 111 | 112 | 113 | class MinChange: 114 | """ 115 | Computes the number of pairwise mechanism changes in all DAGs in a given 116 | Markov equivalence class across given environment datasets 117 | """ 118 | def __init__(self, cpdag, test='kci', alpha=0.05, scale_alpha=True, test_kwargs={}): 119 | self.cpdag = cpdag 120 | self.test = test 121 | self.alpha = alpha 122 | self.scale_alpha = scale_alpha 123 | self.test_kwargs = test_kwargs 124 | self.dags_ = np.asarray(cpdag2dags(cpdag)) 125 | self.n_vars_ = cpdag.shape[0] 126 | self.alpha_ = alpha 127 | if scale_alpha: 128 | self.alpha_ /= self.n_vars_ # account for false positive rate within dag 129 | self.n_envs_ = 0 130 | self.n_dags_ = self.dags_.shape[0] 131 | self.Xs_ = [] 132 | 133 | def add_environment(self, X): 134 | X = np.asarray(X) 135 | if self.n_envs_ == 0: 136 | self.pvalues_ = np.ones((self.n_dags_, self.n_vars_, 1, 1)) 137 | else: 138 | old_changes = self.pvalues_.copy() 139 | self.pvalues_ = np.ones((self.n_dags_, self.n_vars_, self.n_envs_+1, self.n_envs_+1)) 140 | self.pvalues_[:, :, :self.n_envs_, :self.n_envs_] = old_changes 141 | 142 | for env, prior_X in enumerate(self.Xs_): 143 | try: 144 | pvalues = test_dag_shifts( # shape (n_dags, n_mech, 2, 2) 145 | Xs=[prior_X, X], 146 | dags=self.dags_, 147 | test=self.test, 148 | test_kwargs=self.test_kwargs) 149 | self.pvalues_[:, :, -1, env] = pvalues[:, :, 0, 1] 150 | self.pvalues_[:, :, env, -1] = pvalues[:, :, 0, 1] 151 | except ValueError as e: 152 | print(e) 153 | self.pvalues_[:, :, -1, env] = 1 154 | self.pvalues_[:, :, env, -1] = 1 155 | 156 | 157 | self.n_envs_ += 1 158 | self.Xs_.append(X) 159 | 160 | @property 161 | def n_dag_changes_(self): 162 | return np.sum(self.pvalues_ <= self.alpha_, axis=(1, 2, 3)) / 2 163 | 164 | @property 165 | def soft_scores_(self): 166 | scores = np.sum(1 - self.pvalues_, axis=(1, 2, 3)) 167 | 168 | return scores 169 | 170 | def get_min_dags(self, soft=False): 171 | if soft: 172 | scores = self.soft_scores_ 173 | min_idx = np.where(scores == np.min(scores))[0] 174 | else: 175 | min_idx = np.where(self.n_dag_changes_ == np.min(self.n_dag_changes_))[0] 176 | return self.dags_[min_idx] 177 | 178 | def get_min_cpdag(self, soft=False): 179 | min_dags = self.get_min_dags(soft=soft) 180 | cpdag = (np.sum(min_dags, axis=0) > 0).astype(int) 181 | return cpdag 182 | 183 | 184 | class FullMinChanges: 185 | """ 186 | Computes the number of mechanism changes in all DAGs in a given 187 | Markov equivalence class across given environment datasets 188 | """ 189 | def __init__(self, cpdag, test='kci', alpha=0.05, scale_alpha=True, test_kwargs={}): 190 | self.cpdag = cpdag 191 | self.test = test 192 | self.alpha = alpha 193 | self.scale_alpha = scale_alpha 194 | self.test_kwargs = test_kwargs 195 | self.dags_ = np.asarray(cpdag2dags(cpdag)) 196 | self.n_vars_ = cpdag.shape[0] 197 | self.alpha_ = alpha 198 | if scale_alpha: 199 | self.alpha_ /= self.n_vars_ # account for false positive rate within dag 200 | self.n_envs_ = 0 201 | self.n_dags_ = self.dags_.shape[0] 202 | self.Xs_ = [] 203 | 204 | def add_environment(self, X): 205 | X = np.asarray(X) 206 | self.Xs_.append(X) 207 | self.n_envs_ += 1 208 | 209 | if self.n_envs_ == 1: 210 | self.pvalues_ = np.ones((self.n_dags_, self.n_vars_)) 211 | return 212 | 213 | self.pvalues_ = test_dag_shifts( # shape (n_dags, n_mech, 2, 2) 214 | Xs=self.Xs_, 215 | dags=self.dags_, 216 | test=self.test, 217 | test_kwargs=self.test_kwargs, 218 | pairwise=False) 219 | 220 | 221 | @property 222 | def n_dag_changes_(self): 223 | return np.sum(self.pvalues_ <= self.alpha_, axis=(1)) 224 | 225 | @property 226 | def soft_scores_(self): 227 | scores = np.sum(1 - self.pvalues_, axis=(1)) 228 | 229 | return scores 230 | 231 | def get_min_dags(self, soft=False): 232 | if soft: 233 | scores = self.soft_scores_ 234 | min_idx = np.where(scores == np.min(scores))[0] 235 | else: 236 | min_idx = np.where(self.n_dag_changes_ == np.min(self.n_dag_changes_))[0] 237 | return self.dags_[min_idx] 238 | 239 | def get_min_cpdag(self, soft=False): 240 | min_dags = self.get_min_dags(soft=soft) 241 | cpdag = (np.sum(min_dags, axis=0) > 0).astype(int) 242 | return cpdag 243 | 244 | 245 | class ParamChanges: 246 | """ 247 | Computes the number of parameter changes in pairwise mechanism changes in all DAGs 248 | in a given Markov equivalence class across given environment datasets 249 | """ 250 | def __init__(self, cpdag, test='linear_params', alpha=0.05, scale_alpha=True, test_kwargs={}): 251 | self.cpdag = cpdag 252 | self.test = test 253 | self.alpha = alpha 254 | self.scale_alpha = scale_alpha 255 | self.test_kwargs = test_kwargs 256 | self.dags_ = np.asarray(cpdag2dags(cpdag)) 257 | self.n_vars_ = cpdag.shape[0] 258 | self.alpha_ = alpha 259 | if scale_alpha: 260 | self.alpha_ /= self.n_vars_ # account for false positive rate within dag 261 | self.n_envs_ = 0 262 | self.n_dags_ = self.dags_.shape[0] 263 | self.Xs_ = [] 264 | 265 | def add_environment(self, X): 266 | assert self.test 267 | X = np.asarray(X) 268 | if self.n_envs_ == 0: 269 | self.pvalues_ = np.ones((self.n_dags_, self.n_vars_, 1, 1, 2)) 270 | else: 271 | old_changes = self.pvalues_.copy() 272 | self.pvalues_ = np.ones((self.n_dags_, self.n_vars_, self.n_envs_+1, self.n_envs_+1, 2)) 273 | self.pvalues_[:, :, :self.n_envs_, :self.n_envs_, :] = old_changes 274 | 275 | for env, prior_X in enumerate(self.Xs_): 276 | try: 277 | pvalues = test_dag_shifts( # shape (n_dags, n_mech, 2, 2) 278 | Xs=[prior_X, X], 279 | dags=self.dags_, 280 | test=self.test, 281 | test_kwargs=self.test_kwargs) 282 | self.pvalues_[:, :, -1, env, :] = pvalues[:, :, 0, 1, :] 283 | self.pvalues_[:, :, env, -1, :] = pvalues[:, :, 0, 1, :] 284 | except ValueError as e: 285 | print(e) 286 | self.pvalues_[:, :, -1, env, :] = 1 287 | self.pvalues_[:, :, env, -1, :] = 1 288 | 289 | self.n_envs_ += 1 290 | self.Xs_.append(X) 291 | 292 | @property 293 | def n_dag_changes_(self): 294 | return np.sum(self.pvalues_ <= self.alpha_, axis=(1, 2, 3, 4)) / 2 295 | 296 | @property 297 | def soft_scores_(self): 298 | scores = np.sum(1 - self.pvalues_, axis=(1, 2, 3, 4)) 299 | 300 | return scores 301 | 302 | def get_min_dags(self, soft=False): 303 | if soft: 304 | scores = self.soft_scores_ 305 | min_idx = np.where(scores == np.min(scores))[0] 306 | else: 307 | min_idx = np.where(self.n_dag_changes_ == np.min(self.n_dag_changes_))[0] 308 | return self.dags_[min_idx] 309 | 310 | def get_min_cpdag(self, soft=False): 311 | min_dags = self.get_min_dags(soft=soft) 312 | cpdag = (np.sum(min_dags, axis=0) > 0).astype(int) 313 | return cpdag 314 | 315 | 316 | def _construct_augmented_cpdag(cpdag): 317 | from sparse_shift.utils import create_causal_learn_cpdag 318 | from sparse_shift.causal_learn.GraphClass import CausalGraph 319 | from causallearn.graph.GraphNode import GraphNode 320 | from causallearn.graph.Edge import Edge 321 | from causallearn.graph.Endpoint import Endpoint 322 | 323 | n_x_vars = cpdag.shape[0] 324 | cl_cpdag = create_causal_learn_cpdag(cpdag) 325 | cl_cpdag.add_node(GraphNode(f'X{n_x_vars+1}')) 326 | 327 | nodes = cl_cpdag.get_nodes() 328 | 329 | for i in range(n_x_vars): 330 | cl_cpdag.add_edge(Edge(nodes[i], nodes[-1], Endpoint.ARROW, Endpoint.TAIL)) 331 | 332 | cg = CausalGraph(G=cl_cpdag) 333 | 334 | return cg 335 | 336 | 337 | class AugmentedPC: 338 | """ 339 | Runs the PC algorithm on an augmented graph, starting from a known MEC (optional). 340 | """ 341 | def __init__(self, cpdag, test='kci', alpha=0.05, test_kwargs={}, verbose=False): 342 | self.cpdag = cpdag 343 | self.test = test 344 | self.alpha = alpha 345 | self.test_kwargs = test_kwargs 346 | self.dags_ = np.asarray(cpdag2dags(cpdag)) 347 | self.n_vars_ = cpdag.shape[0] 348 | self.alpha_ = alpha 349 | self.n_envs_ = 0 350 | self.n_dags_ = self.dags_.shape[0] 351 | self.aug_cpdag_ = _construct_augmented_cpdag(cpdag) 352 | self.Xs_ = [] 353 | self.verbose = verbose 354 | 355 | def add_environment(self, X): 356 | self.Xs_.append(np.asarray(X)) 357 | 358 | if len(self.Xs_) == 1: 359 | self.learned_cpdag_ = self.cpdag 360 | return 361 | 362 | data = np.block([ 363 | [self.Xs_[e], np.reshape([e] * self.Xs_[e].shape[0], (-1, 1))] 364 | for e in range(len(self.Xs_)) 365 | ]) 366 | 367 | if self.test == 'fisherz': 368 | from causallearn.utils.cit import fisherz 369 | test_func = fisherz 370 | elif self.test == 'kci': 371 | from causallearn.utils.cit import kci 372 | test_func = kci 373 | else: 374 | raise ValueError(f'Test {self.test} not implemented.') 375 | 376 | 377 | # Run meek rules on found edges 378 | from sparse_shift.causal_learn.SkeletonDiscovery import augmented_skeleton_discovery 379 | cg_skel_disc = augmented_skeleton_discovery(data, self.alpha_, test_func, 380 | stable=True, 381 | background_knowledge=None, verbose=self.verbose, 382 | show_progress=self.verbose, cg=self.aug_cpdag_) 383 | 384 | from causallearn.utils.PCUtils import Meek 385 | cg_meek = Meek.meek(cg_skel_disc, background_knowledge=None) 386 | 387 | adj = np.zeros(cg_meek.G.graph.shape) 388 | adj[cg_meek.G.graph > 0] = 1 389 | adj[np.abs(cg_meek.G.graph + cg_meek.G.graph.T) > 0] = 1 390 | 391 | self.learned_cpdag_ = adj[:-1, :-1] 392 | 393 | def get_min_dags(self, soft=None): 394 | """For experiment compliance""" 395 | return self.get_dags() 396 | 397 | def get_min_cpdag(self, soft=None): 398 | """For experiment compliance""" 399 | return self.get_cpdag() 400 | 401 | def get_dags(self, soft=None): 402 | return cpdag2dags(self.get_cpdag()) 403 | 404 | def get_cpdag(self, soft=None): 405 | return self.learned_cpdag_ 406 | -------------------------------------------------------------------------------- /sparse_shift/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def dag_true_orientations(true_dag, cpdag): 5 | """Number of correctly oriented edges / number of edges""" 6 | # np.testing.assert_array_equal(true_dag, np.tril(true_dag)) 7 | tp = len(np.where((true_dag + cpdag - cpdag.T) == 2)[0]) 8 | n_edges = np.sum(true_dag) 9 | return tp / n_edges 10 | 11 | 12 | def dag_false_orientations(true_dag, cpdag): 13 | """Number of falsely oriented edges / number of edges""" 14 | # np.testing.assert_array_equal(true_dag, np.tril(true_dag)) 15 | fp = len(np.where((true_dag + cpdag.T - cpdag) == 2)[0]) 16 | n_edges = np.sum(true_dag) 17 | return fp / n_edges 18 | 19 | 20 | def dag_precision(true_dag, cpdag): 21 | tp = len(np.where((true_dag + cpdag - cpdag.T) == 2)[0]) 22 | fp = len(np.where((true_dag + cpdag.T - cpdag) == 2)[0]) 23 | return tp / (tp + fp) if (tp + fp) > 0 else 1 24 | 25 | 26 | def dag_recall(true_dag, cpdag): 27 | tp = len(np.where((true_dag + cpdag - cpdag.T) == 2)[0]) 28 | return tp / np.sum(true_dag) 29 | 30 | 31 | def average_precision_score(true_dag, pvalues_mat): 32 | """ 33 | Computes average precision score from pvalue thresholds 34 | """ 35 | from sparse_shift.utils import dag2cpdag, cpdag2dags 36 | thresholds = np.unique(pvalues_mat) 37 | dags = np.asarray(cpdag2dags(dag2cpdag(true_dag))) 38 | 39 | # ap_score = 0 40 | # prior_recall = 0 41 | 42 | precisions = [] 43 | recalls = [] 44 | 45 | for t in thresholds: 46 | axis = tuple(np.arange(1, pvalues_mat.ndim)) 47 | n_changes = np.sum(pvalues_mat <= t, axis=axis) / 2 48 | min_idx = np.where(n_changes == np.min(n_changes))[0] 49 | cpdag = (np.sum(dags[min_idx], axis=0) > 0).astype(int) 50 | precisions.append(dag_precision(true_dag, cpdag)) 51 | recalls.append(dag_recall(true_dag, cpdag)) 52 | 53 | # ap_score += precision * (recall - prior_recall) 54 | # prior_recall = recall 55 | 56 | # if len(thresholds) == 1: 57 | # ap_score = precisions[0] * recalls[0] 58 | # else: 59 | sort_idx = np.argsort(recalls) 60 | recalls = np.asarray(recalls)[sort_idx] 61 | precisions = np.asarray(precisions)[sort_idx] 62 | ap_score = (np.diff(recalls, prepend=0) * precisions).sum() 63 | 64 | return ap_score 65 | -------------------------------------------------------------------------------- /sparse_shift/plotting.py: -------------------------------------------------------------------------------- 1 | """Tools for visualizing DAGs and more""" 2 | # Authors: Ronan Perry 3 | # License: MIT 4 | 5 | import networkx as nx 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def plot_dag( 10 | adj, 11 | topological_sort=True, 12 | parent_adj=False, 13 | layout="circular", 14 | figsize=(5, 5), 15 | title=None, 16 | highlight_nodes=None, 17 | highlight_edges=None, 18 | labels=None, 19 | node_size=500, 20 | ): 21 | if parent_adj: 22 | adj = adj.T 23 | 24 | G = nx.convert_matrix.from_numpy_matrix(adj, create_using=nx.DiGraph) 25 | assert nx.is_directed_acyclic_graph(G) 26 | 27 | if layout == "circular": 28 | pos = nx.circular_layout(G) 29 | else: 30 | pos = None 31 | # raise NotImplementedError(f'layout {layout} not a valid mode yet') 32 | 33 | edge_options = { 34 | "width": 3, 35 | "arrowstyle": "-|>", 36 | "arrowsize": 20, 37 | "alpha": 0.5, 38 | "arrows": True, 39 | "node_size": node_size, 40 | } 41 | 42 | labeldict = {} 43 | for i in range(adj.shape[0]): 44 | if labels is None: 45 | labeldict[i] = f"X{i+1}" 46 | else: 47 | labeldict[i] = labels[i] 48 | 49 | if highlight_edges is not None: 50 | if parent_adj: 51 | highlight_edges = highlight_edges.T 52 | black_edges = [edge for edge in G.edges() if ((adj[edge] == 1) and (highlight_edges[edge] == 0))] 53 | red_edges = [edge for edge in G.edges() if ((adj[edge] == 1) and (highlight_edges[edge] != 0))] 54 | else: 55 | red_edges = [] 56 | black_edges = G.edges() 57 | 58 | if highlight_nodes is not None: 59 | node_colors = ['white' if highlight_nodes[node] == 0 else 'red' for node in G.nodes()] 60 | else: 61 | node_colors = 'white' 62 | 63 | fig, ax = plt.subplots(figsize=figsize) 64 | 65 | nx.draw_networkx_nodes( 66 | G, pos=pos, 67 | node_color=node_colors, 68 | ax=ax, 69 | node_size=node_size, alpha=0.5, 70 | ) 71 | nx.draw_networkx_labels(G, pos, ax=ax, labels=labeldict) 72 | nx.draw_networkx_edges(G, pos, edgelist=red_edges, edge_color='r', ax=ax, **edge_options) 73 | nx.draw_networkx_edges(G, pos, edgelist=black_edges, edge_color='black', ax=ax, **edge_options) 74 | 75 | ax = plt.gca() 76 | ax.collections[0].set_edgecolor("#000000") 77 | if title is not None: 78 | plt.title(title) 79 | plt.box(False) 80 | -------------------------------------------------------------------------------- /sparse_shift/testing.py: -------------------------------------------------------------------------------- 1 | """Tests for sparse shifts""" 2 | # Authors: Ronan Perry 3 | # License: MIT 4 | 5 | import numpy as np 6 | from hyppo.ksample import MMD 7 | from sparse_shift import KCD 8 | from sparse_shift.independence_tests import invariant_residual_test 9 | from sparse_shift.utils import dags2mechanisms 10 | from causallearn.utils.cit import fisherz, kci 11 | 12 | 13 | def test_dag_shifts(Xs, dags, test='kci', test_kwargs={}, pairwise=True): 14 | """ 15 | Tests pairwise mechanism equality across a set of dags 16 | 17 | Parameters 18 | ---------- 19 | Xs : list of np.ndarray, shape (E, n_e, m) 20 | List of observations from each environment 21 | dags : np.ndarray, shape (d, m, m) 22 | List of adjacency matrices 23 | test : {'invariant_residuals', 'kcd', 'kci', 'fisherz'}, default='kci' 24 | Test for equality of distribution 25 | test_kwargs : dict, optional 26 | Dictionary of named arguments for the independence test 27 | 28 | Returns 29 | ------- 30 | np.ndarray 31 | pvalues for each pairwise test 32 | """ 33 | E = len(Xs) 34 | M = dags[0].shape[0] 35 | mech_dict = dags2mechanisms(dags) 36 | pvalues_dict = {} 37 | for m, mechanisms in mech_dict.items(): 38 | pvalues_dict[m] = {} 39 | for parents in mechanisms: 40 | if pairwise: 41 | pvalues = test_mechanism( 42 | Xs, m, parents, test, test_kwargs 43 | ) 44 | pvalues_dict[m][tuple(parents)] = pvalues 45 | else: 46 | pvalue = test_pooled_mechanism( 47 | Xs, m, parents, test, test_kwargs 48 | ) 49 | pvalues_dict[m][tuple(parents)] = pvalue 50 | 51 | if pairwise: 52 | if test == 'linear_params': 53 | dag_pvalues = np.zeros((len(dags), M, E, E, 2)) 54 | else: 55 | dag_pvalues = np.zeros((len(dags), M, E, E)) 56 | else: 57 | dag_pvalues = np.zeros((len(dags), M)) 58 | 59 | for i, dag in enumerate(dags): 60 | for m, parents in enumerate(dag.T): # transpose to get parents 61 | dag_pvalues[i, m] = pvalues_dict[m][tuple(parents)] 62 | 63 | return dag_pvalues 64 | 65 | 66 | def test_mechanism_shifts(Xs, dag, test='kci', test_kwargs={}, alpha=0.05, pairwise=True): 67 | """ 68 | Tests pairwise mechanism equality 69 | 70 | Parameters 71 | ---------- 72 | Xs : list of np.ndarray, shape (E, n_e, m) 73 | List of observations from each environment 74 | dag : np.ndarray, shape (m, m) 75 | Adjacency matrix 76 | test : {'invariant_residuals', 'kcd'}, default='kci' 77 | Test for equality of distribution 78 | test_kwargs : dict, optional 79 | Dictionary of named arguments for the independence test 80 | 81 | Returns 82 | ------- 83 | int : total number of shifts 84 | np.ndarray, shape (e, e, m) 85 | pvalues for each pairwise test 86 | """ 87 | E = len(Xs) 88 | parent_graph = np.asarray(dag).T 89 | M = parent_graph.shape[0] 90 | 91 | pvalues = np.ones((M, E, E)) 92 | 93 | for m in range(M): 94 | pvalues[m] = test_mechanism( 95 | Xs, m, parent_graph[m], test, test_kwargs 96 | ) 97 | 98 | num_shifts = np.sum(pvalues <= alpha) // 2 99 | 100 | return num_shifts, pvalues 101 | 102 | 103 | def test_pooled_mechanism(Xs, m, parents, test='kci', test_kwargs={}): 104 | parents = np.asarray(parents).astype(bool) 105 | 106 | if test == 'fisherz': 107 | assert len(Xs) > 1 108 | # Test X \indep E | PA_X 109 | data = np.block([ 110 | [np.reshape([e] * Xs[e].shape[0], (-1, 1)), Xs[e]] 111 | for e in range(len(Xs)) 112 | ]) 113 | condition_set = tuple(np.where(parents > 0)[0] + 1) 114 | pvalue = fisherz(data, 0, m+1, condition_set) 115 | elif test == 'kci': 116 | assert len(Xs) > 1 117 | # Test X \indep E | PA_X 118 | data = np.block([ 119 | [np.reshape([e] * Xs[e].shape[0], (-1, 1)), Xs[e]] 120 | for e in range(len(Xs)) 121 | ]) 122 | condition_set = tuple(np.where(parents > 0)[0] + 1) 123 | pvalue = kci(data, 0, m+1, condition_set) 124 | else: 125 | raise ValueError(f'Test {test} not implemented.') 126 | 127 | return pvalue 128 | 129 | 130 | def test_mechanism(Xs, m, parents, test='kci', test_kwargs={}): 131 | """Tests a mechanism""" 132 | 133 | E = len(Xs) 134 | parents = np.asarray(parents).astype(bool) 135 | if test == 'linear_params': 136 | pvalues = np.ones((E, E, 2)) 137 | else: 138 | pvalues = np.ones((E, E)) 139 | 140 | for e1 in range(E): 141 | for e2 in range(e1 + 1, E): 142 | if sum(parents) == 0: 143 | stat, pvalue = MMD().test( 144 | Xs[e1][:, m].reshape(-1, 1), 145 | Xs[e2][:, m].reshape(-1, 1), 146 | ) 147 | else: 148 | if test == 'kcd': 149 | assert len(Xs) == 2 150 | _, pvalue = KCD(n_jobs=test_kwargs['n_jobs']).test( 151 | np.vstack((Xs[e1][:, parents], Xs[e2][:, parents])), 152 | np.concatenate((Xs[e1][:, m], Xs[e2][:, m])), 153 | np.asarray([0] * Xs[e1].shape[0] + [1] * Xs[e2].shape[0]), 154 | reps=test_kwargs['n_reps'], 155 | ) 156 | elif test == 'invariant_residuals': 157 | assert len(Xs) == 2 158 | pvalue, *_ = invariant_residual_test( 159 | np.vstack((Xs[e1][:, parents], Xs[e2][:, parents])), 160 | np.concatenate((Xs[e1][:, m], Xs[e2][:, m])), 161 | np.asarray([0] * Xs[e1].shape[0] + [1] * Xs[e2].shape[0]), 162 | **test_kwargs 163 | ) 164 | elif test == 'fisherz': 165 | assert len(Xs) > 1 166 | # Test X \indep E | PA_X 167 | data = np.block([ 168 | [np.reshape([0] * Xs[e1].shape[0], (-1, 1)), Xs[e1]], 169 | [np.reshape([1] * Xs[e2].shape[0], (-1, 1)), Xs[e2]] 170 | ]) 171 | condition_set = tuple(np.where(parents > 0)[0] + 1) 172 | pvalue = fisherz(data, 0, m+1, condition_set) 173 | elif test == 'kci': 174 | assert len(Xs) > 1 175 | # Test X \indep E | PA_X 176 | data = np.block([ 177 | [np.reshape([0] * Xs[e1].shape[0], (-1, 1)), Xs[e1]], 178 | [np.reshape([1] * Xs[e2].shape[0], (-1, 1)), Xs[e2]] 179 | ]) 180 | condition_set = tuple(np.where(parents > 0)[0] + 1) 181 | pvalue = kci(data, 0, m+1, condition_set) 182 | elif test == 'linear_params': 183 | assert len(Xs) == 2 184 | pvalue, *_ = invariant_residual_test( 185 | np.vstack((Xs[e1][:, parents], Xs[e2][:, parents])), 186 | np.concatenate((Xs[e1][:, m], Xs[e2][:, m])), 187 | np.asarray([0] * Xs[e1].shape[0] + [1] * Xs[e2].shape[0]), 188 | combine_pvalues=False, 189 | **test_kwargs 190 | ) 191 | else: 192 | raise ValueError(f'Test {test} not implemented.') 193 | pvalues[e1, e2] = pvalue 194 | pvalues[e2, e1] = pvalue 195 | 196 | return pvalues 197 | -------------------------------------------------------------------------------- /sparse_shift/utils.py: -------------------------------------------------------------------------------- 1 | """Tools for other functions and methods""" 2 | import numpy as np 3 | from causaldag import DAG, PDAG 4 | 5 | 6 | def check_2d(X): 7 | if X is not None and X.ndim == 1: 8 | X = X.reshape(-1, 1) 9 | return X 10 | 11 | 12 | def dags2mechanisms(dags): 13 | """ 14 | Returns a dictionary of variable: mechanisms from 15 | a list of DAGs. 16 | """ 17 | m = len(dags[0]) 18 | mech_dict = {i: [] for i in range(m)} 19 | for dag in dags: 20 | for i, mech in enumerate(dag.T): # Transpose to get parents 21 | mech_dict[i].append(mech) 22 | 23 | # remove duplicates 24 | for i in range(m): 25 | mech_dict[i] = np.unique(mech_dict[i], axis=0) 26 | 27 | return mech_dict 28 | 29 | 30 | def create_causal_learn_dag(G): 31 | """Converts directed adj matrix G to causal graph""" 32 | from causallearn.graph.Dag import Dag 33 | from causallearn.graph.GraphNode import GraphNode 34 | 35 | n_vars = G.shape[0] 36 | node_names = [("X%d" % (i + 1)) for i in range(n_vars)] 37 | nodes = [GraphNode(name) for name in node_names] 38 | 39 | cl_dag = Dag(nodes) 40 | for i in range(n_vars): 41 | for j in range(n_vars): 42 | if G[i, j] != 0: 43 | cl_dag.add_directed_edge(nodes[i], nodes[j]) 44 | 45 | return cl_dag 46 | 47 | 48 | def create_causal_learn_cpdag(G): 49 | """Converts adj mat of cpdag to a causal learn graph object""" 50 | from causallearn.graph.Edge import Edge 51 | from causallearn.graph.Endpoint import Endpoint 52 | from causallearn.graph.GeneralGraph import GeneralGraph 53 | from causallearn.graph.GraphNode import GraphNode 54 | 55 | n_vars = G.shape[0] 56 | node_names = [("X%d" % (i + 1)) for i in range(n_vars)] 57 | nodes = [GraphNode(name) for name in node_names] 58 | 59 | cl_cpdag = GeneralGraph(nodes) 60 | 61 | for i in range(n_vars): 62 | for j in range(i + 1, n_vars): 63 | if G[i, j] == 1 and G[j, i] == 1: 64 | cl_cpdag.add_edge(Edge(nodes[i], nodes[j], Endpoint.TAIL, Endpoint.TAIL)) 65 | elif G[i, j] == 1 and G[j, i] == 0: 66 | cl_cpdag.add_edge(Edge(nodes[i], nodes[j], Endpoint.TAIL, Endpoint.ARROW)) 67 | elif G[i, j] == 0 and G[j, i] == 1: 68 | cl_cpdag.add_edge(Edge(nodes[i], nodes[j], Endpoint.ARROW, Endpoint.TAIL)) 69 | 70 | return cl_cpdag 71 | 72 | 73 | def dag2cpdag(adj, targets=None): 74 | """Converts an adjacency matrix to the cpdag adjacency matrix, with potential interventions""" 75 | dag = DAG().from_amat(adj) 76 | cpdag = dag.cpdag() 77 | if targets is None: 78 | return cpdag.to_amat()[0] 79 | else: 80 | return dag.interventional_cpdag( 81 | [targets], cpdag=cpdag 82 | ).to_amat()[0] 83 | 84 | def cpdag2dags(adj): 85 | """Converts a cpdag adjacency matrix to a list of all dags""" 86 | adj = np.asarray(adj) 87 | dags_elist = list(PDAG().from_amat(adj).all_dags()) 88 | dags = [] 89 | for elist in dags_elist: 90 | G = np.zeros(adj.shape) 91 | elist = np.asarray(list(elist)) 92 | if len(elist) > 0: 93 | G[elist[:, 0], elist[:, 1]] = 1 94 | dags.append(G) 95 | 96 | return dags 97 | 98 | """ 99 | Useful causal-learn utils for reference 100 | 101 | # Orients edges in a pdag, to find a dag (not necessarily possible) 102 | causallearn.utils.PDAG2DAG import pdag2dag 103 | 104 | # Returns the CPDAG of a DAG (the MEC!) 105 | from causallearn.utils.DAG2CPDAG import dag2cpdag 106 | 107 | # Checks if two dags are in the same MEC 108 | from causallearn.utils.MECCheck import mec_check 109 | 110 | # Runs meek's orientation rules over a DAG, with optional background 111 | # knowledge. definite_meek examines definite unshielded triples 112 | from causallearn.utils.PCUtils.Meek import meek, definite_meek 113 | 114 | """ --------------------------------------------------------------------------------