├── data ├── __init__.py ├── sr_utils.py ├── datasets │ ├── plot_ringtree_dataset.py │ ├── __init__.py │ ├── test_ocean.py │ ├── test_ringtransfer.py │ ├── test_zinc.py │ ├── test_flow.py │ ├── ocean.py │ ├── flow.py │ ├── plot_flow_dataset.py │ ├── cluster.py │ ├── dummy.py │ ├── ringlookup.py │ ├── ringtransfer.py │ ├── ring_utils.py │ ├── ogb.py │ ├── sr.py │ ├── csl.py │ ├── zinc.py │ └── tu.py ├── test_parallel.py ├── parallel.py ├── perm_utils.py ├── test_data.py ├── test_dataset.py ├── test_tu_utils.py └── helper_test.py ├── exp ├── __init__.py ├── launch_tu_tuning.sh ├── scripts │ ├── mpsn-ocean.sh │ ├── mpsn-flow.sh │ ├── gnn-inv-ocean.sh │ ├── gnn-inv-flow.sh │ ├── mpsn-sr-base.sh │ ├── mpsn-sr.sh │ ├── cwn-sr-base.sh │ ├── cwn-sr.sh │ ├── mpsn-redditb.sh │ ├── cwn-nci109.sh │ ├── cwn-csl.sh │ ├── cin++-nci109.sh │ ├── cwn-molhiv.sh │ ├── cwn-molhiv-small.sh │ ├── cwn-zinc.sh │ ├── cin++-molhiv.sh │ ├── cwn-zinc-full.sh │ ├── cwn-zinc-small.sh │ ├── cin++-molhiv-small.sh │ ├── cin++-zinc.sh │ ├── cwn-zinc-full-small.sh │ ├── cin++-zinc-small.sh │ ├── cin++-zinc-500k.sh │ ├── cin++-pep-f.sh │ └── cin++-pep-s.sh ├── prepare_tu_tuning.py ├── test_run_exp.py ├── tuning_configurations │ └── template.yml ├── run_tu_tuning.py ├── prepare_sr_tests.py ├── run_tu_exp.py ├── run_ring_exp.py ├── evaluate_sr_cwn_emb_mag.py ├── run_sr_exp.py ├── count_rings.py ├── plot_sr_cwn_results.py ├── run_mol_exp.py ├── test_sr.py └── train_utils.py ├── mp ├── __init__.py ├── test_permutation.py ├── nn.py ├── cell_mp_inspector.py ├── ring_exp_models.py ├── test_orientation.py └── test_layers.py ├── datasets ├── CSL │ ├── .gitignore │ └── splits │ │ ├── CSL_test.txt │ │ ├── CSL_val.txt │ │ └── CSL_train.txt ├── SR_graphs │ ├── .gitignore │ ├── raw │ │ ├── sr16622.g6 │ │ ├── sr281264.g6 │ │ ├── sr261034.g6 │ │ ├── sr251256.g6 │ │ ├── sr291467.g6 │ │ └── sr401224.g6 │ └── README.md └── .gitignore ├── definitions.py ├── figures ├── cwn.png ├── sphere.jpeg ├── glue_disks.jpeg ├── empty_tetrahderon.jpeg └── cell_complex_molecule.png ├── graph-tool_install.sh ├── pyG_install.sh ├── requirements.txt ├── conftest.py ├── LICENSE ├── .github └── workflows │ └── python-package.yml ├── .gitignore └── README.md /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /exp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/CSL/.gitignore: -------------------------------------------------------------------------------- 1 | /* 2 | !/splits/ 3 | !/.gitignore 4 | -------------------------------------------------------------------------------- /datasets/SR_graphs/.gitignore: -------------------------------------------------------------------------------- 1 | /* 2 | !/raw/ 3 | !/.gitignore 4 | !/README.md 5 | -------------------------------------------------------------------------------- /definitions.py: -------------------------------------------------------------------------------- 1 | import os 2 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) -------------------------------------------------------------------------------- /figures/cwn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twitter-research/cwn/HEAD/figures/cwn.png -------------------------------------------------------------------------------- /figures/sphere.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twitter-research/cwn/HEAD/figures/sphere.jpeg -------------------------------------------------------------------------------- /figures/glue_disks.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/twitter-research/cwn/HEAD/figures/glue_disks.jpeg -------------------------------------------------------------------------------- /graph-tool_install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | conda install -c conda-forge -y graph-tool==2.44 4 | -------------------------------------------------------------------------------- /datasets/SR_graphs/raw/sr16622.g6: -------------------------------------------------------------------------------- 1 | >>graph6<>graph6< Given integers `l`, `u`, every two adjacent nodes have `l` common neighbours and every two non-adjacent nodes have `u` common neighbours. 5 | 6 | SR graphs in the same family share the same parameters `n`, `d`, `l`, `u`, with `n` the number of nodes in each graph. 7 | Two non-isomorphic SR graphs in the same family cannot be distinguished by the standard WL test, and not even the more 8 | powerful 3-WL. 9 | 10 | In `./raw`, each family is stored in `g6` format and is named as `sr.g6` (two digits for ``). These data 11 | were originally obtained from `http://users.cecs.anu.edu.au/~bdm/data/graphs.html`. 12 | -------------------------------------------------------------------------------- /exp/scripts/cwn-csl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --start_seed 0 \ 5 | --stop_seed 19 \ 6 | --folds 5 \ 7 | --exp_name=cwn-csl \ 8 | --dataset CSL \ 9 | --train_eval_period 25 \ 10 | --epochs 300 \ 11 | --batch_size 12 \ 12 | --drop_rate 0.0 \ 13 | --graph_norm ln \ 14 | --drop_position lin2 \ 15 | --emb_dim 160 \ 16 | --max_dim 2 \ 17 | --final_readout sum \ 18 | --init_method sum \ 19 | --lr 5e-4 \ 20 | --model embed_sparse_cin \ 21 | --nonlinearity relu \ 22 | --num_layers 3 \ 23 | --readout mean \ 24 | --max_ring_size=8 \ 25 | --lr_scheduler='ReduceLROnPlateau' \ 26 | --lr_scheduler_min 1e-6 \ 27 | --lr_scheduler_patience 20 \ 28 | --early_stop \ 29 | --use_edge_features \ 30 | --device=0 \ 31 | --use_coboundaries True \ 32 | --preproc_jobs 32 33 | -------------------------------------------------------------------------------- /exp/scripts/cin++-nci109.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_tu_exp \ 4 | --device 0 \ 5 | --exp_name cin++-nci109 \ 6 | --dataset NCI109 \ 7 | --train_eval_period 50 \ 8 | --epochs 150 \ 9 | --batch_size 32 \ 10 | --drop_rate 0.0 \ 11 | --drop_position lin2 \ 12 | --emb_dim 64 \ 13 | --max_dim 2 \ 14 | --final_readout sum \ 15 | --init_method mean \ 16 | --jump_mode 'cat' \ 17 | --lr 0.001 \ 18 | --graph_norm bn \ 19 | --model cin++ \ 20 | --include_down_adj \ 21 | --nonlinearity relu \ 22 | --num_layers 4 \ 23 | --readout sum \ 24 | --max_ring_size 6 \ 25 | --task_type classification \ 26 | --eval_metric accuracy \ 27 | --lr_scheduler 'StepLR' \ 28 | --lr_scheduler_decay_rate 0.5 \ 29 | --lr_scheduler_decay_steps 20 \ 30 | --use_coboundaries True \ 31 | --dump_curves \ 32 | --preproc_jobs 4 -------------------------------------------------------------------------------- /datasets/SR_graphs/raw/sr251256.g6: -------------------------------------------------------------------------------- 1 | X}rM\QTeLEuUlQY[I\IgtWfTxCrhEOZo{FfwEew`LXMbWp}JDtM 2 | X}rM\QWh[jUYkiYYJMBDfSrXsLRgdOMusE{{JQhgxiMDSxVEplU 3 | X}rM\Qhd[fU`kdUSjDjHjWqxtOJdOgMxwEzYFRPgtdMDctfIpsu 4 | X}rUTEdmTpQxbkSxHkZceZBLtObTQHHmsbM{EFxwrroAtQtBtBr 5 | X}rUTEdmTpQybiSwhkjgeYbLrHBX`HI\wavYEFxwrroAt`tBsdr 6 | X}rUTIbmLqQybiTWh[jcXZCtqpBYPHE]wcuyEFxwrroAtQtBtBr 7 | X}rU\adeSetTjKWNJEYNR]PLjPBgUGVTkK^YKbipMcxbk`{DlXF 8 | X}r^SQbkLJQjesS[jLJQhPxTcZZcZ?S|krEYBlQolTZDhWuNKKN 9 | X}r^SQbkLQqjdwS[jLJJHQtTcZZcZ?L\krEYBlQostZDhWuNKKN 10 | X}vEKeLlTTUXjKXXK[ZKo]EXZAqxI``\khVUD]VOVptAqtXBbrL 11 | X}vEKiJklYUXjKXWk[jKo]EXYdQyD`amkgmuD]VGVpuAqtUBbrR 12 | X}vEKiJlTTUUjQWxKkZKo]EXZBQxH`a]kguuC}NGZquAptYBdrJ 13 | X}ve[IJd\FUImBSYmFJRIWfLLRYmJ?Tl[ZKYHhsokuXdhS{NKKN 14 | X~rM[Edd\RRBkpOxlMIqZXDxTWZdQGX][fI[EsholidDRdVNGk] 15 | X~rM[Ihi[eQVlQPTkTiqrWwxXgZgiGX^KfWqEdu?nG\dRP}NGk] 16 | -------------------------------------------------------------------------------- /exp/scripts/cwn-molhiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 9 \ 7 | --exp_name cwn-molhiv \ 8 | --dataset MOLHIV \ 9 | --model ogb_embed_sparse_cin \ 10 | --use_coboundaries True \ 11 | --indrop_rate 0.0 \ 12 | --drop_rate 0.5 \ 13 | --graph_norm bn \ 14 | --drop_position lin2 \ 15 | --nonlinearity relu \ 16 | --readout mean \ 17 | --final_readout sum \ 18 | --lr 0.0001 \ 19 | --lr_scheduler None \ 20 | --num_layers 2 \ 21 | --emb_dim 64 \ 22 | --batch_size 128 \ 23 | --epochs 150 \ 24 | --num_workers 2 \ 25 | --preproc_jobs 32 \ 26 | --task_type bin_classification \ 27 | --eval_metric ogbg-molhiv \ 28 | --max_dim 2 \ 29 | --max_ring_size 6 \ 30 | --init_method sum \ 31 | --train_eval_period 10 \ 32 | --use_edge_features \ 33 | --dump_curves 34 | -------------------------------------------------------------------------------- /exp/scripts/cwn-molhiv-small.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 9 \ 7 | --exp_name cwn-molhiv-small \ 8 | --dataset MOLHIV \ 9 | --model ogb_embed_sparse_cin \ 10 | --use_coboundaries True \ 11 | --indrop_rate 0.0 \ 12 | --drop_rate 0.5 \ 13 | --graph_norm bn \ 14 | --drop_position lin2 \ 15 | --nonlinearity relu \ 16 | --readout mean \ 17 | --final_readout sum \ 18 | --lr 0.0001 \ 19 | --lr_scheduler None \ 20 | --num_layers 2 \ 21 | --emb_dim 48 \ 22 | --batch_size 128 \ 23 | --epochs 150 \ 24 | --num_workers 2 \ 25 | --preproc_jobs 32 \ 26 | --task_type bin_classification \ 27 | --eval_metric ogbg-molhiv \ 28 | --max_dim 2 \ 29 | --max_ring_size 6 \ 30 | --init_method sum \ 31 | --train_eval_period 10 \ 32 | --use_edge_features \ 33 | --dump_curves 34 | -------------------------------------------------------------------------------- /exp/scripts/cwn-zinc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 9 \ 7 | --exp_name cwn-zinc \ 8 | --dataset ZINC \ 9 | --train_eval_period 20 \ 10 | --epochs 1000 \ 11 | --batch_size 128 \ 12 | --drop_rate 0.0 \ 13 | --drop_position lin2 \ 14 | --emb_dim 128 \ 15 | --max_dim 2 \ 16 | --final_readout sum \ 17 | --init_method sum \ 18 | --lr 0.001 \ 19 | --graph_norm bn \ 20 | --model embed_sparse_cin \ 21 | --nonlinearity relu \ 22 | --num_layers 4 \ 23 | --readout sum \ 24 | --max_ring_size 18 \ 25 | --task_type regression \ 26 | --eval_metric mae \ 27 | --minimize \ 28 | --lr_scheduler 'ReduceLROnPlateau' \ 29 | --use_coboundaries True \ 30 | --use_edge_features \ 31 | --early_stop \ 32 | --lr_scheduler_patience 20 \ 33 | --dump_curves \ 34 | --preproc_jobs 32 35 | -------------------------------------------------------------------------------- /exp/scripts/cin++-molhiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 9 \ 7 | --exp_name cin++-molhiv \ 8 | --dataset MOLHIV \ 9 | --model ogb_embed_cin++ \ 10 | --include_down_adj \ 11 | --use_coboundaries True \ 12 | --indrop_rate 0.0 \ 13 | --drop_rate 0.5 \ 14 | --graph_norm bn \ 15 | --drop_position lin2 \ 16 | --nonlinearity relu \ 17 | --readout mean \ 18 | --final_readout sum \ 19 | --lr 0.0001 \ 20 | --lr_scheduler None \ 21 | --num_layers 2 \ 22 | --emb_dim 64 \ 23 | --batch_size 128 \ 24 | --epochs 150 \ 25 | --num_workers 2 \ 26 | --preproc_jobs 32 \ 27 | --task_type bin_classification \ 28 | --eval_metric ogbg-molhiv \ 29 | --max_dim 2 \ 30 | --max_ring_size 6 \ 31 | --init_method sum \ 32 | --train_eval_period 10 \ 33 | --use_edge_features \ 34 | --dump_curves 35 | -------------------------------------------------------------------------------- /exp/scripts/cwn-zinc-full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 3 \ 7 | --exp_name cwn-zinc-full \ 8 | --dataset ZINC-FULL \ 9 | --train_eval_period 25 \ 10 | --epochs 150 \ 11 | --batch_size 128 \ 12 | --drop_rate 0.0 \ 13 | --drop_position lin2 \ 14 | --emb_dim 128 \ 15 | --max_dim 2 \ 16 | --final_readout sum \ 17 | --init_method sum \ 18 | --lr 0.001 \ 19 | --graph_norm bn \ 20 | --model embed_sparse_cin \ 21 | --nonlinearity relu \ 22 | --num_layers 4 \ 23 | --readout sum \ 24 | --max_ring_size 18 \ 25 | --task_type regression \ 26 | --eval_metric mae \ 27 | --minimize \ 28 | --lr_scheduler 'ReduceLROnPlateau' \ 29 | --use_coboundaries True \ 30 | --use_edge_features \ 31 | --early_stop \ 32 | --lr_scheduler_patience 5 \ 33 | --dump_curves \ 34 | --preproc_jobs 32 35 | -------------------------------------------------------------------------------- /exp/scripts/cwn-zinc-small.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 9 \ 7 | --exp_name cwn-zinc-small \ 8 | --dataset ZINC \ 9 | --train_eval_period 20 \ 10 | --epochs 1000 \ 11 | --batch_size 128 \ 12 | --drop_rate 0.0 \ 13 | --drop_position lin2 \ 14 | --emb_dim 48 \ 15 | --max_dim 2 \ 16 | --final_readout sum \ 17 | --init_method sum \ 18 | --lr 0.001 \ 19 | --graph_norm bn \ 20 | --model embed_sparse_cin \ 21 | --nonlinearity relu \ 22 | --num_layers 2 \ 23 | --readout sum \ 24 | --max_ring_size 18 \ 25 | --task_type regression \ 26 | --eval_metric mae \ 27 | --minimize \ 28 | --lr_scheduler 'ReduceLROnPlateau' \ 29 | --use_coboundaries True \ 30 | --use_edge_features \ 31 | --early_stop \ 32 | --lr_scheduler_patience 20 \ 33 | --dump_curves \ 34 | --preproc_jobs 32 35 | -------------------------------------------------------------------------------- /exp/scripts/cin++-molhiv-small.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 9 \ 7 | --exp_name cin++-molhiv-small \ 8 | --dataset MOLHIV \ 9 | --model ogb_embed_cin++ \ 10 | --include_down_adj \ 11 | --use_coboundaries True \ 12 | --indrop_rate 0.0 \ 13 | --drop_rate 0.5 \ 14 | --graph_norm bn \ 15 | --drop_position lin2 \ 16 | --nonlinearity relu \ 17 | --readout mean \ 18 | --final_readout sum \ 19 | --lr 0.0001 \ 20 | --lr_scheduler None \ 21 | --num_layers 2 \ 22 | --emb_dim 48 \ 23 | --batch_size 128 \ 24 | --epochs 150 \ 25 | --num_workers 2 \ 26 | --preproc_jobs 32 \ 27 | --task_type bin_classification \ 28 | --eval_metric ogbg-molhiv \ 29 | --max_dim 2 \ 30 | --max_ring_size 6 \ 31 | --init_method sum \ 32 | --train_eval_period 10 \ 33 | --use_edge_features \ 34 | --dump_curves 35 | -------------------------------------------------------------------------------- /exp/scripts/cin++-zinc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 9 \ 7 | --exp_name cin++-zinc \ 8 | --dataset ZINC \ 9 | --train_eval_period 20 \ 10 | --epochs 1000 \ 11 | --batch_size 128 \ 12 | --drop_rate 0.0 \ 13 | --drop_position lin2 \ 14 | --emb_dim 128 \ 15 | --max_dim 2 \ 16 | --final_readout sum \ 17 | --init_method sum \ 18 | --lr 0.001 \ 19 | --graph_norm bn \ 20 | --model embed_cin++ \ 21 | --include_down_adj \ 22 | --nonlinearity relu \ 23 | --num_layers 4 \ 24 | --readout sum \ 25 | --max_ring_size 18 \ 26 | --task_type regression \ 27 | --eval_metric mae \ 28 | --minimize \ 29 | --lr_scheduler 'ReduceLROnPlateau' \ 30 | --use_coboundaries True \ 31 | --use_edge_features \ 32 | --early_stop \ 33 | --lr_scheduler_patience 20 \ 34 | --dump_curves \ 35 | --preproc_jobs 32 36 | -------------------------------------------------------------------------------- /exp/scripts/cwn-zinc-full-small.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 3 \ 7 | --exp_name cwn-zinc-full-small \ 8 | --dataset ZINC-FULL \ 9 | --train_eval_period 25 \ 10 | --epochs 150 \ 11 | --batch_size 128 \ 12 | --drop_rate 0.0 \ 13 | --drop_position lin2 \ 14 | --emb_dim 48 \ 15 | --max_dim 2 \ 16 | --final_readout sum \ 17 | --init_method sum \ 18 | --lr 0.001 \ 19 | --graph_norm bn \ 20 | --model embed_sparse_cin \ 21 | --nonlinearity relu \ 22 | --num_layers 2 \ 23 | --readout sum \ 24 | --max_ring_size 18 \ 25 | --task_type regression \ 26 | --eval_metric mae \ 27 | --minimize \ 28 | --lr_scheduler 'ReduceLROnPlateau' \ 29 | --use_coboundaries True \ 30 | --use_edge_features \ 31 | --early_stop \ 32 | --lr_scheduler_patience 5 \ 33 | --dump_curves \ 34 | --preproc_jobs 32 35 | -------------------------------------------------------------------------------- /exp/scripts/cin++-zinc-small.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 9 \ 7 | --exp_name cin++-zinc \ 8 | --dataset ZINC \ 9 | --train_eval_period 20 \ 10 | --epochs 1000 \ 11 | --batch_size 128 \ 12 | --drop_rate 0.0 \ 13 | --drop_position lin2 \ 14 | --emb_dim 48 \ 15 | --max_dim 2 \ 16 | --final_readout sum \ 17 | --init_method sum \ 18 | --lr 0.001 \ 19 | --graph_norm bn \ 20 | --model embed_cin++ \ 21 | --include_down_adj \ 22 | --nonlinearity relu \ 23 | --num_layers 2 \ 24 | --readout sum \ 25 | --max_ring_size 18 \ 26 | --task_type regression \ 27 | --eval_metric mae \ 28 | --minimize \ 29 | --lr_scheduler 'ReduceLROnPlateau' \ 30 | --use_coboundaries True \ 31 | --use_edge_features \ 32 | --early_stop \ 33 | --lr_scheduler_patience 20 \ 34 | --dump_curves \ 35 | --preproc_jobs 32 36 | -------------------------------------------------------------------------------- /exp/scripts/cin++-zinc-500k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 9 \ 7 | --exp_name cin++-zinc-500k \ 8 | --dataset ZINC \ 9 | --train_eval_period 20 \ 10 | --epochs 1000 \ 11 | --batch_size 128 \ 12 | --drop_rate 0.0 \ 13 | --drop_position lin2 \ 14 | --emb_dim 64 \ 15 | --max_dim 2 \ 16 | --final_readout sum \ 17 | --init_method sum \ 18 | --lr 0.001 \ 19 | --graph_norm bn \ 20 | --model embed_cin++ \ 21 | --include_down_adj \ 22 | --nonlinearity relu \ 23 | --num_layers 3 \ 24 | --readout sum \ 25 | --max_ring_size 18 \ 26 | --task_type regression \ 27 | --eval_metric mae \ 28 | --minimize \ 29 | --lr_scheduler 'ReduceLROnPlateau' \ 30 | --use_coboundaries True \ 31 | --use_edge_features \ 32 | --early_stop \ 33 | --lr_scheduler_patience 20 \ 34 | --dump_curves \ 35 | --preproc_jobs 32 36 | -------------------------------------------------------------------------------- /exp/scripts/cin++-pep-f.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 3 \ 7 | --exp_name cwn-pep-f-500k \ 8 | --dataset PEPTIDES-F \ 9 | --model ogb_embed_cin++ \ 10 | --include_down_adj \ 11 | --use_coboundaries True \ 12 | --indrop_rate 0.0 \ 13 | --drop_rate 0.15 \ 14 | --graph_norm bn \ 15 | --drop_position lin2 \ 16 | --nonlinearity relu \ 17 | --readout sum \ 18 | --final_readout sum \ 19 | --lr 0.001 \ 20 | --num_layers 3 \ 21 | --emb_dim 64 \ 22 | --batch_size 128 \ 23 | --epochs 1000 \ 24 | --num_workers 0 \ 25 | --preproc_jobs 32 \ 26 | --task_type bin_classification \ 27 | --eval_metric ap \ 28 | --max_dim 2 \ 29 | --max_ring_size 8 \ 30 | --lr_scheduler 'ReduceLROnPlateau' \ 31 | --init_method sum \ 32 | --train_eval_period 10 \ 33 | --use_edge_features \ 34 | --lr_scheduler_patience 15 \ 35 | --dump_curves 36 | -------------------------------------------------------------------------------- /exp/scripts/cin++-pep-s.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m exp.run_mol_exp \ 4 | --device 0 \ 5 | --start_seed 0 \ 6 | --stop_seed 3 \ 7 | --exp_name cwn-pep-s-500k \ 8 | --dataset PEPTIDES-S \ 9 | --model ogb_embed_cin++ \ 10 | --include_down_adj \ 11 | --use_coboundaries True \ 12 | --indrop_rate 0.0 \ 13 | --drop_rate 0.0 \ 14 | --graph_norm bn \ 15 | --drop_position lin2 \ 16 | --nonlinearity relu \ 17 | --readout mean \ 18 | --final_readout sum \ 19 | --lr 0.001 \ 20 | --num_layers 3 \ 21 | --emb_dim 64 \ 22 | --batch_size 128 \ 23 | --epochs 1000 \ 24 | --num_workers 0 \ 25 | --preproc_jobs 32 \ 26 | --task_type regression \ 27 | --eval_metric mae \ 28 | --max_dim 2 \ 29 | --max_ring_size 8 \ 30 | --lr_scheduler 'ReduceLROnPlateau' \ 31 | --init_method sum \ 32 | --minimize \ 33 | --early_stop \ 34 | --train_eval_period 10 \ 35 | --use_edge_features \ 36 | --lr_scheduler_patience 20 \ 37 | --dump_curves 38 | -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from data.datasets.dataset import ComplexDataset, InMemoryComplexDataset 2 | from data.datasets.sr import SRDataset, load_sr_graph_dataset 3 | from data.datasets.cluster import ClusterDataset 4 | from data.datasets.tu import TUDataset, load_tu_graph_dataset 5 | from data.datasets.flow import FlowDataset 6 | from data.datasets.ocean import OceanDataset 7 | from data.datasets.zinc import ZincDataset, load_zinc_graph_dataset 8 | from data.datasets.dummy import DummyDataset, DummyMolecularDataset 9 | from data.datasets.csl import CSLDataset 10 | from data.datasets.ogb import OGBDataset, load_ogb_graph_dataset 11 | from data.datasets.peptides_functional import PeptidesFunctionalDataset, load_pep_f_graph_dataset 12 | from data.datasets.peptides_structural import PeptidesStructuralDataset, load_pep_s_graph_dataset 13 | from data.datasets.ringtransfer import RingTransferDataset, load_ring_transfer_dataset 14 | from data.datasets.ringlookup import RingLookupDataset, load_ring_lookup_dataset 15 | 16 | -------------------------------------------------------------------------------- /exp/test_run_exp.py: -------------------------------------------------------------------------------- 1 | from exp.parser import get_parser 2 | from exp.run_exp import main 3 | 4 | def get_args_for_dummym(): 5 | args = list() 6 | args += ['--use_coboundaries', 'True'] 7 | args += ['--graph_norm', 'id'] 8 | args += ['--lr_scheduler', 'None'] 9 | args += ['--num_layers', '3'] 10 | args += ['--emb_dim', '8'] 11 | args += ['--batch_size', '3'] 12 | args += ['--epochs', '1'] 13 | args += ['--dataset', 'DUMMYM'] 14 | args += ['--max_ring_size', '5'] 15 | args += ['--exp_name', 'dummym_test'] 16 | args += ['--readout_dims', '0', '2'] 17 | return args 18 | 19 | def test_run_exp_on_dummym(): 20 | parser = get_parser() 21 | args = get_args_for_dummym() 22 | parsed_args = parser.parse_args(args) 23 | curves = main(parsed_args) 24 | # On this dataset the splits all coincide; we assert 25 | # that the final performance is the same on all of them. 26 | assert curves['last_train'] == curves['last_val'] 27 | assert curves['last_train'] == curves['last_test'] -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def pytest_addoption(parser): 5 | parser.addoption( 6 | "--runslow", action="store_true", default=False, help="run slow tests", 7 | ) 8 | parser.addoption( 9 | "--rundata", action="store_true", default=False, help="run tests using datasets", 10 | ) 11 | 12 | 13 | def pytest_configure(config): 14 | config.addinivalue_line("markers", "slow: mark test as slow to run") 15 | config.addinivalue_line("markers", "data: mark test as using a dataset") 16 | 17 | 18 | def pytest_collection_modifyitems(config, items): 19 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 20 | skip_data = pytest.mark.skip(reason="need --rundata option to run") 21 | 22 | if not config.getoption("--runslow"): 23 | for item in items: 24 | if "slow" in item.keywords: 25 | item.add_marker(skip_slow) 26 | 27 | if not config.getoption("--rundata"): 28 | for item in items: 29 | if "data" in item.keywords: 30 | item.add_marker(skip_data) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 The CWN Project Authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /exp/tuning_configurations/template.yml: -------------------------------------------------------------------------------- 1 | # Just a basic template that can be customised to set-up grids on TUDatasets. 2 | # TODO: add support to readout only on specific dimensions with arg "--readout_dims". 3 | dataset: TU_dataset_name 4 | epochs: 5 | - 150 6 | batch_size: 7 | - 32 8 | - 128 9 | drop_position: 10 | - lin2 11 | - final_readout 12 | - lin1 13 | drop_rate: 14 | - 0.0 15 | - 0.5 16 | emb_dim: 17 | - 16 18 | - 32 19 | - 64 20 | final_readout: 21 | - sum 22 | init_method: 23 | - sum 24 | - mean 25 | jump_mode: 26 | - cat 27 | lr: 28 | - 0.0005 29 | - 0.001 30 | - 0.003 31 | - 0.01 32 | lr_scheduler: 33 | - StepLR 34 | lr_scheduler_decay_rate: 35 | - 0.5 36 | - 0.9 37 | lr_scheduler_decay_steps: 38 | - 50 39 | - 20 40 | max_dim: 41 | # If supplying a max_ring_size, max_dim should be set to 2. 42 | - 2 43 | max_ring_size: 44 | # Remove this one if you want to tune an MPSN. 45 | - 6 46 | model: 47 | - sparse_cin 48 | use_coboundaries: 49 | - True 50 | - False 51 | nonlinearity: 52 | - relu 53 | num_layers: 54 | - 3 55 | - 4 56 | readout: 57 | - mean 58 | # Use sum for bio datasets. 59 | # - sum 60 | train_eval_period: 61 | - 50 -------------------------------------------------------------------------------- /data/datasets/test_ocean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from data.datasets.ocean_utils import load_ocean_dataset 4 | 5 | 6 | @pytest.mark.data 7 | def test_ocean_dataset_generation(): 8 | train, test, _ = load_ocean_dataset() 9 | assert len(train) == 160 10 | assert len(test) == 40 11 | 12 | for cochain in train + test: 13 | # checks the upper/lower orientation features are consistent 14 | # in shape with the upper/lower indices 15 | assert len(cochain.upper_orient) == cochain.upper_index.size(1) 16 | assert len(cochain.lower_orient) == cochain.lower_index.size(1) 17 | # checks the upper and lower indices are consistent with the number of edges 18 | assert cochain.upper_index.max() < cochain.x.size(0), print(cochain.upper_index.max(), 19 | cochain.x.size(0)) 20 | assert cochain.lower_index.max() < cochain.x.size(0), print(cochain.lower_index.max(), 21 | cochain.x.size(0)) 22 | 23 | # checks the values for orientations are either +1 (coherent) or -1 (not coherent) 24 | assert (torch.sum(cochain.upper_orient == 1) 25 | + torch.sum(cochain.upper_orient == -1) == cochain.upper_orient.numel()) 26 | assert (torch.sum(cochain.lower_orient == 1) 27 | + torch.sum(cochain.lower_orient == -1) == cochain.lower_orient.numel()) 28 | -------------------------------------------------------------------------------- /datasets/CSL/splits/CSL_train.txt: -------------------------------------------------------------------------------- 1 | 62,69,88,115,43,128,11,55,108,84,137,74,48,104,72,81,24,34,90,12,134,37,99,58,63,25,22,7,4,30,103,21,87,28,148,86,68,44,118,145,36,114,95,54,42,144,18,20,124,8,85,126,122,138,49,73,5,92,141,113,45,109,120,146,51,59,149,89,41,15,46,139,2,116,123,78,61,97,121,112,125,76,101,119,6,1,67,100,31,29 2 | 63,65,95,49,55,141,135,105,134,122,61,66,132,89,118,37,99,92,91,142,18,0,93,139,15,116,9,148,3,30,53,43,138,35,84,67,12,51,117,22,125,25,69,74,81,39,120,70,10,17,45,137,124,75,11,87,94,131,140,29,107,58,2,20,82,96,119,83,23,28,112,36,8,130,60,56,115,42,50,59,31,76,100,128,78,44,114,103,145,6 3 | 42,137,47,57,10,87,148,83,61,23,107,32,109,73,14,123,46,72,38,37,51,48,24,127,21,92,141,36,132,6,13,136,84,19,98,91,63,120,134,16,106,18,60,50,112,80,125,66,22,135,131,94,4,3,1,79,70,102,110,145,96,33,56,146,108,93,58,143,119,97,15,129,27,8,81,40,149,115,55,75,9,30,78,74,104,89,35,62,133,114 4 | 70,90,65,46,44,87,128,114,98,136,33,80,116,149,84,118,7,47,5,103,21,34,57,37,134,25,100,142,3,60,79,97,69,135,8,66,144,45,139,12,148,102,107,35,95,1,108,124,18,82,16,59,143,76,111,10,126,43,77,29,109,9,26,94,73,120,49,6,140,20,83,74,48,110,133,51,72,32,39,130,24,127,91,54,123,85,30,27,62,112 5 | 55,26,13,23,46,16,101,136,91,125,111,80,25,9,14,128,56,15,109,100,133,20,40,82,143,21,88,113,38,77,45,139,62,67,130,124,93,10,35,110,73,144,116,44,58,142,107,85,36,83,131,106,64,7,70,138,105,127,108,66,97,79,34,0,147,5,52,61,47,72,92,48,17,1,33,68,24,129,32,78,123,102,95,149,145,11,41,76,57,99 6 | -------------------------------------------------------------------------------- /exp/run_tu_tuning.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import copy 4 | import yaml 5 | import argparse 6 | from definitions import ROOT_DIR 7 | from exp.parser import get_parser 8 | from exp.run_tu_exp import exp_main 9 | 10 | __max_devices__ = 8 11 | 12 | if __name__ == "__main__": 13 | 14 | parser = argparse.ArgumentParser(description='CWN tuning.') 15 | parser.add_argument('--conf', type=str, help='path to yaml configuration') 16 | parser.add_argument('--code', type=str, help='tuning name') 17 | parser.add_argument('--idx', type=int, help='selection index') 18 | t_args = parser.parse_args() 19 | 20 | # parse grid from yaml 21 | with open(t_args.conf, 'r') as handle: 22 | conf = yaml.safe_load(handle) 23 | dataset = conf['dataset'] 24 | hyper_list = list() 25 | hyper_values = list() 26 | for key in conf: 27 | if key == 'dataset': 28 | continue 29 | hyper_list.append(key) 30 | hyper_values.append(conf[key]) 31 | grid = itertools.product(*hyper_values) 32 | exp_queue = list() 33 | for h, hypers in enumerate(grid): 34 | if h % __max_devices__ == (t_args.idx % __max_devices__): 35 | exp_queue.append((h, hypers)) 36 | 37 | # form args 38 | base_args = [ 39 | '--device', str(t_args.idx), 40 | '--task_type', 'classification', 41 | '--eval_metric', 'accuracy', 42 | '--dataset', dataset, 43 | '--result_folder', os.path.join(ROOT_DIR, 'exp', 'results', '{}_tuning_{}'.format(dataset, t_args.code))] 44 | 45 | for exp in exp_queue: 46 | args = copy.copy(base_args) 47 | addendum = ['--exp_name', str(exp[0])] 48 | hypers = exp[1] 49 | for name, value in zip(hyper_list, hypers): 50 | addendum.append('--{}'.format(name)) 51 | addendum.append('{}'.format(value)) 52 | args += addendum 53 | exp_main(args) 54 | 55 | -------------------------------------------------------------------------------- /exp/prepare_sr_tests.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | 5 | from data.data_loading import load_dataset, load_graph_dataset 6 | from data.perm_utils import permute_graph, generate_permutation_matrices 7 | from definitions import ROOT_DIR 8 | 9 | __families__ = [ 10 | 'sr16622', 11 | 'sr251256', 12 | 'sr261034', 13 | 'sr281264', 14 | 'sr291467', 15 | 'sr351668', 16 | 'sr351899', 17 | 'sr361446', 18 | 'sr401224' 19 | ] 20 | 21 | def prepare(family, jobs, max_ring_size, permute, init, seed): 22 | root = os.path.join(ROOT_DIR, 'datasets') 23 | raw_dir = os.path.join(root, 'SR_graphs', 'raw') 24 | _ = load_dataset(family, max_dim=2, max_ring_size=max_ring_size, n_jobs=jobs, init_method=init) 25 | if permute: 26 | graphs, _, _, _, _ = load_graph_dataset(family) 27 | permuted_graphs = list() 28 | for graph in graphs: 29 | perm = generate_permutation_matrices(graph.num_nodes, 1, seed=seed)[0] 30 | permuted_graph = permute_graph(graph, perm) 31 | permuted_graphs.append((permuted_graph.edge_index, permuted_graph.num_nodes)) 32 | with open(os.path.join(raw_dir, f'{family}p{seed}.pkl'), 'wb') as handle: 33 | pickle.dump(permuted_graphs, handle) 34 | _ = load_dataset(f'{family}p{seed}', max_dim=2, max_ring_size=max_ring_size, n_jobs=jobs, init_method=init) 35 | 36 | if __name__ == "__main__": 37 | 38 | # Standard args 39 | passed_args = sys.argv[1:] 40 | jobs = int(passed_args[0]) 41 | max_ring_size = int(passed_args[1]) 42 | permute = passed_args[2].lower() 43 | init_method = passed_args[3].lower() 44 | assert max_ring_size > 3 45 | 46 | # Execute 47 | for family in __families__: 48 | print('\n==============================================================') 49 | print(f'[i] Preprocessing on family {family}...') 50 | prepare(family, jobs, max_ring_size, permute=='y', init_method, 43) 51 | -------------------------------------------------------------------------------- /mp/test_permutation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from data.utils import compute_ring_2complex 4 | from data.perm_utils import permute_graph, generate_permutation_matrices 5 | from data.dummy_complexes import get_mol_testing_complex_list, convert_to_graph 6 | from data.complex import ComplexBatch 7 | from mp.models import SparseCIN 8 | 9 | def test_sparse_cin0_perm_invariance_on_dummy_mol_complexes(): 10 | 11 | # Generate reference graph list 12 | dummy_complexes = get_mol_testing_complex_list() 13 | dummy_graphs = [convert_to_graph(complex) for complex in dummy_complexes] 14 | for graph in dummy_graphs: 15 | graph.edge_attr = None 16 | # (We convert back to complexes to regenerate signals on edges and rings, fixing max_k to 7) 17 | dummy_complexes = [compute_ring_2complex(graph.x, graph.edge_index, None, graph.num_nodes, max_k=7, 18 | include_down_adj=False, init_method='sum', init_edges=True, init_rings=True) 19 | for graph in dummy_graphs] 20 | 21 | # Instantiate model 22 | model = SparseCIN(num_input_features=1, num_classes=16, num_layers=3, hidden=32, use_coboundaries=True, nonlinearity='elu') 23 | model.eval() 24 | 25 | # Compute reference complex embeddings 26 | embeddings = [model.forward(ComplexBatch.from_complex_list([comp], max_dim=comp.dimension)) for comp in dummy_complexes] 27 | 28 | # Test invariance for multiple random permutations 29 | for comp_emb, graph in zip(embeddings, dummy_graphs): 30 | permutations = generate_permutation_matrices(graph.num_nodes, 5) 31 | for perm in permutations: 32 | permuted_graph = permute_graph(graph, perm) 33 | permuted_comp = compute_ring_2complex(permuted_graph.x, permuted_graph.edge_index, None, permuted_graph.num_nodes, 34 | max_k=7, include_down_adj=False, init_method='sum', init_edges=True, init_rings=True) 35 | permuted_emb = model.forward(ComplexBatch.from_complex_list([permuted_comp], max_dim=permuted_comp.dimension)) 36 | assert torch.allclose(comp_emb, permuted_emb, atol=1e-6) 37 | -------------------------------------------------------------------------------- /data/datasets/test_ringtransfer.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from data.datasets.ring_utils import generate_ring_transfer_graph_dataset 4 | from data.utils import convert_graph_dataset_with_rings 5 | from data.datasets import RingTransferDataset 6 | from definitions import ROOT_DIR 7 | 8 | 9 | def test_ringtree_dataset_generation(): 10 | dataset = generate_ring_transfer_graph_dataset(nodes=10, samples=100, classes=5) 11 | labels = dict() 12 | for data in dataset: 13 | assert data.edge_index[0].min() == 0 14 | assert data.edge_index[1].min() == 0 15 | assert data.edge_index[0].max() == 9 16 | assert data.edge_index[1].max() == 9 17 | assert data.x.size(0) == 10 18 | assert data.x.size(1) == 5 19 | 20 | label = data.y.item() 21 | if label not in labels: 22 | labels[label] = 0 23 | labels[label] += 1 24 | 25 | assert list(range(5)) == list(sorted(labels.keys())) 26 | assert {20} == set(labels.values()) 27 | 28 | 29 | def test_ringtree_dataset_conversion(): 30 | dataset = generate_ring_transfer_graph_dataset(nodes=10, samples=10, classes=5) 31 | complexes, _, _ = convert_graph_dataset_with_rings(dataset, max_ring_size=10, 32 | include_down_adj=False, init_rings=True) 33 | 34 | for complex in complexes: 35 | assert 2 in complex.cochains 36 | assert complex.cochains[2].num_cells == 1 37 | assert complex.cochains[1].num_cells == 10 38 | assert complex.cochains[0].num_cells == 10 39 | assert complex.nodes.x.size(0) == 10 40 | assert complex.nodes.x.size(1) == 5 41 | assert complex.edges.x.size(0) == 10 42 | assert complex.edges.x.size(1) == 5 43 | assert complex.two_cells.x.size(0) == 1 44 | assert complex.two_cells.x.size(1) == 5 45 | 46 | 47 | def test_ring_transfer_dataset_loading(): 48 | # Test everything runs without errors. 49 | root = osp.join(ROOT_DIR, 'datasets', 'RING-TRANSFER') 50 | dataset = RingTransferDataset(root=root, train=20, test=10) 51 | dataset.get(0) 52 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: 9 | - main 10 | schedule: 11 | # Run the tests at 00:00 each day 12 | - cron: "0 0 * * *" 13 | 14 | jobs: 15 | build: 16 | 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: [3.8] 21 | defaults: 22 | run: 23 | shell: bash -l {0} 24 | 25 | steps: 26 | - uses: actions/checkout@v2 27 | - name: cache conda 28 | uses: actions/cache@v2 29 | env: 30 | # Increase this value to reset cache if etc/example-environment.yml has not changed 31 | CACHE_NUMBER: 0 32 | with: 33 | path: ~/conda_pkgs_dir 34 | key: 35 | ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ 36 | hashFiles('requirements.txt') }} 37 | - uses: conda-incubator/setup-miniconda@v2 38 | with: 39 | activate-environment: test 40 | python-version: 3.8 41 | use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! 42 | - name: Set up env 43 | run: | 44 | conda activate test 45 | conda install pip 46 | - name: Cache pip 47 | uses: actions/cache@v2 48 | with: 49 | # This path is specific to Ubuntu 50 | path: ~/.cache/pip 51 | # Look to see if there is a cache hit for the corresponding requirements file 52 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 53 | restore-keys: | 54 | ${{ runner.os }}-pip- 55 | ${{ runner.os }}- 56 | - name: Install graph-tool 57 | run: | 58 | conda install -c conda-forge -y graph-tool==2.44 59 | - name: pytorch 60 | run: | 61 | conda install -y pytorch=1.7.0 torchvision cudatoolkit=10.2 -c pytorch --update-deps 62 | - name: Install pyG 63 | run: | 64 | ./pyG_install.sh cu102 65 | - name: Install dependencies 66 | run: | 67 | pip install -r requirements.txt 68 | - name: Test with pytest 69 | run: | 70 | pytest -v 71 | -------------------------------------------------------------------------------- /mp/nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import global_mean_pool, global_add_pool 4 | from torch.nn import BatchNorm1d as BN, LayerNorm as LN, Identity 5 | 6 | 7 | def get_nonlinearity(nonlinearity, return_module=True): 8 | if nonlinearity == 'relu': 9 | module = torch.nn.ReLU 10 | function = F.relu 11 | elif nonlinearity == 'elu': 12 | module = torch.nn.ELU 13 | function = F.elu 14 | elif nonlinearity == 'id': 15 | module = torch.nn.Identity 16 | function = lambda x: x 17 | elif nonlinearity == 'sigmoid': 18 | module = torch.nn.Sigmoid 19 | function = F.sigmoid 20 | elif nonlinearity == 'tanh': 21 | module = torch.nn.Tanh 22 | function = torch.tanh 23 | else: 24 | raise NotImplementedError('Nonlinearity {} is not currently supported.'.format(nonlinearity)) 25 | if return_module: 26 | return module 27 | return function 28 | 29 | 30 | def get_pooling_fn(readout): 31 | if readout == 'sum': 32 | return global_add_pool 33 | elif readout == 'mean': 34 | return global_mean_pool 35 | else: 36 | raise NotImplementedError('Readout {} is not currently supported.'.format(readout)) 37 | 38 | 39 | def get_graph_norm(norm): 40 | if norm == 'bn': 41 | return BN 42 | elif norm == 'ln': 43 | return LN 44 | elif norm == 'id': 45 | return Identity 46 | else: 47 | raise ValueError(f'Graph Normalisation {norm} not currently supported') 48 | 49 | 50 | def pool_complex(xs, data, max_dim, readout_type): 51 | pooling_fn = get_pooling_fn(readout_type) 52 | # All complexes have nodes so we can extract the batch size from cochains[0] 53 | batch_size = data.cochains[0].batch.max() + 1 54 | # The MP output is of shape [message_passing_dim, batch_size, feature_dim] 55 | pooled_xs = torch.zeros(max_dim+1, batch_size, xs[0].size(-1), 56 | device=batch_size.device) 57 | for i in range(len(xs)): 58 | # It's very important that size is supplied. 59 | pooled_xs[i, :, :] = pooling_fn(xs[i], data.cochains[i].batch, size=batch_size) 60 | return pooled_xs 61 | -------------------------------------------------------------------------------- /mp/cell_mp_inspector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on https://github.com/rusty1s/pytorch_geometric/blob/76d61eaa9fc8702aa25f29dfaa5134a169d0f1f6/torch_geometric/nn/conv/utils/inspector.py 3 | 4 | MIT License 5 | 6 | Copyright (c) 2020 Matthias Fey 7 | Copyright (c) 2021 The CWN Project Authors 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in 17 | all copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 25 | THE SOFTWARE. 26 | """ 27 | 28 | import inspect 29 | from collections import OrderedDict 30 | from typing import Dict, Any, Callable 31 | from torch_geometric.nn.conv.utils.inspector import Inspector 32 | 33 | 34 | class CellularInspector(Inspector): 35 | """Wrapper of the PyTorch Geometric Inspector so to adapt it to our use cases.""" 36 | 37 | def __implements__(self, cls, func_name: str) -> bool: 38 | if cls.__name__ == 'CochainMessagePassing': 39 | return False 40 | if func_name in cls.__dict__.keys(): 41 | return True 42 | return any(self.__implements__(c, func_name) for c in cls.__bases__) 43 | 44 | def inspect(self, func: Callable, pop_first_n: int = 0) -> Dict[str, Any]: 45 | params = inspect.signature(func).parameters 46 | params = OrderedDict(params) 47 | for _ in range(pop_first_n): 48 | params.popitem(last=False) 49 | self.params[func.__name__] = params 50 | -------------------------------------------------------------------------------- /data/perm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from scipy import sparse as sp 5 | from torch_geometric.data import Data 6 | 7 | def permute_graph(graph: Data, P: np.ndarray) -> Data: 8 | 9 | # TODO: support edge features and their permutation 10 | assert graph.edge_attr is None 11 | 12 | # Check validity of permutation matrix 13 | n = graph.x.size(0) 14 | if not is_valid_permutation_matrix(P, n): 15 | raise AssertionError 16 | 17 | # Apply permutation to features 18 | x = graph.x.numpy() 19 | x_perm = torch.FloatTensor(P @ x) 20 | 21 | # Apply perm to labels, if per-node 22 | if graph.y is None: 23 | y_perm = None 24 | elif graph.y.size(0) == n: 25 | y = graph.y.numpy() 26 | y_perm = torch.tensor(P @ y) 27 | else: 28 | y_perm = graph.y.clone().detach() 29 | 30 | # Apply permutation to adjacencies, if any 31 | if graph.edge_index.size(1) > 0: 32 | inps = (np.ones(graph.edge_index.size(1)), (graph.edge_index[0].numpy(), graph.edge_index[1].numpy())) 33 | A = sp.csr_matrix(inps, shape=(n,n)) 34 | P = sp.csr_matrix(P) 35 | A_perm = P.dot(A).dot(P.transpose()).tocoo() 36 | edge_index_perm = torch.LongTensor(np.vstack((A_perm.row, A_perm.col))) 37 | else: 38 | edge_index_perm = graph.edge_index.clone().detach() 39 | 40 | # Instantiate new graph 41 | graph_perm = Data(x=x_perm, edge_index=edge_index_perm, y=y_perm) 42 | 43 | return graph_perm 44 | 45 | def is_valid_permutation_matrix(P: np.ndarray, n: int): 46 | valid = True 47 | valid &= P.ndim == 2 48 | valid &= P.shape[0] == n 49 | valid &= np.all(P.sum(0) == np.ones(n)) 50 | valid &= np.all(P.sum(1) == np.ones(n)) 51 | valid &= np.all(P.max(0) == np.ones(n)) 52 | valid &= np.all(P.max(1) == np.ones(n)) 53 | if n > 1: 54 | valid &= np.all(P.min(0) == np.zeros(n)) 55 | valid &= np.all(P.min(1) == np.zeros(n)) 56 | valid &= not np.array_equal(P, np.eye(n)) 57 | return valid 58 | 59 | def generate_permutation_matrices(size, amount=10, seed=43): 60 | 61 | Ps = list() 62 | random_state = np.random.RandomState(seed) 63 | count = 0 64 | while count < amount: 65 | I = np.eye(size) 66 | perm = random_state.permutation(size) 67 | P = I[perm] 68 | if is_valid_permutation_matrix(P, size): 69 | Ps.append(P) 70 | count += 1 71 | 72 | return Ps -------------------------------------------------------------------------------- /data/test_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from data.dummy_complexes import get_house_complex 4 | 5 | 6 | def test_up_and_down_feature_extraction_on_house_complex(): 7 | house_complex = get_house_complex() 8 | 9 | v_cochain_params = house_complex.get_cochain_params(dim=0) 10 | v_up_attr = v_cochain_params.kwargs['up_attr'] 11 | expected_v_up_attr = torch.tensor([[1], [1], [4], [4], [2], [2], [3], [3], [6], [6], [5], [5]], 12 | dtype=torch.float) 13 | assert torch.equal(expected_v_up_attr, v_up_attr) 14 | 15 | e_cochain_params = house_complex.get_cochain_params(dim=1) 16 | e_up_attr = e_cochain_params.kwargs['up_attr'] 17 | expected_e_up_attr = torch.tensor([[1], [1], [1], [1], [1], [1]], dtype=torch.float) 18 | assert torch.equal(expected_e_up_attr, e_up_attr) 19 | 20 | e_down_attr = e_cochain_params.kwargs['down_attr'] 21 | expected_e_down_attr = torch.tensor([[2], [2], [1], [1], [3], [3], [3], [3], [4], [4], [4], [4], 22 | [3], [3], [4], [4], [5], [5]], dtype=torch.float) 23 | assert torch.equal(expected_e_down_attr, e_down_attr) 24 | 25 | t_cochain_params = house_complex.get_cochain_params(dim=2) 26 | t_up_attr = t_cochain_params.kwargs['up_attr'] 27 | assert t_up_attr is None 28 | 29 | t_down_attr = t_cochain_params.kwargs['down_attr'] 30 | assert t_down_attr is None 31 | 32 | 33 | def test_get_all_cochain_params_with_max_dim_one_and_no_top_features(): 34 | house_complex = get_house_complex() 35 | 36 | params = house_complex.get_all_cochain_params(max_dim=1, include_top_features=False) 37 | assert len(params) == 2 38 | 39 | v_cochain_params, e_cochain_params = params 40 | 41 | v_up_attr = v_cochain_params.kwargs['up_attr'] 42 | expected_v_up_attr = torch.tensor([[1], [1], [4], [4], [2], [2], [3], [3], [6], [6], [5], [5]], 43 | dtype=torch.float) 44 | assert torch.equal(expected_v_up_attr, v_up_attr) 45 | 46 | e_up_attr = e_cochain_params.kwargs['up_attr'] 47 | assert e_up_attr is None 48 | assert e_cochain_params.up_index is not None 49 | assert e_cochain_params.up_index.size(1) == 6 50 | 51 | e_down_attr = e_cochain_params.kwargs['down_attr'] 52 | expected_e_down_attr = torch.tensor([[2], [2], [1], [1], [3], [3], [3], [3], [4], [4], [4], [4], 53 | [3], [3], [4], [4], [5], [5]], dtype=torch.float) 54 | assert torch.equal(expected_e_down_attr, e_down_attr) 55 | -------------------------------------------------------------------------------- /data/datasets/test_zinc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | import pytest 4 | 5 | from data.data_loading import load_dataset 6 | from data.helper_test import (check_edge_index_are_the_same, 7 | check_edge_attr_are_the_same, get_rings, 8 | get_complex_rings) 9 | from torch_geometric.datasets import ZINC 10 | 11 | 12 | @pytest.mark.slow 13 | def test_zinc_splits_are_retained(): 14 | dataset1 = load_dataset("ZINC", max_ring_size=7, use_edge_features=True) 15 | dataset1_train = dataset1.get_split('train') 16 | dataset1_valid = dataset1.get_split('valid') 17 | dataset1_test = dataset1.get_split('test') 18 | 19 | raw_dir = osp.join(dataset1.root, 'raw') 20 | dataset2_train = ZINC(raw_dir, subset=True, split='train') 21 | dataset2_valid = ZINC(raw_dir, subset=True, split='val') 22 | dataset2_test = ZINC(raw_dir, subset=True, split='test') 23 | 24 | datasets1 = [dataset1_train, dataset1_valid, dataset1_test] 25 | datasets2 = [dataset2_train, dataset2_valid, dataset2_test] 26 | datasets = zip(datasets1, datasets2) 27 | 28 | for datas1, datas2 in datasets: 29 | for i, _ in enumerate(datas1): 30 | data1, data2 = datas1[i], datas2[i] 31 | 32 | assert torch.equal(data1.y, data2.y) 33 | assert torch.equal(data1.cochains[0].x, data2.x) 34 | assert data1.cochains[1].x.size(0) == (data2.edge_index.size(1) // 2) 35 | check_edge_index_are_the_same(data1.cochains[0].upper_index, data2.edge_index) 36 | check_edge_attr_are_the_same(data1.cochains[1].boundary_index, 37 | data1.cochains[1].x, data2.edge_index, data2.edge_attr) 38 | 39 | 40 | @pytest.mark.slow 41 | def test_we_find_only_the_induced_cycles_on_zinc(): 42 | max_ring = 7 43 | dataset = load_dataset("ZINC", max_ring_size=max_ring, use_edge_features=True) 44 | # Check only on validation to save time. I've also run once on the whole dataset and passes. 45 | dataset = dataset.get_split('valid') 46 | 47 | for complex in dataset: 48 | nx_rings = get_rings(complex.nodes.num_cells, complex.nodes.upper_index, 49 | max_ring=max_ring) 50 | if 2 not in complex.cochains: 51 | assert len(nx_rings) == 0 52 | continue 53 | 54 | complex_rings = get_complex_rings(complex.cochains[2].boundary_index, complex.edges.boundary_index) 55 | assert len(complex_rings) > 0 56 | assert len(nx_rings) == complex.cochains[2].num_cells 57 | assert nx_rings == complex_rings 58 | 59 | 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | *.DS_Store 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # PyCharm 134 | .idea/ 135 | .vscode/ 136 | 137 | 138 | *.gz 139 | exp/results/ 140 | *.pt 141 | 142 | -------------------------------------------------------------------------------- /data/datasets/test_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from scipy.spatial import Delaunay 5 | from data.datasets.flow_utils import load_flow_dataset, create_hole, is_inside_rectangle 6 | 7 | 8 | def test_create_hole(): 9 | # This seed contains some edge cases. 10 | np.random.seed(4) 11 | points = np.random.uniform(size=(400, 2)) 12 | tri = Delaunay(points) 13 | 14 | hole1 = np.array([[0.2, 0.2], [0.4, 0.4]]) 15 | points, triangles = create_hole(points, tri.simplices, hole1) 16 | 17 | assert triangles.max() == len(points) - 1 18 | assert triangles.min() == 0 19 | 20 | # Check all points are outside the hole 21 | for i in range(len(points)): 22 | assert not is_inside_rectangle(points[i], hole1) 23 | 24 | # Double check each point appears in some triangle. 25 | for i in range(len(points)): 26 | assert np.sum(triangles == i) > 0 27 | 28 | 29 | def test_flow_util_dataset_loading(): 30 | # Fix seed for reproducibility 31 | np.random.seed(0) 32 | 33 | train, test, _ = load_flow_dataset(num_points=300, num_train=20, num_test=10) 34 | assert len(train) == 20 35 | assert len(test) == 10 36 | 37 | label_count = {0: 0, 1: 0} 38 | 39 | for cochain in train + test: 40 | # checks x values (flow direction) are either +1 or -1 41 | assert (torch.sum(cochain.x == 1) + torch.sum(cochain.x == -1) 42 | == torch.count_nonzero(cochain.x)) 43 | 44 | # checks the upper/lower orientation features are consistent 45 | # in shape with the upper/lower indices 46 | assert len(cochain.upper_orient) == cochain.upper_index.size(1) 47 | assert len(cochain.lower_orient) == cochain.lower_index.size(1) 48 | # checks the upper and lower indices are consistent with the number of edges 49 | assert cochain.upper_index.max() < cochain.x.size(0), print(cochain.upper_index.max(), 50 | cochain.x.size(0)) 51 | assert cochain.lower_index.max() < cochain.x.size(0), print(cochain.lower_index.max(), 52 | cochain.x.size(0)) 53 | 54 | # checks the values for orientations are either +1 (coherent) or -1 (not coherent) 55 | assert (torch.sum(cochain.upper_orient == 1) 56 | + torch.sum(cochain.upper_orient == -1) == cochain.upper_orient.numel()) 57 | assert (torch.sum(cochain.lower_orient == 1) 58 | + torch.sum(cochain.lower_orient == -1) == cochain.lower_orient.numel()) 59 | 60 | label_count[cochain.y.item()] += 1 61 | 62 | # checks distribution of labels 63 | assert label_count[0] == 20 // 2 + 10 // 2 64 | assert label_count[1] == 20 // 2 + 10 // 2 65 | -------------------------------------------------------------------------------- /data/datasets/ocean.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os.path as osp 3 | 4 | from data.datasets import InMemoryComplexDataset 5 | from data.datasets.ocean_utils import load_ocean_dataset 6 | 7 | 8 | # TODO: Set up a cochain dataset structure or make complex dataset better support cochain-only data. 9 | # TODO: Refactor the dataset to use the latest storage formatting. 10 | class OceanDataset(InMemoryComplexDataset): 11 | """A real-world dataset for edge-flow classification. 12 | 13 | The dataset is adapted from https://arxiv.org/abs/1807.05044 14 | """ 15 | 16 | def __init__(self, root, name, load_graph=False, train_orient='default', 17 | test_orient='default'): 18 | self.name = name 19 | self._num_classes = 2 20 | self._train_orient = train_orient 21 | self._test_orient = test_orient 22 | 23 | super(OceanDataset, self).__init__(root, max_dim=1, 24 | num_classes=self._num_classes, include_down_adj=True) 25 | 26 | with open(self.processed_paths[0], 'rb') as handle: 27 | train = pickle.load(handle) 28 | 29 | with open(self.processed_paths[1], 'rb') as handle: 30 | val = pickle.load(handle) 31 | 32 | self.__data_list__ = train + val 33 | 34 | self.G = None 35 | if load_graph: 36 | with open(self.processed_paths[2], 'rb') as handle: 37 | self.G = pickle.load(handle) 38 | 39 | self.train_ids = list(range(len(train))) 40 | self.val_ids = list(range(len(train), len(train) + len(val))) 41 | self.test_ids = None 42 | 43 | @property 44 | def processed_dir(self): 45 | """This is overwritten, so the cellular complex data is placed in another folder""" 46 | return osp.join(self.root, f'complex_{self._train_orient}_{self._test_orient}') 47 | 48 | @property 49 | def processed_file_names(self): 50 | return ['train_{}_complex_list.pkl'.format(self.name), 51 | 'val_{}_complex_list.pkl'.format(self.name), 52 | '{}_graph.pkl'.format(self.name)] 53 | 54 | def process(self): 55 | train, val, G = load_ocean_dataset(self._train_orient, self._test_orient) 56 | 57 | train_path = self.processed_paths[0] 58 | print(f"Saving train dataset to {train_path}") 59 | with open(train_path, 'wb') as handle: 60 | pickle.dump(train, handle) 61 | 62 | val_path = self.processed_paths[1] 63 | print(f"Saving val dataset to {val_path}") 64 | with open(val_path, 'wb') as handle: 65 | pickle.dump(val, handle) 66 | 67 | graph_path = self.processed_paths[2] 68 | with open(graph_path, 'wb') as handle: 69 | pickle.dump(G, handle) 70 | 71 | def len(self): 72 | """Override method to make the class work with deprecated stoarage""" 73 | return len(self.__data_list__) 74 | -------------------------------------------------------------------------------- /data/datasets/flow.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os.path as osp 3 | 4 | from data.datasets import InMemoryComplexDataset 5 | from data.datasets.flow_utils import load_flow_dataset 6 | 7 | 8 | # TODO: Set up a cochain dataset structure or make complex dataset better support cochain-only data. 9 | # TODO: Make this dataset use the new storage system. 10 | class FlowDataset(InMemoryComplexDataset): 11 | """A synthetic dataset for edge-flow classification.""" 12 | 13 | def __init__(self, root, name, num_points, train_samples, val_samples, 14 | load_graph=False, train_orient='default', test_orient='default', n_jobs=2): 15 | self.name = name 16 | self._num_classes = 2 17 | self._num_points = num_points 18 | self._train_samples = train_samples 19 | self._val_samples = val_samples 20 | self._train_orient = train_orient 21 | self._test_orient = test_orient 22 | self._n_jobs = n_jobs 23 | 24 | super(FlowDataset, self).__init__(root, max_dim=1, 25 | num_classes=self._num_classes, include_down_adj=True) 26 | 27 | with open(self.processed_paths[0], 'rb') as handle: 28 | self.__data_list__ = pickle.load(handle) 29 | 30 | self.G = None 31 | if load_graph: 32 | with open(self.processed_paths[1], 'rb') as handle: 33 | self.G = pickle.load(handle) 34 | 35 | self.train_ids = list(range(train_samples)) 36 | self.val_ids = list(range(train_samples, train_samples + val_samples)) 37 | self.test_ids = None 38 | 39 | @property 40 | def processed_dir(self): 41 | """This is overwritten, so the cellular complex data is placed in another folder""" 42 | return osp.join(self.root, 43 | f'flow{self._num_points}_orient_{self._train_orient}_{self._test_orient}') 44 | 45 | @property 46 | def processed_file_names(self): 47 | return ['{}_complex_list.pkl'.format(self.name), '{}_graph.pkl'.format(self.name)] 48 | 49 | def process(self): 50 | train, val, G = load_flow_dataset(num_points=self._num_points, 51 | num_train=self._train_samples, num_test=self._val_samples, 52 | train_orientation=self._train_orient, test_orientation=self._test_orient, 53 | n_jobs=self._n_jobs) 54 | 55 | cochains = train + val 56 | path = self.processed_paths[0] 57 | print(f"Saving dataset in {path}...") 58 | with open(path, 'wb') as handle: 59 | pickle.dump(cochains, handle) 60 | 61 | graph_path = self.processed_paths[1] 62 | with open(graph_path, 'wb') as handle: 63 | pickle.dump(G, handle) 64 | 65 | @property 66 | def raw_file_names(self): 67 | return "" 68 | 69 | def download(self): 70 | pass 71 | 72 | def len(self): 73 | """Override method to make the class work with deprecated stoarage""" 74 | return len(self.__data_list__) 75 | -------------------------------------------------------------------------------- /datasets/SR_graphs/raw/sr291467.g6: -------------------------------------------------------------------------------- 1 | \}rE^yceKLtLlGeduRQhMYLLKZZapjTSLUKw[kQXIyLbHtkCzStI`{y@}FOWzcBzHozFg 2 | \}rEnYSbLSTVeek[ulB`YUiH[QrWqzLKhbRHU\EGRVzCfT]@vLMKdZgfGsigmUclBxQq[ 3 | \}rFMqidTgvDbYbdu]BHsULThIykJzwW@jPPTwaP`frcg{zEaxyB\Kx@vDiXEyXhCzcks 4 | \}rFMyoiLcqseVblTlBG{TI\YsBoMqUqxiam\[D?JlzDKmYFEZcCyq\@puexMeDuKrHrc 5 | \}rFUihddgvDbYbdu]BPsTLTeIyqJzwW@f`HUscP`frcg{zEaxyBlJT@nE[XEydhCzbKs 6 | \}rMNeceLDqtedmQtmAbwYIt[pRTKZcpLsExRuE@Qldb`lMC{itHep]?{^cWUZmXLXPfo 7 | \}rMNeockxsqayjHvIbHwSdtiaywbRRadXMX\i@_TlNcZXYMBZLHRknAk^oGnIBtBuLPk 8 | \}rNEmSeLXUMifkiPyJ_rQlLipIfSrIkh[bF^BE?Q^vCstUDc{\@zRUcnEw`qJHxEu`c{ 9 | \}rNEmo`k\uQmBe\TkQhhYK\iRIwkjbSTJLX\iD?FtnAjXiJJO{EbmLCg}fovBKuHtIps 10 | \}rNMqobdRqYdhhZVEQXFST\ctYVSrcppycE\Q`OF{mdBtuFBs\DmQwcw\Qgz`HrIp[Yk 11 | \}rNUio`{rt`hkehUGycyYJd`pywdylE`jGd\XAOXjN_ZvNBYrWLBkwbW\QpMigfMm?m[ 12 | \}rNUmBhKpUMkijErbb_NXN@WyJT`ijQthL[Tcn?Frf`szIBXlhLSpofPdWhMTPvEtPNW 13 | \}rNUmBiKhULkkjFRbR_NXN@WyJJ`qtQliK{TSv?Frf`szIBXlhMTPoepTWhMUHvEtINW 14 | \}rNUmBiKhULkkjFRbR_NYN@SyJLQir`tdS{Uav?Frf`susBXuTLSpofPdWhMTPvEtPNW 15 | \}rNUmBiKhUMkijErbb_NYN@SyJJQqt`leK{USv?Frf`sykBXmdMaioeigwhMdHvErQNW 16 | \}rUNUgdLFrdhke`vIJ`qQetYkJaMrWdXLsR\gL@DNNd`nIE[jc@vInaZMl?^E`lLR`pk 17 | \}rVE]gclhTRc{bVVcRPbUBtpUY\gAloteBy]Y@_JulbPtVGyq[LDXk`ktwWubamHrLRK 18 | \}rVE]gdLWuda|c{VPIpbWjTdqYssRckdLLX^GH_Kv\dNG]Bhk]KPx^BHvoGnISlHuLPk 19 | \}rVE]oakxsqmDczTkRPJP\TiwJcUarctUQm^BD?W^|AueMHVasKWzK`ZMUwuMIyHyHrc 20 | \}reMuHkK\VBdmjHUYIpRWmLhqI[iYskdsbX\RD?c^{c[wnJElSD^Ak?}MYwyFLUIrIXs 21 | \}reMuWglFrdbdlKVGZOtTLdikAeeY\BxwTU\YE?dVm`hyVI[jcHRZpaZfKHkLHtB{O\k 22 | \}reMuoiDLrcefb\U]BPhSjd[YRgiREetM`{]Y@_dfN_|ijIYq[AvMdAjLsX[LDuLPlPk 23 | \}ree]`eLhrakkazTorDiYG|YWzWkajgpUJJ]Y@_LZlf@lMA{g~DXjU@jXsHeMcjHuLPk 24 | \}rmeU`clPrJfIlcs{ahhYFTJqJFQqi[dqIl^KD?JfzEKwyEikkHfH\@o~ow]qEuMPYNc 25 | \}vENMahLWrUeUkXrtAhsRLTRURgbZTWhiJJ]WM?FVxeRcvI[hsDg{w`hlxO]kPmLRST[ 26 | \}vevMgiKaolcjlEpub`XTG|K]AKryFk\wTTYYf?UZdcY\NL`lHFUJDbkhLH\EKtK}Co{ 27 | \}zEmYWglgo}kRa{vdAbbYdLqeIq[rKqdTwSZGpWH^}civEEeqlI^@w@xLdW|HEmKo|P[ 28 | \}zUvMckCTpJ`\fJUUR`iYIhSXiSliFU|[eYXVF?NXLaW~IJSpoFeaLfGldp\Bc}K}CYw 29 | \}zUvMgkChpQ_~fLSlaXXST\QiiLMjA{tkXF]Kf?YNXc]LsLdKdLWoufPclW]s`{EyUP[ 30 | \}zVMq``lPqRb[fBu[aoxXJPEbyb[yUqhbUJ^KE@`fj`ivEAxZTM`oma~?YXiRRXJal`[ 31 | \}zVUmobCoqV_zjKsnB`RQUtEsyKmjA\t[iK]L?wZFFc]TeLTQTLashdpclWupcyExVAk 32 | \}~EMmPhCXtNkchirURHdUItQ]JhQjFgxiDeRUMGF\laixeBh]XK[W|BrCwXUUQrNPcbk 33 | \}~VEMQbDBtQeZdlSnA@}\@lX`rLQqYwliosSueHINx@[twDlI\MEZd@yWVheYIuNI_yk 34 | \~rEM]WhLTS^d[ecv`ITqRD\hqRciYxSTpMX\gL?L\]crYMD\MKIPzNAX]tOvIEmJUHpk 35 | \~rE~YWdKKsJaZjLUTR_ZRK\gjQdkIIt|i[MTXe`KNT`\fEL]@LFatEfITLXNBM[EtSp[ 36 | \~rE~YceKSsL`jlMRIqpLTD\EqiRqjA{\kJX[jD_iVddPleJ[HNMkcWaz`JW^IsJKoxrS 37 | \~rF]ySiCXqTglhUrMb`bQ]dE[yJMqSNls[E]IaXQZFai[mFiSTJkKieXj`PNDWzEtahw 38 | \~rU^QccSTqiblgzVFA@}Y`lL`qrcjJWpqXYTeLPoNwaslsBVJXG{itCyWfWuYEyNI_yk 39 | \~veMUSiCFpqejk\SnA@}[`lLarWwjFQtXWwTkTHINxBKlwBlI\KejdAxgVhUeIuNI_yk 40 | \~zE]mcdKcpHa^lJTUaW]WilDsrBMYcr|w[EVKeG[jFa\deLP]FJiTHFIUgWvEEzHuaiw 41 | \~zUUMPhCTuKilbVQvA@}R`lwPrIwqxQpbguS{dOkNxDKlkB[jdKfItCyWNguYKuNI_yk 42 | -------------------------------------------------------------------------------- /exp/run_tu_exp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import copy 4 | import time 5 | import numpy as np 6 | from exp.parser import get_parser 7 | from exp.run_exp import main 8 | 9 | # python3 -m exp.run_tu_exp --dataset IMDBBINARY --model cin --drop_rate 0.0 --lr 0.0001 --max_dim 2 --emb_dim 32 --dump_curves --epochs 30 --num_layers 1 --lr_scheduler StepLR --lr_scheduler_decay_steps 5 10 | 11 | __num_folds__ = 10 12 | 13 | 14 | def print_summary(summary): 15 | msg = '' 16 | for k, v in summary.items(): 17 | msg += f'Fold {k:1d}: {v:.3f}\n' 18 | print(msg) 19 | 20 | 21 | def exp_main(passed_args): 22 | 23 | parser = get_parser() 24 | args = parser.parse_args(copy.copy(passed_args)) 25 | 26 | # run each experiment separately and gather results 27 | results = list() 28 | for i in range(__num_folds__): 29 | current_args = copy.copy(passed_args) + ['--fold', str(i)] 30 | parsed_args = parser.parse_args(current_args) 31 | curves = main(parsed_args) 32 | results.append(curves) 33 | 34 | # aggregate results 35 | val_curves = np.asarray([curves['val'] for curves in results]) 36 | avg_val_curve = val_curves.mean(axis=0) 37 | best_index = np.argmax(avg_val_curve) 38 | mean_perf = avg_val_curve[best_index] 39 | std_perf = val_curves.std(axis=0)[best_index] 40 | 41 | print(" ===== Mean performance per fold ======") 42 | perf_per_fold = val_curves.mean(1) 43 | perf_per_fold = {i: perf_per_fold[i] for i in range(len(perf_per_fold))} 44 | print_summary(perf_per_fold) 45 | 46 | print(" ===== Max performance per fold ======") 47 | perf_per_fold = np.max(val_curves, axis=1) 48 | perf_per_fold = {i: perf_per_fold[i] for i in range(len(perf_per_fold))} 49 | print_summary(perf_per_fold) 50 | 51 | print(" ===== Median performance per fold ======") 52 | perf_per_fold = np.median(val_curves, axis=1) 53 | perf_per_fold = {i: perf_per_fold[i] for i in range(len(perf_per_fold))} 54 | print_summary(perf_per_fold) 55 | 56 | print(" ===== Performance on best epoch ======") 57 | perf_per_fold = val_curves[:, best_index] 58 | perf_per_fold = {i: perf_per_fold[i] for i in range(len(perf_per_fold))} 59 | print_summary(perf_per_fold) 60 | 61 | print(" ===== Final result ======") 62 | msg = ( 63 | f'Dataset: {args.dataset}\n' 64 | f'Accuracy: {mean_perf} ± {std_perf}\n' 65 | f'Best epoch: {best_index}\n' 66 | '-------------------------------\n') 67 | print(msg) 68 | 69 | # additionally write msg and configuration on file 70 | msg += str(args) 71 | filename = os.path.join(args.result_folder, f'{args.dataset}-{args.exp_name}/result.txt') 72 | print('Writing results at: {}'.format(filename)) 73 | with open(filename, 'w') as handle: 74 | handle.write(msg) 75 | 76 | if __name__ == "__main__": 77 | 78 | # standard args 79 | passed_args = sys.argv[1:] 80 | assert 'fold' not in passed_args 81 | exp_main(passed_args) 82 | -------------------------------------------------------------------------------- /data/datasets/plot_flow_dataset.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import matplotlib.pyplot as plt 3 | import os 4 | 5 | from data.datasets import FlowDataset 6 | from definitions import ROOT_DIR 7 | 8 | sns.set_style('white') 9 | sns.color_palette("tab10") 10 | 11 | 12 | def plot_arrow(p1, p2, color='red'): 13 | plt.arrow(p1[0], p1[1], p2[0] - p1[0], p2[1] - p1[1], color=color, 14 | shape='full', lw=3, length_includes_head=True, head_width=.01, zorder=10) 15 | 16 | 17 | def visualise_flow_dataset(): 18 | root = os.path.join(ROOT_DIR, 'datasets') 19 | name = 'FLOW' 20 | dataset = FlowDataset(os.path.join(root, name), name, num_points=1000, train_samples=1000, 21 | val_samples=200, classes=3, load_graph=True) 22 | G = dataset.G 23 | edge_to_tuple = G.graph['edge_to_tuple'] 24 | triangles = G.graph['triangles'] 25 | points = G.graph['points'] 26 | 27 | plt.figure(figsize=(10, 8)) 28 | plt.triplot(points[:, 0], points[:, 1], triangles) 29 | plt.plot(points[:, 0], points[:, 1], 'o') 30 | 31 | for i, cochain in enumerate([dataset[180], dataset[480]]): 32 | colors = ['red', 'navy', 'purple'] 33 | color = colors[i] 34 | 35 | x = cochain.x 36 | # 37 | # source_edge = 92 38 | # source_points = edge_to_tuple[source_edge] 39 | # plot_arrow(points[source_points[0]], points[source_points[1]], color='black') 40 | 41 | path_length = 0 42 | for i in range(len(x)): 43 | flow = x[i].item() 44 | if flow == 0: 45 | continue 46 | path_length += 1 47 | 48 | nodes1 = edge_to_tuple[i] 49 | if flow > 0: 50 | p1, p2 = points[nodes1[0]], points[nodes1[1]] 51 | else: 52 | p1, p2 = points[nodes1[1]], points[nodes1[0]], 53 | 54 | plt.arrow(p1[0], p1[1], p2[0] - p1[0], p2[1] - p1[1], color=color, 55 | shape='full', lw=3, length_includes_head=True, head_width=.01, zorder=10) 56 | 57 | # lower_index = cochain.lower_index 58 | # for i in range(lower_index.size(1)): 59 | # n1, n2 = lower_index[0, i].item(), lower_index[1, i].item() 60 | # if n1 == source_edge: 61 | # source_points = edge_to_tuple[n2] 62 | # orient = cochain.lower_orient[i].item() 63 | # color = 'green' if orient == 1.0 else 'yellow' 64 | # plot_arrow(points[source_points[0]], points[source_points[1]], color=color) 65 | 66 | # upper_index = cochain.upper_index 67 | # for i in range(upper_index.size(1)): 68 | # n1, n2 = upper_index[0, i].item(), upper_index[1, i].item() 69 | # if n1 == source_edge: 70 | # source_points = edge_to_tuple[n2] 71 | # orient = cochain.upper_orient[i].item() 72 | # color = 'green' if orient == 1.0 else 'yellow' 73 | # plot_arrow(points[source_points[0]], points[source_points[1]], color=color) 74 | 75 | plt.show() 76 | 77 | 78 | if __name__ == "__main__": 79 | visualise_flow_dataset() 80 | -------------------------------------------------------------------------------- /exp/run_ring_exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import subprocess 5 | import numpy as np 6 | 7 | from exp.parser import get_parser 8 | from exp.run_exp import main 9 | 10 | RING_SIZES = list(range(10, 32, 2)) 11 | 12 | 13 | def exp_main(passed_args): 14 | # Extract the commit sha so we can check the code that was used for each experiment 15 | sha = subprocess.check_output(["git", "describe", "--always"]).strip().decode() 16 | 17 | parser = get_parser() 18 | args = parser.parse_args(copy.copy(passed_args)) 19 | assert args.max_ring_size is None 20 | 21 | # run each experiment separately and gather results 22 | train_results = {fold: [] for fold in range(len(RING_SIZES))} 23 | val_results = {fold: [] for fold in range(len(RING_SIZES))} 24 | for seed in range(args.start_seed, args.stop_seed + 1): 25 | # We use the ring_size as a "fold" for the dataset. 26 | # This is just a hack to save the results properly using our usual infrastructure. 27 | for fold in range(len(RING_SIZES)): 28 | max_ring_size = RING_SIZES[fold] 29 | num_layers = 3 if args.model == 'ring_sparse_cin' else max_ring_size // 2 30 | current_args = (copy.copy(passed_args) + ['--fold', str(fold)] + 31 | ['--max_ring_size', str(max_ring_size)] + 32 | ['--num_layers', str(num_layers)] + 33 | ['--seed', str(seed)]) 34 | parsed_args = parser.parse_args(current_args) 35 | # Check that the default parameter value (5) was overwritten 36 | assert parsed_args.num_layers == num_layers 37 | curves = main(parsed_args) 38 | 39 | # Extract results 40 | train_results[fold].append(curves['last_train']) 41 | val_results[fold].append(curves['last_val']) 42 | 43 | msg = ( 44 | f"========= Final result ==========\n" 45 | f'Dataset: {args.dataset}\n' 46 | f'SHA: {sha}\n' 47 | f'----------- Train ----------\n') 48 | 49 | for fold, results in train_results.items(): 50 | mean = np.mean(results) 51 | std = np.std(results) 52 | msg += f'Ring size: {RING_SIZES[fold]} {mean}+-{std}\n' 53 | 54 | msg += f'----------- Test ----------\n' 55 | 56 | for fold, results in val_results.items(): 57 | mean = np.mean(results) 58 | std = np.std(results) 59 | msg += f'Ring size: {RING_SIZES[fold]} {mean}+-{std}\n' 60 | 61 | print(msg) 62 | 63 | # additionally write msg and configuration on file 64 | msg += str(args) 65 | filename = os.path.join(args.result_folder, f'{args.dataset}-{args.exp_name}/result.txt') 66 | print('Writing results at: {}'.format(filename)) 67 | with open(filename, 'w') as handle: 68 | handle.write(msg) 69 | 70 | 71 | if __name__ == "__main__": 72 | passed_args = sys.argv[1:] 73 | assert '--fold' not in passed_args 74 | assert '--seed' not in passed_args 75 | exp_main(passed_args) 76 | -------------------------------------------------------------------------------- /data/datasets/cluster.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from data.datasets import InMemoryComplexDataset 4 | from data.utils import convert_graph_dataset_with_gudhi 5 | from torch_geometric.datasets import GNNBenchmarkDataset 6 | 7 | 8 | class ClusterDataset(InMemoryComplexDataset): 9 | """This is the Cluster dataset from the Benchmarking GNNs paper. 10 | 11 | The dataset contains multiple graphs and we have to do node classification on all these graphs. 12 | """ 13 | 14 | def __init__(self, root, transform=None, 15 | pre_transform=None, pre_filter=None, max_dim=2): 16 | self.name = 'CLUSTER' 17 | super(ClusterDataset, self).__init__(root, transform, pre_transform, pre_filter, 18 | max_dim=max_dim) 19 | 20 | self.max_dim = max_dim 21 | 22 | self._data_list, idx = self.load_dataset() 23 | self.train_ids = idx[0] 24 | self.val_ids = idx[1] 25 | self.test_ids = idx[2] 26 | 27 | @property 28 | def raw_file_names(self): 29 | name = self.name 30 | # The processed graph files are our raw files. 31 | # I've obtained this from inside the GNNBenchmarkDataset class 32 | return [f'{name}_train.pt', f'{name}_val.pt', f'{name}_test.pt'] 33 | 34 | @property 35 | def processed_file_names(self): 36 | return ['complex_train.pkl', 'complex_val.pkl', 'complex_test.pkl'] 37 | 38 | def download(self): 39 | # Instantiating this will download and process the graph dataset. 40 | GNNBenchmarkDataset('./datasets/', 'CLUSTER') 41 | 42 | def load_dataset(self): 43 | """Load the dataset from here and process it if it doesn't exist""" 44 | data_list, idx = [], [] 45 | start = 0 46 | for path in self.processed_paths: 47 | with open(path, 'rb') as handle: 48 | data_list.extend(pickle.load(handle)) 49 | idx.append(list(range(start, len(data_list)))) 50 | start = len(data_list) 51 | return data_list, idx 52 | 53 | def process(self): 54 | # At this stage, the graph dataset is already downloaded and processed 55 | print(f"Processing cellular complex dataset for {self.name}") 56 | train_data = GNNBenchmarkDataset('./datasets/', 'CLUSTER', split='train') 57 | val_data = GNNBenchmarkDataset('./datasets/', 'CLUSTER', split='val') 58 | test_data = GNNBenchmarkDataset('./datasets/', 'CLUSTER', split='test') 59 | 60 | # For testing 61 | # train_data = list(train_data)[:3] 62 | # val_data = list(val_data)[:3] 63 | # test_data = list(test_data)[:3] 64 | 65 | print("Converting the train dataset with gudhi...") 66 | train_complexes, _, _ = convert_graph_dataset_with_gudhi(train_data, 67 | expansion_dim=self.max_dim, include_down_adj=self.include_down_adj) 68 | print("Converting the validation dataset with gudhi...") 69 | val_complexes, _, _ = convert_graph_dataset_with_gudhi(val_data, expansion_dim=self.max_dim, include_down_adj=self.include_down_adj) 70 | print("Converting the test dataset with gudhi...") 71 | test_complexes, _, _ = convert_graph_dataset_with_gudhi(test_data, 72 | expansion_dim=self.max_dim) 73 | complexes = [train_complexes, val_complexes, test_complexes] 74 | 75 | for i, path in enumerate(self.processed_paths): 76 | with open(path, 'wb') as handle: 77 | pickle.dump(complexes[i], handle) 78 | -------------------------------------------------------------------------------- /data/datasets/dummy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from data.datasets import InMemoryComplexDataset 4 | from data.dummy_complexes import get_testing_complex_list, get_mol_testing_complex_list 5 | 6 | 7 | class DummyDataset(InMemoryComplexDataset): 8 | """A dummy dataset using a list of hand-crafted cell complexes with many edge cases.""" 9 | 10 | def __init__(self, root): 11 | self.name = 'DUMMY' 12 | super(DummyDataset, self).__init__(root, max_dim=3, num_classes=2, 13 | init_method=None, include_down_adj=True, cellular=False) 14 | self.data, self.slices = torch.load(self.processed_paths[0]) 15 | self.train_ids = list(range(self.len())) 16 | self.val_ids = list(range(self.len())) 17 | self.test_ids = list(range(self.len())) 18 | 19 | @property 20 | def processed_file_names(self): 21 | name = self.name 22 | return [f'{name}_complex_list.pt'] 23 | 24 | @property 25 | def raw_file_names(self): 26 | # The processed graph files are our raw files. 27 | # They are obtained when running the initial data conversion S2V_to_PyG. 28 | return [] 29 | 30 | def download(self): 31 | return 32 | 33 | @staticmethod 34 | def factory(): 35 | complexes = get_testing_complex_list() 36 | for c, complex in enumerate(complexes): 37 | complex.y = torch.LongTensor([c % 2]) 38 | return complexes 39 | 40 | def process(self): 41 | print("Instantiating complexes...") 42 | complexes = self.factory() 43 | torch.save(self.collate(complexes, self.max_dim), self.processed_paths[0]) 44 | 45 | 46 | class DummyMolecularDataset(InMemoryComplexDataset): 47 | """A dummy dataset using a list of hand-crafted molecular cell complexes with many edge cases.""" 48 | 49 | def __init__(self, root, remove_2feats=False): 50 | self.name = 'DUMMYM' 51 | self.remove_2feats = remove_2feats 52 | super(DummyMolecularDataset, self).__init__(root, max_dim=2, num_classes=2, 53 | init_method=None, include_down_adj=True, cellular=True) 54 | self.data, self.slices = torch.load(self.processed_paths[0]) 55 | self.train_ids = list(range(self.len())) 56 | self.val_ids = list(range(self.len())) 57 | self.test_ids = list(range(self.len())) 58 | 59 | @property 60 | def processed_file_names(self): 61 | name = self.name 62 | remove_2feats = self.remove_2feats 63 | fn = f'{name}_complex_list' 64 | if remove_2feats: 65 | fn += '_removed_2feats' 66 | fn += '.pt' 67 | return [fn] 68 | 69 | @property 70 | def raw_file_names(self): 71 | # The processed graph files are our raw files. 72 | # They are obtained when running the initial data conversion S2V_to_PyG. 73 | return [] 74 | 75 | def download(self): 76 | return 77 | 78 | @staticmethod 79 | def factory(remove_2feats=False): 80 | complexes = get_mol_testing_complex_list() 81 | for c, complex in enumerate(complexes): 82 | if remove_2feats: 83 | if 2 in complex.cochains: 84 | complex.cochains[2].x = None 85 | complex.y = torch.LongTensor([c % 2]) 86 | return complexes 87 | 88 | def process(self): 89 | print("Instantiating complexes...") 90 | complexes = self.factory(self.remove_2feats) 91 | torch.save(self.collate(complexes, self.max_dim), self.processed_paths[0]) 92 | -------------------------------------------------------------------------------- /data/datasets/ringlookup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | 4 | from data.datasets import InMemoryComplexDataset 5 | from data.datasets.ring_utils import generate_ringlookup_graph_dataset 6 | from data.utils import convert_graph_dataset_with_rings 7 | 8 | 9 | class RingLookupDataset(InMemoryComplexDataset): 10 | """A dataset where the task is to perform dictionary lookup on the features 11 | of a set of nodes forming a ring. The feature of each node is composed of a key and a value 12 | and one must assign to a target node the value of the key its feature encodes. 13 | """ 14 | 15 | def __init__(self, root, nodes=10): 16 | self.name = 'RING-LOOKUP' 17 | self._nodes = nodes 18 | 19 | super(RingLookupDataset, self).__init__( 20 | root, None, None, None, max_dim=2, cellular=True, num_classes=nodes-1) 21 | 22 | self.data, self.slices = torch.load(self.processed_paths[0]) 23 | idx = torch.load(self.processed_paths[1]) 24 | 25 | self.train_ids = idx[0] 26 | self.val_ids = idx[1] 27 | self.test_ids = idx[2] 28 | 29 | @property 30 | def processed_dir(self): 31 | """This is overwritten, so the cellular complex data is placed in another folder""" 32 | return osp.join(self.root, 'complex') 33 | 34 | @property 35 | def processed_file_names(self): 36 | return [f'ringlookup-n{self._nodes}.pkl', f'idx-n{self._nodes}.pkl'] 37 | 38 | @property 39 | def raw_file_names(self): 40 | # No raw files, but must be implemented 41 | return [] 42 | 43 | def download(self): 44 | # Nothing to download, but must be implemented 45 | pass 46 | 47 | def process(self): 48 | train = generate_ringlookup_graph_dataset(self._nodes, samples=10000) 49 | val = generate_ringlookup_graph_dataset(self._nodes, samples=1000) 50 | dataset = train + val 51 | 52 | train_ids = list(range(len(train))) 53 | val_ids = list(range(len(train), len(train) + len(val))) 54 | print("Converting dataset to a cell complex...") 55 | 56 | complexes, _, _ = convert_graph_dataset_with_rings( 57 | dataset, 58 | max_ring_size=self._nodes, 59 | include_down_adj=False, 60 | init_edges=True, 61 | init_rings=True, 62 | n_jobs=4) 63 | 64 | for complex in complexes: 65 | # Add mask for the target node. 66 | mask = torch.zeros(complex.nodes.num_cells, dtype=torch.bool) 67 | mask[0] = 1 68 | setattr(complex.cochains[0], 'mask', mask) 69 | 70 | # Make HOF zero 71 | complex.edges.x = torch.zeros_like(complex.edges.x) 72 | complex.two_cells.x = torch.zeros_like(complex.two_cells.x) 73 | assert complex.two_cells.num_cells == 1 74 | 75 | path = self.processed_paths[0] 76 | print(f'Saving processed dataset in {path}....') 77 | torch.save(self.collate(complexes, 2), path) 78 | 79 | idx = [train_ids, val_ids, None] 80 | 81 | path = self.processed_paths[1] 82 | print(f'Saving idx in {path}....') 83 | torch.save(idx, path) 84 | 85 | 86 | def load_ring_lookup_dataset(nodes=10): 87 | train = generate_ringlookup_graph_dataset(nodes, samples=10000) 88 | val = generate_ringlookup_graph_dataset(nodes, samples=1000) 89 | dataset = train + val 90 | 91 | train_ids = list(range(len(train))) 92 | val_ids = list(range(len(train), len(train) + len(val))) 93 | 94 | return dataset, train_ids, val_ids, None 95 | -------------------------------------------------------------------------------- /data/datasets/ringtransfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | 4 | from data.datasets import InMemoryComplexDataset 5 | from data.datasets.ring_utils import generate_ring_transfer_graph_dataset 6 | from data.utils import convert_graph_dataset_with_rings 7 | 8 | 9 | class RingTransferDataset(InMemoryComplexDataset): 10 | """A dataset where the task is to transfer features from a source node to a target node 11 | placed on the other side of a ring. 12 | """ 13 | 14 | def __init__(self, root, nodes=10, train=5000, test=500): 15 | self.name = 'RING-TRANSFER' 16 | self._nodes = nodes 17 | self._num_classes = 5 18 | self._train = train 19 | self._test = test 20 | 21 | super(RingTransferDataset, self).__init__(root, None, None, None, 22 | max_dim=2, cellular=True, num_classes=self._num_classes) 23 | 24 | self.data, self.slices = torch.load(self.processed_paths[0]) 25 | idx = torch.load(self.processed_paths[1]) 26 | 27 | self.train_ids = idx[0] 28 | self.val_ids = idx[1] 29 | self.test_ids = idx[2] 30 | 31 | @property 32 | def processed_dir(self): 33 | """This is overwritten, so the cellular complex data is placed in another folder""" 34 | return osp.join(self.root, 'complex') 35 | 36 | @property 37 | def processed_file_names(self): 38 | return [f'ringtree-n{self._nodes}.pkl', f'idx-n{self._nodes}.pkl'] 39 | 40 | @property 41 | def raw_file_names(self): 42 | # No raw files, but must be implemented 43 | return [] 44 | 45 | def download(self): 46 | # Nothing to download, but must be implemented 47 | pass 48 | 49 | def process(self): 50 | train = generate_ring_transfer_graph_dataset(self._nodes, classes=self._num_classes, 51 | samples=self._train) 52 | val = generate_ring_transfer_graph_dataset(self._nodes, classes=self._num_classes, 53 | samples=self._test) 54 | dataset = train + val 55 | 56 | train_ids = list(range(len(train))) 57 | val_ids = list(range(len(train), len(train) + len(val))) 58 | print("Converting dataset to a cell complex...") 59 | 60 | complexes, _, _ = convert_graph_dataset_with_rings( 61 | dataset, 62 | max_ring_size=self._nodes, 63 | include_down_adj=False, 64 | init_edges=True, 65 | init_rings=True, 66 | n_jobs=4) 67 | 68 | for complex in complexes: 69 | # Add mask for the target node. 70 | mask = torch.zeros(complex.nodes.num_cells, dtype=torch.bool) 71 | mask[0] = 1 72 | setattr(complex.cochains[0], 'mask', mask) 73 | 74 | # Make HOF zero 75 | complex.edges.x = torch.zeros_like(complex.edges.x) 76 | complex.two_cells.x = torch.zeros_like(complex.two_cells.x) 77 | 78 | path = self.processed_paths[0] 79 | print(f'Saving processed dataset in {path}....') 80 | torch.save(self.collate(complexes, 2), path) 81 | 82 | idx = [train_ids, val_ids, None] 83 | 84 | path = self.processed_paths[1] 85 | print(f'Saving idx in {path}....') 86 | torch.save(idx, path) 87 | 88 | 89 | def load_ring_transfer_dataset(nodes=10, train=5000, test=500, classes=5): 90 | train = generate_ring_transfer_graph_dataset(nodes, classes=classes, samples=train) 91 | val = generate_ring_transfer_graph_dataset(nodes, classes=classes, samples=test) 92 | dataset = train + val 93 | 94 | train_ids = list(range(len(train))) 95 | val_ids = list(range(len(train), len(train) + len(val))) 96 | 97 | return dataset, train_ids, val_ids, None 98 | -------------------------------------------------------------------------------- /datasets/SR_graphs/raw/sr401224.g6: -------------------------------------------------------------------------------- 1 | g}aCCMIbC?S@cAPEGTAQOOdDCaQBIIAdPi@?V?G?WXJ@PXQAP_{?LbKCdBAOcohCGUGcSPh@``KGIggkCTPOdAod?CpSEO@PEa?UbAeO?{Gcx?gDDHY?cIFG@``_X`_EX`_ 2 | g}aCCMIbC?S@cAPEGTAQOOdDCaQBIIAdPi@?V?G?WXJ@PXQAPbK?L_{CdBAGcohAGUGccPh@a`KGIggkCTPOdAod?CpSEO@PEa?UbAeO?{Gcx?gDDHY?cIFG?Xe?X`_WWX_ 3 | g}aCCMIbC?S@cAPEGTAQOOdDCaQBIIAdPi@?V?G?WXJ@PXQAPbK@`_{CdBAGcohAGUGccPh@a`@gIgg`cTPOdAocBCpSE?pPEaAObAeOGcGcx?gDDHY?cIFG?Xe?Xk?WWWW 4 | g}aCKMGaK?Q@gAPIGRAQOPDDAaQBHIAePi@?V?G?WTT@PgiAP_{@HbKCbAaOdOpCGYGSSPXA``CcSggcgJPO`Sod?CWsE?wsDaAGD_iO@XDCx?cHGhY?gEFG@``_Xh?EX`O 5 | g}aCKMGaK?Q@gARAGRABPPD@CaISCqQHHa`?SSO?DYK_JXK?deK?XhKBP?kPQOSbHEDCGSgg_OdCEgXAgFOWkA`YCF@ChQGgfCBACciGQPI`Y?Q_EOt?`OeQC`aAhb?oX@S 6 | g}aCKQH`c?Q@gAOMGFA@oQ`DDAYIHAEEDY@?Z?G?IUP_XceADqcADhWD`ASGeOadIaIO?[HE?GocagHAgdOSaaQeCDHCUQHDIGaiAcScRGCgiaGaHHQPG`EID@IcHaaGLOc 7 | g}aCKQH`c?Q@gAOMGFA@oR@DBAYEIAIDD[?_Y_O?IUP_XceADpWADicDPASOiOabIaIO?[HE?GocagHAgdOScaQdCChCUaHHHGciASWcJGCgiaGaHHQPG`EIC`KaHaaOKpC 8 | g}aCKQH`c?Q@gAOMGFA@oR@DBAYIHAEEDY@?Z?G?IUP_XceADpcADiWDPASOiOabIaIO?[HE?GocagHAgdOScaQdCChCUaHDIGaiAcScRGCgiaGaHHQPG`EIC`KaHaaOKpC 9 | g}aCKQH`c?S@cAOMGFA@oQ`DDAYIDAEIDY@?Z?G?IUP_XceADqcADhWD`ASGeOadEaI_?kHD?GocagHAgdOSaaQeCDHCUQGhHHCiCSWaJGGgiQGaDHQ`G`EID@IcHaaGLOc 10 | g}aCKQH`c?S@cAOMGFA@oR@DBAYEEAIHD[?_Y_O?IUP_XceADpWADicDPASOiOabEaI_?kHD?GocagHAgdOScaQdCChCUaGdIHAiCcSaRGGgiQGaDHQ`G`EIC`KaHaaOKpC 11 | g}aCKQH`c?S@cAOMGFA@oR@DBAYIDAEIDY@?Z?G?IUP_XceADpcADiWDPASOiOabEaI_?kHD?GocagHAgdOScaQdCChCUaGhHHCiCSWaJGGgiQGaDHQ`G`EIC`KaHaaOKpC 12 | g}aCSUC_KGSAcAPQHDACoOWdKCqOeADIDPO_X_A?KcZ?JdQ@Pg[?b`kAUKAHBObAFCCcck_I`@Oq@o?kI`GGWaDSRGgk__ooWqDKAKh_HQGSt?hGI@yAMOEPCH`_pi?KX_o 13 | g}aKCEBbC?Q@gAOMGFA@oQ`DDAYQHAKDDi@?TOO?QWb_pQU?dpc?TqWDPDAOiPHBIQP_?UHH?GpCabPHGdCkQSccIGSiE@WpDHSGCcdAh@EGq`CCHHPQGGeIAOk`Hc_aSoK 14 | g}aKCEEaS?S@cAOMGFA@oQ_dDCYSHAIDDT@?Y_O?QSd_paM?XqS?dpgAXCa@WPPFIQPO?UHI?HGcabGpGdDSaQgcHGXII@_pDGiOaacbOcEDQPGCHGpaGGeIAOq`Hd?aQoK 15 | g}aKCEH`c?Q@gAOMH`AEOQDDAcaAdQHChY@?XOO?SUW_hg[AP`k@`a[DBAKOhOocGiGgSPWcP`@cXGg`gEpOiCoSBDOkI?opEcAHCKl?KACcwOcHHDT?aEFG@H`_Xk?KX_W 16 | g}aKCEI`S?S@cAOMGFA@oQ_dDCYKDAQHDT@?Y_O?KWd`HQM?Tqc?hpWAXCa@WPPFIQP_?UHH?HGcabGpGdDSaSgcDGWiIA_hHGkGbASbHCEDQQGCHGp`GGeKAHQ_hc_cIoS 17 | g}aKCEK_s?Q@gAOMHHAKOQDDAcaGdQBChY@?XOO?SUW_hg[AP`k@`a[DBAKObOocGiGgSSWcP`@cXGg`gEpOiCoSBDOkI?oofAA`CKl?KACgs_`HHDT?aEFG@H`_Xk?KX_W 18 | g}aKCYA_c@S@cCOsHEAGoOWdCdIOUASGdPG_X_G?KTJ?JdQ@Da[?p`kBPBAHB_bADIOcci_Ia@OqHD?kI`GGcaq@BEG`TCobJAD@CHYB?bEDL?R@I@yAMOEW@H`SPg_KXgO 19 | g}aKKQ@_KCSAcAPQHSAApOch@RAWGYGW`W_@WgO?ScJ`BhE?Jwg@HhSAeHG@DOL@EECcS`gJG@K_hC?hCcgKRDG`MHOeHS?eHIGaDGPITPCpCegCI@{AKWF@C_koPgb?SWo 20 | g}aKKQA_S@S@cCOwGbIGoPE@IAIPGiON@b?GSWC?PgR_LQk@`c[@PqcBP@I@E_e@EKGg_hGcQ?aYWCp@eSDGQIahSGdHDY@HDGocAQi`?_DPGhGRIAsGYBEc@Gj_`wA_WF_ 21 | g}aKSAB`CAQ@gCOeGkICoOe@GcIQCiON@T?OXGA?QSR`EaM@BlC@`igAi@g@HOe@FAOc?hPCa?dCRCxEGabWeDCgOCLOeT@@UHJOCIdO`SCKlAAWCoodgGF@D?i_Xs@_T?w 22 | g}aKSAB`CAS@cCOeGkICoOe@HAQPGYON@S_GXOC?QSR`EaM@Bkg@`jCCi@W?hOi@FAOc?hPCa@DCQcweGbBWeDCgOCLOeT?ahWGTCDWcl?GKkaAWCoodgGF@D?i_Xs@_T?w 23 | g}aKSEC_KAS@cCOsGdIOoQW@CSICSYON@Q_OXOC?ISj?UaY@`pK?XgkDD?[Oq_IbDEOcGaWc_OiGRGXDCQoWWIgUCE`L?QPWEcAKDAhDHCCJU?cQGSm?gcEoCPa?xbAGX@o 24 | g}aK[A@_KAQAgAQWGRACqOe@GcIICYON@T?OXGG?QcR_daYABdK@DwgAhCg@H_RDIIHGOXOaQ@EGRGWdCe?wQbGRaGpDEWPEJCC_A`hB@HCghJ?AHCt?H`f@H?iaPo_cXPO 25 | g}aK[A@_S@QAgAOeGkICoOe@HCIPCiON@d?GTGC?QWb`EPM@BlG@DicCiAW?hOY@IQPCOWpAa?eCQdXDGbAwSLAcgHCpED?eDgGBcSdP?UCooiS?GhPDg?f@D@K`Ho`_KoK 26 | g}aK[A@_S@SAcAOeGkICoOe@DAIHGiON@h?GUGC?Icb_d`MABlGADicDHAW?p_Y@IQPCOWpAa?eGiGXDCUOWSLAgWHCpDH?eHOdBcSUCIUCoohW?GhPEc?e`D@KoHg`_KqC 27 | g~aCCEG`CGoWaGOah@aGKOSh?tABOaS_dE_gROP@[B_AkB`CkQaESHaCM?N@?[[PCJSGWPIoQ__nAH?wGJCgYGAsQD_aSU@QQQGQCQQcGdHHGc_fHOogOoDAbAPCHDPIGGb 28 | g~aCCQH_cGO`_cPSGSaOSPSDDOaOTIDI@GJ@OhE@[B_AXDDEDDQDGoeCGof@@oUPCqDOWRHAQOoiCh?hDPDGSaTCQEDIGU@AKUM?CF?ew?GP{?@LGM[?AqDo?e?fH[?U?UK 29 | -------------------------------------------------------------------------------- /data/datasets/ring_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | from torch_geometric.data import Data 6 | from sklearn.preprocessing import LabelBinarizer 7 | 8 | 9 | # TODO: Add a graph dataset for ring lookup. 10 | def generate_ring_lookup_graph(nodes): 11 | """This generates a dictionary lookup ring. No longer being used for now.""" 12 | # Assign all the other nodes in the ring a unique key and value 13 | keys = np.arange(1, nodes) 14 | vals = np.random.permutation(nodes - 1) 15 | 16 | oh_keys = np.array(LabelBinarizer().fit_transform(keys)) 17 | oh_vals = np.array(LabelBinarizer().fit_transform(vals)) 18 | oh_all = np.concatenate((oh_keys, oh_vals), axis=-1) 19 | x = np.empty((nodes, oh_all.shape[1])) 20 | x[1:, :] = oh_all 21 | 22 | # Assign the source node one of these random keys and set the value to -1 23 | key_idx = random.randint(0, nodes - 2) 24 | val = vals[key_idx] 25 | 26 | x[0, :] = 0 27 | x[0, :oh_keys.shape[1]] = oh_keys[key_idx] 28 | 29 | x = torch.tensor(x, dtype=torch.float32) 30 | 31 | edge_index = [] 32 | for i in range(nodes-1): 33 | edge_index.append([i, i + 1]) 34 | edge_index.append([i + 1, i]) 35 | 36 | # Add the edges that close the ring 37 | edge_index.append([0, nodes - 1]) 38 | edge_index.append([nodes - 1, 0]) 39 | 40 | edge_index = np.array(edge_index, dtype=np.long).T 41 | edge_index = torch.tensor(edge_index, dtype=torch.long) 42 | 43 | # Create a mask for the target node of the graph 44 | mask = torch.zeros(nodes, dtype=torch.bool) 45 | mask[0] = 1 46 | 47 | # Add the label of the graph as a graph label 48 | y = torch.tensor([val], dtype=torch.long) 49 | return Data(x=x, edge_index=edge_index, mask=mask, y=y) 50 | 51 | 52 | def generate_ringlookup_graph_dataset(nodes, samples=10000): 53 | # Generate the dataset 54 | dataset = [] 55 | for i in range(samples): 56 | graph = generate_ring_lookup_graph(nodes) 57 | dataset.append(graph) 58 | return dataset 59 | 60 | 61 | def generate_ring_transfer_graph(nodes, target_label): 62 | opposite_node = nodes // 2 63 | 64 | # Initialise the feature matrix with a constant feature vector 65 | # TODO: Modify the experiment to use another random constant feature per graph 66 | x = np.ones((nodes, len(target_label))) 67 | 68 | x[0, :] = 0.0 69 | x[opposite_node, :] = target_label 70 | x = torch.tensor(x, dtype=torch.float32) 71 | 72 | edge_index = [] 73 | for i in range(nodes-1): 74 | edge_index.append([i, i + 1]) 75 | edge_index.append([i + 1, i]) 76 | 77 | # Add the edges that close the ring 78 | edge_index.append([0, nodes - 1]) 79 | edge_index.append([nodes - 1, 0]) 80 | 81 | edge_index = np.array(edge_index, dtype=np.long).T 82 | edge_index = torch.tensor(edge_index, dtype=torch.long) 83 | 84 | # Create a mask for the target node of the graph 85 | mask = torch.zeros(nodes, dtype=torch.bool) 86 | mask[0] = 1 87 | 88 | # Add the label of the graph as a graph label 89 | y = torch.tensor([np.argmax(target_label)], dtype=torch.long) 90 | return Data(x=x, edge_index=edge_index, mask=mask, y=y) 91 | 92 | 93 | def generate_ring_transfer_graph_dataset(nodes, classes=5, samples=10000): 94 | # Generate the dataset 95 | dataset = [] 96 | samples_per_class = samples // classes 97 | for i in range(samples): 98 | label = i // samples_per_class 99 | target_class = np.zeros(classes) 100 | target_class[label] = 1.0 101 | graph = generate_ring_transfer_graph(nodes, target_class) 102 | dataset.append(graph) 103 | return dataset 104 | -------------------------------------------------------------------------------- /data/test_dataset.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | 4 | from data.data_loading import load_graph_dataset 5 | from data.datasets import TUDataset, DummyMolecularDataset, DummyDataset 6 | from data.utils import compute_clique_complex_with_gudhi, compute_ring_2complex 7 | from data.helper_test import compare_complexes, compare_complexes_without_2feats 8 | from definitions import ROOT_DIR 9 | 10 | 11 | def validate_data_retrieval(dataset, graph_list, exp_dim, include_down_adj, ring_size=None): 12 | 13 | assert len(dataset) == len(graph_list) 14 | for i in range(len(graph_list)): 15 | graph = graph_list[i] 16 | yielded = dataset[i] 17 | if ring_size is not None: 18 | expected = compute_ring_2complex(graph.x, graph.edge_index, None, 19 | graph.num_nodes, y=graph.y, 20 | max_k=ring_size, include_down_adj=include_down_adj, 21 | init_rings=True) 22 | else: 23 | expected = compute_clique_complex_with_gudhi(graph.x, graph.edge_index, 24 | graph.num_nodes, expansion_dim=exp_dim, 25 | y=graph.y, include_down_adj=include_down_adj) 26 | compare_complexes(yielded, expected, include_down_adj) 27 | 28 | 29 | @pytest.mark.data 30 | def test_data_retrieval_on_proteins(): 31 | dataset = TUDataset(os.path.join(ROOT_DIR, 'datasets', 'PROTEINS'), 'PROTEINS', max_dim=3, 32 | num_classes=2, fold=0, degree_as_tag=False, init_method='sum', include_down_adj=True) 33 | graph_list, train_ids, val_ids, _, num_classes = load_graph_dataset('PROTEINS', fold=0) 34 | assert dataset.include_down_adj 35 | assert dataset.num_classes == num_classes 36 | validate_data_retrieval(dataset, graph_list, 3, True) 37 | validate_data_retrieval(dataset[train_ids], [graph_list[i] for i in train_ids], 3, True) 38 | validate_data_retrieval(dataset[val_ids], [graph_list[i] for i in val_ids], 3, True) 39 | return 40 | 41 | 42 | @pytest.mark.data 43 | def test_data_retrieval_on_proteins_with_rings(): 44 | dataset = TUDataset(os.path.join(ROOT_DIR, 'datasets', 'PROTEINS'), 'PROTEINS', max_dim=2, 45 | num_classes=2, fold=0, degree_as_tag=False, init_method='sum', include_down_adj=True, 46 | max_ring_size=6) 47 | graph_list, train_ids, val_ids, _, num_classes = load_graph_dataset('PROTEINS', fold=0) 48 | assert dataset.include_down_adj 49 | assert dataset.num_classes == num_classes 50 | # Reducing to val_ids only, to save some time. Uncomment the lines below to test on the whole set 51 | # validate_data_retrieval(dataset, graph_list, 2, True, 6) 52 | # validate_data_retrieval(dataset[train_ids], [graph_list[i] for i in train_ids], 2, True, 6) 53 | validate_data_retrieval(dataset[val_ids], [graph_list[i] for i in val_ids], 2, True, 6) 54 | 55 | 56 | def test_dummy_dataset_data_retrieval(): 57 | 58 | complexes = DummyDataset.factory() 59 | dataset = DummyDataset(os.path.join(ROOT_DIR, 'datasets', 'DUMMY')) 60 | assert len(complexes) == len(dataset) 61 | for i in range(len(dataset)): 62 | compare_complexes(dataset[i], complexes[i], True) 63 | 64 | 65 | def test_dummy_mol_dataset_data_retrieval(): 66 | 67 | complexes = DummyMolecularDataset.factory(False) 68 | dataset = DummyMolecularDataset(os.path.join(ROOT_DIR, 'datasets', 'DUMMYM'), False) 69 | assert len(complexes) == len(dataset) 70 | for i in range(len(dataset)): 71 | compare_complexes(dataset[i], complexes[i], True) 72 | 73 | 74 | def test_dummy_mol_dataset_data_retrieval_without_2feats(): 75 | 76 | complexes = DummyMolecularDataset.factory(True) 77 | dataset = DummyMolecularDataset(os.path.join(ROOT_DIR, 'datasets', 'DUMMYM'), True) 78 | assert len(complexes) == len(dataset) 79 | for i in range(len(dataset)): 80 | compare_complexes_without_2feats(dataset[i], complexes[i], True) 81 | -------------------------------------------------------------------------------- /exp/evaluate_sr_cwn_emb_mag.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import numpy as np 5 | import random 6 | from definitions import ROOT_DIR 7 | 8 | from exp.prepare_sr_tests import prepare 9 | from mp.models import MessagePassingAgnostic, SparseCIN 10 | from data.data_loading import DataLoader, load_dataset 11 | 12 | __families__ = [ 13 | 'sr16622', 14 | 'sr251256', 15 | 'sr261034', 16 | 'sr281264', 17 | 'sr291467', 18 | 'sr351668', 19 | 'sr351899', 20 | 'sr361446', 21 | 'sr401224' 22 | ] 23 | 24 | def compute_embeddings(family, baseline, seed): 25 | 26 | # Set the seed for everything 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | np.random.seed(seed) 31 | random.seed(seed) 32 | 33 | # Perform the check in double precision 34 | torch.set_default_dtype(torch.float64) 35 | 36 | # Please set the parameters below to the ones used in SR experiments. 37 | hidden = 16 38 | num_layers = 3 39 | max_ring_size = 6 40 | use_coboundaries = True 41 | nonlinearity = 'elu' 42 | graph_norm = 'id' 43 | readout = 'sum' 44 | final_readout = 'sum' 45 | readout_dims = (0,1,2) 46 | init = 'sum' 47 | jobs = 64 48 | device = torch.device("cuda:" + str(0)) if torch.cuda.is_available() else torch.device("cpu") 49 | 50 | # Build and dump dataset if needed 51 | prepare(family, jobs, max_ring_size, False, init, None) 52 | 53 | # Load reference dataset 54 | complexes = load_dataset(family, max_dim=2, max_ring_size=max_ring_size, init_method=init) 55 | data_loader = DataLoader(complexes, batch_size=8, shuffle=False, num_workers=16, max_dim=2) 56 | 57 | # Instantiate model 58 | if not baseline: 59 | model = SparseCIN(num_input_features=1, num_classes=complexes.num_classes, num_layers=num_layers, hidden=hidden, 60 | use_coboundaries=use_coboundaries, nonlinearity=nonlinearity, graph_norm=graph_norm, 61 | readout=readout, final_readout=final_readout, readout_dims=readout_dims) 62 | else: 63 | hidden = 256 64 | model = MessagePassingAgnostic(num_input_features=1, num_classes=complexes.num_classes, hidden=hidden, 65 | nonlinearity=nonlinearity, readout=readout) 66 | model = model.to(device) 67 | model.eval() 68 | 69 | # Compute complex embeddings 70 | with torch.no_grad(): 71 | embeddings = list() 72 | for batch in data_loader: 73 | batch.nodes.x = batch.nodes.x.double() 74 | batch.edges.x = batch.edges.x.double() 75 | batch.two_cells.x = batch.two_cells.x.double() 76 | out = model.forward(batch.to(device)) 77 | embeddings.append(out) 78 | embeddings = torch.cat(embeddings, 0) # n x d 79 | assert embeddings.size(1) == complexes.num_classes 80 | 81 | return embeddings 82 | 83 | if __name__ == "__main__": 84 | 85 | # Standard args 86 | passed_args = sys.argv[1:] 87 | baseline = (passed_args[0].lower() == 'true') 88 | max_ring_size = int(passed_args[1]) 89 | assert max_ring_size > 3 90 | 91 | # Execute 92 | msg = f'Model: {"CIN" if not baseline else "MLP-sum"}({max_ring_size})' 93 | print(msg) 94 | for family in __families__: 95 | text = f'\n======================== {family}' 96 | msg += text+'\n' 97 | print(text) 98 | for seed in range(5): 99 | embeddings = compute_embeddings(family, baseline, seed) 100 | text = f'seed {seed}: {torch.max(torch.abs(embeddings)):.2f}' 101 | msg += text+'\n' 102 | print(text) 103 | path = os.path.join(ROOT_DIR, 'exp', 'results') 104 | if baseline: 105 | path = os.path.join(path, f'sr-base-{max_ring_size}.txt') 106 | else: 107 | path = os.path.join(path, f'sr-{max_ring_size}.txt') 108 | with open(path, 'w') as handle: 109 | handle.write(msg) 110 | -------------------------------------------------------------------------------- /data/test_tu_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import numpy as np 4 | import torch 5 | import random 6 | from data.tu_utils import get_fold_indices, load_data, S2V_to_PyG 7 | from torch_geometric.utils import degree 8 | from definitions import ROOT_DIR 9 | 10 | 11 | @pytest.fixture 12 | def imdbbinary_graphs(): 13 | data, num_classes = load_data(os.path.join(ROOT_DIR, 'datasets', 'IMDBBINARY', 'raw'), 'IMDBBINARY', True) 14 | graph_list = [S2V_to_PyG(datum) for datum in data] 15 | return graph_list 16 | 17 | @pytest.fixture 18 | def imdbbinary_nonattributed_graphs(): 19 | data, num_classes = load_data(os.path.join(ROOT_DIR, 'datasets', 'IMDBBINARY', 'raw'), 'IMDBBINARY', False) 20 | graph_list = [S2V_to_PyG(datum) for datum in data] 21 | return graph_list 22 | 23 | @pytest.fixture 24 | def proteins_graphs(): 25 | data, num_classes = load_data(os.path.join(ROOT_DIR, 'datasets', 'PROTEINS', 'raw'), 'PROTEINS', True) 26 | graph_list = [S2V_to_PyG(datum) for datum in data] 27 | return graph_list 28 | 29 | 30 | def validate_degree_as_tag(graphs): 31 | 32 | degree_set = set() 33 | degrees = dict() 34 | for g, graph in enumerate(graphs): 35 | d = degree(graph.edge_index[0]) 36 | d = d.numpy().astype(int).tolist() 37 | degree_set |= set(d) 38 | degrees[g] = d 39 | encoder = {deg: d for d, deg in enumerate(sorted(degree_set))} 40 | for g, graph in enumerate(graphs): 41 | feats = graph.x 42 | edge_index = graph.edge_index 43 | assert feats.shape[1] == len(encoder) 44 | row_sum = torch.sum(feats, 1) 45 | assert torch.equal(row_sum, torch.ones(feats.shape[0])) 46 | tags = torch.argmax(feats, 1) 47 | d = degrees[g] 48 | encoded = torch.LongTensor([encoder[deg] for deg in d]) 49 | assert torch.equal(tags, encoded), '{}\n{}'.format(tags, encoded) 50 | 51 | 52 | def validate_get_fold_indices(graphs): 53 | 54 | seeds = [0, 42, 43, 666] 55 | folds = list(range(10)) 56 | 57 | prev_train = None 58 | prev_test = None 59 | for fold in folds: 60 | for seed in seeds: 61 | torch.manual_seed(43) 62 | np.random.seed(43) 63 | random.seed(43) 64 | train_idx_0, test_idx_0 = get_fold_indices(graphs, seed, fold) 65 | torch.manual_seed(0) 66 | np.random.seed(0) 67 | random.seed(0) 68 | train_idx_1, test_idx_1 = get_fold_indices(graphs, seed, fold) 69 | # check the splitting procedure is deterministic and robust w.r.t. global seeds 70 | assert np.all(np.equal(train_idx_0, train_idx_1)) 71 | assert np.all(np.equal(test_idx_0, test_idx_1)) 72 | # check test and train form a partition 73 | assert len(set(train_idx_0) & set(test_idx_0)) == 0 74 | assert len(set(train_idx_0) | set(test_idx_0)) == len(graphs) 75 | # check idxs are different across seeds 76 | if prev_train is not None: 77 | assert np.any(~np.equal(train_idx_0, prev_train)) 78 | assert np.any(~np.equal(test_idx_0, prev_test)) 79 | prev_train = train_idx_0 80 | prev_test = test_idx_0 81 | 82 | 83 | def validate_constant_scalar_features(graphs): 84 | 85 | for graph in graphs: 86 | feats = graph.x 87 | assert feats.shape[1] 88 | expected = torch.ones(feats.shape[0], 1) 89 | assert torch.equal(feats, expected) 90 | 91 | 92 | @pytest.mark.data 93 | def test_get_fold_indices_on_imdbbinary(imdbbinary_graphs): 94 | validate_get_fold_indices(imdbbinary_graphs) 95 | 96 | 97 | @pytest.mark.data 98 | def test_degree_as_tag_on_imdbbinary(imdbbinary_graphs): 99 | validate_degree_as_tag(imdbbinary_graphs) 100 | 101 | 102 | @pytest.mark.data 103 | def test_constant_scalar_features_on_imdbbinary_without_tags(imdbbinary_nonattributed_graphs): 104 | validate_constant_scalar_features(imdbbinary_nonattributed_graphs) 105 | 106 | 107 | @pytest.mark.data 108 | def test_degree_as_tag_on_proteins(proteins_graphs): 109 | validate_degree_as_tag(proteins_graphs) 110 | -------------------------------------------------------------------------------- /exp/run_sr_exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import time 5 | import numpy as np 6 | import subprocess 7 | 8 | from definitions import ROOT_DIR 9 | from exp.parser import get_parser 10 | from exp.run_exp import main 11 | 12 | # python3 -m exp.run_sr_exp --task_type isomorphism --eval_metric isomorphism --untrained --model sparse_cin --nonlinearity id --emb_dim 16 --readout sum --num_layers 5 13 | # python3 -m exp.run_sr_exp --task_type isomorphism --eval_metric isomorphism --untrained --model gin --nonlinearity id --emb_dim 16 --readout sum --num_layers 5 14 | #--jump_mode None 15 | 16 | __families__ = [ 17 | 'sr16622', 18 | 'sr251256', 19 | 'sr261034', 20 | 'sr281264', 21 | 'sr291467', 22 | 'sr351668', 23 | 'sr351899', 24 | 'sr361446', 25 | 'sr401224' 26 | ] 27 | 28 | __max_dim__ = [ 29 | 3, 30 | 4, 31 | 3, 32 | 6, 33 | 4, 34 | 4, 35 | 6, 36 | 3, 37 | 3] 38 | 39 | if __name__ == "__main__": 40 | 41 | # Extract the commit sha so we can check the code that was used for each experiment 42 | sha = subprocess.check_output(["git", "describe", "--always"]).strip().decode() 43 | 44 | # standard args 45 | passed_args = sys.argv[1:] 46 | assert '--seed' not in passed_args 47 | assert '--dataset' not in passed_args 48 | assert '--readout_dims' not in passed_args 49 | parser = get_parser() 50 | args = parser.parse_args(copy.copy(passed_args)) 51 | 52 | # set result folder 53 | folder_name = f'SR-{args.exp_name}' 54 | if '--max_ring_size' in passed_args: 55 | folder_name += f'-{args.max_ring_size}' 56 | result_folder = os.path.join(args.result_folder, folder_name) 57 | passed_args += ['--result_folder', result_folder] 58 | 59 | # run each experiment separately and gather results 60 | results = [list() for _ in __families__] 61 | for f, family in enumerate(__families__): 62 | for seed in range(args.start_seed, args.stop_seed + 1): 63 | print(f'[i] family {family}, seed {seed}') 64 | current_args = copy.copy(passed_args) + ['--dataset', family, '--seed', str(seed)] 65 | if '--max_dim' not in passed_args: 66 | if '--max_ring_size' not in passed_args: 67 | current_args += ['--max_dim', str(__max_dim__[f])] 68 | max_dim = __max_dim__[f] 69 | else: 70 | current_args += ['--max_dim', str(2)] 71 | max_dim = 2 72 | else: 73 | assert '--max_ring_size' not in passed_args 74 | max_dim = args.max_dim 75 | readout_dims = [str(i) for i in range(max_dim + 1)] 76 | readout_dims = ['--readout_dims'] + readout_dims 77 | current_args += readout_dims 78 | parsed_args = parser.parse_args(current_args) 79 | curves = main(parsed_args) 80 | results[f].append(curves) 81 | 82 | msg = ( 83 | f"========= Final result ==========\n" 84 | f'Datasets: SR\n' 85 | f'SHA: {sha}\n') 86 | for f, family in enumerate(__families__): 87 | curves = results[f] 88 | test_perfs = [curve['last_test'] for curve in curves] 89 | assert len(test_perfs) == args.stop_seed + 1 - args.start_seed 90 | mean = np.mean(test_perfs) 91 | std_err = np.std(test_perfs) / float(len(test_perfs)) 92 | minim = np.min(test_perfs) 93 | maxim = np.max(test_perfs) 94 | msg += ( 95 | f'------------------ {family} ------------------\n' 96 | f'Mean failure rate: {mean}\n' 97 | f'StdErr failure rate: {std_err}\n' 98 | f'Min failure rate: {minim}\n' 99 | f'Max failure rate: {maxim}\n' 100 | '-----------------------------------------------\n') 101 | print(msg) 102 | 103 | # additionally write msg and configuration on file 104 | msg += str(args) 105 | filename = os.path.join(result_folder, 'result.txt') 106 | print('Writing results at: {}'.format(filename)) 107 | with open(filename, 'w') as handle: 108 | handle.write(msg) 109 | -------------------------------------------------------------------------------- /exp/count_rings.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import argparse 4 | import time 5 | 6 | from data.parallel import ProgressParallel 7 | from data.data_loading import load_graph_dataset 8 | from data.utils import get_rings 9 | from joblib import delayed 10 | 11 | parser = argparse.ArgumentParser(description='Ring counting experiment.') 12 | parser.add_argument('--dataset', type=str, default="ZINC", 13 | help='dataset name (default: ZINC)') 14 | parser.add_argument('--n_jobs', type=int, default=4, 15 | help='Number of jobs to use') 16 | parser.add_argument('--max_ring_size', type=int, default=12, 17 | help='maximum ring size to look for') 18 | 19 | 20 | def get_ring_count_for_graph(edge_index, max_ring, keys): 21 | rings = get_rings(edge_index, max_k=max_ring) 22 | rings_per_graph = {key: 0 for key in keys} 23 | for ring in rings: 24 | k = len(ring) 25 | rings_per_graph[k] += 1 26 | return rings_per_graph 27 | 28 | 29 | def combine_all_cards(*cards): 30 | keys = cards[0].keys() 31 | ring_cards = {key: [] for key in keys} 32 | 33 | for card in cards: 34 | for k in keys: 35 | ring_cards[k].append(card[k]) 36 | return ring_cards 37 | 38 | 39 | def get_ring_counts(dataset, max_ring, jobs): 40 | start = time.time() 41 | keys = list(range(3, max_ring+1)) 42 | 43 | parallel = ProgressParallel(n_jobs=jobs, use_tqdm=True, total=len(dataset)) 44 | # It is important we supply a numpy array here. tensors seem to slow joblib down significantly. 45 | cards = parallel(delayed(get_ring_count_for_graph)( 46 | graph.edge_index.numpy(), max_ring, keys) for graph in dataset) 47 | 48 | end = time.time() 49 | print(f'Done ({end - start:.2f} secs).') 50 | return combine_all_cards(*cards) 51 | 52 | 53 | def combine_all_counts(*stats): 54 | all_stats = dict() 55 | 56 | for k in stats[0].keys(): 57 | all_stats[k] = [] 58 | 59 | for stat in stats: 60 | for k, v in stat.items(): 61 | # Extend the list 62 | all_stats[k] += v 63 | return all_stats 64 | 65 | 66 | def print_stats(stats): 67 | for k in stats: 68 | min = np.min(stats[k]) 69 | max = np.max(stats[k]) 70 | mean = np.mean(stats[k]) 71 | med = np.median(stats[k]) 72 | sum = np.sum(stats[k]) 73 | nz = np.count_nonzero(stats[k]) 74 | print( 75 | f'Ring {k:02d} => Min: {min:.3f}, Max: {max:.3f}, Mean:{mean:.3f}, Median: {med:.3f}, ' 76 | f'Sum: {sum:05d}, Non-zero: {nz:05d}') 77 | 78 | 79 | def exp_main(passed_args): 80 | args = parser.parse_args(passed_args) 81 | 82 | print('----==== {} ====----'.format(args.dataset)) 83 | graph_list, train_ids, val_ids, test_ids, _ = load_graph_dataset(args.dataset) 84 | graph_list = list(graph_list) # Needed to bring OGB in the right format 85 | 86 | train = [graph_list[i] for i in train_ids] 87 | val = [graph_list[i] for i in val_ids] 88 | test = None 89 | if test_ids is not None: 90 | test = [graph_list[i] for i in test_ids] 91 | 92 | print("Counting rings on the training set ....") 93 | print("First, it will take a while to set up the processes...") 94 | train_stats = get_ring_counts(train, args.max_ring_size, args.n_jobs) 95 | print("Counting rings on the validation set ....") 96 | val_stats = get_ring_counts(val, args.max_ring_size, args.n_jobs) 97 | 98 | test_stats = None 99 | if test is not None: 100 | print("Counting rings on the test set ....") 101 | test_stats = get_ring_counts(test, args.max_ring_size, args.n_jobs) 102 | all_stats = combine_all_counts(train_stats, val_stats, test_stats) 103 | else: 104 | all_stats = combine_all_counts(train_stats, val_stats) 105 | 106 | print("=============== Train ================") 107 | print_stats(train_stats) 108 | print("=============== Validation ================") 109 | print_stats(val_stats) 110 | if test is not None: 111 | print("=============== Test ================") 112 | print_stats(test_stats) 113 | print("=============== Whole Dataset ================") 114 | print_stats(all_stats) 115 | 116 | 117 | if __name__ == "__main__": 118 | passed_args = sys.argv[1:] 119 | exp_main(passed_args) 120 | -------------------------------------------------------------------------------- /data/datasets/ogb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | 4 | from data.utils import convert_graph_dataset_with_rings 5 | from data.datasets import InMemoryComplexDataset 6 | from ogb.graphproppred import PygGraphPropPredDataset 7 | 8 | 9 | class OGBDataset(InMemoryComplexDataset): 10 | """This is OGB graph-property prediction. This are graph-wise classification tasks.""" 11 | 12 | def __init__(self, root, name, max_ring_size, use_edge_features=False, transform=None, 13 | pre_transform=None, pre_filter=None, init_method='sum', 14 | include_down_adj=False, simple=False, n_jobs=2): 15 | self.name = name 16 | self._max_ring_size = max_ring_size 17 | self._use_edge_features = use_edge_features 18 | self._simple = simple 19 | self._n_jobs = n_jobs 20 | super(OGBDataset, self).__init__(root, transform, pre_transform, pre_filter, 21 | max_dim=2, init_method=init_method, 22 | include_down_adj=include_down_adj, cellular=True) 23 | self.data, self.slices, idx, self.num_tasks = self.load_dataset() 24 | self.train_ids = idx['train'] 25 | self.val_ids = idx['valid'] 26 | self.test_ids = idx['test'] 27 | 28 | @property 29 | def raw_file_names(self): 30 | name = self.name.replace('-', '_') # Replacing is to follow OGB folder naming convention 31 | # The processed graph files are our raw files. 32 | return [f'{name}/processed/geometric_data_processed.pt'] 33 | 34 | @property 35 | def processed_file_names(self): 36 | return [f'{self.name}_complex.pt', f'{self.name}_idx.pt', f'{self.name}_tasks.pt'] 37 | 38 | @property 39 | def processed_dir(self): 40 | """Overwrite to change name based on edge and simple feats""" 41 | directory = super(OGBDataset, self).processed_dir 42 | suffix1 = f"_{self._max_ring_size}rings" if self._cellular else "" 43 | suffix2 = "-E" if self._use_edge_features else "" 44 | suffix3 = "-S" if self._simple else "" 45 | return directory + suffix1 + suffix2 + suffix3 46 | 47 | def download(self): 48 | # Instantiating this will download and process the graph dataset. 49 | dataset = PygGraphPropPredDataset(self.name, self.raw_dir) 50 | 51 | def load_dataset(self): 52 | """Load the dataset from here and process it if it doesn't exist""" 53 | print("Loading dataset from disk...") 54 | data, slices = torch.load(self.processed_paths[0]) 55 | idx = torch.load(self.processed_paths[1]) 56 | tasks = torch.load(self.processed_paths[2]) 57 | return data, slices, idx, tasks 58 | 59 | def process(self): 60 | 61 | # At this stage, the graph dataset is already downloaded and processed 62 | dataset = PygGraphPropPredDataset(self.name, self.raw_dir) 63 | split_idx = dataset.get_idx_split() 64 | if self._simple: # Only retain the top two node/edge features 65 | print('Using simple features') 66 | dataset.data.x = dataset.data.x[:,:2] 67 | dataset.data.edge_attr = dataset.data.edge_attr[:,:2] 68 | 69 | # NB: the init method would basically have no effect if 70 | # we use edge features and do not initialize rings. 71 | print(f"Converting the {self.name} dataset to a cell complex...") 72 | complexes, _, _ = convert_graph_dataset_with_rings( 73 | dataset, 74 | max_ring_size=self._max_ring_size, 75 | include_down_adj=self.include_down_adj, 76 | init_method=self._init_method, 77 | init_edges=self._use_edge_features, 78 | init_rings=False, 79 | n_jobs=self._n_jobs) 80 | 81 | print(f'Saving processed dataset in {self.processed_paths[0]}...') 82 | torch.save(self.collate(complexes, self.max_dim), self.processed_paths[0]) 83 | 84 | print(f'Saving idx in {self.processed_paths[1]}...') 85 | torch.save(split_idx, self.processed_paths[1]) 86 | 87 | print(f'Saving num_tasks in {self.processed_paths[2]}...') 88 | torch.save(dataset.num_tasks, self.processed_paths[2]) 89 | 90 | 91 | def load_ogb_graph_dataset(root, name): 92 | raw_dir = osp.join(root, 'raw') 93 | dataset = PygGraphPropPredDataset(name, raw_dir) 94 | idx = dataset.get_idx_split() 95 | 96 | return dataset, idx['train'], idx['valid'], idx['test'] 97 | -------------------------------------------------------------------------------- /exp/plot_sr_cwn_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import numpy as np 6 | import seaborn as sns 7 | sns.set_style("whitegrid", {'legend.frameon': False}) 8 | 9 | from matplotlib import cm 10 | from matplotlib import pyplot as plt 11 | from definitions import ROOT_DIR 12 | 13 | def run(exps, codenames, plot_name): 14 | 15 | # Meta 16 | family_names = [ 17 | 'SR(16,6,2,2)', 18 | 'SR(25,12,5,6)', 19 | 'SR(26,10,3,4)', 20 | 'SR(28,12,6,4)', 21 | 'SR(29,14,6,7)', 22 | 'SR(35,16,6,8)', 23 | 'SR(35,18,9,9)', 24 | 'SR(36,14,4,6)', 25 | 'SR(40,12,2,4)'] 26 | 27 | # Retrieve results 28 | base_path = os.path.join(ROOT_DIR, 'exp', 'results') 29 | results = list() 30 | for e, exp_path in enumerate(exps): 31 | path = os.path.join(base_path, exp_path, 'result.txt') 32 | results.append(dict()) 33 | with open(path, 'r') as handle: 34 | found = False 35 | f = 0 36 | for line in handle: 37 | if not found: 38 | if line.strip().startswith('Mean'): 39 | mean = float(line.strip().split(':')[1].strip()) 40 | found = True 41 | else: 42 | continue 43 | else: 44 | std = float(line.strip().split(':')[1].strip()) 45 | results[-1][family_names[f]] = (mean, std) 46 | f += 1 47 | found = False 48 | assert f == len(family_names) 49 | 50 | # Set colours 51 | colors = cm.get_cmap('tab20c').colors[1:4] + cm.get_cmap('tab20c').colors[5:9] 52 | matplotlib.rc('axes', edgecolor='black', lw=0.25) 53 | a = np.asarray([83, 115, 171])/255.0 +0.0 54 | b = np.asarray([209, 135, 92])/255.0 +0.0 55 | colors = [a, a +0.13, a +0.2, b, b +0.065, b +0.135] 56 | 57 | # Set plotting 58 | num_families = len(family_names) 59 | num_experiments = len(results) 60 | sep = 1.75 61 | width = 0.7 62 | disp = num_experiments * width + sep 63 | xa = np.asarray([i*disp for i in range(num_families)]) 64 | xs = [xa + i*width for i in range(num_experiments//2)] + [xa + i*width + sep*0.25 for i in range(num_experiments//2, num_experiments)] 65 | plt.rcParams['ytick.right'] = plt.rcParams['ytick.labelright'] = True 66 | plt.rcParams['ytick.left'] = plt.rcParams['ytick.labelleft'] = False 67 | print(sns.axes_style()) 68 | matplotlib.rc('axes', edgecolor='#c4c4c4', linewidth=0.9) 69 | 70 | # Plot 71 | plt.figure(dpi=300, figsize=(9,6.6)) 72 | plt.grid(axis='x', alpha=0.0) 73 | for r, res in enumerate(results): 74 | x = xs[r] 75 | y = [10+res[family][0] for family in sorted(res)] 76 | yerr = [res[family][1] for family in sorted(res)] 77 | plt.bar(x, y, yerr=yerr, bottom=-10, color=colors[r], width=width, 78 | label=codenames[r], ecolor='grey', error_kw={'lw': 0.75, 'capsize':0.7}, 79 | edgecolor='white') 80 | # hatch=('//' if r<3 else '\\\\')) 81 | plt.axhline(y=1.0, color='indianred', lw=1.5, label='3WL') 82 | plt.ylim([-0.000005, 2]) 83 | plt.yscale(matplotlib.scale.SymmetricalLogScale(axis='y', linthresh=0.00001)) 84 | plt.xticks(xa+3*width, family_names, fontsize=12, rotation=315, ha='left') 85 | plt.yticks([0.0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0], fontsize=12) 86 | handles, labels = plt.gca().get_legend_handles_labels() 87 | order = [1, 4, 2, 5, 3, 6] + [0] 88 | plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order], fontsize=10, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.15)) 89 | plt.xlabel('Family', fontsize=15) 90 | plt.ylabel('Failure rate', fontsize=15, labelpad=-580, rotation=270) 91 | plt.tight_layout() 92 | plt.savefig(f'./sr_exp_{plot_name}.pdf', bbox_inches='tight', pad_inches=0.1) 93 | plt.close() 94 | 95 | if __name__ == '__main__': 96 | 97 | # Standard args 98 | passed_args = sys.argv[1:] 99 | codenames = list() 100 | exps = list() 101 | plot_name = passed_args[0] 102 | for a, arg in enumerate(passed_args[1:]): 103 | if a % 2 == 0: 104 | exps.append(arg) 105 | else: 106 | codenames.append(arg) 107 | assert len(codenames) == len(exps) == 6 108 | run(exps, codenames, plot_name) 109 | -------------------------------------------------------------------------------- /data/datasets/sr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | 5 | from data.sr_utils import load_sr_dataset 6 | from data.utils import compute_clique_complex_with_gudhi, compute_ring_2complex 7 | from data.utils import convert_graph_dataset_with_rings, convert_graph_dataset_with_gudhi 8 | from data.datasets import InMemoryComplexDataset 9 | from definitions import ROOT_DIR 10 | from torch_geometric.data import Data 11 | 12 | import os.path as osp 13 | import errno 14 | 15 | 16 | def makedirs(path): 17 | try: 18 | os.makedirs(osp.expanduser(osp.normpath(path))) 19 | except OSError as e: 20 | if e.errno != errno.EEXIST and osp.isdir(path): 21 | raise e 22 | 23 | 24 | def load_sr_graph_dataset(name, root=os.path.join(ROOT_DIR, 'datasets'), prefer_pkl=False): 25 | raw_dir = os.path.join(root, 'SR_graphs', 'raw') 26 | load_from = os.path.join(raw_dir, '{}.g6'.format(name)) 27 | load_from_pkl = os.path.join(raw_dir, '{}.pkl'.format(name)) 28 | if prefer_pkl and osp.exists(load_from_pkl): 29 | print(f"Loading SR graph {name} from pickle dump...") 30 | with open(load_from_pkl, 'rb') as handle: 31 | data = pickle.load(handle) 32 | else: 33 | data = load_sr_dataset(load_from) 34 | graphs = list() 35 | for datum in data: 36 | edge_index, num_nodes = datum 37 | x = torch.ones(num_nodes, 1, dtype=torch.float32) 38 | graph = Data(x=x, edge_index=edge_index, y=None, edge_attr=None, num_nodes=num_nodes) 39 | graphs.append(graph) 40 | train_ids = list(range(len(graphs))) 41 | val_ids = list(range(len(graphs))) 42 | test_ids = list(range(len(graphs))) 43 | return graphs, train_ids, val_ids, test_ids 44 | 45 | 46 | class SRDataset(InMemoryComplexDataset): 47 | """A dataset of complexes obtained by lifting Strongly Regular graphs.""" 48 | 49 | def __init__(self, root, name, max_dim=2, num_classes=16, train_ids=None, val_ids=None, test_ids=None, 50 | include_down_adj=False, max_ring_size=None, n_jobs=2, init_method='sum'): 51 | self.name = name 52 | self._num_classes = num_classes 53 | self._n_jobs = n_jobs 54 | assert max_ring_size is None or max_ring_size > 3 55 | self._max_ring_size = max_ring_size 56 | cellular = (max_ring_size is not None) 57 | if cellular: 58 | assert max_dim == 2 59 | super(SRDataset, self).__init__(root, max_dim=max_dim, num_classes=num_classes, 60 | include_down_adj=include_down_adj, cellular=cellular, init_method=init_method) 61 | 62 | self.data, self.slices = torch.load(self.processed_paths[0]) 63 | 64 | self.train_ids = list(range(self.len())) if train_ids is None else train_ids 65 | self.val_ids = list(range(self.len())) if val_ids is None else val_ids 66 | self.test_ids = list(range(self.len())) if test_ids is None else test_ids 67 | 68 | @property 69 | def processed_dir(self): 70 | """This is overwritten, so the cellular complex data is placed in another folder""" 71 | directory = super(SRDataset, self).processed_dir 72 | suffix = f"_{self._max_ring_size}rings" if self._cellular else "" 73 | suffix += f"_down_adj" if self.include_down_adj else "" 74 | return directory + suffix 75 | 76 | @property 77 | def processed_file_names(self): 78 | return ['{}_complex_list.pt'.format(self.name)] 79 | 80 | def process(self): 81 | 82 | graphs, _, _, _ = load_sr_graph_dataset(self.name, prefer_pkl=True) 83 | exp_dim = self.max_dim 84 | if self._cellular: 85 | print(f"Converting the {self.name} dataset to a cell complex...") 86 | complexes, max_dim, num_features = convert_graph_dataset_with_rings( 87 | graphs, 88 | max_ring_size=self._max_ring_size, 89 | include_down_adj=self.include_down_adj, 90 | init_method=self._init_method, 91 | init_edges=True, 92 | init_rings=True, 93 | n_jobs=self._n_jobs) 94 | else: 95 | print(f"Converting the {self.name} dataset with gudhi...") 96 | complexes, max_dim, num_features = convert_graph_dataset_with_gudhi( 97 | graphs, 98 | expansion_dim=exp_dim, 99 | include_down_adj=self.include_down_adj, 100 | init_method=self._init_method) 101 | 102 | if self._max_ring_size is not None: 103 | assert max_dim <= 2 104 | if max_dim != self.max_dim: 105 | self.max_dim = max_dim 106 | makedirs(self.processed_dir) 107 | 108 | # Now we save in opt format. 109 | path = self.processed_paths[0] 110 | torch.save(self.collate(complexes, self.max_dim), path) 111 | -------------------------------------------------------------------------------- /exp/run_mol_exp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import copy 4 | import numpy as np 5 | import subprocess 6 | 7 | from exp.parser import get_parser 8 | from exp.run_exp import main 9 | from itertools import product 10 | 11 | 12 | def exp_main(passed_args): 13 | # Extract the commit sha so we can check the code that was used for each experiment 14 | sha = subprocess.check_output(["git", "describe", "--always"]).strip().decode() 15 | 16 | parser = get_parser() 17 | args = parser.parse_args(copy.copy(passed_args)) 18 | assert args.stop_seed >= args.start_seed 19 | 20 | # run each experiment separately and gather results 21 | results = list() 22 | if args.folds is None: 23 | for seed in range(args.start_seed, args.stop_seed + 1): 24 | current_args = copy.copy(passed_args) + ['--seed', str(seed)] 25 | parsed_args = parser.parse_args(current_args) 26 | curves = main(parsed_args) 27 | results.append(curves) 28 | else: 29 | # Used by CSL only to run experiments across both seeds and folds 30 | assert args.dataset == 'CSL' 31 | for seed, fold in product(range(args.start_seed, args.stop_seed + 1), range(args.folds)): 32 | current_args = copy.copy(passed_args) + ['--seed', str(seed)] + ['--fold', str(fold)] 33 | parsed_args = parser.parse_args(current_args) 34 | curves = main(parsed_args) 35 | results.append(curves) 36 | 37 | # Extract results 38 | train_curves = [curves['train'] for curves in results] 39 | val_curves = [curves['val'] for curves in results] 40 | test_curves = [curves['test'] for curves in results] 41 | best_idx = [curves['best'] for curves in results] 42 | last_train = [curves['last_train'] for curves in results] 43 | last_val = [curves['last_val'] for curves in results] 44 | last_test = [curves['last_test'] for curves in results] 45 | 46 | # Extract results at the best validation epoch. 47 | best_epoch_train_results = [train_curves[i][best] for i, best in enumerate(best_idx)] 48 | best_epoch_train_results = np.array(best_epoch_train_results, dtype=np.float) 49 | best_epoch_val_results = [val_curves[i][best] for i, best in enumerate(best_idx)] 50 | best_epoch_val_results = np.array(best_epoch_val_results, dtype=np.float) 51 | best_epoch_test_results = [test_curves[i][best] for i, best in enumerate(best_idx)] 52 | best_epoch_test_results = np.array(best_epoch_test_results, dtype=np.float) 53 | 54 | # Compute stats for the best validation epoch 55 | mean_train_perf = np.mean(best_epoch_train_results) 56 | std_train_perf = np.std(best_epoch_train_results, ddof=1) # ddof=1 makes the estimator unbiased 57 | mean_val_perf = np.mean(best_epoch_val_results) 58 | std_val_perf = np.std(best_epoch_val_results, ddof=1) # ddof=1 makes the estimator unbiased 59 | mean_test_perf = np.mean(best_epoch_test_results) 60 | std_test_perf = np.std(best_epoch_test_results, ddof=1) # ddof=1 makes the estimator unbiased 61 | min_perf = np.min(best_epoch_test_results) 62 | max_perf = np.max(best_epoch_test_results) 63 | 64 | # Compute stats for the last epoch 65 | mean_final_train_perf = np.mean(last_train) 66 | std_final_train_perf = np.std(last_train, ddof=1) 67 | mean_final_val_perf = np.mean(last_val) 68 | std_final_val_perf = np.std(last_val, ddof=1) 69 | mean_final_test_perf = np.mean(last_test) 70 | std_final_test_perf = np.std(last_test, ddof=1) 71 | final_test_min = np.min(last_test) 72 | final_test_max = np.max(last_test) 73 | 74 | msg = ( 75 | f"========= Final result ==========\n" 76 | f'Dataset: {args.dataset}\n' 77 | f'SHA: {sha}\n' 78 | f'----------- Best epoch ----------\n' 79 | f'Train: {mean_train_perf} ± {std_train_perf}\n' 80 | f'Valid: {mean_val_perf} ± {std_val_perf}\n' 81 | f'Test: {mean_test_perf} ± {std_test_perf}\n' 82 | f'Test Min: {min_perf}\n' 83 | f'Test Max: {max_perf}\n' 84 | f'----------- Last epoch ----------\n' 85 | f'Train: {mean_final_train_perf} ± {std_final_train_perf}\n' 86 | f'Valid: {mean_final_val_perf} ± {std_final_val_perf}\n' 87 | f'Test: {mean_final_test_perf} ± {std_final_test_perf}\n' 88 | f'Test Min: {final_test_min}\n' 89 | f'Test Max: {final_test_max}\n' 90 | f'---------------------------------\n\n') 91 | print(msg) 92 | 93 | # additionally write msg and configuration on file 94 | msg += str(args) 95 | filename = os.path.join(args.result_folder, f'{args.dataset}-{args.exp_name}/result.txt') 96 | print('Writing results at: {}'.format(filename)) 97 | with open(filename, 'w') as handle: 98 | handle.write(msg) 99 | 100 | 101 | if __name__ == "__main__": 102 | passed_args = sys.argv[1:] 103 | assert '--seed' not in passed_args 104 | assert '--fold' not in passed_args 105 | exp_main(passed_args) 106 | -------------------------------------------------------------------------------- /mp/ring_exp_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mp.layers import SparseCINConv 4 | from mp.nn import get_nonlinearity, get_graph_norm 5 | from data.complex import ComplexBatch 6 | from torch.nn import Linear, Sequential 7 | from torch_geometric.nn import GINConv 8 | 9 | 10 | class RingSparseCIN(torch.nn.Module): 11 | """ 12 | A simple cellular version of GIN employed for Ring experiments. 13 | 14 | This model is based on 15 | https://github.com/rusty1s/pytorch_geometric/blob/master/benchmark/kernel/gin.py 16 | """ 17 | 18 | def __init__(self, num_input_features, num_classes, num_layers, hidden, 19 | max_dim: int = 2, nonlinearity='relu', train_eps=False, use_coboundaries=False, 20 | graph_norm='id'): 21 | super(RingSparseCIN, self).__init__() 22 | 23 | self.max_dim = max_dim 24 | self.convs = torch.nn.ModuleList() 25 | self.nonlinearity = nonlinearity 26 | self.init_layer = Linear(num_input_features, num_input_features) 27 | act_module = get_nonlinearity(nonlinearity, return_module=True) 28 | self.graph_norm = get_graph_norm(graph_norm) 29 | 30 | for i in range(num_layers): 31 | layer_dim = num_input_features if i == 0 else hidden 32 | self.convs.append( 33 | SparseCINConv(up_msg_size=layer_dim, down_msg_size=layer_dim, 34 | boundary_msg_size=layer_dim, passed_msg_boundaries_nn=None, passed_msg_up_nn=None, 35 | passed_update_up_nn=None, passed_update_boundaries_nn=None, 36 | train_eps=train_eps, max_dim=self.max_dim, 37 | hidden=hidden, act_module=act_module, layer_dim=layer_dim, 38 | graph_norm=self.graph_norm, use_coboundaries=use_coboundaries)) 39 | self.lin1 = Linear(hidden, num_classes) 40 | 41 | def reset_parameters(self): 42 | self.init_layer.reset_parameters() 43 | for conv in self.convs: 44 | conv.reset_parameters() 45 | self.lin1.reset_parameters() 46 | 47 | def forward(self, data: ComplexBatch, include_partial=False): 48 | xs = None 49 | res = {} 50 | 51 | data.nodes.x = self.init_layer(data.nodes.x) 52 | for c, conv in enumerate(self.convs): 53 | params = data.get_all_cochain_params(max_dim=self.max_dim, include_down_features=False) 54 | xs = conv(*params) 55 | data.set_xs(xs) 56 | 57 | if include_partial: 58 | for k in range(len(xs)): 59 | res[f"layer{c}_{k}"] = xs[k] 60 | 61 | x = xs[0] 62 | # Extract the target node from each graph 63 | mask = data.nodes.mask 64 | x = self.lin1(x[mask]) 65 | 66 | if include_partial: 67 | res['out'] = x 68 | return x, res 69 | 70 | return x 71 | 72 | def __repr__(self): 73 | return self.__class__.__name__ 74 | 75 | 76 | class RingGIN(torch.nn.Module): 77 | def __init__(self, num_features, num_layers, hidden, num_classes, nonlinearity='relu', 78 | graph_norm='bn'): 79 | super(RingGIN, self).__init__() 80 | self.nonlinearity = nonlinearity 81 | conv_nonlinearity = get_nonlinearity(nonlinearity, return_module=True) 82 | self.init_linear = Linear(num_features, num_features) 83 | self.graph_norm = get_graph_norm(graph_norm) 84 | 85 | # BN is needed to make GIN work empirically beyond 2 layers for the ring experiments. 86 | self.conv1 = GINConv( 87 | Sequential( 88 | Linear(num_features, hidden), 89 | self.graph_norm(hidden), 90 | conv_nonlinearity(), 91 | Linear(hidden, hidden), 92 | self.graph_norm(hidden), 93 | conv_nonlinearity(), 94 | ), train_eps=False) 95 | 96 | self.convs = torch.nn.ModuleList() 97 | for i in range(num_layers - 1): 98 | self.convs.append( 99 | GINConv( 100 | Sequential( 101 | Linear(hidden, hidden), 102 | self.graph_norm(hidden), 103 | conv_nonlinearity(), 104 | Linear(hidden, hidden), 105 | self.graph_norm(hidden), 106 | conv_nonlinearity(), 107 | ), train_eps=False)) 108 | self.lin1 = Linear(hidden, num_classes) 109 | 110 | def reset_parameters(self): 111 | self.init_linear.reset_parameters() 112 | self.conv1.reset_parameters() 113 | for conv in self.convs: 114 | conv.reset_parameters() 115 | self.lin1.reset_parameters() 116 | 117 | def forward(self, data): 118 | act = get_nonlinearity(self.nonlinearity, return_module=False) 119 | x, edge_index, mask = data.x, data.edge_index, data.mask 120 | x = self.init_linear(x) 121 | x = act(self.conv1(x, edge_index)) 122 | for conv in self.convs: 123 | x = conv(x, edge_index) 124 | # Select the target node of each graph in the batch 125 | x = x[mask] 126 | x = self.lin1(x) 127 | return x 128 | 129 | def __repr__(self): 130 | return self.__class__.__name__ 131 | 132 | -------------------------------------------------------------------------------- /data/datasets/csl.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | 5 | from data.datasets import InMemoryComplexDataset 6 | from data.utils import convert_graph_dataset_with_rings 7 | from torch_geometric.datasets import GNNBenchmarkDataset 8 | from torch_geometric.utils import remove_self_loops 9 | 10 | 11 | class CSLDataset(InMemoryComplexDataset): 12 | """This is the CSL (Circular Skip Link) dataset from the Benchmarking GNNs paper. 13 | 14 | The dataset contains 10 isomorphism classes of regular graphs that must be classified. 15 | """ 16 | 17 | def __init__(self, root, transform=None, 18 | pre_transform=None, pre_filter=None, max_ring_size=6, fold=0, init_method='sum', 19 | n_jobs=2): 20 | self.name = 'CSL' 21 | self._max_ring_size = max_ring_size 22 | self._n_jobs = n_jobs 23 | super(CSLDataset, self).__init__(root, transform, pre_transform, pre_filter, 24 | max_dim=2, cellular=True, init_method=init_method, 25 | num_classes=10) 26 | 27 | assert 0 <= fold <= 4 28 | self.fold = fold 29 | 30 | self.data, self.slices = self.load_dataset() 31 | 32 | self.num_node_type = 1 33 | self.num_edge_type = 1 34 | 35 | # These cross-validation splits have been taken from 36 | # https://github.com/graphdeeplearning/benchmarking-gnns/tree/master/data/CSL 37 | train_filename = osp.join(self.root, 'splits', 'CSL_train.txt') 38 | valid_filename = osp.join(self.root, 'splits', 'CSL_val.txt') 39 | test_filename = osp.join(self.root, 'splits', 'CSL_test.txt') 40 | 41 | self.train_ids = np.loadtxt(train_filename, dtype=int, delimiter=',')[fold].tolist() 42 | self.val_ids = np.loadtxt(valid_filename, dtype=int, delimiter=',')[fold].tolist() 43 | self.test_ids = np.loadtxt(test_filename, dtype=int, delimiter=',')[fold].tolist() 44 | 45 | # Make sure the split ratios are as expected (3:1:1) 46 | assert len(self.train_ids) == 3 * len(self.test_ids) 47 | assert len(self.val_ids) == len(self.test_ids) 48 | # Check all splits contain numbers that are smaller than the total number of graphs 49 | assert max(self.train_ids) < 150 50 | assert max(self.val_ids) < 150 51 | assert max(self.test_ids) < 150 52 | 53 | @property 54 | def raw_file_names(self): 55 | return ['data.pt'] 56 | 57 | @property 58 | def processed_file_names(self): 59 | return ['complexes.pt'] 60 | 61 | def download(self): 62 | # Instantiating this will download and process the graph dataset. 63 | GNNBenchmarkDataset(self.raw_dir, 'CSL') 64 | 65 | def load_dataset(self): 66 | """Load the dataset from here and process it if it doesn't exist""" 67 | print("Loading dataset from disk...") 68 | data, slices = torch.load(self.processed_paths[0]) 69 | return data, slices 70 | 71 | def process(self): 72 | # At this stage, the graph dataset is already downloaded and processed 73 | print(f"Processing cell complex dataset for {self.name}") 74 | # This dataset has no train / val / test splits and we must use cross-validation 75 | data = GNNBenchmarkDataset(self.raw_dir, 'CSL') 76 | assert len(data) == 150 77 | 78 | # Check that indeed there are no features 79 | assert data[0].x is None 80 | assert data[0].edge_attr is None 81 | 82 | print("Populating graph with features") 83 | # Initialise everything with zero as in the Benchmarking GNNs code 84 | # https://github.com/graphdeeplearning/benchmarking-gnns/blob/ef8bd8c7d2c87948bc1bdd44099a52036e715cd0/data/CSL.py#L144 85 | new_data = [] 86 | for i, datum in enumerate(data): 87 | edge_index = datum.edge_index 88 | num_nodes = datum.num_nodes 89 | # Make sure we have no self-loops in this dataset 90 | edge_index, _ = remove_self_loops(edge_index) 91 | num_edges = edge_index.size(1) 92 | 93 | vx = torch.zeros((num_nodes, 1), dtype=torch.long) 94 | edge_attr = torch.zeros(num_edges, dtype=torch.long) 95 | setattr(datum, 'edge_index', edge_index) 96 | setattr(datum, 'x', vx) 97 | setattr(datum, 'edge_attr', edge_attr) 98 | new_data.append(datum) 99 | 100 | assert new_data[0].x is not None 101 | assert new_data[0].edge_attr is not None 102 | 103 | print("Converting the train dataset to a cell complex...") 104 | complexes, _, _ = convert_graph_dataset_with_rings( 105 | new_data, 106 | max_ring_size=self._max_ring_size, 107 | include_down_adj=False, 108 | init_edges=True, 109 | init_rings=False, 110 | n_jobs=self._n_jobs) 111 | 112 | path = self.processed_paths[0] 113 | print(f'Saving processed dataset in {path}....') 114 | torch.save(self.collate(complexes, 2), path) 115 | 116 | @property 117 | def processed_dir(self): 118 | """Overwrite to change name based on edges""" 119 | directory = super(CSLDataset, self).processed_dir 120 | suffix1 = f"_{self._max_ring_size}rings" if self._cellular else "" 121 | return directory + suffix1 122 | -------------------------------------------------------------------------------- /data/datasets/zinc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os.path as osp 3 | 4 | from data.utils import convert_graph_dataset_with_rings 5 | from data.datasets import InMemoryComplexDataset 6 | from torch_geometric.datasets import ZINC 7 | 8 | 9 | class ZincDataset(InMemoryComplexDataset): 10 | """This is ZINC from the Benchmarking GNNs paper. This is a graph regression task.""" 11 | 12 | def __init__(self, root, max_ring_size, use_edge_features=False, transform=None, 13 | pre_transform=None, pre_filter=None, subset=True, 14 | include_down_adj=False, n_jobs=2): 15 | self.name = 'ZINC' 16 | self._max_ring_size = max_ring_size 17 | self._use_edge_features = use_edge_features 18 | self._subset = subset 19 | self._n_jobs = n_jobs 20 | super(ZincDataset, self).__init__(root, transform, pre_transform, pre_filter, 21 | max_dim=2, cellular=True, 22 | include_down_adj=include_down_adj, num_classes=1) 23 | 24 | self.data, self.slices, idx = self.load_dataset() 25 | self.train_ids = idx[0] 26 | self.val_ids = idx[1] 27 | self.test_ids = idx[2] 28 | 29 | self.num_node_type = 28 30 | self.num_edge_type = 4 31 | 32 | @property 33 | def raw_file_names(self): 34 | return ['train.pt', 'val.pt', 'test.pt'] 35 | 36 | @property 37 | def processed_file_names(self): 38 | name = self.name 39 | return [f'{name}_complex.pt', f'{name}_idx.pt'] 40 | 41 | def download(self): 42 | # Instantiating this will download and process the graph dataset. 43 | ZINC(self.raw_dir, subset=self._subset) 44 | 45 | def load_dataset(self): 46 | """Load the dataset from here and process it if it doesn't exist""" 47 | print("Loading dataset from disk...") 48 | data, slices = torch.load(self.processed_paths[0]) 49 | idx = torch.load(self.processed_paths[1]) 50 | return data, slices, idx 51 | 52 | def process(self): 53 | # At this stage, the graph dataset is already downloaded and processed 54 | print(f"Processing cell complex dataset for {self.name}") 55 | train_data = ZINC(self.raw_dir, subset=self._subset, split='train') 56 | val_data = ZINC(self.raw_dir, subset=self._subset, split='val') 57 | test_data = ZINC(self.raw_dir, subset=self._subset, split='test') 58 | 59 | data_list = [] 60 | idx = [] 61 | start = 0 62 | print("Converting the train dataset to a cell complex...") 63 | train_complexes, _, _ = convert_graph_dataset_with_rings( 64 | train_data, 65 | max_ring_size=self._max_ring_size, 66 | include_down_adj=self.include_down_adj, 67 | init_edges=self._use_edge_features, 68 | init_rings=False, 69 | n_jobs=self._n_jobs) 70 | data_list += train_complexes 71 | idx.append(list(range(start, len(data_list)))) 72 | start = len(data_list) 73 | print("Converting the validation dataset to a cell complex...") 74 | val_complexes, _, _ = convert_graph_dataset_with_rings( 75 | val_data, 76 | max_ring_size=self._max_ring_size, 77 | include_down_adj=self.include_down_adj, 78 | init_edges=self._use_edge_features, 79 | init_rings=False, 80 | n_jobs=self._n_jobs) 81 | data_list += val_complexes 82 | idx.append(list(range(start, len(data_list)))) 83 | start = len(data_list) 84 | print("Converting the test dataset to a cell complex...") 85 | test_complexes, _, _ = convert_graph_dataset_with_rings( 86 | test_data, 87 | max_ring_size=self._max_ring_size, 88 | include_down_adj=self.include_down_adj, 89 | init_edges=self._use_edge_features, 90 | init_rings=False, 91 | n_jobs=self._n_jobs) 92 | data_list += test_complexes 93 | idx.append(list(range(start, len(data_list)))) 94 | 95 | path = self.processed_paths[0] 96 | print(f'Saving processed dataset in {path}....') 97 | torch.save(self.collate(data_list, 2), path) 98 | 99 | path = self.processed_paths[1] 100 | print(f'Saving idx in {path}....') 101 | torch.save(idx, path) 102 | 103 | @property 104 | def processed_dir(self): 105 | """Overwrite to change name based on edges""" 106 | directory = super(ZincDataset, self).processed_dir 107 | suffix0 = "_full" if self._subset is False else "" 108 | suffix1 = f"_{self._max_ring_size}rings" if self._cellular else "" 109 | suffix2 = "-E" if self._use_edge_features else "" 110 | return directory + suffix0 + suffix1 + suffix2 111 | 112 | 113 | def load_zinc_graph_dataset(root, subset=True): 114 | raw_dir = osp.join(root, 'ZINC', 'raw') 115 | 116 | train_data = ZINC(raw_dir, subset=subset, split='train') 117 | val_data = ZINC(raw_dir, subset=subset, split='val') 118 | test_data = ZINC(raw_dir, subset=subset, split='test') 119 | data = train_data + val_data + test_data 120 | 121 | if subset: 122 | assert len(train_data) == 10000 123 | assert len(val_data) == 1000 124 | assert len(test_data) == 1000 125 | else: 126 | assert len(train_data) == 220011 127 | assert len(val_data) == 24445 128 | assert len(test_data) == 5000 129 | 130 | idx = [] 131 | start = 0 132 | idx.append(list(range(start, len(train_data)))) 133 | start = len(train_data) 134 | idx.append(list(range(start, start + len(val_data)))) 135 | start = len(train_data) + len(val_data) 136 | idx.append(list(range(start, start + len(test_data)))) 137 | 138 | return data, idx[0], idx[1], idx[2] 139 | 140 | 141 | -------------------------------------------------------------------------------- /exp/test_sr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import pytest 5 | 6 | from data.data_loading import DataLoader, load_dataset 7 | from exp.prepare_sr_tests import prepare 8 | from mp.models import MessagePassingAgnostic, SparseCIN 9 | 10 | def _get_cwn_sr_embeddings(family, seed, baseline=False): 11 | 12 | # Set the seed for everything 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | 19 | # Please set the parameters below to the ones used in SR experiments. 20 | # If so, if tests pass then the experiments are deemed sound. 21 | hidden = 16 22 | num_layers = 3 23 | max_ring_size = 6 24 | use_coboundaries = True 25 | nonlinearity = 'elu' 26 | graph_norm = 'id' 27 | readout = 'sum' 28 | final_readout = 'sum' 29 | readout_dims = (0,1,2) 30 | init = 'sum' 31 | jobs = 64 32 | prepare_seed = 43 33 | device = torch.device("cuda:" + str(0)) if torch.cuda.is_available() else torch.device("cpu") 34 | 35 | # Build and dump dataset if needed 36 | prepare(family, jobs, max_ring_size, True, init, prepare_seed) 37 | 38 | # Load reference dataset 39 | complexes = load_dataset(family, max_dim=2, max_ring_size=max_ring_size, init_method=init) 40 | permuted_complexes = load_dataset(f'{family}p{prepare_seed}', max_dim=2, max_ring_size=max_ring_size, init_method=init) 41 | 42 | # Instantiate model 43 | if not baseline: 44 | model = SparseCIN(num_input_features=1, num_classes=complexes.num_classes, num_layers=num_layers, hidden=hidden, 45 | use_coboundaries=use_coboundaries, nonlinearity=nonlinearity, graph_norm=graph_norm, 46 | readout=readout, final_readout=final_readout, readout_dims=readout_dims) 47 | else: 48 | hidden = 256 49 | model = MessagePassingAgnostic(num_input_features=1, num_classes=complexes.num_classes, hidden=hidden, 50 | nonlinearity=nonlinearity, readout=readout) 51 | 52 | model = model.to(device) 53 | model.eval() 54 | 55 | # Compute reference complex embeddings 56 | data_loader = DataLoader(complexes, batch_size=8, shuffle=False, num_workers=16, max_dim=2) 57 | data_loader_perm = DataLoader(permuted_complexes, batch_size=8, shuffle=False, num_workers=16, max_dim=2) 58 | 59 | with torch.no_grad(): 60 | embeddings = list() 61 | perm_embeddings = list() 62 | for batch in data_loader: 63 | batch.nodes.x = batch.nodes.x.double() 64 | batch.edges.x = batch.edges.x.double() 65 | batch.two_cells.x = batch.two_cells.x.double() 66 | out = model.forward(batch.to(device)) 67 | embeddings.append(out) 68 | for batch in data_loader_perm: 69 | batch.nodes.x = batch.nodes.x.double() 70 | batch.edges.x = batch.edges.x.double() 71 | batch.two_cells.x = batch.two_cells.x.double() 72 | out = model.forward(batch.to(device)) 73 | perm_embeddings.append(out) 74 | embeddings = torch.cat(embeddings, 0) # n x d 75 | perm_embeddings = torch.cat(perm_embeddings, 0) # n x d 76 | assert embeddings.size(0) == perm_embeddings.size(0) 77 | assert embeddings.size(1) == perm_embeddings.size(1) == complexes.num_classes 78 | 79 | return embeddings, perm_embeddings 80 | 81 | def _validate_self_iso_on_sr(embeddings, perm_embeddings): 82 | eps = 0.01 83 | for i in range(embeddings.size(0)): 84 | preds = torch.stack((embeddings[i], perm_embeddings[i]), 0) 85 | assert preds.size(0) == 2 86 | assert preds.size(1) == embeddings.size(1) 87 | dist = torch.pdist(preds, p=2).item() 88 | assert dist <= eps 89 | 90 | def _validate_magnitude_embeddings(embeddings): 91 | # At (5)e8, the fp64 granularity is still (2**29 - 2**28) / (2**52) ≈ 0.000000059604645 92 | # The fact that we work in such a (safe) range can also be verified by running the following: 93 | # a = torch.DoubleTensor([2.5e8]) 94 | # d = torch.DoubleTensor([5.0e8]) 95 | # b = torch.nextafter(a, d) 96 | # print(b - a) 97 | # >>> tensor([2.9802e-08], dtype=torch.float64) 98 | thresh = torch.DoubleTensor([5.0*1e8]) 99 | apex = torch.max(torch.abs(embeddings)).cpu() 100 | print(apex) 101 | assert apex.dtype == torch.float64 102 | assert torch.all(apex < thresh) 103 | 104 | @pytest.mark.slow 105 | @pytest.mark.parametrize("family", ['sr16622', 'sr251256', 'sr261034', 'sr281264', 'sr291467', 'sr351668', 'sr351899', 'sr361446', 'sr401224']) 106 | def test_sparse_cin0_self_isomorphism(family): 107 | # Perform the check in double precision 108 | torch.set_default_dtype(torch.float64) 109 | for seed in range(5): 110 | embeddings, perm_embeddings = _get_cwn_sr_embeddings(family, seed) 111 | _validate_magnitude_embeddings(embeddings) 112 | _validate_magnitude_embeddings(perm_embeddings) 113 | _validate_self_iso_on_sr(embeddings, perm_embeddings) 114 | # Revert back to float32 for other tests 115 | torch.set_default_dtype(torch.float32) 116 | 117 | @pytest.mark.slow 118 | @pytest.mark.parametrize("family", ['sr16622', 'sr251256', 'sr261034', 'sr281264', 'sr291467', 'sr351668', 'sr351899', 'sr361446', 'sr401224']) 119 | def test_cwn_baseline_self_isomorphism(family): 120 | # Perform the check in double precision 121 | torch.set_default_dtype(torch.float64) 122 | for seed in range(5): 123 | embeddings, perm_embeddings = _get_cwn_sr_embeddings(family, seed, baseline=True) 124 | _validate_magnitude_embeddings(embeddings) 125 | _validate_magnitude_embeddings(perm_embeddings) 126 | _validate_self_iso_on_sr(embeddings, perm_embeddings) 127 | # Revert back to float32 for other tests 128 | torch.set_default_dtype(torch.float32) 129 | -------------------------------------------------------------------------------- /mp/test_orientation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from data.datasets.flow import load_flow_dataset 5 | from mp.models import EdgeOrient, EdgeMPNN 6 | from mp.layers import OrientedConv 7 | from data.complex import CochainBatch 8 | from data.data_loading import DataLoader 9 | from data.datasets.flow_utils import build_cochain 10 | 11 | 12 | def generate_oriented_flow_pair(): 13 | # This is the complex from slide 19 of https://crisbodnar.github.io/files/mml_talk.pdf 14 | B1 = np.array([ 15 | [-1, -1, 0, 0, 0, 0], 16 | [+1, 0, -1, 0, 0, +1], 17 | [ 0, +1, 0, -1, 0, -1], 18 | [ 0, 0, +1, +1, -1, 0], 19 | [ 0, 0, 0, 0, +1, 0], 20 | ]) 21 | 22 | B2 = np.array([ 23 | [-1, 0], 24 | [+1, 0], 25 | [ 0, +1], 26 | [ 0, -1], 27 | [ 0, 0], 28 | [+1, +1], 29 | ]) 30 | 31 | x = np.array([[1.0], [0.0], [0.0], [1.0], [1.0], [-1.0]]) 32 | id = np.identity(x.shape[0]) 33 | T2 = np.diag([+1.0, +1.0, +1.0, +1.0, -1.0, -1.0]) 34 | 35 | cochain1 = build_cochain(B1, B2, id, x, 0) 36 | cochain2 = build_cochain(B1, B2, T2, x, 0) 37 | return cochain1, cochain2, torch.tensor(T2, dtype=torch.float) 38 | 39 | 40 | def test_edge_orient_model_on_flow_dataset_with_batching(): 41 | dataset, _, _ = load_flow_dataset(num_points=100, num_train=20, num_test=2) 42 | 43 | np.random.seed(4) 44 | data_loader = DataLoader(dataset, batch_size=16) 45 | model = EdgeOrient(num_input_features=1, num_classes=2, num_layers=2, hidden=5) 46 | # We use the model in eval mode to test its inference behavior. 47 | model.eval() 48 | 49 | batched_preds = [] 50 | for batch in data_loader: 51 | batched_pred = model.forward(batch) 52 | batched_preds.append(batched_pred) 53 | batched_preds = torch.cat(batched_preds, dim=0) 54 | 55 | preds = [] 56 | for cochain in dataset: 57 | pred = model.forward(CochainBatch.from_cochain_list([cochain])) 58 | preds.append(pred) 59 | preds = torch.cat(preds, dim=0) 60 | 61 | assert (preds.size() == batched_preds.size()) 62 | assert torch.allclose(preds, batched_preds, atol=1e-5) 63 | 64 | 65 | def test_edge_orient_conv_is_orientation_equivariant(): 66 | cochain1, cochain2, T2 = generate_oriented_flow_pair() 67 | assert torch.equal(cochain1.lower_index, cochain2.lower_index) 68 | assert torch.equal(cochain1.upper_index, cochain2.upper_index) 69 | 70 | layer = OrientedConv(dim=1, up_msg_size=1, down_msg_size=1, update_up_nn=None, 71 | update_down_nn=None, update_nn=None, act_fn=None) 72 | 73 | out_up1, out_down1, _ = layer.propagate(cochain1.upper_index, cochain1.lower_index, None, x=cochain1.x, 74 | up_attr=cochain1.upper_orient.view(-1, 1), down_attr=cochain1.lower_orient.view(-1, 1)) 75 | out_up2, out_down2, _ = layer.propagate(cochain2.upper_index, cochain2.lower_index, None, x=cochain2.x, 76 | up_attr=cochain2.upper_orient.view(-1, 1), down_attr=cochain2.lower_orient.view(-1, 1)) 77 | 78 | assert torch.equal(T2 @ out_up1, out_up2) 79 | assert torch.equal(T2 @ out_down1, out_down2) 80 | assert torch.equal(T2 @ (cochain1.x + out_up1 + out_down1), cochain2.x + out_up2 + out_down2) 81 | 82 | 83 | def test_edge_orient_model_with_tanh_is_orientation_equivariant_and_invariant_at_readout(): 84 | cochain1, cochain2, T2 = generate_oriented_flow_pair() 85 | assert torch.equal(cochain1.lower_index, cochain2.lower_index) 86 | assert torch.equal(cochain1.upper_index, cochain2.upper_index) 87 | 88 | model = EdgeOrient(num_input_features=1, num_classes=2, num_layers=2, hidden=5, 89 | nonlinearity='tanh', dropout_rate=0.0) 90 | model.eval() 91 | 92 | final1, pred1 = model.forward(CochainBatch.from_cochain_list([cochain1]), include_partial=True) 93 | final2, pred2 = model.forward(CochainBatch.from_cochain_list([cochain2]), include_partial=True) 94 | # Check equivariant. 95 | assert torch.equal(T2 @ pred1, pred2) 96 | # Check invariant after readout. 97 | assert torch.equal(final1, final2) 98 | 99 | 100 | def test_edge_orient_model_with_id_is_orientation_equivariant_and_invariant_at_readout(): 101 | cochain1, cochain2, T2 = generate_oriented_flow_pair() 102 | assert torch.equal(cochain1.lower_index, cochain2.lower_index) 103 | assert torch.equal(cochain1.upper_index, cochain2.upper_index) 104 | 105 | model = EdgeOrient(num_input_features=1, num_classes=2, num_layers=2, hidden=5, 106 | nonlinearity='id', dropout_rate=0.0) 107 | model.eval() 108 | 109 | final1, pred1 = model.forward(CochainBatch.from_cochain_list([cochain1]), include_partial=True) 110 | final2, pred2 = model.forward(CochainBatch.from_cochain_list([cochain2]), include_partial=True) 111 | # Check equivariant. 112 | assert torch.equal(T2 @ pred1, pred2) 113 | # Check invariant after readout. 114 | assert torch.equal(final1, final2) 115 | 116 | 117 | def test_edge_orient_model_with_relu_is_not_orientation_equivariant_or_invariant(): 118 | cochain1, cochain2, T2 = generate_oriented_flow_pair() 119 | assert torch.equal(cochain1.lower_index, cochain2.lower_index) 120 | assert torch.equal(cochain1.upper_index, cochain2.upper_index) 121 | 122 | model = EdgeOrient(num_input_features=1, num_classes=2, num_layers=2, hidden=5, 123 | nonlinearity='relu', dropout_rate=0.0) 124 | model.eval() 125 | 126 | _, pred1 = model.forward(CochainBatch.from_cochain_list([cochain1]), include_partial=True) 127 | _, pred2 = model.forward(CochainBatch.from_cochain_list([cochain2]), include_partial=True) 128 | # Check not equivariant. 129 | assert not torch.equal(T2 @ pred1, pred2) 130 | # Check not invariant. 131 | assert not torch.equal(pred1, pred2) 132 | 133 | 134 | def test_edge_mpnn_model_is_orientation_invariant(): 135 | cochain1, cochain2, T2 = generate_oriented_flow_pair() 136 | assert torch.equal(cochain1.lower_index, cochain2.lower_index) 137 | assert torch.equal(cochain1.upper_index, cochain2.upper_index) 138 | 139 | model = EdgeMPNN(num_input_features=1, num_classes=2, num_layers=2, hidden=5, 140 | nonlinearity='id', dropout_rate=0.0) 141 | model.eval() 142 | 143 | _, pred1 = model.forward(CochainBatch.from_cochain_list([cochain1]), include_partial=True) 144 | _, pred2 = model.forward(CochainBatch.from_cochain_list([cochain2]), include_partial=True) 145 | 146 | # Check the model is orientation invariant. 147 | assert torch.equal(pred1, pred2) 148 | -------------------------------------------------------------------------------- /mp/test_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | 4 | from mp.layers import ( 5 | DummyCellularMessagePassing, CINConv, OrientedConv, InitReduceConv, EmbedVEWithReduce) 6 | from data.dummy_complexes import get_house_complex, get_molecular_complex 7 | from torch import nn 8 | from data.datasets.flow import load_flow_dataset 9 | 10 | 11 | def test_dummy_cellular_message_passing_with_down_msg(): 12 | house_complex = get_house_complex() 13 | v_params = house_complex.get_cochain_params(dim=0) 14 | e_params = house_complex.get_cochain_params(dim=1) 15 | t_params = house_complex.get_cochain_params(dim=2) 16 | 17 | dsmp = DummyCellularMessagePassing() 18 | v_x, e_x, t_x = dsmp.forward(v_params, e_params, t_params) 19 | 20 | expected_v_x = torch.tensor([[12], [9], [25], [25], [23]], dtype=torch.float) 21 | assert torch.equal(v_x, expected_v_x) 22 | 23 | expected_e_x = torch.tensor([[10], [20], [47], [22], [42], [37]], dtype=torch.float) 24 | assert torch.equal(e_x, expected_e_x) 25 | 26 | expected_t_x = torch.tensor([[1]], dtype=torch.float) 27 | assert torch.equal(t_x, expected_t_x) 28 | 29 | 30 | def test_dummy_cellular_message_passing_with_boundary_msg(): 31 | house_complex = get_house_complex() 32 | v_params = house_complex.get_cochain_params(dim=0) 33 | e_params = house_complex.get_cochain_params(dim=1) 34 | t_params = house_complex.get_cochain_params(dim=2) 35 | 36 | dsmp = DummyCellularMessagePassing(use_boundary_msg=True, use_down_msg=False) 37 | v_x, e_x, t_x = dsmp.forward(v_params, e_params, t_params) 38 | 39 | expected_v_x = torch.tensor([[12], [9], [25], [25], [23]], dtype=torch.float) 40 | assert torch.equal(v_x, expected_v_x) 41 | 42 | expected_e_x = torch.tensor([[4], [7], [23], [9], [25], [24]], dtype=torch.float) 43 | assert torch.equal(e_x, expected_e_x) 44 | 45 | expected_t_x = torch.tensor([[15]], dtype=torch.float) 46 | assert torch.equal(t_x, expected_t_x) 47 | 48 | 49 | def test_dummy_cellular_message_passing_on_molecular_cell_complex(): 50 | molecular_complex = get_molecular_complex() 51 | v_params = molecular_complex.get_cochain_params(dim=0) 52 | e_params = molecular_complex.get_cochain_params(dim=1) 53 | ring_params = molecular_complex.get_cochain_params(dim=2) 54 | 55 | dsmp = DummyCellularMessagePassing(use_boundary_msg=True, use_down_msg=True) 56 | v_x, e_x, ring_x = dsmp.forward(v_params, e_params, ring_params) 57 | 58 | expected_v_x = torch.tensor([[12], [24], [24], [15], [25], [31], [47], [24]], 59 | dtype=torch.float) 60 | assert torch.equal(v_x, expected_v_x) 61 | 62 | expected_e_x = torch.tensor([[35], [79], [41], [27], [66], [70], [92], [82], [53]], 63 | dtype=torch.float) 64 | assert torch.equal(e_x, expected_e_x) 65 | 66 | # The first cell feature is given by 1[x] + 0[up] + (2+2)[down] + (1+2+3+4)[boundaries] = 15 67 | # The 2nd cell is given by 2[x] + 0[up] + (1+2)[down] + (2+5+6+7+8)[boundaries] = 33 68 | expected_ring_x = torch.tensor([[15], [33]], dtype=torch.float) 69 | assert torch.equal(ring_x, expected_ring_x) 70 | 71 | 72 | def test_cin_conv_training(): 73 | msg_net = nn.Sequential(nn.Linear(2, 1)) 74 | update_net = nn.Sequential(nn.Linear(1, 3)) 75 | 76 | cin_conv = CINConv(1, 1, msg_net, msg_net, update_net, 0.05) 77 | 78 | all_params_before = [] 79 | for p in cin_conv.parameters(): 80 | all_params_before.append(p.clone().data) 81 | assert len(all_params_before) > 0 82 | 83 | house_complex = get_house_complex() 84 | 85 | v_params = house_complex.get_cochain_params(dim=0) 86 | e_params = house_complex.get_cochain_params(dim=1) 87 | t_params = house_complex.get_cochain_params(dim=2) 88 | 89 | yv = house_complex.get_labels(dim=0) 90 | ye = house_complex.get_labels(dim=1) 91 | yt = house_complex.get_labels(dim=2) 92 | y = torch.cat([yv, ye, yt]) 93 | 94 | optimizer = optim.SGD(cin_conv.parameters(), lr=0.001) 95 | optimizer.zero_grad() 96 | 97 | out_v, out_e, out_t = cin_conv.forward(v_params, e_params, t_params) 98 | out = torch.cat([out_v, out_e, out_t], dim=0) 99 | 100 | criterion = nn.CrossEntropyLoss() 101 | loss = criterion(out, y) 102 | loss.backward() 103 | optimizer.step() 104 | 105 | all_params_after = [] 106 | for p in cin_conv.parameters(): 107 | all_params_after.append(p.clone().data) 108 | assert len(all_params_after) == len(all_params_before) 109 | 110 | # Check that parameters have been updated. 111 | for i, _ in enumerate(all_params_before): 112 | assert not torch.equal(all_params_before[i], all_params_after[i]) 113 | 114 | 115 | def test_orient_conv_on_flow_dataset(): 116 | import numpy as np 117 | 118 | np.random.seed(4) 119 | update_up = nn.Sequential(nn.Linear(1, 4)) 120 | update_down = nn.Sequential(nn.Linear(1, 4)) 121 | update = nn.Sequential(nn.Linear(1, 4)) 122 | 123 | train, _, G = load_flow_dataset(num_points=400, num_train=3, num_test=3) 124 | number_of_edges = G.number_of_edges() 125 | 126 | model = OrientedConv(1, 1, 1, update_up_nn=update_up, update_down_nn=update_down, 127 | update_nn=update, act_fn=torch.tanh) 128 | model.eval() 129 | 130 | out = model.forward(train[0]) 131 | assert out.size(0) == number_of_edges 132 | assert out.size(1) == 4 133 | 134 | 135 | def test_init_reduce_conv_on_house_complex(): 136 | house_complex = get_house_complex() 137 | v_params = house_complex.get_cochain_params(dim=0) 138 | e_params = house_complex.get_cochain_params(dim=1) 139 | t_params = house_complex.get_cochain_params(dim=2) 140 | 141 | conv = InitReduceConv(reduce='add') 142 | 143 | ex = conv.forward(v_params.x, e_params.boundary_index) 144 | expected_ex = torch.tensor([[3], [5], [7], [5], [9], [8]], dtype=torch.float) 145 | assert torch.equal(expected_ex, ex) 146 | 147 | tx = conv.forward(e_params.x, t_params.boundary_index) 148 | expected_tx = torch.tensor([[14]], dtype=torch.float) 149 | assert torch.equal(expected_tx, tx) 150 | 151 | 152 | def test_embed_with_reduce_layer_on_house_complex(): 153 | house_complex = get_house_complex() 154 | cochains = house_complex.cochains 155 | params = house_complex.get_all_cochain_params() 156 | 157 | embed_layer = nn.Embedding(num_embeddings=32, embedding_dim=10) 158 | init_reduce = InitReduceConv() 159 | conv = EmbedVEWithReduce(embed_layer, None, init_reduce) 160 | 161 | # Simulate the lack of features in these dimensions. 162 | params[1].x = None 163 | params[2].x = None 164 | 165 | xs = conv.forward(*params) 166 | 167 | assert len(xs) == 3 168 | assert xs[0].dim() == 2 169 | assert xs[0].size(0) == cochains[0].num_cells 170 | assert xs[0].size(1) == 10 171 | assert xs[1].size(0) == cochains[1].num_cells 172 | assert xs[1].size(1) == 10 173 | assert xs[2].size(0) == cochains[2].num_cells 174 | assert xs[2].size(1) == 10 175 | 176 | 177 | -------------------------------------------------------------------------------- /data/datasets/tu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | from definitions import ROOT_DIR 6 | 7 | from data.tu_utils import load_data, S2V_to_PyG, get_fold_indices 8 | from data.utils import convert_graph_dataset_with_gudhi, convert_graph_dataset_with_rings 9 | from data.datasets import InMemoryComplexDataset 10 | 11 | 12 | def load_tu_graph_dataset(name, root=os.path.join(ROOT_DIR, 'datasets'), degree_as_tag=False, fold=0, seed=0): 13 | raw_dir = os.path.join(root, name, 'raw') 14 | load_from = os.path.join(raw_dir, '{}_graph_list_degree_as_tag_{}.pkl'.format(name, degree_as_tag)) 15 | if os.path.isfile(load_from): 16 | with open(load_from, 'rb') as handle: 17 | graph_list = pickle.load(handle) 18 | else: 19 | data, num_classes = load_data(raw_dir, name, degree_as_tag) 20 | print('Converting graph data into PyG format...') 21 | graph_list = [S2V_to_PyG(datum) for datum in data] 22 | with open(load_from, 'wb') as handle: 23 | pickle.dump(graph_list, handle) 24 | train_filename = os.path.join(raw_dir, '10fold_idx', 'train_idx-{}.txt'.format(fold + 1)) 25 | test_filename = os.path.join(raw_dir, '10fold_idx', 'test_idx-{}.txt'.format(fold + 1)) 26 | if os.path.isfile(train_filename) and os.path.isfile(test_filename): 27 | # NB: we consider the loaded test indices as val_ids ones and set test_ids to None 28 | # to make it more convenient to work with the training pipeline 29 | train_ids = np.loadtxt(train_filename, dtype=int).tolist() 30 | val_ids = np.loadtxt(test_filename, dtype=int).tolist() 31 | else: 32 | train_ids, val_ids = get_fold_indices(graph_list, seed, fold) 33 | test_ids = None 34 | return graph_list, train_ids, val_ids, test_ids 35 | 36 | 37 | class TUDataset(InMemoryComplexDataset): 38 | """A dataset of complexes obtained by lifting graphs from TUDatasets.""" 39 | 40 | def __init__(self, root, name, max_dim=2, num_classes=2, degree_as_tag=False, fold=0, 41 | init_method='sum', seed=0, include_down_adj=False, max_ring_size=None): 42 | self.name = name 43 | self.degree_as_tag = degree_as_tag 44 | assert max_ring_size is None or max_ring_size > 3 45 | self._max_ring_size = max_ring_size 46 | cellular = (max_ring_size is not None) 47 | if cellular: 48 | assert max_dim == 2 49 | 50 | super(TUDataset, self).__init__(root, max_dim=max_dim, num_classes=num_classes, 51 | init_method=init_method, include_down_adj=include_down_adj, cellular=cellular) 52 | 53 | self.data, self.slices = torch.load(self.processed_paths[0]) 54 | 55 | self.fold = fold 56 | self.seed = seed 57 | train_filename = os.path.join(self.raw_dir, '10fold_idx', 'train_idx-{}.txt'.format(fold + 1)) 58 | test_filename = os.path.join(self.raw_dir, '10fold_idx', 'test_idx-{}.txt'.format(fold + 1)) 59 | if os.path.isfile(train_filename) and os.path.isfile(test_filename): 60 | # NB: we consider the loaded test indices as val_ids ones and set test_ids to None 61 | # to make it more convenient to work with the training pipeline 62 | self.train_ids = np.loadtxt(train_filename, dtype=int).tolist() 63 | self.val_ids = np.loadtxt(test_filename, dtype=int).tolist() 64 | else: 65 | train_ids, val_ids = get_fold_indices(self, self.seed, self.fold) 66 | self.train_ids = train_ids 67 | self.val_ids = val_ids 68 | self.test_ids = None 69 | # TODO: Add this later to our zip 70 | # tune_train_filename = os.path.join(self.raw_dir, 'tests_train_split.txt'.format(fold + 1)) 71 | # self.tune_train_ids = np.loadtxt(tune_train_filename, dtype=int).tolist() 72 | # tune_test_filename = os.path.join(self.raw_dir, 'tests_val_split.txt'.format(fold + 1)) 73 | # self.tune_val_ids = np.loadtxt(tune_test_filename, dtype=int).tolist() 74 | # self.tune_test_ids = None 75 | 76 | @property 77 | def processed_dir(self): 78 | """This is overwritten, so the cellular complex data is placed in another folder""" 79 | directory = super(TUDataset, self).processed_dir 80 | suffix = f"_{self._max_ring_size}rings" if self._cellular else "" 81 | suffix += f"_down_adj" if self.include_down_adj else "" 82 | return directory + suffix 83 | 84 | @property 85 | def processed_file_names(self): 86 | return ['{}_complex_list.pt'.format(self.name)] 87 | 88 | @property 89 | def raw_file_names(self): 90 | # The processed graph files are our raw files. 91 | # They are obtained when running the initial data conversion S2V_to_PyG. 92 | return ['{}_graph_list_degree_as_tag_{}.pkl'.format(self.name, self.degree_as_tag)] 93 | 94 | def download(self): 95 | # This will process the raw data into a list of PyG Data objs. 96 | data, num_classes = load_data(self.raw_dir, self.name, self.degree_as_tag) 97 | self._num_classes = num_classes 98 | print('Converting graph data into PyG format...') 99 | graph_list = [S2V_to_PyG(datum) for datum in data] 100 | with open(self.raw_paths[0], 'wb') as handle: 101 | pickle.dump(graph_list, handle) 102 | 103 | def process(self): 104 | with open(self.raw_paths[0], 'rb') as handle: 105 | graph_list = pickle.load(handle) 106 | 107 | if self._cellular: 108 | print("Converting the dataset accounting for rings...") 109 | complexes, _, _ = convert_graph_dataset_with_rings(graph_list, max_ring_size=self._max_ring_size, 110 | include_down_adj=self.include_down_adj, 111 | init_method=self._init_method, 112 | init_edges=True, init_rings=True) 113 | else: 114 | print("Converting the dataset with gudhi...") 115 | # TODO: eventually remove the following comment 116 | # What about the init_method here? Adding now, although I remember we had handled this 117 | complexes, _, _ = convert_graph_dataset_with_gudhi(graph_list, expansion_dim=self.max_dim, 118 | include_down_adj=self.include_down_adj, 119 | init_method=self._init_method) 120 | 121 | torch.save(self.collate(complexes, self.max_dim), self.processed_paths[0]) 122 | 123 | def get_tune_idx_split(self): 124 | raise NotImplementedError('Not implemented yet') 125 | # idx_split = { 126 | # 'train': self.tune_train_ids, 127 | # 'valid': self.tune_val_ids, 128 | # 'test': self.tune_test_ids} 129 | # return idx_split 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CW Networks 2 | 3 | [![example workflow](https://github.com/twitter-research/cwn/actions/workflows/python-package.yml/badge.svg)](https://github.com/twitter-research/cwn/actions) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://github.com/twitter-research/cwn/blob/main/LICENSE) 5 | 6 | 7 | 8 | This repository contains the official code used for the papers 9 | **[Weisfeiler and Lehman Go Cellular: CW Networks](https://arxiv.org/abs/2106.12575) (NeurIPS 2021)** 10 | and **[Weisfeiler and Lehman Go Topological: Message Passing Simplicial Networks](https://arxiv.org/abs/2103.03212) (ICML 2021)** 11 | 12 |

13 |      14 |

15 | 16 | *Graph Neural Networks (GNNs) are limited in their expressive power, struggle with long-range 17 | interactions and lack a principled way to model higher-order structures. These problems can be 18 | attributed to the strong coupling between the computational graph and the input graph structure. 19 | The recently proposed Message Passing Simplicial Networks naturally decouple these elements 20 | by performing message passing on the clique complex of the graph. Nevertheless, 21 | these models are severely constrained by the rigid combinatorial structure of 22 | Simplicial Complexes (SCs). In this work, we extend recent theoretical results on SCs to 23 | regular Cell Complexes, topological objects that flexibly subsume SCs and graphs. 24 | We show that this generalisation provides a powerful set of graph "lifting" transformations, 25 | each leading to a unique hierarchical message passing procedure. The resulting methods, 26 | which we collectively call CW Networks (CWNs), are strictly more powerful than the WL test and, 27 | in certain cases, not less powerful than the 3-WL test. In particular, we demonstrate the 28 | effectiveness of one such scheme, based on rings, when applied to molecular graph problems. 29 | The proposed architecture benefits from provably larger expressivity than commonly used GNNs, 30 | principled modelling of higher-order signals and from compressing the distances between nodes. 31 | We demonstrate that our model achieves state-of-the-art results on a variety of molecular datasets.* 32 | 33 | ## Installation 34 | 35 | 36 | 37 | We use `Python 3.8` and `PyTorch 1.7.0` on `CUDA 10.2` for this project. 38 | Please open a terminal window and follow these steps to prepare the virtual environment needed to run any experiment. 39 | 40 | Create the environment: 41 | ```shell 42 | conda create --name cwn python=3.8 43 | conda activate cwn 44 | conda install pip # Make sure the environment pip is used 45 | ``` 46 | 47 | Install dependencies: 48 | ```shell 49 | sh graph-tool_install.sh 50 | conda install -y pytorch=1.7.0 torchvision cudatoolkit=10.2 -c pytorch 51 | sh pyG_install.sh cu102 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | ### Testing 56 | 57 | We suggest running all tests in the repository to verify everything is in place. Run: 58 | ```shell 59 | pytest -v . 60 | ``` 61 | All tests should pass. Note that some tests are skipped since they rely on external 62 | datasets or take a long time to run. We periodically run these additional tests manually. 63 | 64 | ## Experiments 65 | 66 | We prepared individual scripts for each experiment. The results are written in the 67 | `exp/results/` directory and are also displayed in the terminal once the training is 68 | complete. Before the training starts, the scripts will download / preprocess the corresponding graph datasets 69 | and perform the appropriate graph-lifting procedure (this might take a while). 70 | 71 | ### Molecular benchmarks 72 | 73 |

74 | 75 |

76 | 77 | To run an experiment on a molecular benchmark with a CWN, execute: 78 | ```shell 79 | sh exp/scripts/cwn-.sh 80 | ``` 81 | with `` one amongst `zinc`, `zinc-full`, `molhiv`. 82 | 83 | Imposing the parameter budget: it is sufficient to add the suffix `-small` to the `` placeholder: 84 | ```shell 85 | sh exp/scripts/cwn--small.sh 86 | ``` 87 | For example, `sh exp/scripts/cwn-zinc-small.sh` will run the training on ZINC with parameter budget. 88 | 89 | ### Distinguishing SR graphs 90 | 91 | To run an experiment on the SR benchmark with a CWN, run: 92 | ```shell 93 | sh exp/scripts/cwn-sr.sh 94 | ``` 95 | replacing `` with a value amongst `4`, `5`, `6` (`` is the maximum ring size employed in the lifting procedure). The results, for each family, will be written under `exp/results/SR-cwn-sr-/`. 96 | 97 | The following command will run the MLP-sum (strong) baseline on the same ring-lifted graphs: 98 | ```shell 99 | sh exp/scripts/cwn-sr-base.sh 100 | ``` 101 | 102 | In order to run these experiment with clique-complex lifting (MPSNs), run: 103 | ```shell 104 | sh exp/scripts/mpsn-sr.sh 105 | ``` 106 | Clique-lifting is applied up to dimension `k-1`, with `k` the maximum clique-size in the family. 107 | 108 | The MLP-sum baseline on clique-complexes is run with: 109 | ```shell 110 | sh exp/scripts/mpsn-sr-base.sh 111 | ``` 112 | 113 | ### Circular Skip Link (CSL) Experiments 114 | 115 | To run the experiments on the CSL dataset (5 folds x 20 seeds), run the following script: 116 | ```shell 117 | sh exp/scripts/cwn-csl.sh 118 | ``` 119 | 120 | ### Trajectory classification 121 | 122 | For the Ocean Dataset experiments, the data must be downloaded from [here](https://github.com/nglaze00/SCoNe_GCN/blob/master/ocean_drifters_data/dataBuoys.jld2). 123 | The file must be placed in `datasets/OCEAN/raw/`. 124 | 125 | For running the experiments use the following scripts: 126 | ```shell 127 | sh ./exp/scripts/mpsn-flow.sh [id/relu/tanh] 128 | sh ./exp/scripts/mpsn-ocean.sh [id/relu/tanh] 129 | sh ./exp/scripts/gnn-inv-flow.sh 130 | sh ./exp/scripts/gnn-inv-ocean.sh 131 | ``` 132 | 133 | ### TUDatasets 134 | 135 | For experiments on TUDatasets first download the raw data from [here](https://www.dropbox.com/s/2ekun30wxyxpcr7/datasets.zip?dl=0). 136 | Please place the downloaded archive on the root of the repository and unzip it (e.g. `unzip ./datasets.zip`). 137 | 138 | Here we provide the scripts to run CWN on NCI109 and MPSN on REDDITBINARY. This script can be customised to run additional experiments on other datasets. 139 | ```shell 140 | sh ./exp/scripts/cwn-nci109.sh 141 | sh ./exp/scripts/mpsn-redditb.sh 142 | ``` 143 | 144 | ### Credits 145 | 146 | For attribution in academic contexts, please cite the following papers 147 | 148 | ``` 149 | @inproceedings{pmlr-v139-bodnar21a, 150 | title = {Weisfeiler and {Lehman} Go Topological: Message Passing Simplicial Networks}, 151 | author = {Bodnar, Cristian and Frasca, Fabrizio and Wang, Yuguang and Otter, Nina and Montufar, Guido F and Li{\'o}, Pietro and Bronstein, Michael}, 152 | booktitle = {Proceedings of the 38th International Conference on Machine Learning}, 153 | pages = {1026--1037}, 154 | year = {2021}, 155 | editor = {Meila, Marina and Zhang, Tong}, 156 | volume = {139}, 157 | series = {Proceedings of Machine Learning Research}, 158 | month = {18--24 Jul}, 159 | publisher = {PMLR}, 160 | } 161 | ``` 162 | 163 | ``` 164 | @inproceedings{neurips-bodnar2021b, 165 | title = {Weisfeiler and {Lehman} Go Cellular: {CW} Networks}, 166 | author = {Bodnar, Cristian and Frasca, Fabrizio and Otter, Nina and Wang, Yuguang and Li\`{o}, Pietro and Montufar, Guido F and Bronstein, Michael}, 167 | booktitle = {Advances in Neural Information Processing Systems}, 168 | editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan}, 169 | pages = {2625--2640}, 170 | publisher = {Curran Associates, Inc.}, 171 | volume = {34}, 172 | year = {2021} 173 | } 174 | ``` 175 | 176 | ## TODOs 177 | 178 | - [ ] Add support for coboundary adjacencies. 179 | - [ ] Refactor the way empty cochains are handled for batching. 180 | - [ ] Remove redundant parameters from the models 181 | (e.g. msg_up_nn in the top dimension.) 182 | - [ ] Refactor data classes so to remove setters for `__num_xxx_cells__` like attributes. 183 | - [ ] Address other TODOs left in the code. 184 | -------------------------------------------------------------------------------- /exp/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import logging 5 | from tqdm import tqdm 6 | from sklearn import metrics as met 7 | from data.complex import ComplexBatch 8 | from ogb.graphproppred import Evaluator as OGBEvaluator 9 | 10 | cls_criterion = torch.nn.CrossEntropyLoss() 11 | bicls_criterion = torch.nn.BCEWithLogitsLoss() 12 | reg_criterion = torch.nn.L1Loss() 13 | msereg_criterion = torch.nn.MSELoss() 14 | 15 | 16 | def train(model, device, loader, optimizer, task_type='classification', ignore_unlabeled=False): 17 | """ 18 | Performs one training epoch, i.e. one optimization pass over the batches of a data loader. 19 | """ 20 | 21 | if task_type == 'classification': 22 | loss_fn = cls_criterion 23 | elif task_type == 'bin_classification': 24 | loss_fn = bicls_criterion 25 | elif task_type == 'regression': 26 | loss_fn = reg_criterion 27 | elif task_type == 'mse_regression': 28 | loss_fn = msereg_criterion 29 | else: 30 | raise NotImplementedError('Training on task type {} not yet supported.'.format(task_type)) 31 | 32 | curve = list() 33 | model.train() 34 | num_skips = 0 35 | for step, batch in enumerate(tqdm(loader, desc="Training iteration")): 36 | batch = batch.to(device) 37 | if isinstance(batch, ComplexBatch): 38 | num_samples = batch.cochains[0].x.size(0) 39 | for dim in range(1, batch.dimension+1): 40 | num_samples = min(num_samples, batch.cochains[dim].num_cells) 41 | else: 42 | # This is graph. 43 | num_samples = batch.x.size(0) 44 | 45 | if num_samples <= 1: 46 | # Skip batch if it only comprises one sample (could cause problems with BN) 47 | num_skips += 1 48 | if float(num_skips) / len(loader) >= 0.25: 49 | logging.warning("Warning! 25% of the batches were skipped this epoch") 50 | continue 51 | 52 | # (DEBUG) 53 | if num_samples < 10: 54 | logging.warning("Warning! BatchNorm applied on a batch " 55 | "with only {} samples".format(num_samples)) 56 | 57 | optimizer.zero_grad() 58 | pred = model(batch) 59 | if isinstance(loss_fn, torch.nn.CrossEntropyLoss): 60 | targets = batch.y.view(-1,) 61 | else: 62 | targets = batch.y.to(torch.float32).view(pred.shape) 63 | 64 | # In some ogbg-mol* datasets we may have null targets. 65 | # When the cross entropy loss is used and targets are of shape (N,) 66 | # the maks is broadcasted automatically to the shape of the predictions. 67 | mask = ~torch.isnan(targets) 68 | loss = loss_fn(pred[mask], targets[mask]) 69 | 70 | loss.backward() 71 | optimizer.step() 72 | curve.append(loss.detach().cpu().item()) 73 | 74 | return curve 75 | 76 | 77 | def infer(model, device, loader): 78 | """ 79 | Runs inference over all the batches of a data loader. 80 | """ 81 | model.eval() 82 | y_pred = list() 83 | for step, batch in enumerate(tqdm(loader, desc="Inference iteration")): 84 | batch = batch.to(device) 85 | with torch.no_grad(): 86 | pred = model(batch) 87 | y_pred.append(pred.detach().cpu()) 88 | y_pred = torch.cat(y_pred, dim=0).numpy() 89 | return y_pred 90 | 91 | 92 | def eval(model, device, loader, evaluator, task_type): 93 | """ 94 | Evaluates a model over all the batches of a data loader. 95 | """ 96 | 97 | if task_type == 'classification': 98 | loss_fn = cls_criterion 99 | elif task_type == 'bin_classification': 100 | loss_fn = bicls_criterion 101 | elif task_type == 'regression': 102 | loss_fn = reg_criterion 103 | elif task_type == 'mse_regression': 104 | loss_fn = msereg_criterion 105 | else: 106 | loss_fn = None 107 | 108 | model.eval() 109 | y_true = [] 110 | y_pred = [] 111 | losses = [] 112 | for step, batch in enumerate(tqdm(loader, desc="Eval iteration")): 113 | 114 | # Cast features to double precision if that is used 115 | if torch.get_default_dtype() == torch.float64: 116 | for dim in range(batch.dimension + 1): 117 | batch.cochains[dim].x = batch.cochains[dim].x.double() 118 | assert batch.cochains[dim].x.dtype == torch.float64, batch.cochains[dim].x.dtype 119 | 120 | batch = batch.to(device) 121 | with torch.no_grad(): 122 | pred = model(batch) 123 | 124 | if task_type != 'isomorphism': 125 | if isinstance(loss_fn, torch.nn.CrossEntropyLoss): 126 | targets = batch.y.view(-1,) 127 | y_true.append(batch.y.detach().cpu()) 128 | else: 129 | targets = batch.y.to(torch.float32).view(pred.shape) 130 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 131 | mask = ~torch.isnan(targets) # In some ogbg-mol* datasets we may have null targets. 132 | loss = loss_fn(pred[mask], targets[mask]) 133 | losses.append(loss.detach().cpu().item()) 134 | else: 135 | assert loss_fn is None 136 | 137 | y_pred.append(pred.detach().cpu()) 138 | 139 | y_true = torch.cat(y_true, dim=0).numpy() if len(y_true) > 0 else None 140 | y_pred = torch.cat(y_pred, dim=0).numpy() 141 | 142 | input_dict = {'y_pred': y_pred, 'y_true': y_true} 143 | mean_loss = float(np.mean(losses)) if len(losses) > 0 else np.nan 144 | return evaluator.eval(input_dict), mean_loss 145 | 146 | 147 | class Evaluator(object): 148 | 149 | def __init__(self, metric, **kwargs): 150 | if metric == 'isomorphism': 151 | self.eval_fn = self._isomorphism 152 | self.eps = kwargs.get('eps', 0.01) 153 | self.p_norm = kwargs.get('p', 2) 154 | elif metric == 'accuracy': 155 | self.eval_fn = self._accuracy 156 | elif metric == 'ap': 157 | self.eval_fn = self._ap 158 | elif metric == 'mae': 159 | self.eval_fn = self._mae 160 | elif metric.startswith('ogbg-mol'): 161 | self._ogb_evaluator = OGBEvaluator(metric) 162 | self._key = self._ogb_evaluator.eval_metric 163 | self.eval_fn = self._ogb 164 | else: 165 | raise NotImplementedError('Metric {} is not yet supported.'.format(metric)) 166 | 167 | def eval(self, input_dict): 168 | return self.eval_fn(input_dict) 169 | 170 | def _isomorphism(self, input_dict): 171 | # NB: here we return the failure percentage... the smaller the better! 172 | preds = input_dict['y_pred'] 173 | assert preds is not None 174 | assert preds.dtype == np.float64 175 | preds = torch.tensor(preds, dtype=torch.float64) 176 | mm = torch.pdist(preds, p=self.p_norm) 177 | wrong = (mm < self.eps).sum().item() 178 | metric = wrong / mm.shape[0] 179 | return metric 180 | 181 | def _accuracy(self, input_dict, **kwargs): 182 | y_true = input_dict['y_true'] 183 | y_pred = np.argmax(input_dict['y_pred'], axis=1) 184 | assert y_true is not None 185 | assert y_pred is not None 186 | metric = met.accuracy_score(y_true, y_pred) 187 | return metric 188 | 189 | def _ap(self, input_dict, **kwargs): 190 | y_true = input_dict['y_true'] 191 | y_pred = input_dict['y_pred'] 192 | assert y_true is not None 193 | assert y_pred is not None 194 | metric = met.average_precision_score(y_true, y_pred) 195 | return metric 196 | 197 | 198 | def _mae(self, input_dict, **kwargs): 199 | y_true = input_dict['y_true'] 200 | y_pred = input_dict['y_pred'] 201 | assert y_true is not None 202 | assert y_pred is not None 203 | metric = met.mean_absolute_error(y_true, y_pred) 204 | return metric 205 | 206 | def _ogb(self, input_dict, **kwargs): 207 | assert 'y_true' in input_dict 208 | assert input_dict['y_true'] is not None 209 | assert 'y_pred' in input_dict 210 | assert input_dict['y_pred'] is not None 211 | return self._ogb_evaluator.eval(input_dict)[self._key] 212 | -------------------------------------------------------------------------------- /data/helper_test.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch 3 | import networkx as nx 4 | 5 | from torch_geometric.utils import convert 6 | from torch_geometric.data import Data 7 | 8 | 9 | def check_edge_index_are_the_same(upper_index, edge_index): 10 | """Checks that two edge/cell indexes are the same.""" 11 | # These two tensors should have the same content but in different order. 12 | assert upper_index.size() == edge_index.size() 13 | num_edges = edge_index.size(1) 14 | 15 | edge_set1 = set() 16 | edge_set2 = set() 17 | for i in range(num_edges): 18 | e1, e2 = edge_index[0, i].item(), edge_index[1, i].item() 19 | edge1 = tuple(sorted([e1, e2])) 20 | edge_set1.add(edge1) 21 | 22 | e1, e2 = upper_index[0, i].item(), upper_index[1, i].item() 23 | edge2 = tuple(sorted([e1, e2])) 24 | edge_set2.add(edge2) 25 | 26 | assert edge_set1 == edge_set2 27 | 28 | 29 | def get_table(boundary_index): 30 | """Indexes each cell based on the boundary index.""" 31 | elements = boundary_index.size(1) 32 | id_to_cell = dict() 33 | for i in range(elements): 34 | cell_id = boundary_index[1, i].item() 35 | boundary = boundary_index[0, i].item() 36 | if cell_id not in id_to_cell: 37 | id_to_cell[cell_id] = [] 38 | id_to_cell[cell_id].append(boundary) 39 | return id_to_cell 40 | 41 | 42 | def check_edge_attr_are_the_same(boundary_index, ex, edge_index, edge_attr): 43 | """Checks that a pairs of edge attributes are identical.""" 44 | # The maximum node that has an edge must be the same. 45 | assert boundary_index[0, :].max() == edge_index.max() 46 | # The number of edges present in both tensors should be the same. 47 | assert boundary_index.size(1) == edge_index.size(1) 48 | 49 | id_to_edge = get_table(boundary_index) 50 | 51 | edge_to_id = dict() 52 | for edge_idx, edge in id_to_edge.items(): 53 | edge_to_id[tuple(sorted(edge))] = edge_idx 54 | 55 | edges = boundary_index.size(1) 56 | for i in range(edges): 57 | e1, e2 = edge_index[0, i].item(), edge_index[1, i].item() 58 | edge = tuple(sorted([e1, e2])) 59 | 60 | edge_attr1 = ex[edge_to_id[edge]].squeeze() 61 | edge_attr2 = edge_attr[i].squeeze() 62 | 63 | # NB: edge feats may be multidimensional, so we cannot 64 | # generally use the `==` operator here 65 | assert torch.equal(edge_attr1, edge_attr2) 66 | 67 | 68 | def get_rings(n, edge_index, max_ring): 69 | """Extracts the induced cycles from a graph using networkx.""" 70 | x = torch.zeros((n, 1)) 71 | data = Data(x, edge_index=edge_index) 72 | graph = convert.to_networkx(data) 73 | 74 | def is_cycle_edge(i1, i2, cycle): 75 | if i2 == i1 + 1: 76 | return True 77 | if i1 == 0 and i2 == len(cycle) - 1: 78 | return True 79 | return False 80 | 81 | def is_chordless(cycle): 82 | for (i1, v1), (i2, v2) in itertools.combinations(enumerate(cycle), 2): 83 | if not is_cycle_edge(i1, i2, cycle) and graph.has_edge(v1, v2): 84 | return False 85 | return True 86 | 87 | nx_rings = set() 88 | for cycle in nx.simple_cycles(graph): 89 | # Because we need to use a DiGraph for this method, it will also return each edge 90 | # as a cycle. So we skip these together with cycles above the maximum length. 91 | if len(cycle) <= 2 or len(cycle) > max_ring: 92 | continue 93 | # We skip the cycles with chords 94 | if not is_chordless(cycle): 95 | continue 96 | # Store the cycle in a canonical form 97 | nx_rings.add(tuple(sorted(cycle))) 98 | 99 | return nx_rings 100 | 101 | 102 | def get_complex_rings(r_boundary_index, e_boundary_index): 103 | """Extracts the vertices that are part of a ring.""" 104 | # Construct the edge and ring tables 105 | id_to_ring = get_table(r_boundary_index) 106 | id_to_edge = get_table(e_boundary_index) 107 | 108 | rings = set() 109 | for ring, edges in id_to_ring.items(): 110 | # Compose the two tables to extract the vertices in the ring. 111 | vertices = [vertex for edge in edges for vertex in id_to_edge[edge]] 112 | # Eliminate duplicates. 113 | vertices = set(vertices) 114 | # Store the ring in sorted order. 115 | rings.add(tuple(sorted(vertices))) 116 | return rings 117 | 118 | 119 | def compare_complexes(yielded, expected, include_down_adj): 120 | """Checks that two cell complexes are the same.""" 121 | assert yielded.dimension == expected.dimension 122 | assert torch.equal(yielded.y, expected.y) 123 | for dim in range(expected.dimension + 1): 124 | y_cochain = yielded.cochains[dim] 125 | e_cochain = expected.cochains[dim] 126 | assert y_cochain.num_cells == e_cochain.num_cells 127 | assert y_cochain.num_cells_up == e_cochain.num_cells_up 128 | assert y_cochain.num_cells_up == e_cochain.num_cells_up 129 | assert y_cochain.num_cells_down == e_cochain.num_cells_down, dim 130 | assert torch.equal(y_cochain.x, e_cochain.x) 131 | if dim > 0: 132 | assert torch.equal(y_cochain.boundary_index, e_cochain.boundary_index) 133 | if include_down_adj: 134 | if y_cochain.lower_index is None: 135 | assert e_cochain.lower_index is None 136 | assert y_cochain.shared_boundaries is None 137 | assert e_cochain.shared_boundaries is None 138 | else: 139 | assert torch.equal(y_cochain.lower_index, e_cochain.lower_index) 140 | assert torch.equal(y_cochain.shared_boundaries, e_cochain.shared_boundaries) 141 | else: 142 | assert y_cochain.boundary_index is None and e_cochain.boundary_index is None 143 | assert y_cochain.lower_index is None and e_cochain.lower_index is None 144 | assert y_cochain.shared_boundaries is None and e_cochain.shared_boundaries is None 145 | if dim < expected.dimension: 146 | if y_cochain.upper_index is None: 147 | assert e_cochain.upper_index is None 148 | assert y_cochain.shared_coboundaries is None 149 | assert e_cochain.shared_coboundaries is None 150 | else: 151 | assert torch.equal(y_cochain.upper_index, e_cochain.upper_index) 152 | assert torch.equal(y_cochain.shared_coboundaries, e_cochain.shared_coboundaries) 153 | else: 154 | assert y_cochain.upper_index is None and e_cochain.upper_index is None 155 | assert y_cochain.shared_coboundaries is None and e_cochain.shared_coboundaries is None 156 | 157 | 158 | def compare_complexes_without_2feats(yielded, expected, include_down_adj): 159 | """Checks that two cell complexes are the same, except for the features of the 2-cells.""" 160 | 161 | assert yielded.dimension == expected.dimension 162 | assert torch.equal(yielded.y, expected.y) 163 | for dim in range(expected.dimension + 1): 164 | y_cochain = yielded.cochains[dim] 165 | e_cochain = expected.cochains[dim] 166 | assert y_cochain.num_cells == e_cochain.num_cells 167 | assert y_cochain.num_cells_up == e_cochain.num_cells_up 168 | assert y_cochain.num_cells_up == e_cochain.num_cells_up 169 | assert y_cochain.num_cells_down == e_cochain.num_cells_down, dim 170 | if dim > 0: 171 | assert torch.equal(y_cochain.boundary_index, e_cochain.boundary_index) 172 | if include_down_adj: 173 | if y_cochain.lower_index is None: 174 | assert e_cochain.lower_index is None 175 | assert y_cochain.shared_boundaries is None 176 | assert e_cochain.shared_boundaries is None 177 | else: 178 | assert torch.equal(y_cochain.lower_index, e_cochain.lower_index) 179 | assert torch.equal(y_cochain.shared_boundaries, e_cochain.shared_boundaries) 180 | else: 181 | assert y_cochain.boundary_index is None and e_cochain.boundary_index is None 182 | assert y_cochain.lower_index is None and e_cochain.lower_index is None 183 | assert y_cochain.shared_boundaries is None and e_cochain.shared_boundaries is None 184 | if dim < expected.dimension: 185 | if y_cochain.upper_index is None: 186 | assert e_cochain.upper_index is None 187 | assert y_cochain.shared_coboundaries is None 188 | assert e_cochain.shared_coboundaries is None 189 | else: 190 | assert torch.equal(y_cochain.upper_index, e_cochain.upper_index) 191 | assert torch.equal(y_cochain.shared_coboundaries, e_cochain.shared_coboundaries) 192 | else: 193 | assert y_cochain.upper_index is None and e_cochain.upper_index is None 194 | assert y_cochain.shared_coboundaries is None and e_cochain.shared_coboundaries is None 195 | if dim != 2: 196 | assert torch.equal(y_cochain.x, e_cochain.x) 197 | else: 198 | assert y_cochain.x is None and e_cochain.x is None 199 | --------------------------------------------------------------------------------