├── .github
└── workflows
│ ├── cloc.yml
│ └── pylint.yml
├── README.md
├── assets
├── overview_fig.png
├── snp100.csv
└── snp100_results.png
├── configs
├── data_generation
│ ├── perturb_data.yaml
│ ├── synthetic_data_linear.yaml
│ └── synthetic_data_nonlinear.yaml
├── dream3
│ ├── dynotears.yaml
│ ├── mcd.yaml
│ ├── mcd_linear.yaml
│ ├── pcmci.yaml
│ ├── rhino.yaml
│ ├── rhino_linear.yaml
│ └── varlingam.yaml
├── main.yaml
├── netsim
│ ├── dynotears.yaml
│ ├── mcd.yaml
│ ├── mcd_linear.yaml
│ ├── pcmci.yaml
│ ├── rhino.yaml
│ ├── rhino_linear.yaml
│ └── varlingam.yaml
├── snp100
│ ├── dynotears.yaml
│ ├── mcd.yaml
│ ├── mcd_linear.yaml
│ ├── pcmci.yaml
│ ├── rhino.yaml
│ └── varlingam.yaml
└── synthetic
│ ├── dynotears.yaml
│ ├── mcd.yaml
│ ├── mcd_linear.yaml
│ ├── pcmci.yaml
│ ├── rhino.yaml
│ ├── rhino_linear.yaml
│ └── varlingam.yaml
├── scripts
├── generate_perturb_datasets.sh
├── generate_snp100.sh
├── generate_synthetic_datasets_linear.sh
├── generate_synthetic_datasets_nonlinear.sh
├── setup_dream3.sh
└── setup_netsim.sh
├── setup.py
└── src
├── __init__.py
├── baselines
├── BaselineTrainer.py
├── DYNOTEARSTrainer.py
├── PCMCITrainer.py
├── VARLiNGAMTrainer.py
└── __init__.py
├── dataset
├── BaselineTSDataset.py
└── FragmentDataset.py
├── model
├── BaseTrainer.py
├── MCDTrainer.py
├── RhinoTrainer.py
└── generate_model.py
├── modules
├── CausalDecoder.py
├── LinearCausalGraph.py
├── MixtureSelectionLogits.py
├── MultiCausalDecoder.py
├── MultiEmbedding.py
├── MultiLinearCausalGraph.py
├── MultiTemporalHyperNet.py
├── TemporalConditionalSplineFlow.py
├── TemporalHyperNet.py
└── adjacency_matrices
│ ├── AdjMatrix.py
│ ├── MultiTemporalAdjacencyMatrix.py
│ ├── TemporalAdjacencyMatrix.py
│ ├── ThreeWayGraphDist.py
│ ├── TwoWayGraphDist.py
│ └── TwoWayTemporalAdjacencyMatrix.py
├── train.py
├── training
└── auglag.py
└── utils
├── causality_utils.py
├── config_utils.py
├── data_gen
├── data_generation_utils.py
├── generate_perturb_syn.py
├── generate_stock.py
├── generate_synthetic_data.py
├── process_dream3.py
├── process_netsim.py
└── splines.py
├── data_utils
├── data_format_utils.py
└── dataloading_utils.py
├── loss_utils.py
├── metrics_utils.py
├── torch_utils.py
└── utils.py
/.github/workflows/cloc.yml:
--------------------------------------------------------------------------------
1 | name: Count Lines of Code
2 |
3 | # Controls when the action will run. Triggers the workflow on push or pull request
4 | # events but only for the main branch
5 | on: [pull_request]
6 |
7 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
8 | jobs:
9 |
10 | # This workflow contains a single job called "build"
11 | cloc:
12 |
13 | # The type of runner that the job will run on
14 | runs-on: ubuntu-latest
15 |
16 | # Steps represent a sequence of tasks that will be executed as part of the job
17 | steps:
18 |
19 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
20 | - name: Checkout repo
21 | uses: actions/checkout@v3
22 |
23 | # Runs djdefi/cloc-action
24 | - name: Install and run cloc
25 | run: |
26 | sudo apt-get install cloc
27 | cloc src --csv --quiet --report-file=cloc_report.csv
28 |
29 | - name: Read CSV
30 | id: csv
31 | uses: juliangruber/read-file-action@v1
32 | with:
33 | path: ./cloc_report.csv
34 |
35 | - name: Create MD
36 | uses: petems/csv-to-md-table-action@master
37 | id: csv-table-output
38 | with:
39 | csvinput: ${{ steps.csv.outputs.content }}
40 |
41 | - name: Write file
42 | uses: "DamianReeves/write-file-action@master"
43 | with:
44 | path: ./cloc_report.md
45 | write-mode: overwrite
46 | contents: |
47 | ${{steps.csv-table-output.outputs.markdown-table}}
48 |
49 | - name: PR comment with file
50 | uses: thollander/actions-comment-pull-request@v2
51 | with:
52 | filePath: ./cloc_report.md
53 |
--------------------------------------------------------------------------------
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: Pylint
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | matrix:
10 | python-version: ["3.8", "3.9", "3.10"]
11 | steps:
12 | - uses: actions/checkout@v4
13 | - name: Set up Python ${{ matrix.python-version }}
14 | uses: actions/setup-python@v3
15 | with:
16 | python-version: ${{ matrix.python-version }}
17 | - name: Install dependencies
18 | run: |
19 | python -m pip install --upgrade pip
20 | pip install pylint
21 | - name: Analysing the code with pylint
22 | run: |
23 | pylint --disable=E402,E731,F541,W291,E122,E127,F401,E266,E241,C901,E741,W293,F811,W503,E203,F403,F405,B007,E0401,W0221 --max-line-length=150 $(git ls-files '*.py')
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Discovering Mixtures of Structural Causal Models from Time Series Data
2 |
3 | Implementation of the paper "Discovering Mixtures of Structural Causal Models from Time Series Data", to appear at ICML 2024, Vienna.
4 |
5 | Mixture Causal Discovery (MCD) aims to infer multiple causal graphs from time-series data.
6 |
7 |
8 |

9 |
Model Overview
10 |
11 |
12 |
13 |

14 |
Results on the S&P100 dataset
15 |
16 |
17 |
18 | ## Requirements
19 |
20 | - NVIDIA GPU with minimum CUDA 11.8 installed.
21 | - Make sure you have `conda` installed.
22 |
23 | ## Setup
24 |
25 | Create a conda environment and install the prerequisite packages:
26 | ```
27 | conda create -n mcd python=3.9 -y && \
28 | conda run --no-capture-output -n mcd pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 && \
29 | conda run --no-capture-output -n mcd pip3 install lightning matplotlib numpy scikit-learn seaborn \
30 | cdt wandb igraph pyro-ppl hydra-core yahoofinancials && \
31 | ```
32 |
33 | You also need the `graphviz` library. This library can be installed on Ubuntu systems using the command:
34 | ```
35 | sudo apt-get install -y git && \
36 | sudo apt-get install -y graphviz graphviz-dev
37 | ```
38 |
39 | For baselines:
40 | ```
41 | conda create -n baselines python=3.8 -y && \
42 | conda run --no-capture-output -n baselines pip3 install pygraphviz wandb tigramite hydra-core pyro-ppl lightning causalnex matplotlib cdt seaborn lingam
43 | ```
44 |
45 |
46 | ## Dataset generation
47 |
48 | Generate the datasets using the script files in the `scripts/` folder.
49 |
50 | - Linear synthetic dataset: Run `./scripts/generate_synthetic_datasets_linear.sh`
51 | - Nonlinear synthetic dataset: Run `./scripts/generate_synthetic_datasets_nonlinear.sh`
52 | - Netsim datasets: Run `./scripts/setup_netsim.sh`
53 | - DREAM3: Run `./scripts/setup_dream3.sh`
54 | - S&P100: Run `./scripts/generate_snp100.sh`
55 |
56 | ## Running the code
57 |
58 | Change the name of the `wandb` project in the config file.
59 |
60 | - Linear synthetic dataset: Run `python3 -m src.train +dataset=ER_ER_num_graphs__lag_2_dim__NoHistDep_0.5_linear_gaussian_con_1_seed_0_n_samples_1000 +synthetic=mcd_linear`. Change `` and `` to the correct setting.
61 | - Nonlinear synthetic dataset: Run `python3 -m src.train +dataset=ER_ER_num_graphs__lag_2_dim__HistDep_0.5_mlp_spline_product_con_1_seed_0_n_samples_1000 +synthetic=mcd`. Change `` and `` to the correct setting.
62 | - Netsim-mixture: Run `python3 -m src.train +dataset=netsim_15_200_permuted +netsim=mcd`
63 | - DREAM3: Run `python3 -m src.train +dataset=dream3 +dream3=mcd`
64 | - S&P100: Run `python3 -m src.train +dataset=snp100 +snp100=mcd`
65 |
66 | Results are stored in the `results/` folder.
67 |
68 | ## Acknowledgement
69 |
70 | We implemented some parts of our framework using code from [Project Causica](https://github.com/microsoft/causica).
71 |
72 | ## Citation
73 |
74 | If you find this work useful, please consider citing us.
75 |
76 | ```
77 | @inproceedings{varambally2024discovering,
78 | author = {Varambally, Sumanth and Ma, Yi-An and Yu, Rose},
79 | title = {Discovering Mixtures of Structural Causal Models from Time Series Data},
80 | booktitle = {International Conference on Machine Learning, {ICML} 2024},
81 | series = {Proceedings of Machine Learning Research},
82 | year = {2024}
83 | }
84 | ```
85 |
--------------------------------------------------------------------------------
/assets/overview_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rose-STL-Lab/MCD/5276806eb761545d2e99f53d5135289d7def20e3/assets/overview_fig.png
--------------------------------------------------------------------------------
/assets/snp100.csv:
--------------------------------------------------------------------------------
1 | Symbol,Name,Sector
2 | AAPL,Apple,Technology
3 | ABBV,AbbVie,Healthcare
4 | ABT,Abbott,Healthcare
5 | ACN,Accenture,Technology
6 | ADBE,Adobe,Technology
7 | AIG,American International Group,Financials
8 | AMD,AMD,Technology
9 | AMGN,Amgen,Healthcare
10 | AMT,American Tower,Real Estate
11 | AMZN,Amazon,Consumer Discretionary
12 | AVGO,Broadcom,Technology
13 | AXP,American Express,Financials
14 | BA,Boeing,Industrials
15 | BAC,Bank of America,Financials
16 | BK,BNY Mellon,Financials
17 | BKNG,Booking Holdings,Consumer Discretionary
18 | BLK,BlackRock,Financials
19 | BMY,Bristol Myers Squibb,Healthcare
20 | BRK-B,Berkshire Hathaway (Class B),Financials
21 | C,Citigroup,Financials
22 | CAT,Caterpillar,Industrials
23 | CHTR,Charter Communications,Communication Services
24 | CL,Colgate-Palmolive,Consumer Staples
25 | CMCSA,Comcast,Communication Services
26 | COF,Capital One,Financials
27 | COP,ConocoPhillips,Energy
28 | COST,Costco,Consumer Staples
29 | CRM,Salesforce,Technology
30 | CSCO,Cisco,Technology
31 | CVS,CVS Health,Healthcare
32 | CVX,Chevron,Energy
33 | DE,Deere & Company,Industrials
34 | DHR,Danaher,Healthcare
35 | DIS,Disney,Communication Services
36 | DUK,Duke Energy,Utilities
37 | EMR,Emerson,Industrials
38 | EXC,Exelon,Utilities
39 | F,Ford,Consumer Discretionary
40 | FDX,FedEx,Industrials
41 | GD,General Dynamics,Industrials
42 | GE,GE,Industrials
43 | GILD,Gilead,Healthcare
44 | GM,GM,Consumer Discretionary
45 | GOOG,Alphabet (Class C),Communication Services
46 | GOOGL,Alphabet (Class A),Communication Services
47 | GS,Goldman Sachs,Financials
48 | HD,Home Depot,Consumer Discretionary
49 | HON,Honeywell,Industrials
50 | IBM,IBM,Technology
51 | INTC,Intel,Technology
52 | JNJ,Johnson & Johnson,Healthcare
53 | JPM,JPMorgan Chase,Financials
54 | KHC,Kraft Heinz,Consumer Staples
55 | KO,Coca-Cola,Consumer Staples
56 | LIN,Linde,Materials
57 | LLY,Lilly,Healthcare
58 | LMT,Lockheed Martin,Industrials
59 | LOW,Lowe's,Consumer Discretionary
60 | MA,Mastercard,Technology
61 | MCD,McDonald's,Consumer Discretionary
62 | MDLZ,Mondelēz International,Consumer Staples
63 | MDT,Medtronic,Healthcare
64 | MET,MetLife,Financials
65 | META,Meta,Communication Services
66 | MMM,3M,Industrials
67 | MO,Altria,Consumer Staples
68 | MRK,Merck,Healthcare
69 | MS,Morgan Stanley,Financials
70 | MSFT,Microsoft,Technology
71 | NEE,NextEra Energy,Utilities
72 | NFLX,Netflix,Communication Services
73 | NKE,Nike,Consumer Discretionary
74 | NVDA,Nvidia,Technology
75 | ORCL,Oracle,Technology
76 | PEP,PepsiCo,Consumer Staples
77 | PFE,Pfizer,Healthcare
78 | PG,Procter & Gamble,Consumer Staples
79 | PM,Philip Morris International,Consumer Staples
80 | PYPL,PayPal,Technology
81 | QCOM,Qualcomm,Technology
82 | RTX,RTX Corporation,Industrials
83 | SBUX,Starbucks,Consumer Discretionary
84 | SCHW,Charles Schwab,Financials
85 | SO,Southern Company,Utilities
86 | SPG,Simon,Real Estate
87 | T,AT&T,Communication Services
88 | TGT,Target,Consumer Discretionary
89 | TMO,Thermo Fisher Scientific,Healthcare
90 | TMUS,T-Mobile,Communication Services
91 | TSLA,Tesla,Consumer Discretionary
92 | TXN,Texas Instruments,Technology
93 | UNH,UnitedHealth Group,Healthcare
94 | UNP,Union Pacific,Industrials
95 | UPS,United Parcel Service,Industrials
96 | USB,U.S. Bank,Financials
97 | V,Visa,Technology
98 | VZ,Verizon,Communication Services
99 | WFC,Wells Fargo,Financials
100 | WMT,Walmart,Consumer Staples
101 | XOM,ExxonMobil,Energy
--------------------------------------------------------------------------------
/assets/snp100_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rose-STL-Lab/MCD/5276806eb761545d2e99f53d5135289d7def20e3/assets/snp100_results.png
--------------------------------------------------------------------------------
/configs/data_generation/perturb_data.yaml:
--------------------------------------------------------------------------------
1 |
2 | # graph settings
3 | num_timesteps: 100
4 | lag: 2
5 | history_dep_noise : true
6 | burnin_length: 100
7 | noise_level: 0.5
8 | function_type: mlp # options: mlp, spline, spline_product, conditional_spline, mlp_noise, inverse_noise_spline
9 | noise_function_type: spline_product # options: same as function_type
10 | base_noise_type: gaussian # options: gaussian, uniform
11 | num_samples: 1000
12 |
13 | disable_inst: false # whether to exclude instantaneous edges
14 | inst_graph_type: ER # options: SF, ER
15 | lag_graph_type: ER # options: SF, ER
16 | save_dir: "data/synthetic"
17 |
18 | random_seed:
19 | - 0
20 |
21 | p_array:
22 | - 0.005
23 | - 0.008
24 | - 0.01
25 | - 0.05
26 | - 0.1
27 |
28 | num_nodes: 10
29 | num_graphs:
30 | - 5
31 | - 10
--------------------------------------------------------------------------------
/configs/data_generation/synthetic_data_linear.yaml:
--------------------------------------------------------------------------------
1 |
2 | # graph settings
3 | num_timesteps: 100
4 | lag: 2
5 | history_dep_noise : false
6 | burnin_length: 100
7 | noise_level: 0.5
8 | function_type: linear # options: mlp, spline, linear, spline_product, conditional_spline, mlp_noise, inverse_noise_spline
9 | noise_function_type: spline_product # options: same as function_type
10 | base_noise_type: gaussian # options: gaussian, uniform
11 | num_samples: 1000
12 |
13 | disable_inst: false # whether to exclude instantaneous edges
14 | inst_graph_type: ER # options: SF, ER
15 | lag_graph_type: ER # options: SF, ER
16 | save_dir: "data/synthetic"
17 |
18 | random_seed:
19 | - 0
20 | # - 1
21 | # - 2
22 | # - 3
23 | # - 4
24 | # - 5
25 |
26 | num_nodes:
27 | - 5
28 | - 10
29 | - 20
30 | - 40
31 |
32 | num_graphs:
33 | - 1
34 | - 5
35 | - 10
36 | - 20
37 | # - 40
38 | # - 60
39 | # - 80
40 | # -
--------------------------------------------------------------------------------
/configs/data_generation/synthetic_data_nonlinear.yaml:
--------------------------------------------------------------------------------
1 |
2 | # graph settings
3 | num_timesteps: 100
4 | lag: 2
5 | history_dep_noise : true
6 | burnin_length: 100
7 | noise_level: 0.5
8 | function_type: mlp # options: mlp, spline, linear, spline_product, conditional_spline, mlp_noise, inverse_noise_spline
9 | noise_function_type: spline_product # options: same as function_type
10 | base_noise_type: gaussian # options: gaussian, uniform
11 | num_samples: 1000
12 |
13 | disable_inst: false # whether to exclude instantaneous edges
14 | inst_graph_type: ER # options: SF, ER
15 | lag_graph_type: ER # options: SF, ER
16 | save_dir: "data/synthetic"
17 |
18 | random_seed:
19 | - 0
20 | # - 1
21 | # - 2
22 | # - 3
23 | # - 4
24 | # - 5
25 |
26 | num_nodes:
27 | - 5
28 | - 10
29 | - 20
30 | - 40
31 |
32 | num_graphs:
33 | - 1
34 | - 5
35 | - 10
36 | - 20
37 | # - 40
38 | # - 60
39 | # - 80
40 | # -
--------------------------------------------------------------------------------
/configs/dream3/dynotears.yaml:
--------------------------------------------------------------------------------
1 | model: dynotears
2 | lag: 2
3 | dream3_size: 100
4 |
5 | trainer:
6 | _target_: src.baselines.DYNOTEARSTrainer.DYNOTEARSTrainer
7 | _partial_: true
8 |
9 | single_graph: false
10 | max_iter: 1000
11 | lambda_w: 0.1
12 | lambda_a: 0.1
13 | w_threshold: 0.01
14 | h_tol: 1e-8
15 |
16 | group_by_graph: true
--------------------------------------------------------------------------------
/configs/dream3/mcd.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: mcd
5 | lag: 2
6 | dream3_size: 100
7 | monitor_checkpoint_based_on: likelihood
8 |
9 | hypernet:
10 | _target_: src.modules.MultiTemporalHyperNet.MultiTemporalHyperNet
11 | order: linear
12 | num_bins: 8
13 | skip_connection: true
14 | embedding_dim: 32
15 | dropout_p: 0.2
16 |
17 | # decoder options
18 | decoder:
19 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder
20 | skip_connection: true
21 | embedding_dim: 32
22 | dropout_p: 0.2
23 |
24 | trainer:
25 | _target_: src.model.MCDTrainer.MCDTrainer
26 | _partial_: true
27 |
28 | batch_size: 64
29 | sparsity_factor: 10
30 | likelihood_loss: flow
31 | matrix_temperature: 0.25
32 | threeway_graph_dist: true
33 | training_procedure: auglag
34 | skip_auglag_epochs: 0
35 | init_logits: [-3, -3]
36 |
37 | # multi-rhino settings
38 | num_graphs: 10
39 | use_correct_mixture_index: false
40 | use_true_graph: false
41 | log_num_unique_graphs: false
42 | disable_inst: true
43 |
44 | auglag_config:
45 | _target_: src.training.auglag.AugLagLRConfig
46 |
47 | lr_update_lag: 500
48 | lr_update_lag_best: 500
49 | lr_init_dict:
50 | vardist: 0.001
51 | functional_relationships: 0.001
52 | noise_dist: 0.001
53 | mixing_probs: 0.01
54 |
55 | aggregation_period: 20
56 | lr_factor: 0.1
57 | max_lr_down: 3
58 | penalty_progress_rate: 0.65
59 | safety_rho: 1e13
60 | safety_alpha: 1e13
61 | inner_early_stopping_patience: 1500
62 | max_outer_steps: 60
63 | patience_penalty_reached: 10
64 | patience_max_rho: 3
65 | penalty_tolerance: -1
66 | max_inner_steps: 6000
67 | force_not_converged: false
68 | init_rho: 1
69 | init_alpha: 0
--------------------------------------------------------------------------------
/configs/dream3/mcd_linear.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: mcd
5 | lag: 2
6 | dream3_size: 100
7 | monitor_checkpoint_based_on: likelihood
8 |
9 | # decoder options
10 | decoder:
11 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder
12 | linear: true
13 | embedding_dim: 32
14 |
15 | trainer:
16 | _target_: src.model.MCDTrainer.MCDTrainer
17 | _partial_: true
18 |
19 | batch_size: 64
20 | sparsity_factor: 25
21 | likelihood_loss: mse
22 | matrix_temperature: 0.25
23 | threeway_graph_dist: true
24 | training_procedure: auglag
25 | skip_auglag_epochs: 0
26 | init_logits: [-3, -3]
27 |
28 | # multi-rhino settings
29 | num_graphs: 10
30 | use_correct_mixture_index: false
31 | use_true_graph: false
32 | log_num_unique_graphs: false
33 | disable_inst: true
34 |
35 | auglag_config:
36 | _target_: src.training.auglag.AugLagLRConfig
37 |
38 | lr_update_lag: 500
39 | lr_update_lag_best: 500
40 | lr_init_dict:
41 | vardist: 0.001
42 | functional_relationships: 0.001
43 | noise_dist: 0.001
44 | mixing_probs: 0.01
45 |
46 | aggregation_period: 20
47 | lr_factor: 0.1
48 | max_lr_down: 3
49 | penalty_progress_rate: 0.65
50 | safety_rho: 1e13
51 | safety_alpha: 1e13
52 | inner_early_stopping_patience: 1500
53 | max_outer_steps: 60
54 | patience_penalty_reached: 10
55 | patience_max_rho: 3
56 | penalty_tolerance: -1
57 | max_inner_steps: 6000
58 | force_not_converged: false
59 | init_rho: 1
60 | init_alpha: 0
--------------------------------------------------------------------------------
/configs/dream3/pcmci.yaml:
--------------------------------------------------------------------------------
1 | model: pcmci
2 | lag: 2
3 | dream3_size: 100
4 |
5 | trainer:
6 | _target_: src.baselines.PCMCITrainer.PCMCITrainer
7 | _partial_: true
8 |
9 | ci_test: ParCorr
10 | single_graph: true
11 | pcmci_plus: true
12 | pc_alpha: 0.01
13 |
14 | group_by_graph: false
--------------------------------------------------------------------------------
/configs/dream3/rhino.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: rhino
5 | lag: 2
6 | dream3_size: 100
7 | monitor_checkpoint_based_on: likelihood
8 |
9 | hypernet:
10 | _target_: src.modules.TemporalHyperNet.TemporalHyperNet
11 | order: linear
12 | num_bins: 8
13 | skip_connection: true
14 | embedding_dim: 32
15 |
16 | # decoder options
17 | decoder:
18 | _target_: src.modules.CausalDecoder.CausalDecoder
19 | skip_connection: true
20 | embedding_dim: 32
21 |
22 | trainer:
23 | _target_: src.model.RhinoTrainer.RhinoTrainer
24 | _partial_: true
25 |
26 | batch_size: 64
27 | sparsity_factor: 25
28 | likelihood_loss: flow
29 | matrix_temperature: 0.25
30 | threeway_graph_dist: true
31 | training_procedure: auglag
32 | skip_auglag_epochs: 0
33 | init_logits: [-3, -3]
34 | disable_inst: true
35 |
36 | auglag_config:
37 | _target_: src.training.auglag.AugLagLRConfig
38 |
39 | lr_update_lag: 500
40 | lr_update_lag_best: 500
41 | lr_init_dict:
42 | vardist: 0.001
43 | functional_relationships: 0.001
44 | noise_dist: 0.001
45 |
46 | aggregation_period: 20
47 | lr_factor: 0.1
48 | max_lr_down: 3
49 | penalty_progress_rate: 0.65
50 | safety_rho: 1e13
51 | safety_alpha: 1e13
52 | inner_early_stopping_patience: 1500
53 | max_outer_steps: 60
54 | patience_penalty_reached: 10
55 | patience_max_rho: 3
56 | penalty_tolerance: -1
57 | max_inner_steps: 6000
58 | force_not_converged: false
59 | init_rho: 1
60 | init_alpha: 0
--------------------------------------------------------------------------------
/configs/dream3/rhino_linear.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: rhino
5 | lag: 2
6 | dream3_size: 100
7 | monitor_checkpoint_based_on: likelihood
8 |
9 | # decoder options
10 | decoder:
11 | _target_: src.modules.CausalDecoder.CausalDecoder
12 | linear: true
13 |
14 | trainer:
15 | _target_: src.model.RhinoTrainer.RhinoTrainer
16 | _partial_: true
17 |
18 | batch_size: 64
19 | sparsity_factor: 25
20 | likelihood_loss: mse
21 | matrix_temperature: 0.25
22 | threeway_graph_dist: true
23 | training_procedure: auglag
24 | skip_auglag_epochs: 0
25 | init_logits: [-3, -3]
26 | disable_inst: true
27 |
28 | auglag_config:
29 | _target_: src.training.auglag.AugLagLRConfig
30 |
31 | lr_update_lag: 500
32 | lr_update_lag_best: 500
33 | lr_init_dict:
34 | vardist: 0.001
35 | functional_relationships: 0.001
36 | noise_dist: 0.001
37 |
38 | aggregation_period: 20
39 | lr_factor: 0.1
40 | max_lr_down: 3
41 | penalty_progress_rate: 0.65
42 | safety_rho: 1e13
43 | safety_alpha: 1e13
44 | inner_early_stopping_patience: 1500
45 | max_outer_steps: 60
46 | patience_penalty_reached: 10
47 | patience_max_rho: 3
48 | penalty_tolerance: -1
49 | max_inner_steps: 6000
50 | force_not_converged: false
51 | init_rho: 1
52 | init_alpha: 0
--------------------------------------------------------------------------------
/configs/dream3/varlingam.yaml:
--------------------------------------------------------------------------------
1 | model: varlingam
2 | lag: 2
3 | dream3_size: 100
4 |
5 | trainer:
6 | _target_: src.baselines.VARLiNGAMTrainer.VARLiNGAMTrainer
7 | _partial_: true
--------------------------------------------------------------------------------
/configs/main.yaml:
--------------------------------------------------------------------------------
1 |
2 | precision: 32
3 | wandb_project: causal_mixture_model
4 | gpu: [0]
5 | num_workers: 8
6 | dataset_dir: 'data/'
7 | random_seed: 0
8 |
--------------------------------------------------------------------------------
/configs/netsim/dynotears.yaml:
--------------------------------------------------------------------------------
1 | model: dynotears
2 |
3 | lag: 2
4 | trainer:
5 | _target_: src.baselines.DYNOTEARSTrainer.DYNOTEARSTrainer
6 | _partial_: true
7 |
8 | single_graph: true
9 | max_iter: 1000
10 | lambda_w: 0.1
11 | lambda_a: 0.1
12 | w_threshold: 0.01
13 | h_tol: 1e-8
14 |
15 | group_by_graph: false
--------------------------------------------------------------------------------
/configs/netsim/mcd.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: mcd
5 | lag: 2
6 | monitor_checkpoint_based_on: likelihood
7 |
8 | hypernet:
9 | _target_: src.modules.MultiTemporalHyperNet.MultiTemporalHyperNet
10 | order: linear
11 | num_bins: 8
12 | skip_connection: true
13 |
14 | # decoder options
15 | decoder:
16 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder
17 | skip_connection: true
18 |
19 | trainer:
20 | _target_: src.model.MCDTrainer.MCDTrainer
21 | _partial_: true
22 |
23 | batch_size: 64
24 | sparsity_factor: 25
25 | likelihood_loss: flow
26 | matrix_temperature: 0.25
27 | threeway_graph_dist: true
28 | training_procedure: auglag
29 | skip_auglag_epochs: 0
30 | graph_selection_prior_lambda: 0.0
31 |
32 | # multi-rhino settings
33 | num_graphs: 20
34 | use_correct_mixture_index: false
35 | use_true_graph: false
36 | log_num_unique_graphs: false
37 |
38 | auglag_config:
39 | _target_: src.training.auglag.AugLagLRConfig
40 |
41 | lr_update_lag: 500
42 | lr_update_lag_best: 500
43 | lr_init_dict:
44 | vardist: 0.01
45 | functional_relationships: 0.001
46 | noise_dist: 0.001
47 | mixing_probs: 0.01
48 |
49 | aggregation_period: 20
50 | lr_factor: 0.1
51 | max_lr_down: 3
52 | penalty_progress_rate: 0.65
53 | safety_rho: 1e13
54 | safety_alpha: 1e13
55 | inner_early_stopping_patience: 1500
56 | max_outer_steps: 60
57 | patience_penalty_reached: 100
58 | patience_max_rho: 50
59 | penalty_tolerance: 1e-8
60 | max_inner_steps: 2000
61 | force_not_converged: false
62 | init_rho: 1
63 | init_alpha: 0
--------------------------------------------------------------------------------
/configs/netsim/mcd_linear.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: mcd
5 | lag: 2
6 | monitor_checkpoint_based_on: likelihood
7 |
8 | # decoder options
9 | decoder:
10 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder
11 | linear: true
12 |
13 | trainer:
14 | _target_: src.model.MCDTrainer.MCDTrainer
15 | _partial_: true
16 |
17 | batch_size: 64
18 | sparsity_factor: 25
19 | likelihood_loss: mse
20 | matrix_temperature: 0.25
21 | threeway_graph_dist: true
22 | training_procedure: auglag
23 | skip_auglag_epochs: 0
24 | graph_selection_prior_lambda: 0.0
25 |
26 | # multi-rhino settings
27 | num_graphs: 10
28 | use_correct_mixture_index: false
29 | use_true_graph: false
30 | log_num_unique_graphs: false
31 |
32 | auglag_config:
33 | _target_: src.training.auglag.AugLagLRConfig
34 |
35 | lr_update_lag: 500
36 | lr_update_lag_best: 500
37 | lr_init_dict:
38 | vardist: 0.01
39 | functional_relationships: 0.001
40 | noise_dist: 0.001
41 | mixing_probs: 0.01
42 |
43 | aggregation_period: 20
44 | lr_factor: 0.1
45 | max_lr_down: 3
46 | penalty_progress_rate: 0.65
47 | safety_rho: 1e13
48 | safety_alpha: 1e13
49 | inner_early_stopping_patience: 1500
50 | max_outer_steps: 60
51 | patience_penalty_reached: 100
52 | patience_max_rho: 50
53 | penalty_tolerance: 1e-8
54 | max_inner_steps: 2000
55 | force_not_converged: false
56 | init_rho: 1
57 | init_alpha: 0
--------------------------------------------------------------------------------
/configs/netsim/pcmci.yaml:
--------------------------------------------------------------------------------
1 | model: pcmci
2 | lag: 2
3 |
4 | trainer:
5 | _target_: src.baselines.PCMCITrainer.PCMCITrainer
6 | _partial_: true
7 |
8 | ci_test: ParCorr
9 | single_graph: false
10 | pcmci_plus: true
11 | pc_alpha: 0.01
12 |
13 | group_by_graph: true
--------------------------------------------------------------------------------
/configs/netsim/rhino.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: rhino
5 | monitor_checkpoint_based_on: likelihood
6 | lag: 2
7 |
8 | hypernet:
9 | _target_: src.modules.TemporalHyperNet.TemporalHyperNet
10 | order: linear
11 | num_bins: 8
12 | skip_connection: true
13 |
14 | # decoder options
15 | decoder:
16 | _target_: src.modules.CausalDecoder.CausalDecoder
17 | skip_connection: true
18 |
19 | trainer:
20 | _target_: src.model.RhinoTrainer.RhinoTrainer
21 | _partial_: true
22 |
23 | batch_size: 128
24 | sparsity_factor: 25
25 |
26 | likelihood_loss: flow
27 | matrix_temperature: 0.25
28 | threeway_graph_dist: true
29 | training_procedure: auglag # options: auglag, dagma
30 | skip_auglag_epochs: 0
31 |
32 | auglag_config:
33 | _target_: src.training.auglag.AugLagLRConfig
34 |
35 | lr_update_lag: 500
36 | lr_update_lag_best: 500
37 | lr_init_dict:
38 | vardist: 0.001
39 | functional_relationships: 0.001
40 | noise_dist: 0.001
41 |
42 | aggregation_period: 20
43 | lr_factor: 0.1
44 | max_lr_down: 3
45 | penalty_progress_rate: 0.65
46 | safety_rho: 1e13
47 | safety_alpha: 1e13
48 | inner_early_stopping_patience: 1500
49 | max_outer_steps: 60
50 | patience_penalty_reached: 100
51 | patience_max_rho: 50
52 | penalty_tolerance: 1e-8
53 | max_inner_steps: 2000
54 | force_not_converged: false
55 | init_rho: 1
56 | init_alpha: 0
57 |
--------------------------------------------------------------------------------
/configs/netsim/rhino_linear.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: rhino
5 | monitor_checkpoint_based_on: likelihood
6 | lag: 2
7 |
8 | # decoder options
9 | decoder:
10 | _target_: src.modules.CausalDecoder.CausalDecoder
11 | linear: true
12 |
13 | trainer:
14 | _target_: src.model.RhinoTrainer.RhinoTrainer
15 | _partial_: true
16 |
17 | batch_size: 64
18 | sparsity_factor: 25
19 |
20 | likelihood_loss: mse
21 | matrix_temperature: 0.25
22 | threeway_graph_dist: true
23 | training_procedure: auglag
24 | skip_auglag_epochs: 0
25 |
26 | auglag_config:
27 | _target_: src.training.auglag.AugLagLRConfig
28 |
29 | lr_update_lag: 500
30 | lr_update_lag_best: 500
31 | lr_init_dict:
32 | vardist: 0.001
33 | functional_relationships: 0.001
34 | noise_dist: 0.001
35 |
36 | aggregation_period: 20
37 | lr_factor: 0.1
38 | max_lr_down: 3
39 | penalty_progress_rate: 0.65
40 | safety_rho: 1e13
41 | safety_alpha: 1e13
42 | inner_early_stopping_patience: 1500
43 | max_outer_steps: 60
44 | patience_penalty_reached: 100
45 | patience_max_rho: 50
46 | penalty_tolerance: 1e-8
47 | max_inner_steps: 2000
48 | force_not_converged: false
49 | init_rho: 1
50 | init_alpha: 0
51 |
--------------------------------------------------------------------------------
/configs/netsim/varlingam.yaml:
--------------------------------------------------------------------------------
1 | model: varlingam
2 | lag: 2
3 |
4 | trainer:
5 | _target_: src.baselines.VARLiNGAMTrainer.VARLiNGAMTrainer
6 | _partial_: true
--------------------------------------------------------------------------------
/configs/snp100/dynotears.yaml:
--------------------------------------------------------------------------------
1 | model: dynotears
2 |
3 | trainer:
4 | _target_: src.baselines.DYNOTEARSTrainer.DYNOTEARSTrainer
5 | _partial_: true
6 |
7 | single_graph: true
8 | max_iter: 1000
9 | lambda_w: 0.10
10 | lambda_a: 0.10
11 | w_threshold: 0.01
12 | h_tol: 1e-8
13 | lag: 1
14 | group_by_graph: false
--------------------------------------------------------------------------------
/configs/snp100/mcd.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 100000
4 | model: mcd
5 | monitor_checkpoint_based_on: likelihood
6 | val_every_n_epochs: 10
7 |
8 | hypernet:
9 | _target_: src.modules.MultiTemporalHyperNet.MultiTemporalHyperNet
10 | order: linear
11 | num_bins: 8
12 | skip_connection: true
13 |
14 | # decoder options
15 | decoder:
16 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder
17 | skip_connection: true
18 |
19 | trainer:
20 | _target_: src.model.MCDTrainer.MCDTrainer
21 | _partial_: true
22 |
23 | shuffle: False
24 | batch_size: 64
25 | sparsity_factor: 20
26 | likelihood_loss: flow
27 | matrix_temperature: 0.25
28 | threeway_graph_dist: true
29 | training_procedure: auglag
30 | skip_auglag_epochs: 0
31 |
32 | # multi-rhino settings
33 | num_graphs: 5
34 | use_correct_mixture_index: false
35 | use_true_graph: false
36 | log_num_unique_graphs: true
37 | init_logits: [-3, -3]
38 | use_all_for_val: False
39 |
40 | auglag_config:
41 | _target_: src.training.auglag.AugLagLRConfig
42 |
43 | lr_update_lag: 500
44 | lr_update_lag_best: 500
45 | lr_init_dict:
46 | vardist: 0.01
47 | functional_relationships: 0.001
48 | noise_dist: 0.001
49 | mixing_probs: 0.01
50 |
51 | aggregation_period: 20
52 | lr_factor: 0.1
53 | max_lr_down: 3
54 | penalty_progress_rate: 0.65
55 | safety_rho: 1e13
56 | safety_alpha: 1e13
57 | inner_early_stopping_patience: 1500
58 | max_outer_steps: 60
59 | patience_penalty_reached: 5
60 | patience_max_rho: 3
61 | penalty_tolerance: 1e-5
62 | max_inner_steps: 2000
63 | force_not_converged: false
64 | init_rho: 1
65 | init_alpha: 0
--------------------------------------------------------------------------------
/configs/snp100/mcd_linear.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 2000
4 | model: mcd
5 | monitor_checkpoint_based_on: likelihood
6 | val_every_n_epochs: 10
7 |
8 | # decoder options
9 | decoder:
10 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder
11 | linear: true
12 |
13 | trainer:
14 | _target_: src.model.MCDTrainer.MCDTrainer
15 | _partial_: true
16 |
17 | batch_size: 64
18 | sparsity_factor: 10
19 | likelihood_loss: mse
20 | matrix_temperature: 0.25
21 | threeway_graph_dist: true
22 | training_procedure: auglag
23 | skip_auglag_epochs: 0
24 |
25 | # multi-rhino settings
26 | num_graphs: 2
27 | use_correct_mixture_index: false
28 | use_true_graph: false
29 | log_num_unique_graphs: false
30 | init_logits: [-3, -3]
31 | use_all_for_val: True
32 |
33 | auglag_config:
34 | _target_: src.training.auglag.AugLagLRConfig
35 |
36 | lr_update_lag: 500
37 | lr_update_lag_best: 500
38 | lr_init_dict:
39 | vardist: 0.01
40 | functional_relationships: 0.01
41 | noise_dist: 0.001
42 | mixing_probs: 0.01
43 |
44 | aggregation_period: 20
45 | lr_factor: 0.1
46 | max_lr_down: 3
47 | penalty_progress_rate: 0.65
48 | safety_rho: 1e13
49 | safety_alpha: 1e13
50 | inner_early_stopping_patience: 1500
51 | max_outer_steps: 100
52 | patience_penalty_reached: 100
53 | patience_max_rho: 50
54 | penalty_tolerance: 1e-8
55 | max_inner_steps: 6000
56 | force_not_converged: false
57 | init_rho: 1
58 | init_alpha: 0
--------------------------------------------------------------------------------
/configs/snp100/pcmci.yaml:
--------------------------------------------------------------------------------
1 | model: pcmci
2 | trainer:
3 | _target_: src.baselines.PCMCITrainer.PCMCITrainer
4 | _partial_: true
5 |
6 | ci_test: ParCorr
7 | single_graph: false
8 | pcmci_plus: true
9 | pc_alpha: 0.01
10 |
11 | group_by_graph: false
--------------------------------------------------------------------------------
/configs/snp100/rhino.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 2000
4 | model: rhino
5 | monitor_checkpoint_based_on: likelihood
6 |
7 | hypernet:
8 | _target_: src.modules.TemporalHyperNet.TemporalHyperNet
9 | order: linear
10 | num_bins: 8
11 | skip_connection: true
12 |
13 | # decoder options
14 | decoder:
15 | _target_: src.modules.CausalDecoder.CausalDecoder
16 | skip_connection: true
17 |
18 | trainer:
19 | _target_: src.model.RhinoTrainer.RhinoTrainer
20 | _partial_: true
21 |
22 | batch_size: 128
23 | sparsity_factor: 10
24 |
25 | likelihood_loss: flow
26 | matrix_temperature: 0.25
27 | threeway_graph_dist: true
28 | training_procedure: auglag # options: auglag, dagma
29 | skip_auglag_epochs: 0
30 | init_logits: [-3, -3]
31 |
32 | auglag_config:
33 | _target_: src.training.auglag.AugLagLRConfig
34 |
35 | lr_update_lag: 500
36 | lr_update_lag_best: 500
37 | lr_init_dict:
38 | vardist: 0.001
39 | functional_relationships: 0.001
40 | noise_dist: 0.001
41 |
42 | aggregation_period: 20
43 | lr_factor: 0.1
44 | max_lr_down: 3
45 | penalty_progress_rate: 0.65
46 | safety_rho: 1e13
47 | safety_alpha: 1e13
48 | inner_early_stopping_patience: 1500
49 | max_outer_steps: 100
50 | patience_penalty_reached: 100
51 | patience_max_rho: 3
52 | penalty_tolerance: 1e-5
53 | max_inner_steps: 6000
54 | force_not_converged: false
55 | init_rho: 1
56 | init_alpha: 0
57 |
--------------------------------------------------------------------------------
/configs/snp100/varlingam.yaml:
--------------------------------------------------------------------------------
1 | model: varlingam
2 | trainer:
3 | _target_: src.baselines.VARLiNGAMTrainer.VARLiNGAMTrainer
4 | _partial_: true
--------------------------------------------------------------------------------
/configs/synthetic/dynotears.yaml:
--------------------------------------------------------------------------------
1 | model: dynotears
2 |
3 | trainer:
4 | _target_: src.baselines.DYNOTEARSTrainer.DYNOTEARSTrainer
5 | _partial_: true
6 |
7 | single_graph: false
8 | max_iter: 1000
9 | lambda_w: 0.05
10 | lambda_a: 0.05
11 | w_threshold: 0.01
12 | h_tol: 1e-8
13 |
14 | group_by_graph: false
--------------------------------------------------------------------------------
/configs/synthetic/mcd.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: mcd
5 | monitor_checkpoint_based_on: likelihood
6 |
7 | hypernet:
8 | _target_: src.modules.MultiTemporalHyperNet.MultiTemporalHyperNet
9 | order: quadratic
10 | num_bins: 8
11 | skip_connection: true
12 |
13 | # decoder options
14 | decoder:
15 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder
16 | skip_connection: true
17 |
18 | trainer:
19 | _target_: src.model.MCDTrainer.MCDTrainer
20 | _partial_: true
21 |
22 | batch_size: 128
23 | sparsity_factor: 5
24 | likelihood_loss: flow
25 | matrix_temperature: 0.25
26 | threeway_graph_dist: true
27 | training_procedure: auglag
28 | skip_auglag_epochs: 0
29 |
30 | # multi-rhino settings
31 | num_graphs: 10
32 | use_correct_mixture_index: false
33 | use_true_graph: false
34 | log_num_unique_graphs: false
35 |
36 | auglag_config:
37 | _target_: src.training.auglag.AugLagLRConfig
38 |
39 | lr_update_lag: 500
40 | lr_update_lag_best: 500
41 | lr_init_dict:
42 | vardist: 0.01
43 | functional_relationships: 0.01
44 | noise_dist: 0.001
45 | mixing_probs: 0.01
46 |
47 | aggregation_period: 20
48 | lr_factor: 0.1
49 | max_lr_down: 3
50 | penalty_progress_rate: 0.65
51 | safety_rho: 1e13
52 | safety_alpha: 1e13
53 | inner_early_stopping_patience: 1500
54 | max_outer_steps: 100
55 | patience_penalty_reached: 100
56 | patience_max_rho: 50
57 | penalty_tolerance: 1e-8
58 | max_inner_steps: 6000
59 | force_not_converged: false
60 | init_rho: 1
61 | init_alpha: 0
--------------------------------------------------------------------------------
/configs/synthetic/mcd_linear.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: mcd
5 | monitor_checkpoint_based_on: likelihood
6 |
7 | # decoder options
8 | decoder:
9 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder
10 | linear: true
11 |
12 | trainer:
13 | _target_: src.model.MCDTrainer.MCDTrainer
14 | _partial_: true
15 |
16 | batch_size: 128
17 | sparsity_factor: 5
18 | likelihood_loss: mse
19 | matrix_temperature: 0.25
20 | threeway_graph_dist: true
21 | training_procedure: auglag
22 | skip_auglag_epochs: 0
23 |
24 | # multi-rhino settings
25 | num_graphs: 10
26 | use_correct_mixture_index: false
27 | use_true_graph: false
28 | log_num_unique_graphs: false
29 |
30 | auglag_config:
31 | _target_: src.training.auglag.AugLagLRConfig
32 |
33 | lr_update_lag: 500
34 | lr_update_lag_best: 500
35 | lr_init_dict:
36 | vardist: 0.01
37 | functional_relationships: 0.01
38 | noise_dist: 0.001
39 | mixing_probs: 0.01
40 |
41 | aggregation_period: 20
42 | lr_factor: 0.1
43 | max_lr_down: 3
44 | penalty_progress_rate: 0.65
45 | safety_rho: 1e13
46 | safety_alpha: 1e13
47 | inner_early_stopping_patience: 1500
48 | max_outer_steps: 100
49 | patience_penalty_reached: 100
50 | patience_max_rho: 50
51 | penalty_tolerance: 1e-8
52 | max_inner_steps: 6000
53 | force_not_converged: false
54 | init_rho: 1
55 | init_alpha: 0
--------------------------------------------------------------------------------
/configs/synthetic/pcmci.yaml:
--------------------------------------------------------------------------------
1 | model: pcmci
2 | trainer:
3 | _target_: src.baselines.PCMCITrainer.PCMCITrainer
4 | _partial_: true
5 |
6 | ci_test: ParCorr
7 | single_graph: false
8 | pcmci_plus: true
9 | pc_alpha: 0.01
10 |
11 | group_by_graph: false
--------------------------------------------------------------------------------
/configs/synthetic/rhino.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: rhino
5 | monitor_checkpoint_based_on: likelihood
6 |
7 | hypernet:
8 | _target_: src.modules.TemporalHyperNet.TemporalHyperNet
9 | order: quadratic
10 | num_bins: 8
11 | skip_connection: true
12 |
13 | # decoder options
14 | decoder:
15 | _target_: src.modules.CausalDecoder.CausalDecoder
16 | skip_connection: true
17 |
18 | trainer:
19 | _target_: src.model.RhinoTrainer.RhinoTrainer
20 | _partial_: true
21 |
22 | batch_size: 128
23 | sparsity_factor: 5
24 |
25 | likelihood_loss: flow
26 | matrix_temperature: 0.25
27 | threeway_graph_dist: true
28 | training_procedure: auglag # options: auglag, dagma
29 | skip_auglag_epochs: 0
30 |
31 | auglag_config:
32 | _target_: src.training.auglag.AugLagLRConfig
33 |
34 | lr_update_lag: 500
35 | lr_update_lag_best: 500
36 | lr_init_dict:
37 | vardist: 0.01
38 | functional_relationships: 0.01
39 | noise_dist: 0.001
40 | mixing_probs: 0.01
41 |
42 | aggregation_period: 20
43 | lr_factor: 0.1
44 | max_lr_down: 3
45 | penalty_progress_rate: 0.65
46 | safety_rho: 1e13
47 | safety_alpha: 1e13
48 | inner_early_stopping_patience: 1500
49 | max_outer_steps: 100
50 | patience_penalty_reached: 100
51 | patience_max_rho: 50
52 | penalty_tolerance: 1e-8
53 | max_inner_steps: 6000
54 | force_not_converged: false
55 | init_rho: 1
56 | init_alpha: 0
57 |
--------------------------------------------------------------------------------
/configs/synthetic/rhino_linear.yaml:
--------------------------------------------------------------------------------
1 | # hyperparameters
2 | watch_gradients: false
3 | num_epochs: 10000
4 | model: rhino
5 | monitor_checkpoint_based_on: likelihood
6 |
7 | # decoder options
8 | decoder:
9 | _target_: src.modules.CausalDecoder.CausalDecoder
10 | linear: true
11 |
12 | trainer:
13 | _target_: src.model.RhinoTrainer.RhinoTrainer
14 | _partial_: true
15 |
16 | batch_size: 128
17 | sparsity_factor: 5
18 |
19 | likelihood_loss: mse
20 | matrix_temperature: 0.25
21 | threeway_graph_dist: true
22 | training_procedure: auglag # options: auglag, dagma
23 | skip_auglag_epochs: 0
24 |
25 | auglag_config:
26 | _target_: src.training.auglag.AugLagLRConfig
27 |
28 | lr_update_lag: 500
29 | lr_update_lag_best: 500
30 | lr_init_dict:
31 | vardist: 0.01
32 | functional_relationships: 0.01
33 | noise_dist: 0.001
34 | mixing_probs: 0.01
35 |
36 | aggregation_period: 20
37 | lr_factor: 0.1
38 | max_lr_down: 3
39 | penalty_progress_rate: 0.65
40 | safety_rho: 1e13
41 | safety_alpha: 1e13
42 | inner_early_stopping_patience: 1500
43 | max_outer_steps: 100
44 | patience_penalty_reached: 100
45 | patience_max_rho: 50
46 | penalty_tolerance: 1e-8
47 | max_inner_steps: 6000
48 | force_not_converged: false
49 | init_rho: 1
50 | init_alpha: 0
51 |
--------------------------------------------------------------------------------
/configs/synthetic/varlingam.yaml:
--------------------------------------------------------------------------------
1 | model: varlingam
2 | trainer:
3 | _target_: src.baselines.VARLiNGAMTrainer.VARLiNGAMTrainer
4 | _partial_: true
--------------------------------------------------------------------------------
/scripts/generate_perturb_datasets.sh:
--------------------------------------------------------------------------------
1 | python3 src/utils/data_gen/generate_perturb_syn.py --config_file configs/data_generation/perturb_data.yaml
--------------------------------------------------------------------------------
/scripts/generate_snp100.sh:
--------------------------------------------------------------------------------
1 | python3 -m src.utils.data_gen.generate_stock --stock_list_file assets/snp100.csv --save_dir data/snp100/ --chunk_size 31
--------------------------------------------------------------------------------
/scripts/generate_synthetic_datasets_linear.sh:
--------------------------------------------------------------------------------
1 | python3 src/utils/data_gen/generate_synthetic_data.py --config_file configs/data_generation/synthetic_data_linear.yaml
--------------------------------------------------------------------------------
/scripts/generate_synthetic_datasets_nonlinear.sh:
--------------------------------------------------------------------------------
1 | python3 src/utils/data_gen/generate_synthetic_data.py --config_file configs/data_generation/synthetic_data_nonlinear.yaml
--------------------------------------------------------------------------------
/scripts/setup_dream3.sh:
--------------------------------------------------------------------------------
1 | mkdir temp
2 | cd temp
3 |
4 | wget https://github.com/sakhanna/SRU_for_GCI/archive/refs/heads/master.zip
5 | unzip master.zip
6 |
7 | if [ ! -d "../data/dream3/raw" ]
8 | then
9 | mkdir -p ../data/dream3/raw
10 | fi
11 |
12 | mv SRU_for_GCI-master/data/dream3/* ../data/dream3/raw/
13 | cd ../
14 | rm -r temp
15 | python3 src/utils/data_gen/process_dream3.py --dataset_dir data/dream3/raw --save_dir data/dream3/
16 |
17 |
--------------------------------------------------------------------------------
/scripts/setup_netsim.sh:
--------------------------------------------------------------------------------
1 | mkdir temp
2 | cd temp
3 |
4 | wget https://www.fmrib.ox.ac.uk/datasets/netsim/sims.tar.gz
5 | tar xvf sims.tar.gz
6 |
7 | if [ ! -d "../data/netsim/raw" ]
8 | then
9 | mkdir -p ../data/netsim/raw
10 | fi
11 |
12 |
13 | mv *.mat ../data/netsim/raw/
14 | cd ../
15 | rm -r temp
16 | python3 src/utils/data_gen/process_netsim.py --dataset_dir data/netsim/raw --save_dir data/netsim/
17 |
18 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name='multi_graph_rhino',
5 | packages=find_packages(),
6 | )
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rose-STL-Lab/MCD/5276806eb761545d2e99f53d5135289d7def20e3/src/__init__.py
--------------------------------------------------------------------------------
/src/baselines/BaselineTrainer.py:
--------------------------------------------------------------------------------
1 | import lightning.pytorch as pl
2 | from torch.utils.data import DataLoader
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from src.dataset.BaselineTSDataset import BaselineTSDataset
7 | from src.utils.metrics_utils import mape_loss
8 |
9 | class BaselineTrainer(pl.LightningModule):
10 | # trainer class for the baselines
11 | # they do not need training
12 | def __init__(self,
13 | full_dataset: np.array,
14 | adj_matrices: np.array,
15 | data_dim: int,
16 | num_nodes: int,
17 | lag: int,
18 | num_workers: int = 16,
19 | aggregated_graph: bool = False):
20 | super().__init__()
21 |
22 | self.num_workers = num_workers
23 | self.aggregated_graph = aggregated_graph
24 | self.data_dim = data_dim
25 | self.lag = lag
26 | self.num_nodes = num_nodes
27 | self.full_dataset_np = full_dataset
28 | self.adj_matrices_np = adj_matrices
29 | self.total_samples = full_dataset.shape[0]
30 | assert adj_matrices.shape[0] == self.total_samples
31 |
32 | self.full_dataset = BaselineTSDataset(
33 | X = self.full_dataset_np,
34 | adj_matrix = self.adj_matrices_np,
35 | lag=lag,
36 | aggregated_graph=self.aggregated_graph,
37 | return_graph_indices=True
38 | )
39 |
40 | self.batch_size = 1
41 |
42 | def compute_mse(self, x_current, x_pred):
43 | return F.mse_loss(x_current, x_pred)
44 |
45 | def compute_mape(self, x_current, x_pred):
46 | return mape_loss(x_current, x_pred)
47 |
48 | def get_full_dataloader(self) -> DataLoader:
49 | return DataLoader(self.full_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
50 |
--------------------------------------------------------------------------------
/src/baselines/DYNOTEARSTrainer.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import numpy as np
3 | import torch
4 |
5 | # import tigramite for pcmci
6 | import networkx as nx
7 | import pandas as pd
8 |
9 | from src.baselines.BaselineTrainer import BaselineTrainer
10 | from src.modules.dynotears.dynotears import from_pandas_dynamic
11 | from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, to_time_aggregated_scores_np, zero_out_diag_np
12 |
13 | class DYNOTEARSTrainer(BaselineTrainer):
14 |
15 | def __init__(self,
16 | full_dataset: np.array,
17 | adj_matrices: np.array,
18 | data_dim: int,
19 | num_nodes: int,
20 | lag: int,
21 | num_workers: int = 16,
22 | aggregated_graph: bool = False,
23 | single_graph: bool = True,
24 | max_iter: int = 1000,
25 | lambda_w: float = 0.1,
26 | lambda_a: float = 0.1,
27 | w_threshold: float = 0.4,
28 | h_tol: float = 1e-8,
29 | group_by_graph: bool = True,
30 | ignore_self_connections: bool = False
31 | ):
32 | self.group_by_graph = group_by_graph
33 | self.ignore_self_connections = ignore_self_connections
34 | if self.group_by_graph:
35 | self.single_graph = True
36 | print("DYNOTEARS: Group by graph option set. Overriding single graph flag to True...")
37 | else:
38 | self.single_graph = single_graph
39 | super().__init__(full_dataset=full_dataset,
40 | adj_matrices=adj_matrices,
41 | data_dim=data_dim,
42 | lag=lag,
43 | num_nodes=num_nodes,
44 | num_workers=num_workers,
45 | aggregated_graph=aggregated_graph)
46 |
47 | self.max_iter = max_iter
48 | self.lambda_w = lambda_w
49 | self.lambda_a = lambda_a
50 | self.w_threshold = w_threshold
51 | self.h_tol = h_tol
52 |
53 | if self.single_graph:
54 | self.batch_size = full_dataset.shape[0] # we want the full dataset
55 | else:
56 | self.batch_size = 1
57 |
58 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
59 | X, adj_matrix, graph_index = batch
60 |
61 | batch_size, timesteps, num_nodes, _ = X.shape
62 | assert num_nodes == self.num_nodes
63 | X = X.view(batch_size, timesteps, -1)
64 |
65 | X, adj_matrix, graph_index = X.cpu().numpy(), adj_matrix.cpu().numpy(), graph_index.cpu().numpy()
66 |
67 | X_list = []
68 |
69 | graphs = np.zeros((batch_size, self.lag+1, num_nodes, num_nodes))
70 | scores = np.zeros((batch_size, self.lag+1, num_nodes, num_nodes))
71 | if self.group_by_graph:
72 | n_unique_matrices = np.max(graph_index)+1
73 | else:
74 | graph_index = np.zeros((batch_size))
75 | n_unique_matrices = 1
76 |
77 | for i in range(n_unique_matrices):
78 |
79 | n_samples = np.sum(graph_index == i)
80 | for x in X[graph_index == i]:
81 | X_list.append(pd.DataFrame(x))
82 | learner = from_pandas_dynamic(
83 | X_list,
84 | p=self.lag,
85 | max_iter=self.max_iter,
86 | lambda_w=self.lambda_w,
87 | lambda_a=self.lambda_a,
88 | w_threshold=self.w_threshold,
89 | h_tol=self.h_tol)
90 |
91 | adj_static = nx.to_numpy_array(learner)
92 | temporal_adj_list = []
93 | for l in range(self.lag + 1):
94 | cur_adj = adj_static[l:: self.lag + 1, 0:: self.lag + 1]
95 | temporal_adj_list.append(cur_adj)
96 |
97 | # [lag+1, num_nodes, num_nodes]
98 | score = np.stack(temporal_adj_list, axis=0)
99 | # scores = np.hstack(temporal_adj_list)
100 | temporal_adj = [(score != 0).astype(int) for _ in range(n_samples)]
101 | score = [np.abs(score) for _ in range(n_samples)]
102 | graphs[i == graph_index] = np.array(temporal_adj)
103 | scores[i == graph_index] = np.array(score)
104 | if self.aggregated_graph:
105 | graphs = to_time_aggregated_graph_np(graphs)
106 | scores = to_time_aggregated_scores_np(scores)
107 | if self.ignore_self_connections:
108 | graphs = zero_out_diag_np(graphs)
109 | scores = zero_out_diag_np(scores)
110 | return torch.Tensor(graphs), torch.abs(torch.Tensor(scores)), torch.Tensor(adj_matrix)
111 |
--------------------------------------------------------------------------------
/src/baselines/PCMCITrainer.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from copy import deepcopy
3 |
4 | import numpy as np
5 | # import tigramite for pcmci
6 | from tigramite import data_processing as pp
7 | from tigramite.pcmci import PCMCI
8 | from tigramite.independence_tests.parcorr import ParCorr
9 | from tigramite.independence_tests.cmiknn import CMIknn
10 | import torch
11 |
12 | from src.utils.causality_utils import convert_temporal_to_static_adjacency_matrix, cpdag2dags
13 | from src.baselines.BaselineTrainer import BaselineTrainer
14 | from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, zero_out_diag_np
15 |
16 | """
17 | Large parts adapted from https://github.com/microsoft/causica
18 | """
19 |
20 | class PCMCITrainer(BaselineTrainer):
21 |
22 | def __init__(self,
23 | full_dataset: np.array,
24 | adj_matrices: np.array,
25 | data_dim: int,
26 | num_nodes: int,
27 | lag: int,
28 | num_workers: int = 16,
29 | aggregated_graph: bool = False,
30 | ci_test: str = 'ParCorr', # options are ParCorr, CMIknn, GPDCtorch
31 | single_graph: bool = False,
32 | pcmci_plus: bool = True,
33 | pc_alpha: float = 0.01,
34 | group_by_graph: bool = False,
35 | ignore_self_connections: bool = False
36 | ):
37 | self.group_by_graph = group_by_graph
38 | self.ignore_self_connections = ignore_self_connections
39 | if self.group_by_graph:
40 | print("PCMCI: Group by graph option set. Overriding single graph flag to True...")
41 | self.single_graph = True
42 | else:
43 | self.single_graph = single_graph
44 |
45 | super().__init__(full_dataset=full_dataset,
46 | adj_matrices=adj_matrices,
47 | data_dim=data_dim,
48 | lag=lag,
49 | num_nodes=num_nodes,
50 | num_workers=num_workers,
51 | aggregated_graph=aggregated_graph)
52 |
53 | if ci_test == 'ParCorr':
54 | self.ci_test = ParCorr(significance='analytic')
55 | elif ci_test == 'CMIknn':
56 | self.ci_test = CMIknn()
57 |
58 | # self.single_graph = single_graph
59 | self.pcmci_plus = pcmci_plus
60 | self.pc_alpha = pc_alpha
61 |
62 | if self.single_graph:
63 | self.batch_size = full_dataset.shape[0] # we want the full dataset
64 | self.analysis_mode = 'multiple'
65 | else:
66 | self.batch_size = 1
67 | self.analysis_mode = 'single'
68 |
69 | def _process_adj_matrix(self, adj_matrix: np.ndarray) -> np.ndarray:
70 | """
71 | Borrowed from microsoft/causica
72 |
73 | This will process the raw output adj graphs from pcmci_plus. The raw output can contain 3 types of edges:
74 | (1) "-->" or "<--". This indicates the directed edges, and they should appear symmetrically in the matrix.
75 | (2) "o-o": This indicates the bi-directed edges, also appears symmetrically.
76 | Note: for lagged matrix, it can only contain "-->".
77 | (3) "x-x": this means the edge direction is un-decided due to conflicting orientation rules. We ignores
78 | the edges in this case.
79 | Args:
80 | inst_matrix: the input raw inst matrix with shape [num_nodes, num_nodes, lag+1]
81 | Returns:
82 | inst_adj_matrix: np.ndarray, an inst adj matrix with shape [lag+1, num_nodes, num_nodes]
83 | """
84 | assert adj_matrix.ndim == 3
85 |
86 | adj_matrix = deepcopy(adj_matrix)
87 | # shape [lag+1, num_nodes, num_nodes]
88 | adj_matrix = np.moveaxis(adj_matrix, -1, 0)
89 | adj_matrix[adj_matrix == ""] = 0
90 | adj_matrix[adj_matrix == "<--"] = 0
91 | adj_matrix[adj_matrix == "-->"] = 1
92 | adj_matrix[adj_matrix == "o-o"] = 1
93 | adj_matrix[adj_matrix == "x-x"] = 0
94 |
95 | return adj_matrix.astype(int)
96 |
97 | def _run_pcmci(self, pcmci, tau_max, pc_alpha):
98 | if self.pcmci_plus:
99 | return pcmci.run_pcmciplus(tau_max=tau_max, pc_alpha=pc_alpha)
100 | return pcmci.run_pcmci(tau_max=tau_max, pc_alpha=pc_alpha)
101 |
102 | def _process_cpdag(self, adj_matrix: np.ndarray):
103 | """
104 | This will process the inst cpdag (i.e. adj_matrix[0, ...]) according to the mec_mode. It supports "enumerate" and "truth"
105 | Args:
106 | adj_matrix: np.ndarray, a temporal adj matrix with shape [lag+1, num_nodes, num_nodes] where the inst part can be a cpdag.
107 |
108 | Returns:
109 | adj_matrix: np.ndarray with shape [num_possible_dags, lag+1, num_nodes, num_nodes]
110 | """
111 | lag_plus, num_nodes = adj_matrix.shape[0], adj_matrix.shape[1]
112 | static_temporal_graph = convert_temporal_to_static_adjacency_matrix(
113 | adj_matrix, conversion_type="auto_regressive"
114 | ) # shape[(lag+1) *nodes, (lag+1)*nodes]
115 | all_static_temp_dags = cpdag2dags(
116 | static_temporal_graph, samples=3000
117 | ) # [all_possible_dags, (lag+1)*num_nodes, (lag+1)*num_nodes]
118 | # convert back to temporal adj matrix.
119 | temp_adj_list = np.split(
120 | all_static_temp_dags[..., :, (lag_plus - 1) * num_nodes :], lag_plus, axis=1
121 | ) # list with length lag+1, each with shape [all_possible_dags, num_nodes, num_nodes]
122 | proc_adj_matrix = np.stack(
123 | list(reversed(temp_adj_list)), axis=-3
124 | ) # shape [all_possible_dags, lag+1, num_nodes, num_nodes]
125 |
126 | return proc_adj_matrix
127 |
128 |
129 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
130 | X, adj_matrix, graph_index = batch
131 | batch_size, timesteps, num_nodes, _ = X.shape
132 | assert num_nodes == self.num_nodes
133 | X = X.view(batch_size, timesteps, -1)
134 | X, adj_matrix, graph_index = X.numpy(), adj_matrix.numpy(), graph_index.numpy()
135 | graphs = [] #np.zeros((batch_size, self.lag+1, num_nodes, num_nodes))
136 | new_adj_matrix = []
137 | if self.group_by_graph:
138 | n_unique_matrices = np.max(graph_index)+1
139 | else:
140 | graph_index = np.zeros((batch_size))
141 | n_unique_matrices = 1
142 | for i in range(n_unique_matrices):
143 | print(f"{i}/{n_unique_matrices}")
144 | n_samples = np.sum(graph_index == i)
145 | dataframe = pp.DataFrame(X[graph_index==i], analysis_mode=self.analysis_mode)
146 | pcmci = PCMCI(
147 | dataframe=dataframe,
148 | cond_ind_test=self.ci_test,
149 | verbosity=0)
150 |
151 | results = self._run_pcmci(pcmci, self.lag, self.pc_alpha)
152 | graph = self._process_adj_matrix(results["graph"])
153 | graph = self._process_cpdag(graph)
154 | num_possible_dags = graph.shape[0]
155 | new_adj_matrix.append(np.repeat(adj_matrix[graph_index==i][0][np.newaxis, ...], n_samples*num_possible_dags, axis=0))
156 | graphs.append(np.repeat(graph, n_samples, axis=0))
157 |
158 | graphs = np.concatenate(graphs, axis=0)
159 | new_adj_matrix = np.concatenate(new_adj_matrix, axis=0)
160 | if self.aggregated_graph:
161 | graphs = to_time_aggregated_graph_np(graphs)
162 | if self.ignore_self_connections:
163 | graphs = zero_out_diag_np(graphs)
164 |
165 | return torch.Tensor(graphs), torch.Tensor(graphs), torch.Tensor(new_adj_matrix)
166 |
--------------------------------------------------------------------------------
/src/baselines/VARLiNGAMTrainer.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import numpy as np
3 |
4 | import lingam
5 | import torch
6 |
7 | from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np
8 | from src.baselines.BaselineTrainer import BaselineTrainer
9 |
10 | class VARLiNGAMTrainer(BaselineTrainer):
11 |
12 | def __init__(self,
13 | full_dataset: np.array,
14 | adj_matrices: np.array,
15 | data_dim: int,
16 | num_nodes: int,
17 | lag: int,
18 | num_workers: int = 16,
19 | aggregated_graph: bool = False
20 | ):
21 | super().__init__(full_dataset=full_dataset,
22 | adj_matrices=adj_matrices,
23 | data_dim=data_dim,
24 | lag=lag,
25 | num_nodes=num_nodes,
26 | num_workers=num_workers,
27 | aggregated_graph=aggregated_graph)
28 |
29 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
30 | X, adj_matrix, _ = batch
31 |
32 | batch, timesteps, num_nodes, _ = X.shape
33 | X = X.view(batch, timesteps, -1)
34 |
35 | assert num_nodes == self.num_nodes
36 | assert batch == 1, "VARLiNGAM needs batch size 1"
37 |
38 | model_pruned = lingam.VARLiNGAM(lags=self.lag, prune=True)
39 | model_pruned.fit(X[0])
40 | graph = np.transpose(np.abs(model_pruned.adjacency_matrices_) > 0, axes=[0, 2, 1])
41 | if graph.shape[0] != (self.lag+1):
42 | while graph.shape[0] != (self.lag+1):
43 | graph = np.concatenate((graph, np.zeros((1, num_nodes, num_nodes) )), axis=0)
44 | graphs = [graph]
45 | if self.aggregated_graph:
46 | graphs = to_time_aggregated_graph_np(graphs)
47 | print(graphs)
48 | print(adj_matrix)
49 | return torch.Tensor(graphs), torch.Tensor(graphs), torch.Tensor(adj_matrix)
--------------------------------------------------------------------------------
/src/baselines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rose-STL-Lab/MCD/5276806eb761545d2e99f53d5135289d7def20e3/src/baselines/__init__.py
--------------------------------------------------------------------------------
/src/dataset/BaselineTSDataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from src.utils.data_utils.data_format_utils import get_adj_matrix_id
3 |
4 | class BaselineTSDataset(Dataset):
5 | # Dataset class that can be used with the baselines PCMCI(+), VARLiNGAM and DYNOTEARS
6 | def __init__(self,
7 | X,
8 | adj_matrix,
9 | lag,
10 | aggregated_graph=False,
11 | return_graph_indices=False):
12 | """
13 | X: np.array of shape (n_samples, timesteps, num_nodes, data_dim)
14 | adj_matrix: np.array of shape (n_samples, lag+1, num_nodes, num_nodes)
15 | """
16 | self.lag = lag
17 | self.aggregated_graph = aggregated_graph
18 | self.X = X
19 | self.adj_matrix = adj_matrix
20 | self.return_graph_indices = return_graph_indices
21 | if self.return_graph_indices:
22 | self.unique_matrices, self.matrix_indices = get_adj_matrix_id(self.adj_matrix)
23 |
24 | def __len__(self):
25 | return self.X.shape[0]
26 |
27 | def __getitem__(self, index):
28 | if not self.return_graph_indices:
29 | return self.X[index], self.adj_matrix[index]
30 | return self.X[index], self.adj_matrix[index], self.matrix_indices[index]
31 |
--------------------------------------------------------------------------------
/src/dataset/FragmentDataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Terminology:
3 |
4 | A fragment refers to the pair of tensors X_history and x_current, where X_history represents
5 | the lag information (X(t-lag) to X(t-1)) and x_current represents the current information X(t).
6 | Note that which sample a fragment comes from is irrelevant, since all we are concerned about is the
7 | causal graph which generated the fragment.
8 | """
9 |
10 | from torch.utils.data import Dataset
11 | import torch
12 |
13 | from src.utils.data_utils.data_format_utils import convert_data_to_timelagged, convert_adj_to_timelagged
14 |
15 |
16 | class FragmentDataset(Dataset):
17 | def __init__(self,
18 | X,
19 | adj_matrix,
20 | lag,
21 | return_graph_indices=True,
22 | aggregated_graph=False):
23 | """
24 | X: np.array of shape (n_samples, timesteps, num_nodes, data_dim)
25 | adj_matrix: np.array of shape (n_samples, lag+1, num_nodes, num_nodes)
26 | """
27 | self.lag = lag
28 | self.aggregated_graph = aggregated_graph
29 | self.return_graph_indices = return_graph_indices
30 | # preprocess data
31 | self.X_history, self.x_current, self.X_indices = convert_data_to_timelagged(
32 | X, lag=lag)
33 | if self.return_graph_indices:
34 | self.adj_matrix, self.graph_indices = convert_adj_to_timelagged(
35 | adj_matrix,
36 | lag=lag,
37 | n_fragments=self.X_history.shape[0],
38 | aggregated_graph=self.aggregated_graph,
39 | return_indices=True)
40 | else:
41 | self.adj_matrix = convert_adj_to_timelagged(
42 | adj_matrix,
43 | lag=lag,
44 | n_fragments=self.X_history.shape[0],
45 | aggregated_graph=self.aggregated_graph,
46 | return_indices=False)
47 |
48 | self.X_history, self.x_current, self.adj_matrix, self.X_indices = \
49 | torch.Tensor(self.X_history), torch.Tensor(self.x_current), torch.Tensor(
50 | self.adj_matrix), torch.Tensor(self.X_indices)
51 | if self.return_graph_indices:
52 | self.graph_indices = torch.Tensor(self.graph_indices)
53 |
54 | self.num_fragments = self.X_history.shape[0]
55 |
56 | def __len__(self):
57 | return self.num_fragments
58 |
59 | def __getitem__(self, index):
60 | if self.return_graph_indices:
61 | return self.X_history[index], self.x_current[index], self.adj_matrix[index], self.X_indices[index].long(), \
62 | self.graph_indices[index].long()
63 | return self.X_history[index], self.x_current[index], self.adj_matrix[index], self.X_indices[index].long()
64 |
--------------------------------------------------------------------------------
/src/model/BaseTrainer.py:
--------------------------------------------------------------------------------
1 |
2 | import lightning.pytorch as pl
3 | import torch
4 | from torch.utils.data import DataLoader, TensorDataset
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | from src.dataset.FragmentDataset import FragmentDataset
9 | from src.utils.metrics_utils import mape_loss
10 |
11 |
12 | class BaseTrainer(pl.LightningModule):
13 |
14 | def __init__(self,
15 | full_dataset: np.array,
16 | adj_matrices: np.array,
17 | data_dim: int,
18 | lag: int,
19 | num_workers: int = 16,
20 | batch_size: int = 256,
21 | aggregated_graph: bool = False,
22 | return_graph_indices: bool = True,
23 | val_frac: float = 0.2,
24 | use_all_for_val: bool = False,
25 | shuffle: bool = True):
26 | super().__init__()
27 | # self.automatic_optimization = False
28 | self.num_workers = num_workers
29 | self.batch_size = batch_size
30 | self.aggregated_graph = aggregated_graph
31 | self.data_dim = data_dim
32 | self.lag = lag
33 | self.return_graph_indices = return_graph_indices
34 |
35 | I = np.arange(full_dataset.shape[0])
36 | if shuffle:
37 | rng = np.random.default_rng()
38 | rng.shuffle(I)
39 |
40 | self.full_dataset_np = full_dataset[I]
41 | self.adj_matrices_np = adj_matrices[I]
42 |
43 | self.use_all_for_val = use_all_for_val
44 | self.val_frac = val_frac
45 | assert self.val_frac >= 0.0 and self.val_frac < 1.0, "Validation fraction should be between 0 and 1"
46 |
47 | num_samples = full_dataset.shape[0]
48 | self.total_samples = num_samples
49 | assert adj_matrices.shape[0] == num_samples
50 |
51 | if self.use_all_for_val:
52 | print("Using *all* examples for validation. Ignoring val_frac...")
53 | self.train_dataset_np = self.full_dataset_np
54 | self.val_dataset_np = self.full_dataset_np
55 | self.train_adj_np = self.adj_matrices_np
56 | self.val_adj_np = self.adj_matrices_np
57 | else:
58 | self.train_dataset_np = self.full_dataset_np[:int(
59 | (1-self.val_frac)*num_samples)]
60 | self.val_dataset_np = self.full_dataset_np[int(
61 | (1-self.val_frac)*num_samples):]
62 | self.train_adj_np = self.adj_matrices_np[:int(
63 | (1-self.val_frac)*num_samples)]
64 | self.val_adj_np = self.adj_matrices_np[int(
65 | (1-self.val_frac)*num_samples):]
66 | self.train_frag_dataset = FragmentDataset(
67 | self.train_dataset_np,
68 | self.train_adj_np,
69 | lag=lag,
70 | aggregated_graph=self.aggregated_graph,
71 | return_graph_indices=self.return_graph_indices)
72 | self.val_frag_dataset = FragmentDataset(
73 | self.val_dataset_np,
74 | self.val_adj_np,
75 | lag=lag,
76 | aggregated_graph=self.aggregated_graph,
77 | return_graph_indices=self.return_graph_indices)
78 | # self.full_frag_dataset = FragmentDataset(
79 | # self.full_dataset_np,
80 | # self.adj_matrices_np,
81 | # lag=lag,
82 | # aggregated_graph=self.aggregated_graph,
83 | # return_graph_indices=self.return_graph_indices)
84 | self.num_fragments = len(self.train_frag_dataset)
85 | self.full_dataset = TensorDataset(
86 | torch.Tensor(self.full_dataset_np),
87 | torch.Tensor(self.adj_matrices_np),
88 | torch.arange(self.full_dataset_np.shape[0]))
89 | if self.batch_size is None:
90 | # do full-batch training
91 | self.batch_size = self.num_fragments
92 |
93 | def forward(self):
94 | raise NotImplementedError
95 |
96 | def compute_loss(self, X_history, x_current, X_full, adj_matrix):
97 | raise NotImplementedError
98 |
99 | def compute_mse(self, x_current, x_pred):
100 | return F.mse_loss(x_current, x_pred)
101 |
102 | def compute_mape(self, x_current, x_pred):
103 | return mape_loss(x_current, x_pred)
104 |
105 | def training_step(self, batch, batch_idx):
106 | X_history, x_current, X_full, adj_matrix = batch
107 | loss = self.compute_loss(X_history, x_current, X_full, adj_matrix)
108 | return loss
109 |
110 | def validation_step(self, batch, batch_idx):
111 | X_history, x_current, X_full, adj_matrix = batch
112 | loss = self.compute_loss(X_history, x_current, X_full, adj_matrix)
113 | self.log("val_loss", loss)
114 |
115 | def train_dataloader(self) -> DataLoader:
116 | return DataLoader(self.train_frag_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
117 |
118 | def val_dataloader(self) -> DataLoader:
119 | return DataLoader(self.val_frag_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
120 |
121 | def get_full_dataloader(self) -> DataLoader:
122 | return DataLoader(self.full_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
123 |
124 | def track_gradients(self, m, log_name):
125 | total_norm = 0
126 | for p in m.parameters():
127 | if p.grad is not None:
128 | param_norm = p.grad.data.norm(2)
129 | total_norm += param_norm.item() ** 2
130 | total_norm = total_norm ** (1. / 2)
131 | self.log(log_name, total_norm)
132 |
--------------------------------------------------------------------------------
/src/model/RhinoTrainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn.metrics import f1_score
4 |
5 | from src.utils.loss_utils import dag_penalty_notears, temporal_graph_sparsity
6 | from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, to_time_aggregated_scores_torch, zero_out_diag_np, zero_out_diag_torch
7 | from src.utils.metrics_utils import compute_shd
8 | from src.model.BaseTrainer import BaseTrainer
9 | from src.modules.adjacency_matrices.TemporalAdjacencyMatrix import TemporalAdjacencyMatrix
10 | from src.modules.adjacency_matrices.TwoWayTemporalAdjacencyMatrix import TwoWayTemporalAdjacencyMatrix
11 | from src.training.auglag import AugLagLRConfig, AuglagLRCallback, AugLagLR, AugLagLossCalculator
12 |
13 |
14 | class RhinoTrainer(BaseTrainer):
15 |
16 | def __init__(self,
17 | full_dataset: np.array,
18 | adj_matrices: np.array,
19 | data_dim: int,
20 | lag: int,
21 | num_nodes: int,
22 | causal_decoder,
23 | tcsf,
24 | disable_inst: bool = False,
25 | likelihood_loss: str = 'flow',
26 | sparsity_factor: float = 20,
27 | num_workers: int = 16,
28 | batch_size: int = 256,
29 | matrix_temperature: float = 0.25,
30 | aggregated_graph: bool = False,
31 | ignore_self_connections: bool = False,
32 | threeway_graph_dist: bool = True,
33 | skip_auglag_epochs: int = 0,
34 | training_procedure: str = 'auglag',
35 | training_config=None,
36 | init_logits=[0, 0],
37 | use_all_for_val=False,
38 | shuffle=True):
39 |
40 | self.aggregated_graph = aggregated_graph
41 | self.ignore_self_connections = ignore_self_connections
42 | self.skip_auglag_epochs = skip_auglag_epochs
43 | self.init_logits = init_logits
44 | self.disable_inst = disable_inst
45 | super().__init__(full_dataset=full_dataset,
46 | adj_matrices=adj_matrices,
47 | data_dim=data_dim,
48 | lag=lag,
49 | num_workers=num_workers,
50 | batch_size=batch_size,
51 | aggregated_graph=self.aggregated_graph,
52 | use_all_for_val=use_all_for_val,
53 | shuffle=shuffle)
54 | self.num_nodes = num_nodes
55 | self.matrix_temperature = matrix_temperature
56 | self.threeway_graph_dist = threeway_graph_dist
57 | self.training_procedure = training_procedure
58 |
59 | print("Number of fragments:", self.num_fragments)
60 | print("Number of samples:", self.total_samples)
61 |
62 | assert likelihood_loss in ['mse', 'flow']
63 | self.likelihood_loss = likelihood_loss
64 |
65 | self.sparsity_factor = sparsity_factor
66 | self.graphs = []
67 |
68 | self.causal_decoder = causal_decoder
69 | self.tcsf = tcsf
70 |
71 | self.initialize_graph()
72 |
73 | if self.training_procedure == 'auglag':
74 | self.training_config = training_config
75 | if training_config is None:
76 | self.training_config = AugLagLRConfig()
77 | if self.skip_auglag_epochs > 0:
78 | print(
79 | f"Not performing augmented lagrangian optimization for the first {self.skip_auglag_epochs} epochs...")
80 | self.disabled_epochs = range(self.skip_auglag_epochs)
81 | else:
82 | self.disabled_epochs = None
83 | self.lr_scheduler = AugLagLR(config=self.training_config)
84 | self.loss_calc = AugLagLossCalculator(init_alpha=self.training_config.init_alpha,
85 | init_rho=self.training_config.init_rho)
86 |
87 | def initialize_graph(self):
88 | if self.threeway_graph_dist:
89 | self.adj_matrix = TemporalAdjacencyMatrix(
90 | input_dim=self.num_nodes,
91 | lag=self.lag,
92 | tau_gumbel=self.matrix_temperature,
93 | init_logits=self.init_logits,
94 | disable_inst=self.disable_inst)
95 | else:
96 | self.adj_matrix = TwoWayTemporalAdjacencyMatrix(input_dim=self.num_nodes,
97 | lag=self.lag,
98 | tau_gumbel=self.matrix_temperature,
99 | init_logits=self.init_logits,
100 | disable_inst=self.disable_inst)
101 |
102 | def forward(self):
103 | raise NotImplementedError
104 |
105 | def compute_loss_terms(self, X_history: torch.Tensor, x_current: torch.Tensor, G: torch.Tensor):
106 |
107 | # ******************* graph prior *********************
108 | graph_sparsity_term = self.sparsity_factor * \
109 | temporal_graph_sparsity(G)
110 |
111 | # dagness factors
112 | if self.training_procedure == 'auglag':
113 | dagness_penalty = dag_penalty_notears(G[0])
114 |
115 | prior_term = graph_sparsity_term
116 | prior_term /= self.num_fragments
117 |
118 | # **************** graph entropy **********************
119 |
120 | entropy_term = -self.adj_matrix.entropy()/self.num_fragments
121 |
122 | # ************* likelihood loss ***********************
123 |
124 | batch_size = X_history.shape[0]
125 | expanded_G = G.unsqueeze(0).repeat(batch_size, 1, 1, 1)
126 |
127 | X_input = torch.cat((X_history, x_current.unsqueeze(1)), dim=1)
128 | x_pred = self.causal_decoder(X_input, expanded_G)
129 |
130 | mse_loss = self.compute_mse(x_current, x_pred)
131 | mape_loss = self.compute_mape(x_current, x_pred)
132 |
133 | if self.likelihood_loss == 'mse':
134 | likelihood_term = mse_loss
135 |
136 | elif self.likelihood_loss == 'flow':
137 | batch, num_nodes, data_dim = x_current.shape
138 | log_prob = self.tcsf.log_prob(X_input=(x_current - x_pred).view(batch, num_nodes*data_dim),
139 | X_history=X_history,
140 | A=expanded_G).sum(-1)
141 | likelihood_term = -torch.mean(log_prob)
142 |
143 | loss_terms = {
144 | 'graph_sparsity': graph_sparsity_term,
145 | 'dagness_penalty': dagness_penalty,
146 | 'graph_prior': prior_term,
147 |
148 | 'graph_entropy': entropy_term,
149 |
150 | 'mse': mse_loss,
151 | 'mape': mape_loss,
152 | 'likelihood': likelihood_term
153 | }
154 | return loss_terms
155 |
156 | def compute_loss(self, X_history, x_current, idx):
157 | # sample G
158 |
159 | G = self.adj_matrix.sample_A()
160 | loss_terms = self.compute_loss_terms(
161 | X_history=X_history,
162 | x_current=x_current,
163 | G=G)
164 |
165 | total_loss = loss_terms['likelihood'] +\
166 | loss_terms['graph_prior'] +\
167 | loss_terms['graph_entropy']
168 |
169 | return total_loss, loss_terms, None
170 |
171 | def training_step(self, batch, batch_idx):
172 |
173 | X_history, x_current, _, idx, _ = batch
174 | loss, loss_terms, _ = self.compute_loss(X_history, x_current, idx)
175 |
176 | loss = self.loss_calc(
177 | loss, loss_terms['dagness_penalty']/self.num_fragments)
178 | self.log_dict(loss_terms, on_epoch=True)
179 |
180 | loss_terms['loss'] = loss
181 | return loss_terms
182 |
183 | def validation_func(self, X_history, x_current, adj_matrix, G, idx):
184 | batch_size = X_history.shape[0]
185 |
186 | loss, loss_terms, _ = self.compute_loss(X_history, x_current, idx)
187 |
188 | G = G.detach().cpu().numpy()
189 | adj_matrix = adj_matrix.detach().cpu().numpy()
190 |
191 | if self.aggregated_graph:
192 | G = to_time_aggregated_graph_np(G)
193 | if self.ignore_self_connections:
194 | G = zero_out_diag_np(G)
195 |
196 | shd_loss = compute_shd(adj_matrix, G, aggregated_graph=True)
197 | shd_loss = torch.Tensor([shd_loss])
198 | f1 = f1_score(adj_matrix.flatten(), G.flatten())
199 | else:
200 | mask = adj_matrix != G
201 |
202 | shd_loss = np.sum(mask)/batch_size
203 | shd_inst = np.sum(mask[:, 0])/batch_size
204 | shd_lag = np.sum(mask[:, 1:])/batch_size
205 |
206 | # shd_loss, shd_inst, shd_lag = compute_shd(adj_matrix, G)
207 | tp = np.logical_and(adj_matrix == 1, adj_matrix == G)
208 | fp = np.logical_and(adj_matrix != 1, G == 1)
209 | fn = np.logical_and(adj_matrix != 0, G == 0)
210 |
211 | f1_inst = 2*np.sum(tp[:, 0]) / (2*np.sum(tp[:, 0]) +
212 | np.sum(fp[:, 0]) + np.sum(fn[:, 0]))
213 | f1_lag = 2*np.sum(tp[:, 1:]) / (2*np.sum(tp[:, 1:]) +
214 | np.sum(fp[:, 1:]) + np.sum(fn[:, 1:]))
215 |
216 | # f1_inst = f1_score(get_off_diagonal(adj_matrix[:, 0]).flatten(), get_off_diagonal(G[:, 0]).flatten())
217 | # f1_lag = f1_score(adj_matrix[:, 1:].flatten(), G[:, 1:].flatten())
218 | shd_loss = torch.Tensor([shd_loss])
219 | shd_inst = torch.Tensor([shd_inst])
220 | shd_lag = torch.Tensor([shd_lag])
221 |
222 | if not self.aggregated_graph:
223 | loss_terms['shd_inst'] = shd_inst
224 | loss_terms['shd_lag'] = shd_lag
225 | loss_terms['f1_lag'] = f1_lag
226 | loss_terms['f1_inst'] = f1_inst
227 | else:
228 | loss_terms['f1'] = f1
229 |
230 | loss_terms['shd_loss'] = shd_loss
231 |
232 | loss_terms['val_loss'] = loss
233 |
234 | for key, item in loss_terms.items():
235 | self.log(key, item)
236 |
237 | return loss_terms
238 |
239 | def validation_step(self, batch, batch_idx):
240 | X_history, x_current, adj_matrix, idx, _ = batch
241 |
242 | batch_size = X_history.shape[0]
243 | G = self.adj_matrix.sample_A()
244 | expanded_G = G.unsqueeze(0).repeat(batch_size, 1, 1, 1)
245 |
246 | loss_terms = self.validation_func(
247 | X_history, x_current, adj_matrix, expanded_G, idx)
248 |
249 | probs = self.adj_matrix.get_adj_matrix(do_round=False)
250 | probs = probs.unsqueeze(0).repeat(batch_size, 1, 1, 1)
251 | if self.aggregated_graph:
252 | probs = to_time_aggregated_scores_torch(probs)
253 | if self.ignore_self_connections:
254 | probs = zero_out_diag_torch(probs)
255 |
256 | return loss_terms
257 |
258 | def configure_optimizers(self):
259 | """Set the learning rates for different sets of parameters."""
260 | modules = {
261 | "functional_relationships": self.causal_decoder,
262 | "vardist": self.adj_matrix,
263 | "noise_dist": self.tcsf,
264 | }
265 |
266 | parameter_list = [
267 | {
268 | "params": module.parameters(),
269 | "lr": self.training_config.lr_init_dict[name],
270 | "name": name,
271 | }
272 | for name, module in modules.items() if module is not None
273 | ]
274 |
275 | # Check that all modules are added to the parameter list
276 | check_modules = set(modules.values())
277 | for module in self.parameters(recurse=False):
278 | assert module in check_modules, f"Module {module} not in module list"
279 |
280 | return torch.optim.Adam(parameter_list)
281 |
282 | def configure_callbacks(self):
283 | """Create a callback for the auglag callback."""
284 | if self.training_procedure == 'auglag':
285 | return [AuglagLRCallback(self.lr_scheduler, log_auglag=True, disabled_epochs=self.disabled_epochs)]
286 | return None # should not happen
287 |
288 | def predict_step(self, batch, batch_idx, dataloader_idx=0):
289 |
290 | X_full, adj_matrix, _ = batch
291 | batch_size = X_full.shape[0]
292 |
293 | probs = self.adj_matrix.get_adj_matrix(
294 | do_round=False).unsqueeze(0).repeat(batch_size, 1, 1, 1)
295 |
296 | if self.aggregated_graph:
297 | probs = to_time_aggregated_scores_torch(probs)
298 | if self.ignore_self_connections:
299 | probs = zero_out_diag_torch(probs)
300 | # G = torch.bernoulli(probs)
301 | G = (probs >= 0.5).long()
302 | return G, probs, adj_matrix
303 |
--------------------------------------------------------------------------------
/src/model/generate_model.py:
--------------------------------------------------------------------------------
1 | from omegaconf import DictConfig
2 | from hydra.utils import instantiate
3 | from src.modules.TemporalConditionalSplineFlow import TemporalConditionalSplineFlow
4 |
5 |
6 | def generate_model(cfg: DictConfig):
7 | lag = cfg.lag
8 | num_nodes = cfg.num_nodes
9 | data_dim = cfg.data_dim
10 | num_workers = cfg.num_workers
11 | aggregated_graph = cfg.aggregated_graph
12 |
13 | if cfg.model in ['pcmci', 'varlingam', 'dynotears']:
14 | trainer = instantiate(cfg.trainer,
15 | num_workers=num_workers,
16 | lag=lag,
17 | num_nodes=num_nodes,
18 | data_dim=data_dim,
19 | aggregated_graph=aggregated_graph)
20 | else:
21 | multi_graph = cfg.model == 'mcd'
22 | if multi_graph:
23 | num_graphs = cfg.trainer.num_graphs
24 | if 'decoder' in cfg:
25 | # generate the decoder
26 | if not multi_graph:
27 | causal_decoder = instantiate(cfg.decoder,
28 | lag=lag,
29 | num_nodes=num_nodes,
30 | data_dim=data_dim)
31 | else:
32 | causal_decoder = instantiate(cfg.decoder,
33 | lag=lag,
34 | num_nodes=num_nodes,
35 | data_dim=data_dim,
36 | num_graphs=num_graphs)
37 | if 'likelihood_loss' in cfg.trainer and cfg.trainer.likelihood_loss == 'flow':
38 | # create hypernet
39 | if not multi_graph:
40 | hypernet = instantiate(cfg.hypernet,
41 | lag=lag,
42 | num_nodes=num_nodes,
43 | data_dim=data_dim)
44 | else:
45 | hypernet = instantiate(cfg.hypernet,
46 | lag=lag,
47 | num_nodes=num_nodes,
48 | data_dim=data_dim,
49 | num_graphs=num_graphs)
50 | tcsf = TemporalConditionalSplineFlow(hypernet=hypernet)
51 | else:
52 | hypernet = None
53 | tcsf = None
54 |
55 | # create auglag config
56 | if cfg.trainer.training_procedure == 'auglag':
57 | training_config = instantiate(cfg.auglag_config)
58 |
59 | if cfg.model == 'rhino':
60 | trainer = instantiate(cfg.trainer,
61 | num_workers=num_workers,
62 | lag=lag,
63 | num_nodes=num_nodes,
64 | data_dim=data_dim,
65 | causal_decoder=causal_decoder,
66 | tcsf=tcsf,
67 | training_config=training_config,
68 | aggregated_graph=aggregated_graph)
69 | elif cfg.model == 'mcd':
70 | trainer = instantiate(cfg.trainer,
71 | num_workers=num_workers,
72 | lag=lag,
73 | num_nodes=num_nodes,
74 | data_dim=data_dim,
75 | num_graphs=num_graphs,
76 | causal_decoder=causal_decoder,
77 | tcsf=tcsf,
78 | training_config=training_config,
79 | aggregated_graph=aggregated_graph)
80 |
81 | return trainer
82 |
--------------------------------------------------------------------------------
/src/modules/CausalDecoder.py:
--------------------------------------------------------------------------------
1 | import lightning.pytorch as pl
2 | from torch import nn
3 | import torch
4 | from src.utils.torch_utils import generate_fully_connected
5 |
6 |
7 | class CausalDecoder(pl.LightningModule):
8 |
9 | def __init__(self,
10 | data_dim: int,
11 | lag: int,
12 | num_nodes: int,
13 | embedding_dim: int = None,
14 | skip_connection: bool = False,
15 | linear: bool = False
16 | ):
17 | super().__init__()
18 |
19 | if embedding_dim is None:
20 | embedding_dim = num_nodes * data_dim
21 |
22 | self.embedding_dim = embedding_dim
23 | self.data_dim = data_dim
24 | self.lag = lag
25 | self.num_nodes = num_nodes
26 | self.linear = linear
27 |
28 | if not self.linear:
29 | self.embeddings = nn.Parameter((
30 | torch.randn(self.lag + 1, self.num_nodes,
31 | self.embedding_dim, device=self.device) * 0.01
32 | ), requires_grad=True) # shape (lag+1, num_nodes, embedding_dim)
33 |
34 | input_dim = 2*self.embedding_dim
35 | self.nn_size = max(4 * num_nodes, self.embedding_dim, 64)
36 |
37 | self.f = generate_fully_connected(
38 | input_dim=input_dim,
39 | output_dim=num_nodes*data_dim, # potentially num_nodes
40 | hidden_dims=[self.nn_size, self.nn_size],
41 | non_linearity=nn.LeakyReLU,
42 | activation=nn.Identity,
43 | device=self.device,
44 | normalization=nn.LayerNorm,
45 | res_connection=skip_connection,
46 | )
47 |
48 | self.g = generate_fully_connected(
49 | input_dim=self.embedding_dim+self.data_dim,
50 | output_dim=self.embedding_dim,
51 | hidden_dims=[self.nn_size, self.nn_size],
52 | non_linearity=nn.LeakyReLU,
53 | activation=nn.Identity,
54 | device=self.device,
55 | normalization=nn.LayerNorm,
56 | res_connection=skip_connection,
57 | )
58 |
59 | else:
60 | self.w = nn.Parameter(
61 | torch.randn(self.lag+1, self.num_nodes, self.num_nodes, device=self.device)*0.5, requires_grad=True
62 | )
63 |
64 | def forward(self, X_input: torch.Tensor, A: torch.Tensor, embeddings: torch.Tensor = None):
65 | """
66 | Args:
67 | X_input: input data of shape (batch, lag+1, num_nodes, data_dim)
68 | A: adjacency matrix of shape (batch, (lag+1), num_nodes, num_nodes)
69 | """
70 |
71 | assert len(X_input.shape) == 4
72 |
73 | batch, L, num_nodes, data_dim = X_input.shape
74 | lag = L-1
75 |
76 | if not self.linear:
77 | # ensure we have the correct shape
78 | assert (A.shape[0] == batch and A.shape[1] == lag+1 and A.shape[2] ==
79 | num_nodes and A.shape[2] == num_nodes)
80 |
81 | if embeddings is None:
82 | E = self.embeddings.expand(
83 | X_input.shape[0], -1, -1, -1
84 | )
85 | else:
86 | E = embeddings
87 |
88 | X_in = torch.cat((X_input, E), dim=-1)
89 | # X_in: (batch, lag+1, num_nodes, embedding_dim+data_dim)
90 | X_enc = self.g(X_in)
91 |
92 | A_temp = A.flip([1])
93 |
94 | # get the parents of X
95 | X_sum = torch.einsum("blij,blio->bjo", A_temp,
96 | X_enc) # / num_nodes
97 |
98 | X_sum = torch.cat([X_sum, E[:, 0, :, :]], dim=-1)
99 | # (batch, num_nodes, embedding_dim)
100 | # pass through f network to get the predictions
101 |
102 | group_mask = torch.eye(num_nodes*data_dim).to(self.device)
103 | # (batch, num_nodes, data_dim)
104 | return torch.sum(self.f(X_sum)*group_mask, dim=-1).unsqueeze(-1)
105 | return torch.einsum("lij,blio->bjo", (self.w * A[0]).flip([0]), X_input)
106 |
--------------------------------------------------------------------------------
/src/modules/LinearCausalGraph.py:
--------------------------------------------------------------------------------
1 |
2 | import lightning.pytorch as pl
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class LinearCausalGraph(pl.LightningModule):
8 |
9 | def __init__(
10 | self,
11 | lag: int,
12 | input_dim: int
13 | ):
14 | """
15 | Args:
16 | input_dim: dimension.
17 | tau_gumbel: temperature used for gumbel softmax sampling.
18 | """
19 | super().__init__()
20 |
21 | self.lag = lag
22 | self.w = nn.Parameter(
23 | torch.zeros(self.lag+1, input_dim, input_dim, device=self.device), requires_grad=True
24 | )
25 | self.I = torch.arange(input_dim)
26 | self.mask = torch.ones((self.lag+1, input_dim, input_dim))
27 | self.mask[0, self.I, self.I] = 0
28 | self.input_dim = input_dim
29 |
30 | def get_w(self) -> torch.Tensor:
31 | """
32 | Returns the matrix. Ensures that the instantaneous matrix has zero in the diagonals
33 | """
34 | return self.w * self.mask.to(self.device)
35 |
--------------------------------------------------------------------------------
/src/modules/MixtureSelectionLogits.py:
--------------------------------------------------------------------------------
1 | import lightning.pytorch as pl
2 | from torch import nn
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.distributions as td
6 |
7 |
8 | class MixtureSelectionLogits(pl.LightningModule):
9 |
10 | def __init__(
11 | self,
12 | num_samples: int,
13 | num_graphs: int,
14 | tau: float = 1.0
15 | ):
16 |
17 | super().__init__()
18 | self.num_graphs = num_graphs
19 | self.num_samples = num_samples
20 | self.graph_select_logits = nn.Parameter((
21 | torch.ones(self.num_graphs, self.num_samples,
22 | device=self.device) * 0.01
23 | ), requires_grad=True)
24 | self.tau = tau
25 |
26 | def manual_set_mixture_indices(self, idx, mixture_idx):
27 | """
28 | Use this function to manually set the mixture index.
29 | Mainly used for diagnostic/ablative purposes
30 | """
31 |
32 | with torch.no_grad():
33 | self.graph_select_logits[:, idx] = -10
34 | self.graph_select_logits[mixture_idx, idx] = 10
35 | self.graph_select_logits.requires_grad_(False)
36 |
37 | def set_logits(self, idx, logits):
38 | """
39 | Use this function to manually set the logits.
40 | Used in the E step of the EM implementation
41 | """
42 | with torch.no_grad():
43 | self.graph_select_logits[:, idx] = logits.transpose(0, -1)
44 |
45 | def reset_parameters(self):
46 | with torch.no_grad():
47 | self.graph_select_logits[:] = torch.ones(self.num_graphs,
48 | self.num_samples,
49 | device=self.device) * 0.01
50 |
51 | def turn_off_grad(self):
52 | self.graph_select_logits.requires_grad_(False)
53 |
54 | def turn_on_grad(self):
55 | self.graph_select_logits.requires_grad_(True)
56 |
57 | def get_probs(self, idx):
58 | return F.softmax(self.graph_select_logits[:, idx]/self.tau, dim=0)
59 |
60 | def get_mixture_indices(self, idx):
61 | return torch.argmax(self.graph_select_logits[:, idx], dim=0)
62 |
63 | def entropy(self, idx):
64 | logits = self.graph_select_logits[:, idx]/self.tau
65 | dist = td.Categorical(logits=logits.transpose(0, -1))
66 | entropy = dist.entropy().sum()
67 |
68 | return entropy/idx.shape[0]
69 |
--------------------------------------------------------------------------------
/src/modules/MultiCausalDecoder.py:
--------------------------------------------------------------------------------
1 | import lightning.pytorch as pl
2 | from torch import nn
3 | import torch
4 | from src.modules.MultiEmbedding import MultiEmbedding
5 | from src.utils.torch_utils import generate_fully_connected
6 |
7 |
8 | class MultiCausalDecoder(pl.LightningModule):
9 |
10 | def __init__(self,
11 | data_dim: int,
12 | lag: int,
13 | num_nodes: int,
14 | num_graphs: int,
15 | embedding_dim: int = None,
16 | skip_connection: bool = False,
17 | linear: bool = False,
18 | dropout_p: float = 0.0
19 | ):
20 |
21 | super().__init__()
22 |
23 | if embedding_dim is not None:
24 | self.embedding_dim = embedding_dim
25 | else:
26 | self.embedding_dim = num_nodes * data_dim
27 |
28 | self.data_dim = data_dim
29 | self.lag = lag
30 | self.num_nodes = num_nodes
31 | self.num_graphs = num_graphs
32 | self.dropout_p = dropout_p
33 | self.linear = linear
34 |
35 | input_dim = 2*self.embedding_dim
36 |
37 | if not self.linear:
38 | self.nn_size = max(4 * num_nodes, self.embedding_dim, 64)
39 |
40 | self.f = generate_fully_connected(
41 | input_dim=input_dim,
42 | output_dim=num_nodes*data_dim, # potentially num_nodes
43 | hidden_dims=[self.nn_size, self.nn_size],
44 | non_linearity=nn.LeakyReLU,
45 | activation=nn.Identity,
46 | device=self.device,
47 | normalization=nn.LayerNorm,
48 | res_connection=skip_connection,
49 | )
50 |
51 | self.g = generate_fully_connected(
52 | input_dim=self.embedding_dim+self.data_dim,
53 | output_dim=self.embedding_dim,
54 | hidden_dims=[self.nn_size, self.nn_size],
55 | non_linearity=nn.LeakyReLU,
56 | activation=nn.Identity,
57 | device=self.device,
58 | normalization=nn.LayerNorm,
59 | res_connection=skip_connection,
60 | )
61 |
62 | self.cd_embeddings = MultiEmbedding(num_nodes=self.num_nodes,
63 | lag=self.lag,
64 | num_graphs=self.num_graphs,
65 | embedding_dim=self.embedding_dim)
66 |
67 | else:
68 | self.w = nn.Parameter(
69 | torch.randn(self.num_graphs, self.lag+1, self.num_nodes, self.num_nodes, device=self.device)*0.5, requires_grad=True
70 | )
71 |
72 | def forward(self, X_input: torch.Tensor, A: torch.Tensor):
73 | """
74 | Args:
75 | X_input: input data of shape (batch, lag+1, num_nodes, data_dim)
76 | A: adjacency matrix of shape (num_graphs, lag+1, num_nodes, num_nodes)
77 | """
78 |
79 | assert len(X_input.shape) == 4
80 |
81 | batch, L, num_nodes, data_dim = X_input.shape
82 | lag = L-1
83 |
84 | if not self.linear:
85 | E = self.cd_embeddings.get_embeddings()
86 |
87 | # reshape X to the correct shape
88 | A = A.unsqueeze(0).expand((batch, -1, -1, -1, -1))
89 | E = E.unsqueeze(0).expand((batch, -1, -1, -1, -1))
90 | X_input = X_input.unsqueeze(1).expand(
91 | (-1, self.num_graphs, -1, -1, -1))
92 |
93 | # ensure we have the correct shape
94 | assert (A.shape[0] == batch and A.shape[1] == self.num_graphs and
95 | A.shape[2] == lag + 1 and A.shape[3] == num_nodes and
96 | A.shape[4] == num_nodes)
97 | assert (E.shape[0] == batch and E.shape[1] == self.num_graphs
98 | and E.shape[2] == lag+1 and E.shape[3] == num_nodes
99 | and E.shape[4] == self.embedding_dim)
100 |
101 | X_in = torch.cat((X_input, E), dim=-1)
102 | # X_in: (batch, lag+1, num_nodes, embedding_dim+data_dim)
103 | X_enc = self.g(X_in)
104 | A_temp = A.flip([2])
105 | # get the parents of X
106 | X_sum = torch.einsum("bnlij,bnlio->bnjo", A_temp, X_enc)
107 |
108 | X_sum = torch.cat([X_sum, E[:, :, 0, :, :]], dim=-1)
109 | # (batch, num_graphs, num_nodes, embedding_dim)
110 | # pass through f network to get the predictions
111 | group_mask = torch.eye(num_nodes*data_dim).to(self.device)
112 |
113 | # (batch, num_graphs, num_nodes, data_dim)
114 | return torch.sum(self.f(X_sum)*group_mask, dim=-1).unsqueeze(-1)
115 |
116 | return torch.einsum("klij,blio->bkjo", (self.w * A).flip([1]), X_input)
117 | # return self.f(X_sum)
118 |
--------------------------------------------------------------------------------
/src/modules/MultiEmbedding.py:
--------------------------------------------------------------------------------
1 | import lightning.pytorch as pl
2 | from torch import nn
3 | import torch
4 |
5 |
6 | class MultiEmbedding(pl.LightningModule):
7 |
8 | def __init__(
9 | self,
10 | num_nodes: int,
11 | lag: int,
12 | num_graphs: int,
13 | embedding_dim: int
14 | ):
15 |
16 | super().__init__()
17 | self.lag = lag
18 | # Assertion lag > 0
19 | assert lag > 0
20 | self.num_nodes = num_nodes
21 | self.num_graphs = num_graphs
22 | self.embedding_dim = embedding_dim
23 |
24 | self.lag_embeddings = nn.Parameter((
25 | torch.randn(self.num_graphs, self.lag, self.num_nodes,
26 | self.embedding_dim, device=self.device) * 0.01
27 | ), requires_grad=True)
28 |
29 | self.inst_embeddings = nn.Parameter((
30 | torch.randn(self.num_graphs, 1, self.num_nodes,
31 | self.embedding_dim, device=self.device) * 0.01
32 | ), requires_grad=True)
33 |
34 | def turn_off_inst_grad(self):
35 | self.inst_embeddings.requires_grad_(False)
36 |
37 | def turn_on_inst_grad(self):
38 | self.inst_embeddings.requires_grad_(True)
39 |
40 | def get_embeddings(self):
41 | return torch.cat((self.inst_embeddings, self.lag_embeddings), dim=1)
42 |
--------------------------------------------------------------------------------
/src/modules/MultiLinearCausalGraph.py:
--------------------------------------------------------------------------------
1 |
2 | import lightning.pytorch as pl
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class MultiLinearCausalGraph(pl.LightningModule):
8 |
9 | def __init__(
10 | self,
11 | lag: int,
12 | input_dim: int,
13 | num_graphs: int
14 | ):
15 | """
16 | Args:
17 | input_dim: dimension.
18 | tau_gumbel: temperature used for gumbel softmax sampling.
19 | """
20 | super().__init__()
21 |
22 | self.num_graphs = num_graphs
23 | self.lag = lag
24 | self.w = nn.Parameter(
25 | torch.randn(self.num_graphs, self.lag+1, input_dim, input_dim, device=self.device)*0.5, requires_grad=True
26 | )
27 | self.I = torch.arange(input_dim)
28 | self.mask = torch.ones(
29 | (self.num_graphs, self.lag+1, input_dim, input_dim))
30 | self.mask[:, 0, self.I, self.I] = 0
31 | self.input_dim = input_dim
32 |
33 | def get_w(self) -> torch.Tensor:
34 | """
35 | Returns the matrix. Ensures that the instantaneous matrix has zero in the diagonals
36 | """
37 | return self.w * self.mask.to(self.device)
38 |
--------------------------------------------------------------------------------
/src/modules/MultiTemporalHyperNet.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 | import lightning.pytorch as pl
3 | from torch import nn
4 | import torch
5 | from src.modules.MultiEmbedding import MultiEmbedding
6 | from src.utils.torch_utils import generate_fully_connected
7 |
8 |
9 | class MultiTemporalHyperNet(pl.LightningModule):
10 |
11 | def __init__(self,
12 | order: str,
13 | lag: int,
14 | data_dim: int,
15 | num_nodes: int,
16 | num_graphs: int,
17 | embedding_dim: int = None,
18 | skip_connection: bool = False,
19 | num_bins: int = 8,
20 | dropout_p: float = 0.0
21 | ):
22 |
23 | super().__init__()
24 |
25 | if embedding_dim is not None:
26 | self.embedding_dim = embedding_dim
27 | else:
28 | self.embedding_dim = num_nodes * data_dim
29 |
30 | self.data_dim = data_dim
31 | self.lag = lag
32 | self.order = order
33 | self.num_bins = num_bins
34 | self.num_nodes = num_nodes
35 | self.num_graphs = num_graphs
36 | self.dropout_p = dropout_p
37 |
38 | if self.order == "quadratic":
39 | self.param_dim = [
40 | self.num_bins,
41 | self.num_bins,
42 | (self.num_bins - 1),
43 | ] # this is for quadratic order conditional spline flow
44 | elif self.order == "linear":
45 | self.param_dim = [
46 | self.num_bins,
47 | self.num_bins,
48 | (self.num_bins - 1),
49 | self.num_bins,
50 | ] # this is for linear order conditional spline flow
51 |
52 | self.total_param = sum(self.param_dim)
53 | input_dim = 2*self.embedding_dim
54 |
55 | self.nn_size = max(4 * num_nodes, self.embedding_dim, 64)
56 |
57 | self.f = generate_fully_connected(
58 | input_dim=input_dim,
59 | output_dim=self.total_param, # potentially num_nodes
60 | hidden_dims=[self.nn_size, self.nn_size],
61 | non_linearity=nn.LeakyReLU,
62 | activation=nn.Identity,
63 | device=self.device,
64 | normalization=nn.LayerNorm,
65 | res_connection=skip_connection,
66 | )
67 |
68 | self.g = generate_fully_connected(
69 | input_dim=self.embedding_dim+self.data_dim,
70 | output_dim=self.embedding_dim,
71 | hidden_dims=[self.nn_size, self.nn_size],
72 | non_linearity=nn.LeakyReLU,
73 | activation=nn.Identity,
74 | device=self.device,
75 | normalization=nn.LayerNorm,
76 | res_connection=skip_connection,
77 | )
78 |
79 | self.th_embeddings = MultiEmbedding(num_nodes=self.num_nodes,
80 | lag=self.lag,
81 | num_graphs=self.num_graphs,
82 | embedding_dim=self.embedding_dim)
83 |
84 | def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]:
85 | """
86 | Args:
87 | X: A dict consisting of two keys, "A" is the adjacency matrix with shape [batch, lag+1, num_nodes, num_nodes]
88 | and "X" is the history data with shape (batch, lag, num_nodes, data_dim).
89 |
90 | Returns:
91 | A tuple of parameters with shape [N_batch, num_cts_node*param_dim_each].
92 | The length of tuple is len(self.param_dim),
93 | """
94 |
95 | # assert "A" in X and "X" in X and len(
96 | # X) == 2, "The key for input can only contain two keys, 'A', 'X'."
97 |
98 | A = X["A"]
99 | X_in = X["X"]
100 |
101 | E = self.th_embeddings.get_embeddings()
102 | # X["embeddings"]
103 |
104 | batch, lag, num_nodes, _ = X_in.shape
105 |
106 | # reshape X to the correct shape
107 | A = A.unsqueeze(0).expand((batch, -1, -1, -1, -1))
108 | E = E.unsqueeze(0).expand((batch, -1, -1, -1, -1))
109 | X_in = X_in.unsqueeze(1).expand((-1, self.num_graphs, -1, -1, -1))
110 |
111 | # ensure we have the correct shape
112 | assert (A.shape[0] == batch and A.shape[1] == self.num_graphs and
113 | A.shape[2] == lag + 1 and A.shape[3] == num_nodes and
114 | A.shape[4] == num_nodes)
115 | assert (E.shape[0] == batch and E.shape[1] == self.num_graphs
116 | and E.shape[2] == lag+1 and E.shape[3] == num_nodes
117 | and E.shape[4] == self.embedding_dim)
118 |
119 | # shape [batch_size, num_graphs, lag, num_nodes, embedding_size]
120 | E_lag = E[:, :, 1:, :, :]
121 |
122 | # reshape X to the correct shape
123 | X_in = torch.cat((X_in, E_lag), dim=-1)
124 |
125 | # X_in: (batch, num_graphs, lag, num_nodes, embedding_dim + data_dim)
126 |
127 | X_enc = self.g(X_in) # (batch, lag, num_nodes, embedding_dim)
128 |
129 | # get the parents of X
130 | # (batch, num_graphs, num_nodes, embedding_dim)
131 | A_temp = A[:, :, 1:].flip([2])
132 |
133 | X_sum = torch.einsum("bnlij,bnlio->bnjo", A_temp, X_enc) # / num_nodes
134 |
135 | X_sum = torch.cat((X_sum, E[:, :, 0, :, :]), dim=-1)
136 |
137 | # pass through f network to get the parameters
138 | params = self.f(X_sum) # (batch, num_graphs, num_nodes, total_params)
139 |
140 | # pass multiple graph parameters along batch dimension
141 | params = params.reshape(batch*self.num_graphs, num_nodes, -1)
142 |
143 | param_list = torch.split(params, self.param_dim, dim=-1)
144 | # a list of tensor with shape [batch*num_graphs, num_nodes*each_param]
145 | return tuple(
146 | param.reshape([-1, num_nodes * param.shape[-1]]) for param in param_list)
147 |
--------------------------------------------------------------------------------
/src/modules/TemporalConditionalSplineFlow.py:
--------------------------------------------------------------------------------
1 | import lightning.pytorch as pl
2 | from torch import nn
3 | import torch
4 |
5 | import pyro.distributions as distrib
6 | from pyro.distributions.transforms.spline import ConditionalSpline
7 | from pyro.distributions import constraints
8 | from pyro.distributions.torch_transform import TransformModule
9 |
10 |
11 | class AffineDiagonalPyro(TransformModule):
12 | """
13 | This creates a diagonal affine transformation compatible with pyro transforms
14 | """
15 |
16 | domain = constraints.real
17 | codomain = constraints.real
18 | bijective = True
19 |
20 | def __init__(self, input_dim: int):
21 | super().__init__(cache_size=1)
22 | self.dim = input_dim
23 | self.a = nn.Parameter(torch.ones(input_dim), requires_grad=True)
24 | self.b = nn.Parameter(torch.zeros(input_dim), requires_grad=True)
25 |
26 | def _call(self, x: torch.Tensor) -> torch.Tensor:
27 | """
28 | Forward method
29 | Args:
30 | x: tensor with shape [batch, input_dim]
31 | Returns:
32 | Transformed inputs
33 | """
34 | return self.a.exp().unsqueeze(0) * x + self.b.unsqueeze(0)
35 |
36 | def _inverse(self, y: torch.Tensor) -> torch.Tensor:
37 | """
38 | Reverse method
39 | Args:
40 | y: tensor with shape [batch, input]
41 | Returns:
42 | Reversed input
43 | """
44 | return (-self.a).exp().unsqueeze(0) * (y - self.b.unsqueeze(0))
45 |
46 | def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
47 | _, _ = x, y
48 | return self.a.unsqueeze(0)
49 |
50 |
51 | class TemporalConditionalSplineFlow(pl.LightningModule):
52 |
53 | def __init__(self,
54 | hypernet
55 | ):
56 | super().__init__()
57 |
58 | self.hypernet = hypernet
59 | self.num_bins = self.hypernet.num_bins
60 | self.order = self.hypernet.order
61 |
62 | def log_prob(self,
63 | X_input: torch.Tensor,
64 | X_history: torch.Tensor,
65 | A: torch.Tensor,
66 | embeddings: torch.Tensor = None):
67 | """
68 | Args:
69 | X_input: input data of shape (batch, num_nodes, data_dim)
70 | X_history: input data of shape (batch, lag, num_nodes, data_dim)
71 | A: adjacency matrix of shape (batch, lag+1, num_nodes, num_nodes)
72 | embeddings: embeddings (batch, lag+1, num_nodes, embedding_dim)
73 | """
74 |
75 | assert len(X_history.shape) == 4
76 |
77 | _, _, num_nodes, data_dim = X_history.shape
78 |
79 | # if not self.trainable_embeddings:
80 | transform = nn.ModuleList(
81 | [
82 | ConditionalSpline(
83 | self.hypernet, input_dim=num_nodes*data_dim, count_bins=self.num_bins, order=self.order, bound=5.0
84 | )
85 | # AffineDiagonalPyro(input_dim=self.num_nodes*self.data_dim),
86 | # Spline(input_dim=self.num_nodes*self.data_dim, count_bins=self.num_bins, order="quadratic", bound=5.0),
87 | # AffineDiagonalPyro(input_dim=self.num_nodes*self.data_dim),
88 | # Spline(input_dim=self.num_nodes*self.data_dim, count_bins=self.num_bins, order="quadratic", bound=5.0),
89 | # AffineDiagonalPyro(input_dim=self.num_nodes*self.data_dim)
90 | ]
91 | )
92 | base_dist = distrib.Normal(
93 | torch.zeros(num_nodes*data_dim, device=self.device), torch.ones(
94 | num_nodes*data_dim, device=self.device)
95 | )
96 | # else:
97 |
98 | context_dict = {"X": X_history, "A": A, "embeddings": embeddings}
99 | flow_dist = distrib.ConditionalTransformedDistribution(
100 | base_dist, transform).condition(context_dict)
101 | return flow_dist.log_prob(X_input)
102 |
103 | def sample(self,
104 | N_samples: int,
105 | X_history: torch.Tensor,
106 | W: torch.Tensor,
107 | embeddings: torch.Tensor):
108 | assert len(X_history.shape) == 4
109 |
110 | batch, _, num_nodes, data_dim = X_history.shape
111 |
112 | base_dist = distrib.Normal(
113 | torch.zeros(num_nodes*data_dim, device=self.device), torch.ones(
114 | num_nodes*data_dim, device=self.device)
115 | )
116 |
117 | context_dict = {"X": X_history, "A": W, "embeddings": embeddings}
118 | flow_dist = distrib.ConditionalTransformedDistribution(
119 | base_dist, self.transform).condition(context_dict)
120 | return flow_dist.sample([N_samples, batch])
121 |
--------------------------------------------------------------------------------
/src/modules/TemporalHyperNet.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 | import lightning.pytorch as pl
3 | from torch import nn
4 | import torch
5 | from src.utils.torch_utils import generate_fully_connected
6 |
7 | class TemporalHyperNet(pl.LightningModule):
8 |
9 | def __init__(self,
10 | order: str,
11 | lag: int,
12 | data_dim: int,
13 | num_nodes: int,
14 | embedding_dim: int = None,
15 | skip_connection: bool = False,
16 | num_bins: int = 8):
17 | super().__init__()
18 |
19 | if embedding_dim is None:
20 | embedding_dim = num_nodes * data_dim
21 |
22 | self.embedding_dim = embedding_dim
23 | self.data_dim = data_dim
24 | self.lag = lag
25 | self.order = order
26 | self.num_bins = num_bins
27 | self.num_nodes = num_nodes
28 |
29 | if self.order == "quadratic":
30 | self.param_dim = [
31 | self.num_bins,
32 | self.num_bins,
33 | (self.num_bins - 1),
34 | ] # this is for quadratic order conditional spline flow
35 | elif self.order == "linear":
36 | self.param_dim = [
37 | self.num_bins,
38 | self.num_bins,
39 | (self.num_bins - 1),
40 | self.num_bins,
41 | ] # this is for linear order conditional spline flow
42 |
43 | self.total_param = sum(self.param_dim)
44 | input_dim = 2*self.embedding_dim
45 | self.nn_size = max(4 * num_nodes, self.embedding_dim, 64)
46 |
47 | self.f = generate_fully_connected(
48 | input_dim=input_dim,
49 | output_dim=self.total_param, # potentially num_nodes
50 | hidden_dims=[self.nn_size, self.nn_size],
51 | non_linearity=nn.LeakyReLU,
52 | activation=nn.Identity,
53 | device=self.device,
54 | normalization=nn.LayerNorm,
55 | res_connection=skip_connection,
56 | )
57 |
58 | self.g = generate_fully_connected(
59 | input_dim=self.embedding_dim+self.data_dim,
60 | output_dim=self.embedding_dim,
61 | hidden_dims=[self.nn_size, self.nn_size],
62 | non_linearity=nn.LeakyReLU,
63 | activation=nn.Identity,
64 | device=self.device,
65 | normalization=nn.LayerNorm,
66 | res_connection=skip_connection,
67 | )
68 |
69 | self.embeddings = nn.Parameter((
70 | torch.randn(self.lag + 1, self.num_nodes,
71 | self.embedding_dim, device=self.device) * 0.01
72 | ), requires_grad=True) # shape (lag+1, num_nodes, embedding_dim)
73 |
74 | def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]:
75 | """
76 | Args:
77 | X: A dict consisting of two keys, "A" is the adjacency matrix with shape [batch, lag+1, num_nodes, num_nodes]
78 | and "X" is the history data with shape (batch, lag, num_nodes, data_dim).
79 |
80 | Returns:
81 | A tuple of parameters with shape [N_batch, num_cts_node*param_dim_each].
82 | The length of tuple is len(self.param_dim),
83 | """
84 |
85 | # assert "A" in X and "X" in X and len(
86 | # X) == 2, "The key for input can only contain two keys, 'A', 'X'."
87 |
88 | A = X["A"]
89 | X_in = X["X"]
90 | embeddings = X["embeddings"]
91 | batch, lag, num_nodes, _ = X_in.shape
92 |
93 | # ensure we have the correct shape
94 | assert (A.shape[0] == batch and A.shape[1] == lag +
95 | 1 and A.shape[2] == num_nodes and A.shape[3] == num_nodes)
96 |
97 | if embeddings is None:
98 | E = self.embeddings.expand(
99 | X_in.shape[0], -1, -1, -1
100 | )
101 | else:
102 | E = embeddings
103 |
104 | # shape [batch_size, lag, num_nodes, embedding_size]
105 | E_lag = E[..., 1:, :, :]
106 |
107 | X_in = torch.cat((X_in, E_lag), dim=-1)
108 | # X_in: (batch, lag, num_nodes, embedding_dim + data_dim)
109 |
110 | X_enc = self.g(X_in) # (batch, lag, num_nodes, embedding_dim)
111 |
112 | # get the parents of X
113 | # (batch, num_nodes, embedding_dim)
114 | A_temp = A[:, 1:].flip([1])
115 |
116 | X_sum = torch.einsum("blij,blio->bjo", A_temp, X_enc) # / num_nodes
117 |
118 | X_sum = torch.cat((X_sum, E[..., 0, :, :]), dim=-1)
119 |
120 | # pass through f network to get the parameters
121 | params = self.f(X_sum) # (batch, num_nodes, total_params)
122 |
123 | param_list = torch.split(params, self.param_dim, dim=-1)
124 | # a list of tensor with shape [batch, num_nodes*each_param]
125 | return tuple(
126 | param.reshape([-1, num_nodes * param.shape[-1]]) for param in param_list)
127 |
--------------------------------------------------------------------------------
/src/modules/adjacency_matrices/AdjMatrix.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import torch
3 |
4 |
5 | class AdjMatrix(ABC):
6 | """
7 | Adjacency matrix interface for DECI
8 | """
9 |
10 | @abstractmethod
11 | def get_adj_matrix(self, do_round: bool = True) -> torch.Tensor:
12 | """
13 | Returns the adjacency matrix.
14 | """
15 | raise NotImplementedError()
16 |
17 | @abstractmethod
18 | def entropy(self) -> torch.Tensor:
19 | """
20 | Computes the entropy of distribution q. In this case 0.
21 | """
22 | raise NotImplementedError()
23 |
24 | @abstractmethod
25 | def sample_A(self) -> torch.Tensor:
26 | """
27 | Returns the adjacency matrix.
28 | """
29 | raise NotImplementedError()
30 |
--------------------------------------------------------------------------------
/src/modules/adjacency_matrices/MultiTemporalAdjacencyMatrix.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import lightning.pytorch as pl
4 | import torch
5 | import torch.distributions as td
6 | import torch.nn.functional as F
7 | from torch import nn
8 | from src.modules.adjacency_matrices.AdjMatrix import AdjMatrix
9 |
10 |
11 | class MultiTemporalAdjacencyMatrix(pl.LightningModule, AdjMatrix):
12 | def __init__(
13 | self,
14 | num_nodes: int,
15 | lag: int,
16 | num_graphs: int,
17 | tau_gumbel: float = 1.0,
18 | threeway: bool = True,
19 | init_logits: Optional[List[float]] = None,
20 | disable_inst: bool = False
21 | ):
22 | super().__init__()
23 | self.lag = lag
24 | self.tau_gumbel = tau_gumbel
25 | self.threeway = threeway
26 | self.num_nodes = num_nodes
27 | # Assertion lag > 0
28 | assert lag > 0
29 | self.num_graphs = num_graphs
30 | self.disable_inst = disable_inst
31 |
32 | if self.threeway:
33 | self.logits_inst = nn.Parameter(
34 | torch.zeros((3, self.num_graphs, (num_nodes * (num_nodes - 1)) // 2),
35 | device=self.device),
36 | requires_grad=True
37 | )
38 | self.lower_idxs = torch.unbind(
39 | torch.tril_indices(self.num_nodes, self.num_nodes,
40 | offset=-1, device=self.device), 0
41 | )
42 | else:
43 | self.logits_inst = nn.Parameter(
44 | torch.zeros((2, self.num_graphs, num_nodes, num_nodes),
45 | device=self.device),
46 | requires_grad=True
47 | )
48 |
49 | self.logits_lag = nn.Parameter(torch.zeros((2, self.num_graphs, lag, num_nodes, num_nodes),
50 | device=self.device),
51 | requires_grad=True)
52 | self.init_logits = init_logits
53 | # Set the init_logits if not None
54 | if self.init_logits is not None:
55 | if self.threeway:
56 | self.logits_inst.data[2, :, ...] = self.init_logits[0]
57 | else:
58 | self.logits_inst.data[1, :, ...] = self.init_logits[0]
59 | self.logits_lag.data[0, :, ...] = self.init_logits[1]
60 |
61 | def zero_out_diagonal(self, matrix: torch.Tensor):
62 | # matrix: (num_graphs, num_nodes, num_nodes)
63 | N = matrix.shape[1]
64 | I = torch.arange(N).to(self.device)
65 | matrix = matrix.clone()
66 | matrix[:, I, I] = 0
67 | return matrix
68 |
69 | def _triangular_vec_to_matrix(self, vec):
70 | """
71 | Given an array of shape (k, N, n(n-1)/2) where k in {2, 3}, creates a matrix of shape
72 | (N, n, n) where the lower triangular is filled from vec[0, :] and the upper
73 | triangular is filled from vec[1, :].
74 | """
75 | N = vec.shape[1]
76 | output = torch.zeros(
77 | (N, self.num_nodes, self.num_nodes), device=self.device)
78 | output[:, self.lower_idxs[0], self.lower_idxs[1]] = vec[0, :, ...]
79 | output[:, self.lower_idxs[1], self.lower_idxs[0]] = vec[1, :, ...]
80 | return output
81 |
82 | def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor:
83 | """
84 | Returns the adjacency matrix.
85 | """
86 | probs = torch.zeros((self.num_graphs, self.lag + 1, self.num_nodes, self.num_nodes),
87 | device=self.device)
88 |
89 | if not self.disable_inst:
90 | inst_probs = F.softmax(self.logits_inst, dim=0)
91 | if self.threeway:
92 | # (3, n(n-1)/2) probabilities
93 | inst_probs = self._triangular_vec_to_matrix(inst_probs)
94 | else:
95 | inst_probs = self.zero_out_diagonal(inst_probs[1, ...])
96 |
97 | # Generate simultaneous adj matrix
98 | # shape (input_dim, input_dim)
99 | probs[:, 0, ...] = inst_probs
100 |
101 | # Generate lagged adj matrix
102 | # shape (lag, input_dim, input_dim)
103 | probs[:, 1:, ...] = F.softmax(self.logits_lag, dim=0)[1, ...]
104 | if do_round:
105 | return probs.round()
106 |
107 | return probs
108 |
109 | def entropy(self) -> torch.Tensor:
110 | """
111 | Computes the entropy of distribution q. In this case 0.
112 | """
113 |
114 | if not self.disable_inst:
115 | if self.threeway:
116 | dist = td.Categorical(
117 | logits=self.logits_inst[:, :].transpose(0, -1))
118 | entropies_inst = dist.entropy().sum()
119 | else:
120 | dist = td.Categorical(
121 | logits=self.logits_inst[1, ...] - self.logits_inst[0, ...])
122 | I = torch.arange(self.num_nodes)
123 | dist_diag = td.Categorical(
124 | logits=self.logits_inst[1, :, I, I] - self.logits_inst[0, :, I, I])
125 | entropies = dist.entropy()
126 | diag_entropy = dist_diag.entropy()
127 | entropies_inst = entropies.sum() - diag_entropy.sum()
128 | else:
129 | entropies_inst = 0
130 |
131 | dist_lag = td.Independent(td.Bernoulli(
132 | logits=self.logits_lag[1, :] - self.logits_lag[0, :]), 3)
133 | entropies_lag = dist_lag.entropy().sum()
134 |
135 | return entropies_lag + entropies_inst
136 |
137 | def sample_A(self) -> torch.Tensor:
138 | """
139 | This samples the adjacency matrix from the variational distribution. This uses the gumbel softmax trick and returns
140 | hard samples. This can be done by (1) sample instantaneous adj matrix using self.logits, (2) sample lagged adj matrix using self.logits_lag.
141 | """
142 | # Create adj matrix to avoid concatenation
143 | adj_sample = torch.zeros(
144 | (self.num_graphs, self.lag + 1, self.num_nodes, self.num_nodes), device=self.device
145 | ) # shape ( lag+1, input_dim, input_dim)
146 |
147 | if not self.disable_inst:
148 | if self.threeway:
149 | # Sample instantaneous adj matrix
150 | adj_sample[:, 0, ...] = self._triangular_vec_to_matrix(
151 | F.gumbel_softmax(self.logits_inst,
152 | tau=self.tau_gumbel,
153 | hard=True,
154 | dim=0)
155 | ) # shape (N, input_dim, input_dim)
156 | else:
157 | sample = F.gumbel_softmax(
158 | self.logits_inst, tau=self.tau_gumbel, hard=True, dim=0)[1, ...]
159 | adj_sample[:, 0, ...] = self.zero_out_diagonal(sample)
160 |
161 | # Sample lagged adj matrix
162 | # shape (N, lag, input_dim, input_dim)
163 | adj_sample[:, 1:, ...] = F.gumbel_softmax(self.logits_lag,
164 | tau=self.tau_gumbel,
165 | hard=True,
166 | dim=0)[1, ...]
167 | return adj_sample
168 |
169 | def turn_off_inst_grad(self):
170 | self.logits_inst.requires_grad_(False)
171 |
172 | def turn_on_inst_grad(self):
173 | self.logits_inst.requires_grad_(True)
174 |
--------------------------------------------------------------------------------
/src/modules/adjacency_matrices/TemporalAdjacencyMatrix.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | import torch.distributions as td
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from src.modules.adjacency_matrices.ThreeWayGraphDist import ThreeWayGraphDist
8 |
9 |
10 | class TemporalAdjacencyMatrix(ThreeWayGraphDist):
11 | """
12 | This class adapts the ThreeWayGraphDist s.t. it supports the variational distributions for temporal adjacency matrix.
13 |
14 | The principle is to follow the logic as ThreeWayGraphDist. The implementation has two separate part:
15 | (1) categorical distribution for instantaneous adj (see ThreeWayGraphDist); (2) Bernoulli distribution for lagged
16 | adj. Note that for lagged adj, we do not need to follow the logic from ThreeWayGraphDist, since lagged adj allows diagonal elements
17 | and does not have to be a DAG. Therefore, it is simpler to directly model it with Bernoulli distribution.
18 | """
19 |
20 | def __init__(
21 | self,
22 | input_dim: int,
23 | lag: int,
24 | tau_gumbel: float = 1.0,
25 | disable_inst: bool = False,
26 | init_logits: Optional[List[float]] = None,
27 | ):
28 | """
29 | This creates an instance of variational distribution for temporal adjacency matrix.
30 | Args:
31 | input_dim: The number of nodes for adjacency matrix.
32 | lag: The lag for the temporal adj matrix. The adj matrix has the shape (lag+1, num_nodes, num_nodes).
33 | tau_gumbel: The temperature for the gumbel softmax sampling.
34 | init_logits: The initialized logits value. If None, then use the default initlized logits (value 0). Otherwise,
35 | init_logits[0] indicates the non-existence edge logit for instantaneous effect, and init_logits[1] indicates the
36 | non-existence edge logit for lagged effect. E.g. if we want a dense initialization, one choice is (-7, -0.5)
37 | """
38 | # Call parent init method
39 | super().__init__(input_dim=input_dim, tau_gumbel=tau_gumbel)
40 | # Create a separate logit for lagged adj
41 | # The logits_lag are initialized to zero with shape (2, lag, input_dim, input_dim).
42 | # logits_lag[0,...] indicates the logit prob for no edges, and logits_lag[1,...] indicates the logit for edge existence.
43 | self.lag = lag
44 | # Assertion lag > 0
45 | assert lag > 0
46 | self.logits_lag = nn.Parameter(torch.zeros(
47 | (2, lag, input_dim, input_dim), device=self.device), requires_grad=True)
48 | self.init_logits = init_logits
49 | self.disable_inst = disable_inst
50 | # Set the init_logits if not None
51 | if self.init_logits is not None:
52 | self.logits.data[2, ...] = self.init_logits[0]
53 | self.logits_lag.data[0, ...] = self.init_logits[1]
54 |
55 | def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor:
56 | """
57 | This returns the temporal adjacency matrix of edge probability.
58 | Args:
59 | do_round: Whether to round the edge probabilities.
60 |
61 | Returns:
62 | The adjacency matrix with shape [lag+1, num_nodes, num_nodes].
63 | """
64 |
65 | # Create the temporal adj matrix
66 | probs = torch.zeros(self.lag + 1, self.input_dim,
67 | self.input_dim, device=self.device)
68 | # Generate simultaneous adj matrix
69 | if not self.disable_inst:
70 | probs[0, ...] = super().get_adj_matrix(
71 | do_round=do_round) # shape (input_dim, input_dim)
72 | # Generate lagged adj matrix
73 | probs[1:, ...] = F.softmax(self.logits_lag, dim=0)[
74 | 1, ...] # shape (lag, input_dim, input_dim)
75 | if do_round:
76 | return probs.round()
77 |
78 | return probs
79 |
80 | def entropy(self) -> torch.Tensor:
81 | """
82 | This computes the entropy of the variational distribution.
83 | This can be done by (1) compute the entropy of instantaneous adj matrix(categorical, same as ThreeWayGraphDist),
84 | (2) compute the entropy of lagged adj matrix (Bernoulli dist), and (3) add them together.
85 | """
86 | # Entropy for instantaneous dist, call super().entropy
87 | if not self.disable_inst:
88 | entropies_inst = super().entropy()
89 | else:
90 | entropies_inst = 0
91 | # Entropy for lagged dist
92 | # batch_shape [lag], event_shape [num_nodes, num_nodes]
93 |
94 | dist_lag = td.Independent(td.Bernoulli(
95 | logits=self.logits_lag[1, ...] - self.logits_lag[0, ...]), 2)
96 | entropies_lag = dist_lag.entropy().sum()
97 | # entropies_lag = dist_lag.entropy().mean()
98 |
99 | return entropies_lag + entropies_inst
100 |
101 | def sample_A(self) -> torch.Tensor:
102 | """
103 | This samples the adjacency matrix from the variational distribution. This uses the gumbel softmax trick and returns
104 | hard samples. This can be done by (1) sample instantaneous adj matrix using self.logits, (2) sample lagged adj matrix using self.logits_lag.
105 | """
106 |
107 | # Create adj matrix to avoid concatenation
108 | adj_sample = torch.zeros(
109 | self.lag + 1, self.input_dim, self.input_dim, device=self.device
110 | ) # shape (lag+1, input_dim, input_dim)
111 |
112 | # Sample instantaneous adj matrix
113 | if not self.disable_inst:
114 | adj_sample[0, ...] = self._triangular_vec_to_matrix(
115 | F.gumbel_softmax(
116 | self.logits, tau=self.tau_gumbel, hard=True, dim=0)
117 | ) # shape (input_dim, input_dim)
118 | # Sample lagged adj matrix
119 | adj_sample[1:, ...] = F.gumbel_softmax(self.logits_lag, tau=self.tau_gumbel, hard=True, dim=0)[
120 | 1, ...
121 | ] # shape (lag, input_dim, input_dim)
122 | return adj_sample
123 |
--------------------------------------------------------------------------------
/src/modules/adjacency_matrices/ThreeWayGraphDist.py:
--------------------------------------------------------------------------------
1 |
2 | import lightning.pytorch as pl
3 | import torch
4 | import torch.distributions as td
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from src.modules.adjacency_matrices.AdjMatrix import AdjMatrix
8 |
9 |
10 | class ThreeWayGraphDist(AdjMatrix, pl.LightningModule):
11 | """
12 | An alternative variational distribution for graph edges. For each pair of nodes x_i and x_j
13 | where i < j, we sample a three way categorical C_ij. If C_ij = 0, we sample the edge
14 | x_i -> x_j, if C_ij = 1, we sample the edge x_j -> x_i, and if C_ij = 2, there is no
15 | edge between these nodes. This variational distribution is faster to use than ENCO
16 | because it avoids any calls to `torch.stack`.
17 |
18 | Sampling is performed with `torch.gumbel_softmax(..., hard=True)` to give
19 | binary samples and a straight-through gradient estimator.
20 | """
21 |
22 | def __init__(
23 | self,
24 | input_dim: int,
25 | tau_gumbel: float = 1.0,
26 | ):
27 | """
28 | Args:
29 | input_dim: dimension.
30 | tau_gumbel: temperature used for gumbel softmax sampling.
31 | """
32 | super().__init__()
33 | # We only use n(n-1)/2 random samples
34 | # For each edge, sample either A->B, B->A or no edge
35 | # We convert this to a proper adjacency matrix using torch.tril_indices
36 | self.logits = nn.Parameter(
37 | torch.zeros(3, (input_dim * (input_dim - 1)) // 2, device=self.device), requires_grad=True
38 | )
39 | self.tau_gumbel = tau_gumbel
40 | self.input_dim = input_dim
41 | self.lower_idxs = torch.unbind(
42 | torch.tril_indices(self.input_dim, self.input_dim,
43 | offset=-1, device=self.device), 0
44 | )
45 |
46 | def _triangular_vec_to_matrix(self, vec):
47 | """
48 | Given an array of shape (k, n(n-1)/2) where k in {2, 3}, creates a matrix of shape
49 | (n, n) where the lower triangular is filled from vec[0, :] and the upper
50 | triangular is filled from vec[1, :].
51 | """
52 | output = torch.zeros(
53 | (self.input_dim, self.input_dim), device=self.device)
54 | output[self.lower_idxs[0], self.lower_idxs[1]] = vec[0, ...]
55 | output[self.lower_idxs[1], self.lower_idxs[0]] = vec[1, ...]
56 | return output
57 |
58 | def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor:
59 | """
60 | Returns the adjacency matrix of edge probabilities.
61 | """
62 | probs = F.softmax(self.logits, dim=0) # (3, n(n-1)/2) probabilities
63 | out_probs = self._triangular_vec_to_matrix(probs)
64 | if do_round:
65 | return out_probs.round()
66 | return out_probs
67 |
68 | def entropy(self) -> torch.Tensor:
69 | """
70 | Computes the entropy of distribution q, which is a collection of n(n-1) categoricals on 3 values.
71 | """
72 | dist = td.Categorical(logits=self.logits.transpose(0, -1))
73 | entropies = dist.entropy()
74 | return entropies.sum()
75 | # return entropies.mean()
76 |
77 | def sample_A(self) -> torch.Tensor:
78 | """
79 | Sample an adjacency matrix from the variational distribution. It uses the gumbel_softmax trick,
80 | and returns hard samples (straight through gradient estimator). Adjacency returned always has
81 | zeros in its diagonal (no self loops).
82 |
83 | V1: Returns one sample to be used for the whole batch.
84 | """
85 | sample = F.gumbel_softmax(
86 | self.logits, tau=self.tau_gumbel, hard=True, dim=0) # (3, n(n-1)/2) binary
87 | return self._triangular_vec_to_matrix(sample)
88 |
--------------------------------------------------------------------------------
/src/modules/adjacency_matrices/TwoWayGraphDist.py:
--------------------------------------------------------------------------------
1 | import lightning.pytorch as pl
2 | import torch
3 | import torch.distributions as td
4 | import torch.nn.functional as F
5 | from torch import nn
6 | from src.modules.adjacency_matrices.AdjMatrix import AdjMatrix
7 |
8 |
9 | class TwoWayGraphDist(AdjMatrix, pl.LightningModule):
10 | """
11 | Sampling is performed with `torch.gumbel_softmax(..., hard=True)` to give
12 | binary samples and a straight-through gradient estimator.
13 | """
14 |
15 | def __init__(
16 | self,
17 | input_dim: int,
18 | tau_gumbel: float = 1.0
19 | ):
20 | """
21 | Args:
22 | input_dim: dimension.
23 | tau_gumbel: temperature used for gumbel softmax sampling.
24 | """
25 | super().__init__()
26 | # We only use n(n-1)/2 random samples
27 | # For each edge, sample either A->B, B->A or no edge
28 | # We convert this to a proper adjacency matrix using torch.tril_indices
29 | self.logits = nn.Parameter(
30 | torch.zeros((2, input_dim, input_dim), device=self.device), requires_grad=True
31 | )
32 | self.tau_gumbel = tau_gumbel
33 | self.input_dim = input_dim
34 |
35 | def zero_out_diagonal(self, matrix: torch.Tensor):
36 | # matrix: (num_nodes, num_nodes)
37 | N = matrix.shape[0]
38 | I = torch.arange(N).to(self.device)
39 | matrix = matrix.clone()
40 | matrix[I, I] = 0
41 | return matrix
42 |
43 | def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor:
44 | """
45 | Returns the adjacency matrix of edge probabilities.
46 | """
47 | probs = F.softmax(self.logits, dim=0)[1, ...]
48 | probs = self.zero_out_diagonal(probs)
49 | # probs = F.softmax(self.logits, dim=0) # (3, n(n-1)/2) probabilities
50 |
51 | if do_round:
52 | return probs.round()
53 | return probs
54 |
55 | def entropy(self) -> torch.Tensor:
56 | """
57 | Computes the entropy of distribution q, which is a collection of n(n-1) categoricals on 3 values.
58 | """
59 | dist = td.Categorical(logits=self.logits[1, ...] - self.logits[0, ...])
60 | I = torch.arange(self.input_dim)
61 | dist_diag = td.Categorical(
62 | logits=self.logits[1, I, I] - self.logits[0, I, I])
63 | entropies = dist.entropy()
64 | diag_entropy = dist_diag.entropy()
65 | return entropies.sum() - diag_entropy.sum()
66 | # return entropies.mean()
67 |
68 | def sample_A(self) -> torch.Tensor:
69 | """
70 | Sample an adjacency matrix from the variational distribution. It uses the gumbel_softmax trick,
71 | and returns hard samples (straight through gradient estimator). Adjacency returned always has
72 | zeros in its diagonal (no self loops).
73 |
74 | V1: Returns one sample to be used for the whole batch.
75 | """
76 | sample = F.gumbel_softmax(
77 | self.logits, tau=self.tau_gumbel, hard=True, dim=0)[1, ...]
78 | return self.zero_out_diagonal(sample)
79 |
--------------------------------------------------------------------------------
/src/modules/adjacency_matrices/TwoWayTemporalAdjacencyMatrix.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | import torch.distributions as td
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from src.modules.adjacency_matrices.TwoWayGraphDist import TwoWayGraphDist
8 |
9 |
10 | class TwoWayTemporalAdjacencyMatrix(TwoWayGraphDist):
11 | """
12 | This class adapts the TwoWayGraphDist s.t. it supports the variational distributions for temporal adjacency matrix.
13 |
14 | """
15 |
16 | def __init__(
17 | self,
18 | input_dim: int,
19 | lag: int,
20 | tau_gumbel: float = 1.0,
21 | init_logits: Optional[List[float]] = None,
22 | disable_inst: bool = False
23 | ):
24 | """
25 | This creates an instance of variational distribution for temporal adjacency matrix.
26 | Args:
27 | device: Device used.
28 | input_dim: The number of nodes for adjacency matrix.
29 | lag: The lag for the temporal adj matrix. The adj matrix has the shape (lag+1, num_nodes, num_nodes).
30 | tau_gumbel: The temperature for the gumbel softmax sampling.
31 | init_logits: The initialized logits value. If None, then use the default initlized logits (value 0). Otherwise,
32 | init_logits[0] indicates the non-existence edge logit for instantaneous effect, and init_logits[1] indicates the
33 | non-existence edge logit for lagged effect. E.g. if we want a dense initialization, one choice is (-7, -0.5)
34 | """
35 | # Call parent init method, this will init a self.logits parameters for instantaneous effect.
36 | super().__init__(input_dim=input_dim, tau_gumbel=tau_gumbel)
37 | # Create a separate logit for lagged adj
38 | # The logits_lag are initialized to zero with shape (2, lag, input_dim, input_dim).
39 | # logits_lag[0,...] indicates the logit prob for no edges, and logits_lag[1,...] indicates the logit for edge existence.
40 | self.lag = lag
41 | # Assertion lag > 0
42 | assert lag > 0
43 | self.logits_lag = nn.Parameter(torch.zeros(
44 | (2, lag, input_dim, input_dim), device=self.device), requires_grad=True)
45 | self.init_logits = init_logits
46 | self.disable_inst = disable_inst
47 | # Set the init_logits if not None
48 | if self.init_logits is not None:
49 | self.logits.data[0, ...] = self.init_logits[0]
50 | self.logits_lag.data[0, ...] = self.init_logits[1]
51 |
52 | def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor:
53 | """
54 | This returns the temporal adjacency matrix of edge probability.
55 | Args:
56 | do_round: Whether to round the edge probabilities.
57 |
58 | Returns:
59 | The adjacency matrix with shape [lag+1, num_nodes, num_nodes].
60 | """
61 |
62 | # Create the temporal adj matrix
63 | probs = torch.zeros(self.lag + 1, self.input_dim,
64 | self.input_dim, device=self.device)
65 | # Generate simultaneous adj matrix
66 | if not self.disable_inst:
67 | probs[0, ...] = super().get_adj_matrix(
68 | do_round=do_round) # shape (input_dim, input_dim)
69 | # Generate lagged adj matrix
70 | probs[1:, ...] = F.softmax(self.logits_lag, dim=0)[
71 | 1, ...] # shape (lag, input_dim, input_dim)
72 | if do_round:
73 | return probs.round()
74 |
75 | return probs
76 |
77 | def entropy(self) -> torch.Tensor:
78 | """
79 | This computes the entropy of the variational distribution.
80 | This can be done by (1) compute the entropy of instantaneous adj matrix(categorical, same as ThreeWayGraphDist),
81 | (2) compute the entropy of lagged adj matrix (Bernoulli dist), and (3) add them together.
82 | """
83 | # Entropy for instantaneous dist, call super().entropy
84 | if not self.disable_inst:
85 | entropies_inst = super().entropy()
86 | else:
87 | entropies_inst = 0
88 | # Entropy for lagged dist
89 | # batch_shape [lag], event_shape [num_nodes, num_nodes]
90 |
91 | dist_lag = td.Independent(td.Bernoulli(
92 | logits=self.logits_lag[1, ...] - self.logits_lag[0, ...]), 2)
93 | entropies_lag = dist_lag.entropy().sum()
94 | # entropies_lag = dist_lag.entropy().mean()
95 |
96 | return entropies_lag + entropies_inst
97 |
98 | def sample_A(self) -> torch.Tensor:
99 | """
100 | This samples the adjacency matrix from the variational distribution. This uses the gumbel softmax trick and returns
101 | hard samples. This can be done by (1) sample instantaneous adj matrix using self.logits, (2) sample lagged adj matrix using self.logits_lag.
102 | """
103 |
104 | # Create adj matrix to avoid concatenation
105 | adj_sample = torch.zeros(
106 | self.lag + 1, self.input_dim, self.input_dim, device=self.device
107 | ) # shape (lag+1, input_dim, input_dim)
108 |
109 | # Sample instantaneous adj matrix
110 | if not self.disable_inst:
111 | adj_sample[0, ...] = self.zero_out_diagonal(
112 | F.gumbel_softmax(self.logits, tau=self.tau_gumbel,
113 | hard=True, dim=0)[1, ...]
114 | ) # shape (input_dim, input_dim)
115 |
116 | # Sample lagged adj matrix
117 | adj_sample[1:, ...] = F.gumbel_softmax(self.logits_lag, tau=self.tau_gumbel, hard=True, dim=0)[
118 | 1, ...
119 | ] # shape (lag, input_dim, input_dim)
120 | return adj_sample
121 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | # standard libraries
2 | import time
3 | import os
4 |
5 | import hydra
6 | import numpy as np
7 | import torch
8 | from omegaconf import DictConfig
9 | from lightning.pytorch import Trainer, seed_everything
10 | from lightning.pytorch.loggers import CSVLogger, WandbLogger
11 | from lightning.pytorch.callbacks import ModelCheckpoint
12 | from lightning.pytorch.utilities.model_summary import ModelSummary
13 | from lightning.pytorch.utilities import rank_zero_only
14 | import wandb
15 |
16 | from src.utils.config_utils import add_all_attributes, add_attribute, generate_unique_name
17 | from src.utils.data_utils.dataloading_utils import load_data, get_dataset_path, create_save_name
18 | from src.utils.data_utils.data_format_utils import zero_out_diag_np
19 | from src.utils.utils import write_results_to_disk
20 | from src.utils.metrics_utils import evaluate_results
21 |
22 | from src.model.generate_model import generate_model
23 |
24 |
25 | @hydra.main(version_base=None, config_path="../configs/", config_name="main.yaml")
26 | def run(cfg: DictConfig):
27 | if 'dataset' not in cfg:
28 | raise Exception("No dataset found in the config")
29 |
30 | dataset = cfg.dataset
31 | dataset_path = get_dataset_path(dataset)
32 |
33 | if dataset_path in cfg:
34 | model_config = cfg[dataset_path]
35 | else:
36 | raise Exception(
37 | "No model found in the config. Try running with option: python3 -m src.train +dataset= +=")
38 |
39 | cfg.dataset_dir = os.path.join(cfg.dataset_dir, dataset_path)
40 |
41 | add_all_attributes(cfg, model_config)
42 | train(cfg)
43 |
44 |
45 | def train(cfg):
46 |
47 | print("Running config:")
48 | print(cfg)
49 | dataset = cfg.dataset
50 | seed = int(cfg.random_seed)
51 | # set seed
52 | seed_everything(cfg.random_seed)
53 |
54 | X, adj_matrix, aggregated_graph, lag, data_dim = load_data(
55 | cfg.dataset, cfg.dataset_dir, cfg)
56 |
57 | add_attribute(cfg, 'lag', lag)
58 | add_attribute(cfg, 'aggregated_graph', aggregated_graph)
59 | add_attribute(cfg, 'num_nodes', X.shape[2])
60 | add_attribute(cfg, 'data_dim', data_dim)
61 |
62 | # generate model
63 | model = generate_model(cfg)
64 |
65 | # pass the dataset
66 | model = model(full_dataset=X,
67 | adj_matrices=adj_matrix)
68 |
69 | model_name = cfg.model
70 |
71 | if 'use_indices' in cfg:
72 | f_path = os.path.join(cfg.dataset_dir, cfg.dataset,
73 | f'{cfg.use_indices}_seed={cfg.random_seed}.npy')
74 | mix_idx = torch.Tensor(np.load(f_path))
75 | model.set_mixture_indices(mix_idx)
76 |
77 | training_needed = cfg.model in ['rhino', 'mcd']
78 | unique_name = generate_unique_name(cfg)
79 | csv_logger = CSVLogger("logs", name=unique_name)
80 | wandb_logger = WandbLogger(
81 | name=unique_name, project=cfg.wandb_project, log_model=True)
82 |
83 | # log all hyperparameters
84 |
85 | if rank_zero_only.rank == 0:
86 | for key in cfg:
87 | wandb_logger.experiment.config[key] = cfg[key]
88 | csv_logger.log_hyperparams(cfg)
89 |
90 | if training_needed:
91 | # either val_loss or likelihood
92 | monitor_checkpoint_based_on = cfg.monitor_checkpoint_based_on
93 | ckpt_choice = 'best'
94 |
95 | checkpoint_callback = ModelCheckpoint(
96 | save_top_k=1, monitor=monitor_checkpoint_based_on, mode="min", save_last=True)
97 |
98 | if len(cfg.gpu) > 1:
99 | strategy = 'ddp_find_unused_parameters_true'
100 | else:
101 | strategy = 'auto'
102 |
103 | if 'val_every_n_epochs' in cfg:
104 | val_every_n_epochs = cfg.val_every_n_epochs
105 | else:
106 | val_every_n_epochs = 1
107 |
108 | trainer = Trainer(max_epochs=cfg.num_epochs,
109 | accelerator="gpu",
110 | devices=cfg.gpu,
111 | precision=cfg.precision,
112 | logger=[csv_logger, wandb_logger],
113 | callbacks=[checkpoint_callback],
114 | strategy=strategy,
115 | enable_progress_bar=True,
116 | check_val_every_n_epoch=val_every_n_epochs)
117 |
118 | summary = ModelSummary(model, max_depth=10)
119 | print(summary)
120 |
121 | if cfg.watch_gradients:
122 | wandb_logger.watch(model)
123 |
124 | start_time = time.time()
125 | trainer.fit(model=model)
126 | end_time = time.time()
127 |
128 | print("Model trained in", str(end_time-start_time) + "s")
129 |
130 | else:
131 | if cfg.gpu != -1:
132 | print("WARNING: GPU specified, but baseline cannot use GPU.")
133 | trainer = Trainer(logger=[csv_logger, wandb_logger],
134 | accelerator='cpu')
135 |
136 | # get predictions
137 |
138 | full_dataloader = model.get_full_dataloader()
139 | if training_needed:
140 | model.eval()
141 | L = trainer.predict(model, full_dataloader, ckpt_path=ckpt_choice)
142 | else:
143 | L = trainer.predict(model, full_dataloader)
144 |
145 | predictions = []
146 | scores = []
147 | adj_matrix = []
148 | for graph, prob, matrix in L:
149 | predictions.append(graph)
150 | scores.append(prob)
151 | adj_matrix.append(matrix)
152 |
153 | predictions = torch.concatenate(predictions, dim=0)
154 | scores = torch.concatenate(scores, dim=0)
155 | adj_matrix = torch.concatenate(adj_matrix, dim=0)
156 |
157 | predictions = predictions.detach().cpu().numpy()
158 | scores = scores.detach().cpu().numpy()
159 | adj_matrix = adj_matrix.detach().cpu().numpy()
160 |
161 | if 'dream3' in dataset:
162 | predictions = zero_out_diag_np(predictions)
163 | scores = zero_out_diag_np(scores)
164 |
165 | # save predictions
166 | if not os.path.exists(os.path.join('results', dataset)):
167 | os.makedirs(os.path.join('results', dataset))
168 |
169 | if training_needed and ckpt_choice == 'best':
170 | if model_name == 'mcd':
171 | np.save(os.path.join('results', dataset,
172 | f'{model_name}_{checkpoint_callback.best_model_score.item()}_k{cfg.trainer.num_graphs}.npy'), scores)
173 | else:
174 | np.save(os.path.join('results', dataset,
175 | f'{model_name}_{checkpoint_callback.best_model_score.item()}.npy'), scores)
176 | else:
177 | np.save(os.path.join('results', dataset, f'{model_name}.npy'), scores)
178 |
179 | if model_name == 'mcd':
180 | true_cluster_indices, pred_cluster_indices = model.get_cluster_indices()
181 | else:
182 | pred_cluster_indices = None
183 | true_cluster_indices = None
184 |
185 | metrics = evaluate_results(scores=scores,
186 | adj_matrix=adj_matrix,
187 | predictions=predictions,
188 | aggregated_graph=aggregated_graph,
189 | true_cluster_indices=true_cluster_indices,
190 | pred_cluster_indices=pred_cluster_indices)
191 | # add the dataset name and model to the csv
192 | metrics['model'] = model_name + "_seed_" + str(seed)
193 |
194 | if model_name in ['pcmci', 'dynotears']:
195 | metrics['model'] += "_singlegraph_" + \
196 | str(cfg.trainer.single_graph) + "_grouped_" + \
197 | str(cfg.trainer.group_by_graph)
198 | if model_name == 'mcd':
199 | metrics['model'] += "_trueindex_" + \
200 | str(cfg.trainer.use_correct_mixture_index)
201 | if 'use_indices' in cfg:
202 | metrics['model'] += '_' + cfg.use_indices
203 | if (model_name == ['rhino', 'mcd']) and 'linear' in cfg.decoder and cfg.decoder.linear:
204 | metrics['model'] += '_linearmodel'
205 | if training_needed and ckpt_choice == 'best':
206 | metrics['best_loss'] = checkpoint_callback.best_model_score.item()
207 | metrics['dataset'] = dataset
208 |
209 | # write the results to wandb
210 | wandb.init()
211 | wandb.log(metrics)
212 |
213 | # write the results to the log directory
214 | write_results_to_disk(create_save_name(dataset, cfg), metrics)
215 |
216 | wandb.finish()
217 |
218 |
219 | if __name__ == "__main__":
220 | run()
221 |
--------------------------------------------------------------------------------
/src/utils/causality_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrowed from github.com/microsoft/causica
3 | """
4 |
5 | from typing import Any, Dict, Optional, Union
6 |
7 | import numpy as np
8 | import torch
9 |
10 | def convert_temporal_to_static_adjacency_matrix(
11 | adj_matrix: np.ndarray, conversion_type: str, fill_value: Union[float, int] = 0.0
12 | ) -> np.ndarray:
13 | """
14 | This method convert the temporal adjacency matrix to a specified type of static adjacency.
15 | It supports two types of conversion: "full_time" and "auto_regressive".
16 | The conversion type determines the connections between time steps.
17 | "full_time" will convert the temporal adjacency matrix to a full-time static graph, where the connections between lagged time steps are preserved.
18 | "auto_regressive" will convert temporal adj to a static graph that only keeps the connections to the current time step.
19 | E.g. a temporal adj matrix with lag 2 is [A,B,C], where A,B and C are also adj matrices. "full_time" will convert this to
20 | [[A,B,C],[0,A,B],[0,0,A]]. "auto_regressive" will convert this to [[0,0,C],[0,0,B],[0,0,A]].
21 | "fill_value" is used to specify the value to fill in the converted static adjacency matrix. Default is 0, but sometimes we may want
22 | other values. E.g. if we have a temporal soft prior with prior mask, then we may want to fill the converted prior mask with value 1 instead of 0,
23 | since the converted prior mask should never disable the blocks specifying the "arrow-against-time" in converted soft prior.
24 | Args:
25 | adj_matrix: The temporal adjacency matrix with shape [lag+1, from, to] or [N, lag+1, from, to].
26 | conversion_type: The conversion type. It supports "full_time" and "auto_regressive".
27 | fill_value: The value used to fill the static adj matrix. The default is 0.
28 |
29 | Returns: static adjacency matrix with shape [(lag+1)*from, (lag+1)*to] or [N, (lag+1)*from, (lag+1)*to].
30 |
31 | """
32 | assert conversion_type in [
33 | "full_time",
34 | "auto_regressive",
35 | ], f"The conversion_type {conversion_type} is not supported."
36 | if len(adj_matrix.shape) == 3:
37 | adj_matrix = adj_matrix[None, ...] # [1, lag+1, num_node, num_node]
38 | batch_dim, n_lag, n_nodes, _ = adj_matrix.shape # n_lag is lag+1
39 | if conversion_type == "full_time":
40 | block_fill_value = np.full((n_nodes, n_nodes), fill_value)
41 | else:
42 | block_fill_value = np.full(
43 | (batch_dim, n_lag * n_nodes, (n_lag - 1) * n_nodes), fill_value)
44 |
45 | if conversion_type == "full_time":
46 | static_adj = np.sum(
47 | np.stack([np.kron(np.diag(np.ones(n_lag - i), k=i),
48 | adj_matrix[:, i, :, :]) for i in range(n_lag)], axis=1),
49 | axis=1,
50 | ) # [N, n_lag*from, n_lag*to]
51 | static_adj += np.kron(
52 | np.tril(np.ones((batch_dim, n_lag, n_lag)), k=-1), block_fill_value
53 | ) # [N, n_lag*from, n_lag*to]
54 |
55 | if conversion_type == "auto_regressive":
56 | # Flip the temporal adj and concatenate to form one block column of the static. The flipping is needed due to the
57 | # format of converted adjacency matrix. E.g. temporal adj [A,B,C], where A is the instant adj matrix. Then, the converted adj
58 | # is [[[0,0,C],[0,0,B],[0,0,A]]]. The last column is the concatenation of flipped temporal adj.
59 | block_column = np.flip(adj_matrix, axis=1).reshape(
60 | -1, n_lag * n_nodes, n_nodes
61 | ) # [N, (lag+1)*num_node, num_node]
62 | # Static graph
63 | # [N, (lag+1)*num_node, (lag+1)*num_node]
64 | static_adj = np.concatenate((block_fill_value, block_column), axis=2)
65 |
66 | return np.squeeze(static_adj)
67 |
68 |
69 | def dag_pen_np(X):
70 | assert X.shape[0] == X.shape[1]
71 | X = torch.from_numpy(X)
72 | return (torch.trace(torch.matrix_exp(X)) - X.shape[0]).item()
73 |
74 |
75 | def approximate_maximal_acyclic_subgraph(adj_matrix: np.ndarray, n_samples: int = 10) -> np.ndarray:
76 | """
77 | Compute an (approximate) maximal acyclic subgraph of a directed non-dag but removing at most 1/2 of the edges
78 | See Vazirani, Vijay V. Approximation algorithms. Vol. 1. Berlin: springer, 2001, Page 7;
79 | Also Hassin, Refael, and Shlomi Rubinstein. "Approximations for the maximum acyclic subgraph problem."
80 | Information processing letters 51.3 (1994): 133-140.
81 | Args:
82 | adj_matrix: adjacency matrix of a directed graph (may contain cycles)
83 | n_samples: number of the random permutations generated. Default is 10.
84 | Returns:
85 | an adjacency matrix of the acyclic subgraph
86 | """
87 | # assign each node with a order
88 | adj_dag = np.zeros_like(adj_matrix)
89 | for _ in range(n_samples):
90 | random_order = np.expand_dims(
91 | np.random.permutation(adj_matrix.shape[0]), 0)
92 | # subgraph with only forward edges defined by the assigned order
93 | adj_forward = (
94 | (random_order.T > random_order).astype(int)) * adj_matrix
95 | # subgraph with only backward edges defined by the assigned order
96 | adj_backward = (
97 | (random_order.T < random_order).astype(int)) * adj_matrix
98 | # return the subgraph with the least deleted edges
99 | adj_dag_n = adj_forward if adj_backward.sum() < adj_forward.sum() else adj_backward
100 | if adj_dag_n.sum() > adj_dag.sum():
101 | adj_dag = adj_dag_n
102 | return adj_dag
103 |
104 |
105 | def int2binlist(i: int, n_bits: int):
106 | """
107 | Convert integer to list of ints with values in {0, 1}
108 | """
109 | assert i < 2**n_bits
110 | str_list = list(np.binary_repr(i, n_bits))
111 | return [int(i) for i in str_list]
112 |
113 |
114 | def cpdag2dags(cp_mat: np.ndarray, samples: Optional[int] = None) -> np.ndarray:
115 | """
116 | Compute all possible DAGs contained within a Markov equivalence class, given by a CPDAG
117 | Args:
118 | cp_mat: adjacency matrix containing both forward and backward edges for edges for which directionality is undetermined
119 | Returns:
120 | 3 dimensional tensor, where the first indexes all the possible DAGs
121 | """
122 | assert len(cp_mat.shape) == 2 and cp_mat.shape[0] == cp_mat.shape[1]
123 |
124 | # matrix composed of just undetermined edges
125 | cycle_mat = (cp_mat == cp_mat.T) * cp_mat
126 | # return original matrix if there are no length-1 cycles
127 | if cycle_mat.sum() == 0:
128 | if dag_pen_np(cp_mat) != 0.0:
129 | cp_mat = approximate_maximal_acyclic_subgraph(cp_mat)
130 | return cp_mat[None, :, :]
131 |
132 | # matrix of determined edges
133 | cp_determined_subgraph = cp_mat - cycle_mat
134 |
135 | # prune cycles if the matrix of determined edges is not a dag
136 | if dag_pen_np(cp_determined_subgraph.copy()) != 0.0:
137 | cp_determined_subgraph = approximate_maximal_acyclic_subgraph(
138 | cp_determined_subgraph, 1000)
139 |
140 | # number of parent nodes for each node under the well determined matrix
141 | n_in_nodes = cp_determined_subgraph.sum(axis=0)
142 |
143 | # lower triangular version of cycles edges: only keep cycles in one direction.
144 | cycles_tril = np.tril(cycle_mat, k=-1)
145 |
146 | # indices of potential new edges
147 | undetermined_idx_mat = np.array(np.nonzero(
148 | cycles_tril)).T # (N_undedetermined, 2)
149 |
150 | # number of undetermined edges
151 | N_undetermined = int(cycles_tril.sum())
152 |
153 | # choose random order for mask iteration
154 | max_dags = 2**N_undetermined
155 |
156 | if max_dags > 10000:
157 | print("The number of possible dags are too large (>10000), limit to 10000")
158 | max_dags = 10000
159 |
160 | if samples is None:
161 | samples = max_dags
162 | mask_indices = list(np.random.permutation(np.arange(max_dags)))
163 |
164 | # iterate over list of all potential combinations of new edges. 0 represents keeping edge from upper triangular and 1 from lower triangular
165 | dag_list: list = []
166 | while mask_indices and len(dag_list) < samples:
167 |
168 | mask_index = mask_indices.pop()
169 | mask = np.array(int2binlist(mask_index, N_undetermined))
170 |
171 | # extract list of indices which our new edges are pointing into
172 | incoming_edges = np.take_along_axis(
173 | undetermined_idx_mat, mask[:, None], axis=1).squeeze()
174 |
175 | # check if multiple edges are pointing at same node
176 | _, unique_counts = np.unique(
177 | incoming_edges, return_index=False, return_inverse=False, return_counts=True)
178 |
179 | # check if new colider has been created by checkig if multiple edges point at same node or if new edge points at existing child node
180 | new_colider = np.any(unique_counts > 1) or np.any(
181 | n_in_nodes[incoming_edges] > 0)
182 |
183 | if not new_colider:
184 | # get indices of new edges by sampling from lower triangular mat and upper triangular according to indices
185 | edge_selection = undetermined_idx_mat.copy()
186 | edge_selection[mask == 0, :] = np.fliplr(
187 | edge_selection[mask == 0, :])
188 |
189 | # add new edges to matrix and add to dag list
190 | new_dag = cp_determined_subgraph.copy()
191 | new_dag[(edge_selection[:, 0], edge_selection[:, 1])] = 1
192 |
193 | # Check for high order cycles
194 | if dag_pen_np(new_dag.copy()) == 0.0:
195 | dag_list.append(new_dag)
196 | # When all combinations of new edges create cycles, we will only keep determined ones
197 | if len(dag_list) == 0:
198 | dag_list.append(cp_determined_subgraph)
199 |
200 | return np.stack(dag_list, axis=0)
201 |
--------------------------------------------------------------------------------
/src/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from omegaconf import DictConfig, open_dict
3 | from sklearn.model_selection import ParameterGrid
4 |
5 |
6 | def generate_unique_name(config):
7 | # generate unique name based on the config
8 | run_name = config['model']+"_"+config['dataset'] + \
9 | "_"+"seed_"+str(config['random_seed'])+"_"
10 | run_name += datetime.now().strftime("%Y%m%d_%H_%M_%S")
11 | return run_name
12 |
13 |
14 | def read_optional(config, arg, default):
15 | if arg in config:
16 | return config[arg]
17 | return default
18 |
19 |
20 | def add_attribute(config: DictConfig, name, val):
21 | with open_dict(config):
22 | config[name] = val
23 |
24 |
25 | def add_all_attributes(cfg, cfg2):
26 | # add all attributes from cfg2 to cfg
27 | for key in cfg2:
28 | add_attribute(cfg, key, cfg2[key])
29 |
30 |
31 | def build_subdictionary(hyperparameters, loop_hyperparameters):
32 | """
33 | Given dictionary of hyperparameters (where some of the values may be lists) and a list of keys
34 | loop_hyperparameters, build a ParameterGrid
35 |
36 | """
37 | # build sub dictionary of hyperparameters
38 | subparameters = dict(
39 | (k, hyperparameters[k]) for k in loop_hyperparameters if k in hyperparameters)
40 | subparameters = dict((k, [subparameters[k]]) if not isinstance(
41 | subparameters[k], list) else (k, subparameters[k]) for k in subparameters)
42 |
43 | subparameters = ParameterGrid(subparameters)
44 |
45 | return subparameters
46 |
--------------------------------------------------------------------------------
/src/utils/data_gen/generate_perturb_syn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import yaml
4 |
5 | from data_generation_utils import generate_cts_temporal_data, set_random_seed, generate_temporal_graph
6 | import cdt
7 | import numpy as np
8 | import networkx as nx
9 |
10 |
11 | def calc_dist(adj_matrix):
12 | unique_matrices = np.unique(adj_matrix, axis=0)
13 | distances = []
14 | for i in range(unique_matrices.shape[0]):
15 | for j in range(i):
16 | distances.append(cdt.metrics.SHD(
17 | unique_matrices[i], unique_matrices[j]))
18 | mean_dist = np.mean(distances)
19 | min_dist = np.min(distances)
20 | max_dist = np.max(distances)
21 | std_dist = np.std(distances)
22 |
23 | return mean_dist, std_dist, min_dist, max_dist
24 |
25 |
26 | def perturb_graph(G, p):
27 | retry_counter = 0
28 | while True:
29 | perturbation = np.random.binomial(1, p, size=G.shape)
30 | perturbed = np.logical_xor(G, perturbation).astype(int)
31 | # nxG = nx.from_numpy_matrix(perturbed[0], create_using=nx.DiGraph)
32 | nxG = nx.DiGraph(perturbed[0])
33 | if nx.is_directed_acyclic_graph(nxG):
34 | break
35 | retry_counter += 1
36 |
37 | if retry_counter >= 200000:
38 | assert False, "Cannot generate DAG, try a lower value of p"
39 |
40 | return perturbed.astype(int)
41 |
42 |
43 | def main(config_file):
44 |
45 | # read the yaml file
46 | with open(config_file, 'r', encoding="utf-8") as f:
47 | data_config = yaml.load(f, Loader=yaml.FullLoader)
48 |
49 | series_length = int(data_config["num_timesteps"])
50 | burnin_length = int(data_config["burnin_length"])
51 | num_samples = int(data_config["num_samples"])
52 | disable_inst = bool(data_config["disable_inst"])
53 | graph_type = [data_config["inst_graph_type"],
54 | data_config["lag_graph_type"]]
55 | p_array = data_config['p_array']
56 |
57 | connection_factor = 1
58 |
59 | # in this case, all inst and lagged adj are dags and have total edges per adj = 2*num_nodes
60 | # graph_type = ["SF", "SF"]
61 | # graph_config = [{"m":2, "directed":True}, {"m":2, "directed":True}]
62 |
63 | lag = int(data_config["lag"])
64 | is_history_dep = bool(data_config["history_dep_noise"])
65 |
66 | noise_level = float(data_config["noise_level"])
67 | function_type = data_config["function_type"]
68 | noise_function_type = data_config["noise_function_type"]
69 | base_noise_type = data_config["base_noise_type"]
70 |
71 | save_dir = data_config["save_dir"]
72 | N = data_config["num_nodes"]
73 | num_graphs = data_config["num_graphs"]
74 |
75 | for seed in data_config['random_seed']:
76 | seed = int(seed)
77 | set_random_seed(seed)
78 | graph_config = [
79 | {"m": N * 2 * connection_factor if not disable_inst else 0, "directed": True},
80 | {"m": N * connection_factor, "directed": True},
81 | ]
82 | G = generate_temporal_graph(
83 | N, graph_type, graph_config, lag=2).astype(int)
84 | N = int(N)
85 |
86 | for N_G in num_graphs:
87 | N_G = int(N_G)
88 | print(f"Generating dataset for N={N}, num_graphs={N_G}")
89 |
90 | for p in p_array:
91 | graphs = []
92 | for i in range(N_G):
93 | print(f"Generating graph {i}/{N_G}")
94 | Gtilde = perturb_graph(G, p)
95 | graphs.append(Gtilde)
96 |
97 | mean_dist, std_dist, min_dist, max_dist = calc_dist(
98 | np.array(graphs))
99 | print(
100 | f"Perturbation,{N},{N_G},{mean_dist},{std_dist},{min_dist},{max_dist},{p}")
101 |
102 | folder_name = f"perturb_N{N}_K{N_G}_p{p}_seed{seed}"
103 | path = os.path.join(save_dir, folder_name)
104 |
105 | generate_cts_temporal_data(
106 | path=path,
107 | num_graphs=N_G,
108 | series_length=series_length,
109 | burnin_length=burnin_length,
110 | num_samples=num_samples,
111 | num_nodes=N,
112 | graph_type=graph_type,
113 | graph_config=graph_config,
114 | lag=lag,
115 | is_history_dep=is_history_dep,
116 | noise_level=noise_level,
117 | function_type=function_type,
118 | noise_function_type=noise_function_type,
119 | base_noise_type=base_noise_type,
120 | temporal_graphs=graphs
121 | )
122 |
123 |
124 | if __name__ == "__main__":
125 | parser = argparse.ArgumentParser("Temporal Synthetic Data Generator")
126 | parser.add_argument("--config_file", type=str)
127 | args = parser.parse_args()
128 |
129 | main(config_file=args.config_file)
130 |
--------------------------------------------------------------------------------
/src/utils/data_gen/generate_stock.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from yahoofinancials import YahooFinancials
4 | import pandas as pd
5 | import numpy as np
6 | import tqdm
7 |
8 |
9 | def process_stock_price(tickers, start_date, end_date):
10 | yfs = [YahooFinancials(t) for t in tickers]
11 | X = []
12 | dates = []
13 | for yf in tqdm.tqdm(yfs, "Processing stock market data"):
14 | r = yf.get_historical_price_data(start_date, end_date, 'daily')
15 | r = dict(r)
16 | ticker = list(r.keys())[0]
17 | close_prices = np.array([t['close'] for t in r[ticker]['prices']])
18 | X.append(close_prices)
19 | dates.append([t['formatted_date'] for t in r[ticker]['prices']])
20 | X = np.array(X)
21 | # print(X.shape)
22 | dates = np.array(dates)
23 | return X, dates
24 |
25 |
26 | def get_log_returns(X):
27 | return np.diff(np.log(X))
28 |
29 |
30 | def standardize(X):
31 | mu = np.mean(X, axis=1)
32 | std = np.std(X, axis=1)
33 | mu = np.repeat(mu[:, np.newaxis], repeats=X.shape[1], axis=1)
34 | std = np.repeat(std[:, np.newaxis], repeats=X.shape[1], axis=1)
35 | return (X-mu)/std
36 |
37 |
38 | def generate_stock(args):
39 | df = pd.read_csv(args.stock_list_file)
40 | tickers = df['Symbol'].tolist()
41 |
42 | X, dates = process_stock_price(tickers, args.start_date, args.end_date)
43 | X = get_log_returns(X)
44 |
45 | X = standardize(X)
46 |
47 | D = X.shape[0]
48 | T = X.shape[1]
49 | L = args.chunk_size
50 | dat = np.zeros((D, T//L, L))
51 | date_array = []
52 |
53 | for i in range(T//L):
54 | dat[:, i] = X[:, L*i:(i+1)*L]
55 | date_array.append(dates[0][L*i])
56 | print(date_array[-1])
57 | date_array.append(args.end_date)
58 | print("Data shape:", dat.shape)
59 | dat = dat.transpose((1, 2, 0))
60 | if not os.path.exists(args.save_dir):
61 | os.makedirs(args.save_dir)
62 |
63 | np.save(os.path.join(args.save_dir, 'X.npy'), dat)
64 |
65 | for sector in df['Sector'].unique():
66 | X_sector = dat[:, :, df[df['Sector'] == sector].index.values]
67 | np.save(os.path.join(args.save_dir,
68 | f'X_{sector.replace(" ", "")}.npy'), X_sector)
69 |
70 | with open(os.path.join(args.save_dir, 'dates.csv'), 'w', encoding="utf-8") as f:
71 | for i in range(len(date_array)-1):
72 | f.write(f"{date_array[i]} to {date_array[i+1]}\n")
73 |
74 |
75 | if __name__ == '__main__':
76 | parser = argparse.ArgumentParser(
77 | "Stock data generator from Yahoo Financials")
78 | parser.add_argument("--start_date", type=str, default='2016-01-01')
79 | parser.add_argument("--end_date", type=str, default='2023-07-01')
80 | parser.add_argument('--chunk_size', type=int, default=31,
81 | help='Number of days to chunk together into one sample. Default is 31 days')
82 | parser.add_argument('--stock_list_file', type=str)
83 | parser.add_argument('--save_dir', type=str)
84 |
85 | args = parser.parse_args()
86 | generate_stock(args)
87 |
--------------------------------------------------------------------------------
/src/utils/data_gen/generate_synthetic_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import yaml
4 | from data_generation_utils import generate_cts_temporal_data, generate_name, set_random_seed
5 |
6 |
7 | def main(config_file):
8 |
9 | # read the yaml file
10 | with open(config_file, encoding="utf-8") as f:
11 | data_config = yaml.load(f, Loader=yaml.FullLoader)
12 |
13 | series_length = int(data_config["num_timesteps"])
14 | burnin_length = int(data_config["burnin_length"])
15 | num_samples = int(data_config["num_samples"])
16 | disable_inst = bool(data_config["disable_inst"])
17 | graph_type = [data_config["inst_graph_type"],
18 | data_config["lag_graph_type"]]
19 |
20 | connection_factor = 1
21 |
22 | # in this case, all inst and lagged adj are dags and have total edges per adj = 2*num_nodes
23 | # graph_type = ["SF", "SF"]
24 | # graph_config = [{"m":2, "directed":True}, {"m":2, "directed":True}]
25 |
26 | lag = int(data_config["lag"])
27 | is_history_dep = bool(data_config["history_dep_noise"])
28 |
29 | noise_level = float(data_config["noise_level"])
30 | function_type = data_config["function_type"]
31 | noise_function_type = data_config["noise_function_type"]
32 | base_noise_type = data_config["base_noise_type"]
33 |
34 | save_dir = data_config["save_dir"]
35 | num_nodes = data_config["num_nodes"]
36 | num_graphs = data_config["num_graphs"]
37 |
38 | for seed in data_config['random_seed']:
39 | seed = int(seed)
40 | set_random_seed(seed)
41 |
42 | for N in num_nodes:
43 | for N_G in num_graphs:
44 | print(f"Generating dataset for N={N}, num_graphs={N_G}")
45 | N = int(N)
46 | N_G = int(N_G)
47 | graph_config = [
48 | {"m": N * 2 * connection_factor if not disable_inst else 0,
49 | "directed": True},
50 | {"m": N * connection_factor, "directed": True},
51 | ]
52 |
53 | folder_name = generate_name(
54 | N,
55 | num_samples,
56 | graph_type,
57 | N_G,
58 | lag,
59 | is_history_dep,
60 | noise_level,
61 | function_type=function_type,
62 | noise_function_type=noise_function_type,
63 | disable_inst=disable_inst,
64 | seed=seed,
65 | connection_factor=connection_factor,
66 | base_noise_type=base_noise_type
67 | )
68 | path = os.path.join(save_dir, folder_name)
69 |
70 | generate_cts_temporal_data(
71 | path=path,
72 | num_graphs=N_G,
73 | series_length=series_length,
74 | burnin_length=burnin_length,
75 | num_samples=num_samples,
76 | num_nodes=N,
77 | graph_type=graph_type,
78 | graph_config=graph_config,
79 | lag=lag,
80 | is_history_dep=is_history_dep,
81 | noise_level=noise_level,
82 | function_type=function_type,
83 | noise_function_type=noise_function_type,
84 | base_noise_type=base_noise_type
85 | )
86 |
87 |
88 | if __name__ == "__main__":
89 | parser = argparse.ArgumentParser("Temporal Synthetic Data Generator")
90 | parser.add_argument("--config_file", type=str)
91 | args = parser.parse_args()
92 |
93 | main(config_file=args.config_file)
94 |
--------------------------------------------------------------------------------
/src/utils/data_gen/process_dream3.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | import pandas as pd
7 |
8 |
9 | def process_ts(ts, timepoints, N_subjects):
10 | N_nodes = ts.shape[1]
11 | X = np.zeros((N_subjects, timepoints, N_nodes))
12 |
13 | for i in range(N_subjects):
14 | X[i] = ts[i*timepoints: (i+1)*timepoints]
15 |
16 | return X
17 |
18 |
19 | def process_adj_matrix(net, size):
20 | A = np.zeros((size, size))
21 | for a, b, c in net.values:
22 | src = int(a[1:])-1
23 | dest = int(b[1:])-1
24 | A[src, dest] = 1
25 | return A
26 |
27 |
28 | def split_by_trajectory(X, A, T=21):
29 | time_len = X.shape[0]
30 | N = X.shape[1]
31 | num_samples = time_len // T
32 | data = np.zeros((num_samples, T, N))
33 | adj_matrix = np.zeros((num_samples, N, N))
34 | for i in range(num_samples):
35 | data[i] = X[i*T:(i+1)*T]
36 | adj_matrix[i] = A
37 |
38 | return data, adj_matrix
39 |
40 |
41 | def process_dream3(args):
42 |
43 | for size in [10, 50, 100]:
44 | X1 = torch.load(os.path.join(
45 | args.dataset_dir, 'Dream3TensorData', f'Size{size}Ecoli1.pt'))['TsData'].numpy()
46 | A1 = pd.read_table(os.path.join(
47 | args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Ecoli1.tsv'), header=None)
48 | A1 = process_adj_matrix(A1, size)
49 | X1, A1 = split_by_trajectory(X1, A1)
50 |
51 | X2 = torch.load(os.path.join(
52 | args.dataset_dir, 'Dream3TensorData', f'Size{size}Ecoli2.pt'))['TsData'].numpy()
53 | A2 = pd.read_table(os.path.join(
54 | args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Ecoli2.tsv'), header=None)
55 | A2 = process_adj_matrix(A2, size)
56 | X2, A2 = split_by_trajectory(X2, A2)
57 |
58 | if not os.path.exists(os.path.join(args.save_dir, f'ecoli_{size}', 'grouped_by_matrix')):
59 | os.makedirs(os.path.join(args.save_dir,
60 | f'ecoli_{size}', 'grouped_by_matrix'))
61 |
62 | np.savez(os.path.join(
63 | args.save_dir, f'ecoli_{size}', 'grouped_by_matrix', 'ecoli_1.npz'), X=X1, adj_matrix=A1)
64 | np.savez(os.path.join(
65 | args.save_dir, f'ecoli_{size}', 'grouped_by_matrix', 'ecoli_2.npz'), X=X2, adj_matrix=A2)
66 |
67 | X = np.concatenate((X1, X2), axis=0)
68 | A = np.concatenate((A1, A2), axis=0)
69 | np.savez(os.path.join(args.save_dir,
70 | f'ecoli_{size}', 'ecoli.npz'), X=X, adj_matrix=A)
71 |
72 | X11 = torch.load(os.path.join(
73 | args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast1.pt'))['TsData'].numpy()
74 | A11 = pd.read_table(os.path.join(
75 | args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast1.tsv'), header=None)
76 | A11 = process_adj_matrix(A11, size)
77 | X11, A11 = split_by_trajectory(X11, A11)
78 |
79 | X21 = torch.load(os.path.join(
80 | args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast2.pt'))['TsData'].numpy()
81 | A21 = pd.read_table(os.path.join(
82 | args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast2.tsv'), header=None)
83 | A21 = process_adj_matrix(A21, size)
84 | X21, A21 = split_by_trajectory(X21, A21)
85 |
86 | X31 = torch.load(os.path.join(
87 | args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast3.pt'))['TsData'].numpy()
88 | A31 = pd.read_table(os.path.join(
89 | args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast3.tsv'), header=None)
90 | A31 = process_adj_matrix(A31, size)
91 | X31, A31 = split_by_trajectory(X31, A31)
92 |
93 | if not os.path.exists(os.path.join(args.save_dir, f'yeast_{size}', 'grouped_by_matrix')):
94 | os.makedirs(os.path.join(args.save_dir,
95 | f'yeast_{size}', 'grouped_by_matrix'))
96 |
97 | np.savez(os.path.join(
98 | args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_1.npz'), X=X11, adj_matrix=A11)
99 | np.savez(os.path.join(
100 | args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_2.npz'), X=X21, adj_matrix=A21)
101 | np.savez(os.path.join(
102 | args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_3.npz'), X=X31, adj_matrix=A31)
103 | X = np.concatenate((X11, X21, X31), axis=0)
104 | A = np.concatenate((A11, A21, A31), axis=0)
105 | print(X.shape)
106 | print(A.shape)
107 | np.savez(os.path.join(args.save_dir,
108 | f'yeast_{size}', 'yeast.npz'), X=X, adj_matrix=A)
109 |
110 | # save combined
111 | X = np.concatenate((X1, X2, X11, X21, X31), axis=0)
112 | A = np.concatenate((A1, A2, A11, A21, A31), axis=0)
113 | print(X.shape)
114 | print(A.shape)
115 | np.savez(os.path.join(args.save_dir,
116 | f'combined_{size}.npz'), X=X, adj_matrix=A)
117 |
118 |
119 | if __name__ == "__main__":
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument('--dataset_dir', type=str)
122 | parser.add_argument('--save_dir', type=str)
123 | args = parser.parse_args()
124 | process_dream3(args)
125 |
--------------------------------------------------------------------------------
/src/utils/data_gen/process_netsim.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import random
4 |
5 | import scipy.io as sio
6 | import argparse
7 | import numpy as np
8 |
9 |
10 | def process_ts(ts, timepoints, N_subjects):
11 | N_nodes = ts.shape[1]
12 | X = np.zeros((N_subjects, timepoints, N_nodes))
13 |
14 | for i in range(N_subjects):
15 | X[i] = ts[i*timepoints: (i+1)*timepoints]
16 |
17 | return X
18 |
19 |
20 | def process_adj_matrix(net):
21 | return (np.abs(np.swapaxes(net, 1, 2)) > 0).astype(int)
22 |
23 |
24 | def process_netsim(args):
25 | seed = 0
26 | simulations = range(1, 29)
27 | random.seed(seed)
28 | np.random.seed(seed)
29 | N_t_dict = {}
30 | X_dict_by_matrix = {}
31 |
32 | for i in simulations:
33 | mat = sio.loadmat(os.path.join(args.dataset_dir, f'sim{i}.mat'))
34 | timepoints = mat['Ntimepoints'][0, 0]
35 | N_subjects = mat['Nsubjects'][0, 0]
36 | ts = mat['ts']
37 | net = mat['net']
38 | N_nodes = ts.shape[1]
39 | X = process_ts(ts, timepoints, N_subjects)
40 | adj_matrix = process_adj_matrix(net)
41 |
42 | if (N_nodes, timepoints) not in N_t_dict:
43 | N_t_dict[(N_nodes, timepoints)] = {}
44 | N_t_dict[(N_nodes, timepoints)]['X'] = []
45 | N_t_dict[(N_nodes, timepoints)]['adj_matrix'] = []
46 |
47 | N_t_dict[(N_nodes, timepoints)]['X'].append(X)
48 | N_t_dict[(N_nodes, timepoints)]['adj_matrix'].append(adj_matrix)
49 |
50 | for j in range(X.shape[0]):
51 | if (adj_matrix[j].tobytes(), X.shape[1]) not in X_dict_by_matrix:
52 | X_dict_by_matrix[(adj_matrix[j].tobytes(), X.shape[1])] = []
53 | X_dict_by_matrix[(adj_matrix[j].tobytes(),
54 | X.shape[1])].append(X[j])
55 |
56 | if N_nodes == 15 and timepoints == 200:
57 | print("SIMULATION", i)
58 | for N_nodes, timepoints in N_t_dict:
59 | X = np.concatenate(N_t_dict[(N_nodes, timepoints)]['X'], axis=0)
60 | adj_matrix = np.concatenate(
61 | N_t_dict[(N_nodes, timepoints)]['adj_matrix'], axis=0)
62 | np.savez(os.path.join(
63 | args.save_dir, f'netsim_{N_nodes}_{timepoints}.npz'), X=X, adj_matrix=adj_matrix)
64 |
65 | counts = {}
66 | for key, T in X_dict_by_matrix:
67 | N = int(math.sqrt(len(np.frombuffer(key, dtype=int))))
68 | if (N, T) not in counts:
69 | counts[(N, T)] = 0
70 |
71 | adj_matrix = np.array(np.frombuffer(key, dtype=int)).reshape(N, N)
72 | if not os.path.exists(os.path.join(args.save_dir, 'grouped_by_matrix', f'{N}_{T}')):
73 | os.makedirs(os.path.join(args.save_dir,
74 | 'grouped_by_matrix', f'{N}_{T}'))
75 | X = np.array(X_dict_by_matrix[(key, T)])
76 | np.savez(os.path.join(args.save_dir, 'grouped_by_matrix',
77 | f'{N}_{T}', f'netsim_{counts[(N, T)]}.npz'), X=X, adj_matrix=adj_matrix)
78 | counts[(N, T)] += 1
79 |
80 | # add the permutations
81 | for N_nodes in [15, 50]:
82 | timepoints = 200
83 | num_graphs = 3
84 | permutation_pool = [np.random.permutation(
85 | np.arange(N_nodes)) for i in range(num_graphs)]
86 |
87 | X = []
88 | adj_matrix = []
89 |
90 | for i in range(len(N_t_dict[(N_nodes, timepoints)]['X'][0])):
91 | I = random.choice(permutation_pool)
92 | x = N_t_dict[(N_nodes, timepoints)]['X'][0][i]
93 | x = np.array(x)
94 | x = x[:, I]
95 |
96 | G = N_t_dict[(N_nodes, timepoints)]['adj_matrix'][0][i]
97 | G = G[I][:, I]
98 |
99 | X.append(x)
100 | adj_matrix.append(G)
101 |
102 | X = np.array(X)
103 | adj_matrix = np.array(adj_matrix)
104 | np.savez(os.path.join(
105 | args.save_dir, f'netsim_{N_nodes}_{timepoints}_permuted.npz'), X=X, adj_matrix=adj_matrix)
106 |
107 |
108 | if __name__ == "__main__":
109 | parser = argparse.ArgumentParser()
110 | parser.add_argument('--dataset_dir', type=str)
111 | parser.add_argument('--save_dir', type=str)
112 | args = parser.parse_args()
113 | process_netsim(args)
114 |
--------------------------------------------------------------------------------
/src/utils/data_gen/splines.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrowed from github.com/microsoft/causica
3 | Needed by data_generation_utils.py
4 | """
5 |
6 | import numpy as np
7 | import torch
8 | from torch.nn import functional as F
9 |
10 | DEFAULT_MIN_BIN_WIDTH = 1e-3
11 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
12 | DEFAULT_MIN_DERIVATIVE = 1e-3
13 |
14 |
15 | def searchsorted(bin_locations, inputs, eps=1e-6):
16 | bin_locations[..., -1] += eps
17 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
18 |
19 |
20 | def unconstrained_RQS(
21 | inputs,
22 | unnormalized_widths,
23 | unnormalized_heights,
24 | unnormalized_derivatives,
25 | inverse=False,
26 | tail_bound=1.0,
27 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
28 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
29 | min_derivative=DEFAULT_MIN_DERIVATIVE,
30 | ):
31 |
32 | inside_intvl_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
33 | outside_interval_mask = ~inside_intvl_mask
34 |
35 | outputs = torch.zeros_like(inputs)
36 | logabsdet = torch.zeros_like(inputs)
37 |
38 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
39 | constant = np.log(np.exp(1 - min_derivative) - 1)
40 | unnormalized_derivatives[..., 0] = constant
41 | unnormalized_derivatives[..., -1] = constant
42 |
43 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
44 | logabsdet[outside_interval_mask] = 0
45 |
46 | if inside_intvl_mask.any():
47 | outputs[inside_intvl_mask], logabsdet[inside_intvl_mask] = RQS(
48 | inputs=inputs[inside_intvl_mask],
49 | unnormalized_widths=unnormalized_widths[inside_intvl_mask, :],
50 | unnormalized_heights=unnormalized_heights[inside_intvl_mask, :],
51 | unnormalized_derivatives=unnormalized_derivatives[inside_intvl_mask, :],
52 | inverse=inverse,
53 | left=-tail_bound,
54 | right=tail_bound,
55 | bottom=-tail_bound,
56 | top=tail_bound,
57 | min_bin_width=min_bin_width,
58 | min_bin_height=min_bin_height,
59 | min_derivative=min_derivative,
60 | )
61 | return outputs, logabsdet
62 |
63 |
64 | def RQS(
65 | inputs,
66 | unnormalized_widths,
67 | unnormalized_heights,
68 | unnormalized_derivatives,
69 | inverse=False,
70 | left=0.0,
71 | right=1.0,
72 | bottom=0.0,
73 | top=1.0,
74 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
75 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
76 | min_derivative=DEFAULT_MIN_DERIVATIVE,
77 | ):
78 | if torch.min(inputs) < left or torch.max(inputs) > right:
79 | raise ValueError("Input outside domain")
80 |
81 | num_bins = unnormalized_widths.shape[-1]
82 |
83 | if min_bin_width * num_bins > 1.0:
84 | raise ValueError("Minimal bin width too large for the number of bins")
85 | if min_bin_height * num_bins > 1.0:
86 | raise ValueError("Minimal bin height too large for the number of bins")
87 |
88 | widths = F.softmax(unnormalized_widths, dim=-1)
89 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
90 | cumwidths = torch.cumsum(widths, dim=-1)
91 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
92 | cumwidths = (right - left) * cumwidths + left
93 | cumwidths[..., 0] = left
94 | cumwidths[..., -1] = right
95 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
96 |
97 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
98 |
99 | heights = F.softmax(unnormalized_heights, dim=-1)
100 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
101 | cumheights = torch.cumsum(heights, dim=-1)
102 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
103 | cumheights = (top - bottom) * cumheights + bottom
104 | cumheights[..., 0] = bottom
105 | cumheights[..., -1] = top
106 | heights = cumheights[..., 1:] - cumheights[..., :-1]
107 |
108 | if inverse:
109 | bin_idx = searchsorted(cumheights, inputs)[..., None]
110 | else:
111 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
112 |
113 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
114 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
115 |
116 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
117 | delta = heights / widths
118 | input_delta = delta.gather(-1, bin_idx)[..., 0]
119 |
120 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
121 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)
122 | input_derivatives_plus_one = input_derivatives_plus_one[..., 0]
123 |
124 | input_heights = heights.gather(-1, bin_idx)[..., 0]
125 |
126 | if inverse:
127 | a = (inputs - input_cumheights) * (
128 | input_derivatives + input_derivatives_plus_one - 2 * input_delta
129 | ) + input_heights * (input_delta - input_derivatives)
130 | b = input_heights * input_derivatives - (inputs - input_cumheights) * (
131 | input_derivatives + input_derivatives_plus_one - 2 * input_delta
132 | )
133 | c = -input_delta * (inputs - input_cumheights)
134 |
135 | discriminant = b.pow(2) - 4 * a * c
136 | assert (discriminant >= 0).all()
137 |
138 | root = (2 * c) / (-b - torch.sqrt(discriminant))
139 | outputs = root * input_bin_widths + input_cumwidths
140 |
141 | theta_one_minus_theta = root * (1 - root)
142 | denominator = input_delta + (
143 | (input_derivatives + input_derivatives_plus_one -
144 | 2 * input_delta) * theta_one_minus_theta
145 | )
146 | derivative_numerator = input_delta.pow(2) * (
147 | input_derivatives_plus_one * root.pow(2)
148 | + 2 * input_delta * theta_one_minus_theta
149 | + input_derivatives * (1 - root).pow(2)
150 | )
151 | logabsdet = torch.log(derivative_numerator) - \
152 | 2 * torch.log(denominator)
153 | return outputs, -logabsdet
154 |
155 | # else
156 | theta = (inputs - input_cumwidths) / input_bin_widths
157 | theta_one_minus_theta = theta * (1 - theta)
158 |
159 | numerator = input_heights * \
160 | (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
161 | denominator = input_delta + (
162 | (input_derivatives + input_derivatives_plus_one -
163 | 2 * input_delta) * theta_one_minus_theta
164 | )
165 | outputs = input_cumheights + numerator / denominator
166 |
167 | derivative_numerator = input_delta.pow(2) * (
168 | input_derivatives_plus_one * theta.pow(2)
169 | + 2 * input_delta * theta_one_minus_theta
170 | + input_derivatives * (1 - theta).pow(2)
171 | )
172 | logabsdet = torch.log(derivative_numerator) - \
173 | 2 * torch.log(denominator)
174 | return outputs, logabsdet
175 |
--------------------------------------------------------------------------------
/src/utils/data_utils/data_format_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def convert_data_to_timelagged(X, lag):
6 | """
7 | Converts data with shape (n_samples, timesteps, num_nodes, data_dim) to two tensors,
8 | one history tensor of shape (n_fragments, lag, num_nodes, data_dim) and
9 | another "input" tensor of shape (n_fragments, num_nodes, data_dim)
10 | """
11 | n_samples, timesteps, num_nodes, data_dim = X.shape
12 |
13 | n_fragments_per_sample = timesteps - lag
14 | n_fragments = n_samples * n_fragments_per_sample
15 | X_history = np.zeros((n_fragments, lag, num_nodes, data_dim))
16 | X_input = np.zeros((n_fragments, num_nodes, data_dim))
17 | X_indices = np.zeros((n_fragments))
18 | for i in range(n_fragments_per_sample):
19 | X_history[i*n_samples:(i+1)*n_samples] = X[:, i:i+lag, :, :]
20 | X_input[i*n_samples:(i+1)*n_samples] = X[:, i+lag, :, :]
21 | X_indices[i*n_samples:(i+1)*n_samples] = np.arange(n_samples)
22 | return X_history, X_input, X_indices
23 |
24 |
25 | def get_adj_matrix_id(A):
26 | return np.unique(A, axis=(0), return_inverse=True)
27 |
28 |
29 | def convert_adj_to_timelagged(A, lag, n_fragments, aggregated_graph=False, return_indices=True):
30 | """
31 | Converts adjacency matrix with shape (n_samples, (lag+1), num_nodes, num_nodes) to shape
32 | (n_fragments, lag+1, num_nodes, num_nodes)
33 | """
34 |
35 | A_indices = np.zeros((n_fragments))
36 | if len(A.shape) == 4:
37 | n_samples, L, num_nodes, num_nodes = A.shape
38 | assert L == lag+1
39 | Ap = np.zeros((n_fragments, lag+1, num_nodes, num_nodes))
40 | elif aggregated_graph:
41 | n_samples, num_nodes, num_nodes = A.shape
42 | Ap = np.zeros((n_fragments, num_nodes, num_nodes))
43 | else:
44 | assert False, "invalid adjacency matrix"
45 | n_fragments_per_sample = n_fragments // n_samples
46 |
47 | _, matrix_indices = get_adj_matrix_id(A)
48 |
49 | for i in range(n_fragments_per_sample):
50 | Ap[i*n_samples:(i+1)*n_samples] = A
51 | A_indices[i*n_samples:(i+1)*n_samples] = matrix_indices
52 |
53 | if return_indices:
54 | return Ap, A_indices
55 | return Ap
56 |
57 |
58 | def to_time_aggregated_graph_np(graph):
59 | # convert graph of shape [batch, lag+1, num_nodes, num_nodes] to aggregated
60 | # graph of shape [batch, num_nodes, num_nodes]
61 | return (np.sum(graph, axis=1) > 0).astype(int)
62 |
63 |
64 | def to_time_aggregated_scores_np(graph):
65 | return np.max(graph, axis=1)
66 |
67 |
68 | def to_time_aggregated_graph_torch(graph):
69 | # convert graph of shape [batch, lag+1, num_nodes, num_nodes] to aggregated
70 | # graph of shape [batch, num_nodes, num_nodes]
71 | return (torch.sum(graph, dim=1) > 0).long()
72 |
73 |
74 | def to_time_aggregated_scores_torch(graph):
75 | # convert edge probability matrix of shape [batch, lag+1, num_nodes, num_nodes] to aggregated
76 | # matrix of shape [batch, num_nodes, num_nodes]
77 | max_val, _ = torch.max(graph, dim=1)
78 | return max_val
79 |
80 |
81 | def zero_out_diag_np(G):
82 |
83 | if len(G.shape) == 3:
84 | N = G.shape[1]
85 | I = np.arange(N)
86 | G[:, I, I] = 0
87 |
88 | elif len(G.shape) == 2:
89 | N = G.shape[0]
90 | I = np.arange(N)
91 | G[I, I] = 0
92 |
93 | return G
94 |
95 |
96 | def zero_out_diag_torch(G):
97 |
98 | if len(G.shape) == 3:
99 | N = G.shape[1]
100 | I = torch.arange(N)
101 | G[:, I, I] = 0
102 |
103 | elif len(G.shape) == 2:
104 | N = G.shape[0]
105 | I = torch.arange(N)
106 | G[I, I] = 0
107 |
108 | return G
109 |
--------------------------------------------------------------------------------
/src/utils/data_utils/dataloading_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 |
5 | def get_dataset_path(dataset):
6 | if 'netsim' in dataset:
7 | dataset_path = 'netsim'
8 | elif 'dream3' in dataset:
9 | dataset_path = 'dream3'
10 | elif 'snp100' in dataset:
11 | dataset_path = 'snp100'
12 | elif dataset == ['lorenz96', 'finance', 'fluxnet']:
13 | dataset_path = dataset
14 | else:
15 | dataset_path = 'synthetic'
16 |
17 | return dataset_path
18 |
19 |
20 | def create_save_name(dataset, cfg):
21 | if dataset == 'lorenz96':
22 | return f'lorenz96_N={cfg.num_nodes}_T={cfg.timesteps}_num_graphs={cfg.num_graphs}'
23 | return dataset
24 |
25 |
26 | def load_synthetic_from_folder(dataset_dir, dataset_name):
27 | X = np.load(os.path.join(dataset_dir, dataset_name, 'X.npy'))
28 | adj_matrix = np.load(os.path.join(
29 | dataset_dir, dataset_name, 'adj_matrix.npy'))
30 |
31 | return X, adj_matrix
32 |
33 |
34 | def load_netsim(dataset_dir, dataset_file):
35 | # load the files
36 | data = np.load(os.path.join(dataset_dir, dataset_file + '.npz'))
37 | X = data['X']
38 | adj_matrix = data['adj_matrix']
39 | # adj_matrix = np.transpose(adj_matrix, (0, 2, 1))
40 | return X, adj_matrix
41 |
42 |
43 | def load_dream3_combined(dataset_dir, size):
44 | data = np.load(os.path.join(dataset_dir, f'combined_{size}.npz'))
45 | X = data['X']
46 | adj_matrix = data['adj_matrix']
47 | return X, adj_matrix
48 |
49 |
50 | def load_snp100(dataset, dataset_dir):
51 | if dataset == 'snp100':
52 | X = np.load(os.path.join(dataset_dir, 'X.npy'))
53 | else:
54 | # get the sector
55 | sector = dataset.split('_')[1]
56 | X = np.load(os.path.join(dataset_dir, f'X_{sector}.npy'))
57 |
58 | D = X.shape[2]
59 | # we do not have the true adjacency matrix
60 | adj_matrix = np.zeros((X.shape[0], D, D))
61 | return X, adj_matrix
62 |
63 |
64 | def load_data(dataset, dataset_dir, config):
65 | if 'netsim' in dataset:
66 | X, adj_matrix = load_netsim(
67 | dataset_dir=dataset_dir, dataset_file=dataset)
68 | aggregated_graph = True
69 | # adj_matrix = np.transpose(adj_matrix, (0, 2, 1))
70 | # read lag from config file
71 | lag = int(config['lag'])
72 | data_dim = 1
73 | X = np.expand_dims(X, axis=-1)
74 | elif dataset == 'dream3':
75 | dream3_size = int(config['dream3_size'])
76 | X, adj_matrix = load_dream3_combined(
77 | dataset_dir=dataset_dir, size=dream3_size)
78 | lag = int(config['lag'])
79 | data_dim = 1
80 | aggregated_graph = True
81 | X = np.expand_dims(X, axis=-1)
82 | elif 'snp100' in dataset:
83 | X, adj_matrix = load_snp100(dataset=dataset, dataset_dir=dataset_dir)
84 | aggregated_graph = True
85 | lag = 1
86 | data_dim = 1
87 | X = np.expand_dims(X, axis=-1)
88 | else:
89 | X, adj_matrix = load_synthetic_from_folder(
90 | dataset_dir=dataset_dir, dataset_name=dataset)
91 | lag = adj_matrix.shape[1] - 1
92 | data_dim = 1
93 | X = np.expand_dims(X, axis=-1)
94 | aggregated_graph = False
95 | print("Loaded data of shape:", X.shape)
96 | return X, adj_matrix, aggregated_graph, lag, data_dim
97 |
--------------------------------------------------------------------------------
/src/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | def temporal_graph_sparsity(G: torch.Tensor):
5 | """
6 | Args:
7 | G: (lag+1, num_nodes, num_nodes) or (batch, lag+1, num_nodes, num_nodes)
8 |
9 | Returns:
10 | Square of Frobenius norm of G (batch, )
11 | """
12 |
13 | return torch.sum(torch.square(G))
14 |
15 |
16 | def l1_sparsity(G: torch.Tensor):
17 | return torch.sum(torch.abs(G))
18 |
19 |
20 | def dag_penalty_notears(G: torch.Tensor):
21 | """
22 | Implements the DAGness penalty from
23 | "DAGs with NO TEARS: Continuous Optimization for Structure Learning"
24 |
25 | Args:
26 | G: (num_nodes, num_nodes) or (num_graphs, num_nodes, num_nodes)
27 |
28 | """
29 | num_nodes = G.shape[-1]
30 |
31 | if len(G.shape) == 2:
32 | trace_term = torch.trace(torch.matrix_exp(G))
33 | return trace_term - num_nodes
34 | elif len(G.shape) == 3:
35 | trace_term = torch.einsum("ijj->i", torch.matrix_exp(G))
36 | return torch.sum(trace_term - num_nodes)
37 | assert False, "DAG Penalty received illegal shape"
38 |
39 |
40 | def dag_penalty_notears_sq(W: torch.Tensor):
41 | num_nodes = W.shape[-1]
42 |
43 | if len(W.shape) == 2:
44 | trace_term = torch.trace(torch.matrix_exp(W * W))
45 | return trace_term - num_nodes
46 | elif len(W.shape) == 3:
47 | trace_term = torch.einsum("ijj->i", torch.matrix_exp(W * W))
48 | return torch.sum(trace_term) - W.shape[0] * num_nodes
49 | assert False, "DAG Penalty received illegal shape"
50 |
--------------------------------------------------------------------------------
/src/utils/metrics_utils.py:
--------------------------------------------------------------------------------
1 |
2 | from scipy.optimize import linear_sum_assignment
3 | from sklearn.metrics import confusion_matrix
4 | # metrics
5 | from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score
6 | from cdt.metrics import SHD
7 | import numpy as np
8 | import torch
9 |
10 |
11 | def get_off_diagonal(A):
12 | # assumes A.shape: (batch, x, y)
13 | M = np.invert(np.eye(A.shape[1], dtype=bool))
14 | return A[:, M]
15 |
16 |
17 | def adjacency_f1(adj_matrix, predictions):
18 | # adj_matrix: (b, l, d, d) or (b, d, d)
19 | # predictions: (b, l, d, d) or (b, d, d)
20 | D = adj_matrix.shape[-1]
21 | L = np.tril_indices(D, k=-1)
22 | U = np.triu_indices(D, k=1)
23 |
24 | adj_lower = adj_matrix[..., L[0], L[1]]
25 | adj_upper = adj_matrix[..., U[0], U[1]]
26 |
27 | adj_diag = np.diagonal(adj_matrix, axis1=-2, axis2=-1).flatten()
28 | adj = np.concatenate((adj_diag, np.logical_or(
29 | adj_lower, adj_upper).flatten().astype(int)))
30 |
31 | pred_diag = np.diagonal(predictions, axis1=-2, axis2=-1).flatten()
32 | pred_lower = predictions[..., L[0], L[1]]
33 | pred_upper = predictions[..., U[0], U[1]]
34 | pred = np.concatenate((pred_diag, np.logical_or(
35 | pred_lower, pred_upper).flatten().astype(int)))
36 |
37 | return f1_score(adj, pred)
38 |
39 |
40 | def compute_shd(adj_matrix, preds, aggregated_graph=False):
41 | assert adj_matrix.shape == preds.shape, f"Dimension of adj_matrix {adj_matrix.shape} should match the predictions {preds.shape}"
42 |
43 | if not aggregated_graph:
44 | assert len(
45 | adj_matrix.shape) == 4, "Expects adj_matrix of shape (batch, lag+1, num_nodes, num_nodes)"
46 | assert len(
47 | preds.shape) == 4, "Expects preds of shape (batch, lag+1, num_nodes, num_nodes)"
48 | else:
49 | assert len(
50 | adj_matrix.shape) == 3, "Expects adj_matrix of shape (batch, num_nodes, num_nodes)"
51 | assert len(
52 | preds.shape) == 3, "Expects preds of shape (batch, num_nodes, num_nodes)"
53 |
54 | shd_score = 0
55 | if not aggregated_graph:
56 | shd_inst = 0
57 | shd_lag = 0
58 | for i in range(adj_matrix.shape[0]):
59 | for j in range(adj_matrix.shape[1]):
60 | adj_sub_matrix = adj_matrix[i, j]
61 | preds_sub_matrix = preds[i, j]
62 | shd = SHD(adj_sub_matrix, preds_sub_matrix)
63 | shd_score += shd
64 | if j == 0:
65 | shd_inst += shd
66 | else:
67 | shd_lag += shd
68 | return shd_score/adj_matrix.shape[0], shd_inst/adj_matrix.shape[0], shd_lag/adj_matrix.shape[0]
69 | for i in range(adj_matrix.shape[0]):
70 | adj_sub_matrix = adj_matrix[i]
71 | preds_sub_matrix = preds[i]
72 | shd_score += SHD(adj_sub_matrix, preds_sub_matrix)
73 | # print(SHD(adj_sub_matrix, preds_sub_matrix))
74 | return shd_score/adj_matrix.shape[0]
75 |
76 |
77 | def calculate_expected_shd(scores, adj_matrix, aggregated_graph=False, n_trials=100):
78 | totals_shd = 0
79 | for _ in range(n_trials):
80 | draw = np.random.binomial(1, scores)
81 | if aggregated_graph:
82 | shd = compute_shd(adj_matrix, draw,
83 | aggregated_graph=aggregated_graph)
84 | else:
85 | shd, _, _ = compute_shd(
86 | adj_matrix, draw, aggregated_graph=aggregated_graph)
87 | totals_shd += shd
88 |
89 | return totals_shd/n_trials
90 |
91 |
92 | def evaluate_results(scores,
93 | adj_matrix,
94 | predictions,
95 | aggregated_graph=False,
96 | true_cluster_indices=None,
97 | pred_cluster_indices=None):
98 |
99 | assert adj_matrix.shape == predictions.shape, "Dimension of adj_matrix should match the predictions"
100 |
101 | abs_scores = np.abs(scores).flatten()
102 | preds = np.abs(np.round(predictions))
103 | truth = adj_matrix.flatten()
104 |
105 | # calculate shd
106 |
107 | if aggregated_graph:
108 | shd_score = compute_shd(adj_matrix, preds, aggregated_graph)
109 | else:
110 | shd_score, shd_inst, shd_lag = compute_shd(
111 | adj_matrix, preds, aggregated_graph)
112 | f1_inst = f1_score(get_off_diagonal(adj_matrix[:, 0]).flatten(
113 | ), get_off_diagonal(predictions[:, 0]).flatten())
114 | f1_lag = f1_score(adj_matrix[:, 1:].flatten(), preds[:, 1:].flatten())
115 |
116 | f1 = f1_score(truth, preds.flatten())
117 | adj_f1 = adjacency_f1(adj_matrix, predictions)
118 |
119 | preds = preds.flatten()
120 | zero_edge_accuracy = np.sum(np.logical_and(
121 | preds == 0, truth == 0))/np.sum(truth == 0)
122 | one_edge_accuracy = np.sum(np.logical_and(
123 | preds == 1, truth == 1))/np.sum(truth == 1)
124 |
125 | accuracy = accuracy_score(truth, preds)
126 | precision = precision_score(truth, preds)
127 | recall = recall_score(truth, preds)
128 |
129 | try:
130 | rocauc = roc_auc_score(truth, abs_scores)
131 | except ValueError:
132 | rocauc = 0.5
133 |
134 | tnr = zero_edge_accuracy
135 | tpr = one_edge_accuracy
136 |
137 | print("Accuracy score:", accuracy)
138 | print("Orientation F1 score:", f1)
139 | print("Adjacency F1:", adj_f1)
140 | print("Precision score:", precision)
141 | print("Recall score:", recall)
142 | print("ROC AUC score:", rocauc)
143 |
144 | print("Accuracy on '0' edges", tnr)
145 | print("Accuracy on '1' edges", tpr)
146 | print("Structural Hamming Distance:", shd_score)
147 | if not aggregated_graph:
148 | print("Structural Hamming Distance (inst):", shd_inst)
149 | print("Structural Hamming Distance (lag):", shd_lag)
150 | print("Orientation F1 inst", f1_inst)
151 | print("Orientation F1 lag", f1_lag)
152 | eshd = calculate_expected_shd(
153 | np.abs(scores/(np.max(scores)+1e-4)), adj_matrix, aggregated_graph)
154 | print("Expected SHD:", eshd)
155 | # also return a dictionary of metrics
156 | metrics = {
157 | 'accuracy': accuracy,
158 | 'f1': f1,
159 | 'adj_f1': adj_f1,
160 | 'precision': precision,
161 | 'recall': recall,
162 | 'roc_auc': rocauc,
163 | 'tnr': tnr,
164 | 'tpr': tpr,
165 | 'shd_overall': shd_score,
166 | 'expected_shd': eshd
167 | }
168 |
169 | if not aggregated_graph:
170 | metrics['shd_inst'] = shd_inst
171 | metrics['shd_lag'] = shd_lag
172 |
173 | metrics['f1_inst'] = f1_inst
174 | metrics['f1_lag'] = f1_lag
175 |
176 | if pred_cluster_indices is not None and true_cluster_indices is not None:
177 | metrics['cluster_acc'] = cluster_accuracy(true_idx=true_cluster_indices,
178 | pred_idx=pred_cluster_indices)
179 | else:
180 | _, true_cluster_indices = np.unique(
181 | adj_matrix, return_inverse=True, axis=0)
182 | _, pred_cluster_indices = np.unique(
183 | predictions, return_inverse=True, axis=0)
184 | metrics['cluster_acc'] = cluster_accuracy(true_idx=true_cluster_indices,
185 | pred_idx=pred_cluster_indices)
186 |
187 | return metrics
188 |
189 |
190 | def mape_loss(X_true, x_pred):
191 | return torch.mean(torch.abs((X_true - x_pred) / X_true))*100
192 |
193 |
194 | def cluster_accuracy(true_idx, pred_idx):
195 |
196 | assert true_idx.shape == pred_idx.shape, "Shapes must match"
197 | # first get the confusion matrix
198 | cm = confusion_matrix(true_idx, pred_idx)
199 |
200 | # next run a linear sum assignment problem to obtain the maximum matching
201 | row_ind, col_ind = linear_sum_assignment(cm, maximize=True)
202 |
203 | # get the maximum matching
204 | return cm[row_ind, col_ind].sum()/cm.sum()
205 |
--------------------------------------------------------------------------------
/src/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrowed from github.com/microsoft/causica
3 | """
4 |
5 | from typing import List, Optional, Type
6 | from torch.nn import Dropout, LayerNorm, Linear, Module, Sequential
7 | import torch
8 |
9 |
10 | class resBlock(Module):
11 | """
12 | Wraps an nn.Module, adding a skip connection to it.
13 | """
14 |
15 | def __init__(self, block: Module):
16 | """
17 | Args:
18 | block: module to which skip connection will be added. The input dimension must match the output dimension.
19 | """
20 | super().__init__()
21 | self.block = block
22 |
23 | def forward(self, x):
24 | return x + self.block(x)
25 |
26 |
27 | def generate_fully_connected(
28 | input_dim: int,
29 | output_dim: int,
30 | hidden_dims: List[int],
31 | non_linearity: Optional[Type[Module]],
32 | activation: Optional[Type[Module]],
33 | device: torch.device,
34 | p_dropout: float = 0.0,
35 | normalization: Optional[Type[LayerNorm]] = None,
36 | res_connection: bool = False,
37 | ) -> Module:
38 | """
39 | Generate a fully connected network.
40 |
41 | Args:
42 | input_dim: Int. Size of input to network.
43 | output_dim: Int. Size of output of network.
44 | hidden_dims: List of int. Sizes of internal hidden layers. i.e.
45 | [a, b] is three linear layers with shapes (input_dim, a), (a, b), (b, output_dim)
46 | non_linearity: Non linear activation function used between Linear layers.
47 | activation: Final layer activation to use.
48 | device: torch device to load weights to.
49 | p_dropout: Float. Dropout probability at the hidden layers.
50 | init_method: initialization method
51 | normalization: Normalisation layer to use (batchnorm, layer norm, etc). Will be placed before linear layers, excluding the input layer.
52 | res_connection : Whether to use residual connections where possible (if previous layer width matches next layer width)
53 |
54 | Returns:
55 | Sequential object containing the desired network.
56 | """
57 | layers: List[Module] = []
58 |
59 | prev_dim = input_dim
60 | for idx, hidden_dim in enumerate(hidden_dims):
61 |
62 | block: List[Module] = []
63 |
64 | if normalization is not None and idx > 0:
65 | block.append(normalization(prev_dim).to(device))
66 | block.append(Linear(prev_dim, hidden_dim).to(device))
67 |
68 | if non_linearity is not None:
69 | block.append(non_linearity())
70 | if p_dropout != 0:
71 | block.append(Dropout(p_dropout))
72 |
73 | if res_connection and (prev_dim == hidden_dim):
74 | layers.append(resBlock(Sequential(*block)))
75 | else:
76 | layers.append(Sequential(*block))
77 | prev_dim = hidden_dim
78 |
79 | if normalization is not None:
80 | layers.append(normalization(prev_dim).to(device))
81 | layers.append(Linear(prev_dim, output_dim).to(device))
82 |
83 | if activation is not None:
84 | layers.append(activation())
85 |
86 | fcnn = Sequential(*layers)
87 | return fcnn
88 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import numpy as np
4 |
5 |
6 | def standard_scaling(X, across_samples=False):
7 | # expected X of shape (n_samples, timesteps, num_nodes, data_dim) or (n_samples, timesteps, num_nodes)
8 |
9 | if across_samples:
10 | means = np.mean(X, axis=(0, 1))[np.newaxis, np.newaxis, :]
11 | std = np.std(X, axis=(0, 1))[np.newaxis, np.newaxis, :]
12 | else:
13 | means = np.mean(X, axis=(1))[:, np.newaxis]
14 | std = np.std(X, axis=(1))[:, np.newaxis]
15 |
16 | eps = 1e-6
17 | Y = (X-means) / (std + eps)
18 |
19 | return Y
20 |
21 |
22 | def min_max_scaling(X, across_samples=False):
23 | # expected X of shape (n_samples, timesteps, num_nodes, data_dim) or (n_samples, timesteps, num_nodes)
24 |
25 | if across_samples:
26 | mins = np.amin(X, axis=(0, 1))[np.newaxis, np.newaxis, :]
27 | maxs = np.amax(X, axis=(0, 1))[np.newaxis, np.newaxis, :]
28 | else:
29 | mins = np.amin(X, axis=(1))[:, np.newaxis]
30 | maxs = np.amax(X, axis=(1))[:, np.newaxis]
31 |
32 | Y = (X-mins) / (maxs - mins) * 2 - 1
33 |
34 | return Y
35 |
36 |
37 | def write_results_to_disk(dataset, metrics):
38 | # write results to file
39 | results_dir = os.path.join('results', dataset)
40 | results_file = os.path.join(results_dir, 'results.csv')
41 | file_exists = os.path.isfile(results_file)
42 |
43 | if not os.path.exists(results_dir):
44 | os.makedirs(results_dir)
45 |
46 | with open(results_file, 'a', encoding="utf-8") as csvfile:
47 | writer = csv.DictWriter(csvfile, fieldnames=list(metrics.keys()))
48 | if not file_exists:
49 | writer.writeheader()
50 | writer.writerows([metrics])
51 |
--------------------------------------------------------------------------------