├── cdmir ├── __init__.py ├── tests │ ├── __init__.py │ ├── tests_graph │ │ ├── __init__.py │ │ ├── test_digraph.py │ │ └── test_grpah_transform.py │ ├── testdata │ │ ├── graph_data │ │ │ ├── dag.1.txt │ │ │ ├── pdag.1.txt │ │ │ ├── dag.2.txt │ │ │ ├── pdag.2.txt │ │ │ ├── dag.3.txt │ │ │ ├── pdag.3.txt │ │ │ ├── dag.4.txt │ │ │ ├── pdag.4.txt │ │ │ ├── dag.5.txt │ │ │ ├── pdag.5.txt │ │ │ ├── dag.6.txt │ │ │ ├── pdag.6.txt │ │ │ ├── dag.7.txt │ │ │ ├── pdag.7.txt │ │ │ ├── dag.8.txt │ │ │ ├── pdag.8.txt │ │ │ ├── dag.9.txt │ │ │ ├── pdag.9.txt │ │ │ ├── dag.10.txt │ │ │ ├── pdag.10.txt │ │ │ ├── dag.11.txt │ │ │ ├── pdag.11.txt │ │ │ ├── dag.12.txt │ │ │ ├── pdag.12.txt │ │ │ ├── dag.13.txt │ │ │ ├── pdag.13.txt │ │ │ ├── dag.14.txt │ │ │ ├── pdag.14.txt │ │ │ ├── dag.15.txt │ │ │ ├── pdag.15.txt │ │ │ ├── dag.16.txt │ │ │ ├── pdag.16.txt │ │ │ ├── dag.17.txt │ │ │ ├── pdag.17.txt │ │ │ ├── dag.18.txt │ │ │ ├── pdag.18.txt │ │ │ ├── dag.19.txt │ │ │ ├── pdag.19.txt │ │ │ ├── dag.20.txt │ │ │ ├── pdag.20.txt │ │ │ ├── dag.21.txt │ │ │ ├── pdag.21.txt │ │ │ ├── dag.22.txt │ │ │ ├── pdag.22.txt │ │ │ ├── dag.23.txt │ │ │ ├── pdag.23.txt │ │ │ ├── dag.24.txt │ │ │ ├── pdag.24.txt │ │ │ ├── dag.25.txt │ │ │ ├── pdag.25.txt │ │ │ ├── dag.26.txt │ │ │ ├── pdag.26.txt │ │ │ ├── dag.27.txt │ │ │ ├── pdag.27.txt │ │ │ ├── dag.28.txt │ │ │ ├── pdag.28.txt │ │ │ ├── dag.29.txt │ │ │ ├── pdag.29.txt │ │ │ ├── dag.30.txt │ │ │ ├── pdag.30.txt │ │ │ ├── dag.31.txt │ │ │ ├── pdag.31.txt │ │ │ ├── dag.32.txt │ │ │ └── pdag.32.txt │ │ ├── dag.3.txt │ │ ├── cpdag.3.txt │ │ ├── cpdag.5.txt │ │ ├── dag.5.txt │ │ ├── cpdag.2.txt │ │ ├── dag.2.txt │ │ ├── cpdag.1.txt │ │ ├── cpdag.4.txt │ │ ├── dag.1.txt │ │ └── dag.4.txt │ ├── test_log1p.py │ ├── test_linear.py │ ├── test_polynomial.py │ ├── test_anm.py │ ├── test_fisherz.py │ ├── test_ica_lingam.py │ ├── test_gaussian.py │ ├── test_GeneralMarginalScore.py │ ├── test_GeneralCVScore.py │ ├── test_MultiCVScore.py │ ├── test_plot_graph.py │ ├── test_kci.py │ ├── test_desp.py │ ├── test_hawkes_simulator.py │ ├── test_datasets_utils.py │ ├── test_tensorrank.py │ ├── test_graph_evaluation.py │ └── test_pc.py ├── datasets │ ├── __init__.py │ └── utils.py ├── discovery │ ├── __init__.py │ ├── constraint │ │ ├── PBSCM │ │ │ └── __init__.py │ │ ├── PBSCM_PGF │ │ │ ├── __init__.py │ │ │ └── CCARankTest.py │ │ ├── __init__.py │ │ └── pc.py │ ├── funtional_based │ │ ├── SHP │ │ │ ├── __init__.py │ │ │ └── Generate_Hawkes_data_from_tick.py │ │ ├── one_component │ │ │ └── __init__.py │ │ ├── LearningHierarchicalStructure │ │ │ ├── __init__.py │ │ │ ├── indTest │ │ │ │ ├── __init__.py │ │ │ │ ├── FisherTest.py │ │ │ │ ├── TestObject.py │ │ │ │ ├── independence.py │ │ │ │ ├── HSICtestImpure.py │ │ │ │ ├── HSICPermutationTestObject.py │ │ │ │ ├── HSICSpectralTestObject.py │ │ │ │ ├── fastHSIC.py │ │ │ │ ├── HSIC2.py │ │ │ │ └── HSIC.py │ │ │ ├── requirements.txt │ │ │ ├── README.md │ │ │ └── GIN2.py │ │ ├── __init__.py │ │ └── lingam_based │ │ │ └── __init__.py │ └── Tensor_Rank │ │ ├── LearnCausalCluster.py │ │ └── Gtest.py ├── effect │ ├── __init__.py │ ├── LASER │ │ └── __init__.py │ └── DoublyRobust │ │ └── src │ │ ├── __init__.py │ │ ├── run.sh │ │ └── layers.py ├── transfer_learning │ └── __init__.py ├── visual │ ├── __init__.py │ ├── graph_layout.py │ └── plot_graph.py ├── utils │ ├── independence │ │ ├── functional │ │ │ ├── __init__.py │ │ │ ├── fisherz.py │ │ │ ├── kci.py │ │ │ └── HSIC.py │ │ ├── __init__.py │ │ ├── kci.py │ │ ├── dsep.py │ │ ├── basic_independence.py │ │ ├── fisherz.py │ │ └── _base.py │ ├── __init__.py │ ├── metrics │ │ ├── __init__.py │ │ └── graph_evaluation.py │ ├── kernel │ │ ├── __init__.py │ │ ├── linear.py │ │ ├── polynomial.py │ │ ├── _base.py │ │ └── gaussian.py │ ├── local_score │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── bic_score.py │ │ └── bdeu_score.py │ └── adapters.py └── graph │ ├── __init__.py │ ├── mark.py │ ├── edge.py │ ├── dag2cpdag.py │ ├── pdag2dag.py │ └── digraph.py ├── docs ├── requirements.txt ├── source │ ├── effect_methods │ │ ├── DoublyRobust │ │ │ └── index.rst │ │ ├── OTCI │ │ │ └── index.rst │ │ └── index.rst │ ├── discovery_methods │ │ ├── tensor_rank │ │ │ ├── index.rst │ │ │ └── Tensor_rank │ │ │ │ └── tensor_rank.rst │ │ ├── LaHiCaSI │ │ │ ├── index.rst │ │ │ └── LaHiCaSI.rst │ │ ├── functional_based │ │ │ ├── index.rst │ │ │ ├── ANM │ │ │ │ └── anm.rst │ │ │ ├── SHP │ │ │ │ └── shp.rst │ │ │ └── OLC │ │ │ │ └── olc.rst │ │ ├── constraint │ │ │ ├── index.rst │ │ │ ├── PBSCM │ │ │ │ └── pbscm.rst │ │ │ ├── PBSCM_PGF │ │ │ │ └── pbscm_pgf.rst │ │ │ └── pc │ │ │ │ └── pc.rst │ │ └── index.rst │ ├── utilities_index │ │ ├── index.rst │ │ └── graph_operations │ │ │ ├── index.rst │ │ │ ├── pdag2dag │ │ │ └── pdag2dag.rst │ │ │ └── dag2cpdag │ │ │ └── dag2cpdag.rst │ ├── index.rst │ ├── conf.py │ └── getting_started.rst ├── Makefile └── make.bat ├── images └── causal-discovery.png ├── .readthedocs.yaml ├── requirements.txt ├── setup.py ├── .github └── workflows │ └── python-package.yml ├── README.md └── .gitignore /cdmir/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/discovery/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/effect/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/effect/LASER/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/transfer_learning/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/discovery/constraint/PBSCM/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/discovery/constraint/PBSCM_PGF/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/SHP/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/effect/DoublyRobust/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-rtd-theme -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/one_component/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/discovery/constraint/__init__.py: -------------------------------------------------------------------------------- 1 | from .pc import PC 2 | -------------------------------------------------------------------------------- /cdmir/tests/tests_graph/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_graph import TestGraph -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cdmir/visual/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph_layout import circular_layout 2 | from .plot_graph import plot_graph 3 | -------------------------------------------------------------------------------- /images/causal-discovery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DMIRLAB-Group/CDMIR/HEAD/images/causal-discovery.png -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/__init__.py: -------------------------------------------------------------------------------- 1 | from .lingam_based import * 2 | 3 | __all__ = [ 4 | "ICA_LINGAM" 5 | ] -------------------------------------------------------------------------------- /cdmir/utils/independence/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .fisherz import fisherz, fisherz_from_corr 2 | from .kci import KCI 3 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/lingam_based/__init__.py: -------------------------------------------------------------------------------- 1 | from .ica_lingam import ICA_LINGAM 2 | 3 | __all__ = [ 4 | "ICA_LINGAM" 5 | ] -------------------------------------------------------------------------------- /cdmir/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .adapters import data_form_converter_for_class_method 2 | from .independence import ConditionalIndependentTest 3 | -------------------------------------------------------------------------------- /docs/source/effect_methods/DoublyRobust/index.rst: -------------------------------------------------------------------------------- 1 | DoublyRobust 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | DoublyRobust 8 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/indTest/__init__.py: -------------------------------------------------------------------------------- 1 | # from .HSIC2 import * 2 | # 3 | # __all__ = [ 4 | # "HSIC2", 5 | # ] 6 | -------------------------------------------------------------------------------- /docs/source/effect_methods/OTCI/index.rst: -------------------------------------------------------------------------------- 1 | OTCI (Optimal Transport Causal Inference) 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | OTCI -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | networkx 4 | kerpy (option) 5 | HSIC 6 | matplotlib 7 | scipy 8 | -------------------------------------------------------------------------------- /cdmir/graph/__init__.py: -------------------------------------------------------------------------------- 1 | from .edge import Edge 2 | from .graph import Graph 3 | from .mark import Mark 4 | from .pdag import PDAG 5 | from .digraph import DiGraph 6 | -------------------------------------------------------------------------------- /docs/source/discovery_methods/tensor_rank/index.rst: -------------------------------------------------------------------------------- 1 | Tensor_Rank 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | Tensor_rank/tensor_rank -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.1.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X10 --> X16 6 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.1.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X10 --> X16 6 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.2.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X8 --> X18 6 | 2. X16 --> X10 7 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.2.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X8 --> X18 6 | 2. X10 --- X16 7 | -------------------------------------------------------------------------------- /cdmir/utils/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph_evaluation import ( 2 | arrow_evaluation, 3 | directed_edge_evaluation, 4 | graph_equal, 5 | shd, 6 | skeleton_evaluation, 7 | ) -------------------------------------------------------------------------------- /docs/source/discovery_methods/LaHiCaSI/index.rst: -------------------------------------------------------------------------------- 1 | LaHiCaSI (Latent Hierarchical Causal Structure Learning) 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | LaHiCaSI -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.3.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X3 --> X13 6 | 2. X18 --> X8 7 | 3. X16 --> X10 8 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.3.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X3 --> X13 6 | 2. X8 --- X18 7 | 3. X10 --- X16 8 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.4.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X13 --> X3 6 | 2. X8 --> X9 7 | 3. X8 --> X18 8 | 4. X16 --> X10 9 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.4.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X3 --- X13 6 | 2. X8 --> X9 7 | 3. X8 --- X18 8 | 4. X10 --- X16 9 | -------------------------------------------------------------------------------- /docs/source/utilities_index/index.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | 4 | In this section, we introduce utility functions and helper tools in CDMIR. 5 | 6 | .. toctree:: 7 | :maxdepth: 3 8 | 9 | graph_operations/index 10 | datasets 11 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-24.04 5 | tools: 6 | python: "3.12" 7 | 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.5.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X3 --> X9 6 | 2. X3 --> X13 7 | 3. X8 --> X9 8 | 4. X8 --> X18 9 | 5. X16 --> X10 10 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.5.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X3 --> X9 6 | 2. X3 --- X13 7 | 3. X8 --> X9 8 | 4. X8 --- X18 9 | 5. X10 --- X16 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<=1.26.4 2 | pandas 3 | scipy>=1.8.1 4 | scikit-learn 5 | torch>=1.7.1 6 | networkx 7 | matplotlib 8 | setuptools 9 | igraph 10 | lingam 11 | pgmpy<1.0.0 12 | tensorly==0.8.1 13 | tqdm>=4 14 | KDEpy>=1.1.12,<2 15 | statsmodels==0.14.4 -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.6.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X3 --> X9 6 | 2. X3 --> X13 7 | 3. X8 --> X9 8 | 4. X8 --> X18 9 | 5. X16 --> X10 10 | 6. X15 --> X19 11 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.6.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X3 --> X9 6 | 2. X3 --- X13 7 | 3. X8 --> X9 8 | 4. X8 --- X18 9 | 5. X10 --- X16 10 | 6. X15 --> X19 11 | -------------------------------------------------------------------------------- /cdmir/utils/independence/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_independence import ConditionalIndependentTest 2 | from .dsep import Dsep 3 | from .fisherz import FisherZ 4 | # from .kci import KCI 5 | from .kernel_based import KCI 6 | 7 | __all__ = [ 8 | "KCI", 9 | "FisherZ" 10 | ] -------------------------------------------------------------------------------- /docs/source/effect_methods/index.rst: -------------------------------------------------------------------------------- 1 | Effect methods 2 | ================= 3 | 4 | In this section, we introduce effect methods implemented in CDMIR. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | LASER/LASER 10 | DoublyRobust/DoublyRobust 11 | OTCI/OTCI 12 | 13 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.7.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X3 --> X9 6 | 2. X3 --> X13 7 | 3. X8 --> X9 8 | 4. X8 --> X18 9 | 5. X8 --> X19 10 | 6. X16 --> X10 11 | 7. X15 --> X19 12 | -------------------------------------------------------------------------------- /cdmir/utils/kernel/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseKernel 2 | from .gaussian import GaussianKernel 3 | from .linear import LinearKernel 4 | from .polynomial import PolynomialKernel 5 | 6 | __all__ = [ 7 | "GaussianKernel", "LinearKernel", "PolynomialKernel", "BaseKernel" 8 | ] 9 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.7.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X3 --> X9 6 | 2. X3 --- X13 7 | 3. X8 --> X9 8 | 4. X8 --- X18 9 | 5. X8 --> X19 10 | 6. X10 --- X16 11 | 7. X15 --> X19 12 | -------------------------------------------------------------------------------- /docs/source/discovery_methods/functional_based/index.rst: -------------------------------------------------------------------------------- 1 | Functional-based 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | ANM/anm 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | 12 | SHP/shp 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | 17 | OLC/olc -------------------------------------------------------------------------------- /docs/source/discovery_methods/constraint/index.rst: -------------------------------------------------------------------------------- 1 | Constraint-based 2 | ================= 3 | .. toctree:: 4 | :maxdepth: 2 5 | 6 | PBSCM/pbscm 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | PBSCM_PGF/pbscm_pgf 12 | 13 | .. toctree:: 14 | :maxdepth: 2 15 | 16 | pc/pc -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.8.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X3 --> X9 7 | 3. X13 --> X3 8 | 4. X8 --> X9 9 | 5. X8 --> X18 10 | 6. X8 --> X19 11 | 7. X16 --> X10 12 | 8. X15 --> X19 13 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.8.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X3 --> X9 7 | 3. X3 --- X13 8 | 4. X8 --> X9 9 | 5. X8 --- X18 10 | 6. X8 --> X19 11 | 7. X10 --- X16 12 | 8. X15 --> X19 13 | -------------------------------------------------------------------------------- /docs/source/utilities_index/graph_operations/index.rst: -------------------------------------------------------------------------------- 1 | Graph operations 2 | ================ 3 | 4 | In this section, we would like to introduce graph operations in causal-learn. 5 | 6 | Contents: 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | dag2cpdag/dag2cpdag 12 | pdag2dag/pdag2dag -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.9.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X9 8 | 4. X13 --> X3 9 | 5. X8 --> X9 10 | 6. X8 --> X18 11 | 7. X8 --> X19 12 | 8. X16 --> X10 13 | 9. X15 --> X19 14 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.9.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X9 8 | 4. X3 --- X13 9 | 5. X8 --> X9 10 | 6. X8 --- X18 11 | 7. X8 --> X19 12 | 8. X10 --- X16 13 | 9. X15 --> X19 14 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.10.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X9 8 | 4. X13 --> X3 9 | 5. X5 --> X20 10 | 6. X8 --> X9 11 | 7. X8 --> X18 12 | 8. X8 --> X19 13 | 9. X16 --> X10 14 | 10. X15 --> X19 15 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.10.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X9 8 | 4. X13 --> X3 9 | 5. X5 --> X20 10 | 6. X8 --> X9 11 | 7. X8 --- X18 12 | 8. X8 --> X19 13 | 9. X10 --- X16 14 | 10. X15 --> X19 15 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.11.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X9 8 | 4. X13 --> X3 9 | 5. X5 --> X20 10 | 6. X8 --> X9 11 | 7. X8 --> X18 12 | 8. X8 --> X19 13 | 9. X16 --> X10 14 | 10. X15 --> X19 15 | 11. X16 --> X20 16 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.11.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X9 8 | 4. X13 --> X3 9 | 5. X5 --> X20 10 | 6. X8 --> X9 11 | 7. X8 --- X18 12 | 8. X8 --> X19 13 | 9. X10 --- X16 14 | 10. X15 --> X19 15 | 11. X16 --> X20 16 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.12.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X9 8 | 4. X13 --> X3 9 | 5. X5 --> X20 10 | 6. X6 --> X7 11 | 7. X8 --> X9 12 | 8. X8 --> X18 13 | 9. X8 --> X19 14 | 10. X16 --> X10 15 | 11. X15 --> X19 16 | 12. X16 --> X20 17 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.12.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X9 8 | 4. X13 --> X3 9 | 5. X5 --> X20 10 | 6. X6 --> X7 11 | 7. X8 --> X9 12 | 8. X8 --- X18 13 | 9. X8 --> X19 14 | 10. X10 --- X16 15 | 11. X15 --> X19 16 | 12. X16 --> X20 17 | -------------------------------------------------------------------------------- /cdmir/tests/test_log1p.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from numpy import isclose, log, log1p 4 | from numpy.random import uniform 5 | 6 | 7 | class TestLog1p(TestCase): 8 | def test_case(self): 9 | x = uniform(-0.8, 0.8, size=(10000,)) 10 | y1 = 0.5 * log1p(2 * x / (1 - x)) 11 | y2 = 0.5 * log((1 + x) / (1 - x)) 12 | assert isclose(y1, y2).all() 13 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.13.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X9 8 | 4. X13 --> X3 9 | 5. X5 --> X20 10 | 6. X7 --> X6 11 | 7. X8 --> X9 12 | 8. X8 --> X18 13 | 9. X8 --> X19 14 | 10. X12 --> X10 15 | 11. X10 --> X16 16 | 12. X15 --> X19 17 | 13. X16 --> X20 18 | -------------------------------------------------------------------------------- /docs/source/discovery_methods/index.rst: -------------------------------------------------------------------------------- 1 | Discovery methods 2 | ================= 3 | 4 | In this section, we introduce discovery methods implemented in CDMIR. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | Constraint-based methods 10 | Functional-based methods 11 | Tensor-Rank methods 12 | LaHiCaSI 13 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.13.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X9 8 | 4. X13 --> X3 9 | 5. X5 --> X20 10 | 6. X6 --- X7 11 | 7. X8 --> X9 12 | 8. X8 --- X18 13 | 9. X8 --> X19 14 | 10. X12 --> X10 15 | 11. X10 --- X16 16 | 12. X15 --> X19 17 | 13. X16 --> X20 18 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.14.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X13 --> X3 10 | 6. X5 --> X20 11 | 7. X6 --> X7 12 | 8. X8 --> X9 13 | 9. X8 --> X18 14 | 10. X8 --> X19 15 | 11. X10 --> X12 16 | 12. X16 --> X10 17 | 13. X15 --> X19 18 | 14. X16 --> X20 19 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.14.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X13 --> X3 10 | 6. X5 --> X20 11 | 7. X6 --- X7 12 | 8. X8 --> X9 13 | 9. X8 --- X18 14 | 10. X8 --> X19 15 | 11. X10 --- X12 16 | 12. X10 --- X16 17 | 13. X15 --> X19 18 | 14. X16 --> X20 19 | -------------------------------------------------------------------------------- /cdmir/tests/test_linear.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cdmir.utils.kernel._base import BaseKernel 4 | from cdmir.utils.kernel.linear import LinearKernel 5 | def test_linear(): 6 | 7 | arr1=np.array([1,2,3]) 8 | arr2=np.array([1,1,1]) 9 | 10 | l=LinearKernel() 11 | l(arr1,arr2) 12 | #l.__call__(arr1,arr2) 13 | print(arr1.dot(arr2.T)) 14 | print(l.__call__(arr1,arr2)) 15 | 16 | test_linear() -------------------------------------------------------------------------------- /cdmir/graph/mark.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class Mark(IntEnum): 5 | Tail = -1 6 | Null = 0 7 | Arrow = 1 8 | Circle = 2 9 | 10 | def __str__(self): 11 | return self.name 12 | 13 | @staticmethod 14 | def pdag_marks(): 15 | return [Mark.Tail, Mark.Arrow] 16 | 17 | @staticmethod 18 | def pag_marks(): 19 | return [Mark.Tail, Mark.Arrow, Mark.Circle] 20 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.15.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X13 --> X3 10 | 6. X5 --> X4 11 | 7. X5 --> X20 12 | 8. X6 --> X7 13 | 9. X8 --> X9 14 | 10. X8 --> X18 15 | 11. X8 --> X19 16 | 12. X10 --> X12 17 | 13. X16 --> X10 18 | 14. X15 --> X19 19 | 15. X16 --> X20 20 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.15.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X13 --> X3 10 | 6. X5 --> X4 11 | 7. X5 --> X20 12 | 8. X6 --> X7 13 | 9. X8 --> X9 14 | 10. X8 --- X18 15 | 11. X8 --> X19 16 | 12. X10 --- X12 17 | 13. X10 --- X16 18 | 14. X15 --> X19 19 | 15. X16 --> X20 20 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/dag.3.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10 3 | 4 | Graph Edges: 5 | 1. X3 --> X1 6 | 2. X8 --> X1 7 | 3. X9 --> X1 8 | 4. X3 --> X2 9 | 5. X8 --> X2 10 | 6. X2 --> X9 11 | 7. X3 --> X5 12 | 8. X3 --> X7 13 | 9. X3 --> X9 14 | 10. X3 --> X10 15 | 11. X8 --> X4 16 | 12. X4 --> X9 17 | 13. X10 --> X4 18 | 14. X5 --> X6 19 | 15. X5 --> X7 20 | 16. X8 --> X5 21 | 17. X5 --> X9 22 | 18. X6 --> X7 23 | 19. X8 --> X6 24 | -------------------------------------------------------------------------------- /docs/source/utilities_index/graph_operations/pdag2dag/pdag2dag.rst: -------------------------------------------------------------------------------- 1 | PDAG2DAG 2 | ============== 3 | 4 | Convert a PDAG to its corresponding DAG. 5 | 6 | Usage 7 | -------- 8 | .. code-block:: python 9 | 10 | from cdmir.graph import pdag2dag 11 | DAG = pdag2dag(G) 12 | 13 | Parameters 14 | --------------------- 15 | **G**: Partially Directed Acyclic Graph. 16 | 17 | Returns 18 | -------------- 19 | **DAG**: Directed Acyclic Graph. 20 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/cpdag.3.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10 3 | 4 | Graph Edges: 5 | 1. X3 --> X1 6 | 2. X8 --> X1 7 | 3. X9 --> X1 8 | 4. X3 --> X2 9 | 5. X8 --> X2 10 | 6. X2 --> X9 11 | 7. X3 --> X5 12 | 8. X3 --> X7 13 | 9. X3 --> X9 14 | 10. X3 --- X10 15 | 11. X8 --> X4 16 | 12. X4 --> X9 17 | 13. X10 --> X4 18 | 14. X5 --> X6 19 | 15. X5 --> X7 20 | 16. X8 --> X5 21 | 17. X5 --> X9 22 | 18. X6 --> X7 23 | 19. X8 --> X6 24 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.16.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X13 --> X3 10 | 6. X5 --> X4 11 | 7. X5 --> X20 12 | 8. X6 --> X7 13 | 9. X8 --> X9 14 | 10. X8 --> X14 15 | 11. X8 --> X18 16 | 12. X8 --> X19 17 | 13. X10 --> X12 18 | 14. X16 --> X10 19 | 15. X15 --> X19 20 | 16. X16 --> X20 21 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.16.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X13 --> X3 10 | 6. X4 --- X5 11 | 7. X5 --> X20 12 | 8. X6 --> X7 13 | 9. X8 --> X9 14 | 10. X8 --> X14 15 | 11. X8 --- X18 16 | 12. X8 --> X19 17 | 13. X10 --- X12 18 | 14. X10 --- X16 19 | 15. X15 --> X19 20 | 16. X16 --> X20 21 | -------------------------------------------------------------------------------- /docs/source/utilities_index/graph_operations/dag2cpdag/dag2cpdag.rst: -------------------------------------------------------------------------------- 1 | DAG2CPDAG 2 | ============== 3 | 4 | Convert a DAG to its corresponding CPDAG. 5 | 6 | Usage 7 | -------- 8 | .. code-block:: python 9 | 10 | from cdmir.graph import dag2cpdag 11 | CPDAG = dag2cpdag(G) 12 | 13 | Parameters 14 | --------------------- 15 | **G**: Directed Acyclic Graph. 16 | 17 | Returns 18 | -------------- 19 | **CPDAG**: Completed Partially Directed Acyclic Graph. 20 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/cpdag.5.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10 3 | 4 | Graph Edges: 5 | 1. X3 --> X1 6 | 2. X4 --> X1 7 | 3. X1 --> X7 8 | 4. X8 --> X1 9 | 5. X3 --> X2 10 | 6. X4 --> X2 11 | 7. X6 --> X2 12 | 8. X2 --> X10 13 | 9. X3 --> X5 14 | 10. X3 --> X10 15 | 11. X4 --> X5 16 | 12. X4 --- X6 17 | 13. X4 --> X8 18 | 14. X4 --> X10 19 | 15. X5 --> X7 20 | 16. X5 --> X8 21 | 17. X6 --> X8 22 | 18. X8 --> X9 23 | 19. X10 --> X8 24 | 20. X10 --> X9 25 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/dag.5.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10 3 | 4 | Graph Edges: 5 | 1. X3 --> X1 6 | 2. X4 --> X1 7 | 3. X1 --> X7 8 | 4. X8 --> X1 9 | 5. X3 --> X2 10 | 6. X4 --> X2 11 | 7. X6 --> X2 12 | 8. X2 --> X10 13 | 9. X3 --> X5 14 | 10. X3 --> X10 15 | 11. X4 --> X5 16 | 12. X4 --> X6 17 | 13. X4 --> X8 18 | 14. X4 --> X10 19 | 15. X5 --> X7 20 | 16. X5 --> X8 21 | 17. X6 --> X8 22 | 18. X8 --> X9 23 | 19. X10 --> X8 24 | 20. X10 --> X9 25 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.17.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X13 --> X3 10 | 6. X5 --> X4 11 | 7. X5 --> X20 12 | 8. X6 --> X7 13 | 9. X7 --> X14 14 | 10. X8 --> X9 15 | 11. X8 --> X14 16 | 12. X8 --> X18 17 | 13. X8 --> X19 18 | 14. X10 --> X12 19 | 15. X16 --> X10 20 | 16. X15 --> X19 21 | 17. X16 --> X20 22 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.17.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X13 --> X3 10 | 6. X4 --- X5 11 | 7. X5 --> X20 12 | 8. X6 --> X7 13 | 9. X7 --> X14 14 | 10. X8 --> X9 15 | 11. X8 --> X14 16 | 12. X8 --- X18 17 | 13. X8 --> X19 18 | 14. X10 --- X12 19 | 15. X10 --- X16 20 | 16. X15 --> X19 21 | 17. X16 --> X20 22 | -------------------------------------------------------------------------------- /cdmir/tests/test_polynomial.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cdmir.utils.kernel.polynomial import PolynomialKernel 4 | 5 | def test_polynomial(): #多项式核函数 (const+x.dot(y.T))^2 6 | arr1=np.array([[1,2,3]]) 7 | arr2=np.array([[1,1,1]]) 8 | 9 | np.random.seed(1) 10 | xs=np.random.randn(1,3) 11 | ys=np.random.randn(1,3) 12 | 13 | pk=PolynomialKernel() 14 | print(pk.__call__(arr1,arr2)) 15 | print(pk.__call__(xs,ys)) 16 | 17 | test_polynomial() 18 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/cpdag.2.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X9 --> X1 7 | 3. X10 --> X1 8 | 4. X2 --> X4 9 | 5. X2 --> X5 10 | 6. X2 --> X8 11 | 7. X3 --> X5 12 | 8. X3 --> X6 13 | 9. X3 --> X8 14 | 10. X3 --- X9 15 | 11. X5 --> X4 16 | 12. X4 --> X6 17 | 13. X8 --> X4 18 | 14. X10 --> X4 19 | 15. X5 --- X8 20 | 16. X7 --> X6 21 | 17. X9 --> X6 22 | 18. X10 --> X6 23 | 19. X7 --> X10 24 | 20. X8 --> X10 25 | 21. X9 --> X10 26 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/dag.2.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X9 --> X1 7 | 3. X10 --> X1 8 | 4. X2 --> X4 9 | 5. X2 --> X5 10 | 6. X2 --> X8 11 | 7. X3 --> X5 12 | 8. X3 --> X6 13 | 9. X3 --> X8 14 | 10. X3 --> X9 15 | 11. X5 --> X4 16 | 12. X4 --> X6 17 | 13. X8 --> X4 18 | 14. X10 --> X4 19 | 15. X5 --> X8 20 | 16. X7 --> X6 21 | 17. X9 --> X6 22 | 18. X10 --> X6 23 | 19. X7 --> X10 24 | 20. X8 --> X10 25 | 21. X9 --> X10 26 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.18.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X13 --> X3 10 | 6. X5 --> X4 11 | 7. X5 --> X20 12 | 8. X6 --> X7 13 | 9. X7 --> X14 14 | 10. X8 --> X9 15 | 11. X8 --> X14 16 | 12. X8 --> X18 17 | 13. X8 --> X19 18 | 14. X12 --> X10 19 | 15. X10 --> X16 20 | 16. X12 --> X15 21 | 17. X15 --> X19 22 | 18. X16 --> X20 23 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.18.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X13 --> X3 10 | 6. X4 --- X5 11 | 7. X5 --> X20 12 | 8. X6 --> X7 13 | 9. X7 --> X14 14 | 10. X8 --> X9 15 | 11. X8 --> X14 16 | 12. X8 --- X18 17 | 13. X8 --> X19 18 | 14. X10 --- X12 19 | 15. X10 --- X16 20 | 16. X12 --> X15 21 | 17. X15 --> X19 22 | 18. X16 --> X20 23 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/cpdag.1.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10 3 | 4 | Graph Edges: 5 | 1. X2 --> X1 6 | 2. X3 --> X1 7 | 3. X6 --> X1 8 | 4. X1 --> X9 9 | 5. X4 --> X2 10 | 6. X2 --> X8 11 | 7. X2 --> X9 12 | 8. X10 --> X2 13 | 9. X3 --> X4 14 | 10. X3 --> X5 15 | 11. X3 --> X10 16 | 12. X4 --> X8 17 | 13. X10 --> X4 18 | 14. X6 --> X5 19 | 15. X5 --> X8 20 | 16. X5 --> X9 21 | 17. X5 --- X10 22 | 18. X6 --- X7 23 | 19. X6 --> X8 24 | 20. X6 --> X9 25 | 21. X6 --> X10 26 | 22. X10 --> X8 27 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/cpdag.4.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10 3 | 4 | Graph Edges: 5 | 1. X2 --> X1 6 | 2. X4 --> X1 7 | 3. X7 --> X1 8 | 4. X8 --> X1 9 | 5. X9 --> X1 10 | 6. X10 --> X1 11 | 7. X2 --- X6 12 | 8. X2 --> X7 13 | 9. X2 --- X9 14 | 10. X5 --> X3 15 | 11. X3 --> X7 16 | 12. X9 --> X3 17 | 13. X3 --> X10 18 | 14. X4 --> X7 19 | 15. X5 --> X10 20 | 16. X6 --> X7 21 | 17. X6 --- X9 22 | 18. X6 --> X10 23 | 19. X8 --> X7 24 | 20. X10 --> X7 25 | 21. X8 --- X9 26 | 22. X9 --> X10 27 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/dag.1.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10 3 | 4 | Graph Edges: 5 | 1. X2 --> X1 6 | 2. X3 --> X1 7 | 3. X6 --> X1 8 | 4. X1 --> X9 9 | 5. X4 --> X2 10 | 6. X2 --> X8 11 | 7. X2 --> X9 12 | 8. X10 --> X2 13 | 9. X3 --> X4 14 | 10. X3 --> X5 15 | 11. X3 --> X10 16 | 12. X4 --> X8 17 | 13. X10 --> X4 18 | 14. X6 --> X5 19 | 15. X5 --> X8 20 | 16. X5 --> X9 21 | 17. X10 --> X5 22 | 18. X6 --> X7 23 | 19. X6 --> X8 24 | 20. X6 --> X9 25 | 21. X6 --> X10 26 | 22. X10 --> X8 27 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/dag.4.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10 3 | 4 | Graph Edges: 5 | 1. X2 --> X1 6 | 2. X4 --> X1 7 | 3. X7 --> X1 8 | 4. X8 --> X1 9 | 5. X9 --> X1 10 | 6. X10 --> X1 11 | 7. X6 --> X2 12 | 8. X2 --> X7 13 | 9. X9 --> X2 14 | 10. X5 --> X3 15 | 11. X3 --> X7 16 | 12. X9 --> X3 17 | 13. X3 --> X10 18 | 14. X4 --> X7 19 | 15. X5 --> X10 20 | 16. X6 --> X7 21 | 17. X9 --> X6 22 | 18. X6 --> X10 23 | 19. X8 --> X7 24 | 20. X10 --> X7 25 | 21. X9 --> X8 26 | 22. X9 --> X10 27 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.19.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X3 --> X10 10 | 6. X13 --> X3 11 | 7. X5 --> X4 12 | 8. X5 --> X20 13 | 9. X6 --> X7 14 | 10. X7 --> X14 15 | 11. X8 --> X9 16 | 12. X8 --> X14 17 | 13. X8 --> X18 18 | 14. X8 --> X19 19 | 15. X10 --> X12 20 | 16. X10 --> X16 21 | 17. X12 --> X15 22 | 18. X15 --> X19 23 | 19. X16 --> X20 24 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.19.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X13 6 | 2. X2 --> X13 7 | 3. X3 --> X6 8 | 4. X3 --> X9 9 | 5. X3 --> X10 10 | 6. X13 --> X3 11 | 7. X4 --- X5 12 | 8. X5 --> X20 13 | 9. X6 --> X7 14 | 10. X7 --> X14 15 | 11. X8 --> X9 16 | 12. X8 --> X14 17 | 13. X8 --- X18 18 | 14. X8 --> X19 19 | 15. X10 --- X12 20 | 16. X10 --- X16 21 | 17. X12 --- X15 22 | 18. X15 --> X19 23 | 19. X16 --> X20 24 | -------------------------------------------------------------------------------- /cdmir/utils/local_score/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import BaseLocalScoreFunction 2 | from .bdeu_score import BDeuScore 3 | from .bic_score import BICScore 4 | from .cross_validated_base import GeneralCVScore, MultiCVScore 5 | from .marginal_base import GeneralMarginalScore, MultiMarginalScore 6 | 7 | 8 | __all__ = [ 9 | "BaseLocalScoreFunction", 10 | "BICScore", 11 | "BDeuScore", 12 | "GeneralCVScore", 13 | "MultiCVScore", 14 | "GeneralMarginalScore", 15 | "MultiMarginalScore" 16 | ] 17 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.20.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X4 --> X1 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --> X5 13 | 9. X5 --> X20 14 | 10. X6 --> X7 15 | 11. X7 --> X14 16 | 12. X8 --> X9 17 | 13. X8 --> X14 18 | 14. X8 --> X18 19 | 15. X8 --> X19 20 | 16. X10 --> X12 21 | 17. X10 --> X16 22 | 18. X12 --> X15 23 | 19. X15 --> X19 24 | 20. X16 --> X20 25 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.20.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X4 --> X1 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --- X5 13 | 9. X5 --> X20 14 | 10. X6 --> X7 15 | 11. X7 --> X14 16 | 12. X8 --> X9 17 | 13. X8 --> X14 18 | 14. X8 --- X18 19 | 15. X8 --> X19 20 | 16. X10 --> X12 21 | 17. X10 --> X16 22 | 18. X12 --> X15 23 | 19. X15 --> X19 24 | 20. X16 --> X20 25 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.21.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --> X5 13 | 9. X5 --> X7 14 | 10. X5 --> X20 15 | 11. X6 --> X7 16 | 12. X7 --> X14 17 | 13. X8 --> X9 18 | 14. X8 --> X14 19 | 15. X8 --> X18 20 | 16. X8 --> X19 21 | 17. X10 --> X12 22 | 18. X10 --> X16 23 | 19. X12 --> X15 24 | 20. X15 --> X19 25 | 21. X16 --> X20 26 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.21.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --- X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --- X5 13 | 9. X5 --> X7 14 | 10. X5 --> X20 15 | 11. X6 --> X7 16 | 12. X7 --> X14 17 | 13. X8 --> X9 18 | 14. X8 --> X14 19 | 15. X8 --- X18 20 | 16. X8 --> X19 21 | 17. X10 --> X12 22 | 18. X10 --> X16 23 | 19. X12 --> X15 24 | 20. X15 --> X19 25 | 21. X16 --> X20 26 | -------------------------------------------------------------------------------- /cdmir/utils/independence/kci.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | from cdmir.utils.independence.basic_independence import ConditionalIndependentTest 4 | from cdmir.utils.independence.functional import kci 5 | 6 | 7 | class KCI(ConditionalIndependentTest): 8 | def __init__(self, data, var_names=None): 9 | super().__init__(data, var_names=var_names) 10 | self._num_records = data.shape[0] 11 | 12 | def cal_stats(self, x: int, y: int, z: Iterable[int] = None): 13 | return kci(self._data[:, x], self._data[:, y], self._data[:, z]) 14 | -------------------------------------------------------------------------------- /cdmir/visual/graph_layout.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from numpy import cos, linspace, pi, sin 4 | 5 | from cdmir.graph import Graph 6 | 7 | 8 | def circular_layout(graph: Graph, scale=0.45, sort_node=False) -> Dict: 9 | node_list = [str(node) for node in graph.nodes] 10 | 11 | n = len(node_list) 12 | 13 | if sort_node: 14 | node_list = sorted(node_list) 15 | angle = pi / 2 - linspace(0, 2 * pi, n, endpoint=False) 16 | 17 | return {node: (scale * cos(angle[i]), scale * sin(angle[i])) for i, node in enumerate(node_list)} 18 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.22.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --> X5 13 | 9. X5 --> X7 14 | 10. X5 --> X20 15 | 11. X6 --> X7 16 | 12. X7 --> X14 17 | 13. X8 --> X9 18 | 14. X8 --> X14 19 | 15. X8 --> X18 20 | 16. X8 --> X19 21 | 17. X10 --> X12 22 | 18. X10 --> X16 23 | 19. X11 --> X15 24 | 20. X12 --> X15 25 | 21. X15 --> X19 26 | 22. X16 --> X20 27 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.22.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --- X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --- X5 13 | 9. X5 --> X7 14 | 10. X5 --> X20 15 | 11. X6 --> X7 16 | 12. X7 --> X14 17 | 13. X8 --> X9 18 | 14. X8 --> X14 19 | 15. X8 --- X18 20 | 16. X8 --> X19 21 | 17. X10 --> X12 22 | 18. X10 --> X16 23 | 19. X11 --> X15 24 | 20. X12 --> X15 25 | 21. X15 --> X19 26 | 22. X16 --> X20 27 | -------------------------------------------------------------------------------- /cdmir/utils/kernel/linear.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from numpy import ndarray 4 | 5 | from cdmir.utils.kernel._base import BaseKernel 6 | 7 | 8 | class LinearKernel(BaseKernel): 9 | 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def __call__(self, xs: ndarray, ys: ndarray, *args, **kwargs): 14 | return self._BaseKernel__kernel(xs, ys, self.__kernel_func) #self.__kernel改成self._BaseKernel__kernel() 15 | 16 | def __kernel_func(self, x: ndarray, y: ndarray): 17 | return x.dot(y.T) #dot()矩阵乘法运算 一维的时候就是两个数字的乘积 y.T表示y的转置 18 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.23.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --> X5 13 | 9. X4 --> X15 14 | 10. X5 --> X7 15 | 11. X5 --> X20 16 | 12. X6 --> X7 17 | 13. X7 --> X14 18 | 14. X8 --> X9 19 | 15. X8 --> X14 20 | 16. X8 --> X18 21 | 17. X8 --> X19 22 | 18. X10 --> X12 23 | 19. X10 --> X16 24 | 20. X11 --> X15 25 | 21. X12 --> X15 26 | 22. X15 --> X19 27 | 23. X16 --> X20 28 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.23.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --- X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --- X5 13 | 9. X4 --> X15 14 | 10. X5 --> X7 15 | 11. X5 --> X20 16 | 12. X6 --> X7 17 | 13. X7 --> X14 18 | 14. X8 --> X9 19 | 15. X8 --> X14 20 | 16. X8 --- X18 21 | 17. X8 --> X19 22 | 18. X10 --> X12 23 | 19. X10 --> X16 24 | 20. X11 --> X15 25 | 21. X12 --> X15 26 | 22. X15 --> X19 27 | 23. X16 --> X20 28 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.24.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --> X5 13 | 9. X4 --> X14 14 | 10. X4 --> X15 15 | 11. X5 --> X7 16 | 12. X5 --> X20 17 | 13. X6 --> X7 18 | 14. X7 --> X14 19 | 15. X8 --> X9 20 | 16. X8 --> X14 21 | 17. X8 --> X18 22 | 18. X8 --> X19 23 | 19. X10 --> X12 24 | 20. X10 --> X16 25 | 21. X11 --> X15 26 | 22. X12 --> X15 27 | 23. X15 --> X19 28 | 24. X16 --> X20 29 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.24.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --- X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --- X5 13 | 9. X4 --> X14 14 | 10. X4 --> X15 15 | 11. X5 --> X7 16 | 12. X5 --> X20 17 | 13. X6 --> X7 18 | 14. X7 --> X14 19 | 15. X8 --> X9 20 | 16. X8 --> X14 21 | 17. X8 --- X18 22 | 18. X8 --> X19 23 | 19. X10 --> X12 24 | 20. X10 --> X16 25 | 21. X11 --> X15 26 | 22. X12 --> X15 27 | 23. X15 --> X19 28 | 24. X16 --> X20 29 | -------------------------------------------------------------------------------- /cdmir/tests/test_anm.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from cdmir.discovery.funtional_based.anm.ANM import ANM 7 | 8 | 9 | class TestANM(TestCase): 10 | 11 | def test_anm_using_simulation(self): 12 | # simulated data y = 3^x + e 13 | np.random.seed(2025) 14 | X = np.random.uniform(size=10000) 15 | Y = np.power(X, 3) + np.random.uniform(size=10000) 16 | anm = ANM() 17 | nonindepscore_forward, nonindepscore_backward = anm.cause_or_effect(X, Y) 18 | 19 | assert nonindepscore_forward < nonindepscore_backward 20 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.25.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --> X5 13 | 9. X4 --> X14 14 | 10. X4 --> X15 15 | 11. X5 --> X7 16 | 12. X5 --> X20 17 | 13. X6 --> X7 18 | 14. X7 --> X14 19 | 15. X8 --> X9 20 | 16. X8 --> X14 21 | 17. X8 --> X18 22 | 18. X8 --> X19 23 | 19. X10 --> X12 24 | 20. X10 --> X16 25 | 21. X11 --> X14 26 | 22. X11 --> X15 27 | 23. X12 --> X15 28 | 24. X15 --> X19 29 | 25. X16 --> X20 30 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.25.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --- X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --- X5 13 | 9. X4 --> X14 14 | 10. X4 --> X15 15 | 11. X5 --> X7 16 | 12. X5 --> X20 17 | 13. X6 --> X7 18 | 14. X7 --> X14 19 | 15. X8 --> X9 20 | 16. X8 --> X14 21 | 17. X8 --- X18 22 | 18. X8 --> X19 23 | 19. X10 --> X12 24 | 20. X10 --> X16 25 | 21. X11 --> X14 26 | 22. X11 --> X15 27 | 23. X12 --> X15 28 | 24. X15 --> X19 29 | 25. X16 --> X20 30 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.26.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --> X5 13 | 9. X4 --> X6 14 | 10. X4 --> X14 15 | 11. X4 --> X15 16 | 12. X5 --> X7 17 | 13. X5 --> X20 18 | 14. X6 --> X7 19 | 15. X7 --> X14 20 | 16. X8 --> X9 21 | 17. X8 --> X14 22 | 18. X8 --> X18 23 | 19. X8 --> X19 24 | 20. X10 --> X12 25 | 21. X10 --> X16 26 | 22. X11 --> X14 27 | 23. X11 --> X15 28 | 24. X12 --> X15 29 | 25. X15 --> X19 30 | 26. X16 --> X20 31 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.26.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --- X4 6 | 2. X1 --> X13 7 | 3. X2 --> X13 8 | 4. X3 --> X6 9 | 5. X3 --> X9 10 | 6. X3 --> X10 11 | 7. X13 --> X3 12 | 8. X4 --- X5 13 | 9. X4 --> X6 14 | 10. X4 --> X14 15 | 11. X4 --> X15 16 | 12. X5 --> X7 17 | 13. X5 --> X20 18 | 14. X6 --> X7 19 | 15. X7 --> X14 20 | 16. X8 --> X9 21 | 17. X8 --> X14 22 | 18. X8 --- X18 23 | 19. X8 --> X19 24 | 20. X10 --> X12 25 | 21. X10 --> X16 26 | 22. X11 --> X14 27 | 23. X11 --> X15 28 | 24. X12 --> X15 29 | 25. X15 --> X19 30 | 26. X16 --> X20 31 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/indTest/FisherTest.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from scipy.stats import chi2 4 | 5 | 6 | # independent test by Fisher'method 7 | def FisherTest(pvals, alph=0.01): 8 | Fisher_Stat = 0 9 | L = len(pvals) 10 | for i in range(0, L): 11 | if pvals[i] == 0: 12 | TP = 1e-05 13 | else: 14 | TP = pvals[i] 15 | 16 | Fisher_Stat = Fisher_Stat - 2 * math.log(TP) 17 | 18 | Fisher_pval = 1 - chi2.cdf(Fisher_Stat, 2 * L) 19 | 20 | if Fisher_pval > alph: 21 | return True, Fisher_pval 22 | else: 23 | return False, Fisher_pval 24 | -------------------------------------------------------------------------------- /cdmir/tests/test_fisherz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cdmir.utils.independence import FisherZ 4 | 5 | def test_fisherz(): 6 | np.random.seed(10) 7 | X = np.random.randn(300, 1) 8 | X_prime = np.random.randn(300, 1) 9 | Y = X + 0.5 * np.random.randn(300, 1) 10 | Z = Y + 0.5 * np.random.randn(300, 1) 11 | data = np.hstack((X, X_prime, Y, Z)) 12 | 13 | f = FisherZ(data=data) 14 | p_value, stat = f.cal_stats(0, 3, [2]) 15 | assert p_value > 0.01 16 | 17 | p_value, stat = f.cal_stats(0, 3, None) 18 | assert p_value < 0.01 19 | 20 | p_value, stat = f.cal_stats(3, 1, None) 21 | assert p_value > 0.01 22 | 23 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.27.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X1 --> X8 7 | 3. X1 --> X13 8 | 4. X2 --> X13 9 | 5. X3 --> X6 10 | 6. X3 --> X9 11 | 7. X3 --> X10 12 | 8. X13 --> X3 13 | 9. X4 --> X5 14 | 10. X4 --> X6 15 | 11. X4 --> X14 16 | 12. X4 --> X15 17 | 13. X5 --> X7 18 | 14. X5 --> X20 19 | 15. X6 --> X7 20 | 16. X7 --> X14 21 | 17. X8 --> X9 22 | 18. X8 --> X14 23 | 19. X8 --> X18 24 | 20. X8 --> X19 25 | 21. X10 --> X12 26 | 22. X10 --> X16 27 | 23. X11 --> X14 28 | 24. X11 --> X15 29 | 25. X12 --> X15 30 | 26. X15 --> X19 31 | 27. X16 --> X20 32 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.27.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --- X4 6 | 2. X1 --> X8 7 | 3. X1 --> X13 8 | 4. X2 --> X13 9 | 5. X3 --> X6 10 | 6. X3 --> X9 11 | 7. X3 --> X10 12 | 8. X13 --> X3 13 | 9. X4 --- X5 14 | 10. X4 --> X6 15 | 11. X4 --> X14 16 | 12. X4 --> X15 17 | 13. X5 --> X7 18 | 14. X5 --> X20 19 | 15. X6 --> X7 20 | 16. X7 --> X14 21 | 17. X8 --> X9 22 | 18. X8 --> X14 23 | 19. X8 --- X18 24 | 20. X8 --> X19 25 | 21. X10 --> X12 26 | 22. X10 --> X16 27 | 23. X11 --> X14 28 | 24. X11 --> X15 29 | 25. X12 --> X15 30 | 26. X15 --> X19 31 | 27. X16 --> X20 32 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.28.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X1 --> X8 7 | 3. X1 --> X13 8 | 4. X1 --> X20 9 | 5. X2 --> X13 10 | 6. X3 --> X6 11 | 7. X3 --> X9 12 | 8. X3 --> X10 13 | 9. X13 --> X3 14 | 10. X4 --> X5 15 | 11. X4 --> X6 16 | 12. X4 --> X14 17 | 13. X4 --> X15 18 | 14. X5 --> X7 19 | 15. X5 --> X20 20 | 16. X6 --> X7 21 | 17. X7 --> X14 22 | 18. X8 --> X9 23 | 19. X8 --> X14 24 | 20. X8 --> X18 25 | 21. X8 --> X19 26 | 22. X10 --> X12 27 | 23. X10 --> X16 28 | 24. X11 --> X14 29 | 25. X11 --> X15 30 | 26. X12 --> X15 31 | 27. X15 --> X19 32 | 28. X16 --> X20 33 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.28.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --- X4 6 | 2. X1 --- X8 7 | 3. X1 --> X13 8 | 4. X1 --> X20 9 | 5. X2 --> X13 10 | 6. X3 --> X6 11 | 7. X3 --> X9 12 | 8. X3 --> X10 13 | 9. X13 --> X3 14 | 10. X4 --- X5 15 | 11. X4 --> X6 16 | 12. X4 --> X14 17 | 13. X4 --> X15 18 | 14. X5 --> X7 19 | 15. X5 --> X20 20 | 16. X6 --> X7 21 | 17. X7 --> X14 22 | 18. X8 --> X9 23 | 19. X8 --> X14 24 | 20. X8 --- X18 25 | 21. X8 --> X19 26 | 22. X10 --> X12 27 | 23. X10 --> X16 28 | 24. X11 --> X14 29 | 25. X11 --> X15 30 | 26. X12 --> X15 31 | 27. X15 --> X19 32 | 28. X16 --> X20 33 | -------------------------------------------------------------------------------- /cdmir/graph/edge.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from .mark import Mark 4 | 5 | 6 | class Edge(namedtuple('Edge', ['node_u', 'node_v', 'mark_u', 'mark_v'])): 7 | __slots__ = () 8 | 9 | def __new__(cls, node_u, node_v, mark_u=Mark.Tail, mark_v=Mark.Arrow): 10 | return super().__new__(cls, node_u, node_v, mark_u, mark_v) 11 | 12 | def __str__(self): 13 | return f'{self.node_u} {_lmark2ascii[self.mark_u]}-{_rmark2ascii[self.mark_v]} {self.node_v}' 14 | 15 | 16 | _lmark2ascii = { 17 | Mark.Tail: '-', 18 | Mark.Arrow: '<', 19 | Mark.Circle: 'o' 20 | } 21 | _rmark2ascii = { 22 | Mark.Tail: '-', 23 | Mark.Arrow: '>', 24 | Mark.Circle: 'o' 25 | } 26 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.29.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X4 6 | 2. X1 --> X8 7 | 3. X1 --> X13 8 | 4. X1 --> X20 9 | 5. X2 --> X3 10 | 6. X2 --> X13 11 | 7. X3 --> X6 12 | 8. X3 --> X9 13 | 9. X3 --> X10 14 | 10. X13 --> X3 15 | 11. X4 --> X5 16 | 12. X4 --> X6 17 | 13. X4 --> X14 18 | 14. X4 --> X15 19 | 15. X5 --> X7 20 | 16. X5 --> X20 21 | 17. X6 --> X7 22 | 18. X7 --> X14 23 | 19. X8 --> X9 24 | 20. X8 --> X14 25 | 21. X8 --> X18 26 | 22. X8 --> X19 27 | 23. X10 --> X12 28 | 24. X10 --> X16 29 | 25. X11 --> X14 30 | 26. X11 --> X15 31 | 27. X12 --> X15 32 | 28. X15 --> X19 33 | 29. X16 --> X20 34 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.29.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --- X4 6 | 2. X1 --- X8 7 | 3. X1 --> X13 8 | 4. X1 --> X20 9 | 5. X2 --> X3 10 | 6. X2 --> X13 11 | 7. X3 --> X6 12 | 8. X3 --> X9 13 | 9. X3 --> X10 14 | 10. X13 --> X3 15 | 11. X4 --- X5 16 | 12. X4 --> X6 17 | 13. X4 --> X14 18 | 14. X4 --> X15 19 | 15. X5 --> X7 20 | 16. X5 --> X20 21 | 17. X6 --> X7 22 | 18. X7 --> X14 23 | 19. X8 --> X9 24 | 20. X8 --> X14 25 | 21. X8 --- X18 26 | 22. X8 --> X19 27 | 23. X10 --> X12 28 | 24. X10 --> X16 29 | 25. X11 --> X14 30 | 26. X11 --> X15 31 | 27. X12 --> X15 32 | 28. X15 --> X19 33 | 29. X16 --> X20 34 | -------------------------------------------------------------------------------- /cdmir/utils/adapters.py: -------------------------------------------------------------------------------- 1 | from numpy import ndarray 2 | from pandas import DataFrame 3 | 4 | 5 | def data_form_converter_for_class_method(func): 6 | def wrapper(self, data, var_names=None, *args, **kwargs): 7 | assert type(data) in [DataFrame, ndarray] 8 | n = data.shape[1] 9 | if var_names is None: 10 | if type(data) == DataFrame: 11 | var_names = list(data.columns) 12 | data = data.to_numpy() 13 | else: 14 | var_names = [i for i in range(0, n)] 15 | else: 16 | assert len(var_names) == n 17 | return func(self=self, data=data, var_names=var_names, *args, **kwargs) 18 | 19 | return wrapper 20 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.30.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X3 6 | 2. X1 --> X4 7 | 3. X1 --> X8 8 | 4. X1 --> X13 9 | 5. X1 --> X20 10 | 6. X2 --> X3 11 | 7. X2 --> X13 12 | 8. X3 --> X6 13 | 9. X3 --> X9 14 | 10. X3 --> X10 15 | 11. X13 --> X3 16 | 12. X4 --> X5 17 | 13. X4 --> X6 18 | 14. X4 --> X14 19 | 15. X4 --> X15 20 | 16. X5 --> X7 21 | 17. X5 --> X20 22 | 18. X6 --> X7 23 | 19. X7 --> X14 24 | 20. X8 --> X9 25 | 21. X8 --> X14 26 | 22. X8 --> X18 27 | 23. X8 --> X19 28 | 24. X10 --> X12 29 | 25. X10 --> X16 30 | 26. X11 --> X14 31 | 27. X11 --> X15 32 | 28. X12 --> X15 33 | 29. X15 --> X19 34 | 30. X16 --> X20 35 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.30.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X3 6 | 2. X1 --- X4 7 | 3. X1 --- X8 8 | 4. X1 --> X13 9 | 5. X1 --> X20 10 | 6. X2 --> X3 11 | 7. X2 --> X13 12 | 8. X3 --> X6 13 | 9. X3 --> X9 14 | 10. X3 --> X10 15 | 11. X13 --> X3 16 | 12. X4 --- X5 17 | 13. X4 --> X6 18 | 14. X4 --> X14 19 | 15. X4 --> X15 20 | 16. X5 --> X7 21 | 17. X5 --> X20 22 | 18. X6 --> X7 23 | 19. X7 --> X14 24 | 20. X8 --> X9 25 | 21. X8 --> X14 26 | 22. X8 --- X18 27 | 23. X8 --> X19 28 | 24. X10 --> X12 29 | 25. X10 --> X16 30 | 26. X11 --> X14 31 | 27. X11 --> X15 32 | 28. X12 --> X15 33 | 29. X15 --> X19 34 | 30. X16 --> X20 35 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.31.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X3 6 | 2. X1 --> X4 7 | 3. X1 --> X8 8 | 4. X1 --> X13 9 | 5. X1 --> X20 10 | 6. X2 --> X3 11 | 7. X2 --> X13 12 | 8. X2 --> X20 13 | 9. X3 --> X6 14 | 10. X3 --> X9 15 | 11. X3 --> X10 16 | 12. X3 --> X13 17 | 13. X4 --> X5 18 | 14. X4 --> X6 19 | 15. X4 --> X14 20 | 16. X4 --> X15 21 | 17. X5 --> X7 22 | 18. X5 --> X20 23 | 19. X6 --> X7 24 | 20. X7 --> X14 25 | 21. X8 --> X9 26 | 22. X8 --> X14 27 | 23. X8 --> X18 28 | 24. X8 --> X19 29 | 25. X10 --> X12 30 | 26. X10 --> X16 31 | 27. X11 --> X14 32 | 28. X11 --> X15 33 | 29. X12 --> X15 34 | 30. X15 --> X19 35 | 31. X16 --> X20 36 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.31.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X3 6 | 2. X1 --- X4 7 | 3. X1 --- X8 8 | 4. X1 --> X13 9 | 5. X1 --> X20 10 | 6. X2 --> X3 11 | 7. X2 --> X13 12 | 8. X2 --> X20 13 | 9. X3 --> X6 14 | 10. X3 --> X9 15 | 11. X3 --> X10 16 | 12. X3 --- X13 17 | 13. X4 --- X5 18 | 14. X4 --> X6 19 | 15. X4 --> X14 20 | 16. X4 --> X15 21 | 17. X5 --> X7 22 | 18. X5 --> X20 23 | 19. X6 --> X7 24 | 20. X7 --> X14 25 | 21. X8 --> X9 26 | 22. X8 --> X14 27 | 23. X8 --- X18 28 | 24. X8 --> X19 29 | 25. X10 --> X12 30 | 26. X10 --> X16 31 | 27. X11 --> X14 32 | 28. X11 --> X15 33 | 29. X12 --> X15 34 | 30. X15 --> X19 35 | 31. X16 --> X20 36 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/dag.32.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X3 6 | 2. X1 --> X4 7 | 3. X1 --> X8 8 | 4. X1 --> X13 9 | 5. X1 --> X20 10 | 6. X2 --> X3 11 | 7. X2 --> X13 12 | 8. X2 --> X20 13 | 9. X3 --> X6 14 | 10. X3 --> X9 15 | 11. X3 --> X10 16 | 12. X3 --> X13 17 | 13. X3 --> X20 18 | 14. X4 --> X5 19 | 15. X4 --> X6 20 | 16. X4 --> X14 21 | 17. X4 --> X15 22 | 18. X5 --> X7 23 | 19. X5 --> X20 24 | 20. X6 --> X7 25 | 21. X7 --> X14 26 | 22. X8 --> X9 27 | 23. X8 --> X14 28 | 24. X8 --> X18 29 | 25. X8 --> X19 30 | 26. X10 --> X12 31 | 27. X10 --> X16 32 | 28. X11 --> X14 33 | 29. X11 --> X15 34 | 30. X12 --> X15 35 | 31. X15 --> X19 36 | 32. X16 --> X20 37 | -------------------------------------------------------------------------------- /cdmir/tests/testdata/graph_data/pdag.32.txt: -------------------------------------------------------------------------------- 1 | Graph Nodes: 2 | X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15;X16;X17;X18;X19;X20 3 | 4 | Graph Edges: 5 | 1. X1 --> X3 6 | 2. X1 --- X4 7 | 3. X1 --- X8 8 | 4. X1 --> X13 9 | 5. X1 --> X20 10 | 6. X2 --> X3 11 | 7. X2 --> X13 12 | 8. X2 --> X20 13 | 9. X3 --> X6 14 | 10. X3 --> X9 15 | 11. X3 --> X10 16 | 12. X3 --- X13 17 | 13. X3 --> X20 18 | 14. X4 --- X5 19 | 15. X4 --> X6 20 | 16. X4 --> X14 21 | 17. X4 --> X15 22 | 18. X5 --> X7 23 | 19. X5 --> X20 24 | 20. X6 --> X7 25 | 21. X7 --> X14 26 | 22. X8 --> X9 27 | 23. X8 --> X14 28 | 24. X8 --- X18 29 | 25. X8 --> X19 30 | 26. X10 --> X12 31 | 27. X10 --> X16 32 | 28. X11 --> X14 33 | 29. X11 --> X15 34 | 30. X12 --> X15 35 | 31. X15 --> X19 36 | 32. X16 --> X20 37 | -------------------------------------------------------------------------------- /cdmir/utils/independence/dsep.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | from cdmir.graph import DiGraph 4 | from cdmir.utils.independence.basic_independence import ConditionalIndependentTest 5 | 6 | 7 | class Dsep(ConditionalIndependentTest): 8 | def __init__(self, data, var_names=None, true_graph: DiGraph = None): 9 | super().__init__(data, var_names=var_names) 10 | assert true_graph is not None 11 | self.true_graph = true_graph 12 | 13 | def cal_stats(self, x_id: int, y_id: int, z_ids: Iterable[int] = None): 14 | zs_name = [self.var_names[z_id] for z_id in z_ids] 15 | if self.true_graph.is_d_separate(self.var_names[x_id], self.var_names[y_id], zs_name): 16 | return 1.0, 1.0 #相关 17 | else: 18 | return 0.0, 0.0 #独立 19 | -------------------------------------------------------------------------------- /cdmir/tests/test_ica_lingam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cdmir.discovery.funtional_based import ICA_LINGAM 4 | 5 | 6 | def test_ica_lingam(): 7 | size = 1000 8 | 9 | # generate data 10 | x1 = np.random.uniform(size=size) 11 | x2 = 3 * x1 + np.random.uniform(size=size) 12 | x3 = 2 * x1 + np.random.uniform(size=size) 13 | x4 = 4 * x2 + 4 * x3 + np.random.uniform(size=size) 14 | mat = np.asarray([x1, x2, x3, x4]).T 15 | 16 | gt = np.asarray([[0, 0, 0, 0], 17 | [1, 0, 0, 0], 18 | [1, 0, 0, 0], 19 | [0, 1, 1, 0]]) 20 | 21 | model = ICA_LINGAM(wald_alpha=.5) 22 | model.fit(mat) 23 | 24 | assert np.all((model.get_coef() > 1) == gt) 25 | assert np.all((model.get_causal_graph() == 1) == gt) 26 | -------------------------------------------------------------------------------- /cdmir/tests/test_gaussian.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cdmir.utils.kernel.gaussian import GaussianKernel 4 | 5 | #arr1 = np.array([[1, 2, 3, 4], [1, 3, 4, 5]]) 6 | #arr2 = np.array([[1, 2, 3, 4], [3, 4, 5, 6]]) 7 | 8 | def test_gaussian_case1(): #shape(x)[1]==shape(y)[1] 9 | np.random.seed(10) 10 | xs=np.random.randn(20,20) 11 | ys=np.random.randn(30,20) 12 | 13 | gk=GaussianKernel() 14 | gk(xs,ys) 15 | print(gk.__call__(xs,ys)) 16 | # print(gk(xs,ys)) 17 | 18 | 19 | def test_gaussian_case2():#y is none 20 | np.random.seed(10) 21 | xs = np.random.randn(20, 20) 22 | #ys = np.array([]) 23 | 24 | gk = GaussianKernel() 25 | gk(xs,ys=None) 26 | print(gk.__call__(xs,ys=None)) 27 | # print(gk(xs, ys=None)) 28 | 29 | test_gaussian_case1() 30 | # test_gaussian_case2() 31 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /cdmir/tests/test_GeneralMarginalScore.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pandas import DataFrame 3 | from cdmir.utils.local_score.marginal_base import GeneralMarginalScore 4 | 5 | def test_score_function(): 6 | data = np.array([[5.1, 3.5, 1.4, 0.2], 7 | [4.9, 3, 1.4, 0.2], 8 | [4.7, 3.2, 1.3, 0.2], 9 | [4.6, 3.1, 1.5, 0.2], 10 | [5.4, 3.9, 1.7, 0.4], 11 | [4.9, 3.1, 1.5, 0.1], 12 | [5.8, 4, 1.2, 0.2]]) 13 | data_frame = DataFrame(data, columns=[f'Feature_{i}' for i in range(1, 5)]) 14 | 15 | generalMarginalscore = GeneralMarginalScore(data_frame) 16 | 17 | # 选择要测试的变量和其父变量的索引 18 | variable_index = 2 19 | parent_indices = [0, 1] 20 | score = generalMarginalscore(variable_index, parent_indices) 21 | 22 | assert isinstance(score, float) 23 | assert score is not None, "GeneralMarginalScore calculation failed." -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/indTest/TestObject.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from scipy.stats import norm as normaldist 4 | 5 | 6 | class TestObject(object): 7 | def __init__(self, test_type, streaming=False, freeze_data=False): 8 | self.test_type = test_type 9 | self.streaming = streaming 10 | self.freeze_data = freeze_data 11 | if self.freeze_data: 12 | self.generate_data() 13 | assert not self.streaming 14 | 15 | @abstractmethod 16 | def compute_Zscore(self): 17 | raise NotImplementedError 18 | 19 | @abstractmethod 20 | def generate_data(self): 21 | raise NotImplementedError 22 | 23 | def compute_pvalue(self): 24 | Z_score = self.compute_Zscore() 25 | pvalue = normaldist.sf(Z_score) 26 | return pvalue 27 | 28 | def perform_test(self, alpha): 29 | pvalue = self.compute_pvalue() 30 | return pvalue < alpha 31 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/indTest/independence.py: -------------------------------------------------------------------------------- 1 | import cdmir.discovery.funtional_based.LearningHierarchicalStructure.mutual as MI 2 | import numpy as np 3 | import pandas as pd 4 | from sklearn import metrics 5 | 6 | 7 | # estimating mutual information by sklearn 8 | def independent11(x1, y1): 9 | x = x1.copy() 10 | y = y1.copy() 11 | length = len(x) 12 | x = list(x) 13 | y = list(y) 14 | 15 | result_NMI = metrics.normalized_mutual_info_score(x, y) 16 | print(result_NMI) 17 | return result_NMI 18 | 19 | 20 | # estimating mutual information by Non-parametric computation of entropy and mutual-information 21 | def independent(x1, y1): 22 | x = x1.copy() 23 | y = y1.copy() 24 | length = len(x) 25 | x = x.reshape(length, 1) 26 | y = y.reshape(length, 1) 27 | 28 | if length > 3000: 29 | k = 15 30 | else: 31 | k = 10 32 | 33 | mi = MI.mutual_information((x, y), k) 34 | return abs(mi) 35 | -------------------------------------------------------------------------------- /cdmir/tests/test_GeneralCVScore.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pandas import DataFrame 3 | from cdmir.utils.local_score.cross_validated_base import GeneralCVScore 4 | 5 | def test_score_function(): 6 | data = np.array([[5.1, 3.5, 1.4, 0.2], 7 | [4.9, 3, 1.4, 0.2], 8 | [4.7, 3.2, 1.3, 0.2], 9 | [4.6, 3.1, 1.5, 0.2], 10 | [5.4, 3.9, 1.7, 0.4], 11 | [4.9, 3.1, 1.5, 0.1], 12 | [5.8, 4, 1.2, 0.2]]) 13 | data_frame = DataFrame(data, columns=[f'Feature_{i}' for i in range(1, 5)]) 14 | 15 | lambda_value = 0.01 16 | k_fold = 10 17 | general_cv_score = GeneralCVScore(data_frame, lambda_value, k_fold) 18 | 19 | # 选择要测试的变量和其父变量的索引 20 | variable_index = 2 21 | parent_indices = [0, 1] 22 | score = general_cv_score(variable_index, parent_indices) 23 | 24 | assert isinstance(score, float) 25 | assert score is not None, "GeneralCVScore calculation failed." -------------------------------------------------------------------------------- /cdmir/tests/test_MultiCVScore.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pandas import DataFrame 3 | from cdmir.utils.local_score.cross_validated_base import MultiCVScore 4 | 5 | def test_score_function(): 6 | data = np.array([[5.1, 3.5, 1.4, 0.2], 7 | [4.9, 3, 1.4, 0.2], 8 | [4.7, 3.2, 1.3, 0.2], 9 | [4.6, 3.1, 1.5, 0.2], 10 | [5.4, 3.9, 1.7, 0.4], 11 | [4.9, 3.1, 1.5, 0.1], 12 | [5.8, 4, 1.2, 0.2]]) 13 | data_frame = DataFrame(data, columns=[f'Feature_{i}' for i in range(1, 5)]) 14 | 15 | lambda_value = 0.01 16 | k_fold = 10 17 | d_label = {0: 0, 1: 1, 2: 2} 18 | multi_cv_score = MultiCVScore(data_frame, lambda_value, k_fold, d_label) 19 | 20 | # 选择要测试的变量和其父变量的索引 21 | variable_index = 2 22 | parent_indices = [0, 1] 23 | score = multi_cv_score(variable_index, parent_indices) 24 | 25 | assert isinstance(score, float) 26 | assert score is not None, "MultiCVScore calculation failed." -------------------------------------------------------------------------------- /cdmir/graph/dag2cpdag.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations, permutations 2 | 3 | from cdmir.graph import DiGraph, PDAG 4 | 5 | 6 | def dag2cpdag(dag: DiGraph): 7 | n = dag.number_of_nodes() 8 | pdag = PDAG(list(dag.nodes)) 9 | pdag.create_complete_undirected_graph() 10 | sep_set = {(node_u, node_v): set() for node_u, node_v in permutations(pdag.nodes, 2)} 11 | 12 | for node_u, node_v in combinations(pdag.nodes, 2): 13 | for condition_size in range(0, n - 1): 14 | for nodes_z in combinations(pdag.node_set - {node_u, node_v}, condition_size): 15 | if dag.is_d_separate(node_u, node_v, nodes_z): 16 | sep_set[(node_u, node_v)] |= set(nodes_z) 17 | sep_set[(node_v, node_u)] |= set(nodes_z) 18 | if pdag.is_connected(node_u, node_v): 19 | pdag.remove_edge(node_u, node_v) 20 | 21 | pdag.rule0(sep_set=sep_set, verbose=True) 22 | 23 | pdag.orient_by_meek_rules(verbose=True) 24 | 25 | return pdag 26 | -------------------------------------------------------------------------------- /cdmir/tests/test_plot_graph.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from cdmir.visual.graph_layout import circular_layout 4 | from cdmir.visual.plot_graph import plot_graph 5 | 6 | from cdmir.graph import Edge, Graph, Mark 7 | 8 | 9 | class TestPlotGraph(TestCase): 10 | def test_case1(self): 11 | g = Graph(range(2)) 12 | plot_graph(g, circular_layout) 13 | 14 | def test_case2(self): 15 | g = Graph() 16 | g.add_node('dmir') 17 | g.add_node(1) 18 | g.add_node('2') 19 | g.add_node('X3') 20 | g.add_node('?') 21 | g.add_node('L3') 22 | plot_graph(g, circular_layout) 23 | 24 | def test_case3(self): 25 | g = Graph([f'X{i+1}' for i in range(14)]) 26 | g.add_edge(Edge('X1', 'X3', Mark.Tail, Mark.Arrow)) 27 | g.add_edge(Edge('X1', 'X11', Mark.Circle, Mark.Arrow)) 28 | g.add_edge(Edge('X1', 'X14', Mark.Arrow, Mark.Arrow)) 29 | 30 | # g.add_edge(Edge('X1', 'X20', Mark.Arrow, Mark.Arrow)) 31 | plot_graph(g, circular_layout, is_latent=['X1', 'X2']) 32 | -------------------------------------------------------------------------------- /cdmir/utils/local_score/_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | from numpy import ndarray 6 | from pandas import DataFrame 7 | 8 | 9 | class BaseLocalScoreFunction(object): 10 | 11 | def __init__(self, data: ndarray | DataFrame, *args, **kwargs): 12 | if type(data) == ndarray: 13 | self.data = data 14 | else: 15 | self.data = data.values 16 | self.cache_dict = dict() 17 | 18 | def _score(self, i: int, parent_i: Iterable[int], score_function): 19 | dict_key = hash(str((i, parent_i))) 20 | if self.cache_dict.__contains__(dict_key): 21 | res = self.cache_dict[dict_key] 22 | else: 23 | res = score_function(i, parent_i) 24 | self.cache_dict[dict_key] = res 25 | 26 | return res 27 | 28 | def _score_function(self, i: int, parent_i: Iterable[int]): 29 | raise NotImplementedError() 30 | 31 | def __call__(self, i: int, parent_i: Iterable[int], *args, **kwargs): 32 | return self._score(i, parent_i, self._score_function) 33 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to CDMIR's documentation! 2 | ================================= 3 | 4 | **CDMIR** is a Python package for causal discovery, modeling, and reasoning, 5 | developed by the DMIR Lab at Guangdong University of Technology (GDUT). 6 | It provides a unified framework for building, evaluating, and interpreting 7 | causal models, with user-friendly APIs that support both methodological 8 | research and real-world applications. 9 | 10 | .. note:: 11 | 12 | CDMIR is under active development. For source code and usage examples, 13 | please refer to our `GitHub repository `_. 14 | 15 | Contents 16 | ======== 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | :caption: Getting started 21 | 22 | getting_started 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :caption: Discovery methods 27 | 28 | discovery_methods/index 29 | 30 | .. toctree:: 31 | :maxdepth: 1 32 | :caption: Effect methods 33 | 34 | effect_methods/index 35 | 36 | .. toctree:: 37 | :maxdepth: 1 38 | :caption: Utilities 39 | 40 | utilities_index/index 41 | -------------------------------------------------------------------------------- /docs/source/discovery_methods/functional_based/ANM/anm.rst: -------------------------------------------------------------------------------- 1 | ANM (Additive Noise Model) 2 | ============================ 3 | 4 | Introduction 5 | ------------ 6 | 7 | ANM is a functional-based causal discovery method that identifies causal relationships by testing the independence between the cause and the noise term in an additive noise model. 8 | 9 | Usage 10 | ----- 11 | 12 | .. code-block:: python 13 | 14 | from cdmir.discovery.funtional_based.anm.ANM import ANM 15 | # Initialize ANM model 16 | anm = ANM() 17 | # Test causal direction 18 | nonindepscore_forward, nonindepscore_backward = anm.cause_or_effect(X, Y) 19 | 20 | 21 | Parameters 22 | ---------- 23 | 24 | - **x**: Input data array of shape (n,) or (n, 1) 25 | - **y**: Output data array of shape (n,) or (n, 1) 26 | 27 | Returns 28 | ------- 29 | 30 | - **nonindepscore_forward**: HSIC statistic in the X→Y direction 31 | - **nonindepscore_backward**: HSIC statistic in the Y→X direction 32 | 33 | 34 | References 35 | ---------- 36 | 37 | .. [1] Hoyer, Patrik O., et al. "Nonlinear causal discovery with additive noise models." Advances in Neural Information Processing Systems. 2008. 38 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/README.md: -------------------------------------------------------------------------------- 1 | ## Some General Identification Results for Linear Latent Hierarchical Causal Structure -IJCAI2023 2 | 3 | ## Overview 4 | This project learns the causal structure of the linear latent hierarchical model, including both the relations among latent variables and those between latent and observed variables. 5 | 6 | 7 | 8 | ## Main Function 9 | 10 | Causal_Discovery_in_LHM.py : Causal_Discovery_LHM(data, alpha=0.01) 11 | 12 | Input: 13 | data: DataFrame (pandas) 14 | the observational data set 15 | alpha: float 16 | the signification level of independence 17 | 18 | Output: 19 | LatentIndex: dic 20 | the relations between each latent and their direct measured set 21 | Graph (selected) 22 | the Causal graph of hierarchical structure 23 | 24 | 25 | 26 | One may use the "test_LHS.py" to test our method, in which a latent tree structure is simulated. 27 | 28 | 29 | ## Notes 30 | Our method relies heavily on independence tests. One may carefully adjust some parameters, like kernel width, in the kerpy.GaussianKernel, to ensure accuracy. 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from setuptools import find_packages, setup 4 | 5 | if sys.version_info < (3, 6): 6 | sys.exit("Sorry, Python < 3.6 is not supported.") 7 | 8 | try: 9 | long_description = open("README.md").read() 10 | except IOError: 11 | long_description = "" 12 | 13 | setup( 14 | name="CDMIR", 15 | version="0.1.1", 16 | description="A pip package", 17 | license="GPL", 18 | author="DMIRLab", 19 | packages=find_packages(exclude=["tests", "tests.*"]), 20 | install_requires=[ 21 | 'numpy', 22 | 'pandas', 23 | 'scipy>=1.7.3', 24 | 'scikit-learn', 25 | 'torch>=1.7.1', 26 | 'networkx', 27 | 'matplotlib', 28 | ], 29 | extra_require={ 30 | "hawkes":[ 31 | "tick; python_version < '3.12'" 32 | ] 33 | }, 34 | long_description=long_description, 35 | long_description_content_type='text/markdown', 36 | classifiers=[ 37 | 'Operating System :: OS Independent', 38 | "Programming Language :: Python", 39 | "Programming Language :: Python :: 3.7", 40 | "Programming Language :: Python :: 3.8", 41 | ], 42 | python_requires='>=3.6', 43 | ) 44 | -------------------------------------------------------------------------------- /cdmir/utils/kernel/polynomial.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from numpy import ndarray 4 | 5 | from cdmir.utils.kernel._base import BaseKernel 6 | 7 | 8 | class PolynomialKernel(BaseKernel): 9 | 10 | def __init__(self, degree: int = 2, const: float = 1.0): 11 | super().__init__() 12 | self.degree = degree 13 | self.const = const 14 | 15 | def __kernel_func(self, x: ndarray, y: ndarray): 16 | return pow(self.const + x.dot(y.T), self.degree)#const是加减上的调整,degree是控制多项式的次数 17 | 18 | def __kernel(self, xs: ndarray, ys: ndarray, *args, **kwargs): 19 | dict_key = hash(str((xs, ys, self.degree, self.const))) # add 'degree' and 'const' to cache 20 | if self.cache_dict.__contains__(dict_key): 21 | res = self.cache_dict[dict_key] 22 | else: 23 | res = self.__kernel(xs, ys) #Plan B:改成 res = self.__kernel_func(xs,ys) 24 | self.cache_dict[dict_key] = res 25 | 26 | return res 27 | 28 | def __call__(self, xs: ndarray, ys: ndarray, *args, **kwargs): 29 | return self._BaseKernel__kernel(xs, ys, self.__kernel_func) #Plan A:把self.__kernel改成return self._BaseKernel__kernel(xs, ys, self.__kernel_func) 30 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | import os 9 | import sys 10 | sys.path.insert(0, os.path.abspath('..')) 11 | 12 | project = 'CDMIR' 13 | copyright = '2025, Wei Chen' 14 | author = 'Wei Chen' 15 | release = '0.1.0' 16 | 17 | # -- General configuration --------------------------------------------------- 18 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 19 | 20 | extensions = [ 21 | "sphinx.ext.autodoc", 22 | "sphinx.ext.napoleon", 23 | "sphinx.ext.viewcode", 24 | ] 25 | 26 | templates_path = ['_templates'] 27 | exclude_patterns = [] 28 | 29 | 30 | 31 | # -- Options for HTML output ------------------------------------------------- 32 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 33 | 34 | html_theme = "sphinx_rtd_theme" 35 | html_static_path = ['_static'] 36 | -------------------------------------------------------------------------------- /cdmir/tests/test_kci.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cdmir.utils.independence.kernel_based import KCI 4 | from cdmir.utils.kernel import GaussianKernel 5 | 6 | np.random.seed(10) 7 | X = np.random.randn(300, 1) 8 | X_prime = np.random.randn(300, 1) 9 | Y = X + 0.5 * np.random.randn(300, 1) 10 | Z = Y + 0.5 * np.random.randn(300, 1) 11 | data = np.hstack((X, X_prime, Y, Z)) 12 | 13 | kernel = GaussianKernel(width_strategy=GaussianKernel.WidthStrategyEnum.empirical_kci) 14 | kci = KCI(data=data, kernel_x=kernel, kernel_y=kernel) 15 | p_value, stat = kci(0, 3, 2) 16 | 17 | 18 | # from cdmir.utils.independence import KCI 19 | # from cdmir.utils.kernel import GaussianKernel 20 | # 21 | # 22 | # def test_kci(): 23 | # np.random.seed(10) 24 | # X = np.random.randn(300, 1) 25 | # X_prime = np.random.randn(300, 1) 26 | # Y = X + 0.5 * np.random.randn(300, 1) 27 | # Z = Y + 0.5 * np.random.randn(300, 1) 28 | # data = np.hstack((X, X_prime, Y, Z)) 29 | # 30 | # kernel = GaussianKernel(width_strategy=GaussianKernel.WidthStrategyEnum.empirical_kci) 31 | # kci = KCI(data=data, kernel_x=kernel, kernel_y=kernel) 32 | # p_value, stat = kci(0, 3, 2) 33 | # assert p_value > 0.01 34 | # 35 | # p_value, stat = kci(0, 3, None) 36 | # assert p_value < 0.01 37 | # 38 | # p_value, stat = kci(3, 1, None) 39 | # assert p_value > 0.01 40 | 41 | -------------------------------------------------------------------------------- /cdmir/utils/local_score/bic_score.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable 4 | 5 | from numpy import corrcoef, ix_, log, mat, ndarray, shape 6 | from numpy.linalg import inv 7 | from pandas import DataFrame 8 | 9 | from ._base import BaseLocalScoreFunction 10 | 11 | 12 | class BICScore(BaseLocalScoreFunction): 13 | 14 | def __init__(self, data: ndarray | DataFrame, *args, **kwargs): 15 | super().__init__(data, *args, **kwargs) 16 | self.cov = corrcoef(data.T) 17 | self.sample_count = shape(data)[0] 18 | if not kwargs.__contains__('lambda_value'): 19 | self.lambda_value = 1 20 | else: 21 | self.lambda_value = kwargs["lambda_value"] 22 | 23 | def _score_function(self, i: int, parent_i: Iterable[int]): 24 | parent_i = list(parent_i) 25 | 26 | if len(parent_i) == 0: 27 | return self.sample_count * log(self.cov[i, i]) 28 | 29 | yX = mat(self.cov[ix_([i], parent_i)]) 30 | XX = mat(self.cov[ix_(parent_i, parent_i)]) 31 | H = log(self.cov[i, i] - yX * inv(XX) * yX.T) 32 | 33 | return -(self.sample_count * H + log(self.sample_count) * len(parent_i) * self.lambda_value).item() 34 | 35 | def __call__(self, i: int, parent_i: Iterable[int], *args, **kwargs): 36 | return self._score(i, parent_i, self._score_function) 37 | 38 | -------------------------------------------------------------------------------- /cdmir/tests/test_desp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cdmir.discovery.constraint.pc import PC 4 | from cdmir.graph.digraph import DiGraph 5 | from cdmir.graph.edge import Edge 6 | from cdmir.utils.independence import Dsep, FisherZ 7 | 8 | def test_desp(): 9 | np.random.seed(10) 10 | X = np.random.randn(300, 1) # 生成 300*1 的随机矩阵 11 | X_prime = np.random.randn(300, 1) #生成和X不一样的300*1的随机数矩阵 12 | Y = X + 0.5 * np.random.randn(300, 1) 13 | Z = Y + 0.5 * np.random.randn(300, 1) 14 | data =np.hstack((X,X_prime,Y,Z)) #将数组按水平方向堆叠起来 15 | 16 | cg=PC() 17 | cg.fit(data,indep_cls=FisherZ) 18 | print(cg.causal_graph) 19 | # print(type(cg.causal_graph)) 20 | node_list=list(range(data.shape[1])) 21 | 22 | dag=DiGraph(range(len(node_list))) 23 | edges=[(0, 2),(2, 3)] 24 | for edge in edges: 25 | dag.add_edge(Edge(*edge)) #(*)以元组的形式传入参数 26 | 27 | d = Dsep(data=data,true_graph=dag) 28 | 29 | # 0 && 2, 2 && 3 dependent 30 | a,b = d.cal_stats(0,2,[3]) 31 | print(a,b) 32 | if a==0.0 and b==0.0: 33 | print("a and b dependent") 34 | elif a==1.0 and b==1.0: 35 | print("a and b independent") 36 | #assert a==0.0 and b==0.0 37 | 38 | a, b = d.cal_stats(0, 3, [1,2]) 39 | print(a,b) 40 | if a==1.0 and b==1.0 : 41 | print("a and b independent") 42 | #assert a==1.0 and b==1.0 43 | 44 | test_desp() -------------------------------------------------------------------------------- /cdmir/utils/independence/functional/fisherz.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | from numpy import array, concatenate, corrcoef, ix_, log1p, ndarray, reshape, sqrt 4 | from numpy.linalg import inv 5 | from scipy.stats import norm 6 | 7 | 8 | def fisherz(x, y, Z=None): 9 | if Z is None: 10 | data = array([x, y]).T 11 | else: 12 | x = reshape(x, (-1, 1)) if x.ndim == 1 else x 13 | y = reshape(y, (-1, 1)) if y.ndim == 1 else y 14 | Z = reshape(Z, (-1, 1)) if Z.ndim == 1 else Z 15 | data = concatenate((x, y, Z), axis=1) 16 | corr = corrcoef(data, rowvar=False) 17 | num_records = data.shape[0] 18 | if Z is None: 19 | return fisherz_from_corr(corr, num_records, 0, 1) 20 | else: 21 | return fisherz_from_corr(corr, num_records, 0, 1, list(range(2, 2 + Z.shape[1]))) 22 | 23 | 24 | def fisherz_from_corr(corr: ndarray, num_records: int, x_id: int, y_id: int, z_ids: Iterable[int] = None): 25 | z_ids = [] if z_ids is None else z_ids 26 | var = [x_id, y_id] + z_ids 27 | sub_corr = corr[ix_(var, var)] 28 | inv_mat = inv(sub_corr) 29 | stats = -inv_mat[0, 1] / sqrt(abs(inv_mat[0, 0] * inv_mat[1, 1])) 30 | abs_stats = min(0.9999999, abs(stats)) 31 | z = 1 / 2 * log1p(2 * abs_stats / (1 - abs_stats)) 32 | X = sqrt(num_records - len(z_ids) - 3) * abs(z) 33 | pval = 2 * (1 - norm.cdf(abs(X))) 34 | return pval, stats 35 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | branches: [ main, development ] 6 | pull_request: 7 | branches: [ main, development ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | python-version: [ "3.7", "3.8" ] 17 | os: [ ubuntu-latest ] 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Setup SWIG 25 | uses: mmomtchev/setup-swig@v3 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install flake8 pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | python -m pytest -------------------------------------------------------------------------------- /cdmir/tests/test_hawkes_simulator.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from cdmir.datasets.simlulators import HawkesSimulator 7 | from cdmir.datasets.utils import erdos_renyi, generate_lag_transitions 8 | 9 | 10 | class TestHawkesSimulator(TestCase): 11 | 12 | def test_INSEM_data(self): 13 | sample_size = 10000 14 | lambda_x = 1 15 | theta = 0.5 16 | lambda_e = 1 17 | seed = 42 18 | insem_data = HawkesSimulator.INSEM_data(sample_size, lambda_x, theta, lambda_e, seed) 19 | assert isinstance(insem_data, pd.DataFrame) 20 | assert insem_data.shape[1] == 3 21 | 22 | def test_generate_data(self): 23 | mu_range_str = '0,1' 24 | alpha_range_str = '0,1' 25 | n = 3 26 | sample_size = 30000 27 | out_degree_rate = 1.5 28 | NE_num = 40 29 | decay = 0.1 30 | seed = 42 31 | event_table, edge_mat, alpha, mu, events = HawkesSimulator.generate_data( 32 | n, mu_range_str, alpha_range_str, sample_size, out_degree_rate, NE_num, decay, seed) 33 | 34 | assert isinstance(event_table, pd.DataFrame) 35 | assert event_table.shape[1] == 3 # Three columns: seq_id, time_stamp, event_type 36 | assert edge_mat.shape == (n, n) 37 | assert alpha.shape == (n, n) 38 | assert mu.shape[0] == n 39 | assert len(events) == NE_num 40 | -------------------------------------------------------------------------------- /docs/source/discovery_methods/constraint/PBSCM/pbscm.rst: -------------------------------------------------------------------------------- 1 | PBSCM (Poisson Branching Structural Causal Model) 2 | ============================================= 3 | 4 | Introduction 5 | ------------ 6 | 7 | PBSCM is a functional-based causal discovery method that identifies causal relationships using high-order cumulants with path analysis. It assumes that the causal relationships can be modeled as a Poisson branching process. 8 | 9 | 10 | Usage 11 | ----- 12 | 13 | .. code-block:: python 14 | 15 | 16 | from cdmir.discovery.funtional_based.PBSCM.PB_SCM import PB_SCM 17 | 18 | # Initialize PBSCM model 19 | pbscm = PB_SCM(data) 20 | 21 | # Get causal graph 22 | causal_graph = pbscm.get_causal_graph(alpha=0.04, max_order=4, threshold=0) 23 | 24 | 25 | Parameters 26 | ---------- 27 | 28 | **PB_SCM Class Initialization** 29 | 30 | - **data**: Input data matrix (n_samples x n_variables) 31 | 32 | **get_causal_graph** 33 | 34 | 35 | - **alpha**: Confidence level (default: 0.04) 36 | - **max_order**: The maximum order of Lambda_k (default: 4) 37 | - **threshold**: Threshold when bootstrap test fails (default: 0) 38 | 39 | Returns 40 | ------- 41 | 42 | - **causal_graph**: Causal matrix of the data 43 | 44 | References 45 | ---------- 46 | 47 | .. [1] Qiao J, Xiang Y, Chen Z, et al. Causal discovery from poisson branching structural causal model using high-order cumulant with path analysis[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2024, 38(18): 20524-20531. 48 | -------------------------------------------------------------------------------- /docs/source/discovery_methods/constraint/PBSCM_PGF/pbscm_pgf.rst: -------------------------------------------------------------------------------- 1 | PBSCM_PGF (Poisson Branching Structural Causal Model using Probability Generating Function) 2 | ========================================================================================= 3 | 4 | Introduction 5 | ------------ 6 | 7 | PBSCM_PGF is a causal discovery method that identifies causal relationships using Probability Generating Functions (PGF). It extends the Poisson Branching Structural Causal Model (PBSCM) by leveraging PGFs to analyze the structure of causal relationships in data. 8 | 9 | 10 | Usage 11 | ----- 12 | 13 | .. code-block:: python 14 | 15 | from cdmir.effect.PBSCM_PGF.PB_SCM_PGF import PBSCM_PGF 16 | 17 | # Initialize PBSCM_PGF model 18 | pbscm_pgf = PBSCM_PGF(data) 19 | 20 | # Learn the causal graph 21 | causal_graph = pbscm_pgf.get_causal_graph() 22 | 23 | 24 | 25 | 26 | Parameters 27 | ---------- 28 | 29 | **PBSCM_PGF Class Initialization** 30 | 31 | - **data**: Input data matrix with shape (n_samples, n_features) 32 | 33 | 34 | Returns 35 | ------- 36 | 37 | - **causal_graph**: Causal matrix of the data, where causal_graph[i][j] = 1 indicates a causal relationship from variable i to variable j, and 0 otherwise 38 | 39 | 40 | References 41 | ---------- 42 | 43 | .. [1] Xiang Y, Qiao J, Liang Z, et al. On the identifiability of poisson branching structural causal model using probability generating function[J]. Advances in Neural Information Processing Systems, 2024, 37: 11664-11699. 44 | -------------------------------------------------------------------------------- /cdmir/utils/kernel/_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from numpy import eye, ndarray, shape 4 | from numpy.linalg import pinv 5 | 6 | 7 | class BaseKernel(object): 8 | 9 | def __init__(self, *args, **kwargs): 10 | self.cache_dict = dict() 11 | 12 | def __call__(self, xs: ndarray, ys: ndarray, *args, **kwargs): 13 | return self.__kernel(xs, ys, self.__kernel_func) 14 | 15 | def __kernel(self, xs: ndarray, ys: ndarray, kernel_func, *args, **kwargs): 16 | dict_key = hash(str((xs, ys))) 17 | if self.cache_dict.__contains__(dict_key): 18 | res = self.cache_dict[dict_key] 19 | else: 20 | res = kernel_func(xs, ys) 21 | self.cache_dict[dict_key] = res 22 | 23 | return res 24 | 25 | def __kernel_func(self, data_x: ndarray, data_y: ndarray): 26 | raise NotImplementedError() 27 | 28 | @staticmethod 29 | def center_kernel_matrix(K: ndarray): 30 | n = shape(K)[0] 31 | K_colsums = K.sum(axis=0) 32 | K_allsum = K_colsums.sum() 33 | return K - (K_colsums[None, :] + K_colsums[:, None]) / n + (K_allsum / n ** 2) 34 | 35 | @staticmethod 36 | def center_kernel_matrix_regression(K: ndarray, Kz: ndarray, epsilon: float): 37 | """ 38 | Centers the kernel matrix via a centering matrix R=I-Kz(Kz+\epsilonI)^{-1} and returns RKR 39 | """ 40 | n = shape(K)[0] 41 | Rz = epsilon * pinv(Kz + epsilon * eye(n)) 42 | return Rz.dot(K.dot(Rz)), Rz 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CDMIR: Causal Discovery, Modeling and Reasoning 2 | 3 | CDMIR is a python package for causal discovery, modeling and reasoning that is developed from [DMIR Lab](https://dmir.gdut.edu.cn/), GDUT. 4 | 5 | The package is under active development. Feedbacks (issues, suggestiong, etc.) are highly encouraged. 6 | ![causal-discovery](./images/causal-discovery.png) 7 | 8 | ## Requirements 9 | 10 | - Python3 11 | - numpy<=1.26.4 12 | - pandas 13 | - scipy>=1.8.1 14 | - scikit-learn 15 | - torch>=1.7.1 16 | - networkx 17 | - matplotlib 18 | - tick 19 | - igraph 20 | - lingam 21 | - pgmpy<1.0.0 22 | - tensorly=0.8.1 23 | - tqdm>=4 24 | - KDEpy>=1.1.12,<2 25 | - statsmodels=0.14.4 26 | 27 | 28 | ## Contributors 29 | 30 | Team Leaders: Ruichu Cai, Zhifeng Hao 31 | 32 | Coordinators: Wei Chen 33 | 34 | Thanks to the following developers for this project: 35 | 36 | - [@chenweiDelight](https://github.com/chenweiDelight) 37 | - [@zhi-yi-huang](https://github.com/zhi-yi-huang) 38 | - [@miumiujiang12138](https://github.com/miumiujiang12138) 39 | - [@WeilinChen507](https://github.com/WeilinChen507) 40 | - [@wean2016](https://github.com/wean2016) 41 | - [@Jie-Qiao](https://github.com/Jie-Qiao) 42 | - [@jinshi201](https://github.com/jinshi201) 43 | - [@kanseaveg](https://github.com/kanseaveg) 44 | - [@Boyle-Coffee](https://github.com/) 45 | 46 | Please feel free to open an issue if you find anything unexpected. And please create pull requests, perhaps after passing unittests in 'tests/', if you would like to contribute to cdmir. We are always targeting to make our community better! 47 | -------------------------------------------------------------------------------- /cdmir/tests/test_datasets_utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | 4 | from cdmir.datasets.utils import erdos_renyi, generate_lag_transitions, np2nx 5 | 6 | 7 | def test_erdos_renyi(): 8 | n_nodes = 10 9 | n_edges = 20 10 | seed = 10 11 | dag = erdos_renyi(n_nodes, n_edges, seed=seed) 12 | assert isinstance(dag, np.ndarray) 13 | assert dag.shape == (n_nodes, n_nodes) 14 | assert ((dag == 0.0) + (dag == 1.0)).all() 15 | assert nx.is_directed_acyclic_graph(np2nx(dag, create_using=nx.DiGraph)) 16 | 17 | 18 | def test_erdos_renyi_weights(): 19 | n_nodes = 10 20 | n_edges = 20 21 | seed = 10 22 | weight_range = (0.5, 2.0) 23 | dag = erdos_renyi(n_nodes, n_edges, weight_range, seed) 24 | assert isinstance(dag, np.ndarray) 25 | assert dag.shape == (n_nodes, n_nodes) 26 | assert ((np.abs(dag) == 0.0) + ((np.abs(dag) >= weight_range[0]) * (np.abs(dag) < weight_range[1]))).all() 27 | assert nx.is_directed_acyclic_graph(np2nx(dag, create_using=nx.DiGraph)) 28 | 29 | 30 | def test_erdos_renyi_error(): 31 | n_nodes = -10 32 | n_edges = 20 33 | seed = 10 34 | try: 35 | dag = erdos_renyi(n_nodes, n_edges, seed=seed) 36 | except Exception as e: 37 | assert isinstance(e, AssertionError) 38 | 39 | 40 | def test_generate_lag_transitions(): 41 | n_nodes = 10 42 | max_lag = 3 43 | seed = 10 44 | transitions = generate_lag_transitions(n_nodes, max_lag, seed=seed) 45 | assert isinstance(transitions, np.ndarray) 46 | assert transitions.shape == (max_lag, n_nodes, n_nodes) 47 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/indTest/HSICtestImpure.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------- 2 | # Name: 模块1 3 | # Purpose: 4 | # 5 | # Author: YY 6 | # 7 | # Created: 25/10/2021 8 | # Copyright: (c) YY 2021 9 | # Licence: 10 | # ------------------------------------------------------------------------------- 11 | import LearningHierarchicalStructure.GIN2 as GIN 12 | import numpy as np 13 | import pandas as pd 14 | import LearningHierarchicalStructure.Paper_simulation as SD 15 | 16 | 17 | def main22(): 18 | data = SD.CaseIV(20000) 19 | # print(data.columns) 20 | # X11 X4 X6->X12 21 | X = ['x11', 'x6', 'x12'] 22 | Z = ['x4', 'x6'] 23 | GIN.getomega(data, X, Z) 24 | 25 | X = ['x11', 'x6', 'x12'] 26 | Z = ['x4', 'x12'] 27 | 28 | GIN.getomega(data, X, Z) 29 | 30 | 31 | def main(): 32 | Num = 30000 33 | L = np.random.uniform(size=Num) 34 | L2 = np.random.uniform(size=Num) + L * ToBij() 35 | x1 = np.random.uniform(size=Num) * 0.2 + L * ToBij() 36 | x2 = np.random.uniform(size=Num) * 0.2 + L * ToBij() + x1 * ToBij() 37 | x3 = np.random.uniform(size=Num) * 0.2 + L2 * ToBij() 38 | x4 = np.random.uniform(size=Num) * 0.2 + L2 * ToBij() 39 | data = pd.DataFrame(np.array([x1, x2, x3, x4]).T, columns=['x1', 'x2', 'x3', 'x4']) 40 | 41 | X = ['x1', 'x2', 'x3'] 42 | Z = ['x1', 'x4'] 43 | 44 | GIN.getomega(data, X, Z) 45 | 46 | X = ['x1', 'x2', 'x3'] 47 | Z = ['x2', 'x4'] 48 | GIN.getomega(data, X, Z) 49 | 50 | 51 | def ToBij(): 52 | ten = np.random.randint(0, 2) 53 | s = np.random.random() 54 | while abs(s) < 0.5 and ten == 0: 55 | s = np.random.random() 56 | result = ten + s 57 | if np.random.randint(0, 10) > 5: 58 | result = -1 * result 59 | return round(result, 3) 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /cdmir/effect/DoublyRobust/src/run.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ $1 -eq 1 ]; then 3 | for id in {0,1,2,3,4} ; do 4 | CUDA_VISIBLE_DEVICES=1 python main.py --model=TargetedModel_DoubleBSpline --dataset=BC --expID=$id --flipRate=1 --lr_1step=0.0001 --lr_2step=0.01 --num_grid=20 --beta=20 --epochs=160 --tr_knots=0.1 --alpha=0.5 --gamma=1.0 5 | done 6 | fi 7 | 8 | if [ $1 -eq 2 ]; then 9 | for id in {0,1,2,3,4} ; do 10 | CUDA_VISIBLE_DEVICES=1 python main.py --model=TargetedModel_DoubleBSpline --dataset=BC_hete --expID=$id --flipRate=1 --lr_1step=0.001 --lr_2step=0.0001 --num_grid=20 --beta=20 --epochs=160 --tr_knots=0.05 --alpha=0.5 --gamma=1.0 11 | done 12 | fi 13 | 14 | if [ $1 -eq 3 ]; then 15 | for id in {0,1,2,3,4} ; do 16 | CUDA_VISIBLE_DEVICES=1 python main.py --model=TargetedModel_DoubleBSpline --dataset=Flickr --expID=$id --flipRate=1 --lr_1step=0.0001 --lr_2step=0.0001 --num_grid=20 --beta=20 --epochs=160 --tr_knots=0.25 --alpha=0.5 --gamma=0.5 17 | done 18 | fi 19 | 20 | if [ $1 -eq 4 ]; then 21 | for id in {0,1,2,3,4} ; do 22 | CUDA_VISIBLE_DEVICES=1 python main.py --model=TargetedModel_DoubleBSpline --dataset=Flickr_hete --expID=$id --flipRate=1 --lr_1step=0.001 --lr_2step=0.001 --num_grid=20 --beta=20 --epochs=160 --tr_knots=0.2 --alpha=0.5 --gamma=1.0 23 | done 24 | fi 25 | 26 | 27 | if [ $1 -eq 5 ]; then 28 | for id in {0,1,2,3,4} ; do 29 | CUDA_VISIBLE_DEVICES=1 python main.py --model=TargetedModel_DoubleBSpline --dataset=BC_hete_z --expID=$id --flipRate=1 --lr_1step=0.001 --lr_2step=0.001 --num_grid=20 --beta=20 --epochs=160 --tr_knots=0.2 --alpha=1. --gamma=1.0 30 | done 31 | fi 32 | 33 | if [ $1 -eq 6 ]; then 34 | for id in {0,1,2,3,4} ; do 35 | CUDA_VISIBLE_DEVICES=1 python main.py --model=TargetedModel_DoubleBSpline --dataset=Flickr_hete_z --expID=$id --flipRate=1 --lr_1step=0.001 --lr_2step=0.001 --num_grid=20 --beta=20 --epochs=160 --tr_knots=0.2 --alpha=1. --gamma=1.0 36 | done 37 | fi 38 | 39 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/GIN2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import lingam.hsic as hsic 4 | 5 | import cdmir.discovery.funtional_based.LearningHierarchicalStructure.indTest.independence as ID 6 | 7 | 8 | # GIN by fast HSIC 9 | # X=['X1','X2'] 10 | # Z=['X3'] 11 | # data.type=Pandas.DataFrame 12 | def GIN(X, Z, data, alpha=0.05): 13 | omega = getomega(data, X, Z) 14 | tdata = data[X] 15 | result = np.dot(omega, tdata.T) 16 | for i in Z: 17 | temp = np.array(data[i]) 18 | 19 | pvalue = hsic.hsic_test_gamma(result.T, temp, alpha) 20 | 21 | if pvalue > alpha: 22 | flag = True 23 | else: 24 | flag = False 25 | 26 | if not flag: 27 | return False 28 | 29 | return True 30 | 31 | 32 | # mthod 1: estimating mutual information by k nearest neighbors (density estimation) 33 | # mthod 2: estimating mutual information by sklearn package 34 | def GIN_MI(X, Z, data, method='1'): 35 | omega = getomega(data, X, Z) 36 | tdata = data[X] 37 | result = np.dot(omega, tdata.T) 38 | MIS = 0 39 | for i in Z: 40 | 41 | temp = np.array(data[i]) 42 | if method == '1': 43 | mi = ID.independent(result.T, temp) 44 | else: 45 | mi = ID.independent11(result.T, temp) 46 | MIS += mi 47 | MIS = MIS / len(Z) 48 | 49 | return MIS 50 | 51 | 52 | def getomega(data, X, Z): 53 | cov_m = np.cov(data, rowvar=False) 54 | col = list(data.columns) 55 | Xlist = [] 56 | Zlist = [] 57 | for i in X: 58 | t = col.index(i) 59 | Xlist.append(t) 60 | for i in Z: 61 | t = col.index(i) 62 | Zlist.append(t) 63 | B = cov_m[Xlist] 64 | B = B[:, Zlist] 65 | A = B.T 66 | u, s, v = np.linalg.svd(A) 67 | lens = len(X) 68 | omega = v.T[:, lens - 1] 69 | omegalen = len(omega) 70 | omega = omega.reshape(1, omegalen) 71 | 72 | return omega -------------------------------------------------------------------------------- /cdmir/utils/independence/basic_independence.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | from cdmir.utils.adapters import data_form_converter_for_class_method 4 | 5 | 6 | def _stringize_list(z: Iterable[int] = None): 7 | if z is None: 8 | return '' 9 | return ', '.join([str(s) for s in z]) 10 | 11 | 12 | def _get_cache_key(x: int, y: int, z: Iterable[int] = None): 13 | if x > y: 14 | x, y = y, x 15 | if z is None: 16 | return f'I({x}, {y})' 17 | else: 18 | sSlist = sorted(list(z)) 19 | return f'I({x}, {y} | {_stringize_list(sSlist)})' 20 | 21 | 22 | class ConditionalIndependentTest(object): 23 | 24 | @data_form_converter_for_class_method 25 | def __init__(self, data, var_names=None): 26 | self._data, self.var_names = data, var_names 27 | self.name_id = {name: i for i, name in enumerate(self.var_names)} 28 | self.cache = dict() 29 | 30 | def _get_name_id(self, name): 31 | return self.name_id[name] 32 | 33 | def test(self, x, y, z: Iterable = None): 34 | x_id, y_id = map(self._get_name_id, (x, y)) 35 | z_ids = list(map(self._get_name_id, z)) if z is not None else None 36 | return self.itest(x_id, y_id, z_ids) 37 | 38 | def itest(self, x_id: int, y_id: int, z_ids: Iterable[int] = None): 39 | key = _get_cache_key(x_id, y_id, z_ids) 40 | if self.cache.get(key): 41 | value = self.cache[key] 42 | else: 43 | value = self.cal_stats(x_id, y_id, z_ids) 44 | self.cache[key] = value 45 | # print(f'{key}: {value}') 46 | return value 47 | 48 | def cal_stats(self, x: int, y: int, z: Iterable[int] = None): 49 | ''' 50 | Calculate a Statistic with associated p-value. 51 | Parameters 52 | ---------- 53 | x : int 54 | variable index 55 | y : int 56 | variable index 57 | S : tuple 58 | variables index 59 | Returns 60 | ------- 61 | ''' 62 | raise NotImplementedError 63 | -------------------------------------------------------------------------------- /docs/source/getting_started.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | Getting started 3 | ============= 4 | 5 | 6 | Installation 7 | ^^^^^^^^^^^^ 8 | 9 | Requirements 10 | ------------ 11 | 12 | * python 3 (>=3.7) 13 | * numpy 14 | * networkx 15 | * pandas 16 | * scipy 17 | * scikit-learn 18 | * statsmodels 19 | * pydot 20 | * torch>=1.7.1 21 | * tick 22 | * igraph 23 | * lingam 24 | * pgmpy<1.0.0 25 | * tensorly=0.8.1 26 | * tqdm>=4 27 | * KDEpy>=1.1.12,<2 28 | 29 | 30 | 31 | (For visualization) 32 | 33 | * matplotlib 34 | * graphviz 35 | * pygraphviz (might not support the most recent Mac) 36 | 37 | 38 | Install via PyPI 39 | ------------ 40 | 41 | To use CDMIR, we could install it using `pip `_: 42 | 43 | .. code-block:: console 44 | 45 | (.venv) $ pip install CDMIR 46 | 47 | 48 | Install from source 49 | ------------ 50 | 51 | For development version, please kindly refer to our `GitHub Repository `_. 52 | 53 | 54 | Running examples 55 | ^^^^^^^^^^^^ 56 | 57 | For search methods in causal discovery, there are various running examples in the 'tests' directory in our `GitHub Repository `_, 58 | such as test_pc.py and test_desp.py. 59 | 60 | For the implemented modules, such as independent test methods, we provide unit tests for the convenience of developing your own methods. 61 | 62 | 63 | Contributors 64 | ^^^^^^^^^^^^ 65 | 66 | **Team Leaders**: Ruichu Cai, Zhifeng Hao 67 | 68 | **Coordinators**: Wei Chen 69 | 70 | **Developers**: 71 | 72 | - [@chenweiDelight](https://github.com/chenweiDelight) 73 | - [@zhi-yi-huang](https://github.com/zhi-yi-huang) 74 | - [@miumiujiang12138](https://github.com/miumiujiang12138) 75 | - [@WeilinChen507](https://github.com/WeilinChen507) 76 | - [@wean2016](https://github.com/wean2016) 77 | - [@Jie-Qiao](https://github.com/Jie-Qiao) 78 | - [@jinshi201](https://github.com/jinshi201) 79 | - [@kanseaveg](https://github.com/kanseaveg) 80 | - [@Boyle-Coffee](https://github.com/) 81 | 82 | 83 | **Quality control**: Wei Chen, Zhiyi Huang, Qian Huang 84 | 85 | -------------------------------------------------------------------------------- /cdmir/utils/independence/fisherz.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable, List, Tuple 4 | 5 | from numpy import corrcoef 6 | 7 | from cdmir.utils.independence._base import BaseConditionalIndependenceTest 8 | from cdmir.utils.independence.basic_independence import ConditionalIndependentTest 9 | from cdmir.utils.independence.functional import fisherz_from_corr 10 | 11 | 12 | class FisherZ(ConditionalIndependentTest): 13 | def __init__(self, data, var_names=None): 14 | super().__init__(data, var_names=var_names) 15 | self._num_records = data.shape[0] 16 | self._corr = corrcoef(self._data, rowvar=False) 17 | 18 | def cal_stats(self, x_id: int, y_id: int, z_ids: Iterable[int] = None): 19 | return fisherz_from_corr(corr=self._corr, num_records=self._num_records, x_id=x_id, y_id=y_id, z_ids=z_ids) 20 | 21 | 22 | 23 | # class FisherZ(BaseConditionalIndependenceTest): 24 | # def __init__(self, data): 25 | # super().__init__(data) 26 | # self._num_records = data.shape[0] 27 | # self._corr = corrcoef(self._data, rowvar=False) 28 | # 29 | # def __compute_p_value_with_condition(self, x_ids: List[int], y_ids: List[int], z_ids: List[int]) -> Tuple[float, float | ndarray | None]: 30 | # stat, p_value = fisherz_from_corr(corr=self._corr, num_records=self._num_records, x=x_ids[0], y=y_ids[0], S=z_ids) 31 | # return p_value, stat 32 | # 33 | # def __compute_p_value_without_condition(self, x_ids: List[int], y_ids: List[int]) -> Tuple[float, float | ndarray | None]: 34 | # stat, p_value = fisherz_from_corr(corr=self._corr, num_records=self._num_records, x=x_ids[0], y=y_ids[0], S=None) 35 | # return p_value, stat 36 | # 37 | # def __call__(self, xs: int | str | List[int | str] | ndarray, 38 | # ys: int | str | List[int | str] | ndarray, 39 | # zs: int | str | List[int | str] | ndarray | None = None, *args, **kwargs) -> Tuple[float, float | ndarray | None]: 40 | # return self._compute_p_value(xs, ys, zs, self.__compute_p_value_without_condition, self.__compute_p_value_with_condition) 41 | 42 | -------------------------------------------------------------------------------- /docs/source/discovery_methods/functional_based/SHP/shp.rst: -------------------------------------------------------------------------------- 1 | SHP (Structural Hawkes Processes) 2 | ================================= 3 | 4 | Introduction 5 | ------------ 6 | 7 | SHP is a functional-based causal discovery method for learning causal structure from discrete-time event sequences. It extends the traditional Hawkes process by incorporating structural learning to identify causal relationships between event types. 8 | 9 | Usage 10 | ----- 11 | 12 | .. code-block:: python 13 | 14 | from cdmir.discovery.funtional_based.SHP.SHP import SHP 15 | 16 | # Initialize SHP model 17 | shp = SHP( 18 | event_table=event_table, 19 | decay=0.35, 20 | time_interval=5, 21 | penalty='BIC', 22 | reg=0.85, 23 | seed=2025 24 | ) 25 | 26 | # Train the model with Hill Climb search 27 | likelihood, alpha, mu = shp.train_model(hill_climb=True) 28 | 29 | Parameters 30 | ---------- 31 | 32 | **SHP Class Initialization** 33 | 34 | - **event_table**: Event data table containing columns 'seq_id', 'time_stamp', and 'event_type' 35 | - **decay**: The decay coefficient of the exponential kernel 36 | - **time_interval**: Time interval size for discretizing timestamps (default: None) 37 | - **init_structure**: Initial adjacency matrix for causal structure (default: None) 38 | - **penalty**: Penalty method, either 'BIC' or 'AIC' (default: 'BIC') 39 | - **seed**: Random seed for reproducibility (default: None) 40 | - **reg**: Regularization parameter (default: 3.0) 41 | 42 | **train_model Parameters** 43 | 44 | - **hill_climb**: Whether to use Hill Climb search for structure learning (default: True) 45 | 46 | Returns 47 | ------- 48 | 49 | - **likelihood**: Likelihood value of the learned model 50 | - **alpha**: Matrix of causal effect parameters between event types, where :math:`\alpha_{ij}` represents the causal effect from event type :math:`i` to :math:`j` 51 | - **mu**: Vector of base intensity parameters for each event type, where :math:`\mu_j` is the base intensity for event type :math:`j` 52 | 53 | References 54 | ---------- 55 | 56 | .. [1] Qiao J, Cai R, Wu S, et al. Structural hawkes processes for learning causal structure from discrete-time event sequences[J]. arXiv preprint arXiv:2305.05986, 2023. 57 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/indTest/HSICPermutationTestObject.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on 17 Nov 2015 3 | 4 | @author: qinyi 5 | ''' 6 | import time 7 | 8 | from cdmir.discovery.funtional_based.LearningHierarchicalStructure.indTest.HSICTestObject import HSICTestObject 9 | 10 | 11 | class HSICPermutationTestObject(HSICTestObject): 12 | 13 | def __init__(self, num_samples, data_generator=None, kernelX=None, kernelY=None, kernelX_use_median=False, 14 | kernelY_use_median=False, num_rfx=None, num_rfy=None, rff=False, 15 | induce_set=False, num_inducex=None, num_inducey=None, num_shuffles=1000, unbiased=True): 16 | HSICTestObject.__init__(self, num_samples, data_generator=data_generator, kernelX=kernelX, kernelY=kernelY, 17 | kernelX_use_median=kernelX_use_median, kernelY_use_median=kernelY_use_median, 18 | num_rfx=num_rfx, num_rfy=num_rfy, rff=rff, induce_set=induce_set, 19 | num_inducex=num_inducex, num_inducey=num_inducey) 20 | self.num_shuffles = num_shuffles 21 | self.unbiased = unbiased 22 | 23 | def compute_pvalue_with_time_tracking(self, data_x=None, data_y=None): 24 | if data_x is None and data_y is None: 25 | if not self.streaming and not self.freeze_data: 26 | start = time.perf_counter() 27 | self.generate_data() 28 | data_generating_time = time.perf_counter() - start 29 | data_x = self.data_x 30 | data_y = self.data_y 31 | else: 32 | data_generating_time = 0. 33 | else: 34 | data_generating_time = 0. 35 | 36 | print('Permutation data generating time passed: ', data_generating_time) 37 | 38 | hsic_statistic, null_samples, _, _, _, _, _ = self.HSICmethod(unbiased=self.unbiased, 39 | num_shuffles=self.num_shuffles, 40 | data_x=data_x, data_y=data_y) 41 | pvalue = (1 + sum(null_samples > hsic_statistic)) / float(1 + self.num_shuffles) 42 | 43 | return pvalue, data_generating_time 44 | -------------------------------------------------------------------------------- /docs/source/discovery_methods/constraint/pc/pc.rst: -------------------------------------------------------------------------------- 1 | PC (Peter-Clark Algorithm) 2 | ========================== 3 | 4 | Introduction 5 | ------------ 6 | 7 | PC is a constraint-based causal discovery algorithm that infers causal relationships between variables from observational data. It starts with a complete undirected graph and iteratively removes edges based on conditional independence tests, then applies a set of rules to orient edges, resulting in a Partially Directed Acyclic Graph (PDAG) that represents causal relationships. 8 | 9 | Usage 10 | ----- 11 | 12 | .. code-block:: python 13 | 14 | from cdmir.discovery.constraint.pc import PC 15 | from cdmir.utils.independence import ConditionalIndependentTest 16 | 17 | # Initialize PC algorithm with default parameters 18 | pc = PC(alpha=0.05, verbose=False) 19 | 20 | # Fit the model to data 21 | pc.fit(data, var_names, ConditionalIndependentTest) 22 | 23 | # Access results 24 | causal_graph = pc.causal_graph 25 | skeleton = pc.skeleton 26 | sep_set = pc.sep_set 27 | 28 | Parameters 29 | ---------- 30 | 31 | PC Class Parameters: 32 | 33 | - alpha: Significance level for independence tests (default: 0.05) 34 | - adjacency_search_method: Function for adjacency search phase (default: adjacency_search) 35 | - verbose: Whether to print algorithm progress (default: False) 36 | 37 | fit() Method Parameters: 38 | 39 | - data: Input dataset containing variable observations 40 | - var_names: List of variable names corresponding to the columns in data 41 | - indep_cls: Conditional independence test class implementing the ConditionalIndependentTest interface 42 | - args: Positional arguments passed to the independence test constructor 43 | - kwargs: Keyword arguments passed to the independence test constructor 44 | 45 | Returns 46 | ------- 47 | 48 | - causal_graph: Partially Directed Acyclic Graph (PDAG) representing inferred causal relationships 49 | - skeleton: Undirected graph representing the skeleton of causal relationships 50 | - sep_set: Separation sets for node pairs, stored as a dictionary where keys are node pairs and values are sets of separating nodes 51 | 52 | References 53 | ---------- 54 | 55 | [1] Spirtes, P., Glymour, C. N., Scheines, R., & Heckerman, D. (2000). Causation, prediction, and search. MIT press. -------------------------------------------------------------------------------- /cdmir/tests/test_tensorrank.py: -------------------------------------------------------------------------------- 1 | # pgmpy<1.0.0 tensorly<0.9.0 2 | import cdmir.discovery.Tensor_Rank.LearnCausalCluster as LCC 3 | import cdmir.discovery.Tensor_Rank.DiscretePC as PC 4 | import random 5 | import pandas as pd 6 | 7 | from cdmir.datasets.pgmdata import Gdata2 8 | import pkg_resources 9 | 10 | 11 | ''' 12 | A toy example illustrating the tensor rank condition for learning discrete latent variable models with a three-pure-children structure is presented. 13 | The proposed two-stage algorithm first identifies causal clusters from the observed variables and then infers the d-separation relationships among the latent variables. 14 | 15 | Reference: 16 | [1] Chen Z, Cai R, Xie F, et al. Learning Discrete Latent Variable Structures with Tensor Rank Conditions[C]//The Thirty-eighth Annual Conference on Neural Information Processing Systems. 17 | 18 | ''' 19 | def main(): 20 | #Causal Cluster learning by Tensor Rank Condition, ground truth: [['O1a', 'O1b', 'O1c'], ['O2a', 'O2b', 'O2c']] 21 | test1() 22 | 23 | 24 | #Test the causal skeleton learning, ground truth: L1-L2-L3 25 | test2() 26 | 27 | 28 | 29 | def test1(): 30 | data = Gdata2(100000) 31 | 32 | 33 | print(data.columns) 34 | 35 | 36 | Cluster = LCC.LearnCausalCluster(data) 37 | 38 | print('##########-----------------------------------------') 39 | print('The learned causal cluster is : ',Cluster) 40 | print('##########-----------------------------------------') 41 | assert Cluster == [['O1a', 'O1b', 'O1c'], ['O2a', 'O2b', 'O2c']] 42 | 43 | 44 | def test2(): 45 | 46 | csv_path = pkg_resources.resource_filename('cdmir', 'tests/testdata/out.csv') 47 | data = pd.read_csv(csv_path) 48 | 49 | labels = ['L1','L2','L3'] 50 | 51 | cluster = {'L1':['O1a','O1b', 'O1c'],'L2':['O2a','O2b', 'O2c'],'L3':['O3a','O3b', 'O3c']} 52 | 53 | 54 | p = PC.test(data,labels,cluster) 55 | 56 | m = p+0 57 | 58 | print('###########-----------------------------------------') 59 | print('The adjacent matrix among L1, L2 and L3 is: ') 60 | print(m) 61 | print('###########-----------------------------------------') 62 | assert m.tolist() == [[0,1,0],[1,0,1],[0,1,0]] 63 | 64 | 65 | 66 | if __name__ == '__main__': 67 | main() -------------------------------------------------------------------------------- /cdmir/discovery/constraint/PBSCM_PGF/CCARankTest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from statsmodels.multivariate.cancorr import CanCorr 3 | from math import log, pow 4 | from scipy.stats import chi2 5 | 6 | 7 | class CCARankTester: 8 | def __init__(self, data, alpha=0.05): 9 | 10 | # Centre the data 11 | data = data - np.mean(data, axis=0) 12 | self.data = np.array(data) 13 | self.n = data.shape[0] 14 | self.alpha = alpha 15 | 16 | # Test null hypothesis that rank is less than or equal to r 17 | # Return True if reject 18 | def test(self, pcols, qcols, r=1): 19 | 20 | p = len(pcols) 21 | q = len(qcols) 22 | X = self.data[:, pcols] 23 | Y = self.data[:, qcols] 24 | cca = CanCorr(X, Y, tolerance=1e-8) 25 | l = cca.cancorr[r:] 26 | 27 | testStat = 0 28 | for li in l: 29 | testStat += log(1 - pow(li, 2)) 30 | testStat = testStat * -(self.n - 0.5*(p+q+3)) 31 | 32 | dfreedom = (p-r) * (q-r) 33 | criticalValue = chi2.ppf(1-self.alpha, dfreedom) 34 | # print(f"testStat: {testStat}, crit: {criticalValue}") 35 | 36 | return testStat > criticalValue 37 | 38 | def test_my(self, matrix, r=1): 39 | 40 | U, S, Vt = np.linalg.svd(matrix, full_matrices=False) 41 | p = 2 42 | q = 2 43 | 44 | l = S[r:] 45 | testStat = 0 46 | for li in l: 47 | testStat += log(1 - pow(li, 2)) 48 | # print(S, pow(l[0], 2), testStat) 49 | testStat = testStat * -(self.n - 0.5*(p+q+3)) 50 | 51 | dfreedom = (p-r) * (q-r) 52 | criticalValue = chi2.ppf(1-self.alpha, dfreedom) 53 | # print(f"testStat: {testStat}, crit: {criticalValue}") 54 | 55 | return testStat > criticalValue, [testStat, S[1], pow(l[0], 2)] 56 | 57 | 58 | def prob(self, pcols, qcols, r=1): 59 | p = len(pcols) 60 | q = len(qcols) 61 | X = self.data[:, pcols] 62 | Y = self.data[:, qcols] 63 | 64 | cca = CanCorr(X, Y) 65 | l = cca.cancorr[r:] 66 | 67 | testStat = 0 68 | for li in l: 69 | testStat += log(1 - pow(li, 2)) 70 | testStat = testStat * -(self.n - 0.5*(p+q+3)) 71 | 72 | dfreedom = (p-r) * (q-r) 73 | criticalValue = chi2.ppf(1-self.alpha, dfreedom) 74 | #print(f"testStat: {testStat}, crit: {criticalValue}") 75 | 76 | return 1-chi2.cdf(testStat, dfreedom) 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /docs/source/discovery_methods/LaHiCaSI/LaHiCaSI.rst: -------------------------------------------------------------------------------- 1 | LaHiCaSI (Latent Hierarchical Causal Structure Learning) 2 | ========================================================== 3 | 4 | Introduction 5 | ------------ 6 | 7 | LaHiCaSI is a causal discovery method that focuses on learning hierarchical causal structures in the presence of latent variables. It operates in two main phases: first locating latent variables by identifying causal clusters, and then inferring the causal structure among these latent variables. 8 | 9 | Usage 10 | ----- 11 | 12 | .. code-block:: python 13 | 14 | from cdmir.discovery.LaHiCaSl.LaHiCaSl import Latent_Hierarchical_Causal_Structure_Learning 15 | import pandas as pd 16 | import numpy as np 17 | 18 | # Load or generate your dataset 19 | # Example: Generate random data with 10 variables and 1000 samples 20 | data = pd.DataFrame(np.random.randn(1000, 10), columns=[f'X{i}' for i in range(10)]) 21 | 22 | # Set significance level 23 | alpha = 0.05 24 | 25 | # Run LaHiCaSI algorithm 26 | Latent_Hierarchical_Causal_Structure_Learning(data, alpha) 27 | 28 | Parameters 29 | ---------- 30 | 31 | - **data**: Dataset of observed variables, typically a pandas DataFrame or numpy array. 32 | - **alpha**: Statistical significance threshold (default: 0.05), used to determine the significance of causal relationships during the learning process. 33 | 34 | Returns 35 | ------- 36 | 37 | The function prints the resulting causal structure in the form of an adjacency matrix. It also generates intermediate results during the two-phase learning process. 38 | 39 | Algorithm Overview 40 | ------------------ 41 | 42 | LaHiCaSI consists of two main phases: 43 | 44 | 1. **Phase I: Locate latent variables** 45 | - **Stage I-S1**: Identify global causal clusters using `IdentifyGlobalCausalClusters` 46 | - **Stage I-S2**: Determine latent variables by merging clusters using `Determine_Latent_Variables` 47 | - **Stage I-S3**: Update active data and cluster information using `UpdateActiveData` 48 | 49 | 2. **Phase II: Infer causal structure among latent variables** 50 | - Use `LocallyInferCausalStructure` to learn the causal relationships between the identified latent variables 51 | 52 | The algorithm iteratively identifies clusters of variables that share common latent causes, updates the data representation to include these latent variables, and then infers the causal structure among them. 53 | 54 | References 55 | ---------- 56 | 57 | [1] Xie F, Huang B, Chen Z, et al. Generalized independent noise condition for estimating causal structure with latent variables[J]. Journal of Machine Learning Research, 2024, 25(191): 1-61. -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/indTest/HSICSpectralTestObject.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on 15 Nov 2015 3 | 4 | @author: qinyi 5 | ''' 6 | import time 7 | 8 | import numpy as np 9 | from cdmir.discovery.funtional_based.LearningHierarchicalStructure.indTest.HSICTestObject import HSICTestObject 10 | 11 | 12 | class HSICSpectralTestObject(HSICTestObject): 13 | 14 | def __init__(self, num_samples, data_generator=None, 15 | kernelX=None, kernelY=None, kernelX_use_median=False, kernelY_use_median=False, 16 | rff=False, num_rfx=None, num_rfy=None, induce_set=False, num_inducex=None, num_inducey=None, 17 | num_nullsims=1000, unbiased=False): 18 | HSICTestObject.__init__(self, num_samples, data_generator=data_generator, kernelX=kernelX, kernelY=kernelY, 19 | kernelX_use_median=kernelX_use_median, kernelY_use_median=kernelY_use_median, 20 | num_rfx=num_rfx, num_rfy=num_rfy, rff=rff, 21 | induce_set=induce_set, num_inducex=num_inducex, num_inducey=num_inducey) 22 | self.num_nullsims = num_nullsims 23 | self.unbiased = unbiased 24 | 25 | def get_null_samples_with_spectral_approach(self, Mx, My): 26 | lambdax, lambday = self.get_spectrum_on_data(Mx, My) 27 | Dx = len(lambdax) 28 | Dy = len(lambday) 29 | null_samples = np.zeros(self.num_nullsims) 30 | for jj in range(self.num_nullsims): 31 | zz = np.random.randn(Dx, Dy) ** 2 32 | if self.unbiased: 33 | zz = zz - 1 34 | null_samples[jj] = np.dot(lambdax.T, np.dot(zz, lambday)) 35 | return null_samples 36 | 37 | def compute_pvalue_with_time_tracking(self, data_x=None, data_y=None): 38 | if data_x is None and data_y is None: 39 | if not self.streaming and not self.freeze_data: 40 | start = time.perf_counter() 41 | self.generate_data() 42 | data_generating_time = time.perf_counter() - start 43 | data_x = self.data_x 44 | data_y = self.data_y 45 | else: 46 | data_generating_time = 0. 47 | else: 48 | data_generating_time = 0. 49 | # print 'data generating time passed: ', data_generating_time 50 | hsic_statistic, _, _, _, Mx, My, _ = self.HSICmethod(unbiased=self.unbiased, data_x=data_x, data_y=data_y) 51 | null_samples = self.get_null_samples_with_spectral_approach(Mx, My) 52 | pvalue = (1 + sum(null_samples > self.num_samples * hsic_statistic)) / float(1 + self.num_nullsims) 53 | return pvalue, data_generating_time 54 | -------------------------------------------------------------------------------- /cdmir/tests/tests_graph/test_digraph.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from copy import deepcopy 3 | from unittest import TestCase 4 | 5 | from cdmir.graph import DiGraph, Edge, Mark 6 | 7 | logging.basicConfig(level=logging.DEBUG, 8 | format=' %(levelname)s :: %(message)s', 9 | datefmt='%m/%d/%Y %I:%M:%S %p') 10 | 11 | sample_dag = DiGraph(range(5)) 12 | edges = [(0, 1), (0, 2), (1, 3), (2, 3), (3, 4)] 13 | for edge in edges: 14 | sample_dag.add_edge(Edge(*edge)) 15 | 16 | class TestDiGraph(TestCase): 17 | def test_case1(self): 18 | assert 3 not in sample_dag.get_reachable_nodes(0, [1, 2]) 19 | assert 2 not in sample_dag.get_reachable_nodes(1, [0]) 20 | assert 2 in sample_dag.get_reachable_nodes(1, [0, 3]) 21 | assert 2 in sample_dag.get_reachable_nodes(1, [0, 4]) 22 | 23 | def test_case2(self): 24 | parents = dict() 25 | children = dict() 26 | for i in range(5) : 27 | parents[i] = [] 28 | children[i] = [] 29 | for u, v in edges : 30 | parents[v].append(u) 31 | children[u].append(v) 32 | 33 | for u in range(5): 34 | assert list(sample_dag.get_parents(u)) == parents[u] 35 | assert list(sample_dag.get_children(u)) == children[u] 36 | 37 | def test_is_d_separate(self): 38 | dag = DiGraph(range(5)) 39 | edges = [(0, 1), (2, 1), (1, 3), (3, 4), (2, 3)] 40 | for e in edges: 41 | dag.add_edge(Edge(*e)) 42 | assert dag.is_d_separate(0, 2, []) == True 43 | assert dag.is_d_separate(0, 2, [1]) == False 44 | assert dag.is_d_separate(0, 2, [3]) == False 45 | assert dag.is_d_separate(0, 2, [1, 3]) == False 46 | assert dag.is_d_separate(2, 4, []) == False 47 | assert dag.is_d_separate(2, 4, [1, 3]) == True 48 | 49 | def test_degree(self): 50 | ind = [0 for i in range(5)] 51 | outd =[0 for i in range(5)] 52 | for u, v in edges: 53 | ind[v] = ind[v] + 1 54 | outd[u] = outd[u] + 1 55 | 56 | for u in range(5): 57 | assert ind[u] == sample_dag.in_degree(u) 58 | assert outd[u] == sample_dag.out_degree(u) 59 | 60 | def test_topo_sort(self): 61 | topo_list = sample_dag.topo_sort() 62 | for i in range(5): 63 | for j in range(i+1, 5): 64 | u, v = topo_list[i], topo_list[j] 65 | assert v not in sample_dag._adj[u] or sample_dag._adj[v][u] != Mark.Arrow 66 | 67 | def test_is_acyclic(self): 68 | dag = deepcopy(sample_dag) 69 | 70 | assert dag.is_acyclic() == True 71 | dag.add_edge(Edge(4, 1)) 72 | assert dag.is_acyclic() == False 73 | dag.remove_edge(4, 1) 74 | dag.add_edge(Edge(4, 0)) 75 | assert dag.is_acyclic() == False -------------------------------------------------------------------------------- /cdmir/tests/tests_graph/test_grpah_transform.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from copy import deepcopy 4 | from unittest import TestCase 5 | 6 | from cdmir.graph import DiGraph, Edge, Graph, Mark, PDAG 7 | from cdmir.graph.dag2cpdag import dag2cpdag 8 | from cdmir.graph.pdag2dag import pdag2dag 9 | 10 | logging.basicConfig(level=logging.DEBUG, 11 | format=' %(levelname)s :: %(message)s', 12 | datefmt='%m/%d/%Y %I:%M:%S %p') 13 | 14 | def txt2graph(filename: str) -> Graph: 15 | g = Graph() 16 | node_map = {} 17 | with open(filename, "r") as file: 18 | next_nodes_line = False 19 | for line in file.readlines(): 20 | line = line.strip() 21 | words = line.split() 22 | if len(words) > 1 and words[1] == 'Nodes:': 23 | next_nodes_line = True 24 | elif len(line) > 0 and next_nodes_line: 25 | next_nodes_line = False 26 | nodes = line.split(';') 27 | # print(nodes) 28 | for node in nodes: 29 | g.add_node(node) 30 | elif len(words) > 0 and words[0][-1] == '.': 31 | next_nodes_line = False 32 | node1 = words[1] 33 | node2 = words[3] 34 | end1 = words[2][0] 35 | end2 = words[2][-1] 36 | g.add_edge(Edge(node1, node2, to_endpotin(end1), to_endpotin(end2))) 37 | return g 38 | 39 | def to_endpotin(s: str) -> Mark: 40 | if s == 'o': 41 | return Mark.Circle 42 | elif s == '>': 43 | return Mark.Arrow 44 | elif s == '-': 45 | return Mark.Tail 46 | else : 47 | raise NotImplementedError 48 | 49 | 50 | def graph_compare(G1, G2) -> bool: 51 | if G1.node_set != G2.node_set: return False 52 | edge_set_1 = set(G1.edges) 53 | edge_set_2 = set(G2.edges) 54 | if edge_set_1 != edge_set_2 : return False 55 | return True 56 | 57 | class Test_graph_transform(TestCase): 58 | 59 | def test_dag2cpdag(self): 60 | ct = 5 61 | for i in range(1, ct+1): 62 | g = txt2graph(f'cdmir/tests/testdata/dag.{i}.txt') 63 | dag = DiGraph(g.node_list) 64 | dag.add_edges([e if e.mark_u==Mark.Tail else Edge(e.node_v, e.node_u) for e in g.edges]) 65 | cpdag = dag2cpdag(dag) 66 | truth_cpdag = txt2graph(f'cdmir/tests/testdata/cpdag.{i}.txt') 67 | assert graph_compare(cpdag, truth_cpdag) 68 | 69 | def test_pdag2dag(self): 70 | ct = 32 71 | for i in range(1, ct+1): 72 | g = txt2graph(f'cdmir/tests/testdata/graph_data/pdag.{i}.txt') 73 | pdag = PDAG(g.node_list) 74 | pdag.add_edges(g.edges) 75 | dag = pdag2dag(pdag) 76 | truth_dag = txt2graph(f'cdmir/tests/testdata/graph_data/dag.{i}.txt') 77 | assert graph_compare(dag, truth_dag) 78 | 79 | -------------------------------------------------------------------------------- /cdmir/tests/test_graph_evaluation.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | 5 | from cdmir.graph import DiGraph, Edge, Graph, Mark, PDAG 6 | 7 | from cdmir.utils.metrics import ( 8 | arrow_evaluation, 9 | directed_edge_evaluation, 10 | graph_equal, 11 | shd, 12 | skeleton_evaluation, 13 | ) 14 | 15 | 16 | class TestGraphEvaluation(TestCase): 17 | def test_graph_equal(self): 18 | g1 = PDAG([1, 2, 3]) 19 | g2 = PDAG([1, 2, 3]) 20 | 21 | assert graph_equal(g1, g2) 22 | 23 | def test_case1(self): 24 | g1 = PDAG([1, 2, 3]) 25 | g1.create_complete_undirected_graph() 26 | g2 = PDAG([1, 2, 3]) 27 | g2.create_complete_undirected_graph() 28 | 29 | skeleton_eval = skeleton_evaluation(g1, g2) 30 | assert np.isclose(skeleton_eval['precision'], 1.0) 31 | assert np.isclose(skeleton_eval['recall'], 1.0) 32 | assert np.isclose(skeleton_eval['f1'], 1.0) 33 | 34 | shd_eval = shd(g1, g2) 35 | assert shd_eval == 0 36 | 37 | 38 | 39 | def test_case2(self): 40 | g1 = PDAG([1, 2, 3]) 41 | g1.add_edge(Edge(1, 2, Mark.Tail, Mark.Arrow)) 42 | g2 = PDAG([1, 2, 3]) 43 | g2.add_edge(Edge(1, 2, Mark.Tail, Mark.Arrow)) 44 | 45 | 46 | arrow_eval = arrow_evaluation(g1, g2) 47 | assert np.isclose(arrow_eval['precision'], 1.0) 48 | assert np.isclose(arrow_eval['recall'], 1.0) 49 | assert np.isclose(arrow_eval['f1'], 1.0) 50 | 51 | directed_edge_eval = directed_edge_evaluation(g1, g2) 52 | assert np.isclose(directed_edge_eval['precision'], 1.0) 53 | assert np.isclose(directed_edge_eval['recall'], 1.0) 54 | assert np.isclose(directed_edge_eval['f1'], 1.0) 55 | 56 | shd_eval = shd(g1, g2) 57 | assert shd_eval == 0 58 | 59 | def test_case3(self): 60 | g1 = PDAG([1, 2, 3]) 61 | g1.add_edge(Edge(1, 2, Mark.Tail, Mark.Arrow)) 62 | g2 = PDAG([1, 2, 3]) 63 | g2.add_edge(Edge(1, 2, Mark.Tail, Mark.Arrow)) 64 | g2.add_edge(Edge(1, 3, Mark.Tail, Mark.Arrow)) 65 | 66 | arrow_eval = arrow_evaluation(g1, g2) 67 | assert np.isclose(arrow_eval['precision'], 0.5) 68 | assert np.isclose(arrow_eval['recall'], 1.0) 69 | assert np.isclose(arrow_eval['f1'], 2/3) 70 | 71 | directed_edge_eval = directed_edge_evaluation(g1, g2) 72 | assert np.isclose(directed_edge_eval['precision'], 0.5) 73 | assert np.isclose(directed_edge_eval['recall'], 1.0) 74 | assert np.isclose(directed_edge_eval['f1'], 2/3) 75 | 76 | shd_eval = shd(g1, g2) 77 | assert shd_eval == 1 78 | 79 | 80 | class TestAssertion(TestCase): 81 | def test_graph_type(self): 82 | g1 = DiGraph([1, 2, 3]) 83 | g2 = PDAG([1, 2, 3]) 84 | with self.assertRaises(AssertionError): 85 | _ = graph_equal(g1, g2) 86 | 87 | def test_graph_node(self): 88 | g1 = PDAG([1, 2, 3]) 89 | g2 = PDAG([3, 2, 1]) 90 | 91 | with self.assertRaises(AssertionError): 92 | _ = graph_equal(g1, g2) 93 | 94 | -------------------------------------------------------------------------------- /cdmir/utils/local_score/bdeu_score.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from math import lgamma 4 | from typing import Iterable 5 | 6 | from numpy import asarray, log, ndarray, shape, unique 7 | from pandas import DataFrame 8 | 9 | from ._base import BaseLocalScoreFunction 10 | 11 | 12 | class BDeuScore(BaseLocalScoreFunction): 13 | 14 | def __init__(self, data: ndarray | DataFrame, *args, **kwargs): 15 | super().__init__(data, *args, **kwargs) 16 | if not kwargs.__contains__('sample_prior'): 17 | self.sample_prior = 1 18 | else: 19 | self.sample_prior = kwargs["sample_prior"] 20 | if not kwargs.__contains__('structure_prior'): 21 | self.structure_prior = 1 22 | else: 23 | self.structure_prior = kwargs["structure_prior"] 24 | self.r_i_map = {i: len(unique(asarray(self.data[:, i]))) for i in range(shape(self.data)[1])} 25 | 26 | def _score_function(self, i: int, parent_i: Iterable[int]): 27 | # calculate the local score with BDeu for the discrete case 28 | # 29 | # INPUT 30 | # i: current index 31 | # PAi: parent indexes 32 | # OUTPUT: 33 | # local BDeu score 34 | 35 | parent_i = list(parent_i) 36 | # calculate q_{i} 37 | q_i = 1 38 | for pa in parent_i: 39 | q_i *= self.r_i_map[pa] 40 | 41 | # calculate N_{ij} 42 | names = ['x{}'.format(var) for var in range(shape(self.data)[1])] 43 | Data_pd = DataFrame(self.data, columns=names) 44 | parant_names = ['x{}'.format(var) for var in parent_i] 45 | Data_pd_group_Nij = Data_pd.groupby(parant_names) 46 | Nij_map = {key: len(Data_pd_group_Nij.indices.get(key)) for key in Data_pd_group_Nij.indices.keys()} 47 | Nij_map_keys_list = list(Nij_map.keys()) 48 | 49 | # calculate N_{ijk} 50 | Nijk_map = {ij: Data_pd_group_Nij.get_group(ij).groupby('x{}'.format(i)).apply(len).reset_index() for ij in 51 | Nij_map.keys()} 52 | for v in Nijk_map.values(): 53 | v.columns = ['x{}'.format(i), 'times'] 54 | 55 | BDeu_score = 0 56 | # first term 57 | vm = shape(self.data)[0] - 1 58 | BDeu_score += len(parent_i) * log(self.structure_prior / vm) + (vm - len(parent_i)) * log(1 - (self.structure_prior / vm)) 59 | 60 | # second term 61 | for pa in range(len(Nij_map_keys_list)): 62 | Nij = Nij_map.get(Nij_map_keys_list[pa]) 63 | first_term = lgamma(self.sample_prior / q_i) - lgamma(Nij + self.sample_prior / q_i) 64 | 65 | second_term = 0 66 | Nijk_list = Nijk_map.get(Nij_map_keys_list[pa])['times'].to_numpy() 67 | for Nijk in Nijk_list: 68 | second_term += lgamma(Nijk + self.sample_prior / (self.r_i_map[i] * q_i)) - lgamma(self.sample_prior / (self.r_i_map[i] * q_i)) 69 | 70 | BDeu_score += first_term + second_term 71 | 72 | return BDeu_score 73 | 74 | def __call__(self, i: int, parent_i: Iterable[int], *args, **kwargs): 75 | return self._score(i, parent_i, self._score_function) 76 | 77 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/indTest/fastHSIC.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("./indTest") 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from .HSICTestObject import HSICTestObject 8 | from .HSICBlockTestObject import HSICBlockTestObject 9 | from .HSICSpectralTestObject import HSICSpectralTestObject 10 | from numpy import concatenate, loadtxt, reshape, savetxt, shape, shape, transpose 11 | from kerpy.kerpy.GaussianKernel import GaussianKernel 12 | 13 | 14 | def test(alph=0.05): 15 | x = np.random.randn(1000) 16 | y = np.random.randn(1000) 17 | 18 | lens = len(x) 19 | x = x.reshape(lens, 1) 20 | y = y.reshape(lens, 1) 21 | 22 | kernelY = GaussianKernel(float(0.15)) 23 | kernelX = GaussianKernel(float(0.15)) 24 | 25 | num_samples = lens 26 | 27 | myspectralobject = HSICSpectralTestObject(num_samples, kernelX=kernelX, kernelY=kernelY, 28 | kernelX_use_median=False, kernelY_use_median=False, 29 | rff=True, num_rfx=20, num_rfy=20, num_nullsims=1000) 30 | 31 | pvalue = myspectralobject.compute_pvalue(x, y) 32 | 33 | if pvalue > alph: 34 | return True 35 | else: 36 | return False 37 | 38 | 39 | def test2(alph=0.08): 40 | x = np.random.randn(1000) 41 | y = np.random.randn(1000) 42 | 43 | lens = len(x) 44 | x = x.reshape(lens, 1) 45 | y = y.reshape(lens, 1) 46 | kernelX = GaussianKernel() 47 | kernelY = GaussianKernel() 48 | 49 | num_samples = lens 50 | 51 | myblockobject = HSICBlockTestObject(num_samples, kernelX=kernelX, kernelY=kernelY, 52 | kernelX_use_median=False, kernelY_use_median=False, 53 | blocksize=80, nullvarmethod='permutation') 54 | 55 | pvalue = myblockobject.compute_pvalue(x, y) 56 | 57 | if pvalue > alph: 58 | return True 59 | else: 60 | return False 61 | 62 | 63 | def INtest(x, y, alph=0.01): 64 | lens = len(x) 65 | x = x.reshape(lens, 1) 66 | y = y.reshape(lens, 1) 67 | kernelX = GaussianKernel() 68 | kernelY = GaussianKernel() 69 | 70 | num_samples = lens 71 | 72 | myspectralobject = HSICSpectralTestObject(num_samples, kernelX=kernelX, kernelY=kernelY, 73 | kernelX_use_median=True, kernelY_use_median=True, 74 | rff=True, num_rfx=20, num_rfy=20, num_nullsims=1000) 75 | 76 | pvalue = myspectralobject.compute_pvalue(x, y) 77 | 78 | return pvalue 79 | 80 | 81 | def INtest2(x, y, alph=0.01): 82 | lens = len(x) 83 | x = x.reshape(lens, 1) 84 | y = y.reshape(lens, 1) 85 | kernelX = GaussianKernel() 86 | kernelY = GaussianKernel() 87 | num_samples = lens 88 | 89 | myblockobject = HSICBlockTestObject(num_samples, kernelX=kernelX, kernelY=kernelY, 90 | kernelX_use_median=True, kernelY_use_median=True, 91 | blocksize=200, nullvarmethod='permutation') 92 | 93 | pvalue = myblockobject.compute_pvalue(x, y) 94 | 95 | return pvalue 96 | -------------------------------------------------------------------------------- /cdmir/discovery/constraint/pc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from cdmir.utils import data_form_converter_for_class_method 3 | from cdmir.utils.independence import ConditionalIndependentTest 4 | 5 | from cdmir.discovery.constraint.adjacency_search import adjacency_search 6 | 7 | 8 | class PC(object): 9 | """ 10 | A Python implementation of the PC algorithm for causal discovery. 11 | 12 | Attributes 13 | ---------- 14 | skeleton : Undirected graph representing causal relationships 15 | causal_graph : Partially directed acyclic graph (PDAG) 16 | sep_set : Separation sets for node pairs 17 | alpha : Significance level for independence tests 18 | indep_test : Example of Conditional Independence Test 19 | verbose : Whether to print algorithm progress 20 | """ 21 | 22 | def __init__(self, 23 | alpha: float = 0.05, 24 | adjacency_search_method=adjacency_search, 25 | verbose: bool = False 26 | ): 27 | """ 28 | Initialize a PC estimator. 29 | 30 | :param alpha: Significance level for conditional independence tests (default: 0.05) 31 | :param adjacency_search_method: Function for adjacency search phase (default: adjacency_search) 32 | :param verbose: Whether to print algorithm progress (default: False) 33 | """ 34 | self.skeleton = None 35 | self.causal_graph = None 36 | self.sep_set = None 37 | self.alpha = alpha 38 | self.indep_test = None 39 | self.adjacency_search_method = adjacency_search_method 40 | self.verbose = verbose 41 | 42 | @data_form_converter_for_class_method 43 | def fit(self, data, var_names, indep_cls, *args, **kwargs): 44 | """ 45 | This method is the core training method of the model, which is based on input data 46 | and conditional independence testing to construct and direct causal diagrams. 47 | 48 | :param data: Input dataset 49 | :param var_names: List of variable names 50 | :param indep_cls: Conditional independence test class 51 | :param *args: Positional arguments for independence test 52 | :param **kwargs: Keyword arguments for independence test 53 | """ 54 | assert issubclass(indep_cls, ConditionalIndependentTest) 55 | self.indep_test = indep_cls(data, var_names, *args, **kwargs) 56 | self.causal_graph, self.sep_set = self.adjacency_search_method( 57 | self.indep_test, self.indep_test.var_names, self.alpha, verbose=self.verbose 58 | ) 59 | self.skeleton = copy.deepcopy(self.causal_graph) 60 | self.causal_graph.rule0(self.sep_set, self.verbose) 61 | self.causal_graph.orient_by_meek_rules(self.verbose) 62 | 63 | def set_alpha(self, alpha): 64 | # Set significance level for conditional independence tests. 65 | self.alpha = alpha 66 | 67 | def get_alpha(self): 68 | # Get current significance level. 69 | return self.alpha 70 | 71 | def set_verbose(self, verbose): 72 | # Toggle verbose output during algorithm execution. 73 | self.verbose = verbose 74 | 75 | def get_verbose(self): 76 | # Get current verbosity setting. 77 | return self.verbose 78 | -------------------------------------------------------------------------------- /cdmir/tests/test_pc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from itertools import permutations 4 | from unittest import TestCase 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from numpy.random import normal 9 | from scipy import stats 10 | 11 | from cdmir.discovery.constraint import PC 12 | from cdmir.graph import DiGraph, Edge 13 | from cdmir.graph.dag2cpdag import dag2cpdag 14 | from cdmir.utils.independence import Dsep, FisherZ 15 | from cdmir.utils.metrics.graph_evaluation import graph_equal 16 | 17 | logging.basicConfig(level=logging.DEBUG, 18 | format=' %(levelname)s :: %(message)s', 19 | datefmt='%m/%d/%Y %I:%M:%S %p') 20 | 21 | 22 | class TestPC(TestCase): 23 | def test_pc_numpy_dataset(self): 24 | X = self.gen_numpy_dataset() 25 | pc = PC() 26 | pc.fit(X, indep_cls=FisherZ) 27 | print(pc.causal_graph) 28 | assert pc.causal_graph.node_list == list(range(X.shape[1])) 29 | assert pc.causal_graph.number_of_edges() == 4 30 | 31 | def test_pc_pandas_dataset(self): 32 | pd = self.gen_pandas_dataset() 33 | pc = PC(verbose=True) 34 | pc.fit(pd, indep_cls=FisherZ) 35 | print(pc.causal_graph) 36 | assert (pc.causal_graph.node_list == pd.columns).all() 37 | assert pc.causal_graph.number_of_edges() == 7 38 | 39 | def test_pc_with_random_graph(self): 40 | random.seed(3407) 41 | np.random.seed(3407) 42 | node_dim = 10 43 | for i in range(100): 44 | graph = np.random.choice([0, 1], size=(node_dim, node_dim), p=[0.7, 0.3]) 45 | dag = np.tril(graph, k=-1) 46 | perm_mat = np.random.permutation(np.eye(node_dim)) 47 | dag = perm_mat.T @ dag @ perm_mat 48 | true_graph = DiGraph(range(node_dim)) 49 | data = np.empty(shape=(0, node_dim)) 50 | for node_u, node_v in permutations(range(node_dim), 2): 51 | if dag[node_u, node_v] > 0.5: 52 | true_graph.add_edge(Edge(node_u, node_v)) 53 | pc = PC(verbose=True) 54 | pc.fit(data, indep_cls=Dsep, true_graph=true_graph) 55 | 56 | cpdag = dag2cpdag(true_graph) 57 | assert graph_equal(pc.causal_graph, cpdag) 58 | 59 | def gen_numpy_dataset(self): 60 | random.seed(3407) 61 | np.random.seed(3407) 62 | sample_size = 100000 63 | X1 = normal(size=(sample_size, 1)) 64 | X2 = X1 + normal(size=(sample_size, 1)) 65 | X3 = X1 + normal(size=(sample_size, 1)) 66 | X4 = X2 + X3 + normal(size=(sample_size, 1)) 67 | X = np.hstack((X1, X2, X3, X4)) 68 | X = stats.zscore(X, ddof=1, axis=0) 69 | return X 70 | 71 | def gen_pandas_dataset(self): 72 | random.seed(3407) 73 | np.random.seed(3407) 74 | sample_size = 100000 75 | X1 = normal(size=(sample_size, 1)) 76 | X2 = normal(size=(sample_size, 1)) 77 | X3 = X1 + X2 + 0.3 * normal(size=(sample_size, 1)) 78 | X4 = X1 + X3 + 0.3 * normal(size=(sample_size, 1)) 79 | X5 = 0.5 * X1 + 0.5 * X2 + X3 + X4 + 0.3 * normal(size=(sample_size, 1)) 80 | 81 | X = np.hstack((X1, X2, X3, X4, X5)) 82 | X = stats.zscore(X, ddof=1, axis=0) 83 | return pd.DataFrame(X, columns=['A', 'B', 'C', 'D', 'E']) 84 | -------------------------------------------------------------------------------- /cdmir/utils/independence/functional/kci.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import exp 3 | from scipy import stats 4 | from scipy.linalg import eigh 5 | from scipy.spatial.distance import cdist, pdist, squareform 6 | 7 | 8 | class KCI: 9 | """ 10 | Simplified Kernel-based Conditional Independence test (unconditional version) 11 | Uses only Gaussian kernels for independence testing 12 | """ 13 | 14 | def __init__(self, null_samples=1000, use_gamma_approx=True): 15 | """ 16 | Initialize KCI test 17 | 18 | Parameters: 19 | null_samples: number of samples for null distribution (if not using gamma approximation) 20 | use_gamma_approx: whether to use gamma approximation for p-value (faster) 21 | """ 22 | self.null_samples = null_samples 23 | self.use_gamma_approx = use_gamma_approx 24 | 25 | def test(self, X, Y): 26 | """ 27 | Perform independence test between X and Y 28 | 29 | Parameters: 30 | X: input data matrix (n_samples x n_features_x) 31 | Y: input data matrix (n_samples x n_features_y) 32 | 33 | Returns: 34 | p_value: p-value for independence test 35 | test_stat: test statistic value 36 | """ 37 | n = X.shape[0] 38 | 39 | # Standardize data 40 | X = stats.zscore(X, ddof=1, axis=0) 41 | Y = stats.zscore(Y, ddof=1, axis=0) 42 | X[np.isnan(X)] = 0 43 | Y[np.isnan(Y)] = 0 44 | 45 | # Compute Gaussian kernel matrices 46 | Kx = self._gaussian_kernel(X) 47 | Ky = self._gaussian_kernel(Y) 48 | 49 | # Center kernel matrices 50 | H = np.eye(n) - np.ones((n, n)) / n 51 | Kx = H @ Kx @ H 52 | Ky = H @ Ky @ H 53 | 54 | # Compute test statistic (HSIC) 55 | test_stat = np.sum(Kx * Ky) 56 | 57 | # Compute p-value 58 | p_value = self._gamma_approx_pvalue(Kx, Ky, test_stat) 59 | 60 | return p_value, test_stat 61 | 62 | def _gaussian_kernel(self, data, width=1.): 63 | """ 64 | Compute Gaussian kernel matrix 65 | 66 | Parameters: 67 | data: input data matrix 68 | 69 | Returns: 70 | kernel_matrix: Gaussian kernel matrix 71 | """ 72 | n, d = data.shape 73 | 74 | sq_dists = squareform(pdist(data, 'sqeuclidean')) 75 | K = exp(-0.5 * sq_dists) 76 | 77 | return K 78 | 79 | def _gamma_approx_pvalue(self, Kx, Ky, test_stat): 80 | """ 81 | Compute p-value using Gamma distribution approximation 82 | 83 | Parameters: 84 | Kx: centered kernel matrix for X 85 | Ky: centered kernel matrix for Y 86 | test_stat: test statistic value 87 | 88 | Returns: 89 | p_value: approximated p-value 90 | """ 91 | n = Kx.shape[0] 92 | 93 | # Compute moments for Gamma approximation 94 | mean_approx = np.trace(Kx) * np.trace(Ky) / n 95 | var_approx = 2 * np.sum(Kx ** 2) * np.sum(Ky ** 2) / n ** 2 96 | 97 | # Gamma distribution parameters 98 | k = mean_approx ** 2 / var_approx 99 | theta = var_approx / mean_approx 100 | 101 | return 1 - stats.gamma.cdf(test_stat, k, scale=theta) 102 | -------------------------------------------------------------------------------- /cdmir/graph/pdag2dag.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import numpy as np 4 | 5 | from cdmir.graph import DiGraph, Edge 6 | from cdmir.graph.mark import Mark 7 | from cdmir.graph.pdag import PDAG 8 | 9 | 10 | # function Gd = PDAG2DAG(G) % transform a PDAG to DAG 11 | def pdag2dag(pdag: PDAG): 12 | nodes = pdag.node_list 13 | # first create a DAG that contains all the directed edges in PDAG 14 | pdag_copy = deepcopy(pdag) 15 | edges = pdag_copy.edges 16 | for edge in edges: 17 | if not ((edge.mark_u == Mark.Arrow and edge.mark_v == Mark.Tail) or (edge.mark_u == Mark.Tail and edge.mark_v == Mark.Arrow)): 18 | pdag_copy.remove_edge(edge.node_u, edge.node_v) 19 | 20 | pdag_p = deepcopy(pdag) 21 | inde = np.zeros(pdag_p.number_of_nodes(), dtype=np.dtype(int)) # index whether the ith node has been removed. 1:removed; 0: not 22 | while 0 in inde: 23 | for i in range(pdag_p.number_of_nodes()): 24 | if inde[i] == 0: 25 | sign = 0 26 | if len(np.intersect1d(np.where([pdag_p.is_arrow(pdag_p.node_list[i], pdag_p.node_list[i_parent]) for i_parent in range(pdag_p.number_of_nodes())])[0], np.where(inde == 0)[0])) == 0: # Xi has no out-going edges 27 | sign = sign + 1 28 | Nx = np.intersect1d(np.intersect1d(np.where([pdag_p.is_tail(pdag_p.node_list[i], pdag_p.node_list[i_tail]) for i_tail in range(pdag_p.number_of_nodes())])[0], np.where([pdag_p.is_tail(pdag_p.node_list[i_tail], pdag_p.node_list[i]) for i_tail in range(pdag_p.number_of_nodes())])[0]), np.where(inde == 0)[0]) # find the neighbors of Xi in P 29 | Ax = np.intersect1d(np.union1d(np.where([pdag_p.is_arrow(pdag_p.node_list[i_parent], pdag_p.node_list[i]) for i_parent in range(pdag_p.number_of_nodes())])[0], np.where([pdag_p.is_arrow(pdag_p.node_list[i], pdag_p.node_list[i_parent]) for i_parent in range(pdag_p.number_of_nodes())])[0]), np.where(inde == 0)[0]) # find the adjacent of Xi in P 30 | Ax = np.union1d(Ax, Nx) 31 | if len(Nx) > 0: 32 | if check2(pdag_p, Nx, Ax): # according to the original paper 33 | sign = sign + 1 34 | else: 35 | sign = sign + 1 36 | if sign == 2: 37 | # for each undirected edge Y-X in PDAG, insert a directed edge Y->X in G 38 | for index in np.intersect1d(np.where([pdag_p.is_tail(pdag_p.node_list[i], pdag_p.node_list[i_tail]) for i_tail in range(pdag_p.number_of_nodes())])[0], np.where([pdag_p.is_tail(pdag_p.node_list[i_tail], pdag_p.node_list[i]) for i_tail in range(pdag_p.number_of_nodes())])[0]): 39 | if not pdag_copy.is_connected(nodes[index], nodes[i]): 40 | pdag_copy.add_edge(Edge(nodes[index], nodes[i], Mark.Tail, Mark.Arrow), overwrite=False) 41 | inde[i] = 1 42 | 43 | d = DiGraph(pdag_copy.node_list) 44 | for edge in pdag_copy.edges: 45 | if edge.mark_u == Mark.Arrow and edge.mark_v == Mark.Tail: 46 | edge = Edge(edge.node_v, edge.node_u, edge.mark_v, edge.mark_u) 47 | d.add_edge(edge, overwrite=True) 48 | 49 | return d 50 | 51 | 52 | def check2(G, Nx, Ax): 53 | s = 1 54 | for i in range(len(Nx)): 55 | j = np.delete(Ax, np.where(Ax == Nx[i])[0]) 56 | if len(np.where([not G.is_connected(G.node_list[Nx[i]], G.node_list[jj]) for jj in j])[0]) != 0: 57 | s = 0 58 | break 59 | return s -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/indTest/HSIC2.py: -------------------------------------------------------------------------------- 1 | # API for HSIC test 2 | import sys 3 | 4 | sys.path.append("./indTest") 5 | import numpy as np 6 | import pandas as pd 7 | from cdmir.discovery.funtional_based.LearningHierarchicalStructure.indTest.HSICTestObject import HSICTestObject 8 | from cdmir.discovery.funtional_based.LearningHierarchicalStructure.indTest.HSICBlockTestObject import \ 9 | HSICBlockTestObject 10 | from cdmir.discovery.funtional_based.LearningHierarchicalStructure.indTest.HSICSpectralTestObject import \ 11 | HSICSpectralTestObject 12 | from kerpy.kerpy.GaussianKernel import GaussianKernel 13 | from numpy import concatenate, loadtxt, reshape, savetxt, shape, shape, transpose 14 | 15 | 16 | # method 1:HSIC test and return boolean 17 | # x and y: 18 | # data type: numpy.array() 19 | # dim: samples * 1 20 | # alph: test alph 21 | def test(alph=0.01): 22 | x = np.random.randn(1000) 23 | y = np.random.randn(1000) 24 | 25 | lens = len(x) 26 | x = x.reshape(lens, 1) 27 | y = y.reshape(lens, 1) 28 | 29 | kernelY = GaussianKernel(float(0.1)) 30 | kernelX = GaussianKernel(float(0.1)) 31 | num_samples = lens 32 | myspectralobject = HSICSpectralTestObject(num_samples, kernelX=kernelX, kernelY=kernelY, 33 | kernelX_use_median=False, kernelY_use_median=False, 34 | rff=True, num_rfx=30, num_rfy=30, num_nullsims=1000) 35 | pvalue = myspectralobject.compute_pvalue(x, y) 36 | 37 | # print(pvalue) 38 | if pvalue > alph: 39 | return True 40 | else: 41 | return False 42 | 43 | 44 | # method 2 45 | def test2(alph=0.08): 46 | x = np.random.randn(1000) 47 | y = np.random.randn(1000) 48 | 49 | lens = len(x) 50 | x = x.reshape(lens, 1) 51 | y = y.reshape(lens, 1) 52 | kernelY = GaussianKernel(float(0.45)) 53 | kernelX = GaussianKernel(float(0.45)) 54 | num_samples = lens 55 | myblockobject = HSICBlockTestObject(num_samples, kernelX=kernelX, kernelY=kernelY, 56 | kernelX_use_median=False, kernelY_use_median=False, 57 | blocksize=80, nullvarmethod='permutation') 58 | 59 | pvalue = myblockobject.compute_pvalue(x, y) 60 | 61 | return pvalue 62 | 63 | # HSIC test by return hsic pval 64 | def INtest(x, y, alph=0.01): 65 | lens = len(x) 66 | x = x.reshape(lens, 1) 67 | y = y.reshape(lens, 1) 68 | kernelX = GaussianKernel() 69 | kernelY = GaussianKernel() 70 | 71 | num_samples = lens 72 | 73 | myspectralobject = HSICSpectralTestObject(num_samples, kernelX=kernelX, kernelY=kernelY, 74 | kernelX_use_median=True, kernelY_use_median=True, 75 | rff=True, num_rfx=30, num_rfy=30, num_nullsims=1000) 76 | pvalue = myspectralobject.compute_pvalue(x, y) 77 | 78 | return pvalue 79 | 80 | 81 | def INtest2(x, y, alph=0.01): 82 | lens = len(x) 83 | x = x.reshape(lens, 1) 84 | y = y.reshape(lens, 1) 85 | kernelX = GaussianKernel() 86 | kernelY = GaussianKernel() 87 | num_samples = lens 88 | 89 | myblockobject = HSICBlockTestObject(num_samples, kernelX=kernelX, kernelY=kernelY, 90 | kernelX_use_median=True, kernelY_use_median=True, 91 | blocksize=200, nullvarmethod='permutation') 92 | 93 | pvalue = myblockobject.compute_pvalue(x, y) 94 | 95 | return pvalue 96 | -------------------------------------------------------------------------------- /docs/source/discovery_methods/tensor_rank/Tensor_rank/tensor_rank.rst: -------------------------------------------------------------------------------- 1 | Tensor Rank Causal Discovery 2 | ============================= 3 | 4 | Introduction 5 | ------------ 6 | 7 | Tensor rank causal discovery is a method for learning discrete latent variable models with a three-pure-children structure. It uses tensor rank conditions to identify causal clusters from observed variables and then infers d-separation relationships among latent variables. 8 | 9 | This method consists of three main components: 10 | 11 | 1. **LCC (LearnCausalCluster)**: Identifies causal clusters from observed variables using tensor rank conditions 12 | 2. **Gtest**: Performs goodness of fit tests to determine tensor ranks 13 | 3. **DiscretePC**: Learns causal skeleton relationships among latent variables 14 | 15 | Usage 16 | ----- 17 | 18 | .. code-block:: python 19 | 20 | from cdmir.discovery.Tensor_Rank.LearnCausalCluster import LearnCausalCluster 21 | import cdmir.discovery.Tensor_Rank.DiscretePC as PC 22 | import pandas as pd 23 | from cdmir.datasets.pgmdata import Gdata2 24 | import pkg_resources 25 | 26 | # Example 1: Learn causal clusters 27 | data = Gdata2(100000) 28 | clusters = LearnCausalCluster(data, LSupp=2) 29 | print("Learned causal clusters:", clusters) 30 | 31 | # Example 2: Learn causal skeleton 32 | csv_path = pkg_resources.resource_filename('cdmir', 'tests/testdata/out.csv') 33 | data = pd.read_csv(csv_path) 34 | labels = ['L1', 'L2', 'L3'] 35 | cluster = {'L1': ['O1a', 'O1b', 'O1c'], 'L2': ['O2a', 'O2b', 'O2c'], 'L3': ['O3a', 'O3b', 'O3c']} 36 | adjacency_matrix = PC.test(data, labels, cluster) 37 | print("Causal adjacency matrix of latent variables:", adjacency_matrix) 38 | 39 | Parameters 40 | ---------- 41 | 42 | **LCC (LearnCausalCluster) Parameters:** 43 | 44 | - **data**: ndarray. 45 | 46 | Input data containing observed variables. 47 | - **LSupp**: int, optional, default: 2. 48 | 49 | Support set size for hidden variables. 50 | - **alpha**: float, optional, default: 0.05. 51 | 52 | Confidence level for the goodness of fit test. 53 | 54 | **Gtest (test_goodness_of_fit) Parameters:** 55 | 56 | - **tensor**: ndarray. 57 | 58 | The original four-way tensor to test. 59 | - **rank**: int. 60 | 61 | The rank of the CP decomposition. 62 | 63 | **DiscretePC (test) Parameters:** 64 | 65 | - **data1**: ndarray. 66 | 67 | Data for all observed variables. 68 | - **la**: list. 69 | 70 | List of hidden variable names. 71 | - **cluster**: dict. 72 | 73 | Causal clustering composed of observed variables corresponding to hidden variables. 74 | - **alpha**: float, optional, default: 0.2. 75 | 76 | Significance level for conditional independent test. 77 | 78 | Returns 79 | ------- 80 | 81 | **LCC (LearnCausalCluster) Returns:** 82 | 83 | - **CausalCluster**: list. 84 | 85 | List of identified causal clusters, where each cluster is a list of variable names. 86 | 87 | **Gtest (test_goodness_of_fit) Returns:** 88 | 89 | - **chi_square_p_value**: float. 90 | 91 | P-value from the Chi-square goodness of fit test. 92 | 93 | **DiscretePC (test) Returns:** 94 | 95 | - **adjacency_matrix**: ndarray. 96 | 97 | Causal adjacency matrix of variables, where True indicates a direct causal relationship. 98 | 99 | References 100 | ---------- 101 | 102 | .. [1] Chen Z, Cai R, Xie F, et al. Learning Discrete Latent Variable Structures with Tensor Rank Conditions[C]//The Thirty-eighth Annual Conference on Neural Information Processing Systems. 103 | -------------------------------------------------------------------------------- /cdmir/utils/independence/functional/HSIC.py: -------------------------------------------------------------------------------- 1 | """ 2 | python implementation of Hilbert Schmidt Independence Criterion 3 | hsic_gam implements the HSIC test using a Gamma approximation 4 | Python 2.7.12 5 | Gretton, A., Fukumizu, K., Teo, C. H., Song, L., Scholkopf, B., 6 | & Smola, A. J. (2007). A kernel statistical test of independence. 7 | In Advances in neural information processing systems (pp. 585-592). 8 | Shoubo (shoubo.sub AT gmail.com) 9 | 09/11/2016 10 | Inputs: 11 | X n by dim_x matrix 12 | Y n by dim_y matrix 13 | alph level of test 14 | Outputs: 15 | testStat test statistics 16 | thresh test threshold for level alpha test 17 | """ 18 | 19 | # HSIC with pval return 20 | 21 | from __future__ import division 22 | 23 | import numpy as np 24 | from scipy.stats import gamma 25 | 26 | 27 | def rbf_dot(pattern1, pattern2, deg): 28 | size1 = pattern1.shape 29 | size2 = pattern2.shape 30 | 31 | G = np.sum(pattern1 * pattern1, 1).reshape(size1[0], 1) 32 | H = np.sum(pattern2 * pattern2, 1).reshape(size2[0], 1) 33 | 34 | Q = np.tile(G, (1, size2[0])) 35 | R = np.tile(H.T, (size1[0], 1)) 36 | 37 | H = Q + R - 2 * np.dot(pattern1, pattern2.T) 38 | 39 | H = np.exp(-H / 2 / (deg ** 2)) 40 | 41 | return H 42 | 43 | 44 | def hsic_gam(X, Y, alph=0.01): 45 | """ 46 | X, Y are numpy vectors with row - sample, col - dim 47 | alph is the significance level 48 | auto choose median to be the kernel width 49 | """ 50 | n = X.shape[0] 51 | 52 | # ----- width of X ----- 53 | Xmed = X 54 | 55 | G = np.sum(Xmed * Xmed, 1).reshape(n, 1) 56 | Q = np.tile(G, (1, n)) 57 | R = np.tile(G.T, (n, 1)) 58 | 59 | dists = Q + R - 2 * np.dot(Xmed, Xmed.T) 60 | dists = dists - np.tril(dists) 61 | dists = dists.reshape(n ** 2, 1) 62 | 63 | width_x = np.sqrt(0.5 * np.median(dists[dists > 0])) 64 | # ----- ----- 65 | 66 | # ----- width of Y ----- 67 | Ymed = Y 68 | 69 | G = np.sum(Ymed * Ymed, 1).reshape(n, 1) 70 | Q = np.tile(G, (1, n)) 71 | R = np.tile(G.T, (n, 1)) 72 | 73 | dists = Q + R - 2 * np.dot(Ymed, Ymed.T) 74 | dists = dists - np.tril(dists) 75 | dists = dists.reshape(n ** 2, 1) 76 | 77 | width_y = np.sqrt(0.5 * np.median(dists[dists > 0])) 78 | # ----- ----- 79 | 80 | bone = np.ones((n, 1), dtype=float) 81 | H = np.identity(n) - np.ones((n, n), dtype=float) / n 82 | 83 | K = rbf_dot(X, X, width_x) 84 | L = rbf_dot(Y, Y, width_y) 85 | 86 | Kc = np.dot(np.dot(H, K), H) 87 | Lc = np.dot(np.dot(H, L), H) 88 | 89 | # approximate 90 | 91 | testStat = np.sum(Kc.T * Lc) / n 92 | 93 | varHSIC = (Kc * Lc / 6) ** 2 94 | 95 | varHSIC = (np.sum(varHSIC) - np.trace(varHSIC)) / n / (n - 1) 96 | 97 | varHSIC = varHSIC * 72 * (n - 4) * (n - 5) / n / (n - 1) / (n - 2) / (n - 3) 98 | 99 | K = K - np.diag(np.diag(K)) 100 | L = L - np.diag(np.diag(L)) 101 | 102 | muX = np.dot(np.dot(bone.T, K), bone) / n / (n - 1) 103 | muY = np.dot(np.dot(bone.T, L), bone) / n / (n - 1) 104 | 105 | mHSIC = (1 + muX * muY - muX - muY) / n 106 | 107 | al = mHSIC ** 2 / varHSIC # k_appr 108 | bet = varHSIC * n / mHSIC # theta_appr 109 | 110 | pval = 1 - gamma.cdf(testStat, al, bet) 111 | 112 | thresh = gamma.ppf(1 - alph, al, scale=bet)[0][0] 113 | 114 | return testStat, thresh 115 | 116 | 117 | def test(x, y, alpha=0.05): 118 | lens = len(x) 119 | x1 = x.reshape(lens, 1) 120 | y1 = y.reshape(lens, 1) 121 | testStat, thresh = hsic_gam(x1, y1, alpha) 122 | 123 | if testStat < thresh: 124 | return True 125 | elif testStat > thresh: 126 | return False 127 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/LearningHierarchicalStructure/indTest/HSIC.py: -------------------------------------------------------------------------------- 1 | """ 2 | python implementation of Hilbert Schmidt Independence Criterion 3 | hsic_gam implements the HSIC test using a Gamma approximation 4 | Python 2.7.12 5 | Gretton, A., Fukumizu, K., Teo, C. H., Song, L., Scholkopf, B., 6 | & Smola, A. J. (2007). A kernel statistical test of independence. 7 | In Advances in neural information processing systems (pp. 585-592). 8 | Shoubo (shoubo.sub AT gmail.com) 9 | 09/11/2016 10 | Inputs: 11 | X n by dim_x matrix 12 | Y n by dim_y matrix 13 | alph level of test 14 | Outputs: 15 | testStat test statistics 16 | thresh test threshold for level alpha test 17 | """ 18 | 19 | # HSIC with pval return 20 | 21 | from __future__ import division 22 | 23 | import numpy as np 24 | from scipy.stats import gamma 25 | 26 | 27 | def rbf_dot(pattern1, pattern2, deg): 28 | size1 = pattern1.shape 29 | size2 = pattern2.shape 30 | 31 | G = np.sum(pattern1 * pattern1, 1).reshape(size1[0], 1) 32 | H = np.sum(pattern2 * pattern2, 1).reshape(size2[0], 1) 33 | 34 | Q = np.tile(G, (1, size2[0])) 35 | R = np.tile(H.T, (size1[0], 1)) 36 | 37 | H = Q + R - 2 * np.dot(pattern1, pattern2.T) 38 | 39 | H = np.exp(-H / 2 / (deg ** 2)) 40 | 41 | return H 42 | 43 | 44 | def hsic_gam(X, Y, alph=0.01): 45 | """ 46 | X, Y are numpy vectors with row - sample, col - dim 47 | alph is the significance level 48 | auto choose median to be the kernel width 49 | """ 50 | n = X.shape[0] 51 | 52 | # ----- width of X ----- 53 | Xmed = X 54 | 55 | G = np.sum(Xmed * Xmed, 1).reshape(n, 1) 56 | Q = np.tile(G, (1, n)) 57 | R = np.tile(G.T, (n, 1)) 58 | 59 | dists = Q + R - 2 * np.dot(Xmed, Xmed.T) 60 | dists = dists - np.tril(dists) 61 | dists = dists.reshape(n ** 2, 1) 62 | 63 | width_x = np.sqrt(0.5 * np.median(dists[dists > 0])) 64 | # ----- ----- 65 | 66 | # ----- width of Y ----- 67 | Ymed = Y 68 | 69 | G = np.sum(Ymed * Ymed, 1).reshape(n, 1) 70 | Q = np.tile(G, (1, n)) 71 | R = np.tile(G.T, (n, 1)) 72 | 73 | dists = Q + R - 2 * np.dot(Ymed, Ymed.T) 74 | dists = dists - np.tril(dists) 75 | dists = dists.reshape(n ** 2, 1) 76 | 77 | width_y = np.sqrt(0.5 * np.median(dists[dists > 0])) 78 | # ----- ----- 79 | 80 | bone = np.ones((n, 1), dtype=float) 81 | H = np.identity(n) - np.ones((n, n), dtype=float) / n 82 | 83 | K = rbf_dot(X, X, width_x) 84 | L = rbf_dot(Y, Y, width_y) 85 | 86 | Kc = np.dot(np.dot(H, K), H) 87 | Lc = np.dot(np.dot(H, L), H) 88 | 89 | # approximate 90 | 91 | testStat = np.sum(Kc.T * Lc) / n 92 | 93 | varHSIC = (Kc * Lc / 6) ** 2 94 | 95 | varHSIC = (np.sum(varHSIC) - np.trace(varHSIC)) / n / (n - 1) 96 | 97 | varHSIC = varHSIC * 72 * (n - 4) * (n - 5) / n / (n - 1) / (n - 2) / (n - 3) 98 | 99 | K = K - np.diag(np.diag(K)) 100 | L = L - np.diag(np.diag(L)) 101 | 102 | muX = np.dot(np.dot(bone.T, K), bone) / n / (n - 1) 103 | muY = np.dot(np.dot(bone.T, L), bone) / n / (n - 1) 104 | 105 | mHSIC = (1 + muX * muY - muX - muY) / n 106 | 107 | al = mHSIC ** 2 / varHSIC # k_appr 108 | bet = varHSIC * n / mHSIC # theta_appr 109 | 110 | pval = 1 - gamma.cdf(testStat, al, bet) 111 | 112 | thresh = gamma.ppf(1 - alph, al, scale=bet)[0][0] 113 | 114 | return (testStat, thresh) 115 | 116 | 117 | def independent(x, y, alpha=0.05): 118 | lens = len(x) 119 | x1 = x.reshape(lens, 1) 120 | y1 = y.reshape(lens, 1) 121 | testStat, thresh = hsic_gam(x1, y1, alpha) 122 | 123 | if testStat < thresh: 124 | return True 125 | elif testStat > thresh: 126 | return False 127 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ -------------------------------------------------------------------------------- /cdmir/utils/kernel/gaussian.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import Enum 4 | 5 | from numpy import exp, median, ndarray, shape, sqrt 6 | from numpy.random import permutation 7 | from pandas import DataFrame 8 | from scipy.spatial.distance import cdist, pdist, squareform 9 | 10 | from ._base import BaseKernel 11 | 12 | 13 | class GaussianKernel(BaseKernel): 14 | class WidthStrategyEnum(Enum): 15 | manual = 1, 16 | median = 2, 17 | empirical_kci = 3, 18 | empirical_hsic = 4 19 | 20 | def __init__(self, width: float = 1.0, width_strategy: WidthStrategyEnum = WidthStrategyEnum.manual): 21 | super().__init__() 22 | 23 | self.width = width 24 | self.width_strategy = width_strategy 25 | 26 | def __call__(self, xs: ndarray, ys: ndarray, *args, **kwargs): 27 | return self.__kernel(xs, ys, self.__kernel_func) 28 | 29 | def __kernel(self, xs: ndarray, ys: ndarray, kernel_func, *args, **kwargs): 30 | if self.width_strategy != self.WidthStrategyEnum.manual: 31 | self.__update_kernel_width_by_width_strategy(self.width_strategy, xs) 32 | 33 | dict_key = hash(str((xs, ys, self.width))) # add width to cache 34 | if self.cache_dict.__contains__(dict_key): 35 | res = self.cache_dict[dict_key] 36 | else: 37 | res = kernel_func(xs, ys) 38 | self.cache_dict[dict_key] = res 39 | 40 | return res 41 | 42 | def __kernel_func(self, x: ndarray, y: ndarray): 43 | if y is None: 44 | sq_dists = squareform(pdist(x, 'sqeuclidean')) #计算矩阵每行与其他行之间的距离,然后把距离转化成方阵 45 | else: 46 | assert (shape(x)[1] == shape(y)[1]) #如果x和y的列数一样 47 | sq_dists = cdist(x, y, 'sqeuclidean')#计算两个集合向量之间的距离 48 | k = exp(-0.5 * sq_dists * self.width) 49 | return k 50 | 51 | def __update_kernel_width_by_width_strategy(self, strategy, data): 52 | if type(data) == ndarray: 53 | pass 54 | elif type(data) == DataFrame: 55 | data = data.values 56 | else: 57 | raise Exception("'data' must be ndarray or DataFrame!") 58 | 59 | if strategy == self.width_strategy.median: 60 | width = self.__cal_kernel_width_by_median(data) 61 | elif strategy == self.width_strategy.empirical_kci: 62 | width = self.__cal_kernel_width_by_empirical_kci(data) 63 | elif strategy == self.width_strategy.empirical_hsic: 64 | width = self.__cal_kernel_width_by_empirical_hsic(data) 65 | else: 66 | raise NotImplementedError("width_strategy '{}' is not implemented!".format(strategy)) 67 | 68 | self.width = width 69 | 70 | @staticmethod 71 | def __cal_kernel_width_by_median(data_x): 72 | n = shape(data_x)[0] 73 | if n > 1000: 74 | data_x = data_x[permutation(n)[:1000], :] 75 | dists = squareform(pdist(data_x, 'euclidean')) 76 | median_dist = median(dists[dists > 0]) 77 | width = sqrt(2.) * median_dist 78 | theta = 1.0 / (width ** 2) 79 | return theta 80 | 81 | @staticmethod 82 | def __cal_kernel_width_by_empirical_kci(data_x): 83 | n = shape(data_x)[0] 84 | if n < 200: 85 | width = 1.2 86 | elif n < 1200: 87 | width = 0.7 88 | else: 89 | width = 0.4 90 | theta = 1.0 / (width ** 2) 91 | return theta / shape(data_x)[1] 92 | 93 | @staticmethod 94 | def __cal_kernel_width_by_empirical_hsic(data_x): 95 | n = shape(data_x)[0] 96 | if n < 200: 97 | width = 0.8 98 | elif n < 1200: 99 | width = 0.5 100 | else: 101 | width = 0.3 102 | theta = 1.0 / (width ** 2) 103 | return theta * data_x.shape[1] 104 | -------------------------------------------------------------------------------- /cdmir/visual/plot_graph.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib import patches 3 | from numpy import abs, cos, sign, sin, sqrt 4 | 5 | from cdmir.graph import Mark 6 | 7 | 8 | def plot_graph(graph, layout, is_latent=None, figsize=None, dpi=300, node_radius=0.04, edge_circle_mark_ratio=1 / 6): 9 | if is_latent is None: 10 | is_latent = set() 11 | 12 | fig = plt.figure(figsize=figsize, dpi=dpi) 13 | ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect=1.) 14 | ax.set_axis_off() 15 | center = (0.5, 0.5) 16 | pos = layout(graph) 17 | node_str = [str(node) for node in graph.nodes] 18 | for i, node in enumerate(graph.nodes): 19 | node = str(node) 20 | if node in is_latent: 21 | node_fill_color = 'white' 22 | else: 23 | node_fill_color = '0.8' # light gray 24 | 25 | plt_node(ax, (pos[node][0] + center[0], pos[node][1] + center[1]), node_radius=node_radius, 26 | node_name=node_str[i], node_fill_color=node_fill_color) 27 | 28 | edge_circle_mark_radius = node_radius * edge_circle_mark_ratio 29 | for edge in graph.edges: 30 | node_u, node_v, mark_u, mark_v = edge 31 | plt_edge(ax, 32 | (pos[node_u][0] + center[0], pos[node_u][1] + center[1]), 33 | (pos[node_v][0] + center[0], pos[node_v][1] + center[1]), 34 | mark_u, mark_v, 35 | node_radius=node_radius, 36 | circle_mark_radius=edge_circle_mark_radius) 37 | 38 | return fig 39 | 40 | 41 | def plt_edge(axes, 42 | pos_u, 43 | pos_v, 44 | mark_u, 45 | mark_v, 46 | node_radius, 47 | circle_mark_radius, 48 | ): 49 | assert circle_mark_radius > 0 50 | 51 | dx = pos_v[0] - pos_u[0] 52 | dy = pos_v[1] - pos_u[1] 53 | dis = sqrt(dx ** 2 + dy ** 2) 54 | offset_x = node_radius * dx / dis 55 | offset_y = node_radius * dy / dis 56 | 57 | pos_u = (pos_u[0] + offset_x, pos_u[1] + offset_y) 58 | pos_v = (pos_v[0] - offset_x, pos_v[1] - offset_y) 59 | 60 | offset_x = 2 * circle_mark_radius * dx / dis 61 | offset_y = 2 * circle_mark_radius * dy / dis 62 | 63 | if mark_u == Mark.Circle: 64 | axes.add_patch( 65 | patches.Circle((pos_u[0] + offset_x / 2, pos_u[1] + offset_y / 2), circle_mark_radius, facecolor='white', 66 | edgecolor='black')) 67 | pos_u = (pos_u[0] + offset_x, pos_u[1] + offset_y) 68 | if mark_v == Mark.Circle: 69 | axes.add_patch( 70 | patches.Circle((pos_v[0] - offset_x / 2, pos_v[1] - offset_y / 2), circle_mark_radius, facecolor='white', 71 | edgecolor='black')) 72 | pos_v = (pos_v[0] - offset_x, pos_v[1] - offset_y) 73 | 74 | if mark_u == Mark.Arrow: 75 | if mark_v == Mark.Arrow: 76 | arrow_style = '<|-|>' 77 | else: 78 | arrow_style = '<|-' 79 | elif mark_v == Mark.Arrow: 80 | arrow_style = '-|>' 81 | else: 82 | arrow_style = '-' 83 | axes.add_patch(patches.FancyArrowPatch(pos_u, pos_v, 84 | edgecolor='black', facecolor='black', 85 | arrowstyle=arrow_style, mutation_scale=12, shrinkA=0, shrinkB=0)) 86 | 87 | 88 | def plt_node(axes, 89 | pos, 90 | node_radius, 91 | node_name='', 92 | edge_color='black', 93 | node_fill_color='white', 94 | font_family='sans-serif', 95 | font_size=8, 96 | font_color='black', 97 | ): 98 | axes.add_patch(patches.Circle(pos, node_radius, facecolor=node_fill_color, edgecolor=edge_color)) 99 | 100 | axes.text(pos[0], pos[1], node_name, ha='center', va='center_baseline', family=font_family, size=font_size, 101 | color=font_color) 102 | -------------------------------------------------------------------------------- /cdmir/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Optional, Sequence, Tuple, Union 3 | 4 | import networkx as nx 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | 9 | 10 | def set_random_seed(seed, set_torch_seed=True): 11 | # 设置随机数种子 12 | np.random.seed(seed) 13 | random.seed(seed) 14 | if set_torch_seed: 15 | torch.manual_seed(seed) 16 | 17 | 18 | def pd2np(pd_data): 19 | return np.array(pd_data) 20 | 21 | 22 | def nx2np(nx_data): 23 | return nx.to_numpy_array(nx_data) 24 | 25 | 26 | def np2nx(np_data, create_using=None): 27 | return nx.from_numpy_array(np_data, create_using=create_using) 28 | 29 | 30 | def leaky_relu(inputs, neg_slope=0.2): 31 | assert neg_slope > 0 32 | leaky_relu_1d = lambda x: x if x > 0 else x * neg_slope 33 | leaky1d = np.vectorize(leaky_relu_1d) 34 | return leaky1d(inputs) 35 | 36 | 37 | def _random_permutation(mat): 38 | # np.random.permutation permutes first axis only 39 | perm_mat = np.random.permutation(np.eye(mat.shape[0])) 40 | return perm_mat.T @ mat @ perm_mat 41 | 42 | 43 | def _random_acyclic_orientation(graph): 44 | dag = np.tril(_random_permutation(graph), k=-1) 45 | dag_perm = _random_permutation(dag) 46 | return dag_perm 47 | 48 | 49 | def _adj2weights(adj_mat, mat_dim, w_range): 50 | uni_mat = np.random.uniform(low=w_range[0], high=w_range[1], size=[mat_dim, mat_dim]) 51 | uni_mat[np.random.rand(mat_dim, mat_dim) < 0.5] *= -1 # reverse 50% of the weights 52 | weight_mat = (adj_mat != 0).astype(float) * uni_mat 53 | return weight_mat 54 | 55 | 56 | def check_data(inputs): 57 | # 检查数据类型,一律转换为 ndarray 58 | assert isinstance(inputs, (np.ndarray, pd.DataFrame)), "plearse input ndarray or dataframe" 59 | if isinstance(inputs, pd.DataFrame): 60 | return pd2np(inputs) 61 | return inputs 62 | 63 | 64 | def erdos_renyi(n_nodes: int, n_edges: int, weight_range: Union[Sequence[float], None] = None, 65 | seed: Optional[int] = None): 66 | assert n_nodes > 0, "The numbers of nodes must be larger than 0" 67 | set_random_seed(seed) 68 | # erdos renyi 69 | egdes_prob = (n_edges * 2) / (n_nodes ** 2) 70 | nx_graph = nx.erdos_renyi_graph(n=n_nodes, p=egdes_prob, seed=seed) 71 | np_graph = nx2np(nx_graph) 72 | np_dag = _random_acyclic_orientation(np_graph) 73 | if weight_range is None: 74 | return np_dag 75 | else: 76 | weights = _adj2weights(np_dag, n_nodes, weight_range) 77 | return weights 78 | 79 | 80 | def _generate_uniform_mat(n_nodes, cond_thresh): 81 | """ 82 | generate a random matrix by sampling each element uniformly at random 83 | check condition number versus a condition threshold 84 | """ 85 | A = np.random.uniform(0, 2, (n_nodes, n_nodes)) - 1 86 | for i in range(n_nodes): 87 | A[:, i] /= np.sqrt(((A[:, i]) ** 2).sum()) 88 | 89 | while np.linalg.cond(A) > cond_thresh: 90 | # generate a new A matrix! 91 | A = np.random.uniform(0, 2, (n_nodes, n_nodes)) - 1 92 | for i in range(n_nodes): 93 | A[:, i] /= np.sqrt((A[:, i] ** 2).sum()) 94 | 95 | return A 96 | 97 | 98 | def generate_lag_transitions(n_nodes: int, max_lag: int, seed: Optional[int] = None, accept_per: int = 25, 99 | niter4cond_thresh: int = 1e4): 100 | assert n_nodes > 0 101 | assert max_lag > 0 102 | set_random_seed(seed) 103 | cond_list = [] 104 | for _ in range(int(niter4cond_thresh)): 105 | A = np.random.uniform(1, 2, (n_nodes, n_nodes)) 106 | for i in range(n_nodes): 107 | A[:, i] /= np.sqrt((A[:, i] ** 2).sum()) 108 | cond_list.append(np.linalg.cond(A)) 109 | 110 | cond_thresh = np.percentile(cond_list, accept_per) 111 | transitions = [] 112 | for lag in range(max_lag): 113 | B = _generate_uniform_mat(n_nodes, cond_thresh) 114 | transitions.append(B) 115 | transitions.reverse() 116 | 117 | return np.array(transitions) 118 | -------------------------------------------------------------------------------- /cdmir/discovery/Tensor_Rank/LearnCausalCluster.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from itertools import permutations, combinations 4 | from numpy.linalg import matrix_rank 5 | # import GBNdata 6 | import tensorly as tl 7 | import random 8 | import cdmir.discovery.Tensor_Rank.Gtest as Gtest 9 | import random 10 | 11 | 12 | 13 | 14 | def calculate_joint_distribution2(df, var1, var2, var3, var4): 15 | """ 16 | Calculate the joint probability tensor. 17 | :param df: array-like 18 | Input data. 19 | :param var1: str 20 | Index of variables. 21 | :param var2: str 22 | Index of variables. 23 | :param var3: str 24 | Index of variables. 25 | :param var4: str 26 | Index of variables. 27 | :return: 28 | The joint probability distribution tensor of variables. 29 | """ 30 | 31 | # Calculate the joint frequency table of four variables 32 | joint_freq = pd.crosstab(index=[df[var1], df[var2], df[var3]], columns=df[var4]) 33 | 34 | # Create a four-dimensional tensor and initialize it to 0 35 | tensor_shape = (len(df[var1].unique()), len(df[var2].unique()), 36 | len(df[var3].unique()), len(df[var4].unique())) 37 | 38 | joint_freq_tensor = np.zeros(tensor_shape) 39 | 40 | # Filling tensor 41 | for i, val1 in enumerate(sorted(df[var1].unique())): 42 | for j, val2 in enumerate(sorted(df[var2].unique())): 43 | for k, val3 in enumerate(sorted(df[var3].unique())): 44 | for l, val4 in enumerate(sorted(df[var4].unique())): 45 | if (val1, val2, val3) in joint_freq.index and val4 in joint_freq.columns: 46 | joint_freq_tensor[i, j, k, l] = joint_freq.loc[(val1, val2, val3), val4] 47 | else: 48 | joint_freq_tensor[i, j, k, l] = 0 49 | 50 | # Calculate the joint probability tensor 51 | joint_prob_tensor = joint_freq_tensor / joint_freq_tensor.sum() 52 | 53 | return joint_freq_tensor 54 | 55 | 56 | 57 | def LearnCausalCluster(data, LSupp=2, alhpa=0.05): 58 | """ 59 | Identify causal clusters. 60 | :param data: array-like 61 | Input data. 62 | :param LSupp: int 63 | Support set for hidden variables. 64 | :param alhpa: float 65 | Confidence level for the test rank of goodness of fit test. 66 | :return: 67 | Causal clustering learned. 68 | """ 69 | 70 | # Generate a 3-variable combination based on all observed variables. 71 | indexs = list(data.columns) 72 | #indexs = ['O1a', 'O1b', 'O1c', 'O2a', 'O2b', 'O2c'] 73 | 74 | combinations_list = list(combinations(indexs, 3)) 75 | 76 | List_Result = [] 77 | E_Result = [] 78 | 79 | CausalCluster = [] 80 | 81 | flag = True 82 | 83 | # Test the tensor rank of the corresponding probability distribution for all but three variables 84 | for clist in combinations_list: 85 | tempindex = indexs.copy() 86 | for i in clist: 87 | if i in tempindex: 88 | tempindex.remove(i) 89 | 90 | for v4 in tempindex: 91 | v1 = clist[0] 92 | v2 = clist[1] 93 | v3 = clist[2] 94 | 95 | # Calculate the joint probability tensor 96 | tensor = calculate_joint_distribution2(data,v1,v2,v3,v4) 97 | 98 | # Goodness of fit test tests the tensor rank 99 | pval = Gtest.test_goodness_of_fit(tensor, LSupp) 100 | 101 | print(clist,v4,pval) 102 | 103 | if pval < alhpa: 104 | flag = False 105 | break 106 | 107 | 108 | print('-------------------') 109 | if flag: 110 | print('This is a causal cluster: ', clist) 111 | CausalCluster.append(list(clist)) 112 | print('-------------------') 113 | flag = True 114 | 115 | 116 | 117 | return CausalCluster 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /cdmir/utils/metrics/graph_evaluation.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations, permutations 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.metrics import f1_score, precision_score, recall_score 6 | 7 | from cdmir.graph import Graph, PDAG 8 | 9 | 10 | 11 | def _miss_match_graph_type(true_graph: Graph, est_graph: Graph): 12 | assert type(true_graph) == type(est_graph), 'The graph type cannot be matched.' 13 | 14 | 15 | def _miss_match_graph_nodes(true_graph: Graph, est_graph: Graph): 16 | assert tuple(true_graph.nodes) == tuple(est_graph.nodes), 'The graph nodes cannot be matched.' 17 | 18 | def graph_equal(true_graph: Graph, est_graph: Graph): 19 | _miss_match_graph_type(true_graph, est_graph) 20 | _miss_match_graph_nodes(true_graph, est_graph) 21 | return (true_graph.to_numpy() == est_graph.to_numpy()).all() 22 | 23 | 24 | def skeleton_evaluation(true_graph: Graph, est_graph: Graph): 25 | _miss_match_graph_type(true_graph, est_graph) 26 | _miss_match_graph_nodes(true_graph, est_graph) 27 | 28 | true_skeleton = [true_graph.is_connected(node_u, node_v) for node_u, node_v in combinations(true_graph.nodes, 2)] 29 | est_skeleton = [est_graph.is_connected(node_u, node_v) for node_u, node_v in combinations(est_graph.nodes, 2)] 30 | 31 | return { 32 | 'precision': precision_score(true_skeleton, est_skeleton), 33 | 'recall': recall_score(true_skeleton, est_skeleton), 34 | 'f1': f1_score(true_skeleton, est_skeleton) 35 | } 36 | 37 | 38 | def arrow_evaluation(true_graph: Graph, est_graph: Graph): 39 | _miss_match_graph_type(true_graph, est_graph) 40 | _miss_match_graph_nodes(true_graph, est_graph) 41 | 42 | true_arrow = [true_graph.is_arrow(node_u, node_v) for node_u, node_v in permutations(true_graph.nodes, 2)] 43 | est_arrow = [est_graph.is_arrow(node_u, node_v) for node_u, node_v in permutations(est_graph.nodes, 2)] 44 | 45 | return { 46 | 'precision': precision_score(true_arrow, est_arrow), 47 | 'recall': recall_score(true_arrow, est_arrow), 48 | 'f1': f1_score(true_arrow, est_arrow) 49 | } 50 | 51 | 52 | def directed_edge_evaluation(true_graph: Graph, est_graph: Graph): 53 | _miss_match_graph_type(true_graph, est_graph) 54 | _miss_match_graph_nodes(true_graph, est_graph) 55 | 56 | true_directed_edge = [true_graph.is_fully_directed(node_u, node_v) for node_u, node_v in 57 | permutations(true_graph.nodes, 2)] 58 | est_directed_edge = [est_graph.is_fully_directed(node_u, node_v) for node_u, node_v in 59 | permutations(est_graph.nodes, 2)] 60 | 61 | return { 62 | 'precision': precision_score(true_directed_edge, est_directed_edge), 63 | 'recall': recall_score(true_directed_edge, est_directed_edge), 64 | 'f1': f1_score(true_directed_edge, est_directed_edge) 65 | } 66 | 67 | 68 | def shd(true_pdag: PDAG, est_pdag: PDAG): 69 | # Tsamardinos, Ioannis, Laura E. Brown, and Constantin F. Aliferis. 70 | # "The max-min hill-climbing Bayesian network structure learning algorithm." 71 | # Machine learning 65.1 (2006): 31-78. 72 | _miss_match_graph_type(true_pdag, est_pdag) 73 | _miss_match_graph_nodes(true_pdag, est_pdag) 74 | 75 | return sum(0 if true_pdag.get_edge(node_u, node_v) == est_pdag.get_edge(node_u, node_v) else 1 76 | for node_u, node_v in combinations(true_pdag.nodes, 2)) 77 | 78 | def get_performance(fitted, real, threshold=0, drop_diag=True): 79 | if isinstance(fitted,Graph): 80 | fitted=fitted.to_numpy() 81 | if isinstance(real, Graph): 82 | real = real.to_numpy() 83 | fitted = np.abs(fitted) 84 | if drop_diag: 85 | fitted = fitted - np.diag(np.diag(fitted)) 86 | real = real - np.diag(np.diag(real)) 87 | 88 | f1 = f1_score(y_true=real.ravel(), y_pred=np.array(fitted.ravel() > threshold)) 89 | precision = precision_score(y_true=real.ravel(), y_pred=np.array(fitted.ravel() > threshold)) 90 | recall = recall_score(y_true=real.ravel(), y_pred=np.array(fitted.ravel() > threshold)) 91 | temp_result = np.array((f1, precision, recall, threshold)) 92 | result = pd.DataFrame(columns=['F1', "Precision", "Recall", "threshold"]) 93 | result.loc[0] = temp_result 94 | return result -------------------------------------------------------------------------------- /docs/source/discovery_methods/functional_based/OLC/olc.rst: -------------------------------------------------------------------------------- 1 | OLC (One-Component Latent Confounder Detection) 2 | ==================================================== 3 | 4 | Introduction 5 | ------------ 6 | 7 | OLC is a functional-based causal discovery method that detects latent confounders using higher-order cumulants. Based on the paper "Causal Discovery with Latent Confounders Based on Higher-Order Cumulants", this algorithm identifies causal relationships and latent confounders by leveraging the properties of higher-order cumulants and conditional independence tests. 8 | 9 | Usage 10 | ----- 11 | 12 | .. code-block:: python 13 | 14 | import numpy as np 15 | from cdmir.discovery.funtional_based.one_component.olc import olc 16 | 17 | # Generate or load data 18 | # Example: 1000 samples, 5 variables 19 | data = np.random.randn(1000, 5) 20 | 21 | # Set significance thresholds 22 | alpha = 0.05 # Primary significance level 23 | beta = 0.01 # Secondary significance level for more stringent tests 24 | 25 | # Run OLC algorithm 26 | adjmat, coef = olc(data, alpha=alpha, beta=beta, verbose=False) 27 | 28 | # Print results 29 | print("Adjacency Matrix:") 30 | print(adjmat) 31 | print("\nCoefficient Matrix:") 32 | print(coef) 33 | 34 | Parameters 35 | ---------- 36 | 37 | - **data**: Input data matrix of shape (n_samples, n_variables), where rows represent samples and columns represent variables. 38 | - **alpha**: Significance threshold for initial edge orientation tests (default: 0.05). 39 | - **beta**: Significance threshold for more stringent tests involving higher-order cumulants (default: 0.01). 40 | - **verbose**: If True, prints detailed information during the algorithm execution (default: False). 41 | 42 | Returns 43 | ------- 44 | 45 | - **adjmat**: Adjacency matrix of the discovered causal graph. The matrix has shape (n_variables + n_latents, n_variables + n_latents), where: 46 | - 0: No edge 47 | - 1: Directed edge 48 | - 2: Undirected edge (ambiguous direction) 49 | - Latent variables are indexed from n_variables onwards. 50 | 51 | - **coef**: Coefficient matrix of the discovered causal relationships. It has the same shape as adjmat and contains the estimated coefficients for each directed edge. 52 | 53 | Algorithm Overview 54 | ------------------ 55 | 56 | OLC follows a structured approach to causal discovery with latent confounder detection: 57 | 58 | 1. **Initialization** 59 | - Create an undirected graph (UDG) with all possible edges 60 | - Create an empty directed graph (CG) for causal relationships 61 | - Initialize KCI (Kernel-based Conditional Independence) test for independence testing 62 | 63 | 2. **Edge Orientation Phase** 64 | - Test edge orientations using linear regression and KCI tests 65 | - Remove edges and orient them in the directed graph based on significance tests 66 | - Normalize residuals and update data 67 | 68 | 3. **Clique Detection and Latent Confounder Detection** 69 | - Identify cliques in the undirected graph 70 | - Use surrogate regression to handle complex relationships 71 | - Apply higher-order cumulant (4th order) analysis to detect latent confounders 72 | - Update the adjacency matrix with detected latent confounders 73 | 74 | 4. **Refinement** 75 | - Iteratively refine the graph structure using conditional independence tests 76 | - Update surrogate variables and exogenous variables 77 | - Adjust edge orientations based on cumulant-based tests 78 | 79 | Key Techniques 80 | -------------- 81 | 82 | - **Higher-Order Cumulants**: Uses 4th order cumulants to detect latent confounders that cannot be identified using traditional covariance-based methods. 83 | 84 | - **KCI Tests**: Employs Kernel-based Conditional Independence tests for robust independence testing between variables and residuals. 85 | 86 | - **Surrogate Regression**: Implements surrogate regression to handle complex causal relationships involving multiple variables. 87 | 88 | - **Fisher's Combination Test**: Combines multiple p-values to enhance statistical power. 89 | 90 | References 91 | ---------- 92 | 93 | .. [1] Cai R, Huang Z, Chen W, et al. Causal discovery with latent confounders based on higher-order cumulants[C]//International conference on machine learning. PMLR, 2023: 3380-3407. -------------------------------------------------------------------------------- /cdmir/discovery/Tensor_Rank/Gtest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorly as tl 3 | from tensorly.decomposition import parafac 4 | from tensorly import kruskal_to_tensor 5 | from scipy import stats 6 | from tensorly.decomposition import parafac,non_negative_parafac 7 | # Ensure Tensorly uses the NumPy backend explicitly 8 | tl.set_backend('numpy') 9 | 10 | def compute_cp_decomposition(tensor, rank): 11 | """Perform CP decomposition on a given tensor with a specified rank.""" 12 | factors = non_negative_parafac(tensor, rank=rank, init='random', tol=1e-50, verbose=False, n_iter_max=800,cvg_criterion="rec_error",normalize_factors=True) 13 | reconstructed_tensor = kruskal_to_tensor(factors) 14 | return reconstructed_tensor, factors 15 | 16 | def frobenius_norm_error(tensor, reconstructed_tensor): 17 | """Calculate the relative Frobenius norm error between observed and reconstructed tensors.""" 18 | # Ensure tensors are in float format for tensorly calculations 19 | tensor = tensor.astype(np.float64) 20 | reconstructed_tensor = reconstructed_tensor.astype(np.float64) 21 | 22 | # Calculate the residual tensor 23 | residual_tensor = tensor - reconstructed_tensor 24 | residual_frobenius_norm = np.linalg.norm(residual_tensor.ravel(), ord=2) 25 | tensor_frobenius_norm = np.linalg.norm(tensor.ravel(), ord=2) 26 | relative_error = residual_frobenius_norm / tensor_frobenius_norm 27 | return relative_error 28 | 29 | def normalize_expected(observed, expected): 30 | """Normalize the expected tensor to have the same sum as the observed tensor.""" 31 | observed_sum = np.sum(observed) 32 | expected_sum = np.sum(expected) 33 | 34 | if expected_sum != 0: 35 | expected = expected * (observed_sum / expected_sum) 36 | 37 | return expected 38 | 39 | def chi_square_goodness_of_fit(observed, expected): 40 | """Perform the Chi-square goodness of fit test between observed and expected counts.""" 41 | observed = observed.astype(np.float64) 42 | expected = expected.astype(np.float64) 43 | 44 | # Normalize expected counts 45 | expected = normalize_expected(observed, expected) 46 | 47 | chi_square_statistic, p_value = stats.chisquare(f_obs=observed.ravel(), f_exp=expected.ravel()) 48 | return chi_square_statistic, p_value 49 | 50 | def t_test_goodness_of_fit(observed, expected): 51 | """Perform a one-sample T-test between observed and expected counts.""" 52 | observed = observed.astype(np.float64) 53 | expected = expected.astype(np.float64) 54 | residuals = observed - expected 55 | t_statistic, p_value = stats.ttest_1samp(residuals.ravel(), 0) 56 | #t_statistic, p_value = stats.ttest_rel(observed.ravel(), expected.ravel()) 57 | return t_statistic, p_value 58 | 59 | 60 | def Get_Reconstructed_Error(tensor, rank): 61 | reconstructed_tensor, _ = compute_cp_decomposition(tensor, rank) 62 | relative_error = frobenius_norm_error(tensor, reconstructed_tensor) 63 | 64 | return relative_error 65 | 66 | def test_goodness_of_fit(tensor, rank): 67 | """ 68 | Test the goodness of fit for a tensor decomposition using Frobenius norm error, Chi-square test, and T-test. 69 | 70 | Parameters: 71 | tensor (ndarray): The original four-way tensor. 72 | rank (int): The rank of the CP decomposition. 73 | 74 | Returns: 75 | dict: A dictionary containing the Frobenius norm error, Chi-square statistic, T-test statistic, and p-value. 76 | """ 77 | reconstructed_tensor, _ = compute_cp_decomposition(tensor, rank) 78 | relative_error = frobenius_norm_error(tensor, reconstructed_tensor) 79 | 80 | # Use the original tensor data as the observed counts 81 | observed = tensor 82 | expected = reconstructed_tensor 83 | 84 | chi_square_statistic, chi_square_p_value = chi_square_goodness_of_fit(observed, expected) 85 | t_statistic, t_p_value = t_test_goodness_of_fit(observed, expected) 86 | 87 | return chi_square_p_value 88 | 89 | return { 90 | 'Frobenius Norm Error': relative_error, 91 | 'Chi-square Statistic': chi_square_statistic, 92 | 'Chi-square P-value': chi_square_p_value, 93 | 'T-test Statistic': t_statistic, 94 | 'T-test P-value': t_p_value 95 | } 96 | -------------------------------------------------------------------------------- /cdmir/utils/independence/_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import List, Tuple 4 | 5 | from numpy import ndarray 6 | from pandas import DataFrame 7 | 8 | 9 | class BaseConditionalIndependenceTest(object): 10 | """ 11 | Base model for conditional independence test (CIT). 12 | """ 13 | 14 | @staticmethod 15 | def __init_input_data(data: ndarray | DataFrame): 16 | if type(data) == ndarray: 17 | var_dim = data.shape[1] 18 | var_names = ["x%d" % i for i in range(var_dim)] 19 | var_values = data 20 | else: 21 | var_names = data.columns 22 | var_values = data.values 23 | name_dict = {var_names[index]: index for index in range(len(var_names))} 24 | name_to_index = lambda name: name_dict[name] 25 | index_to_name = lambda index: var_names[index] 26 | return var_values, var_names, name_to_index, index_to_name 27 | 28 | def __init__(self, data: ndarray | DataFrame, *args, **kwargs): 29 | self._data, self._names, self._name_to_index, self._index_to_name = self.__init_input_data(data) 30 | self.cache_dict = dict() 31 | 32 | def __input_to_list(self, input: int | str | List[int | str] | ndarray | None) -> List[int] | None: 33 | if input is None: 34 | final_res = None 35 | elif type(input) == int: 36 | final_res = [input] 37 | elif type(input) == str: 38 | final_res = [self._name_to_index(input)] 39 | elif type(input) == list: 40 | if len(input) == 0: 41 | final_res = None 42 | elif type(input[0]) == int: 43 | final_res = input 44 | elif type(input[0]) == str: 45 | final_res = [self._name_to_index(name) for name in input] 46 | else: 47 | raise Exception("data type should be int or str!") 48 | else: 49 | if len(input) == 0: 50 | final_res = None 51 | else: 52 | final_res = input.tolist() 53 | return final_res 54 | 55 | def _compute_p_value(self, xs: int | str | List[int | str] | ndarray, 56 | ys: int | str | List[int | str] | ndarray, 57 | zs: int | str | List[int | str] | ndarray | None = None, 58 | compute_p_value_without_condition_func=None, 59 | compute_p_value_with_condition_func=None) -> Tuple[float, float | ndarray | None]: 60 | x_ids = self.__input_to_list(xs) 61 | y_ids = self.__input_to_list(ys) 62 | z_ids = self.__input_to_list(zs) 63 | hash_key = hash(str((x_ids, y_ids, z_ids))) 64 | if self.cache_dict.__contains__(hash_key): 65 | p_value, stat = self.cache_dict[hash_key] 66 | else: 67 | if z_ids is None: 68 | if compute_p_value_without_condition_func is None: 69 | raise Exception("'compute_p_value_without_condition_func' has not been given!") 70 | p_value, stat = compute_p_value_without_condition_func(x_ids, y_ids) 71 | else: 72 | if compute_p_value_with_condition_func is None: 73 | raise Exception("'compute_p_value_with_condition_func' has not been given!") 74 | p_value, stat = compute_p_value_with_condition_func(x_ids, y_ids, z_ids) 75 | self.cache_dict[hash_key] = (p_value, stat) 76 | 77 | return p_value, stat 78 | 79 | def __call__(self, xs: int | str | List[int | str] | ndarray, 80 | ys: int | str | List[int | str] | ndarray, 81 | zs: int | str | List[int | str] | ndarray | None = None, *args, **kwargs) -> Tuple[ 82 | float, float | ndarray | None]: 83 | return self._compute_p_value(xs, ys, zs, 84 | self.__compute_p_value_without_condition, self.__compute_p_value_with_condition) 85 | 86 | def __compute_p_value_with_condition(self, x_ids: List[int], y_ids: List[int], z_ids: List[int]) \ 87 | -> Tuple[float, float | ndarray | None]: 88 | raise NotImplementedError() 89 | 90 | def __compute_p_value_without_condition(self, x_ids: List[int], y_ids: List[int]) \ 91 | -> Tuple[float, float | ndarray | None]: 92 | raise NotImplementedError() 93 | -------------------------------------------------------------------------------- /cdmir/graph/digraph.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from itertools import product 3 | from typing import Iterable 4 | 5 | from cdmir.graph import Edge, Graph, Mark 6 | 7 | 8 | class DiGraph(Graph): 9 | def add_edge(self, edge: Edge, overwrite=False): 10 | self.check_mark(edge.mark_u, [Mark.Tail]) 11 | self.check_mark(edge.mark_v, [Mark.Arrow]) 12 | super().add_edge(edge, overwrite=overwrite) 13 | 14 | def get_parents(self, node_u): 15 | for node_v in self.get_neighbours(node_u): 16 | if self.is_arrow(node_v, node_u): 17 | yield node_v 18 | 19 | def get_children(self, node_u): 20 | for node_v in self.get_neighbours(node_u): 21 | if self.is_arrow(node_u, node_v): 22 | yield node_v 23 | 24 | def get_reachable_nodes(self, x, z: Iterable = None): 25 | ''' 26 | Koller, D., & Friedman, N. (2009). Probabilistic Graphical Models: Principles and Techniques (1st ed.). The MIT Press. 27 | 28 | Parameters 29 | ---------- 30 | x 31 | z 32 | 33 | Returns 34 | ------- 35 | 36 | ''' 37 | 38 | # Phase I: Insert all ancestors of z into a 39 | node_l_queue = deque(z) 40 | a = set() 41 | node_l_visited = {node: False for node in self.nodes} 42 | for node in z: 43 | node_l_visited[node] = True 44 | 45 | while len(node_l_queue) != 0: 46 | y = node_l_queue.popleft() 47 | if y not in a: 48 | for pa_y in self.get_parents(y): 49 | if not node_l_visited[pa_y]: 50 | node_l_queue.append(pa_y) 51 | node_l_visited[pa_y] = True 52 | a |= {y} 53 | 54 | # Phase II: Traverse active trails starting from x 55 | 56 | # 0 means trailing up through y. 57 | # 1 means trailing down through y. 58 | node_direction_l_queue = deque([(x, 1)]) 59 | node_direction_l_visited = {(node, direction): False for node, direction in product(self.nodes, [0, 1])} 60 | node_direction_l_visited[(x, 1)] = True 61 | 62 | def push_queue(key): 63 | if not node_direction_l_visited[key]: 64 | node_direction_l_queue.append(key) 65 | node_direction_l_visited[key] = True 66 | 67 | while len(node_direction_l_queue) != 0: 68 | y, direction = node_direction_l_queue.popleft() 69 | if y not in z: 70 | yield y 71 | if direction == 1 and y not in z: 72 | for pa_y in self.get_parents(y): 73 | push_queue((pa_y, 1)) 74 | for ch_y in self.get_children(y): 75 | push_queue((ch_y, 0)) 76 | 77 | elif direction == 0: 78 | if y not in z: 79 | for ch_y in self.get_children(y): 80 | push_queue((ch_y, 0)) 81 | if y in a: 82 | for pa_y in self.get_parents(y): 83 | push_queue((pa_y, 1)) 84 | 85 | def is_d_separate(self, x, y, z: Iterable = ()): 86 | # print(f'{x = }, {y = } | {list(z) = }') 87 | return y not in self.get_reachable_nodes(x, z) 88 | 89 | def in_degree(self, node_u): 90 | return sum(1 for _ in self.get_parents(node_u)) 91 | 92 | def out_degree(self, node_u): 93 | return sum(1 for _ in self.get_children(node_u)) 94 | 95 | def topo_sort(self): 96 | found_parents = {node: 0 for node in self.nodes} 97 | sorted_node_list = [] 98 | q = deque() 99 | v = {node: False for node in self.nodes} 100 | for node_u in self.nodes: 101 | if self.in_degree(node_u) == found_parents[node_u]: 102 | q.append(node_u) 103 | v[node_u] = True 104 | while len(q) != 0: 105 | node_u = q.popleft() 106 | sorted_node_list.append(node_u) 107 | for ch_u in self.get_children(node_u): 108 | found_parents[ch_u] = found_parents[ch_u] + 1 109 | if not v[ch_u] and (self.in_degree(ch_u) == found_parents[ch_u]): 110 | q.append(ch_u) 111 | v[ch_u] = True 112 | 113 | return sorted_node_list 114 | def is_acyclic(self): 115 | return len(self.topo_sort()) == self.number_of_nodes() 116 | -------------------------------------------------------------------------------- /cdmir/discovery/funtional_based/SHP/Generate_Hawkes_data_from_tick.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from itertools import product 3 | import numpy as np 4 | from tick.hawkes import SimuHawkes, HawkesKernelExp 5 | import argparse 6 | import os 7 | 8 | print(os.getcwd()) 9 | 10 | 11 | def check_and_create(path: str): 12 | if os.path.exists(path): 13 | return 14 | os.mkdir(path) 15 | 16 | 17 | parser = argparse.ArgumentParser(description="data") 18 | parser.add_argument('-sd', '--save_dir', type=str, help='save dir', default='./data/Hawkes_data_from_tick') 19 | parser.add_argument('-et', '--exp_tag', type=int, help='data type', default=0) 20 | 21 | 22 | def get_artificial_data(mu_range, alpha_range, n, sample_size=30000, out_degree_rate=1.5, NE_num=40, decay=0.1, 23 | time_interval=None, seed=None): 24 | print( 25 | f'n={n},mu_range={mu_range},alpha_range={alpha_range},edge_num={round(out_degree_rate * n)},sample_size={sample_size}') 26 | rand_state = np.random.RandomState(seed=seed) 27 | 28 | edge_mat = np.zeros([n, n]) 29 | edge_select = list(filter(lambda i: i[0] < i[1], product(range(n), range(n)))) 30 | rand_state.shuffle(edge_select) 31 | for edge_ind in edge_select[:round(out_degree_rate * n)]: 32 | edge_mat[edge_ind] = 1 33 | mu = rand_state.uniform(*mu_range, n) 34 | 35 | alpha = rand_state.uniform(*alpha_range, [n, n]) 36 | alpha = edge_mat * alpha 37 | 38 | hawkes = SimuHawkes(baseline=mu, max_jumps=sample_size / NE_num, verbose=False, seed=seed) 39 | for i in range(n): 40 | for j in range(n): 41 | if (alpha[i, j] == 0): 42 | continue 43 | hawkes.set_kernel(j, i, HawkesKernelExp(alpha[i, j], decay)) 44 | 45 | event_dict = dict() 46 | for node in range(NE_num): 47 | hawkes.reset() 48 | hawkes.simulate() 49 | event_dict[node] = hawkes.timestamps 50 | 51 | event_list = [] 52 | 53 | 54 | for node in event_dict: 55 | for event_name in range(n): 56 | for timestamp in event_dict[node][event_name]: 57 | event_list.append([node, timestamp, event_name]) 58 | columns = ['seq_id', 'time_stamp', 'event_type'] 59 | event_table = pd.DataFrame(event_list, columns=columns) 60 | 61 | if time_interval is not None: 62 | event_table['time_stamp'] = (event_table['time_stamp'] / time_interval).astype('int') * time_interval 63 | 64 | events = [[event_table[(event_table['event_type'] == i) & (event_table['seq_id'] == j)][ 65 | 'time_stamp'].values.astype('float') for i in 66 | np.sort((event_table['event_type']).unique())] for j in (event_table['seq_id']).unique()] 67 | return event_table, edge_mat, alpha, mu, events 68 | 69 | 70 | def INSEM_data(sample_size=10000, lambda_x=1, theta=0.5, lambda_e=1, seed=None): 71 | rand_state = np.random.RandomState(seed=seed) 72 | X = rand_state.poisson(lambda_x, sample_size) 73 | Y = np.zeros(sample_size, dtype='int') 74 | 75 | def operator(X): 76 | return sum(rand_state.binomial(1, theta, X)) 77 | 78 | for i in range(sample_size): 79 | Y[i] = operator(X[i]) 80 | Y = Y + rand_state.poisson(lambda_e, sample_size) 81 | t = 0 82 | df = pd.DataFrame() 83 | for i, n in enumerate(X): 84 | term = [] 85 | for j in range(n): 86 | term.append([0, t, 0]) 87 | t = t + 1 88 | df = pd.concat([df, pd.DataFrame(term)]) 89 | t = 0 90 | for i, n in enumerate(Y): 91 | term = [] 92 | for j in range(n): 93 | term.append([0, t, 1]) 94 | t = t + 1 95 | df = pd.concat([df, pd.DataFrame(term)]) 96 | df.columns = ['seq_id', 'time_stamp', 'event_type'] 97 | return df 98 | 99 | 100 | def generate_data(n, mu_range_str, alpha_range_str, sample_size, out_degree_rate, NE_num, decay, seed=None): 101 | alpha_range = tuple([float(i) for i in alpha_range_str.split(',')]) 102 | mu_range = tuple([float(i) for i in mu_range_str.split(',')]) 103 | event_table, edge_mat, alpha, mu, events = get_artificial_data(mu_range, alpha_range, n, sample_size=sample_size, 104 | out_degree_rate=out_degree_rate, NE_num=NE_num, 105 | decay=decay, seed=seed) 106 | 107 | return event_table, edge_mat, alpha, mu, events 108 | 109 | 110 | -------------------------------------------------------------------------------- /cdmir/effect/DoublyRobust/src/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from torch.nn.parameter import Parameter 5 | from torch.nn.modules.module import Module 6 | 7 | class GraphConvolution(Module): 8 | """ 9 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 10 | """ 11 | 12 | """ 13 | Simple graph convolutional layer (GCN layer) for graph neural networks. 14 | 15 | Typically used as a building block in graph neural networks for tasks like node classification, 16 | link prediction, and graph representation learning. 17 | 18 | :param in_features (int): Dimension of input node features 19 | :param out_features (int): Dimension of output node features after convolution 20 | :param weight (torch.nn.Parameter): Learnable weight matrix of shape (in_features, out_features) 21 | :param bias (torch.nn.Parameter or None): Optional learnable bias vector of shape (out_features,) 22 | """ 23 | 24 | def __init__(self, in_features, out_features, bias=True): 25 | """ 26 | Initialize graph convolutional layer 27 | 28 | :param in_features (int): Number of input features per node 29 | :param out_features (int): Number of output features per node after convolution 30 | :param bias (bool, optional): Whether to include a learnable bias term. Defaults to True. 31 | """ 32 | 33 | # Call parent Module constructor 34 | super(GraphConvolution, self).__init__() 35 | # Store dimension parameters 36 | self.in_features = in_features 37 | self.out_features = out_features 38 | # Initialize learnable weight matrix 39 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 40 | self.register_parameter("weight", self.weight) 41 | # Initialize optional bias term 42 | if bias: 43 | self.bias = Parameter(torch.FloatTensor(out_features)) 44 | self.register_parameter('bias', self.bias) 45 | else: 46 | self.register_parameter('bias', None) 47 | # Initialize parameters with proper initialization scheme 48 | self.reset_parameters() 49 | 50 | def reset_parameters(self): 51 | """ 52 | Initialize learnable parameters using uniform distribution 53 | 54 | Follows the initialization strategy recommended in the original GCN paper: 55 | - Weights are initialized uniformly from [-1/sqrt(out_features), 1/sqrt(out_features)] 56 | - Bias terms (if present) use the same initialization range 57 | 58 | This initialization helps stabilize the variance of activations during training. 59 | """ 60 | # Calculate standard deviation range for initialization 61 | stdv = 1. / math.sqrt(self.weight.size(1)) 62 | # Initialize weight matrix 63 | self.weight.data.uniform_(-stdv, stdv) 64 | # Initialize bias term if present 65 | if self.bias is not None: 66 | self.bias.data.uniform_(-stdv, stdv) 67 | 68 | def forward(self, input, adj): 69 | """ 70 | Implements the core GCN operation: H^{(l+1)} = σ(Ã H^{(l)} W^{(l)} + b) 71 | :param input (torch.Tensor): Input node feature matrix of shape (N, in_features), 72 | where N is the number of nodes 73 | :param adj (torch.Tensor): Sparse adjacency matrix of the graph of shape (N, N), 74 | representing node connectivity 75 | Returns: 76 | torch.Tensor: Output feature matrix after graph convolution of shape (N, out_features) 77 | """ 78 | # Step 1: Feature transformation (H * W) 79 | support = torch.mm(input, self.weight) 80 | # Step 2: Neighbor aggregation (A * (H * W)) 81 | output = torch.spmm(adj, support) 82 | # Step 3: Add bias term (if present) 83 | if self.bias is not None: 84 | return output + self.bias 85 | else: 86 | return output 87 | 88 | def __repr__(self): 89 | """ 90 | String representation of the layer for debugging/printing 91 | Returns: 92 | str: Layer description showing input and output feature dimensions 93 | """ 94 | return self.__class__.__name__ + ' (' \ 95 | + str(self.in_features) + ' -> ' \ 96 | + str(self.out_features) + ')' --------------------------------------------------------------------------------