├── .github └── workflows │ ├── cloc.yml │ └── pylint.yml ├── README.md ├── assets ├── overview_fig.png ├── snp100.csv └── snp100_results.png ├── configs ├── data_generation │ ├── perturb_data.yaml │ ├── synthetic_data_linear.yaml │ └── synthetic_data_nonlinear.yaml ├── dream3 │ ├── dynotears.yaml │ ├── mcd.yaml │ ├── mcd_linear.yaml │ ├── pcmci.yaml │ ├── rhino.yaml │ ├── rhino_linear.yaml │ └── varlingam.yaml ├── main.yaml ├── netsim │ ├── dynotears.yaml │ ├── mcd.yaml │ ├── mcd_linear.yaml │ ├── pcmci.yaml │ ├── rhino.yaml │ ├── rhino_linear.yaml │ └── varlingam.yaml ├── snp100 │ ├── dynotears.yaml │ ├── mcd.yaml │ ├── mcd_linear.yaml │ ├── pcmci.yaml │ ├── rhino.yaml │ └── varlingam.yaml └── synthetic │ ├── dynotears.yaml │ ├── mcd.yaml │ ├── mcd_linear.yaml │ ├── pcmci.yaml │ ├── rhino.yaml │ ├── rhino_linear.yaml │ └── varlingam.yaml ├── scripts ├── generate_perturb_datasets.sh ├── generate_snp100.sh ├── generate_synthetic_datasets_linear.sh ├── generate_synthetic_datasets_nonlinear.sh ├── setup_dream3.sh └── setup_netsim.sh ├── setup.py └── src ├── __init__.py ├── baselines ├── BaselineTrainer.py ├── DYNOTEARSTrainer.py ├── PCMCITrainer.py ├── VARLiNGAMTrainer.py └── __init__.py ├── dataset ├── BaselineTSDataset.py └── FragmentDataset.py ├── model ├── BaseTrainer.py ├── MCDTrainer.py ├── RhinoTrainer.py └── generate_model.py ├── modules ├── CausalDecoder.py ├── LinearCausalGraph.py ├── MixtureSelectionLogits.py ├── MultiCausalDecoder.py ├── MultiEmbedding.py ├── MultiLinearCausalGraph.py ├── MultiTemporalHyperNet.py ├── TemporalConditionalSplineFlow.py ├── TemporalHyperNet.py └── adjacency_matrices │ ├── AdjMatrix.py │ ├── MultiTemporalAdjacencyMatrix.py │ ├── TemporalAdjacencyMatrix.py │ ├── ThreeWayGraphDist.py │ ├── TwoWayGraphDist.py │ └── TwoWayTemporalAdjacencyMatrix.py ├── train.py ├── training └── auglag.py └── utils ├── causality_utils.py ├── config_utils.py ├── data_gen ├── data_generation_utils.py ├── generate_perturb_syn.py ├── generate_stock.py ├── generate_synthetic_data.py ├── process_dream3.py ├── process_netsim.py └── splines.py ├── data_utils ├── data_format_utils.py └── dataloading_utils.py ├── loss_utils.py ├── metrics_utils.py ├── torch_utils.py └── utils.py /.github/workflows/cloc.yml: -------------------------------------------------------------------------------- 1 | name: Count Lines of Code 2 | 3 | # Controls when the action will run. Triggers the workflow on push or pull request 4 | # events but only for the main branch 5 | on: [pull_request] 6 | 7 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 8 | jobs: 9 | 10 | # This workflow contains a single job called "build" 11 | cloc: 12 | 13 | # The type of runner that the job will run on 14 | runs-on: ubuntu-latest 15 | 16 | # Steps represent a sequence of tasks that will be executed as part of the job 17 | steps: 18 | 19 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 20 | - name: Checkout repo 21 | uses: actions/checkout@v3 22 | 23 | # Runs djdefi/cloc-action 24 | - name: Install and run cloc 25 | run: | 26 | sudo apt-get install cloc 27 | cloc src --csv --quiet --report-file=cloc_report.csv 28 | 29 | - name: Read CSV 30 | id: csv 31 | uses: juliangruber/read-file-action@v1 32 | with: 33 | path: ./cloc_report.csv 34 | 35 | - name: Create MD 36 | uses: petems/csv-to-md-table-action@master 37 | id: csv-table-output 38 | with: 39 | csvinput: ${{ steps.csv.outputs.content }} 40 | 41 | - name: Write file 42 | uses: "DamianReeves/write-file-action@master" 43 | with: 44 | path: ./cloc_report.md 45 | write-mode: overwrite 46 | contents: | 47 | ${{steps.csv-table-output.outputs.markdown-table}} 48 | 49 | - name: PR comment with file 50 | uses: thollander/actions-comment-pull-request@v2 51 | with: 52 | filePath: ./cloc_report.md 53 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.8", "3.9", "3.10"] 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install pylint 21 | - name: Analysing the code with pylint 22 | run: | 23 | pylint --disable=E402,E731,F541,W291,E122,E127,F401,E266,E241,C901,E741,W293,F811,W503,E203,F403,F405,B007,E0401,W0221 --max-line-length=150 $(git ls-files '*.py') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Discovering Mixtures of Structural Causal Models from Time Series Data 2 | 3 | Implementation of the paper "Discovering Mixtures of Structural Causal Models from Time Series Data", to appear at ICML 2024, Vienna. 4 | 5 | Mixture Causal Discovery (MCD) aims to infer multiple causal graphs from time-series data. 6 | 7 |
8 | Overview 9 |

Model Overview

10 |
11 | 12 |
13 | S&P100 Results 14 |

Results on the S&P100 dataset

15 |
16 | 17 | 18 | ## Requirements 19 | 20 | - NVIDIA GPU with minimum CUDA 11.8 installed. 21 | - Make sure you have `conda` installed. 22 | 23 | ## Setup 24 | 25 | Create a conda environment and install the prerequisite packages: 26 | ``` 27 | conda create -n mcd python=3.9 -y && \ 28 | conda run --no-capture-output -n mcd pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 && \ 29 | conda run --no-capture-output -n mcd pip3 install lightning matplotlib numpy scikit-learn seaborn \ 30 | cdt wandb igraph pyro-ppl hydra-core yahoofinancials && \ 31 | ``` 32 | 33 | You also need the `graphviz` library. This library can be installed on Ubuntu systems using the command: 34 | ``` 35 | sudo apt-get install -y git && \ 36 | sudo apt-get install -y graphviz graphviz-dev 37 | ``` 38 | 39 | For baselines: 40 | ``` 41 | conda create -n baselines python=3.8 -y && \ 42 | conda run --no-capture-output -n baselines pip3 install pygraphviz wandb tigramite hydra-core pyro-ppl lightning causalnex matplotlib cdt seaborn lingam 43 | ``` 44 | 45 | 46 | ## Dataset generation 47 | 48 | Generate the datasets using the script files in the `scripts/` folder. 49 | 50 | - Linear synthetic dataset: Run `./scripts/generate_synthetic_datasets_linear.sh` 51 | - Nonlinear synthetic dataset: Run `./scripts/generate_synthetic_datasets_nonlinear.sh` 52 | - Netsim datasets: Run `./scripts/setup_netsim.sh` 53 | - DREAM3: Run `./scripts/setup_dream3.sh` 54 | - S&P100: Run `./scripts/generate_snp100.sh` 55 | 56 | ## Running the code 57 | 58 | Change the name of the `wandb` project in the config file. 59 | 60 | - Linear synthetic dataset: Run `python3 -m src.train +dataset=ER_ER_num_graphs__lag_2_dim__NoHistDep_0.5_linear_gaussian_con_1_seed_0_n_samples_1000 +synthetic=mcd_linear`. Change `` and `` to the correct setting. 61 | - Nonlinear synthetic dataset: Run `python3 -m src.train +dataset=ER_ER_num_graphs__lag_2_dim__HistDep_0.5_mlp_spline_product_con_1_seed_0_n_samples_1000 +synthetic=mcd`. Change `` and `` to the correct setting. 62 | - Netsim-mixture: Run `python3 -m src.train +dataset=netsim_15_200_permuted +netsim=mcd` 63 | - DREAM3: Run `python3 -m src.train +dataset=dream3 +dream3=mcd` 64 | - S&P100: Run `python3 -m src.train +dataset=snp100 +snp100=mcd` 65 | 66 | Results are stored in the `results/` folder. 67 | 68 | ## Acknowledgement 69 | 70 | We implemented some parts of our framework using code from [Project Causica](https://github.com/microsoft/causica). 71 | 72 | ## Citation 73 | 74 | If you find this work useful, please consider citing us. 75 | 76 | ``` 77 | @inproceedings{varambally2024discovering, 78 | author = {Varambally, Sumanth and Ma, Yi-An and Yu, Rose}, 79 | title = {Discovering Mixtures of Structural Causal Models from Time Series Data}, 80 | booktitle = {International Conference on Machine Learning, {ICML} 2024}, 81 | series = {Proceedings of Machine Learning Research}, 82 | year = {2024} 83 | } 84 | ``` 85 | -------------------------------------------------------------------------------- /assets/overview_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/MCD/5276806eb761545d2e99f53d5135289d7def20e3/assets/overview_fig.png -------------------------------------------------------------------------------- /assets/snp100.csv: -------------------------------------------------------------------------------- 1 | Symbol,Name,Sector 2 | AAPL,Apple,Technology 3 | ABBV,AbbVie,Healthcare 4 | ABT,Abbott,Healthcare 5 | ACN,Accenture,Technology 6 | ADBE,Adobe,Technology 7 | AIG,American International Group,Financials 8 | AMD,AMD,Technology 9 | AMGN,Amgen,Healthcare 10 | AMT,American Tower,Real Estate 11 | AMZN,Amazon,Consumer Discretionary 12 | AVGO,Broadcom,Technology 13 | AXP,American Express,Financials 14 | BA,Boeing,Industrials 15 | BAC,Bank of America,Financials 16 | BK,BNY Mellon,Financials 17 | BKNG,Booking Holdings,Consumer Discretionary 18 | BLK,BlackRock,Financials 19 | BMY,Bristol Myers Squibb,Healthcare 20 | BRK-B,Berkshire Hathaway (Class B),Financials 21 | C,Citigroup,Financials 22 | CAT,Caterpillar,Industrials 23 | CHTR,Charter Communications,Communication Services 24 | CL,Colgate-Palmolive,Consumer Staples 25 | CMCSA,Comcast,Communication Services 26 | COF,Capital One,Financials 27 | COP,ConocoPhillips,Energy 28 | COST,Costco,Consumer Staples 29 | CRM,Salesforce,Technology 30 | CSCO,Cisco,Technology 31 | CVS,CVS Health,Healthcare 32 | CVX,Chevron,Energy 33 | DE,Deere & Company,Industrials 34 | DHR,Danaher,Healthcare 35 | DIS,Disney,Communication Services 36 | DUK,Duke Energy,Utilities 37 | EMR,Emerson,Industrials 38 | EXC,Exelon,Utilities 39 | F,Ford,Consumer Discretionary 40 | FDX,FedEx,Industrials 41 | GD,General Dynamics,Industrials 42 | GE,GE,Industrials 43 | GILD,Gilead,Healthcare 44 | GM,GM,Consumer Discretionary 45 | GOOG,Alphabet (Class C),Communication Services 46 | GOOGL,Alphabet (Class A),Communication Services 47 | GS,Goldman Sachs,Financials 48 | HD,Home Depot,Consumer Discretionary 49 | HON,Honeywell,Industrials 50 | IBM,IBM,Technology 51 | INTC,Intel,Technology 52 | JNJ,Johnson & Johnson,Healthcare 53 | JPM,JPMorgan Chase,Financials 54 | KHC,Kraft Heinz,Consumer Staples 55 | KO,Coca-Cola,Consumer Staples 56 | LIN,Linde,Materials 57 | LLY,Lilly,Healthcare 58 | LMT,Lockheed Martin,Industrials 59 | LOW,Lowe's,Consumer Discretionary 60 | MA,Mastercard,Technology 61 | MCD,McDonald's,Consumer Discretionary 62 | MDLZ,Mondelēz International,Consumer Staples 63 | MDT,Medtronic,Healthcare 64 | MET,MetLife,Financials 65 | META,Meta,Communication Services 66 | MMM,3M,Industrials 67 | MO,Altria,Consumer Staples 68 | MRK,Merck,Healthcare 69 | MS,Morgan Stanley,Financials 70 | MSFT,Microsoft,Technology 71 | NEE,NextEra Energy,Utilities 72 | NFLX,Netflix,Communication Services 73 | NKE,Nike,Consumer Discretionary 74 | NVDA,Nvidia,Technology 75 | ORCL,Oracle,Technology 76 | PEP,PepsiCo,Consumer Staples 77 | PFE,Pfizer,Healthcare 78 | PG,Procter & Gamble,Consumer Staples 79 | PM,Philip Morris International,Consumer Staples 80 | PYPL,PayPal,Technology 81 | QCOM,Qualcomm,Technology 82 | RTX,RTX Corporation,Industrials 83 | SBUX,Starbucks,Consumer Discretionary 84 | SCHW,Charles Schwab,Financials 85 | SO,Southern Company,Utilities 86 | SPG,Simon,Real Estate 87 | T,AT&T,Communication Services 88 | TGT,Target,Consumer Discretionary 89 | TMO,Thermo Fisher Scientific,Healthcare 90 | TMUS,T-Mobile,Communication Services 91 | TSLA,Tesla,Consumer Discretionary 92 | TXN,Texas Instruments,Technology 93 | UNH,UnitedHealth Group,Healthcare 94 | UNP,Union Pacific,Industrials 95 | UPS,United Parcel Service,Industrials 96 | USB,U.S. Bank,Financials 97 | V,Visa,Technology 98 | VZ,Verizon,Communication Services 99 | WFC,Wells Fargo,Financials 100 | WMT,Walmart,Consumer Staples 101 | XOM,ExxonMobil,Energy -------------------------------------------------------------------------------- /assets/snp100_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/MCD/5276806eb761545d2e99f53d5135289d7def20e3/assets/snp100_results.png -------------------------------------------------------------------------------- /configs/data_generation/perturb_data.yaml: -------------------------------------------------------------------------------- 1 | 2 | # graph settings 3 | num_timesteps: 100 4 | lag: 2 5 | history_dep_noise : true 6 | burnin_length: 100 7 | noise_level: 0.5 8 | function_type: mlp # options: mlp, spline, spline_product, conditional_spline, mlp_noise, inverse_noise_spline 9 | noise_function_type: spline_product # options: same as function_type 10 | base_noise_type: gaussian # options: gaussian, uniform 11 | num_samples: 1000 12 | 13 | disable_inst: false # whether to exclude instantaneous edges 14 | inst_graph_type: ER # options: SF, ER 15 | lag_graph_type: ER # options: SF, ER 16 | save_dir: "data/synthetic" 17 | 18 | random_seed: 19 | - 0 20 | 21 | p_array: 22 | - 0.005 23 | - 0.008 24 | - 0.01 25 | - 0.05 26 | - 0.1 27 | 28 | num_nodes: 10 29 | num_graphs: 30 | - 5 31 | - 10 -------------------------------------------------------------------------------- /configs/data_generation/synthetic_data_linear.yaml: -------------------------------------------------------------------------------- 1 | 2 | # graph settings 3 | num_timesteps: 100 4 | lag: 2 5 | history_dep_noise : false 6 | burnin_length: 100 7 | noise_level: 0.5 8 | function_type: linear # options: mlp, spline, linear, spline_product, conditional_spline, mlp_noise, inverse_noise_spline 9 | noise_function_type: spline_product # options: same as function_type 10 | base_noise_type: gaussian # options: gaussian, uniform 11 | num_samples: 1000 12 | 13 | disable_inst: false # whether to exclude instantaneous edges 14 | inst_graph_type: ER # options: SF, ER 15 | lag_graph_type: ER # options: SF, ER 16 | save_dir: "data/synthetic" 17 | 18 | random_seed: 19 | - 0 20 | # - 1 21 | # - 2 22 | # - 3 23 | # - 4 24 | # - 5 25 | 26 | num_nodes: 27 | - 5 28 | - 10 29 | - 20 30 | - 40 31 | 32 | num_graphs: 33 | - 1 34 | - 5 35 | - 10 36 | - 20 37 | # - 40 38 | # - 60 39 | # - 80 40 | # - -------------------------------------------------------------------------------- /configs/data_generation/synthetic_data_nonlinear.yaml: -------------------------------------------------------------------------------- 1 | 2 | # graph settings 3 | num_timesteps: 100 4 | lag: 2 5 | history_dep_noise : true 6 | burnin_length: 100 7 | noise_level: 0.5 8 | function_type: mlp # options: mlp, spline, linear, spline_product, conditional_spline, mlp_noise, inverse_noise_spline 9 | noise_function_type: spline_product # options: same as function_type 10 | base_noise_type: gaussian # options: gaussian, uniform 11 | num_samples: 1000 12 | 13 | disable_inst: false # whether to exclude instantaneous edges 14 | inst_graph_type: ER # options: SF, ER 15 | lag_graph_type: ER # options: SF, ER 16 | save_dir: "data/synthetic" 17 | 18 | random_seed: 19 | - 0 20 | # - 1 21 | # - 2 22 | # - 3 23 | # - 4 24 | # - 5 25 | 26 | num_nodes: 27 | - 5 28 | - 10 29 | - 20 30 | - 40 31 | 32 | num_graphs: 33 | - 1 34 | - 5 35 | - 10 36 | - 20 37 | # - 40 38 | # - 60 39 | # - 80 40 | # - -------------------------------------------------------------------------------- /configs/dream3/dynotears.yaml: -------------------------------------------------------------------------------- 1 | model: dynotears 2 | lag: 2 3 | dream3_size: 100 4 | 5 | trainer: 6 | _target_: src.baselines.DYNOTEARSTrainer.DYNOTEARSTrainer 7 | _partial_: true 8 | 9 | single_graph: false 10 | max_iter: 1000 11 | lambda_w: 0.1 12 | lambda_a: 0.1 13 | w_threshold: 0.01 14 | h_tol: 1e-8 15 | 16 | group_by_graph: true -------------------------------------------------------------------------------- /configs/dream3/mcd.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: mcd 5 | lag: 2 6 | dream3_size: 100 7 | monitor_checkpoint_based_on: likelihood 8 | 9 | hypernet: 10 | _target_: src.modules.MultiTemporalHyperNet.MultiTemporalHyperNet 11 | order: linear 12 | num_bins: 8 13 | skip_connection: true 14 | embedding_dim: 32 15 | dropout_p: 0.2 16 | 17 | # decoder options 18 | decoder: 19 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder 20 | skip_connection: true 21 | embedding_dim: 32 22 | dropout_p: 0.2 23 | 24 | trainer: 25 | _target_: src.model.MCDTrainer.MCDTrainer 26 | _partial_: true 27 | 28 | batch_size: 64 29 | sparsity_factor: 10 30 | likelihood_loss: flow 31 | matrix_temperature: 0.25 32 | threeway_graph_dist: true 33 | training_procedure: auglag 34 | skip_auglag_epochs: 0 35 | init_logits: [-3, -3] 36 | 37 | # multi-rhino settings 38 | num_graphs: 10 39 | use_correct_mixture_index: false 40 | use_true_graph: false 41 | log_num_unique_graphs: false 42 | disable_inst: true 43 | 44 | auglag_config: 45 | _target_: src.training.auglag.AugLagLRConfig 46 | 47 | lr_update_lag: 500 48 | lr_update_lag_best: 500 49 | lr_init_dict: 50 | vardist: 0.001 51 | functional_relationships: 0.001 52 | noise_dist: 0.001 53 | mixing_probs: 0.01 54 | 55 | aggregation_period: 20 56 | lr_factor: 0.1 57 | max_lr_down: 3 58 | penalty_progress_rate: 0.65 59 | safety_rho: 1e13 60 | safety_alpha: 1e13 61 | inner_early_stopping_patience: 1500 62 | max_outer_steps: 60 63 | patience_penalty_reached: 10 64 | patience_max_rho: 3 65 | penalty_tolerance: -1 66 | max_inner_steps: 6000 67 | force_not_converged: false 68 | init_rho: 1 69 | init_alpha: 0 -------------------------------------------------------------------------------- /configs/dream3/mcd_linear.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: mcd 5 | lag: 2 6 | dream3_size: 100 7 | monitor_checkpoint_based_on: likelihood 8 | 9 | # decoder options 10 | decoder: 11 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder 12 | linear: true 13 | embedding_dim: 32 14 | 15 | trainer: 16 | _target_: src.model.MCDTrainer.MCDTrainer 17 | _partial_: true 18 | 19 | batch_size: 64 20 | sparsity_factor: 25 21 | likelihood_loss: mse 22 | matrix_temperature: 0.25 23 | threeway_graph_dist: true 24 | training_procedure: auglag 25 | skip_auglag_epochs: 0 26 | init_logits: [-3, -3] 27 | 28 | # multi-rhino settings 29 | num_graphs: 10 30 | use_correct_mixture_index: false 31 | use_true_graph: false 32 | log_num_unique_graphs: false 33 | disable_inst: true 34 | 35 | auglag_config: 36 | _target_: src.training.auglag.AugLagLRConfig 37 | 38 | lr_update_lag: 500 39 | lr_update_lag_best: 500 40 | lr_init_dict: 41 | vardist: 0.001 42 | functional_relationships: 0.001 43 | noise_dist: 0.001 44 | mixing_probs: 0.01 45 | 46 | aggregation_period: 20 47 | lr_factor: 0.1 48 | max_lr_down: 3 49 | penalty_progress_rate: 0.65 50 | safety_rho: 1e13 51 | safety_alpha: 1e13 52 | inner_early_stopping_patience: 1500 53 | max_outer_steps: 60 54 | patience_penalty_reached: 10 55 | patience_max_rho: 3 56 | penalty_tolerance: -1 57 | max_inner_steps: 6000 58 | force_not_converged: false 59 | init_rho: 1 60 | init_alpha: 0 -------------------------------------------------------------------------------- /configs/dream3/pcmci.yaml: -------------------------------------------------------------------------------- 1 | model: pcmci 2 | lag: 2 3 | dream3_size: 100 4 | 5 | trainer: 6 | _target_: src.baselines.PCMCITrainer.PCMCITrainer 7 | _partial_: true 8 | 9 | ci_test: ParCorr 10 | single_graph: true 11 | pcmci_plus: true 12 | pc_alpha: 0.01 13 | 14 | group_by_graph: false -------------------------------------------------------------------------------- /configs/dream3/rhino.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: rhino 5 | lag: 2 6 | dream3_size: 100 7 | monitor_checkpoint_based_on: likelihood 8 | 9 | hypernet: 10 | _target_: src.modules.TemporalHyperNet.TemporalHyperNet 11 | order: linear 12 | num_bins: 8 13 | skip_connection: true 14 | embedding_dim: 32 15 | 16 | # decoder options 17 | decoder: 18 | _target_: src.modules.CausalDecoder.CausalDecoder 19 | skip_connection: true 20 | embedding_dim: 32 21 | 22 | trainer: 23 | _target_: src.model.RhinoTrainer.RhinoTrainer 24 | _partial_: true 25 | 26 | batch_size: 64 27 | sparsity_factor: 25 28 | likelihood_loss: flow 29 | matrix_temperature: 0.25 30 | threeway_graph_dist: true 31 | training_procedure: auglag 32 | skip_auglag_epochs: 0 33 | init_logits: [-3, -3] 34 | disable_inst: true 35 | 36 | auglag_config: 37 | _target_: src.training.auglag.AugLagLRConfig 38 | 39 | lr_update_lag: 500 40 | lr_update_lag_best: 500 41 | lr_init_dict: 42 | vardist: 0.001 43 | functional_relationships: 0.001 44 | noise_dist: 0.001 45 | 46 | aggregation_period: 20 47 | lr_factor: 0.1 48 | max_lr_down: 3 49 | penalty_progress_rate: 0.65 50 | safety_rho: 1e13 51 | safety_alpha: 1e13 52 | inner_early_stopping_patience: 1500 53 | max_outer_steps: 60 54 | patience_penalty_reached: 10 55 | patience_max_rho: 3 56 | penalty_tolerance: -1 57 | max_inner_steps: 6000 58 | force_not_converged: false 59 | init_rho: 1 60 | init_alpha: 0 -------------------------------------------------------------------------------- /configs/dream3/rhino_linear.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: rhino 5 | lag: 2 6 | dream3_size: 100 7 | monitor_checkpoint_based_on: likelihood 8 | 9 | # decoder options 10 | decoder: 11 | _target_: src.modules.CausalDecoder.CausalDecoder 12 | linear: true 13 | 14 | trainer: 15 | _target_: src.model.RhinoTrainer.RhinoTrainer 16 | _partial_: true 17 | 18 | batch_size: 64 19 | sparsity_factor: 25 20 | likelihood_loss: mse 21 | matrix_temperature: 0.25 22 | threeway_graph_dist: true 23 | training_procedure: auglag 24 | skip_auglag_epochs: 0 25 | init_logits: [-3, -3] 26 | disable_inst: true 27 | 28 | auglag_config: 29 | _target_: src.training.auglag.AugLagLRConfig 30 | 31 | lr_update_lag: 500 32 | lr_update_lag_best: 500 33 | lr_init_dict: 34 | vardist: 0.001 35 | functional_relationships: 0.001 36 | noise_dist: 0.001 37 | 38 | aggregation_period: 20 39 | lr_factor: 0.1 40 | max_lr_down: 3 41 | penalty_progress_rate: 0.65 42 | safety_rho: 1e13 43 | safety_alpha: 1e13 44 | inner_early_stopping_patience: 1500 45 | max_outer_steps: 60 46 | patience_penalty_reached: 10 47 | patience_max_rho: 3 48 | penalty_tolerance: -1 49 | max_inner_steps: 6000 50 | force_not_converged: false 51 | init_rho: 1 52 | init_alpha: 0 -------------------------------------------------------------------------------- /configs/dream3/varlingam.yaml: -------------------------------------------------------------------------------- 1 | model: varlingam 2 | lag: 2 3 | dream3_size: 100 4 | 5 | trainer: 6 | _target_: src.baselines.VARLiNGAMTrainer.VARLiNGAMTrainer 7 | _partial_: true -------------------------------------------------------------------------------- /configs/main.yaml: -------------------------------------------------------------------------------- 1 | 2 | precision: 32 3 | wandb_project: causal_mixture_model 4 | gpu: [0] 5 | num_workers: 8 6 | dataset_dir: 'data/' 7 | random_seed: 0 8 | -------------------------------------------------------------------------------- /configs/netsim/dynotears.yaml: -------------------------------------------------------------------------------- 1 | model: dynotears 2 | 3 | lag: 2 4 | trainer: 5 | _target_: src.baselines.DYNOTEARSTrainer.DYNOTEARSTrainer 6 | _partial_: true 7 | 8 | single_graph: true 9 | max_iter: 1000 10 | lambda_w: 0.1 11 | lambda_a: 0.1 12 | w_threshold: 0.01 13 | h_tol: 1e-8 14 | 15 | group_by_graph: false -------------------------------------------------------------------------------- /configs/netsim/mcd.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: mcd 5 | lag: 2 6 | monitor_checkpoint_based_on: likelihood 7 | 8 | hypernet: 9 | _target_: src.modules.MultiTemporalHyperNet.MultiTemporalHyperNet 10 | order: linear 11 | num_bins: 8 12 | skip_connection: true 13 | 14 | # decoder options 15 | decoder: 16 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder 17 | skip_connection: true 18 | 19 | trainer: 20 | _target_: src.model.MCDTrainer.MCDTrainer 21 | _partial_: true 22 | 23 | batch_size: 64 24 | sparsity_factor: 25 25 | likelihood_loss: flow 26 | matrix_temperature: 0.25 27 | threeway_graph_dist: true 28 | training_procedure: auglag 29 | skip_auglag_epochs: 0 30 | graph_selection_prior_lambda: 0.0 31 | 32 | # multi-rhino settings 33 | num_graphs: 20 34 | use_correct_mixture_index: false 35 | use_true_graph: false 36 | log_num_unique_graphs: false 37 | 38 | auglag_config: 39 | _target_: src.training.auglag.AugLagLRConfig 40 | 41 | lr_update_lag: 500 42 | lr_update_lag_best: 500 43 | lr_init_dict: 44 | vardist: 0.01 45 | functional_relationships: 0.001 46 | noise_dist: 0.001 47 | mixing_probs: 0.01 48 | 49 | aggregation_period: 20 50 | lr_factor: 0.1 51 | max_lr_down: 3 52 | penalty_progress_rate: 0.65 53 | safety_rho: 1e13 54 | safety_alpha: 1e13 55 | inner_early_stopping_patience: 1500 56 | max_outer_steps: 60 57 | patience_penalty_reached: 100 58 | patience_max_rho: 50 59 | penalty_tolerance: 1e-8 60 | max_inner_steps: 2000 61 | force_not_converged: false 62 | init_rho: 1 63 | init_alpha: 0 -------------------------------------------------------------------------------- /configs/netsim/mcd_linear.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: mcd 5 | lag: 2 6 | monitor_checkpoint_based_on: likelihood 7 | 8 | # decoder options 9 | decoder: 10 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder 11 | linear: true 12 | 13 | trainer: 14 | _target_: src.model.MCDTrainer.MCDTrainer 15 | _partial_: true 16 | 17 | batch_size: 64 18 | sparsity_factor: 25 19 | likelihood_loss: mse 20 | matrix_temperature: 0.25 21 | threeway_graph_dist: true 22 | training_procedure: auglag 23 | skip_auglag_epochs: 0 24 | graph_selection_prior_lambda: 0.0 25 | 26 | # multi-rhino settings 27 | num_graphs: 10 28 | use_correct_mixture_index: false 29 | use_true_graph: false 30 | log_num_unique_graphs: false 31 | 32 | auglag_config: 33 | _target_: src.training.auglag.AugLagLRConfig 34 | 35 | lr_update_lag: 500 36 | lr_update_lag_best: 500 37 | lr_init_dict: 38 | vardist: 0.01 39 | functional_relationships: 0.001 40 | noise_dist: 0.001 41 | mixing_probs: 0.01 42 | 43 | aggregation_period: 20 44 | lr_factor: 0.1 45 | max_lr_down: 3 46 | penalty_progress_rate: 0.65 47 | safety_rho: 1e13 48 | safety_alpha: 1e13 49 | inner_early_stopping_patience: 1500 50 | max_outer_steps: 60 51 | patience_penalty_reached: 100 52 | patience_max_rho: 50 53 | penalty_tolerance: 1e-8 54 | max_inner_steps: 2000 55 | force_not_converged: false 56 | init_rho: 1 57 | init_alpha: 0 -------------------------------------------------------------------------------- /configs/netsim/pcmci.yaml: -------------------------------------------------------------------------------- 1 | model: pcmci 2 | lag: 2 3 | 4 | trainer: 5 | _target_: src.baselines.PCMCITrainer.PCMCITrainer 6 | _partial_: true 7 | 8 | ci_test: ParCorr 9 | single_graph: false 10 | pcmci_plus: true 11 | pc_alpha: 0.01 12 | 13 | group_by_graph: true -------------------------------------------------------------------------------- /configs/netsim/rhino.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: rhino 5 | monitor_checkpoint_based_on: likelihood 6 | lag: 2 7 | 8 | hypernet: 9 | _target_: src.modules.TemporalHyperNet.TemporalHyperNet 10 | order: linear 11 | num_bins: 8 12 | skip_connection: true 13 | 14 | # decoder options 15 | decoder: 16 | _target_: src.modules.CausalDecoder.CausalDecoder 17 | skip_connection: true 18 | 19 | trainer: 20 | _target_: src.model.RhinoTrainer.RhinoTrainer 21 | _partial_: true 22 | 23 | batch_size: 128 24 | sparsity_factor: 25 25 | 26 | likelihood_loss: flow 27 | matrix_temperature: 0.25 28 | threeway_graph_dist: true 29 | training_procedure: auglag # options: auglag, dagma 30 | skip_auglag_epochs: 0 31 | 32 | auglag_config: 33 | _target_: src.training.auglag.AugLagLRConfig 34 | 35 | lr_update_lag: 500 36 | lr_update_lag_best: 500 37 | lr_init_dict: 38 | vardist: 0.001 39 | functional_relationships: 0.001 40 | noise_dist: 0.001 41 | 42 | aggregation_period: 20 43 | lr_factor: 0.1 44 | max_lr_down: 3 45 | penalty_progress_rate: 0.65 46 | safety_rho: 1e13 47 | safety_alpha: 1e13 48 | inner_early_stopping_patience: 1500 49 | max_outer_steps: 60 50 | patience_penalty_reached: 100 51 | patience_max_rho: 50 52 | penalty_tolerance: 1e-8 53 | max_inner_steps: 2000 54 | force_not_converged: false 55 | init_rho: 1 56 | init_alpha: 0 57 | -------------------------------------------------------------------------------- /configs/netsim/rhino_linear.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: rhino 5 | monitor_checkpoint_based_on: likelihood 6 | lag: 2 7 | 8 | # decoder options 9 | decoder: 10 | _target_: src.modules.CausalDecoder.CausalDecoder 11 | linear: true 12 | 13 | trainer: 14 | _target_: src.model.RhinoTrainer.RhinoTrainer 15 | _partial_: true 16 | 17 | batch_size: 64 18 | sparsity_factor: 25 19 | 20 | likelihood_loss: mse 21 | matrix_temperature: 0.25 22 | threeway_graph_dist: true 23 | training_procedure: auglag 24 | skip_auglag_epochs: 0 25 | 26 | auglag_config: 27 | _target_: src.training.auglag.AugLagLRConfig 28 | 29 | lr_update_lag: 500 30 | lr_update_lag_best: 500 31 | lr_init_dict: 32 | vardist: 0.001 33 | functional_relationships: 0.001 34 | noise_dist: 0.001 35 | 36 | aggregation_period: 20 37 | lr_factor: 0.1 38 | max_lr_down: 3 39 | penalty_progress_rate: 0.65 40 | safety_rho: 1e13 41 | safety_alpha: 1e13 42 | inner_early_stopping_patience: 1500 43 | max_outer_steps: 60 44 | patience_penalty_reached: 100 45 | patience_max_rho: 50 46 | penalty_tolerance: 1e-8 47 | max_inner_steps: 2000 48 | force_not_converged: false 49 | init_rho: 1 50 | init_alpha: 0 51 | -------------------------------------------------------------------------------- /configs/netsim/varlingam.yaml: -------------------------------------------------------------------------------- 1 | model: varlingam 2 | lag: 2 3 | 4 | trainer: 5 | _target_: src.baselines.VARLiNGAMTrainer.VARLiNGAMTrainer 6 | _partial_: true -------------------------------------------------------------------------------- /configs/snp100/dynotears.yaml: -------------------------------------------------------------------------------- 1 | model: dynotears 2 | 3 | trainer: 4 | _target_: src.baselines.DYNOTEARSTrainer.DYNOTEARSTrainer 5 | _partial_: true 6 | 7 | single_graph: true 8 | max_iter: 1000 9 | lambda_w: 0.10 10 | lambda_a: 0.10 11 | w_threshold: 0.01 12 | h_tol: 1e-8 13 | lag: 1 14 | group_by_graph: false -------------------------------------------------------------------------------- /configs/snp100/mcd.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 100000 4 | model: mcd 5 | monitor_checkpoint_based_on: likelihood 6 | val_every_n_epochs: 10 7 | 8 | hypernet: 9 | _target_: src.modules.MultiTemporalHyperNet.MultiTemporalHyperNet 10 | order: linear 11 | num_bins: 8 12 | skip_connection: true 13 | 14 | # decoder options 15 | decoder: 16 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder 17 | skip_connection: true 18 | 19 | trainer: 20 | _target_: src.model.MCDTrainer.MCDTrainer 21 | _partial_: true 22 | 23 | shuffle: False 24 | batch_size: 64 25 | sparsity_factor: 20 26 | likelihood_loss: flow 27 | matrix_temperature: 0.25 28 | threeway_graph_dist: true 29 | training_procedure: auglag 30 | skip_auglag_epochs: 0 31 | 32 | # multi-rhino settings 33 | num_graphs: 5 34 | use_correct_mixture_index: false 35 | use_true_graph: false 36 | log_num_unique_graphs: true 37 | init_logits: [-3, -3] 38 | use_all_for_val: False 39 | 40 | auglag_config: 41 | _target_: src.training.auglag.AugLagLRConfig 42 | 43 | lr_update_lag: 500 44 | lr_update_lag_best: 500 45 | lr_init_dict: 46 | vardist: 0.01 47 | functional_relationships: 0.001 48 | noise_dist: 0.001 49 | mixing_probs: 0.01 50 | 51 | aggregation_period: 20 52 | lr_factor: 0.1 53 | max_lr_down: 3 54 | penalty_progress_rate: 0.65 55 | safety_rho: 1e13 56 | safety_alpha: 1e13 57 | inner_early_stopping_patience: 1500 58 | max_outer_steps: 60 59 | patience_penalty_reached: 5 60 | patience_max_rho: 3 61 | penalty_tolerance: 1e-5 62 | max_inner_steps: 2000 63 | force_not_converged: false 64 | init_rho: 1 65 | init_alpha: 0 -------------------------------------------------------------------------------- /configs/snp100/mcd_linear.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 2000 4 | model: mcd 5 | monitor_checkpoint_based_on: likelihood 6 | val_every_n_epochs: 10 7 | 8 | # decoder options 9 | decoder: 10 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder 11 | linear: true 12 | 13 | trainer: 14 | _target_: src.model.MCDTrainer.MCDTrainer 15 | _partial_: true 16 | 17 | batch_size: 64 18 | sparsity_factor: 10 19 | likelihood_loss: mse 20 | matrix_temperature: 0.25 21 | threeway_graph_dist: true 22 | training_procedure: auglag 23 | skip_auglag_epochs: 0 24 | 25 | # multi-rhino settings 26 | num_graphs: 2 27 | use_correct_mixture_index: false 28 | use_true_graph: false 29 | log_num_unique_graphs: false 30 | init_logits: [-3, -3] 31 | use_all_for_val: True 32 | 33 | auglag_config: 34 | _target_: src.training.auglag.AugLagLRConfig 35 | 36 | lr_update_lag: 500 37 | lr_update_lag_best: 500 38 | lr_init_dict: 39 | vardist: 0.01 40 | functional_relationships: 0.01 41 | noise_dist: 0.001 42 | mixing_probs: 0.01 43 | 44 | aggregation_period: 20 45 | lr_factor: 0.1 46 | max_lr_down: 3 47 | penalty_progress_rate: 0.65 48 | safety_rho: 1e13 49 | safety_alpha: 1e13 50 | inner_early_stopping_patience: 1500 51 | max_outer_steps: 100 52 | patience_penalty_reached: 100 53 | patience_max_rho: 50 54 | penalty_tolerance: 1e-8 55 | max_inner_steps: 6000 56 | force_not_converged: false 57 | init_rho: 1 58 | init_alpha: 0 -------------------------------------------------------------------------------- /configs/snp100/pcmci.yaml: -------------------------------------------------------------------------------- 1 | model: pcmci 2 | trainer: 3 | _target_: src.baselines.PCMCITrainer.PCMCITrainer 4 | _partial_: true 5 | 6 | ci_test: ParCorr 7 | single_graph: false 8 | pcmci_plus: true 9 | pc_alpha: 0.01 10 | 11 | group_by_graph: false -------------------------------------------------------------------------------- /configs/snp100/rhino.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 2000 4 | model: rhino 5 | monitor_checkpoint_based_on: likelihood 6 | 7 | hypernet: 8 | _target_: src.modules.TemporalHyperNet.TemporalHyperNet 9 | order: linear 10 | num_bins: 8 11 | skip_connection: true 12 | 13 | # decoder options 14 | decoder: 15 | _target_: src.modules.CausalDecoder.CausalDecoder 16 | skip_connection: true 17 | 18 | trainer: 19 | _target_: src.model.RhinoTrainer.RhinoTrainer 20 | _partial_: true 21 | 22 | batch_size: 128 23 | sparsity_factor: 10 24 | 25 | likelihood_loss: flow 26 | matrix_temperature: 0.25 27 | threeway_graph_dist: true 28 | training_procedure: auglag # options: auglag, dagma 29 | skip_auglag_epochs: 0 30 | init_logits: [-3, -3] 31 | 32 | auglag_config: 33 | _target_: src.training.auglag.AugLagLRConfig 34 | 35 | lr_update_lag: 500 36 | lr_update_lag_best: 500 37 | lr_init_dict: 38 | vardist: 0.001 39 | functional_relationships: 0.001 40 | noise_dist: 0.001 41 | 42 | aggregation_period: 20 43 | lr_factor: 0.1 44 | max_lr_down: 3 45 | penalty_progress_rate: 0.65 46 | safety_rho: 1e13 47 | safety_alpha: 1e13 48 | inner_early_stopping_patience: 1500 49 | max_outer_steps: 100 50 | patience_penalty_reached: 100 51 | patience_max_rho: 3 52 | penalty_tolerance: 1e-5 53 | max_inner_steps: 6000 54 | force_not_converged: false 55 | init_rho: 1 56 | init_alpha: 0 57 | -------------------------------------------------------------------------------- /configs/snp100/varlingam.yaml: -------------------------------------------------------------------------------- 1 | model: varlingam 2 | trainer: 3 | _target_: src.baselines.VARLiNGAMTrainer.VARLiNGAMTrainer 4 | _partial_: true -------------------------------------------------------------------------------- /configs/synthetic/dynotears.yaml: -------------------------------------------------------------------------------- 1 | model: dynotears 2 | 3 | trainer: 4 | _target_: src.baselines.DYNOTEARSTrainer.DYNOTEARSTrainer 5 | _partial_: true 6 | 7 | single_graph: false 8 | max_iter: 1000 9 | lambda_w: 0.05 10 | lambda_a: 0.05 11 | w_threshold: 0.01 12 | h_tol: 1e-8 13 | 14 | group_by_graph: false -------------------------------------------------------------------------------- /configs/synthetic/mcd.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: mcd 5 | monitor_checkpoint_based_on: likelihood 6 | 7 | hypernet: 8 | _target_: src.modules.MultiTemporalHyperNet.MultiTemporalHyperNet 9 | order: quadratic 10 | num_bins: 8 11 | skip_connection: true 12 | 13 | # decoder options 14 | decoder: 15 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder 16 | skip_connection: true 17 | 18 | trainer: 19 | _target_: src.model.MCDTrainer.MCDTrainer 20 | _partial_: true 21 | 22 | batch_size: 128 23 | sparsity_factor: 5 24 | likelihood_loss: flow 25 | matrix_temperature: 0.25 26 | threeway_graph_dist: true 27 | training_procedure: auglag 28 | skip_auglag_epochs: 0 29 | 30 | # multi-rhino settings 31 | num_graphs: 10 32 | use_correct_mixture_index: false 33 | use_true_graph: false 34 | log_num_unique_graphs: false 35 | 36 | auglag_config: 37 | _target_: src.training.auglag.AugLagLRConfig 38 | 39 | lr_update_lag: 500 40 | lr_update_lag_best: 500 41 | lr_init_dict: 42 | vardist: 0.01 43 | functional_relationships: 0.01 44 | noise_dist: 0.001 45 | mixing_probs: 0.01 46 | 47 | aggregation_period: 20 48 | lr_factor: 0.1 49 | max_lr_down: 3 50 | penalty_progress_rate: 0.65 51 | safety_rho: 1e13 52 | safety_alpha: 1e13 53 | inner_early_stopping_patience: 1500 54 | max_outer_steps: 100 55 | patience_penalty_reached: 100 56 | patience_max_rho: 50 57 | penalty_tolerance: 1e-8 58 | max_inner_steps: 6000 59 | force_not_converged: false 60 | init_rho: 1 61 | init_alpha: 0 -------------------------------------------------------------------------------- /configs/synthetic/mcd_linear.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: mcd 5 | monitor_checkpoint_based_on: likelihood 6 | 7 | # decoder options 8 | decoder: 9 | _target_: src.modules.MultiCausalDecoder.MultiCausalDecoder 10 | linear: true 11 | 12 | trainer: 13 | _target_: src.model.MCDTrainer.MCDTrainer 14 | _partial_: true 15 | 16 | batch_size: 128 17 | sparsity_factor: 5 18 | likelihood_loss: mse 19 | matrix_temperature: 0.25 20 | threeway_graph_dist: true 21 | training_procedure: auglag 22 | skip_auglag_epochs: 0 23 | 24 | # multi-rhino settings 25 | num_graphs: 10 26 | use_correct_mixture_index: false 27 | use_true_graph: false 28 | log_num_unique_graphs: false 29 | 30 | auglag_config: 31 | _target_: src.training.auglag.AugLagLRConfig 32 | 33 | lr_update_lag: 500 34 | lr_update_lag_best: 500 35 | lr_init_dict: 36 | vardist: 0.01 37 | functional_relationships: 0.01 38 | noise_dist: 0.001 39 | mixing_probs: 0.01 40 | 41 | aggregation_period: 20 42 | lr_factor: 0.1 43 | max_lr_down: 3 44 | penalty_progress_rate: 0.65 45 | safety_rho: 1e13 46 | safety_alpha: 1e13 47 | inner_early_stopping_patience: 1500 48 | max_outer_steps: 100 49 | patience_penalty_reached: 100 50 | patience_max_rho: 50 51 | penalty_tolerance: 1e-8 52 | max_inner_steps: 6000 53 | force_not_converged: false 54 | init_rho: 1 55 | init_alpha: 0 -------------------------------------------------------------------------------- /configs/synthetic/pcmci.yaml: -------------------------------------------------------------------------------- 1 | model: pcmci 2 | trainer: 3 | _target_: src.baselines.PCMCITrainer.PCMCITrainer 4 | _partial_: true 5 | 6 | ci_test: ParCorr 7 | single_graph: false 8 | pcmci_plus: true 9 | pc_alpha: 0.01 10 | 11 | group_by_graph: false -------------------------------------------------------------------------------- /configs/synthetic/rhino.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: rhino 5 | monitor_checkpoint_based_on: likelihood 6 | 7 | hypernet: 8 | _target_: src.modules.TemporalHyperNet.TemporalHyperNet 9 | order: quadratic 10 | num_bins: 8 11 | skip_connection: true 12 | 13 | # decoder options 14 | decoder: 15 | _target_: src.modules.CausalDecoder.CausalDecoder 16 | skip_connection: true 17 | 18 | trainer: 19 | _target_: src.model.RhinoTrainer.RhinoTrainer 20 | _partial_: true 21 | 22 | batch_size: 128 23 | sparsity_factor: 5 24 | 25 | likelihood_loss: flow 26 | matrix_temperature: 0.25 27 | threeway_graph_dist: true 28 | training_procedure: auglag # options: auglag, dagma 29 | skip_auglag_epochs: 0 30 | 31 | auglag_config: 32 | _target_: src.training.auglag.AugLagLRConfig 33 | 34 | lr_update_lag: 500 35 | lr_update_lag_best: 500 36 | lr_init_dict: 37 | vardist: 0.01 38 | functional_relationships: 0.01 39 | noise_dist: 0.001 40 | mixing_probs: 0.01 41 | 42 | aggregation_period: 20 43 | lr_factor: 0.1 44 | max_lr_down: 3 45 | penalty_progress_rate: 0.65 46 | safety_rho: 1e13 47 | safety_alpha: 1e13 48 | inner_early_stopping_patience: 1500 49 | max_outer_steps: 100 50 | patience_penalty_reached: 100 51 | patience_max_rho: 50 52 | penalty_tolerance: 1e-8 53 | max_inner_steps: 6000 54 | force_not_converged: false 55 | init_rho: 1 56 | init_alpha: 0 57 | -------------------------------------------------------------------------------- /configs/synthetic/rhino_linear.yaml: -------------------------------------------------------------------------------- 1 | # hyperparameters 2 | watch_gradients: false 3 | num_epochs: 10000 4 | model: rhino 5 | monitor_checkpoint_based_on: likelihood 6 | 7 | # decoder options 8 | decoder: 9 | _target_: src.modules.CausalDecoder.CausalDecoder 10 | linear: true 11 | 12 | trainer: 13 | _target_: src.model.RhinoTrainer.RhinoTrainer 14 | _partial_: true 15 | 16 | batch_size: 128 17 | sparsity_factor: 5 18 | 19 | likelihood_loss: mse 20 | matrix_temperature: 0.25 21 | threeway_graph_dist: true 22 | training_procedure: auglag # options: auglag, dagma 23 | skip_auglag_epochs: 0 24 | 25 | auglag_config: 26 | _target_: src.training.auglag.AugLagLRConfig 27 | 28 | lr_update_lag: 500 29 | lr_update_lag_best: 500 30 | lr_init_dict: 31 | vardist: 0.01 32 | functional_relationships: 0.01 33 | noise_dist: 0.001 34 | mixing_probs: 0.01 35 | 36 | aggregation_period: 20 37 | lr_factor: 0.1 38 | max_lr_down: 3 39 | penalty_progress_rate: 0.65 40 | safety_rho: 1e13 41 | safety_alpha: 1e13 42 | inner_early_stopping_patience: 1500 43 | max_outer_steps: 100 44 | patience_penalty_reached: 100 45 | patience_max_rho: 50 46 | penalty_tolerance: 1e-8 47 | max_inner_steps: 6000 48 | force_not_converged: false 49 | init_rho: 1 50 | init_alpha: 0 51 | -------------------------------------------------------------------------------- /configs/synthetic/varlingam.yaml: -------------------------------------------------------------------------------- 1 | model: varlingam 2 | trainer: 3 | _target_: src.baselines.VARLiNGAMTrainer.VARLiNGAMTrainer 4 | _partial_: true -------------------------------------------------------------------------------- /scripts/generate_perturb_datasets.sh: -------------------------------------------------------------------------------- 1 | python3 src/utils/data_gen/generate_perturb_syn.py --config_file configs/data_generation/perturb_data.yaml -------------------------------------------------------------------------------- /scripts/generate_snp100.sh: -------------------------------------------------------------------------------- 1 | python3 -m src.utils.data_gen.generate_stock --stock_list_file assets/snp100.csv --save_dir data/snp100/ --chunk_size 31 -------------------------------------------------------------------------------- /scripts/generate_synthetic_datasets_linear.sh: -------------------------------------------------------------------------------- 1 | python3 src/utils/data_gen/generate_synthetic_data.py --config_file configs/data_generation/synthetic_data_linear.yaml -------------------------------------------------------------------------------- /scripts/generate_synthetic_datasets_nonlinear.sh: -------------------------------------------------------------------------------- 1 | python3 src/utils/data_gen/generate_synthetic_data.py --config_file configs/data_generation/synthetic_data_nonlinear.yaml -------------------------------------------------------------------------------- /scripts/setup_dream3.sh: -------------------------------------------------------------------------------- 1 | mkdir temp 2 | cd temp 3 | 4 | wget https://github.com/sakhanna/SRU_for_GCI/archive/refs/heads/master.zip 5 | unzip master.zip 6 | 7 | if [ ! -d "../data/dream3/raw" ] 8 | then 9 | mkdir -p ../data/dream3/raw 10 | fi 11 | 12 | mv SRU_for_GCI-master/data/dream3/* ../data/dream3/raw/ 13 | cd ../ 14 | rm -r temp 15 | python3 src/utils/data_gen/process_dream3.py --dataset_dir data/dream3/raw --save_dir data/dream3/ 16 | 17 | -------------------------------------------------------------------------------- /scripts/setup_netsim.sh: -------------------------------------------------------------------------------- 1 | mkdir temp 2 | cd temp 3 | 4 | wget https://www.fmrib.ox.ac.uk/datasets/netsim/sims.tar.gz 5 | tar xvf sims.tar.gz 6 | 7 | if [ ! -d "../data/netsim/raw" ] 8 | then 9 | mkdir -p ../data/netsim/raw 10 | fi 11 | 12 | 13 | mv *.mat ../data/netsim/raw/ 14 | cd ../ 15 | rm -r temp 16 | python3 src/utils/data_gen/process_netsim.py --dataset_dir data/netsim/raw --save_dir data/netsim/ 17 | 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='multi_graph_rhino', 5 | packages=find_packages(), 6 | ) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/MCD/5276806eb761545d2e99f53d5135289d7def20e3/src/__init__.py -------------------------------------------------------------------------------- /src/baselines/BaselineTrainer.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | from torch.utils.data import DataLoader 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from src.dataset.BaselineTSDataset import BaselineTSDataset 7 | from src.utils.metrics_utils import mape_loss 8 | 9 | class BaselineTrainer(pl.LightningModule): 10 | # trainer class for the baselines 11 | # they do not need training 12 | def __init__(self, 13 | full_dataset: np.array, 14 | adj_matrices: np.array, 15 | data_dim: int, 16 | num_nodes: int, 17 | lag: int, 18 | num_workers: int = 16, 19 | aggregated_graph: bool = False): 20 | super().__init__() 21 | 22 | self.num_workers = num_workers 23 | self.aggregated_graph = aggregated_graph 24 | self.data_dim = data_dim 25 | self.lag = lag 26 | self.num_nodes = num_nodes 27 | self.full_dataset_np = full_dataset 28 | self.adj_matrices_np = adj_matrices 29 | self.total_samples = full_dataset.shape[0] 30 | assert adj_matrices.shape[0] == self.total_samples 31 | 32 | self.full_dataset = BaselineTSDataset( 33 | X = self.full_dataset_np, 34 | adj_matrix = self.adj_matrices_np, 35 | lag=lag, 36 | aggregated_graph=self.aggregated_graph, 37 | return_graph_indices=True 38 | ) 39 | 40 | self.batch_size = 1 41 | 42 | def compute_mse(self, x_current, x_pred): 43 | return F.mse_loss(x_current, x_pred) 44 | 45 | def compute_mape(self, x_current, x_pred): 46 | return mape_loss(x_current, x_pred) 47 | 48 | def get_full_dataloader(self) -> DataLoader: 49 | return DataLoader(self.full_dataset, batch_size=self.batch_size, num_workers=self.num_workers) 50 | -------------------------------------------------------------------------------- /src/baselines/DYNOTEARSTrainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import numpy as np 3 | import torch 4 | 5 | # import tigramite for pcmci 6 | import networkx as nx 7 | import pandas as pd 8 | 9 | from src.baselines.BaselineTrainer import BaselineTrainer 10 | from src.modules.dynotears.dynotears import from_pandas_dynamic 11 | from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, to_time_aggregated_scores_np, zero_out_diag_np 12 | 13 | class DYNOTEARSTrainer(BaselineTrainer): 14 | 15 | def __init__(self, 16 | full_dataset: np.array, 17 | adj_matrices: np.array, 18 | data_dim: int, 19 | num_nodes: int, 20 | lag: int, 21 | num_workers: int = 16, 22 | aggregated_graph: bool = False, 23 | single_graph: bool = True, 24 | max_iter: int = 1000, 25 | lambda_w: float = 0.1, 26 | lambda_a: float = 0.1, 27 | w_threshold: float = 0.4, 28 | h_tol: float = 1e-8, 29 | group_by_graph: bool = True, 30 | ignore_self_connections: bool = False 31 | ): 32 | self.group_by_graph = group_by_graph 33 | self.ignore_self_connections = ignore_self_connections 34 | if self.group_by_graph: 35 | self.single_graph = True 36 | print("DYNOTEARS: Group by graph option set. Overriding single graph flag to True...") 37 | else: 38 | self.single_graph = single_graph 39 | super().__init__(full_dataset=full_dataset, 40 | adj_matrices=adj_matrices, 41 | data_dim=data_dim, 42 | lag=lag, 43 | num_nodes=num_nodes, 44 | num_workers=num_workers, 45 | aggregated_graph=aggregated_graph) 46 | 47 | self.max_iter = max_iter 48 | self.lambda_w = lambda_w 49 | self.lambda_a = lambda_a 50 | self.w_threshold = w_threshold 51 | self.h_tol = h_tol 52 | 53 | if self.single_graph: 54 | self.batch_size = full_dataset.shape[0] # we want the full dataset 55 | else: 56 | self.batch_size = 1 57 | 58 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: 59 | X, adj_matrix, graph_index = batch 60 | 61 | batch_size, timesteps, num_nodes, _ = X.shape 62 | assert num_nodes == self.num_nodes 63 | X = X.view(batch_size, timesteps, -1) 64 | 65 | X, adj_matrix, graph_index = X.cpu().numpy(), adj_matrix.cpu().numpy(), graph_index.cpu().numpy() 66 | 67 | X_list = [] 68 | 69 | graphs = np.zeros((batch_size, self.lag+1, num_nodes, num_nodes)) 70 | scores = np.zeros((batch_size, self.lag+1, num_nodes, num_nodes)) 71 | if self.group_by_graph: 72 | n_unique_matrices = np.max(graph_index)+1 73 | else: 74 | graph_index = np.zeros((batch_size)) 75 | n_unique_matrices = 1 76 | 77 | for i in range(n_unique_matrices): 78 | 79 | n_samples = np.sum(graph_index == i) 80 | for x in X[graph_index == i]: 81 | X_list.append(pd.DataFrame(x)) 82 | learner = from_pandas_dynamic( 83 | X_list, 84 | p=self.lag, 85 | max_iter=self.max_iter, 86 | lambda_w=self.lambda_w, 87 | lambda_a=self.lambda_a, 88 | w_threshold=self.w_threshold, 89 | h_tol=self.h_tol) 90 | 91 | adj_static = nx.to_numpy_array(learner) 92 | temporal_adj_list = [] 93 | for l in range(self.lag + 1): 94 | cur_adj = adj_static[l:: self.lag + 1, 0:: self.lag + 1] 95 | temporal_adj_list.append(cur_adj) 96 | 97 | # [lag+1, num_nodes, num_nodes] 98 | score = np.stack(temporal_adj_list, axis=0) 99 | # scores = np.hstack(temporal_adj_list) 100 | temporal_adj = [(score != 0).astype(int) for _ in range(n_samples)] 101 | score = [np.abs(score) for _ in range(n_samples)] 102 | graphs[i == graph_index] = np.array(temporal_adj) 103 | scores[i == graph_index] = np.array(score) 104 | if self.aggregated_graph: 105 | graphs = to_time_aggregated_graph_np(graphs) 106 | scores = to_time_aggregated_scores_np(scores) 107 | if self.ignore_self_connections: 108 | graphs = zero_out_diag_np(graphs) 109 | scores = zero_out_diag_np(scores) 110 | return torch.Tensor(graphs), torch.abs(torch.Tensor(scores)), torch.Tensor(adj_matrix) 111 | -------------------------------------------------------------------------------- /src/baselines/PCMCITrainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from copy import deepcopy 3 | 4 | import numpy as np 5 | # import tigramite for pcmci 6 | from tigramite import data_processing as pp 7 | from tigramite.pcmci import PCMCI 8 | from tigramite.independence_tests.parcorr import ParCorr 9 | from tigramite.independence_tests.cmiknn import CMIknn 10 | import torch 11 | 12 | from src.utils.causality_utils import convert_temporal_to_static_adjacency_matrix, cpdag2dags 13 | from src.baselines.BaselineTrainer import BaselineTrainer 14 | from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, zero_out_diag_np 15 | 16 | """ 17 | Large parts adapted from https://github.com/microsoft/causica 18 | """ 19 | 20 | class PCMCITrainer(BaselineTrainer): 21 | 22 | def __init__(self, 23 | full_dataset: np.array, 24 | adj_matrices: np.array, 25 | data_dim: int, 26 | num_nodes: int, 27 | lag: int, 28 | num_workers: int = 16, 29 | aggregated_graph: bool = False, 30 | ci_test: str = 'ParCorr', # options are ParCorr, CMIknn, GPDCtorch 31 | single_graph: bool = False, 32 | pcmci_plus: bool = True, 33 | pc_alpha: float = 0.01, 34 | group_by_graph: bool = False, 35 | ignore_self_connections: bool = False 36 | ): 37 | self.group_by_graph = group_by_graph 38 | self.ignore_self_connections = ignore_self_connections 39 | if self.group_by_graph: 40 | print("PCMCI: Group by graph option set. Overriding single graph flag to True...") 41 | self.single_graph = True 42 | else: 43 | self.single_graph = single_graph 44 | 45 | super().__init__(full_dataset=full_dataset, 46 | adj_matrices=adj_matrices, 47 | data_dim=data_dim, 48 | lag=lag, 49 | num_nodes=num_nodes, 50 | num_workers=num_workers, 51 | aggregated_graph=aggregated_graph) 52 | 53 | if ci_test == 'ParCorr': 54 | self.ci_test = ParCorr(significance='analytic') 55 | elif ci_test == 'CMIknn': 56 | self.ci_test = CMIknn() 57 | 58 | # self.single_graph = single_graph 59 | self.pcmci_plus = pcmci_plus 60 | self.pc_alpha = pc_alpha 61 | 62 | if self.single_graph: 63 | self.batch_size = full_dataset.shape[0] # we want the full dataset 64 | self.analysis_mode = 'multiple' 65 | else: 66 | self.batch_size = 1 67 | self.analysis_mode = 'single' 68 | 69 | def _process_adj_matrix(self, adj_matrix: np.ndarray) -> np.ndarray: 70 | """ 71 | Borrowed from microsoft/causica 72 | 73 | This will process the raw output adj graphs from pcmci_plus. The raw output can contain 3 types of edges: 74 | (1) "-->" or "<--". This indicates the directed edges, and they should appear symmetrically in the matrix. 75 | (2) "o-o": This indicates the bi-directed edges, also appears symmetrically. 76 | Note: for lagged matrix, it can only contain "-->". 77 | (3) "x-x": this means the edge direction is un-decided due to conflicting orientation rules. We ignores 78 | the edges in this case. 79 | Args: 80 | inst_matrix: the input raw inst matrix with shape [num_nodes, num_nodes, lag+1] 81 | Returns: 82 | inst_adj_matrix: np.ndarray, an inst adj matrix with shape [lag+1, num_nodes, num_nodes] 83 | """ 84 | assert adj_matrix.ndim == 3 85 | 86 | adj_matrix = deepcopy(adj_matrix) 87 | # shape [lag+1, num_nodes, num_nodes] 88 | adj_matrix = np.moveaxis(adj_matrix, -1, 0) 89 | adj_matrix[adj_matrix == ""] = 0 90 | adj_matrix[adj_matrix == "<--"] = 0 91 | adj_matrix[adj_matrix == "-->"] = 1 92 | adj_matrix[adj_matrix == "o-o"] = 1 93 | adj_matrix[adj_matrix == "x-x"] = 0 94 | 95 | return adj_matrix.astype(int) 96 | 97 | def _run_pcmci(self, pcmci, tau_max, pc_alpha): 98 | if self.pcmci_plus: 99 | return pcmci.run_pcmciplus(tau_max=tau_max, pc_alpha=pc_alpha) 100 | return pcmci.run_pcmci(tau_max=tau_max, pc_alpha=pc_alpha) 101 | 102 | def _process_cpdag(self, adj_matrix: np.ndarray): 103 | """ 104 | This will process the inst cpdag (i.e. adj_matrix[0, ...]) according to the mec_mode. It supports "enumerate" and "truth" 105 | Args: 106 | adj_matrix: np.ndarray, a temporal adj matrix with shape [lag+1, num_nodes, num_nodes] where the inst part can be a cpdag. 107 | 108 | Returns: 109 | adj_matrix: np.ndarray with shape [num_possible_dags, lag+1, num_nodes, num_nodes] 110 | """ 111 | lag_plus, num_nodes = adj_matrix.shape[0], adj_matrix.shape[1] 112 | static_temporal_graph = convert_temporal_to_static_adjacency_matrix( 113 | adj_matrix, conversion_type="auto_regressive" 114 | ) # shape[(lag+1) *nodes, (lag+1)*nodes] 115 | all_static_temp_dags = cpdag2dags( 116 | static_temporal_graph, samples=3000 117 | ) # [all_possible_dags, (lag+1)*num_nodes, (lag+1)*num_nodes] 118 | # convert back to temporal adj matrix. 119 | temp_adj_list = np.split( 120 | all_static_temp_dags[..., :, (lag_plus - 1) * num_nodes :], lag_plus, axis=1 121 | ) # list with length lag+1, each with shape [all_possible_dags, num_nodes, num_nodes] 122 | proc_adj_matrix = np.stack( 123 | list(reversed(temp_adj_list)), axis=-3 124 | ) # shape [all_possible_dags, lag+1, num_nodes, num_nodes] 125 | 126 | return proc_adj_matrix 127 | 128 | 129 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: 130 | X, adj_matrix, graph_index = batch 131 | batch_size, timesteps, num_nodes, _ = X.shape 132 | assert num_nodes == self.num_nodes 133 | X = X.view(batch_size, timesteps, -1) 134 | X, adj_matrix, graph_index = X.numpy(), adj_matrix.numpy(), graph_index.numpy() 135 | graphs = [] #np.zeros((batch_size, self.lag+1, num_nodes, num_nodes)) 136 | new_adj_matrix = [] 137 | if self.group_by_graph: 138 | n_unique_matrices = np.max(graph_index)+1 139 | else: 140 | graph_index = np.zeros((batch_size)) 141 | n_unique_matrices = 1 142 | for i in range(n_unique_matrices): 143 | print(f"{i}/{n_unique_matrices}") 144 | n_samples = np.sum(graph_index == i) 145 | dataframe = pp.DataFrame(X[graph_index==i], analysis_mode=self.analysis_mode) 146 | pcmci = PCMCI( 147 | dataframe=dataframe, 148 | cond_ind_test=self.ci_test, 149 | verbosity=0) 150 | 151 | results = self._run_pcmci(pcmci, self.lag, self.pc_alpha) 152 | graph = self._process_adj_matrix(results["graph"]) 153 | graph = self._process_cpdag(graph) 154 | num_possible_dags = graph.shape[0] 155 | new_adj_matrix.append(np.repeat(adj_matrix[graph_index==i][0][np.newaxis, ...], n_samples*num_possible_dags, axis=0)) 156 | graphs.append(np.repeat(graph, n_samples, axis=0)) 157 | 158 | graphs = np.concatenate(graphs, axis=0) 159 | new_adj_matrix = np.concatenate(new_adj_matrix, axis=0) 160 | if self.aggregated_graph: 161 | graphs = to_time_aggregated_graph_np(graphs) 162 | if self.ignore_self_connections: 163 | graphs = zero_out_diag_np(graphs) 164 | 165 | return torch.Tensor(graphs), torch.Tensor(graphs), torch.Tensor(new_adj_matrix) 166 | -------------------------------------------------------------------------------- /src/baselines/VARLiNGAMTrainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import numpy as np 3 | 4 | import lingam 5 | import torch 6 | 7 | from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np 8 | from src.baselines.BaselineTrainer import BaselineTrainer 9 | 10 | class VARLiNGAMTrainer(BaselineTrainer): 11 | 12 | def __init__(self, 13 | full_dataset: np.array, 14 | adj_matrices: np.array, 15 | data_dim: int, 16 | num_nodes: int, 17 | lag: int, 18 | num_workers: int = 16, 19 | aggregated_graph: bool = False 20 | ): 21 | super().__init__(full_dataset=full_dataset, 22 | adj_matrices=adj_matrices, 23 | data_dim=data_dim, 24 | lag=lag, 25 | num_nodes=num_nodes, 26 | num_workers=num_workers, 27 | aggregated_graph=aggregated_graph) 28 | 29 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: 30 | X, adj_matrix, _ = batch 31 | 32 | batch, timesteps, num_nodes, _ = X.shape 33 | X = X.view(batch, timesteps, -1) 34 | 35 | assert num_nodes == self.num_nodes 36 | assert batch == 1, "VARLiNGAM needs batch size 1" 37 | 38 | model_pruned = lingam.VARLiNGAM(lags=self.lag, prune=True) 39 | model_pruned.fit(X[0]) 40 | graph = np.transpose(np.abs(model_pruned.adjacency_matrices_) > 0, axes=[0, 2, 1]) 41 | if graph.shape[0] != (self.lag+1): 42 | while graph.shape[0] != (self.lag+1): 43 | graph = np.concatenate((graph, np.zeros((1, num_nodes, num_nodes) )), axis=0) 44 | graphs = [graph] 45 | if self.aggregated_graph: 46 | graphs = to_time_aggregated_graph_np(graphs) 47 | print(graphs) 48 | print(adj_matrix) 49 | return torch.Tensor(graphs), torch.Tensor(graphs), torch.Tensor(adj_matrix) -------------------------------------------------------------------------------- /src/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/MCD/5276806eb761545d2e99f53d5135289d7def20e3/src/baselines/__init__.py -------------------------------------------------------------------------------- /src/dataset/BaselineTSDataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from src.utils.data_utils.data_format_utils import get_adj_matrix_id 3 | 4 | class BaselineTSDataset(Dataset): 5 | # Dataset class that can be used with the baselines PCMCI(+), VARLiNGAM and DYNOTEARS 6 | def __init__(self, 7 | X, 8 | adj_matrix, 9 | lag, 10 | aggregated_graph=False, 11 | return_graph_indices=False): 12 | """ 13 | X: np.array of shape (n_samples, timesteps, num_nodes, data_dim) 14 | adj_matrix: np.array of shape (n_samples, lag+1, num_nodes, num_nodes) 15 | """ 16 | self.lag = lag 17 | self.aggregated_graph = aggregated_graph 18 | self.X = X 19 | self.adj_matrix = adj_matrix 20 | self.return_graph_indices = return_graph_indices 21 | if self.return_graph_indices: 22 | self.unique_matrices, self.matrix_indices = get_adj_matrix_id(self.adj_matrix) 23 | 24 | def __len__(self): 25 | return self.X.shape[0] 26 | 27 | def __getitem__(self, index): 28 | if not self.return_graph_indices: 29 | return self.X[index], self.adj_matrix[index] 30 | return self.X[index], self.adj_matrix[index], self.matrix_indices[index] 31 | -------------------------------------------------------------------------------- /src/dataset/FragmentDataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Terminology: 3 | 4 | A fragment refers to the pair of tensors X_history and x_current, where X_history represents 5 | the lag information (X(t-lag) to X(t-1)) and x_current represents the current information X(t). 6 | Note that which sample a fragment comes from is irrelevant, since all we are concerned about is the 7 | causal graph which generated the fragment. 8 | """ 9 | 10 | from torch.utils.data import Dataset 11 | import torch 12 | 13 | from src.utils.data_utils.data_format_utils import convert_data_to_timelagged, convert_adj_to_timelagged 14 | 15 | 16 | class FragmentDataset(Dataset): 17 | def __init__(self, 18 | X, 19 | adj_matrix, 20 | lag, 21 | return_graph_indices=True, 22 | aggregated_graph=False): 23 | """ 24 | X: np.array of shape (n_samples, timesteps, num_nodes, data_dim) 25 | adj_matrix: np.array of shape (n_samples, lag+1, num_nodes, num_nodes) 26 | """ 27 | self.lag = lag 28 | self.aggregated_graph = aggregated_graph 29 | self.return_graph_indices = return_graph_indices 30 | # preprocess data 31 | self.X_history, self.x_current, self.X_indices = convert_data_to_timelagged( 32 | X, lag=lag) 33 | if self.return_graph_indices: 34 | self.adj_matrix, self.graph_indices = convert_adj_to_timelagged( 35 | adj_matrix, 36 | lag=lag, 37 | n_fragments=self.X_history.shape[0], 38 | aggregated_graph=self.aggregated_graph, 39 | return_indices=True) 40 | else: 41 | self.adj_matrix = convert_adj_to_timelagged( 42 | adj_matrix, 43 | lag=lag, 44 | n_fragments=self.X_history.shape[0], 45 | aggregated_graph=self.aggregated_graph, 46 | return_indices=False) 47 | 48 | self.X_history, self.x_current, self.adj_matrix, self.X_indices = \ 49 | torch.Tensor(self.X_history), torch.Tensor(self.x_current), torch.Tensor( 50 | self.adj_matrix), torch.Tensor(self.X_indices) 51 | if self.return_graph_indices: 52 | self.graph_indices = torch.Tensor(self.graph_indices) 53 | 54 | self.num_fragments = self.X_history.shape[0] 55 | 56 | def __len__(self): 57 | return self.num_fragments 58 | 59 | def __getitem__(self, index): 60 | if self.return_graph_indices: 61 | return self.X_history[index], self.x_current[index], self.adj_matrix[index], self.X_indices[index].long(), \ 62 | self.graph_indices[index].long() 63 | return self.X_history[index], self.x_current[index], self.adj_matrix[index], self.X_indices[index].long() 64 | -------------------------------------------------------------------------------- /src/model/BaseTrainer.py: -------------------------------------------------------------------------------- 1 | 2 | import lightning.pytorch as pl 3 | import torch 4 | from torch.utils.data import DataLoader, TensorDataset 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | from src.dataset.FragmentDataset import FragmentDataset 9 | from src.utils.metrics_utils import mape_loss 10 | 11 | 12 | class BaseTrainer(pl.LightningModule): 13 | 14 | def __init__(self, 15 | full_dataset: np.array, 16 | adj_matrices: np.array, 17 | data_dim: int, 18 | lag: int, 19 | num_workers: int = 16, 20 | batch_size: int = 256, 21 | aggregated_graph: bool = False, 22 | return_graph_indices: bool = True, 23 | val_frac: float = 0.2, 24 | use_all_for_val: bool = False, 25 | shuffle: bool = True): 26 | super().__init__() 27 | # self.automatic_optimization = False 28 | self.num_workers = num_workers 29 | self.batch_size = batch_size 30 | self.aggregated_graph = aggregated_graph 31 | self.data_dim = data_dim 32 | self.lag = lag 33 | self.return_graph_indices = return_graph_indices 34 | 35 | I = np.arange(full_dataset.shape[0]) 36 | if shuffle: 37 | rng = np.random.default_rng() 38 | rng.shuffle(I) 39 | 40 | self.full_dataset_np = full_dataset[I] 41 | self.adj_matrices_np = adj_matrices[I] 42 | 43 | self.use_all_for_val = use_all_for_val 44 | self.val_frac = val_frac 45 | assert self.val_frac >= 0.0 and self.val_frac < 1.0, "Validation fraction should be between 0 and 1" 46 | 47 | num_samples = full_dataset.shape[0] 48 | self.total_samples = num_samples 49 | assert adj_matrices.shape[0] == num_samples 50 | 51 | if self.use_all_for_val: 52 | print("Using *all* examples for validation. Ignoring val_frac...") 53 | self.train_dataset_np = self.full_dataset_np 54 | self.val_dataset_np = self.full_dataset_np 55 | self.train_adj_np = self.adj_matrices_np 56 | self.val_adj_np = self.adj_matrices_np 57 | else: 58 | self.train_dataset_np = self.full_dataset_np[:int( 59 | (1-self.val_frac)*num_samples)] 60 | self.val_dataset_np = self.full_dataset_np[int( 61 | (1-self.val_frac)*num_samples):] 62 | self.train_adj_np = self.adj_matrices_np[:int( 63 | (1-self.val_frac)*num_samples)] 64 | self.val_adj_np = self.adj_matrices_np[int( 65 | (1-self.val_frac)*num_samples):] 66 | self.train_frag_dataset = FragmentDataset( 67 | self.train_dataset_np, 68 | self.train_adj_np, 69 | lag=lag, 70 | aggregated_graph=self.aggregated_graph, 71 | return_graph_indices=self.return_graph_indices) 72 | self.val_frag_dataset = FragmentDataset( 73 | self.val_dataset_np, 74 | self.val_adj_np, 75 | lag=lag, 76 | aggregated_graph=self.aggregated_graph, 77 | return_graph_indices=self.return_graph_indices) 78 | # self.full_frag_dataset = FragmentDataset( 79 | # self.full_dataset_np, 80 | # self.adj_matrices_np, 81 | # lag=lag, 82 | # aggregated_graph=self.aggregated_graph, 83 | # return_graph_indices=self.return_graph_indices) 84 | self.num_fragments = len(self.train_frag_dataset) 85 | self.full_dataset = TensorDataset( 86 | torch.Tensor(self.full_dataset_np), 87 | torch.Tensor(self.adj_matrices_np), 88 | torch.arange(self.full_dataset_np.shape[0])) 89 | if self.batch_size is None: 90 | # do full-batch training 91 | self.batch_size = self.num_fragments 92 | 93 | def forward(self): 94 | raise NotImplementedError 95 | 96 | def compute_loss(self, X_history, x_current, X_full, adj_matrix): 97 | raise NotImplementedError 98 | 99 | def compute_mse(self, x_current, x_pred): 100 | return F.mse_loss(x_current, x_pred) 101 | 102 | def compute_mape(self, x_current, x_pred): 103 | return mape_loss(x_current, x_pred) 104 | 105 | def training_step(self, batch, batch_idx): 106 | X_history, x_current, X_full, adj_matrix = batch 107 | loss = self.compute_loss(X_history, x_current, X_full, adj_matrix) 108 | return loss 109 | 110 | def validation_step(self, batch, batch_idx): 111 | X_history, x_current, X_full, adj_matrix = batch 112 | loss = self.compute_loss(X_history, x_current, X_full, adj_matrix) 113 | self.log("val_loss", loss) 114 | 115 | def train_dataloader(self) -> DataLoader: 116 | return DataLoader(self.train_frag_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) 117 | 118 | def val_dataloader(self) -> DataLoader: 119 | return DataLoader(self.val_frag_dataset, batch_size=self.batch_size, num_workers=self.num_workers) 120 | 121 | def get_full_dataloader(self) -> DataLoader: 122 | return DataLoader(self.full_dataset, batch_size=self.batch_size, num_workers=self.num_workers) 123 | 124 | def track_gradients(self, m, log_name): 125 | total_norm = 0 126 | for p in m.parameters(): 127 | if p.grad is not None: 128 | param_norm = p.grad.data.norm(2) 129 | total_norm += param_norm.item() ** 2 130 | total_norm = total_norm ** (1. / 2) 131 | self.log(log_name, total_norm) 132 | -------------------------------------------------------------------------------- /src/model/RhinoTrainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import f1_score 4 | 5 | from src.utils.loss_utils import dag_penalty_notears, temporal_graph_sparsity 6 | from src.utils.data_utils.data_format_utils import to_time_aggregated_graph_np, to_time_aggregated_scores_torch, zero_out_diag_np, zero_out_diag_torch 7 | from src.utils.metrics_utils import compute_shd 8 | from src.model.BaseTrainer import BaseTrainer 9 | from src.modules.adjacency_matrices.TemporalAdjacencyMatrix import TemporalAdjacencyMatrix 10 | from src.modules.adjacency_matrices.TwoWayTemporalAdjacencyMatrix import TwoWayTemporalAdjacencyMatrix 11 | from src.training.auglag import AugLagLRConfig, AuglagLRCallback, AugLagLR, AugLagLossCalculator 12 | 13 | 14 | class RhinoTrainer(BaseTrainer): 15 | 16 | def __init__(self, 17 | full_dataset: np.array, 18 | adj_matrices: np.array, 19 | data_dim: int, 20 | lag: int, 21 | num_nodes: int, 22 | causal_decoder, 23 | tcsf, 24 | disable_inst: bool = False, 25 | likelihood_loss: str = 'flow', 26 | sparsity_factor: float = 20, 27 | num_workers: int = 16, 28 | batch_size: int = 256, 29 | matrix_temperature: float = 0.25, 30 | aggregated_graph: bool = False, 31 | ignore_self_connections: bool = False, 32 | threeway_graph_dist: bool = True, 33 | skip_auglag_epochs: int = 0, 34 | training_procedure: str = 'auglag', 35 | training_config=None, 36 | init_logits=[0, 0], 37 | use_all_for_val=False, 38 | shuffle=True): 39 | 40 | self.aggregated_graph = aggregated_graph 41 | self.ignore_self_connections = ignore_self_connections 42 | self.skip_auglag_epochs = skip_auglag_epochs 43 | self.init_logits = init_logits 44 | self.disable_inst = disable_inst 45 | super().__init__(full_dataset=full_dataset, 46 | adj_matrices=adj_matrices, 47 | data_dim=data_dim, 48 | lag=lag, 49 | num_workers=num_workers, 50 | batch_size=batch_size, 51 | aggregated_graph=self.aggregated_graph, 52 | use_all_for_val=use_all_for_val, 53 | shuffle=shuffle) 54 | self.num_nodes = num_nodes 55 | self.matrix_temperature = matrix_temperature 56 | self.threeway_graph_dist = threeway_graph_dist 57 | self.training_procedure = training_procedure 58 | 59 | print("Number of fragments:", self.num_fragments) 60 | print("Number of samples:", self.total_samples) 61 | 62 | assert likelihood_loss in ['mse', 'flow'] 63 | self.likelihood_loss = likelihood_loss 64 | 65 | self.sparsity_factor = sparsity_factor 66 | self.graphs = [] 67 | 68 | self.causal_decoder = causal_decoder 69 | self.tcsf = tcsf 70 | 71 | self.initialize_graph() 72 | 73 | if self.training_procedure == 'auglag': 74 | self.training_config = training_config 75 | if training_config is None: 76 | self.training_config = AugLagLRConfig() 77 | if self.skip_auglag_epochs > 0: 78 | print( 79 | f"Not performing augmented lagrangian optimization for the first {self.skip_auglag_epochs} epochs...") 80 | self.disabled_epochs = range(self.skip_auglag_epochs) 81 | else: 82 | self.disabled_epochs = None 83 | self.lr_scheduler = AugLagLR(config=self.training_config) 84 | self.loss_calc = AugLagLossCalculator(init_alpha=self.training_config.init_alpha, 85 | init_rho=self.training_config.init_rho) 86 | 87 | def initialize_graph(self): 88 | if self.threeway_graph_dist: 89 | self.adj_matrix = TemporalAdjacencyMatrix( 90 | input_dim=self.num_nodes, 91 | lag=self.lag, 92 | tau_gumbel=self.matrix_temperature, 93 | init_logits=self.init_logits, 94 | disable_inst=self.disable_inst) 95 | else: 96 | self.adj_matrix = TwoWayTemporalAdjacencyMatrix(input_dim=self.num_nodes, 97 | lag=self.lag, 98 | tau_gumbel=self.matrix_temperature, 99 | init_logits=self.init_logits, 100 | disable_inst=self.disable_inst) 101 | 102 | def forward(self): 103 | raise NotImplementedError 104 | 105 | def compute_loss_terms(self, X_history: torch.Tensor, x_current: torch.Tensor, G: torch.Tensor): 106 | 107 | # ******************* graph prior ********************* 108 | graph_sparsity_term = self.sparsity_factor * \ 109 | temporal_graph_sparsity(G) 110 | 111 | # dagness factors 112 | if self.training_procedure == 'auglag': 113 | dagness_penalty = dag_penalty_notears(G[0]) 114 | 115 | prior_term = graph_sparsity_term 116 | prior_term /= self.num_fragments 117 | 118 | # **************** graph entropy ********************** 119 | 120 | entropy_term = -self.adj_matrix.entropy()/self.num_fragments 121 | 122 | # ************* likelihood loss *********************** 123 | 124 | batch_size = X_history.shape[0] 125 | expanded_G = G.unsqueeze(0).repeat(batch_size, 1, 1, 1) 126 | 127 | X_input = torch.cat((X_history, x_current.unsqueeze(1)), dim=1) 128 | x_pred = self.causal_decoder(X_input, expanded_G) 129 | 130 | mse_loss = self.compute_mse(x_current, x_pred) 131 | mape_loss = self.compute_mape(x_current, x_pred) 132 | 133 | if self.likelihood_loss == 'mse': 134 | likelihood_term = mse_loss 135 | 136 | elif self.likelihood_loss == 'flow': 137 | batch, num_nodes, data_dim = x_current.shape 138 | log_prob = self.tcsf.log_prob(X_input=(x_current - x_pred).view(batch, num_nodes*data_dim), 139 | X_history=X_history, 140 | A=expanded_G).sum(-1) 141 | likelihood_term = -torch.mean(log_prob) 142 | 143 | loss_terms = { 144 | 'graph_sparsity': graph_sparsity_term, 145 | 'dagness_penalty': dagness_penalty, 146 | 'graph_prior': prior_term, 147 | 148 | 'graph_entropy': entropy_term, 149 | 150 | 'mse': mse_loss, 151 | 'mape': mape_loss, 152 | 'likelihood': likelihood_term 153 | } 154 | return loss_terms 155 | 156 | def compute_loss(self, X_history, x_current, idx): 157 | # sample G 158 | 159 | G = self.adj_matrix.sample_A() 160 | loss_terms = self.compute_loss_terms( 161 | X_history=X_history, 162 | x_current=x_current, 163 | G=G) 164 | 165 | total_loss = loss_terms['likelihood'] +\ 166 | loss_terms['graph_prior'] +\ 167 | loss_terms['graph_entropy'] 168 | 169 | return total_loss, loss_terms, None 170 | 171 | def training_step(self, batch, batch_idx): 172 | 173 | X_history, x_current, _, idx, _ = batch 174 | loss, loss_terms, _ = self.compute_loss(X_history, x_current, idx) 175 | 176 | loss = self.loss_calc( 177 | loss, loss_terms['dagness_penalty']/self.num_fragments) 178 | self.log_dict(loss_terms, on_epoch=True) 179 | 180 | loss_terms['loss'] = loss 181 | return loss_terms 182 | 183 | def validation_func(self, X_history, x_current, adj_matrix, G, idx): 184 | batch_size = X_history.shape[0] 185 | 186 | loss, loss_terms, _ = self.compute_loss(X_history, x_current, idx) 187 | 188 | G = G.detach().cpu().numpy() 189 | adj_matrix = adj_matrix.detach().cpu().numpy() 190 | 191 | if self.aggregated_graph: 192 | G = to_time_aggregated_graph_np(G) 193 | if self.ignore_self_connections: 194 | G = zero_out_diag_np(G) 195 | 196 | shd_loss = compute_shd(adj_matrix, G, aggregated_graph=True) 197 | shd_loss = torch.Tensor([shd_loss]) 198 | f1 = f1_score(adj_matrix.flatten(), G.flatten()) 199 | else: 200 | mask = adj_matrix != G 201 | 202 | shd_loss = np.sum(mask)/batch_size 203 | shd_inst = np.sum(mask[:, 0])/batch_size 204 | shd_lag = np.sum(mask[:, 1:])/batch_size 205 | 206 | # shd_loss, shd_inst, shd_lag = compute_shd(adj_matrix, G) 207 | tp = np.logical_and(adj_matrix == 1, adj_matrix == G) 208 | fp = np.logical_and(adj_matrix != 1, G == 1) 209 | fn = np.logical_and(adj_matrix != 0, G == 0) 210 | 211 | f1_inst = 2*np.sum(tp[:, 0]) / (2*np.sum(tp[:, 0]) + 212 | np.sum(fp[:, 0]) + np.sum(fn[:, 0])) 213 | f1_lag = 2*np.sum(tp[:, 1:]) / (2*np.sum(tp[:, 1:]) + 214 | np.sum(fp[:, 1:]) + np.sum(fn[:, 1:])) 215 | 216 | # f1_inst = f1_score(get_off_diagonal(adj_matrix[:, 0]).flatten(), get_off_diagonal(G[:, 0]).flatten()) 217 | # f1_lag = f1_score(adj_matrix[:, 1:].flatten(), G[:, 1:].flatten()) 218 | shd_loss = torch.Tensor([shd_loss]) 219 | shd_inst = torch.Tensor([shd_inst]) 220 | shd_lag = torch.Tensor([shd_lag]) 221 | 222 | if not self.aggregated_graph: 223 | loss_terms['shd_inst'] = shd_inst 224 | loss_terms['shd_lag'] = shd_lag 225 | loss_terms['f1_lag'] = f1_lag 226 | loss_terms['f1_inst'] = f1_inst 227 | else: 228 | loss_terms['f1'] = f1 229 | 230 | loss_terms['shd_loss'] = shd_loss 231 | 232 | loss_terms['val_loss'] = loss 233 | 234 | for key, item in loss_terms.items(): 235 | self.log(key, item) 236 | 237 | return loss_terms 238 | 239 | def validation_step(self, batch, batch_idx): 240 | X_history, x_current, adj_matrix, idx, _ = batch 241 | 242 | batch_size = X_history.shape[0] 243 | G = self.adj_matrix.sample_A() 244 | expanded_G = G.unsqueeze(0).repeat(batch_size, 1, 1, 1) 245 | 246 | loss_terms = self.validation_func( 247 | X_history, x_current, adj_matrix, expanded_G, idx) 248 | 249 | probs = self.adj_matrix.get_adj_matrix(do_round=False) 250 | probs = probs.unsqueeze(0).repeat(batch_size, 1, 1, 1) 251 | if self.aggregated_graph: 252 | probs = to_time_aggregated_scores_torch(probs) 253 | if self.ignore_self_connections: 254 | probs = zero_out_diag_torch(probs) 255 | 256 | return loss_terms 257 | 258 | def configure_optimizers(self): 259 | """Set the learning rates for different sets of parameters.""" 260 | modules = { 261 | "functional_relationships": self.causal_decoder, 262 | "vardist": self.adj_matrix, 263 | "noise_dist": self.tcsf, 264 | } 265 | 266 | parameter_list = [ 267 | { 268 | "params": module.parameters(), 269 | "lr": self.training_config.lr_init_dict[name], 270 | "name": name, 271 | } 272 | for name, module in modules.items() if module is not None 273 | ] 274 | 275 | # Check that all modules are added to the parameter list 276 | check_modules = set(modules.values()) 277 | for module in self.parameters(recurse=False): 278 | assert module in check_modules, f"Module {module} not in module list" 279 | 280 | return torch.optim.Adam(parameter_list) 281 | 282 | def configure_callbacks(self): 283 | """Create a callback for the auglag callback.""" 284 | if self.training_procedure == 'auglag': 285 | return [AuglagLRCallback(self.lr_scheduler, log_auglag=True, disabled_epochs=self.disabled_epochs)] 286 | return None # should not happen 287 | 288 | def predict_step(self, batch, batch_idx, dataloader_idx=0): 289 | 290 | X_full, adj_matrix, _ = batch 291 | batch_size = X_full.shape[0] 292 | 293 | probs = self.adj_matrix.get_adj_matrix( 294 | do_round=False).unsqueeze(0).repeat(batch_size, 1, 1, 1) 295 | 296 | if self.aggregated_graph: 297 | probs = to_time_aggregated_scores_torch(probs) 298 | if self.ignore_self_connections: 299 | probs = zero_out_diag_torch(probs) 300 | # G = torch.bernoulli(probs) 301 | G = (probs >= 0.5).long() 302 | return G, probs, adj_matrix 303 | -------------------------------------------------------------------------------- /src/model/generate_model.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | from hydra.utils import instantiate 3 | from src.modules.TemporalConditionalSplineFlow import TemporalConditionalSplineFlow 4 | 5 | 6 | def generate_model(cfg: DictConfig): 7 | lag = cfg.lag 8 | num_nodes = cfg.num_nodes 9 | data_dim = cfg.data_dim 10 | num_workers = cfg.num_workers 11 | aggregated_graph = cfg.aggregated_graph 12 | 13 | if cfg.model in ['pcmci', 'varlingam', 'dynotears']: 14 | trainer = instantiate(cfg.trainer, 15 | num_workers=num_workers, 16 | lag=lag, 17 | num_nodes=num_nodes, 18 | data_dim=data_dim, 19 | aggregated_graph=aggregated_graph) 20 | else: 21 | multi_graph = cfg.model == 'mcd' 22 | if multi_graph: 23 | num_graphs = cfg.trainer.num_graphs 24 | if 'decoder' in cfg: 25 | # generate the decoder 26 | if not multi_graph: 27 | causal_decoder = instantiate(cfg.decoder, 28 | lag=lag, 29 | num_nodes=num_nodes, 30 | data_dim=data_dim) 31 | else: 32 | causal_decoder = instantiate(cfg.decoder, 33 | lag=lag, 34 | num_nodes=num_nodes, 35 | data_dim=data_dim, 36 | num_graphs=num_graphs) 37 | if 'likelihood_loss' in cfg.trainer and cfg.trainer.likelihood_loss == 'flow': 38 | # create hypernet 39 | if not multi_graph: 40 | hypernet = instantiate(cfg.hypernet, 41 | lag=lag, 42 | num_nodes=num_nodes, 43 | data_dim=data_dim) 44 | else: 45 | hypernet = instantiate(cfg.hypernet, 46 | lag=lag, 47 | num_nodes=num_nodes, 48 | data_dim=data_dim, 49 | num_graphs=num_graphs) 50 | tcsf = TemporalConditionalSplineFlow(hypernet=hypernet) 51 | else: 52 | hypernet = None 53 | tcsf = None 54 | 55 | # create auglag config 56 | if cfg.trainer.training_procedure == 'auglag': 57 | training_config = instantiate(cfg.auglag_config) 58 | 59 | if cfg.model == 'rhino': 60 | trainer = instantiate(cfg.trainer, 61 | num_workers=num_workers, 62 | lag=lag, 63 | num_nodes=num_nodes, 64 | data_dim=data_dim, 65 | causal_decoder=causal_decoder, 66 | tcsf=tcsf, 67 | training_config=training_config, 68 | aggregated_graph=aggregated_graph) 69 | elif cfg.model == 'mcd': 70 | trainer = instantiate(cfg.trainer, 71 | num_workers=num_workers, 72 | lag=lag, 73 | num_nodes=num_nodes, 74 | data_dim=data_dim, 75 | num_graphs=num_graphs, 76 | causal_decoder=causal_decoder, 77 | tcsf=tcsf, 78 | training_config=training_config, 79 | aggregated_graph=aggregated_graph) 80 | 81 | return trainer 82 | -------------------------------------------------------------------------------- /src/modules/CausalDecoder.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | from torch import nn 3 | import torch 4 | from src.utils.torch_utils import generate_fully_connected 5 | 6 | 7 | class CausalDecoder(pl.LightningModule): 8 | 9 | def __init__(self, 10 | data_dim: int, 11 | lag: int, 12 | num_nodes: int, 13 | embedding_dim: int = None, 14 | skip_connection: bool = False, 15 | linear: bool = False 16 | ): 17 | super().__init__() 18 | 19 | if embedding_dim is None: 20 | embedding_dim = num_nodes * data_dim 21 | 22 | self.embedding_dim = embedding_dim 23 | self.data_dim = data_dim 24 | self.lag = lag 25 | self.num_nodes = num_nodes 26 | self.linear = linear 27 | 28 | if not self.linear: 29 | self.embeddings = nn.Parameter(( 30 | torch.randn(self.lag + 1, self.num_nodes, 31 | self.embedding_dim, device=self.device) * 0.01 32 | ), requires_grad=True) # shape (lag+1, num_nodes, embedding_dim) 33 | 34 | input_dim = 2*self.embedding_dim 35 | self.nn_size = max(4 * num_nodes, self.embedding_dim, 64) 36 | 37 | self.f = generate_fully_connected( 38 | input_dim=input_dim, 39 | output_dim=num_nodes*data_dim, # potentially num_nodes 40 | hidden_dims=[self.nn_size, self.nn_size], 41 | non_linearity=nn.LeakyReLU, 42 | activation=nn.Identity, 43 | device=self.device, 44 | normalization=nn.LayerNorm, 45 | res_connection=skip_connection, 46 | ) 47 | 48 | self.g = generate_fully_connected( 49 | input_dim=self.embedding_dim+self.data_dim, 50 | output_dim=self.embedding_dim, 51 | hidden_dims=[self.nn_size, self.nn_size], 52 | non_linearity=nn.LeakyReLU, 53 | activation=nn.Identity, 54 | device=self.device, 55 | normalization=nn.LayerNorm, 56 | res_connection=skip_connection, 57 | ) 58 | 59 | else: 60 | self.w = nn.Parameter( 61 | torch.randn(self.lag+1, self.num_nodes, self.num_nodes, device=self.device)*0.5, requires_grad=True 62 | ) 63 | 64 | def forward(self, X_input: torch.Tensor, A: torch.Tensor, embeddings: torch.Tensor = None): 65 | """ 66 | Args: 67 | X_input: input data of shape (batch, lag+1, num_nodes, data_dim) 68 | A: adjacency matrix of shape (batch, (lag+1), num_nodes, num_nodes) 69 | """ 70 | 71 | assert len(X_input.shape) == 4 72 | 73 | batch, L, num_nodes, data_dim = X_input.shape 74 | lag = L-1 75 | 76 | if not self.linear: 77 | # ensure we have the correct shape 78 | assert (A.shape[0] == batch and A.shape[1] == lag+1 and A.shape[2] == 79 | num_nodes and A.shape[2] == num_nodes) 80 | 81 | if embeddings is None: 82 | E = self.embeddings.expand( 83 | X_input.shape[0], -1, -1, -1 84 | ) 85 | else: 86 | E = embeddings 87 | 88 | X_in = torch.cat((X_input, E), dim=-1) 89 | # X_in: (batch, lag+1, num_nodes, embedding_dim+data_dim) 90 | X_enc = self.g(X_in) 91 | 92 | A_temp = A.flip([1]) 93 | 94 | # get the parents of X 95 | X_sum = torch.einsum("blij,blio->bjo", A_temp, 96 | X_enc) # / num_nodes 97 | 98 | X_sum = torch.cat([X_sum, E[:, 0, :, :]], dim=-1) 99 | # (batch, num_nodes, embedding_dim) 100 | # pass through f network to get the predictions 101 | 102 | group_mask = torch.eye(num_nodes*data_dim).to(self.device) 103 | # (batch, num_nodes, data_dim) 104 | return torch.sum(self.f(X_sum)*group_mask, dim=-1).unsqueeze(-1) 105 | return torch.einsum("lij,blio->bjo", (self.w * A[0]).flip([0]), X_input) 106 | -------------------------------------------------------------------------------- /src/modules/LinearCausalGraph.py: -------------------------------------------------------------------------------- 1 | 2 | import lightning.pytorch as pl 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class LinearCausalGraph(pl.LightningModule): 8 | 9 | def __init__( 10 | self, 11 | lag: int, 12 | input_dim: int 13 | ): 14 | """ 15 | Args: 16 | input_dim: dimension. 17 | tau_gumbel: temperature used for gumbel softmax sampling. 18 | """ 19 | super().__init__() 20 | 21 | self.lag = lag 22 | self.w = nn.Parameter( 23 | torch.zeros(self.lag+1, input_dim, input_dim, device=self.device), requires_grad=True 24 | ) 25 | self.I = torch.arange(input_dim) 26 | self.mask = torch.ones((self.lag+1, input_dim, input_dim)) 27 | self.mask[0, self.I, self.I] = 0 28 | self.input_dim = input_dim 29 | 30 | def get_w(self) -> torch.Tensor: 31 | """ 32 | Returns the matrix. Ensures that the instantaneous matrix has zero in the diagonals 33 | """ 34 | return self.w * self.mask.to(self.device) 35 | -------------------------------------------------------------------------------- /src/modules/MixtureSelectionLogits.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | from torch import nn 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.distributions as td 6 | 7 | 8 | class MixtureSelectionLogits(pl.LightningModule): 9 | 10 | def __init__( 11 | self, 12 | num_samples: int, 13 | num_graphs: int, 14 | tau: float = 1.0 15 | ): 16 | 17 | super().__init__() 18 | self.num_graphs = num_graphs 19 | self.num_samples = num_samples 20 | self.graph_select_logits = nn.Parameter(( 21 | torch.ones(self.num_graphs, self.num_samples, 22 | device=self.device) * 0.01 23 | ), requires_grad=True) 24 | self.tau = tau 25 | 26 | def manual_set_mixture_indices(self, idx, mixture_idx): 27 | """ 28 | Use this function to manually set the mixture index. 29 | Mainly used for diagnostic/ablative purposes 30 | """ 31 | 32 | with torch.no_grad(): 33 | self.graph_select_logits[:, idx] = -10 34 | self.graph_select_logits[mixture_idx, idx] = 10 35 | self.graph_select_logits.requires_grad_(False) 36 | 37 | def set_logits(self, idx, logits): 38 | """ 39 | Use this function to manually set the logits. 40 | Used in the E step of the EM implementation 41 | """ 42 | with torch.no_grad(): 43 | self.graph_select_logits[:, idx] = logits.transpose(0, -1) 44 | 45 | def reset_parameters(self): 46 | with torch.no_grad(): 47 | self.graph_select_logits[:] = torch.ones(self.num_graphs, 48 | self.num_samples, 49 | device=self.device) * 0.01 50 | 51 | def turn_off_grad(self): 52 | self.graph_select_logits.requires_grad_(False) 53 | 54 | def turn_on_grad(self): 55 | self.graph_select_logits.requires_grad_(True) 56 | 57 | def get_probs(self, idx): 58 | return F.softmax(self.graph_select_logits[:, idx]/self.tau, dim=0) 59 | 60 | def get_mixture_indices(self, idx): 61 | return torch.argmax(self.graph_select_logits[:, idx], dim=0) 62 | 63 | def entropy(self, idx): 64 | logits = self.graph_select_logits[:, idx]/self.tau 65 | dist = td.Categorical(logits=logits.transpose(0, -1)) 66 | entropy = dist.entropy().sum() 67 | 68 | return entropy/idx.shape[0] 69 | -------------------------------------------------------------------------------- /src/modules/MultiCausalDecoder.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | from torch import nn 3 | import torch 4 | from src.modules.MultiEmbedding import MultiEmbedding 5 | from src.utils.torch_utils import generate_fully_connected 6 | 7 | 8 | class MultiCausalDecoder(pl.LightningModule): 9 | 10 | def __init__(self, 11 | data_dim: int, 12 | lag: int, 13 | num_nodes: int, 14 | num_graphs: int, 15 | embedding_dim: int = None, 16 | skip_connection: bool = False, 17 | linear: bool = False, 18 | dropout_p: float = 0.0 19 | ): 20 | 21 | super().__init__() 22 | 23 | if embedding_dim is not None: 24 | self.embedding_dim = embedding_dim 25 | else: 26 | self.embedding_dim = num_nodes * data_dim 27 | 28 | self.data_dim = data_dim 29 | self.lag = lag 30 | self.num_nodes = num_nodes 31 | self.num_graphs = num_graphs 32 | self.dropout_p = dropout_p 33 | self.linear = linear 34 | 35 | input_dim = 2*self.embedding_dim 36 | 37 | if not self.linear: 38 | self.nn_size = max(4 * num_nodes, self.embedding_dim, 64) 39 | 40 | self.f = generate_fully_connected( 41 | input_dim=input_dim, 42 | output_dim=num_nodes*data_dim, # potentially num_nodes 43 | hidden_dims=[self.nn_size, self.nn_size], 44 | non_linearity=nn.LeakyReLU, 45 | activation=nn.Identity, 46 | device=self.device, 47 | normalization=nn.LayerNorm, 48 | res_connection=skip_connection, 49 | ) 50 | 51 | self.g = generate_fully_connected( 52 | input_dim=self.embedding_dim+self.data_dim, 53 | output_dim=self.embedding_dim, 54 | hidden_dims=[self.nn_size, self.nn_size], 55 | non_linearity=nn.LeakyReLU, 56 | activation=nn.Identity, 57 | device=self.device, 58 | normalization=nn.LayerNorm, 59 | res_connection=skip_connection, 60 | ) 61 | 62 | self.cd_embeddings = MultiEmbedding(num_nodes=self.num_nodes, 63 | lag=self.lag, 64 | num_graphs=self.num_graphs, 65 | embedding_dim=self.embedding_dim) 66 | 67 | else: 68 | self.w = nn.Parameter( 69 | torch.randn(self.num_graphs, self.lag+1, self.num_nodes, self.num_nodes, device=self.device)*0.5, requires_grad=True 70 | ) 71 | 72 | def forward(self, X_input: torch.Tensor, A: torch.Tensor): 73 | """ 74 | Args: 75 | X_input: input data of shape (batch, lag+1, num_nodes, data_dim) 76 | A: adjacency matrix of shape (num_graphs, lag+1, num_nodes, num_nodes) 77 | """ 78 | 79 | assert len(X_input.shape) == 4 80 | 81 | batch, L, num_nodes, data_dim = X_input.shape 82 | lag = L-1 83 | 84 | if not self.linear: 85 | E = self.cd_embeddings.get_embeddings() 86 | 87 | # reshape X to the correct shape 88 | A = A.unsqueeze(0).expand((batch, -1, -1, -1, -1)) 89 | E = E.unsqueeze(0).expand((batch, -1, -1, -1, -1)) 90 | X_input = X_input.unsqueeze(1).expand( 91 | (-1, self.num_graphs, -1, -1, -1)) 92 | 93 | # ensure we have the correct shape 94 | assert (A.shape[0] == batch and A.shape[1] == self.num_graphs and 95 | A.shape[2] == lag + 1 and A.shape[3] == num_nodes and 96 | A.shape[4] == num_nodes) 97 | assert (E.shape[0] == batch and E.shape[1] == self.num_graphs 98 | and E.shape[2] == lag+1 and E.shape[3] == num_nodes 99 | and E.shape[4] == self.embedding_dim) 100 | 101 | X_in = torch.cat((X_input, E), dim=-1) 102 | # X_in: (batch, lag+1, num_nodes, embedding_dim+data_dim) 103 | X_enc = self.g(X_in) 104 | A_temp = A.flip([2]) 105 | # get the parents of X 106 | X_sum = torch.einsum("bnlij,bnlio->bnjo", A_temp, X_enc) 107 | 108 | X_sum = torch.cat([X_sum, E[:, :, 0, :, :]], dim=-1) 109 | # (batch, num_graphs, num_nodes, embedding_dim) 110 | # pass through f network to get the predictions 111 | group_mask = torch.eye(num_nodes*data_dim).to(self.device) 112 | 113 | # (batch, num_graphs, num_nodes, data_dim) 114 | return torch.sum(self.f(X_sum)*group_mask, dim=-1).unsqueeze(-1) 115 | 116 | return torch.einsum("klij,blio->bkjo", (self.w * A).flip([1]), X_input) 117 | # return self.f(X_sum) 118 | -------------------------------------------------------------------------------- /src/modules/MultiEmbedding.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | from torch import nn 3 | import torch 4 | 5 | 6 | class MultiEmbedding(pl.LightningModule): 7 | 8 | def __init__( 9 | self, 10 | num_nodes: int, 11 | lag: int, 12 | num_graphs: int, 13 | embedding_dim: int 14 | ): 15 | 16 | super().__init__() 17 | self.lag = lag 18 | # Assertion lag > 0 19 | assert lag > 0 20 | self.num_nodes = num_nodes 21 | self.num_graphs = num_graphs 22 | self.embedding_dim = embedding_dim 23 | 24 | self.lag_embeddings = nn.Parameter(( 25 | torch.randn(self.num_graphs, self.lag, self.num_nodes, 26 | self.embedding_dim, device=self.device) * 0.01 27 | ), requires_grad=True) 28 | 29 | self.inst_embeddings = nn.Parameter(( 30 | torch.randn(self.num_graphs, 1, self.num_nodes, 31 | self.embedding_dim, device=self.device) * 0.01 32 | ), requires_grad=True) 33 | 34 | def turn_off_inst_grad(self): 35 | self.inst_embeddings.requires_grad_(False) 36 | 37 | def turn_on_inst_grad(self): 38 | self.inst_embeddings.requires_grad_(True) 39 | 40 | def get_embeddings(self): 41 | return torch.cat((self.inst_embeddings, self.lag_embeddings), dim=1) 42 | -------------------------------------------------------------------------------- /src/modules/MultiLinearCausalGraph.py: -------------------------------------------------------------------------------- 1 | 2 | import lightning.pytorch as pl 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class MultiLinearCausalGraph(pl.LightningModule): 8 | 9 | def __init__( 10 | self, 11 | lag: int, 12 | input_dim: int, 13 | num_graphs: int 14 | ): 15 | """ 16 | Args: 17 | input_dim: dimension. 18 | tau_gumbel: temperature used for gumbel softmax sampling. 19 | """ 20 | super().__init__() 21 | 22 | self.num_graphs = num_graphs 23 | self.lag = lag 24 | self.w = nn.Parameter( 25 | torch.randn(self.num_graphs, self.lag+1, input_dim, input_dim, device=self.device)*0.5, requires_grad=True 26 | ) 27 | self.I = torch.arange(input_dim) 28 | self.mask = torch.ones( 29 | (self.num_graphs, self.lag+1, input_dim, input_dim)) 30 | self.mask[:, 0, self.I, self.I] = 0 31 | self.input_dim = input_dim 32 | 33 | def get_w(self) -> torch.Tensor: 34 | """ 35 | Returns the matrix. Ensures that the instantaneous matrix has zero in the diagonals 36 | """ 37 | return self.w * self.mask.to(self.device) 38 | -------------------------------------------------------------------------------- /src/modules/MultiTemporalHyperNet.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import lightning.pytorch as pl 3 | from torch import nn 4 | import torch 5 | from src.modules.MultiEmbedding import MultiEmbedding 6 | from src.utils.torch_utils import generate_fully_connected 7 | 8 | 9 | class MultiTemporalHyperNet(pl.LightningModule): 10 | 11 | def __init__(self, 12 | order: str, 13 | lag: int, 14 | data_dim: int, 15 | num_nodes: int, 16 | num_graphs: int, 17 | embedding_dim: int = None, 18 | skip_connection: bool = False, 19 | num_bins: int = 8, 20 | dropout_p: float = 0.0 21 | ): 22 | 23 | super().__init__() 24 | 25 | if embedding_dim is not None: 26 | self.embedding_dim = embedding_dim 27 | else: 28 | self.embedding_dim = num_nodes * data_dim 29 | 30 | self.data_dim = data_dim 31 | self.lag = lag 32 | self.order = order 33 | self.num_bins = num_bins 34 | self.num_nodes = num_nodes 35 | self.num_graphs = num_graphs 36 | self.dropout_p = dropout_p 37 | 38 | if self.order == "quadratic": 39 | self.param_dim = [ 40 | self.num_bins, 41 | self.num_bins, 42 | (self.num_bins - 1), 43 | ] # this is for quadratic order conditional spline flow 44 | elif self.order == "linear": 45 | self.param_dim = [ 46 | self.num_bins, 47 | self.num_bins, 48 | (self.num_bins - 1), 49 | self.num_bins, 50 | ] # this is for linear order conditional spline flow 51 | 52 | self.total_param = sum(self.param_dim) 53 | input_dim = 2*self.embedding_dim 54 | 55 | self.nn_size = max(4 * num_nodes, self.embedding_dim, 64) 56 | 57 | self.f = generate_fully_connected( 58 | input_dim=input_dim, 59 | output_dim=self.total_param, # potentially num_nodes 60 | hidden_dims=[self.nn_size, self.nn_size], 61 | non_linearity=nn.LeakyReLU, 62 | activation=nn.Identity, 63 | device=self.device, 64 | normalization=nn.LayerNorm, 65 | res_connection=skip_connection, 66 | ) 67 | 68 | self.g = generate_fully_connected( 69 | input_dim=self.embedding_dim+self.data_dim, 70 | output_dim=self.embedding_dim, 71 | hidden_dims=[self.nn_size, self.nn_size], 72 | non_linearity=nn.LeakyReLU, 73 | activation=nn.Identity, 74 | device=self.device, 75 | normalization=nn.LayerNorm, 76 | res_connection=skip_connection, 77 | ) 78 | 79 | self.th_embeddings = MultiEmbedding(num_nodes=self.num_nodes, 80 | lag=self.lag, 81 | num_graphs=self.num_graphs, 82 | embedding_dim=self.embedding_dim) 83 | 84 | def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]: 85 | """ 86 | Args: 87 | X: A dict consisting of two keys, "A" is the adjacency matrix with shape [batch, lag+1, num_nodes, num_nodes] 88 | and "X" is the history data with shape (batch, lag, num_nodes, data_dim). 89 | 90 | Returns: 91 | A tuple of parameters with shape [N_batch, num_cts_node*param_dim_each]. 92 | The length of tuple is len(self.param_dim), 93 | """ 94 | 95 | # assert "A" in X and "X" in X and len( 96 | # X) == 2, "The key for input can only contain two keys, 'A', 'X'." 97 | 98 | A = X["A"] 99 | X_in = X["X"] 100 | 101 | E = self.th_embeddings.get_embeddings() 102 | # X["embeddings"] 103 | 104 | batch, lag, num_nodes, _ = X_in.shape 105 | 106 | # reshape X to the correct shape 107 | A = A.unsqueeze(0).expand((batch, -1, -1, -1, -1)) 108 | E = E.unsqueeze(0).expand((batch, -1, -1, -1, -1)) 109 | X_in = X_in.unsqueeze(1).expand((-1, self.num_graphs, -1, -1, -1)) 110 | 111 | # ensure we have the correct shape 112 | assert (A.shape[0] == batch and A.shape[1] == self.num_graphs and 113 | A.shape[2] == lag + 1 and A.shape[3] == num_nodes and 114 | A.shape[4] == num_nodes) 115 | assert (E.shape[0] == batch and E.shape[1] == self.num_graphs 116 | and E.shape[2] == lag+1 and E.shape[3] == num_nodes 117 | and E.shape[4] == self.embedding_dim) 118 | 119 | # shape [batch_size, num_graphs, lag, num_nodes, embedding_size] 120 | E_lag = E[:, :, 1:, :, :] 121 | 122 | # reshape X to the correct shape 123 | X_in = torch.cat((X_in, E_lag), dim=-1) 124 | 125 | # X_in: (batch, num_graphs, lag, num_nodes, embedding_dim + data_dim) 126 | 127 | X_enc = self.g(X_in) # (batch, lag, num_nodes, embedding_dim) 128 | 129 | # get the parents of X 130 | # (batch, num_graphs, num_nodes, embedding_dim) 131 | A_temp = A[:, :, 1:].flip([2]) 132 | 133 | X_sum = torch.einsum("bnlij,bnlio->bnjo", A_temp, X_enc) # / num_nodes 134 | 135 | X_sum = torch.cat((X_sum, E[:, :, 0, :, :]), dim=-1) 136 | 137 | # pass through f network to get the parameters 138 | params = self.f(X_sum) # (batch, num_graphs, num_nodes, total_params) 139 | 140 | # pass multiple graph parameters along batch dimension 141 | params = params.reshape(batch*self.num_graphs, num_nodes, -1) 142 | 143 | param_list = torch.split(params, self.param_dim, dim=-1) 144 | # a list of tensor with shape [batch*num_graphs, num_nodes*each_param] 145 | return tuple( 146 | param.reshape([-1, num_nodes * param.shape[-1]]) for param in param_list) 147 | -------------------------------------------------------------------------------- /src/modules/TemporalConditionalSplineFlow.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | from torch import nn 3 | import torch 4 | 5 | import pyro.distributions as distrib 6 | from pyro.distributions.transforms.spline import ConditionalSpline 7 | from pyro.distributions import constraints 8 | from pyro.distributions.torch_transform import TransformModule 9 | 10 | 11 | class AffineDiagonalPyro(TransformModule): 12 | """ 13 | This creates a diagonal affine transformation compatible with pyro transforms 14 | """ 15 | 16 | domain = constraints.real 17 | codomain = constraints.real 18 | bijective = True 19 | 20 | def __init__(self, input_dim: int): 21 | super().__init__(cache_size=1) 22 | self.dim = input_dim 23 | self.a = nn.Parameter(torch.ones(input_dim), requires_grad=True) 24 | self.b = nn.Parameter(torch.zeros(input_dim), requires_grad=True) 25 | 26 | def _call(self, x: torch.Tensor) -> torch.Tensor: 27 | """ 28 | Forward method 29 | Args: 30 | x: tensor with shape [batch, input_dim] 31 | Returns: 32 | Transformed inputs 33 | """ 34 | return self.a.exp().unsqueeze(0) * x + self.b.unsqueeze(0) 35 | 36 | def _inverse(self, y: torch.Tensor) -> torch.Tensor: 37 | """ 38 | Reverse method 39 | Args: 40 | y: tensor with shape [batch, input] 41 | Returns: 42 | Reversed input 43 | """ 44 | return (-self.a).exp().unsqueeze(0) * (y - self.b.unsqueeze(0)) 45 | 46 | def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 47 | _, _ = x, y 48 | return self.a.unsqueeze(0) 49 | 50 | 51 | class TemporalConditionalSplineFlow(pl.LightningModule): 52 | 53 | def __init__(self, 54 | hypernet 55 | ): 56 | super().__init__() 57 | 58 | self.hypernet = hypernet 59 | self.num_bins = self.hypernet.num_bins 60 | self.order = self.hypernet.order 61 | 62 | def log_prob(self, 63 | X_input: torch.Tensor, 64 | X_history: torch.Tensor, 65 | A: torch.Tensor, 66 | embeddings: torch.Tensor = None): 67 | """ 68 | Args: 69 | X_input: input data of shape (batch, num_nodes, data_dim) 70 | X_history: input data of shape (batch, lag, num_nodes, data_dim) 71 | A: adjacency matrix of shape (batch, lag+1, num_nodes, num_nodes) 72 | embeddings: embeddings (batch, lag+1, num_nodes, embedding_dim) 73 | """ 74 | 75 | assert len(X_history.shape) == 4 76 | 77 | _, _, num_nodes, data_dim = X_history.shape 78 | 79 | # if not self.trainable_embeddings: 80 | transform = nn.ModuleList( 81 | [ 82 | ConditionalSpline( 83 | self.hypernet, input_dim=num_nodes*data_dim, count_bins=self.num_bins, order=self.order, bound=5.0 84 | ) 85 | # AffineDiagonalPyro(input_dim=self.num_nodes*self.data_dim), 86 | # Spline(input_dim=self.num_nodes*self.data_dim, count_bins=self.num_bins, order="quadratic", bound=5.0), 87 | # AffineDiagonalPyro(input_dim=self.num_nodes*self.data_dim), 88 | # Spline(input_dim=self.num_nodes*self.data_dim, count_bins=self.num_bins, order="quadratic", bound=5.0), 89 | # AffineDiagonalPyro(input_dim=self.num_nodes*self.data_dim) 90 | ] 91 | ) 92 | base_dist = distrib.Normal( 93 | torch.zeros(num_nodes*data_dim, device=self.device), torch.ones( 94 | num_nodes*data_dim, device=self.device) 95 | ) 96 | # else: 97 | 98 | context_dict = {"X": X_history, "A": A, "embeddings": embeddings} 99 | flow_dist = distrib.ConditionalTransformedDistribution( 100 | base_dist, transform).condition(context_dict) 101 | return flow_dist.log_prob(X_input) 102 | 103 | def sample(self, 104 | N_samples: int, 105 | X_history: torch.Tensor, 106 | W: torch.Tensor, 107 | embeddings: torch.Tensor): 108 | assert len(X_history.shape) == 4 109 | 110 | batch, _, num_nodes, data_dim = X_history.shape 111 | 112 | base_dist = distrib.Normal( 113 | torch.zeros(num_nodes*data_dim, device=self.device), torch.ones( 114 | num_nodes*data_dim, device=self.device) 115 | ) 116 | 117 | context_dict = {"X": X_history, "A": W, "embeddings": embeddings} 118 | flow_dist = distrib.ConditionalTransformedDistribution( 119 | base_dist, self.transform).condition(context_dict) 120 | return flow_dist.sample([N_samples, batch]) 121 | -------------------------------------------------------------------------------- /src/modules/TemporalHyperNet.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import lightning.pytorch as pl 3 | from torch import nn 4 | import torch 5 | from src.utils.torch_utils import generate_fully_connected 6 | 7 | class TemporalHyperNet(pl.LightningModule): 8 | 9 | def __init__(self, 10 | order: str, 11 | lag: int, 12 | data_dim: int, 13 | num_nodes: int, 14 | embedding_dim: int = None, 15 | skip_connection: bool = False, 16 | num_bins: int = 8): 17 | super().__init__() 18 | 19 | if embedding_dim is None: 20 | embedding_dim = num_nodes * data_dim 21 | 22 | self.embedding_dim = embedding_dim 23 | self.data_dim = data_dim 24 | self.lag = lag 25 | self.order = order 26 | self.num_bins = num_bins 27 | self.num_nodes = num_nodes 28 | 29 | if self.order == "quadratic": 30 | self.param_dim = [ 31 | self.num_bins, 32 | self.num_bins, 33 | (self.num_bins - 1), 34 | ] # this is for quadratic order conditional spline flow 35 | elif self.order == "linear": 36 | self.param_dim = [ 37 | self.num_bins, 38 | self.num_bins, 39 | (self.num_bins - 1), 40 | self.num_bins, 41 | ] # this is for linear order conditional spline flow 42 | 43 | self.total_param = sum(self.param_dim) 44 | input_dim = 2*self.embedding_dim 45 | self.nn_size = max(4 * num_nodes, self.embedding_dim, 64) 46 | 47 | self.f = generate_fully_connected( 48 | input_dim=input_dim, 49 | output_dim=self.total_param, # potentially num_nodes 50 | hidden_dims=[self.nn_size, self.nn_size], 51 | non_linearity=nn.LeakyReLU, 52 | activation=nn.Identity, 53 | device=self.device, 54 | normalization=nn.LayerNorm, 55 | res_connection=skip_connection, 56 | ) 57 | 58 | self.g = generate_fully_connected( 59 | input_dim=self.embedding_dim+self.data_dim, 60 | output_dim=self.embedding_dim, 61 | hidden_dims=[self.nn_size, self.nn_size], 62 | non_linearity=nn.LeakyReLU, 63 | activation=nn.Identity, 64 | device=self.device, 65 | normalization=nn.LayerNorm, 66 | res_connection=skip_connection, 67 | ) 68 | 69 | self.embeddings = nn.Parameter(( 70 | torch.randn(self.lag + 1, self.num_nodes, 71 | self.embedding_dim, device=self.device) * 0.01 72 | ), requires_grad=True) # shape (lag+1, num_nodes, embedding_dim) 73 | 74 | def forward(self, X: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, ...]: 75 | """ 76 | Args: 77 | X: A dict consisting of two keys, "A" is the adjacency matrix with shape [batch, lag+1, num_nodes, num_nodes] 78 | and "X" is the history data with shape (batch, lag, num_nodes, data_dim). 79 | 80 | Returns: 81 | A tuple of parameters with shape [N_batch, num_cts_node*param_dim_each]. 82 | The length of tuple is len(self.param_dim), 83 | """ 84 | 85 | # assert "A" in X and "X" in X and len( 86 | # X) == 2, "The key for input can only contain two keys, 'A', 'X'." 87 | 88 | A = X["A"] 89 | X_in = X["X"] 90 | embeddings = X["embeddings"] 91 | batch, lag, num_nodes, _ = X_in.shape 92 | 93 | # ensure we have the correct shape 94 | assert (A.shape[0] == batch and A.shape[1] == lag + 95 | 1 and A.shape[2] == num_nodes and A.shape[3] == num_nodes) 96 | 97 | if embeddings is None: 98 | E = self.embeddings.expand( 99 | X_in.shape[0], -1, -1, -1 100 | ) 101 | else: 102 | E = embeddings 103 | 104 | # shape [batch_size, lag, num_nodes, embedding_size] 105 | E_lag = E[..., 1:, :, :] 106 | 107 | X_in = torch.cat((X_in, E_lag), dim=-1) 108 | # X_in: (batch, lag, num_nodes, embedding_dim + data_dim) 109 | 110 | X_enc = self.g(X_in) # (batch, lag, num_nodes, embedding_dim) 111 | 112 | # get the parents of X 113 | # (batch, num_nodes, embedding_dim) 114 | A_temp = A[:, 1:].flip([1]) 115 | 116 | X_sum = torch.einsum("blij,blio->bjo", A_temp, X_enc) # / num_nodes 117 | 118 | X_sum = torch.cat((X_sum, E[..., 0, :, :]), dim=-1) 119 | 120 | # pass through f network to get the parameters 121 | params = self.f(X_sum) # (batch, num_nodes, total_params) 122 | 123 | param_list = torch.split(params, self.param_dim, dim=-1) 124 | # a list of tensor with shape [batch, num_nodes*each_param] 125 | return tuple( 126 | param.reshape([-1, num_nodes * param.shape[-1]]) for param in param_list) 127 | -------------------------------------------------------------------------------- /src/modules/adjacency_matrices/AdjMatrix.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | 4 | 5 | class AdjMatrix(ABC): 6 | """ 7 | Adjacency matrix interface for DECI 8 | """ 9 | 10 | @abstractmethod 11 | def get_adj_matrix(self, do_round: bool = True) -> torch.Tensor: 12 | """ 13 | Returns the adjacency matrix. 14 | """ 15 | raise NotImplementedError() 16 | 17 | @abstractmethod 18 | def entropy(self) -> torch.Tensor: 19 | """ 20 | Computes the entropy of distribution q. In this case 0. 21 | """ 22 | raise NotImplementedError() 23 | 24 | @abstractmethod 25 | def sample_A(self) -> torch.Tensor: 26 | """ 27 | Returns the adjacency matrix. 28 | """ 29 | raise NotImplementedError() 30 | -------------------------------------------------------------------------------- /src/modules/adjacency_matrices/MultiTemporalAdjacencyMatrix.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import lightning.pytorch as pl 4 | import torch 5 | import torch.distributions as td 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from src.modules.adjacency_matrices.AdjMatrix import AdjMatrix 9 | 10 | 11 | class MultiTemporalAdjacencyMatrix(pl.LightningModule, AdjMatrix): 12 | def __init__( 13 | self, 14 | num_nodes: int, 15 | lag: int, 16 | num_graphs: int, 17 | tau_gumbel: float = 1.0, 18 | threeway: bool = True, 19 | init_logits: Optional[List[float]] = None, 20 | disable_inst: bool = False 21 | ): 22 | super().__init__() 23 | self.lag = lag 24 | self.tau_gumbel = tau_gumbel 25 | self.threeway = threeway 26 | self.num_nodes = num_nodes 27 | # Assertion lag > 0 28 | assert lag > 0 29 | self.num_graphs = num_graphs 30 | self.disable_inst = disable_inst 31 | 32 | if self.threeway: 33 | self.logits_inst = nn.Parameter( 34 | torch.zeros((3, self.num_graphs, (num_nodes * (num_nodes - 1)) // 2), 35 | device=self.device), 36 | requires_grad=True 37 | ) 38 | self.lower_idxs = torch.unbind( 39 | torch.tril_indices(self.num_nodes, self.num_nodes, 40 | offset=-1, device=self.device), 0 41 | ) 42 | else: 43 | self.logits_inst = nn.Parameter( 44 | torch.zeros((2, self.num_graphs, num_nodes, num_nodes), 45 | device=self.device), 46 | requires_grad=True 47 | ) 48 | 49 | self.logits_lag = nn.Parameter(torch.zeros((2, self.num_graphs, lag, num_nodes, num_nodes), 50 | device=self.device), 51 | requires_grad=True) 52 | self.init_logits = init_logits 53 | # Set the init_logits if not None 54 | if self.init_logits is not None: 55 | if self.threeway: 56 | self.logits_inst.data[2, :, ...] = self.init_logits[0] 57 | else: 58 | self.logits_inst.data[1, :, ...] = self.init_logits[0] 59 | self.logits_lag.data[0, :, ...] = self.init_logits[1] 60 | 61 | def zero_out_diagonal(self, matrix: torch.Tensor): 62 | # matrix: (num_graphs, num_nodes, num_nodes) 63 | N = matrix.shape[1] 64 | I = torch.arange(N).to(self.device) 65 | matrix = matrix.clone() 66 | matrix[:, I, I] = 0 67 | return matrix 68 | 69 | def _triangular_vec_to_matrix(self, vec): 70 | """ 71 | Given an array of shape (k, N, n(n-1)/2) where k in {2, 3}, creates a matrix of shape 72 | (N, n, n) where the lower triangular is filled from vec[0, :] and the upper 73 | triangular is filled from vec[1, :]. 74 | """ 75 | N = vec.shape[1] 76 | output = torch.zeros( 77 | (N, self.num_nodes, self.num_nodes), device=self.device) 78 | output[:, self.lower_idxs[0], self.lower_idxs[1]] = vec[0, :, ...] 79 | output[:, self.lower_idxs[1], self.lower_idxs[0]] = vec[1, :, ...] 80 | return output 81 | 82 | def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: 83 | """ 84 | Returns the adjacency matrix. 85 | """ 86 | probs = torch.zeros((self.num_graphs, self.lag + 1, self.num_nodes, self.num_nodes), 87 | device=self.device) 88 | 89 | if not self.disable_inst: 90 | inst_probs = F.softmax(self.logits_inst, dim=0) 91 | if self.threeway: 92 | # (3, n(n-1)/2) probabilities 93 | inst_probs = self._triangular_vec_to_matrix(inst_probs) 94 | else: 95 | inst_probs = self.zero_out_diagonal(inst_probs[1, ...]) 96 | 97 | # Generate simultaneous adj matrix 98 | # shape (input_dim, input_dim) 99 | probs[:, 0, ...] = inst_probs 100 | 101 | # Generate lagged adj matrix 102 | # shape (lag, input_dim, input_dim) 103 | probs[:, 1:, ...] = F.softmax(self.logits_lag, dim=0)[1, ...] 104 | if do_round: 105 | return probs.round() 106 | 107 | return probs 108 | 109 | def entropy(self) -> torch.Tensor: 110 | """ 111 | Computes the entropy of distribution q. In this case 0. 112 | """ 113 | 114 | if not self.disable_inst: 115 | if self.threeway: 116 | dist = td.Categorical( 117 | logits=self.logits_inst[:, :].transpose(0, -1)) 118 | entropies_inst = dist.entropy().sum() 119 | else: 120 | dist = td.Categorical( 121 | logits=self.logits_inst[1, ...] - self.logits_inst[0, ...]) 122 | I = torch.arange(self.num_nodes) 123 | dist_diag = td.Categorical( 124 | logits=self.logits_inst[1, :, I, I] - self.logits_inst[0, :, I, I]) 125 | entropies = dist.entropy() 126 | diag_entropy = dist_diag.entropy() 127 | entropies_inst = entropies.sum() - diag_entropy.sum() 128 | else: 129 | entropies_inst = 0 130 | 131 | dist_lag = td.Independent(td.Bernoulli( 132 | logits=self.logits_lag[1, :] - self.logits_lag[0, :]), 3) 133 | entropies_lag = dist_lag.entropy().sum() 134 | 135 | return entropies_lag + entropies_inst 136 | 137 | def sample_A(self) -> torch.Tensor: 138 | """ 139 | This samples the adjacency matrix from the variational distribution. This uses the gumbel softmax trick and returns 140 | hard samples. This can be done by (1) sample instantaneous adj matrix using self.logits, (2) sample lagged adj matrix using self.logits_lag. 141 | """ 142 | # Create adj matrix to avoid concatenation 143 | adj_sample = torch.zeros( 144 | (self.num_graphs, self.lag + 1, self.num_nodes, self.num_nodes), device=self.device 145 | ) # shape ( lag+1, input_dim, input_dim) 146 | 147 | if not self.disable_inst: 148 | if self.threeway: 149 | # Sample instantaneous adj matrix 150 | adj_sample[:, 0, ...] = self._triangular_vec_to_matrix( 151 | F.gumbel_softmax(self.logits_inst, 152 | tau=self.tau_gumbel, 153 | hard=True, 154 | dim=0) 155 | ) # shape (N, input_dim, input_dim) 156 | else: 157 | sample = F.gumbel_softmax( 158 | self.logits_inst, tau=self.tau_gumbel, hard=True, dim=0)[1, ...] 159 | adj_sample[:, 0, ...] = self.zero_out_diagonal(sample) 160 | 161 | # Sample lagged adj matrix 162 | # shape (N, lag, input_dim, input_dim) 163 | adj_sample[:, 1:, ...] = F.gumbel_softmax(self.logits_lag, 164 | tau=self.tau_gumbel, 165 | hard=True, 166 | dim=0)[1, ...] 167 | return adj_sample 168 | 169 | def turn_off_inst_grad(self): 170 | self.logits_inst.requires_grad_(False) 171 | 172 | def turn_on_inst_grad(self): 173 | self.logits_inst.requires_grad_(True) 174 | -------------------------------------------------------------------------------- /src/modules/adjacency_matrices/TemporalAdjacencyMatrix.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.distributions as td 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from src.modules.adjacency_matrices.ThreeWayGraphDist import ThreeWayGraphDist 8 | 9 | 10 | class TemporalAdjacencyMatrix(ThreeWayGraphDist): 11 | """ 12 | This class adapts the ThreeWayGraphDist s.t. it supports the variational distributions for temporal adjacency matrix. 13 | 14 | The principle is to follow the logic as ThreeWayGraphDist. The implementation has two separate part: 15 | (1) categorical distribution for instantaneous adj (see ThreeWayGraphDist); (2) Bernoulli distribution for lagged 16 | adj. Note that for lagged adj, we do not need to follow the logic from ThreeWayGraphDist, since lagged adj allows diagonal elements 17 | and does not have to be a DAG. Therefore, it is simpler to directly model it with Bernoulli distribution. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | input_dim: int, 23 | lag: int, 24 | tau_gumbel: float = 1.0, 25 | disable_inst: bool = False, 26 | init_logits: Optional[List[float]] = None, 27 | ): 28 | """ 29 | This creates an instance of variational distribution for temporal adjacency matrix. 30 | Args: 31 | input_dim: The number of nodes for adjacency matrix. 32 | lag: The lag for the temporal adj matrix. The adj matrix has the shape (lag+1, num_nodes, num_nodes). 33 | tau_gumbel: The temperature for the gumbel softmax sampling. 34 | init_logits: The initialized logits value. If None, then use the default initlized logits (value 0). Otherwise, 35 | init_logits[0] indicates the non-existence edge logit for instantaneous effect, and init_logits[1] indicates the 36 | non-existence edge logit for lagged effect. E.g. if we want a dense initialization, one choice is (-7, -0.5) 37 | """ 38 | # Call parent init method 39 | super().__init__(input_dim=input_dim, tau_gumbel=tau_gumbel) 40 | # Create a separate logit for lagged adj 41 | # The logits_lag are initialized to zero with shape (2, lag, input_dim, input_dim). 42 | # logits_lag[0,...] indicates the logit prob for no edges, and logits_lag[1,...] indicates the logit for edge existence. 43 | self.lag = lag 44 | # Assertion lag > 0 45 | assert lag > 0 46 | self.logits_lag = nn.Parameter(torch.zeros( 47 | (2, lag, input_dim, input_dim), device=self.device), requires_grad=True) 48 | self.init_logits = init_logits 49 | self.disable_inst = disable_inst 50 | # Set the init_logits if not None 51 | if self.init_logits is not None: 52 | self.logits.data[2, ...] = self.init_logits[0] 53 | self.logits_lag.data[0, ...] = self.init_logits[1] 54 | 55 | def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: 56 | """ 57 | This returns the temporal adjacency matrix of edge probability. 58 | Args: 59 | do_round: Whether to round the edge probabilities. 60 | 61 | Returns: 62 | The adjacency matrix with shape [lag+1, num_nodes, num_nodes]. 63 | """ 64 | 65 | # Create the temporal adj matrix 66 | probs = torch.zeros(self.lag + 1, self.input_dim, 67 | self.input_dim, device=self.device) 68 | # Generate simultaneous adj matrix 69 | if not self.disable_inst: 70 | probs[0, ...] = super().get_adj_matrix( 71 | do_round=do_round) # shape (input_dim, input_dim) 72 | # Generate lagged adj matrix 73 | probs[1:, ...] = F.softmax(self.logits_lag, dim=0)[ 74 | 1, ...] # shape (lag, input_dim, input_dim) 75 | if do_round: 76 | return probs.round() 77 | 78 | return probs 79 | 80 | def entropy(self) -> torch.Tensor: 81 | """ 82 | This computes the entropy of the variational distribution. 83 | This can be done by (1) compute the entropy of instantaneous adj matrix(categorical, same as ThreeWayGraphDist), 84 | (2) compute the entropy of lagged adj matrix (Bernoulli dist), and (3) add them together. 85 | """ 86 | # Entropy for instantaneous dist, call super().entropy 87 | if not self.disable_inst: 88 | entropies_inst = super().entropy() 89 | else: 90 | entropies_inst = 0 91 | # Entropy for lagged dist 92 | # batch_shape [lag], event_shape [num_nodes, num_nodes] 93 | 94 | dist_lag = td.Independent(td.Bernoulli( 95 | logits=self.logits_lag[1, ...] - self.logits_lag[0, ...]), 2) 96 | entropies_lag = dist_lag.entropy().sum() 97 | # entropies_lag = dist_lag.entropy().mean() 98 | 99 | return entropies_lag + entropies_inst 100 | 101 | def sample_A(self) -> torch.Tensor: 102 | """ 103 | This samples the adjacency matrix from the variational distribution. This uses the gumbel softmax trick and returns 104 | hard samples. This can be done by (1) sample instantaneous adj matrix using self.logits, (2) sample lagged adj matrix using self.logits_lag. 105 | """ 106 | 107 | # Create adj matrix to avoid concatenation 108 | adj_sample = torch.zeros( 109 | self.lag + 1, self.input_dim, self.input_dim, device=self.device 110 | ) # shape (lag+1, input_dim, input_dim) 111 | 112 | # Sample instantaneous adj matrix 113 | if not self.disable_inst: 114 | adj_sample[0, ...] = self._triangular_vec_to_matrix( 115 | F.gumbel_softmax( 116 | self.logits, tau=self.tau_gumbel, hard=True, dim=0) 117 | ) # shape (input_dim, input_dim) 118 | # Sample lagged adj matrix 119 | adj_sample[1:, ...] = F.gumbel_softmax(self.logits_lag, tau=self.tau_gumbel, hard=True, dim=0)[ 120 | 1, ... 121 | ] # shape (lag, input_dim, input_dim) 122 | return adj_sample 123 | -------------------------------------------------------------------------------- /src/modules/adjacency_matrices/ThreeWayGraphDist.py: -------------------------------------------------------------------------------- 1 | 2 | import lightning.pytorch as pl 3 | import torch 4 | import torch.distributions as td 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from src.modules.adjacency_matrices.AdjMatrix import AdjMatrix 8 | 9 | 10 | class ThreeWayGraphDist(AdjMatrix, pl.LightningModule): 11 | """ 12 | An alternative variational distribution for graph edges. For each pair of nodes x_i and x_j 13 | where i < j, we sample a three way categorical C_ij. If C_ij = 0, we sample the edge 14 | x_i -> x_j, if C_ij = 1, we sample the edge x_j -> x_i, and if C_ij = 2, there is no 15 | edge between these nodes. This variational distribution is faster to use than ENCO 16 | because it avoids any calls to `torch.stack`. 17 | 18 | Sampling is performed with `torch.gumbel_softmax(..., hard=True)` to give 19 | binary samples and a straight-through gradient estimator. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | input_dim: int, 25 | tau_gumbel: float = 1.0, 26 | ): 27 | """ 28 | Args: 29 | input_dim: dimension. 30 | tau_gumbel: temperature used for gumbel softmax sampling. 31 | """ 32 | super().__init__() 33 | # We only use n(n-1)/2 random samples 34 | # For each edge, sample either A->B, B->A or no edge 35 | # We convert this to a proper adjacency matrix using torch.tril_indices 36 | self.logits = nn.Parameter( 37 | torch.zeros(3, (input_dim * (input_dim - 1)) // 2, device=self.device), requires_grad=True 38 | ) 39 | self.tau_gumbel = tau_gumbel 40 | self.input_dim = input_dim 41 | self.lower_idxs = torch.unbind( 42 | torch.tril_indices(self.input_dim, self.input_dim, 43 | offset=-1, device=self.device), 0 44 | ) 45 | 46 | def _triangular_vec_to_matrix(self, vec): 47 | """ 48 | Given an array of shape (k, n(n-1)/2) where k in {2, 3}, creates a matrix of shape 49 | (n, n) where the lower triangular is filled from vec[0, :] and the upper 50 | triangular is filled from vec[1, :]. 51 | """ 52 | output = torch.zeros( 53 | (self.input_dim, self.input_dim), device=self.device) 54 | output[self.lower_idxs[0], self.lower_idxs[1]] = vec[0, ...] 55 | output[self.lower_idxs[1], self.lower_idxs[0]] = vec[1, ...] 56 | return output 57 | 58 | def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: 59 | """ 60 | Returns the adjacency matrix of edge probabilities. 61 | """ 62 | probs = F.softmax(self.logits, dim=0) # (3, n(n-1)/2) probabilities 63 | out_probs = self._triangular_vec_to_matrix(probs) 64 | if do_round: 65 | return out_probs.round() 66 | return out_probs 67 | 68 | def entropy(self) -> torch.Tensor: 69 | """ 70 | Computes the entropy of distribution q, which is a collection of n(n-1) categoricals on 3 values. 71 | """ 72 | dist = td.Categorical(logits=self.logits.transpose(0, -1)) 73 | entropies = dist.entropy() 74 | return entropies.sum() 75 | # return entropies.mean() 76 | 77 | def sample_A(self) -> torch.Tensor: 78 | """ 79 | Sample an adjacency matrix from the variational distribution. It uses the gumbel_softmax trick, 80 | and returns hard samples (straight through gradient estimator). Adjacency returned always has 81 | zeros in its diagonal (no self loops). 82 | 83 | V1: Returns one sample to be used for the whole batch. 84 | """ 85 | sample = F.gumbel_softmax( 86 | self.logits, tau=self.tau_gumbel, hard=True, dim=0) # (3, n(n-1)/2) binary 87 | return self._triangular_vec_to_matrix(sample) 88 | -------------------------------------------------------------------------------- /src/modules/adjacency_matrices/TwoWayGraphDist.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | import torch 3 | import torch.distributions as td 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from src.modules.adjacency_matrices.AdjMatrix import AdjMatrix 7 | 8 | 9 | class TwoWayGraphDist(AdjMatrix, pl.LightningModule): 10 | """ 11 | Sampling is performed with `torch.gumbel_softmax(..., hard=True)` to give 12 | binary samples and a straight-through gradient estimator. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | input_dim: int, 18 | tau_gumbel: float = 1.0 19 | ): 20 | """ 21 | Args: 22 | input_dim: dimension. 23 | tau_gumbel: temperature used for gumbel softmax sampling. 24 | """ 25 | super().__init__() 26 | # We only use n(n-1)/2 random samples 27 | # For each edge, sample either A->B, B->A or no edge 28 | # We convert this to a proper adjacency matrix using torch.tril_indices 29 | self.logits = nn.Parameter( 30 | torch.zeros((2, input_dim, input_dim), device=self.device), requires_grad=True 31 | ) 32 | self.tau_gumbel = tau_gumbel 33 | self.input_dim = input_dim 34 | 35 | def zero_out_diagonal(self, matrix: torch.Tensor): 36 | # matrix: (num_nodes, num_nodes) 37 | N = matrix.shape[0] 38 | I = torch.arange(N).to(self.device) 39 | matrix = matrix.clone() 40 | matrix[I, I] = 0 41 | return matrix 42 | 43 | def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: 44 | """ 45 | Returns the adjacency matrix of edge probabilities. 46 | """ 47 | probs = F.softmax(self.logits, dim=0)[1, ...] 48 | probs = self.zero_out_diagonal(probs) 49 | # probs = F.softmax(self.logits, dim=0) # (3, n(n-1)/2) probabilities 50 | 51 | if do_round: 52 | return probs.round() 53 | return probs 54 | 55 | def entropy(self) -> torch.Tensor: 56 | """ 57 | Computes the entropy of distribution q, which is a collection of n(n-1) categoricals on 3 values. 58 | """ 59 | dist = td.Categorical(logits=self.logits[1, ...] - self.logits[0, ...]) 60 | I = torch.arange(self.input_dim) 61 | dist_diag = td.Categorical( 62 | logits=self.logits[1, I, I] - self.logits[0, I, I]) 63 | entropies = dist.entropy() 64 | diag_entropy = dist_diag.entropy() 65 | return entropies.sum() - diag_entropy.sum() 66 | # return entropies.mean() 67 | 68 | def sample_A(self) -> torch.Tensor: 69 | """ 70 | Sample an adjacency matrix from the variational distribution. It uses the gumbel_softmax trick, 71 | and returns hard samples (straight through gradient estimator). Adjacency returned always has 72 | zeros in its diagonal (no self loops). 73 | 74 | V1: Returns one sample to be used for the whole batch. 75 | """ 76 | sample = F.gumbel_softmax( 77 | self.logits, tau=self.tau_gumbel, hard=True, dim=0)[1, ...] 78 | return self.zero_out_diagonal(sample) 79 | -------------------------------------------------------------------------------- /src/modules/adjacency_matrices/TwoWayTemporalAdjacencyMatrix.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.distributions as td 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from src.modules.adjacency_matrices.TwoWayGraphDist import TwoWayGraphDist 8 | 9 | 10 | class TwoWayTemporalAdjacencyMatrix(TwoWayGraphDist): 11 | """ 12 | This class adapts the TwoWayGraphDist s.t. it supports the variational distributions for temporal adjacency matrix. 13 | 14 | """ 15 | 16 | def __init__( 17 | self, 18 | input_dim: int, 19 | lag: int, 20 | tau_gumbel: float = 1.0, 21 | init_logits: Optional[List[float]] = None, 22 | disable_inst: bool = False 23 | ): 24 | """ 25 | This creates an instance of variational distribution for temporal adjacency matrix. 26 | Args: 27 | device: Device used. 28 | input_dim: The number of nodes for adjacency matrix. 29 | lag: The lag for the temporal adj matrix. The adj matrix has the shape (lag+1, num_nodes, num_nodes). 30 | tau_gumbel: The temperature for the gumbel softmax sampling. 31 | init_logits: The initialized logits value. If None, then use the default initlized logits (value 0). Otherwise, 32 | init_logits[0] indicates the non-existence edge logit for instantaneous effect, and init_logits[1] indicates the 33 | non-existence edge logit for lagged effect. E.g. if we want a dense initialization, one choice is (-7, -0.5) 34 | """ 35 | # Call parent init method, this will init a self.logits parameters for instantaneous effect. 36 | super().__init__(input_dim=input_dim, tau_gumbel=tau_gumbel) 37 | # Create a separate logit for lagged adj 38 | # The logits_lag are initialized to zero with shape (2, lag, input_dim, input_dim). 39 | # logits_lag[0,...] indicates the logit prob for no edges, and logits_lag[1,...] indicates the logit for edge existence. 40 | self.lag = lag 41 | # Assertion lag > 0 42 | assert lag > 0 43 | self.logits_lag = nn.Parameter(torch.zeros( 44 | (2, lag, input_dim, input_dim), device=self.device), requires_grad=True) 45 | self.init_logits = init_logits 46 | self.disable_inst = disable_inst 47 | # Set the init_logits if not None 48 | if self.init_logits is not None: 49 | self.logits.data[0, ...] = self.init_logits[0] 50 | self.logits_lag.data[0, ...] = self.init_logits[1] 51 | 52 | def get_adj_matrix(self, do_round: bool = False) -> torch.Tensor: 53 | """ 54 | This returns the temporal adjacency matrix of edge probability. 55 | Args: 56 | do_round: Whether to round the edge probabilities. 57 | 58 | Returns: 59 | The adjacency matrix with shape [lag+1, num_nodes, num_nodes]. 60 | """ 61 | 62 | # Create the temporal adj matrix 63 | probs = torch.zeros(self.lag + 1, self.input_dim, 64 | self.input_dim, device=self.device) 65 | # Generate simultaneous adj matrix 66 | if not self.disable_inst: 67 | probs[0, ...] = super().get_adj_matrix( 68 | do_round=do_round) # shape (input_dim, input_dim) 69 | # Generate lagged adj matrix 70 | probs[1:, ...] = F.softmax(self.logits_lag, dim=0)[ 71 | 1, ...] # shape (lag, input_dim, input_dim) 72 | if do_round: 73 | return probs.round() 74 | 75 | return probs 76 | 77 | def entropy(self) -> torch.Tensor: 78 | """ 79 | This computes the entropy of the variational distribution. 80 | This can be done by (1) compute the entropy of instantaneous adj matrix(categorical, same as ThreeWayGraphDist), 81 | (2) compute the entropy of lagged adj matrix (Bernoulli dist), and (3) add them together. 82 | """ 83 | # Entropy for instantaneous dist, call super().entropy 84 | if not self.disable_inst: 85 | entropies_inst = super().entropy() 86 | else: 87 | entropies_inst = 0 88 | # Entropy for lagged dist 89 | # batch_shape [lag], event_shape [num_nodes, num_nodes] 90 | 91 | dist_lag = td.Independent(td.Bernoulli( 92 | logits=self.logits_lag[1, ...] - self.logits_lag[0, ...]), 2) 93 | entropies_lag = dist_lag.entropy().sum() 94 | # entropies_lag = dist_lag.entropy().mean() 95 | 96 | return entropies_lag + entropies_inst 97 | 98 | def sample_A(self) -> torch.Tensor: 99 | """ 100 | This samples the adjacency matrix from the variational distribution. This uses the gumbel softmax trick and returns 101 | hard samples. This can be done by (1) sample instantaneous adj matrix using self.logits, (2) sample lagged adj matrix using self.logits_lag. 102 | """ 103 | 104 | # Create adj matrix to avoid concatenation 105 | adj_sample = torch.zeros( 106 | self.lag + 1, self.input_dim, self.input_dim, device=self.device 107 | ) # shape (lag+1, input_dim, input_dim) 108 | 109 | # Sample instantaneous adj matrix 110 | if not self.disable_inst: 111 | adj_sample[0, ...] = self.zero_out_diagonal( 112 | F.gumbel_softmax(self.logits, tau=self.tau_gumbel, 113 | hard=True, dim=0)[1, ...] 114 | ) # shape (input_dim, input_dim) 115 | 116 | # Sample lagged adj matrix 117 | adj_sample[1:, ...] = F.gumbel_softmax(self.logits_lag, tau=self.tau_gumbel, hard=True, dim=0)[ 118 | 1, ... 119 | ] # shape (lag, input_dim, input_dim) 120 | return adj_sample 121 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # standard libraries 2 | import time 3 | import os 4 | 5 | import hydra 6 | import numpy as np 7 | import torch 8 | from omegaconf import DictConfig 9 | from lightning.pytorch import Trainer, seed_everything 10 | from lightning.pytorch.loggers import CSVLogger, WandbLogger 11 | from lightning.pytorch.callbacks import ModelCheckpoint 12 | from lightning.pytorch.utilities.model_summary import ModelSummary 13 | from lightning.pytorch.utilities import rank_zero_only 14 | import wandb 15 | 16 | from src.utils.config_utils import add_all_attributes, add_attribute, generate_unique_name 17 | from src.utils.data_utils.dataloading_utils import load_data, get_dataset_path, create_save_name 18 | from src.utils.data_utils.data_format_utils import zero_out_diag_np 19 | from src.utils.utils import write_results_to_disk 20 | from src.utils.metrics_utils import evaluate_results 21 | 22 | from src.model.generate_model import generate_model 23 | 24 | 25 | @hydra.main(version_base=None, config_path="../configs/", config_name="main.yaml") 26 | def run(cfg: DictConfig): 27 | if 'dataset' not in cfg: 28 | raise Exception("No dataset found in the config") 29 | 30 | dataset = cfg.dataset 31 | dataset_path = get_dataset_path(dataset) 32 | 33 | if dataset_path in cfg: 34 | model_config = cfg[dataset_path] 35 | else: 36 | raise Exception( 37 | "No model found in the config. Try running with option: python3 -m src.train +dataset= +=") 38 | 39 | cfg.dataset_dir = os.path.join(cfg.dataset_dir, dataset_path) 40 | 41 | add_all_attributes(cfg, model_config) 42 | train(cfg) 43 | 44 | 45 | def train(cfg): 46 | 47 | print("Running config:") 48 | print(cfg) 49 | dataset = cfg.dataset 50 | seed = int(cfg.random_seed) 51 | # set seed 52 | seed_everything(cfg.random_seed) 53 | 54 | X, adj_matrix, aggregated_graph, lag, data_dim = load_data( 55 | cfg.dataset, cfg.dataset_dir, cfg) 56 | 57 | add_attribute(cfg, 'lag', lag) 58 | add_attribute(cfg, 'aggregated_graph', aggregated_graph) 59 | add_attribute(cfg, 'num_nodes', X.shape[2]) 60 | add_attribute(cfg, 'data_dim', data_dim) 61 | 62 | # generate model 63 | model = generate_model(cfg) 64 | 65 | # pass the dataset 66 | model = model(full_dataset=X, 67 | adj_matrices=adj_matrix) 68 | 69 | model_name = cfg.model 70 | 71 | if 'use_indices' in cfg: 72 | f_path = os.path.join(cfg.dataset_dir, cfg.dataset, 73 | f'{cfg.use_indices}_seed={cfg.random_seed}.npy') 74 | mix_idx = torch.Tensor(np.load(f_path)) 75 | model.set_mixture_indices(mix_idx) 76 | 77 | training_needed = cfg.model in ['rhino', 'mcd'] 78 | unique_name = generate_unique_name(cfg) 79 | csv_logger = CSVLogger("logs", name=unique_name) 80 | wandb_logger = WandbLogger( 81 | name=unique_name, project=cfg.wandb_project, log_model=True) 82 | 83 | # log all hyperparameters 84 | 85 | if rank_zero_only.rank == 0: 86 | for key in cfg: 87 | wandb_logger.experiment.config[key] = cfg[key] 88 | csv_logger.log_hyperparams(cfg) 89 | 90 | if training_needed: 91 | # either val_loss or likelihood 92 | monitor_checkpoint_based_on = cfg.monitor_checkpoint_based_on 93 | ckpt_choice = 'best' 94 | 95 | checkpoint_callback = ModelCheckpoint( 96 | save_top_k=1, monitor=monitor_checkpoint_based_on, mode="min", save_last=True) 97 | 98 | if len(cfg.gpu) > 1: 99 | strategy = 'ddp_find_unused_parameters_true' 100 | else: 101 | strategy = 'auto' 102 | 103 | if 'val_every_n_epochs' in cfg: 104 | val_every_n_epochs = cfg.val_every_n_epochs 105 | else: 106 | val_every_n_epochs = 1 107 | 108 | trainer = Trainer(max_epochs=cfg.num_epochs, 109 | accelerator="gpu", 110 | devices=cfg.gpu, 111 | precision=cfg.precision, 112 | logger=[csv_logger, wandb_logger], 113 | callbacks=[checkpoint_callback], 114 | strategy=strategy, 115 | enable_progress_bar=True, 116 | check_val_every_n_epoch=val_every_n_epochs) 117 | 118 | summary = ModelSummary(model, max_depth=10) 119 | print(summary) 120 | 121 | if cfg.watch_gradients: 122 | wandb_logger.watch(model) 123 | 124 | start_time = time.time() 125 | trainer.fit(model=model) 126 | end_time = time.time() 127 | 128 | print("Model trained in", str(end_time-start_time) + "s") 129 | 130 | else: 131 | if cfg.gpu != -1: 132 | print("WARNING: GPU specified, but baseline cannot use GPU.") 133 | trainer = Trainer(logger=[csv_logger, wandb_logger], 134 | accelerator='cpu') 135 | 136 | # get predictions 137 | 138 | full_dataloader = model.get_full_dataloader() 139 | if training_needed: 140 | model.eval() 141 | L = trainer.predict(model, full_dataloader, ckpt_path=ckpt_choice) 142 | else: 143 | L = trainer.predict(model, full_dataloader) 144 | 145 | predictions = [] 146 | scores = [] 147 | adj_matrix = [] 148 | for graph, prob, matrix in L: 149 | predictions.append(graph) 150 | scores.append(prob) 151 | adj_matrix.append(matrix) 152 | 153 | predictions = torch.concatenate(predictions, dim=0) 154 | scores = torch.concatenate(scores, dim=0) 155 | adj_matrix = torch.concatenate(adj_matrix, dim=0) 156 | 157 | predictions = predictions.detach().cpu().numpy() 158 | scores = scores.detach().cpu().numpy() 159 | adj_matrix = adj_matrix.detach().cpu().numpy() 160 | 161 | if 'dream3' in dataset: 162 | predictions = zero_out_diag_np(predictions) 163 | scores = zero_out_diag_np(scores) 164 | 165 | # save predictions 166 | if not os.path.exists(os.path.join('results', dataset)): 167 | os.makedirs(os.path.join('results', dataset)) 168 | 169 | if training_needed and ckpt_choice == 'best': 170 | if model_name == 'mcd': 171 | np.save(os.path.join('results', dataset, 172 | f'{model_name}_{checkpoint_callback.best_model_score.item()}_k{cfg.trainer.num_graphs}.npy'), scores) 173 | else: 174 | np.save(os.path.join('results', dataset, 175 | f'{model_name}_{checkpoint_callback.best_model_score.item()}.npy'), scores) 176 | else: 177 | np.save(os.path.join('results', dataset, f'{model_name}.npy'), scores) 178 | 179 | if model_name == 'mcd': 180 | true_cluster_indices, pred_cluster_indices = model.get_cluster_indices() 181 | else: 182 | pred_cluster_indices = None 183 | true_cluster_indices = None 184 | 185 | metrics = evaluate_results(scores=scores, 186 | adj_matrix=adj_matrix, 187 | predictions=predictions, 188 | aggregated_graph=aggregated_graph, 189 | true_cluster_indices=true_cluster_indices, 190 | pred_cluster_indices=pred_cluster_indices) 191 | # add the dataset name and model to the csv 192 | metrics['model'] = model_name + "_seed_" + str(seed) 193 | 194 | if model_name in ['pcmci', 'dynotears']: 195 | metrics['model'] += "_singlegraph_" + \ 196 | str(cfg.trainer.single_graph) + "_grouped_" + \ 197 | str(cfg.trainer.group_by_graph) 198 | if model_name == 'mcd': 199 | metrics['model'] += "_trueindex_" + \ 200 | str(cfg.trainer.use_correct_mixture_index) 201 | if 'use_indices' in cfg: 202 | metrics['model'] += '_' + cfg.use_indices 203 | if (model_name == ['rhino', 'mcd']) and 'linear' in cfg.decoder and cfg.decoder.linear: 204 | metrics['model'] += '_linearmodel' 205 | if training_needed and ckpt_choice == 'best': 206 | metrics['best_loss'] = checkpoint_callback.best_model_score.item() 207 | metrics['dataset'] = dataset 208 | 209 | # write the results to wandb 210 | wandb.init() 211 | wandb.log(metrics) 212 | 213 | # write the results to the log directory 214 | write_results_to_disk(create_save_name(dataset, cfg), metrics) 215 | 216 | wandb.finish() 217 | 218 | 219 | if __name__ == "__main__": 220 | run() 221 | -------------------------------------------------------------------------------- /src/utils/causality_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrowed from github.com/microsoft/causica 3 | """ 4 | 5 | from typing import Any, Dict, Optional, Union 6 | 7 | import numpy as np 8 | import torch 9 | 10 | def convert_temporal_to_static_adjacency_matrix( 11 | adj_matrix: np.ndarray, conversion_type: str, fill_value: Union[float, int] = 0.0 12 | ) -> np.ndarray: 13 | """ 14 | This method convert the temporal adjacency matrix to a specified type of static adjacency. 15 | It supports two types of conversion: "full_time" and "auto_regressive". 16 | The conversion type determines the connections between time steps. 17 | "full_time" will convert the temporal adjacency matrix to a full-time static graph, where the connections between lagged time steps are preserved. 18 | "auto_regressive" will convert temporal adj to a static graph that only keeps the connections to the current time step. 19 | E.g. a temporal adj matrix with lag 2 is [A,B,C], where A,B and C are also adj matrices. "full_time" will convert this to 20 | [[A,B,C],[0,A,B],[0,0,A]]. "auto_regressive" will convert this to [[0,0,C],[0,0,B],[0,0,A]]. 21 | "fill_value" is used to specify the value to fill in the converted static adjacency matrix. Default is 0, but sometimes we may want 22 | other values. E.g. if we have a temporal soft prior with prior mask, then we may want to fill the converted prior mask with value 1 instead of 0, 23 | since the converted prior mask should never disable the blocks specifying the "arrow-against-time" in converted soft prior. 24 | Args: 25 | adj_matrix: The temporal adjacency matrix with shape [lag+1, from, to] or [N, lag+1, from, to]. 26 | conversion_type: The conversion type. It supports "full_time" and "auto_regressive". 27 | fill_value: The value used to fill the static adj matrix. The default is 0. 28 | 29 | Returns: static adjacency matrix with shape [(lag+1)*from, (lag+1)*to] or [N, (lag+1)*from, (lag+1)*to]. 30 | 31 | """ 32 | assert conversion_type in [ 33 | "full_time", 34 | "auto_regressive", 35 | ], f"The conversion_type {conversion_type} is not supported." 36 | if len(adj_matrix.shape) == 3: 37 | adj_matrix = adj_matrix[None, ...] # [1, lag+1, num_node, num_node] 38 | batch_dim, n_lag, n_nodes, _ = adj_matrix.shape # n_lag is lag+1 39 | if conversion_type == "full_time": 40 | block_fill_value = np.full((n_nodes, n_nodes), fill_value) 41 | else: 42 | block_fill_value = np.full( 43 | (batch_dim, n_lag * n_nodes, (n_lag - 1) * n_nodes), fill_value) 44 | 45 | if conversion_type == "full_time": 46 | static_adj = np.sum( 47 | np.stack([np.kron(np.diag(np.ones(n_lag - i), k=i), 48 | adj_matrix[:, i, :, :]) for i in range(n_lag)], axis=1), 49 | axis=1, 50 | ) # [N, n_lag*from, n_lag*to] 51 | static_adj += np.kron( 52 | np.tril(np.ones((batch_dim, n_lag, n_lag)), k=-1), block_fill_value 53 | ) # [N, n_lag*from, n_lag*to] 54 | 55 | if conversion_type == "auto_regressive": 56 | # Flip the temporal adj and concatenate to form one block column of the static. The flipping is needed due to the 57 | # format of converted adjacency matrix. E.g. temporal adj [A,B,C], where A is the instant adj matrix. Then, the converted adj 58 | # is [[[0,0,C],[0,0,B],[0,0,A]]]. The last column is the concatenation of flipped temporal adj. 59 | block_column = np.flip(adj_matrix, axis=1).reshape( 60 | -1, n_lag * n_nodes, n_nodes 61 | ) # [N, (lag+1)*num_node, num_node] 62 | # Static graph 63 | # [N, (lag+1)*num_node, (lag+1)*num_node] 64 | static_adj = np.concatenate((block_fill_value, block_column), axis=2) 65 | 66 | return np.squeeze(static_adj) 67 | 68 | 69 | def dag_pen_np(X): 70 | assert X.shape[0] == X.shape[1] 71 | X = torch.from_numpy(X) 72 | return (torch.trace(torch.matrix_exp(X)) - X.shape[0]).item() 73 | 74 | 75 | def approximate_maximal_acyclic_subgraph(adj_matrix: np.ndarray, n_samples: int = 10) -> np.ndarray: 76 | """ 77 | Compute an (approximate) maximal acyclic subgraph of a directed non-dag but removing at most 1/2 of the edges 78 | See Vazirani, Vijay V. Approximation algorithms. Vol. 1. Berlin: springer, 2001, Page 7; 79 | Also Hassin, Refael, and Shlomi Rubinstein. "Approximations for the maximum acyclic subgraph problem." 80 | Information processing letters 51.3 (1994): 133-140. 81 | Args: 82 | adj_matrix: adjacency matrix of a directed graph (may contain cycles) 83 | n_samples: number of the random permutations generated. Default is 10. 84 | Returns: 85 | an adjacency matrix of the acyclic subgraph 86 | """ 87 | # assign each node with a order 88 | adj_dag = np.zeros_like(adj_matrix) 89 | for _ in range(n_samples): 90 | random_order = np.expand_dims( 91 | np.random.permutation(adj_matrix.shape[0]), 0) 92 | # subgraph with only forward edges defined by the assigned order 93 | adj_forward = ( 94 | (random_order.T > random_order).astype(int)) * adj_matrix 95 | # subgraph with only backward edges defined by the assigned order 96 | adj_backward = ( 97 | (random_order.T < random_order).astype(int)) * adj_matrix 98 | # return the subgraph with the least deleted edges 99 | adj_dag_n = adj_forward if adj_backward.sum() < adj_forward.sum() else adj_backward 100 | if adj_dag_n.sum() > adj_dag.sum(): 101 | adj_dag = adj_dag_n 102 | return adj_dag 103 | 104 | 105 | def int2binlist(i: int, n_bits: int): 106 | """ 107 | Convert integer to list of ints with values in {0, 1} 108 | """ 109 | assert i < 2**n_bits 110 | str_list = list(np.binary_repr(i, n_bits)) 111 | return [int(i) for i in str_list] 112 | 113 | 114 | def cpdag2dags(cp_mat: np.ndarray, samples: Optional[int] = None) -> np.ndarray: 115 | """ 116 | Compute all possible DAGs contained within a Markov equivalence class, given by a CPDAG 117 | Args: 118 | cp_mat: adjacency matrix containing both forward and backward edges for edges for which directionality is undetermined 119 | Returns: 120 | 3 dimensional tensor, where the first indexes all the possible DAGs 121 | """ 122 | assert len(cp_mat.shape) == 2 and cp_mat.shape[0] == cp_mat.shape[1] 123 | 124 | # matrix composed of just undetermined edges 125 | cycle_mat = (cp_mat == cp_mat.T) * cp_mat 126 | # return original matrix if there are no length-1 cycles 127 | if cycle_mat.sum() == 0: 128 | if dag_pen_np(cp_mat) != 0.0: 129 | cp_mat = approximate_maximal_acyclic_subgraph(cp_mat) 130 | return cp_mat[None, :, :] 131 | 132 | # matrix of determined edges 133 | cp_determined_subgraph = cp_mat - cycle_mat 134 | 135 | # prune cycles if the matrix of determined edges is not a dag 136 | if dag_pen_np(cp_determined_subgraph.copy()) != 0.0: 137 | cp_determined_subgraph = approximate_maximal_acyclic_subgraph( 138 | cp_determined_subgraph, 1000) 139 | 140 | # number of parent nodes for each node under the well determined matrix 141 | n_in_nodes = cp_determined_subgraph.sum(axis=0) 142 | 143 | # lower triangular version of cycles edges: only keep cycles in one direction. 144 | cycles_tril = np.tril(cycle_mat, k=-1) 145 | 146 | # indices of potential new edges 147 | undetermined_idx_mat = np.array(np.nonzero( 148 | cycles_tril)).T # (N_undedetermined, 2) 149 | 150 | # number of undetermined edges 151 | N_undetermined = int(cycles_tril.sum()) 152 | 153 | # choose random order for mask iteration 154 | max_dags = 2**N_undetermined 155 | 156 | if max_dags > 10000: 157 | print("The number of possible dags are too large (>10000), limit to 10000") 158 | max_dags = 10000 159 | 160 | if samples is None: 161 | samples = max_dags 162 | mask_indices = list(np.random.permutation(np.arange(max_dags))) 163 | 164 | # iterate over list of all potential combinations of new edges. 0 represents keeping edge from upper triangular and 1 from lower triangular 165 | dag_list: list = [] 166 | while mask_indices and len(dag_list) < samples: 167 | 168 | mask_index = mask_indices.pop() 169 | mask = np.array(int2binlist(mask_index, N_undetermined)) 170 | 171 | # extract list of indices which our new edges are pointing into 172 | incoming_edges = np.take_along_axis( 173 | undetermined_idx_mat, mask[:, None], axis=1).squeeze() 174 | 175 | # check if multiple edges are pointing at same node 176 | _, unique_counts = np.unique( 177 | incoming_edges, return_index=False, return_inverse=False, return_counts=True) 178 | 179 | # check if new colider has been created by checkig if multiple edges point at same node or if new edge points at existing child node 180 | new_colider = np.any(unique_counts > 1) or np.any( 181 | n_in_nodes[incoming_edges] > 0) 182 | 183 | if not new_colider: 184 | # get indices of new edges by sampling from lower triangular mat and upper triangular according to indices 185 | edge_selection = undetermined_idx_mat.copy() 186 | edge_selection[mask == 0, :] = np.fliplr( 187 | edge_selection[mask == 0, :]) 188 | 189 | # add new edges to matrix and add to dag list 190 | new_dag = cp_determined_subgraph.copy() 191 | new_dag[(edge_selection[:, 0], edge_selection[:, 1])] = 1 192 | 193 | # Check for high order cycles 194 | if dag_pen_np(new_dag.copy()) == 0.0: 195 | dag_list.append(new_dag) 196 | # When all combinations of new edges create cycles, we will only keep determined ones 197 | if len(dag_list) == 0: 198 | dag_list.append(cp_determined_subgraph) 199 | 200 | return np.stack(dag_list, axis=0) 201 | -------------------------------------------------------------------------------- /src/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from omegaconf import DictConfig, open_dict 3 | from sklearn.model_selection import ParameterGrid 4 | 5 | 6 | def generate_unique_name(config): 7 | # generate unique name based on the config 8 | run_name = config['model']+"_"+config['dataset'] + \ 9 | "_"+"seed_"+str(config['random_seed'])+"_" 10 | run_name += datetime.now().strftime("%Y%m%d_%H_%M_%S") 11 | return run_name 12 | 13 | 14 | def read_optional(config, arg, default): 15 | if arg in config: 16 | return config[arg] 17 | return default 18 | 19 | 20 | def add_attribute(config: DictConfig, name, val): 21 | with open_dict(config): 22 | config[name] = val 23 | 24 | 25 | def add_all_attributes(cfg, cfg2): 26 | # add all attributes from cfg2 to cfg 27 | for key in cfg2: 28 | add_attribute(cfg, key, cfg2[key]) 29 | 30 | 31 | def build_subdictionary(hyperparameters, loop_hyperparameters): 32 | """ 33 | Given dictionary of hyperparameters (where some of the values may be lists) and a list of keys 34 | loop_hyperparameters, build a ParameterGrid 35 | 36 | """ 37 | # build sub dictionary of hyperparameters 38 | subparameters = dict( 39 | (k, hyperparameters[k]) for k in loop_hyperparameters if k in hyperparameters) 40 | subparameters = dict((k, [subparameters[k]]) if not isinstance( 41 | subparameters[k], list) else (k, subparameters[k]) for k in subparameters) 42 | 43 | subparameters = ParameterGrid(subparameters) 44 | 45 | return subparameters 46 | -------------------------------------------------------------------------------- /src/utils/data_gen/generate_perturb_syn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import yaml 4 | 5 | from data_generation_utils import generate_cts_temporal_data, set_random_seed, generate_temporal_graph 6 | import cdt 7 | import numpy as np 8 | import networkx as nx 9 | 10 | 11 | def calc_dist(adj_matrix): 12 | unique_matrices = np.unique(adj_matrix, axis=0) 13 | distances = [] 14 | for i in range(unique_matrices.shape[0]): 15 | for j in range(i): 16 | distances.append(cdt.metrics.SHD( 17 | unique_matrices[i], unique_matrices[j])) 18 | mean_dist = np.mean(distances) 19 | min_dist = np.min(distances) 20 | max_dist = np.max(distances) 21 | std_dist = np.std(distances) 22 | 23 | return mean_dist, std_dist, min_dist, max_dist 24 | 25 | 26 | def perturb_graph(G, p): 27 | retry_counter = 0 28 | while True: 29 | perturbation = np.random.binomial(1, p, size=G.shape) 30 | perturbed = np.logical_xor(G, perturbation).astype(int) 31 | # nxG = nx.from_numpy_matrix(perturbed[0], create_using=nx.DiGraph) 32 | nxG = nx.DiGraph(perturbed[0]) 33 | if nx.is_directed_acyclic_graph(nxG): 34 | break 35 | retry_counter += 1 36 | 37 | if retry_counter >= 200000: 38 | assert False, "Cannot generate DAG, try a lower value of p" 39 | 40 | return perturbed.astype(int) 41 | 42 | 43 | def main(config_file): 44 | 45 | # read the yaml file 46 | with open(config_file, 'r', encoding="utf-8") as f: 47 | data_config = yaml.load(f, Loader=yaml.FullLoader) 48 | 49 | series_length = int(data_config["num_timesteps"]) 50 | burnin_length = int(data_config["burnin_length"]) 51 | num_samples = int(data_config["num_samples"]) 52 | disable_inst = bool(data_config["disable_inst"]) 53 | graph_type = [data_config["inst_graph_type"], 54 | data_config["lag_graph_type"]] 55 | p_array = data_config['p_array'] 56 | 57 | connection_factor = 1 58 | 59 | # in this case, all inst and lagged adj are dags and have total edges per adj = 2*num_nodes 60 | # graph_type = ["SF", "SF"] 61 | # graph_config = [{"m":2, "directed":True}, {"m":2, "directed":True}] 62 | 63 | lag = int(data_config["lag"]) 64 | is_history_dep = bool(data_config["history_dep_noise"]) 65 | 66 | noise_level = float(data_config["noise_level"]) 67 | function_type = data_config["function_type"] 68 | noise_function_type = data_config["noise_function_type"] 69 | base_noise_type = data_config["base_noise_type"] 70 | 71 | save_dir = data_config["save_dir"] 72 | N = data_config["num_nodes"] 73 | num_graphs = data_config["num_graphs"] 74 | 75 | for seed in data_config['random_seed']: 76 | seed = int(seed) 77 | set_random_seed(seed) 78 | graph_config = [ 79 | {"m": N * 2 * connection_factor if not disable_inst else 0, "directed": True}, 80 | {"m": N * connection_factor, "directed": True}, 81 | ] 82 | G = generate_temporal_graph( 83 | N, graph_type, graph_config, lag=2).astype(int) 84 | N = int(N) 85 | 86 | for N_G in num_graphs: 87 | N_G = int(N_G) 88 | print(f"Generating dataset for N={N}, num_graphs={N_G}") 89 | 90 | for p in p_array: 91 | graphs = [] 92 | for i in range(N_G): 93 | print(f"Generating graph {i}/{N_G}") 94 | Gtilde = perturb_graph(G, p) 95 | graphs.append(Gtilde) 96 | 97 | mean_dist, std_dist, min_dist, max_dist = calc_dist( 98 | np.array(graphs)) 99 | print( 100 | f"Perturbation,{N},{N_G},{mean_dist},{std_dist},{min_dist},{max_dist},{p}") 101 | 102 | folder_name = f"perturb_N{N}_K{N_G}_p{p}_seed{seed}" 103 | path = os.path.join(save_dir, folder_name) 104 | 105 | generate_cts_temporal_data( 106 | path=path, 107 | num_graphs=N_G, 108 | series_length=series_length, 109 | burnin_length=burnin_length, 110 | num_samples=num_samples, 111 | num_nodes=N, 112 | graph_type=graph_type, 113 | graph_config=graph_config, 114 | lag=lag, 115 | is_history_dep=is_history_dep, 116 | noise_level=noise_level, 117 | function_type=function_type, 118 | noise_function_type=noise_function_type, 119 | base_noise_type=base_noise_type, 120 | temporal_graphs=graphs 121 | ) 122 | 123 | 124 | if __name__ == "__main__": 125 | parser = argparse.ArgumentParser("Temporal Synthetic Data Generator") 126 | parser.add_argument("--config_file", type=str) 127 | args = parser.parse_args() 128 | 129 | main(config_file=args.config_file) 130 | -------------------------------------------------------------------------------- /src/utils/data_gen/generate_stock.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from yahoofinancials import YahooFinancials 4 | import pandas as pd 5 | import numpy as np 6 | import tqdm 7 | 8 | 9 | def process_stock_price(tickers, start_date, end_date): 10 | yfs = [YahooFinancials(t) for t in tickers] 11 | X = [] 12 | dates = [] 13 | for yf in tqdm.tqdm(yfs, "Processing stock market data"): 14 | r = yf.get_historical_price_data(start_date, end_date, 'daily') 15 | r = dict(r) 16 | ticker = list(r.keys())[0] 17 | close_prices = np.array([t['close'] for t in r[ticker]['prices']]) 18 | X.append(close_prices) 19 | dates.append([t['formatted_date'] for t in r[ticker]['prices']]) 20 | X = np.array(X) 21 | # print(X.shape) 22 | dates = np.array(dates) 23 | return X, dates 24 | 25 | 26 | def get_log_returns(X): 27 | return np.diff(np.log(X)) 28 | 29 | 30 | def standardize(X): 31 | mu = np.mean(X, axis=1) 32 | std = np.std(X, axis=1) 33 | mu = np.repeat(mu[:, np.newaxis], repeats=X.shape[1], axis=1) 34 | std = np.repeat(std[:, np.newaxis], repeats=X.shape[1], axis=1) 35 | return (X-mu)/std 36 | 37 | 38 | def generate_stock(args): 39 | df = pd.read_csv(args.stock_list_file) 40 | tickers = df['Symbol'].tolist() 41 | 42 | X, dates = process_stock_price(tickers, args.start_date, args.end_date) 43 | X = get_log_returns(X) 44 | 45 | X = standardize(X) 46 | 47 | D = X.shape[0] 48 | T = X.shape[1] 49 | L = args.chunk_size 50 | dat = np.zeros((D, T//L, L)) 51 | date_array = [] 52 | 53 | for i in range(T//L): 54 | dat[:, i] = X[:, L*i:(i+1)*L] 55 | date_array.append(dates[0][L*i]) 56 | print(date_array[-1]) 57 | date_array.append(args.end_date) 58 | print("Data shape:", dat.shape) 59 | dat = dat.transpose((1, 2, 0)) 60 | if not os.path.exists(args.save_dir): 61 | os.makedirs(args.save_dir) 62 | 63 | np.save(os.path.join(args.save_dir, 'X.npy'), dat) 64 | 65 | for sector in df['Sector'].unique(): 66 | X_sector = dat[:, :, df[df['Sector'] == sector].index.values] 67 | np.save(os.path.join(args.save_dir, 68 | f'X_{sector.replace(" ", "")}.npy'), X_sector) 69 | 70 | with open(os.path.join(args.save_dir, 'dates.csv'), 'w', encoding="utf-8") as f: 71 | for i in range(len(date_array)-1): 72 | f.write(f"{date_array[i]} to {date_array[i+1]}\n") 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser( 77 | "Stock data generator from Yahoo Financials") 78 | parser.add_argument("--start_date", type=str, default='2016-01-01') 79 | parser.add_argument("--end_date", type=str, default='2023-07-01') 80 | parser.add_argument('--chunk_size', type=int, default=31, 81 | help='Number of days to chunk together into one sample. Default is 31 days') 82 | parser.add_argument('--stock_list_file', type=str) 83 | parser.add_argument('--save_dir', type=str) 84 | 85 | args = parser.parse_args() 86 | generate_stock(args) 87 | -------------------------------------------------------------------------------- /src/utils/data_gen/generate_synthetic_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import yaml 4 | from data_generation_utils import generate_cts_temporal_data, generate_name, set_random_seed 5 | 6 | 7 | def main(config_file): 8 | 9 | # read the yaml file 10 | with open(config_file, encoding="utf-8") as f: 11 | data_config = yaml.load(f, Loader=yaml.FullLoader) 12 | 13 | series_length = int(data_config["num_timesteps"]) 14 | burnin_length = int(data_config["burnin_length"]) 15 | num_samples = int(data_config["num_samples"]) 16 | disable_inst = bool(data_config["disable_inst"]) 17 | graph_type = [data_config["inst_graph_type"], 18 | data_config["lag_graph_type"]] 19 | 20 | connection_factor = 1 21 | 22 | # in this case, all inst and lagged adj are dags and have total edges per adj = 2*num_nodes 23 | # graph_type = ["SF", "SF"] 24 | # graph_config = [{"m":2, "directed":True}, {"m":2, "directed":True}] 25 | 26 | lag = int(data_config["lag"]) 27 | is_history_dep = bool(data_config["history_dep_noise"]) 28 | 29 | noise_level = float(data_config["noise_level"]) 30 | function_type = data_config["function_type"] 31 | noise_function_type = data_config["noise_function_type"] 32 | base_noise_type = data_config["base_noise_type"] 33 | 34 | save_dir = data_config["save_dir"] 35 | num_nodes = data_config["num_nodes"] 36 | num_graphs = data_config["num_graphs"] 37 | 38 | for seed in data_config['random_seed']: 39 | seed = int(seed) 40 | set_random_seed(seed) 41 | 42 | for N in num_nodes: 43 | for N_G in num_graphs: 44 | print(f"Generating dataset for N={N}, num_graphs={N_G}") 45 | N = int(N) 46 | N_G = int(N_G) 47 | graph_config = [ 48 | {"m": N * 2 * connection_factor if not disable_inst else 0, 49 | "directed": True}, 50 | {"m": N * connection_factor, "directed": True}, 51 | ] 52 | 53 | folder_name = generate_name( 54 | N, 55 | num_samples, 56 | graph_type, 57 | N_G, 58 | lag, 59 | is_history_dep, 60 | noise_level, 61 | function_type=function_type, 62 | noise_function_type=noise_function_type, 63 | disable_inst=disable_inst, 64 | seed=seed, 65 | connection_factor=connection_factor, 66 | base_noise_type=base_noise_type 67 | ) 68 | path = os.path.join(save_dir, folder_name) 69 | 70 | generate_cts_temporal_data( 71 | path=path, 72 | num_graphs=N_G, 73 | series_length=series_length, 74 | burnin_length=burnin_length, 75 | num_samples=num_samples, 76 | num_nodes=N, 77 | graph_type=graph_type, 78 | graph_config=graph_config, 79 | lag=lag, 80 | is_history_dep=is_history_dep, 81 | noise_level=noise_level, 82 | function_type=function_type, 83 | noise_function_type=noise_function_type, 84 | base_noise_type=base_noise_type 85 | ) 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser("Temporal Synthetic Data Generator") 90 | parser.add_argument("--config_file", type=str) 91 | args = parser.parse_args() 92 | 93 | main(config_file=args.config_file) 94 | -------------------------------------------------------------------------------- /src/utils/data_gen/process_dream3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import pandas as pd 7 | 8 | 9 | def process_ts(ts, timepoints, N_subjects): 10 | N_nodes = ts.shape[1] 11 | X = np.zeros((N_subjects, timepoints, N_nodes)) 12 | 13 | for i in range(N_subjects): 14 | X[i] = ts[i*timepoints: (i+1)*timepoints] 15 | 16 | return X 17 | 18 | 19 | def process_adj_matrix(net, size): 20 | A = np.zeros((size, size)) 21 | for a, b, c in net.values: 22 | src = int(a[1:])-1 23 | dest = int(b[1:])-1 24 | A[src, dest] = 1 25 | return A 26 | 27 | 28 | def split_by_trajectory(X, A, T=21): 29 | time_len = X.shape[0] 30 | N = X.shape[1] 31 | num_samples = time_len // T 32 | data = np.zeros((num_samples, T, N)) 33 | adj_matrix = np.zeros((num_samples, N, N)) 34 | for i in range(num_samples): 35 | data[i] = X[i*T:(i+1)*T] 36 | adj_matrix[i] = A 37 | 38 | return data, adj_matrix 39 | 40 | 41 | def process_dream3(args): 42 | 43 | for size in [10, 50, 100]: 44 | X1 = torch.load(os.path.join( 45 | args.dataset_dir, 'Dream3TensorData', f'Size{size}Ecoli1.pt'))['TsData'].numpy() 46 | A1 = pd.read_table(os.path.join( 47 | args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Ecoli1.tsv'), header=None) 48 | A1 = process_adj_matrix(A1, size) 49 | X1, A1 = split_by_trajectory(X1, A1) 50 | 51 | X2 = torch.load(os.path.join( 52 | args.dataset_dir, 'Dream3TensorData', f'Size{size}Ecoli2.pt'))['TsData'].numpy() 53 | A2 = pd.read_table(os.path.join( 54 | args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Ecoli2.tsv'), header=None) 55 | A2 = process_adj_matrix(A2, size) 56 | X2, A2 = split_by_trajectory(X2, A2) 57 | 58 | if not os.path.exists(os.path.join(args.save_dir, f'ecoli_{size}', 'grouped_by_matrix')): 59 | os.makedirs(os.path.join(args.save_dir, 60 | f'ecoli_{size}', 'grouped_by_matrix')) 61 | 62 | np.savez(os.path.join( 63 | args.save_dir, f'ecoli_{size}', 'grouped_by_matrix', 'ecoli_1.npz'), X=X1, adj_matrix=A1) 64 | np.savez(os.path.join( 65 | args.save_dir, f'ecoli_{size}', 'grouped_by_matrix', 'ecoli_2.npz'), X=X2, adj_matrix=A2) 66 | 67 | X = np.concatenate((X1, X2), axis=0) 68 | A = np.concatenate((A1, A2), axis=0) 69 | np.savez(os.path.join(args.save_dir, 70 | f'ecoli_{size}', 'ecoli.npz'), X=X, adj_matrix=A) 71 | 72 | X11 = torch.load(os.path.join( 73 | args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast1.pt'))['TsData'].numpy() 74 | A11 = pd.read_table(os.path.join( 75 | args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast1.tsv'), header=None) 76 | A11 = process_adj_matrix(A11, size) 77 | X11, A11 = split_by_trajectory(X11, A11) 78 | 79 | X21 = torch.load(os.path.join( 80 | args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast2.pt'))['TsData'].numpy() 81 | A21 = pd.read_table(os.path.join( 82 | args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast2.tsv'), header=None) 83 | A21 = process_adj_matrix(A21, size) 84 | X21, A21 = split_by_trajectory(X21, A21) 85 | 86 | X31 = torch.load(os.path.join( 87 | args.dataset_dir, 'Dream3TensorData', f'Size{size}Yeast3.pt'))['TsData'].numpy() 88 | A31 = pd.read_table(os.path.join( 89 | args.dataset_dir, 'TrueGeneNetworks', f'InSilicoSize{size}-Yeast3.tsv'), header=None) 90 | A31 = process_adj_matrix(A31, size) 91 | X31, A31 = split_by_trajectory(X31, A31) 92 | 93 | if not os.path.exists(os.path.join(args.save_dir, f'yeast_{size}', 'grouped_by_matrix')): 94 | os.makedirs(os.path.join(args.save_dir, 95 | f'yeast_{size}', 'grouped_by_matrix')) 96 | 97 | np.savez(os.path.join( 98 | args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_1.npz'), X=X11, adj_matrix=A11) 99 | np.savez(os.path.join( 100 | args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_2.npz'), X=X21, adj_matrix=A21) 101 | np.savez(os.path.join( 102 | args.save_dir, f'yeast_{size}', 'grouped_by_matrix', 'yeast_3.npz'), X=X31, adj_matrix=A31) 103 | X = np.concatenate((X11, X21, X31), axis=0) 104 | A = np.concatenate((A11, A21, A31), axis=0) 105 | print(X.shape) 106 | print(A.shape) 107 | np.savez(os.path.join(args.save_dir, 108 | f'yeast_{size}', 'yeast.npz'), X=X, adj_matrix=A) 109 | 110 | # save combined 111 | X = np.concatenate((X1, X2, X11, X21, X31), axis=0) 112 | A = np.concatenate((A1, A2, A11, A21, A31), axis=0) 113 | print(X.shape) 114 | print(A.shape) 115 | np.savez(os.path.join(args.save_dir, 116 | f'combined_{size}.npz'), X=X, adj_matrix=A) 117 | 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument('--dataset_dir', type=str) 122 | parser.add_argument('--save_dir', type=str) 123 | args = parser.parse_args() 124 | process_dream3(args) 125 | -------------------------------------------------------------------------------- /src/utils/data_gen/process_netsim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | 5 | import scipy.io as sio 6 | import argparse 7 | import numpy as np 8 | 9 | 10 | def process_ts(ts, timepoints, N_subjects): 11 | N_nodes = ts.shape[1] 12 | X = np.zeros((N_subjects, timepoints, N_nodes)) 13 | 14 | for i in range(N_subjects): 15 | X[i] = ts[i*timepoints: (i+1)*timepoints] 16 | 17 | return X 18 | 19 | 20 | def process_adj_matrix(net): 21 | return (np.abs(np.swapaxes(net, 1, 2)) > 0).astype(int) 22 | 23 | 24 | def process_netsim(args): 25 | seed = 0 26 | simulations = range(1, 29) 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | N_t_dict = {} 30 | X_dict_by_matrix = {} 31 | 32 | for i in simulations: 33 | mat = sio.loadmat(os.path.join(args.dataset_dir, f'sim{i}.mat')) 34 | timepoints = mat['Ntimepoints'][0, 0] 35 | N_subjects = mat['Nsubjects'][0, 0] 36 | ts = mat['ts'] 37 | net = mat['net'] 38 | N_nodes = ts.shape[1] 39 | X = process_ts(ts, timepoints, N_subjects) 40 | adj_matrix = process_adj_matrix(net) 41 | 42 | if (N_nodes, timepoints) not in N_t_dict: 43 | N_t_dict[(N_nodes, timepoints)] = {} 44 | N_t_dict[(N_nodes, timepoints)]['X'] = [] 45 | N_t_dict[(N_nodes, timepoints)]['adj_matrix'] = [] 46 | 47 | N_t_dict[(N_nodes, timepoints)]['X'].append(X) 48 | N_t_dict[(N_nodes, timepoints)]['adj_matrix'].append(adj_matrix) 49 | 50 | for j in range(X.shape[0]): 51 | if (adj_matrix[j].tobytes(), X.shape[1]) not in X_dict_by_matrix: 52 | X_dict_by_matrix[(adj_matrix[j].tobytes(), X.shape[1])] = [] 53 | X_dict_by_matrix[(adj_matrix[j].tobytes(), 54 | X.shape[1])].append(X[j]) 55 | 56 | if N_nodes == 15 and timepoints == 200: 57 | print("SIMULATION", i) 58 | for N_nodes, timepoints in N_t_dict: 59 | X = np.concatenate(N_t_dict[(N_nodes, timepoints)]['X'], axis=0) 60 | adj_matrix = np.concatenate( 61 | N_t_dict[(N_nodes, timepoints)]['adj_matrix'], axis=0) 62 | np.savez(os.path.join( 63 | args.save_dir, f'netsim_{N_nodes}_{timepoints}.npz'), X=X, adj_matrix=adj_matrix) 64 | 65 | counts = {} 66 | for key, T in X_dict_by_matrix: 67 | N = int(math.sqrt(len(np.frombuffer(key, dtype=int)))) 68 | if (N, T) not in counts: 69 | counts[(N, T)] = 0 70 | 71 | adj_matrix = np.array(np.frombuffer(key, dtype=int)).reshape(N, N) 72 | if not os.path.exists(os.path.join(args.save_dir, 'grouped_by_matrix', f'{N}_{T}')): 73 | os.makedirs(os.path.join(args.save_dir, 74 | 'grouped_by_matrix', f'{N}_{T}')) 75 | X = np.array(X_dict_by_matrix[(key, T)]) 76 | np.savez(os.path.join(args.save_dir, 'grouped_by_matrix', 77 | f'{N}_{T}', f'netsim_{counts[(N, T)]}.npz'), X=X, adj_matrix=adj_matrix) 78 | counts[(N, T)] += 1 79 | 80 | # add the permutations 81 | for N_nodes in [15, 50]: 82 | timepoints = 200 83 | num_graphs = 3 84 | permutation_pool = [np.random.permutation( 85 | np.arange(N_nodes)) for i in range(num_graphs)] 86 | 87 | X = [] 88 | adj_matrix = [] 89 | 90 | for i in range(len(N_t_dict[(N_nodes, timepoints)]['X'][0])): 91 | I = random.choice(permutation_pool) 92 | x = N_t_dict[(N_nodes, timepoints)]['X'][0][i] 93 | x = np.array(x) 94 | x = x[:, I] 95 | 96 | G = N_t_dict[(N_nodes, timepoints)]['adj_matrix'][0][i] 97 | G = G[I][:, I] 98 | 99 | X.append(x) 100 | adj_matrix.append(G) 101 | 102 | X = np.array(X) 103 | adj_matrix = np.array(adj_matrix) 104 | np.savez(os.path.join( 105 | args.save_dir, f'netsim_{N_nodes}_{timepoints}_permuted.npz'), X=X, adj_matrix=adj_matrix) 106 | 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument('--dataset_dir', type=str) 111 | parser.add_argument('--save_dir', type=str) 112 | args = parser.parse_args() 113 | process_netsim(args) 114 | -------------------------------------------------------------------------------- /src/utils/data_gen/splines.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrowed from github.com/microsoft/causica 3 | Needed by data_generation_utils.py 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | from torch.nn import functional as F 9 | 10 | DEFAULT_MIN_BIN_WIDTH = 1e-3 11 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 12 | DEFAULT_MIN_DERIVATIVE = 1e-3 13 | 14 | 15 | def searchsorted(bin_locations, inputs, eps=1e-6): 16 | bin_locations[..., -1] += eps 17 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 18 | 19 | 20 | def unconstrained_RQS( 21 | inputs, 22 | unnormalized_widths, 23 | unnormalized_heights, 24 | unnormalized_derivatives, 25 | inverse=False, 26 | tail_bound=1.0, 27 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 28 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 29 | min_derivative=DEFAULT_MIN_DERIVATIVE, 30 | ): 31 | 32 | inside_intvl_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 33 | outside_interval_mask = ~inside_intvl_mask 34 | 35 | outputs = torch.zeros_like(inputs) 36 | logabsdet = torch.zeros_like(inputs) 37 | 38 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 39 | constant = np.log(np.exp(1 - min_derivative) - 1) 40 | unnormalized_derivatives[..., 0] = constant 41 | unnormalized_derivatives[..., -1] = constant 42 | 43 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 44 | logabsdet[outside_interval_mask] = 0 45 | 46 | if inside_intvl_mask.any(): 47 | outputs[inside_intvl_mask], logabsdet[inside_intvl_mask] = RQS( 48 | inputs=inputs[inside_intvl_mask], 49 | unnormalized_widths=unnormalized_widths[inside_intvl_mask, :], 50 | unnormalized_heights=unnormalized_heights[inside_intvl_mask, :], 51 | unnormalized_derivatives=unnormalized_derivatives[inside_intvl_mask, :], 52 | inverse=inverse, 53 | left=-tail_bound, 54 | right=tail_bound, 55 | bottom=-tail_bound, 56 | top=tail_bound, 57 | min_bin_width=min_bin_width, 58 | min_bin_height=min_bin_height, 59 | min_derivative=min_derivative, 60 | ) 61 | return outputs, logabsdet 62 | 63 | 64 | def RQS( 65 | inputs, 66 | unnormalized_widths, 67 | unnormalized_heights, 68 | unnormalized_derivatives, 69 | inverse=False, 70 | left=0.0, 71 | right=1.0, 72 | bottom=0.0, 73 | top=1.0, 74 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 75 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 76 | min_derivative=DEFAULT_MIN_DERIVATIVE, 77 | ): 78 | if torch.min(inputs) < left or torch.max(inputs) > right: 79 | raise ValueError("Input outside domain") 80 | 81 | num_bins = unnormalized_widths.shape[-1] 82 | 83 | if min_bin_width * num_bins > 1.0: 84 | raise ValueError("Minimal bin width too large for the number of bins") 85 | if min_bin_height * num_bins > 1.0: 86 | raise ValueError("Minimal bin height too large for the number of bins") 87 | 88 | widths = F.softmax(unnormalized_widths, dim=-1) 89 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 90 | cumwidths = torch.cumsum(widths, dim=-1) 91 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 92 | cumwidths = (right - left) * cumwidths + left 93 | cumwidths[..., 0] = left 94 | cumwidths[..., -1] = right 95 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 96 | 97 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 98 | 99 | heights = F.softmax(unnormalized_heights, dim=-1) 100 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 101 | cumheights = torch.cumsum(heights, dim=-1) 102 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 103 | cumheights = (top - bottom) * cumheights + bottom 104 | cumheights[..., 0] = bottom 105 | cumheights[..., -1] = top 106 | heights = cumheights[..., 1:] - cumheights[..., :-1] 107 | 108 | if inverse: 109 | bin_idx = searchsorted(cumheights, inputs)[..., None] 110 | else: 111 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 112 | 113 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 114 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 115 | 116 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 117 | delta = heights / widths 118 | input_delta = delta.gather(-1, bin_idx)[..., 0] 119 | 120 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 121 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx) 122 | input_derivatives_plus_one = input_derivatives_plus_one[..., 0] 123 | 124 | input_heights = heights.gather(-1, bin_idx)[..., 0] 125 | 126 | if inverse: 127 | a = (inputs - input_cumheights) * ( 128 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 129 | ) + input_heights * (input_delta - input_derivatives) 130 | b = input_heights * input_derivatives - (inputs - input_cumheights) * ( 131 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 132 | ) 133 | c = -input_delta * (inputs - input_cumheights) 134 | 135 | discriminant = b.pow(2) - 4 * a * c 136 | assert (discriminant >= 0).all() 137 | 138 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 139 | outputs = root * input_bin_widths + input_cumwidths 140 | 141 | theta_one_minus_theta = root * (1 - root) 142 | denominator = input_delta + ( 143 | (input_derivatives + input_derivatives_plus_one - 144 | 2 * input_delta) * theta_one_minus_theta 145 | ) 146 | derivative_numerator = input_delta.pow(2) * ( 147 | input_derivatives_plus_one * root.pow(2) 148 | + 2 * input_delta * theta_one_minus_theta 149 | + input_derivatives * (1 - root).pow(2) 150 | ) 151 | logabsdet = torch.log(derivative_numerator) - \ 152 | 2 * torch.log(denominator) 153 | return outputs, -logabsdet 154 | 155 | # else 156 | theta = (inputs - input_cumwidths) / input_bin_widths 157 | theta_one_minus_theta = theta * (1 - theta) 158 | 159 | numerator = input_heights * \ 160 | (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) 161 | denominator = input_delta + ( 162 | (input_derivatives + input_derivatives_plus_one - 163 | 2 * input_delta) * theta_one_minus_theta 164 | ) 165 | outputs = input_cumheights + numerator / denominator 166 | 167 | derivative_numerator = input_delta.pow(2) * ( 168 | input_derivatives_plus_one * theta.pow(2) 169 | + 2 * input_delta * theta_one_minus_theta 170 | + input_derivatives * (1 - theta).pow(2) 171 | ) 172 | logabsdet = torch.log(derivative_numerator) - \ 173 | 2 * torch.log(denominator) 174 | return outputs, logabsdet 175 | -------------------------------------------------------------------------------- /src/utils/data_utils/data_format_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def convert_data_to_timelagged(X, lag): 6 | """ 7 | Converts data with shape (n_samples, timesteps, num_nodes, data_dim) to two tensors, 8 | one history tensor of shape (n_fragments, lag, num_nodes, data_dim) and 9 | another "input" tensor of shape (n_fragments, num_nodes, data_dim) 10 | """ 11 | n_samples, timesteps, num_nodes, data_dim = X.shape 12 | 13 | n_fragments_per_sample = timesteps - lag 14 | n_fragments = n_samples * n_fragments_per_sample 15 | X_history = np.zeros((n_fragments, lag, num_nodes, data_dim)) 16 | X_input = np.zeros((n_fragments, num_nodes, data_dim)) 17 | X_indices = np.zeros((n_fragments)) 18 | for i in range(n_fragments_per_sample): 19 | X_history[i*n_samples:(i+1)*n_samples] = X[:, i:i+lag, :, :] 20 | X_input[i*n_samples:(i+1)*n_samples] = X[:, i+lag, :, :] 21 | X_indices[i*n_samples:(i+1)*n_samples] = np.arange(n_samples) 22 | return X_history, X_input, X_indices 23 | 24 | 25 | def get_adj_matrix_id(A): 26 | return np.unique(A, axis=(0), return_inverse=True) 27 | 28 | 29 | def convert_adj_to_timelagged(A, lag, n_fragments, aggregated_graph=False, return_indices=True): 30 | """ 31 | Converts adjacency matrix with shape (n_samples, (lag+1), num_nodes, num_nodes) to shape 32 | (n_fragments, lag+1, num_nodes, num_nodes) 33 | """ 34 | 35 | A_indices = np.zeros((n_fragments)) 36 | if len(A.shape) == 4: 37 | n_samples, L, num_nodes, num_nodes = A.shape 38 | assert L == lag+1 39 | Ap = np.zeros((n_fragments, lag+1, num_nodes, num_nodes)) 40 | elif aggregated_graph: 41 | n_samples, num_nodes, num_nodes = A.shape 42 | Ap = np.zeros((n_fragments, num_nodes, num_nodes)) 43 | else: 44 | assert False, "invalid adjacency matrix" 45 | n_fragments_per_sample = n_fragments // n_samples 46 | 47 | _, matrix_indices = get_adj_matrix_id(A) 48 | 49 | for i in range(n_fragments_per_sample): 50 | Ap[i*n_samples:(i+1)*n_samples] = A 51 | A_indices[i*n_samples:(i+1)*n_samples] = matrix_indices 52 | 53 | if return_indices: 54 | return Ap, A_indices 55 | return Ap 56 | 57 | 58 | def to_time_aggregated_graph_np(graph): 59 | # convert graph of shape [batch, lag+1, num_nodes, num_nodes] to aggregated 60 | # graph of shape [batch, num_nodes, num_nodes] 61 | return (np.sum(graph, axis=1) > 0).astype(int) 62 | 63 | 64 | def to_time_aggregated_scores_np(graph): 65 | return np.max(graph, axis=1) 66 | 67 | 68 | def to_time_aggregated_graph_torch(graph): 69 | # convert graph of shape [batch, lag+1, num_nodes, num_nodes] to aggregated 70 | # graph of shape [batch, num_nodes, num_nodes] 71 | return (torch.sum(graph, dim=1) > 0).long() 72 | 73 | 74 | def to_time_aggregated_scores_torch(graph): 75 | # convert edge probability matrix of shape [batch, lag+1, num_nodes, num_nodes] to aggregated 76 | # matrix of shape [batch, num_nodes, num_nodes] 77 | max_val, _ = torch.max(graph, dim=1) 78 | return max_val 79 | 80 | 81 | def zero_out_diag_np(G): 82 | 83 | if len(G.shape) == 3: 84 | N = G.shape[1] 85 | I = np.arange(N) 86 | G[:, I, I] = 0 87 | 88 | elif len(G.shape) == 2: 89 | N = G.shape[0] 90 | I = np.arange(N) 91 | G[I, I] = 0 92 | 93 | return G 94 | 95 | 96 | def zero_out_diag_torch(G): 97 | 98 | if len(G.shape) == 3: 99 | N = G.shape[1] 100 | I = torch.arange(N) 101 | G[:, I, I] = 0 102 | 103 | elif len(G.shape) == 2: 104 | N = G.shape[0] 105 | I = torch.arange(N) 106 | G[I, I] = 0 107 | 108 | return G 109 | -------------------------------------------------------------------------------- /src/utils/data_utils/dataloading_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | def get_dataset_path(dataset): 6 | if 'netsim' in dataset: 7 | dataset_path = 'netsim' 8 | elif 'dream3' in dataset: 9 | dataset_path = 'dream3' 10 | elif 'snp100' in dataset: 11 | dataset_path = 'snp100' 12 | elif dataset == ['lorenz96', 'finance', 'fluxnet']: 13 | dataset_path = dataset 14 | else: 15 | dataset_path = 'synthetic' 16 | 17 | return dataset_path 18 | 19 | 20 | def create_save_name(dataset, cfg): 21 | if dataset == 'lorenz96': 22 | return f'lorenz96_N={cfg.num_nodes}_T={cfg.timesteps}_num_graphs={cfg.num_graphs}' 23 | return dataset 24 | 25 | 26 | def load_synthetic_from_folder(dataset_dir, dataset_name): 27 | X = np.load(os.path.join(dataset_dir, dataset_name, 'X.npy')) 28 | adj_matrix = np.load(os.path.join( 29 | dataset_dir, dataset_name, 'adj_matrix.npy')) 30 | 31 | return X, adj_matrix 32 | 33 | 34 | def load_netsim(dataset_dir, dataset_file): 35 | # load the files 36 | data = np.load(os.path.join(dataset_dir, dataset_file + '.npz')) 37 | X = data['X'] 38 | adj_matrix = data['adj_matrix'] 39 | # adj_matrix = np.transpose(adj_matrix, (0, 2, 1)) 40 | return X, adj_matrix 41 | 42 | 43 | def load_dream3_combined(dataset_dir, size): 44 | data = np.load(os.path.join(dataset_dir, f'combined_{size}.npz')) 45 | X = data['X'] 46 | adj_matrix = data['adj_matrix'] 47 | return X, adj_matrix 48 | 49 | 50 | def load_snp100(dataset, dataset_dir): 51 | if dataset == 'snp100': 52 | X = np.load(os.path.join(dataset_dir, 'X.npy')) 53 | else: 54 | # get the sector 55 | sector = dataset.split('_')[1] 56 | X = np.load(os.path.join(dataset_dir, f'X_{sector}.npy')) 57 | 58 | D = X.shape[2] 59 | # we do not have the true adjacency matrix 60 | adj_matrix = np.zeros((X.shape[0], D, D)) 61 | return X, adj_matrix 62 | 63 | 64 | def load_data(dataset, dataset_dir, config): 65 | if 'netsim' in dataset: 66 | X, adj_matrix = load_netsim( 67 | dataset_dir=dataset_dir, dataset_file=dataset) 68 | aggregated_graph = True 69 | # adj_matrix = np.transpose(adj_matrix, (0, 2, 1)) 70 | # read lag from config file 71 | lag = int(config['lag']) 72 | data_dim = 1 73 | X = np.expand_dims(X, axis=-1) 74 | elif dataset == 'dream3': 75 | dream3_size = int(config['dream3_size']) 76 | X, adj_matrix = load_dream3_combined( 77 | dataset_dir=dataset_dir, size=dream3_size) 78 | lag = int(config['lag']) 79 | data_dim = 1 80 | aggregated_graph = True 81 | X = np.expand_dims(X, axis=-1) 82 | elif 'snp100' in dataset: 83 | X, adj_matrix = load_snp100(dataset=dataset, dataset_dir=dataset_dir) 84 | aggregated_graph = True 85 | lag = 1 86 | data_dim = 1 87 | X = np.expand_dims(X, axis=-1) 88 | else: 89 | X, adj_matrix = load_synthetic_from_folder( 90 | dataset_dir=dataset_dir, dataset_name=dataset) 91 | lag = adj_matrix.shape[1] - 1 92 | data_dim = 1 93 | X = np.expand_dims(X, axis=-1) 94 | aggregated_graph = False 95 | print("Loaded data of shape:", X.shape) 96 | return X, adj_matrix, aggregated_graph, lag, data_dim 97 | -------------------------------------------------------------------------------- /src/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | def temporal_graph_sparsity(G: torch.Tensor): 5 | """ 6 | Args: 7 | G: (lag+1, num_nodes, num_nodes) or (batch, lag+1, num_nodes, num_nodes) 8 | 9 | Returns: 10 | Square of Frobenius norm of G (batch, ) 11 | """ 12 | 13 | return torch.sum(torch.square(G)) 14 | 15 | 16 | def l1_sparsity(G: torch.Tensor): 17 | return torch.sum(torch.abs(G)) 18 | 19 | 20 | def dag_penalty_notears(G: torch.Tensor): 21 | """ 22 | Implements the DAGness penalty from 23 | "DAGs with NO TEARS: Continuous Optimization for Structure Learning" 24 | 25 | Args: 26 | G: (num_nodes, num_nodes) or (num_graphs, num_nodes, num_nodes) 27 | 28 | """ 29 | num_nodes = G.shape[-1] 30 | 31 | if len(G.shape) == 2: 32 | trace_term = torch.trace(torch.matrix_exp(G)) 33 | return trace_term - num_nodes 34 | elif len(G.shape) == 3: 35 | trace_term = torch.einsum("ijj->i", torch.matrix_exp(G)) 36 | return torch.sum(trace_term - num_nodes) 37 | assert False, "DAG Penalty received illegal shape" 38 | 39 | 40 | def dag_penalty_notears_sq(W: torch.Tensor): 41 | num_nodes = W.shape[-1] 42 | 43 | if len(W.shape) == 2: 44 | trace_term = torch.trace(torch.matrix_exp(W * W)) 45 | return trace_term - num_nodes 46 | elif len(W.shape) == 3: 47 | trace_term = torch.einsum("ijj->i", torch.matrix_exp(W * W)) 48 | return torch.sum(trace_term) - W.shape[0] * num_nodes 49 | assert False, "DAG Penalty received illegal shape" 50 | -------------------------------------------------------------------------------- /src/utils/metrics_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from scipy.optimize import linear_sum_assignment 3 | from sklearn.metrics import confusion_matrix 4 | # metrics 5 | from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score 6 | from cdt.metrics import SHD 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def get_off_diagonal(A): 12 | # assumes A.shape: (batch, x, y) 13 | M = np.invert(np.eye(A.shape[1], dtype=bool)) 14 | return A[:, M] 15 | 16 | 17 | def adjacency_f1(adj_matrix, predictions): 18 | # adj_matrix: (b, l, d, d) or (b, d, d) 19 | # predictions: (b, l, d, d) or (b, d, d) 20 | D = adj_matrix.shape[-1] 21 | L = np.tril_indices(D, k=-1) 22 | U = np.triu_indices(D, k=1) 23 | 24 | adj_lower = adj_matrix[..., L[0], L[1]] 25 | adj_upper = adj_matrix[..., U[0], U[1]] 26 | 27 | adj_diag = np.diagonal(adj_matrix, axis1=-2, axis2=-1).flatten() 28 | adj = np.concatenate((adj_diag, np.logical_or( 29 | adj_lower, adj_upper).flatten().astype(int))) 30 | 31 | pred_diag = np.diagonal(predictions, axis1=-2, axis2=-1).flatten() 32 | pred_lower = predictions[..., L[0], L[1]] 33 | pred_upper = predictions[..., U[0], U[1]] 34 | pred = np.concatenate((pred_diag, np.logical_or( 35 | pred_lower, pred_upper).flatten().astype(int))) 36 | 37 | return f1_score(adj, pred) 38 | 39 | 40 | def compute_shd(adj_matrix, preds, aggregated_graph=False): 41 | assert adj_matrix.shape == preds.shape, f"Dimension of adj_matrix {adj_matrix.shape} should match the predictions {preds.shape}" 42 | 43 | if not aggregated_graph: 44 | assert len( 45 | adj_matrix.shape) == 4, "Expects adj_matrix of shape (batch, lag+1, num_nodes, num_nodes)" 46 | assert len( 47 | preds.shape) == 4, "Expects preds of shape (batch, lag+1, num_nodes, num_nodes)" 48 | else: 49 | assert len( 50 | adj_matrix.shape) == 3, "Expects adj_matrix of shape (batch, num_nodes, num_nodes)" 51 | assert len( 52 | preds.shape) == 3, "Expects preds of shape (batch, num_nodes, num_nodes)" 53 | 54 | shd_score = 0 55 | if not aggregated_graph: 56 | shd_inst = 0 57 | shd_lag = 0 58 | for i in range(adj_matrix.shape[0]): 59 | for j in range(adj_matrix.shape[1]): 60 | adj_sub_matrix = adj_matrix[i, j] 61 | preds_sub_matrix = preds[i, j] 62 | shd = SHD(adj_sub_matrix, preds_sub_matrix) 63 | shd_score += shd 64 | if j == 0: 65 | shd_inst += shd 66 | else: 67 | shd_lag += shd 68 | return shd_score/adj_matrix.shape[0], shd_inst/adj_matrix.shape[0], shd_lag/adj_matrix.shape[0] 69 | for i in range(adj_matrix.shape[0]): 70 | adj_sub_matrix = adj_matrix[i] 71 | preds_sub_matrix = preds[i] 72 | shd_score += SHD(adj_sub_matrix, preds_sub_matrix) 73 | # print(SHD(adj_sub_matrix, preds_sub_matrix)) 74 | return shd_score/adj_matrix.shape[0] 75 | 76 | 77 | def calculate_expected_shd(scores, adj_matrix, aggregated_graph=False, n_trials=100): 78 | totals_shd = 0 79 | for _ in range(n_trials): 80 | draw = np.random.binomial(1, scores) 81 | if aggregated_graph: 82 | shd = compute_shd(adj_matrix, draw, 83 | aggregated_graph=aggregated_graph) 84 | else: 85 | shd, _, _ = compute_shd( 86 | adj_matrix, draw, aggregated_graph=aggregated_graph) 87 | totals_shd += shd 88 | 89 | return totals_shd/n_trials 90 | 91 | 92 | def evaluate_results(scores, 93 | adj_matrix, 94 | predictions, 95 | aggregated_graph=False, 96 | true_cluster_indices=None, 97 | pred_cluster_indices=None): 98 | 99 | assert adj_matrix.shape == predictions.shape, "Dimension of adj_matrix should match the predictions" 100 | 101 | abs_scores = np.abs(scores).flatten() 102 | preds = np.abs(np.round(predictions)) 103 | truth = adj_matrix.flatten() 104 | 105 | # calculate shd 106 | 107 | if aggregated_graph: 108 | shd_score = compute_shd(adj_matrix, preds, aggregated_graph) 109 | else: 110 | shd_score, shd_inst, shd_lag = compute_shd( 111 | adj_matrix, preds, aggregated_graph) 112 | f1_inst = f1_score(get_off_diagonal(adj_matrix[:, 0]).flatten( 113 | ), get_off_diagonal(predictions[:, 0]).flatten()) 114 | f1_lag = f1_score(adj_matrix[:, 1:].flatten(), preds[:, 1:].flatten()) 115 | 116 | f1 = f1_score(truth, preds.flatten()) 117 | adj_f1 = adjacency_f1(adj_matrix, predictions) 118 | 119 | preds = preds.flatten() 120 | zero_edge_accuracy = np.sum(np.logical_and( 121 | preds == 0, truth == 0))/np.sum(truth == 0) 122 | one_edge_accuracy = np.sum(np.logical_and( 123 | preds == 1, truth == 1))/np.sum(truth == 1) 124 | 125 | accuracy = accuracy_score(truth, preds) 126 | precision = precision_score(truth, preds) 127 | recall = recall_score(truth, preds) 128 | 129 | try: 130 | rocauc = roc_auc_score(truth, abs_scores) 131 | except ValueError: 132 | rocauc = 0.5 133 | 134 | tnr = zero_edge_accuracy 135 | tpr = one_edge_accuracy 136 | 137 | print("Accuracy score:", accuracy) 138 | print("Orientation F1 score:", f1) 139 | print("Adjacency F1:", adj_f1) 140 | print("Precision score:", precision) 141 | print("Recall score:", recall) 142 | print("ROC AUC score:", rocauc) 143 | 144 | print("Accuracy on '0' edges", tnr) 145 | print("Accuracy on '1' edges", tpr) 146 | print("Structural Hamming Distance:", shd_score) 147 | if not aggregated_graph: 148 | print("Structural Hamming Distance (inst):", shd_inst) 149 | print("Structural Hamming Distance (lag):", shd_lag) 150 | print("Orientation F1 inst", f1_inst) 151 | print("Orientation F1 lag", f1_lag) 152 | eshd = calculate_expected_shd( 153 | np.abs(scores/(np.max(scores)+1e-4)), adj_matrix, aggregated_graph) 154 | print("Expected SHD:", eshd) 155 | # also return a dictionary of metrics 156 | metrics = { 157 | 'accuracy': accuracy, 158 | 'f1': f1, 159 | 'adj_f1': adj_f1, 160 | 'precision': precision, 161 | 'recall': recall, 162 | 'roc_auc': rocauc, 163 | 'tnr': tnr, 164 | 'tpr': tpr, 165 | 'shd_overall': shd_score, 166 | 'expected_shd': eshd 167 | } 168 | 169 | if not aggregated_graph: 170 | metrics['shd_inst'] = shd_inst 171 | metrics['shd_lag'] = shd_lag 172 | 173 | metrics['f1_inst'] = f1_inst 174 | metrics['f1_lag'] = f1_lag 175 | 176 | if pred_cluster_indices is not None and true_cluster_indices is not None: 177 | metrics['cluster_acc'] = cluster_accuracy(true_idx=true_cluster_indices, 178 | pred_idx=pred_cluster_indices) 179 | else: 180 | _, true_cluster_indices = np.unique( 181 | adj_matrix, return_inverse=True, axis=0) 182 | _, pred_cluster_indices = np.unique( 183 | predictions, return_inverse=True, axis=0) 184 | metrics['cluster_acc'] = cluster_accuracy(true_idx=true_cluster_indices, 185 | pred_idx=pred_cluster_indices) 186 | 187 | return metrics 188 | 189 | 190 | def mape_loss(X_true, x_pred): 191 | return torch.mean(torch.abs((X_true - x_pred) / X_true))*100 192 | 193 | 194 | def cluster_accuracy(true_idx, pred_idx): 195 | 196 | assert true_idx.shape == pred_idx.shape, "Shapes must match" 197 | # first get the confusion matrix 198 | cm = confusion_matrix(true_idx, pred_idx) 199 | 200 | # next run a linear sum assignment problem to obtain the maximum matching 201 | row_ind, col_ind = linear_sum_assignment(cm, maximize=True) 202 | 203 | # get the maximum matching 204 | return cm[row_ind, col_ind].sum()/cm.sum() 205 | -------------------------------------------------------------------------------- /src/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrowed from github.com/microsoft/causica 3 | """ 4 | 5 | from typing import List, Optional, Type 6 | from torch.nn import Dropout, LayerNorm, Linear, Module, Sequential 7 | import torch 8 | 9 | 10 | class resBlock(Module): 11 | """ 12 | Wraps an nn.Module, adding a skip connection to it. 13 | """ 14 | 15 | def __init__(self, block: Module): 16 | """ 17 | Args: 18 | block: module to which skip connection will be added. The input dimension must match the output dimension. 19 | """ 20 | super().__init__() 21 | self.block = block 22 | 23 | def forward(self, x): 24 | return x + self.block(x) 25 | 26 | 27 | def generate_fully_connected( 28 | input_dim: int, 29 | output_dim: int, 30 | hidden_dims: List[int], 31 | non_linearity: Optional[Type[Module]], 32 | activation: Optional[Type[Module]], 33 | device: torch.device, 34 | p_dropout: float = 0.0, 35 | normalization: Optional[Type[LayerNorm]] = None, 36 | res_connection: bool = False, 37 | ) -> Module: 38 | """ 39 | Generate a fully connected network. 40 | 41 | Args: 42 | input_dim: Int. Size of input to network. 43 | output_dim: Int. Size of output of network. 44 | hidden_dims: List of int. Sizes of internal hidden layers. i.e. 45 | [a, b] is three linear layers with shapes (input_dim, a), (a, b), (b, output_dim) 46 | non_linearity: Non linear activation function used between Linear layers. 47 | activation: Final layer activation to use. 48 | device: torch device to load weights to. 49 | p_dropout: Float. Dropout probability at the hidden layers. 50 | init_method: initialization method 51 | normalization: Normalisation layer to use (batchnorm, layer norm, etc). Will be placed before linear layers, excluding the input layer. 52 | res_connection : Whether to use residual connections where possible (if previous layer width matches next layer width) 53 | 54 | Returns: 55 | Sequential object containing the desired network. 56 | """ 57 | layers: List[Module] = [] 58 | 59 | prev_dim = input_dim 60 | for idx, hidden_dim in enumerate(hidden_dims): 61 | 62 | block: List[Module] = [] 63 | 64 | if normalization is not None and idx > 0: 65 | block.append(normalization(prev_dim).to(device)) 66 | block.append(Linear(prev_dim, hidden_dim).to(device)) 67 | 68 | if non_linearity is not None: 69 | block.append(non_linearity()) 70 | if p_dropout != 0: 71 | block.append(Dropout(p_dropout)) 72 | 73 | if res_connection and (prev_dim == hidden_dim): 74 | layers.append(resBlock(Sequential(*block))) 75 | else: 76 | layers.append(Sequential(*block)) 77 | prev_dim = hidden_dim 78 | 79 | if normalization is not None: 80 | layers.append(normalization(prev_dim).to(device)) 81 | layers.append(Linear(prev_dim, output_dim).to(device)) 82 | 83 | if activation is not None: 84 | layers.append(activation()) 85 | 86 | fcnn = Sequential(*layers) 87 | return fcnn 88 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import numpy as np 4 | 5 | 6 | def standard_scaling(X, across_samples=False): 7 | # expected X of shape (n_samples, timesteps, num_nodes, data_dim) or (n_samples, timesteps, num_nodes) 8 | 9 | if across_samples: 10 | means = np.mean(X, axis=(0, 1))[np.newaxis, np.newaxis, :] 11 | std = np.std(X, axis=(0, 1))[np.newaxis, np.newaxis, :] 12 | else: 13 | means = np.mean(X, axis=(1))[:, np.newaxis] 14 | std = np.std(X, axis=(1))[:, np.newaxis] 15 | 16 | eps = 1e-6 17 | Y = (X-means) / (std + eps) 18 | 19 | return Y 20 | 21 | 22 | def min_max_scaling(X, across_samples=False): 23 | # expected X of shape (n_samples, timesteps, num_nodes, data_dim) or (n_samples, timesteps, num_nodes) 24 | 25 | if across_samples: 26 | mins = np.amin(X, axis=(0, 1))[np.newaxis, np.newaxis, :] 27 | maxs = np.amax(X, axis=(0, 1))[np.newaxis, np.newaxis, :] 28 | else: 29 | mins = np.amin(X, axis=(1))[:, np.newaxis] 30 | maxs = np.amax(X, axis=(1))[:, np.newaxis] 31 | 32 | Y = (X-mins) / (maxs - mins) * 2 - 1 33 | 34 | return Y 35 | 36 | 37 | def write_results_to_disk(dataset, metrics): 38 | # write results to file 39 | results_dir = os.path.join('results', dataset) 40 | results_file = os.path.join(results_dir, 'results.csv') 41 | file_exists = os.path.isfile(results_file) 42 | 43 | if not os.path.exists(results_dir): 44 | os.makedirs(results_dir) 45 | 46 | with open(results_file, 'a', encoding="utf-8") as csvfile: 47 | writer = csv.DictWriter(csvfile, fieldnames=list(metrics.keys())) 48 | if not file_exists: 49 | writer.writeheader() 50 | writer.writerows([metrics]) 51 | --------------------------------------------------------------------------------