├── LICENSE ├── README.md ├── configs ├── molhiv │ └── egt_110m.yaml ├── molpcba │ └── egt_110m.yaml ├── pcqm4m │ └── egt_47m.yaml └── pcqm4mv2 │ ├── egt_110m.yaml │ ├── egt_47m.yaml │ └── egt_90m.yaml ├── do_evaluations.py ├── environment.yml ├── lib ├── __init__.py ├── data │ ├── __init__.py │ ├── dataset_base.py │ ├── graph_dataset │ │ ├── __init__.py │ │ ├── graph_dataset.py │ │ ├── stack_with_pad.py │ │ ├── structural_dataset.py │ │ └── svd_encodings_dataset.py │ ├── molhiv │ │ ├── __init__.py │ │ └── data.py │ ├── molpcba │ │ ├── __init__.py │ │ └── data.py │ ├── pcqm4m │ │ ├── __init__.py │ │ └── data.py │ └── pcqm4mv2 │ │ ├── __init__.py │ │ └── data.py ├── models │ ├── egt.py │ ├── egt_layers.py │ ├── egt_molgraph.py │ ├── molhiv │ │ ├── __init__.py │ │ └── model.py │ ├── molpcba │ │ ├── __init__.py │ │ └── model.py │ ├── pcqm4m │ │ ├── __init__.py │ │ └── model.py │ └── pcqm4mv2 │ │ ├── __init__.py │ │ └── model.py ├── training │ ├── execute.py │ ├── schemes │ │ ├── egt_mol_training.py │ │ ├── egt_training.py │ │ ├── molhiv │ │ │ ├── __init__.py │ │ │ └── scheme.py │ │ ├── molpcba │ │ │ ├── __init__.py │ │ │ └── scheme.py │ │ ├── pcqm4m │ │ │ ├── __init__.py │ │ │ └── scheme.py │ │ └── pcqm4mv2 │ │ │ ├── __init__.py │ │ │ └── scheme.py │ ├── testing.py │ ├── training.py │ └── training_mixins.py └── utils │ └── dotdict │ ├── __init__.py │ └── dotdict.py ├── make_predictions.py └── run_training.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Md Shamim Hussain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Edge-augmented Graph Transformer (PyTorch) 2 | 3 | ## News 4 | * 02/09/2024 - The preprint of our paper ["Triplet Interaction Improves Graph Transformers: Accurate Molecular Graph Learning with Triplet Graph Transformers"](https://arxiv.org/abs/2402.04538) is now available on ArXiv. We achieved SOTA results on PCQM4Mv2, OC20 IS2RE, QM9, MOLPCBA and LIT-PCBA. We will include our new results, and methods, along with model weights soon at . 5 | * 11/23/2023 - We achieved SOTA results (again) on the [PCQM4M-V2](https://ogb.stanford.edu/docs/lsc/) dataset by incorporating triangular attention and 3D molecular structure. You can find the new implementation at and a [technical report](https://github.com/shamim-hussain/egt_triangular/blob/master/Report.pdf) (full paper coming soon!). 6 | * 06/21/2022 - The trained checkpoints on the [PCQM4M-V2](https://ogb.stanford.edu/docs/lsc/) have been released. They are available at . For additional information, see the ["Download Trained Model Checkpoints"](#download-trained-model-checkpoints) section below. 7 | * 06/05/2022 - The [accepted prerprint](https://arxiv.org/abs/2108.03348) our paper in KDD '22 is now available on arXiv. It includes discussions on dynamic centrality scalers, random masking, attention dropout and other details about the latest experiments and results. Note that the title is changed to **"Global Self-Attention as a Replacement for Graph Convolution"**. 8 | * 05/18/2022 - Our paper "Global Self-Attention as a Replacement for Graph Convolution" has been accepted at [KDD'22](https://kdd.org/kdd2022/). The preprint at arXiv will be updated soon with the latest version of the paper. 9 | 10 | ## Introduction 11 | 12 | This is the official **PyTorch** implementation of the **Edge-augmented Graph Transformer (EGT)** as described in , which augments the Transformer architecture with residual edge channels. The resultant architecture can directly process graph-structured data. For a **Tensorflow** implementation see: . 13 | 14 | This implementation focuses on the [OGB-Mol](https://ogb.stanford.edu/docs/graphprop/) datasets and [OGB-LSC](https://ogb.stanford.edu/docs/lsc/) datasets. (OGB-Mol datasets utilize transfer learning from PCQM4Mv2 dataset.) 15 | 16 | ## Results 17 | 18 | Dataset | #layers | #params | Metric | Valid | Test | 19 | --------------|---------|---------|----------------|-----------------|----------------| 20 | PCQM4M | 18 | 47.4M | MAE | 0.1225 | -- | 21 | PCQM4M-V2 | 18 | 47.4M | MAE | 0.0883 | -- | 22 | PCQM4M-V2 | 24 | 89.3M | MAE | 0.0857 | 0.0862 | 23 | OGBG-MolPCBA | 30 | 110.8M | Avg. Precision | 0.3021 ± 0.0053 | 0.2961 ± 0.0024| 24 | OGBG-MolHIV | 30 | 110.8M | ROC-AUC | 0.8060 ± 0.0065 | 0.8051 ± 0.0030| 25 | 26 | ## Download Trained Model Checkpoints 27 | 28 | The trained model checkpoints on the PCQM4M-V2 dataset are available at . Individual *zip* files are downloadable. The extracted folders can be put under the *models/pcqm4mv2* directory. See the *config_input.yaml* file contained within to see the training configurations. 29 | 30 | We found that the results can be further improved by freezing the node channel layers and training the edge channel layers for a few additional epochs. The corresponding tuned models are given the suffix **-T** and achieve better results than their untuned counterparts. However, its effect on transfer learning has not yet been studied. That is why we include checkpoints for both tuned and untuned models. 31 | 32 | Model | #layers | #params | Valid MAE | Test MAE | Comment | 33 | -----------------|---------|---------|-----------------|----------------|----------------------------------------| 34 | EGT-48M-SIMPLE | 18 | 47.2M | 0.0872 | -- | EGT-Simple (lightweight variant of EGT)| 35 | EGT-48M-SIMPLE-T | 18 | 47.2M | 0.0860 | -- | Tuned version of above | 36 | EGT-90M | 24 | 89.3M | 0.0869 | 0.0872 | **Submitted to the leaderboard** | 37 | EGT-90M-T | 24 | 89.3M | **0.0857** | **0.0862** | **Submitted tuned version of above** | 38 | EGT-110M | 30 | 110.8M | 0.0870 | -- | **Used for transfer learning** | 39 | EGT-110M-T | 30 | 110.8M | 0.0859 | -- | Tuned version of above | 40 | 41 | ## Requirements 42 | 43 | * `python >= 3.7` 44 | * `pytorch >= 1.6.0` 45 | * `numpy >= 1.18.4` 46 | * `numba >= 0.50.1` 47 | * `ogb >= 1.3.2` 48 | * `rdkit>=2019.03.1` 49 | * `yaml >= 5.3.1` 50 | 51 | ## Run Training and Evaluations 52 | 53 | You can specify the training/prediction/evaluation configurations by creating a `yaml` config file and also by passing a series of `yaml` readable arguments. (Any additional config passed as argument willl override the config specified in the file.) 54 | 55 | * To run training: ```python run_training.py [config_file.yaml] ['config1: value1'] ['config2: value2'] ...``` 56 | * To make predictions: ```python make_predictions.py [config_file.yaml] ['config1: value1'] ['config2: value2'] ...``` 57 | * To perform evaluations: ```python do_evaluations.py [config_file.yaml] ['config1: value1'] ['config2: value2'] ...``` 58 | 59 | Config files for the results can be found in the configs directory. Examples: 60 | ``` 61 | python run_training.py configs/pcqm4m/egt_47m.yaml 62 | python run_training.py 'scheme: pcqm4m' 'model_height: 6' 63 | python make_predictions.py configs/pcqm4m/egt_47m.yaml 'evaluate_on: ["val"]' 64 | ``` 65 | 66 | ### More About Training 67 | 68 | Once the training is started a model folder will be created in the *models* directory, under the specified dataset name. This folder will contain a copy of the input config file, for the convenience of resuming training/evaluation. Also, it will contain a config.yaml which will contain all configs, including unspecified default values, used for the training. Training will be checkpointed per epoch. In the case of any interruption, you can resume training by running the *run_training.py* with the config.yaml file again. 69 | 70 | ### Configs 71 | There many different configurations. The only **required** configuration is `scheme`, which specifies the training scheme. If the other configurations are not specified, a default value will be assumed for them. Here are some of the commonly used configurations: 72 | 73 | `scheme`: pcqm4m/pcqm4mv2/molpcba/mohiv. 74 | 75 | `dataset_path`: Where the downloaded OGB datasets will be saved. 76 | 77 | `model_name`: Serves as an identifier for the model, also specifies default path of the model directory, weight files etc. 78 | 79 | `save_path`: The training process will create a model directory containing the logs, checkpoints, configs, model summary and predictions/evaluations. By default it creates a folder at *models/* but it can be changed via this config. 80 | 81 | `cache_dir`: During first time of training/evaluation the data will be cached. Default path is *cache_data/*. But it can be changed via this config. 82 | 83 | `distributed`: In a multi-gpu setting you can set it to True, for distributed training. Note that, the batch size should also be adjusted accordingly. 84 | 85 | `batch_size`: Batch size. In case of distributed training it is the local batch size. So, the total batch size = batch_size x number of available gpus. 86 | 87 | `num_epochs`: Maximum Number of epochs. 88 | 89 | `max_lr`: Maximum learning rate. 90 | 91 | `min_lr`: Minimum learning rate. 92 | 93 | `lr_warmup_steps`: Initial linear learning rate warmup steps. 94 | 95 | `lr_total_steps`: Total number of gradient updates to be performed, including linear warmup and cosine decay. 96 | 97 | `model_height`: The number of layers *L*. 98 | 99 | `node_width`: The dimensionality of the node channels *d_h*. 100 | 101 | `edge_width`: The dimensionality of the edge channels *d_e*. 102 | 103 | `num_heads`: The number of attention heads. Default is 8. 104 | 105 | `node_ffn_multiplier`: FFN multiplier for node channels. 106 | 107 | `edge_ffn_multiplier`: FFN multiplier for edge channels. 108 | 109 | `virtual_nodes`: number of virtual nodes. 0 (default) would result in global average pooling being used instead of virtual nodes. 110 | 111 | `upto_hop`: Clipping value of the input distance matrix. 112 | 113 | `attn_dropout`: Dropout rate for the attention matrix. 114 | 115 | `node_dropout`: Dropout rate for the node channel's MHA and FFN blocks. 116 | 117 | `edge_dropout`: Dropout rate for the edge channel's MHA and FFN blocks. 118 | 119 | `sel_svd_features`: Rank of the SVD encodings *r*. 120 | 121 | `svd_calculated_dim` : Number of left and right singular vectors calculated and cached for svd encodings. 122 | 123 | `svd_output_dim` : Number of left and right singular vectors used as svd encodings. 124 | 125 | `svd_random_neg` : Whether to randomly flip the signs of the singular vectors. Default - true. 126 | 127 | `pretrained_weights_file` : Used to specify the learned weights of an already trained model. 128 | 129 | ## Python Environment 130 | 131 | The Anaconda environment in which the experiments were conducted is specified in the `environment.yml` file. 132 | 133 | ## Citation 134 | 135 | Please cite the following paper if you find the code useful: 136 | ``` 137 | @article{hussain2021global, 138 | title={Global Self-Attention as a Replacement for Graph Convolution}, 139 | author={Hussain, Md Shamim and Zaki, Mohammed J and Subramanian, Dharmashankar}, 140 | journal={arXiv preprint arXiv:2108.03348}, 141 | year={2021} 142 | } 143 | ``` 144 | -------------------------------------------------------------------------------- /configs/molhiv/egt_110m.yaml: -------------------------------------------------------------------------------- 1 | scheme: molhiv 2 | model_name: egt_110m 3 | distributed: false # Set = true for multi-gpu 4 | batch_size: 12 # For 6 GPUs: 12//6=2 5 | model_height: 30 6 | node_width: 768 7 | edge_width: 64 8 | num_heads: 32 9 | num_epochs: 1000 10 | max_lr: 0.0001 11 | attn_dropout: 0.3 12 | lr_warmup_steps: 1000 13 | lr_total_steps: 3000 14 | node_ffn_multiplier: 1.0 15 | edge_ffn_multiplier: 1.0 16 | upto_hop: 16 17 | dataloader_workers: 1 # For multi-process data fetch 18 | scale_degree: true 19 | num_virtual_nodes: 4 20 | svd_random_neg: true 21 | pretrained_weights_file: models/pcqm4mv2/egt_110m/checkpoint/model_state 22 | # ^ For transfer learning from PCQM4Mv2 -------------------------------------------------------------------------------- /configs/molpcba/egt_110m.yaml: -------------------------------------------------------------------------------- 1 | scheme: molpcba 2 | model_name: egt_110m 3 | distributed: false # Set = true for multi-gpu 4 | batch_size: 16 # For 8 GPUs: 16//8=2 5 | model_height: 30 6 | node_width: 768 7 | edge_width: 64 8 | num_heads: 32 9 | num_epochs: 1000 10 | max_lr: 0.0001 11 | attn_dropout: 0.3 12 | lr_warmup_steps: 20000 13 | lr_total_steps: 200000 14 | node_ffn_multiplier: 1.0 15 | edge_ffn_multiplier: 1.0 16 | upto_hop: 16 17 | dataloader_workers: 1 # For multi-process data fetch 18 | scale_degree: true 19 | num_virtual_nodes: 4 20 | svd_random_neg: true 21 | pretrained_weights_file: models/pcqm4mv2/egt_110m/checkpoint/model_state 22 | # ^ For transfer learning from PCQM4Mv2 -------------------------------------------------------------------------------- /configs/pcqm4m/egt_47m.yaml: -------------------------------------------------------------------------------- 1 | scheme: pcqm4m 2 | model_name: egt_47m 3 | distributed: false # Set = true for multi-gpu 4 | batch_size: 512 # For 8 GPUs: 512//8=64 5 | model_height: 18 6 | node_width: 640 7 | edge_width: 64 8 | num_heads: 32 9 | num_epochs: 1000 10 | max_lr: 0.0001 11 | attn_dropout: 0.3 12 | lr_warmup_steps: 200000 13 | lr_total_steps: 1000000 14 | node_ffn_multiplier: 1.0 15 | edge_ffn_multiplier: 1.0 16 | upto_hop: 16 17 | dataloader_workers: 1 # For multi-process data fetch 18 | scale_degree: true 19 | num_virtual_nodes: 4 20 | svd_random_neg: true -------------------------------------------------------------------------------- /configs/pcqm4mv2/egt_110m.yaml: -------------------------------------------------------------------------------- 1 | scheme: pcqm4mv2 2 | model_name: egt_110m 3 | distributed: false # Set = true for multi-gpu 4 | batch_size: 512 # For 8 GPUs: 512//8=64 5 | model_height: 30 6 | node_width: 768 7 | edge_width: 64 8 | num_heads: 32 9 | num_epochs: 1000 10 | max_lr: 8.0e-05 11 | attn_dropout: 0.3 12 | lr_warmup_steps: 240000 13 | lr_total_steps: 1000000 14 | node_ffn_multiplier: 1.0 15 | edge_ffn_multiplier: 1.0 16 | upto_hop: 16 17 | dataloader_workers: 1 # For multi-process data fetch 18 | scale_degree: true 19 | num_virtual_nodes: 4 20 | svd_random_neg: true 21 | -------------------------------------------------------------------------------- /configs/pcqm4mv2/egt_47m.yaml: -------------------------------------------------------------------------------- 1 | scheme: pcqm4mv2 2 | model_name: egt_47m 3 | distributed: false # Set = true for multi-gpu 4 | batch_size: 512 # For 8 GPUs: 512//8=64 5 | model_height: 18 6 | node_width: 640 7 | edge_width: 64 8 | num_heads: 32 9 | num_epochs: 1000 10 | max_lr: 0.0001 11 | attn_dropout: 0.3 12 | lr_warmup_steps: 200000 13 | lr_total_steps: 1000000 14 | node_ffn_multiplier: 1.0 15 | edge_ffn_multiplier: 1.0 16 | upto_hop: 16 17 | dataloader_workers: 1 # For multi-process data fetch 18 | scale_degree: true 19 | num_virtual_nodes: 4 20 | svd_random_neg: true -------------------------------------------------------------------------------- /configs/pcqm4mv2/egt_90m.yaml: -------------------------------------------------------------------------------- 1 | scheme: pcqm4mv2 2 | model_name: egt_90m 3 | distributed: false # Set = true for multi-gpu 4 | batch_size: 512 # For 8 GPUs: 512//8=64 5 | model_height: 24 6 | node_width: 768 7 | edge_width: 64 8 | num_heads: 32 9 | num_epochs: 1000 10 | max_lr: 0.0001 11 | attn_dropout: 0.3 12 | lr_warmup_steps: 200000 13 | lr_total_steps: 1000000 14 | node_ffn_multiplier: 1.0 15 | edge_ffn_multiplier: 1.0 16 | upto_hop: 16 17 | dataloader_workers: 1 # For multi-process data fetch 18 | scale_degree: true 19 | num_virtual_nodes: 4 20 | svd_random_neg: true -------------------------------------------------------------------------------- /do_evaluations.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from lib.training.execute import get_configs_from_args, execute 3 | 4 | if __name__ == '__main__': 5 | config = get_configs_from_args(sys.argv) 6 | execute('evaluate', config) 7 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - dglteam 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0 10 | - alabaster=0.7.12=py_0 11 | - anaconda=2020.07=py38_0 12 | - anaconda-client=1.7.2=py38_0 13 | - anaconda-navigator=1.9.12=py38_0 14 | - anaconda-project=0.8.4=py_0 15 | - appdirs=1.4.4=pyh9f0ad1d_0 16 | - argh=0.26.2=py38_0 17 | - asn1crypto=1.3.0=py38_0 18 | - astroid=2.4.2=py38_0 19 | - astropy=4.0.1.post1=py38he774522_1 20 | - atomicwrites=1.4.0=py_0 21 | - attrs=19.3.0=py_0 22 | - audioread=2.1.9=py38haa244fe_0 23 | - autopep8=1.5.3=py_0 24 | - babel=2.8.0=py_0 25 | - backcall=0.2.0=py_0 26 | - backports=1.0=py_2 27 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 28 | - backports.shutil_get_terminal_size=1.0.0=py38_2 29 | - backports.tempfile=1.0=py_0 30 | - backports.weakref=1.0.post1=py38h32f6830_1002 31 | - bcrypt=3.1.7=py38he774522_1 32 | - beautifulsoup4=4.9.1=py38_0 33 | - bitarray=1.4.0=py38he774522_0 34 | - bkcharts=0.2=py38_0 35 | - blas=1.0=mkl 36 | - bleach=3.1.5=py_0 37 | - blosc=1.19.0=h7bd577a_0 38 | - bokeh=2.1.1=py38_0 39 | - boost=1.74.0=py38h1266d08_3 40 | - boost-cpp=1.74.0=h54f0996_1 41 | - boto=2.49.0=py38_0 42 | - bottleneck=1.3.2=py38h2a96729_1 43 | - brotlipy=0.7.0=py38he774522_1000 44 | - bzip2=1.0.8=he774522_0 45 | - ca-certificates=2020.6.24=0 46 | - cairo=1.16.0=h63a05c6_1001 47 | - certifi=2020.6.20=py38_0 48 | - cffi=1.14.0=py38h7a1dbc1_0 49 | - chardet=3.0.4=py38_1003 50 | - click=7.1.2=py_0 51 | - cloudpickle=1.5.0=py_0 52 | - clyent=1.2.2=py38_1 53 | - colorama=0.4.3=py_0 54 | - comtypes=1.1.7=py38_1001 55 | - conda=4.11.0=py38haa95532_0 56 | - conda-build=3.18.11=py38_1 57 | - conda-env=2.6.0=1 58 | - conda-package-handling=1.7.3=py38h31c79cd_0 59 | - conda-verify=3.4.2=py_1 60 | - console_shortcut=0.1.1=4 61 | - contextlib2=0.6.0.post1=py_0 62 | - cryptography=2.9.2=py38h7a1dbc1_0 63 | - cudatoolkit=11.0.221=h74a9793_0 64 | - curl=7.71.1=h2a8f88b_1 65 | - cycler=0.10.0=py38_0 66 | - cython=0.29.21=py38ha925a31_0 67 | - cytoolz=0.10.1=py38he774522_0 68 | - dask=2.20.0=py_0 69 | - dask-core=2.20.0=py_0 70 | - decorator=4.4.2=py_0 71 | - defusedxml=0.6.0=py_0 72 | - dgl=0.5.3=py38_0 73 | - diff-match-patch=20200713=py_0 74 | - distributed=2.20.0=py38_0 75 | - docutils=0.16=py38_1 76 | - entrypoints=0.3=py38_0 77 | - et_xmlfile=1.0.1=py_1001 78 | - fastcache=1.1.0=py38he774522_0 79 | - filelock=3.0.12=py_0 80 | - flake8=3.8.3=py_0 81 | - flask=1.1.2=py_0 82 | - freetype=2.10.2=hd328e21_0 83 | - fsspec=0.7.4=py_0 84 | - future=0.18.2=py38_1 85 | - get_terminal_size=1.0.0=h38e98db_0 86 | - gevent=20.6.2=py38he774522_0 87 | - gitdb=4.0.7=pyhd8ed1ab_0 88 | - gitpython=3.1.18=pyhd8ed1ab_0 89 | - glob2=0.7=py_0 90 | - gmpy2=2.0.8=py38h7edee0f_3 91 | - graphviz=2.38=hfd603c8_2 92 | - greenlet=0.4.16=py38he774522_0 93 | - h5py=2.10.0=py38h5e291fa_0 94 | - hdf5=1.10.4=h7ebc959_0 95 | - heapdict=1.0.1=py_0 96 | - html5lib=1.1=py_0 97 | - icc_rt=2019.0.0=h0cc432a_1 98 | - icu=58.2=ha925a31_3 99 | - idna=2.10=py_0 100 | - imageio=2.9.0=py_0 101 | - imagesize=1.2.0=py_0 102 | - importlib-metadata=1.7.0=py38_0 103 | - importlib_metadata=1.7.0=0 104 | - intel-openmp=2020.1=216 105 | - intervaltree=3.0.2=py_1 106 | - ipykernel=5.3.2=py38h5ca1d4c_0 107 | - ipython=7.16.1=py38h5ca1d4c_0 108 | - ipython_genutils=0.2.0=py38_0 109 | - ipywidgets=7.5.1=py_0 110 | - isort=4.3.21=py38_0 111 | - itsdangerous=1.1.0=py_0 112 | - jdcal=1.4.1=py_0 113 | - jedi=0.17.1=py38_0 114 | - jinja2=2.11.2=py_0 115 | - joblib=0.16.0=py_0 116 | - jpeg=9b=hb83a4c4_2 117 | - json5=0.9.5=py_0 118 | - jsonschema=3.2.0=py38_0 119 | - jupyter=1.0.0=py38_7 120 | - jupyter_client=6.1.6=py_0 121 | - jupyter_console=6.1.0=py_0 122 | - jupyter_core=4.6.3=py38_0 123 | - jupyterlab=2.1.5=py_0 124 | - jupyterlab-git=0.22.3=pyhd8ed1ab_0 125 | - jupyterlab_server=1.2.0=py_0 126 | - keyring=21.2.1=py38_0 127 | - kiwisolver=1.2.0=py38h74a9793_0 128 | - krb5=1.18.2=hc04afaa_0 129 | - lazy-object-proxy=1.4.3=py38he774522_0 130 | - libarchive=3.4.2=h5e25573_0 131 | - libcurl=7.71.1=h2a8f88b_1 132 | - libflac=1.3.3=h0e60522_1 133 | - libiconv=1.15=h1df5818_7 134 | - liblief=0.10.1=ha925a31_0 135 | - libllvm9=9.0.1=h21ff451_0 136 | - libogg=1.3.4=h8ffe710_1 137 | - libopencv=4.0.1=hbb9e17c_0 138 | - libopus=1.3.1=h8ffe710_1 139 | - libpng=1.6.37=h2a8f88b_0 140 | - libprotobuf=3.14.0=h7755175_0 141 | - librosa=0.8.1=pyhd8ed1ab_0 142 | - libsndfile=1.0.31=h0e60522_1 143 | - libsodium=1.0.18=h62dcd97_0 144 | - libspatialindex=1.9.3=h33f27b4_0 145 | - libssh2=1.9.0=h7a1dbc1_1 146 | - libtiff=4.1.0=h56a325e_1 147 | - libuv=1.41.1=h8ffe710_0 148 | - libvorbis=1.3.7=h0e60522_0 149 | - libxml2=2.9.10=h464c3ec_1 150 | - libxslt=1.1.34=he774522_0 151 | - littleutils=0.2.2=py_0 152 | - llvmlite=0.33.0=py38ha925a31_0 153 | - locket=0.2.0=py38_1 154 | - lxml=4.5.2=py38h1350720_0 155 | - lz4-c=1.9.2=h62dcd97_0 156 | - lzo=2.10=he774522_2 157 | - m2w64-gcc-libgfortran=5.3.0=6 158 | - m2w64-gcc-libs=5.3.0=7 159 | - m2w64-gcc-libs-core=5.3.0=7 160 | - m2w64-gmp=6.1.0=2 161 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2 162 | - markupsafe=1.1.1=py38he774522_0 163 | - matplotlib=3.2.2=0 164 | - matplotlib-base=3.2.2=py38h64f37c6_0 165 | - mccabe=0.6.1=py38_1 166 | - menuinst=1.4.16=py38he774522_1 167 | - mistune=0.8.4=py38he774522_1000 168 | - mkl=2020.1=216 169 | - mkl-service=2.3.0=py38hb782905_0 170 | - mkl_fft=1.1.0=py38h45dec08_0 171 | - mkl_random=1.1.1=py38h47e9c7a_0 172 | - mock=4.0.2=py_0 173 | - more-itertools=8.4.0=py_0 174 | - mpc=1.1.0=h7edee0f_1 175 | - mpfr=4.0.2=h62dcd97_1 176 | - mpir=3.0.0=hec2e145_1 177 | - mpmath=1.1.0=py38_0 178 | - msgpack-python=1.0.0=py38h74a9793_1 179 | - msys2-conda-epoch=20160418=1 180 | - multipledispatch=0.6.0=py38_0 181 | - navigator-updater=0.2.1=py38_0 182 | - nbconvert=5.6.1=py38_0 183 | - nbdime=2.1.0=py_0 184 | - nbformat=5.0.7=py_0 185 | - networkx=2.4=py_1 186 | - ninja=1.10.2=h5362a0b_0 187 | - nltk=3.5=py_0 188 | - nodejs=15.14.0=h57928b3_0 189 | - nose=1.3.7=py38_2 190 | - notebook=6.0.3=py38_0 191 | - numba=0.50.1=py38h47e9c7a_0 192 | - numexpr=2.7.1=py38h25d0782_0 193 | - numpy=1.18.5=py38h6530119_0 194 | - numpy-base=1.18.5=py38hc3f5095_0 195 | - numpydoc=1.1.0=py_0 196 | - olefile=0.46=py_0 197 | - opencv=4.0.1=py38h2a7c758_0 198 | - openpyxl=3.0.4=py_0 199 | - openssl=1.1.1g=he774522_0 200 | - outdated=0.2.0=py_0 201 | - packaging=20.4=py_0 202 | - pandas=1.0.5=py38h47e9c7a_0 203 | - pandoc=2.10=0 204 | - pandocfilters=1.4.2=py38_1 205 | - paramiko=2.7.1=py_0 206 | - parso=0.7.0=py_0 207 | - partd=1.1.0=py_0 208 | - path=13.1.0=py38_0 209 | - path.py=12.4.0=0 210 | - pathlib2=2.3.5=py38_0 211 | - pathtools=0.1.2=py_1 212 | - patsy=0.5.1=py38_0 213 | - pep8=1.7.1=py38_0 214 | - pexpect=4.8.0=py38_0 215 | - pickleshare=0.7.5=py38_1000 216 | - pillow=7.2.0=py38hcc1f983_0 217 | - pip=20.1.1=py38_1 218 | - pixman=0.38.0=hfa6e2cd_1003 219 | - pkginfo=1.5.0.1=py38_0 220 | - pluggy=0.13.1=py38_0 221 | - ply=3.11=py38_0 222 | - pooch=1.5.2=pyhd8ed1ab_0 223 | - portaudio=19.6.0=h0e60522_4 224 | - powershell_shortcut=0.0.1=3 225 | - prometheus_client=0.8.0=py_0 226 | - prompt-toolkit=3.0.5=py_0 227 | - prompt_toolkit=3.0.5=0 228 | - psutil=5.7.0=py38he774522_0 229 | - py=1.9.0=py_0 230 | - py-lief=0.10.1=py38ha925a31_0 231 | - py-opencv=4.0.1=py38he44ac1e_0 232 | - pycairo=1.20.1=py38h979ce04_0 233 | - pycodestyle=2.6.0=py_0 234 | - pycosat=0.6.3=py38he774522_0 235 | - pycparser=2.20=py_2 236 | - pycurl=7.43.0.5=py38h7a1dbc1_0 237 | - pydocstyle=5.0.2=py_0 238 | - pydot=1.4.1=py38_0 239 | - pyflakes=2.2.0=py_0 240 | - pygments=2.6.1=py_0 241 | - pylint=2.5.3=py38_0 242 | - pynacl=1.4.0=py38h62dcd97_1 243 | - pyodbc=4.0.30=py38ha925a31_0 244 | - pyopenssl=19.1.0=py_1 245 | - pyparsing=2.4.7=py_0 246 | - pyqt=5.9.2=py38ha925a31_4 247 | - pyreadline=2.1=py38_1 248 | - pyrsistent=0.16.0=py38he774522_0 249 | - pysocks=1.7.1=py38_0 250 | - pysoundfile=0.10.3.post1=pyhd3deb0d_0 251 | - pytables=3.6.1=py38ha5be198_0 252 | - pytest=5.4.3=py38_0 253 | - python=3.8.3=he1778fa_2 254 | - python-dateutil=2.8.1=py_0 255 | - python-jsonrpc-server=0.3.4=py_1 256 | - python-language-server=0.34.1=py38_0 257 | - python-libarchive-c=2.9=py_0 258 | - python-sounddevice=0.4.1=pyh9f0ad1d_0 259 | - python_abi=3.8=2_cp38 260 | - pytorch=1.7.0=py3.8_cuda110_cudnn8_0 261 | - pytz=2020.1=py_0 262 | - pywavelets=1.1.1=py38he774522_0 263 | - pywin32=227=py38he774522_1 264 | - pywin32-ctypes=0.2.0=py38_1000 265 | - pywinpty=0.5.7=py38_0 266 | - pyyaml=5.3.1=py38he774522_1 267 | - pyzmq=19.0.1=py38ha925a31_1 268 | - qdarkstyle=2.8.1=py_0 269 | - qt=5.9.7=vc14h73c81de_0 270 | - qtawesome=0.7.2=py_0 271 | - qtconsole=4.7.5=py_0 272 | - qtpy=1.9.0=py_0 273 | - rdkit=2020.09.1=py38h35bba09_0 274 | - regex=2020.6.8=py38he774522_0 275 | - requests=2.24.0=py_0 276 | - resampy=0.2.2=py_0 277 | - rope=0.17.0=py_0 278 | - rtree=0.9.4=py38h21ff451_1 279 | - ruamel_yaml=0.15.87=py38he774522_1 280 | - scikit-image=0.16.2=py38h47e9c7a_0 281 | - scikit-learn=0.23.1=py38h25d0782_0 282 | - scipy=1.5.0=py38h9439919_0 283 | - seaborn=0.10.1=py_0 284 | - send2trash=1.5.0=py38_0 285 | - setuptools=49.2.0=py38_0 286 | - simplegeneric=0.8.1=py38_2 287 | - singledispatch=3.4.0.3=py38_0 288 | - sip=4.19.13=py38ha925a31_0 289 | - six=1.15.0=py_0 290 | - smmap=3.0.5=pyh44b312d_0 291 | - snappy=1.1.8=h33f27b4_0 292 | - snowballstemmer=2.0.0=py_0 293 | - sortedcollections=1.2.1=py_0 294 | - sortedcontainers=2.2.2=py_0 295 | - soupsieve=2.0.1=py_0 296 | - sphinx=3.1.2=py_0 297 | - sphinxcontrib=1.0=py38_1 298 | - sphinxcontrib-applehelp=1.0.2=py_0 299 | - sphinxcontrib-devhelp=1.0.2=py_0 300 | - sphinxcontrib-htmlhelp=1.0.3=py_0 301 | - sphinxcontrib-jsmath=1.0.1=py_0 302 | - sphinxcontrib-qthelp=1.0.3=py_0 303 | - sphinxcontrib-serializinghtml=1.1.4=py_0 304 | - sphinxcontrib-websupport=1.2.3=py_0 305 | - spyder=4.1.4=py38_0 306 | - spyder-kernels=1.9.2=py38_0 307 | - sqlalchemy=1.3.18=py38he774522_0 308 | - sqlite=3.32.3=h2a8f88b_0 309 | - statsmodels=0.11.1=py38he774522_0 310 | - sympy=1.6.1=py38_0 311 | - tbb=2020.0=h74a9793_0 312 | - tblib=1.6.0=py_0 313 | - terminado=0.8.3=py38_0 314 | - testpath=0.4.4=py_0 315 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 316 | - tk=8.6.10=he774522_0 317 | - toml=0.10.1=py_0 318 | - toolz=0.10.0=py_0 319 | - torchaudio=0.7.0=py38 320 | - torchvision=0.8.1=py38_cu110 321 | - tornado=6.0.4=py38he774522_1 322 | - tqdm=4.47.0=py_0 323 | - traitlets=4.3.3=py38_0 324 | - typing_extensions=3.7.4.2=py_0 325 | - ujson=1.35=py38he774522_0 326 | - unicodecsv=0.14.1=py38_0 327 | - urllib3=1.25.9=py_0 328 | - vc=14.1=h0510ff6_4 329 | - vs2015_runtime=14.16.27012=hf0eaf9b_3 330 | - watchdog=0.10.3=py38_0 331 | - wcwidth=0.2.5=py_0 332 | - webencodings=0.5.1=py38_1 333 | - werkzeug=1.0.1=py_0 334 | - wheel=0.34.2=py38_0 335 | - widgetsnbextension=3.5.1=py38_0 336 | - win_inet_pton=1.1.0=py38_0 337 | - win_unicode_console=0.5=py38_0 338 | - wincertstore=0.2=py38_0 339 | - winpty=0.4.3=4 340 | - wrapt=1.11.2=py38he774522_0 341 | - xlrd=1.2.0=py_0 342 | - xlsxwriter=1.2.9=py_0 343 | - xlwings=0.19.5=py38_0 344 | - xlwt=1.3.0=py38_0 345 | - xmltodict=0.12.0=py_0 346 | - xz=5.2.5=h62dcd97_0 347 | - yaml=0.2.5=he774522_0 348 | - yapf=0.30.0=py_0 349 | - zeromq=4.3.2=ha925a31_2 350 | - zict=2.0.0=py_0 351 | - zipp=3.1.0=py_0 352 | - zlib=1.2.11=h62dcd97_4 353 | - zope=1.0=py38_1 354 | - zope.event=4.4=py38_0 355 | - zope.interface=4.7.1=py38he774522_0 356 | - zstd=1.4.5=ha9fde0e_0 357 | - pip: 358 | - absl-py==0.10.0 359 | - astunparse==1.6.3 360 | - cachetools==4.1.1 361 | - dataclasses==0.6 362 | - dm-tree==0.1.5 363 | - gast==0.3.3 364 | - google-auth==1.22.1 365 | - google-auth-oauthlib==0.4.1 366 | - google-pasta==0.2.0 367 | - grpcio==1.33.1 368 | - jupyter-http-over-ws==0.0.8 369 | - keras-preprocessing==1.1.2 370 | - markdown==3.3.3 371 | - oauthlib==3.1.0 372 | - ogb==1.3.2 373 | - opt-einsum==3.3.0 374 | - protobuf==3.13.0 375 | - pyasn1==0.4.8 376 | - pyasn1-modules==0.2.8 377 | - python-speech-features==0.6 378 | - requests-oauthlib==1.3.0 379 | - rsa==4.6 380 | - tensorboard==2.5.0 381 | - tensorboard-data-server==0.6.0 382 | - tensorboard-plugin-wit==1.7.0 383 | - tensorflow==2.3.1 384 | - tensorflow-estimator==2.3.0 385 | - tensorflow-probability==0.11.1 386 | - termcolor==1.1.0 387 | - torch-tb-profiler==0.1.0 388 | - youtube-dl==2020.11.1.1 389 | prefix: C:\ProgramData\Anaconda3 390 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamim-hussain/egt_pytorch/9e66956a5fdc6f6e8a865863d029468380bb63e5/lib/__init__.py -------------------------------------------------------------------------------- /lib/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shamim-hussain/egt_pytorch/9e66956a5fdc6f6e8a865863d029468380bb63e5/lib/data/__init__.py -------------------------------------------------------------------------------- /lib/data/dataset_base.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import dataset 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | 7 | 8 | class DatasetBase(dataset.Dataset): 9 | def __init__(self, 10 | dataset_name, split, 11 | cache_dir = None, 12 | load_cache_if_exists=True, 13 | **kwargs): 14 | super().__init__(**kwargs) 15 | self.dataset_name = dataset_name 16 | self.split = split 17 | self.cache_dir = cache_dir 18 | 19 | self.is_cached = False 20 | if load_cache_if_exists: 21 | self.cache(verbose=0, must_exist=True) 22 | 23 | @property 24 | def record_tokens(self): 25 | raise NotImplementedError 26 | 27 | def read_record(self, token): 28 | raise NotImplementedError 29 | 30 | def __len__(self): 31 | return len(self.record_tokens) 32 | 33 | def __getitem__(self, index): 34 | token = self.record_tokens[index] 35 | try: 36 | return self._records[token] 37 | except AttributeError: 38 | record = self.read_record(token) 39 | self._records = {token:record} 40 | return record 41 | except KeyError: 42 | record = self.read_record(token) 43 | self._records[token] = record 44 | return record 45 | 46 | def read_all_records(self, verbose=1): 47 | self._records = {} 48 | if verbose: 49 | print(f'Reading all {self.split} records...', flush=True) 50 | for token in tqdm(self.record_tokens): 51 | self._records[token] = self.read_record(token) 52 | else: 53 | for token in self.record_tokens: 54 | self._records[token] = self.read_record(token) 55 | 56 | def get_cache_path(self, path=None): 57 | if path is None: path = self.cache_dir 58 | base_path = (Path(path)/self.dataset_name)/self.split 59 | base_path.mkdir(parents=True, exist_ok=True) 60 | return base_path 61 | 62 | def cache_load_and_save(self, base_path, op, verbose): 63 | tokens_path = base_path/'tokens.pt' 64 | records_path = base_path/'records.pt' 65 | 66 | if op == 'load': 67 | self._record_tokens = torch.load(str(tokens_path)) 68 | self._records = torch.load(str(records_path)) 69 | elif op == 'save': 70 | if tokens_path.exists() and records_path.exists() \ 71 | and hasattr(self, '_record_tokens') and hasattr(self, '_records'): 72 | return 73 | self.read_all_records(verbose=verbose) 74 | torch.save(self.record_tokens, str(tokens_path)) 75 | torch.save(self._records, str(records_path)) 76 | else: 77 | raise ValueError(f'Unknown operation: {op}') 78 | 79 | def cache(self, path=None, verbose=1, must_exist=False): 80 | if self.is_cached: return 81 | 82 | base_path = self.get_cache_path(path) 83 | try: 84 | if verbose: print(f'Trying to load {self.split} cache from disk...', flush=True) 85 | self.cache_load_and_save(base_path, 'load', verbose) 86 | if verbose: print(f'Loaded {self.split} cache from disk.', flush=True) 87 | except FileNotFoundError: 88 | if must_exist: return 89 | 90 | if verbose: print(f'{self.split} cache does not exist! Cacheing...', flush=True) 91 | self.cache_load_and_save(base_path, 'save', verbose) 92 | if verbose: print(f'Saved {self.split} cache to disk.', flush=True) 93 | 94 | self.is_cached = True 95 | 96 | 97 | -------------------------------------------------------------------------------- /lib/data/graph_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph_dataset import GraphDataset, graphdata_collate 2 | from .svd_encodings_dataset import SVDEncodingsGraphDataset 3 | from .structural_dataset import StructuralDataset 4 | -------------------------------------------------------------------------------- /lib/data/graph_dataset/graph_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | 5 | from ..dataset_base import DatasetBase 6 | 7 | from .stack_with_pad import stack_with_pad 8 | from collections import defaultdict 9 | from numba.typed import List 10 | 11 | 12 | class GraphDataset(DatasetBase): 13 | def __init__(self, 14 | num_nodes_key = 'num_nodes', 15 | edges_key = 'edges', 16 | node_features_key = 'node_features', 17 | edge_features_key = 'edge_features', 18 | node_mask_key = 'node_mask', 19 | targets_key = 'target', 20 | include_node_mask = True, 21 | **kwargs): 22 | super().__init__(**kwargs) 23 | self.num_nodes_key = num_nodes_key 24 | self.edges_key = edges_key 25 | self.node_features_key = node_features_key 26 | self.edge_features_key = edge_features_key 27 | self.node_mask_key = node_mask_key 28 | self.targets_key = targets_key 29 | self.include_node_mask = include_node_mask 30 | 31 | def __getitem__(self, index): 32 | item = super().__getitem__(index) 33 | if self.include_node_mask: 34 | item = item.copy() 35 | item[self.node_mask_key] = np.ones((item[self.num_nodes_key],), dtype=np.uint8) 36 | return item 37 | 38 | def _calculate_max_nodes(self): 39 | self._max_nodes = self[0][self.num_nodes_key] 40 | self._max_nodes_index = 0 41 | for i in range(1, super().__len__()): 42 | graph = super().__getitem__(i) 43 | cur_nodes = graph[self.num_nodes_key] 44 | if cur_nodes > self._max_nodes: 45 | self._max_nodes = cur_nodes 46 | self._max_nodes_index = i 47 | 48 | @property 49 | def max_nodes(self): 50 | try: 51 | return self._max_nodes 52 | except AttributeError: 53 | self._calculate_max_nodes() 54 | return self._max_nodes 55 | 56 | @property 57 | def max_nodes_index(self): 58 | try: 59 | return self._max_nodes_index 60 | except AttributeError: 61 | self._calculate_max_nodes() 62 | return self._max_nodes_index 63 | 64 | def cache_load_and_save(self, base_path, op, verbose): 65 | super().cache_load_and_save(base_path, op, verbose) 66 | max_nodes_path = base_path/'max_nodes_data.pt' 67 | 68 | if op == 'load': 69 | max_nodes_data = torch.load(str(max_nodes_path)) 70 | self._max_nodes = max_nodes_data['max_nodes'] 71 | self._max_nodes_index = max_nodes_data['max_nodes_index'] 72 | elif op == 'save': 73 | if verbose: print(f'Calculating {self.split} max nodes...',flush=True) 74 | max_nodes_data = {'max_nodes': self.max_nodes, 75 | 'max_nodes_index': self.max_nodes_index} 76 | torch.save(max_nodes_data, str(max_nodes_path)) 77 | else: 78 | raise ValueError(f'Unknown operation: {op}') 79 | 80 | def max_batch(self, batch_size, collate_fn): 81 | return collate_fn([self.__getitem__(self.max_nodes_index)] * batch_size) 82 | 83 | 84 | 85 | def graphdata_collate(batch): 86 | batch_data = defaultdict(List) 87 | for elem in batch: 88 | for k,v in elem.items(): 89 | batch_data[k].append(v) 90 | 91 | out = {k:torch.from_numpy(stack_with_pad(dat)) 92 | for k, dat in batch_data.items()} 93 | return out 94 | -------------------------------------------------------------------------------- /lib/data/graph_dataset/stack_with_pad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba as nb 3 | 4 | 5 | @nb.njit 6 | def stack_with_pad_4d(inputs): 7 | num_elem = len(inputs) 8 | ms_0, ms_1, ms_2, ms_3 = inputs[0].shape 9 | 10 | for i in range(1,num_elem): 11 | is_0, is_1, is_2, is_3 = inputs[i].shape 12 | ms_0 = max(is_0, ms_0) 13 | ms_1 = max(is_1, ms_1) 14 | ms_2 = max(is_2, ms_2) 15 | ms_3 = max(is_3, ms_3) 16 | 17 | stacked_shape = (num_elem,ms_0,ms_1,ms_2,ms_3) 18 | stacked = np.zeros(stacked_shape, dtype=inputs[0].dtype) 19 | 20 | for i, elem in enumerate(inputs): 21 | stacked[i][:elem.shape[0],:elem.shape[1],:elem.shape[2],:elem.shape[3]] = elem 22 | return stacked 23 | 24 | @nb.njit 25 | def stack_with_pad_3d(inputs): 26 | num_elem = len(inputs) 27 | ms_0, ms_1, ms_2 = inputs[0].shape 28 | 29 | for i in range(1,num_elem): 30 | is_0, is_1, is_2 = inputs[i].shape 31 | ms_0 = max(is_0, ms_0) 32 | ms_1 = max(is_1, ms_1) 33 | ms_2 = max(is_2, ms_2) 34 | 35 | stacked_shape = (num_elem,ms_0,ms_1,ms_2) 36 | stacked = np.zeros(stacked_shape, dtype=inputs[0].dtype) 37 | 38 | for i, elem in enumerate(inputs): 39 | stacked[i][:elem.shape[0],:elem.shape[1],:elem.shape[2]] = elem 40 | return stacked 41 | 42 | @nb.njit 43 | def stack_with_pad_2d(inputs): 44 | num_elem = len(inputs) 45 | ms_0, ms_1 = inputs[0].shape 46 | 47 | for i in range(1,num_elem): 48 | is_0, is_1 = inputs[i].shape 49 | ms_0 = max(is_0, ms_0) 50 | ms_1 = max(is_1, ms_1) 51 | 52 | stacked_shape = (num_elem,ms_0,ms_1) 53 | stacked = np.zeros(stacked_shape, dtype=inputs[0].dtype) 54 | 55 | for i, elem in enumerate(inputs): 56 | stacked[i][:elem.shape[0],:elem.shape[1]] = elem 57 | return stacked 58 | 59 | @nb.njit 60 | def stack_with_pad_1d(inputs): 61 | num_elem = len(inputs) 62 | ms_0 = inputs[0].shape[0] 63 | 64 | for i in range(1,num_elem): 65 | is_0 = inputs[i].shape[0] 66 | ms_0 = max(is_0, ms_0) 67 | 68 | stacked_shape = (num_elem,ms_0) 69 | stacked = np.zeros(stacked_shape, dtype=inputs[0].dtype) 70 | 71 | for i, elem in enumerate(inputs): 72 | stacked[i][:elem.shape[0]] = elem 73 | return stacked 74 | 75 | 76 | def stack_with_pad(inputs): 77 | shape_rank = np.ndim(inputs[0]) 78 | if shape_rank == 0: 79 | return np.stack(inputs) 80 | if shape_rank == 1: 81 | return stack_with_pad_1d(inputs) 82 | elif shape_rank == 2: 83 | return stack_with_pad_2d(inputs) 84 | elif shape_rank == 3: 85 | return stack_with_pad_3d(inputs) 86 | elif shape_rank == 4: 87 | return stack_with_pad_4d(inputs) 88 | else: 89 | raise ValueError('Only support up to 4D tensor') 90 | 91 | 92 | -------------------------------------------------------------------------------- /lib/data/graph_dataset/structural_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba as nb 3 | 4 | from .graph_dataset import GraphDataset 5 | 6 | NODE_FEATURES_OFFSET = 128 7 | EDGE_FEATURES_OFFSET = 8 8 | 9 | @nb.njit 10 | def floyd_warshall(A): 11 | n = A.shape[0] 12 | D = np.zeros((n,n), dtype=np.int16) 13 | 14 | for i in range(n): 15 | for j in range(n): 16 | if i == j: 17 | pass 18 | elif A[i,j] == 0: 19 | D[i,j] = 510 20 | else: 21 | D[i,j] = 1 22 | 23 | for k in range(n): 24 | for i in range(n): 25 | for j in range(n): 26 | old_dist = D[i,j] 27 | new_dist = D[i,k] + D[k,j] 28 | if new_dist < old_dist: 29 | D[i,j] = new_dist 30 | return D 31 | 32 | @nb.njit 33 | def preprocess_data(num_nodes, edges, node_feats, edge_feats): 34 | node_feats = node_feats + np.arange(1,node_feats.shape[-1]*NODE_FEATURES_OFFSET+1, 35 | NODE_FEATURES_OFFSET,dtype=np.int16) 36 | edge_feats = edge_feats + np.arange(1,edge_feats.shape[-1]*EDGE_FEATURES_OFFSET+1, 37 | EDGE_FEATURES_OFFSET,dtype=np.int16) 38 | 39 | A = np.zeros((num_nodes,num_nodes),dtype=np.int16) 40 | E = np.zeros((num_nodes,num_nodes,edge_feats.shape[-1]),dtype=np.int16) 41 | for k in range(edges.shape[0]): 42 | i,j = edges[k,0], edges[k,1] 43 | A[i,j] = 1 44 | E[i,j] = edge_feats[k] 45 | 46 | D = floyd_warshall(A) 47 | return node_feats, D, E 48 | 49 | 50 | class StructuralDataset(GraphDataset): 51 | def __init__(self, 52 | distance_matrix_key = 'distance_matrix', 53 | feature_matrix_key = 'feature_matrix', 54 | **kwargs): 55 | super().__init__(**kwargs) 56 | self.distance_matrix_key = distance_matrix_key 57 | self.feature_matrix_key = feature_matrix_key 58 | 59 | def __getitem__(self, index): 60 | item = super().__getitem__(index) 61 | 62 | num_nodes = int(item[self.num_nodes_key]) 63 | edges = item.pop(self.edges_key) 64 | node_feats = item.pop(self.node_features_key) 65 | edge_feats = item.pop(self.edge_features_key) 66 | 67 | node_feats, dist_mat, edge_feats_mat = preprocess_data(num_nodes, edges, node_feats, edge_feats) 68 | item[self.node_features_key] = node_feats 69 | item[self.distance_matrix_key] = dist_mat 70 | item[self.feature_matrix_key] = edge_feats_mat 71 | 72 | return item 73 | 74 | -------------------------------------------------------------------------------- /lib/data/graph_dataset/svd_encodings_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from tqdm import trange 5 | import numba as nb 6 | 7 | from .graph_dataset import GraphDataset 8 | 9 | class SVDEncodingsDatasetBase: 10 | def __init__(self, 11 | svd_encodings_key = 'svd_encodings', 12 | calculated_dim = 8, 13 | output_dim = 8, 14 | random_neg_splits = ['training'], 15 | **kwargs): 16 | if output_dim > calculated_dim: 17 | raise ValueError('SVD: output_dim > calculated_dim') 18 | super().__init__(**kwargs) 19 | self.svd_encodings_key = svd_encodings_key 20 | self.calculated_dim = calculated_dim 21 | self.output_dim = output_dim 22 | self.random_neg_splits = random_neg_splits 23 | 24 | def calculate_encodings(self, item): 25 | raise NotImplementedError('SVDEncodingsDatasetBase.calculate_encodings()') 26 | 27 | def __getitem__(self, index): 28 | item = super().__getitem__(index) 29 | token = self.record_tokens[index] 30 | 31 | try: 32 | encodings = self._svd_encodings[token] 33 | except AttributeError: 34 | encodings = self.calculate_encodings(item) 35 | self._svd_encodings = {token:encodings} 36 | except KeyError: 37 | encodings = self.calculate_encodings(item) 38 | self._svd_encodings[token] = encodings 39 | 40 | if self.output_dim < self.calculated_dim: 41 | encodings = encodings[:,:self.output_dim,:] 42 | 43 | if self.split in self.random_neg_splits: 44 | rn_factors = np.random.randint(0, high=2, size=(encodings.shape[1],1))*2-1 #size=(encodings.shape[0],1,1) 45 | encodings = encodings * rn_factors.astype(encodings.dtype) 46 | 47 | item[self.svd_encodings_key] = encodings.reshape(encodings.shape[0],-1) 48 | return item 49 | 50 | def calculate_all_svd_encodings(self,verbose=1): 51 | self._svd_encodings = {} 52 | if verbose: 53 | print(f'Calculating all {self.split} SVD encodings...', flush=True) 54 | for index in trange(super().__len__()): 55 | item = super().__getitem__(index) 56 | token = self.record_tokens[index] 57 | self._svd_encodings[token] = self.calculate_encodings(item) 58 | else: 59 | for index in range(super().__len__()): 60 | item = super().__getitem__(index) 61 | token = self.record_tokens[index] 62 | self._svd_encodings[token] = self.calculate_encodings(item) 63 | 64 | def cache_load_and_save(self, base_path, op, verbose): 65 | super().cache_load_and_save(base_path, op, verbose) 66 | svd_encodings_path = base_path/'svd_encodings.pt' 67 | 68 | if op == 'load': 69 | self._svd_encodings = torch.load(str(svd_encodings_path)) 70 | elif op == 'save': 71 | if verbose: print(f'{self.split} SVD encodings cache does not exist! Cacheing...', flush=True) 72 | self.calculate_all_svd_encodings(verbose=verbose) 73 | torch.save(self._svd_encodings, str(svd_encodings_path)) 74 | if verbose: print(f'Saved {self.split} SVD encodings cache to disk.', flush=True) 75 | else: 76 | raise ValueError(f'Unknown operation: {op}') 77 | 78 | 79 | @nb.njit 80 | def calculate_svd_encodings(edges, num_nodes, calculated_dim): 81 | adj = np.zeros((num_nodes,num_nodes),dtype=np.float32) 82 | for i in range(edges.shape[0]): 83 | adj[nb.int64(edges[i,0]),nb.int64(edges[i,1])] = 1 84 | 85 | for i in range(num_nodes): 86 | adj[i,i] = 1 87 | u, s, vh = np.linalg.svd(adj) 88 | 89 | if calculated_dim < num_nodes: 90 | s = s[:calculated_dim] 91 | u = u[:,:calculated_dim] 92 | vh = vh[:calculated_dim,:] 93 | 94 | encodings = np.stack((u,vh.T),axis=-1) * np.expand_dims(np.sqrt(s), axis=-1) 95 | elif calculated_dim > num_nodes: 96 | z = np.zeros((num_nodes,calculated_dim-num_nodes,2),dtype=np.float32) 97 | encodings = np.concatenate((np.stack((u,vh.T),axis=-1) * np.expand_dims(np.sqrt(s), axis=-1), z), axis=1) 98 | else: 99 | encodings = np.stack((u,vh.T),axis=-1) * np.expand_dims(np.sqrt(s), axis=-1) 100 | return encodings 101 | 102 | 103 | class SVDEncodingsGraphDataset(SVDEncodingsDatasetBase, GraphDataset): 104 | def calculate_encodings(self, item): 105 | num_nodes = int(item[self.num_nodes_key]) 106 | edges = item[self.edges_key] 107 | encodings = calculate_svd_encodings(edges, num_nodes, self.calculated_dim) 108 | return encodings 109 | 110 | 111 | -------------------------------------------------------------------------------- /lib/data/molhiv/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * -------------------------------------------------------------------------------- /lib/data/molhiv/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from ..dataset_base import DatasetBase 4 | from ..graph_dataset import GraphDataset 5 | from ..graph_dataset import SVDEncodingsGraphDataset 6 | from ..graph_dataset import StructuralDataset 7 | 8 | class MOLHIVDataset(DatasetBase): 9 | def __init__(self, 10 | dataset_path , 11 | dataset_name = 'MOLHIV' , 12 | **kwargs 13 | ): 14 | super().__init__(dataset_name = dataset_name, 15 | **kwargs) 16 | self.dataset_path = dataset_path 17 | 18 | @property 19 | def dataset(self): 20 | try: 21 | return self._dataset 22 | except AttributeError: 23 | from ogb.graphproppred import GraphPropPredDataset 24 | self._dataset = GraphPropPredDataset(name='ogbg-molhiv', root=self.dataset_path) 25 | return self._dataset 26 | 27 | @property 28 | def record_tokens(self): 29 | try: 30 | return self._record_tokens 31 | except AttributeError: 32 | split = {'training':'train', 33 | 'validation':'valid', 34 | 'test':'test'}[self.split] 35 | self._record_tokens = self.dataset.get_idx_split()[split] 36 | return self._record_tokens 37 | 38 | def read_record(self, token): 39 | graph, target = self.dataset[token] 40 | graph['num_nodes'] = np.array(graph['num_nodes'], dtype=np.int16) 41 | graph['edges'] = graph.pop('edge_index').T.astype(np.int16) 42 | graph['edge_features'] = graph.pop('edge_feat').astype(np.int16) 43 | graph['node_features'] = graph.pop('node_feat').astype(np.int16) 44 | graph['target'] = np.array(target, np.float32) 45 | return graph 46 | 47 | 48 | 49 | class MOLHIVGraphDataset(GraphDataset,MOLHIVDataset): 50 | pass 51 | 52 | class MOLHIVSVDGraphDataset(SVDEncodingsGraphDataset,MOLHIVDataset): 53 | pass 54 | 55 | class MOLHIVStructuralGraphDataset(StructuralDataset,MOLHIVGraphDataset): 56 | pass 57 | 58 | class MOLHIVStructuralSVDGraphDataset(StructuralDataset,MOLHIVSVDGraphDataset): 59 | pass 60 | -------------------------------------------------------------------------------- /lib/data/molpcba/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * -------------------------------------------------------------------------------- /lib/data/molpcba/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from ..dataset_base import DatasetBase 4 | from ..graph_dataset import GraphDataset 5 | from ..graph_dataset import SVDEncodingsGraphDataset 6 | from ..graph_dataset import StructuralDataset 7 | 8 | class MOLPCBADataset(DatasetBase): 9 | def __init__(self, 10 | dataset_path , 11 | dataset_name = 'MOLPCBA' , 12 | **kwargs 13 | ): 14 | super().__init__(dataset_name = dataset_name, 15 | **kwargs) 16 | self.dataset_path = dataset_path 17 | 18 | @property 19 | def dataset(self): 20 | try: 21 | return self._dataset 22 | except AttributeError: 23 | from ogb.graphproppred import GraphPropPredDataset 24 | self._dataset = GraphPropPredDataset(name='ogbg-molpcba', root=self.dataset_path) 25 | return self._dataset 26 | 27 | @property 28 | def record_tokens(self): 29 | try: 30 | return self._record_tokens 31 | except AttributeError: 32 | split = {'training':'train', 33 | 'validation':'valid', 34 | 'test':'test'}[self.split] 35 | self._record_tokens = self.dataset.get_idx_split()[split] 36 | return self._record_tokens 37 | 38 | def read_record(self, token): 39 | graph, target = self.dataset[token] 40 | graph['num_nodes'] = np.array(graph['num_nodes'], dtype=np.int16) 41 | graph['edges'] = graph.pop('edge_index').T.astype(np.int16) 42 | graph['edge_features'] = graph.pop('edge_feat').astype(np.int16) 43 | graph['node_features'] = graph.pop('node_feat').astype(np.int16) 44 | graph['target'] = np.array(target, np.float32) 45 | return graph 46 | 47 | 48 | 49 | class MOLPCBAGraphDataset(GraphDataset,MOLPCBADataset): 50 | pass 51 | 52 | class MOLPCBASVDGraphDataset(SVDEncodingsGraphDataset,MOLPCBADataset): 53 | pass 54 | 55 | class MOLPCBAStructuralGraphDataset(StructuralDataset,MOLPCBAGraphDataset): 56 | pass 57 | 58 | class MOLPCBAStructuralSVDGraphDataset(StructuralDataset,MOLPCBASVDGraphDataset): 59 | pass 60 | -------------------------------------------------------------------------------- /lib/data/pcqm4m/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * -------------------------------------------------------------------------------- /lib/data/pcqm4m/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from ..dataset_base import DatasetBase 4 | from ..graph_dataset import GraphDataset 5 | from ..graph_dataset import SVDEncodingsGraphDataset 6 | from ..graph_dataset import StructuralDataset 7 | 8 | class PCQM4MDataset(DatasetBase): 9 | def __init__(self, 10 | dataset_path , 11 | dataset_name = 'PCQM4M', 12 | **kwargs 13 | ): 14 | super().__init__(dataset_name = dataset_name, 15 | **kwargs) 16 | self.dataset_path = dataset_path 17 | 18 | @property 19 | def dataset(self): 20 | try: 21 | return self._dataset 22 | except AttributeError: 23 | from ogb.lsc import PCQM4MDataset 24 | from ogb.utils import smiles2graph 25 | self._smiles2graph = smiles2graph 26 | self._dataset = PCQM4MDataset(root = self.dataset_path, only_smiles=True) 27 | return self._dataset 28 | 29 | @property 30 | def record_tokens(self): 31 | try: 32 | return self._record_tokens 33 | except AttributeError: 34 | split = {'training':'train', 35 | 'validation':'valid', 36 | 'test':'test'}[self.split] 37 | self._record_tokens = self.dataset.get_idx_split()[split] 38 | return self._record_tokens 39 | 40 | def read_record(self, token): 41 | smiles, target = self.dataset[token] 42 | graph = self._smiles2graph(smiles) 43 | graph['num_nodes'] = np.array(graph['num_nodes'], dtype=np.int16) 44 | graph['edges'] = graph.pop('edge_index').T.astype(np.int16) 45 | graph['edge_features'] = graph.pop('edge_feat').astype(np.int16) 46 | graph['node_features'] = graph.pop('node_feat').astype(np.int16) 47 | graph['target'] = np.array(target, np.float32) 48 | return graph 49 | 50 | 51 | 52 | class PCQM4MGraphDataset(GraphDataset,PCQM4MDataset): 53 | pass 54 | 55 | class PCQM4MSVDGraphDataset(SVDEncodingsGraphDataset,PCQM4MDataset): 56 | pass 57 | 58 | class PCQM4MStructuralGraphDataset(StructuralDataset,PCQM4MGraphDataset): 59 | pass 60 | 61 | class PCQM4MStructuralSVDGraphDataset(StructuralDataset,PCQM4MSVDGraphDataset): 62 | pass 63 | -------------------------------------------------------------------------------- /lib/data/pcqm4mv2/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * -------------------------------------------------------------------------------- /lib/data/pcqm4mv2/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from ..dataset_base import DatasetBase 4 | from ..graph_dataset import GraphDataset 5 | from ..graph_dataset import SVDEncodingsGraphDataset 6 | from ..graph_dataset import StructuralDataset 7 | 8 | class PCQM4Mv2Dataset(DatasetBase): 9 | def __init__(self, 10 | dataset_path , 11 | dataset_name = 'PCQM4MV2', 12 | **kwargs 13 | ): 14 | super().__init__(dataset_name = dataset_name, 15 | **kwargs) 16 | self.dataset_path = dataset_path 17 | 18 | @property 19 | def dataset(self): 20 | try: 21 | return self._dataset 22 | except AttributeError: 23 | from ogb.lsc import PCQM4Mv2Dataset 24 | from ogb.utils import smiles2graph 25 | self._smiles2graph = smiles2graph 26 | self._dataset = PCQM4Mv2Dataset(root = self.dataset_path, only_smiles=True) 27 | return self._dataset 28 | 29 | @property 30 | def record_tokens(self): 31 | try: 32 | return self._record_tokens 33 | except AttributeError: 34 | split = {'training':'train', 35 | 'validation':'valid', 36 | 'test':'test-dev', 37 | 'challenge': 'test-challenge'}[self.split] 38 | self._record_tokens = self.dataset.get_idx_split()[split] 39 | return self._record_tokens 40 | 41 | def read_record(self, token): 42 | smiles, target = self.dataset[token] 43 | graph = self._smiles2graph(smiles) 44 | graph['num_nodes'] = np.array(graph['num_nodes'], dtype=np.int16) 45 | graph['edges'] = graph.pop('edge_index').T.astype(np.int16) 46 | graph['edge_features'] = graph.pop('edge_feat').astype(np.int16) 47 | graph['node_features'] = graph.pop('node_feat').astype(np.int16) 48 | graph['target'] = np.array(target, np.float32) 49 | return graph 50 | 51 | 52 | 53 | class PCQM4Mv2GraphDataset(GraphDataset,PCQM4Mv2Dataset): 54 | pass 55 | 56 | class PCQM4Mv2SVDGraphDataset(SVDEncodingsGraphDataset,PCQM4Mv2Dataset): 57 | pass 58 | 59 | class PCQM4Mv2StructuralGraphDataset(StructuralDataset,PCQM4Mv2GraphDataset): 60 | pass 61 | 62 | class PCQM4Mv2StructuralSVDGraphDataset(StructuralDataset,PCQM4Mv2SVDGraphDataset): 63 | pass 64 | -------------------------------------------------------------------------------- /lib/models/egt.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from .egt_layers import EGT_Layer, Graph 6 | 7 | 8 | class EGT_Base(nn.Module): 9 | def __init__(self, 10 | node_width = 128 , 11 | edge_width = 32 , 12 | num_heads = 8 , 13 | model_height = 4 , 14 | node_mha_dropout = 0. , 15 | node_ffn_dropout = 0. , 16 | edge_mha_dropout = 0. , 17 | edge_ffn_dropout = 0. , 18 | attn_dropout = 0. , 19 | attn_maskout = 0. , 20 | activation = 'elu' , 21 | clip_logits_value = [-5,5] , 22 | node_ffn_multiplier = 2. , 23 | edge_ffn_multiplier = 2. , 24 | scale_dot = True , 25 | scale_degree = False , 26 | node_ended = False , 27 | edge_ended = False , 28 | egt_simple = False , 29 | **kwargs 30 | ): 31 | super().__init__(**kwargs) 32 | 33 | self.node_width = node_width 34 | self.edge_width = edge_width 35 | self.num_heads = num_heads 36 | self.model_height = model_height 37 | self.node_mha_dropout = node_mha_dropout 38 | self.node_ffn_dropout = node_ffn_dropout 39 | self.edge_mha_dropout = edge_mha_dropout 40 | self.edge_ffn_dropout = edge_ffn_dropout 41 | self.attn_dropout = attn_dropout 42 | self.attn_maskout = attn_maskout 43 | self.activation = activation 44 | self.clip_logits_value = clip_logits_value 45 | self.node_ffn_multiplier = node_ffn_multiplier 46 | self.edge_ffn_multiplier = edge_ffn_multiplier 47 | self.scale_dot = scale_dot 48 | self.scale_degree = scale_degree 49 | self.node_ended = node_ended 50 | self.edge_ended = edge_ended 51 | self.egt_simple = egt_simple 52 | 53 | self.layer_common_kwargs = dict( 54 | node_width = self.node_width , 55 | edge_width = self.edge_width , 56 | num_heads = self.num_heads , 57 | node_mha_dropout = self.node_mha_dropout , 58 | node_ffn_dropout = self.node_ffn_dropout , 59 | edge_mha_dropout = self.edge_mha_dropout , 60 | edge_ffn_dropout = self.edge_ffn_dropout , 61 | attn_dropout = self.attn_dropout , 62 | attn_maskout = self.attn_maskout , 63 | activation = self.activation , 64 | clip_logits_value = self.clip_logits_value , 65 | scale_dot = self.scale_dot , 66 | scale_degree = self.scale_degree , 67 | node_ffn_multiplier = self.node_ffn_multiplier , 68 | edge_ffn_multiplier = self.edge_ffn_multiplier , 69 | ) 70 | 71 | def input_block(self, inputs): 72 | return Graph(inputs) 73 | 74 | def final_embedding(self, g): 75 | raise NotImplementedError 76 | 77 | def output_block(self, g): 78 | raise NotImplementedError 79 | 80 | def forward(self, inputs): 81 | raise NotImplementedError 82 | 83 | 84 | 85 | 86 | class EGT(EGT_Base): 87 | def __init__(self, **kwargs): 88 | super().__init__(**kwargs) 89 | 90 | self.EGT_layers = nn.ModuleList([EGT_Layer(**self.layer_common_kwargs, 91 | edge_update=(not self.egt_simple)) 92 | for _ in range(self.model_height-1)]) 93 | 94 | if (not self.node_ended) and (not self.edge_ended): 95 | pass 96 | elif not self.node_ended: 97 | self.EGT_layers.append(EGT_Layer(**self.layer_common_kwargs, node_update = False)) 98 | elif not self.edge_ended: 99 | self.EGT_layers.append(EGT_Layer(**self.layer_common_kwargs, edge_update = False)) 100 | else: 101 | self.EGT_layers.append(EGT_Layer(**self.layer_common_kwargs)) 102 | 103 | def forward(self, inputs): 104 | g = self.input_block(inputs) 105 | 106 | for layer in self.EGT_layers: 107 | g = layer(g) 108 | 109 | g = self.final_embedding(g) 110 | 111 | outputs = self.output_block(g) 112 | return outputs 113 | 114 | -------------------------------------------------------------------------------- /lib/models/egt_layers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from typing import Tuple, Optional 6 | 7 | class Graph(dict): 8 | def __dir__(self): 9 | return super().__dir__() + list(self.keys()) 10 | 11 | def __getattr__(self, key): 12 | try: 13 | return self[key] 14 | except KeyError: 15 | raise AttributeError('No such attribute: '+key) 16 | 17 | def __setattr__(self, key, value): 18 | self[key]=value 19 | 20 | def copy(self): 21 | return self.__class__(self) 22 | 23 | 24 | 25 | class EGT_Layer(nn.Module): 26 | @staticmethod 27 | @torch.jit.script 28 | def _egt(scale_dot: bool, 29 | scale_degree: bool, 30 | num_heads: int, 31 | dot_dim: int, 32 | clip_logits_min: float, 33 | clip_logits_max: float, 34 | attn_dropout: float, 35 | attn_maskout: float, 36 | training: bool, 37 | num_vns: int, 38 | QKV: torch.Tensor, 39 | G: torch.Tensor, 40 | E: torch.Tensor, 41 | mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: 42 | shp = QKV.shape 43 | Q, K, V = QKV.view(shp[0],shp[1],-1,num_heads).split(dot_dim,dim=2) 44 | 45 | A_hat = torch.einsum('bldh,bmdh->blmh', Q, K) 46 | if scale_dot: 47 | A_hat = A_hat * (dot_dim ** -0.5) 48 | 49 | H_hat = A_hat.clamp(clip_logits_min, clip_logits_max) + E 50 | 51 | if mask is None: 52 | if attn_maskout > 0 and training: 53 | rmask = torch.empty_like(H_hat).bernoulli_(attn_maskout) * -1e9 54 | gates = torch.sigmoid(G)#+rmask 55 | A_tild = F.softmax(H_hat+rmask, dim=2) * gates 56 | else: 57 | gates = torch.sigmoid(G) 58 | A_tild = F.softmax(H_hat, dim=2) * gates 59 | else: 60 | if attn_maskout > 0 and training: 61 | rmask = torch.empty_like(H_hat).bernoulli_(attn_maskout) * -1e9 62 | gates = torch.sigmoid(G+mask) 63 | A_tild = F.softmax(H_hat+mask+rmask, dim=2) * gates 64 | else: 65 | gates = torch.sigmoid(G+mask) 66 | A_tild = F.softmax(H_hat+mask, dim=2) * gates 67 | 68 | if attn_dropout > 0: 69 | A_tild = F.dropout(A_tild, p=attn_dropout, training=training) 70 | 71 | V_att = torch.einsum('blmh,bmkh->blkh', A_tild, V) 72 | 73 | if scale_degree: 74 | degrees = torch.sum(gates,dim=2,keepdim=True) 75 | degree_scalers = torch.log(1+degrees) 76 | degree_scalers[:,:num_vns] = 1. 77 | V_att = V_att * degree_scalers 78 | 79 | V_att = V_att.reshape(shp[0],shp[1],num_heads*dot_dim) 80 | return V_att, H_hat 81 | 82 | @staticmethod 83 | @torch.jit.script 84 | def _egt_edge(scale_dot: bool, 85 | num_heads: int, 86 | dot_dim: int, 87 | clip_logits_min: float, 88 | clip_logits_max: float, 89 | QK: torch.Tensor, 90 | E: torch.Tensor) -> torch.Tensor: 91 | shp = QK.shape 92 | Q, K = QK.view(shp[0],shp[1],-1,num_heads).split(dot_dim,dim=2) 93 | 94 | A_hat = torch.einsum('bldh,bmdh->blmh', Q, K) 95 | if scale_dot: 96 | A_hat = A_hat * (dot_dim ** -0.5) 97 | H_hat = A_hat.clamp(clip_logits_min, clip_logits_max) + E 98 | return H_hat 99 | 100 | def __init__(self, 101 | node_width , 102 | edge_width , 103 | num_heads , 104 | node_mha_dropout = 0 , 105 | edge_mha_dropout = 0 , 106 | node_ffn_dropout = 0 , 107 | edge_ffn_dropout = 0 , 108 | attn_dropout = 0 , 109 | attn_maskout = 0 , 110 | activation = 'elu' , 111 | clip_logits_value = [-5,5] , 112 | node_ffn_multiplier = 2. , 113 | edge_ffn_multiplier = 2. , 114 | scale_dot = True , 115 | scale_degree = False , 116 | node_update = True , 117 | edge_update = True , 118 | ): 119 | super().__init__() 120 | self.node_width = node_width 121 | self.edge_width = edge_width 122 | self.num_heads = num_heads 123 | self.node_mha_dropout = node_mha_dropout 124 | self.edge_mha_dropout = edge_mha_dropout 125 | self.node_ffn_dropout = node_ffn_dropout 126 | self.edge_ffn_dropout = edge_ffn_dropout 127 | self.attn_dropout = attn_dropout 128 | self.attn_maskout = attn_maskout 129 | self.activation = activation 130 | self.clip_logits_value = clip_logits_value 131 | self.node_ffn_multiplier = node_ffn_multiplier 132 | self.edge_ffn_multiplier = edge_ffn_multiplier 133 | self.scale_dot = scale_dot 134 | self.scale_degree = scale_degree 135 | self.node_update = node_update 136 | self.edge_update = edge_update 137 | 138 | assert not (self.node_width % self.num_heads) 139 | self.dot_dim = self.node_width//self.num_heads 140 | 141 | self.mha_ln_h = nn.LayerNorm(self.node_width) 142 | self.mha_ln_e = nn.LayerNorm(self.edge_width) 143 | self.lin_E = nn.Linear(self.edge_width, self.num_heads) 144 | if self.node_update: 145 | self.lin_QKV = nn.Linear(self.node_width, self.node_width*3) 146 | self.lin_G = nn.Linear(self.edge_width, self.num_heads) 147 | else: 148 | self.lin_QKV = nn.Linear(self.node_width, self.node_width*2) 149 | 150 | self.ffn_fn = getattr(F, self.activation) 151 | if self.node_update: 152 | self.lin_O_h = nn.Linear(self.node_width, self.node_width) 153 | if self.node_mha_dropout > 0: 154 | self.mha_drp_h = nn.Dropout(self.node_mha_dropout) 155 | 156 | node_inner_dim = round(self.node_width*self.node_ffn_multiplier) 157 | self.ffn_ln_h = nn.LayerNorm(self.node_width) 158 | self.lin_W_h_1 = nn.Linear(self.node_width, node_inner_dim) 159 | self.lin_W_h_2 = nn.Linear(node_inner_dim, self.node_width) 160 | if self.node_ffn_dropout > 0: 161 | self.ffn_drp_h = nn.Dropout(self.node_ffn_dropout) 162 | 163 | if self.edge_update: 164 | self.lin_O_e = nn.Linear(self.num_heads, self.edge_width) 165 | if self.edge_mha_dropout > 0: 166 | self.mha_drp_e = nn.Dropout(self.edge_mha_dropout) 167 | 168 | edge_inner_dim = round(self.edge_width*self.edge_ffn_multiplier) 169 | self.ffn_ln_e = nn.LayerNorm(self.edge_width) 170 | self.lin_W_e_1 = nn.Linear(self.edge_width, edge_inner_dim) 171 | self.lin_W_e_2 = nn.Linear(edge_inner_dim, self.edge_width) 172 | if self.edge_ffn_dropout > 0: 173 | self.ffn_drp_e = nn.Dropout(self.edge_ffn_dropout) 174 | 175 | def forward(self, g): 176 | h, e = g.h, g.e 177 | mask = g.mask 178 | 179 | h_r1 = h 180 | e_r1 = e 181 | 182 | h_ln = self.mha_ln_h(h) 183 | e_ln = self.mha_ln_e(e) 184 | 185 | QKV = self.lin_QKV(h_ln) 186 | E = self.lin_E(e_ln) 187 | 188 | if self.node_update: 189 | G = self.lin_G(e_ln) 190 | V_att, H_hat = self._egt(self.scale_dot, 191 | self.scale_degree, 192 | self.num_heads, 193 | self.dot_dim, 194 | self.clip_logits_value[0], 195 | self.clip_logits_value[1], 196 | self.attn_dropout, 197 | self.attn_maskout, 198 | self.training, 199 | 0 if 'num_vns' not in g else g.num_vns, 200 | QKV, 201 | G, E, mask) 202 | 203 | h = self.lin_O_h(V_att) 204 | if self.node_mha_dropout > 0: 205 | h = self.mha_drp_h(h) 206 | h.add_(h_r1) 207 | 208 | h_r2 = h 209 | h_ln = self.ffn_ln_h(h) 210 | h = self.lin_W_h_2(self.ffn_fn(self.lin_W_h_1(h_ln))) 211 | if self.node_ffn_dropout > 0: 212 | h = self.ffn_drp_h(h) 213 | h.add_(h_r2) 214 | else: 215 | H_hat = self._egt_edge(self.scale_dot, 216 | self.num_heads, 217 | self.dot_dim, 218 | self.clip_logits_value[0], 219 | self.clip_logits_value[1], 220 | QKV, E) 221 | 222 | 223 | if self.edge_update: 224 | e = self.lin_O_e(H_hat) 225 | if self.edge_mha_dropout > 0: 226 | e = self.mha_drp_e(e) 227 | e.add_(e_r1) 228 | 229 | e_r2 = e 230 | e_ln = self.ffn_ln_e(e) 231 | e = self.lin_W_e_2(self.ffn_fn(self.lin_W_e_1(e_ln))) 232 | if self.edge_ffn_dropout > 0: 233 | e = self.ffn_drp_e(e) 234 | e.add_(e_r2) 235 | 236 | g = g.copy() 237 | g.h, g.e = h, e 238 | return g 239 | 240 | def __repr__(self): 241 | rep = super().__repr__() 242 | rep = (rep + ' (' 243 | + f'num_heads: {self.num_heads},' 244 | + f'activation: {self.activation},' 245 | + f'attn_maskout: {self.attn_maskout},' 246 | + f'attn_dropout: {self.attn_dropout}' 247 | +')') 248 | return rep 249 | 250 | 251 | 252 | class VirtualNodes(nn.Module): 253 | def __init__(self, node_width, edge_width, num_virtual_nodes = 1): 254 | super().__init__() 255 | self.node_width = node_width 256 | self.edge_width = edge_width 257 | self.num_virtual_nodes = num_virtual_nodes 258 | 259 | self.vn_node_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, 260 | self.node_width)) 261 | self.vn_edge_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, 262 | self.edge_width)) 263 | nn.init.normal_(self.vn_node_embeddings) 264 | nn.init.normal_(self.vn_edge_embeddings) 265 | 266 | def forward(self, g): 267 | h, e = g.h, g.e 268 | mask = g.mask 269 | 270 | node_emb = self.vn_node_embeddings.unsqueeze(0).expand(h.shape[0], -1, -1) 271 | h = torch.cat([node_emb, h], dim=1) 272 | 273 | e_shape = e.shape 274 | edge_emb_row = self.vn_edge_embeddings.unsqueeze(1) 275 | edge_emb_col = self.vn_edge_embeddings.unsqueeze(0) 276 | edge_emb_box = 0.5 * (edge_emb_row + edge_emb_col) 277 | 278 | edge_emb_row = edge_emb_row.unsqueeze(0).expand(e_shape[0], -1, e_shape[2], -1) 279 | edge_emb_col = edge_emb_col.unsqueeze(0).expand(e_shape[0], e_shape[1], -1, -1) 280 | edge_emb_box = edge_emb_box.unsqueeze(0).expand(e_shape[0], -1, -1, -1) 281 | 282 | e = torch.cat([edge_emb_row, e], dim=1) 283 | e_col_box = torch.cat([edge_emb_box, edge_emb_col], dim=1) 284 | e = torch.cat([e_col_box, e], dim=2) 285 | 286 | g = g.copy() 287 | g.h, g.e = h, e 288 | 289 | g.num_vns = self.num_virtual_nodes 290 | 291 | if mask is not None: 292 | g.mask = F.pad(mask, (0,0, self.num_virtual_nodes,0, self.num_virtual_nodes,0), 293 | mode='constant', value=0) 294 | return g 295 | 296 | -------------------------------------------------------------------------------- /lib/models/egt_molgraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from .egt import EGT 5 | from .egt_layers import VirtualNodes 6 | 7 | 8 | NODE_FEATURES_OFFSET = 128 9 | NUM_NODE_FEATURES = 9 10 | EDGE_FEATURES_OFFSET = 8 11 | NUM_EDGE_FEATURES = 3 12 | 13 | 14 | class EGT_MOL(EGT): 15 | def __init__(self, 16 | upto_hop = 16, 17 | mlp_ratios = [1., 1.], 18 | num_virtual_nodes = 0, 19 | svd_encodings = 0, 20 | output_dim = 1, 21 | **kwargs): 22 | super().__init__(node_ended=True, **kwargs) 23 | 24 | self.upto_hop = upto_hop 25 | self.mlp_ratios = mlp_ratios 26 | self.num_virtual_nodes = num_virtual_nodes 27 | self.svd_encodings = svd_encodings 28 | self.output_dim = output_dim 29 | 30 | self.nodef_embed = nn.Embedding(NUM_NODE_FEATURES*NODE_FEATURES_OFFSET+1, 31 | self.node_width, padding_idx=0) 32 | if self.svd_encodings: 33 | self.svd_embed = nn.Linear(self.svd_encodings*2, self.node_width) 34 | 35 | self.dist_embed = nn.Embedding(self.upto_hop+2, self.edge_width) 36 | self.featm_embed = nn.Embedding(NUM_EDGE_FEATURES*EDGE_FEATURES_OFFSET+1, 37 | self.edge_width, padding_idx=0) 38 | 39 | if self.num_virtual_nodes > 0: 40 | self.vn_layer = VirtualNodes(self.node_width, self.edge_width, 41 | self.num_virtual_nodes) 42 | 43 | self.final_ln_h = nn.LayerNorm(self.node_width) 44 | mlp_dims = [self.node_width * max(self.num_virtual_nodes, 1)]\ 45 | +[round(self.node_width*r) for r in self.mlp_ratios]\ 46 | +[self.output_dim] 47 | self.mlp_layers = nn.ModuleList([nn.Linear(mlp_dims[i],mlp_dims[i+1]) 48 | for i in range(len(mlp_dims)-1)]) 49 | self.mlp_fn = getattr(F, self.activation) 50 | 51 | 52 | def input_block(self, inputs): 53 | g = super().input_block(inputs) 54 | nodef = g.node_features.long() # (b,i,f) 55 | nodem = g.node_mask.float() # (b,i) 56 | 57 | dm0 = g.distance_matrix # (b,i,j) 58 | dm = dm0.long().clamp(max=self.upto_hop+1) # (b,i,j) 59 | featm = g.feature_matrix.long() # (b,i,j,f) 60 | 61 | h = self.nodef_embed(nodef).sum(dim=2) # (b,i,w,h) -> (b,i,h) 62 | 63 | if self.svd_encodings: 64 | h = h + self.svd_embed(g.svd_encodings) 65 | 66 | e = self.dist_embed(dm)\ 67 | + self.featm_embed(featm).sum(dim=3) # (b,i,j,f,e) -> (b,i,j,e) 68 | 69 | g.mask = (nodem[:,:,None,None] * nodem[:,None,:,None] - 1)*1e9 70 | g.h, g.e = h, e 71 | 72 | if self.num_virtual_nodes > 0: 73 | g = self.vn_layer(g) 74 | return g 75 | 76 | def final_embedding(self, g): 77 | h = g.h 78 | h = self.final_ln_h(h) 79 | if self.num_virtual_nodes > 0: 80 | h = h[:,:self.num_virtual_nodes].reshape(h.shape[0],-1) 81 | else: 82 | nodem = g.node_mask.float().unsqueeze(dim=-1) 83 | h = (h*nodem).sum(dim=1)/(nodem.sum(dim=1)+1e-9) 84 | g.h = h 85 | return g 86 | 87 | def output_block(self, g): 88 | h = g.h 89 | h = self.mlp_layers[0](h) 90 | for layer in self.mlp_layers[1:]: 91 | h = layer(self.mlp_fn(h)) 92 | return h 93 | 94 | 95 | -------------------------------------------------------------------------------- /lib/models/molhiv/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * -------------------------------------------------------------------------------- /lib/models/molhiv/model.py: -------------------------------------------------------------------------------- 1 | from ..egt_molgraph import EGT_MOL 2 | 3 | class EGT_MOLHIV(EGT_MOL): 4 | def __init__(self, **kwargs): 5 | super().__init__(output_dim=1, **kwargs) 6 | -------------------------------------------------------------------------------- /lib/models/molpcba/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * -------------------------------------------------------------------------------- /lib/models/molpcba/model.py: -------------------------------------------------------------------------------- 1 | from ..egt_molgraph import EGT_MOL 2 | 3 | class EGT_MOLPCBA(EGT_MOL): 4 | def __init__(self, **kwargs): 5 | super().__init__(output_dim=128, **kwargs) -------------------------------------------------------------------------------- /lib/models/pcqm4m/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * -------------------------------------------------------------------------------- /lib/models/pcqm4m/model.py: -------------------------------------------------------------------------------- 1 | from ..egt_molgraph import EGT_MOL 2 | 3 | class EGT_PCQM4M(EGT_MOL): 4 | def __init__(self, **kwargs): 5 | super().__init__(output_dim=1, **kwargs) 6 | 7 | def output_block(self, g): 8 | h = super().output_block(g) 9 | return h.squeeze(-1) 10 | -------------------------------------------------------------------------------- /lib/models/pcqm4mv2/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * -------------------------------------------------------------------------------- /lib/models/pcqm4mv2/model.py: -------------------------------------------------------------------------------- 1 | from ..egt_molgraph import EGT_MOL 2 | 3 | class EGT_PCQM4MV2(EGT_MOL): 4 | def __init__(self, **kwargs): 5 | super().__init__(output_dim=1, **kwargs) 6 | 7 | def output_block(self, g): 8 | h = super().output_block(g) 9 | return h.squeeze(-1) 10 | -------------------------------------------------------------------------------- /lib/training/execute.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from yaml import SafeLoader as yaml_Loader 4 | import numpy as np 5 | import torch 6 | import random 7 | from lib.training.training import read_config_from_file 8 | import importlib 9 | 10 | MASTER_ADDR = 'localhost' 11 | MASTER_PORT = '12356' 12 | 13 | SCHEME_LIB = 'lib.training.schemes' 14 | SCHEME_CLS = 'SCHEME' 15 | 16 | KEY_SCHEME = 'scheme' 17 | KEY_SEED = 'random_seed' 18 | KEY_DISTRIBUTED = 'distributed' 19 | 20 | COMMANDS = { 21 | 'train': 'execute_training', 22 | 'predict': 'make_predictions', 23 | 'evaluate': 'do_evaluations', 24 | } 25 | 26 | DEFAULT_CONFIG_FILE = 'config_input.yaml' 27 | 28 | def get_configs_from_args(args): 29 | config = {} 30 | args = args[1:].copy() 31 | 32 | if os.path.isfile(args[0]): 33 | config.update(read_config_from_file(args[0])) 34 | args = args[1:] 35 | elif os.path.isdir(args[0]): 36 | config_path = os.path.join(args[0], 'config_input.yaml') 37 | config.update(read_config_from_file(config_path)) 38 | args = args[1:] 39 | 40 | if len(args)>0: 41 | additional_configs = yaml.load('\n'.join(args), 42 | Loader=yaml_Loader) 43 | config.update(additional_configs) 44 | 45 | if not KEY_SCHEME in config: 46 | raise ValueError(f'"{KEY_SCHEME}" is not in config!') 47 | return config 48 | 49 | def import_scheme(scheme_name): 50 | full_name = f'{SCHEME_LIB}.{scheme_name}.{SCHEME_CLS}' 51 | module_name, object_name = full_name.rsplit('.', 1) 52 | imported_module = importlib.import_module(module_name) 53 | return getattr(imported_module, object_name) 54 | 55 | 56 | def run_worker(rank, world_size, command, scheme_class, config, seed): 57 | torch.cuda.set_device(rank) 58 | torch.manual_seed(seed) 59 | random.seed(seed) 60 | np.random.seed(seed) 61 | torch.distributed.init_process_group(backend="nccl", 62 | rank=rank, 63 | world_size=world_size) 64 | 65 | print(f'Initiated rank: {rank}', flush=True) 66 | try: 67 | scheme = scheme_class(config, rank, world_size) 68 | getattr(scheme, COMMANDS[command])() 69 | finally: 70 | torch.distributed.destroy_process_group() 71 | print(f'Rank {rank}:Destroyed process!', flush=True) 72 | 73 | 74 | def execute(command, config): 75 | scheme_class = import_scheme(config[KEY_SCHEME]) 76 | 77 | world_size = torch.cuda.device_count() 78 | 79 | if KEY_SEED in config and config[KEY_SEED] is not None: 80 | seed = config[KEY_SEED] 81 | else: 82 | seed = random.randint(0, 100000) 83 | 84 | if KEY_DISTRIBUTED in config and config[KEY_DISTRIBUTED] and world_size>1: 85 | os.environ['MASTER_ADDR'] = MASTER_ADDR 86 | os.environ['MASTER_PORT'] = MASTER_PORT 87 | torch.multiprocessing.spawn(fn = run_worker, 88 | args = (world_size,command,scheme_class,config,seed), 89 | nprocs = world_size, 90 | join = True) 91 | else: 92 | torch.manual_seed(seed) 93 | random.seed(seed) 94 | np.random.seed(seed) 95 | scheme = scheme_class(config) 96 | getattr(scheme, COMMANDS[command])() 97 | 98 | -------------------------------------------------------------------------------- /lib/training/schemes/egt_mol_training.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from .egt_training import EGTTraining 5 | from ..training_mixins import LinearLRWarmupCosineDecay, VerboseLR 6 | 7 | class EGT_MOL_Training(LinearLRWarmupCosineDecay, VerboseLR, EGTTraining): 8 | def get_default_config(self): 9 | config_dict = super().get_default_config() 10 | config_dict.update( 11 | num_virtual_nodes = 1, 12 | upto_hop = 16, 13 | svd_calculated_dim = 8, 14 | svd_output_dim = 8, 15 | svd_random_neg = True, 16 | pretrained_weights_file = None, 17 | num_epochs = 1000, 18 | ) 19 | return config_dict 20 | 21 | def get_dataset_config(self): 22 | dataset_config, dataset_class = super().get_dataset_config() 23 | if self.config.svd_output_dim > 0: 24 | dataset_config.update( 25 | calculated_dim = self.config.svd_calculated_dim, 26 | output_dim = self.config.svd_output_dim, 27 | random_neg_splits = ['training'] if self.config.svd_random_neg else [], 28 | ) 29 | return dataset_config, dataset_class 30 | 31 | def get_model_config(self): 32 | model_config, model_class = super().get_model_config() 33 | model_config.update( 34 | num_virtual_nodes = self.config.num_virtual_nodes, 35 | upto_hop = self.config.upto_hop, 36 | svd_encodings = self.config.svd_output_dim, 37 | ) 38 | return model_config, model_class 39 | 40 | def load_checkpoint(self): 41 | super().load_checkpoint() 42 | w_file = self.config.pretrained_weights_file 43 | if w_file is not None and self.state.global_step == 0: 44 | weights = torch.load(w_file) 45 | for k in list(weights.keys()).copy(): 46 | if 'mlp_layers.2' in k: 47 | del weights[k] 48 | 49 | missing, unexpected = self.base_model.load_state_dict(weights, strict=False) 50 | torch.cuda.empty_cache() 51 | if self.is_main_rank: 52 | print(f'Loaded pretrained weights from {w_file}',flush=True) 53 | print(f'missing keys: {missing}',flush=True) 54 | print(f'unexpected keys: {unexpected}',flush=True) 55 | -------------------------------------------------------------------------------- /lib/training/schemes/egt_training.py: -------------------------------------------------------------------------------- 1 | 2 | from lib.training.training import TrainingBase, cached_property, CollatedBatch 3 | from lib.training.testing import TestingBase 4 | from contextlib import nullcontext 5 | from lib.training.training_mixins import SaveModel, VerboseLR 6 | from lib.utils.dotdict import HDict 7 | import torch 8 | from lib.data.graph_dataset import graphdata_collate 9 | 10 | class EGTTraining(TestingBase,TrainingBase): 11 | def get_default_config(self): 12 | config = super().get_default_config() 13 | config.update( 14 | model_name = 'egt', 15 | cache_dir = 'cache_data', 16 | dataset_name = 'unnamed_dataset', 17 | dataset_path = HDict.L('c:f"{c.cache_dir}/{c.dataset_name.upper()}"'), 18 | save_path = HDict.L('c:path.join(f"models/{c.dataset_name.lower()}",c.model_name)'), 19 | model_height = 4, 20 | node_width = 64, 21 | edge_width = 64, 22 | num_heads = 8, 23 | node_dropout = 0., 24 | edge_dropout = 0., 25 | node_ffn_dropout = HDict.L('c:c.node_dropout'), 26 | edge_ffn_dropout = HDict.L('c:c.edge_dropout'), 27 | attn_dropout = 0., 28 | attn_maskout = 0., 29 | activation = 'elu', 30 | clip_logits_value = [-5,5], 31 | scale_degree = True, 32 | node_ffn_multiplier = 1., 33 | edge_ffn_multiplier = 1., 34 | allocate_max_batch = True, 35 | scale_dot_product = True, 36 | egt_simple = False, 37 | ) 38 | return config 39 | 40 | 41 | def get_dataset_config(self): 42 | config = self.config 43 | dataset_config = dict( 44 | dataset_path = config.dataset_path, 45 | cache_dir = config.cache_dir, 46 | ) 47 | return dataset_config, None 48 | 49 | def get_model_config(self): 50 | config = self.config 51 | model_config = dict( 52 | model_height = config.model_height , 53 | node_width = config.node_width , 54 | edge_width = config.edge_width , 55 | num_heads = config.num_heads , 56 | node_mha_dropout = config.node_dropout , 57 | edge_mha_dropout = config.edge_dropout , 58 | node_ffn_dropout = config.node_ffn_dropout , 59 | edge_ffn_dropout = config.edge_ffn_dropout , 60 | attn_dropout = config.attn_dropout , 61 | attn_maskout = config.attn_maskout , 62 | activation = config.activation , 63 | clip_logits_value = config.clip_logits_value , 64 | scale_degree = config.scale_degree , 65 | node_ffn_multiplier = config.node_ffn_multiplier , 66 | edge_ffn_multiplier = config.edge_ffn_multiplier , 67 | scale_dot = config.scale_dot_product , 68 | egt_simple = config.egt_simple , 69 | ) 70 | return model_config, None 71 | 72 | def _cache_dataset(self, dataset): 73 | if self.is_main_rank: 74 | dataset.cache() 75 | self.distributed_barrier() 76 | if not self.is_main_rank: 77 | dataset.cache(verbose=0) 78 | 79 | def _get_dataset(self, split): 80 | dataset_config, dataset_class = self.get_dataset_config() 81 | if dataset_class is None: 82 | raise NotImplementedError('Dataset class not specified') 83 | dataset = dataset_class(**dataset_config, split=split) 84 | self._cache_dataset(dataset) 85 | return dataset 86 | 87 | @cached_property 88 | def train_dataset(self): 89 | return self._get_dataset('training') 90 | @cached_property 91 | def val_dataset(self): 92 | return self._get_dataset('validation') 93 | @cached_property 94 | def test_dataset(self): 95 | return self._get_dataset('test') 96 | 97 | @property 98 | def collate_fn(self): 99 | return graphdata_collate 100 | 101 | @cached_property 102 | def base_model(self): 103 | model_config, model_class = self.get_model_config() 104 | if model_class is None: 105 | raise NotImplementedError 106 | model = model_class(**model_config).cuda() 107 | return model 108 | 109 | def prepare_for_training(self): 110 | # cache datasets in same order on all ranks 111 | if self.is_distributed: 112 | self.train_dataset 113 | self.val_dataset 114 | super().prepare_for_training() 115 | 116 | # GPU memory cache for biggest batch 117 | if self.config.allocate_max_batch: 118 | if self.is_main_rank: print('Allocating cache for max batch size...', flush=True) 119 | torch.cuda.empty_cache() 120 | self.model.train() 121 | max_batch = self.train_dataset.max_batch(self.config.batch_size, self.collate_fn) 122 | max_batch = self.preprocess_batch(max_batch) 123 | 124 | outputs = self.model(max_batch) 125 | loss = self.calculate_loss(outputs=outputs, inputs=max_batch) 126 | loss.backward() 127 | 128 | for param in self.model.parameters(): 129 | param.grad = None 130 | 131 | def initialize_losses(self, logs, training): 132 | self._total_loss = 0. 133 | self._total_samples = 0. 134 | 135 | def update_losses(self, i, loss, inputs, logs, training): 136 | if not isinstance(inputs, CollatedBatch): 137 | step_samples = float(inputs['num_nodes'].shape[0]) 138 | else: 139 | step_samples = float(sum(i['num_nodes'].shape[0] for i in inputs)) 140 | if not self.is_distributed: 141 | step_loss = loss.item() * step_samples 142 | else: 143 | step_samples = torch.tensor(step_samples, device=loss.device, 144 | dtype=loss.dtype) 145 | 146 | if training: 147 | loss = loss.detach() 148 | step_loss = loss * step_samples 149 | 150 | torch.distributed.all_reduce(step_loss) 151 | torch.distributed.all_reduce(step_samples) 152 | 153 | step_loss = step_loss.item() 154 | step_samples = step_samples.item() 155 | 156 | self._total_loss += step_loss 157 | self._total_samples += step_samples 158 | self.update_logs(logs=logs, training=training, 159 | loss=self._total_loss/self._total_samples) 160 | 161 | -------------------------------------------------------------------------------- /lib/training/schemes/molhiv/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheme import SCHEME -------------------------------------------------------------------------------- /lib/training/schemes/molhiv/scheme.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from lib.training.training import cached_property 5 | from ..egt_mol_training import EGT_MOL_Training 6 | 7 | from lib.models.molhiv import EGT_MOLHIV 8 | from lib.data.molhiv import MOLHIVStructuralSVDGraphDataset 9 | 10 | class MOLHIV_Training(EGT_MOL_Training): 11 | def get_default_config(self): 12 | config_dict = super().get_default_config() 13 | config_dict.update( 14 | dataset_name = 'molhiv', 15 | dataset_path = 'cache_data/MOLHIV', 16 | evaluation_type = 'prediction', 17 | predict_on = ['test'], 18 | state_file = None, 19 | ) 20 | return config_dict 21 | 22 | def get_dataset_config(self): 23 | dataset_config, _ = super().get_dataset_config() 24 | return dataset_config, MOLHIVStructuralSVDGraphDataset 25 | 26 | def get_model_config(self): 27 | model_config, _ = super().get_model_config() 28 | return model_config, EGT_MOLHIV 29 | 30 | def calculate_bce_loss(self, outputs, targets): 31 | outputs = outputs.view(-1) 32 | targets = targets.view(-1) 33 | return F.binary_cross_entropy_with_logits(outputs, targets) 34 | 35 | def calculate_loss(self, outputs, inputs): 36 | return self.calculate_bce_loss(outputs, inputs['target']) 37 | 38 | @cached_property 39 | def evaluator(self): 40 | from ogb.graphproppred import Evaluator 41 | evaluator = Evaluator(name = "ogbg-molhiv") 42 | return evaluator 43 | 44 | def prediction_step(self, batch): 45 | return dict( 46 | predictions = torch.sigmoid(self.model(batch)), 47 | targets = batch['target'], 48 | ) 49 | 50 | def evaluate_predictions(self, predictions): 51 | input_dict = {"y_true": predictions['targets'], 52 | "y_pred": predictions['predictions']} 53 | results = self.evaluator.eval(input_dict) 54 | 55 | xent = self.calculate_bce_loss(torch.from_numpy(predictions['predictions']), 56 | torch.from_numpy(predictions['targets'])).item() 57 | results['xent'] = xent 58 | 59 | for k, v in results.items(): 60 | if hasattr(v, 'tolist'): 61 | results[k] = v.tolist() 62 | return results 63 | 64 | def evaluate_on(self, dataset_name, dataset, predictions): 65 | print(f'Evaluating on {dataset_name}') 66 | results = self.evaluate_predictions(predictions) 67 | return results 68 | 69 | SCHEME = MOLHIV_Training 70 | -------------------------------------------------------------------------------- /lib/training/schemes/molpcba/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheme import SCHEME -------------------------------------------------------------------------------- /lib/training/schemes/molpcba/scheme.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from lib.training.training import cached_property 5 | from ..egt_mol_training import EGT_MOL_Training 6 | 7 | from lib.models.molpcba import EGT_MOLPCBA 8 | from lib.data.molpcba import MOLPCBAStructuralSVDGraphDataset 9 | 10 | class MOLPCBA_Training(EGT_MOL_Training): 11 | def get_default_config(self): 12 | config_dict = super().get_default_config() 13 | config_dict.update( 14 | dataset_name = 'molpcba', 15 | dataset_path = 'cache_data/MOLPCBA', 16 | evaluation_type = 'prediction', 17 | predict_on = ['test'], 18 | state_file = None, 19 | ) 20 | return config_dict 21 | 22 | 23 | def get_dataset_config(self): 24 | dataset_config, _ = super().get_dataset_config() 25 | return dataset_config, MOLPCBAStructuralSVDGraphDataset 26 | 27 | def get_model_config(self): 28 | model_config, _ = super().get_model_config() 29 | return model_config, EGT_MOLPCBA 30 | 31 | def calculate_masked_loss(self, outputs, targets): 32 | outputs = outputs.view(-1) 33 | targets = targets.view(-1) 34 | targets_mask = (targets == targets) 35 | loss = F.binary_cross_entropy_with_logits(outputs[targets_mask], 36 | targets[targets_mask]) 37 | return loss 38 | 39 | def calculate_loss(self, outputs, inputs): 40 | return self.calculate_masked_loss(outputs, inputs['target']) 41 | 42 | @cached_property 43 | def evaluator(self): 44 | from ogb.graphproppred import Evaluator 45 | evaluator = Evaluator(name = "ogbg-molpcba") 46 | return evaluator 47 | 48 | def prediction_step(self, batch): 49 | return dict( 50 | predictions = torch.sigmoid(self.model(batch)), 51 | targets = batch['target'], 52 | ) 53 | 54 | def evaluate_predictions(self, predictions): 55 | input_dict = {"y_true": predictions['targets'], 56 | "y_pred": predictions['predictions']} 57 | results = self.evaluator.eval(input_dict) 58 | 59 | xent = self.calculate_masked_loss(torch.from_numpy(predictions['predictions']), 60 | torch.from_numpy(predictions['targets'])).item() 61 | results['xent'] = xent 62 | 63 | for k, v in results.items(): 64 | if hasattr(v, 'tolist'): 65 | results[k] = v.tolist() 66 | return results 67 | 68 | def evaluate_on(self, dataset_name, dataset, predictions): 69 | print(f'Evaluating on {dataset_name}') 70 | results = self.evaluate_predictions(predictions) 71 | return results 72 | 73 | 74 | SCHEME = MOLPCBA_Training 75 | -------------------------------------------------------------------------------- /lib/training/schemes/pcqm4m/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheme import SCHEME -------------------------------------------------------------------------------- /lib/training/schemes/pcqm4m/scheme.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | from lib.training.training import cached_property 6 | from ..egt_mol_training import EGT_MOL_Training 7 | 8 | from lib.models.pcqm4m import EGT_PCQM4M 9 | from lib.data.pcqm4m import PCQM4MStructuralSVDGraphDataset 10 | 11 | class PCQM4M_Training(EGT_MOL_Training): 12 | def get_default_config(self): 13 | config_dict = super().get_default_config() 14 | config_dict.update( 15 | dataset_name = 'pcqm4m', 16 | dataset_path = 'cache_data/PCQM4M', 17 | predict_on = ['val'], 18 | evaluate_on = ['val'], 19 | state_file = None, 20 | ) 21 | return config_dict 22 | 23 | def get_dataset_config(self): 24 | dataset_config, _ = super().get_dataset_config() 25 | return dataset_config, PCQM4MStructuralSVDGraphDataset 26 | 27 | def get_model_config(self): 28 | model_config, _ = super().get_model_config() 29 | return model_config, EGT_PCQM4M 30 | 31 | def calculate_loss(self, outputs, inputs): 32 | return F.l1_loss(outputs, inputs['target']) 33 | 34 | @cached_property 35 | def evaluator(self): 36 | from ogb.lsc.pcqm4m import PCQM4MEvaluator 37 | evaluator = PCQM4MEvaluator() 38 | return evaluator 39 | 40 | def prediction_step(self, batch): 41 | return dict( 42 | predictions = self.model(batch), 43 | targets = batch['target'], 44 | ) 45 | 46 | def evaluate_on(self, dataset_name, dataset, predictions): 47 | print(f'Evaluating on {dataset_name}') 48 | input_dict = {"y_true": predictions['targets'], 49 | "y_pred": predictions['predictions']} 50 | results = self.evaluator.eval(input_dict) 51 | for k, v in results.items(): 52 | if hasattr(v, 'tolist'): 53 | results[k] = v.tolist() 54 | return results 55 | 56 | SCHEME = PCQM4M_Training 57 | -------------------------------------------------------------------------------- /lib/training/schemes/pcqm4mv2/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheme import SCHEME -------------------------------------------------------------------------------- /lib/training/schemes/pcqm4mv2/scheme.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | from lib.training.training import cached_property 6 | from ..egt_mol_training import EGT_MOL_Training 7 | 8 | from lib.models.pcqm4mv2 import EGT_PCQM4MV2 9 | from lib.data.pcqm4mv2 import PCQM4Mv2StructuralSVDGraphDataset 10 | 11 | class PCQM4MV2_Training(EGT_MOL_Training): 12 | def get_default_config(self): 13 | config_dict = super().get_default_config() 14 | config_dict.update( 15 | dataset_name = 'pcqm4mv2', 16 | dataset_path = 'cache_data/PCQM4MV2', 17 | predict_on = ['val','test'], 18 | evaluate_on = ['val','test'], 19 | state_file = None, 20 | ) 21 | return config_dict 22 | 23 | def get_dataset_config(self): 24 | dataset_config, _ = super().get_dataset_config() 25 | return dataset_config, PCQM4Mv2StructuralSVDGraphDataset 26 | 27 | def get_model_config(self): 28 | model_config, _ = super().get_model_config() 29 | return model_config, EGT_PCQM4MV2 30 | 31 | def calculate_loss(self, outputs, inputs): 32 | return F.l1_loss(outputs, inputs['target']) 33 | 34 | @cached_property 35 | def evaluator(self): 36 | from ogb.lsc.pcqm4mv2 import PCQM4Mv2Evaluator 37 | evaluator = PCQM4Mv2Evaluator() 38 | return evaluator 39 | 40 | def prediction_step(self, batch): 41 | return dict( 42 | predictions = self.model(batch), 43 | targets = batch['target'], 44 | ) 45 | 46 | def evaluate_on(self, dataset_name, dataset, predictions): 47 | if dataset_name == 'test': 48 | self.evaluator.save_test_submission( 49 | input_dict = {'y_pred': predictions['predictions']}, 50 | dir_path = self.config.predictions_path, 51 | mode = 'test-dev', 52 | ) 53 | print(f'Saved final test-dev predictions to {self.config.predictions_path}') 54 | return {'mae': np.nan} 55 | 56 | print(f'Evaluating on {dataset_name}') 57 | input_dict = {"y_true": predictions['targets'], 58 | "y_pred": predictions['predictions']} 59 | results = self.evaluator.eval(input_dict) 60 | for k, v in results.items(): 61 | if hasattr(v, 'tolist'): 62 | results[k] = v.tolist() 63 | return results 64 | 65 | SCHEME = PCQM4MV2_Training 66 | -------------------------------------------------------------------------------- /lib/training/testing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | import os 5 | import yaml 6 | from yaml import SafeDumper as yaml_Dumper 7 | 8 | from lib.utils.dotdict import HDict 9 | from lib.training.training import TrainingBase, DistributedTestDataSampler, cached_property 10 | 11 | class TestingBase(TrainingBase): 12 | def get_default_config(self): 13 | config = super().get_default_config() 14 | config.update( 15 | state_file = None, 16 | predict_on = ['train', 'val', 'test'], 17 | evaluate_on = HDict.L('c:c.predict_on'), 18 | predictions_path = HDict.L('c:path.join(c.save_path,"predictions")'), 19 | ) 20 | return config 21 | 22 | @cached_property 23 | def test_dataset(self): 24 | raise NotImplementedError 25 | 26 | @cached_property 27 | def train_pred_dataloader(self): 28 | prediction_batch_size = self.config.batch_size*self.config.prediction_bmult 29 | if not self.is_distributed: 30 | dataloader = DataLoader(dataset=self.train_dataset, 31 | batch_size=prediction_batch_size, 32 | shuffle=False, 33 | drop_last=False, 34 | collate_fn=self.collate_fn, 35 | pin_memory=True) 36 | else: 37 | sampler = DistributedTestDataSampler(data_source=self.train_dataset, 38 | batch_size=prediction_batch_size, 39 | rank=self.ddp_rank, 40 | world_size=self.ddp_world_size) 41 | dataloader = DataLoader(dataset=self.train_dataset, 42 | collate_fn=self.collate_fn, 43 | batch_sampler=sampler, 44 | pin_memory=True) 45 | return dataloader 46 | 47 | @cached_property 48 | def test_dataloader(self): 49 | prediction_batch_size = self.config.batch_size*self.config.prediction_bmult 50 | if not self.is_distributed: 51 | dataloader = DataLoader(dataset=self.test_dataset, 52 | batch_size=prediction_batch_size, 53 | shuffle=False, 54 | drop_last=False, 55 | collate_fn=self.collate_fn, 56 | pin_memory=True) 57 | else: 58 | sampler = DistributedTestDataSampler(data_source=self.test_dataset, 59 | batch_size=prediction_batch_size, 60 | rank=self.ddp_rank, 61 | world_size=self.ddp_world_size) 62 | dataloader = DataLoader(dataset=self.test_dataset, 63 | collate_fn=self.collate_fn, 64 | batch_sampler=sampler, 65 | pin_memory=True) 66 | return dataloader 67 | 68 | 69 | def test_dataloader_for_dataset(self, dataset_name): 70 | if dataset_name == 'train': 71 | return self.train_pred_dataloader 72 | elif dataset_name == 'val': 73 | return self.val_dataloader 74 | elif dataset_name == 'test': 75 | return self.test_dataloader 76 | else: 77 | raise ValueError(f'Unknown dataset name: {dataset_name}') 78 | 79 | def predict_and_save(self): 80 | for dataset_name in self.config.predict_on: 81 | if self.is_main_rank: 82 | print(f'Predicting on {dataset_name} dataset...') 83 | dataloader = self.test_dataloader_for_dataset(dataset_name) 84 | outputs = self.prediction_loop(dataloader) 85 | outputs = self.preprocess_predictions(outputs) 86 | 87 | if self.is_distributed: 88 | outputs = self.distributed_gather_predictions(outputs) 89 | 90 | if self.is_main_rank: 91 | predictions = self.postprocess_predictions(outputs) 92 | self.save_predictions(dataset_name, predictions) 93 | 94 | 95 | def load_model_state(self): 96 | if self.config.state_file is None: 97 | state_file = os.path.join(self.config.checkpoint_path, 'model_state') 98 | else: 99 | state_file = self.config.state_file 100 | self.base_model.load_state_dict(torch.load(state_file)) 101 | 102 | if self.is_main_rank: 103 | print(f'Loaded model state from {state_file}') 104 | 105 | def prepare_for_testing(self): 106 | self.config_summary() 107 | self.load_model_state() 108 | 109 | def make_predictions(self): 110 | self.prepare_for_testing() 111 | self.predict_and_save() 112 | if len(self.config.evaluate_on) > 0: 113 | self.evaluate_and_save() 114 | 115 | 116 | def get_dataset(self, dataset_name): 117 | if dataset_name == 'train': 118 | return self.train_dataset 119 | elif dataset_name == 'val': 120 | return self.val_dataset 121 | elif dataset_name == 'test': 122 | return self.test_dataset 123 | else: 124 | raise ValueError(f'Unknown dataset name: {dataset_name}') 125 | 126 | def evaluate_on(self, dataset_name, dataset, predictions): 127 | raise NotImplementedError() 128 | 129 | def evaluate_and_save(self): 130 | if not self.is_main_rank: 131 | return 132 | 133 | results = {} 134 | results_file = os.path.join(self.config.predictions_path, 'results.yaml') 135 | 136 | for dataset_name in self.config.evaluate_on: 137 | dataset = self.get_dataset(dataset_name) 138 | predictions = torch.load(os.path.join(self.config.predictions_path, f'{dataset_name}.pt')) 139 | dataset_results = self.evaluate_on(dataset_name, dataset, predictions) 140 | 141 | for k,v in dataset_results.items(): 142 | print(f'{dataset_name} {k}: {v}') 143 | 144 | results[dataset_name] = dataset_results 145 | with open(results_file, 'w') as fp: 146 | yaml.dump(results, fp, sort_keys=False, Dumper=yaml_Dumper) 147 | 148 | 149 | def do_evaluations(self): 150 | self.evaluate_and_save() 151 | 152 | def finalize_training(self): 153 | super().finalize_training() 154 | self.predict_and_save() 155 | if len(self.config.evaluate_on) > 0: 156 | self.evaluate_and_save() 157 | -------------------------------------------------------------------------------- /lib/training/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from collections import OrderedDict 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import Sampler 7 | from contextlib import nullcontext 8 | 9 | import yaml 10 | from yaml import SafeLoader as yaml_Loader, SafeDumper as yaml_Dumper 11 | import os,sys 12 | 13 | from tqdm import tqdm 14 | 15 | from lib.utils.dotdict import HDict 16 | HDict.L.update_globals({'path':os.path}) 17 | 18 | def str_presenter(dumper, data): 19 | if len(data.splitlines()) > 1: # check for multiline string 20 | return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') 21 | return dumper.represent_scalar('tag:yaml.org,2002:str', data) 22 | yaml.representer.SafeRepresenter.add_representer(str, str_presenter) 23 | 24 | 25 | def read_config_from_file(config_file): 26 | with open(config_file, 'r') as fp: 27 | return yaml.load(fp, Loader=yaml_Loader) 28 | 29 | def save_config_to_file(config, config_file): 30 | with open(config_file, 'w') as fp: 31 | return yaml.dump(config, fp, sort_keys=False, Dumper=yaml_Dumper) 32 | 33 | 34 | class StopTrainingException(Exception): 35 | pass 36 | 37 | class CollatedBatch(list): 38 | pass 39 | 40 | class DistributedTestDataSampler(Sampler): 41 | def __init__(self, data_source, batch_size, rank, world_size): 42 | data_len = len(data_source) 43 | all_indices = np.arange(data_len, dtype=int) 44 | split_indices = np.array_split(all_indices, world_size) 45 | 46 | num_batches = (len(split_indices[0]) + batch_size -1) // batch_size 47 | self.batch_indices = [i.tolist() for i in np.array_split(split_indices[rank], 48 | num_batches)] 49 | 50 | def __iter__(self): 51 | return iter(self.batch_indices) 52 | 53 | def __len__(self): 54 | return len(self.batch_indices) 55 | 56 | 57 | 58 | def cached_property(func): 59 | atrribute_name = f'_{func.__name__}' 60 | def _wrapper(self): 61 | try: 62 | return getattr(self, atrribute_name) 63 | except AttributeError: 64 | val = func(self) 65 | self.__dict__[atrribute_name] = val 66 | return val 67 | return property(_wrapper) 68 | 69 | 70 | class TrainingBase: 71 | def __init__(self, config=None, ddp_rank=0, ddp_world_size=1): 72 | self.config_input = config 73 | self.config = self.get_default_config() 74 | if config is not None: 75 | for k in config.keys(): 76 | if not k in self.config: 77 | raise KeyError(f'Unknown config "{k}"') 78 | self.config.update(config) 79 | 80 | self.state = self.get_default_state() 81 | 82 | self.ddp_rank = ddp_rank 83 | self.ddp_world_size = ddp_world_size 84 | self.is_distributed = (self.ddp_world_size > 1) 85 | self.is_main_rank = (self.ddp_rank == 0) 86 | 87 | 88 | @cached_property 89 | def train_dataset(self): 90 | raise NotImplementedError 91 | 92 | @cached_property 93 | def val_dataset(self): 94 | raise NotImplementedError 95 | 96 | @cached_property 97 | def collate_fn(self): 98 | return None 99 | 100 | @cached_property 101 | def train_sampler(self): 102 | return torch.utils.data.DistributedSampler(self.train_dataset, 103 | shuffle=True) 104 | 105 | @cached_property 106 | def train_dataloader(self): 107 | common_kwargs = dict( 108 | dataset=self.train_dataset, 109 | batch_size=self.config.batch_size, 110 | collate_fn=self.collate_fn, 111 | pin_memory=True, 112 | ) 113 | if self.config.dataloader_workers > 0: 114 | common_kwargs.update( 115 | num_workers=self.config.dataloader_workers, 116 | persistent_workers=True, 117 | multiprocessing_context=self.config.dataloader_mp_context, 118 | ) 119 | if not self.is_distributed: 120 | dataloader = DataLoader(**common_kwargs, shuffle=True, 121 | drop_last=False) 122 | else: 123 | dataloader = DataLoader(**common_kwargs, 124 | sampler=self.train_sampler) 125 | return dataloader 126 | 127 | @cached_property 128 | def val_dataloader(self): 129 | common_kwargs = dict( 130 | dataset=self.val_dataset, 131 | collate_fn=self.collate_fn, 132 | pin_memory=True, 133 | ) 134 | if self.config.dataloader_workers > 0: 135 | common_kwargs.update( 136 | num_workers=self.config.dataloader_workers, 137 | persistent_workers=True, 138 | multiprocessing_context=self.config.dataloader_mp_context, 139 | ) 140 | prediction_batch_size = self.config.batch_size*self.config.prediction_bmult 141 | if not self.is_distributed: 142 | dataloader = DataLoader(**common_kwargs, 143 | batch_size=prediction_batch_size, 144 | shuffle=False, drop_last=False) 145 | else: 146 | sampler = DistributedTestDataSampler(data_source=self.val_dataset, 147 | batch_size=prediction_batch_size, 148 | rank=self.ddp_rank, 149 | world_size=self.ddp_world_size) 150 | dataloader = DataLoader(**common_kwargs, batch_sampler=sampler) 151 | return dataloader 152 | 153 | @cached_property 154 | def base_model(self): 155 | raise NotImplementedError 156 | 157 | @cached_property 158 | def model(self): 159 | model = self.base_model 160 | if self.is_distributed: 161 | model = torch.nn.parallel.DistributedDataParallel(model,device_ids=[self.ddp_rank], 162 | output_device=self.ddp_rank) 163 | return model 164 | 165 | @cached_property 166 | def optimizer(self): 167 | config = self.config 168 | optimizer_class = getattr(torch.optim, config.optimizer) 169 | optimizer = optimizer_class(self.model.parameters(), 170 | lr=config.max_lr, 171 | **config.optimizer_params) 172 | return optimizer 173 | 174 | def get_default_config(self): 175 | return HDict( 176 | scheme = None, 177 | model_name = 'unnamed_model', 178 | distributed = False, 179 | random_seed = None, 180 | num_epochs = 100, 181 | save_path = HDict.L('c:path.join("models",c.model_name)'), 182 | checkpoint_path = HDict.L('c:path.join(c.save_path,"checkpoint")'), 183 | config_path = HDict.L('c:path.join(c.save_path,"config")'), 184 | summary_path = HDict.L('c:path.join(c.save_path,"summary")'), 185 | log_path = HDict.L('c:path.join(c.save_path,"logs")'), 186 | validation_frequency = 1, 187 | batch_size = HDict.L('c:128 if c.distributed else 32'), 188 | optimizer = 'Adam' , 189 | max_lr = 5e-4 , 190 | clip_grad_value = None , 191 | optimizer_params = {} , 192 | dataloader_workers = 0 , 193 | dataloader_mp_context = 'forkserver', 194 | training_type = 'normal' , 195 | evaluation_type = 'validation', 196 | predictions_path = HDict.L('c:path.join(c.save_path,"predictions")'), 197 | grad_accum_steps = 1 , 198 | prediction_bmult = 1 , 199 | ) 200 | 201 | def get_default_state(self): 202 | state = HDict( 203 | current_epoch = 0, 204 | global_step = 0, 205 | ) 206 | return state 207 | 208 | def config_summary(self): 209 | if not self.is_main_rank: return 210 | for k,v in self.config.get_dict().items(): 211 | print(f'{k} : {v}', flush=True) 212 | 213 | def save_config_file(self): 214 | if not self.is_main_rank: return 215 | os.makedirs(os.path.dirname(self.config.config_path), exist_ok=True) 216 | save_config_to_file(self.config.get_dict(), self.config.config_path+'.yaml') 217 | save_config_to_file(self.config_input, self.config.config_path+'_input.yaml') 218 | 219 | def model_summary(self): 220 | if not self.is_main_rank: return 221 | os.makedirs(os.path.dirname(self.config.summary_path), exist_ok=True) 222 | trainable_params = 0 223 | non_trainable_params = 0 224 | for p in self.model.parameters(): 225 | if p.requires_grad: 226 | trainable_params += p.numel() 227 | else: 228 | non_trainable_params += p.numel() 229 | summary = dict( 230 | trainable_params = trainable_params, 231 | non_trainable_params = non_trainable_params, 232 | model_representation = repr(self.model), 233 | ) 234 | with open(self.config.summary_path+'.txt', 'w') as fp: 235 | yaml.dump(summary, fp, sort_keys=False, Dumper=yaml_Dumper) 236 | 237 | def save_checkpoint(self): 238 | if not self.is_main_rank: return 239 | ckpt_path = self.config.checkpoint_path 240 | os.makedirs(ckpt_path, exist_ok=True) 241 | 242 | torch.save(self.state, os.path.join(ckpt_path, 'training_state')) 243 | torch.save(self.base_model.state_dict(), os.path.join(ckpt_path, 'model_state')) 244 | torch.save(self.optimizer.state_dict(), os.path.join(ckpt_path, 'optimizer_state')) 245 | print(f'Checkpoint saved to: {ckpt_path}',flush=True) 246 | 247 | def load_checkpoint(self): 248 | ckpt_path = self.config.checkpoint_path 249 | try: 250 | self.state.update(torch.load(os.path.join(ckpt_path, 'training_state'))) 251 | self.base_model.load_state_dict(torch.load(os.path.join(ckpt_path, 'model_state'))) 252 | self.optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, 'optimizer_state'))) 253 | if self.is_main_rank: 254 | print(f'Checkpoint loaded from: {ckpt_path}',flush=True) 255 | torch.cuda.empty_cache() 256 | except FileNotFoundError: 257 | pass 258 | 259 | # Callbacks 260 | def on_train_begin(self): 261 | pass 262 | def on_train_end(self): 263 | pass 264 | def on_epoch_begin(self, logs, training): 265 | pass 266 | def on_epoch_end(self, logs, training): 267 | pass 268 | def on_batch_begin(self, i, logs, training): 269 | pass 270 | def on_batch_end(self, i, logs, training): 271 | pass 272 | 273 | 274 | # Logging 275 | def get_verbose_logs(self): 276 | return OrderedDict(loss='0.4f') 277 | 278 | @cached_property 279 | def verbose_logs(self): 280 | return self.get_verbose_logs() 281 | 282 | def update_logs(self, logs, training, **updates): 283 | if training: 284 | logs.update(updates) 285 | else: 286 | logs.update(('val_'+k,v) for k,v in updates.items()) 287 | 288 | def log_description(self, i, logs, training): 289 | if training: 290 | return list(f'{k} = {logs[k]:{f}}' 291 | for k,f in self.verbose_logs.items()) 292 | else: 293 | return list(f'val_{k} = {logs["val_"+k]:{f}}' 294 | for k,f in self.verbose_logs.items()) 295 | 296 | 297 | # Training loop 298 | def preprocess_batch(self, batch): 299 | if isinstance(batch, CollatedBatch): 300 | return CollatedBatch(self.preprocess_batch(b) for b in batch) 301 | elif hasattr(batch, 'cuda'): 302 | return batch.cuda(non_blocking=True) 303 | elif hasattr(batch, 'items'): 304 | return batch.__class__((k,v.cuda(non_blocking=True)) for k,v in batch.items()) 305 | elif hasattr(batch, '__iter__'): 306 | return batch.__class__(v.cuda(non_blocking=True) for v in batch) 307 | else: 308 | raise ValueError(f'Unsupported batch type: {type(batch)}') 309 | 310 | def calculate_loss(self, outputs, inputs): 311 | raise NotImplementedError 312 | 313 | def grad_accum_gather_outputs(self, outputs): 314 | return torch.cat(outputs, dim=0) 315 | 316 | def grad_accum_reduce_loss(self, loss): 317 | with torch.no_grad(): 318 | total_loss = sum(loss) 319 | return total_loss 320 | 321 | def grad_accum_collator(self, dataloader): 322 | dataloader_iter = iter(dataloader) 323 | if self.config.grad_accum_steps == 1: 324 | yield from dataloader_iter 325 | else: 326 | while True: 327 | collated_batch = CollatedBatch() 328 | try: 329 | for _ in range(self.config.grad_accum_steps): 330 | collated_batch.append(next(dataloader_iter)) 331 | except StopIteration: 332 | break 333 | finally: 334 | if len(collated_batch) > 0: yield collated_batch 335 | 336 | @cached_property 337 | def train_steps_per_epoch(self): 338 | if self.config.grad_accum_steps == 1: 339 | return len(self.train_dataloader) 340 | else: 341 | return (len(self.train_dataloader) + self.config.grad_accum_steps - 1)\ 342 | // self.config.grad_accum_steps 343 | 344 | @cached_property 345 | def validation_steps_per_epoch(self): 346 | return len(self.val_dataloader) 347 | 348 | 349 | def training_step(self, batch, logs): 350 | for param in self.model.parameters(): 351 | param.grad = None 352 | 353 | if not isinstance(batch, CollatedBatch): 354 | outputs = self.model(batch) 355 | loss = self.calculate_loss(outputs=outputs, inputs=batch) 356 | loss.backward() 357 | else: 358 | num_nested_batches = len(batch) 359 | outputs = CollatedBatch() 360 | loss = CollatedBatch() 361 | 362 | sync_context = self.model.no_sync() \ 363 | if self.is_distributed else nullcontext() 364 | with sync_context: 365 | for b in batch: 366 | o = self.model(b) 367 | l = self.calculate_loss(outputs=o, inputs=b) / num_nested_batches 368 | l.backward() 369 | outputs.append(o) 370 | loss.append(l) 371 | 372 | outputs = self.grad_accum_gather_outputs(outputs) 373 | loss = self.grad_accum_reduce_loss(loss) 374 | 375 | if self.config.clip_grad_value is not None: 376 | nn.utils.clip_grad_value_(self.model.parameters(), self.config.clip_grad_value) 377 | self.optimizer.step() 378 | return outputs, loss 379 | 380 | def validation_step(self, batch, logs): 381 | outputs = self.model(batch) 382 | loss = self.calculate_loss(outputs=outputs, inputs=batch) 383 | return outputs, loss 384 | 385 | def initialize_metrics(self, logs, training): 386 | pass 387 | 388 | def update_metrics(self, outputs, inputs, logs, training): 389 | pass 390 | 391 | def initialize_losses(self, logs, training): 392 | self._total_loss = 0. 393 | 394 | def update_losses(self, i, loss, inputs, logs, training): 395 | if not self.is_distributed: 396 | step_loss = loss.item() 397 | else: 398 | if training: 399 | loss = loss.detach() 400 | torch.distributed.all_reduce(loss) 401 | step_loss = loss.item()/self.ddp_world_size 402 | self._total_loss += step_loss 403 | self.update_logs(logs=logs, training=training, 404 | loss=self._total_loss/(i+1)) 405 | 406 | 407 | def train_epoch(self, epoch, logs): 408 | self.model.train() 409 | self.initialize_losses(logs, True) 410 | self.initialize_metrics(logs, True) 411 | 412 | if self.is_distributed: 413 | self.train_sampler.set_epoch(epoch) 414 | 415 | gen = self.grad_accum_collator(self.train_dataloader) 416 | if self.is_main_rank: 417 | gen = tqdm(gen, dynamic_ncols=True, 418 | total=self.train_steps_per_epoch) 419 | try: 420 | for i, batch in enumerate(gen): 421 | self.on_batch_begin(i, logs, True) 422 | batch = self.preprocess_batch(batch) 423 | outputs, loss = self.training_step(batch, logs) 424 | 425 | self.state.global_step = self.state.global_step + 1 426 | logs.update(global_step=self.state.global_step) 427 | 428 | self.update_losses(i, loss, batch, logs, True) 429 | self.update_metrics(outputs, batch, logs, True) 430 | 431 | self.on_batch_end(i, logs, True) 432 | 433 | if self.is_main_rank: 434 | desc = 'Training: '+'; '.join(self.log_description(i, logs, True)) 435 | gen.set_description(desc) 436 | finally: 437 | if self.is_main_rank: gen.close() 438 | for param in self.model.parameters(): 439 | param.grad = None 440 | 441 | def minimal_train_epoch(self, epoch, logs): 442 | self.model.train() 443 | 444 | if self.is_distributed: 445 | self.train_sampler.set_epoch(epoch) 446 | 447 | gen = self.grad_accum_collator(self.train_dataloader) 448 | if self.is_main_rank: 449 | gen = tqdm(gen, dynamic_ncols=True, desc='Training: ', 450 | total=self.train_steps_per_epoch) 451 | try: 452 | for i, batch in enumerate(gen): 453 | self.on_batch_begin(i, logs, True) 454 | batch = self.preprocess_batch(batch) 455 | _ = self.training_step(batch, logs) 456 | 457 | self.state.global_step = self.state.global_step + 1 458 | logs.update(global_step=self.state.global_step) 459 | 460 | self.on_batch_end(i, logs, True) 461 | finally: 462 | if self.is_main_rank: gen.close() 463 | for param in self.model.parameters(): 464 | param.grad = None 465 | 466 | 467 | def validation_epoch(self, epoch, logs): 468 | self.model.eval() 469 | self.initialize_losses(logs, False) 470 | self.initialize_metrics(logs, False) 471 | 472 | gen = self.val_dataloader 473 | if self.is_main_rank: 474 | gen = tqdm(gen, dynamic_ncols=True, 475 | total=self.validation_steps_per_epoch) 476 | try: 477 | with torch.no_grad(): 478 | for i, batch in enumerate(gen): 479 | self.on_batch_begin(i, logs, False) 480 | batch = self.preprocess_batch(batch) 481 | outputs, loss = self.validation_step(batch, logs) 482 | 483 | self.update_losses(i, loss, batch, logs, False) 484 | self.update_metrics(outputs, batch, logs, False) 485 | 486 | self.on_batch_end(i, logs, False) 487 | 488 | if self.is_main_rank: 489 | desc = 'Validation: '+'; '.join(self.log_description(i, logs, False)) 490 | gen.set_description(desc) 491 | finally: 492 | if self.is_main_rank: gen.close() 493 | 494 | def load_history(self): 495 | history_file = os.path.join(self.config.log_path, 'history.yaml') 496 | try: 497 | with open(history_file, 'r') as fp: 498 | return yaml.load(fp, Loader=yaml_Loader) 499 | except FileNotFoundError: 500 | return [] 501 | 502 | def save_history(self, history): 503 | os.makedirs(self.config.log_path, exist_ok=True) 504 | history_file = os.path.join(self.config.log_path, 'history.yaml') 505 | with open(history_file, 'w') as fp: 506 | yaml.dump(history, fp, sort_keys=False, Dumper=yaml_Dumper) 507 | 508 | 509 | def train_model(self): 510 | if self.is_main_rank: 511 | history = self.load_history() 512 | starting_epoch = self.state.current_epoch 513 | 514 | self.on_train_begin() 515 | should_stop_training = False 516 | try: 517 | for i in range(starting_epoch, self.config.num_epochs): 518 | self.state.current_epoch = i 519 | if self.is_main_rank: 520 | print(f'\nEpoch {i+1}/{self.config.num_epochs}:', flush=True) 521 | logs = dict(epoch = self.state.current_epoch, 522 | global_step = self.state.global_step) 523 | 524 | try: 525 | self.on_epoch_begin(logs, True) 526 | if self.config.training_type == 'normal': 527 | self.train_epoch(i, logs) 528 | elif self.config.training_type == 'minimal': 529 | self.minimal_train_epoch(i, logs) 530 | else: 531 | raise ValueError(f'Unknown training type: {self.config.training_type}') 532 | self.on_epoch_end(logs, True) 533 | except StopTrainingException: 534 | should_stop_training = True 535 | 536 | try: 537 | if (self.val_dataloader is not None)\ 538 | and (not ((i+1) % self.config.validation_frequency)): 539 | self.on_epoch_begin(logs, False) 540 | if self.config.evaluation_type == 'validation': 541 | self.validation_epoch(i, logs) 542 | elif self.config.evaluation_type == 'prediction': 543 | self.prediction_epoch(i, logs) 544 | else: 545 | raise ValueError(f'Unknown evaluation type: {self.config.evaluation_type}') 546 | self.on_epoch_end(logs, False) 547 | except StopTrainingException: 548 | should_stop_training = True 549 | 550 | self.state.current_epoch = i + 1 551 | if self.is_main_rank: 552 | self.save_checkpoint() 553 | 554 | history.append(logs) 555 | self.save_history(history) 556 | 557 | if should_stop_training: 558 | if self.is_main_rank: 559 | print('Stopping training ...') 560 | break 561 | finally: 562 | self.on_train_end() 563 | 564 | def distributed_barrier(self): 565 | if self.is_distributed: 566 | dummy = torch.ones((),dtype=torch.int64).cuda() 567 | torch.distributed.all_reduce(dummy) 568 | 569 | # Prediction logic 570 | def prediction_step(self, batch): 571 | predictions = self.model(batch) 572 | if isinstance(batch, torch.Tensor): 573 | return dict(inputs=batch, predictions=predictions) 574 | elif isinstance(batch, list): 575 | outputs = batch.copy() 576 | batch.append(predictions) 577 | return outputs 578 | elif isinstance(batch, dict): 579 | outputs = batch.copy() 580 | outputs.update(predictions=predictions) 581 | return outputs 582 | 583 | def prediction_loop(self, dataloader): 584 | self.model.eval() 585 | 586 | outputs = [] 587 | 588 | if self.is_main_rank: 589 | gen = tqdm(dataloader, dynamic_ncols=True) 590 | else: 591 | gen = dataloader 592 | try: 593 | with torch.no_grad(): 594 | for batch in gen: 595 | batch = self.preprocess_batch(batch) 596 | outputs.append(self.prediction_step(batch)) 597 | finally: 598 | if self.is_main_rank: gen.close() 599 | 600 | return outputs 601 | 602 | def preprocess_predictions(self, outputs): 603 | if isinstance(outputs[0], torch.Tensor): 604 | return torch.cat(outputs, dim=0) 605 | elif isinstance(outputs[0], dict): 606 | return {k: torch.cat([o[k] for o in outputs], dim=0) 607 | for k in outputs[0].keys()} 608 | elif isinstance(outputs[0], list): 609 | return [torch.cat([o[i] for o in outputs], dim=0) 610 | for i in range(len(outputs[0]))] 611 | else: 612 | raise ValueError('Unsupported output type') 613 | 614 | def postprocess_predictions(self, outputs): 615 | if isinstance(outputs, torch.Tensor): 616 | return outputs.cpu().numpy() 617 | elif isinstance(outputs, dict): 618 | return {k: v.cpu().numpy() for k, v in outputs.items()} 619 | elif isinstance(outputs, list): 620 | return [v.cpu().numpy() for v in outputs] 621 | else: 622 | raise ValueError('Unsupported output type') 623 | 624 | def distributed_gatther_tensor(self, tensors): 625 | shapes = torch.zeros(self.ddp_world_size+1, dtype=torch.long).cuda() 626 | shapes[self.ddp_rank+1] = tensors.shape[0] 627 | torch.distributed.all_reduce(shapes) 628 | 629 | offsets = torch.cumsum(shapes, dim=0) 630 | all_tensors = torch.zeros(offsets[-1], *tensors.shape[1:], dtype=tensors.dtype).cuda() 631 | all_tensors[offsets[self.ddp_rank]:offsets[self.ddp_rank+1]] = tensors 632 | 633 | torch.distributed.all_reduce(all_tensors) 634 | return all_tensors 635 | 636 | def distributed_gather_predictions(self, predictions): 637 | if self.is_main_rank: 638 | print('Gathering predictions from all ranks...') 639 | 640 | if isinstance(predictions, torch.Tensor): 641 | all_predictions = self.distributed_gatther_tensor(predictions) 642 | elif isinstance(predictions, list): 643 | all_predictions = [self.distributed_gatther_tensor(pred) for pred in predictions] 644 | elif isinstance(predictions, dict): 645 | all_predictions = {key:self.distributed_gatther_tensor(pred) 646 | for key, pred in predictions.items()} 647 | else: 648 | raise ValueError('Unsupported output type') 649 | 650 | if self.is_main_rank: 651 | print('Done.') 652 | return all_predictions 653 | 654 | def save_predictions(self, dataset_name, predictions): 655 | os.makedirs(self.config.predictions_path, exist_ok=True) 656 | predictions_file = os.path.join(self.config.predictions_path, f'{dataset_name}.pt') 657 | torch.save(predictions, predictions_file) 658 | print(f'Saved predictions to {predictions_file}') 659 | 660 | def evaluate_predictions(self, predictions): 661 | raise NotImplementedError 662 | 663 | def prediction_epoch(self, epoch, logs): 664 | if self.is_main_rank: 665 | print(f'Predicting on validation dataset...') 666 | dataloader = self.val_dataloader 667 | outputs = self.prediction_loop(dataloader) 668 | outputs = self.preprocess_predictions(outputs) 669 | 670 | if self.is_distributed: 671 | outputs = self.distributed_gather_predictions(outputs) 672 | 673 | predictions = self.postprocess_predictions(outputs) 674 | if self.is_main_rank: 675 | self.save_predictions('validation', predictions) 676 | results = self.evaluate_predictions(predictions) 677 | results = {f'val_{k}': v for k, v in results.items()} 678 | logs.update(results) 679 | if self.is_main_rank: 680 | desc = 'Validation: '+'; '.join(f'{k}: {v:.4f}' for k, v in results.items()) 681 | print(desc, flush=True) 682 | 683 | 684 | # Interface 685 | def prepare_for_training(self): 686 | self.config_summary() 687 | self.save_config_file() 688 | self.load_checkpoint() 689 | self.model_summary() 690 | 691 | def execute_training(self): 692 | self.prepare_for_training() 693 | self.train_model() 694 | self.finalize_training() 695 | 696 | def finalize_training(self): 697 | pass 698 | 699 | 700 | -------------------------------------------------------------------------------- /lib/training/training_mixins.py: -------------------------------------------------------------------------------- 1 | from lib.training.training import TrainingBase, StopTrainingException 2 | from lib.utils.dotdict import HDict 3 | 4 | import torch 5 | import numpy as np 6 | import os 7 | 8 | 9 | class SaveModel(TrainingBase): 10 | def get_default_config(self): 11 | config = super().get_default_config() 12 | config.update( 13 | saved_model_path = HDict.L('c:path.join(c.save_path,"saved_model")'), 14 | save_model_when = 'epoch', 15 | saved_model_name = "epoch_{epoch:0>4d}", 16 | save_model_monitor = 'val_loss', 17 | save_monitor_improves_when = 'less', 18 | save_model_condition = HDict.L("c:c.save_model_monitor+"+ 19 | "('<=' if c.save_monitor_improves_when=='less' else '>=')+"+ 20 | "'save_monitor_value'"), 21 | save_last_only = True, 22 | ) 23 | return config 24 | 25 | def get_default_state(self): 26 | state = super().get_default_state() 27 | state.update( 28 | last_saved_model_file = None, 29 | ) 30 | if self.config.save_monitor_improves_when == 'less': 31 | state.update( 32 | save_monitor_value = np.inf, 33 | save_monitor_epoch = -1, 34 | ) 35 | elif self.config.save_monitor_improves_when == 'greater': 36 | state.update( 37 | save_monitor_value = 0, 38 | save_monitor_epoch = -1, 39 | ) 40 | else: 41 | raise ValueError 42 | return state 43 | 44 | def save_model(self, name): 45 | if not self.is_main_rank: return 46 | 47 | os.makedirs(self.config.saved_model_path, exist_ok=True) 48 | save_file = os.path.join(self.config.saved_model_path, name+'.pt') 49 | torch.save(self.base_model.state_dict(), save_file) 50 | print(f'SAVE: model saved to {save_file}', flush=True) 51 | 52 | if self.config.save_last_only and (self.state.last_saved_model_file is not None)\ 53 | and (os.path.exists(self.state.last_saved_model_file)): 54 | os.remove(self.state.last_saved_model_file) 55 | print(f'SAVE: removed old model file {self.state.last_saved_model_file}', flush=True) 56 | 57 | self.state.last_saved_model_file = save_file 58 | 59 | def on_batch_end(self, i, logs, training): 60 | super().on_batch_end(i, logs, training) 61 | if self.config.save_model_when != 'batch' or not training or not self.is_main_rank: return 62 | config = self.config 63 | scope = dict(batch=i) 64 | scope.update(self.state) 65 | scope.update(logs) 66 | if eval(config.save_model_condition, scope): 67 | self.save_model(config.saved_model_name.format(**scope)) 68 | 69 | def on_epoch_end(self, logs, training): 70 | super().on_epoch_end(logs, training) 71 | if training: return 72 | 73 | config = self.config 74 | state = self.state 75 | monitor = config.save_model_monitor 76 | try: 77 | new_value = logs[monitor] 78 | new_epoch = logs['epoch'] 79 | except KeyError: 80 | print(f'Warning: SAVE: COULD NOT FIND LOG!', flush=True) 81 | return 82 | 83 | old_value = state.save_monitor_value 84 | old_epoch = state.save_monitor_epoch 85 | 86 | if (self.config.save_monitor_improves_when == 'less' and new_value <= old_value)\ 87 | or (self.config.save_monitor_improves_when == 'greater' and new_value >= old_value): 88 | state.save_monitor_value = new_value 89 | state.save_monitor_epoch = new_epoch 90 | if self.is_main_rank: 91 | print(f'MONITOR BEST: {monitor} improved from (epoch:{old_epoch},value:{old_value:0.5f})'+ 92 | f' to (epoch:{new_epoch},value:{new_value:0.5f})',flush=True) 93 | elif self.is_main_rank: 94 | print(f'MONITOR BEST: {monitor} did NOT improve from'+ 95 | f' (epoch:{old_epoch},value:{old_value:0.5f})',flush=True) 96 | 97 | if config.save_model_when != 'epoch' or not self.is_main_rank: return 98 | scope = {} 99 | scope.update(self.state) 100 | scope.update(logs) 101 | if eval(config.save_model_condition, scope): 102 | self.save_model(config.saved_model_name.format(**scope)) 103 | 104 | 105 | class VerboseLR(TrainingBase): 106 | def get_default_config(self): 107 | config = super().get_default_config() 108 | config.update( 109 | verbose_lr_log = True, 110 | ) 111 | return config 112 | 113 | def log_description(self, i, logs, training): 114 | descriptions = super().log_description(i, logs, training) 115 | if training and self.config.verbose_lr_log: 116 | descriptions.append(f'(lr:{logs["lr"]:0.3e})') 117 | return descriptions 118 | 119 | 120 | class ReduceLR(TrainingBase): 121 | def get_default_config(self): 122 | config = super().get_default_config() 123 | config.update( 124 | rlr_factor = 0.5, 125 | rlr_patience = 10, 126 | min_lr = 1e-6, 127 | stopping_lr = 0., 128 | rlr_monitor = 'val_loss', 129 | rlr_monitor_improves_when = 'less', 130 | ) 131 | return config 132 | 133 | def get_default_state(self): 134 | state = super().get_default_state() 135 | state.update( 136 | last_rlr_epoch = -1, 137 | ) 138 | if self.config.rlr_monitor_improves_when == 'less': 139 | state.update( 140 | rlr_monitor_value = np.inf, 141 | rlr_monitor_epoch = -1, 142 | ) 143 | elif self.config.rlr_monitor_improves_when == 'greater': 144 | state.update( 145 | rlr_monitor_value = 0, 146 | rlr_monitor_epoch = -1, 147 | ) 148 | else: 149 | raise ValueError 150 | return state 151 | 152 | def on_epoch_begin(self, logs, training): 153 | super().on_epoch_begin(logs, training) 154 | if 'lr' not in logs: 155 | logs['lr'] = max(group['lr'] for group in self.optimizer.param_groups) 156 | 157 | def on_epoch_end(self, logs, training): 158 | super().on_epoch_end(logs, training) 159 | if training: return 160 | 161 | config = self.config 162 | state = self.state 163 | monitor = config.rlr_monitor 164 | try: 165 | new_value = logs[monitor] 166 | new_epoch = logs['epoch'] 167 | except KeyError: 168 | print(f'Warning: RLR: COULD NOT FIND LOG!', flush=True) 169 | return 170 | 171 | old_value = state.rlr_monitor_value 172 | old_epoch = state.rlr_monitor_epoch 173 | 174 | if (self.config.rlr_monitor_improves_when == 'less' and new_value <= old_value)\ 175 | or (self.config.rlr_monitor_improves_when == 'greater' and new_value >= old_value): 176 | state.rlr_monitor_value = new_value 177 | state.rlr_monitor_epoch = new_epoch 178 | else: 179 | if config.rlr_factor < 1: 180 | epoch_gap = (new_epoch - max(state.last_rlr_epoch, old_epoch)) 181 | if epoch_gap >= config.rlr_patience: 182 | old_lrs = [] 183 | new_lrs = [] 184 | for group in self.optimizer.param_groups: 185 | old_lr = group['lr'] 186 | new_lr = max(old_lr*config.rlr_factor, config.min_lr) 187 | group['lr'] = new_lr 188 | old_lrs.append(old_lr) 189 | new_lrs.append(new_lr) 190 | 191 | old_lr = max(old_lrs) 192 | new_lr = max(new_lrs) 193 | 194 | logs['lr'] = new_lr 195 | 196 | state.last_rlr_epoch = new_epoch 197 | if self.is_main_rank: 198 | print(f'\nRLR: {monitor} did NOT improve for {epoch_gap} epochs,'+ 199 | f' new lr = {new_lr}', flush=True) 200 | 201 | if new_lr < config.stopping_lr: 202 | if self.is_main_rank: 203 | print(f'\nSTOP: lr fell below {config.stopping_lr}, STOPPING TRAINING!',flush=True) 204 | raise StopTrainingException 205 | 206 | 207 | 208 | class LinearLRWarmup(TrainingBase): 209 | def get_default_config(self): 210 | config = super().get_default_config() 211 | config.update( 212 | lr_warmup_steps = -1, 213 | ) 214 | return config 215 | def on_batch_begin(self, i, logs, training): 216 | super().on_batch_begin(i, logs, training) 217 | if training and self.state.global_step <= self.config.lr_warmup_steps: 218 | new_lr = self.config.max_lr * (self.state.global_step / self.config.lr_warmup_steps) 219 | for group in self.optimizer.param_groups: 220 | group['lr'] = new_lr 221 | logs['lr'] = new_lr 222 | 223 | 224 | class LinearLRWarmupCosineDecay(TrainingBase): 225 | def get_default_config(self): 226 | config = super().get_default_config() 227 | config.update( 228 | lr_warmup_steps = 60_000, 229 | lr_total_steps = 1_000_000, 230 | min_lr = 1e-6, 231 | cosine_halfwave = False, 232 | ) 233 | return config 234 | def on_batch_begin(self, i, logs, training): 235 | super().on_batch_begin(i, logs, training) 236 | if training: 237 | global_step = self.state.global_step 238 | lr_total_steps = self.config.lr_total_steps 239 | lr_warmup_steps = self.config.lr_warmup_steps 240 | max_lr = self.config.max_lr 241 | min_lr = self.config.min_lr 242 | 243 | if global_step > lr_total_steps: 244 | if self.is_main_rank: 245 | print(f'\nSTOP: global_step > lr_total_steps, STOPPING TRAINING!',flush=True) 246 | raise StopTrainingException 247 | 248 | if global_step <= lr_warmup_steps: 249 | new_lr = min_lr + (max_lr - min_lr) * (global_step / lr_warmup_steps) 250 | else: 251 | if self.config.cosine_halfwave: 252 | new_lr = min_lr + (max_lr - min_lr) * np.cos(0.5 * np.pi * (global_step - lr_warmup_steps) / (lr_total_steps - lr_warmup_steps)) 253 | else: 254 | new_lr = min_lr + (max_lr - min_lr) * (1 + np.cos(np.pi * (global_step - lr_warmup_steps) / (lr_total_steps - lr_warmup_steps))) * 0.5 255 | new_lr = float(new_lr) 256 | 257 | for group in self.optimizer.param_groups: 258 | group['lr'] = new_lr 259 | logs['lr'] = new_lr 260 | -------------------------------------------------------------------------------- /lib/utils/dotdict/__init__.py: -------------------------------------------------------------------------------- 1 | from .dotdict import * -------------------------------------------------------------------------------- /lib/utils/dotdict/dotdict.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class DDict(dict): 4 | def __dir__(self): 5 | return super().__dir__() + list(self.keys()) 6 | def __setattr__(self, key, value): 7 | if key in self.keys(): 8 | self[key]=value 9 | else: 10 | raise AttributeError('No such attribute: '+key) 11 | def __getattr__(self, key): 12 | try: 13 | return self[key] 14 | except KeyError: 15 | raise AttributeError('No such attribute: '+key) 16 | def copy(self): 17 | return self.__class__(self) 18 | 19 | class Assignable: 20 | def __setattr__(self, key, value): 21 | self[key]=value 22 | 23 | 24 | class Lambda(DDict): 25 | _globals={} 26 | @classmethod 27 | def update_globals(cls, list_or_dict): 28 | dictionary = dict((x.__name__,x) 29 | for x in list_or_dict)\ 30 | if isinstance(list_or_dict, list)\ 31 | else list_or_dict 32 | cls._globals.update(dictionary) 33 | 34 | def __init__(self, func): 35 | super().__init__(tag='', func=func) 36 | def __call__(self, base): 37 | lfunc = eval('lambda '+self.func, self._globals) 38 | return lfunc(base) 39 | 40 | 41 | class MDict(DDict): 42 | L = Lambda 43 | 44 | def call_macro(self, macro): 45 | return macro(self) 46 | 47 | def is_macro(self, value): 48 | return callable(value) 49 | 50 | def __getattr__(self, key): 51 | value = super().__getattr__(key) 52 | if self.is_macro(value): 53 | value = self.call_macro(value) 54 | return value 55 | 56 | def get_dict(self): 57 | ret = {} 58 | for key, value in self.items(): 59 | if self.is_macro(value): 60 | ret[key] = self.call_macro(value) 61 | else: 62 | ret[key] = value 63 | return ret 64 | 65 | 66 | class Inherit(DDict): 67 | def __init__(self, key, default=None, max_nest=None): 68 | super().__init__(tag='',key=key, default=default, max_nest=max_nest) 69 | 70 | def __call__(self, base): 71 | parent = base 72 | val = self.default 73 | for _ in range(self.max_nest 74 | if self.max_nest is not None 75 | else 1000): 76 | parent = parent.get_parent() 77 | if parent is None: break 78 | 79 | try: 80 | val = getattr(parent, self.key) 81 | break 82 | except AttributeError: 83 | pass 84 | 85 | return val 86 | 87 | 88 | class HDict(MDict): 89 | I = Inherit 90 | 91 | def set_parent(self, parent): 92 | self.__dict__['_parent'] = parent 93 | return self 94 | 95 | def get_parent(self): 96 | return self.__dict__.get('_parent', None) 97 | 98 | def __call__(self, parent): 99 | return self.set_parent(parent) 100 | 101 | def get_dict(self): 102 | ret = super().get_dict() 103 | for key, value in ret.items(): 104 | if isinstance(value, HDict): 105 | ret.update({key:value.get_dict()}) 106 | return ret 107 | 108 | 109 | 110 | 111 | class ADDict(Assignable,DDict): 112 | pass 113 | 114 | class AMDict(Assignable,MDict): 115 | pass 116 | 117 | class AHDict(Assignable,HDict): 118 | pass 119 | 120 | 121 | -------------------------------------------------------------------------------- /make_predictions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from lib.training.execute import get_configs_from_args, execute 3 | 4 | if __name__ == '__main__': 5 | config = get_configs_from_args(sys.argv) 6 | execute('predict', config) 7 | -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from lib.training.execute import get_configs_from_args, execute 3 | 4 | if __name__ == '__main__': 5 | config = get_configs_from_args(sys.argv) 6 | execute('train', config) 7 | --------------------------------------------------------------------------------