├── .gitignore ├── .readthedocs.yaml ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── benchmarks └── xgraph │ ├── RandomSelection.py │ ├── Readme.md │ ├── config │ ├── config.yaml │ ├── datasets │ │ ├── ba_2motifs.yaml │ │ ├── ba_community.yaml │ │ ├── ba_lrp.yaml │ │ ├── ba_shapes.yaml │ │ ├── bace.yaml │ │ ├── bbbp.yaml │ │ ├── clintox.yaml │ │ ├── graph_sst2.yaml │ │ ├── graph_sst5.yaml │ │ ├── tox21.yaml │ │ ├── tree_cycle.yaml │ │ ├── tree_grid.yaml │ │ └── twitter.yaml │ ├── explainers │ │ ├── deep_lift.yaml │ │ ├── gnn_explainer.yaml │ │ ├── gnn_gi.yaml │ │ ├── gnn_lrp.yaml │ │ ├── grad_cam.yaml │ │ ├── pgexplainer.yaml │ │ ├── random_explainer.yaml │ │ └── subgraphx.yaml │ └── models │ │ └── gcn.yaml │ ├── dataset.py │ ├── deep_lift.py │ ├── gnnNets.py │ ├── gnn_explainer.py │ ├── gnn_gi.py │ ├── gnn_lrp.py │ ├── grad_cam.py │ ├── imgs │ ├── fidelity.png │ ├── fidelity_inv.png │ └── xgraph.jpg │ ├── pgexplainer_edges.py │ ├── random_explain.py │ ├── subgraphx.py │ ├── train_gnns.py │ └── utils.py ├── dig ├── __init__.py ├── auggraph │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ └── aug_dataset.py │ └── method │ │ ├── GraphAug │ │ ├── __init__.py │ │ ├── aug │ │ │ ├── __init__.py │ │ │ ├── aug_encoder.py │ │ │ ├── aug_utils.py │ │ │ ├── augmenter.py │ │ │ ├── edge_per.py │ │ │ ├── node_drop.py │ │ │ └── node_fm.py │ │ ├── constants │ │ │ ├── __init__.py │ │ │ ├── conf_params.py │ │ │ └── enums.py │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── classifier.py │ │ │ ├── genet.py │ │ │ ├── gmnet.py │ │ │ └── reward_generator.py │ │ ├── runner_aug_cls.py │ │ ├── runner_generator.py │ │ └── runner_reward_gen.py │ │ ├── SMixup │ │ ├── __init__.py │ │ ├── model │ │ │ ├── GCN.py │ │ │ ├── GIN.py │ │ │ ├── GMNET.py │ │ │ ├── GraphMatching.py │ │ │ └── __init__.py │ │ ├── smixup.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── sinkhorn.py │ │ │ └── utils.py │ │ └── __init__.py ├── fairgraph │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ └── fairgraph_dataset.py │ ├── method │ │ ├── Graphair │ │ │ ├── GCN.py │ │ │ ├── __init__.py │ │ │ ├── aug_module.py │ │ │ ├── classifier.py │ │ │ └── graphair.py │ │ ├── __init__.py │ │ └── run.py │ └── utils │ │ ├── __init__.py │ │ └── utils.py ├── ggraph │ ├── __init__.py │ ├── dataset │ │ ├── PygDataset.py │ │ ├── __init__.py │ │ ├── config.csv │ │ └── ggraph_dataset.py │ ├── evaluation │ │ ├── __init__.py │ │ └── metric.py │ ├── method │ │ ├── GraphAF │ │ │ ├── __init__.py │ │ │ ├── graphaf.py │ │ │ ├── model │ │ │ │ ├── __init__.py │ │ │ │ ├── graphaf.py │ │ │ │ ├── graphflow.py │ │ │ │ ├── graphflow_con_rl.py │ │ │ │ ├── graphflow_rl.py │ │ │ │ ├── model_utils.py │ │ │ │ ├── rgcn.py │ │ │ │ └── st_net.py │ │ │ └── train_utils.py │ │ ├── GraphDF │ │ │ ├── __init__.py │ │ │ ├── graphdf.py │ │ │ ├── model │ │ │ │ ├── __init__.py │ │ │ │ ├── df_utils.py │ │ │ │ ├── disgraphaf.py │ │ │ │ ├── graphflow.py │ │ │ │ ├── graphflow_con_rl.py │ │ │ │ ├── graphflow_rl.py │ │ │ │ ├── rgcn.py │ │ │ │ └── st_net.py │ │ │ └── train_utils.py │ │ ├── GraphEBM │ │ │ ├── __init__.py │ │ │ ├── energy_func.py │ │ │ ├── graphebm.py │ │ │ └── util.py │ │ ├── JTVAE │ │ │ ├── __init__.py │ │ │ ├── fast_jtnn │ │ │ │ ├── __init__.py │ │ │ │ ├── chemutils.py │ │ │ │ ├── datautils.py │ │ │ │ ├── jtmpn.py │ │ │ │ ├── jtmpn_bo.py │ │ │ │ ├── jtnn_dec.py │ │ │ │ ├── jtnn_enc.py │ │ │ │ ├── jtnn_enc_bo.py │ │ │ │ ├── jtnn_vae.py │ │ │ │ ├── jtnn_vae_bo.py │ │ │ │ ├── jtprop_vae.py │ │ │ │ ├── mol_tree.py │ │ │ │ ├── mpn.py │ │ │ │ ├── nnutils.py │ │ │ │ ├── sascorer.py │ │ │ │ └── vocab.py │ │ │ └── jtvae.py │ │ ├── __init__.py │ │ └── generator.py │ └── utils │ │ ├── __init__.py │ │ ├── environment.py │ │ ├── fpscores.pkl.gz │ │ ├── gen_mol_from_one_shot_tensor.py │ │ └── sascorer.py ├── ggraph3D │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ └── ggraph3D_dataset.py │ ├── evaluation │ │ ├── __init__.py │ │ └── metric.py │ ├── method │ │ ├── G_SphereNet │ │ │ ├── __init__.py │ │ │ ├── gspherenet.py │ │ │ └── model │ │ │ │ ├── __init__.py │ │ │ │ ├── att.py │ │ │ │ ├── features.py │ │ │ │ ├── geometric_computing.py │ │ │ │ ├── net_utils.py │ │ │ │ ├── spherenet.py │ │ │ │ └── sphgen.py │ │ └── __init__.py │ └── utils │ │ ├── __init__.py │ │ ├── eval_bond_mmd_utils.py │ │ ├── eval_prop_utils.py │ │ └── eval_validity_utils.py ├── lsgraph │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ ├── get_data.py │ │ └── loader.py │ └── method │ │ ├── FM.py │ │ ├── GraphFMOB │ │ ├── __init__.py │ │ ├── csrc │ │ │ ├── cpu │ │ │ │ ├── relabel_cpu.cpp │ │ │ │ └── relabel_cpu.h │ │ │ ├── cuda │ │ │ │ ├── sync_cuda.cu │ │ │ │ └── sync_cuda.h │ │ │ ├── relabel.cpp │ │ │ ├── sync.cpp │ │ │ └── thread.h │ │ ├── history.py │ │ ├── loader.py │ │ ├── metis.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── gcn.py │ │ │ ├── gcn2.py │ │ │ ├── pna.py │ │ │ └── pna_jk.py │ │ ├── pool.py │ │ └── utils.py │ │ └── __init__.py ├── oodgraph │ ├── __init__.py │ ├── good_arxiv.py │ ├── good_cbas.py │ ├── good_cmnist.py │ ├── good_cora.py │ ├── good_hiv.py │ ├── good_motif.py │ ├── good_pcba.py │ └── good_zinc.py ├── sslgraph │ ├── __init__.py │ ├── dataset │ │ ├── TUDataset.py │ │ ├── __init__.py │ │ ├── datasets.py │ │ └── feat_expansion.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── eval_graph.py │ │ └── eval_node.py │ ├── method │ │ ├── __init__.py │ │ └── contrastive │ │ │ ├── __init__.py │ │ │ ├── model │ │ │ ├── __init__.py │ │ │ ├── contrastive.py │ │ │ ├── grace.py │ │ │ ├── graphcl.py │ │ │ ├── infograph.py │ │ │ ├── mvgrl.py │ │ │ └── pgrace.py │ │ │ ├── objectives │ │ │ ├── __init__.py │ │ │ ├── infonce.py │ │ │ └── jse.py │ │ │ └── views_fn │ │ │ ├── __init__.py │ │ │ ├── combination.py │ │ │ ├── feature.py │ │ │ ├── sample.py │ │ │ └── structure.py │ └── utils │ │ ├── __init__.py │ │ ├── adaptive.py │ │ ├── encoders.py │ │ └── seed.py ├── threedgraph │ ├── __init__.py │ ├── dataset │ │ ├── ECdataset.py │ │ ├── FOLDdataset.py │ │ ├── PygMD17.py │ │ ├── PygQM93D.py │ │ ├── README.md │ │ └── __init__.py │ ├── evaluation │ │ ├── __init__.py │ │ └── eval.py │ ├── method │ │ ├── __init__.py │ │ ├── comenet │ │ │ ├── __init__.py │ │ │ ├── comenet.py │ │ │ ├── features.py │ │ │ └── ocp │ │ │ │ ├── ComENetIS2REResults.jpg │ │ │ │ ├── IS2RETrainedModelWeights.pt │ │ │ │ ├── README.md │ │ │ │ ├── comenet-ocp.py │ │ │ │ ├── comenet.yml │ │ │ │ └── utils.py │ │ ├── dimenetpp │ │ │ ├── __init__.py │ │ │ ├── dimenetpp.py │ │ │ └── features.py │ │ ├── pronet │ │ │ ├── __init__.py │ │ │ ├── features.py │ │ │ └── pronet.py │ │ ├── run.py │ │ ├── schnet │ │ │ ├── __init__.py │ │ │ └── schnet.py │ │ └── spherenet │ │ │ ├── __init__.py │ │ │ ├── features.py │ │ │ └── spherenet.py │ └── utils │ │ ├── __init__.py │ │ └── geometric_computing.py ├── version.py └── xgraph │ ├── __init__.py │ ├── dataset │ ├── __init__.py │ ├── mol_dataset.py │ ├── nlp_dataset.py │ ├── syn_dataset.py │ └── utils_dataset.py │ ├── evaluation │ ├── Readme.md │ ├── __init__.py │ ├── defi.py │ └── metrics.py │ ├── method │ ├── __init__.py │ ├── base_explainer.py │ ├── deeplift.py │ ├── flowx.py │ ├── gnn_gi.py │ ├── gnn_lrp.py │ ├── gnnexplainer.py │ ├── gradcam.py │ ├── pgexplainer.py │ ├── shapley.py │ ├── subgraphx.py │ └── utils │ │ ├── __init__.py │ │ └── symmetric_edge_mask.py │ ├── models │ ├── __init__.py │ ├── ext │ │ ├── __init__.py │ │ └── deeplift │ │ │ ├── __init__.py │ │ │ ├── deep_lift.py │ │ │ └── layer_deep_lift.py │ ├── gradient_utils.py │ ├── model_manager.py │ ├── models.py │ └── utils.py │ └── utils │ ├── __init__.py │ ├── compatibility.py │ └── init.py ├── docs ├── .DS_Store ├── Makefile ├── environment.yaml ├── imgs │ ├── DIG-logo.jpg │ ├── DIG-overview.png │ └── GOOD-datasets.png ├── make.bat └── source │ ├── 3dgraph │ ├── dataset.rst │ ├── evaluation.rst │ ├── method.rst │ └── utils.rst │ ├── auggraph │ ├── dataset.rst │ └── method.rst │ ├── conf.py │ ├── contribution │ └── instruction.rst │ ├── fairgraph │ ├── dataset.rst │ └── method.rst │ ├── ggraph │ ├── dataset.rst │ ├── evaluation.rst │ ├── method.rst │ └── utils.rst │ ├── ggraph3d │ ├── dataset.rst │ ├── evaluation.rst │ ├── method.rst │ └── utils.rst │ ├── index.rst │ ├── intro │ ├── installation.rst │ └── introduction.rst │ ├── oodgraph │ └── good.rst │ ├── sslgraph │ ├── dataset.rst │ ├── evaluation.rst │ ├── method.rst │ └── utils.rst │ ├── tutorials │ ├── fairgraph.rst │ ├── graphdf.rst │ ├── imgs │ │ ├── subgraphx_explanation.png │ │ └── subgraphx_ori_graph.png │ ├── oodgraph.rst │ ├── sslgraph.rst │ ├── subgraphx.rst │ └── threedgraph.rst │ └── xgraph │ ├── dataset.rst │ ├── evaluation.rst │ ├── method.rst │ └── utils.rst ├── examples ├── README.md ├── auggraph │ ├── GraphAug │ │ ├── conf │ │ │ ├── aug_cls_conf.py │ │ │ ├── generator_conf.py │ │ │ └── reward_gen_conf.py │ │ ├── run_aug_cls.py │ │ ├── run_generator.py │ │ └── run_reward_gen.py │ └── SMixup │ │ └── run.py ├── fairgraph │ └── Graphair │ │ ├── run_graphair_nba.py │ │ └── run_graphair_pokec.py ├── ggraph │ ├── GraphAF │ │ ├── README.md │ │ ├── config │ │ │ ├── const_prop_opt_graphaf_config_dict.json │ │ │ ├── prop_opt_plogp_config_dict.json │ │ │ ├── prop_opt_qed_config_dict.json │ │ │ ├── rand_gen_qm9_config_dict.json │ │ │ └── rand_gen_zinc250k_config_dict.json │ │ ├── figs │ │ │ └── graphaf.png │ │ ├── run_const_prop_opt.py │ │ ├── run_prop_opt.py │ │ └── run_rand_gen.py │ ├── GraphDF │ │ ├── README.md │ │ ├── config │ │ │ ├── const_prop_opt_graphaf_config_dict.json │ │ │ ├── const_prop_opt_jt_config_dict.json │ │ │ ├── prop_opt_plogp_config_dict.json │ │ │ ├── prop_opt_qed_config_dict.json │ │ │ ├── rand_gen_moses_config_dict.json │ │ │ ├── rand_gen_qm9_config_dict.json │ │ │ └── rand_gen_zinc250k_config_dict.json │ │ ├── figs │ │ │ ├── .DS_Store │ │ │ └── graphdf.png │ │ ├── run_const_prop_opt.py │ │ ├── run_prop_opt.py │ │ └── run_rand_gen.py │ ├── GraphEBM │ │ ├── README.md │ │ ├── compositional_gen.ipynb │ │ ├── figs │ │ │ └── graphebm_training.png │ │ ├── goal-directed_gen.ipynb │ │ └── randn_gen.ipynb │ └── JTVAE │ │ ├── cons_optim.ipynb │ │ └── rand_gen.ipynb ├── ggraph3D │ └── G_SphereNet │ │ ├── README.md │ │ ├── config_dict.json │ │ ├── figs │ │ └── gspherenet.png │ │ ├── run_prop_opt.py │ │ ├── run_rand_gen.py │ │ └── target_bond_lengths.dict ├── lsgraph │ ├── GraphFMIB │ │ └── reddit_example.py │ └── GraphFMOB │ │ ├── GraphFMOB.py │ │ └── conf │ │ ├── config.yaml │ │ ├── dataset │ │ ├── amazon.yaml │ │ ├── arxiv.yaml │ │ ├── flickr.yaml │ │ ├── ppi.yaml │ │ ├── products.yaml │ │ ├── reddit.yaml │ │ └── yelp.yaml │ │ └── model │ │ ├── gcn.yaml │ │ ├── gcn2.yaml │ │ ├── pna.yaml │ │ └── pna_jk.yaml ├── oodgraph │ └── good_datasets.ipynb ├── sslgraph │ ├── example_gca.ipynb │ ├── example_grace.ipynb │ ├── example_graphcl.ipynb │ ├── example_graphcl_grid_search.ipynb │ ├── example_infograph.ipynb │ └── example_mvgrl.ipynb ├── threedgraph │ ├── run_ProNet.py │ ├── threedgraph.ipynb │ └── xyz_to_dat.ipynb └── xgraph │ ├── deeplift.ipynb │ ├── flowx.ipynb │ ├── gnn_lrp.ipynb │ ├── gnnexplainer.ipynb │ ├── gradcam.ipynb │ ├── pgexplainer.ipynb │ └── subgraphx.ipynb ├── imgs ├── DIG-logo.jpg └── DIG-overview.png ├── script └── conda.sh ├── setup.cfg ├── setup.py ├── test ├── ggraph │ ├── dataset │ │ ├── test_QM9.py │ │ ├── test_ZINC250k.py │ │ └── test_ZINC800.py │ ├── evaluation │ │ ├── test_ConstPropOptEvaluator.py │ │ ├── test_PropOptEvaluator.py │ │ └── test_RandGenEvaluator.py │ └── utils │ │ ├── test_environment.py │ │ └── test_gen_mol_from_one_shot_tensor.py ├── oodgraph │ └── test_good_datasets.py ├── sslgraph │ ├── dataset │ │ ├── test_TUDatasetExt.py │ │ ├── test_get_dataset.py │ │ └── test_get_node_dataset.py │ └── evaluation │ │ ├── test_GraphSemisupervised.py │ │ ├── test_GraphUnsupervised.py │ │ ├── test_NodeUnsupervised.py │ │ └── test_nce_more_view.py ├── threedgraph │ ├── dataset │ │ ├── test_MD17.py │ │ └── test_QM93D.py │ └── evaluation │ │ └── test_ThreeDEvaluator.py └── xgraph │ ├── dataset │ ├── test_BA_LRP.py │ ├── test_MarginalDataset.py │ ├── test_MoleculeDataset.py │ └── test_SynGraphDataset.py │ └── evaluation │ └── test_metrics.py └── tutorials └── KDD2022 ├── 3dgraph_code_tutorial.ipynb ├── DIG-Tutorial-KDD22.pdf ├── README.md ├── ggraph_code_tutorial.ipynb ├── sslgraph_code_tutorial.ipynb └── xgraph_code_tutorial.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .gitignore 2 | .idea/ 3 | .vscode/ 4 | benchmarks/xgraph/checkpoints/ 5 | benchmarks/xgraph/*.zip 6 | benchmarks/xgraph/*.sh 7 | benchmarks/xgraph/results/ 8 | jsons/ 9 | jsons.zip 10 | benchmarks/xgraph/subgraphx_timer.py 11 | experiment_curve.py 12 | /dig/xgraph/dataset/ba_lrp/ 13 | /docs/build/ 14 | /fix_bug.py 15 | .eggs/ 16 | **/__pycache__/ 17 | build/ 18 | **/*.egg-info/ 19 | **/*.egg 20 | dig/auggraph/dataset/tudatasets/* 21 | examples/auggraph/GraphAug/results -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Build documentation in the docs/ directory with Sphinx 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | 12 | # Optionally set the version of Python and requirements required to build your docs 13 | #python: 14 | # version: 3.8 15 | # system_packages: true 16 | # install: 17 | # - requirements: docs/environment.txt 18 | # - method: setuptools 19 | # path: . 20 | 21 | conda: 22 | environment: docs/environment.yaml 23 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: shell 2 | 3 | os: 4 | - linux 5 | 6 | vm: 7 | size: x-large 8 | 9 | branches: 10 | only: 11 | - dig 12 | 13 | env: 14 | jobs: 15 | - TORCH_VERSION=1.8.0 PYTHON_VERSION=3.8 IDX=cpu 16 | 17 | install: 18 | - source script/conda.sh 19 | - conda create --yes -n test python="${PYTHON_VERSION}" 20 | - source activate test 21 | - conda install pytorch==${TORCH_VERSION} ${TOOLKIT} -c pytorch -c conda-forge --yes 22 | - conda install -c conda-forge rdkit --yes 23 | - pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${IDX}.html 24 | - pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${IDX}.html 25 | - pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${IDX}.html 26 | - pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${IDX}.html 27 | - pip install furo 28 | - pip install numpy 29 | - pip install sphinx==3.5.4 30 | - pip install sphinx_rtd_theme==0.5.2 31 | - pip install torch-geometric==1.7.0 32 | - pip install git+https://github.com/Chilipp/autodocsumm.git 33 | - pip install captum==0.2.0 34 | - pip install cilog 35 | - pip install typed-argument-parser==1.5.4 36 | - pip install tensorboard 37 | - pip install codecov 38 | - python setup.py install 39 | 40 | script: 41 | - travis_wait 60 python setup.py test 42 | 43 | after_success: 44 | - codecov 45 | 46 | notifications: 47 | email: false 48 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include dig/ggraph/dataset/config.csv 4 | include dig/ggraph/utils/fpscores.pkl.gz 5 | 6 | recursive-exclude test * 7 | recursive-exclude docs * 8 | recursive-exclude benchmarks * -------------------------------------------------------------------------------- /benchmarks/xgraph/RandomSelection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.utils.loop import add_self_loops 4 | from dig.xgraph.models.utils import subgraph 5 | from dig.xgraph.method.base_explainer import ExplainerBase 6 | 7 | 8 | class RandomSelectorExplainer(ExplainerBase): 9 | def __init__(self, model: nn.Module, explain_graph: bool = False): 10 | super().__init__(model=model, explain_graph=explain_graph) 11 | 12 | def forward(self, x, edge_index, **kwargs): 13 | super().forward(x, edge_index) 14 | self.model.eval() 15 | 16 | # Assume the mask we will predict 17 | labels = tuple(i for i in range(kwargs.get('num_classes'))) 18 | ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels) 19 | 20 | if self.explain_graph: 21 | self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0]) 22 | edge_mask = torch.rand(self_loop_edge_index.shape[1]) 23 | edge_masks = [edge_mask for _ in ex_labels] 24 | 25 | self.__clear_masks__() 26 | self.__set_masks__(x, self_loop_edge_index) 27 | hard_edge_masks = [self.control_sparsity(edge_mask, sparsity=kwargs.get('sparsity')).sigmoid().to(self.device) 28 | for _ in ex_labels] 29 | 30 | with torch.no_grad(): 31 | related_preds = self.eval_related_pred( 32 | x, edge_index, hard_edge_masks) 33 | self.__clear_masks__() 34 | 35 | else: 36 | node_idx = kwargs.get('node_idx') 37 | if not node_idx.dim(): 38 | node_idx = node_idx.reshape(-1) 39 | node_idx = node_idx.to(self.device) 40 | assert node_idx is not None 41 | 42 | self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0]) 43 | 44 | _, _, _, self.hard_edge_mask = subgraph( 45 | node_idx, self.__num_hops__, self_loop_edge_index, 46 | relabel_nodes=True, num_nodes=None, flow=self.__flow__()) 47 | 48 | edge_mask = torch.rand(self_loop_edge_index.shape[1]) 49 | 50 | self.__clear_masks__() 51 | self.__set_masks__(x, self_loop_edge_index) 52 | edge_masks = [edge_mask for _ in ex_labels] 53 | hard_edge_masks = [self.control_sparsity( 54 | edge_mask, sparsity=kwargs.get('sparsity')).sigmoid().to(self.device) for _ in ex_labels] 55 | 56 | with torch.no_grad(): 57 | related_preds = self.eval_related_pred( 58 | x, edge_index, hard_edge_masks, **kwargs) 59 | self.__clear_masks__() 60 | 61 | return edge_masks, hard_edge_masks, related_preds 62 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - models: gcn 3 | - datasets: bbbp 4 | - explainers: subgraphx 5 | random_seed: 0 6 | device_id: 0 7 | record_filename: none -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/ba_2motifs.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | dataset_name: 'ba_2motifs' 3 | random_split_flag: True 4 | data_split_ratio: [0.8, 0.1, 0.1] 5 | seed: 2 6 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/ba_community.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | dataset_name: 'ba_community' -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/ba_lrp.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | random_split_flag: True 3 | data_split_ratio: [0.8, 0.1, 0.1] 4 | dataset_name: 'ba_lrp' 5 | seed: 2 6 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/ba_shapes.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | dataset_name: 'ba_shapes' -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/bace.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | random_split_flag: True 3 | data_split_ratio: [0.8, 0.1, 0.1] 4 | dataset_name: 'bace' 5 | seed: 2 6 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/bbbp.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | dataset_name: 'bbbp' 3 | random_split_flag: True 4 | data_split_ratio: [0.8, 0.1, 0.1] 5 | seed: 2 6 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/clintox.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | random_split_flag: True 3 | data_split_ratio: [0.8, 0.1, 0.1] 4 | dataset_name: 'clintox' 5 | seed: 2 6 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/graph_sst2.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | dataset_name: 'graph_sst2' 3 | random_split_flag: False 4 | data_split_ratio: None 5 | seed: 2 6 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/graph_sst5.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | dataset_name: 'graph_sst5' 3 | random_split_flag: False 4 | data_split_ratio: None 5 | seed: 2 6 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/tox21.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | random_split_flag: True 3 | data_split_ratio: [0.8, 0.1, 0.1] 4 | dataset_name: 'tox21' 5 | seed: 2 6 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/tree_cycle.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | dataset_name: 'tree_cycle' 3 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/tree_grid.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | dataset_name: 'tree_grid' -------------------------------------------------------------------------------- /benchmarks/xgraph/config/datasets/twitter.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '/tmp/datasets' 2 | dataset_name: 'twitter' 3 | random_split_flag: False 4 | data_split_ratio: None 5 | seed: 2 6 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/explainers/deep_lift.yaml: -------------------------------------------------------------------------------- 1 | explanation_result_dir: '' 2 | sparsity: 0.5 3 | 4 | param: 5 | bbbp: 6 | none: None 7 | graph_sst2: 8 | none: None 9 | graph_sst5: 10 | none: None 11 | twitter: 12 | none: None 13 | ba_shapes: 14 | none: None 15 | ba_2motifs: 16 | none: None -------------------------------------------------------------------------------- /benchmarks/xgraph/config/explainers/gnn_explainer.yaml: -------------------------------------------------------------------------------- 1 | explanation_result_dir: '' 2 | sparsity: 0.5 3 | param: 4 | bbbp: 5 | lr: 0.01 6 | epochs: 1000 7 | coff_size: 0.001 8 | coff_ent: 1e-5 9 | ba_shapes: 10 | lr: 0.01 11 | epochs: 100 12 | coff_size: 0.1 13 | coff_ent: 0.1 14 | graph_sst2: 15 | lr: 0.01 16 | epochs: 1000 17 | coff_size: 0.001 18 | coff_ent: 1e-5 19 | graph_sst5: 20 | lr: 0.01 21 | epochs: 1000 22 | coff_size: 0.001 23 | coff_ent: 1e-5 24 | twitter: 25 | lr: 0.01 26 | epochs: 1000 27 | coff_size: 0.001 28 | coff_ent: 1e-5 29 | ba_2motifs: 30 | lr: 0.02 31 | epochs: 100 32 | coff_size: 0.001 33 | coff_ent: 1e-5 -------------------------------------------------------------------------------- /benchmarks/xgraph/config/explainers/gnn_gi.yaml: -------------------------------------------------------------------------------- 1 | explanation_result_dir: '' 2 | sparsity: 0.5 3 | param: 4 | bbbp: 5 | none: None 6 | graph_sst2: 7 | none: None 8 | graph_sst5: 9 | none: None 10 | twitter: 11 | none: None 12 | ba_shapes: 13 | none: None 14 | ba_2motifs: 15 | none: None 16 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/explainers/gnn_lrp.yaml: -------------------------------------------------------------------------------- 1 | explanation_result_dir: '' 2 | sparsity: 0.5 3 | param: 4 | bbbp: 5 | none: None 6 | graph_sst2: 7 | none: None 8 | graph_sst5: 9 | none: None 10 | twitter: 11 | none: None 12 | ba_shapes: 13 | none: None 14 | ba_2motifs: 15 | none: None 16 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/explainers/grad_cam.yaml: -------------------------------------------------------------------------------- 1 | explanation_result_dir: '' 2 | sparsity: 0.5 3 | 4 | param: 5 | bbbp: 6 | none: None 7 | graph_sst2: 8 | none: None 9 | graph_sst5: 10 | none: None 11 | twitter: 12 | none: None 13 | ba_shapes: 14 | none: None 15 | ba_2motifs: 16 | none: None 17 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/explainers/pgexplainer.yaml: -------------------------------------------------------------------------------- 1 | explainer_saving_dir: '' 2 | explainer_saving_name: 'pgexplainer.pth' 3 | explanation_result_dir: '' 4 | sparsity: 0.5 5 | 6 | param: 7 | bbbp: 8 | ex_learning_rate: 0.05 9 | ex_epochs: 20 10 | coff_size: 0.01 11 | coff_ent: 0.01 12 | t0: 5.0 13 | t1: 1.0 14 | undirected: True 15 | sample_bias: 0 16 | ba_shapes: 17 | ex_learning_rate: 0.001 18 | ex_epochs: 10 19 | coff_size: 0.01 20 | coff_ent: 0.01 21 | t0: 5.0 22 | t1: 0.1 23 | undirected: True 24 | sample_bias: 0 25 | graph_sst2: 26 | ex_learning_rate: 3e-3 27 | ex_epochs: 20 28 | coff_size: 0.01 29 | coff_ent: 5e-4 30 | t0: 5.0 31 | t1: 1.0 32 | undirected: True 33 | sample_bias: 0 34 | graph_sst5: 35 | ex_learning_rate: 3e-3 36 | ex_epochs: 20 37 | coff_size: 0.01 38 | coff_ent: 5e-4 39 | t0: 5.0 40 | t1: 1.0 41 | undirected: True 42 | sample_bias: 0 43 | twitter: 44 | ex_learning_rate: 0.005 45 | ex_epochs: 20 46 | coff_size: 0.01 47 | coff_ent: 5e-4 48 | t0: 5.0 49 | t1: 1.0 50 | undirected: True 51 | sample_bias: 0 52 | ba_2motifs: 53 | ex_learning_rate: 3e-3 54 | ex_epochs: 20 55 | coff_size: 0.03 56 | coff_ent: 5e-4 57 | t0: 5.0 58 | t1: 1.0 59 | undirected: True 60 | sample_bias: 0 61 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/explainers/random_explainer.yaml: -------------------------------------------------------------------------------- 1 | explanation_result_dir: '' 2 | sparsity: 0.5 3 | param: 4 | bbbp: 5 | none: None 6 | graph_sst2: 7 | none: None 8 | graph_sst5: 9 | none: None 10 | twitter: 11 | none: None 12 | ba_shapes: 13 | none: None 14 | ba_2motifs: 15 | none: None 16 | -------------------------------------------------------------------------------- /benchmarks/xgraph/config/explainers/subgraphx.yaml: -------------------------------------------------------------------------------- 1 | explanation_result_dir: '' 2 | max_ex_size: 5 3 | 4 | param: 5 | bbbp: 6 | rollout: 20 7 | high2low: False 8 | c_puct: 10.0 9 | min_atoms: 5 10 | expand_atoms: 12 11 | reward_method: 'mc_l_shapley' 12 | subgraph_building_method: 'split' 13 | verbose: True 14 | ba_shapes: 15 | rollout: 20 16 | high2low: True 17 | c_puct: 10.0 18 | min_atoms: 5 19 | expand_atoms: 20 20 | reward_method: 'Nc_mc_l_shapley' 21 | subgraph_building_method: 'split' 22 | verbose: True 23 | graph_sst2: 24 | rollout: 20 25 | high2low: False 26 | c_puct: 10.0 27 | min_atoms: 5 28 | expand_atoms: 12 29 | reward_method: 'mc_l_shapley' 30 | subgraph_building_method: 'split' 31 | verbose: True 32 | graph_sst5: 33 | rollout: 20 34 | high2low: False 35 | c_puct: 10.0 36 | min_atoms: 5 37 | expand_atoms: 12 38 | reward_method: 'mc_l_shapley' 39 | subgraph_building_method: 'split' 40 | verbose: True 41 | twitter: 42 | rollout: 20 43 | high2low: False 44 | c_puct: 5.0 45 | min_atoms: 5 46 | expand_atoms: 20 47 | reward_method: 'mc_l_shapley' 48 | subgraph_building_method: 'split' 49 | verbose: True 50 | ba_2motifs: 51 | rollout: 20 52 | high2low: True 53 | c_puct: 10.0 54 | min_atoms: 5 55 | expand_atoms: 20 56 | reward_method: 'mc_l_shapley' 57 | subgraph_building_method: 'split' 58 | verbose: True -------------------------------------------------------------------------------- /benchmarks/xgraph/imgs/fidelity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/benchmarks/xgraph/imgs/fidelity.png -------------------------------------------------------------------------------- /benchmarks/xgraph/imgs/fidelity_inv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/benchmarks/xgraph/imgs/fidelity_inv.png -------------------------------------------------------------------------------- /benchmarks/xgraph/imgs/xgraph.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/benchmarks/xgraph/imgs/xgraph.jpg -------------------------------------------------------------------------------- /dig/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/__init__.py -------------------------------------------------------------------------------- /dig/auggraph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/auggraph/__init__.py -------------------------------------------------------------------------------- /dig/auggraph/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .aug_dataset import DegreeTrans, AUG_trans, Subset, TripleSet 2 | 3 | __all__ = [ 4 | 'DegreeTrans', 5 | 'AUG_trans', 6 | 'Subset', 7 | 'TripleSet' 8 | ] 9 | -------------------------------------------------------------------------------- /dig/auggraph/method/GraphAug/__init__.py: -------------------------------------------------------------------------------- 1 | from .runner_reward_gen import RunnerRewardGen 2 | from .runner_generator import RunnerGenerator 3 | from .runner_aug_cls import RunnerAugCls 4 | 5 | __all__ = [ 6 | 'RunnerRewardGen', 7 | 'RunnerGenerator', 8 | 'RunnerAugCls' 9 | ] 10 | -------------------------------------------------------------------------------- /dig/auggraph/method/GraphAug/aug/__init__.py: -------------------------------------------------------------------------------- 1 | from .node_fm import NodeFM 2 | from .node_drop import NodeDrop 3 | from .edge_per import EdgePer 4 | from .augmenter import Augmenter -------------------------------------------------------------------------------- /dig/auggraph/method/GraphAug/constants/__init__.py: -------------------------------------------------------------------------------- 1 | from .conf_params import * 2 | from .enums import * 3 | -------------------------------------------------------------------------------- /dig/auggraph/method/GraphAug/constants/conf_params.py: -------------------------------------------------------------------------------- 1 | # Author: Youzhi Luo (yzluo@tamu.edu) 2 | # Updated by: Anmol Anand (aanand@tamu.edu) 3 | 4 | BATCH_SIZE = 'batch_size' 5 | INITIAL_LR = 'start_lr' 6 | GENERATOR_STEPS = 'g_steps' 7 | TEST_INTERVAL = 'test_interval' 8 | LOSS_INTERVAL = 'loss_interval' 9 | FACTOR = 'factor' 10 | PATIENCE = 'patience' 11 | MIN_LR = 'min_lr' 12 | MAX_NUM_EPOCHS = 'max_num_epochs' 13 | REWARD_GEN_STATE_PATH = 'dis_model_path' 14 | BASELINE = 'baseline' 15 | MOVING_RATIO = 'moving_ratio' 16 | SAVE_MODEL = 'save_model' 17 | GENERATOR_PARAMS = 'gen_param' 18 | 19 | NUM_LAYERS = 'num_layers' 20 | HID_DIM = 'hid_dim' 21 | MAX_NUM_AUG = 'max_num_aug' 22 | USE_STOP_AUG = 'use_stop_aug' 23 | UNIFORM = 'uniform' 24 | RNN_INPUT = 'rnn_input' 25 | AUG_TYPE_PARAMS = 'aug_type_param_dict' 26 | 27 | TEMPERATURE = 'temperature' 28 | TRAINING = 'training' 29 | MAGNITUDE = 'magnitude' 30 | NODE_FEAT_DIM = 'node_feat_dim' 31 | 32 | HIDDEN_UNITS = 'hidden' 33 | MODEL_TYPE = 'model_type' 34 | POOL_TYPE = 'pool_type' 35 | FUSE_TYPE = 'fuse_type' 36 | PRE_TRAIN_PATH = 'pre_train_path' 37 | IN_DIMENSION = 'in_dim' 38 | EDGE_IN_DIMENSION = 'edge_in_dim' 39 | REWARD_GEN_PARAMS = 'dis_param' 40 | 41 | MODEL_NAME = 'model_name' 42 | DROPOUT = 'dropout' 43 | AUG_MODEL_PATH = 'aug_model_path' 44 | NUM_CLASSES = 'num_classes' 45 | -------------------------------------------------------------------------------- /dig/auggraph/method/GraphAug/constants/enums.py: -------------------------------------------------------------------------------- 1 | # Author: Youzhi Luo (yzluo@tamu.edu) 2 | # Updated by: Anmol Anand (aanand@tamu.edu) 3 | 4 | from enum import Enum 5 | 6 | class FuseType(Enum): 7 | ABS_DIFF = 'abs_diff' 8 | CONCAT = 'concat' 9 | COSINE = 'cosine' 10 | ADD = 'add' 11 | MULTIPLY = 'multiply' 12 | 13 | class AugType(Enum): 14 | NODE_FM = 'node_fm' 15 | NODE_DROP = 'node_drop' 16 | EDGE_Per = 'edge_per' 17 | 18 | class BaselineType(Enum): 19 | EXP = 'exp' 20 | MEAN = 'mean' 21 | 22 | class DatasetName(Enum): 23 | NCI1 = 'NCI1' 24 | COLLAB = 'COLLAB' 25 | MUTAG = 'MUTAG' 26 | PROTEINS = 'PROTEINS' 27 | IMDB_BINARY = 'IMDB-BINARY' 28 | NCI109 = 'NCI109' 29 | AIDS = 'AIDS' 30 | 31 | class PoolType(Enum): 32 | SUM = 'sum' 33 | MEAN = 'mean' 34 | MAX = 'max' 35 | 36 | class RewardGenModelType(Enum): 37 | GMNET = 'gmnet' 38 | GENET = 'genet' 39 | 40 | class RnnInputType(Enum): 41 | VIRTUAL = 'virtual' 42 | ONE_HOT = 'one-hot' 43 | 44 | class NodeUpdateType(Enum): 45 | MLP = 'mlp' 46 | RESIDUAL = 'residual' 47 | GRU = 'gru' 48 | 49 | class ReduceType(Enum): 50 | ADD = 'add' 51 | 52 | class ConvType(Enum): 53 | GEMB = 'gemb' 54 | GIN = 'gin' 55 | 56 | class MaskType(Enum): 57 | ZERO = 'zero' 58 | GAUSSIAN = 'gaussian' 59 | 60 | class CLSModelType(Enum): 61 | GIN = 'gin' 62 | GCN = 'gcn' 63 | -------------------------------------------------------------------------------- /dig/auggraph/method/GraphAug/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .reward_generator import RewardGenModel 2 | from .classifier import GIN, GCN 3 | -------------------------------------------------------------------------------- /dig/auggraph/method/GraphAug/model/reward_generator.py: -------------------------------------------------------------------------------- 1 | # Author: Youzhi Luo (yzluo@tamu.edu) 2 | # Updated by: Anmol Anand (aanand@tamu.edu) 3 | 4 | import torch 5 | import torch.nn as nn 6 | from .genet import GENet 7 | from .gmnet import GMNet 8 | from dig.auggraph.method.GraphAug.constants import * 9 | 10 | 11 | class RewardGenModel(torch.nn.Module): 12 | def __init__(self, in_dim, num_layers, hidden, pool_type=PoolType.SUM, model_type=RewardGenModelType.GMNET, fuse_type=FuseType.ABS_DIFF, **kwargs): 13 | super(RewardGenModel, self).__init__() 14 | if model_type == RewardGenModelType.GMNET: 15 | self.reward_gen_encoder = GMNet(in_dim, num_layers, hidden, pool_type=pool_type, **kwargs) 16 | elif model_type == RewardGenModelType.GENET: 17 | self.reward_gen_encoder = GENet(in_dim, num_layers, hidden, pool_type=pool_type, **kwargs) 18 | 19 | self.fuse_type = fuse_type 20 | if fuse_type == FuseType.CONCAT: 21 | self.pred_head = nn.Sequential( 22 | nn.Linear(2 * hidden, 2 * hidden), 23 | nn.ReLU(), 24 | nn.Linear(2 * hidden, 1), 25 | nn.Sigmoid(), 26 | ) 27 | elif fuse_type == FuseType.COSINE: 28 | self.pred_head = nn.Sequential( 29 | nn.Linear(hidden, 2 * hidden), 30 | nn.ReLU(), 31 | nn.Linear(2 * hidden, hidden) 32 | ) 33 | self.cos = torch.nn.CosineSimilarity(dim=1) 34 | else: 35 | in_hidden = hidden 36 | self.pred_head = nn.Sequential( 37 | nn.Linear(in_hidden, 2 * hidden), 38 | nn.ReLU(), 39 | nn.Linear(2 * hidden, 1), 40 | nn.Sigmoid(), 41 | ) 42 | 43 | def forward(self, data1, data2): 44 | embed1, embed2 = self.reward_gen_encoder(data1, data2) 45 | 46 | if self.fuse_type == FuseType.ADD: 47 | pair_embed = embed1 + embed2 48 | elif self.fuse_type == FuseType.MULTIPLY: 49 | pair_embed = embed1 * embed2 50 | elif self.fuse_type == FuseType.CONCAT: 51 | pair_embed = torch.cat((embed1, embed2), dim=1) 52 | elif self.fuse_type == FuseType.ABS_DIFF: 53 | pair_embed = torch.abs(embed1 - embed2) 54 | elif self.fuse_type == FuseType.COSINE: 55 | embed1, embed2 = self.pred_head(embed1), self.pred_head(embed2) 56 | prob = (1.0 + self.cos(embed1, embed2)) / 2.0 57 | return prob 58 | 59 | prob = self.pred_head(pair_embed) 60 | return prob 61 | -------------------------------------------------------------------------------- /dig/auggraph/method/SMixup/__init__.py: -------------------------------------------------------------------------------- 1 | from .smixup import smixup 2 | 3 | __all__ = [ 4 | 'smixup' 5 | ] 6 | -------------------------------------------------------------------------------- /dig/auggraph/method/SMixup/model/GraphMatching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .GMNET import GMNet 4 | 5 | class GraphMatching(torch.nn.Module): 6 | def __init__(self, in_dim, num_layers, hidden, pool_type='sum', model_type='gmnet', fuse_type='abs_diff', ogb = False, **kwargs): 7 | super(GraphMatching, self).__init__() 8 | self.dis_encoder = GMNet(in_dim, num_layers, hidden, pool_type=pool_type, ogb = ogb, **kwargs) 9 | 10 | self.fuse_type = fuse_type 11 | if fuse_type == 'concat': 12 | self.pred_head = nn.Sequential( 13 | nn.Linear(2 * hidden, 2 * hidden), 14 | nn.ReLU(), 15 | nn.Linear(2 * hidden, 1), 16 | nn.Sigmoid(), 17 | ) 18 | elif fuse_type == 'cos': 19 | self.pred_head = nn.Sequential( 20 | nn.Linear(hidden, 2 * hidden), 21 | nn.ReLU(), 22 | nn.Linear(2 * hidden, hidden) 23 | ) 24 | self.cos = torch.nn.CosineSimilarity(dim=1) 25 | else: 26 | in_hidden = hidden 27 | self.pred_head = nn.Sequential( 28 | nn.Linear(in_hidden, 2 * hidden), 29 | nn.ReLU(), 30 | nn.Linear(2 * hidden, 1), 31 | nn.Sigmoid(), 32 | ) 33 | 34 | def forward(self, data1, data2, pred_head = True): 35 | embed1, embed2 = self.dis_encoder(data1, data2) 36 | 37 | if (pred_head): 38 | if self.fuse_type == 'add': 39 | pair_embed = embed1 + embed2 40 | elif self.fuse_type == 'multiply': 41 | pair_embed = embed1 * embed2 42 | elif self.fuse_type == 'concat': 43 | pair_embed = torch.cat((embed1, embed2), dim=1) 44 | elif self.fuse_type == 'abs_diff': 45 | pair_embed = torch.abs(embed1 - embed2) 46 | elif self.fuse_type == 'cos': 47 | embed1, embed2 = self.pred_head(embed1), self.pred_head(embed2) 48 | prob = (1.0 + self.cos(embed1, embed2)) / 2.0 49 | return prob 50 | 51 | prob = self.pred_head(pair_embed) 52 | return prob 53 | 54 | else: 55 | return embed1, embed2 -------------------------------------------------------------------------------- /dig/auggraph/method/SMixup/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/auggraph/method/SMixup/model/__init__.py -------------------------------------------------------------------------------- /dig/auggraph/method/SMixup/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/auggraph/method/SMixup/utils/__init__.py -------------------------------------------------------------------------------- /dig/auggraph/method/SMixup/utils/utils.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.utils import degree 2 | import torch 3 | 4 | class NormalizedDegree(object): 5 | def __init__(self, mean, std): 6 | self.mean = mean 7 | self.std = std 8 | 9 | def __call__(self, data): 10 | deg = degree(data.edge_index[0], dtype=torch.float) 11 | deg = (deg - self.mean) / self.std 12 | data.x = deg.view(-1, 1) 13 | return data 14 | 15 | def euclidean_distance(x, y): 16 | """This is the squared Euclidean distance.""" 17 | return torch.sum((x - y) ** 2, dim=-1) 18 | 19 | 20 | def approximate_hamming_similarity(x, y): 21 | """Approximate Hamming similarity.""" 22 | return torch.mean(torch.tanh(x) * torch.tanh(y), dim=1) 23 | 24 | 25 | def triplet_loss(x_1, y, x_2, z, loss_type='margin', margin=1.0): 26 | """Compute triplet loss. 27 | This function computes loss on a triplet of inputs (x, y, z). A similarity or 28 | distance value is computed for each pair of (x, y) and (x, z). Since the 29 | representations for x can be different in the two pairs (like our matching 30 | model) we distinguish the two x representations by x_1 and x_2. 31 | Args: 32 | x_1: [N, D] float tensor. 33 | y: [N, D] float tensor. 34 | x_2: [N, D] float tensor. 35 | z: [N, D] float tensor. 36 | loss_type: margin or hamming. 37 | margin: float scalar, margin for the margin loss. 38 | Returns: 39 | loss: [N] float tensor. Loss for each pair of representations. 40 | """ 41 | if loss_type == 'margin': 42 | return torch.relu(margin + 43 | euclidean_distance(x_1, y) - 44 | euclidean_distance(x_2, z)) 45 | elif loss_type == 'hamming': 46 | return 0.125 * ((approximate_hamming_similarity(x_1, y) - 1) ** 2 + 47 | (approximate_hamming_similarity(x_2, z) + 1) ** 2) 48 | else: 49 | raise ValueError('Unknown loss_type %s' % loss_type) 50 | -------------------------------------------------------------------------------- /dig/auggraph/method/__init__.py: -------------------------------------------------------------------------------- 1 | from . import GraphAug 2 | from . import SMixup 3 | 4 | __all__ = [ 5 | 'GraphAug', 6 | 'SMixup' 7 | ] 8 | -------------------------------------------------------------------------------- /dig/fairgraph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/fairgraph/__init__.py -------------------------------------------------------------------------------- /dig/fairgraph/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .fairgraph_dataset import POKEC, NBA 2 | 3 | __all_ = [POKEC,NBA] -------------------------------------------------------------------------------- /dig/fairgraph/method/Graphair/GCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | import torch 6 | 7 | class GCNLayer(nn.Module): 8 | def __init__(self, input_dim, output_dim, activation = F.relu, dropout = None, bias=True): 9 | super(GCNLayer, self).__init__() 10 | self.W = nn.Parameter(torch.FloatTensor(input_dim, output_dim)) 11 | self.activation = activation 12 | if bias: 13 | self.b = nn.Parameter(torch.FloatTensor(output_dim)) 14 | else: 15 | self.b = None 16 | if dropout: 17 | self.dropout = nn.Dropout(p=dropout) 18 | else: 19 | self.dropout = None 20 | self.init_params() 21 | 22 | def init_params(self): 23 | if self.W is not None: 24 | init.xavier_uniform_(self.W) 25 | if self.b is not None: 26 | init.zeros_(self.b) 27 | 28 | def forward(self, adj, h): 29 | if self.dropout: 30 | h = self.dropout(h) 31 | x = h @ self.W 32 | 33 | x = adj @ x 34 | if self.b is not None: 35 | x = x + self.b 36 | if self.activation: 37 | x = self.activation(x) 38 | return x 39 | 40 | class GCN_Body(nn.Module): 41 | def __init__(self, in_feats, n_hidden, out_feats, dropout, nlayer): 42 | super(GCN_Body, self).__init__() 43 | self.layers = nn.ModuleList() 44 | 45 | # input layer 46 | self.layers.append(GCNLayer(in_feats, n_hidden)) 47 | # hidden layers 48 | for i in range(nlayer - 2): 49 | self.layers.append(GCNLayer(n_hidden, n_hidden)) 50 | # output layer 51 | self.layers.append(GCNLayer(n_hidden, out_feats)) 52 | 53 | self.dropout = nn.Dropout(dropout) 54 | 55 | def forward(self, g, x): 56 | h = x 57 | cnt = 0 58 | for layer in self.layers: 59 | if self.dropout and cnt != 0: 60 | h = self.dropout(h) 61 | cnt += 1 62 | h = (layer(g, h)) 63 | return h 64 | 65 | class GCN(nn.Module): 66 | def __init__(self, in_feats, n_hidden, out_feats, nclass, dropout = 0.2, nlayer = 2): 67 | super(GCN, self).__init__() 68 | self.body = GCN_Body(in_feats, n_hidden, out_feats, dropout, nlayer) 69 | self.fc = nn.Sequential( 70 | nn.Linear(out_feats, n_hidden), 71 | nn.ReLU(), 72 | nn.Linear(n_hidden, nclass), 73 | ) 74 | 75 | def forward(self, g, x): 76 | h = self.body(g, x) 77 | x = self.fc(h) 78 | return x , h -------------------------------------------------------------------------------- /dig/fairgraph/method/Graphair/__init__.py: -------------------------------------------------------------------------------- 1 | from .graphair import graphair 2 | from .aug_module import aug_module 3 | from .GCN import GCN,GCN_Body 4 | from .classifier import Classifier 5 | 6 | 7 | 8 | __all__ = [ 9 | 'graphair', 10 | 'aug_module', 11 | 'GCN', 12 | 'GCN_Body', 13 | 'Classifier' 14 | ] -------------------------------------------------------------------------------- /dig/fairgraph/method/Graphair/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Classifier(nn.Module): 5 | def __init__(self, input_dim, hidden_dim) -> None: 6 | super(Classifier,self).__init__() 7 | self.model = nn.Sequential( 8 | nn.Linear(input_dim,hidden_dim), 9 | nn.ReLU(), 10 | nn.Linear(hidden_dim,1) 11 | ) 12 | 13 | def forward(self,h): 14 | return self.model(h) 15 | 16 | def reset_parameters(self) -> None: 17 | for layer in self.model: 18 | if isinstance(layer, nn.Linear): 19 | layer.reset_parameters() 20 | 21 | -------------------------------------------------------------------------------- /dig/fairgraph/method/__init__.py: -------------------------------------------------------------------------------- 1 | from .run import run 2 | 3 | __all_ = ['run'] -------------------------------------------------------------------------------- /dig/fairgraph/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/fairgraph/utils/__init__.py -------------------------------------------------------------------------------- /dig/fairgraph/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | 5 | def scipysp_to_pytorchsp(sp_mx): 6 | """ converts scipy sparse matrix to pytorch sparse matrix """ 7 | if not sp.isspmatrix_coo(sp_mx): 8 | sp_mx = sp_mx.tocoo() 9 | coords = np.vstack((sp_mx.row, sp_mx.col)).transpose() 10 | values = sp_mx.data 11 | shape = sp_mx.shape 12 | pyt_sp_mx = torch.sparse.FloatTensor(torch.LongTensor(coords.T), 13 | torch.FloatTensor(values), 14 | torch.Size(shape)) 15 | return pyt_sp_mx 16 | 17 | def accuracy(output, labels): 18 | output = output.squeeze() 19 | preds = (output>0).type_as(labels) 20 | correct = preds.eq(labels).double() 21 | correct = correct.sum() 22 | return correct / len(labels) 23 | 24 | def fair_metric(output,idx, labels, sens): 25 | val_y = labels[idx].cpu().numpy() 26 | idx_s0 = sens.cpu().numpy()[idx.cpu().numpy()]==0 27 | idx_s1 = sens.cpu().numpy()[idx.cpu().numpy()]==1 28 | 29 | idx_s0_y1 = np.bitwise_and(idx_s0,val_y==1) 30 | idx_s1_y1 = np.bitwise_and(idx_s1,val_y==1) 31 | 32 | pred_y = (output[idx].squeeze()>0).type_as(labels).cpu().numpy() 33 | parity = abs(sum(pred_y[idx_s0])/sum(idx_s0)-sum(pred_y[idx_s1])/sum(idx_s1)) 34 | equality = abs(sum(pred_y[idx_s0_y1])/sum(idx_s0_y1)-sum(pred_y[idx_s1_y1])/sum(idx_s1_y1)) 35 | 36 | return parity,equality -------------------------------------------------------------------------------- /dig/ggraph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/ggraph/__init__.py -------------------------------------------------------------------------------- /dig/ggraph/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .PygDataset import PygDataset 2 | from .ggraph_dataset import QM9, ZINC250k, ZINC800, MOSES 3 | 4 | __all__ = [ 5 | 'QM9', 6 | 'ZINC250k', 7 | 'ZINC800', 8 | 'MOSES', 9 | 'PygDataset' 10 | ] -------------------------------------------------------------------------------- /dig/ggraph/dataset/config.csv: -------------------------------------------------------------------------------- 1 | ,zinc250k,zinc_800_graphaf,zinc_800_jt,zinc250k_property,qm9_property,moses,qm9,0,1,2,3,4,5,6 2 | smile,smiles,smiles,smiles,smile,smile,smiles,SMILES1 3 | prop_list,['qed'],['penalized_logp'],['penalized_logp'],"['qed', 'penalized_logp']","['qed', 'penalized_logp']",[],[] 4 | url,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc250k.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc_800_graphaf.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc_800_jt.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc250k_property.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/qm9_property.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/moses.csv,https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/qm9.csv 5 | num_max_node,38,38,38,38,9,38,9 6 | atom_list,"[6, 7, 8, 9, 15, 16, 17, 35, 53]","[6, 7, 8, 9, 15, 16, 17, 35, 53]","[6, 7, 8, 9, 15, 16, 17, 35, 53]","[6, 7, 8, 9, 15, 16, 17, 35, 53]","[6, 7, 8, 9]","[6, 7, 8, 9, 15, 16, 17, 35, 53]","[6, 7, 8, 9]" -------------------------------------------------------------------------------- /dig/ggraph/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import RandGenEvaluator, PropOptEvaluator, ConstPropOptEvaluator 2 | 3 | __all__ = [ 4 | 'RandGenEvaluator', 5 | 'PropOptEvaluator', 6 | 'ConstPropOptEvaluator' 7 | ] -------------------------------------------------------------------------------- /dig/ggraph/method/GraphAF/__init__.py: -------------------------------------------------------------------------------- 1 | from .graphaf import GraphAF 2 | 3 | __all__ = [ 4 | 'GraphAF' 5 | ] 6 | -------------------------------------------------------------------------------- /dig/ggraph/method/GraphAF/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .graphflow import * 2 | from .graphflow_con_rl import * 3 | from .graphflow_rl import * 4 | -------------------------------------------------------------------------------- /dig/ggraph/method/GraphAF/train_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def adjust_learning_rate(optimizer, cur_iter, init_lr, warm_up_step): 4 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 5 | # if warm up step is 0, no warm up actually. 6 | if cur_iter < warm_up_step: 7 | lr = init_lr * (1. / warm_up_step + 1. / warm_up_step * cur_iter) # [0.1lr, 0.2lr, 0.3lr, ..... 1lr] 8 | else: 9 | lr = init_lr 10 | #lr = args.lr * (0.1 ** (epoch // 30)) 11 | for param_group in optimizer.param_groups: 12 | param_group['lr'] = lr 13 | 14 | 15 | class DataIterator(object): 16 | def __init__(self, dataloader): 17 | self.iterator = self.one_shot_iterator(dataloader) 18 | 19 | def __next__(self): 20 | data = next(self.iterator) 21 | return data 22 | 23 | @staticmethod 24 | def one_shot_iterator(dataloader): 25 | ''' 26 | Transform a PyTorch Dataloader into python iterator 27 | ''' 28 | while True: 29 | for data in dataloader: 30 | yield data -------------------------------------------------------------------------------- /dig/ggraph/method/GraphDF/__init__.py: -------------------------------------------------------------------------------- 1 | from .graphdf import GraphDF 2 | 3 | __all__ = [ 4 | 'GraphDF' 5 | ] -------------------------------------------------------------------------------- /dig/ggraph/method/GraphDF/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .graphflow import GraphFlowModel 2 | from .graphflow_rl import GraphFlowModel_rl 3 | from .graphflow_con_rl import GraphFlowModel_con_rl -------------------------------------------------------------------------------- /dig/ggraph/method/GraphDF/train_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def adjust_learning_rate(optimizer, cur_iter, init_lr, warm_up_step): 4 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 5 | # if warm up step is 0, no warm up actually. 6 | if cur_iter < warm_up_step: 7 | lr = init_lr * (1. / warm_up_step + 1. / warm_up_step * cur_iter) # [0.1lr, 0.2lr, 0.3lr, ..... 1lr] 8 | else: 9 | lr = init_lr 10 | #lr = args.lr * (0.1 ** (epoch // 30)) 11 | for param_group in optimizer.param_groups: 12 | param_group['lr'] = lr 13 | 14 | 15 | class DataIterator(object): 16 | def __init__(self, dataloader): 17 | self.iterator = self.one_shot_iterator(dataloader) 18 | 19 | def __next__(self): 20 | data = next(self.iterator) 21 | return data 22 | 23 | @staticmethod 24 | def one_shot_iterator(dataloader): 25 | ''' 26 | Transform a PyTorch Dataloader into python iterator 27 | ''' 28 | while True: 29 | for data in dataloader: 30 | yield data -------------------------------------------------------------------------------- /dig/ggraph/method/GraphEBM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/ggraph/method/GraphEBM/__init__.py -------------------------------------------------------------------------------- /dig/ggraph/method/GraphEBM/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rescale_adj(adj, dim='all'): 5 | 6 | if dim=='view': 7 | out_degree = adj.sum(dim=-1) 8 | out_degree_sqrt_inv = out_degree.pow(-1) 9 | out_degree_sqrt_inv[out_degree_sqrt_inv == float('inf')] = 0 10 | adj_prime = out_degree_sqrt_inv.unsqueeze(-1) * adj 11 | 12 | elif dim=='all': 13 | num_neighbors = adj.sum(dim=(1, 2)).float() 14 | num_neighbors_inv = num_neighbors.pow(-1) 15 | num_neighbors_inv[num_neighbors_inv == float('inf')] = 0 16 | adj_prime = num_neighbors_inv.unsqueeze(1).unsqueeze(2) * adj 17 | return adj_prime 18 | 19 | 20 | def requires_grad(parameters, flag=True): 21 | for p in parameters: 22 | p.requires_grad = flag 23 | 24 | 25 | def clip_grad(optimizer): 26 | with torch.no_grad(): 27 | for group in optimizer.param_groups: 28 | for p in group['params']: 29 | state = optimizer.state[p] 30 | 31 | if 'step' not in state or state['step'] < 1: 32 | continue 33 | 34 | step = state['step'] 35 | exp_avg_sq = state['exp_avg_sq'] 36 | _, beta2 = group['betas'] 37 | 38 | bound = 3 * torch.sqrt(exp_avg_sq / (1 - beta2 ** step)) + 0.1 39 | p.grad.data.copy_(torch.max(torch.min(p.grad.data, bound), -bound)) -------------------------------------------------------------------------------- /dig/ggraph/method/JTVAE/__init__.py: -------------------------------------------------------------------------------- 1 | from .jtvae import JTVAE 2 | 3 | __all__ = [ 4 | 'JTVAE', 5 | ] 6 | -------------------------------------------------------------------------------- /dig/ggraph/method/JTVAE/fast_jtnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .vocab import Vocab 2 | from .mol_tree import Vocab, MolTree 3 | from .jtnn_vae import JTNNVAE 4 | from .jtnn_enc import JTNNEncoder 5 | from .jtmpn import JTMPN 6 | from .mpn import MPN 7 | from .nnutils import create_var 8 | from .datautils import MolTreeFolder, PairTreeFolder, MolTreeDataset 9 | -------------------------------------------------------------------------------- /dig/ggraph/method/JTVAE/fast_jtnn/nnutils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | def create_var(tensor, requires_grad=None): 8 | if requires_grad is None: 9 | return Variable(tensor).cuda() 10 | else: 11 | return Variable(tensor, requires_grad=requires_grad).cuda() 12 | 13 | 14 | def index_select_ND(source, dim, index): 15 | index_size = index.size() 16 | suffix_dim = source.size()[1:] 17 | final_size = index_size + suffix_dim 18 | target = source.index_select(dim, index.view(-1)) 19 | return target.view(final_size) 20 | 21 | 22 | def avg_pool(all_vecs, scope, dim): 23 | size = create_var(torch.Tensor([le for _, le in scope])) 24 | return all_vecs.sum(dim=dim) / size.unsqueeze(-1) 25 | 26 | 27 | def stack_pad_tensor(tensor_list): 28 | max_len = max([t.size(0) for t in tensor_list]) 29 | for i, tensor in enumerate(tensor_list): 30 | pad_len = max_len - tensor.size(0) 31 | tensor_list[i] = F.pad(tensor, (0, 0, 0, pad_len)) 32 | return torch.stack(tensor_list, dim=0) 33 | 34 | # 3D padded tensor to 2D matrix, with padded zeros removed 35 | 36 | 37 | def flatten_tensor(tensor, scope): 38 | assert tensor.size(0) == len(scope) 39 | tlist = [] 40 | for i, tup in enumerate(scope): 41 | le = tup[1] 42 | tlist.append(tensor[i, 0:le]) 43 | return torch.cat(tlist, dim=0) 44 | 45 | # 2D matrix to 3D padded tensor 46 | 47 | 48 | def inflate_tensor(tensor, scope): 49 | max_len = max([le for _, le in scope]) 50 | batch_vecs = [] 51 | for st, le in scope: 52 | cur_vecs = tensor[st: st + le] 53 | cur_vecs = F.pad(cur_vecs, (0, 0, 0, max_len-le)) 54 | batch_vecs.append(cur_vecs) 55 | 56 | return torch.stack(batch_vecs, dim=0) 57 | 58 | 59 | def GRU(x, h_nei, W_z, W_r, U_r, W_h): 60 | hidden_size = x.size()[-1] 61 | sum_h = h_nei.sum(dim=1) 62 | z_input = torch.cat([x, sum_h], dim=1) 63 | z = F.sigmoid(W_z(z_input)) 64 | 65 | r_1 = W_r(x).view(-1, 1, hidden_size) 66 | r_2 = U_r(h_nei) 67 | r = F.sigmoid(r_1 + r_2) 68 | 69 | gated_h = r * h_nei 70 | sum_gated_h = gated_h.sum(dim=1) 71 | h_input = torch.cat([x, sum_gated_h], dim=1) 72 | pre_h = F.tanh(W_h(h_input)) 73 | new_h = (1.0 - z) * sum_h + z * pre_h 74 | return new_h 75 | -------------------------------------------------------------------------------- /dig/ggraph/method/JTVAE/fast_jtnn/vocab.py: -------------------------------------------------------------------------------- 1 | import rdkit 2 | import rdkit.Chem as Chem 3 | import copy 4 | 5 | 6 | def get_slots(smiles): 7 | mol = Chem.MolFromSmiles(smiles) 8 | return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()] 9 | 10 | 11 | class Vocab(object): 12 | benzynes = ['C1=CC=CC=C1', 'C1=CC=NC=C1', 'C1=CC=NN=C1', 'C1=CN=CC=N1', 13 | 'C1=CN=CN=C1', 'C1=CN=NC=N1', 'C1=CN=NN=C1', 'C1=NC=NC=N1', 'C1=NN=CN=N1'] 14 | penzynes = ['C1=C[NH]C=C1', 'C1=C[NH]C=N1', 'C1=C[NH]N=C1', 'C1=C[NH]N=N1', 'C1=COC=C1', 'C1=COC=N1', 'C1=CON=C1', 'C1=CSC=C1', 'C1=CSC=N1', 'C1=CSN=C1', 'C1=CSN=N1', 15 | 'C1=NN=C[NH]1', 'C1=NN=CO1', 'C1=NN=CS1', 'C1=N[NH]C=N1', 'C1=N[NH]N=C1', 'C1=N[NH]N=N1', 'C1=NN=N[NH]1', 'C1=NN=NS1', 'C1=NOC=N1', 'C1=NON=C1', 'C1=NSC=N1', 'C1=NSN=C1'] 16 | 17 | def __init__(self, smiles_list, bayesian_optimization=False): 18 | self.vocab = smiles_list 19 | self.vmap = {x: i for i, x in enumerate(self.vocab)} 20 | self.slots = [get_slots(smiles) for smiles in self.vocab] 21 | if not bayesian_optimization: 22 | Vocab.benzynes = [s for s in smiles_list if s.count( 23 | '=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 6] + ['C1=CCNCC1'] 24 | Vocab.penzynes = [s for s in smiles_list if s.count( 25 | '=') >= 2 and Chem.MolFromSmiles(s).GetNumAtoms() == 5] + ['C1=NCCN1', 'C1=NNCC1'] 26 | 27 | def get_index(self, smiles): 28 | return self.vmap[smiles] 29 | 30 | def get_smiles(self, idx): 31 | return self.vocab[idx] 32 | 33 | def get_slots(self, idx): 34 | return copy.deepcopy(self.slots[idx]) 35 | 36 | def size(self): 37 | return len(self.vocab) 38 | -------------------------------------------------------------------------------- /dig/ggraph/method/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import Generator 2 | from .GraphEBM.graphebm import GraphEBM 3 | from .GraphDF import GraphDF 4 | from .GraphAF import GraphAF 5 | from .JTVAE.jtvae import JTVAE 6 | 7 | 8 | __all__ = [ 9 | 'Generator', 10 | 'GraphEBM', 11 | 'GraphDF', 12 | 'GraphAF', 13 | 'JTVAE', 14 | ] 15 | -------------------------------------------------------------------------------- /dig/ggraph/method/generator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Generator(): 4 | r""" 5 | The method base class for graph generation. To write a new graph generation method, create a new class 6 | inheriting from this class and implement the functions. 7 | """ 8 | 9 | def train_rand_gen(self, loader, *args, **kwargs): 10 | r""" 11 | Running training for random generation task. 12 | 13 | Args: 14 | loader: The data loader for loading training samples. 15 | """ 16 | 17 | raise NotImplementedError("The function train_rand_gen is not implemented!") 18 | 19 | def run_rand_gen(self, *args, **kwargs): 20 | r""" 21 | Running graph generation for random generation task. 22 | """ 23 | 24 | raise NotImplementedError("The function run_rand_gen is not implemented!") 25 | 26 | def train_prop_opt(self, *args, **kwargs): 27 | r""" 28 | Running training for property optimization task. 29 | """ 30 | 31 | raise NotImplementedError("The function train_prop_opt is not implemented!") 32 | 33 | def run_prop_opt(self, *args, **kwargs): 34 | r""" 35 | Running graph generation for property optimization task. 36 | """ 37 | 38 | raise NotImplementedError("The function run_prop_opt is not implemented!") 39 | 40 | def train_const_prop_opt(self, loader, *args, **kwargs): 41 | r""" 42 | Running training for constrained optimization task. 43 | 44 | Args: 45 | loader: The data loader for loading training samples. 46 | """ 47 | 48 | raise NotImplementedError("The function train_const_prop_opt is not implemented!") 49 | 50 | def run_const_prop_opt(self, *args, **kwargs): 51 | r""" 52 | Running molecule optimization for constrained optimization task. 53 | """ 54 | 55 | raise NotImplementedError("The function run_const_prop_opt is not implemented!") -------------------------------------------------------------------------------- /dig/ggraph/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import check_chemical_validity, check_valency, calculate_min_plogp, reward_target_molecule_similarity, qed 2 | from .environment import convert_radical_electrons_to_hydrogens, steric_strain_filter, zinc_molecule_filter 3 | from .gen_mol_from_one_shot_tensor import gen_mol_from_one_shot_tensor 4 | 5 | 6 | __all__ = [ 7 | 'check_chemical_validity', 8 | 'check_valency', 9 | 'calculate_min_plogp', 10 | 'reward_target_molecule_similarity', 11 | 'qed', 12 | 'convert_radical_electrons_to_hydrogens', 13 | 'steric_strain_filter', 14 | 'zinc_molecule_filter', 15 | 'gen_mol_from_one_shot_tensor' 16 | ] 17 | -------------------------------------------------------------------------------- /dig/ggraph/utils/fpscores.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/ggraph/utils/fpscores.pkl.gz -------------------------------------------------------------------------------- /dig/ggraph3D/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/ggraph3D/__init__.py -------------------------------------------------------------------------------- /dig/ggraph3D/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .ggraph3D_dataset import QM93DGEN, collate_fn 2 | 3 | __all__ = [ 4 | "QM93DGEN", 5 | "collate_fn" 6 | ] -------------------------------------------------------------------------------- /dig/ggraph3D/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric import RandGenEvaluator, PropOptEvaluator 2 | 3 | __all__ = [ 4 | 'RandGenEvaluator', 5 | 'PropOptEvaluator' 6 | ] -------------------------------------------------------------------------------- /dig/ggraph3D/method/G_SphereNet/__init__.py: -------------------------------------------------------------------------------- 1 | from .gspherenet import G_SphereNet 2 | 3 | __all__ = [ 4 | 'G_SphereNet' 5 | ] -------------------------------------------------------------------------------- /dig/ggraph3D/method/G_SphereNet/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .sphgen import SphGen -------------------------------------------------------------------------------- /dig/ggraph3D/method/G_SphereNet/model/att.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.utils import softmax 4 | from torch_scatter import scatter 5 | 6 | 7 | 8 | class MH_ATT(nn.Module): 9 | def __init__(self, n_att_heads=4, q_dim=128, k_dim=128, v_dim=128, out_dim=128): 10 | super(MH_ATT, self).__init__() 11 | self.n_att_heads = n_att_heads 12 | self.d_k = out_dim // n_att_heads 13 | self.q_proj = nn.Linear(q_dim, out_dim, bias=True) 14 | self.k_proj = nn.Linear(k_dim, out_dim, bias=True) 15 | self.v_proj = nn.Linear(v_dim, out_dim, bias=True) 16 | self.out_proj = nn.Linear(out_dim, out_dim, bias=True) 17 | 18 | def forward(self, query, key, value, query_batch, key_value_batch): 19 | query_proj = self.q_proj(query).view(-1, self.n_att_heads, self.d_k) 20 | key_proj = self.k_proj(key).view(-1, self.n_att_heads, self.d_k) 21 | value_proj = self.v_proj(value).view(-1, self.n_att_heads, self.d_k) 22 | 23 | n_querys = query_proj.shape[0] 24 | key_value_mask = (key_value_batch[:,None] == query_batch[None,:]).sum(dim=-1) > 0 25 | key_proj, value_proj = key_proj[key_value_mask], value_proj[key_value_mask] 26 | query_num_nodes = (key_value_batch[:,None] == query_batch[None,:]).sum(dim=0) 27 | query_proj = torch.repeat_interleave(query_proj, query_num_nodes, dim=0) 28 | 29 | scaled_dots = torch.sum(query_proj * key_proj, dim=-1) / torch.sqrt(torch.tensor(self.d_k, dtype=float)) 30 | new_query_batch = torch.repeat_interleave(torch.arange(n_querys, device=query_num_nodes.device), query_num_nodes, dim=0) 31 | att_scores = softmax(scaled_dots, index=new_query_batch, num_nodes=n_querys) 32 | att_outs = scatter(value_proj * att_scores[:,:,None], new_query_batch, dim=0, dim_size=n_querys) 33 | outs = self.out_proj(att_outs.view(n_querys, self.d_k*self.n_att_heads)) 34 | 35 | return outs -------------------------------------------------------------------------------- /dig/ggraph3D/method/__init__.py: -------------------------------------------------------------------------------- 1 | from .G_SphereNet import G_SphereNet 2 | 3 | 4 | __all__ = [ 5 | 'G_SphereNet' 6 | ] -------------------------------------------------------------------------------- /dig/ggraph3D/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_validity_utils import xyz2mol 2 | from .eval_bond_mmd_utils import collect_bond_dists, compute_mmd 3 | from .eval_prop_utils import compute_prop 4 | 5 | __all__ = [ 6 | 'xyz2mol', 7 | 'collect_bond_dists', 8 | 'compute_mmd', 9 | 'compute_prop' 10 | ] -------------------------------------------------------------------------------- /dig/ggraph3D/utils/eval_prop_utils.py: -------------------------------------------------------------------------------- 1 | from pyscf import gto, dft 2 | from rdkit import Chem 3 | from scipy.constants import physical_constants 4 | EH2EV = physical_constants['Hartree energy in eV'][0] 5 | 6 | 7 | def geom2gap(geom): 8 | mol = gto.Mole() 9 | mol.atom = geom 10 | mol.basis = '6-31G(2df,p)' #QM9 11 | mol.nelectron += mol.nelectron % 2 # Make it full shell? Otherwise the density matrix will be 3D 12 | mol.build(0, 0) 13 | 14 | mf = dft.RKS(mol) 15 | mf.xc = 'b3lyp' 16 | mf.kernel() 17 | 18 | nocc = mol.nelectron // 2 19 | homo = mf.mo_energy[nocc - 1] * EH2EV 20 | lumo = mf.mo_energy[nocc] * EH2EV 21 | gap = lumo - homo 22 | return gap 23 | 24 | 25 | def geom2alpha(geom): 26 | mol = gto.Mole() 27 | mol.atom = geom 28 | mol.basis = '6-31G(2df,p)' #QM9 29 | # mol.basis = '6-31G*' # Kddcup 30 | mol.nelectron += mol.nelectron % 2 # Make it full shell? Otherwise the density matrix will be 3D 31 | mol.build(0, 0) 32 | 33 | mf = dft.RKS(mol) 34 | mf.xc = 'b3lyp' 35 | mf.kernel() 36 | 37 | polar = mf.Polarizability().polarizability() 38 | xx, yy, zz = polar.diagonal() 39 | return (xx + yy + zz) / 3 40 | 41 | 42 | def compute_prop(atomic_number, position, prop_name): 43 | """ 44 | Calculate the quantum property score of the given molecular geometry with `PySCF `_. 45 | 46 | Args: 47 | atomic_number (numpy array): the numpy array indicating the atomic number of atoms in the molecular geometry. 48 | position (numpy array): the numpy array indicating the coordinates of atoms in the molecular geometry. 49 | prop_name (string): the name of quantum property, 'gap' for HOMO-LUMO gap, 'alpha' for isotropic polarizability. 50 | 51 | :rtype: 52 | :class:`float` 53 | """ 54 | ptb = Chem.GetPeriodicTable() 55 | geom = [[ptb.GetElementSymbol(int(z)), position[i]] for i, z in enumerate(atomic_number)] 56 | 57 | if prop_name == 'gap': 58 | prop = geom2gap(geom) 59 | elif prop_name == 'alpha': 60 | prop = geom2alpha(geom) 61 | 62 | return prop -------------------------------------------------------------------------------- /dig/lsgraph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/lsgraph/__init__.py -------------------------------------------------------------------------------- /dig/lsgraph/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_data import get_data 2 | from .loader import SubData, SubgraphLoader, EvalSubgraphLoader 3 | 4 | 5 | __all__ = [ 6 | 'get_data', 7 | 'SubData', 8 | 'SubgraphLoader', 9 | 'EvalSubgraphLoader' 10 | ] 11 | -------------------------------------------------------------------------------- /dig/lsgraph/method/GraphFMOB/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from the GNNAutoScale https://github.com/rusty1s/pyg_autoscale 3 | """ 4 | 5 | from .history import History # noqa 6 | from .pool import AsyncIOPool # noqa 7 | from .metis import metis, permute # noqa 8 | from .utils import compute_micro_f1, gen_masks, dropout # noqa 9 | from .loader import SubgraphLoader, EvalSubgraphLoader # noqa 10 | 11 | 12 | __all__ = [ 13 | 'History', 14 | 'AsyncIOPool', 15 | 'metis', 16 | 'permute', 17 | 'compute_micro_f1', 18 | 'gen_masks', 19 | 'dropout', 20 | 'SubgraphLoader', 21 | 'EvalSubgraphLoader' 22 | ] -------------------------------------------------------------------------------- /dig/lsgraph/method/GraphFMOB/csrc/cpu/relabel_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | std::tuple, 6 | torch::Tensor> 7 | relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col, 8 | torch::optional optional_value, 9 | torch::Tensor idx, bool bipartite); 10 | -------------------------------------------------------------------------------- /dig/lsgraph/method/GraphFMOB/csrc/cuda/sync_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void synchronize_cuda(); 6 | void read_async_cuda(torch::Tensor src, 7 | torch::optional optional_offset, 8 | torch::optional optional_count, 9 | torch::Tensor index, torch::Tensor dst, 10 | torch::Tensor buffer); 11 | void write_async_cuda(torch::Tensor src, torch::Tensor offset, 12 | torch::Tensor count, torch::Tensor dst); 13 | -------------------------------------------------------------------------------- /dig/lsgraph/method/GraphFMOB/csrc/relabel.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | namespace py = pybind11; 5 | #include "cpu/relabel_cpu.h" 6 | 7 | #ifdef _WIN32 8 | PyMODINIT_FUNC PyInit__relabel(void) { return NULL; } 9 | #endif 10 | 11 | std::tuple, 12 | torch::Tensor> 13 | relabel_one_hop(torch::Tensor rowptr, torch::Tensor col, 14 | torch::optional optional_value, 15 | torch::Tensor idx, bool bipartite) { 16 | if (rowptr.device().is_cuda()) { 17 | #ifdef WITH_CUDA 18 | AT_ERROR("No CUDA version supported"); 19 | #else 20 | AT_ERROR("Not compiled with CUDA support"); 21 | #endif 22 | } 23 | else { 24 | return relabel_one_hop_cpu(rowptr, col, optional_value, idx, bipartite); 25 | } 26 | } 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("relabel_one_hop", &relabel_one_hop, "relabel_one_hop"); 30 | } -------------------------------------------------------------------------------- /dig/lsgraph/method/GraphFMOB/csrc/sync.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/sync_cuda.h" 7 | #endif 8 | 9 | #ifdef _WIN32 10 | PyMODINIT_FUNC PyInit__async(void) { return NULL; } 11 | #endif 12 | 13 | void synchronize() { 14 | #ifdef WITH_CUDA 15 | synchronize_cuda(); 16 | #else 17 | AT_ERROR("Not compiled with CUDA support"); 18 | #endif 19 | } 20 | 21 | void read_async(torch::Tensor src, 22 | torch::optional optional_offset, 23 | torch::optional optional_count, 24 | torch::Tensor index, torch::Tensor dst, torch::Tensor buffer) { 25 | #ifdef WITH_CUDA 26 | read_async_cuda(src, optional_offset, optional_count, index, dst, buffer); 27 | #else 28 | AT_ERROR("Not compiled with CUDA support"); 29 | #endif 30 | } 31 | 32 | void write_async(torch::Tensor src, torch::Tensor offset, torch::Tensor count, 33 | torch::Tensor dst) { 34 | #ifdef WITH_CUDA 35 | write_async_cuda(src, offset, count, dst); 36 | #else 37 | AT_ERROR("Not compiled with CUDA support"); 38 | #endif 39 | } 40 | 41 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 42 | m.def("synchronize", &synchronize, "synchronize"); 43 | m.def("read_async", &read_async, "read_async"); 44 | m.def("write_async", &write_async, "write_async"); 45 | } -------------------------------------------------------------------------------- /dig/lsgraph/method/GraphFMOB/csrc/thread.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | // A simple C++11 Thread Pool implementation with `num_workers=1`. 9 | // See: https://github.com/progschj/ThreadPool 10 | class Thread { 11 | public: 12 | Thread(); 13 | ~Thread(); 14 | template void run(F &&f); 15 | void synchronize(); 16 | 17 | private: 18 | bool stop; 19 | std::mutex mutex; 20 | std::thread worker; 21 | std ::condition_variable condition; 22 | std::queue> results; 23 | std::queue> tasks; 24 | }; 25 | 26 | inline Thread::Thread() : stop(false) { 27 | worker = std::thread([this] { 28 | while (true) { 29 | std::function task; 30 | { 31 | std::unique_lock lock(this->mutex); 32 | this->condition.wait( 33 | lock, [this] { return this->stop || !this->tasks.empty(); }); 34 | if (this->stop && this->tasks.empty()) 35 | return; 36 | task = std::move(this->tasks.front()); 37 | this->tasks.pop(); 38 | } 39 | task(); 40 | } 41 | }); 42 | } 43 | 44 | inline Thread::~Thread() { 45 | { 46 | std::unique_lock lock(mutex); 47 | stop = true; 48 | } 49 | condition.notify_all(); 50 | worker.join(); 51 | } 52 | 53 | template void Thread::run(F &&f) { 54 | auto task = std::make_shared>( 55 | std::bind(std::forward(f))); 56 | results.emplace(task->get_future()); 57 | { 58 | std::unique_lock lock(mutex); 59 | tasks.emplace([task]() { (*task)(); }); 60 | } 61 | condition.notify_one(); 62 | } 63 | 64 | void Thread::synchronize() { 65 | if (results.empty()) 66 | return; 67 | results.front().get(); 68 | results.pop(); 69 | } 70 | -------------------------------------------------------------------------------- /dig/lsgraph/method/GraphFMOB/history.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | class History(torch.nn.Module): 8 | r"""A historical embedding storage module.""" 9 | def __init__(self, num_embeddings: int, embedding_dim: int, device=None): 10 | super().__init__() 11 | 12 | self.num_embeddings = num_embeddings 13 | self.embedding_dim = embedding_dim 14 | 15 | pin_memory = device is None or str(device) == 'cpu' 16 | self.emb = torch.empty(num_embeddings, embedding_dim, device=device, 17 | pin_memory=pin_memory) 18 | 19 | self._device = torch.device('cpu') 20 | 21 | self.reset_parameters() 22 | 23 | def reset_parameters(self): 24 | self.emb.fill_(0) 25 | 26 | def _apply(self, fn): 27 | # Set the `_device` of the module without transfering `self.emb`. 28 | self._device = fn(torch.zeros(1)).device 29 | return self 30 | 31 | @torch.no_grad() 32 | def pull(self, n_id: Optional[Tensor] = None) -> Tensor: 33 | out = self.emb 34 | if n_id is not None: 35 | assert n_id.device == self.emb.device 36 | out = out.index_select(0, n_id) 37 | return out.to(device=self._device) 38 | 39 | @torch.no_grad() 40 | def push(self, x, n_id: Optional[Tensor] = None, 41 | offset: Optional[Tensor] = None, count: Optional[Tensor] = None): 42 | 43 | if n_id is None and x.size(0) != self.num_embeddings: 44 | raise ValueError 45 | 46 | elif n_id is None and x.size(0) == self.num_embeddings: 47 | self.emb.copy_(x) 48 | 49 | elif offset is None or count is None: 50 | assert n_id.device == self.emb.device 51 | self.emb[n_id] = x.to(self.emb.device) 52 | 53 | else: # Push in chunks: 54 | src_o = 0 55 | x = x.to(self.emb.device) 56 | for dst_o, c, in zip(offset.tolist(), count.tolist()): 57 | self.emb[dst_o:dst_o + c] = x[src_o:src_o + c] 58 | src_o += c 59 | 60 | def forward(self, *args, **kwargs): 61 | """""" 62 | raise NotImplementedError 63 | 64 | def __repr__(self) -> str: 65 | return (f'{self.__class__.__name__}({self.num_embeddings}, ' 66 | f'{self.embedding_dim}, emb_device={self.emb.device}, ' 67 | f'device={self._device})') 68 | -------------------------------------------------------------------------------- /dig/lsgraph/method/GraphFMOB/metis.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import time 4 | import copy 5 | 6 | import torch 7 | from torch import Tensor 8 | from torch_sparse import SparseTensor 9 | from torch_geometric.data import Data 10 | 11 | partition_fn = torch.ops.torch_sparse.partition 12 | 13 | 14 | def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False, 15 | log: bool = True) -> Tuple[Tensor, Tensor]: 16 | r"""Computes the METIS partition of a given sparse adjacency matrix 17 | :obj:`adj_t`, returning its "clustered" permutation :obj:`perm` and 18 | corresponding cluster slices :obj:`ptr`.""" 19 | 20 | if log: 21 | t = time.perf_counter() 22 | print(f'Computing METIS partitioning with {num_parts} parts...', 23 | end=' ', flush=True) 24 | 25 | num_nodes = adj_t.size(0) 26 | 27 | if num_parts <= 1: 28 | perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes]) 29 | else: 30 | rowptr, col, _ = adj_t.csr() 31 | cluster = partition_fn(rowptr, col, None, num_parts, recursive) 32 | cluster, perm = cluster.sort() 33 | ptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts) 34 | 35 | if log: 36 | print(f'Done! [{time.perf_counter() - t:.2f}s]') 37 | 38 | return perm, ptr 39 | 40 | 41 | def permute(data: Data, perm: Tensor, log: bool = True) -> Data: 42 | r"""Permutes a :obj:`data` object according to a given permutation 43 | :obj:`perm`.""" 44 | 45 | if log: 46 | t = time.perf_counter() 47 | print('Permuting data...', end=' ', flush=True) 48 | 49 | data = copy.copy(data) 50 | for key, value in data: 51 | if isinstance(value, Tensor) and value.size(0) == data.num_nodes: 52 | data[key] = value[perm] 53 | elif isinstance(value, Tensor) and value.size(0) == data.num_edges: 54 | raise NotImplementedError 55 | elif isinstance(value, SparseTensor): 56 | data[key] = value.permute(perm) 57 | 58 | if log: 59 | print(f'Done! [{time.perf_counter() - t:.2f}s]') 60 | 61 | return data 62 | -------------------------------------------------------------------------------- /dig/lsgraph/method/GraphFMOB/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from the https://github.com/rusty1s/pyg_autoscale/tree/master/torch_geometric_autoscale/models 3 | """ 4 | 5 | from .base import ScalableGNN 6 | from .gcn import GCN 7 | from .gcn2 import GCN2 8 | from .pna import PNA 9 | from .pna_jk import PNA_JK 10 | 11 | __all__ = [ 12 | 'ScalableGNN', 13 | 'GCN', 14 | 'GCN2', 15 | 'PNA', 16 | 'PNA_JK', 17 | ] 18 | -------------------------------------------------------------------------------- /dig/lsgraph/method/GraphFMOB/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | import torch.nn.functional as F 6 | from torch_sparse import SparseTensor 7 | 8 | 9 | 10 | def compute_micro_f1(logits: Tensor, y: Tensor, 11 | mask: Optional[Tensor] = None) -> float: 12 | if mask is not None: 13 | logits, y = logits[mask], y[mask] 14 | 15 | if y.dim() == 1: 16 | return int(logits.argmax(dim=-1).eq(y).sum()) / y.size(0) 17 | else: 18 | y_pred = logits > 0 19 | y_true = y > 0.5 20 | 21 | tp = int((y_true & y_pred).sum()) 22 | fp = int((~y_true & y_pred).sum()) 23 | fn = int((y_true & ~y_pred).sum()) 24 | 25 | try: 26 | precision = tp / (tp + fp) 27 | recall = tp / (tp + fn) 28 | return 2 * (precision * recall) / (precision + recall) 29 | except ZeroDivisionError: 30 | return 0. 31 | 32 | 33 | def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30, 34 | num_splits: int = 20) -> Tuple[Tensor, Tensor, Tensor]: 35 | num_classes = int(y.max()) + 1 36 | 37 | train_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool) 38 | val_mask = torch.zeros(y.size(0), num_splits, dtype=torch.bool) 39 | 40 | for c in range(num_classes): 41 | idx = (y == c).nonzero(as_tuple=False).view(-1) 42 | perm = torch.stack( 43 | [torch.randperm(idx.size(0)) for _ in range(num_splits)], dim=1) 44 | idx = idx[perm] 45 | 46 | train_idx = idx[:train_per_class] 47 | train_mask.scatter_(0, train_idx, True) 48 | val_idx = idx[train_per_class:train_per_class + val_per_class] 49 | val_mask.scatter_(0, val_idx, True) 50 | 51 | test_mask = ~(train_mask | val_mask) 52 | 53 | return train_mask, val_mask, test_mask 54 | 55 | 56 | def dropout(adj_t: SparseTensor, p: float, training: bool = True): 57 | if not training or p == 0.: 58 | return adj_t 59 | 60 | if adj_t.storage.value() is not None: 61 | value = F.dropout(adj_t.storage.value(), p=p) 62 | adj_t = adj_t.set_value(value, layout='coo') 63 | else: 64 | mask = torch.rand(adj_t.nnz(), device=adj_t.storage.row().device) > p 65 | adj_t = adj_t.masked_select_nnz(mask, layout='coo') 66 | 67 | return adj_t 68 | -------------------------------------------------------------------------------- /dig/lsgraph/method/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/lsgraph/method/__init__.py -------------------------------------------------------------------------------- /dig/oodgraph/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This module includes 8 GOOD datasets. 3 | 4 | - Graph prediction datasets: GOOD-HIV, GOOD-PCBA, GOOD-ZINC, GOOD-CMNIST, GOOD-Motif. 5 | - Node prediction datasets: GOOD-Cora, GOOD-Arxiv, GOOD-CBAS. 6 | """ 7 | 8 | from .good_hiv import GOODHIV 9 | from .good_arxiv import GOODArxiv 10 | from .good_pcba import GOODPCBA 11 | from .good_cmnist import GOODCMNIST 12 | from .good_cora import GOODCora 13 | from .good_cbas import GOODCBAS 14 | from .good_motif import GOODMotif 15 | from .good_zinc import GOODZINC 16 | 17 | __all__ = [ 18 | 'GOODCBAS', 19 | 'GOODZINC', 20 | 'GOODHIV', 21 | 'GOODCMNIST', 22 | 'GOODArxiv', 23 | 'GOODPCBA', 24 | 'GOODMotif', 25 | 'GOODCora' 26 | ] 27 | 28 | -------------------------------------------------------------------------------- /dig/sslgraph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/sslgraph/__init__.py -------------------------------------------------------------------------------- /dig/sslgraph/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import get_dataset, get_node_dataset 2 | from .TUDataset import TUDatasetExt 3 | 4 | __all__ = [ 5 | 'get_dataset', 6 | 'get_node_dataset', 7 | 'TUDatasetExt' 8 | ] 9 | -------------------------------------------------------------------------------- /dig/sslgraph/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_node import NodeUnsupervised 2 | from .eval_graph import GraphUnsupervised, GraphSemisupervised 3 | 4 | __all__ = [ 5 | "GraphUnsupervised", 6 | "GraphSemisupervised", 7 | "NodeUnsupervised" 8 | ] 9 | -------------------------------------------------------------------------------- /dig/sslgraph/method/__init__.py: -------------------------------------------------------------------------------- 1 | from .contrastive.model import Contrastive, GraphCL, GRACE, InfoGraph, MVGRL, NodeMVGRL, pGRACE 2 | 3 | __all__ = [ 4 | 'Contrastive', 5 | 'GraphCL', 6 | 'GRACE', 7 | 'InfoGraph', 8 | 'MVGRL', 9 | 'NodeMVGRL', 10 | 'pGRACE' 11 | ] 12 | 13 | classes = __all__ 14 | -------------------------------------------------------------------------------- /dig/sslgraph/method/contrastive/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/sslgraph/method/contrastive/__init__.py -------------------------------------------------------------------------------- /dig/sslgraph/method/contrastive/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .contrastive import Contrastive 2 | from .infograph import InfoGraph 3 | from .grace import GRACE 4 | from .graphcl import GraphCL 5 | from .mvgrl import MVGRL, NodeMVGRL 6 | from .pgrace import pGRACE -------------------------------------------------------------------------------- /dig/sslgraph/method/contrastive/model/grace.py: -------------------------------------------------------------------------------- 1 | from .contrastive import Contrastive 2 | from dig.sslgraph.method.contrastive.views_fn import NodeAttrMask, EdgePerturbation, Sequential 3 | 4 | 5 | class GRACE(Contrastive): 6 | r""" 7 | Contrastive learning method proposed in the paper `Deep Graph Contrastive Representation 8 | Learning `_. You can refer to `the benchmark code 9 | `_ for 10 | an example of usage. 11 | 12 | *Alias*: :obj:`dig.sslgraph.method.contrastive.model.`:obj:`GRACE`. 13 | 14 | Args: 15 | dim (int): The embedding dimension. 16 | dropE_rate_1, dropE_rate_2 (float): The ratio of the edge dropping augmentation for 17 | view 1. A number between [0,1). 18 | maskN_rate_1, maskN_rate_2 (float): The ratio of the node masking augmentation for 19 | view 2. A number between [0,1). 20 | **kwargs (optinal): Additional arguments of :class:`dig.sslgraph.method.Contrastive`. 21 | """ 22 | 23 | def __init__(self, dim, dropE_rate_1, dropE_rate_2, maskN_rate_1, maskN_rate_2, 24 | **kwargs): 25 | 26 | view_fn_1 = Sequential([EdgePerturbation(drop=True, ratio=dropE_rate_1), 27 | NodeAttrMask(mask_ratio=maskN_rate_1)]) 28 | view_fn_2 = Sequential([EdgePerturbation(drop=True, ratio=dropE_rate_2), 29 | NodeAttrMask(mask_ratio=maskN_rate_2)]) 30 | views_fn = [view_fn_1, view_fn_2] 31 | 32 | super(GRACE, self).__init__(objective='NCE', 33 | views_fn=views_fn, 34 | graph_level=False, 35 | node_level=True, 36 | z_n_dim=dim, 37 | proj_n='MLP', 38 | **kwargs) 39 | 40 | def train(self, encoders, data_loader, optimizer, epochs, per_epoch_out=False): 41 | # GRACE removes projection heads after pre-training 42 | for enc, proj in super().train(encoders, data_loader, 43 | optimizer, epochs, per_epoch_out): 44 | yield enc 45 | -------------------------------------------------------------------------------- /dig/sslgraph/method/contrastive/objectives/__init__.py: -------------------------------------------------------------------------------- 1 | from .infonce import NCE_loss 2 | from .jse import JSE_loss 3 | 4 | __all__ = [ 5 | "NCE_loss", 6 | "JSE_loss", 7 | ] 8 | -------------------------------------------------------------------------------- /dig/sslgraph/method/contrastive/views_fn/__init__.py: -------------------------------------------------------------------------------- 1 | from .feature import NodeAttrMask, AdaNodeAttrMask 2 | from .structure import EdgePerturbation, Diffusion, DiffusionWithSample, AdaEdgePerturbation 3 | from .sample import UniformSample, RWSample 4 | from .combination import RandomView, Sequential 5 | 6 | __all__ = [ 7 | "RandomView", 8 | "Sequential", 9 | "NodeAttrMask", 10 | "EdgePerturbation", 11 | "Diffusion", 12 | "DiffusionWithSample", 13 | "UniformSample", 14 | "RWSample", 15 | "AdaNodeAttrMask", 16 | "AdaEdgePerturbation" 17 | ] 18 | -------------------------------------------------------------------------------- /dig/sslgraph/method/contrastive/views_fn/combination.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch_geometric.data import Batch 3 | 4 | 5 | class RandomView(): 6 | r"""Generate views by random transformation (augmentation) on given batched graphs, 7 | where each graph in the batch is treated independently. Class objects callable via 8 | method :meth:`views_fn`. 9 | 10 | Args: 11 | candidates (list): A list of callable view generation functions (classes). 12 | """ 13 | 14 | def __init__(self, candidates): 15 | self.candidates = candidates 16 | 17 | def __call__(self, data): 18 | return self.views_fn(data) 19 | 20 | def views_fn(self, batch_data): 21 | r"""Method to be called when :class:`RandomView` object is called. 22 | 23 | Args: 24 | batch_data (:class:`torch_geometric.data.Batch`): The input batched graphs. 25 | 26 | :rtype: :class:`torch_geometric.data.Batch`. 27 | """ 28 | data_list = batch_data.to_data_list() 29 | transformed_list = [] 30 | for data in data_list: 31 | view_fn = random.choice(self.candidates) 32 | transformed = view_fn(data) 33 | transformed_list.append(transformed) 34 | 35 | return Batch.from_data_list(transformed_list) 36 | 37 | 38 | class Sequential(): 39 | r"""Generate views by applying a sequence of transformations (augmentations) on 40 | given batched graphs. Class objects callable via method :meth:`views_fn`. 41 | 42 | Args: 43 | fn_sequence (list): A list of callable view generation functions (classes). 44 | """ 45 | 46 | def __init__(self, fn_sequence): 47 | self.fn_sequence = fn_sequence 48 | 49 | def __call__(self, data): 50 | return self.views_fn(data) 51 | 52 | def views_fn(self, data): 53 | r"""Method to be called when :class:`Sequential` object is called. 54 | 55 | Args: 56 | data (:class:`torch_geometric.data.Data`): The input graph or batched graphs. 57 | 58 | :rtype: :class:`torch_geometric.data.Data`. 59 | """ 60 | for fn in self.fn_sequence: 61 | data = fn(data) 62 | 63 | return data -------------------------------------------------------------------------------- /dig/sslgraph/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders import Encoder 2 | from .seed import setup_seed 3 | 4 | __all__ = [ 5 | "Encoder", 6 | "setup_seed" 7 | ] 8 | -------------------------------------------------------------------------------- /dig/sslgraph/utils/seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def setup_seed(seed): 7 | r"""To setup seed for reproducible experiments. 8 | 9 | Args: 10 | seed (int, or float): The number used as seed. 11 | """ 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | torch.backends.cudnn.deterministic = True 15 | np.random.seed(seed) 16 | random.seed(seed) 17 | -------------------------------------------------------------------------------- /dig/threedgraph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/threedgraph/__init__.py -------------------------------------------------------------------------------- /dig/threedgraph/dataset/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | ## QM9 4 | 5 | QM9 dataset includes geometric, energetic, electronic, and thermodynamic properties for 134k stable small organic molecules [(paper)](https://www.nature.com/articles/sdata201422). 6 | We used the processed data in [DimeNet](https://github.com/klicperajo/dimenet/tree/master/data), you can also use [QM9 in Pytorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/qm9.html#QM9). 7 | 8 | There are 12 properties. Here are the units for each property. 9 | 10 | | | mu | alpha | homo | lumo | gap | r2 | zpve | U0 | U | H | G | Cv | std. MAE | 11 | | ------------------------ | ------ | ------ | ---- | ---- | ---- | ----- | ---- | ---- | ---- | ---- | ---- | ------ | -------- | 12 | | Unit in the dataset | D | a_0^3 | eV | eV | eV | a_0^2 | eV | eV | eV | eV | eV | cal / mol K | | 13 | | Unit of the reported MAE | D | a_0^3 | meV | meV | meV | a_0^2 | meV | meV | meV | meV | meV | cal / mol K | % | 14 | 15 | 16 | ## MD17 17 | 18 | MD17 is a collection of eight molecular dynamics simulations for small organic molecules [(paper)](https://advances.sciencemag.org/content/3/5/e1603015.short). 19 | 20 | The units for energy and force are kcal / mol and kcal / mol A. 21 | 22 | 23 | ## ECdataset and FOLDdataset 24 | 25 | For ECdataset and FOLDdatset, please download datasets from [here](https://github.com/phermosilla/IEConv_proteins#download-the-preprocessed-datasets) (Protein function and Scope 1.75) to a path. The set the parameter `root='path'` to load and process the data. 26 | 27 | Usage example: 28 | ```python 29 | # ECdataset 30 | for split in ['Train', 'Val', 'Test']: 31 | print('#### Now processing {} data ####'.format(split)) 32 | dataset = ECdataset(root='path', split=split) 33 | print(dataset) 34 | 35 | # FOLDdataset 36 | for split in ['training', 'validation', 'test_fold', 'test_superfamily', 'test_family']: 37 | print('#### Now processing {} data ####'.format(split)) 38 | dataset = FOLDdataset(root='path', split=split) 39 | print(dataset) 40 | ``` -------------------------------------------------------------------------------- /dig/threedgraph/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .PygQM93D import QM93D 2 | from .PygMD17 import MD17 3 | from .ECdataset import ECdataset 4 | from .FOLDdataset import FOLDdataset 5 | 6 | __all__ = [ 7 | 'QM93D', 8 | 'MD17', 9 | 'ECdataset', 10 | 'FOLDdataset' 11 | ] -------------------------------------------------------------------------------- /dig/threedgraph/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval import ThreeDEvaluator 2 | 3 | __all__ = [ 4 | 'ThreeDEvaluator' 5 | ] -------------------------------------------------------------------------------- /dig/threedgraph/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class ThreeDEvaluator: 5 | r""" 6 | Evaluator for the 3D datasets, including QM9, MD17. 7 | Metric is Mean Absolute Error. 8 | """ 9 | def __init__(self): 10 | pass 11 | 12 | def eval(self, input_dict): 13 | r"""Run evaluation. 14 | 15 | Args: 16 | input_dict (dict): A python dict with the following items: :obj:`y_true` and :obj:`y_pred`. 17 | :obj:`y_true` and :obj:`y_pred` need to be of the same type (either numpy.ndarray or torch.Tensor) and the same shape. 18 | 19 | :rtype: :class:`dict` (a python dict with item :obj:`mae`) 20 | """ 21 | assert('y_pred' in input_dict) 22 | assert('y_true' in input_dict) 23 | 24 | y_pred, y_true = input_dict['y_pred'], input_dict['y_true'] 25 | 26 | assert((isinstance(y_true, np.ndarray) and isinstance(y_pred, np.ndarray)) 27 | or 28 | (isinstance(y_true, torch.Tensor) and isinstance(y_pred, torch.Tensor))) 29 | assert(y_true.shape == y_pred.shape) 30 | 31 | if isinstance(y_true, torch.Tensor): 32 | return {'mae': torch.mean(torch.abs(y_pred - y_true)).cpu().item()} 33 | else: 34 | return {'mae': float(np.mean(np.absolute(y_pred - y_true)))} -------------------------------------------------------------------------------- /dig/threedgraph/method/__init__.py: -------------------------------------------------------------------------------- 1 | from .run import run 2 | from .schnet import SchNet 3 | from .dimenetpp import DimeNetPP 4 | from .spherenet import SphereNet 5 | from .comenet import ComENet 6 | from .pronet import ProNet 7 | 8 | 9 | __all__ = [ 10 | 'run', 11 | 'SchNet', 12 | 'DimeNetPP', 13 | 'SphereNet', 14 | 'ComENet', 15 | 'ProNet' 16 | ] -------------------------------------------------------------------------------- /dig/threedgraph/method/comenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .comenet import ComENet 2 | 3 | __all__ = [ 4 | 'ComENet' 5 | ] -------------------------------------------------------------------------------- /dig/threedgraph/method/comenet/ocp/ComENetIS2REResults.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/threedgraph/method/comenet/ocp/ComENetIS2REResults.jpg -------------------------------------------------------------------------------- /dig/threedgraph/method/comenet/ocp/IS2RETrainedModelWeights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/threedgraph/method/comenet/ocp/IS2RETrainedModelWeights.pt -------------------------------------------------------------------------------- /dig/threedgraph/method/comenet/ocp/README.md: -------------------------------------------------------------------------------- 1 | # ComENet: Towards Complete and Efficient Message Passing for 3D Molecular Graphs 2 | 3 | Limei Wang, Yi Liu, Yuchao Lin, Haoran Liu, Shuiwang Ji 4 | 5 | [paper](https://openreview.net/forum?id=mCzMqeWSFJ), [arxiv](https://arxiv.org/abs/2206.08515) 6 | 7 | ## Experiments 8 | **Task.** The [Open Catalyst 2020 (OC20)](https://arxiv.org/abs/2010.09990) dataset includes three tasks, namely Structure to Energy and Forces (S2EF), Initial Structure to Relaxed Structure (IS2RS), and Initial Structure to Relaxed Energy (IS2RE). We focus on the IS2RE task, which is the most common task in catalysis as the relaxed energies are often correlated with catalyst activity and selectivity. 9 | 10 | **Usage.** We use [the original OC20 framework](https://github.com/Open-Catalyst-Project/ocp) to train ComENet model. Please put file `comenet-ocp.py` and `utils.py` in a new directory `comenet` of `ocpmodels` directory, and put `comenet.yml` into the corresponding configs directory (e.g. https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/is2re/all). 11 | 12 | **Setting.** We use [**direct**](https://arxiv.org/abs/2010.09990) method in our paper without IS2RS auxiliary loss (e.g. [noisy nodes](https://openreview.net/forum?id=1wVvweK3oIb)). Specifically, during training, our model takes the initial structure as input, and the loss is the error between the predicted energy and the ground truth energy. During testing, our model takes an unseen structure as input and outputs the predicted energy. 13 | 14 | **Training and inference time.** ComENet can be trained in under 20 minutes per epoch on a single Nvidia GeForce RTX 2080 Ti GPU. The total training time is less than 1 day, and predictions take less than one minute per test/validation split. 15 | 16 | **Results.** 17 | Here are the results. 18 | ![OC20 IS2RE Performance](ComENetIS2REResults.jpg) 19 | 20 | **Trained model.** We also provided the model checkpoint (`IS2RETrainedModelWeights.pt`) in this folder. 21 | 22 | ## Citing 23 | 24 | If you use ComENet in your work, please consider citing: 25 | 26 | ```bibtex 27 | @inproceedings{ 28 | wang2022comenet, 29 | title={Com{EN}et: Towards Complete and Efficient Message Passing for 3D Molecular Graphs}, 30 | author={Limei Wang and Yi Liu and Yuchao Lin and Haoran Liu and Shuiwang Ji}, 31 | booktitle={Advances in Neural Information Processing Systems}, 32 | year={2022} 33 | } 34 | ``` -------------------------------------------------------------------------------- /dig/threedgraph/method/comenet/ocp/comenet.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - configs/is2re/all/base.yml 3 | 4 | model: 5 | name: ocpmodels.models.comenet.comenet.ComENet 6 | hidden_channels: 256 7 | num_blocks: 4 8 | cutoff: 6.0 9 | num_radial: 3 10 | num_spherical: 2 11 | hetero: False 12 | num_output_layers: 3 13 | use_pbc: True 14 | otf_graph: False 15 | 16 | optim: 17 | batch_size: 32 18 | eval_batch_size: 32 19 | num_workers: 4 20 | scheduler: CyclicLR 21 | lr_initial: 0.001 22 | base_lr: 0.000005 23 | max_lr: 0.001 24 | step_size_up: 57500 25 | mode: triangular2 26 | cycle_momentum: False 27 | amsgrad: True 28 | max_epochs: 50 29 | loss_energy: mae 30 | -------------------------------------------------------------------------------- /dig/threedgraph/method/dimenetpp/__init__.py: -------------------------------------------------------------------------------- 1 | from .dimenetpp import DimeNetPP 2 | from .features import dist_emb, angle_emb 3 | 4 | __all__ = [ 5 | 'DimeNetPP' 6 | ] -------------------------------------------------------------------------------- /dig/threedgraph/method/pronet/__init__.py: -------------------------------------------------------------------------------- 1 | from .pronet import ProNet 2 | 3 | __all__ = [ 4 | 'ProNet' 5 | ] -------------------------------------------------------------------------------- /dig/threedgraph/method/schnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .schnet import SchNet 2 | 3 | __all__ = [ 4 | 'SchNet' 5 | ] 6 | -------------------------------------------------------------------------------- /dig/threedgraph/method/spherenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .spherenet import SphereNet 2 | from .features import dist_emb, angle_emb, torsion_emb 3 | 4 | __all__ = [ 5 | 'SphereNet' 6 | ] -------------------------------------------------------------------------------- /dig/threedgraph/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .geometric_computing import xyz_to_dat 2 | 3 | __all__ = [ 4 | 'xyz_to_dat' 5 | ] 6 | 7 | 8 | -------------------------------------------------------------------------------- /dig/version.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __version__ = '1.0.0' 4 | debug = False 5 | ROOT_DIR = os.path.abspath(os.path.dirname(__file__)) 6 | -------------------------------------------------------------------------------- /dig/xgraph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/xgraph/__init__.py -------------------------------------------------------------------------------- /dig/xgraph/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .mol_dataset import MoleculeDataset 2 | from .nlp_dataset import SentiGraphDataset 3 | from .syn_dataset import BA_LRP, SynGraphDataset 4 | from .utils_dataset import MarginalSubgraphDataset 5 | 6 | __all__ = ['MoleculeDataset', 7 | 'SentiGraphDataset', 8 | 'BA_LRP', 9 | 'SynGraphDataset', 10 | 'MarginalSubgraphDataset'] 11 | 12 | -------------------------------------------------------------------------------- /dig/xgraph/dataset/utils_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data, Dataset 3 | 4 | 5 | class MarginalSubgraphDataset(Dataset): 6 | """ Collect pair-wise graph data to calculate marginal contribution. """ 7 | def __init__(self, data, exclude_mask, include_mask, subgraph_build_func) -> object: 8 | self.num_nodes = data.num_nodes 9 | self.X = data.x 10 | self.edge_index = data.edge_index 11 | self.device = self.X.device 12 | 13 | self.label = data.y 14 | self.exclude_mask = torch.tensor(exclude_mask).type(torch.float32).to(self.device) 15 | self.include_mask = torch.tensor(include_mask).type(torch.float32).to(self.device) 16 | self.subgraph_build_func = subgraph_build_func 17 | 18 | def __len__(self): 19 | return self.exclude_mask.shape[0] 20 | 21 | def __getitem__(self, idx): 22 | exclude_graph_X, exclude_graph_edge_index = self.subgraph_build_func(self.X, self.edge_index, self.exclude_mask[idx]) 23 | include_graph_X, include_graph_edge_index = self.subgraph_build_func(self.X, self.edge_index, self.include_mask[idx]) 24 | exclude_data = Data(x=exclude_graph_X, edge_index=exclude_graph_edge_index) 25 | include_data = Data(x=include_graph_X, edge_index=include_graph_edge_index) 26 | return exclude_data, include_data 27 | -------------------------------------------------------------------------------- /dig/xgraph/evaluation/Readme.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | In this part, we provide three metrics here as shown in the paper: 4 | 5 | * Fidelity+ 6 | * Fidelity- 7 | * Sparsity 8 | 9 | ## Quick Usage 10 | 11 | ### Premise: 12 | * your model 13 | * graphs (node features and edges) 14 | * explanation 15 | * your target sparsity (sparsity control) 16 | 17 | ### Format of Explanation 18 | 19 | An explanation (a mask) is a list where each element 20 | in is corresponding to an important mask for each class. 21 | 22 | Then, the important mask is a `num_edges` size pytorch tensor. 23 | 24 | ### Evaluate 25 | 26 | Given the inputs above, you can use the `ExplanationProcessor` and 27 | `XCollector` classes to obtain the results on the metrics. 28 | 29 | #### Class ExplanationProcessor 30 | 31 | It is for evaluating your model w/o the explanation. This class will generate 32 | related probabilities for further metric calculations. 33 | 34 | #### Class XCollector 35 | 36 | This class is to collect the corresponding model output (the probabilities of 37 | various label), then compute the Fidelity+(fidelity) and Fidelity-(fidelity_inv). 38 | 39 | The example is given in `test`. Please refer to the `test/xgraph/test_metrics` function 40 | where provides elaborated comments. -------------------------------------------------------------------------------- /dig/xgraph/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import XCollector, ExplanationProcessor, control_sparsity 2 | 3 | __all__ = [ 4 | 'XCollector', 5 | 'ExplanationProcessor', 6 | 'control_sparsity' 7 | ] 8 | -------------------------------------------------------------------------------- /dig/xgraph/evaluation/defi.py: -------------------------------------------------------------------------------- 1 | """ 2 | FileName: definitions.py 3 | Description: 4 | Time: 2020/7/28 12:22 5 | Project: GNN_benchmark 6 | Author: Shurui Gui 7 | """ 8 | import os 9 | 10 | ROOT_DIR = os.path.abspath(os.path.dirname(__file__)) 11 | -------------------------------------------------------------------------------- /dig/xgraph/method/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplift import DeepLIFT 2 | from .gnn_gi import GNN_GI 3 | from .gnn_lrp import GNN_LRP 4 | from .gnnexplainer import GNNExplainer 5 | from .gradcam import GradCAM 6 | from .pgexplainer import PGExplainer 7 | from .subgraphx import SubgraphX, MCTS 8 | from .flowx import FlowX 9 | 10 | __all__ = [ 11 | 'DeepLIFT', 12 | 'GNNExplainer', 13 | 'GNN_LRP', 14 | 'GNN_GI', 15 | 'GradCAM', 16 | 'PGExplainer', 17 | 'MCTS', 18 | 'SubgraphX', 19 | 'FlowX', 20 | ] 21 | -------------------------------------------------------------------------------- /dig/xgraph/method/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .symmetric_edge_mask import symmetric_edge_mask_indirect_graph -------------------------------------------------------------------------------- /dig/xgraph/method/utils/symmetric_edge_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def symmetric_edge_mask_indirect_graph(edge_index: 'torch.Tensor', edge_mask: 'torch.Tensor') -> 'torch.Tensor': 4 | """ makes the given edge_mask symmetric, provided the given graph is indirect one. 5 | 6 | Args: 7 | edge_index (torch.Tensor): edges of the target graph 8 | edge_mask (torch.Tensor): edge mask provided by the explainer for the graph 9 | """ 10 | # checkout if graph is indirect one 11 | def _is_indirect() -> bool: 12 | with torch.no_grad(): 13 | edge_index_src = edge_index.detach().unsqueeze(2) # shape: 2 N 1 14 | edge_index_rev = edge_index_src[[1, 0]].transpose(1, 2) # 2 1 N 15 | 16 | eq = edge_index_src - edge_index_rev == 0 # 2 N N 17 | rev_exist = (eq[0] * eq[1]).sum(1) # N 18 | return torch.all(rev_exist > 0).item() 19 | 20 | if _is_indirect(): 21 | edge_mask = edge_mask.to(edge_index.device) 22 | 23 | num_nodes = edge_index.unique().numel() 24 | edge_mask_asym = torch.sparse_coo_tensor(edge_index, 25 | edge_mask, (num_nodes, num_nodes)).to_dense() 26 | edge_mask_sym = (edge_mask_asym + edge_mask_asym.T) / 2 27 | edge_mask = edge_mask_sym[edge_index[0], edge_index[1]] 28 | 29 | return edge_mask 30 | -------------------------------------------------------------------------------- /dig/xgraph/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import GCNConv, GINConv, GINConv_mask, GCNConv_mask, GCN_2l_mask, GCN_2l, GCN_3l, GCN_3l_BN, GIN_2l, GIN_3l, \ 2 | GIN_2l_mask, GNNPool, GNNBasic, GlobalMeanPool, GraphSequential, IdenticalPool 3 | from .model_manager import load_model, config_model 4 | from .utils import ReadOut 5 | -------------------------------------------------------------------------------- /dig/xgraph/models/ext/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | FileName: __init__.py.py 3 | Description: 4 | Time: 2021/4/5 15:24 5 | Project: DIG 6 | Author: Shurui Gui 7 | """ 8 | -------------------------------------------------------------------------------- /dig/xgraph/models/ext/deeplift/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/xgraph/models/ext/deeplift/__init__.py -------------------------------------------------------------------------------- /dig/xgraph/models/model_manager.py: -------------------------------------------------------------------------------- 1 | """ 2 | FileName: model_manager.py 3 | Description: The Controller for all Graph Neural Network models 4 | Time: 2020/7/30 8:56 5 | Project: GNN_benchmark 6 | Author: Shurui Gui 7 | """ 8 | from inspect import isclass 9 | import dig.xgraph.models.models as models 10 | import torch 11 | import sys 12 | import os 13 | 14 | 15 | def load_model(name) -> torch.nn.Module: 16 | classes = [x for x in dir(models) if isclass(getattr(models, x))] 17 | try: 18 | assert name in classes 19 | except: 20 | print('#E#Model of given name does not exist.') 21 | sys.exit(0) 22 | 23 | model = getattr(models, name)() 24 | print(f'#IN#{model}') 25 | 26 | return model 27 | 28 | 29 | def config_model(model: torch.nn.Module, args, mode: str) -> None: 30 | model.to(args.device) 31 | model.train() 32 | 33 | # load checkpoint 34 | if mode == 'train' and args.tr_ctn: 35 | ckpt = torch.load(os.path.join(args.ckpt_dir, f'{args.model_name}_last.ckpt')) 36 | model.load_state_dict(ckpt['state_dict']) 37 | args.ctn_epoch = ckpt['epoch'] + 1 38 | print(f'#IN#Continue training from Epoch {ckpt["epoch"]}...') 39 | 40 | if mode == 'test' or mode == 'explain': 41 | try: 42 | ckpt = torch.load(args.test_ckpt) 43 | except FileNotFoundError: 44 | print(f'#E#Checkpoint not found at {os.path.abspath(args.test_ckpt)}') 45 | exit(1) 46 | model.load_state_dict(ckpt['state_dict']) 47 | print(f'#IN#Loading best Checkpoint {ckpt["epoch"]}...') 48 | 49 | -------------------------------------------------------------------------------- /dig/xgraph/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/dig/xgraph/utils/__init__.py -------------------------------------------------------------------------------- /dig/xgraph/utils/compatibility.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import OrderedDict 3 | 4 | from torch_geometric import __version__ 5 | 6 | 7 | def compatible_state_dict(state_dict): 8 | comp_state_dict = OrderedDict() 9 | for key, value in state_dict.items(): 10 | comp_key = key 11 | comp_value = value 12 | if int(__version__[0]) >= 2: 13 | comp_key = re.sub(r'conv(1|s.[0-9]).weight', 'conv\g<1>.lin.weight', key) 14 | if comp_key != key: 15 | comp_value = value.T 16 | if comp_key != key: 17 | comp_state_dict[key] = value 18 | comp_state_dict[comp_key] = comp_value 19 | return comp_state_dict 20 | -------------------------------------------------------------------------------- /dig/xgraph/utils/init.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def fix_random_seed(random_seed: int): 8 | r""" 9 | Fix multiple random seeds including python, numpy, torch, torch.cuda, and torch.backends. 10 | 11 | Args: 12 | random_seed (int): The random seed. 13 | """ 14 | random.seed(random_seed) 15 | np.random.seed(random_seed) 16 | torch.manual_seed(random_seed) 17 | torch.cuda.manual_seed(random_seed) 18 | torch.cuda.manual_seed_all(random_seed) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/docs/.DS_Store -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/environment.yaml: -------------------------------------------------------------------------------- 1 | name: 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python==3.8 7 | - pip>=21.0 8 | - pip: 9 | - rdkit-pypi 10 | - https://download.pytorch.org/whl/cpu/torch-1.10.0%2Bcpu-cp38-cp38-linux_x86_64.whl 11 | - furo 12 | - numpy 13 | - sphinx 14 | - sphinx_rtd_theme==0.5.2 15 | - https://data.pyg.org/whl/torch-1.10.0%2Bcpu/torch_scatter-2.0.9-cp38-cp38-linux_x86_64.whl 16 | - https://data.pyg.org/whl/torch-1.10.0%2Bcpu/torch_sparse-0.6.13-cp38-cp38-linux_x86_64.whl 17 | - https://data.pyg.org/whl/torch-1.10.0%2Bcpu/torch_cluster-1.6.0-cp38-cp38-linux_x86_64.whl 18 | - https://data.pyg.org/whl/torch-1.10.0%2Bcpu/torch_spline_conv-1.2.1-cp38-cp38-linux_x86_64.whl 19 | - torch-geometric==2.1.0 20 | - git+https://github.com/Chilipp/autodocsumm.git 21 | - captum==0.2.0 22 | - munch 23 | - gdown 24 | - cilog 25 | - typed-argument-parser==1.5.4 26 | - tensorboard 27 | - sympy 28 | - pyscf>=1.7.6 29 | - hydra-core 30 | - pygmtools 31 | - pyro-ppl 32 | - networkx 33 | -------------------------------------------------------------------------------- /docs/imgs/DIG-logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/docs/imgs/DIG-logo.jpg -------------------------------------------------------------------------------- /docs/imgs/DIG-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/docs/imgs/DIG-overview.png -------------------------------------------------------------------------------- /docs/imgs/GOOD-datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/docs/imgs/GOOD-datasets.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/3dgraph/dataset.rst: -------------------------------------------------------------------------------- 1 | dig.threedgraph.dataset 2 | ====== 3 | Dataset interfaces under :obj:`dig.threedgraph.dataset`. 4 | 5 | .. automodule:: dig.threedgraph.dataset 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/3dgraph/evaluation.rst: -------------------------------------------------------------------------------- 1 | dig.threedgraph.evaluation 2 | ====== 3 | Evaluation interfaces under :obj:`dig.threedgraph.evaluation`. 4 | 5 | .. automodule:: dig.threedgraph.evaluation 6 | :members: 7 | :special-members: -------------------------------------------------------------------------------- /docs/source/3dgraph/method.rst: -------------------------------------------------------------------------------- 1 | dig.threedgraph.method 2 | ====== 3 | Method classes under :obj:`dig.threedgraph.method`. 4 | 5 | .. automodule:: dig.threedgraph.method 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: 10 | -------------------------------------------------------------------------------- /docs/source/3dgraph/utils.rst: -------------------------------------------------------------------------------- 1 | dig.threedgraph.utils 2 | ====== 3 | 4 | Utilities under :obj:`dig.threedgraph.utils`. 5 | 6 | .. automodule:: dig.threedgraph.utils 7 | :members: 8 | :special-members: -------------------------------------------------------------------------------- /docs/source/auggraph/dataset.rst: -------------------------------------------------------------------------------- 1 | dig.auggraph.dataset 2 | ==================== 3 | Dataset interfaces under :obj:`dig.auggraph.dataset`. 4 | 5 | .. automodule:: dig.auggraph.dataset 6 | :members: 7 | :special-members: 8 | -------------------------------------------------------------------------------- /docs/source/auggraph/method.rst: -------------------------------------------------------------------------------- 1 | dig.auggraph.method 2 | =================== 3 | 4 | .. contents:: Graph Augmentation Methods 5 | :local: 6 | 7 | GraphAug 8 | -------- 9 | .. currentmodule:: dig.auggraph.method.GraphAug 10 | 11 | An augmentation method for graph datasets under :obj:`dig.auggraph.method.GraphAug` 12 | implemented from the paper 13 | `Automated Data Augmentations for Graph Classification `_. 15 | 16 | .. automodule:: dig.auggraph.method.GraphAug 17 | :members: 18 | :special-members: 19 | :autosummary: 20 | :autosummary-no-nesting: 21 | 22 | S-Mixup 23 | -------- 24 | .. currentmodule:: dig.auggraph.method.SMixup 25 | 26 | The S-Mixup from the `"Graph Mixup with Soft Alignments" `_ paper. 27 | 28 | .. automodule:: dig.auggraph.method.SMixup 29 | :members: 30 | :special-members: 31 | :autosummary: 32 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/fairgraph/dataset.rst: -------------------------------------------------------------------------------- 1 | dig.fairgraph.dataset 2 | ===================== 3 | 4 | .. contents:: Fairgraph Datasets 5 | :local: 6 | 7 | .. currentmodule:: dig.fairgraph.dataset.fairgraph_dataset.POKEC 8 | 9 | .. automodule:: dig.fairgraph.dataset.fairgraph_dataset 10 | :members: 11 | :special-members: 12 | :autosummary: 13 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/fairgraph/method.rst: -------------------------------------------------------------------------------- 1 | dig.fairgraph.method 2 | ==================== 3 | 4 | .. contents:: Fairgraph Methods 5 | :local: 6 | 7 | Graphair 8 | -------- 9 | .. currentmodule:: dig.fairgraph.method.Graphair 10 | 11 | A fair graph representation method for graph datasets under :obj:`dig.fairgraph.dataset.fairgraph_dataset` 12 | implemented from the paper 13 | `LEARNING FAIR GRAPH REPRESENTATIONS VIA AUTOMATED DATA AUGMENTATIONS `_. 14 | 15 | .. automodule:: dig.fairgraph.method.Graphair.graphair 16 | :members: 17 | :special-members: 18 | :exclude-members: forward 19 | :autosummary: 20 | :autosummary-no-nesting: 21 | 22 | Runner 23 | ------ 24 | 25 | .. currentmodule:: dig.fairgraph.method.run 26 | 27 | .. automodule:: dig.fairgraph.method.run 28 | :members: 29 | :special-members: 30 | :autosummary: 31 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/ggraph/dataset.rst: -------------------------------------------------------------------------------- 1 | dig.ggraph.dataset 2 | ========= 3 | Dataset interfaces under :obj:`dig.ggraph.dataset`. 4 | 5 | .. automodule:: dig.ggraph.dataset 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/ggraph/evaluation.rst: -------------------------------------------------------------------------------- 1 | dig.ggraph.evaluation 2 | ============ 3 | Evaluation interfaces under :obj:`dig.ggraph.evaluation`. 4 | 5 | .. automodule:: dig.ggraph.evaluation 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/ggraph/method.rst: -------------------------------------------------------------------------------- 1 | dig.ggraph.method 2 | ======== 3 | 4 | 5 | Method classes under :obj:`dig.ggraph.method`. 6 | 7 | .. automodule:: dig.ggraph.method 8 | :members: 9 | :special-members: 10 | :autosummary: 11 | :autosummary-no-nesting: 12 | -------------------------------------------------------------------------------- /docs/source/ggraph/utils.rst: -------------------------------------------------------------------------------- 1 | dig.ggraph.utils 2 | ====== 3 | Utilities under :obj:`dig.ggraph.utils`. 4 | 5 | .. automodule:: dig.ggraph.utils 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/ggraph3d/dataset.rst: -------------------------------------------------------------------------------- 1 | dig.ggraph3D.dataset 2 | ========= 3 | Dataset interfaces under :obj:`dig.ggraph3D.dataset`. 4 | 5 | .. automodule:: dig.ggraph3D.dataset 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/ggraph3d/evaluation.rst: -------------------------------------------------------------------------------- 1 | dig.ggraph3D.evaluation 2 | ============ 3 | Evaluation interfaces under :obj:`dig.ggraph3D.evaluation`. 4 | 5 | .. automodule:: dig.ggraph3D.evaluation 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/ggraph3d/method.rst: -------------------------------------------------------------------------------- 1 | dig.ggraph3D.method 2 | ======== 3 | 4 | 5 | Method classes under :obj:`dig.ggraph3D.method`. 6 | 7 | .. automodule:: dig.ggraph3D.method 8 | :members: 9 | :special-members: 10 | :autosummary: 11 | :autosummary-no-nesting: 12 | -------------------------------------------------------------------------------- /docs/source/ggraph3d/utils.rst: -------------------------------------------------------------------------------- 1 | dig.ggraph3D.utils 2 | ====== 3 | Utilities under :obj:`dig.ggraph3D.utils`. 4 | 5 | .. automodule:: dig.ggraph3D.utils 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/intro/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ====== 3 | 4 | Please follow the steps below to install DIG: Dive into Graphs. 5 | 6 | .. note:: 7 | We recommend you to create a virtual environment with `conda `_ and install DIG: Dive into Graphs in the virtual environment. 8 | 9 | 10 | Install from pip 11 | -------- 12 | The key dependencies of DIG: Dive into Graphs are PyTorch (>=1.10.0), PyTorch Geometric (>=2.0.0), and RDKit. 13 | 14 | #. Install `PyTorch `_ (>=1.10.0). 15 | 16 | .. code-block:: none 17 | 18 | $ python -c "import torch; print(torch.__version__)" 19 | >>> 1.10.0 20 | 21 | 22 | #. Install `PyTorch Geometric `_ (>=2.0.0). 23 | 24 | .. code-block:: none 25 | 26 | $ python -c "import torch_geometric; print(torch_geometric.__version__)" 27 | >>> 2.0.0 28 | 29 | 30 | #. Install DIG: Dive into Graphs. 31 | 32 | .. code-block:: none 33 | 34 | pip install dive-into-graphs 35 | 36 | 37 | After installation, you can check the version. You have successfully installed DIG: Dive into Graphs if no error occurs. 38 | 39 | .. code-block:: none 40 | 41 | $ python 42 | >>> from dig.version import __version__ 43 | >>> print(__version__) 44 | 45 | Install from source 46 | -------- 47 | If you want to try the latest features that have not been released yet, you can install dig from source. 48 | 49 | .. code-block:: none 50 | 51 | git clone https://github.com/divelab/DIG.git 52 | cd DIG 53 | pip install . -------------------------------------------------------------------------------- /docs/source/intro/introduction.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | DIG includes unified implementations of **data interfaces**, **common algorithms**, and **evaluation metrics** for several advanced tasks. Our goal is to enable researchers to easily implement and benchmark algorithms. Currently, we consider the following research directions. 5 | 6 | * **Graph Augmentation**: :obj:`dig.auggraph` 7 | * **Graph Generation**: :obj:`dig.ggraph` 8 | * **Self-supervised Learning on Graphs**: :obj:`dig.sslgraph` 9 | * **Explainability of Graph Neural Networks**: :obj:`dig.xgraph` 10 | * **Deep Learning on 3D Graphs**: :obj:`dig.threedgraph` 11 | * **Fair Graph Representations**: :obj:`dig.fairgraph` 12 | 13 | 14 | We provide a hands-on tutorial for each direction to help you to get started with DIG: 15 | 16 | * `Tutorial for Graph Generation `_ 17 | * `Tutorial for Self-supervised Learning on Graphs `_ 18 | * `Tutorial for Explainability of Graph Neural Networks `_ 19 | * `Tutorial for Deep Learning on 3D Graphs `_ 20 | * `Tutorial for Learning Fair Graph Representations `_ 21 | 22 | 23 | You can also refer to our provided `examples `_ about how to use APIs in DIG. 24 | 25 | .. image:: ../../imgs/DIG-overview.png 26 | :width: 100% 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/source/oodgraph/good.rst: -------------------------------------------------------------------------------- 1 | dig.oodgraph 2 | ================== 3 | Graph OOD (GOOD) Dataset interfaces under :obj:`dig.oodgraph`. 4 | 5 | Please refer to `the GOOD project `_ for more details. 6 | 7 | .. automodule:: dig.oodgraph 8 | :members: GOODHIV, GOODPCBA, GOODZINC, GOODCMNIST, GOODMotif, GOODCora, GOODArxiv, GOODCBAS 9 | :special-members: 10 | :autosummary: 11 | :autosummary-no-nesting: 12 | -------------------------------------------------------------------------------- /docs/source/sslgraph/dataset.rst: -------------------------------------------------------------------------------- 1 | dig.sslgraph.dataset 2 | ========= 3 | Dataset interfaces under :obj:`dig.sslgraph.dataset`. 4 | 5 | .. automodule:: dig.sslgraph.dataset 6 | :members: 7 | :special-members: 8 | -------------------------------------------------------------------------------- /docs/source/sslgraph/evaluation.rst: -------------------------------------------------------------------------------- 1 | dig.sslgraph.evaluation 2 | ============ 3 | Evaluation interfaces under :obj:`dig.sslgraph.evaluation`. 4 | 5 | .. automodule:: dig.sslgraph.evaluation 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/sslgraph/method.rst: -------------------------------------------------------------------------------- 1 | dig.sslgraph.method 2 | ======== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | Pre-implemented Contrastive Methods 8 | ----------------------------------- 9 | .. currentmodule:: dig.sslgraph.method 10 | 11 | Contrastive method classes under :obj:`dig.sslgraph.method`, or alias :obj:`dig.sslgraph.method.contrastive.model`. 12 | 13 | .. automodule:: dig.sslgraph.method 14 | :members: 15 | :special-members: 16 | :autosummary: 17 | :autosummary-no-nesting: 18 | 19 | 20 | Contrastive Objectives 21 | ---------------------- 22 | .. currentmodule:: dig.sslgraph.method.contrastive.objectives 23 | 24 | Contrastive objective functions under :obj:`dig.sslgraph.method.contrastive.objectives`. 25 | 26 | .. automodule:: dig.sslgraph.method.contrastive.objectives 27 | :members: 28 | :special-members: 29 | 30 | 31 | Transformations for Views Generation 32 | ------------------------------------ 33 | .. currentmodule:: dig.sslgraph.method.contrastive.views_fn 34 | 35 | Views generation classes under :obj:`dig.sslgraph.method.contrastive.views_fn`. 36 | 37 | .. automodule:: dig.sslgraph.method.contrastive.views_fn 38 | :members: 39 | :special-members: 40 | :autosummary: 41 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/sslgraph/utils.rst: -------------------------------------------------------------------------------- 1 | dig.sslgraph.utils 2 | ====== 3 | Utilities under :obj:`dig.sslgraph.utils`. 4 | 5 | .. automodule:: dig.sslgraph.utils 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: -------------------------------------------------------------------------------- /docs/source/tutorials/imgs/subgraphx_explanation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/docs/source/tutorials/imgs/subgraphx_explanation.png -------------------------------------------------------------------------------- /docs/source/tutorials/imgs/subgraphx_ori_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/docs/source/tutorials/imgs/subgraphx_ori_graph.png -------------------------------------------------------------------------------- /docs/source/tutorials/oodgraph.rst: -------------------------------------------------------------------------------- 1 | Tutorial for Graph OOD (GOOD) 2 | ===================== 3 | 4 | This module includes datasets from `the GOOD project `_. **GOOD** (Graph OOD) is a graph out-of-distribution (OOD) algorithm benchmarking library depending on PyTorch and PyG 5 | to make develop and benchmark OOD algorithms easily. 6 | 7 | Currently, this module contains 8 datasets with 14 domain selections. When combined with covariate, concept, and no shifts, we obtain 42 different splits. 8 | We provide `performance results `_ on 7 commonly used baseline methods (ERM, IRM, VREx, GroupDRO, Coral, DANN, Mixup) with 10 random runs. 9 | This results in 294 dataset-model combinations in total. Our results show significant performance gaps between in-distribution and OOD settings. 10 | This GOOD benchmark is a growing project and expects to expand in quantity and variety of resources as the area develops. 11 | 12 | .. image:: ../../imgs/GOOD-datasets.png 13 | :width: 680 14 | :alt: GOOD datasets 15 | 16 | The dataset loading example can be directly found `here `_. -------------------------------------------------------------------------------- /docs/source/xgraph/dataset.rst: -------------------------------------------------------------------------------- 1 | dig.xgraph.dataset 2 | ========= 3 | Dataset interfaces under :obj:`dig.xgraph.dataset`. 4 | 5 | .. automodule:: dig.xgraph.dataset 6 | :members: MoleculeDataset, SentiGraphDataset, BA_LRP, SynGraphDataset 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: 10 | -------------------------------------------------------------------------------- /docs/source/xgraph/evaluation.rst: -------------------------------------------------------------------------------- 1 | dig.xgraph.evaluation 2 | ====== 3 | Evaluation interfaces under :obj:`dig.xgraph.evaluation`. 4 | 5 | .. automodule:: dig.xgraph.evaluation 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: 10 | -------------------------------------------------------------------------------- /docs/source/xgraph/method.rst: -------------------------------------------------------------------------------- 1 | dig.xgraph.method 2 | ======== 3 | Methods interfaces under :obj:`dig.xgraph.method`. 4 | 5 | .. automodule:: dig.xgraph.method 6 | :members: 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: 10 | :exclude-members: MCTS 11 | -------------------------------------------------------------------------------- /docs/source/xgraph/utils.rst: -------------------------------------------------------------------------------- 1 | dig.xgraph.utils 2 | ======== 3 | Methods interfaces under :obj:`dig.xgraph.method`. 4 | 5 | .. automodule:: dig.xgraph.method 6 | :members: MCTS 7 | :special-members: 8 | :autosummary: 9 | :autosummary-no-nesting: 10 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ## How to run an .ipynb Jupyter Notebook from terminal? 2 | 3 | Some of our implementations in `benchmarks` are in `.ipynb` format so that they can be used as tutorials. 4 | 5 | If you don't want to use the "interactive" feature of Jupyter Notebook, you can also run the `.ipynb` file directly from your terminal as the following example. 6 | 7 | ```shell script 8 | $ ipython 9 | In [1]: run ./sslgraph/example_graphcl.ipynb 10 | ``` 11 | -------------------------------------------------------------------------------- /examples/auggraph/GraphAug/conf/aug_cls_conf.py: -------------------------------------------------------------------------------- 1 | # Author: Youzhi Luo (yzluo@tamu.edu) 2 | # Updated by: Anmol Anand (aanand@tamu.edu) 3 | 4 | import copy 5 | from dig.auggraph.method.GraphAug.constants import * 6 | 7 | base_aug_cls_conf = { 8 | MODEL_NAME: CLSModelType.GCN, 9 | NUM_LAYERS: 4, 10 | HIDDEN_UNITS: 128, 11 | DROPOUT: 0.5, 12 | POOL_TYPE: PoolType.MEAN, 13 | BATCH_SIZE: 32, 14 | INITIAL_LR: 0.001, 15 | FACTOR: 0.5, 16 | PATIENCE: 5, 17 | MIN_LR: 0.0000001, 18 | MAX_NUM_EPOCHS: 100, 19 | AUG_MODEL_PATH: None 20 | } 21 | 22 | aug_cls_conf = {} 23 | 24 | aug_cls_conf[DatasetName.NCI1] = copy.deepcopy(base_aug_cls_conf) 25 | aug_cls_conf[DatasetName.NCI1][NUM_LAYERS] = 3 26 | aug_cls_conf[DatasetName.NCI1][PATIENCE] = 100 27 | 28 | aug_cls_conf[DatasetName.COLLAB] = copy.deepcopy(base_aug_cls_conf) 29 | 30 | aug_cls_conf[DatasetName.MUTAG] = copy.deepcopy(base_aug_cls_conf) 31 | aug_cls_conf[DatasetName.MUTAG][BATCH_SIZE] = 16 32 | aug_cls_conf[DatasetName.MUTAG][PATIENCE] = 100 33 | 34 | aug_cls_conf[DatasetName.PROTEINS] = copy.deepcopy(base_aug_cls_conf) 35 | aug_cls_conf[DatasetName.PROTEINS][NUM_LAYERS] = 3 36 | 37 | aug_cls_conf[DatasetName.IMDB_BINARY] = copy.deepcopy(base_aug_cls_conf) 38 | aug_cls_conf[DatasetName.IMDB_BINARY][HIDDEN_UNITS] = 64 39 | 40 | aug_cls_conf[DatasetName.NCI109] = copy.deepcopy(base_aug_cls_conf) 41 | aug_cls_conf[DatasetName.NCI109][PATIENCE] = 100 42 | -------------------------------------------------------------------------------- /examples/auggraph/GraphAug/conf/generator_conf.py: -------------------------------------------------------------------------------- 1 | # Author: Youzhi Luo (yzluo@tamu.edu) 2 | # Updated by: Anmol Anand (aanand@tamu.edu) 3 | 4 | import copy 5 | from dig.auggraph.method.GraphAug.constants import * 6 | 7 | base_generator_conf = { 8 | BATCH_SIZE: 32, 9 | INITIAL_LR: 1e-4, 10 | GENERATOR_STEPS: 1, 11 | TEST_INTERVAL: 1, 12 | MAX_NUM_EPOCHS: 200, 13 | REWARD_GEN_STATE_PATH: None, 14 | BASELINE: BaselineType.MEAN, 15 | MOVING_RATIO: 0.1, 16 | SAVE_MODEL: True, 17 | GENERATOR_PARAMS: { 18 | NUM_LAYERS: 3, 19 | HID_DIM: 64, 20 | MAX_NUM_AUG: 8, 21 | USE_STOP_AUG: False, 22 | UNIFORM: False, 23 | RNN_INPUT: RnnInputType.VIRTUAL, 24 | AUG_TYPE_PARAMS: { 25 | AugType.NODE_FM.value: { 26 | HID_DIM: 64, TEMPERATURE: 1.0, TRAINING: True, MAGNITUDE: 0.05 27 | }, 28 | AugType.NODE_DROP.value: { 29 | HID_DIM: 64, TEMPERATURE: 1.0, TRAINING: True, MAGNITUDE: 0.05 30 | }, 31 | AugType.EDGE_Per.value: { 32 | HID_DIM: 64, TEMPERATURE: 1.0, TRAINING: True, MAGNITUDE: 0.05 33 | } 34 | } 35 | } 36 | } 37 | 38 | generator_conf = {} 39 | 40 | generator_conf[DatasetName.NCI1] = copy.deepcopy(base_generator_conf) 41 | 42 | generator_conf[DatasetName.COLLAB] = copy.deepcopy(base_generator_conf) 43 | generator_conf[DatasetName.COLLAB][BATCH_SIZE] = 8 44 | 45 | generator_conf[DatasetName.NCI109] = copy.deepcopy(base_generator_conf) 46 | 47 | generator_conf[DatasetName.MUTAG] = copy.deepcopy(base_generator_conf) 48 | generator_conf[DatasetName.MUTAG][BATCH_SIZE] = 16 49 | 50 | generator_conf[DatasetName.PROTEINS] = copy.deepcopy(base_generator_conf) 51 | generator_conf[DatasetName.PROTEINS][GENERATOR_PARAMS][NUM_LAYERS] = 6 52 | 53 | generator_conf[DatasetName.IMDB_BINARY] = copy.deepcopy(base_generator_conf) 54 | generator_conf[DatasetName.IMDB_BINARY][GENERATOR_PARAMS][NUM_LAYERS] = 6 55 | -------------------------------------------------------------------------------- /examples/auggraph/GraphAug/conf/reward_gen_conf.py: -------------------------------------------------------------------------------- 1 | # Author: Youzhi Luo (yzluo@tamu.edu) 2 | # Updated by: Anmol Anand (aanand@tamu.edu) 3 | 4 | import copy 5 | from dig.auggraph.method.GraphAug.constants import * 6 | 7 | base_reward_gen_conf = { 8 | BATCH_SIZE: 32, 9 | INITIAL_LR: 1e-4, 10 | FACTOR: 0.5, 11 | PATIENCE: 5, 12 | MIN_LR: 1e-7, 13 | MAX_NUM_EPOCHS: 200, 14 | PRE_TRAIN_PATH: None, 15 | REWARD_GEN_PARAMS: { 16 | NUM_LAYERS: 5, 17 | HIDDEN_UNITS: 256, 18 | MODEL_TYPE: RewardGenModelType.GMNET, 19 | POOL_TYPE: PoolType.SUM, 20 | FUSE_TYPE: FuseType.ABS_DIFF 21 | } 22 | } 23 | 24 | reward_gen_conf = {} 25 | 26 | reward_gen_conf[DatasetName.NCI1] = copy.deepcopy(base_reward_gen_conf) 27 | 28 | reward_gen_conf[DatasetName.COLLAB] = copy.deepcopy(base_reward_gen_conf) 29 | reward_gen_conf[DatasetName.COLLAB][BATCH_SIZE] = 8 30 | reward_gen_conf[DatasetName.COLLAB][MAX_NUM_EPOCHS] = 120 31 | 32 | reward_gen_conf[DatasetName.MUTAG] = copy.deepcopy(base_reward_gen_conf) 33 | reward_gen_conf[DatasetName.MUTAG][MAX_NUM_EPOCHS] = 230 34 | 35 | reward_gen_conf[DatasetName.PROTEINS] = copy.deepcopy(base_reward_gen_conf) 36 | reward_gen_conf[DatasetName.PROTEINS][MAX_NUM_EPOCHS] = 420 37 | reward_gen_conf[DatasetName.PROTEINS][REWARD_GEN_PARAMS][NUM_LAYERS] = 6 38 | 39 | reward_gen_conf[DatasetName.IMDB_BINARY] = copy.deepcopy(base_reward_gen_conf) 40 | reward_gen_conf[DatasetName.IMDB_BINARY][MAX_NUM_EPOCHS] = 320 41 | reward_gen_conf[DatasetName.IMDB_BINARY][REWARD_GEN_PARAMS][NUM_LAYERS] = 6 42 | 43 | reward_gen_conf[DatasetName.NCI109] = copy.deepcopy(base_reward_gen_conf) 44 | -------------------------------------------------------------------------------- /examples/auggraph/GraphAug/run_aug_cls.py: -------------------------------------------------------------------------------- 1 | # Author: Youzhi Luo (yzluo@tamu.edu) 2 | # Updated by: Anmol Anand (aanand@tamu.edu) 3 | 4 | import argparse 5 | from dig.auggraph.method.GraphAug.runner_aug_cls import RunnerAugCls 6 | from dig.auggraph.method.GraphAug.constants import * 7 | from examples.auggraph.GraphAug.conf.generator_conf import generator_conf 8 | from examples.auggraph.GraphAug.conf.aug_cls_conf import aug_cls_conf 9 | 10 | dataset_name = DatasetName.IMDB_BINARY 11 | conf = aug_cls_conf[dataset_name] 12 | conf[GENERATOR_PARAMS] = generator_conf[dataset_name][GENERATOR_PARAMS] 13 | generator_checkpoint = generator_conf[dataset_name][MAX_NUM_EPOCHS] - 1 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--data_root_path', type=str, default='../../../dig/auggraph/dataset/tudatasets', 16 | help='The directory with all graph datasets') 17 | parser.add_argument('--aug_model_path', type=str, 18 | default='./results/generator_results/{}/max_augs_{}/{}.pt'.format(dataset_name.value, conf[GENERATOR_PARAMS][MAX_NUM_AUG], str(generator_checkpoint).zfill(4)), 19 | help='The directory where generator states are stored after each epoch.') 20 | parser.add_argument('--aug_cls_results_path', type=str, 21 | default='./results/aug_cls_results/{}'.format(conf[MODEL_NAME].value), 22 | help='The directory where classification results will be stored after each epoch.') 23 | args = parser.parse_args() 24 | conf[AUG_MODEL_PATH] = args.aug_model_path 25 | 26 | runner = RunnerAugCls(args.data_root_path, dataset_name, conf) 27 | for _ in range(5): 28 | runner.train_test(args.aug_cls_results_path) 29 | -------------------------------------------------------------------------------- /examples/auggraph/GraphAug/run_generator.py: -------------------------------------------------------------------------------- 1 | # Author: Youzhi Luo (yzluo@tamu.edu) 2 | # Updated by: Anmol Anand (aanand@tamu.edu) 3 | 4 | import argparse 5 | from examples.auggraph.GraphAug.conf.generator_conf import generator_conf 6 | from examples.auggraph.GraphAug.conf.reward_gen_conf import reward_gen_conf 7 | from dig.auggraph.method.GraphAug.runner_generator import RunnerGenerator 8 | from dig.auggraph.method.GraphAug.constants import * 9 | 10 | dataset_name = DatasetName.IMDB_BINARY 11 | conf = generator_conf[dataset_name] 12 | conf[REWARD_GEN_PARAMS] = reward_gen_conf[dataset_name][REWARD_GEN_PARAMS] 13 | model_type = conf[REWARD_GEN_PARAMS][MODEL_TYPE] 14 | last_checkpoint = reward_gen_conf[dataset_name][MAX_NUM_EPOCHS] - 1 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data_root_path', type=str, default='../../../dig/auggraph/dataset/tudatasets', 17 | help='The directory with all graph datasets') 18 | parser.add_argument('--generator_results_path', type=str, 19 | default='./results/generator_results', 20 | help='The directory where generator states will be stored after each epoch.') 21 | parser.add_argument('--reward_gen_state_path', type=str, 22 | default='./results/reward_gen_results/{}/{}/{}.pt'.format(dataset_name.value, model_type.value, str(last_checkpoint).zfill(4)), 23 | help='File path for final training state of reward generation model') 24 | args = parser.parse_args() 25 | conf[REWARD_GEN_STATE_PATH] = args.reward_gen_state_path 26 | 27 | runner = RunnerGenerator(args.data_root_path, dataset_name, conf) 28 | runner.train_test(args.generator_results_path) 29 | -------------------------------------------------------------------------------- /examples/auggraph/GraphAug/run_reward_gen.py: -------------------------------------------------------------------------------- 1 | # Author: Youzhi Luo (yzluo@tamu.edu) 2 | # Updated by: Anmol Anand (aanand@tamu.edu) 3 | 4 | import argparse 5 | from examples.auggraph.GraphAug.conf.reward_gen_conf import reward_gen_conf 6 | from dig.auggraph.method.GraphAug.runner_reward_gen import RunnerRewardGen 7 | from dig.auggraph.method.GraphAug.constants import * 8 | 9 | dataset_name = DatasetName.IMDB_BINARY 10 | conf = reward_gen_conf[dataset_name] 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--data_root_path', type=str, default='../../../dig/auggraph/dataset/tudatasets', 13 | help='The directory with all graph datasets') 14 | parser.add_argument('--reward_gen_results_path', type=str, default='./results/reward_gen_results', 15 | help='The directory where reward gen states will be stored after each epoch.') 16 | args = parser.parse_args() 17 | 18 | runner_reward_gen = RunnerRewardGen(args.data_root_path, dataset_name, conf) 19 | runner_reward_gen.train_test(args.reward_gen_results_path) 20 | -------------------------------------------------------------------------------- /examples/auggraph/SMixup/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dig.auggraph.method.SMixup.smixup import smixup 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--dataset', type=str, default='IMDBB', 6 | choices=['IMDBB','PROTEINS', 'MUTAG', "REDDITB", 'IMDBM', 'REDDITM5', 'REDDITM12', 'NCI1']) 7 | parser.add_argument('--GMNET_nlayers', type=int, default=5, 8 | help='Number of layers of GMNET') 9 | parser.add_argument('--GMNET_hidden', type=int, default=100, 10 | help='Number of hidden units of GMNET') 11 | parser.add_argument('--GMNET_bs', type=int, default=32, 12 | help='Batch size of training GMNET') 13 | parser.add_argument('--GMNET_lr', type=float, default=1e-3, 14 | help='Initial learning rate of GMNET') 15 | parser.add_argument('--GMNET_epochs', type=int, default=10, 16 | help='Number of epochs to train GMNET') 17 | parser.add_argument('--batch_size', type=int, default=32, 18 | help='Batch size during training') 19 | parser.add_argument('--model', type = str, default='GIN', 20 | choices=['GCN', 'GIN']) 21 | parser.add_argument('--nlayers', type = int, default = 4, 22 | help='Number of GNN layers.') 23 | parser.add_argument('--hidden', type=int, default=32, 24 | help='Number of hidden units.') 25 | parser.add_argument('--dropout', type = float, default = 0.2, 26 | help='Dropout ratio.') 27 | parser.add_argument('--lr', type=float, default=1e-3, 28 | help='Initial learning rate.') 29 | parser.add_argument('--epochs', type=int, default=10, 30 | help='Number of epochs to train the classifier.') 31 | parser.add_argument('--alpha', type = float, default = 1.0, 32 | help='mixup ratio.') 33 | parser.add_argument('--ckpt_path', type=str, default='../../../../test/ckpts/', 34 | help='Location for saving checkpoints') 35 | 36 | args = parser.parse_args() 37 | 38 | GMNET_conf = {} 39 | GMNET_conf['nlayers'] = args.GMNET_nlayers 40 | GMNET_conf['nhidden'] = args.GMNET_hidden 41 | GMNET_conf['bs'] = args.GMNET_bs 42 | GMNET_conf['lr'] = args.GMNET_lr 43 | GMNET_conf['epochs'] = args.GMNET_epochs 44 | 45 | runner = smixup('../../../../test/datasets', args.dataset, GMNET_conf) 46 | 47 | runner.train_test(args.batch_size, args.model, cls_nlayers=args.nlayers, 48 | cls_hidden=args.hidden, cls_dropout=args.dropout, cls_lr=args.lr, 49 | cls_epochs=args.epochs, alpha=args.alpha, ckpt_path=args.ckpt_path,) 50 | -------------------------------------------------------------------------------- /examples/fairgraph/Graphair/run_graphair_nba.py: -------------------------------------------------------------------------------- 1 | from dig.fairgraph.method import run 2 | from dig.fairgraph.dataset import POKEC, NBA 3 | import torch 4 | 5 | # Load the dataset 6 | nba = NBA() 7 | 8 | # Train and evaluate 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | run_fairgraph = run() 11 | run_fairgraph.run(device,dataset=nba,model='Graphair',epochs=10000,test_epochs=500, 12 | lr=1e-3,weight_decay=1e-5) -------------------------------------------------------------------------------- /examples/fairgraph/Graphair/run_graphair_pokec.py: -------------------------------------------------------------------------------- 1 | from dig.fairgraph.method import run 2 | from dig.fairgraph.dataset import POKEC, NBA 3 | import torch 4 | 5 | # Load the dataset and split 6 | # pokec = POKEC(dataset_sample='pokec_z') # you may also choose 'pokec_n' 7 | pokec = POKEC(dataset_sample='pokec_n') 8 | 9 | # Train and evaluate 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | run_fairgraph = run() 12 | run_fairgraph.run(device,dataset=pokec,model='Graphair',epochs=10_000,test_epochs=500, 13 | lr=1e-3,weight_decay=1e-5) -------------------------------------------------------------------------------- /examples/ggraph/GraphAF/config/const_prop_opt_graphaf_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "smile": "smiles", 4 | "prop_list": "['penalized_logp']", 5 | "url": "https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc_800_graphaf.csv", 6 | "num_max_node": "38", 7 | "atom_list": "[6, 7, 8, 9, 15, 16, 17, 35, 53]" 8 | }, 9 | "model": { 10 | "max_size": 38, 11 | "edge_unroll": 12, 12 | "node_dim": 9, 13 | "bond_dim": 4, 14 | "num_flow_layer": 12, 15 | "num_rgcn_layer": 3, 16 | "nhid": 128, 17 | "nout": 128, 18 | "deq_coeff": 0.9, 19 | "st_type": "exp", 20 | "use_gpu": true, 21 | "use_df": false, 22 | "rl_conf_dict": { 23 | "modify_size": 5, 24 | "penalty": true, 25 | "update_iters": 4, 26 | "reward_type": "imp", 27 | "reward_decay": 0.9, 28 | "exp_temperature": 3.0, 29 | "exp_bias": 4.0, 30 | "linear_coeff": 1.0, 31 | "plogp_coeff": 0.33333, 32 | "moving_coeff": 0.99, 33 | "no_baseline": true, 34 | "split_batch": false, 35 | "divide_loss": true, 36 | "atom_list": [ 37 | 6, 38 | 7, 39 | 8, 40 | 9, 41 | 15, 42 | 16, 43 | 17, 44 | 35, 45 | 53 46 | ], 47 | "temperature": 0.75, 48 | "batch_size": 64, 49 | "max_size_rl": 38 50 | } 51 | }, 52 | "lr": 0.0001, 53 | "weight_decay": 0, 54 | "batch_size": 64, 55 | "max_iters": 300, 56 | "warm_up": 24, 57 | "pretrain_model": "ckpt/dense_gen_net_10.pth", 58 | "dense_gen_model": "saved_ckpts/const_prop_opt/checkpoint277.pth", 59 | "save_interval": 20, 60 | "save_dir": "const_prop_opt_graphaf", 61 | "num_max_node": 38, 62 | "temperature": 0.75, 63 | "atom_list": [ 64 | 6, 65 | 7, 66 | 8, 67 | 9, 68 | 15, 69 | 16, 70 | 17, 71 | 35, 72 | 53 73 | ], 74 | "repeat_time": 200, 75 | "min_optim_time": 50 76 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphAF/config/prop_opt_plogp_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "max_size": 38, 4 | "edge_unroll": 12, 5 | "node_dim": 9, 6 | "bond_dim": 4, 7 | "num_flow_layer": 12, 8 | "num_rgcn_layer": 3, 9 | "nhid": 128, 10 | "nout": 128, 11 | "deq_coeff": 0.9, 12 | "st_type": "exp", 13 | "use_gpu": true, 14 | "use_df": false, 15 | "rl_conf_dict": { 16 | "penalty": true, 17 | "update_iters": 4, 18 | "property_type": "plogp", 19 | "reward_type": "exp", 20 | "not_save_demon": true, 21 | "reward_decay": 0.9, 22 | "exp_temperature": 3.0, 23 | "exp_bias": 4.0, 24 | "linear_coeff": 2.0, 25 | "split_batch": false, 26 | "moving_coeff": 0.99, 27 | "no_baseline": true, 28 | "divide_loss": true, 29 | "atom_list": [ 30 | 6, 31 | 7, 32 | 8, 33 | 9, 34 | 15, 35 | 16, 36 | 17, 37 | 35, 38 | 53 39 | ], 40 | "temperature": 0.75, 41 | "batch_size": 8, 42 | "max_size_rl": 48 43 | } 44 | }, 45 | "lr": 0.0001, 46 | "weight_decay": 0, 47 | "max_iters": 200, 48 | "warm_up": 0, 49 | "pretrain_model": "ckpt/dense_gen_net_10.pth", 50 | "dense_gen_model": "saved_ckpts/prop_opt/dense_gen_net_10.pth", 51 | "save_interval": 20, 52 | "save_dir": "prop_opt_plogp", 53 | "num_min_node": 1, 54 | "num_max_node": 48, 55 | "temperature": 0.75, 56 | "atom_list": [ 57 | 6, 58 | 7, 59 | 8, 60 | 9, 61 | 15, 62 | 16, 63 | 17, 64 | 35, 65 | 53 66 | ] 67 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphAF/config/prop_opt_qed_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "max_size": 38, 4 | "edge_unroll": 12, 5 | "node_dim": 9, 6 | "bond_dim": 4, 7 | "num_flow_layer": 12, 8 | "num_rgcn_layer": 3, 9 | "nhid": 128, 10 | "nout": 128, 11 | "deq_coeff": 0.9, 12 | "st_type": "exp", 13 | "use_gpu": true, 14 | "use_df": false, 15 | "rl_conf_dict": { 16 | "penalty": true, 17 | "update_iters": 4, 18 | "property_type": "qed", 19 | "reward_type": "linear", 20 | "not_save_demon": true, 21 | "reward_decay": 0.97, 22 | "exp_temperature": 3.0, 23 | "exp_bias": 4.0, 24 | "linear_coeff": 2.0, 25 | "split_batch": false, 26 | "moving_coeff": 0.99, 27 | "no_baseline": true, 28 | "divide_loss": true, 29 | "atom_list": [ 30 | 6, 31 | 7, 32 | 8, 33 | 9, 34 | 15, 35 | 16, 36 | 17, 37 | 35, 38 | 53 39 | ], 40 | "temperature": 0.75, 41 | "batch_size": 8, 42 | "max_size_rl": 48 43 | } 44 | }, 45 | "lr": 0.0001, 46 | "weight_decay": 0, 47 | "max_iters": 200, 48 | "warm_up": 0, 49 | "pretrain_model": "ckpt/dense_gen_net_10.pth", 50 | "dense_gen_model": "saved_ckpts/prop_opt/dense_gen_net_10.pth", 51 | "save_interval": 20, 52 | "save_dir": "prop_opt_qed", 53 | "num_min_node": 1, 54 | "num_max_node": 48, 55 | "temperature": 0.75, 56 | "atom_list": [ 57 | 6, 58 | 7, 59 | 8, 60 | 9, 61 | 15, 62 | 16, 63 | 17, 64 | 35, 65 | 53 66 | ] 67 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphAF/config/rand_gen_qm9_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "smile": "smile", 4 | "prop_list": "[]", 5 | "url": "https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/qm9_property.csv", 6 | "num_max_node": "20", 7 | "num_bond_type": "3", 8 | "atom_list": "[6, 7, 8, 9]" 9 | }, 10 | "model": { 11 | "max_size": 20, 12 | "edge_unroll": 12, 13 | "node_dim": 4, 14 | "bond_dim": 4, 15 | "num_flow_layer": 12, 16 | "num_rgcn_layer": 3, 17 | "nhid": 128, 18 | "nout": 128, 19 | "deq_coeff": 0.9, 20 | "st_type": "exp", 21 | "use_gpu": true, 22 | "use_df": false 23 | }, 24 | "lr": 0.001, 25 | "weight_decay": 0, 26 | "batch_size": 32, 27 | "max_epochs": 10, 28 | "save_interval": 1, 29 | "save_dir": "rand_gen_qm9", 30 | "num_min_node": 7, 31 | "num_max_node": 25, 32 | "temperature": 0.6, 33 | "atom_list": [ 34 | 6, 35 | 7, 36 | 8, 37 | 9 38 | ] 39 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphAF/config/rand_gen_zinc250k_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "smile": "smile", 4 | "prop_list": "[]", 5 | "url": "https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc250k_property.csv", 6 | "num_max_node": "38", 7 | "num_bond_type": "3", 8 | "atom_list": "[6, 7, 8, 9, 15, 16, 17, 35, 53]" 9 | }, 10 | "model": { 11 | "max_size": 38, 12 | "edge_unroll": 12, 13 | "node_dim": 9, 14 | "bond_dim": 4, 15 | "num_flow_layer": 12, 16 | "num_rgcn_layer": 3, 17 | "nhid": 128, 18 | "nout": 128, 19 | "deq_coeff": 0.9, 20 | "st_type": "exp", 21 | "use_gpu": true, 22 | "use_df": false 23 | }, 24 | "lr": 0.001, 25 | "weight_decay": 0, 26 | "batch_size": 32, 27 | "max_epochs": 10, 28 | "save_interval": 1, 29 | "save_dir": "rand_gen_zinc250k", 30 | "num_min_node": 7, 31 | "num_max_node": 25, 32 | "temperature": 0.6, 33 | "atom_list": [ 34 | 6, 35 | 7, 36 | 8, 37 | 9, 38 | 15, 39 | 16, 40 | 17, 41 | 35, 42 | 53 43 | ] 44 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphAF/figs/graphaf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/examples/ggraph/GraphAF/figs/graphaf.png -------------------------------------------------------------------------------- /examples/ggraph/GraphAF/run_const_prop_opt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from rdkit import RDLogger 4 | from torch_geometric.loader import DenseDataLoader 5 | from dig.ggraph.method import GraphAF 6 | from dig.ggraph.evaluation import ConstPropOptEvaluator 7 | from dig.ggraph.dataset import ZINC800 8 | 9 | 10 | 11 | RDLogger.DisableLog('rdApp.*') 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--data', type=str, default='graphaf', choices=['graphaf'], help='dataset name') 15 | parser.add_argument('--model_path', type=str, default='./saved_ckpts/const_prop_opt/const_prop_opt_graphaf.pth', help='The path to the saved model file') 16 | parser.add_argument('--train', action='store_true', default=False, help='specify it to be true if you are running training') 17 | 18 | args = parser.parse_args() 19 | 20 | if args.data == 'graphaf': 21 | with open('config/cons_optim_graphaf_config_dict.json') as f: 22 | conf = json.load(f) 23 | dataset = ZINC800(method='graphaf', conf_dict=conf['data'], one_shot=False, use_aug=False) 24 | else: 25 | print('Only graphaf datasets are supported!') 26 | exit() 27 | 28 | runner = GraphAF() 29 | 30 | if args.train: 31 | loader = DenseDataLoader(dataset, batch_size=conf['batch_size'], shuffle=True) 32 | runner.train_cons_optim(loader, conf['lr'], conf['weight_decay'], conf['max_iters'], conf['warm_up'], conf['model'], conf['pretrain_model'], conf['save_interval'], conf['save_dir']) 33 | else: 34 | mols_0, mols_2, mols_4, mols_6 = runner.run_cons_optim(dataset, conf['model'], args.model_path, conf['repeat_time'], conf['min_optim_time'], conf['num_max_node'], conf['temperature'], conf['atom_list']) 35 | smiles = [data.smile for data in dataset] 36 | evaluator = ConstPropOptEvaluator() 37 | input_dict = {'mols_0': mols_0, 'mols_2': mols_2, 'mols_4': mols_4, 'mols_6': mols_6, 'inp_smiles':smiles} 38 | 39 | print('Evaluating...') 40 | results = evaluator.eval(input_dict) 41 | -------------------------------------------------------------------------------- /examples/ggraph/GraphAF/run_prop_opt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from rdkit import RDLogger 4 | from dig.ggraph.method import GraphAF 5 | from dig.ggraph.evaluation import PropOptEvaluator 6 | 7 | RDLogger.DisableLog('rdApp.*') 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--prop', type=str, default='plogp', choices=['plogp', 'qed'], help='property name') 11 | parser.add_argument('--model_path', type=str, default='./saved_ckpts/prop_opt/prop_opt_plogp.pth', help='The path to the saved model file') 12 | parser.add_argument('--num_mols', type=int, default=100, help='The number of molecules to be generated') 13 | parser.add_argument('--train', action='store_true', default=False, help='specify it to be true if you are running training') 14 | 15 | args = parser.parse_args() 16 | 17 | if args.prop == 'plogp': 18 | with open('config/prop_optim_plogp_config_dict.json') as f: 19 | conf = json.load(f) 20 | elif args.prop == 'qed': 21 | with open('config/prop_optim_qed_config_dict.json') as f: 22 | conf = json.load(f) 23 | else: 24 | print('Only plogp and qed properties are supported!') 25 | exit() 26 | 27 | runner = GraphAF() 28 | 29 | if args.train: 30 | runner.train_prop_optim(conf['lr'], conf['weight_decay'], conf['max_iters'], conf['warm_up'], conf['model'], conf['pretrain_model'], conf['save_interval'], conf['save_dir']) 31 | else: 32 | mols = runner.run_prop_optim(conf['model'], args.model_path, args.num_mols, conf['num_min_node'], conf['num_max_node'], conf['temperature'], conf['atom_list']) 33 | evaluator = PropOptEvaluator(prop_name=args.prop) 34 | input_dict = {'mols': mols} 35 | 36 | print('Evaluating...') 37 | results = evaluator.eval(input_dict) 38 | -------------------------------------------------------------------------------- /examples/ggraph/GraphAF/run_rand_gen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from rdkit import RDLogger 4 | from torch_geometric.loader import DenseDataLoader 5 | from dig.ggraph.dataset import QM9, ZINC250k, MOSES 6 | from dig.ggraph.method import GraphAF 7 | from dig.ggraph.evaluation import RandGenEvaluator 8 | 9 | 10 | RDLogger.DisableLog('rdApp.*') 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--data', type=str, default='qm9', choices=['qm9', 'zinc250k'], help='dataset name') 14 | parser.add_argument('--model_path', type=str, default='./saved_ckpts/rand_gen/rand_gen_qm9.pth', help='The path to the saved model file') 15 | parser.add_argument('--num_mols', type=int, default=100, help='The number of molecules to be generated') 16 | parser.add_argument('--train', action='store_true', default=False, help='specify it to be true if you are running training') 17 | 18 | args = parser.parse_args() 19 | 20 | if args.data == 'qm9': 21 | with open('config/rand_gen_qm9_config_dict.json') as f: 22 | conf = json.load(f) 23 | dataset = QM9(conf_dict=conf['data'], one_shot=False, use_aug=True) 24 | elif args.data == 'zinc250k': 25 | with open('config/rand_gen_zinc250k_config_dict.json') as f: 26 | conf = json.load(f) 27 | dataset = ZINC250k(conf_dict=conf['data'], one_shot=False, use_aug=True) 28 | else: 29 | print("Only qm9 and zinc250k datasets are supported!") 30 | exit() 31 | 32 | runner = GraphAF() 33 | 34 | if args.train: 35 | loader = DenseDataLoader(dataset, batch_size=conf['batch_size'], shuffle=True) 36 | runner.train_rand_gen(loader, conf['lr'], conf['weight_decay'], conf['max_epochs'], conf['model'], conf['save_interval'], conf['save_dir']) 37 | else: 38 | mols, pure_valids = runner.run_rand_gen(conf['model'], args.model_path, args.num_mols, conf['num_min_node'], conf['num_max_node'], conf['temperature'], conf['atom_list']) 39 | smiles = [data.smile for data in dataset] 40 | evaluator = RandGenEvaluator() 41 | input_dict = {'mols': mols, 'train_smiles': smiles} 42 | 43 | print('Evaluating...') 44 | results = evaluator.eval(input_dict) 45 | 46 | print("Valid Ratio without valency check: {:.2f}%".format(sum(pure_valids) / args.num_mols * 100)) 47 | -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/config/const_prop_opt_graphaf_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "smile": "smiles", 4 | "prop_list": "['penalized_logp']", 5 | "url": "https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc_800_graphaf.csv", 6 | "num_max_node": "38", 7 | "atom_list": "[6, 7, 8, 9, 15, 16, 17, 35, 53]" 8 | }, 9 | "model": { 10 | "max_size": 38, 11 | "edge_unroll": 12, 12 | "node_dim": 9, 13 | "bond_dim": 4, 14 | "num_flow_layer": 12, 15 | "num_rgcn_layer": 3, 16 | "nhid": 128, 17 | "nout": 128, 18 | "use_gpu": true, 19 | "rl_conf_dict": { 20 | "modify_size": 5, 21 | "penalty": true, 22 | "update_iters": 4, 23 | "reward_type": "imp", 24 | "reward_decay": 0.9, 25 | "exp_temperature": 3.0, 26 | "exp_bias": 4.0, 27 | "linear_coeff": 1.0, 28 | "moving_coeff": 0.99, 29 | "no_baseline": true, 30 | "atom_list": [ 31 | 6, 32 | 7, 33 | 8, 34 | 9, 35 | 15, 36 | 16, 37 | 17, 38 | 35, 39 | 53 40 | ], 41 | "temperature": [ 42 | 1.0, 43 | 1.0 44 | ], 45 | "batch_size": 16, 46 | "max_size_rl": 38 47 | } 48 | }, 49 | "lr": 0.0001, 50 | "weight_decay": 0, 51 | "batch_size": 16, 52 | "max_iters": 200, 53 | "warm_up": 24, 54 | "pretrain_model": "saved_ckpts/const_prop_opt/pretrain_graphaf.pth", 55 | "save_interval": 20, 56 | "save_dir": "const_prop_opt_graphaf", 57 | "num_max_node": 38, 58 | "temperature": [ 59 | 1.0, 60 | 1.0 61 | ], 62 | "atom_list": [ 63 | 6, 64 | 7, 65 | 8, 66 | 9, 67 | 15, 68 | 16, 69 | 17, 70 | 35, 71 | 53 72 | ], 73 | "repeat_time": 200, 74 | "min_optim_time": 50 75 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/config/const_prop_opt_jt_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "smile": "smiles", 4 | "prop_list": "['penalized_logp']", 5 | "url": "https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc_800_jt.csv", 6 | "num_max_node": "38", 7 | "atom_list": "[6, 7, 8, 9, 15, 16, 17, 35, 53]" 8 | }, 9 | "model": { 10 | "max_size": 38, 11 | "edge_unroll": 12, 12 | "node_dim": 9, 13 | "bond_dim": 4, 14 | "num_flow_layer": 12, 15 | "num_rgcn_layer": 3, 16 | "nhid": 128, 17 | "nout": 128, 18 | "use_gpu": true, 19 | "rl_conf_dict": { 20 | "modify_size": 5, 21 | "penalty": true, 22 | "update_iters": 4, 23 | "reward_type": "imp", 24 | "reward_decay": 0.9, 25 | "exp_temperature": 3.0, 26 | "exp_bias": 4.0, 27 | "linear_coeff": 1.0, 28 | "moving_coeff": 0.99, 29 | "no_baseline": true, 30 | "atom_list": [ 31 | 6, 32 | 7, 33 | 8, 34 | 9, 35 | 15, 36 | 16, 37 | 17, 38 | 35, 39 | 53 40 | ], 41 | "temperature": [ 42 | 1.0, 43 | 1.0 44 | ], 45 | "batch_size": 16, 46 | "max_size_rl": 38 47 | } 48 | }, 49 | "lr": 0.0001, 50 | "weight_decay": 0, 51 | "batch_size": 16, 52 | "max_iters": 200, 53 | "warm_up": 24, 54 | "pretrain_model": "saved_ckpts/const_prop_opt/pretrain_jt.pth", 55 | "save_interval": 20, 56 | "save_dir": "const_prop_opt_jt", 57 | "num_max_node": 38, 58 | "temperature": [ 59 | 1.0, 60 | 1.0 61 | ], 62 | "atom_list": [ 63 | 6, 64 | 7, 65 | 8, 66 | 9, 67 | 15, 68 | 16, 69 | 17, 70 | 35, 71 | 53 72 | ], 73 | "repeat_time": 200, 74 | "min_optim_time": 50 75 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/config/prop_opt_plogp_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "max_size": 48, 4 | "edge_unroll": 12, 5 | "node_dim": 9, 6 | "bond_dim": 4, 7 | "num_flow_layer": 12, 8 | "num_rgcn_layer": 3, 9 | "nhid": 128, 10 | "nout": 128, 11 | "use_gpu": true, 12 | "rl_conf_dict": { 13 | "penalty": true, 14 | "update_iters": 4, 15 | "property_type": "plogp", 16 | "reward_type": "exp", 17 | "not_save_demon": true, 18 | "reward_decay": 0.9, 19 | "exp_temperature": 3.0, 20 | "exp_bias": 4.0, 21 | "linear_coeff": 2.0, 22 | "split_batch": false, 23 | "moving_coeff": 0.99, 24 | "no_baseline": true, 25 | "atom_list": [ 26 | 6, 27 | 7, 28 | 8, 29 | 9, 30 | 15, 31 | 16, 32 | 17, 33 | 35, 34 | 53 35 | ], 36 | "temperature": [ 37 | 0.8, 38 | 0.1 39 | ], 40 | "batch_size": 8, 41 | "max_size_rl": 48 42 | } 43 | }, 44 | "lr": 0.0001, 45 | "weight_decay": 0, 46 | "max_iters": 200, 47 | "warm_up": 0, 48 | "pretrain_model": "saved_ckpts/prop_opt/pretrain_plogp.pth", 49 | "save_interval": 20, 50 | "save_dir": "prop_opt_plogp", 51 | "num_min_node": 1, 52 | "num_max_node": 48, 53 | "temperature": [ 54 | 0.8, 55 | 0.1 56 | ], 57 | "atom_list": [ 58 | 6, 59 | 7, 60 | 8, 61 | 9, 62 | 15, 63 | 16, 64 | 17, 65 | 35, 66 | 53 67 | ] 68 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/config/prop_opt_qed_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "max_size": 48, 4 | "edge_unroll": 12, 5 | "node_dim": 9, 6 | "bond_dim": 4, 7 | "num_flow_layer": 12, 8 | "num_rgcn_layer": 3, 9 | "nhid": 128, 10 | "nout": 128, 11 | "use_gpu": true, 12 | "rl_conf_dict": { 13 | "penalty": true, 14 | "update_iters": 4, 15 | "property_type": "qed", 16 | "reward_type": "linear", 17 | "not_save_demon": true, 18 | "reward_decay": 0.9, 19 | "exp_temperature": 3.0, 20 | "exp_bias": 4.0, 21 | "linear_coeff": 2.0, 22 | "split_batch": false, 23 | "moving_coeff": 0.99, 24 | "no_baseline": true, 25 | "atom_list": [ 26 | 6, 27 | 7, 28 | 8, 29 | 9, 30 | 15, 31 | 16, 32 | 17, 33 | 35, 34 | 53 35 | ], 36 | "temperature": [ 37 | 0.8, 38 | 0.1 39 | ], 40 | "batch_size": 8, 41 | "max_size_rl": 48 42 | } 43 | }, 44 | "lr": 0.0001, 45 | "weight_decay": 0, 46 | "max_iters": 200, 47 | "warm_up": 0, 48 | "pretrain_model": "saved_ckpts/prop_opt/pretrain_qed.pth", 49 | "save_interval": 20, 50 | "save_dir": "prop_opt_qed", 51 | "num_min_node": 1, 52 | "num_max_node": 48, 53 | "temperature": [ 54 | 0.8, 55 | 0.1 56 | ], 57 | "atom_list": [ 58 | 6, 59 | 7, 60 | 8, 61 | 9, 62 | 15, 63 | 16, 64 | 17, 65 | 35, 66 | 53 67 | ] 68 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/config/rand_gen_moses_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "smile": "smiles", 4 | "prop_list": "[]", 5 | "url": "https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/moses.csv", 6 | "num_max_node": "30", 7 | "atom_list": "[6, 7, 8, 9, 16, 17, 35]" 8 | }, 9 | "model": { 10 | "max_size": 30, 11 | "edge_unroll": 12, 12 | "node_dim": 7, 13 | "bond_dim": 4, 14 | "num_flow_layer": 12, 15 | "num_rgcn_layer": 3, 16 | "nhid": 128, 17 | "nout": 128, 18 | "use_gpu": true 19 | }, 20 | "lr": 0.001, 21 | "weight_decay": 0, 22 | "batch_size": 32, 23 | "max_epochs": 10, 24 | "save_interval": 1, 25 | "save_dir": "rand_gen_moses", 26 | "num_min_node": 7, 27 | "num_max_node": 25, 28 | "temperature": [ 29 | 0.3, 30 | 0.3 31 | ], 32 | "atom_list": [ 33 | 6, 34 | 7, 35 | 8, 36 | 9, 37 | 16, 38 | 17, 39 | 35 40 | ] 41 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/config/rand_gen_qm9_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "smile": "smile", 4 | "prop_list": "[]", 5 | "url": "https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/qm9_property.csv", 6 | "num_max_node": "20", 7 | "atom_list": "[6, 7, 8, 9]" 8 | }, 9 | "model": { 10 | "max_size": 20, 11 | "edge_unroll": 12, 12 | "node_dim": 4, 13 | "bond_dim": 4, 14 | "num_flow_layer": 12, 15 | "num_rgcn_layer": 3, 16 | "nhid": 128, 17 | "nout": 128, 18 | "use_gpu": true 19 | }, 20 | "lr": 0.001, 21 | "weight_decay": 0, 22 | "batch_size": 32, 23 | "max_epochs": 10, 24 | "save_interval": 1, 25 | "save_dir": "rand_gen_qm9", 26 | "num_min_node": 7, 27 | "num_max_node": 20, 28 | "temperature": [ 29 | 0.35, 30 | 0.23 31 | ], 32 | "atom_list": [ 33 | 6, 34 | 7, 35 | 8, 36 | 9 37 | ] 38 | } 39 | -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/config/rand_gen_zinc250k_config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "smile": "smile", 4 | "prop_list": "[]", 5 | "url": "https://raw.githubusercontent.com/divelab/DIG_storage/main/ggraph/zinc250k_property.csv", 6 | "num_max_node": "38", 7 | "atom_list": "[6, 7, 8, 9, 15, 16, 17, 35, 53]" 8 | }, 9 | "model": { 10 | "max_size": 38, 11 | "edge_unroll": 12, 12 | "node_dim": 9, 13 | "bond_dim": 4, 14 | "num_flow_layer": 12, 15 | "num_rgcn_layer": 3, 16 | "nhid": 128, 17 | "nout": 128, 18 | "use_gpu": true 19 | }, 20 | "lr": 0.001, 21 | "weight_decay": 0, 22 | "batch_size": 32, 23 | "max_epochs": 10, 24 | "save_interval": 1, 25 | "save_dir": "rand_gen_zinc250k", 26 | "num_min_node": 7, 27 | "num_max_node": 25, 28 | "temperature": [ 29 | 0.35, 30 | 0.2 31 | ], 32 | "atom_list": [ 33 | 6, 34 | 7, 35 | 8, 36 | 9, 37 | 15, 38 | 16, 39 | 17, 40 | 35, 41 | 53 42 | ] 43 | } -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/figs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/examples/ggraph/GraphDF/figs/.DS_Store -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/figs/graphdf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/examples/ggraph/GraphDF/figs/graphdf.png -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/run_const_prop_opt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from rdkit import RDLogger 4 | from torch_geometric.loader import DenseDataLoader 5 | from dig.ggraph.method import GraphDF 6 | from dig.ggraph.evaluation import ConstPropOptEvaluator 7 | from dig.ggraph.dataset import ZINC800 8 | 9 | RDLogger.DisableLog('rdApp.*') 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--data', type=str, default='graphaf', choices=['graphaf', 'jt'], help='dataset name') 13 | parser.add_argument('--model_path', type=str, default='./saved_ckpts/const_prop_opt/const_prop_opt_graphaf.pth', help='The path to the saved model file') 14 | parser.add_argument('--train', action='store_true', default=False, help='specify it to be true if you are running training') 15 | 16 | args = parser.parse_args() 17 | 18 | if args.data == 'graphaf': 19 | with open('config/const_prop_opt_graphaf_config_dict.json') as f: 20 | conf = json.load(f) 21 | dataset = ZINC800(method='graphaf', conf_dict=conf['data'], one_shot=False, use_aug=False) 22 | elif args.data == 'jt': 23 | with open('config/const_prop_opt_jt_config_dict.json') as f: 24 | conf = json.load(f) 25 | dataset = ZINC800(method='jt', conf_dict=conf['data'], one_shot=False, use_aug=False) 26 | else: 27 | print('Only graphaf and jt datasets are supported!') 28 | exit() 29 | 30 | runner = GraphDF() 31 | 32 | if args.train: 33 | loader = DenseDataLoader(dataset, batch_size=conf['batch_size'], shuffle=True) 34 | runner.train_const_prop_opt(loader, conf['lr'], conf['weight_decay'], conf['max_iters'], conf['warm_up'], conf['model'], conf['pretrain_model'], conf['save_interval'], conf['save_dir']) 35 | else: 36 | mols_0, mols_2, mols_4, mols_6 = runner.run_const_prop_opt(dataset, conf['model'], args.model_path, conf['repeat_time'], conf['min_optim_time'], conf['num_max_node'], conf['temperature'], conf['atom_list']) 37 | smiles = [data.smile for data in dataset] 38 | evaluator = ConstPropOptEvaluator() 39 | input_dict = {'mols_0': mols_0, 'mols_2': mols_2, 'mols_4': mols_4, 'mols_6': mols_6, 'inp_smiles':smiles} 40 | 41 | print('Evaluating...') 42 | results = evaluator.eval(input_dict) -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/run_prop_opt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from rdkit import RDLogger 4 | from dig.ggraph.method import GraphDF 5 | from dig.ggraph.evaluation import PropOptEvaluator 6 | 7 | RDLogger.DisableLog('rdApp.*') 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--prop', type=str, default='plogp', choices=['plogp', 'qed'], help='property name') 11 | parser.add_argument('--model_path', type=str, default='./saved_ckpts/prop_opt/prop_opt_plogp.pth', help='The path to the saved model file') 12 | parser.add_argument('--num_mols', type=int, default=100, help='The number of molecules to be generated') 13 | parser.add_argument('--train', action='store_true', default=False, help='specify it to be true if you are running training') 14 | 15 | args = parser.parse_args() 16 | 17 | if args.prop == 'plogp': 18 | with open('config/prop_opt_plogp_config_dict.json') as f: 19 | conf = json.load(f) 20 | elif args.prop == 'qed': 21 | with open('config/prop_opt_qed_config_dict.json') as f: 22 | conf = json.load(f) 23 | else: 24 | print('Only plogp and qed properties are supported!') 25 | exit() 26 | 27 | runner = GraphDF() 28 | 29 | if args.train: 30 | runner.train_prop_opt(conf['lr'], conf['weight_decay'], conf['max_iters'], conf['warm_up'], conf['model'], conf['pretrain_model'], conf['save_interval'], conf['save_dir']) 31 | else: 32 | mols = runner.run_prop_opt(conf['model'], args.model_path, args.num_mols, conf['num_min_node'], conf['num_max_node'], conf['temperature'], conf['atom_list']) 33 | evaluator = PropOptEvaluator() 34 | input_dict = {'mols': mols} 35 | 36 | print('Evaluating...') 37 | results = evaluator.eval(input_dict) -------------------------------------------------------------------------------- /examples/ggraph/GraphDF/run_rand_gen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from rdkit import RDLogger 4 | from torch_geometric.loader import DenseDataLoader 5 | from dig.ggraph.dataset import QM9, ZINC250k, MOSES 6 | from dig.ggraph.method import GraphDF 7 | from dig.ggraph.evaluation import RandGenEvaluator 8 | 9 | RDLogger.DisableLog('rdApp.*') 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--data', type=str, default='qm9', choices=['qm9', 'zinc250k', 'moses'], help='dataset name') 13 | parser.add_argument('--model_path', type=str, default='./saved_ckpts/rand_gen/rand_gen_qm9.pth', help='The path to the saved model file') 14 | parser.add_argument('--num_mols', type=int, default=100, help='The number of molecules to be generated') 15 | parser.add_argument('--train', action='store_true', default=False, help='specify it to be true if you are running training') 16 | 17 | args = parser.parse_args() 18 | 19 | if args.data == 'qm9': 20 | with open('config/rand_gen_qm9_config_dict.json') as f: 21 | conf = json.load(f) 22 | dataset = QM9(conf_dict=conf['data'], one_shot=False, use_aug=True) 23 | elif args.data == 'zinc250k': 24 | with open('config/rand_gen_zinc250k_config_dict.json') as f: 25 | conf = json.load(f) 26 | dataset = ZINC250k(conf_dict=conf['data'], one_shot=False, use_aug=True) 27 | elif args.data == 'moses': 28 | with open('config/rand_gen_moses_config_dict.json') as f: 29 | conf = json.load(f) 30 | dataset = MOSES(conf_dict=conf['data'], one_shot=False, use_aug=True) 31 | else: 32 | print("Only qm9, zinc250k and moses datasets are supported!") 33 | exit() 34 | 35 | runner = GraphDF() 36 | 37 | if args.train: 38 | loader = DenseDataLoader(dataset, batch_size=conf['batch_size'], shuffle=True) 39 | runner.train_rand_gen(loader, conf['lr'], conf['weight_decay'], conf['max_epochs'], conf['model'], conf['save_interval'], conf['save_dir']) 40 | else: 41 | mols, pure_valids = runner.run_rand_gen(conf['model'], args.model_path, args.num_mols, conf['num_min_node'], conf['num_max_node'], conf['temperature'], conf['atom_list']) 42 | smiles = [data.smile for data in dataset] 43 | evaluator = RandGenEvaluator() 44 | input_dict = {'mols': mols, 'train_smiles': smiles} 45 | 46 | print('Evaluating...') 47 | results = evaluator.eval(input_dict) 48 | 49 | print("Valid Ratio without valency check: {:.2f}%".format(sum(pure_valids) / args.num_mols * 100)) -------------------------------------------------------------------------------- /examples/ggraph/GraphEBM/README.md: -------------------------------------------------------------------------------- 1 | # GraphEBM 2 | 3 | This is an official implementation for [GraphEBM: Molecular Graph Generation with Energy-Based Models](https://arxiv.org/abs/2102.00546). 4 | 5 | ![](./figs/graphebm_training.png) 6 | 7 | 8 | ### Examples 9 | 10 | 1. Random Generation: `randn_gen.ipynb` 11 | 1. Goal-Directed Generation: `goal-directed_gen.ipynb` 12 | 1. Compositional Generation: `compositional_gen.ipynb` 13 | 14 | 15 | ### Citation 16 | ``` 17 | @article{liu2021graphebm, 18 | title={{GraphEBM}: Molecular Graph Generation with Energy-Based Models}, 19 | author={Meng Liu and Keqiang Yan and Bora Oztekin and Shuiwang Ji}, 20 | journal={arXiv preprint arXiv:2102.00546}, 21 | year={2021} 22 | } 23 | ``` 24 | 25 | ### Acknowledgement 26 | Our implementation is based on [MoFlow](https://github.com/calvin-zcx/moflow) and [IGEBM](https://github.com/rosinality/igebm-pytorch). Thanks a lot for their awesome works. 27 | -------------------------------------------------------------------------------- /examples/ggraph/GraphEBM/figs/graphebm_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/examples/ggraph/GraphEBM/figs/graphebm_training.png -------------------------------------------------------------------------------- /examples/ggraph3D/G_SphereNet/README.md: -------------------------------------------------------------------------------- 1 | # G-SphereNet 2 | 3 | This is the official implementation for [An Autoregressive Flow Model for 3D Molecular Geometry Generation from Scratch](https://openreview.net/forum?id=C03Ajc-NS5W). 4 | 5 | ![](./figs/gspherenet.png) 6 | 7 | ## Usage 8 | 9 | ### Random Generation 10 | 11 | You can use our trained models or train the model from scratch: 12 | ```shell script 13 | $ cd examples/ggraph3D/G_SphereNet 14 | $ CUDA_VISIBLE_DEVICES=${your_gpu_id} python run_rand_gen.py --train 15 | ``` 16 | To generate molecular geometries using our trained model and evaluate the performance, first download models from [this link](https://github.com/divelab/DIG_storage/tree/main/ggraph3D/G-SphereNet), then: 17 | ```shell script 18 | $ cd examples/ggraph3D/G_SphereNet 19 | $ CUDA_VISIBLE_DEVICES=${your_gpu_id} python run_rand_gen.py --num_mols=1000 --model_path=${path_to_the_model} 20 | ``` 21 | 22 | Note that the chemical validity results on our paper are tested with rdkit version of [2020.03.3.0](https://anaconda.org/mjohnson541/rdkit). However, running the pip installation command of the dig package will automatically install the latest rdkit package, in which case you may get a lower chemical validity on molecular geometries generated by our trained model. 23 | 24 | ### Targeted Molecule Discovery 25 | 26 | For targeted molecule discovery, we aim to generate molecular s with desirable properties (*i.e.*, low HOMO-LUMO gap or high isotropic polarizability in this work). You can use our trained models or train the model from scratch: 27 | ```shell script 28 | $ cd examples/ggraph3D/G_SphereNet 29 | $ CUDA_VISIBLE_DEVICES=${your_gpu_id} python run_prop_opt.py --train --prop=gap 30 | $ CUDA_VISIBLE_DEVICES=${your_gpu_id} python run_prop_opt.py --train --prop=alpha 31 | ``` 32 | 33 | To generate molecules using our trained model and evaluate the performance, first download models from [this link](https://github.com/divelab/DIG_storage/tree/main/ggraph3D/G-SphereNet), then: 34 | ```shell script 35 | $ cd example/ggraph3D/G_SphereNet 36 | $ CUDA_VISIBLE_DEVICES=${your_gpu_id} python run_prop_opt.py --num_mols=100 --model_path=${path_to_the_model} --prop=gap 37 | $ CUDA_VISIBLE_DEVICES=${your_gpu_id} python run_prop_opt.py --num_mols=100 --model_path=${path_to_the_model} --prop=alpha 38 | ``` 39 | 40 | ### Citation 41 | ``` 42 | @inproceedings{ 43 | luo2022an, 44 | title={An Autoregressive Flow Model for 3D Molecular Geometry Generation from Scratch}, 45 | author={Youzhi Luo and Shuiwang Ji}, 46 | booktitle={International Conference on Learning Representations}, 47 | year={2022}, 48 | url={https://openreview.net/forum?id=C03Ajc-NS5W} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /examples/ggraph3D/G_SphereNet/config_dict.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "cutoff": 5.0, 4 | "num_node_types": 5, 5 | "num_layers": 4, 6 | "hidden_channels": 128, 7 | "int_emb_size": 64, 8 | "basis_emb_size": 8, 9 | "out_emb_channels": 256, 10 | "num_spherical": 7, 11 | "num_radial": 6, 12 | "num_flow_layers": 6, 13 | "deq_coeff": 0.9, 14 | "use_gpu": true, 15 | "n_att_heads": 4 16 | }, 17 | "lr": 0.0001, 18 | "weight_decay": 0.0, 19 | "save_interval": 1, 20 | "batch_size": 64, 21 | "max_epochs": 100, 22 | "chunk_size": 1000, 23 | "num_min_node": 2, 24 | "num_max_node": 35, 25 | "temperature": [ 26 | 0.5, 27 | 0.3, 28 | 0.4, 29 | 1.0 30 | ], 31 | "focus_th": 0.5 32 | } -------------------------------------------------------------------------------- /examples/ggraph3D/G_SphereNet/figs/gspherenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/examples/ggraph3D/G_SphereNet/figs/gspherenet.png -------------------------------------------------------------------------------- /examples/ggraph3D/G_SphereNet/run_prop_opt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from torch.utils.data import DataLoader 4 | from dig.ggraph3D.method import G_SphereNet 5 | from dig.ggraph3D.evaluation import PropOptEvaluator 6 | from dig.ggraph3D.dataset import QM93DGEN, collate_fn 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--prop', type=str, default='gap', choices=['gap', 'alpha'], help='property name') 10 | parser.add_argument('--model_path', type=str, default='./G_SphereNet/prop_opt_gap.pth', help='The path to the saved model file') 11 | parser.add_argument('--num_mols', type=int, default=100, help='The number of molecule geometries to be generated') 12 | parser.add_argument('--train', action='store_true', default=False, help='specify it to be true if you are running training') 13 | 14 | args = parser.parse_args() 15 | 16 | with open('config_dict.json') as f: 17 | conf = json.load(f) 18 | 19 | runner = G_SphereNet() 20 | 21 | if args.train: 22 | dataset = QM93DGEN() 23 | idxs = dataset.get_idx_split('{}_opt'.format(args.prop)) 24 | train_set = dataset[idxs['train']] 25 | loader = DataLoader(train_set, batch_size=conf['batch_size'], shuffle=True, collate_fn=collate_fn) 26 | runner.train(loader, lr=conf['lr'], wd=conf['weight_decay'], max_epochs=conf['max_epochs'], model_conf_dict=conf['model'], checkpoint_path=None, save_interval=conf['save_interval'], save_dir='prop_opt_{}'.format(args.prop)) 27 | else: 28 | mol_dicts = runner.generate(model_conf_dict=conf['model'], checkpoint_path=args.model_path, n_mols=args.num_mols, chunk_size=conf['chunk_size'], num_min_node=conf['num_min_node'], num_max_node=conf['num_max_node'], temperature=conf['temperature'], focus_th=conf['focus_th']) 29 | good_threshold = 4.5 if args.prop == 'gap' else 91 30 | evaluator = PropOptEvaluator(prop_name=args.prop, good_threshold=good_threshold) 31 | 32 | print('Evaluating...') 33 | results = evaluator.eval(mol_dicts) -------------------------------------------------------------------------------- /examples/ggraph3D/G_SphereNet/run_rand_gen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import argparse 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from dig.ggraph3D.dataset import QM93DGEN, collate_fn 7 | from dig.ggraph3D.method import G_SphereNet 8 | from dig.ggraph3D.evaluation import RandGenEvaluator 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--model_path', type=str, default='./G_SphereNet/rand_gen.pth', help='The path to the saved model file') 13 | parser.add_argument('--num_mols', type=int, default=1000, help='The number of molecule geometries to be generated') 14 | parser.add_argument('--train', action='store_true', default=False, help='specify it to be true if you are running training') 15 | 16 | args = parser.parse_args() 17 | 18 | with open('config_dict.json') as f: 19 | conf = json.load(f) 20 | 21 | runner = G_SphereNet() 22 | 23 | if args.train: 24 | dataset = QM93DGEN() 25 | idxs = dataset.get_idx_split('rand_gen') 26 | train_set = dataset[idxs['train']] 27 | loader = DataLoader(train_set, batch_size=conf['batch_size'], shuffle=True, collate_fn=collate_fn) 28 | runner.train(loader, lr=conf['lr'], wd=conf['weight_decay'], max_epochs=conf['max_epochs'], model_conf_dict=conf['model'], checkpoint_path=None, save_interval=conf['save_interval'], save_dir='rand_gen') 29 | else: 30 | with torch.no_grad(): 31 | mol_dicts = runner.generate(model_conf_dict=conf['model'], checkpoint_path=args.model_path, n_mols=args.num_mols, chunk_size=conf['chunk_size'], num_min_node=conf['num_min_node'], num_max_node=conf['num_max_node'], temperature=conf['temperature'], focus_th=conf['focus_th']) 32 | evaluator = RandGenEvaluator() 33 | 34 | print('Evaluating chemical validity...') 35 | results = evaluator.eval_validity(mol_dicts) 36 | 37 | print('Evaluating MMD distances of bond length distributions...') 38 | with open('target_bond_lengths.dict','rb') as f: 39 | target_bond_dists = pickle.load(f) 40 | input_dict = {'mol_dicts': mol_dicts, 'target_bond_dists': target_bond_dists} 41 | results = evaluator.eval_bond_mmd(input_dict) 42 | -------------------------------------------------------------------------------- /examples/ggraph3D/G_SphereNet/target_bond_lengths.dict: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/examples/ggraph3D/G_SphereNet/target_bond_lengths.dict -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: pna 3 | - dataset: yelp 4 | device: 7 5 | root: '/data/haiyang/datasets' 6 | log_every: 1 7 | -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/dataset/amazon.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: amazon -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/dataset/arxiv.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: arxiv 3 | -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/dataset/flickr.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: flickr 3 | -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/dataset/ppi.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: ppi 3 | -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/dataset/products.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: products 3 | -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/dataset/reddit.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: reddit 3 | -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/dataset/yelp.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: yelp 3 | -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/model/gcn.yaml: -------------------------------------------------------------------------------- 1 | name: GCN 2 | norm: true 3 | loop: true 4 | params: 5 | # 0.9511 6 | reddit: 7 | architecture: 8 | num_layers: 2 9 | hidden_channels: 256 10 | dropout: 0.5 11 | drop_input: false 12 | batch_norm: false 13 | residual: false 14 | num_parts: 200 15 | batch_size: 100 16 | max_steps: 2 17 | pool_size: 2 18 | num_workers: 0 19 | lr: 0.05 20 | reg_weight_decay: 0.0 21 | nonreg_weight_decay: 0.0 22 | grad_norm: none 23 | epochs: 400 24 | gamma: 0.5 25 | 26 | # 0.5449 27 | flickr: 28 | architecture: 29 | num_layers: 2 30 | hidden_channels: 512 31 | dropout: 0.5 32 | drop_input: true 33 | batch_norm: true 34 | residual: false 35 | num_parts: 24 36 | batch_size: 12 37 | max_steps: 2 38 | pool_size: 2 39 | num_workers: 0 40 | lr: 0.01 41 | reg_weight_decay: 0 42 | nonreg_weight_decay: 0 43 | grad_norm: 2.0 44 | epochs: 1000 45 | gamma: 0.1 46 | 47 | # ... 48 | yelp: 49 | architecture: 50 | num_layers: 2 51 | hidden_channels: 512 52 | dropout: 0.0 53 | drop_input: false 54 | batch_norm: false 55 | residual: true 56 | linear: false 57 | num_parts: 40 58 | batch_size: 5 59 | max_steps: 4 60 | pool_size: 2 61 | num_workers: 0 62 | lr: 0.01 63 | reg_weight_decay: 0 64 | nonreg_weight_decay: 0 65 | grad_norm: null 66 | epochs: 500 67 | gamma: 0.3 68 | 69 | # 0.7171 70 | arxiv: 71 | architecture: 72 | num_layers: 3 73 | hidden_channels: 256 74 | dropout: 0.5 75 | drop_input: false 76 | batch_norm: true 77 | residual: false 78 | num_parts: 80 79 | batch_size: 40 80 | max_steps: 2 81 | pool_size: 2 82 | num_workers: 0 83 | lr: 0.01 84 | reg_weight_decay: 0 85 | nonreg_weight_decay: 0 86 | grad_norm: none 87 | epochs: 400 88 | gamma: 0.1 89 | 90 | # 91 | products: 92 | architecture: 93 | num_layers: 3 94 | hidden_channels: 256 95 | dropout: 0.3 96 | drop_input: false 97 | batch_norm: false 98 | residual: false 99 | num_parts: 7 100 | batch_size: 1 101 | max_steps: 4 102 | pool_size: 1 103 | num_workers: 0 104 | lr: 0.005 105 | reg_weight_decay: 0 106 | nonreg_weight_decay: 0 107 | grad_norm: 2.0 108 | epochs: 300 109 | gamma: 0.5 110 | -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/model/pna.yaml: -------------------------------------------------------------------------------- 1 | name: PNA 2 | norm: false 3 | loop: false 4 | params: 5 | 6 | # 0.7296 7 | arxiv: 8 | architecture: 9 | num_layers: 3 10 | hidden_channels: 256 11 | aggregators: ['mean'] 12 | scalers: ['identity', 'amplification'] 13 | dropout: 0.5 14 | drop_input: false 15 | batch_norm: true 16 | residual: false 17 | num_parts: 40 18 | batch_size: 20 19 | max_steps: 2 20 | pool_size: 2 21 | num_workers: 0 22 | lr: 0.005 23 | reg_weight_decay: 0.0 24 | nonreg_weight_decay: 0.0 25 | grad_norm: null 26 | epochs: 300 27 | gamma: 0.3 28 | 29 | flickr: 30 | architecture: 31 | num_layers: 4 32 | hidden_channels: 64 33 | aggregators: ['mean', 'max'] 34 | scalers: ['identity', 'amplification'] 35 | dropout: 0.5 36 | drop_input: true 37 | batch_norm: true 38 | residual: false 39 | num_parts: 24 40 | batch_size: 12 41 | max_steps: 2 42 | pool_size: 2 43 | num_workers: 0 44 | lr: 0.005 45 | reg_weight_decay: 0 46 | nonreg_weight_decay: 0 47 | grad_norm: null 48 | epochs: 800 49 | gamma: 0.5 50 | 51 | # 0.6450 52 | yelp: 53 | architecture: 54 | num_layers: 3 55 | hidden_channels: 512 56 | aggregators: ['mean'] 57 | scalers: ['identity', 'amplification'] 58 | dropout: 0.1 59 | drop_input: false 60 | batch_norm: false 61 | residual: false 62 | num_parts: 40 63 | batch_size: 5 64 | max_steps: 4 65 | pool_size: 2 66 | num_workers: 0 67 | lr: 0.005 68 | reg_weight_decay: 0.0 69 | nonreg_weight_decay: 0.0 70 | grad_norm: 1.0 71 | epochs: 400 72 | gamma: 0.3 73 | -------------------------------------------------------------------------------- /examples/lsgraph/GraphFMOB/conf/model/pna_jk.yaml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /imgs/DIG-logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/imgs/DIG-logo.jpg -------------------------------------------------------------------------------- /imgs/DIG-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/imgs/DIG-overview.png -------------------------------------------------------------------------------- /script/conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "${TRAVIS_OS_NAME}" = "linux" ]; then 4 | wget -nv https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 5 | chmod +x miniconda.sh 6 | ./miniconda.sh -b 7 | PATH=/home/travis/miniconda3/bin:${PATH} 8 | fi 9 | 10 | conda update --yes conda 11 | 12 | conda create --yes -n test python="${PYTHON_VERSION}" 13 | 14 | if [ "${TRAVIS_OS_NAME}" = "linux" ]; then 15 | export TOOLKIT=cpuonly 16 | fi 17 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --capture=no --cov -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from dig.version import __version__ 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | setup_requires = ['pytest-runner'] 8 | tests_require = ['pytest', 'pytest-cov', 'mock'] 9 | 10 | 11 | setuptools.setup( 12 | name="dive_into_graphs", 13 | version=__version__, 14 | author="DIVE Lab@TAMU", 15 | author_email="sji@tamu.edu", 16 | # entry_points={ 17 | # 'console_scripts': [ 18 | # 'dig=dig.xxx.xxx' 19 | # ] 20 | # }, 21 | maintainer="DIVE Lab@TAMU", 22 | # maintainer_email="xxx", 23 | license="GPLv3", 24 | # keywords="xxx", 25 | description="DIG: Dive into Graphs is a turnkey library for graph deep learning research.", 26 | long_description=long_description, 27 | long_description_content_type="text/markdown", 28 | url="https://github.com/divelab/DIG", 29 | packages=setuptools.find_packages(), 30 | classifiers=[ 31 | "Programming Language :: Python :: 3", 32 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 33 | "Operating System :: OS Independent", 34 | ], 35 | install_requires=['scipy', 36 | 'cilog', 37 | 'typed-argument-parser==1.7.2', 38 | 'captum==0.2.0', 39 | 'munch', 40 | 'gdown', 41 | 'shap', 42 | 'IPython', 43 | 'tqdm', 44 | 'rdkit-pypi', 45 | 'pandas', 46 | 'sympy', 47 | 'pyscf>=1.7.6', 48 | 'hydra-core'], 49 | python_requires='>=3.6', 50 | setup_requires=setup_requires, 51 | tests_require=tests_require, 52 | extras_require={'test': tests_require}, 53 | include_package_data=True 54 | ) 55 | -------------------------------------------------------------------------------- /test/ggraph/dataset/test_QM9.py: -------------------------------------------------------------------------------- 1 | from dig.ggraph.dataset import QM9 2 | import shutil 3 | 4 | def test_qm9(): 5 | root = './dataset/QM9' 6 | 7 | dataset = QM9(root, prop_name='penalized_logp') 8 | 9 | assert len(dataset) == 133885 10 | assert dataset.num_features == 4 11 | assert dataset.__repr__() == 'qm9_property(133885)' 12 | assert len(dataset.get_split_idx()) == 2 13 | 14 | assert len(dataset[0]) == 6 15 | assert dataset[0].x.size() == (9, 4) 16 | assert dataset[0].y.size() == (1,) 17 | assert dataset[0].adj.size() == (4, 9, 9) 18 | assert dataset[0].bfs_perm_origin.size() == (9,) 19 | assert dataset[0].num_atom.size() == (1,) 20 | 21 | dataset = QM9(root, one_shot=True, prop_name='penalized_logp') 22 | 23 | assert len(dataset) == 133885 24 | assert dataset.__repr__() == 'qm9_property(133885)' 25 | assert len(dataset.get_split_idx()) == 2 26 | 27 | assert len(dataset[0]) == 5 28 | assert dataset[0].x.size() == (5, 9) 29 | assert dataset[0].y.size() == (1,) 30 | assert dataset[0].adj.size() == (4, 9, 9) 31 | assert dataset[0].num_atom.size() == (1,) 32 | 33 | shutil.rmtree(root) 34 | 35 | if __name__ == '__main__': 36 | test_qm9() 37 | -------------------------------------------------------------------------------- /test/ggraph/dataset/test_ZINC250k.py: -------------------------------------------------------------------------------- 1 | from dig.ggraph.dataset import ZINC250k 2 | import shutil 3 | 4 | def test_zinc250k(): 5 | root = './dataset/ZINC250k' 6 | 7 | dataset = ZINC250k(root, prop_name='penalized_logp') 8 | 9 | assert len(dataset) == 249455 10 | assert dataset.num_features == 9 11 | assert dataset.__repr__() == 'zinc250k_property(249455)' 12 | assert len(dataset.get_split_idx()) == 2 13 | 14 | assert len(dataset[0]) == 6 15 | assert dataset[0].x.size() == (38, 9) 16 | assert dataset[0].y.size() == (1,) 17 | assert dataset[0].adj.size() == (4, 38, 38) 18 | assert dataset[0].bfs_perm_origin.size() == (38,) 19 | assert dataset[0].num_atom.size() == (1,) 20 | 21 | dataset = ZINC250k(root, one_shot=True, prop_name='penalized_logp') 22 | 23 | assert len(dataset) == 249455 24 | assert dataset.__repr__() == 'zinc250k_property(249455)' 25 | assert len(dataset.get_split_idx()) == 2 26 | 27 | assert len(dataset[0]) == 5 28 | assert dataset[0].x.size() == (10, 38) 29 | assert dataset[0].y.size() == (1,) 30 | assert dataset[0].adj.size() == (4, 38, 38) 31 | assert dataset[0].num_atom.size() == (1,) 32 | 33 | shutil.rmtree(root) 34 | 35 | if __name__ == '__main__': 36 | test_zinc250k() -------------------------------------------------------------------------------- /test/ggraph/dataset/test_ZINC800.py: -------------------------------------------------------------------------------- 1 | from dig.ggraph.dataset import ZINC800 2 | import shutil 3 | 4 | def test_zinc800(): 5 | root = './dataset/ZINC800' 6 | 7 | dataset = ZINC800(root) 8 | 9 | assert len(dataset) == 800 10 | assert dataset.num_features == 9 11 | assert dataset.__repr__() == 'zinc_800_jt(800)' 12 | assert dataset.get_split_idx() is None 13 | 14 | assert len(dataset[0]) == 6 15 | assert dataset[0].x.size() == (38, 9) 16 | assert dataset[0].y.size() == (1,) 17 | assert dataset[0].adj.size() == (4, 38, 38) 18 | assert dataset[0].bfs_perm_origin.size() == (38,) 19 | assert dataset[0].num_atom.size() == (1,) 20 | 21 | dataset = ZINC800(root, one_shot=True) 22 | 23 | assert len(dataset) == 800 24 | assert dataset.__repr__() == 'zinc_800_jt(800)' 25 | assert dataset.get_split_idx() is None 26 | 27 | assert len(dataset[0]) == 5 28 | assert dataset[0].x.size() == (10, 38) 29 | assert dataset[0].y.size() == (1,) 30 | assert dataset[0].adj.size() == (4, 38, 38) 31 | assert dataset[0].num_atom.size() == (1,) 32 | 33 | shutil.rmtree(root) 34 | 35 | if __name__ == '__main__': 36 | test_zinc800() -------------------------------------------------------------------------------- /test/ggraph/evaluation/test_ConstPropOptEvaluator.py: -------------------------------------------------------------------------------- 1 | from dig.ggraph.evaluation import ConstPropOptEvaluator 2 | from rdkit import Chem 3 | import shutil 4 | 5 | def test_ConstPropOptEvaluator(): 6 | smile = 'C' 7 | mol = Chem.MolFromSmiles(smile) 8 | res_dict = {'inp_smiles': smile, 'mols_0':[mol], 'mols_2': [mol], 'mols_4': [mol], 'mols_6': [mol]} 9 | 10 | evaluator = ConstPropOptEvaluator() 11 | results = evaluator.eval(res_dict) 12 | 13 | assert results == {0: (100.0, 1.0, 0.0, 0.0, 0.0), 2: (100.0, 1.0, 0.0, 0.0, 0.0), 4: (100.0, 1.0, 0.0, 0.0, 0.0), 6: (100.0, 1.0, 0.0, 0.0, 0.0)} 14 | 15 | if __name__ == '__main__': 16 | test_ConstPropOptEvaluator() -------------------------------------------------------------------------------- /test/ggraph/evaluation/test_PropOptEvaluator.py: -------------------------------------------------------------------------------- 1 | from dig.ggraph.evaluation import PropOptEvaluator 2 | from rdkit import Chem 3 | import shutil 4 | 5 | def test_PropOptEvaluator(): 6 | smiles = ['C', 'N', 'O'] 7 | mols = [] 8 | for s in smiles: 9 | mol = Chem.MolFromSmiles(s) 10 | mols.append(mol) 11 | res_dict = {'mols':mols} 12 | evaluator = PropOptEvaluator() 13 | results = evaluator.eval(res_dict) 14 | 15 | assert results == {1: ('O', -5.496546478798415), 2: ('N', -5.767617318560561), 3: ('C', -6.229620227953575)} 16 | 17 | 18 | if __name__ == '__main__': 19 | test_PropOptEvaluator() -------------------------------------------------------------------------------- /test/ggraph/evaluation/test_RandGenEvaluator.py: -------------------------------------------------------------------------------- 1 | from dig.ggraph.evaluation import RandGenEvaluator 2 | from rdkit import Chem 3 | import shutil 4 | 5 | def test_RandGenEvaluator(): 6 | smile = 'CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1' 7 | mol = Chem.MolFromSmiles(smile) 8 | res_dict = {'mols':[mol], 'train_smiles': [smile]} 9 | evaluator = RandGenEvaluator() 10 | results = evaluator.eval(res_dict) 11 | 12 | assert results == {'valid_ratio': 100.0, 'unique_ratio': 100.0, 'novel_ratio': 0.0} 13 | 14 | 15 | if __name__ == '__main__': 16 | test_RandGenEvaluator() -------------------------------------------------------------------------------- /test/ggraph/utils/test_environment.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from dig.ggraph.utils import steric_strain_filter, zinc_molecule_filter, convert_radical_electrons_to_hydrogens, check_valency 3 | 4 | def test_environment(): 5 | mol = Chem.MolFromSmiles('C') 6 | mol = convert_radical_electrons_to_hydrogens(mol) 7 | 8 | assert steric_strain_filter(mol) 9 | assert zinc_molecule_filter(mol) 10 | assert check_valency(mol) 11 | 12 | if __name__ == '__main__': 13 | test_environment() -------------------------------------------------------------------------------- /test/ggraph/utils/test_gen_mol_from_one_shot_tensor.py: -------------------------------------------------------------------------------- 1 | from dig.ggraph.utils import gen_mol_from_one_shot_tensor 2 | import torch 3 | from rdkit import Chem 4 | 5 | def test_gen_mol_from_one_shot_tensor(): 6 | adj = torch.ones(1, 4, 6, 6) 7 | x = torch.ones(1, 4, 6) 8 | atomic_num_list = [6, 7, 8, 0] 9 | gen_mols = gen_mol_from_one_shot_tensor(adj, x, atomic_num_list) 10 | gen_smiles = Chem.MolToSmiles(gen_mols[0], isomericSmiles=True) 11 | 12 | assert gen_smiles=='C123C45C16C24C356' 13 | 14 | if __name__ == '__main__': 15 | test_gen_mol_from_one_shot_tensor() -------------------------------------------------------------------------------- /test/oodgraph/test_good_datasets.py: -------------------------------------------------------------------------------- 1 | from dig.oodgraph import GOODCBAS, GOODCMNIST, GOODCora, GOODHIV, GOODMotif, GOODPCBA, GOODZINC, GOODArxiv 2 | import pytest 3 | import shutil 4 | 5 | dataset_domain = { 6 | 'GOODHIV': ['scaffold', 'size'], 7 | 'GOODPCBA': ['scaffold', 'size'], 8 | 'GOODZINC': ['scaffold', 'size'], 9 | 'GOODCMNIST': ['color'], 10 | 'GOODMotif': ['basis', 'size'], 11 | 'GOODCora': ['word', 'degree'], 12 | 'GOODArxiv': ['time', 'degree'], 13 | 'GOODCBAS': ['color'] 14 | } 15 | 16 | @pytest.mark.parametrize('dataset_name', list(dataset_domain.keys())) 17 | def test_dataset(dataset_name): 18 | root = 'datasets' 19 | for shift_type in ['no_shift', 'covariate', 'concept']: 20 | for domain in dataset_domain[dataset_name]: 21 | dataset, meta_info = eval(dataset_name).load(root, domain, shift=shift_type) 22 | assert dataset is not None 23 | assert meta_info is not None 24 | shutil.rmtree(root) -------------------------------------------------------------------------------- /test/sslgraph/dataset/test_TUDatasetExt.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from dig.sslgraph.dataset import TUDatasetExt 3 | 4 | def test_TUDatasetExt(): 5 | ## semisupervised 6 | # NCI1 7 | root = './dataset/TUDataset' 8 | dataset = TUDatasetExt(root, name='NCI1', task='semisupervised') 9 | assert len(dataset) == 4110 10 | assert dataset.num_features == 37 11 | 12 | assert dataset[0].x.size() == (21, 37) 13 | assert dataset[0].y.size() == (1,) 14 | assert dataset[0].edge_index.size() == (2, 42) 15 | 16 | shutil.rmtree(root) 17 | 18 | # # REDDIT-BINARY 19 | # dataset = TUDatasetExt(root, name='REDDIT-BINARY', task='semisupervised') 20 | # assert len(dataset) == 2000 21 | # assert dataset.num_features == 0 22 | 23 | # assert dataset[0].y.size() == (1,) 24 | # assert dataset[0].edge_index.size() == (2, 480) 25 | 26 | # shutil.rmtree(root) 27 | 28 | ## unsupervised 29 | # NCI1 30 | root = './dataset/TUDataset' 31 | dataset = TUDatasetExt(root, name='NCI1', task='unsupervised') 32 | assert len(dataset) == 4110 33 | assert dataset.num_features == 37 34 | 35 | assert dataset[0].x.size() == (21, 37) 36 | assert dataset[0].y.size() == (1,) 37 | assert dataset[0].edge_index.size() == (2, 62) 38 | 39 | shutil.rmtree(root) 40 | 41 | # # REDDIT-BINARY 42 | # dataset = TUDatasetExt(root, name='REDDIT-BINARY', task='unsupervised') 43 | # assert len(dataset) == 2000 44 | # assert dataset.num_features == 1 45 | 46 | # assert dataset[0].x.size() == (218, 1) 47 | # assert dataset[0].y.size() == (1,) 48 | # assert dataset[0].edge_index.size() == (2, 697) 49 | 50 | # shutil.rmtree(root) 51 | 52 | if __name__ == '__main__': 53 | test_TUDatasetExt() -------------------------------------------------------------------------------- /test/sslgraph/dataset/test_get_node_dataset.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from dig.sslgraph.dataset import get_node_dataset 3 | 4 | 5 | def test_get_node_dataset(): 6 | root = './dataset/Planetoid' 7 | # Cora 8 | dataset = get_node_dataset('Cora', norm_feat=False, root=root) 9 | assert len(dataset) == 1 10 | assert dataset.num_features == 1433 11 | 12 | assert dataset[0].x.size() == (2708, 1433) 13 | assert dataset[0].y.size() == (2708,) 14 | assert dataset[0].edge_index.size() == (2, 10556) 15 | 16 | shutil.rmtree(root) 17 | 18 | # CiteSeer 19 | dataset = get_node_dataset('CiteSeer', norm_feat=False, root=root) 20 | assert len(dataset) == 1 21 | assert dataset.num_features == 3703 22 | 23 | assert dataset[0].x.size() == (3327, 3703) 24 | assert dataset[0].y.size() == (3327,) 25 | assert dataset[0].edge_index.size() == (2, 9104) 26 | 27 | shutil.rmtree(root) 28 | 29 | # PubMed 30 | dataset = get_node_dataset('PubMed', norm_feat=False, root=root) 31 | assert len(dataset) == 1 32 | assert dataset.num_features == 500 33 | 34 | assert dataset[0].x.size() == (19717, 500) 35 | assert dataset[0].y.size() == (19717,) 36 | assert dataset[0].edge_index.size() == (2, 88648) 37 | 38 | shutil.rmtree(root) 39 | 40 | if __name__ == '__main__': 41 | test_get_node_dataset() -------------------------------------------------------------------------------- /test/sslgraph/evaluation/test_GraphUnsupervised.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from dig.sslgraph.evaluation import GraphUnsupervised 3 | from dig.sslgraph.utils import Encoder 4 | from dig.sslgraph.dataset import get_dataset 5 | from dig.sslgraph.method import GraphCL, InfoGraph, MVGRL 6 | 7 | 8 | def test_GraphUnsupervised(): 9 | root = './dataset' 10 | dataset = get_dataset('MUTAG', task='unsupervised', root=root) 11 | embed_dim = 8 12 | 13 | encoder = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, n_layers=2, gnn='gin', bn=True) 14 | graphcl = GraphCL(embed_dim*2, aug_1=None, aug_2='random2', tau=0.2) 15 | evaluator = GraphUnsupervised(dataset, log_interval=1, p_epoch=1) 16 | test_mean, test_std = evaluator.evaluate(learning_model=graphcl, encoder=encoder) 17 | 18 | assert test_mean <= 1.0 and test_mean >= 0.0 19 | assert test_std is not None 20 | 21 | embed_dim = 8 22 | encoder = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, n_layers=2, gnn='gin', node_level=True) 23 | infograph = InfoGraph(embed_dim*2, embed_dim) 24 | evaluator = GraphUnsupervised(dataset, log_interval=1, p_epoch=1) 25 | test_mean, test_std = evaluator.evaluate(learning_model=infograph, encoder=encoder) 26 | 27 | assert test_mean <= 1.0 and test_mean >= 0.0 28 | assert test_std is not None 29 | 30 | encoder_adj = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, 31 | n_layers=2, gnn='gcn', node_level=True, act='prelu') 32 | encoder_diff = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, 33 | n_layers=2, gnn='gcn', node_level=True, act='prelu', edge_weight=True) 34 | mvgrl = MVGRL(embed_dim*2, embed_dim) 35 | evaluator = GraphUnsupervised(dataset, log_interval=1, p_epoch=1) 36 | test_mean, test_std = evaluator.evaluate(learning_model=mvgrl, encoder=[encoder_adj, encoder_diff]) 37 | 38 | assert test_mean <= 1.0 and test_mean >= 0.0 39 | assert test_std is not None 40 | 41 | 42 | shutil.rmtree(root) 43 | 44 | if __name__ == '__main__': 45 | test_GraphUnsupervised() 46 | -------------------------------------------------------------------------------- /test/sslgraph/evaluation/test_NodeUnsupervised.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from dig.sslgraph.evaluation import NodeUnsupervised 3 | from dig.sslgraph.dataset import get_node_dataset 4 | from dig.sslgraph.utils import Encoder 5 | from dig.sslgraph.method import GRACE, GraphCL, NodeMVGRL 6 | 7 | 8 | def test_NodeUnsupervised(): 9 | root = './dataset' 10 | dataset = get_node_dataset('cora', root=root) 11 | embed_dim = 8 12 | 13 | encoder = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, 14 | n_layers=2, gnn='gcn', node_level=True, graph_level=False) 15 | grace = GRACE(dim=embed_dim, dropE_rate_1=0.2, dropE_rate_2=0.4, 16 | maskN_rate_1=0.3, maskN_rate_2=0.4, tau=0.4) 17 | 18 | evaluator = NodeUnsupervised(dataset, log_interval=1) 19 | evaluator.setup_train_config(p_lr=0.0005, p_epoch=1, p_weight_decay=1e-5, comp_embed_on='cpu') 20 | test_mean = evaluator.evaluate(learning_model=grace, encoder=encoder) 21 | 22 | assert test_mean <= 1.0 and test_mean >= 0.0 23 | 24 | encoder_1 = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, 25 | n_layers=2, gnn='gcn', node_level=True, graph_level=True) 26 | encoder_2 = Encoder(feat_dim=dataset[0].x.shape[1], hidden_dim=embed_dim, 27 | n_layers=2, gnn='gcn', node_level=True, graph_level=True) 28 | mvgrl = NodeMVGRL(z_dim=embed_dim * 2, z_n_dim=embed_dim, diffusion_type='heat') 29 | 30 | evaluator = NodeUnsupervised(dataset, log_interval=1) 31 | evaluator.setup_train_config(p_lr=0.0005, p_epoch=1, p_weight_decay=1e-5, comp_embed_on='cpu') 32 | test_mean = evaluator.evaluate(learning_model=mvgrl, encoder=[encoder_1, encoder_2]) 33 | 34 | assert test_mean <= 1.0 and test_mean >= 0.0 35 | 36 | shutil.rmtree(root) 37 | 38 | 39 | if __name__ == '__main__': 40 | test_NodeUnsupervised() -------------------------------------------------------------------------------- /test/sslgraph/evaluation/test_nce_more_view.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dig.sslgraph.method.contrastive.objectives import NCE_loss, JSE_loss 3 | 4 | def test_nce_more_view(): 5 | zs_1 = torch.randn((32, 32)) 6 | zs_n_1 = torch.randn((32, 32)) 7 | zs_2 = torch.randn((32, 32)) 8 | zs_n_2 = torch.randn((32, 32)) 9 | zs_3 = torch.randn((32, 32)) 10 | zs_n_3 = torch.randn((32, 32)) 11 | sigma = torch.tensor([[0,1,0],[1,0,1],[1,0,1]]) 12 | 13 | loss = NCE_loss(zs=[zs_1,zs_2,zs_3],zs_n=[zs_n_1,zs_n_2,zs_n_3],batch=True,sigma=sigma) 14 | assert loss is not None 15 | loss = NCE_loss(zs=[zs_1,zs_2,zs_3],sigma=sigma) 16 | assert loss is not None 17 | 18 | loss = JSE_loss(zs=[zs_1,zs_2,zs_3],zs_n=[zs_n_1,zs_n_2,zs_n_3],batch=torch.zeros(32).long(),sigma=sigma) 19 | assert loss is not None 20 | loss = JSE_loss(zs=[zs_1,zs_2,zs_3],sigma=sigma) 21 | assert loss is not None 22 | 23 | if __name__ == '__main__': 24 | test_nce_more_view() -------------------------------------------------------------------------------- /test/threedgraph/dataset/test_QM93D.py: -------------------------------------------------------------------------------- 1 | from dig.threedgraph.dataset import QM93D 2 | import shutil 3 | 4 | def test_QM93D(): 5 | root = './dataset' 6 | 7 | dataset = QM93D(root=root) 8 | target='mu' 9 | dataset.data.y = dataset.data[target] 10 | 11 | assert len(dataset) == 130831 12 | assert dataset.__repr__() == 'QM93D(130831)' 13 | 14 | assert len(dataset[0]) == 15 15 | assert dataset[0].y.size() == (1,) 16 | assert dataset[0].z.size() == (5,) 17 | assert dataset[0].pos.size() == (5,3) 18 | assert dataset[0].Cv.size() == (1,) 19 | assert dataset[0].G.size() == (1,) 20 | assert dataset[0].H.size() == (1,) 21 | assert dataset[0].U.size() == (1,) 22 | assert dataset[0].U0.size() == (1,) 23 | assert dataset[0].alpha.size() == (1,) 24 | assert dataset[0].gap.size() == (1,) 25 | assert dataset[0].homo.size() == (1,) 26 | assert dataset[0].lumo.size() == (1,) 27 | assert dataset[0].mu.size() == (1,) 28 | assert dataset[0].r2.size() == (1,) 29 | assert dataset[0].zpve.size() == (1,) 30 | 31 | split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=1000, valid_size=10000, seed=42) 32 | assert split_idx['train'][0] == 112526 33 | assert split_idx['valid'][0] == 120798 34 | assert split_idx['test'][0] == 107901 35 | 36 | shutil.rmtree(root) 37 | -------------------------------------------------------------------------------- /test/threedgraph/evaluation/test_ThreeDEvaluator.py: -------------------------------------------------------------------------------- 1 | from dig.threedgraph.evaluation import ThreeDEvaluator 2 | import numpy as np 3 | import torch 4 | import math 5 | 6 | 7 | def test_ThreeDEvaluator(): 8 | input_dict = {'y_true':np.array([1.0, -0.5]), 'y_pred':np.array([0.6, 0.0])} 9 | evaluator = ThreeDEvaluator() 10 | result = evaluator.eval(input_dict) 11 | assert len(result) == 1 12 | assert type(result['mae']) == float 13 | assert result['mae'] == 0.45 14 | 15 | input_dict = {'y_true':torch.Tensor([1.0, -0.5]), 'y_pred':torch.Tensor([0.6, 0.0])} 16 | evaluator = ThreeDEvaluator() 17 | result = evaluator.eval(input_dict) 18 | assert len(result) == 1 19 | assert type(result['mae']) == float 20 | # https://discuss.pytorch.org/t/item-gives-different-value-than-the-tensor-itself/101826 21 | assert torch.Tensor([result['mae']]) == torch.Tensor([0.45]) 22 | 23 | 24 | if __name__ == '__main__': 25 | test_ThreeDEvaluator() 26 | -------------------------------------------------------------------------------- /test/xgraph/dataset/test_BA_LRP.py: -------------------------------------------------------------------------------- 1 | from dig.xgraph.dataset import BA_LRP 2 | import shutil 3 | 4 | 5 | def test_BA_LRP(): 6 | root = 'datasets' 7 | dataset = BA_LRP(root) 8 | gen_data1 = BA_LRP.gen_class1() 9 | gen_data2 = BA_LRP.gen_class2() 10 | 11 | assert len(dataset) == 20000 12 | assert dataset.data.edge_index.size() == (2, 840000) 13 | assert dataset.data.x.size() == (400000, 1) 14 | assert dataset.data.y.size() == (20000, 1) 15 | 16 | data = dataset[0] 17 | assert data.edge_index.size() == (2, 38) 18 | assert data.x.size() == (20, 1) 19 | assert data.y.size() == (1, 1) 20 | assert gen_data1.edge_index.size() == (2, 38) 21 | assert gen_data1.x.size() == (20, 1) 22 | assert gen_data1.y.size() == (19, 1) 23 | assert gen_data2.edge_index.size() == (2, 46) 24 | assert gen_data2.x.size() == (20, 1) 25 | assert gen_data2.y.size() == (23, 1) 26 | 27 | shutil.rmtree(root) 28 | 29 | 30 | if __name__ == '__main__': 31 | test_BA_LRP() 32 | -------------------------------------------------------------------------------- /test/xgraph/dataset/test_MarginalDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch_geometric.data import Data 4 | from dig.xgraph.dataset import MarginalSubgraphDataset 5 | 6 | 7 | def graph_build_zero_filling(X, edge_index, node_mask: torch.Tensor): 8 | """ subgraph building through masking the unselected nodes with zero features """ 9 | ret_x = X * node_mask.unsqueeze(1) 10 | return ret_x, edge_index 11 | 12 | 13 | def test_MarginalSubgraphDataset(): 14 | num_mask = 10 15 | num_nodes = 6 16 | x = torch.ones(num_nodes, 10) 17 | edge_index = torch.LongTensor([[0, 1, 1, 2, 2, 3, 4, 5, 5], 18 | [2, 2, 5, 0, 1, 5, 5, 1, 3]]) 19 | y = torch.LongTensor([0, 1, 1, 0, 0, 1]) 20 | data = Data(x=x, edge_index=edge_index, y=y) 21 | 22 | node_indices = list(range(num_nodes)) 23 | coalition = [1, 2] 24 | coalition_placeholder = num_nodes 25 | set_include_masks = [] 26 | set_exclude_masks = [] 27 | for mask_idx in range(num_mask): 28 | subset_nodes_from = [node for node in node_indices if node not in coalition] 29 | random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder]) 30 | random_nodes_permutation = np.random.permutation(random_nodes_permutation) 31 | split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0] 32 | selected_nodes = random_nodes_permutation[:split_idx] 33 | set_exclude_mask = np.zeros(num_nodes) 34 | set_exclude_mask[selected_nodes] = 1.0 35 | set_include_mask = set_exclude_mask.copy() 36 | set_include_mask[coalition] = 1.0 37 | 38 | set_exclude_masks.append(set_exclude_mask) 39 | set_include_masks.append(set_include_mask) 40 | 41 | exclude_mask = np.stack(set_exclude_masks, axis=0) 42 | include_mask = np.stack(set_include_masks, axis=0) 43 | marginal_dataloader = \ 44 | MarginalSubgraphDataset(data=data, 45 | exclude_mask=exclude_mask, 46 | include_mask=include_mask, 47 | subgraph_build_func=graph_build_zero_filling) 48 | 49 | exclude_data, include_data = marginal_dataloader[0] 50 | assert exclude_data.x.shape == (6, 10) 51 | assert exclude_data.edge_index.shape == (2, 9) 52 | 53 | assert include_data.x.shape == (6, 10) 54 | assert exclude_data.edge_index.shape == (2, 9) 55 | 56 | assert marginal_dataloader.X.shape == (6, 10) 57 | assert marginal_dataloader.edge_index.shape == (2, 9) 58 | 59 | 60 | if __name__ == '__main__': 61 | test_MarginalSubgraphDataset() 62 | -------------------------------------------------------------------------------- /test/xgraph/dataset/test_MoleculeDataset.py: -------------------------------------------------------------------------------- 1 | from dig.xgraph.dataset import MoleculeDataset 2 | import shutil 3 | 4 | 5 | def test_MoleculeDataset(): 6 | root = 'datasets' 7 | dataset_names = ['mutag', 'bbbp', 'Tox21', 'bace'] 8 | dataset_length = [188, 2039, 7831, 1513] 9 | dataset_x_shape = [(3371, 7), (49068, 9), (145459, 9), (51577, 9)] 10 | dataset_edge_index_shape = [(2, 7442), (2, 105842), (2, 302190), (2, 111536)] 11 | dataset_y_shape = [(188, ), (2039, 1), (7831, 12), (1513, 1)] 12 | 13 | first_data_x_shape = [(17, 7), (20, 9), (16, 9), (32, 9)] 14 | first_data_edge_index_shape = [(2, 38), (2, 40), (2, 34), (2, 70)] 15 | first_data_y_shape = [(1, ), (1, 1), (1, 12), (1, 1)] 16 | 17 | for dataset_idx, name in enumerate(dataset_names): 18 | dataset = MoleculeDataset(root, name) 19 | 20 | assert len(dataset) == dataset_length[dataset_idx] 21 | assert dataset.data.x.size() == dataset_x_shape[dataset_idx] 22 | assert dataset.data.edge_index.size() == dataset_edge_index_shape[dataset_idx] 23 | assert dataset.data.y.size() == dataset_y_shape[dataset_idx] 24 | 25 | data = dataset[0] 26 | assert data.x.size() == first_data_x_shape[dataset_idx] 27 | assert data.edge_index.size() == first_data_edge_index_shape[dataset_idx] 28 | assert data.y.size() == first_data_y_shape[dataset_idx] 29 | 30 | shutil.rmtree(root) 31 | 32 | 33 | if __name__ == '__main__': 34 | test_MoleculeDataset() 35 | -------------------------------------------------------------------------------- /test/xgraph/dataset/test_SynGraphDataset.py: -------------------------------------------------------------------------------- 1 | from dig.xgraph.dataset import SynGraphDataset 2 | import shutil 3 | 4 | 5 | def test_SynGraphDataset(): 6 | root = 'datasets' 7 | dataset_names = ['ba_shapes', 'ba_community', 'tree_grid', 'tree_cycle', 'ba_2motifs'] 8 | dataset_length = [1, 1, 1, 1, 1000] 9 | dataset_x_shape = [(700, 10), (1400, 10), (1231, 10), (871, 10), (25000, 10)] 10 | dataset_edge_index_shape = [(2, 4110), (2, 8920), (2, 3130), (2, 1942), (2, 50960)] 11 | dataset_y_shape = [(700, ), (1400, ), (1231, ), (871, ), (1000, )] 12 | 13 | for dataset_idx, name in enumerate(dataset_names): 14 | dataset = SynGraphDataset(root, name) 15 | 16 | assert len(dataset) == dataset_length[dataset_idx] 17 | assert dataset.data.x.size() == dataset_x_shape[dataset_idx] 18 | assert dataset.data.edge_index.size() == dataset_edge_index_shape[dataset_idx] 19 | assert dataset.data.y.size() == dataset_y_shape[dataset_idx] 20 | 21 | if name == 'ba_2motifs': 22 | data = dataset[0] 23 | assert data.x.size() == (25, 10) 24 | assert data.edge_index.size() == (2, 50) 25 | assert data.y.size() == (1, ) 26 | 27 | shutil.rmtree(root) 28 | 29 | 30 | if __name__ == '__main__': 31 | test_SynGraphDataset() 32 | -------------------------------------------------------------------------------- /test/xgraph/evaluation/test_metrics.py: -------------------------------------------------------------------------------- 1 | from dig.xgraph.evaluation import XCollector, ExplanationProcessor, control_sparsity 2 | from torch_geometric.nn.conv import GCNConv 3 | from torch_geometric.utils.random import barabasi_albert_graph 4 | from torch_geometric.data import Data 5 | import torch 6 | 7 | 8 | def test_metrics(): 9 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 10 | 11 | # --- Create a model --- 12 | model = GCNConv(in_channels=1, out_channels=2).to(device) 13 | 14 | # --- Set the Sparsity to 0.5 --- 15 | sparsity = 0.5 16 | 17 | # --- Create data collector and explanation processor --- 18 | x_collector = XCollector(sparsity) 19 | x_processor = ExplanationProcessor(model=model, device=device) 20 | 21 | # --- Given a 2-class classification with 10 explanation --- 22 | num_classes = 2 23 | for _ in range(10): 24 | # --- Create random ten-node BA graph --- 25 | x = torch.ones((10, 1), dtype=torch.float) 26 | edge_index = barabasi_albert_graph(10, 3) 27 | data = Data(x=x, edge_index=edge_index, y=torch.tensor([1.])) # Assume that y is the ground-truth valuing 1 28 | 29 | # --- Create random explanation --- 30 | masks = [control_sparsity(torch.randn(edge_index.shape[1], device=device), sparsity) for _ in 31 | range(num_classes)] 32 | 33 | # --- Process the explanation including data collection --- 34 | x_processor(data, masks, x_collector) 35 | 36 | # --- Get the evaluation metric results from the data collector --- 37 | print(f'Fidelity: {x_collector.fidelity:.4f}\n' 38 | f'Fidelity_inv: {x_collector.fidelity_inv:.4f}\n' 39 | f'Sparsity: {x_collector.sparsity:.4f}') 40 | 41 | assert x_collector.fidelity is not None 42 | assert x_collector.fidelity_inv is not None 43 | assert x_collector.sparsity is not None 44 | 45 | 46 | if __name__ == '__main__': 47 | test_metrics() 48 | -------------------------------------------------------------------------------- /tutorials/KDD2022/DIG-Tutorial-KDD22.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/divelab/DIG/21476b079c9226f38915dcd082b5c2ee0cddaac8/tutorials/KDD2022/DIG-Tutorial-KDD22.pdf --------------------------------------------------------------------------------