├── .gitignore ├── LICENSE ├── README.md ├── chienn ├── __init__.py ├── data │ ├── __init__.py │ ├── edge_graph │ │ ├── __init__.py │ │ ├── collate_circle_index.py │ │ ├── get_circle_index.py │ │ └── to_edge_graph.py │ ├── featurization │ │ ├── __init__.py │ │ ├── mol_to_data.py │ │ └── smiles_to_3d_mol.py │ └── featurize.py └── model │ ├── __init__.py │ ├── chienn_layer.py │ ├── chienn_model.py │ └── utils.py ├── example.py ├── experiments ├── LICENSE ├── README.md ├── benchmark.py ├── configs │ ├── benchmarks │ │ ├── ChIRo │ │ │ ├── benchmark-bace-chiro.json │ │ │ ├── benchmark-binding_affinity-chiro.json │ │ │ ├── benchmark-binding_rank-chiro.json │ │ │ ├── benchmark-rs-chiro.json │ │ │ └── benchmark-tox21-chiro.json │ │ ├── DMPNN │ │ │ ├── benchmark-bace-dmpnn+with_tags.json │ │ │ ├── benchmark-bace-dmpnn.json │ │ │ ├── benchmark-binding_affinity-dmpnn+with_tags.json │ │ │ ├── benchmark-binding_affinity-dmpnn.json │ │ │ ├── benchmark-binding_rank-dmpnn+with_tags.json │ │ │ ├── benchmark-rs-dmpnn+with_tags.json │ │ │ ├── benchmark-tox21-dmpnn+with_tags.json │ │ │ └── benchmark-tox21-dmpnn.json │ │ ├── GPS+ChiENN │ │ │ ├── benchmark-bace-gps+chienn.json │ │ │ ├── benchmark-binding_affinity-gps+chienn.json │ │ │ ├── benchmark-binding_rank-gps+chienn.json │ │ │ ├── benchmark-rs-gps+chienn.json │ │ │ └── benchmark-tox21-gps+chienn.json │ │ ├── GPS │ │ │ ├── benchmark-bace-gps+with_tags.json │ │ │ ├── benchmark-bace-gps.json │ │ │ ├── benchmark-binding_affinity-gps+with_tags.json │ │ │ ├── benchmark-binding_affinity-gps.json │ │ │ ├── benchmark-binding_rank-gps+with_tags.json │ │ │ ├── benchmark-tox21-gps+with_tags.json │ │ │ └── benchmark-tox21-gps.json │ │ ├── SAN+ChiENN │ │ │ ├── benchmark-bace-san+chienn.json │ │ │ ├── benchmark-binding_affinity-san+chienn.json │ │ │ ├── benchmark-binding_rank-san+chienn.json │ │ │ └── benchmark-tox21-san+chienn.json │ │ ├── SAN │ │ │ ├── benchmark-bace-san+with_tags.json │ │ │ ├── benchmark-bace-san.json │ │ │ ├── benchmark-binding_affinity-san+with_tags.json │ │ │ ├── benchmark-binding_affinity-san.json │ │ │ ├── benchmark-binding_rank-san+with_tags.json │ │ │ ├── benchmark-tox21-san+with_tags.json │ │ │ └── benchmark-tox21-san.json │ │ └── Tetra_DMPNN │ │ │ ├── benchmark-bace-tetra_dmpnn.json │ │ │ ├── benchmark-binding_affinity-tetra_dmpnn.json │ │ │ ├── benchmark-binding_rank-tetra_dmpnn.json │ │ │ ├── benchmark-rs-tetra_dmpnn.json │ │ │ └── benchmark-tox21-tetra_dmpnn.json │ ├── datasets │ │ ├── bace.yaml │ │ ├── binding_affinity.yaml │ │ ├── binding_rank.yaml │ │ ├── rs.yaml │ │ └── tox21.yaml │ └── models │ │ ├── ChIRo │ │ ├── ChIRo.json │ │ ├── ChIRo.yaml │ │ ├── bace-ChIRo.yaml │ │ ├── binding_affinity-ChIRo.yaml │ │ ├── binding_rank-ChIRo.yaml │ │ ├── rs-ChIRo.yaml │ │ └── tox21-ChIRo.yaml │ │ ├── ChiENN │ │ ├── ChiENN.yaml │ │ ├── bace-ChiENN.yaml │ │ ├── binding_affinity-ChiENN.yaml │ │ ├── binding_rank-ChiENN.yaml │ │ ├── rs-ChiENN.yaml │ │ └── tox21-ChiENN.yaml │ │ ├── DMPNN │ │ ├── DMPNN.yaml │ │ ├── bace-DMPNN+with_tags.yaml │ │ ├── bace-DMPNN.yaml │ │ ├── binding_affinity-DMPNN+with_tags.yaml │ │ ├── binding_affinity-DMPNN.yaml │ │ ├── binding_rank-DMPNN+with_tags.yaml │ │ ├── tox21-DMPNN+with_tags.yaml │ │ └── tox21-DMPNN.yaml │ │ ├── GPS+ChiENN │ │ ├── GPS+ChiENN.yaml │ │ ├── bace-GPS+ChiENN.yaml │ │ ├── binding_affinity-GPS+ChiENN.yaml │ │ ├── binding_rank-GPS+ChiENN.yaml │ │ ├── rs-GPS+ChiENN.yaml │ │ └── tox21-GPS+ChiENN.yaml │ │ ├── GPS │ │ ├── GPS.yaml │ │ ├── bace-GPS+with_tags.yaml │ │ ├── bace-GPS.yaml │ │ ├── binding_affinity-GPS+with_tags.yaml │ │ ├── binding_affinity-GPS.yaml │ │ ├── binding_rank-GPS+with_tags.yaml │ │ ├── tox21-GPS+with_tags.yaml │ │ └── tox21-GPS.yaml │ │ ├── SAN+ChiENN │ │ ├── SAN+ChiENN.yaml │ │ ├── bace-SAN+ChiENN.yaml │ │ ├── binding_affinity-SAN+ChiENN.yaml │ │ ├── binding_rank-SAN+ChiENN.yaml │ │ ├── rs-SAN+ChiENN.yaml │ │ └── tox21-SAN+ChiENN.yaml │ │ ├── SAN │ │ ├── SAN.yaml │ │ ├── bace-SAN+with_tags.yaml │ │ ├── bace-SAN.yaml │ │ ├── binding_affinity-SAN+with_tags.yaml │ │ ├── binding_affinity-SAN.yaml │ │ ├── binding_rank-SAN+with_tags.yaml │ │ ├── tox21-SAN+with_tags.yaml │ │ └── tox21-SAN.yaml │ │ ├── Tetra_DMPNN │ │ ├── Tetra_DMPNN.yaml │ │ ├── bace-Tetra_DMPNN.yaml │ │ ├── binding_affinity-Tetra_DMPNN.yaml │ │ ├── binding_rank-Tetra_DMPNN.yaml │ │ ├── rs-Tetra_DMPNN.yaml │ │ └── tox21-Tetra_DMPNN.yaml │ │ └── common.yaml ├── create_dataset.py ├── graphgps │ ├── __init__.py │ ├── act │ │ ├── __init__.py │ │ └── example.py │ ├── config │ │ ├── __init__.py │ │ ├── chienn_config.py │ │ ├── custom_gnn_config.py │ │ ├── custom_model_config.py │ │ ├── dataset_config.py │ │ ├── defaults_config.py │ │ ├── example.py │ │ ├── gt_config.py │ │ ├── optimizers_config.py │ │ ├── posenc_config.py │ │ ├── pretrained_config.py │ │ ├── split_config.py │ │ ├── test_config.py │ │ └── wandb_config.py │ ├── dataset │ │ ├── __init__.py │ │ ├── binding_affinity_dataset.py │ │ ├── chiral_dataset_base.py │ │ ├── collate.py │ │ ├── csv_dataset.py │ │ ├── dataloader.py │ │ ├── ogb_dataset.py │ │ ├── rs_dataset.py │ │ ├── tdc_dataset.py │ │ └── utils.py │ ├── encoder │ │ ├── __init__.py │ │ ├── ast_encoder.py │ │ ├── composed_encoders.py │ │ ├── dummy_edge_encoder.py │ │ ├── equivstable_laplace_pos_encoder.py │ │ ├── example.py │ │ ├── geometric_node_encoder.py │ │ ├── kernel_pos_encoder.py │ │ ├── laplace_pos_encoder.py │ │ ├── linear_edge_encoder.py │ │ ├── linear_node_encoder.py │ │ ├── ppa_encoder.py │ │ ├── signnet_pos_encoder.py │ │ ├── type_dict_encoder.py │ │ └── voc_superpixels_encoder.py │ ├── finetuning.py │ ├── head │ │ ├── __init__.py │ │ ├── example.py │ │ ├── inductive_edge.py │ │ ├── inductive_node.py │ │ ├── ogb_code_graph.py │ │ └── san_graph.py │ ├── layer │ │ ├── __init__.py │ │ ├── bigbird_layer.py │ │ ├── chienn_layer_wrapper.py │ │ ├── example.py │ │ ├── gatedgcn_layer.py │ │ ├── gine_conv_layer.py │ │ ├── gps_layer.py │ │ ├── performer_layer.py │ │ ├── san2_layer.py │ │ ├── san_layer.py │ │ └── utils.py │ ├── loader │ │ ├── __init__.py │ │ ├── dataset │ │ │ ├── __init__.py │ │ │ ├── aqsol_molecules.py │ │ │ ├── coco_superpixels.py │ │ │ ├── malnet_tiny.py │ │ │ ├── pcqm4mv2_contact.py │ │ │ ├── peptides_functional.py │ │ │ ├── peptides_structural.py │ │ │ └── voc_superpixels.py │ │ ├── master_loader.py │ │ ├── ogbg_code2_utils.py │ │ └── split_generator.py │ ├── logger.py │ ├── loss │ │ ├── __init__.py │ │ ├── l1.py │ │ ├── multilabel_classification_loss.py │ │ ├── subtoken_prediction_loss.py │ │ └── weighted_cross_entropy.py │ ├── metric_wrapper.py │ ├── metrics_ogb.py │ ├── network │ │ ├── __init__.py │ │ ├── big_bird.py │ │ ├── chiro.py │ │ ├── custom_gnn.py │ │ ├── dmpnn.py │ │ ├── example.py │ │ ├── gps_model.py │ │ ├── performer.py │ │ ├── san_transformer.py │ │ └── utils.py │ ├── optimizer │ │ ├── __init__.py │ │ └── extra_optimizers.py │ ├── pooling │ │ ├── __init__.py │ │ └── example.py │ ├── stage │ │ ├── __init__.py │ │ └── example.py │ ├── train │ │ ├── __init__.py │ │ ├── custom_train.py │ │ └── example.py │ ├── transform │ │ ├── __init__.py │ │ ├── posenc_stats.py │ │ └── transforms.py │ └── utils.py ├── main.py ├── retrieve_grid_results.py ├── retrieve_results.py ├── setup.py └── submodules │ ├── ChIRo │ ├── LICENSE │ ├── README.md │ ├── experiment_analysis │ │ ├── analyze_RS_experiments.ipynb │ │ ├── analyze_contrastive_experiments.ipynb │ │ └── analyze_docking_experiments.ipynb │ ├── hyperopt │ │ ├── hyperopt_LD.py │ │ ├── hyperopt_RS.py │ │ └── hyperopt_docking.py │ ├── model │ │ ├── alpha_encoder.py │ │ ├── datasets_samplers.py │ │ ├── embedding_functions.py │ │ ├── gnn_3D │ │ │ ├── dimenet_pp.py │ │ │ ├── optimization_functions.py │ │ │ ├── schnet.py │ │ │ ├── spherenet.py │ │ │ ├── train_functions.py │ │ │ └── train_models.py │ │ ├── optimization_functions.py │ │ ├── params_interpreter.py │ │ ├── train_functions.py │ │ └── train_models.py │ ├── params_files │ │ ├── params_LD_ChIRo.json │ │ ├── params_LD_spherenet.json │ │ ├── params_RS_ChIRo.json │ │ ├── params_RS_dimenetpp.json │ │ ├── params_RS_schnet.json │ │ ├── params_RS_spherenet.json │ │ ├── params_binary_ranking_ChIRo.json │ │ ├── params_binary_ranking_spherenet.json │ │ ├── params_contrastive_ChIRo.json │ │ ├── params_contrastive_dimenetpp.json │ │ ├── params_contrastive_schnet.json │ │ └── params_contrastive_spherenet.json │ └── training_scripts │ │ ├── training_LD_classification.py │ │ ├── training_LD_classification_spherenet.py │ │ ├── training_RS_classification.py │ │ ├── training_RS_classification_dimenetpp.py │ │ ├── training_RS_classification_schnet.py │ │ ├── training_RS_classification_spherenet.py │ │ ├── training_binary_ranking.py │ │ ├── training_binary_ranking_spherenet.py │ │ ├── training_contrastive.py │ │ ├── training_contrastive_dimenetpp.py │ │ ├── training_contrastive_schnet.py │ │ └── training_contrastive_spherenet.py │ └── tetra_dmpnn │ ├── Makefile │ ├── README.md │ ├── __init__.py │ ├── devtools │ └── create_env.sh │ ├── environment.yml │ ├── features │ └── featurization.py │ ├── hyperopt.py │ ├── model │ ├── gnn.py │ ├── layers.py │ ├── parsing.py │ ├── tetra.py │ └── training.py │ ├── train.py │ └── utils.py └── images ├── fk.png └── order_example.png /.gitignore: -------------------------------------------------------------------------------- 1 | # CUSTOM 2 | .vscode/ 3 | scripts/pcqm4m/**/*.zip 4 | scripts/pcqm4m/**/*.sdf 5 | scripts/pcqm4m/**/*.xyz 6 | scripts/pcqm4m/**/*.csv 7 | !scripts/pcqm4m/**/periodic_table.csv 8 | scripts/pcqm4m/**/*.gz 9 | scripts/pcqm4m/**/*.tsv 10 | scripts/pcqm4m/pcqm4m-v2/ 11 | slurm_history/ 12 | datasets/ 13 | pretrained/ 14 | results/ 15 | vocprep/benchmark_RELEASE/ 16 | vocprep/voc_viz_files/ 17 | vocprep/VOC/benchmark_RELEASE/ 18 | vocprep/VOC/*.tgz 19 | vocprep/VOC/*.pickle 20 | vocprep/VOC/*.pkl 21 | vocprep/VOC/*.zip 22 | splits/ 23 | wandb/ 24 | __pycache__/ 25 | .idea 26 | *.log 27 | *.bak 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | pip-wheel-metadata/ 52 | share/python-wheels/ 53 | *.egg-info/ 54 | .installed.cfg 55 | *.egg 56 | MANIFEST 57 | 58 | # PyInstaller 59 | # Usually these files are written by a python script from a template 60 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 61 | *.manifest 62 | *.spec 63 | 64 | # Installer logs 65 | pip-log.txt 66 | pip-delete-this-directory.txt 67 | 68 | # Unit test / coverage reports 69 | htmlcov/ 70 | .tox/ 71 | .nox/ 72 | .coverage 73 | .coverage.* 74 | .cache 75 | nosetests.xml 76 | coverage.xml 77 | *.cover 78 | *.py,cover 79 | .hypothesis/ 80 | .pytest_cache/ 81 | 82 | # Translations 83 | *.mo 84 | *.pot 85 | 86 | # Django stuff: 87 | *.log 88 | local_settings.py 89 | db.sqlite3 90 | db.sqlite3-journal 91 | 92 | # Flask stuff: 93 | instance/ 94 | .webassets-cache 95 | 96 | # Scrapy stuff: 97 | .scrapy 98 | 99 | # Sphinx documentation 100 | docs/_build/ 101 | 102 | # PyBuilder 103 | target/ 104 | 105 | # Jupyter Notebook 106 | .ipynb_checkpoints 107 | 108 | # IPython 109 | profile_default/ 110 | ipython_config.py 111 | 112 | # pyenv 113 | .python-version 114 | 115 | # pipenv 116 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 117 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 118 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 119 | # install all needed dependencies. 120 | #Pipfile.lock 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # vim edit buffer 160 | *.swp 161 | 162 | 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Piotr Gaiński 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /chienn/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * 2 | from .data import * 3 | -------------------------------------------------------------------------------- /chienn/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .featurize import smiles_to_data_with_circle_index 2 | from .featurize import collate_with_circle_index 3 | -------------------------------------------------------------------------------- /chienn/data/edge_graph/__init__.py: -------------------------------------------------------------------------------- 1 | from .collate_circle_index import collate_circle_index 2 | from .get_circle_index import get_circle_index 3 | from .to_edge_graph import to_edge_graph 4 | -------------------------------------------------------------------------------- /chienn/data/edge_graph/collate_circle_index.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch_geometric.data import Data 7 | 8 | 9 | def collate_circle_index(batch: List[Data], k_neighbors: int) -> Tensor: 10 | """ 11 | Collates `circle_index` attribute of `Data` objects in `batch` into a single tensor. 12 | 13 | Args: 14 | batch: a list of `Data` objects. 15 | k_neighbors: number of incoming neighbors to consider for each node (k in the paper). 16 | 17 | Returns: 18 | A tensor of shape (num_nodes, circle_size) containing the indices of the (non-parallel) neighbors in the 19 | pre-computed order. The first (k-1) indices for every atom are repeated, e.g. for k=3, the circle_index[0] may 20 | be (i_1, i_2, i_3, ..., i_n, i_1, i_2). Therefore, `circle_size` = `max_num_neighbors` + k-1. 21 | """ 22 | 23 | # To simplify the implementation of ChiENNLayer, we extend each `e.circle_index` with its first `k - 1` elements: 24 | repeated_circle_index = [_repeat_first_elements(e.circle_index, k_neighbors) for e in batch] 25 | circle_index = _collate_repeated(repeated_circle_index) 26 | return circle_index 27 | 28 | 29 | def _collate_repeated(circle_index_list: List[List[List[int]]]) -> Tensor: 30 | """ 31 | Collates `circle_index_list` into a single tensor. 32 | 33 | Args: 34 | circle_index_list: a list of `circle_index` lists. Each of `batch_size` lists contains `num_nodes` lists of 35 | `circle_index` with `circle_size` indices. 36 | 37 | Returns: 38 | A tensor of shape (total_num_nodes, max_circle_size) containing the indices defining a node order for every node. 39 | """ 40 | circle_index_tensor_list = [] 41 | for circle_index in circle_index_list: 42 | n_nodes = len(circle_index_tensor_list) 43 | circle_index_tensor_list.extend(torch.tensor(circle).long() + n_nodes for circle in circle_index) 44 | return pad_sequence(circle_index_tensor_list, batch_first=True, padding_value=-1) 45 | 46 | 47 | def _repeat_first_elements(circle_index_list: List[List[int]], k: int) -> List[List[int]]: 48 | """ 49 | Extends each `circle_index` from `circle_index_list` with its first `k - 1` elements. 50 | """ 51 | def _repeat(circle_index: List[int]) -> List[int]: 52 | """ 53 | Extends `circle_index` list with its first `k - 1` elements. 54 | """ 55 | if len(circle_index) == 0: 56 | return circle_index 57 | n = len(circle_index) + k - 1 58 | circle_index = circle_index * k 59 | return circle_index[:n] 60 | 61 | return [_repeat(circle_index) for circle_index in circle_index_list] 62 | -------------------------------------------------------------------------------- /chienn/data/edge_graph/to_edge_graph.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | import torch_geometric 5 | import torch_geometric.data 6 | from torch_geometric.data import Data 7 | from torch_geometric.utils import is_undirected, to_undirected 8 | 9 | from chienn.data.edge_graph.get_circle_index import get_circle_index 10 | 11 | 12 | def to_edge_graph(data: Data) -> Data: 13 | """ 14 | Converts the graph to a graph of edges. Every directed edge (a, b) with index i becomes a node (denoted with node') 15 | with attribute of the form data.x[a] | data.edge_attr[i]. Then every node' (x, a) incoming to node a is connected 16 | with node' (a, b) with directed edge'. For compatibility with GINE, edge_attr' of edge' (a, b) -> (b, c) are set 17 | to data.edge_attr[j], where j is the index of edge (a, b). 18 | 19 | Args: 20 | data: torch geometric data with nodes attributes (x), edge attributes (edge_attr) and edge indices (edge_index) 21 | 22 | Returns: 23 | Graph of edges. 24 | """ 25 | 26 | if not is_undirected(data.edge_index): 27 | edge_index, edge_attr = to_undirected(edge_index=data.edge_index, edge_attr=data.edge_attr) 28 | else: 29 | edge_index, edge_attr = data.edge_index, data.edge_attr 30 | 31 | new_nodes = [] 32 | new_nodes_to_idx = {} 33 | for edge, edge_attr in zip(edge_index.T, edge_attr): 34 | a, b = edge 35 | a, b = a.item(), b.item() 36 | a2b = torch.cat([data.x[a], edge_attr, data.x[b]]) # x_{i, j} = x'_i | e'_{i, j} | x'_j. 37 | pos = torch.cat([data.pos[a], data.pos[b]]) 38 | new_nodes_to_idx[(a, b)] = len(new_nodes) 39 | new_nodes.append( 40 | {'a': a, 'b': b, 'a_attr': data.x[a], 'node_attr': a2b, 'old_edge_attr': edge_attr, 'pos': pos}) 41 | 42 | in_nodes = defaultdict(list) 43 | for i, node_dict in enumerate(new_nodes): 44 | a, b = node_dict['a'], node_dict['b'] 45 | in_nodes[b].append({'node_idx': i, 'start_node_idx': a}) 46 | 47 | new_edges = [] 48 | for i, node_dict in enumerate(new_nodes): 49 | a, b = node_dict['a'], node_dict['b'] 50 | ab_old_edge_attr = node_dict['old_edge_attr'] 51 | a_attr = node_dict['a_attr'] 52 | a_in_nodes_indices = [d['node_idx'] for d in in_nodes[a]] 53 | for in_node_c in a_in_nodes_indices: 54 | in_node = new_nodes[in_node_c] 55 | ca_old_edge_attr = in_node['old_edge_attr'] 56 | # e_{(i, j), (j, k)} = e'_(i, j) | x'_j | e'_{k, j}: 57 | edge_attr = torch.cat([ca_old_edge_attr, a_attr, ab_old_edge_attr]) 58 | new_edges.append({'edge': [in_node_c, i], 'edge_attr': edge_attr}) 59 | 60 | parallel_node_index = [] 61 | for node_dict in new_nodes: 62 | a, b = node_dict['a'], node_dict['b'] 63 | parallel_idx = new_nodes_to_idx[(b, a)] 64 | parallel_node_index.append(parallel_idx) 65 | 66 | new_x = [d['node_attr'] for d in new_nodes] 67 | new_pos = [d['pos'] for d in new_nodes] 68 | new_edge_index = [d['edge'] for d in new_edges] 69 | new_edge_attr = [d['edge_attr'] for d in new_edges] 70 | new_x = torch.stack(new_x) 71 | new_pos = torch.stack(new_pos) 72 | new_edge_index = torch.tensor(new_edge_index).T 73 | new_edge_attr = torch.stack(new_edge_attr) 74 | parallel_node_index = torch.tensor(parallel_node_index) 75 | 76 | data = torch_geometric.data.Data(x=new_x, edge_index=new_edge_index, edge_attr=new_edge_attr, pos=new_pos) 77 | data.parallel_node_index = parallel_node_index 78 | data.circle_index = get_circle_index(data, clockwise=False) 79 | return data 80 | -------------------------------------------------------------------------------- /chienn/data/featurization/__init__.py: -------------------------------------------------------------------------------- 1 | from .mol_to_data import mol_to_data 2 | from .smiles_to_3d_mol import smiles_to_3d_mol -------------------------------------------------------------------------------- /chienn/data/featurization/smiles_to_3d_mol.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from rdkit import Chem 4 | from rdkit.Chem import AllChem 5 | 6 | 7 | def smiles_to_3d_mol(smiles: str, max_number_of_atoms: int = 100, max_number_of_attempts: int = 5000): 8 | """ 9 | Embeds the molecule in 3D space. 10 | Args: 11 | smiles: a smile representing molecule 12 | max_number_of_atoms: maximal number of atoms in a molecule. Molecules with more atoms will be omitted. 13 | max_number_of_attempts: maximal number of attempts during the embedding. 14 | max_number_of_attempts: max number of embeddings attempts. 15 | 16 | Returns: 17 | Embedded molecule. 18 | """ 19 | 20 | mol = Chem.MolFromSmiles(smiles) 21 | smiles = Chem.MolToSmiles(mol) 22 | smiles = smiles.split('.')[0] 23 | mol = Chem.MolFromSmiles(smiles) 24 | if len(mol.GetAtoms()) > max_number_of_atoms: 25 | logging.warning(f'Omitting molecule {smiles} as it contains more than {max_number_of_atoms} atoms.') 26 | return None 27 | if len(mol.GetAtoms()) == 0: 28 | logging.warning(f'Omitting molecule {smiles} as it contains no atoms after desaltization.') 29 | return None 30 | mol = Chem.AddHs(mol) 31 | res = AllChem.EmbedMolecule(mol, maxAttempts=max_number_of_attempts, randomSeed=0) 32 | if res < 0: # try to embed with different method 33 | res = AllChem.EmbedMolecule(mol, useRandomCoords=True, maxAttempts=max_number_of_attempts, 34 | randomSeed=0) 35 | if res < 0: 36 | logging.warning(f'Omitting molecule {smiles} as cannot be embedded in 3D space properly.') 37 | return None 38 | try: 39 | AllChem.UFFOptimizeMolecule(mol) 40 | except Exception as e: 41 | logging.warning( 42 | f"Omitting molecule {smiles} as cannot be properly optimized. " 43 | f"The original error message was: {e}." 44 | ) 45 | return None 46 | return mol 47 | -------------------------------------------------------------------------------- /chienn/data/featurize.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from torch_geometric.data import Data, Batch 4 | 5 | from chienn.data.edge_graph.collate_circle_index import collate_circle_index 6 | from chienn.data.edge_graph.to_edge_graph import to_edge_graph 7 | from chienn.data.featurization.mol_to_data import mol_to_data 8 | from chienn.data.featurization.smiles_to_3d_mol import smiles_to_3d_mol 9 | 10 | 11 | def smiles_to_data_with_circle_index(smiles: str) -> Data: 12 | """ 13 | Transforms a SMILES string into a torch_geometric Data object that can be fed into the ChiENNLayer. 14 | Args: 15 | smiles: a SMILES string. 16 | 17 | Returns: 18 | Data object containing the following attributes: 19 | - x (num_nodes,): node features. 20 | - edge_index (2, num_edges): edge index. 21 | - circle_index (num_nodes, circle_size): neighbors indices ordered around a node. 22 | """ 23 | mol = smiles_to_3d_mol(smiles) 24 | data = mol_to_data(mol) 25 | data = to_edge_graph(data) 26 | data.pos = None 27 | return data 28 | 29 | 30 | def collate_with_circle_index(data_list: List[Data], k_neighbors: int) -> Batch: 31 | """ 32 | Collates a list of Data objects into a Batch object. 33 | 34 | Args: 35 | data_list: a list of Data objects. Each Data object must contain `circle_index` attribute. 36 | k_neighbors: number of k consecutive neighbors to be used in the message passing step. 37 | 38 | Returns: 39 | Batch object containing the collate attributes from data objects, including `circle_index` collated 40 | to shape (total_num_nodes, max_circle_size). 41 | """ 42 | batch = Batch.from_data_list(data_list, exclude_keys=['circle_index']) 43 | batch.circle_index = collate_circle_index(data_list, k_neighbors) 44 | return batch 45 | -------------------------------------------------------------------------------- /chienn/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .chienn_layer import ChiENNLayer 2 | -------------------------------------------------------------------------------- /chienn/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def _build_single_embedding_layer(in_dim: int, out_dim: int, name: str): 6 | if name == 'linear': 7 | return nn.Linear(in_dim, out_dim, bias=False) 8 | elif name == 'identity': 9 | return nn.Identity() 10 | elif name == 'scalar': 11 | return nn.Linear(in_dim, 1, bias=True) 12 | elif name == 'self_concat': 13 | return lambda x: torch.cat([x, x], dim=-1) 14 | elif name == 'double': 15 | return lambda x: 2 * x 16 | elif hasattr(torch.nn, name): 17 | return getattr(torch.nn, name)() 18 | else: 19 | raise NotImplementedError(f'Layer name {name} is not implemented.') 20 | 21 | 22 | def build_embedding_layer(in_dim: int, out_dim: int, name: str): 23 | sub_names = name.split('+') 24 | if len(sub_names) == 1: 25 | return _build_single_embedding_layer(in_dim, out_dim, sub_names[0]) 26 | else: 27 | return nn.Sequential(*[_build_single_embedding_layer(in_dim, out_dim, sub_name) for sub_name in sub_names]) 28 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from chienn import smiles_to_data_with_circle_index, collate_with_circle_index 2 | from chienn.model.chienn_model import ChiENNModel 3 | 4 | k_neighbors = 3 5 | model = ChiENNModel(k_neighbors=k_neighbors) 6 | 7 | smiles_list = ['C[C@H](C(=O)O)O', 'C[C@@H](C(=O)O)O'] 8 | data_list = [smiles_to_data_with_circle_index(smiles) for smiles in smiles_list] 9 | batch = collate_with_circle_index(data_list, k_neighbors=k_neighbors) 10 | 11 | output = model(batch) 12 | assert output[0] != output[1] 13 | -------------------------------------------------------------------------------- /experiments/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ladislav Rampášek, Michael Galkin, Vijay Prakash Dwivedi, Dominique Beaini 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | # ChiENN - experiments 2 | 3 | This module contains code for running experiments with ChiENN and baselines models. It was adapted from [GraphGPS repository](https://github.com/rampasek/GraphGPS). 4 | 5 | 6 | ### Python environment setup with Conda 7 | 8 | ```bash 9 | conda create -n chienn-experiments python=3.9 10 | conda activate chienn-experiments 11 | 12 | # Packages used also in the chienn module: 13 | conda install pytorch=1.10 torchvision torchaudio -c pytorch -c nvidia 14 | conda install pyg=2.0.4 -c pyg -c conda-forge 15 | conda install rdkit -c conda-forge 16 | 17 | # Packages used only in the experiments module: 18 | conda install openbabel fsspec -c conda-forge 19 | pip install torchmetrics 20 | pip install performer-pytorch 21 | pip install ogb 22 | pip install tensorboardX 23 | pip install wandb 24 | pip install PyTDC 25 | pip install chainer-chemistry 26 | pip install schnetpack==1.0.1 27 | 28 | conda clean --all 29 | ``` 30 | 31 | 32 | ### Running ChiENN 33 | ```bash 34 | conda activate chienn-experiments 35 | 36 | # Running ChiENN with parameters tuned for binding_rank: 37 | python main.py --cfg configs/models/ChiENN/binding_rank-ChiENN.yaml wandb.use False 38 | ``` 39 | 40 | 41 | ### Benchmarking ChiENN 42 | To run a benchmark that tunes the hypeparameters on a validation set and then evaluates the model on a test set, use the `benchmark.py` script and configs from `configs/benchmarks/`: 43 | ```bash 44 | conda activate chienn-experiments 45 | # Run 3 repeats with seed=0,1,2: 46 | python main.py --cfg configs/benchmarks/benchmark-binding_rank-ChiENN.json --repeat 3 wandb.use False 47 | ``` 48 | 49 | 50 | ### W&B logging 51 | To use W&B logging, set `wandb.use True` and have a `chienn` entity set-up in your W&B account (or change it to whatever else you like by setting `wandb.entity` in `configs/models/common.yaml`). 52 | 53 | 54 | ## Citation 55 | 56 | If you find this work useful, please cite our paper: 57 | ```bibtex 58 | @article{chienn, 59 | title={{ChiENN: Embracing Molecular Chirality with Graph Neural Networks}}, 60 | author={Piotr Gaiński, Michał Koziarski, Jacek Tabor, Marek Śmieja}, 61 | year={2023} 62 | } 63 | ``` 64 | and the paper that introduced GraphGPS: 65 | ```bibtex 66 | @article{rampasek2022GPS, 67 | title={{Recipe for a General, Powerful, Scalable Graph Transformer}}, 68 | author={Ladislav Ramp\'{a}\v{s}ek and Mikhail Galkin and Vijay Prakash Dwivedi and Anh Tuan Luu and Guy Wolf and Dominique Beaini}, 69 | journal={arXiv:2205.12454}, 70 | year={2022} 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/ChIRo/benchmark-bace-chiro.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/ChIRo/bace-ChIRo.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 3, 4], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/ChIRo/benchmark-binding_affinity-chiro.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/ChIRo/binding_affinity-ChIRo.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 3, 4], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/ChIRo/benchmark-binding_rank-chiro.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/ChIRo/binding_rank-ChIRo.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 3, 4], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/ChIRo/benchmark-rs-chiro.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/ChIRo/rs-ChIRo.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 3, 4], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/ChIRo/benchmark-tox21-chiro.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/ChIRo/tox21-ChIRo.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 3, 4], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/DMPNN/benchmark-bace-dmpnn+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/DMPNN/bace-DMPNN+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/DMPNN/benchmark-bace-dmpnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/DMPNN/bace-DMPNN.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/DMPNN/benchmark-binding_affinity-dmpnn+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/DMPNN/binding_affinity-DMPNN+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/DMPNN/benchmark-binding_affinity-dmpnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/DMPNN/binding_affinity-DMPNN.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/DMPNN/benchmark-binding_rank-dmpnn+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/DMPNN/binding_rank-DMPNN+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/DMPNN/benchmark-rs-dmpnn+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/DMPNN/rs-DMPNN+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/DMPNN/benchmark-tox21-dmpnn+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/DMPNN/tox21-DMPNN+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/DMPNN/benchmark-tox21-dmpnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/DMPNN/tox21-DMPNN.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS+ChiENN/benchmark-bace-gps+chienn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS+ChiENN/bace-GPS+ChiENN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS+ChiENN/benchmark-binding_affinity-gps+chienn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS+ChiENN/binding_affinity-GPS+ChiENN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS+ChiENN/benchmark-binding_rank-gps+chienn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS+ChiENN/binding_rank-GPS+ChiENN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS+ChiENN/benchmark-rs-gps+chienn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS+ChiENN/rs-GPS+ChiENN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS+ChiENN/benchmark-tox21-gps+chienn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS+ChiENN/tox21-GPS+ChiENN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS/benchmark-bace-gps+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS/bace-GPS+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS/benchmark-bace-gps.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS/bace-GPS.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS/benchmark-binding_affinity-gps+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS/binding_affinity-GPS+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS/benchmark-binding_affinity-gps.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS/binding_affinity-GPS.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS/benchmark-binding_rank-gps+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS/binding_rank-GPS+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS/benchmark-tox21-gps+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS/tox21-GPS+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/GPS/benchmark-tox21-gps.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/GPS/tox21-GPS.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN+ChiENN/benchmark-bace-san+chienn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN+ChiENN/bace-SAN+ChiENN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN+ChiENN/benchmark-binding_affinity-san+chienn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN+ChiENN/binding_affinity-SAN+ChiENN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN+ChiENN/benchmark-binding_rank-san+chienn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN+ChiENN/binding_rank-SAN+ChiENN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN+ChiENN/benchmark-tox21-san+chienn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN+ChiENN/tox21-SAN+ChiENN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN/benchmark-bace-san+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN/bace-SAN+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN/benchmark-bace-san.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN/bace-SAN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN/benchmark-binding_affinity-san+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN/binding_affinity-SAN+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN/benchmark-binding_affinity-san.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN/binding_affinity-SAN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN/benchmark-binding_rank-san+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN/binding_rank-SAN+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN/benchmark-tox21-san+with_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN/tox21-SAN+with_tags.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/SAN/benchmark-tox21-san.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/SAN/tox21-SAN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/Tetra_DMPNN/benchmark-bace-tetra_dmpnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/Tetra_DMPNN/bace-Tetra_DMPNN.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/Tetra_DMPNN/benchmark-binding_affinity-tetra_dmpnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/Tetra_DMPNN/binding_affinity-Tetra_DMPNN.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/Tetra_DMPNN/benchmark-binding_rank-tetra_dmpnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/Tetra_DMPNN/binding_rank-Tetra_DMPNN.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/Tetra_DMPNN/benchmark-rs-tetra_dmpnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/Tetra_DMPNN/rs-Tetra_DMPNN.yaml", 4 | "params_grid": [ 5 | { 6 | "gnn.layers": [2, 4, 6], 7 | "model.hidden_dim": [300, 600, 900], 8 | "gnn.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gnn.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gnn.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gnn.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gnn.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/benchmarks/Tetra_DMPNN/benchmark-tox21-tetra_dmpnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_subset_size": 10000, 3 | "base_config_path": "configs/models/Tetra_DMPNN/tox21-Tetra_DMPNN.yaml", 4 | "params_grid": [ 5 | { 6 | "gt.layers": [3, 6, 10], 7 | "model.hidden_dim": [64, 128, 256], 8 | "gt.dropout": [0.2], 9 | "optim.base_lr": [1e-4] 10 | }, 11 | { 12 | "gt.layers": ["@BestParam()"], 13 | "model.hidden_dim": ["@BestParam()"], 14 | "gt.dropout": [0.0, 0.5], 15 | "optim.base_lr": [1e-4] 16 | }, 17 | { 18 | "gt.layers": ["@BestParam()"], 19 | "model.hidden_dim": ["@BestParam()"], 20 | "gt.dropout": ["@BestParam()"], 21 | "optim.base_lr": [1e-3, 1e-5] 22 | } 23 | ] 24 | } 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /experiments/configs/datasets/bace.yaml: -------------------------------------------------------------------------------- 1 | metric_best: auc 2 | metric_agg: argmax 3 | dataset: 4 | format: ChIRo-CSV 5 | name: BACE 6 | task: graph 7 | task_type: classification 8 | train: 9 | sampler: full_batch 10 | val: 11 | sampler: full_batch 12 | test: 13 | sampler: full_batch 14 | model: 15 | loss_fun: cross_entropy -------------------------------------------------------------------------------- /experiments/configs/datasets/binding_affinity.yaml: -------------------------------------------------------------------------------- 1 | metric_best: mae 2 | metric_agg: argmin 3 | dataset: 4 | format: ChIRo 5 | name: binding_affinity 6 | task: graph 7 | task_type: regression 8 | single_conformer: True 9 | single_enantiomer: False 10 | train: 11 | sampler: full_batch 12 | val: 13 | sampler: full_batch 14 | test: 15 | sampler: full_batch 16 | model: 17 | loss_fun: l1 18 | -------------------------------------------------------------------------------- /experiments/configs/datasets/binding_rank.yaml: -------------------------------------------------------------------------------- 1 | metric_best: ranking_accuracy_0.3 # it was reported in the ChIRo paper 2 | metric_agg: argmax 3 | dataset: 4 | format: ChIRo 5 | name: binding_affinity 6 | task: graph 7 | task_type: regression_rank 8 | single_conformer: True 9 | single_enantiomer: False 10 | train: 11 | sampler: single_conformer_sampler 12 | val: 13 | sampler: full_batch 14 | test: 15 | sampler: full_batch # it is consistent with ChIRo evaluation method on RS when dataset.single_conformer=False. 16 | model: 17 | loss_fun: l1 18 | -------------------------------------------------------------------------------- /experiments/configs/datasets/rs.yaml: -------------------------------------------------------------------------------- 1 | metric_best: accuracy 2 | metric_agg: argmax 3 | dataset: 4 | format: ChIRo 5 | name: RS 6 | task: graph 7 | task_type: classification 8 | single_conformer: True 9 | train: 10 | sampler: single_conformer_sampler 11 | val: 12 | sampler: full_batch 13 | test: 14 | sampler: full_batch # it is consistent with ChIRo evaluation method on RS when dataset.single_conformer=False. 15 | model: 16 | loss_fun: cross_entropy 17 | 18 | -------------------------------------------------------------------------------- /experiments/configs/datasets/tox21.yaml: -------------------------------------------------------------------------------- 1 | metric_best: auc 2 | metric_agg: argmax 3 | dataset: 4 | format: ChIRo-CSV 5 | name: Tox21 6 | task: graph 7 | task_type: classification_multilabel 8 | share: 9 | dim_out: 12 # it cannot be 2, because it will be changed in torchgeometric xd 10 | train: 11 | sampler: full_batch 12 | val: 13 | sampler: full_batch 14 | test: 15 | sampler: full_batch 16 | model: 17 | loss_fun: cross_entropy 18 | -------------------------------------------------------------------------------- /experiments/configs/models/ChIRo/ChIRo.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers_dict": 3 | { 4 | "EConv_mlp_hidden_sizes": [64], 5 | "GAT_hidden_node_sizes": [32, 32], 6 | 7 | "encoder_hidden_sizes_D": [128, 128], 8 | "encoder_hidden_sizes_phi": [128, 128], 9 | "encoder_hidden_sizes_c": [128, 128], 10 | "encoder_hidden_sizes_alpha": [128, 128], 11 | 12 | "encoder_hidden_sizes_sinusoidal_shift": [256, 256], 13 | "output_mlp_hidden_sizes": [64, 64] 14 | }, 15 | 16 | 17 | "activation_dict": 18 | { 19 | "encoder_hidden_activation_D": "torch.nn.LeakyReLU(negative_slope=0.01)", 20 | "encoder_hidden_activation_phi": "torch.nn.LeakyReLU(negative_slope=0.01)", 21 | "encoder_hidden_activation_c": "torch.nn.LeakyReLU(negative_slope=0.01)", 22 | "encoder_hidden_activation_alpha": "torch.nn.LeakyReLU(negative_slope=0.01)", 23 | "encoder_hidden_activation_sinusoidal_shift": "torch.nn.LeakyReLU(negative_slope=0.01)", 24 | 25 | "encoder_output_activation_D": "torch.nn.Identity()", 26 | "encoder_output_activation_phi": "torch.nn.Identity()", 27 | "encoder_output_activation_c": "torch.nn.Identity()", 28 | "encoder_output_activation_alpha": "torch.nn.Identity()", 29 | "encoder_output_activation_sinusoidal_shift": "torch.nn.Identity()", 30 | 31 | "EConv_mlp_hidden_activation": "torch.nn.LeakyReLU(negative_slope=0.01)", 32 | "EConv_mlp_output_activation": "torch.nn.Identity()", 33 | 34 | "output_mlp_hidden_activation": "torch.nn.LeakyReLU(negative_slope=0.01)", 35 | "output_mlp_output_activation": "torch.nn.Identity()" 36 | }, 37 | 38 | "F_z_list": [64, 64, 64], 39 | "GAT_N_heads": 4, 40 | "EConv_bias": true, 41 | "GAT_bias": true, 42 | "encoder_biases": true, 43 | 44 | "chiral_message_passing": true, 45 | "CMP_EConv_MLP_hidden_sizes": [32], 46 | "CMP_GAT_N_heads": 2, 47 | 48 | "encoder_reduction": "sum", 49 | 50 | "output_concatenation_mode": "both" 51 | 52 | } 53 | -------------------------------------------------------------------------------- /experiments/configs/models/ChIRo/ChIRo.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | node_encoder_in_dim: 52 3 | edge_encoder_in_dim: 14 4 | chiral_tags: False 5 | model: 6 | type: ChIRo 7 | config_path: "configs/models/ChIRo/ChIRo.json" 8 | hidden_dim: 64 9 | gnn: 10 | layers: 3 11 | dropout: 0.0 12 | -------------------------------------------------------------------------------- /experiments/configs/models/ChIRo/bace-ChIRo.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/ChIRo/ChIRo.yaml, configs/datasets/bace.yaml] 2 | gnn: 3 | dropout: 0.0 4 | layers: 4 5 | model: 6 | hidden_dim: 128 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/ChIRo/binding_affinity-ChIRo.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/ChIRo/ChIRo.yaml, configs/datasets/binding_affinity.yaml] 2 | gnn: 3 | dropout: 0.0 4 | layers: 2 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/ChIRo/binding_rank-ChIRo.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/ChIRo/ChIRo.yaml, configs/datasets/binding_rank.yaml] 2 | gnn: 3 | dropout: 0.0 4 | layers: 3 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/ChIRo/rs-ChIRo.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/ChIRo/ChIRo.yaml, configs/datasets/rs.yaml] 2 | gnn: 3 | dropout: 0.0 4 | layers: 2 5 | model: 6 | hidden_dim: 64 7 | optim: 8 | base_lr: 0.001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/ChIRo/tox21-ChIRo.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/ChIRo/ChIRo.yaml, configs/datasets/tox21.yaml] 2 | gnn: 3 | dropout: 0.2 4 | layers: 4 5 | model: 6 | hidden_dim: 128 7 | optim: 8 | base_lr: 0.001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/ChiENN/ChiENN.yaml: -------------------------------------------------------------------------------- 1 | # parameters mainly taken from zinc-GPS+RWSE.yaml 2 | dataset: 3 | pre_transform_name: edge_graph 4 | chiral_tags: False 5 | transductive: False 6 | node_encoder: True 7 | node_encoder_name: LinearNode+RWSE 8 | node_encoder_in_dim: 118 9 | node_encoder_bn: False 10 | edge_encoder: False 11 | posenc_RWSE: 12 | enable: True 13 | kernel: 14 | times_func: range(1,21) 15 | model: Linear 16 | dim_pe: 28 17 | raw_norm_type: BatchNorm 18 | train: 19 | mode: custom 20 | batch_size: 32 21 | eval_period: 1 22 | ckpt_period: 100 23 | model: 24 | type: GPSModel 25 | edge_decoding: dot 26 | graph_pooling: add 27 | hidden_dim: 64 28 | chienn: 29 | message: 30 | k_neighbors_embeddings_names: ['linear', 'linear', 'linear'] 31 | final_embedding_name: 'ELU+linear' 32 | aggregate: 33 | self_embedding_name: 'linear' 34 | parallel_embedding_name: 'linear' 35 | aggregation: 'sum' 36 | post_aggregation_embedding_name: 'ELU' 37 | gt: 38 | layer_type: ChiENN+None # CustomGatedGCN+Performer 39 | layers: 5 40 | n_heads: 4 41 | dropout: 0.0 42 | attn_dropout: 0.5 43 | layer_norm: False 44 | batch_norm: True 45 | gnn: 46 | head: san_graph 47 | layers_pre_mp: 0 48 | layers_post_mp: 3 # Not used when `gnn.head: san_graph` 49 | batchnorm: True 50 | act: relu 51 | dropout: 0.0 52 | agg: mean 53 | normalize_adj: False 54 | 55 | -------------------------------------------------------------------------------- /experiments/configs/models/ChiENN/bace-ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/ChiENN/ChiENN.yaml, configs/datasets/bace.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 10 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 1.0e-05 9 | -------------------------------------------------------------------------------- /experiments/configs/models/ChiENN/binding_affinity-ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/ChiENN/ChiENN.yaml, configs/datasets/binding_affinity.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 10 5 | model: 6 | hidden_dim: 128 7 | optim: 8 | base_lr: 0.001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/ChiENN/binding_rank-ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/ChiENN/ChiENN.yaml, configs/datasets/binding_rank.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 10 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/ChiENN/rs-ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/ChiENN/ChiENN.yaml, configs/datasets/rs.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 3 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/ChiENN/tox21-ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/ChiENN/ChiENN.yaml, configs/datasets/tox21.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 3 5 | model: 6 | hidden_dim: 128 7 | optim: 8 | base_lr: 0.001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/DMPNN/DMPNN.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | node_encoder_in_dim: 52 3 | edge_encoder_in_dim: 14 4 | chiral_tags: False 5 | train: 6 | mode: custom 7 | batch_size: 32 8 | eval_period: 1 9 | ckpt_period: 100 10 | model: 11 | type: DMPNN 12 | graph_pooling: sum 13 | hidden_dim: 300 14 | gnn: 15 | layers: 3 16 | dropout: 0.0 17 | tetra: 18 | use: False 19 | -------------------------------------------------------------------------------- /experiments/configs/models/DMPNN/bace-DMPNN+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/DMPNN/DMPNN.yaml, configs/datasets/bace.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gnn: 5 | dropout: 0.2 6 | layers: 6 7 | model: 8 | hidden_dim: 900 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/DMPNN/bace-DMPNN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/DMPNN/DMPNN.yaml, configs/datasets/bace.yaml] 2 | gnn: 3 | dropout: 0.2 4 | layers: 6 5 | model: 6 | hidden_dim: 600 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/DMPNN/binding_affinity-DMPNN+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/DMPNN/DMPNN.yaml, configs/datasets/binding_affinity.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gnn: 5 | dropout: 0.2 6 | layers: 6 7 | model: 8 | hidden_dim: 600 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/DMPNN/binding_affinity-DMPNN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/DMPNN/DMPNN.yaml, configs/datasets/binding_affinity.yaml] 2 | gnn: 3 | dropout: 0.2 4 | layers: 6 5 | model: 6 | hidden_dim: 600 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/DMPNN/binding_rank-DMPNN+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/DMPNN/DMPNN.yaml, configs/datasets/binding_rank.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gnn: 5 | dropout: 0.2 6 | layers: 6 7 | model: 8 | hidden_dim: 900 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/DMPNN/tox21-DMPNN+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/DMPNN/DMPNN.yaml, configs/datasets/tox21.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gnn: 5 | dropout: 0.2 6 | layers: 6 7 | model: 8 | hidden_dim: 600 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/DMPNN/tox21-DMPNN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/DMPNN/DMPNN.yaml, configs/datasets/tox21.yaml] 2 | gnn: 3 | dropout: 0.2 4 | layers: 6 5 | model: 6 | hidden_dim: 600 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS+ChiENN/GPS+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | # parameters mainly taken from zinc-GPS+RWSE.yaml 2 | dataset: 3 | pre_transform_name: edge_graph 4 | chiral_tags: False 5 | transductive: False 6 | node_encoder: True 7 | node_encoder_name: LinearNode+RWSE 8 | node_encoder_in_dim: 118 9 | node_encoder_bn: False 10 | edge_encoder: True 11 | edge_encoder_name: LinearEdge 12 | edge_encoder_in_dim: 80 13 | edge_encoder_bn: False 14 | posenc_RWSE: 15 | enable: True 16 | kernel: 17 | times_func: range(1,21) 18 | model: Linear 19 | dim_pe: 28 20 | raw_norm_type: BatchNorm 21 | train: 22 | mode: custom 23 | batch_size: 32 24 | eval_period: 1 25 | ckpt_period: 100 26 | model: 27 | type: GPSModel 28 | edge_decoding: dot 29 | graph_pooling: add 30 | hidden_dim: 64 31 | add_chienn_layer: True 32 | chienn: 33 | message: 34 | k_neighbors_embeddings_names: ['linear', 'linear', 'linear'] 35 | final_embedding_name: 'ELU+linear' 36 | aggregate: 37 | self_embedding_name: 'linear' 38 | parallel_embedding_name: 'linear' 39 | aggregation: 'sum' 40 | post_aggregation_embedding_name: 'ELU' 41 | gt: 42 | layer_type: GINE+Transformer # CustomGatedGCN+Performer 43 | layers: 5 44 | n_heads: 4 45 | dropout: 0.0 46 | attn_dropout: 0.5 47 | layer_norm: False 48 | batch_norm: True 49 | gnn: 50 | head: san_graph 51 | layers_pre_mp: 0 52 | layers_post_mp: 3 # Not used when `gnn.head: san_graph` 53 | batchnorm: True 54 | act: relu 55 | dropout: 0.0 56 | agg: mean 57 | normalize_adj: False 58 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS+ChiENN/bace-GPS+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS+ChiENN/GPS+ChiENN.yaml, configs/datasets/bace.yaml] 2 | gt: 3 | dropout: 0.2 4 | layers: 10 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS+ChiENN/binding_affinity-GPS+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS+ChiENN/GPS+ChiENN.yaml, configs/datasets/binding_affinity.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 6 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS+ChiENN/binding_rank-GPS+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS+ChiENN/GPS+ChiENN.yaml, configs/datasets/binding_rank.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 10 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS+ChiENN/rs-GPS+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS+ChiENN/GPS+ChiENN.yaml, configs/datasets/rs.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 3 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS+ChiENN/tox21-GPS+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS+ChiENN/GPS+ChiENN.yaml, configs/datasets/tox21.yaml] 2 | gt: 3 | dropout: 0.2 4 | layers: 10 5 | model: 6 | hidden_dim: 64 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS/GPS.yaml: -------------------------------------------------------------------------------- 1 | # parameters mainly taken from zinc-GPS+RWSE.yaml 2 | dataset: 3 | chiral_tags: False 4 | transductive: False 5 | node_encoder: True 6 | node_encoder_name: LinearNode+RWSE 7 | node_encoder_in_dim: 52 8 | node_encoder_bn: False 9 | edge_encoder: True 10 | edge_encoder_name: LinearEdge 11 | edge_encoder_in_dim: 14 12 | edge_encoder_bn: False 13 | posenc_RWSE: 14 | enable: True 15 | kernel: 16 | times_func: range(1,21) 17 | model: Linear 18 | dim_pe: 28 19 | raw_norm_type: BatchNorm 20 | train: 21 | mode: custom 22 | batch_size: 32 23 | eval_period: 1 24 | ckpt_period: 100 25 | model: 26 | type: GPSModel 27 | edge_decoding: dot 28 | graph_pooling: add 29 | hidden_dim: 64 30 | gt: 31 | layer_type: GINE+Transformer # CustomGatedGCN+Performer 32 | layers: 10 33 | n_heads: 4 34 | dropout: 0.0 35 | attn_dropout: 0.5 36 | layer_norm: False 37 | batch_norm: True 38 | gnn: 39 | head: san_graph 40 | layers_pre_mp: 0 41 | layers_post_mp: 3 # Not used when `gnn.head: san_graph` 42 | batchnorm: True 43 | act: relu 44 | dropout: 0.0 45 | agg: mean 46 | normalize_adj: False 47 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS/bace-GPS+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS/GPS.yaml, configs/datasets/bace.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gt: 5 | dropout: 0.2 6 | layers: 10 7 | model: 8 | hidden_dim: 256 9 | optim: 10 | base_lr: 0.0001 11 | 12 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS/bace-GPS.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS/GPS.yaml, configs/datasets/bace.yaml] 2 | gt: 3 | dropout: 0.2 4 | layers: 10 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | 10 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS/binding_affinity-GPS+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS/GPS.yaml, configs/datasets/binding_affinity.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gt: 5 | dropout: 0.0 6 | layers: 10 7 | model: 8 | hidden_dim: 256 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS/binding_affinity-GPS.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS/GPS.yaml, configs/datasets/binding_affinity.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 10 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | 10 | 11 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS/binding_rank-GPS+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS/GPS.yaml, configs/datasets/binding_rank.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gt: 5 | dropout: 0.0 6 | layers: 10 7 | model: 8 | hidden_dim: 256 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS/tox21-GPS+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS/GPS.yaml, configs/datasets/tox21.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gt: 5 | dropout: 0.0 6 | layers: 6 7 | model: 8 | hidden_dim: 128 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/GPS/tox21-GPS.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/GPS/GPS.yaml, configs/datasets/tox21.yaml] 2 | gt: 3 | dropout: 0.2 4 | layers: 3 5 | model: 6 | hidden_dim: 128 7 | optim: 8 | base_lr: 0.0001 9 | 10 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN+ChiENN/SAN+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | pre_transform_name: edge_graph 3 | chiral_tags: False 4 | node_encoder: True 5 | node_encoder_name: LinearNode+LapPE 6 | node_encoder_in_dim: 118 7 | node_encoder_bn: False 8 | edge_encoder: True 9 | edge_encoder_name: LinearEdge 10 | edge_encoder_in_dim: 80 11 | edge_encoder_bn: False 12 | posenc_LapPE: 13 | enable: True 14 | eigen: 15 | laplacian_norm: sym 16 | eigvec_norm: L2 17 | max_freqs: 10 18 | model: Transformer # DeepSet 19 | dim_pe: 8 20 | layers: 2 21 | n_heads: 4 # Only used when `posenc.model: Transformer` 22 | raw_norm_type: none 23 | train: 24 | mode: custom 25 | batch_size: 32 26 | eval_period: 1 27 | ckpt_period: 100 28 | model: 29 | type: SANTransformer 30 | edge_decoding: dot 31 | graph_pooling: add 32 | hidden_dim: 64 33 | add_chienn_layer: True 34 | chienn: 35 | message: 36 | k_neighbors_embeddings_names: ['linear', 'linear', 'linear'] 37 | final_embedding_name: 'ELU+linear' 38 | aggregate: 39 | self_embedding_name: 'linear' 40 | parallel_embedding_name: 'linear' 41 | aggregation: 'sum' 42 | post_aggregation_embedding_name: 'ELU' 43 | gt: 44 | layers: 10 45 | n_heads: 8 46 | full_graph: True 47 | gamma: 1e-5 48 | dropout: 0.0 49 | layer_norm: False 50 | batch_norm: True 51 | residual: True 52 | gnn: 53 | head: san_graph 54 | layers_pre_mp: 0 55 | layers_post_mp: 3 # Not used when `gnn.head: san_graph` 56 | batchnorm: True 57 | act: relu 58 | dropout: 0.0 59 | agg: mean 60 | normalize_adj: False 61 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN+ChiENN/bace-SAN+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN+ChiENN/SAN+ChiENN.yaml, configs/datasets/bace.yaml] 2 | train: 3 | mode: custom 4 | batch_size: 16 5 | gt: 6 | dropout: 0.5 7 | layers: 10 8 | model: 9 | hidden_dim: 256 10 | optim: 11 | base_lr: 0.0001 12 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN+ChiENN/binding_affinity-SAN+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN+ChiENN/SAN+ChiENN.yaml, configs/datasets/binding_affinity.yaml] 2 | gt: 3 | dropout: 0.2 4 | layers: 6 5 | model: 6 | hidden_dim: 128 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN+ChiENN/binding_rank-SAN+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN+ChiENN/SAN+ChiENN.yaml, configs/datasets/binding_rank.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 6 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN+ChiENN/rs-SAN+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN+ChiENN/SAN+ChiENN.yaml, configs/datasets/rs.yaml] 2 | gt: 3 | dropout: 0.2 4 | layers: 6 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | 10 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN+ChiENN/tox21-SAN+ChiENN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN+ChiENN/SAN+ChiENN.yaml, configs/datasets/tox21.yaml] 2 | gt: 3 | dropout: 0.2 4 | layers: 6 5 | model: 6 | hidden_dim: 128 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN/SAN.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | node_encoder: True 3 | node_encoder_name: LinearNode+LapPE 4 | node_encoder_in_dim: 52 5 | node_encoder_bn: False 6 | edge_encoder: True 7 | edge_encoder_name: LinearEdge 8 | edge_encoder_in_dim: 14 9 | edge_encoder_bn: False 10 | chiral_tags: False 11 | posenc_LapPE: 12 | enable: True 13 | eigen: 14 | laplacian_norm: sym 15 | eigvec_norm: L2 16 | max_freqs: 10 17 | model: Transformer # DeepSet 18 | dim_pe: 8 19 | layers: 2 20 | n_heads: 4 # Only used when `posenc.model: Transformer` 21 | raw_norm_type: none 22 | train: 23 | mode: custom 24 | batch_size: 32 25 | eval_period: 1 26 | ckpt_period: 100 27 | model: 28 | type: SANTransformer 29 | edge_decoding: dot 30 | graph_pooling: add 31 | hidden_dim: 64 32 | gt: 33 | layers: 10 34 | n_heads: 8 35 | full_graph: True 36 | gamma: 1e-5 37 | dropout: 0.0 38 | layer_norm: False 39 | batch_norm: True 40 | residual: True 41 | gnn: 42 | head: san_graph 43 | layers_pre_mp: 0 44 | layers_post_mp: 3 # Not used when `gnn.head: san_graph` 45 | batchnorm: True 46 | act: relu 47 | dropout: 0.0 48 | agg: mean 49 | normalize_adj: False 50 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN/bace-SAN+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN/SAN.yaml, configs/datasets/bace.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gt: 5 | dropout: 0.0 6 | layers: 6 7 | model: 8 | hidden_dim: 256 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN/bace-SAN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN/SAN.yaml, configs/datasets/bace.yaml] 2 | gt: 3 | dropout: 0.0 4 | layers: 6 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN/binding_affinity-SAN+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN/SAN.yaml, configs/datasets/binding_affinity.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gt: 5 | dropout: 0.2 6 | layers: 10 7 | model: 8 | hidden_dim: 128 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN/binding_affinity-SAN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN/SAN.yaml, configs/datasets/binding_affinity.yaml] 2 | gt: 3 | dropout: 0.2 4 | layers: 10 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN/binding_rank-SAN+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN/SAN.yaml, configs/datasets/binding_rank.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gt: 5 | dropout: 0.0 6 | layers: 10 7 | model: 8 | hidden_dim: 256 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN/tox21-SAN+with_tags.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN/SAN.yaml, configs/datasets/tox21.yaml] 2 | dataset: 3 | chiral_tags: True 4 | gt: 5 | dropout: 0.2 6 | layers: 6 7 | model: 8 | hidden_dim: 256 9 | optim: 10 | base_lr: 0.0001 11 | -------------------------------------------------------------------------------- /experiments/configs/models/SAN/tox21-SAN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/SAN/SAN.yaml, configs/datasets/tox21.yaml] 2 | gt: 3 | dropout: 0.2 4 | layers: 10 5 | model: 6 | hidden_dim: 256 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/Tetra_DMPNN/Tetra_DMPNN.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | pre_transform_name: add_parity_atoms 3 | node_encoder_in_dim: 52 4 | edge_encoder_in_dim: 14 5 | chiral_tags: False 6 | train: 7 | mode: custom 8 | batch_size: 32 9 | eval_period: 1 10 | ckpt_period: 100 11 | model: 12 | type: DMPNN 13 | graph_pooling: sum 14 | hidden_dim: 300 15 | gnn: 16 | layers: 3 17 | dropout: 0.0 18 | tetra: 19 | use: True 20 | message: tetra_permute 21 | -------------------------------------------------------------------------------- /experiments/configs/models/Tetra_DMPNN/bace-Tetra_DMPNN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/Tetra_DMPNN/Tetra_DMPNN.yaml, configs/datasets/bace.yaml] 2 | gnn: 3 | dropout: 0.2 4 | layers: 6 5 | model: 6 | hidden_dim: 600 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/Tetra_DMPNN/binding_affinity-Tetra_DMPNN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/Tetra_DMPNN/Tetra_DMPNN.yaml, configs/datasets/binding_affinity.yaml] 2 | gnn: 3 | dropout: 0.2 4 | layers: 6 5 | model: 6 | hidden_dim: 600 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/Tetra_DMPNN/binding_rank-Tetra_DMPNN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/Tetra_DMPNN/Tetra_DMPNN.yaml, configs/datasets/binding_rank.yaml] 2 | gnn: 3 | dropout: 0.2 4 | layers: 6 5 | model: 6 | hidden_dim: 900 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/Tetra_DMPNN/rs-Tetra_DMPNN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/Tetra_DMPNN/Tetra_DMPNN.yaml, configs/datasets/rs.yaml] 2 | gnn: 3 | dropout: 0.2 4 | layers: 2 5 | model: 6 | hidden_dim: 300 7 | optim: 8 | base_lr: 0.0001 -------------------------------------------------------------------------------- /experiments/configs/models/Tetra_DMPNN/tox21-Tetra_DMPNN.yaml: -------------------------------------------------------------------------------- 1 | additional_cfg_files: [configs/models/common.yaml, configs/models/Tetra_DMPNN/Tetra_DMPNN.yaml, configs/datasets/tox21.yaml] 2 | gt: 3 | dropout: 0.2 4 | layers: 10 5 | model: 6 | hidden_dim: 128 7 | optim: 8 | base_lr: 0.0001 9 | -------------------------------------------------------------------------------- /experiments/configs/models/common.yaml: -------------------------------------------------------------------------------- 1 | out_dir: results 2 | wandb: 3 | use: True 4 | project: experiments 5 | entity: chienn 6 | optim: 7 | clip_grad_norm: True 8 | optimizer: adamW 9 | weight_decay: 1e-5 10 | base_lr: 1e-4 11 | max_epoch: 100 12 | scheduler: cosine_with_warmup 13 | num_warmup_epochs: 10 -------------------------------------------------------------------------------- /experiments/create_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('submodules/ChIRo') 4 | sys.path.append('submodules/GeometricTransformerMolecule') 5 | from graphgps.dataset.utils import create_custom_loader 6 | from main import custom_load_cfg 7 | 8 | import graphgps # noqa, register custom modules 9 | 10 | from torch_geometric.graphgym.cmd_args import parse_args 11 | from torch_geometric.graphgym.config import (cfg, set_cfg) 12 | from torch_geometric import seed_everything 13 | 14 | if __name__ == '__main__': 15 | # Load cmd line args 16 | args = parse_args() 17 | cfg_file = args.cfg_file 18 | opts = args.opts 19 | 20 | # Load config 21 | set_cfg(cfg) 22 | custom_load_cfg(cfg=cfg, cfg_file=cfg_file, opts=opts) 23 | seed_everything(cfg.seed) 24 | create_custom_loader() 25 | -------------------------------------------------------------------------------- /experiments/graphgps/__init__.py: -------------------------------------------------------------------------------- 1 | from .act import * # noqa 2 | from .config import * # noqa 3 | from .encoder import * # noqa 4 | from .head import * # noqa 5 | from .layer import * # noqa 6 | from .loader import * # noqa 7 | from .loss import * # noqa 8 | from .network import * # noqa 9 | from .optimizer import * # noqa 10 | from .pooling import * # noqa 11 | from .stage import * # noqa 12 | from .train import * # noqa 13 | from .transform import * # noqa -------------------------------------------------------------------------------- /experiments/graphgps/act/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/act/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch_geometric.graphgym.config import cfg 5 | from torch_geometric.graphgym.register import register_act 6 | 7 | 8 | class SWISH(nn.Module): 9 | def __init__(self, inplace=False): 10 | super().__init__() 11 | self.inplace = inplace 12 | 13 | def forward(self, x): 14 | if self.inplace: 15 | x.mul_(torch.sigmoid(x)) 16 | return x 17 | else: 18 | return x * torch.sigmoid(x) 19 | 20 | 21 | register_act('swish', SWISH(inplace=cfg.mem.inplace)) 22 | register_act('lrelu_03', nn.LeakyReLU(0.3, inplace=cfg.mem.inplace)) 23 | -------------------------------------------------------------------------------- /experiments/graphgps/config/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/config/chienn_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('chienn_cfg') 6 | def chienn_cfg(cfg): 7 | """ 8 | Config option for ChiENN model. 9 | """ 10 | cfg.chienn = CN() 11 | 12 | # Parameters for message module of ChiENN 13 | cfg.chienn.message = CN() 14 | cfg.chienn.message.k_neighbors_embeddings_names = ['linear', 'linear', 'linear'] 15 | cfg.chienn.message.final_embedding_name = 'ELU+linear' 16 | 17 | # Parameters for aggregation module of ChiENN 18 | cfg.chienn.aggregate = CN() 19 | cfg.chienn.aggregate.self_embedding_name = 'linear' 20 | cfg.chienn.aggregate.parallel_embedding_name = 'none' 21 | cfg.chienn.aggregate.aggregation = 'sum' 22 | cfg.chienn.aggregate.post_aggregation_embedding_name = 'ELU' 23 | -------------------------------------------------------------------------------- /experiments/graphgps/config/custom_gnn_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | @register_config('custom_gnn') 5 | def custom_gnn_cfg(cfg): 6 | """Extending config group of GraphGym's built-in GNN for purposes of our 7 | CustomGNN network model. 8 | """ 9 | 10 | # Use residual connections between the GNN layers. 11 | cfg.gnn.residual = False 12 | 13 | # Used in DMPNN class. 14 | cfg.gnn.tetra = CN() 15 | cfg.gnn.tetra.use = False 16 | cfg.gnn.tetra.message = "tetra_permute" 17 | cfg.gnn.layers = 5 -------------------------------------------------------------------------------- /experiments/graphgps/config/custom_model_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | @register_config('custom_model_cfg') 5 | def dataset_cfg(cfg): 6 | """Model-specific config options. 7 | """ 8 | 9 | # Path to config used in ChIRo model. 10 | cfg.model.config_path = "" 11 | 12 | cfg.model.hidden_dim = 10 13 | 14 | # Method for coordinates selection in AtomDistance class 15 | cfg.model.coords_selection = "start" 16 | 17 | cfg.model.add_chienn_layer = False 18 | -------------------------------------------------------------------------------- /experiments/graphgps/config/dataset_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config("dataset_cfg") 5 | def dataset_cfg(cfg): 6 | """Dataset-specific config options.""" 7 | 8 | # The number of node types to expect in TypeDictNodeEncoder. 9 | cfg.dataset.node_encoder_num_types = 0 10 | 11 | # The number of edge types to expect in TypeDictEdgeEncoder. 12 | cfg.dataset.edge_encoder_num_types = 0 13 | 14 | # VOC/COCO Superpixels dataset version based on SLIC compactness parameter. 15 | cfg.dataset.slic_compactness = 10 16 | 17 | # Dimension of node attributes. Used in `LinearNodeEncoder`. 18 | cfg.dataset.node_encoder_in_dim = 0 19 | 20 | # Dimension of edge attributes. Used in `LinearEdgeEncoder`. 21 | cfg.dataset.edge_encoder_in_dim = 0 22 | 23 | # Used in ChIRo datasets. Should be set to True, when model is conformer invariant to save some RAM. 24 | cfg.dataset.single_conformer = True 25 | 26 | # Used in ChIRo datasets. 27 | cfg.dataset.single_enantiomer = False 28 | 29 | # Used in ChIRo datasets. Whether to use chiral tags. Is set to False, chiral information will be masked. 30 | cfg.dataset.chiral_tags = True 31 | 32 | # Used in TDC (ChIRo) dataset. Type of task (e.g. "Tox", "ADME"). 33 | cfg.dataset.tdc_type = "" 34 | 35 | # Used in TDC (ChIRo) dataset. Name of assay (in case of tasks with multiple labels). 36 | cfg.dataset.tdc_assay_name = "" 37 | 38 | # Used in OGB (ChIRo) dataset. Type of task (e.g. "hiv", "pcba"). 39 | cfg.dataset.ogb_dataset_name = "" 40 | 41 | # Used for scaling regression labels. 42 | cfg.dataset.scale_label = 1.0 43 | 44 | cfg.dataset.min_number_of_chiral_centers = 0 45 | 46 | # Used in our datasets. 47 | cfg.dataset.pre_transform_name = "" 48 | -------------------------------------------------------------------------------- /experiments/graphgps/config/defaults_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('overwrite_defaults') 5 | def overwrite_defaults_cfg(cfg): 6 | """Overwrite the default config values that are first set by GraphGym in 7 | torch_geometric.graphgym.config.set_cfg 8 | 9 | WARNING: At the time of writing, the order in which custom config-setting 10 | functions like this one are executed is random; see the referenced `set_cfg` 11 | Therefore never reset here config options that are custom added, only change 12 | those that exist in core GraphGym. 13 | """ 14 | 15 | # Overwrite default dataset name 16 | cfg.dataset.name = 'none' 17 | 18 | # Overwrite default rounding precision 19 | cfg.round = 5 20 | 21 | 22 | @register_config('extended_cfg') 23 | def extended_cfg(cfg): 24 | """General extended config options. 25 | """ 26 | 27 | # Additional name tag used in `run_dir` and `wandb_name` auto generation. 28 | cfg.name_tag = "" 29 | 30 | # In training, if True (and also cfg.train.enable_ckpt is True) then 31 | # always checkpoint the current best model based on validation performance, 32 | # instead, when False, follow cfg.train.eval_period checkpointing frequency. 33 | cfg.train.ckpt_best = False 34 | 35 | # File used to load additional dataset-specific config 36 | cfg.additional_cfg_files = [] 37 | 38 | cfg.subset = None 39 | -------------------------------------------------------------------------------- /experiments/graphgps/config/example.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('example') 6 | def set_cfg_example(cfg): 7 | r''' 8 | This function sets the default config value for customized options 9 | :return: customized configuration use by the experiment. 10 | ''' 11 | 12 | # ----------------------------------------------------------------------- # 13 | # Customized options 14 | # ----------------------------------------------------------------------- # 15 | 16 | # example argument 17 | cfg.example_arg = 'example' 18 | 19 | # example argument group 20 | cfg.example_group = CN() 21 | 22 | # then argument can be specified within the group 23 | cfg.example_group.example_arg = 'example' 24 | -------------------------------------------------------------------------------- /experiments/graphgps/config/gt_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('cfg_gt') 6 | def set_cfg_gt(cfg): 7 | """Configuration for Graph Transformer-style models, e.g.: 8 | - Spectral Attention Network (SAN) Graph Transformer. 9 | - "vanilla" Transformer / Performer. 10 | - General Powerful Scalable (GPS) Model. 11 | """ 12 | 13 | # Positional encodings argument group 14 | cfg.gt = CN() 15 | 16 | # Type of Graph Transformer layer to use 17 | cfg.gt.layer_type = 'SANLayer' 18 | 19 | # Number of Transformer layers in the model 20 | cfg.gt.layers = 3 21 | 22 | # Number of attention heads in the Graph Transformer 23 | cfg.gt.n_heads = 8 24 | 25 | # Size of the hidden node and edge representation 26 | cfg.gt.dim_hidden = 64 27 | 28 | # Full attention SAN transformer including all possible pairwise edges 29 | cfg.gt.full_graph = True 30 | 31 | # SAN real vs fake edge attention weighting coefficient 32 | cfg.gt.gamma = 1e-5 33 | 34 | # Histogram of in-degrees of nodes in the training set used by PNAConv. 35 | # Used when `gt.layer_type: PNAConv+...`. If empty it is precomputed during 36 | # the dataset loading process. 37 | cfg.gt.pna_degrees = [] 38 | 39 | # Dropout in feed-forward module. 40 | cfg.gt.dropout = 0.0 41 | 42 | # Dropout in self-attention. 43 | cfg.gt.attn_dropout = 0.0 44 | 45 | cfg.gt.layer_norm = False 46 | 47 | cfg.gt.batch_norm = True 48 | 49 | cfg.gt.residual = True 50 | 51 | # BigBird model/GPS-BigBird layer. 52 | cfg.gt.bigbird = CN() 53 | 54 | cfg.gt.bigbird.attention_type = "block_sparse" 55 | 56 | cfg.gt.bigbird.chunk_size_feed_forward = 0 57 | 58 | cfg.gt.bigbird.is_decoder = False 59 | 60 | cfg.gt.bigbird.add_cross_attention = False 61 | 62 | cfg.gt.bigbird.hidden_act = "relu" 63 | 64 | cfg.gt.bigbird.max_position_embeddings = 128 65 | 66 | cfg.gt.bigbird.use_bias = False 67 | 68 | cfg.gt.bigbird.num_random_blocks = 3 69 | 70 | cfg.gt.bigbird.block_size = 3 71 | 72 | cfg.gt.bigbird.layer_norm_eps = 1e-6 73 | -------------------------------------------------------------------------------- /experiments/graphgps/config/optimizers_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('extended_optim') 5 | def extended_optim_cfg(cfg): 6 | """Extend optimizer config group that is first set by GraphGym in 7 | torch_geometric.graphgym.config.set_cfg 8 | """ 9 | 10 | # Number of batches to accumulate gradients over before updating parameters 11 | # Requires `custom` training loop, set `train.mode: custom` 12 | cfg.optim.batch_accumulation = 1 13 | 14 | # ReduceLROnPlateau: Factor by which the learning rate will be reduced 15 | cfg.optim.reduce_factor = 0.1 16 | 17 | # ReduceLROnPlateau: #epochs without improvement after which LR gets reduced 18 | cfg.optim.schedule_patience = 10 19 | 20 | # ReduceLROnPlateau: Lower bound on the learning rate 21 | cfg.optim.min_lr = 0.0 22 | 23 | # For schedulers with warm-up phase, set the warm-up number of epochs 24 | cfg.optim.num_warmup_epochs = 50 25 | 26 | # Clip gradient norms while training 27 | cfg.optim.clip_grad_norm = False 28 | -------------------------------------------------------------------------------- /experiments/graphgps/config/posenc_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('posenc') 6 | def set_cfg_posenc(cfg): 7 | """Extend configuration with positional encoding options. 8 | """ 9 | 10 | # Argument group for each Positional Encoding class. 11 | cfg.posenc_LapPE = CN() 12 | cfg.posenc_SignNet = CN() 13 | cfg.posenc_RWSE = CN() 14 | cfg.posenc_HKdiagSE = CN() 15 | cfg.posenc_ElstaticSE = CN() 16 | cfg.posenc_EquivStableLapPE = CN() 17 | 18 | # Common arguments to all PE types. 19 | for name in ['posenc_LapPE', 'posenc_SignNet', 20 | 'posenc_RWSE', 'posenc_HKdiagSE', 'posenc_ElstaticSE']: 21 | pecfg = getattr(cfg, name) 22 | 23 | # Use extended positional encodings 24 | pecfg.enable = False 25 | 26 | # Neural-net model type within the PE encoder: 27 | # 'DeepSet', 'Transformer', 'Linear', 'none', ... 28 | pecfg.model = 'none' 29 | 30 | # Size of Positional Encoding embedding 31 | pecfg.dim_pe = 16 32 | 33 | # Number of layers in PE encoder model 34 | pecfg.layers = 3 35 | 36 | # Number of attention heads in PE encoder when model == 'Transformer' 37 | pecfg.n_heads = 4 38 | 39 | # Number of layers to apply in LapPE encoder post its pooling stage 40 | pecfg.post_layers = 0 41 | 42 | # Choice of normalization applied to raw PE stats: 'none', 'BatchNorm' 43 | pecfg.raw_norm_type = 'none' 44 | 45 | # In addition to appending PE to the node features, pass them also as 46 | # a separate variable in the PyG graph batch object. 47 | pecfg.pass_as_var = False 48 | 49 | # Config for EquivStable LapPE 50 | cfg.posenc_EquivStableLapPE.enable = False 51 | cfg.posenc_EquivStableLapPE.raw_norm_type = 'none' 52 | 53 | # Config for Laplacian Eigen-decomposition for PEs that use it. 54 | for name in ['posenc_LapPE', 'posenc_SignNet', 'posenc_EquivStableLapPE']: 55 | pecfg = getattr(cfg, name) 56 | pecfg.eigen = CN() 57 | 58 | # The normalization scheme for the graph Laplacian: 'none', 'sym', or 'rw' 59 | pecfg.eigen.laplacian_norm = 'sym' 60 | 61 | # The normalization scheme for the eigen vectors of the Laplacian 62 | pecfg.eigen.eigvec_norm = 'L2' 63 | 64 | # Maximum number of top smallest frequencies & eigenvectors to use 65 | pecfg.eigen.max_freqs = 10 66 | 67 | # Config for SignNet-specific options. 68 | cfg.posenc_SignNet.phi_out_dim = 4 69 | cfg.posenc_SignNet.phi_hidden_dim = 64 70 | 71 | for name in ['posenc_RWSE', 'posenc_HKdiagSE', 'posenc_ElstaticSE']: 72 | pecfg = getattr(cfg, name) 73 | 74 | # Config for Kernel-based PE specific options. 75 | pecfg.kernel = CN() 76 | 77 | # List of times to compute the heat kernel for (the time is equivalent to 78 | # the variance of the kernel) / the number of steps for random walk kernel 79 | # Can be overridden by `posenc.kernel.times_func` 80 | pecfg.kernel.times = [] 81 | 82 | # Python snippet to generate `posenc.kernel.times`, e.g. 'range(1, 17)' 83 | # If set, it will be executed via `eval()` and override posenc.kernel.times 84 | pecfg.kernel.times_func = '' 85 | 86 | # Override default, electrostatic kernel has fixed set of 10 measures. 87 | cfg.posenc_ElstaticSE.kernel.times_func = 'range(10)' 88 | -------------------------------------------------------------------------------- /experiments/graphgps/config/pretrained_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('cfg_pretrained') 6 | def set_cfg_pretrained(cfg): 7 | """Configuration options for loading a pretrained model. 8 | """ 9 | 10 | cfg.pretrained = CN() 11 | 12 | # Directory path to a saved experiment, if set, load the model from there 13 | # and fine-tune / run inference with it on a specified dataset. 14 | cfg.pretrained.dir = "" 15 | 16 | # Discard pretrained weights of the prediction head and reinitialize. 17 | cfg.pretrained.reset_prediction_head = True 18 | 19 | # Freeze the main pretrained 'body' of the model, learning only the new head 20 | cfg.pretrained.freeze_main = False 21 | -------------------------------------------------------------------------------- /experiments/graphgps/config/split_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | 3 | 4 | @register_config('split') 5 | def set_cfg_split(cfg): 6 | """Reconfigure the default config value for dataset split options. 7 | 8 | Returns: 9 | Reconfigured split configuration use by the experiment. 10 | """ 11 | 12 | # Default to selecting the standard split that ships with the dataset 13 | cfg.dataset.split_mode = 'standard' 14 | 15 | # Choose a particular split to use if multiple splits are available 16 | cfg.dataset.split_index = 0 17 | 18 | # Dir to cache cross-validation splits 19 | cfg.dataset.split_dir = './splits' 20 | 21 | # Choose to run multiple splits in one program execution, if set, 22 | # takes the precedence over cfg.dataset.split_index for split selection 23 | cfg.run_multiple_splits = [] 24 | -------------------------------------------------------------------------------- /experiments/graphgps/config/test_config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode 2 | from torch_geometric.graphgym.register import register_config 3 | 4 | 5 | @register_config('test_cfg') 6 | def dataset_cfg(cfg): 7 | """Test-specific config options. 8 | """ 9 | 10 | cfg.test = CfgNode() 11 | 12 | # Sampling strategy for a test loader 13 | cfg.test.sampler = 'full_batch' 14 | -------------------------------------------------------------------------------- /experiments/graphgps/config/wandb_config.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_config 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | @register_config('cfg_wandb') 6 | def set_cfg_wandb(cfg): 7 | """Weights & Biases tracker configuration. 8 | """ 9 | 10 | # WandB group 11 | cfg.wandb = CN() 12 | 13 | # Use wandb or not 14 | cfg.wandb.use = False 15 | 16 | # Wandb entity name, should exist beforehand 17 | cfg.wandb.entity = "gtransformers" 18 | 19 | # Wandb project name, will be created in your team if doesn't exist already 20 | cfg.wandb.project = "gtblueprint" 21 | 22 | # Optional run name 23 | cfg.wandb.base_name = "" 24 | -------------------------------------------------------------------------------- /experiments/graphgps/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/ChiENN/ee3185b39e8469a8caacf3d6d45a04c4a1cfff5b/experiments/graphgps/dataset/__init__.py -------------------------------------------------------------------------------- /experiments/graphgps/dataset/chiral_dataset_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import copy 3 | import os.path as osp 4 | 5 | import torch 6 | import torch_geometric.data 7 | from torch_geometric.data import InMemoryDataset 8 | 9 | from graphgps.dataset.utils import PRE_TRANSFORM_MAPPING, CHIRAL_MASKING_MAPPING 10 | 11 | 12 | class ChiralDatasetBase(InMemoryDataset, abc.ABC): 13 | r""" 14 | Dataset base class that 15 | """ 16 | 17 | def __init__( 18 | self, 19 | root, 20 | mask_chiral_tags, 21 | split="train", 22 | pre_transform_name=None, 23 | max_number_of_atoms=100, 24 | ): 25 | assert split in ["train", "val", "test"] 26 | self.mask_chiral_tags = mask_chiral_tags 27 | self.pre_transform_name = ( 28 | pre_transform_name if pre_transform_name else "default" 29 | ) 30 | pre_transform = PRE_TRANSFORM_MAPPING.get(self.pre_transform_name) 31 | self.mask_chiral_fn = CHIRAL_MASKING_MAPPING.get(self.pre_transform_name) 32 | self.max_number_of_atoms = max_number_of_atoms 33 | super().__init__( 34 | root, transform=None, pre_transform=pre_transform, pre_filter=None 35 | ) 36 | self.data, self.slices = torch.load(osp.join(self.processed_dir, f"{split}.pt")) 37 | 38 | def __getitem__(self, idx: int): 39 | """ 40 | Standard getitem with chiral tags masking. 41 | """ 42 | 43 | # for some reason it returns something different than data.Data at the very beginning of the training: 44 | data = super().__getitem__(idx) 45 | if isinstance(data, torch_geometric.data.Data): 46 | if self.mask_chiral_tags: 47 | data = self.mask_chiral_fn(data) 48 | return data 49 | -------------------------------------------------------------------------------- /experiments/graphgps/dataset/collate.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any 2 | 3 | from torch_geometric.data.data import BaseData 4 | from torch_geometric.loader.dataloader import Collater 5 | 6 | from chienn.data import collate_with_circle_index 7 | 8 | 9 | class CustomCollater: 10 | def __init__(self, follow_batch=None, exclude_keys=None, n_neighbors_in_circle=None): 11 | self.collator = Collater(follow_batch, exclude_keys) 12 | self.follow_batch = follow_batch 13 | exclude_keys = exclude_keys if exclude_keys else [] 14 | self.exclude_keys = exclude_keys + ['circle_index'] 15 | self.n_neighbors_in_circle = n_neighbors_in_circle 16 | 17 | def __call__(self, batch: List[Any]): 18 | elem = batch[0] 19 | if isinstance(elem, BaseData) and hasattr(elem, 'circle_index'): 20 | return collate_with_circle_index(batch, self.n_neighbors_in_circle) 21 | else: 22 | return self.collator(batch) 23 | -------------------------------------------------------------------------------- /experiments/graphgps/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Optional 2 | 3 | import torch 4 | from torch_geometric.data import Dataset 5 | from torch_geometric.data.data import BaseData 6 | from torch_geometric.loader.dataloader import Collater 7 | 8 | from graphgps.dataset.collate import CustomCollater 9 | 10 | 11 | class CustomDataLoader(torch.utils.data.DataLoader): 12 | r"""A data loader which merges data objects from a 13 | :class:`torch_geometric.data.Dataset` to a mini-batch. 14 | Data objects can be either of type :class:`~torch_geometric.data.Data` or 15 | :class:`~torch_geometric.data.HeteroData`. 16 | 17 | Args: 18 | dataset (Dataset): The dataset from which to load the data. 19 | batch_size (int, optional): How many samples per batch to load. 20 | (default: :obj:`1`) 21 | shuffle (bool, optional): If set to :obj:`True`, the data will be 22 | reshuffled at every epoch. (default: :obj:`False`) 23 | follow_batch (List[str], optional): Creates assignment batch 24 | vectors for each key in the list. (default: :obj:`None`) 25 | exclude_keys (List[str], optional): Will exclude each key in the 26 | list. (default: :obj:`None`) 27 | **kwargs (optional): Additional arguments of 28 | :class:`torch.utils.data.DataLoader`. 29 | """ 30 | def __init__( 31 | self, 32 | dataset: Union[Dataset, List[BaseData]], 33 | batch_size: int = 1, 34 | shuffle: bool = False, 35 | follow_batch: Optional[List[str]] = None, 36 | exclude_keys: Optional[List[str]] = None, 37 | n_neighbors_in_circle: Optional[int] = None, 38 | **kwargs, 39 | ): 40 | 41 | if 'collate_fn' in kwargs: 42 | del kwargs['collate_fn'] 43 | 44 | # Save for PyTorch Lightning: 45 | self.follow_batch = follow_batch 46 | self.exclude_keys = exclude_keys 47 | 48 | super().__init__( 49 | dataset, 50 | batch_size, 51 | shuffle, 52 | collate_fn=CustomCollater(follow_batch, exclude_keys, n_neighbors_in_circle), 53 | **kwargs, 54 | ) -------------------------------------------------------------------------------- /experiments/graphgps/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/encoder/ast_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.register import (register_node_encoder, 3 | register_edge_encoder) 4 | 5 | """ 6 | === Description of the ogbg-code2 dataset === 7 | 8 | * Node Encoder code based on OGB's: 9 | https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/code2/utils.py 10 | 11 | Node Encoder config parameters are set based on the OGB example: 12 | https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/code2/main_pyg.py 13 | where the following three node features are used: 14 | 1. node type 15 | 2. node attribute 16 | 3. node depth 17 | 18 | nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz')) 19 | nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz')) 20 | num_nodetypes = len(nodetypes_mapping['type']) 21 | num_nodeattributes = len(nodeattributes_mapping['attr']) 22 | max_depth = 20 23 | 24 | * Edge attributes are generated by `augment_edge` function dynamically: 25 | edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1) 26 | edge_attr[:,1]: whether it is original direction (0) or inverse direction (1) 27 | """ 28 | 29 | num_nodetypes = 98 30 | num_nodeattributes = 10030 31 | max_depth = 20 32 | 33 | 34 | @register_node_encoder('ASTNode') 35 | class ASTNodeEncoder(torch.nn.Module): 36 | """The Abstract Syntax Tree (AST) Node Encoder used for ogbg-code2 dataset. 37 | 38 | Input: 39 | x: Default node feature. The first and second column represents node 40 | type and node attributes. 41 | node_depth: The depth of the node in the AST. 42 | Output: 43 | emb_dim-dimensional vector 44 | """ 45 | 46 | def __init__(self, emb_dim): 47 | super().__init__() 48 | self.max_depth = max_depth 49 | 50 | self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim) 51 | self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim) 52 | self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim) 53 | 54 | def forward(self, batch): 55 | x = batch.x 56 | depth = batch.node_depth.view(-1, ) 57 | depth[depth > self.max_depth] = self.max_depth 58 | batch.x = self.type_encoder(x[:, 0]) + self.attribute_encoder(x[:, 1]) \ 59 | + self.depth_encoder(depth) 60 | return batch 61 | 62 | 63 | @register_edge_encoder('ASTEdge') 64 | class ASTEdgeEncoder(torch.nn.Module): 65 | """The Abstract Syntax Tree (AST) Edge Encoder used for ogbg-code2 dataset. 66 | 67 | Edge attributes are generated by `augment_edge` function dynamically and 68 | are expected to be: 69 | edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1) 70 | edge_attr[:,1]: whether it is original direction (0) or inverse direction (1) 71 | 72 | Args: 73 | emb_dim (int): Output edge embedding dimension 74 | """ 75 | 76 | def __init__(self, emb_dim): 77 | super().__init__() 78 | self.embedding_type = torch.nn.Embedding(2, emb_dim) 79 | self.embedding_direction = torch.nn.Embedding(2, emb_dim) 80 | 81 | def forward(self, batch): 82 | embedding = self.embedding_type(batch.edge_attr[:, 0]) + \ 83 | self.embedding_direction(batch.edge_attr[:, 1]) 84 | batch.edge_attr = embedding 85 | return batch 86 | -------------------------------------------------------------------------------- /experiments/graphgps/encoder/dummy_edge_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.register import register_edge_encoder 3 | 4 | 5 | @register_edge_encoder('DummyEdge') 6 | class DummyEdgeEncoder(torch.nn.Module): 7 | def __init__(self, emb_dim): 8 | super().__init__() 9 | 10 | self.encoder = torch.nn.Embedding(num_embeddings=1, 11 | embedding_dim=emb_dim) 12 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data) 13 | 14 | def forward(self, batch): 15 | dummy_attr = batch.edge_index.new_zeros(batch.edge_index.shape[1]) 16 | batch.edge_attr = self.encoder(dummy_attr) 17 | return batch 18 | -------------------------------------------------------------------------------- /experiments/graphgps/encoder/equivstable_laplace_pos_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.register import register_node_encoder 5 | 6 | 7 | @register_node_encoder('EquivStableLapPE') 8 | class EquivStableLapPENodeEncoder(torch.nn.Module): 9 | """Equivariant and Stable Laplace Positional Embedding node encoder. 10 | 11 | This encoder simply transforms the k-dim node LapPE to d-dim to be 12 | later used at the local GNN module as edge weights. 13 | Based on the approach proposed in paper https://openreview.net/pdf?id=e95i1IHcWj 14 | 15 | Args: 16 | dim_emb: Size of final node embedding 17 | """ 18 | 19 | def __init__(self, dim_emb): 20 | super().__init__() 21 | 22 | pecfg = cfg.posenc_EquivStableLapPE 23 | max_freqs = pecfg.eigen.max_freqs # Num. eigenvectors (frequencies) 24 | norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type 25 | 26 | if norm_type == 'batchnorm': 27 | self.raw_norm = nn.BatchNorm1d(max_freqs) 28 | else: 29 | self.raw_norm = None 30 | 31 | self.linear_encoder_eigenvec = nn.Linear(max_freqs, dim_emb) 32 | 33 | def forward(self, batch): 34 | if not (hasattr(batch, 'EigVals') and hasattr(batch, 'EigVecs')): 35 | raise ValueError("Precomputed eigen values and vectors are " 36 | f"required for {self.__class__.__name__}; set " 37 | f"config 'posenc_EquivStableLapPE.enable' to True") 38 | pos_enc = batch.EigVecs 39 | 40 | empty_mask = torch.isnan(pos_enc) # (Num nodes) x (Num Eigenvectors) 41 | pos_enc[empty_mask] = 0. # (Num nodes) x (Num Eigenvectors) 42 | 43 | if self.raw_norm: 44 | pos_enc = self.raw_norm(pos_enc) 45 | 46 | pos_enc = self.linear_encoder_eigenvec(pos_enc) 47 | 48 | # Keep PE separate in a variable 49 | batch.pe_EquivStableLapPE = pos_enc 50 | 51 | return batch 52 | -------------------------------------------------------------------------------- /experiments/graphgps/encoder/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ogb.utils.features import get_bond_feature_dims 3 | 4 | from torch_geometric.graphgym.register import ( 5 | register_edge_encoder, 6 | register_node_encoder, 7 | ) 8 | 9 | 10 | @register_node_encoder('example') 11 | class ExampleNodeEncoder(torch.nn.Module): 12 | """ 13 | Provides an encoder for integer node features 14 | Parameters: 15 | num_classes - the number of classes for the embedding mapping to learn 16 | """ 17 | def __init__(self, emb_dim, num_classes=None): 18 | super().__init__() 19 | 20 | self.encoder = torch.nn.Embedding(num_classes, emb_dim) 21 | torch.nn.init.xavier_uniform_(self.encoder.weight.data) 22 | 23 | def forward(self, batch): 24 | # Encode just the first dimension if more exist 25 | batch.x = self.encoder(batch.x[:, 0]) 26 | 27 | return batch 28 | 29 | 30 | @register_edge_encoder('example') 31 | class ExampleEdgeEncoder(torch.nn.Module): 32 | def __init__(self, emb_dim): 33 | super().__init__() 34 | 35 | self.bond_embedding_list = torch.nn.ModuleList() 36 | full_bond_feature_dims = get_bond_feature_dims() 37 | 38 | for i, dim in enumerate(full_bond_feature_dims): 39 | emb = torch.nn.Embedding(dim, emb_dim) 40 | torch.nn.init.xavier_uniform_(emb.weight.data) 41 | self.bond_embedding_list.append(emb) 42 | 43 | def forward(self, batch): 44 | bond_embedding = 0 45 | for i in range(batch.edge_feature.shape[1]): 46 | bond_embedding += \ 47 | self.bond_embedding_list[i](batch.edge_attr[:, i]) 48 | 49 | batch.edge_attr = bond_embedding 50 | return batch 51 | -------------------------------------------------------------------------------- /experiments/graphgps/encoder/geometric_node_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.graphgym import cfg 4 | from torch_geometric.graphgym.register import register_node_encoder 5 | 6 | 7 | class GeometricPE(nn.Module): 8 | """ 9 | Adapted from GeometricTransformer 10 | """ 11 | 12 | def __init__(self, d_model, dims=50): 13 | super().__init__() 14 | self.net = torch.nn.Sequential(*[nn.Linear(1, dims, 15 | bias=True), torch.nn.GELU(), nn.Linear(dims, 1, 16 | bias=True), 17 | torch.nn.GELU()]) 18 | self.embed = nn.Linear(1, d_model, 19 | bias=True) 20 | 21 | def forward(self, batch): 22 | distances_mask = torch.logical_and(batch.mask.unsqueeze(-1), batch.mask.unsqueeze(-2)) 23 | x = self.net(batch.distances.unsqueeze(-1)) 24 | x = torch.masked_fill(x, batch.zero_distances_mask.unsqueeze(-1), 0.0) 25 | x = torch.masked_fill(x, ~distances_mask.unsqueeze(-1), 0.0) 26 | x = torch.sum(x, -2) 27 | x = self.embed(x) 28 | return x[batch.mask] 29 | 30 | 31 | @register_node_encoder('GeometricNode') 32 | class LinearNodeEncoder(torch.nn.Module): 33 | def __init__(self, emb_dim): 34 | super().__init__() 35 | 36 | self.encoder = torch.nn.Linear(cfg.dataset.node_encoder_in_dim, emb_dim) 37 | self.pe = GeometricPE(emb_dim) 38 | 39 | def forward(self, batch): 40 | batch.x = self.encoder(batch.x) + self.pe(batch) 41 | return batch 42 | -------------------------------------------------------------------------------- /experiments/graphgps/encoder/linear_edge_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym import cfg 3 | from torch_geometric.graphgym.register import register_edge_encoder 4 | 5 | 6 | @register_edge_encoder('LinearEdge') 7 | class LinearEdgeEncoder(torch.nn.Module): 8 | def __init__(self, emb_dim): 9 | super().__init__() 10 | if cfg.dataset.name in ['MNIST', 'CIFAR10']: 11 | self.in_dim = 1 12 | else: 13 | self.in_dim = cfg.dataset.edge_encoder_in_dim 14 | self.encoder = torch.nn.Linear(self.in_dim, emb_dim) 15 | 16 | def forward(self, batch): 17 | batch.edge_attr = self.encoder(batch.edge_attr.view(-1, self.in_dim)) 18 | return batch 19 | -------------------------------------------------------------------------------- /experiments/graphgps/encoder/linear_node_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym import cfg 3 | from torch_geometric.graphgym.register import register_node_encoder 4 | 5 | 6 | @register_node_encoder('LinearNode') 7 | class LinearNodeEncoder(torch.nn.Module): 8 | def __init__(self, emb_dim): 9 | super().__init__() 10 | 11 | self.encoder = torch.nn.Linear(cfg.dataset.node_encoder_in_dim, emb_dim) 12 | 13 | def forward(self, batch): 14 | batch.x = self.encoder(batch.x) 15 | return batch 16 | -------------------------------------------------------------------------------- /experiments/graphgps/encoder/ppa_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.register import (register_node_encoder, 3 | register_edge_encoder) 4 | 5 | 6 | @register_node_encoder('PPANode') 7 | class PPANodeEncoder(torch.nn.Module): 8 | """ 9 | Uniform input node embedding for PPA that has no node features. 10 | """ 11 | 12 | def __init__(self, emb_dim): 13 | super().__init__() 14 | self.encoder = torch.nn.Embedding(1, emb_dim) 15 | 16 | def forward(self, batch): 17 | batch.x = self.encoder(batch.x) 18 | return batch 19 | 20 | 21 | @register_edge_encoder('PPAEdge') 22 | class PPAEdgeEncoder(torch.nn.Module): 23 | def __init__(self, emb_dim): 24 | super().__init__() 25 | self.encoder = torch.nn.Linear(7, emb_dim) 26 | 27 | def forward(self, batch): 28 | batch.edge_attr = self.encoder(batch.edge_attr) 29 | return batch 30 | -------------------------------------------------------------------------------- /experiments/graphgps/encoder/voc_superpixels_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import (register_node_encoder, 4 | register_edge_encoder) 5 | 6 | """ 7 | === Description of the VOCSuperpixels dataset === 8 | Each graph is a tuple (x, edge_attr, edge_index, y) 9 | Shape of x : [num_nodes, 14] 10 | Shape of edge_attr : [num_edges, 1] or [num_edges, 2] 11 | Shape of edge_index : [2, num_edges] 12 | Shape of y : [num_nodes] 13 | """ 14 | 15 | VOC_node_input_dim = 14 16 | # VOC_edge_input_dim = 1 or 2; defined in class VOCEdgeEncoder 17 | 18 | @register_node_encoder('VOCNode') 19 | class VOCNodeEncoder(torch.nn.Module): 20 | def __init__(self, emb_dim): 21 | super().__init__() 22 | 23 | self.encoder = torch.nn.Linear(VOC_node_input_dim, emb_dim) 24 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data) 25 | 26 | def forward(self, batch): 27 | batch.x = self.encoder(batch.x) 28 | 29 | return batch 30 | 31 | 32 | @register_edge_encoder('VOCEdge') 33 | class VOCEdgeEncoder(torch.nn.Module): 34 | def __init__(self, emb_dim): 35 | super().__init__() 36 | 37 | VOC_edge_input_dim = 2 if cfg.dataset.name == 'edge_wt_region_boundary' else 1 38 | self.encoder = torch.nn.Linear(VOC_edge_input_dim, emb_dim) 39 | # torch.nn.init.xavier_uniform_(self.encoder.weight.data) 40 | 41 | def forward(self, batch): 42 | batch.edge_attr = self.encoder(batch.edge_attr) 43 | return batch 44 | -------------------------------------------------------------------------------- /experiments/graphgps/head/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/head/example.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch_geometric.graphgym.register import register_head 4 | 5 | 6 | @register_head('head') 7 | class ExampleNodeHead(nn.Module): 8 | '''Head of GNN, node prediction''' 9 | def __init__(self, dim_in, dim_out): 10 | super().__init__() 11 | self.layer_post_mp = nn.Linear(dim_in, dim_out, bias=True) 12 | 13 | def _apply_index(self, batch): 14 | if batch.node_label_index.shape[0] == batch.node_label.shape[0]: 15 | return batch.x[batch.node_label_index], batch.node_label 16 | else: 17 | return batch.x[batch.node_label_index], \ 18 | batch.node_label[batch.node_label_index] 19 | 20 | def forward(self, batch): 21 | batch = self.layer_post_mp(batch) 22 | pred, label = self._apply_index(batch) 23 | return pred, label 24 | -------------------------------------------------------------------------------- /experiments/graphgps/head/inductive_node.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.models.layer import new_layer_config, MLP 4 | from torch_geometric.graphgym.register import register_head 5 | 6 | 7 | @register_head('inductive_node') 8 | class GNNInductiveNodeHead(nn.Module): 9 | """ 10 | GNN prediction head for inductive node prediction tasks. 11 | 12 | Args: 13 | dim_in (int): Input dimension 14 | dim_out (int): Output dimension. For binary prediction, dim_out=1. 15 | """ 16 | 17 | def __init__(self, dim_in, dim_out): 18 | super(GNNInductiveNodeHead, self).__init__() 19 | self.layer_post_mp = MLP( 20 | new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp, 21 | has_act=False, has_bias=True, cfg=cfg)) 22 | 23 | def _apply_index(self, batch): 24 | return batch.x, batch.y 25 | 26 | def forward(self, batch): 27 | batch = self.layer_post_mp(batch) 28 | pred, label = self._apply_index(batch) 29 | return pred, label 30 | -------------------------------------------------------------------------------- /experiments/graphgps/head/ogb_code_graph.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import torch_geometric.graphgym.register as register 4 | from torch_geometric.graphgym import cfg 5 | from torch_geometric.graphgym.register import register_head 6 | 7 | 8 | @register_head('ogb_code_graph') 9 | class OGBCodeGraphHead(nn.Module): 10 | """ 11 | Sequence prediction head for ogbg-code2 graph-level prediction tasks. 12 | 13 | Args: 14 | dim_in (int): Input dimension. 15 | dim_out (int): IGNORED, kept for GraphGym framework compatibility 16 | L (int): Number of hidden layers. 17 | """ 18 | 19 | def __init__(self, dim_in, dim_out, L=1): 20 | super().__init__() 21 | self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling] 22 | self.L = L 23 | num_vocab = 5002 24 | self.max_seq_len = 5 25 | 26 | if self.L != 1: 27 | raise ValueError(f"Multilayer prediction heads are not supported.") 28 | 29 | self.graph_pred_linear_list = nn.ModuleList() 30 | for i in range(self.max_seq_len): 31 | self.graph_pred_linear_list.append(nn.Linear(dim_in, num_vocab)) 32 | 33 | def _apply_index(self, batch): 34 | return batch.pred_list, {'y_arr': batch.y_arr, 'y': batch.y} 35 | 36 | def forward(self, batch): 37 | graph_emb = self.pooling_fun(batch.x, batch.batch) 38 | 39 | pred_list = [] 40 | for i in range(self.max_seq_len): 41 | pred_list.append(self.graph_pred_linear_list[i](graph_emb)) 42 | batch.pred_list = pred_list 43 | 44 | pred, label = self._apply_index(batch) 45 | return pred, label 46 | -------------------------------------------------------------------------------- /experiments/graphgps/head/san_graph.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | import torch_geometric.graphgym.register as register 5 | from torch_geometric.graphgym import cfg 6 | from torch_geometric.graphgym.register import register_head 7 | 8 | 9 | @register_head('san_graph') 10 | class SANGraphHead(nn.Module): 11 | """ 12 | SAN prediction head for graph prediction tasks. 13 | 14 | Args: 15 | dim_in (int): Input dimension. 16 | dim_out (int): Output dimension. For binary prediction, dim_out=1. 17 | L (int): Number of hidden layers. 18 | """ 19 | 20 | def __init__(self, dim_in, dim_out, L=2): 21 | super().__init__() 22 | self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling] 23 | list_FC_layers = [ 24 | nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) 25 | for l in range(L)] 26 | list_FC_layers.append( 27 | nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) 28 | self.FC_layers = nn.ModuleList(list_FC_layers) 29 | self.L = L 30 | 31 | def _apply_index(self, batch): 32 | return batch.graph_feature, batch.y 33 | 34 | def forward(self, batch): 35 | graph_emb = self.pooling_fun(batch.x, batch.batch) 36 | for l in range(self.L): 37 | graph_emb = self.FC_layers[l](graph_emb) 38 | graph_emb = F.relu(graph_emb) 39 | graph_emb = self.FC_layers[self.L](graph_emb) 40 | batch.graph_feature = graph_emb 41 | pred, label = self._apply_index(batch) 42 | return pred, label 43 | -------------------------------------------------------------------------------- /experiments/graphgps/layer/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/layer/chienn_layer_wrapper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from torch_geometric.data import Batch 4 | from torch_geometric.graphgym import cfg 5 | 6 | sys.path.append('..') 7 | from chienn import ChiENNLayer 8 | 9 | 10 | class ChiENNLayerWrapper(ChiENNLayer): 11 | """ 12 | Wrapper for ChiENNLayer that loads the parameters form the graphgym config, making the initialization easier. 13 | """ 14 | 15 | def __init__(self, hidden_dim: int, dropout: float = 0.0, return_batch: bool = False): 16 | super().__init__( 17 | hidden_dim=hidden_dim, 18 | k_neighbors_embeddings_names=cfg.chienn.message.k_neighbors_embeddings_names, 19 | message_final_embedding_name=cfg.chienn.message.final_embedding_name, 20 | aggregation_name=cfg.chienn.aggregate.aggregation, 21 | self_embedding_name=cfg.chienn.aggregate.self_embedding_name, 22 | parallel_embedding_name=cfg.chienn.aggregate.parallel_embedding_name, 23 | post_aggregation_embedding_name=cfg.chienn.aggregate.post_aggregation_embedding_name, 24 | dropout=dropout 25 | ) 26 | self.return_batch = return_batch 27 | 28 | def forward(self, batch: Batch): 29 | x = super().forward(batch) 30 | if self.return_batch: 31 | batch.x = x 32 | return batch 33 | else: 34 | return x 35 | -------------------------------------------------------------------------------- /experiments/graphgps/layer/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | from torch_geometric.graphgym.config import cfg 6 | from torch_geometric.graphgym.register import register_layer 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.nn.inits import glorot, zeros 9 | 10 | # Note: A registered GNN layer should take 'batch' as input 11 | # and 'batch' as output 12 | 13 | 14 | # Example 1: Directly define a GraphGym format Conv 15 | # take 'batch' as input and 'batch' as output 16 | @register_layer('exampleconv1') 17 | class ExampleConv1(MessagePassing): 18 | r"""Example GNN layer 19 | """ 20 | def __init__(self, in_channels, out_channels, bias=True, **kwargs): 21 | super().__init__(aggr=cfg.gnn.agg, **kwargs) 22 | 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | 26 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 27 | 28 | if bias: 29 | self.bias = Parameter(torch.Tensor(out_channels)) 30 | else: 31 | self.register_parameter('bias', None) 32 | 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | glorot(self.weight) 37 | zeros(self.bias) 38 | 39 | def forward(self, batch): 40 | """""" 41 | x, edge_index = batch.x, batch.edge_index 42 | x = torch.matmul(x, self.weight) 43 | 44 | batch.x = self.propagate(edge_index, x=x) 45 | 46 | return batch 47 | 48 | def message(self, x_j): 49 | return x_j 50 | 51 | def update(self, aggr_out): 52 | if self.bias is not None: 53 | aggr_out = aggr_out + self.bias 54 | return aggr_out 55 | 56 | 57 | # Example 2: First define a PyG format Conv layer 58 | # Then wrap it to become GraphGym format 59 | class ExampleConv2Layer(MessagePassing): 60 | r"""Example GNN layer 61 | """ 62 | def __init__(self, in_channels, out_channels, bias=True, **kwargs): 63 | super().__init__(aggr=cfg.gnn.agg, **kwargs) 64 | 65 | self.in_channels = in_channels 66 | self.out_channels = out_channels 67 | 68 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 69 | 70 | if bias: 71 | self.bias = Parameter(torch.Tensor(out_channels)) 72 | else: 73 | self.register_parameter('bias', None) 74 | 75 | self.reset_parameters() 76 | 77 | def reset_parameters(self): 78 | glorot(self.weight) 79 | zeros(self.bias) 80 | 81 | def forward(self, x, edge_index): 82 | """""" 83 | x = torch.matmul(x, self.weight) 84 | 85 | return self.propagate(edge_index, x=x) 86 | 87 | def message(self, x_j): 88 | return x_j 89 | 90 | def update(self, aggr_out): 91 | if self.bias is not None: 92 | aggr_out = aggr_out + self.bias 93 | return aggr_out 94 | 95 | 96 | @register_layer('exampleconv2') 97 | class ExampleConv2(nn.Module): 98 | def __init__(self, dim_in, dim_out, bias=False, **kwargs): 99 | super().__init__() 100 | self.model = ExampleConv2Layer(dim_in, dim_out, bias=bias) 101 | 102 | def forward(self, batch): 103 | batch.x = self.model(batch.x, batch.edge_index) 104 | return batch 105 | -------------------------------------------------------------------------------- /experiments/graphgps/layer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.data import Batch 4 | from torch_geometric.graphgym import cfg 5 | from torch_geometric.utils import to_dense_batch 6 | 7 | 8 | class AtomDistance(nn.Module): 9 | """ 10 | Adds inversed attoms' distances to a batch. 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | coords_selection = cfg.model.coords_selection 16 | if coords_selection == 'start': 17 | self.coords_selection_fn = lambda x: x[:, :3] 18 | elif coords_selection == 'end': 19 | self.coords_selection_fn = lambda x: x[:, -3:] 20 | elif coords_selection == 'center': 21 | self.coords_selection_fn = lambda x: (x[:, :3] + x[:, -3:]) * 0.5 22 | else: 23 | raise NotImplemented(f'Unknown corrds_selection {coords_selection}.') 24 | 25 | def forward(self, batch: Batch) -> Batch: 26 | pos = self.coords_selection_fn(batch.pos) 27 | pos, mask = to_dense_batch(pos, batch.batch) 28 | distances = torch.cdist(pos, pos, compute_mode="donot_use_mm_for_euclid_dist") 29 | zero_distances_mask = distances <= 1e-3 # just for safety 30 | distances = 1. / (distances + 1e-16) 31 | batch.distances = distances 32 | batch.zero_distances_mask = zero_distances_mask 33 | batch.mask = mask 34 | return batch 35 | -------------------------------------------------------------------------------- /experiments/graphgps/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/loader/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/ChiENN/ee3185b39e8469a8caacf3d6d45a04c4a1cfff5b/experiments/graphgps/loader/dataset/__init__.py -------------------------------------------------------------------------------- /experiments/graphgps/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/loss/l1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import register_loss 4 | 5 | 6 | @register_loss('l1_losses') 7 | def l1_losses(pred, true): 8 | if cfg.model.loss_fun == 'l1': 9 | l1_loss = nn.L1Loss() 10 | loss = l1_loss(pred, true) 11 | return loss, pred 12 | elif cfg.model.loss_fun == 'smoothl1': 13 | l1_loss = nn.SmoothL1Loss() 14 | loss = l1_loss(pred, true) 15 | return loss, pred 16 | -------------------------------------------------------------------------------- /experiments/graphgps/loss/multilabel_classification_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import register_loss 4 | 5 | 6 | @register_loss('multilabel_cross_entropy') 7 | def multilabel_cross_entropy(pred, true): 8 | """Multilabel cross-entropy loss. 9 | """ 10 | if cfg.dataset.task_type == 'classification_multilabel': 11 | if cfg.model.loss_fun != 'cross_entropy': 12 | raise ValueError("Only 'cross_entropy' loss_fun supported with " 13 | "'classification_multilabel' task_type.") 14 | bce_loss = nn.BCEWithLogitsLoss() 15 | is_labeled = ~true.isnan() 16 | return bce_loss(pred[is_labeled], true[is_labeled].float()), pred 17 | -------------------------------------------------------------------------------- /experiments/graphgps/loss/subtoken_prediction_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.graphgym.config import cfg 3 | from torch_geometric.graphgym.register import register_loss 4 | 5 | 6 | @register_loss('subtoken_cross_entropy') 7 | def subtoken_cross_entropy(pred_list, true): 8 | """Subtoken prediction cross-entropy loss for ogbg-code2. 9 | """ 10 | if cfg.dataset.task_type == 'subtoken_prediction': 11 | if cfg.model.loss_fun != 'cross_entropy': 12 | raise ValueError("Only 'cross_entropy' loss_fun supported with " 13 | "'subtoken_prediction' task_type.") 14 | multicls_criterion = torch.nn.CrossEntropyLoss() 15 | loss = 0 16 | for i in range(len(pred_list)): 17 | loss += multicls_criterion(pred_list[i].to(torch.float32), true['y_arr'][:, i]) 18 | loss = loss / len(pred_list) 19 | 20 | return loss, pred_list 21 | -------------------------------------------------------------------------------- /experiments/graphgps/loss/weighted_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.register import register_loss 5 | 6 | 7 | @register_loss('weighted_cross_entropy') 8 | def weighted_cross_entropy(pred, true): 9 | """Weighted cross-entropy for unbalanced classes. 10 | """ 11 | if cfg.model.loss_fun == 'weighted_cross_entropy': 12 | # calculating label weights for weighted loss computation 13 | V = true.size(0) 14 | n_classes = pred.shape[1] if pred.ndim > 1 else 2 15 | label_count = torch.bincount(true) 16 | label_count = label_count[label_count.nonzero(as_tuple=True)].squeeze() 17 | cluster_sizes = torch.zeros(n_classes, device=pred.device).long() 18 | cluster_sizes[torch.unique(true)] = label_count 19 | weight = (V - cluster_sizes).float() / V 20 | weight *= (cluster_sizes > 0).float() 21 | # multiclass 22 | if pred.ndim > 1: 23 | pred = F.log_softmax(pred, dim=-1) 24 | return F.nll_loss(pred, true, weight=weight), pred 25 | # binary 26 | else: 27 | loss = F.binary_cross_entropy_with_logits(pred, true.float(), 28 | weight=weight[true]) 29 | return loss, torch.sigmoid(pred) 30 | -------------------------------------------------------------------------------- /experiments/graphgps/network/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/network/big_bird.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.register as register 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP 5 | from torch_geometric.graphgym.register import register_network 6 | 7 | from graphgps.layer.bigbird_layer import BigBirdModel as BackboneBigBird 8 | 9 | 10 | @register_network('BigBird') 11 | class BigBird(torch.nn.Module): 12 | """BigBird without edge features. 13 | This model disregards edge features and runs a linear transformer over a set of node features only. 14 | BirBird applies random sparse attention to the input sequence - the longer the sequence the closer it is to O(N) 15 | https://arxiv.org/abs/2007.14062 16 | """ 17 | 18 | def __init__(self, dim_in, dim_out): 19 | super().__init__() 20 | self.encoder = FeatureEncoder(dim_in) 21 | dim_in = self.encoder.dim_in 22 | 23 | if cfg.gnn.layers_pre_mp > 0: 24 | self.pre_mp = GNNPreMP( 25 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 26 | dim_in = cfg.gnn.dim_inner 27 | 28 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ 29 | "The inner and hidden dims must match." 30 | 31 | # Copy main Transformer hyperparams to the BigBird config. 32 | cfg.gt.bigbird.layers = cfg.gt.layers 33 | cfg.gt.bigbird.n_heads = cfg.gt.n_heads 34 | cfg.gt.bigbird.dim_hidden = cfg.gt.dim_hidden 35 | cfg.gt.bigbird.dropout = cfg.gt.dropout 36 | self.trf = BackboneBigBird( 37 | config=cfg.gt.bigbird, 38 | ) 39 | 40 | GNNHead = register.head_dict[cfg.gnn.head] 41 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 42 | 43 | def forward(self, batch): 44 | for module in self.children(): 45 | batch = module(batch) 46 | return batch 47 | -------------------------------------------------------------------------------- /experiments/graphgps/network/chiro.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | from torch_geometric.graphgym.config import cfg 5 | from torch_geometric.graphgym.register import register_network 6 | 7 | from graphgps.network.utils import get_local_structure_map 8 | from model.alpha_encoder import Encoder 9 | from model.params_interpreter import string_to_object 10 | 11 | 12 | @register_network('ChIRo') 13 | class ChIRo(torch.nn.Module): 14 | def __init__(self, dim_in, dim_out): 15 | super().__init__() 16 | with open(cfg.model.config_path, 'r') as fp: 17 | params = json.load(fp) 18 | for key, value in params['activation_dict'].items(): 19 | params['activation_dict'][key] = string_to_object[value] 20 | self.encoder = Encoder(F_H_embed=cfg.dataset.node_encoder_in_dim, 21 | F_E_embed=cfg.dataset.edge_encoder_in_dim, 22 | F_H=cfg.model.hidden_dim, 23 | F_H_EConv=cfg.model.hidden_dim, 24 | CMP_GAT_N_layers=cfg.gnn.layers, 25 | dropout=cfg.gnn.dropout, 26 | dim_out=dim_out, 27 | **params) 28 | 29 | def forward(self, batch): 30 | LS_map, alpha_indices = get_local_structure_map(batch.dihedral_angle_index) 31 | x = self.encoder(batch, LS_map, alpha_indices) 32 | return x[0], batch.y 33 | -------------------------------------------------------------------------------- /experiments/graphgps/network/custom_gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.models.head # noqa, register module 3 | import torch_geometric.graphgym.register as register 4 | from torch_geometric.graphgym.config import cfg 5 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP 6 | from torch_geometric.graphgym.register import register_network 7 | 8 | from graphgps.layer.gatedgcn_layer import GatedGCNLayer 9 | from graphgps.layer.gine_conv_layer import GINEConvLayer 10 | 11 | 12 | @register_network('custom_gnn') 13 | class CustomGNN(torch.nn.Module): 14 | """ 15 | GNN model that customizes the torch_geometric.graphgym.models.gnn.GNN 16 | to support specific handling of new conv layers. 17 | """ 18 | 19 | def __init__(self, dim_in, dim_out): 20 | super().__init__() 21 | self.encoder = FeatureEncoder(dim_in) 22 | dim_in = self.encoder.dim_in 23 | 24 | if cfg.gnn.layers_pre_mp > 0: 25 | self.pre_mp = GNNPreMP( 26 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 27 | dim_in = cfg.gnn.dim_inner 28 | 29 | assert cfg.gnn.dim_inner == dim_in, \ 30 | "The inner and hidden dims must match." 31 | 32 | conv_model = self.build_conv_model(cfg.gnn.layer_type) 33 | layers = [] 34 | for _ in range(cfg.gnn.layers_mp): 35 | layers.append(conv_model(dim_in, 36 | dim_in, 37 | dropout=cfg.gnn.dropout, 38 | residual=cfg.gnn.residual)) 39 | self.gnn_layers = torch.nn.Sequential(*layers) 40 | 41 | GNNHead = register.head_dict[cfg.gnn.head] 42 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 43 | 44 | def build_conv_model(self, model_type): 45 | if model_type == 'gatedgcnconv': 46 | return GatedGCNLayer 47 | elif model_type == 'gineconv': 48 | return GINEConvLayer 49 | else: 50 | raise ValueError("Model {} unavailable".format(model_type)) 51 | 52 | def forward(self, batch): 53 | for module in self.children(): 54 | batch = module(batch) 55 | return batch 56 | -------------------------------------------------------------------------------- /experiments/graphgps/network/dmpnn.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import torch 4 | from torch_geometric.graphgym.config import cfg 5 | from torch_geometric.graphgym.register import register_network 6 | 7 | from submodules.tetra_dmpnn.model.gnn import GNN 8 | 9 | 10 | @register_network('DMPNN') 11 | class DMPNN(torch.nn.Module): 12 | def __init__(self, dim_in, dim_out, pooling=True): 13 | super().__init__() 14 | args = { 15 | 'depth': cfg.gnn.layers, 16 | 'hidden_size': cfg.model.hidden_dim, 17 | 'dropout': cfg.gnn.dropout, 18 | 'gnn_type': 'dmpnn', 19 | 'graph_pool': cfg.model.graph_pooling, 20 | 'tetra': cfg.gnn.tetra.use, 21 | 'message': cfg.gnn.tetra.message, 22 | } 23 | self.model = GNN(args=Namespace(**args), 24 | num_node_features=cfg.dataset.node_encoder_in_dim, 25 | num_edge_features=cfg.dataset.edge_encoder_in_dim, 26 | out_dim=dim_out, 27 | pooling=pooling) 28 | 29 | def forward(self, batch): 30 | x = self.model(batch) 31 | return x, batch.y 32 | -------------------------------------------------------------------------------- /experiments/graphgps/network/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torch_geometric.graphgym.models.head # noqa, register module 6 | import torch_geometric.graphgym.register as register 7 | import torch_geometric.nn as pyg_nn 8 | from torch_geometric.graphgym.config import cfg 9 | from torch_geometric.graphgym.register import register_network 10 | 11 | 12 | @register_network('example') 13 | class ExampleGNN(torch.nn.Module): 14 | def __init__(self, dim_in, dim_out, num_layers=2, model_type='GCN'): 15 | super().__init__() 16 | conv_model = self.build_conv_model(model_type) 17 | self.convs = nn.ModuleList() 18 | self.convs.append(conv_model(dim_in, dim_in)) 19 | 20 | for _ in range(num_layers - 1): 21 | self.convs.append(conv_model(dim_in, dim_in)) 22 | 23 | GNNHead = register.head_dict[cfg.dataset.task] 24 | self.post_mp = GNNHead(dim_in=dim_in, dim_out=dim_out) 25 | 26 | def build_conv_model(self, model_type): 27 | if model_type == 'GCN': 28 | return pyg_nn.GCNConv 29 | elif model_type == 'GAT': 30 | return pyg_nn.GATConv 31 | elif model_type == "GraphSage": 32 | return pyg_nn.SAGEConv 33 | else: 34 | raise ValueError(f'Model {model_type} unavailable') 35 | 36 | def forward(self, batch): 37 | x, edge_index = batch.x, batch.edge_index 38 | 39 | for i in range(len(self.convs)): 40 | x = self.convs[i](x, edge_index) 41 | x = F.relu(x) 42 | x = F.dropout(x, p=0.1, training=self.training) 43 | 44 | batch.x = x 45 | batch = self.post_mp(batch) 46 | 47 | return batch 48 | -------------------------------------------------------------------------------- /experiments/graphgps/network/performer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.register as register 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP 5 | from torch_geometric.graphgym.register import register_network 6 | 7 | from graphgps.layer.performer_layer import Performer as BackbonePerformer 8 | 9 | 10 | @register_network('Performer') 11 | class Performer(torch.nn.Module): 12 | """Performer without edge features. 13 | This model disregards edge features and runs a linear transformer over a set of node features only. 14 | https://arxiv.org/abs/2009.14794 15 | """ 16 | 17 | def __init__(self, dim_in, dim_out): 18 | super().__init__() 19 | self.encoder = FeatureEncoder(dim_in) 20 | dim_in = self.encoder.dim_in 21 | 22 | if cfg.gnn.layers_pre_mp > 0: 23 | self.pre_mp = GNNPreMP( 24 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 25 | dim_in = cfg.gnn.dim_inner 26 | 27 | assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ 28 | "The inner and hidden dims must match." 29 | 30 | self.trf = BackbonePerformer( 31 | dim=cfg.gt.dim_hidden, 32 | depth=cfg.gt.layers, 33 | heads=cfg.gt.n_heads, 34 | dim_head=cfg.gt.dim_hidden // cfg.gt.n_heads 35 | ) 36 | 37 | GNNHead = register.head_dict[cfg.gnn.head] 38 | self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) 39 | 40 | def forward(self, batch): 41 | for module in self.children(): 42 | batch = module(batch) 43 | return batch 44 | -------------------------------------------------------------------------------- /experiments/graphgps/network/san_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.graphgym.register as register 3 | from torch_geometric.graphgym.config import cfg 4 | from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP 5 | from torch_geometric.graphgym.register import register_network 6 | 7 | from graphgps.layer.chienn_layer_wrapper import ChiENNLayerWrapper 8 | from graphgps.layer.san_layer import SANLayer 9 | from graphgps.layer.san2_layer import SAN2Layer 10 | 11 | 12 | @register_network('SANTransformer') 13 | class SANTransformer(torch.nn.Module): 14 | """Spectral Attention Network (SAN) Graph Transformer. 15 | https://arxiv.org/abs/2106.03893 16 | """ 17 | 18 | def __init__(self, dim_in, dim_out): 19 | super().__init__() 20 | cfg.gnn.dim_inner = cfg.model.hidden_dim 21 | cfg.gt.dim_hidden = cfg.model.hidden_dim 22 | self.encoder = FeatureEncoder(dim_in) 23 | dim_in = self.encoder.dim_in 24 | 25 | if cfg.gnn.layers_pre_mp > 0: 26 | self.pre_mp = GNNPreMP( 27 | dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) 28 | 29 | fake_edge_emb = torch.nn.Embedding(1, cfg.gt.dim_hidden) 30 | # torch.nn.init.xavier_uniform_(fake_edge_emb.weight.data) 31 | Layer = { 32 | 'SANLayer': SANLayer, 33 | 'SAN2Layer': SAN2Layer, 34 | }.get(cfg.gt.layer_type) 35 | layers = [] 36 | for i in range(cfg.gt.layers): 37 | layers.append(Layer(gamma=cfg.gt.gamma, 38 | in_dim=cfg.gt.dim_hidden, 39 | out_dim=cfg.gt.dim_hidden, 40 | num_heads=cfg.gt.n_heads, 41 | full_graph=cfg.gt.full_graph, 42 | fake_edge_emb=fake_edge_emb, 43 | dropout=cfg.gt.dropout, 44 | layer_norm=cfg.gt.layer_norm, 45 | batch_norm=cfg.gt.batch_norm, 46 | residual=cfg.gt.residual)) 47 | if cfg.model.add_chienn_layer: 48 | layers.append(ChiENNLayerWrapper( 49 | hidden_dim=cfg.model.hidden_dim, 50 | dropout=0.0, 51 | return_batch=True)) 52 | self.trf_layers = torch.nn.Sequential(*layers) 53 | 54 | GNNHead = register.head_dict[cfg.gnn.head] 55 | self.post_mp = GNNHead(dim_in=cfg.gt.dim_hidden, dim_out=dim_out) 56 | 57 | def forward(self, batch): 58 | for module in self.children(): 59 | batch = module(batch) 60 | return batch 61 | -------------------------------------------------------------------------------- /experiments/graphgps/network/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def get_local_structure_map(psi_indices): 7 | """ 8 | Adapted from ChIRo repository. 9 | """ 10 | LS_dict = OrderedDict() 11 | LS_map = torch.zeros(psi_indices.shape[1], dtype=torch.long).to(psi_indices.device) 12 | v = 0 13 | for i, indices in enumerate(psi_indices.T): 14 | tupl = (int(indices[1]), int(indices[2])) 15 | if tupl not in LS_dict: 16 | LS_dict[tupl] = v 17 | v += 1 18 | LS_map[i] = LS_dict[tupl] 19 | 20 | alpha_indices = torch.zeros((2, len(LS_dict)), dtype=torch.long) 21 | for i, tupl in enumerate(LS_dict): 22 | alpha_indices[:, i] = torch.LongTensor(tupl) 23 | 24 | return LS_map, alpha_indices.to(psi_indices.device) 25 | -------------------------------------------------------------------------------- /experiments/graphgps/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/pooling/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/pooling/example.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.graphgym.register import register_pooling 2 | from torch_scatter import scatter 3 | 4 | 5 | @register_pooling('example') 6 | def global_example_pool(x, batch, size=None): 7 | size = batch.max().item() + 1 if size is None else size 8 | return scatter(x, batch, dim=0, dim_size=size, reduce='add') 9 | -------------------------------------------------------------------------------- /experiments/graphgps/stage/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/stage/example.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from torch_geometric.graphgym.config import cfg 5 | from torch_geometric.graphgym.models.layer import GeneralLayer 6 | from torch_geometric.graphgym.register import register_stage 7 | 8 | 9 | def GNNLayer(dim_in, dim_out, has_act=True): 10 | return GeneralLayer(cfg.gnn.layer_type, dim_in, dim_out, has_act) 11 | 12 | 13 | @register_stage('example') 14 | class GNNStackStage(nn.Module): 15 | '''Simple Stage that stack GNN layers''' 16 | def __init__(self, dim_in, dim_out, num_layers): 17 | super().__init__() 18 | for i in range(num_layers): 19 | d_in = dim_in if i == 0 else dim_out 20 | layer = GNNLayer(d_in, dim_out) 21 | self.add_module(f'layer{i}', layer) 22 | self.dim_out = dim_out 23 | 24 | def forward(self, batch): 25 | for layer in self.children(): 26 | batch = layer(batch) 27 | if cfg.gnn.l2norm: 28 | batch.x = F.normalize(batch.x, p=2, dim=-1) 29 | return batch 30 | -------------------------------------------------------------------------------- /experiments/graphgps/train/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/train/example.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | 4 | import torch 5 | 6 | from torch_geometric.graphgym.checkpoint import ( 7 | clean_ckpt, 8 | load_ckpt, 9 | save_ckpt, 10 | ) 11 | from torch_geometric.graphgym.config import cfg 12 | from torch_geometric.graphgym.loss import compute_loss 13 | from torch_geometric.graphgym.register import register_train 14 | from torch_geometric.graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch 15 | 16 | 17 | def train_epoch(logger, loader, model, optimizer, scheduler): 18 | model.train() 19 | time_start = time.time() 20 | for batch in loader: 21 | optimizer.zero_grad() 22 | batch.to(torch.device(cfg.device)) 23 | pred, true = model(batch) 24 | loss, pred_score = compute_loss(pred, true) 25 | loss.backward() 26 | optimizer.step() 27 | logger.update_stats(true=true.detach().cpu(), 28 | pred=pred_score.detach().cpu(), loss=loss.item(), 29 | lr=scheduler.get_last_lr()[0], 30 | time_used=time.time() - time_start, 31 | params=cfg.params) 32 | time_start = time.time() 33 | scheduler.step() 34 | 35 | 36 | def eval_epoch(logger, loader, model): 37 | model.eval() 38 | time_start = time.time() 39 | for batch in loader: 40 | batch.to(torch.device(cfg.device)) 41 | pred, true = model(batch) 42 | loss, pred_score = compute_loss(pred, true) 43 | logger.update_stats(true=true.detach().cpu(), 44 | pred=pred_score.detach().cpu(), loss=loss.item(), 45 | lr=0, time_used=time.time() - time_start, 46 | params=cfg.params) 47 | time_start = time.time() 48 | 49 | 50 | @register_train('example') 51 | def train_example(loggers, loaders, model, optimizer, scheduler): 52 | start_epoch = 0 53 | if cfg.train.auto_resume: 54 | start_epoch = load_ckpt(model, optimizer, scheduler, 55 | cfg.train.epoch_resume) 56 | if start_epoch == cfg.optim.max_epoch: 57 | logging.info('Checkpoint found, Task already done') 58 | else: 59 | logging.info('Start from epoch %s', start_epoch) 60 | 61 | num_splits = len(loggers) 62 | for cur_epoch in range(start_epoch, cfg.optim.max_epoch): 63 | train_epoch(loggers[0], loaders[0], model, optimizer, scheduler) 64 | loggers[0].write_epoch(cur_epoch) 65 | if is_eval_epoch(cur_epoch): 66 | for i in range(1, num_splits): 67 | eval_epoch(loggers[i], loaders[i], model) 68 | loggers[i].write_epoch(cur_epoch) 69 | if is_ckpt_epoch(cur_epoch): 70 | save_ckpt(model, optimizer, scheduler, cur_epoch) 71 | for logger in loggers: 72 | logger.close() 73 | if cfg.train.ckpt_clean: 74 | clean_ckpt() 75 | 76 | logging.info('Task done, results saved in %s', cfg.run_dir) 77 | -------------------------------------------------------------------------------- /experiments/graphgps/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, basename, isfile, join 2 | import glob 3 | 4 | modules = glob.glob(join(dirname(__file__), "*.py")) 5 | __all__ = [ 6 | basename(f)[:-3] for f in modules 7 | if isfile(f) and not f.endswith('__init__.py') 8 | ] 9 | -------------------------------------------------------------------------------- /experiments/graphgps/transform/transforms.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch_geometric.utils import subgraph 5 | from tqdm import tqdm 6 | 7 | 8 | def pre_transform_in_memory(dataset, transform_func, show_progress=False): 9 | """Pre-transform already loaded PyG dataset object. 10 | 11 | Apply transform function to a loaded PyG dataset object so that 12 | the transformed result is persistent for the lifespan of the object. 13 | This means the result is not saved to disk, as what PyG's `pre_transform` 14 | would do, but also the transform is applied only once and not at each 15 | data access as what PyG's `transform` hook does. 16 | 17 | Implementation is based on torch_geometric.data.in_memory_dataset.copy 18 | 19 | Args: 20 | dataset: PyG dataset object to modify 21 | transform_func: transformation function to apply to each data example 22 | show_progress: show tqdm progress bar 23 | """ 24 | if transform_func is None: 25 | return dataset 26 | 27 | data_list = [transform_func(dataset.get(i)) 28 | for i in tqdm(range(len(dataset)), 29 | disable=not show_progress, 30 | mininterval=10, 31 | miniters=len(dataset)//20)] 32 | data_list = list(filter(None, data_list)) 33 | 34 | dataset._indices = None 35 | dataset._data_list = data_list 36 | dataset.data, dataset.slices = dataset.collate(data_list) 37 | 38 | 39 | def typecast_x(data, type_str): 40 | if type_str == 'float': 41 | data.x = data.x.float() 42 | elif type_str == 'long': 43 | data.x = data.x.long() 44 | else: 45 | raise ValueError(f"Unexpected type '{type_str}'.") 46 | return data 47 | 48 | 49 | def concat_x_and_pos(data): 50 | data.x = torch.cat((data.x, data.pos), 1) 51 | return data 52 | 53 | 54 | def clip_graphs_to_size(data, size_limit=5000): 55 | if hasattr(data, 'num_nodes'): 56 | N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa 57 | else: 58 | N = data.x.shape[0] # Number of nodes, including disconnected nodes. 59 | if N <= size_limit: 60 | return data 61 | else: 62 | logging.info(f' ...clip to {size_limit} a graph of size: {N}') 63 | if hasattr(data, 'edge_attr'): 64 | edge_attr = data.edge_attr 65 | else: 66 | edge_attr = None 67 | edge_index, edge_attr = subgraph(list(range(size_limit)), 68 | data.edge_index, edge_attr) 69 | if hasattr(data, 'x'): 70 | data.x = data.x[:size_limit] 71 | data.num_nodes = size_limit 72 | else: 73 | data.num_nodes = size_limit 74 | if hasattr(data, 'node_is_attributed'): # for ogbg-code2 dataset 75 | data.node_is_attributed = data.node_is_attributed[:size_limit] 76 | data.node_dfs_order = data.node_dfs_order[:size_limit] 77 | data.node_depth = data.node_depth[:size_limit] 78 | data.edge_index = edge_index 79 | if hasattr(data, 'edge_attr'): 80 | data.edge_attr = edge_attr 81 | return data 82 | -------------------------------------------------------------------------------- /experiments/retrieve_grid_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | from torch_geometric.graphgym.cmd_args import parse_args 7 | from torch_geometric.graphgym.config import (cfg, set_cfg) 8 | 9 | from main import custom_set_out_dir, custom_load_cfg 10 | 11 | def load_json(path): 12 | with open(path, 'r') as f: 13 | return json.load(f) 14 | 15 | 16 | if __name__ == '__main__': 17 | # Load cmd line args 18 | args = parse_args() 19 | opts = args.opts 20 | repeat = args.repeat 21 | grid_config_path = Path(args.cfg_file) 22 | 23 | with open(grid_config_path, 'r') as f: 24 | grid_config = json.load(f) 25 | 26 | cfg_file = grid_config['base_config_path'] 27 | set_cfg(cfg) 28 | custom_load_cfg(cfg=cfg, cfg_file=cfg_file, opts=opts) 29 | experiment_name = grid_config_path.name.split('.')[0] 30 | grid_dir = Path(cfg.out_dir) / experiment_name 31 | 32 | metric = cfg.metric_best 33 | direction = cfg.metric_agg 34 | 35 | results_list = [] 36 | for param_path in os.listdir(grid_dir): 37 | if param_path.endswith('.yaml'): 38 | continue 39 | param_dir_path = grid_dir / param_path 40 | param_path = param_dir_path / 'params.json' 41 | test_path = param_dir_path / 'agg' / 'test' / 'best.json' 42 | valid_path = param_dir_path / 'agg' / 'val' / 'best.json' 43 | 44 | param_dict = load_json(param_path) 45 | test_dict = load_json(test_path) 46 | val_dict = load_json(valid_path) 47 | 48 | param_dict.update({ 49 | f'test_{metric}': round(test_dict[metric], 3), 50 | f'val_{metric}': round(val_dict[metric], 3) 51 | }) 52 | 53 | results_list.append(param_dict) 54 | 55 | df = pd.DataFrame(results_list) 56 | df = df.sort_values(by=f'val_{metric}', ascending=direction != 'argmax') 57 | df.to_csv('grid_result.csv', index=None) 58 | 59 | -------------------------------------------------------------------------------- /experiments/retrieve_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | from torch_geometric.graphgym.cmd_args import parse_args 5 | from torch_geometric.graphgym.config import (cfg, set_cfg) 6 | 7 | from main import custom_set_out_dir, custom_load_cfg 8 | 9 | if __name__ == '__main__': 10 | # Load cmd line args 11 | args = parse_args() 12 | # Load config file 13 | if args.cfg_file.endswith('.json'): 14 | grid_config_path = Path(args.cfg_file) 15 | with open(grid_config_path, 'r') as f: 16 | grid_config = json.load(f) 17 | cfg_file = grid_config['base_config_path'] 18 | set_cfg(cfg) 19 | custom_load_cfg(cfg=cfg, cfg_file=cfg_file, opts=args.opts) 20 | experiment_name = grid_config_path.name.split('.')[0] 21 | grid_dir = Path(cfg.out_dir) / experiment_name 22 | results_dir = grid_dir / 'final_run' / 'agg' / 'test' 23 | else: 24 | set_cfg(cfg) 25 | custom_load_cfg(cfg, args.cfg_file, args.opts) 26 | custom_set_out_dir(cfg, args.cfg_file, cfg.name_tag) 27 | results_dir = Path(cfg.out_dir) / 'agg' / 'test' 28 | 29 | metric = cfg.metric_best 30 | metric_std = f'{cfg.metric_best}_std' 31 | best_path = results_dir / 'best.json' 32 | last_path = results_dir / 'stats.json' 33 | with open(best_path, 'r') as fp: 34 | best = json.load(fp) 35 | 36 | with open(last_path, 'r') as fp: 37 | last = list(fp.readlines())[-1] 38 | last = json.loads(last) 39 | 40 | 41 | def reformat(m, s): 42 | return f'{round(m, 3):.3f} \u00B1 {round(s, 3):.3f}' 43 | 44 | 45 | best_str = reformat(best[metric], best[metric_std]) 46 | last_str = reformat(last[metric], last[metric_std]) 47 | print(f'{best_str}\t{last_str}') 48 | -------------------------------------------------------------------------------- /experiments/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | import os 4 | 5 | # User-friendly description from README.md 6 | current_directory = os.path.dirname(os.path.abspath(__file__)) 7 | try: 8 | with open(os.path.join(current_directory, 'README.md'), encoding='utf-8') as f: 9 | long_description = f.read() 10 | except Exception: 11 | long_description = '' 12 | 13 | setup( 14 | # Name of the package 15 | name='chienn-experiments', 16 | # Packages to include into the distribution 17 | packages=find_packages('.'), 18 | # Start with a small number and increase it with 19 | # every change you make https://semver.org 20 | version='1.1.0', 21 | # Chose a license from here: https: // 22 | # help.github.com / articles / licensing - a - 23 | # repository. For example: MIT 24 | license='', 25 | # Short description of your library 26 | description='', 27 | # Long description of your library 28 | long_description=long_description, 29 | long_description_content_type='text/markdown', 30 | # Your name 31 | author='', 32 | # Your email 33 | author_email='', 34 | # Either the link to your github or to your website 35 | url='', 36 | # Link from which the project can be downloaded 37 | download_url='', 38 | # List of keywords 39 | keywords=[], 40 | # List of packages to install with this one 41 | install_requires=[], 42 | # https://pypi.org/classifiers/ 43 | classifiers=[] 44 | ) 45 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 keiradams 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/model/gnn_3D/optimization_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | import torch_scatter 6 | 7 | def BCE_loss(y, y_hat): 8 | BCE = torch.nn.BCEWithLogitsLoss() 9 | return BCE(y_hat, y) 10 | 11 | def MSE(y, y_hat): 12 | MSE = torch.mean(torch.square(y - y_hat)) 13 | return MSE 14 | 15 | def tripletLoss(z_anchor, z_positive, z_negative, margin = 1.0, reduction = 'mean', distance_metric = 'euclidean'): 16 | if distance_metric == 'euclidean': 17 | criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=2.0), 18 | margin=margin, 19 | swap=False, 20 | reduction=reduction) 21 | elif distance_metric == 'euclidean_normalized': 22 | criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=2.0), 23 | margin=margin, 24 | swap=False, 25 | reduction=reduction) 26 | elif distance_metric == 'manhattan': 27 | criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=1.0), 28 | margin=margin, 29 | swap=False, 30 | reduction=reduction) 31 | elif distance_metric == 'cosine': 32 | criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function= lambda x, y: 1.0 - torch.nn.functional.cosine_similarity(x, y), 33 | margin=margin, 34 | swap=False, 35 | reduction=reduction) 36 | else: 37 | raise Exception(f'distance metric {distance_metric} is not implemented') 38 | 39 | if distance_metric == 'euclidean_normalized': 40 | z_anchor = z_anchor / torch.linalg.norm(z_anchor + 1e-10, dim=1, keepdim = True) 41 | z_positive = z_positive / torch.linalg.norm(z_positive + 1e-10, dim=1, keepdim = True) 42 | z_negative = z_negative / torch.linalg.norm(z_negative + 1e-10, dim=1, keepdim = True) 43 | 44 | loss = criterion(z_anchor, z_positive, z_negative) 45 | return loss 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/model/optimization_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | import torch_scatter 6 | 7 | def BCE_loss(y, y_hat): 8 | BCE = torch.nn.BCEWithLogitsLoss() 9 | return BCE(y_hat, y) 10 | 11 | def MSE(y, y_hat): 12 | MSE = torch.mean(torch.square(y - y_hat)) 13 | return MSE 14 | 15 | def tripletLoss(z_anchor, z_positive, z_negative, margin = 1.0, reduction = 'mean', distance_metric = 'euclidean'): 16 | if distance_metric == 'euclidean': 17 | criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=2.0), 18 | margin=margin, 19 | swap=False, 20 | reduction=reduction) 21 | elif distance_metric == 'euclidean_normalized': 22 | criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=2.0), 23 | margin=margin, 24 | swap=False, 25 | reduction=reduction) 26 | elif distance_metric == 'manhattan': 27 | criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.PairwiseDistance(p=1.0), 28 | margin=margin, 29 | swap=False, 30 | reduction=reduction) 31 | elif distance_metric == 'cosine': 32 | criterion = torch.nn.TripletMarginWithDistanceLoss(distance_function= lambda x, y: 1.0 - torch.nn.functional.cosine_similarity(x, y), 33 | margin=margin, 34 | swap=False, 35 | reduction=reduction) 36 | else: 37 | raise Exception(f'distance metric {distance_metric} is not implemented') 38 | 39 | if distance_metric == 'euclidean_normalized': 40 | z_anchor = z_anchor / torch.linalg.norm(z_anchor + 1e-10, dim=1, keepdim = True) 41 | z_positive = z_positive / torch.linalg.norm(z_positive + 1e-10, dim=1, keepdim = True) 42 | z_negative = z_negative / torch.linalg.norm(z_negative + 1e-10, dim=1, keepdim = True) 43 | 44 | loss = criterion(z_anchor, z_positive, z_negative) 45 | return loss 46 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/model/params_interpreter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric 4 | import numpy as np 5 | 6 | string_to_object = { 7 | "torch.nn.LeakyReLU(negative_slope=0.01)": torch.nn.LeakyReLU(negative_slope=0.01), 8 | "torch.nn.LeakyReLU()": torch.nn.LeakyReLU(), 9 | "torch.nn.Identity()": torch.nn.Identity(), 10 | "torch.nn.ReLU()": torch.nn.ReLU(), 11 | "torch.nn.Sigmoid()": torch.nn.Sigmoid(), 12 | "torch.nn.Tanh()": torch.nn.Tanh() 13 | } 14 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_LD_ChIRo.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers_dict": 3 | { 4 | "EConv_mlp_hidden_sizes": [64], 5 | "GAT_hidden_node_sizes": [32], 6 | 7 | "encoder_hidden_sizes_D": [32, 32], 8 | "encoder_hidden_sizes_phi": [32, 32], 9 | "encoder_hidden_sizes_c": [32, 32], 10 | "encoder_hidden_sizes_alpha": [32, 32], 11 | 12 | "encoder_hidden_sizes_sinusoidal_shift": [256, 256, 256], 13 | "output_mlp_hidden_sizes": [64, 64] 14 | }, 15 | 16 | 17 | "activation_dict": 18 | { 19 | "encoder_hidden_activation_D": "torch.nn.LeakyReLU(negative_slope=0.01)", 20 | "encoder_hidden_activation_phi": "torch.nn.LeakyReLU(negative_slope=0.01)", 21 | "encoder_hidden_activation_c": "torch.nn.LeakyReLU(negative_slope=0.01)", 22 | "encoder_hidden_activation_alpha": "torch.nn.LeakyReLU(negative_slope=0.01)", 23 | "encoder_hidden_activation_sinusoidal_shift": "torch.nn.LeakyReLU(negative_slope=0.01)", 24 | 25 | "encoder_output_activation_D": "torch.nn.Identity()", 26 | "encoder_output_activation_phi": "torch.nn.Identity()", 27 | "encoder_output_activation_c": "torch.nn.Identity()", 28 | "encoder_output_activation_alpha": "torch.nn.Identity()", 29 | "encoder_output_activation_sinusoidal_shift": "torch.nn.Identity()", 30 | 31 | "EConv_mlp_hidden_activation": "torch.nn.LeakyReLU(negative_slope=0.01)", 32 | "EConv_mlp_output_activation": "torch.nn.Identity()", 33 | 34 | "output_mlp_hidden_activation": "torch.nn.LeakyReLU(negative_slope=0.01)", 35 | "output_mlp_output_activation": "torch.nn.Identity()" 36 | }, 37 | 38 | "pretrained": "", 39 | "random_seed": 1, 40 | 41 | "F_z_list": [8, 8, 8], 42 | "F_H": 32, 43 | "F_H_EConv": 32, 44 | "GAT_N_heads": 2, 45 | "EConv_bias": true, 46 | "GAT_bias": true, 47 | "encoder_biases": true, 48 | "dropout": 0.0, 49 | 50 | "chiral_message_passing": false, 51 | "CMP_EConv_MLP_hidden_sizes": [256, 256, 256], 52 | "CMP_GAT_N_layers": 3, 53 | "CMP_GAT_N_heads": 2, 54 | 55 | "c_coefficient_mode": "learned", 56 | "c_coefficient_normalization": "sigmoid", 57 | "phase_shift_coefficient_mode": "learned", 58 | "auxillary_torsion_loss": 0.0018603774073415525, 59 | 60 | "encoder_reduction": "sum", 61 | 62 | "output_concatenation_mode": "both", 63 | 64 | "default_lr": 0.00012821924940469874, 65 | 66 | "num_workers": 8, 67 | "batch_size": 16, 68 | "N_epochs": 100, 69 | 70 | "CV_fold": 1, 71 | "train_datafile": "final_data_splits/chloroform_CV_dataset_optical_rotation_5_rdkit_MOL_ee_95_MW_564_150380.pkl", 72 | "validation_datafile": "final_data_splits/chloroform_CV_dataset_optical_rotation_5_rdkit_MOL_ee_95_MW_564_150380.pkl", 73 | "test_datafile": "final_data_splits/chloroform_CV_dataset_optical_rotation_5_rdkit_MOL_ee_95_MW_564_150380.pkl", 74 | 75 | "iteration_mode": "stereoisomers", 76 | "sample_1conformer": false, 77 | "select_N_enantiomers": null, 78 | 79 | "mask_coordinates": false, 80 | "stereoMask": true, 81 | 82 | "grouping": "none", 83 | "weighted_sum": true, 84 | 85 | "stratified": false, 86 | "withoutReplacement": true, 87 | 88 | "loss_function": "BCE", 89 | "absolute_penalty": null, 90 | "relative_penalty": null, 91 | "ranking_margin": null, 92 | 93 | "contrastive_vector": "none", 94 | "margin": null, 95 | 96 | "N_neg": 1, 97 | "N_pos": 0, 98 | 99 | "save": true 100 | } 101 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_LD_spherenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained": "", 3 | "random_seed": 1, 4 | 5 | "hidden_channels": 64, 6 | "out_channels": 64, 7 | "cutoff": 5.0, 8 | "num_layers": 4, 9 | "int_emb_size": 128, 10 | "basis_emb_size_dist": 8, 11 | "basis_emb_size_angle": 8, 12 | "basis_emb_size_torsion": 8, 13 | "out_emb_channels": 64, 14 | "num_spherical": 7, 15 | "num_radial": 6, 16 | "envelope_exponent": 5, 17 | "num_before_skip": 1, 18 | "num_after_skip": 2, 19 | "num_output_layers": 3, 20 | "MLP_hidden_sizes": [64, 64], 21 | 22 | "lr": 0.00047930126522077043, 23 | 24 | "num_workers": 8, 25 | "batch_size": 16, 26 | "N_epochs": 100, 27 | 28 | "CV_fold": 1, 29 | "train_datafile": "final_data_splits/chloroform_CV_dataset_optical_rotation_5_rdkit_MOL_ee_95_MW_564_150380.pkl", 30 | "validation_datafile": "final_data_splits/chloroform_CV_dataset_optical_rotation_5_rdkit_MOL_ee_95_MW_564_150380.pkl", 31 | "test_datafile": "final_data_splits/chloroform_CV_dataset_optical_rotation_5_rdkit_MOL_ee_95_MW_564_150380.pkl", 32 | 33 | "iteration_mode": "stereoisomers", 34 | "sample_1conformer": false, 35 | "select_N_enantiomers": null, 36 | 37 | "grouping": "none", 38 | "weighted_sum": true, 39 | 40 | "stratified": false, 41 | "withoutReplacement": true, 42 | 43 | "loss_function": "BCE", 44 | "absolute_penalty": null, 45 | "relative_penalty": null, 46 | "ranking_margin": null, 47 | 48 | "margin": null, 49 | 50 | "N_neg": 1, 51 | "N_pos": 0, 52 | 53 | "save": true 54 | } 55 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_RS_ChIRo.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers_dict": 3 | { 4 | "EConv_mlp_hidden_sizes": [64], 5 | "GAT_hidden_node_sizes": [32, 32], 6 | 7 | "encoder_hidden_sizes_D": [128, 128], 8 | "encoder_hidden_sizes_phi": [128, 128], 9 | "encoder_hidden_sizes_c": [128, 128], 10 | "encoder_hidden_sizes_alpha": [128, 128], 11 | 12 | "encoder_hidden_sizes_sinusoidal_shift": [256, 256], 13 | "output_mlp_hidden_sizes": [64, 64] 14 | }, 15 | 16 | 17 | "activation_dict": 18 | { 19 | "encoder_hidden_activation_D": "torch.nn.LeakyReLU(negative_slope=0.01)", 20 | "encoder_hidden_activation_phi": "torch.nn.LeakyReLU(negative_slope=0.01)", 21 | "encoder_hidden_activation_c": "torch.nn.LeakyReLU(negative_slope=0.01)", 22 | "encoder_hidden_activation_alpha": "torch.nn.LeakyReLU(negative_slope=0.01)", 23 | "encoder_hidden_activation_sinusoidal_shift": "torch.nn.LeakyReLU(negative_slope=0.01)", 24 | 25 | "encoder_output_activation_D": "torch.nn.Identity()", 26 | "encoder_output_activation_phi": "torch.nn.Identity()", 27 | "encoder_output_activation_c": "torch.nn.Identity()", 28 | "encoder_output_activation_alpha": "torch.nn.Identity()", 29 | "encoder_output_activation_sinusoidal_shift": "torch.nn.Identity()", 30 | 31 | "EConv_mlp_hidden_activation": "torch.nn.LeakyReLU(negative_slope=0.01)", 32 | "EConv_mlp_output_activation": "torch.nn.Identity()", 33 | 34 | "output_mlp_hidden_activation": "torch.nn.LeakyReLU(negative_slope=0.01)", 35 | "output_mlp_output_activation": "torch.nn.Identity()" 36 | }, 37 | 38 | "pretrained": "", 39 | "random_seed": 1, 40 | 41 | "F_z_list": [64, 64, 64], 42 | "F_H": 64, 43 | "F_H_EConv": 64, 44 | "GAT_N_heads": 4, 45 | "EConv_bias": true, 46 | "GAT_bias": true, 47 | "encoder_biases": true, 48 | "dropout": 0.0, 49 | 50 | "chiral_message_passing": false, 51 | "CMP_EConv_MLP_hidden_sizes": [32], 52 | "CMP_GAT_N_layers": 3, 53 | "CMP_GAT_N_heads": 2, 54 | 55 | "c_coefficient_mode": "learned", 56 | "c_coefficient_normalization": "sigmoid", 57 | "phase_shift_coefficient_mode": "learned", 58 | "auxillary_torsion_loss": 0.0068641705106320325, 59 | 60 | "encoder_reduction": "sum", 61 | 62 | "output_concatenation_mode": "both", 63 | 64 | "default_lr": 0.0005694249946353567, 65 | 66 | "num_workers": 8, 67 | "batch_size": 16, 68 | "N_epochs": 100, 69 | 70 | "train_datafile": "final_data_splits/train_RS_classification_enantiomers_MOL_326865_55084_27542.pkl", 71 | "validation_datafile": "final_data_splits/validation_RS_classification_enantiomers_MOL_70099_11748_5874.pkl", 72 | "test_datafile": "final_data_splits/test_RS_classification_enantiomers_MOL_69719_11680_5840.pkl", 73 | 74 | "iteration_mode": "stereoisomers", 75 | "sample_1conformer": false, 76 | "select_N_enantiomers": null, 77 | 78 | "mask_coordinates": false, 79 | "stereoMask": true, 80 | 81 | "grouping": "none", 82 | "weighted_sum": true, 83 | 84 | "stratified": false, 85 | "withoutReplacement": true, 86 | 87 | "loss_function": "BCE", 88 | "absolute_penalty": null, 89 | "relative_penalty": null, 90 | "ranking_margin": null, 91 | 92 | "contrastive_vector": "none", 93 | "margin": null, 94 | 95 | "N_neg": 1, 96 | "N_pos": 0, 97 | 98 | "save": true 99 | } 100 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_RS_dimenetpp.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained": "", 3 | "random_seed": 1, 4 | 5 | "hidden_channels": 128, 6 | "out_channels": 32, 7 | "num_blocks": 4, 8 | "int_emb_size": 64, 9 | "basis_emb_size": 8, 10 | "out_emb_channels": 256, 11 | "num_spherical": 7, 12 | "num_radial": 6, 13 | "cutoff": 5.0, 14 | "envelope_exponent": 5, 15 | "num_before_skip": 1, 16 | "num_after_skip": 2, 17 | "num_output_layers": 3, 18 | "MLP_hidden_sizes": [64, 64], 19 | 20 | "lr": 0.0001, 21 | 22 | "num_workers": 8, 23 | "batch_size": 32, 24 | "N_epochs": 100, 25 | 26 | "train_datafile": "final_data_splits/train_RS_classification_enantiomers_MOL_326865_55084_27542.pkl", 27 | "validation_datafile": "final_data_splits/validation_RS_classification_enantiomers_MOL_70099_11748_5874.pkl", 28 | "test_datafile": "final_data_splits/test_RS_classification_enantiomers_MOL_69719_11680_5840.pkl", 29 | 30 | "iteration_mode": "stereoisomers", 31 | "sample_1conformer": false, 32 | "select_N_enantiomers": null, 33 | 34 | "grouping": "none", 35 | "weighted_sum": true, 36 | 37 | "stratified": false, 38 | "withoutReplacement": true, 39 | 40 | "loss_function": "BCE", 41 | "absolute_penalty": null, 42 | "relative_penalty": null, 43 | "ranking_margin": null, 44 | 45 | "margin": null, 46 | 47 | "N_neg": 1, 48 | "N_pos": 0, 49 | 50 | "save": true 51 | } 52 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_RS_schnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained": "", 3 | "random_seed": 1, 4 | 5 | "hidden_channels": 128, 6 | "num_filters": 128, 7 | "num_interactions": 6, 8 | "num_gaussians": 50, 9 | "cutoff": 10.0, 10 | "max_num_neighbors": 32, 11 | "out_channels": 32, 12 | "MLP_hidden_sizes": [64, 64], 13 | 14 | "lr": 0.0001, 15 | 16 | "num_workers": 8, 17 | "batch_size": 32, 18 | "N_epochs": 100, 19 | 20 | "train_datafile": "final_data_splits/train_RS_classification_enantiomers_MOL_326865_55084_27542.pkl", 21 | "validation_datafile": "final_data_splits/validation_RS_classification_enantiomers_MOL_70099_11748_5874.pkl", 22 | "test_datafile": "final_data_splits/test_RS_classification_enantiomers_MOL_69719_11680_5840.pkl", 23 | 24 | "iteration_mode": "stereoisomers", 25 | "sample_1conformer": false, 26 | "select_N_enantiomers": null, 27 | 28 | "grouping": "none", 29 | "weighted_sum": true, 30 | 31 | "stratified": false, 32 | "withoutReplacement": true, 33 | 34 | "loss_function": "BCE", 35 | "absolute_penalty": null, 36 | "relative_penalty": null, 37 | "ranking_margin": null, 38 | 39 | "margin": null, 40 | 41 | "N_neg": 1, 42 | "N_pos": 0, 43 | 44 | "save": true 45 | } 46 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_RS_spherenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained": "", 3 | "random_seed": 1, 4 | 5 | "hidden_channels": 256, 6 | "out_channels": 64, 7 | "cutoff": 5.0, 8 | "num_layers": 4, 9 | "int_emb_size": 32, 10 | "basis_emb_size_dist": 8, 11 | "basis_emb_size_angle": 8, 12 | "basis_emb_size_torsion": 8, 13 | "out_emb_channels": 128, 14 | "num_spherical": 7, 15 | "num_radial": 6, 16 | "envelope_exponent": 5, 17 | "num_before_skip": 1, 18 | "num_after_skip": 2, 19 | "num_output_layers": 3, 20 | "MLP_hidden_sizes": [256, 256, 256, 256], 21 | 22 | "lr": 0.00015420909786265446, 23 | 24 | "num_workers": 8, 25 | "batch_size": 64, 26 | "N_epochs": 100, 27 | 28 | "train_datafile": "final_data_splits/train_RS_classification_enantiomers_MOL_326865_55084_27542.pkl", 29 | "validation_datafile": "final_data_splits/validation_RS_classification_enantiomers_MOL_70099_11748_5874.pkl", 30 | "test_datafile": "final_data_splits/test_RS_classification_enantiomers_MOL_69719_11680_5840.pkl", 31 | 32 | "iteration_mode": "stereoisomers", 33 | "sample_1conformer": false, 34 | "select_N_enantiomers": null, 35 | 36 | "grouping": "none", 37 | "weighted_sum": true, 38 | 39 | "stratified": false, 40 | "withoutReplacement": true, 41 | 42 | "loss_function": "BCE", 43 | "absolute_penalty": null, 44 | "relative_penalty": null, 45 | "ranking_margin": null, 46 | 47 | "margin": null, 48 | 49 | "N_neg": 1, 50 | "N_pos": 0, 51 | 52 | "save": true 53 | } 54 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_binary_ranking_spherenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained": "", 3 | "random_seed": 1, 4 | 5 | "hidden_channels": 256, 6 | "out_channels": 32, 7 | "cutoff": 5.0, 8 | "num_layers": 5, 9 | "int_emb_size": 64, 10 | "basis_emb_size_dist": 8, 11 | "basis_emb_size_angle": 8, 12 | "basis_emb_size_torsion": 8, 13 | "out_emb_channels": 32, 14 | "num_spherical": 7, 15 | "num_radial": 6, 16 | "envelope_exponent": 5, 17 | "num_before_skip": 1, 18 | "num_after_skip": 2, 19 | "num_output_layers": 3, 20 | "MLP_hidden_sizes": [64, 64], 21 | 22 | "lr": 0.00014005162999131585, 23 | 24 | "num_workers": 8, 25 | "batch_size": 32, 26 | "N_epochs": 150, 27 | 28 | "train_datafile": "final_data_splits/train_small_enantiomers_stable_full_screen_docking_MOL_margin3_234622_48384_24192.pkl", 29 | "validation_datafile": "final_data_splits/validation_small_enantiomers_stable_full_screen_docking_MOL_margin3_49878_10368_5184.pkl", 30 | "test_datafile": "final_data_splits/test_small_enantiomers_stable_full_screen_docking_MOL_margin3_50571_10368_5184.pkl", 31 | 32 | "iteration_mode": "stereoisomers", 33 | "sample_1conformer": false, 34 | "select_N_enantiomers": null, 35 | 36 | "grouping": "none", 37 | "weighted_sum": true, 38 | 39 | "stratified": false, 40 | "withoutReplacement": true, 41 | 42 | "loss_function": "MSE_MarginRankingLoss", 43 | "absolute_penalty": 1.0, 44 | "relative_penalty": 0.0, 45 | "ranking_margin": 0.3, 46 | 47 | "margin": null, 48 | 49 | "N_neg": 1, 50 | "N_pos": 0, 51 | 52 | "save": true 53 | } 54 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_contrastive_ChIRo.json: -------------------------------------------------------------------------------- 1 | { 2 | "layers_dict": 3 | { 4 | "EConv_mlp_hidden_sizes": [32, 32], 5 | "GAT_hidden_node_sizes": [64], 6 | 7 | "encoder_hidden_sizes_D": [64, 64], 8 | "encoder_hidden_sizes_phi": [64, 64], 9 | "encoder_hidden_sizes_c": [64, 64], 10 | "encoder_hidden_sizes_alpha": [64, 64], 11 | 12 | "encoder_hidden_sizes_sinusoidal_shift": [256, 256], 13 | "output_mlp_hidden_sizes": [] 14 | }, 15 | 16 | 17 | "activation_dict": 18 | { 19 | "encoder_hidden_activation_D": "torch.nn.LeakyReLU(negative_slope=0.01)", 20 | "encoder_hidden_activation_phi": "torch.nn.LeakyReLU(negative_slope=0.01)", 21 | "encoder_hidden_activation_c": "torch.nn.LeakyReLU(negative_slope=0.01)", 22 | "encoder_hidden_activation_alpha": "torch.nn.LeakyReLU(negative_slope=0.01)", 23 | "encoder_hidden_activation_sinusoidal_shift": "torch.nn.LeakyReLU(negative_slope=0.01)", 24 | 25 | "encoder_output_activation_D": "torch.nn.Identity()", 26 | "encoder_output_activation_phi": "torch.nn.Identity()", 27 | "encoder_output_activation_c": "torch.nn.Identity()", 28 | "encoder_output_activation_alpha": "torch.nn.Identity()", 29 | "encoder_output_activation_sinusoidal_shift": "torch.nn.Identity()", 30 | 31 | "EConv_mlp_hidden_activation": "torch.nn.LeakyReLU(negative_slope=0.01)", 32 | "EConv_mlp_output_activation": "torch.nn.Identity()", 33 | 34 | "output_mlp_hidden_activation": "torch.nn.LeakyReLU(negative_slope=0.01)", 35 | "output_mlp_output_activation": "torch.nn.Identity()" 36 | }, 37 | 38 | "pretrained": "", 39 | "random_seed": 1, 40 | 41 | "F_z_list": [2, 2, 2], 42 | "F_H": 64, 43 | "F_H_EConv": 64, 44 | "GAT_N_heads": 4, 45 | "EConv_bias": true, 46 | "GAT_bias": true, 47 | "encoder_biases": true, 48 | "dropout": 0.0, 49 | 50 | "chiral_message_passing": true, 51 | "CMP_EConv_MLP_hidden_sizes": [256, 256], 52 | "CMP_GAT_N_layers": 3, 53 | "CMP_GAT_N_heads": 2, 54 | 55 | "c_coefficient_mode": "learned", 56 | "c_coefficient_normalization": "sigmoid", 57 | "phase_shift_coefficient_mode": "learned", 58 | "auxillary_torsion_loss": 0.0008249542971659538, 59 | 60 | "encoder_reduction": "sum", 61 | 62 | "output_concatenation_mode": "contrastive", 63 | 64 | "default_lr": 0.0006059244630573096, 65 | 66 | "num_workers": 8, 67 | "batch_size": 32, 68 | "N_epochs": 50, 69 | 70 | "train_datafile": "final_data_splits/train_contrastive_MOL_2088008_418922_180426.pkl", 71 | "validation_datafile": "final_data_splits/validation_contrastive_MOL_450726_89786_38658.pkl", 72 | "test_datafile": "", 73 | 74 | "iteration_mode": "stereoisomers", 75 | "sample_1conformer": false, 76 | "select_N_enantiomers": null, 77 | 78 | "mask_coordinates": false, 79 | "stereoMask": true, 80 | 81 | "grouping": "none", 82 | "weighted_sum": true, 83 | 84 | "stratified": false, 85 | "withoutReplacement": true, 86 | 87 | "loss_function": "euclidean-normalized", 88 | "absolute_penalty": null, 89 | "relative_penalty": null, 90 | "ranking_margin": null, 91 | 92 | "contrastive_vector": "z_alpha", 93 | "margin": 1.0, 94 | 95 | "N_neg": 1, 96 | "N_pos": 1, 97 | 98 | "save": true 99 | 100 | } 101 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_contrastive_dimenetpp.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained": "", 3 | "random_seed": 1, 4 | 5 | "hidden_channels": 128, 6 | "out_channels": 2, 7 | "num_blocks": 4, 8 | "int_emb_size": 64, 9 | "basis_emb_size": 8, 10 | "out_emb_channels": 256, 11 | "num_spherical": 7, 12 | "num_radial": 6, 13 | "cutoff": 5.0, 14 | "envelope_exponent": 5, 15 | "num_before_skip": 1, 16 | "num_after_skip": 2, 17 | "num_output_layers": 3, 18 | "MLP_hidden_sizes": [], 19 | 20 | "lr": 0.0001, 21 | 22 | "num_workers": 8, 23 | "batch_size": 32, 24 | "N_epochs": 50, 25 | 26 | "train_datafile": "final_data_splits/train_contrastive_MOL_2088008_418922_180426.pkl", 27 | "validation_datafile": "final_data_splits/validation_contrastive_MOL_450726_89786_38658.pkl", 28 | "test_datafile": "", 29 | 30 | "iteration_mode": "stereoisomers", 31 | "sample_1conformer": false, 32 | "select_N_enantiomers": null, 33 | 34 | "grouping": "none", 35 | "weighted_sum": false, 36 | 37 | "stratified": false, 38 | "withoutReplacement": true, 39 | 40 | "loss_function": "euclidean-normalized", 41 | "absolute_penalty": null, 42 | "relative_penalty": null, 43 | "ranking_margin": null, 44 | 45 | "margin": 1.0, 46 | 47 | "N_neg": 1, 48 | "N_pos": 1, 49 | 50 | "save": true 51 | 52 | } 53 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_contrastive_schnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained": "", 3 | "random_seed": 1, 4 | 5 | "hidden_channels": 128, 6 | "num_filters": 128, 7 | "num_interactions": 6, 8 | "num_gaussians": 50, 9 | "cutoff": 10.0, 10 | "max_num_neighbors": 32, 11 | "out_channels": 2, 12 | "MLP_hidden_sizes": [], 13 | 14 | "lr": 0.0001, 15 | 16 | "num_workers": 8, 17 | "batch_size": 32, 18 | "N_epochs": 50, 19 | 20 | "train_datafile": "final_data_splits/train_contrastive_MOL_2088008_418922_180426.pkl", 21 | "validation_datafile": "final_data_splits/validation_contrastive_MOL_450726_89786_38658.pkl", 22 | "test_datafile": "", 23 | 24 | "iteration_mode": "stereoisomers", 25 | "sample_1conformer": false, 26 | "select_N_enantiomers": null, 27 | 28 | "grouping": "none", 29 | "weighted_sum": false, 30 | 31 | "stratified": false, 32 | "withoutReplacement": true, 33 | 34 | "loss_function": "euclidean-normalized", 35 | "absolute_penalty": null, 36 | "relative_penalty": null, 37 | "ranking_margin": null, 38 | 39 | "margin": 1.0, 40 | 41 | "N_neg": 1, 42 | "N_pos": 1, 43 | 44 | "save": true 45 | } 46 | -------------------------------------------------------------------------------- /experiments/submodules/ChIRo/params_files/params_contrastive_spherenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained": "", 3 | "random_seed": 1, 4 | 5 | "hidden_channels": 128, 6 | "out_channels": 2, 7 | "cutoff": 5.0, 8 | "num_layers": 4, 9 | "int_emb_size": 64, 10 | "basis_emb_size_dist": 8, 11 | "basis_emb_size_angle": 8, 12 | "basis_emb_size_torsion": 8, 13 | "out_emb_channels": 256, 14 | "num_spherical": 7, 15 | "num_radial": 6, 16 | "envelope_exponent": 5, 17 | "num_before_skip": 1, 18 | "num_after_skip": 2, 19 | "num_output_layers": 3, 20 | "MLP_hidden_sizes": [], 21 | 22 | "lr": 0.0001, 23 | 24 | "num_workers": 8, 25 | "batch_size": 32, 26 | "N_epochs": 50, 27 | 28 | "train_datafile": "final_data_splits/train_contrastive_MOL_2088008_418922_180426.pkl", 29 | "validation_datafile": "final_data_splits/validation_contrastive_MOL_450726_89786_38658.pkl", 30 | "test_datafile": "", 31 | 32 | "iteration_mode": "stereoisomers", 33 | "sample_1conformer": false, 34 | "select_N_enantiomers": null, 35 | 36 | "grouping": "none", 37 | "weighted_sum": false, 38 | 39 | "stratified": false, 40 | "withoutReplacement": true, 41 | 42 | "loss_function": "euclidean-normalized", 43 | "absolute_penalty": null, 44 | "relative_penalty": null, 45 | "ranking_margin": null, 46 | 47 | "margin": 1.0, 48 | 49 | "N_neg": 1, 50 | "N_pos": 1, 51 | 52 | "save": true 53 | } 54 | -------------------------------------------------------------------------------- /experiments/submodules/tetra_dmpnn/Makefile: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Makefile for chiral_gnn 4 | # 5 | ################################################################################ 6 | 7 | conda_env: 8 | bash devtools/create_env.sh 9 | -------------------------------------------------------------------------------- /experiments/submodules/tetra_dmpnn/README.md: -------------------------------------------------------------------------------- 1 | # Chirality-aware message passing networks 2 | Custom aggregation functions for molecules with tetrahedral chirality ([arXiv](https://arxiv.org/abs/2012.00094)) 3 | 4 | ## Requirements 5 | * python (version>=3.7) 6 | * pytorch (version>=1.14) 7 | * rdkit (version>=2020.03.2) 8 | * pytorch-geometric (version>=1.6.0) 9 | 10 | ## Installation 11 | First, clone the repository: 12 | `git clone https://github.com/PattanaikL/chiral_gnn` 13 | 14 | Run `make conda_env` to create the conda environment. 15 | The script will request the user to enter one of the supported CUDA versions listed here: https://pytorch.org/get-started/locally/. 16 | The script uses this CUDA version to install PyTorch and PyTorch Geometric. Alternatively, the user could manually follow the steps to install PyTorch Geometric here: https://github.com/rusty1s/pytorch_geometric/blob/master/.travis.yml. 17 | 18 | ## Usage 19 | For the toy classification task, call the `train.py` script with the following parameters defined: 20 | 21 | `python train.py --data_path data/d4_docking/d4_docking_rs.csv --split_path data/d4_docking/rs/split0.npy --task classification --log_dir ./test_run --gnn_type dmpnn --message tetra_permute_concat` 22 | 23 | 24 | To train the model with the best-performing parameters, call the `train.py` script with the following parameters defined: 25 | 26 | `python train.py --data_path data/d4_docking/d4_docking.csv --split_path data/d4_docking/full/split0.npy --log_dir ./test_run --gnn_type dmpnn --message tetra_permute_concat --global_chiral_features --chiral_features` 27 | -------------------------------------------------------------------------------- /experiments/submodules/tetra_dmpnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/ChiENN/ee3185b39e8469a8caacf3d6d45a04c4a1cfff5b/experiments/submodules/tetra_dmpnn/__init__.py -------------------------------------------------------------------------------- /experiments/submodules/tetra_dmpnn/devtools/create_env.sh: -------------------------------------------------------------------------------- 1 | # Developed by Kevin A. Spiekermann 2 | # This script does the following tasks: 3 | # - creates the conda 4 | # - prompts user for desired CUDA version 5 | # - installs PyTorch with specified CUDA version in the environment 6 | # - installs torch torch-geometric in the environment 7 | 8 | 9 | # get OS type 10 | unameOut="$(uname -s)" 11 | case "${unameOut}" in 12 | Linux*) machine=Linux;; 13 | Darwin*) machine=MacOS;; 14 | CYGWIN*) machine=Cygwin;; 15 | MINGW*) machine=MinGw;; 16 | *) machine="UNKNOWN:${unameOut}" 17 | esac 18 | echo "Running ${machine}..." 19 | 20 | 21 | # request user to select one of the supported CUDA versions 22 | # source: https://pytorch.org/get-started/locally/ 23 | PS3='Please enter 1, 2, 3, or 4 to specify the desired CUDA version from the options above: ' 24 | options=("9.2" "10.1" "10.2" "cpu" "Quit") 25 | select opt in "${options[@]}" 26 | do 27 | case $opt in 28 | "9.2") 29 | CUDA="cudatoolkit=9.2" 30 | CUDA_VERSION="cu92" 31 | break 32 | ;; 33 | "10.1") 34 | CUDA="cudatoolkit=10.1" 35 | CUDA_VERSION="cu101" 36 | break 37 | ;; 38 | "10.2") 39 | CUDA="cudatoolkit=10.2" 40 | CUDA_VERSION="cu102" 41 | break 42 | ;; 43 | "cpu") 44 | # "cpuonly" works for Linux and Windows 45 | CUDA="cpuonly" 46 | # Mac does not use "cpuonly" 47 | if [ $machine == "Mac" ] 48 | then 49 | CUDA=" " 50 | fi 51 | CUDA_VERSION="cpu" 52 | break 53 | ;; 54 | "Quit") 55 | exit 56 | ;; 57 | *) echo "invalid option $REPLY";; 58 | esac 59 | done 60 | 61 | echo "Creating conda environment..." 62 | echo "Running: conda env create -f environment.yml" 63 | conda env create -f environment.yml 64 | 65 | # activate the environment to install torch-geometric 66 | source activate chiral_gnn 67 | 68 | echo "Installing PyTorch with requested CUDA version..." 69 | echo "Running: conda install pytorch torchvision $CUDA -c pytorch" 70 | conda install pytorch torchvision $CUDA -c pytorch 71 | 72 | echo "Installing torch-geometric..." 73 | echo "Using CUDA version: $CUDA_VERSION" 74 | # get PyTorch version 75 | TORCH_VERSION=$(python -c "import torch; print(torch.__version__)") 76 | echo "Using PyTorch version: $TORCH_VERSION" 77 | 78 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html 79 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html 80 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html 81 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html 82 | pip install torch-geometric 83 | -------------------------------------------------------------------------------- /experiments/submodules/tetra_dmpnn/environment.yml: -------------------------------------------------------------------------------- 1 | name: chiral_gnn 2 | channels: 3 | - rdkit 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - alembic=1.4.3=pyh9f0ad1d_0 10 | - attrs=20.3.0=pyhd3deb0d_0 11 | - backports=1.0=py_2 12 | - backports.functools_lru_cache=1.6.1=py_0 13 | - blas=1.0=mkl 14 | - bzip2=1.0.8=h7b6447c_0 15 | - ca-certificates=2020.11.8=ha878542_0 16 | - cairo=1.14.12=h8948797_3 17 | - certifi=2020.11.8=py37h89c1867_0 18 | - cliff=3.5.0=pyhd8ed1ab_0 19 | - cmaes=0.7.0=pyhac0dd68_0 20 | - cmd2=0.9.22=py37hc8dfbb8_1 21 | - colorama=0.4.4=pyh9f0ad1d_0 22 | - colorlog=4.6.2=py37h89c1867_0 23 | - cudatoolkit=10.2.89=hfd86e86_1 24 | - fontconfig=2.13.0=h9420a91_0 25 | - freetype=2.10.4=h5ab3b9f_0 26 | - glib=2.66.1=h92f7085_0 27 | - icu=58.2=he6710b0_3 28 | - importlib-metadata=3.1.1=pyhd8ed1ab_0 29 | - importlib_metadata=3.1.1=hd8ed1ab_0 30 | - intel-openmp=2020.2=254 31 | - joblib=0.17.0=py_0 32 | - jpeg=9b=h024ee3a_2 33 | - lcms2=2.11=h396b838_0 34 | - ld_impl_linux-64=2.33.1=h53a641e_7 35 | - libboost=1.67.0=h46d08c1_4 36 | - libedit=3.1.20191231=h14c3975_1 37 | - libffi=3.3=he6710b0_2 38 | - libgcc-ng=9.1.0=hdf63c60_0 39 | - libgfortran-ng=7.3.0=hdf63c60_0 40 | - libpng=1.6.37=hbc83047_0 41 | - libstdcxx-ng=9.1.0=hdf63c60_0 42 | - libtiff=4.1.0=h2733197_1 43 | - libuuid=1.0.3=h1bed415_2 44 | - libuv=1.40.0=h7b6447c_0 45 | - libxcb=1.14=h7b6447c_0 46 | - libxml2=2.9.10=hb55368b_3 47 | - lz4-c=1.9.2=heb0550a_3 48 | - mako=1.1.3=pyh9f0ad1d_0 49 | - markupsafe=1.1.1=py37hb5d75c8_2 50 | - mkl=2020.2=256 51 | - mkl-service=2.3.0=py37he904b0f_0 52 | - mkl_fft=1.2.0=py37h23d657b_0 53 | - mkl_random=1.1.1=py37h0573a6f_0 54 | - ncurses=6.2=he6710b0_1 55 | - ninja=1.10.2=py37hff7bd54_0 56 | - numpy=1.19.2=py37h54aff64_0 57 | - numpy-base=1.19.2=py37hfa32c7d_0 58 | - olefile=0.46=py37_0 59 | - openssl=1.1.1h=h516909a_0 60 | - optuna=2.3.0=pyhd8ed1ab_0 61 | - packaging=20.4=pyh9f0ad1d_0 62 | - pandas=1.1.3=py37he6710b0_0 63 | - pbr=5.5.1=pyh9f0ad1d_0 64 | - pcre=8.44=he6710b0_0 65 | - pillow=8.0.1=py37he98fc37_0 66 | - pip=20.3=py37h06a4308_0 67 | - pixman=0.40.0=h7b6447c_0 68 | - prettytable=2.0.0=pyhd8ed1ab_0 69 | - py-boost=1.67.0=py37h04863e7_4 70 | - pyparsing=2.4.7=pyh9f0ad1d_0 71 | - pyperclip=1.8.0=pyh9f0ad1d_0 72 | - python=3.7.9=h7579374_0 73 | - python-dateutil=2.8.1=py_0 74 | - python-editor=1.0.4=py_0 75 | - python_abi=3.7=1_cp37m 76 | - pytorch=1.7.0=py3.7_cuda10.2.89_cudnn7.6.5_0 77 | - pytz=2020.4=pyhd3eb1b0_0 78 | - pyyaml=5.3.1=py37hb5d75c8_1 79 | - rdkit=2020.03.2.0=py37hc20afe1_1 80 | - readline=8.0=h7b6447c_0 81 | - scikit-learn=0.23.2=py37h0573a6f_0 82 | - scipy=1.5.2=py37h0b6359f_0 83 | - setuptools=50.3.1=py37h06a4308_1 84 | - six=1.15.0=py37h06a4308_0 85 | - sqlalchemy=1.3.20=py37h8f50634_0 86 | - sqlite=3.33.0=h62c20be_0 87 | - stevedore=3.3.0=py37h89c1867_0 88 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 89 | - tk=8.6.10=hbc83047_0 90 | - torchaudio=0.7.0=py37 91 | - torchvision=0.8.1=py37_cu102 92 | - tqdm=4.54.0=pyhd8ed1ab_0 93 | - typing_extensions=3.7.4.3=py_0 94 | - wcwidth=0.2.5=pyh9f0ad1d_2 95 | - wheel=0.35.1=pyhd3eb1b0_0 96 | - xz=5.2.5=h7b6447c_0 97 | - yaml=0.2.5=h516909a_0 98 | - zipp=3.4.0=py_0 99 | - zlib=1.2.11=h7b6447c_3 100 | - zstd=1.4.5=h9ceee32_0 101 | -------------------------------------------------------------------------------- /experiments/submodules/tetra_dmpnn/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import pandas as pd 5 | 6 | from features.featurization import construct_loader 7 | from utils import Standardizer, create_logger, get_loss_func 8 | 9 | from model.gnn import GNN 10 | from model.training import train, eval, test, build_lr_scheduler 11 | from model.parsing import parse_train_args 12 | 13 | args = parse_train_args() 14 | torch.manual_seed(args.seed) 15 | logger = create_logger('train', args.log_dir) 16 | 17 | train_loader, val_loader = construct_loader(args) 18 | mean = train_loader.dataset.mean 19 | std = train_loader.dataset.std 20 | stdzer = Standardizer(mean, std, args.task) 21 | 22 | # create model, optimizer, scheduler, and loss fn 23 | model = GNN(args, train_loader.dataset.num_node_features, train_loader.dataset.num_edge_features).to(args.device) 24 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 25 | scheduler = build_lr_scheduler(optimizer, args, len(train_loader.dataset)) 26 | loss = get_loss_func(args) 27 | best_val_loss = math.inf 28 | best_epoch = 0 29 | 30 | # record args, optimizer, and scheduler info 31 | logger.info('Arguments are...') 32 | for arg in vars(args): 33 | logger.info(f'{arg}: {getattr(args, arg)}') 34 | logger.info(f'\nOptimizer parameters are:\n{optimizer}\n') 35 | logger.info(f'Scheduler state dict is:') 36 | for key, value in scheduler.state_dict().items(): 37 | logger.info(f'{key}: {value}') 38 | logger.info('') 39 | 40 | # train 41 | logger.info("Starting training...") 42 | for epoch in range(0, args.n_epochs): 43 | train_loss, train_acc = train(model, train_loader, optimizer, loss, stdzer, args.device, scheduler, args.task) 44 | logger.info(f"Epoch {epoch}: Training Loss {train_loss}") 45 | 46 | if args.task == 'classification': 47 | logger.info(f"Epoch {epoch}: Training Classification Accuracy {train_acc}") 48 | 49 | val_loss, val_acc = eval(model, val_loader, loss, stdzer, args.device, args.task) 50 | logger.info(f"Epoch {epoch}: Validation Loss {val_loss}") 51 | 52 | if args.task == 'classification': 53 | logger.info(f"Epoch {epoch}: Validation Classification Accuracy {val_acc}") 54 | 55 | if val_loss <= best_val_loss: 56 | best_val_loss = val_loss 57 | best_epoch = epoch 58 | torch.save(model.state_dict(), os.path.join(args.log_dir, 'best_model')) 59 | logger.info(f"Best Validation Loss {best_val_loss} on Epoch {best_epoch}") 60 | 61 | # load best model 62 | model = GNN(args, train_loader.dataset.num_node_features, train_loader.dataset.num_edge_features).to(args.device) 63 | state_dict = torch.load(os.path.join(args.log_dir, 'best_model'), map_location=args.device) 64 | model.load_state_dict(state_dict) 65 | 66 | # predict test data 67 | test_loader = construct_loader(args, modes='test') 68 | preds, test_loss, test_acc, test_auc = test(model, test_loader, loss, stdzer, args.device, args.task) 69 | logger.info(f"Test Loss {test_loss}") 70 | if args.task == 'classification': 71 | logger.info(f"Test Classification Accuracy {test_acc}") 72 | logger.info(f"Test ROC AUC Score {test_auc}") 73 | 74 | # save predictions 75 | smiles = test_loader.dataset.smiles 76 | preds_path = os.path.join(args.log_dir, 'preds.csv') 77 | pd.DataFrame(list(zip(smiles, preds)), columns=['smiles', 'prediction']).to_csv(preds_path, index=False) 78 | -------------------------------------------------------------------------------- /experiments/submodules/tetra_dmpnn/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from argparse import Namespace 4 | from torch import nn 5 | 6 | 7 | class Standardizer: 8 | def __init__(self, mean, std, task='regression'): 9 | if task == 'regression': 10 | self.mean = mean 11 | self.std = std 12 | elif task == 'classification': 13 | self.mean = 0 14 | self.std = 1 15 | 16 | def __call__(self, x, rev=False): 17 | if rev: 18 | return (x * self.std) + self.mean 19 | return (x - self.mean) / self.std 20 | 21 | 22 | def create_logger(name: str, log_dir: str = None) -> logging.Logger: 23 | """ 24 | Creates a logger with a stream handler and file handler. 25 | 26 | :param name: The name of the logger. 27 | :param log_dir: The directory in which to save the logs. 28 | :return: The logger. 29 | """ 30 | logger = logging.getLogger(name) 31 | logger.setLevel(logging.INFO) 32 | logger.propagate = False 33 | 34 | # Set logger 35 | ch = logging.StreamHandler() 36 | ch.setLevel(logging.INFO) 37 | logger.addHandler(ch) 38 | 39 | if not os.path.exists(log_dir): 40 | os.makedirs(log_dir) 41 | 42 | fh = logging.FileHandler(os.path.join(log_dir, name + '.log')) 43 | fh.setLevel(logging.INFO) 44 | logger.addHandler(fh) 45 | 46 | return logger 47 | 48 | 49 | def get_loss_func(args: Namespace) -> nn.Module: 50 | """ 51 | Gets the loss function corresponding to a given dataset type. 52 | 53 | :param args: Namespace containing the dataset type ("classification" or "regression"). 54 | :return: A PyTorch loss function. 55 | """ 56 | if args.task == 'classification': 57 | return nn.BCELoss(reduction='sum') 58 | 59 | if args.task == 'regression': 60 | return nn.MSELoss(reduction='sum') 61 | 62 | raise ValueError(f'Dataset type "{args.task}" not supported.') 63 | -------------------------------------------------------------------------------- /images/fk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/ChiENN/ee3185b39e8469a8caacf3d6d45a04c4a1cfff5b/images/fk.png -------------------------------------------------------------------------------- /images/order_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/ChiENN/ee3185b39e8469a8caacf3d6d45a04c4a1cfff5b/images/order_example.png --------------------------------------------------------------------------------