├── .gitignore ├── LICENSE ├── README.md ├── data ├── GEOM │ ├── __init__.py │ ├── analyze_GEOM.py │ └── pre_process.py ├── README.md └── download_data.sh ├── env.yaml ├── figures ├── Diagram.png └── flow.svg ├── log └── readme.md ├── script ├── FineTuning.sh ├── PreTraining.sh ├── Probing.sh └── README.md └── src ├── __init__.py ├── batch.py ├── config ├── __init__.py ├── aug_whitening.py ├── sweeps │ ├── __init__.py │ └── mlp.yaml ├── training_config.py └── validation_config.py ├── dataloader.py ├── datasets ├── __init__.py ├── molecule_contextual.py ├── molecule_datasets.py ├── molecule_gpt_gnn.py ├── molecule_graphcl.py ├── molecule_graphmvp.py ├── molecule_motif.py ├── molecule_rgcl.py └── utils.py ├── init.py ├── load_save.py ├── logger.py ├── models ├── __init__.py ├── attribute_masking.py ├── building_blocks │ ├── __init__.py │ ├── auto_encoder.py │ ├── flow.py │ ├── gnn.py │ ├── mlp.py │ └── schnet.py ├── context_prediction.py ├── contextual.py ├── discriminator.py ├── edge_prediction.py ├── gpt_gnn.py ├── graph_cl.py ├── graphmae.py ├── graphmvp.py ├── graphpred.py ├── info_max.py ├── joao_v2.py ├── motif.py ├── pre_trainer_model.py └── rgcl.py ├── pretrainers ├── __init__.py ├── attribute_masking.py ├── context_prediction.py ├── contextual.py ├── edge_prediction.py ├── gpt_gnn.py ├── graph_cl.py ├── graphmae.py ├── graphmvp.py ├── info_max.py ├── joao.py ├── joao_v2.py ├── motif.py ├── pretrainer.py └── rgcl.py ├── run_embedding_extraction.py ├── run_pretraining.py ├── run_validation.py ├── splitters.py ├── util.py └── validation ├── __init__.py ├── dataset.py ├── task ├── __init__.py ├── finetune_task.py ├── graph_edit_distance.py ├── graph_level.py ├── metrics.py ├── node_level.py ├── pair_level.py ├── prober_task.py ├── task.py └── weisfeiler_lehman.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | wandb/ 3 | analysis/ 4 | pretrain_models/ 5 | .vscode/ 6 | embedding_dir/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | .DS_Store 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Hanchen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MolGraphEval 2 | 3 | This repository is the official implementation of paper: "**Evaluating Self-supervised Learning for Molecular Graph Embeddings**”, NeurIPS 2023, Datasets and Benchmarks Track. 4 | 5 | ![Diagram](figures/Diagram.png) 6 | 7 | ### Citation 8 | ```bibtex 9 | @inproceedings{GraphEval, 10 | title = {Evaluating Self-supervised Learning for Molecular Graph Embeddings}, 11 | author = {Hanchen Wang* and Jean Kaddour* and Shengchao Liu and Jian Tang and Joan Lasenby and Qi Liu}, 12 | booktitle = {NeurIPS 2023, Datasets and Benchmarks Track}, 13 | year = 2023 14 | } 15 | ``` 16 | 17 | ### Usage 18 | We include scripts for pre-training, probing and fine-tuning for GraphSSL on molecules, see script folder. We use conda to set up the environment: 19 | ```bash 20 | conda env create -f env.yaml 21 | ``` -------------------------------------------------------------------------------- /data/GEOM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/data/GEOM/__init__.py -------------------------------------------------------------------------------- /data/GEOM/analyze_GEOM.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021. Shengchao & Hanchen 2 | # liusheng@mila.quebec & hw501@cam.ac.uk 3 | 4 | import hashlib 5 | import json 6 | import os 7 | import pickle 8 | import random 9 | from os.path import join 10 | 11 | import msgpack 12 | from rdkit import Chem 13 | 14 | # from rdkit.Chem import AllChem 15 | # from rdkit.Chem.Draw import MolsToGridImage 16 | # from rdkit.Chem.rdmolfiles import MolToPDBFile 17 | # from rdkit.Chem.rdMolDescriptors import CalcWHIM 18 | 19 | zip2md5 = { 20 | "drugs_crude.msgpack.tar.gz": "7778e84c50b7cde755cca670d1f75091", 21 | "drugs_featurized.msgpack.tar.gz": "2fb86edc50e3ab3b96f78fe01082965b", 22 | "qm9_crude.msgpack.tar.gz": "aad0081ed5d9b8c93c2bd0235987573b", 23 | "qm9_featurized.msgpack.tar.gz": "09655f470f438e3a7a0dfd20f40f6f22", 24 | "rdkit_folder.tar.gz": "e8f2168b7050652db22c976be25c450e", 25 | } 26 | 27 | 28 | def compute_md5(file_name, chunk_size=65536): 29 | md5 = hashlib.md5() 30 | with open(file_name, "rb") as fin: 31 | chunk = fin.read(chunk_size) 32 | while chunk: 33 | md5.update(chunk) 34 | chunk = fin.read(chunk_size) 35 | return md5.hexdigest() 36 | 37 | 38 | def analyze_crude_file(data): 39 | drugs_file = "{}_crude.msgpack".format(data) 40 | unpacker = msgpack.Unpacker(open(drugs_file, "rb")) 41 | print(compute_md5("{}.tar.gz".format(drugs_file))) 42 | print(zip2md5["{}.tar.gz".format(drugs_file)]) 43 | 44 | total_smiles_list = [] 45 | for idx, drug_batch in enumerate(unpacker): 46 | smiles_list = list(drug_batch.keys()) 47 | print(idx, "\t", len(smiles_list)) 48 | total_smiles_list.extend(smiles_list) 49 | 50 | for smiles in smiles_list: 51 | print(smiles) 52 | if smiles == "CCOCC[C@@H](O)C=O": 53 | print(drug_batch[smiles]) 54 | conformer_list = drug_batch[smiles]["conformers"] 55 | print(len(conformer_list)) 56 | # break 57 | break 58 | print("total smiles list {}".format(len(total_smiles_list))) 59 | return 60 | 61 | 62 | def analyze_featurized_file(data): 63 | drugs_file = "{}_featurized.msgpack".format(data) 64 | unpacker = msgpack.Unpacker(open(drugs_file, "rb")) 65 | print(compute_md5("{}.tar.gz".format(drugs_file))) 66 | print(zip2md5["{}.tar.gz".format(drugs_file)]) 67 | 68 | for idx, drug_batch in enumerate(unpacker): 69 | smiles_list = list(drug_batch.keys()) 70 | for smiles in smiles_list: 71 | print(smiles) 72 | print(len(drug_batch[smiles])) 73 | # print(drug_batch[smiles]) 74 | break 75 | break 76 | 77 | return 78 | 79 | 80 | def analyze_rdkit_file(data): 81 | dir_name = "rdkit_folder" 82 | # dir_zip_name = '{}.tar.gz'.format(dir_name) 83 | # assert compute_md5(dir_zip_name) == zip2md5[dir_zip_name] 84 | 85 | drugs_file = "{}/summary_{}.json".format(dir_name, data) 86 | with open(drugs_file, "r") as f: 87 | drugs_summary = json.load(f) 88 | 89 | smiles_list = list(drugs_summary.keys()) 90 | print("# SMILES: {}".format(len(smiles_list))) # 304,466 91 | example_smiles = smiles_list[0] 92 | print(drugs_summary[example_smiles]) 93 | 94 | # Now let's find active molecules and their pickle paths: 95 | active_mol_paths = [] 96 | active_smiles_list = [] 97 | for smiles, sub_dic in drugs_summary.items(): 98 | if sub_dic.get("sars_cov_one_cl_protease_active") == 1: 99 | pickle_path = join(dir_name, sub_dic.get("pickle_path", "")) 100 | if os.path.isfile(pickle_path): 101 | active_mol_paths.append(pickle_path) 102 | print("# active mols on CoV 3CL: {}\n".format(len(active_mol_paths))) 103 | 104 | # Now randomly sample inactive molecules and their pickle paths: 105 | random_smiles = list(drugs_summary.keys()) 106 | random.shuffle(random_smiles) 107 | random_smiles = random_smiles[:1000] 108 | inactive_mol_paths = [] 109 | for smiles in random_smiles: 110 | sub_dic = drugs_summary[smiles] 111 | if sub_dic.get("sars_cov_one_cl_protease_active") == 0: 112 | pickle_path = join(dir_name, sub_dic.get("pickle_path", "")) 113 | if os.path.isfile(pickle_path): 114 | inactive_mol_paths.append(pickle_path) 115 | print("# inactive mols on CoV 3CL: {}\n".format(len(inactive_mol_paths))) 116 | 117 | sample_dic = {} 118 | sample_smiles = active_smiles_list 119 | sample_smiles.extend(random_smiles) 120 | for mol_path in [*active_mol_paths, *inactive_mol_paths]: 121 | with open(mol_path, "rb") as f: 122 | dic = pickle.load(f) 123 | sample_dic.update({dic["smiles"]: dic}) 124 | print("# all mols on CoV 3CL: {}\n".format(len(sample_dic))) 125 | 126 | idx = 0 127 | for k, v in sample_dic.items(): 128 | conf_list = v["conformers"] 129 | # print(k) 130 | # print(len(conf_list)) 131 | for conf in conf_list: 132 | mol = conf["rd_mol"] 133 | # print(conf) 134 | # print(mol) 135 | print(Chem.MolToSmiles(mol)) 136 | smiles = Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(mol))) 137 | print(smiles) 138 | print("\n") 139 | 140 | idx += 1 141 | if idx > 9: 142 | break 143 | 144 | return 145 | 146 | 147 | if __name__ == "__main__": 148 | data = "drugs" 149 | # analyze_crude_file(data) 150 | # analyze_featurized_file(data) 151 | analyze_rdkit_file(data) -------------------------------------------------------------------------------- /data/GEOM/pre_process.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def clean_rdkit_file(data): 5 | """Write smiles in a csv file""" 6 | 7 | dir_name = "rdkit_folder" 8 | drugs_file = "{}/summary_{}.json".format(dir_name, data) 9 | with open(drugs_file, "r") as f: 10 | drugs_summary = json.load(f) 11 | 12 | # 304,466 molecules in total 13 | smiles_list = list(drugs_summary.keys()) 14 | print("Number of total items (SMILES): {}".format(len(smiles_list))) 15 | 16 | drug_file = "{}.csv".format(data) 17 | with open(drug_file, "w") as f: 18 | f.write("smiles\n") 19 | for smiles in smiles_list: 20 | f.write("{}\n".format(smiles)) 21 | 22 | return 23 | 24 | 25 | if __name__ == "__main__": 26 | data = "drugs" 27 | clean_rdkit_file(data) 28 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | 4 | 5 | #### Geometric Ensemble Of Molecules (GEOM) 6 | 7 | ```bash 8 | mkdir -p GEOM/raw 9 | mkdir -p GEOM/processed 10 | ``` 11 | 12 | ```bash 13 | wget https://dataverse.harvard.edu/api/access/datafile/4327252 14 | mv 4327252 rdkit_folder.tar.gz 15 | tar -xvf rdkit_folder.tar.gz 16 | ``` 17 | 18 | 19 | 20 | #### Chem Dataset 21 | 22 | ```bash 23 | bash download_data.sh 24 | ``` 25 | -------------------------------------------------------------------------------- /data/download_data.sh: -------------------------------------------------------------------------------- 1 | 2 | # MoleculeNet + ZINC 3 | wget http://snap.stanford.edu/gnn-pretrain/data/chem_dataset.zip 4 | unzip chem_dataset.zip 5 | mv dataset molecule_datasets 6 | rm chem_dataset.zip 7 | rm -r molecule_datasets/*/processed 8 | 9 | # for d in molecule_datasets/*/ 10 | # do 11 | # echo "$d" 12 | # ln -s $d ./ 13 | # done 14 | 15 | # GEOM 16 | wget https://dataverse.harvard.edu/api/access/datafile/4327252 17 | mv 4327252 rdkit_folder.tar.gz 18 | tar -xvf rdkit_folder.tar.gz 19 | mv rdkit_folder GEOM/ 20 | rm rdkit_folder.tar.gz 21 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: GraphEval 2 | channels: 3 | - pyg 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - black=23.3.0=py39h06a4308_0 12 | - blas=1.0=mkl 13 | - brotlipy=0.7.0=py39h27cfd23_1003 14 | - bzip2=1.0.8=h7b6447c_0 15 | - ca-certificates=2023.05.30=h06a4308_0 16 | - certifi=2023.7.22=py39h06a4308_0 17 | - cffi=1.15.1=py39h5eee18b_3 18 | - cryptography=41.0.2=py39h22a60cf_0 19 | - cuda-cudart=11.7.99=0 20 | - cuda-cupti=11.7.101=0 21 | - cuda-libraries=11.7.1=0 22 | - cuda-nvrtc=11.7.99=0 23 | - cuda-nvtx=11.7.91=0 24 | - cuda-runtime=11.7.1=0 25 | - cudatoolkit=10.2.89=hfd86e86_1 26 | - ffmpeg=4.3=hf484d3e_0 27 | - filelock=3.9.0=py39h06a4308_0 28 | - freetype=2.12.1=h4a9f257_0 29 | - giflib=5.2.1=h5eee18b_3 30 | - gmp=6.2.1=h295c915_3 31 | - gmpy2=2.1.2=py39heeb90bb_0 32 | - gnutls=3.6.15=he1e5248_0 33 | - idna=3.4=py39h06a4308_0 34 | - intel-openmp=2023.1.0=hdb19cb5_46305 35 | - jinja2=3.1.2=py39h06a4308_0 36 | - joblib=1.2.0=py39h06a4308_0 37 | - jpeg=9e=h5eee18b_1 38 | - lame=3.100=h7b6447c_0 39 | - lcms2=2.12=h3be6417_0 40 | - ld_impl_linux-64=2.38=h1181459_1 41 | - lerc=3.0=h295c915_0 42 | - libcublas=11.10.3.66=0 43 | - libcufft=10.7.2.124=h4fbf590_0 44 | - libcufile=1.7.1.12=0 45 | - libcurand=10.3.3.129=0 46 | - libcusolver=11.4.0.1=0 47 | - libcusparse=11.7.4.91=0 48 | - libdeflate=1.17=h5eee18b_0 49 | - libffi=3.4.4=h6a678d5_0 50 | - libgcc-ng=11.2.0=h1234567_1 51 | - libgfortran-ng=11.2.0=h00389a5_1 52 | - libgfortran5=11.2.0=h1234567_1 53 | - libgomp=11.2.0=h1234567_1 54 | - libiconv=1.16=h7f8727e_2 55 | - libidn2=2.3.4=h5eee18b_0 56 | - libnpp=11.7.4.75=0 57 | - libnvjpeg=11.8.0.2=0 58 | - libpng=1.6.39=h5eee18b_0 59 | - libstdcxx-ng=11.2.0=h1234567_1 60 | - libtasn1=4.19.0=h5eee18b_0 61 | - libtiff=4.5.0=h6a678d5_2 62 | - libunistring=0.9.10=h27cfd23_0 63 | - libuv=1.44.2=h5eee18b_0 64 | - libwebp=1.2.4=h11a3e52_1 65 | - libwebp-base=1.2.4=h5eee18b_1 66 | - lz4-c=1.9.4=h6a678d5_0 67 | - markupsafe=2.1.1=py39h7f8727e_0 68 | - mkl=2023.1.0=h213fc3f_46343 69 | - mkl-service=2.4.0=py39h5eee18b_1 70 | - mkl_fft=1.3.6=py39h417a72b_1 71 | - mkl_random=1.2.2=py39h417a72b_1 72 | - mpc=1.1.0=h10f8cd9_1 73 | - mpfr=4.0.2=hb69a4c5_1 74 | - mpmath=1.3.0=py39h06a4308_0 75 | - mypy_extensions=0.4.3=py39h06a4308_1 76 | - ncurses=6.4=h6a678d5_0 77 | - nettle=3.7.3=hbbd107a_1 78 | - networkx=3.1=py39h06a4308_0 79 | - ninja=1.10.2=h06a4308_5 80 | - ninja-base=1.10.2=hd09550d_5 81 | - numpy=1.25.2=py39h5f9d8c6_0 82 | - numpy-base=1.25.2=py39hb5e798b_0 83 | - openh264=2.1.1=h4ff587b_0 84 | - openssl=3.0.10=h7f8727e_0 85 | - pathspec=0.10.3=py39h06a4308_0 86 | - pip=23.2.1=py39h06a4308_0 87 | - platformdirs=2.5.2=py39h06a4308_0 88 | - psutil=5.9.0=py39h5eee18b_0 89 | - pycparser=2.21=pyhd3eb1b0_0 90 | - pyg=2.3.1=py39_torch_2.0.0_cu117 91 | - pyopenssl=23.2.0=py39h06a4308_0 92 | - pyparsing=3.0.9=py39h06a4308_0 93 | - pysocks=1.7.1=py39h06a4308_0 94 | - python=3.9.17=h955ad1f_0 95 | - pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0 96 | - pytorch-cluster=1.6.1=py39_torch_2.0.0_cu117 97 | - pytorch-cuda=11.7=h778d358_5 98 | - pytorch-mutex=1.0=cuda 99 | - pytorch-scatter=2.1.1=py39_torch_2.0.0_cu117 100 | - readline=8.2=h5eee18b_0 101 | - requests=2.31.0=py39h06a4308_0 102 | - scikit-learn=1.3.0=py39h1128e8f_0 103 | - scipy=1.11.1=py39h5f9d8c6_0 104 | - setuptools=68.0.0=py39h06a4308_0 105 | - sqlite=3.41.2=h5eee18b_0 106 | - sympy=1.11.1=py39h06a4308_0 107 | - tbb=2021.8.0=hdb19cb5_0 108 | - threadpoolctl=2.2.0=pyh0d69192_0 109 | - tk=8.6.12=h1ccaba5_0 110 | - tomli=2.0.1=py39h06a4308_0 111 | - torchaudio=2.0.2=py39_cu117 112 | - torchtriton=2.0.0=py39 113 | - torchvision=0.15.2=py39_cu117 114 | - typing_extensions=4.7.1=py39h06a4308_0 115 | - wheel=0.38.4=py39h06a4308_0 116 | - xz=5.4.2=h5eee18b_0 117 | - zlib=1.2.13=h5eee18b_0 118 | - zstd=1.5.5=hc292b87_0 119 | - pip: 120 | - appdirs==1.4.4 121 | - ase==3.22.1 122 | - calmsize==0.1.3 123 | - chardet==5.2.0 124 | - charset-normalizer==3.2.0 125 | - click==8.1.6 126 | - contourpy==1.1.0 127 | - cycler==0.11.0 128 | - cython==3.0.0 129 | - descriptastorus==2.5.0.23 130 | - docker-pycreds==0.4.0 131 | - fonttools==4.42.0 132 | - fvcore==0.1.5.post20221221 133 | - gensim==4.3.1 134 | - gitdb==4.0.10 135 | - gitpython==3.1.32 136 | - gmatch4py==0.2.5b0 137 | - importlib-resources==6.0.1 138 | - iopath==0.1.10 139 | - kiwisolver==1.4.4 140 | - littleutils==0.2.2 141 | - matplotlib==3.7.2 142 | - msgpack==1.0.5 143 | - ogb==1.3.6 144 | - outdated==0.2.2 145 | - packaging==23.1 146 | - pandas==2.0.3 147 | - pandas-flavor==0.6.0 148 | - pathtools==0.1.2 149 | - pillow==10.0.0 150 | - portalocker==2.7.0 151 | - protobuf==4.24.0 152 | - python-dateutil==2.8.2 153 | - pytorch-memlab==0.3.0 154 | - pytz==2023.3 155 | - pyyaml==6.0.1 156 | - rdkit==2023.3.2 157 | - sentry-sdk==1.29.2 158 | - setproctitle==1.3.2 159 | - six==1.16.0 160 | - smart-open==6.3.0 161 | - smmap==5.0.0 162 | - tabulate==0.9.0 163 | - termcolor==2.3.0 164 | - tqdm==4.66.1 165 | - tzdata==2023.3 166 | - urllib3==2.0.4 167 | - wandb==0.15.8 168 | - xarray==2023.7.0 169 | - yacs==0.1.8 170 | - zipp==3.16.2 171 | prefix: /home/wangh256/anaconda3/envs/GraphEval 172 | -------------------------------------------------------------------------------- /figures/Diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/figures/Diagram.png -------------------------------------------------------------------------------- /figures/flow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | Produced by OmniGraffle 7.18.5\n2021-09-03 12:12:02 +0000 22 | 23 | Canvas 1 24 | 25 | 26 | Layer 1 27 | 28 | 29 | 30 | 31 | run_pretraining.py 32 | 33 | 34 | 35 | 36 | 37 | 38 | run_embedding_extraction.py 39 | 40 | 41 | 42 | 43 | 44 | 45 | run_validation.py 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | Config 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | Saved model 65 | 66 | 67 | 68 | 69 | Saved embeds. 70 | 71 | 72 | 73 | 74 | Results 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 1 83 | 84 | 85 | 86 | 87 | 2 88 | 89 | 90 | 91 | 92 | 3 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /log/readme.md: -------------------------------------------------------------------------------- 1 | ### Path to save log files 2 | -------------------------------------------------------------------------------- /script/FineTuning.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ../ 3 | 4 | export PreTrainer=AM # or others 5 | export PreTrainData=zinc_standard_agent # or others 6 | 7 | export ProbeTask=downstream 8 | export FineTuneData_List=(bbbp tox21 toxcast sider clintox muv hiv bace) 9 | export Checkpoint="./pretrain_models/$PreTrainer/$PreTrainData/epoch99_model_complete.pth" 10 | 11 | for FineTuneData in "${FineTuneData_List[@]}"; do 12 | python src/run_validation.py \ 13 | --val_task="finetune" \ 14 | --pretrainer="FineTune" \ 15 | --dataset=$FineTuneData \ 16 | --probe_task=$ProbeTask \ 17 | --input_model_file=$Checkpoint ; 18 | done >> FineTune_${PreTrainer}_${PreTrainData}.log 19 | -------------------------------------------------------------------------------- /script/PreTraining.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ../ 3 | 4 | export PreTrainData=zinc_standard_agent # or geom2d_nmol500000_nconf1 5 | export PreTrainer=AM # or IM EP CP GPT_GNN JOAO JOAOv2 GraphCL Motif Contextual GraphMAE RGCL 6 | 7 | # If the PreTrainer is GraphMVP, then the PreTrainData is geom3d_nmol500000_nconf1 8 | 9 | # === Pre-Training === 10 | python src/run_pretraining.py \ 11 | --dataset="$PreTrainData" \ 12 | --pretrainer="$PreTrainer" \ 13 | --output_model_dir=./pretrain_models/; 14 | -------------------------------------------------------------------------------- /script/Probing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ../ 3 | 4 | export PreTrainer=AM # or others 5 | export PreTrainData=zinc_standard_agent # or others 6 | export EmbeddingDir="./embedding_dir/$PreTrainer/$PreTrainData/" 7 | export Checkpoint="./pretrain_models/$PreTrainer/$PreTrainData/epoch99_model_complete.pth" 8 | # for AM, the default path of checkpoint is "./pretrain_models/$PreTrainer/$PreTrainData/mask_rate-0.15_seed-42_lr-0.001/epoch0_model_complete.pth" 9 | # for the random baseline, just set the checkpoint as epoch0_model_complete.pth 10 | 11 | export FineTuneData_List=(bbbp tox21 toxcast sider clintox muv hiv bace) 12 | 13 | # === Embedding Extractions, from Fixed Pre-Trained GNNs === 14 | for FineTuneData in "${FineTuneData_List[@]}"; do 15 | python src/run_embedding_extraction.py \ 16 | --dataset="$FineTuneData" \ 17 | --pretrainer="$PreTrainer" \ 18 | --embedding_dir="$EmbeddingDir" \ 19 | --input_model_file="$Checkpoint" ; 20 | done 21 | 22 | # === Probing on Downstream Tasks === 23 | export ProbeTask=downstream 24 | for FineTuneData in "${FineTuneData_List[@]}"; do 25 | python run_validation.py \ 26 | --dataset="$FineTuneData" \ 27 | --probe_task="$ProbeTask" \ 28 | --pretrainer="$PreTrainer" \ 29 | --embedding_dir="$EmbeddingDir"; 30 | done >> Downstream_Probe_${PreTrainer}_${PreTrainData}.log 31 | 32 | 33 | # === Probing on Topological Metrics === 34 | export ProbeTask=node_degree # or node_centrality node_clustering link_prediction jaccard_coefficient katz_index graph_diameter node_connectivity cycle_basis assortativity_coefficient average_clustering_coefficient 35 | for FineTuneData in "${FineTuneData_List[@]}"; do 36 | python src/run_validation.py \ 37 | --probe_task="$ProbeTask" \ 38 | --dataset="$FineTuneData" \ 39 | --pretrainer="$PreTrainer" \ 40 | --embedding_dir="$EmbeddingDir" ; 41 | done >> Topological_Probe_${PreTrainer}_${PreTrainData}.log 42 | 43 | 44 | # === Probing on Substructures === 45 | export Substructure_List=(fr_epoxide fr_lactam fr_morpholine fr_oxazole \ 46 | fr_tetrazole fr_N_O fr_ether fr_furan fr_guanido fr_halogen fr_morpholine \ 47 | fr_piperdine fr_thiazole fr_thiophene fr_urea fr_allylic_oxid fr_amide \ 48 | fr_amidine fr_azo fr_benzene fr_imidazole fr_imide fr_piperzine fr_pyridine) 49 | 50 | for Substructure in "${Substructure_List[@]}"; do 51 | for FineTuneData in "${FineTuneData_List[@]}"; do 52 | python src/run_validation.py \ 53 | --dataset="$FineTuneData" \ 54 | --pretrainer="$PreTrainer" \ 55 | --embedding_dir="$EmbeddingDir" \ 56 | --probe_task="RDKiTFragment_$Substructure" ; 57 | done 58 | done >> Substructure_Probe_${PreTrainer}_${PreTrainData}.log 59 | -------------------------------------------------------------------------------- /script/README.md: -------------------------------------------------------------------------------- 1 | ## Scripts 2 | 3 | ### see PreTraining.sh, Probing.sh, FineTuning.sh for details. 4 | 5 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/src/__init__.py -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from abc import ABC 3 | 4 | 5 | def str2bool(v: str): 6 | return v.lower() == "true" 7 | 8 | 9 | @dataclasses.dataclass 10 | class Config(ABC): 11 | # about seed and basic info 12 | seed: int 13 | runseed: int 14 | device: int 15 | no_cuda: bool 16 | dataset: str 17 | # about model and pre-trainer 18 | model: str 19 | pretrainer: str 20 | -------------------------------------------------------------------------------- /src/config/aug_whitening.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from .training_config import TrainingConfig 4 | from .validation_config import ValidationConfig 5 | 6 | 7 | def aug_whitening( 8 | config: Union[TrainingConfig, ValidationConfig] 9 | ) -> Union[TrainingConfig, ValidationConfig]: 10 | """Remove the augmentations in dataloader.""" 11 | 12 | """ AM PreTrainer """ 13 | config.mask_rate = 0.0 14 | config.mask_edge = 0 15 | 16 | """ GraphCL/JOAO/JOAOv2 PreTrainer """ 17 | config.aug_mode = "no_aug" 18 | config.aug_strength = 0.0 19 | config.aug_prob = 0.0 20 | 21 | # unresolved Methods: 22 | # GraphGPTGNN, Contextual 23 | 24 | # unchanged Methods: 25 | # Motif, IM 26 | return config 27 | -------------------------------------------------------------------------------- /src/config/sweeps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/src/config/sweeps/__init__.py -------------------------------------------------------------------------------- /src/config/sweeps/mlp.yaml: -------------------------------------------------------------------------------- 1 | # see https://docs.wandb.ai/guides/sweeps/configuration for more infos 2 | method: grid 3 | metric: 4 | name: val_loss 5 | goal: minimize 6 | parameters: 7 | mlp_dim_hidden: 8 | values: 9 | - 128 10 | - 256 11 | - 512 12 | - 1024 13 | mlp_num_layers: 14 | values: 15 | - 1 16 | - 2 17 | - 4 18 | program: run_validation.py 19 | -------------------------------------------------------------------------------- /src/config/training_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | 4 | from config import Config, str2bool 5 | 6 | 7 | @dataclasses.dataclass 8 | class TrainingConfig(Config): 9 | model: str 10 | pretrainer: str 11 | log_filepath: str 12 | log_to_wandb: bool 13 | log_interval: int 14 | project_name: str 15 | run_name: str 16 | val_interval: int 17 | eval_train: bool 18 | input_data_dir: str 19 | save_model: bool 20 | input_model_file: str 21 | output_model_dir: str 22 | verbose: bool 23 | num_workers: str 24 | optimizer_name: str 25 | split: str 26 | batch_size: int 27 | epochs: int 28 | epochs_save: int 29 | lr: float 30 | # lr_scale: float 31 | weight_decay: float 32 | gnn_type: str 33 | num_layer: int 34 | emb_dim: int 35 | dropout_ratio: float 36 | graph_pooling: str 37 | JK: str 38 | # gnn_lr_scale: float 39 | aggr: str 40 | 41 | # === GraphCL === 42 | aug_mode: str 43 | aug_strength: float 44 | aug_prob: float 45 | # === AttrMask === 46 | mask_rate: float 47 | mask_edge: int 48 | num_atom_type: int 49 | num_edge_type: int 50 | # === ContextPred === 51 | csize: int 52 | contextpred_neg_samples: int 53 | atom_vocab_size: int 54 | # === JOAO === 55 | gamma_joao: float 56 | gamma_joaov2: float 57 | # === GraphMVP === 58 | GMVP_alpha1: float 59 | GMVP_alpha2: float 60 | GMVP_T: float 61 | GMVP_normalize: bool 62 | GMVP_CL_similarity_metric: str 63 | GMVP_CL_Neg_Samples: int 64 | GMVP_Masking_Ratio: float 65 | 66 | 67 | def parse_config() -> TrainingConfig: 68 | parser = argparse.ArgumentParser() 69 | 70 | # Seed and Basic Info 71 | parser.add_argument("--seed", type=int, default=42) 72 | parser.add_argument("--device", type=int, default=0) 73 | parser.add_argument("--runseed", type=int, default=0) 74 | parser.add_argument("--no_cuda", type=str2bool, default=False, help="Disable CUDA") 75 | parser.add_argument( 76 | "--model", type=str, default="gnn", choices=["gnn", "schnet", "egnn"] 77 | ) 78 | parser.add_argument( 79 | "--pretrainer", 80 | type=str, 81 | default="GraphCL", 82 | choices=[ 83 | "Motif", 84 | "Contextual", 85 | "GPT_GNN", 86 | "GraphCL", 87 | "JOAO", 88 | "JOAOv2", 89 | "AM", 90 | "IM", 91 | "CP", 92 | "EP", 93 | "GraphMVP", 94 | "RGCL", 95 | "GraphMAE", 96 | ], 97 | ) 98 | 99 | # Logging 100 | parser.add_argument("--log_filepath", type=str, default="./log/") 101 | parser.add_argument("--log_to_wandb", type=str2bool, default=True) 102 | parser.add_argument( 103 | "--log_interval", default=10, type=int, help="Log every n steps" 104 | ) 105 | parser.add_argument( 106 | "--val_interval", 107 | default=1, 108 | type=int, 109 | help="Evaluate validation push_loss every n steps", 110 | ) 111 | parser.add_argument( 112 | "--project_name", default="GraphEval", type=str, help="project name in wandb" 113 | ) 114 | parser.add_argument("--run_name", type=str, help="run name in wandb") 115 | parser.add_argument("--verbose", type=str2bool, default=False) 116 | 117 | # about if we would print out eval metric for training data 118 | parser.add_argument("--eval_train", type=str2bool, default=True) 119 | parser.add_argument("--input_data_dir", type=str, default="") 120 | 121 | # Loading and saving model checkpoints 122 | parser.add_argument("--save_model", type=str2bool, default=True) 123 | parser.add_argument("--input_model_file", type=str, default="") 124 | parser.add_argument("--output_model_dir", type=str, default="./saved_models/") 125 | 126 | # about dataset and dataloader 127 | parser.add_argument("--dataset", type=str, default="bace") 128 | parser.add_argument("--num_workers", type=int, default=8) 129 | 130 | # Training strategies (shared by PreTraining and FineTuning) 131 | parser.add_argument("--optimizer_name", type=str, default="adam") 132 | parser.add_argument("--split", type=str, default="scaffold") 133 | parser.add_argument("--batch_size", type=int, default=256) 134 | parser.add_argument("--epochs", type=int, default=100) 135 | parser.add_argument("--epochs_save", type=int, default=20) 136 | parser.add_argument("--lr", type=float, default=0.001) 137 | # parser.add_argument("--lr_scale", type=float, default=1) 138 | parser.add_argument("--weight_decay", type=float, default=0) 139 | 140 | # Molecule GNN 141 | parser.add_argument("--gnn_type", type=str, default="gin") 142 | parser.add_argument("--num_layer", type=int, default=5) 143 | parser.add_argument("--emb_dim", type=int, default=300) 144 | parser.add_argument("--dropout_ratio", type=float, default=0.5) 145 | parser.add_argument("--graph_pooling", type=str, default="mean") 146 | parser.add_argument( 147 | "--JK", 148 | type=str, 149 | default="last", 150 | choices=["last", "sum", "max", "concat"], 151 | help="how the node features across layers are combined.", 152 | ) 153 | # parser.add_argument("--gnn_lr_scale", type=float, default=1) 154 | parser.add_argument("--aggr", type=str, default="add") 155 | 156 | # PreTrainer: GraphCL, JOAO, JOAOv2 157 | parser.add_argument("--aug_mode", type=str, default="sample") 158 | parser.add_argument("--aug_strength", type=float, default=0.2) 159 | parser.add_argument("--aug_prob", type=float, default=0.1) 160 | 161 | # PreTrainer: AttrMask 162 | parser.add_argument("--mask_rate", type=float, default=0.15) 163 | parser.add_argument("--mask_edge", type=int, default=0) 164 | parser.add_argument("--num_atom_type", type=int, default=119) 165 | parser.add_argument("--num_edge_type", type=int, default=5) 166 | 167 | # PreTrainer: G-Cont, will automatically adjust based on pre-training data 168 | parser.add_argument("--atom_vocab_size", type=int, default=1) 169 | 170 | # PreTrainer: ContextPred 171 | parser.add_argument("--csize", type=int, default=3) 172 | parser.add_argument("--contextpred_neg_samples", type=int, default=1) 173 | 174 | # PreTrainer: JOAO and JOAOv2 175 | parser.add_argument("--gamma_joao", type=float, default=0.1) 176 | parser.add_argument("--gamma_joaov2", type=float, default=0.1) 177 | 178 | # PreTrainer: GraphMVP 179 | # Ref: https://github.com/chao1224/GraphMVP/blob/main/scripts_classification/submit_pre_training_GraphMVP_hybrid.sh#L14-L50 180 | parser.add_argument("--GMVP_alpha1", type=float, default=1) 181 | parser.add_argument("--GMVP_alpha2", type=float, default=1) # 0.1, 1, 10 182 | parser.add_argument("--GMVP_T", type=float, default=0.1) # 0.1, 0.2, 0.5, 1, 2 183 | parser.add_argument("--GMVP_normalize", type=bool, default=True) 184 | parser.add_argument( 185 | "--GMVP_CL_similarity_metric", 186 | type=str, 187 | default="EBM_dot_prod", 188 | choices=["InfoNCE_dot_prod", "EBM_dot_prod"], 189 | ) 190 | parser.add_argument("--GMVP_CL_Neg_Samples", type=int, default=5) 191 | parser.add_argument("--GMVP_Masking_Ratio", type=float, default=0.0) 192 | 193 | args = parser.parse_args() 194 | return TrainingConfig(**vars(args)) 195 | -------------------------------------------------------------------------------- /src/config/validation_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | 4 | from config import Config, str2bool 5 | 6 | # TODO (Jean): There are lots of redundancies between TrainingConfig and ValidationConfig 7 | # I am sure there is a better way; maybe making ValidationConfig a subclass of TrainingConfig 8 | 9 | 10 | @dataclasses.dataclass 11 | class ValidationConfig(Config): 12 | # validation task 13 | val_task: str 14 | probe_task: str 15 | # logging 16 | log_filepath: str 17 | log_to_wandb: bool 18 | log_interval: int 19 | val_interval: int 20 | project_name: str 21 | run_name: str 22 | # about if we would print out eval metric for training data 23 | eval_train: bool 24 | input_data_dir: str 25 | # about loading and saving 26 | save_model: bool 27 | input_model_file: str 28 | output_model_dir: str 29 | embedding_dir: str 30 | verbose: bool 31 | # about dataset and dataloader 32 | batch_size: int 33 | num_workers: str 34 | # about molecule GNN 35 | gnn_type: str 36 | num_layer: int 37 | emb_dim: int 38 | dropout_ratio: float 39 | graph_pooling: str 40 | JK: str 41 | # gnn_lr_scale: float 42 | aggr: str 43 | 44 | # for ProberTaskMLP 45 | mlp_dim_hidden: int 46 | mlp_dim_out: int 47 | mlp_num_layers: int 48 | mlp_batch_norm: bool 49 | mlp_initializer: str 50 | mlp_dropout: float 51 | mlp_activation: str 52 | mlp_leaky_relu: float 53 | 54 | # for GraphCL 55 | aug_mode: str 56 | aug_strength: float 57 | aug_prob: float 58 | # for AttributeMask 59 | mask_rate: float 60 | mask_edge: int 61 | num_atom_type: int 62 | num_edge_type: int 63 | # for ContextPred 64 | csize: int 65 | atom_vocab_size: int 66 | contextpred_neg_samples: int 67 | # for JOAO 68 | gamma_joao: float 69 | gamma_joaov2: float 70 | 71 | # Validation metric arguments 72 | 73 | # ProberTask 74 | optimizer_name: str 75 | split: str 76 | batch_size: int 77 | epochs: int 78 | lr: float 79 | # lr_scale: float 80 | weight_decay: float 81 | criterion_type: float 82 | 83 | 84 | def parse_config(parser: argparse.ArgumentParser = None) -> ValidationConfig: 85 | parser = argparse.ArgumentParser() if parser is None else parser 86 | # val and probe task 87 | parser.add_argument( 88 | "--val_task", 89 | type=str, 90 | default="prober", 91 | choices=["prober", "smoothing_metric", "finetune"], 92 | ) 93 | parser.add_argument("--probe_task", type=str, default="downstream") 94 | 95 | # seed and basic info 96 | parser.add_argument("--seed", type=int, default=42) 97 | parser.add_argument("--runseed", type=int, default=0) 98 | parser.add_argument("--no_cuda", type=str2bool, default=False) 99 | parser.add_argument("--device", type=int, default=0) 100 | parser.add_argument( 101 | "--model", type=str, default="gnn", choices=["gnn", "schnet", "egnn"] 102 | ) 103 | parser.add_argument( 104 | "--pretrainer", 105 | type=str, 106 | default="AM", 107 | choices=[ 108 | "Motif", 109 | "Contextual", 110 | "GPT_GNN", 111 | "GraphCL", 112 | "JOAO", 113 | "JOAOv2", 114 | "AM", 115 | "IM", 116 | "GraphMVP", 117 | "CP", 118 | "EP", 119 | "GraphMAE", 120 | "RGCL", 121 | "FineTune", 122 | ], 123 | ) 124 | 125 | # logging 126 | parser.add_argument("--log_filepath", type=str, default="./log/") 127 | parser.add_argument("--log_to_wandb", type=str2bool, default=True) 128 | parser.add_argument("--log_interval", default=10, type=int, help="Log steps") 129 | parser.add_argument( 130 | "--val_interval", 131 | default=1, 132 | type=int, 133 | help="Evaluate validation push_loss every n steps", 134 | ) 135 | parser.add_argument( 136 | "--project_name", default="GraphEval", type=str, help="project name in wandb" 137 | ) 138 | parser.add_argument("--run_name", type=str, help="run name in wandb") 139 | 140 | # about loading and saving 141 | parser.add_argument("--save_model", type=str2bool, default=True) 142 | parser.add_argument("--input_model_file", type=str, default="") 143 | parser.add_argument("--output_model_dir", type=str, default="") 144 | parser.add_argument("--embedding_dir", type=str, default="") 145 | # parser.add_argument("--embedding_dir", type=str, 146 | # default="./embedding_dir_x/Contextual/geom2d_nmol50000_nconf1_nupper1000/") 147 | parser.add_argument("--verbose", type=str2bool, default=False) 148 | 149 | # about dataset and dataloader 150 | parser.add_argument("--dataset", type=str, default="tox21") 151 | parser.add_argument("--num_workers", type=int, default=8) 152 | parser.add_argument("--batch_size", type=int, default=256) 153 | 154 | # ProberTask 155 | # about training strategies 156 | parser.add_argument("--optimizer_name", type=str, default="adam") 157 | parser.add_argument("--split", type=str, default="scaffold") 158 | parser.add_argument("--epochs", type=int, default=100) 159 | parser.add_argument("--lr", type=float, default=0.001) 160 | # parser.add_argument("--lr_scale", type=float, default=1) 161 | parser.add_argument("--weight_decay", type=float, default=0) 162 | parser.add_argument("--criterion_type", type=str, default="mse") 163 | 164 | # about molecule GNN 165 | parser.add_argument("--gnn_type", type=str, default="gin") 166 | parser.add_argument("--num_layer", type=int, default=5) 167 | parser.add_argument("--emb_dim", type=int, default=300) 168 | parser.add_argument("--dropout_ratio", type=float, default=0.5) 169 | parser.add_argument("--graph_pooling", type=str, default="mean") 170 | parser.add_argument( 171 | "--JK", 172 | type=str, 173 | default="last", 174 | choices=["last", "sum", "max", "concat"], 175 | help="how the node features across layers are combined.", 176 | ) 177 | # parser.add_argument("--gnn_lr_scale", type=float, default=1) 178 | parser.add_argument("--aggr", type=str, default="add") 179 | 180 | # for ProberTaskMLP 181 | parser.add_argument("--mlp_dim_hidden", type=int, default=600) 182 | parser.add_argument("--mlp_dim_out", type=int, default=1) 183 | parser.add_argument("--mlp_num_layers", type=int, default=2) 184 | parser.add_argument("--mlp_batch_norm", type=str2bool, default=False) 185 | parser.add_argument("--mlp_initializer", type=str, default="xavier") 186 | parser.add_argument("--mlp_dropout", type=float, default=0.0) 187 | parser.add_argument( 188 | "--mlp_activation", 189 | type=str, 190 | default="relu", 191 | choices=["leaky_relu", "rrelu", "relu", "elu", "gelu", "prelu", "selu"], 192 | ) 193 | parser.add_argument("--mlp_leaky_relu", type=float, default=0.5) 194 | 195 | # for GraphCL 196 | parser.add_argument("--aug_mode", type=str, default="sample") 197 | parser.add_argument("--aug_strength", type=float, default=0.2) 198 | parser.add_argument("--aug_prob", type=float, default=0.1) 199 | 200 | # for AttributeMask 201 | parser.add_argument("--mask_rate", type=float, default=0.15) 202 | parser.add_argument("--mask_edge", type=int, default=0) 203 | parser.add_argument("--num_atom_type", type=int, default=119) 204 | parser.add_argument("--num_edge_type", type=int, default=5) 205 | 206 | # PreTrainer: G-Cont, will automatically adjust based on pre-training data 207 | parser.add_argument("--atom_vocab_size", type=int, default=1) 208 | 209 | # for ContextPred 210 | parser.add_argument("--csize", type=int, default=3) 211 | parser.add_argument("--contextpred_neg_samples", type=int, default=1) 212 | 213 | # for JOAO 214 | parser.add_argument("--gamma_joao", type=float, default=0.1) 215 | parser.add_argument("--gamma_joaov2", type=float, default=0.1) 216 | 217 | # about if we would print out eval metric for training data 218 | parser.add_argument("--eval_train", type=str2bool, default=True) 219 | parser.add_argument("--input_data_dir", type=str, default="") 220 | 221 | args = parser.parse_args() 222 | return ValidationConfig(**vars(args)) 223 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | # Ref: snap-stanford/pretrain-gnns/blob/master/bio/dataloader.py 2 | # Ref: snap-stanford/pretrain-gnns/blob/master/chem/dataloader.py 3 | from util import MaskAtom 4 | from torch.utils.data import DataLoader 5 | from batch import BatchAE, BatchMasking, BatchSubstructContext 6 | 7 | 8 | class DataLoaderSubstructContext(DataLoader): 9 | """Data loader which merges data objects from a 10 | :class:`torch_geometric.data.dataset` to a mini-batch. 11 | Args: 12 | dataset (Dataset): The dataset from which to load the data. 13 | batch_size (int, optional): How may samples per batch to load. 14 | (default: :obj:`1`) 15 | shuffle (bool, optional): If set to :obj:`True`, the data will be 16 | reshuffled at every epoch (default: :obj:`True`)""" 17 | 18 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 19 | super(DataLoaderSubstructContext, self).__init__( 20 | dataset, 21 | batch_size, 22 | shuffle, 23 | collate_fn=lambda l: BatchSubstructContext.from_data_list(l), 24 | **kwargs 25 | ) 26 | 27 | 28 | class DataLoaderMasking(DataLoader): 29 | """Data loader which merges data objects from a 30 | :class:`torch_geometric.data.dataset` to a mini-batch. 31 | Args: 32 | dataset (Dataset): The dataset from which to load the data. 33 | batch_size (int, optional): How may samples per batch to load. 34 | (default: :obj:`1`) 35 | shuffle (bool, optional): If set to :obj:`True`, the data will be 36 | reshuffled at every epoch (default: :obj:`True`)""" 37 | 38 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 39 | super(DataLoaderMasking, self).__init__( 40 | dataset, 41 | batch_size, 42 | shuffle, 43 | collate_fn=lambda l: BatchMasking.from_data_list(l), 44 | **kwargs 45 | ) 46 | 47 | 48 | class DataLoaderAE(DataLoader): 49 | """Data loader which merges data objects from a 50 | :class:`torch_geometric.data.dataset` to a mini-batch. 51 | Args: 52 | dataset (Dataset): The dataset from which to load the data. 53 | batch_size (int, optional): How may samples per batch to load. 54 | (default: :obj:`1`) 55 | shuffle (bool, optional): If set to :obj:`True`, the data will be 56 | reshuffled at every epoch (default: :obj:`True`)""" 57 | 58 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 59 | super(DataLoaderAE, self).__init__( 60 | dataset, 61 | batch_size, 62 | shuffle, 63 | collate_fn=lambda l: BatchAE.from_data_list(l), 64 | **kwargs 65 | ) 66 | 67 | 68 | class DataLoaderMaskingPred(DataLoader): 69 | r"""Data loader which merges data objects from a 70 | :class:`torch_geometric.data.dataset` to a mini-batch. 71 | Args: 72 | dataset (Dataset): The dataset from which to load the data. 73 | batch_size (int, optional): How may samples per batch to load. 74 | (default: :obj:`1`) 75 | shuffle (bool, optional): If set to :obj:`True`, the data will be 76 | reshuffled at every epoch (default: :obj:`True`) 77 | """ 78 | 79 | def __init__( 80 | self, 81 | dataset, 82 | batch_size=1, 83 | shuffle=True, 84 | mask_rate=0.25, 85 | mask_edge=0.0, 86 | **kwargs 87 | ): 88 | self._transform = MaskAtom( 89 | num_atom_type=119, num_edge_type=5, mask_rate=mask_rate, mask_edge=mask_edge 90 | ) 91 | super(DataLoaderMaskingPred, self).__init__( 92 | dataset, batch_size, shuffle, collate_fn=self.collate_fn, **kwargs 93 | ) 94 | 95 | def collate_fn(self, batches): 96 | batchs = [self._transform(x) for x in batches] 97 | return BatchMasking.from_data_list(batchs) 98 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .molecule_contextual import MoleculeDataset_Contextual 2 | from .molecule_datasets import MoleculeDataset 3 | from .molecule_gpt_gnn import MoleculeDataset_GPTGNN 4 | from .molecule_graphcl import MoleculeDataset_GraphCL 5 | from .molecule_motif import RDKIT_PROPS, MoleculeDataset_Motif 6 | from .molecule_rgcl import MoleculeDataset_RGCL 7 | from .molecule_graphmvp import Molecule3DDataset, Molecule3DMaskingDataset 8 | from .utils import ( 9 | allowable_features, 10 | graph_data_obj_to_mol_simple, 11 | graph_data_obj_to_nx_simple, 12 | nx_to_graph_data_obj_simple, 13 | ) 14 | -------------------------------------------------------------------------------- /src/datasets/molecule_gpt_gnn.py: -------------------------------------------------------------------------------- 1 | import torch, random 2 | from tqdm import tqdm 3 | from torch_geometric.utils import subgraph 4 | from torch_geometric.data import Data, InMemoryDataset 5 | 6 | 7 | def search_graph(graph): 8 | num_node, edge_set = len(graph.x), set() 9 | u_list, v_list = graph.edge_index[0].numpy(), graph.edge_index[1].numpy() 10 | for u, v in zip(u_list, v_list): 11 | edge_set.add((u, v)) 12 | edge_set.add((v, u)) 13 | 14 | visited_list = list() 15 | unvisited_set = set([i for i in range(num_node)]) 16 | 17 | while len(unvisited_set) > 0: 18 | u = random.sample(unvisited_set, 1)[0] 19 | queue = [u] 20 | while len(queue): 21 | u = queue.pop(0) 22 | if u in visited_list: 23 | continue 24 | visited_list.append(u) 25 | unvisited_set.remove(u) 26 | 27 | for v in range(num_node): 28 | if (v not in visited_list) and ((u, v) in edge_set): 29 | queue.append(v) 30 | assert len(visited_list) == num_node 31 | return visited_list 32 | 33 | 34 | class MoleculeDataset_GPTGNN(InMemoryDataset): 35 | def __init__(self, molecule_dataset, transform=None, pre_transform=None): 36 | self.molecule_dataset = molecule_dataset 37 | self.root = molecule_dataset.root + "_GPT" 38 | super(MoleculeDataset_GPTGNN, self).__init__( 39 | self.root, transform=transform, pre_transform=pre_transform 40 | ) 41 | 42 | self.data, self.slices = torch.load(self.processed_paths[0]) 43 | 44 | return 45 | 46 | def process(self): 47 | num_molecule, data_list = len(self.molecule_dataset), list() 48 | for i in tqdm(range(num_molecule)): 49 | graph = self.molecule_dataset.get(i) 50 | 51 | num_node = len(graph.x) 52 | # TODO: will replace this with DFS/BFS searching 53 | node_list = search_graph(graph) 54 | 55 | for idx in range(num_node - 1): 56 | # [0..idx] -> [idx+1] 57 | sub_node_list = node_list[: idx + 1] 58 | next_node = node_list[idx + 1] 59 | 60 | edge_index, edge_attr = subgraph( 61 | subset=sub_node_list, 62 | edge_index=graph.edge_index, 63 | edge_attr=graph.edge_attr, 64 | relabel_nodes=True, 65 | num_nodes=num_node, 66 | ) 67 | 68 | # Take the subgraph and predict the next node (atom type only) 69 | sub_graph = Data( 70 | x=graph.x[sub_node_list], 71 | edge_index=edge_index, 72 | edge_attr=edge_attr, 73 | next_x=graph.x[next_node, :1], 74 | ) 75 | data_list.append(sub_graph) 76 | 77 | print("len of data\t", len(data_list)) 78 | data, slices = self.collate(data_list) 79 | print("Saving...") 80 | torch.save((data, slices), self.processed_paths[0]) 81 | return 82 | 83 | @property 84 | def processed_file_names(self): 85 | return "geometric_data_processed.pt" 86 | -------------------------------------------------------------------------------- /src/datasets/molecule_graphcl.py: -------------------------------------------------------------------------------- 1 | import torch, numpy as np 2 | from itertools import repeat 3 | from torch_geometric.data import Data 4 | from torch_geometric.utils import subgraph, to_networkx 5 | 6 | from .molecule_datasets import MoleculeDataset 7 | 8 | 9 | class MoleculeDataset_GraphCL(MoleculeDataset): 10 | # used in GraphCL, JOAOv1, JOAOv2 11 | def __init__( 12 | self, 13 | root, 14 | transform=None, 15 | pre_transform=None, 16 | pre_filter=None, 17 | dataset=None, 18 | empty=False, 19 | ): 20 | self.aug_prob = None 21 | self.aug_mode = "no_aug" 22 | self.aug_strength = 0.2 23 | self.augmentations = [ 24 | self.node_drop, 25 | self.subgraph, 26 | self.edge_pert, 27 | self.attr_mask, 28 | lambda x: x, 29 | ] 30 | super(MoleculeDataset_GraphCL, self).__init__( 31 | root, transform, pre_transform, pre_filter, dataset, empty 32 | ) 33 | 34 | def set_augMode(self, aug_mode): 35 | self.aug_mode = aug_mode 36 | 37 | def set_augStrength(self, aug_strength): 38 | self.aug_strength = aug_strength 39 | 40 | def set_augProb(self, aug_prob): 41 | self.aug_prob = aug_prob 42 | 43 | def node_drop(self, data): 44 | node_num, _ = data.x.size() 45 | _, edge_num = data.edge_index.size() 46 | drop_num = int(node_num * self.aug_strength) 47 | 48 | idx_perm = np.random.permutation(node_num) 49 | idx_nodrop = idx_perm[drop_num:].tolist() 50 | idx_nodrop.sort() 51 | 52 | edge_idx, edge_attr = subgraph( 53 | subset=idx_nodrop, 54 | edge_index=data.edge_index, 55 | edge_attr=data.edge_attr, 56 | relabel_nodes=True, 57 | num_nodes=node_num, 58 | ) 59 | 60 | data.edge_index = edge_idx 61 | data.edge_attr = edge_attr 62 | data.x = data.x[idx_nodrop] 63 | data.__num_nodes__, _ = data.x.shape 64 | return data 65 | 66 | def edge_pert(self, data): 67 | node_num, _ = data.x.size() 68 | _, edge_num = data.edge_index.size() 69 | pert_num = int(edge_num * self.aug_strength) 70 | 71 | # del edges 72 | idx_drop = np.random.choice(edge_num, (edge_num - pert_num), replace=False) 73 | edge_index = data.edge_index[:, idx_drop] 74 | edge_attr = data.edge_attr[idx_drop] 75 | 76 | # add edges 77 | adj = torch.ones((node_num, node_num)) 78 | adj[edge_index[0], edge_index[1]] = 0 79 | edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t() 80 | idx_add = np.random.choice( 81 | edge_index_nonexist.shape[1], pert_num, replace=False 82 | ) 83 | edge_index_add = edge_index_nonexist[:, idx_add] 84 | 85 | # random 4-class & 3-class edge_attr for 1st & 2nd dimension 86 | edge_attr_add_1 = torch.tensor( 87 | np.random.randint(4, size=(edge_index_add.shape[1], 1)) 88 | ) 89 | edge_attr_add_2 = torch.tensor( 90 | np.random.randint(3, size=(edge_index_add.shape[1], 1)) 91 | ) 92 | edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2), dim=1) 93 | edge_index = torch.cat((edge_index, edge_index_add), dim=1) 94 | edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0) 95 | 96 | data.edge_index = edge_index 97 | data.edge_attr = edge_attr 98 | return data 99 | 100 | def attr_mask(self, data): 101 | _x = data.x.clone() 102 | node_num, _ = data.x.size() 103 | mask_num = int(node_num * self.aug_strength) 104 | 105 | token = data.x.float().mean(dim=0).long() 106 | idx_mask = np.random.choice(node_num, mask_num, replace=False) 107 | 108 | _x[idx_mask] = token 109 | data.x = _x 110 | return data 111 | 112 | def subgraph(self, data): 113 | G = to_networkx(data) 114 | node_num, _ = data.x.size() 115 | _, edge_num = data.edge_index.size() 116 | sub_num = int(node_num * (1 - self.aug_strength)) 117 | 118 | idx_sub = [np.random.randint(node_num, size=1)[0]] 119 | idx_neigh = set([n for n in G.neighbors(idx_sub[-1])]) 120 | 121 | while len(idx_sub) <= sub_num: 122 | if len(idx_neigh) == 0: 123 | idx_unsub = list( 124 | set([n for n in range(node_num)]).difference(set(idx_sub)) 125 | ) 126 | idx_neigh = set([np.random.choice(idx_unsub)]) 127 | sample_node = np.random.choice(list(idx_neigh)) 128 | 129 | idx_sub.append(sample_node) 130 | idx_neigh = idx_neigh.union( 131 | set([n for n in G.neighbors(idx_sub[-1])]) 132 | ).difference(set(idx_sub)) 133 | 134 | idx_nondrop = idx_sub 135 | idx_nondrop.sort() 136 | 137 | edge_idx, edge_attr = subgraph( 138 | subset=idx_nondrop, 139 | edge_index=data.edge_index, 140 | edge_attr=data.edge_attr, 141 | relabel_nodes=True, 142 | num_nodes=node_num, 143 | ) 144 | 145 | data.edge_index = edge_idx 146 | data.edge_attr = edge_attr 147 | data.x = data.x[idx_nondrop] 148 | data.__num_nodes__, _ = data.x.shape 149 | return data 150 | 151 | def get(self, idx): 152 | data, data1, data2 = Data(), Data(), Data() 153 | keys_for_2D = ["x", "edge_index", "edge_attr"] 154 | for key in self.data.keys: 155 | item, slices = self.data[key], self.slices[key] 156 | s = list(repeat(slice(None), item.dim())) 157 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 158 | if key in keys_for_2D: 159 | data[key], data1[key], data2[key] = item[s], item[s], item[s] 160 | else: 161 | data[key] = item[s] 162 | 163 | if self.aug_mode == "no_aug": 164 | n_aug1, n_aug2 = 4, 4 165 | data1 = self.augmentations[n_aug1](data1) 166 | data2 = self.augmentations[n_aug2](data2) 167 | elif self.aug_mode == "uniform": 168 | n_aug = np.random.choice(25, 1)[0] 169 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5 170 | data1 = self.augmentations[n_aug1](data1) 171 | data2 = self.augmentations[n_aug2](data2) 172 | elif self.aug_mode == "sample": 173 | n_aug = np.random.choice(25, 1, p=self.aug_prob)[0] 174 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5 175 | data1 = self.augmentations[n_aug1](data1) 176 | data2 = self.augmentations[n_aug2](data2) 177 | else: 178 | raise NotImplementedError("aug_mode not implemented") 179 | return data, data1, data2 180 | -------------------------------------------------------------------------------- /src/datasets/molecule_graphmvp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from os.path import join 5 | from itertools import repeat 6 | from torch_geometric.loader import DataLoader 7 | from torch_geometric.data import Data, InMemoryDataset 8 | from torch_geometric.utils import subgraph, to_networkx 9 | 10 | 11 | class Molecule3DDataset(InMemoryDataset): 12 | def __init__( 13 | self, 14 | root, 15 | dataset, 16 | transform=None, 17 | pre_transform=None, 18 | pre_filter=None, 19 | empty=False, 20 | ): 21 | self.root = root 22 | self.dataset = dataset 23 | 24 | super(Molecule3DDataset, self).__init__( 25 | root, transform, pre_transform, pre_filter 26 | ) 27 | self.transform, self.pre_transform, self.pre_filter = ( 28 | transform, 29 | pre_transform, 30 | pre_filter, 31 | ) 32 | 33 | if not empty: 34 | self.data, self.slices = torch.load(self.processed_paths[0]) 35 | print("Dataset: {}\nData: {}".format(self.dataset, self.data)) 36 | 37 | def get(self, idx): 38 | data = Data() 39 | for key in self.data.keys: 40 | item, slices = self.data[key], self.slices[key] 41 | s = list(repeat(slice(None), item.dim())) 42 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 43 | data[key] = item[s] 44 | return data 45 | 46 | @property 47 | def raw_file_names(self): 48 | return os.listdir(self.raw_dir) 49 | 50 | @property 51 | def processed_file_names(self): 52 | return "geometric_data_processed.pt" 53 | 54 | def process(self): 55 | return 56 | 57 | 58 | class Molecule3DMaskingDataset(InMemoryDataset): 59 | def __init__( 60 | self, 61 | root, 62 | dataset, 63 | mask_ratio, 64 | transform=None, 65 | pre_transform=None, 66 | pre_filter=None, 67 | empty=False, 68 | ): 69 | self.root = root 70 | self.dataset = dataset 71 | self.mask_ratio = mask_ratio 72 | 73 | super(Molecule3DMaskingDataset, self).__init__( 74 | root, transform, pre_transform, pre_filter 75 | ) 76 | self.transform, self.pre_transform, self.pre_filter = ( 77 | transform, 78 | pre_transform, 79 | pre_filter, 80 | ) 81 | 82 | if not empty: 83 | self.data, self.slices = torch.load(self.processed_paths[0]) 84 | print("Dataset: {}\nData: {}".format(self.dataset, self.data)) 85 | 86 | def subgraph(self, data): 87 | G = to_networkx(data) 88 | node_num, _ = data.x.size() 89 | sub_num = int(node_num * (1 - self.mask_ratio)) 90 | 91 | idx_sub = [np.random.randint(node_num, size=1)[0]] 92 | idx_neigh = set([n for n in G.neighbors(idx_sub[-1])]) 93 | 94 | # BFS 95 | while len(idx_sub) <= sub_num: 96 | if len(idx_neigh) == 0: 97 | idx_unsub = list( 98 | set([n for n in range(node_num)]).difference(set(idx_sub)) 99 | ) 100 | idx_neigh = set([np.random.choice(idx_unsub)]) 101 | sample_node = np.random.choice(list(idx_neigh)) 102 | 103 | idx_sub.append(sample_node) 104 | idx_neigh = idx_neigh.union( 105 | set([n for n in G.neighbors(idx_sub[-1])]) 106 | ).difference(set(idx_sub)) 107 | 108 | idx_nondrop = idx_sub 109 | idx_nondrop.sort() 110 | 111 | edge_idx, edge_attr = subgraph( 112 | subset=idx_nondrop, 113 | edge_index=data.edge_index, 114 | edge_attr=data.edge_attr, 115 | relabel_nodes=True, 116 | num_nodes=node_num, 117 | ) 118 | 119 | data.edge_index = edge_idx 120 | data.edge_attr = edge_attr 121 | data.x = data.x[idx_nondrop] 122 | data.positions = data.positions[idx_nondrop] 123 | data.__num_nodes__, _ = data.x.shape 124 | return data 125 | 126 | def get(self, idx): 127 | data = Data() 128 | for key in self.data.keys: 129 | item, slices = self.data[key], self.slices[key] 130 | s = list(repeat(slice(None), item.dim())) 131 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 132 | data[key] = item[s] 133 | 134 | if self.mask_ratio > 0: 135 | data = self.subgraph(data) 136 | return data 137 | 138 | @property 139 | def raw_file_names(self): 140 | return os.listdir(self.raw_dir) 141 | 142 | @property 143 | def processed_file_names(self): 144 | return "geometric_data_processed.pt" 145 | 146 | def process(self): 147 | return 148 | 149 | 150 | if __name__ == "__main__": 151 | dataset = "geom3d_nmol500000_nconf1" 152 | root = join("~/GraphEval_dev/data/GEOM", dataset) 153 | dataset = Molecule3DDataset(root, dataset=dataset) 154 | dataset = Molecule3DMaskingDataset(root, dataset=dataset, mask_ratio=0.1) 155 | 156 | loader = DataLoader(dataset[:64], batch_size=16, shuffle=True, num_workers=4) 157 | for _, batch in enumerate(loader): 158 | print( 159 | "\n", 160 | "batch.x.shape: ", 161 | batch.x.shape, 162 | "\n", 163 | "batch.edge_index.shape: ", 164 | batch.edge_index.shape, 165 | "\n", 166 | "batch.edge_attr.shape: ", 167 | batch.edge_attr.shape, 168 | "\n", 169 | "batch.positions.shape: ", 170 | batch.positions.shape, 171 | "\n", 172 | "batch.batch.shape: ", 173 | batch.batch.shape, 174 | "\n", 175 | "=" * 7, 176 | ) 177 | -------------------------------------------------------------------------------- /src/datasets/molecule_motif.py: -------------------------------------------------------------------------------- 1 | import os, torch, numpy as np 2 | from tqdm import tqdm 3 | from itertools import repeat 4 | from descriptastorus.descriptors import rdDescriptors 5 | from torch_geometric.data import Data, InMemoryDataset 6 | 7 | 8 | RDKIT_PROPS = [ 9 | "fr_Al_COO", 10 | "fr_Al_OH", 11 | "fr_Al_OH_noTert", 12 | "fr_ArN", 13 | "fr_Ar_COO", 14 | "fr_Ar_N", 15 | "fr_Ar_NH", 16 | "fr_Ar_OH", 17 | "fr_COO", 18 | "fr_COO2", 19 | "fr_C_O", 20 | "fr_C_O_noCOO", 21 | "fr_C_S", 22 | "fr_HOCCN", 23 | "fr_Imine", 24 | "fr_NH0", 25 | "fr_NH1", 26 | "fr_NH2", 27 | "fr_N_O", 28 | "fr_Ndealkylation1", 29 | "fr_Ndealkylation2", 30 | "fr_Nhpyrrole", 31 | "fr_SH", 32 | "fr_aldehyde", 33 | "fr_alkyl_carbamate", 34 | "fr_alkyl_halide", 35 | "fr_allylic_oxid", 36 | "fr_amide", 37 | "fr_amidine", 38 | "fr_aniline", 39 | "fr_aryl_methyl", 40 | "fr_azide", 41 | "fr_azo", 42 | "fr_barbitur", 43 | "fr_benzene", 44 | "fr_benzodiazepine", 45 | "fr_bicyclic", 46 | "fr_diazo", 47 | "fr_dihydropyridine", 48 | "fr_epoxide", 49 | "fr_ester", 50 | "fr_ether", 51 | "fr_furan", 52 | "fr_guanido", 53 | "fr_halogen", 54 | "fr_hdrzine", 55 | "fr_hdrzone", 56 | "fr_imidazole", 57 | "fr_imide", 58 | "fr_isocyan", 59 | "fr_isothiocyan", 60 | "fr_ketone", 61 | "fr_ketone_Topliss", 62 | "fr_lactam", 63 | "fr_lactone", 64 | "fr_methoxy", 65 | "fr_morpholine", 66 | "fr_nitrile", 67 | "fr_nitro", 68 | "fr_nitro_arom", 69 | "fr_nitro_arom_nonortho", 70 | "fr_nitroso", 71 | "fr_oxazole", 72 | "fr_oxime", 73 | "fr_para_hydroxylation", 74 | "fr_phenol", 75 | "fr_phenol_noOrthoHbond", 76 | "fr_phos_acid", 77 | "fr_phos_ester", 78 | "fr_piperdine", 79 | "fr_piperzine", 80 | "fr_priamide", 81 | "fr_prisulfonamd", 82 | "fr_pyridine", 83 | "fr_quatN", 84 | "fr_sulfide", 85 | "fr_sulfonamd", 86 | "fr_sulfone", 87 | "fr_term_acetylene", 88 | "fr_tetrazole", 89 | "fr_thiazole", 90 | "fr_thiocyan", 91 | "fr_thiophene", 92 | "fr_unbrch_alkane", 93 | "fr_urea", 94 | ] 95 | 96 | 97 | def rdkit_functional_group_label_features_generator(smiles): 98 | """ 99 | Generates functional group label for a molecule using RDKit. 100 | :param smiles: A molecule (i.e. either a SMILES string or an RDKit obj). 101 | :return: A 1D numpy array containing the RDKit 2D features.""" 102 | # smiles = Chem.MolToSmiles(mol, isomericSmiles=True) 103 | # if type(mol) != str else mol 104 | generator = rdDescriptors.RDKit2D(RDKIT_PROPS) 105 | features = generator.process(smiles)[1:] 106 | features = np.array(features) 107 | features[features != 0] = 1 108 | return features 109 | 110 | 111 | class MoleculeDataset_Motif(InMemoryDataset): 112 | def __init__( 113 | self, 114 | root, 115 | dataset, 116 | transform=None, 117 | pre_transform=None, 118 | pre_filter=None, 119 | empty=False, 120 | ): 121 | self.dataset = dataset 122 | self.root = root 123 | 124 | super(MoleculeDataset_Motif, self).__init__( 125 | root, transform, pre_transform, pre_filter 126 | ) 127 | self.transform, self.pre_transform, self.pre_filter = ( 128 | transform, 129 | pre_transform, 130 | pre_filter, 131 | ) 132 | 133 | if not empty: 134 | self.data, self.slices = torch.load(self.processed_paths[0]) 135 | 136 | self.motif_file = os.path.join(root, "processed", "motif.pt") 137 | self.process_motif_file() 138 | self.motif_label_list = torch.load(self.motif_file) 139 | 140 | print( 141 | "Dataset: {}\nData: {}\nMotif: {}".format( 142 | self.dataset, self.data, self.motif_label_list.size() 143 | ) 144 | ) 145 | 146 | def process_motif_file(self): 147 | if not os.path.exists(self.motif_file): 148 | smiles_file = os.path.join(self.root, "processed", "smiles.csv") 149 | data_smiles_list = [] 150 | with open(smiles_file, "r") as f: 151 | lines = f.readlines() 152 | for smiles in lines: 153 | data_smiles_list.append(smiles.strip()) 154 | 155 | motif_label_list = [] 156 | for smiles in tqdm(data_smiles_list): 157 | label = rdkit_functional_group_label_features_generator(smiles) 158 | motif_label_list.append(label) 159 | 160 | self.motif_label_list = torch.LongTensor(motif_label_list) 161 | torch.save(self.motif_label_list, self.motif_file) 162 | return 163 | 164 | def get(self, idx): 165 | data = Data() 166 | for key in self.data.keys: 167 | item, slices = self.data[key], self.slices[key] 168 | s = list(repeat(slice(None), item.dim())) 169 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 170 | data[key] = item[s] 171 | data.y = self.motif_label_list[idx] 172 | return data 173 | 174 | @property 175 | def raw_file_names(self): 176 | return os.listdir(self.raw_dir) 177 | 178 | @property 179 | def processed_file_names(self): 180 | return "geometric_data_processed.pt" 181 | 182 | def download(self): 183 | return 184 | 185 | def process(self): 186 | return 187 | -------------------------------------------------------------------------------- /src/load_save.py: -------------------------------------------------------------------------------- 1 | import torch, pickle 2 | from tqdm import tqdm 3 | from typing import List 4 | from pathlib import Path 5 | from torch_geometric.data.dataset import Dataset 6 | from config.training_config import TrainingConfig 7 | from config.validation_config import ValidationConfig 8 | from models.pre_trainer_model import PreTrainerModel 9 | from torch_geometric.loader.dataloader import DataLoader 10 | 11 | 12 | def infer_and_save_embeddings( 13 | config: ValidationConfig, 14 | model: PreTrainerModel, 15 | device: torch.device, 16 | datasets: List[Dataset], 17 | loaders: List[DataLoader], 18 | smile_splits: list, 19 | save: bool = True, 20 | ) -> None: 21 | """Save graph and node representations for analysis. 22 | :param config: configurations 23 | :param model: pretrained or randomly-initialized model. 24 | :param device: device in use. 25 | :param datasets: train, valid and test data. 26 | :param loaders: train, valid and test data loaders. 27 | :param smile_splits: train, valid and test smiles. 28 | :param save: whether to save the pickle file. 29 | :return: None""" 30 | model.eval() 31 | for dataset, loader, smiles, split in zip( 32 | datasets, loaders, smile_splits, ["train", "valid", "test"] 33 | ): 34 | pbar = tqdm(total=len(loader)) 35 | pbar.set_description(f"{split} embeddings extracted: ") 36 | graph_embeddings_list, node_embeddings_list = [], [] 37 | for batch in loader: 38 | # if config.pretrainer == 'GraphCL' and isinstance(batch, list): 39 | # batch = batch[0] # remove the contrastive augmented data. 40 | batch = batch.to(device) 41 | with torch.no_grad(): 42 | node_embeddings, graph_embeddings = model.get_embeddings( 43 | batch.x, batch.edge_index, batch.edge_attr, batch.batch 44 | ) 45 | unbatched_node_embedding = [[] for _ in range(batch.batch.max() + 1)] 46 | for embedding, graph_id in zip(node_embeddings, batch.batch): 47 | unbatched_node_embedding[graph_id].append( 48 | embedding.detach().cpu().numpy() 49 | ) 50 | graph_embeddings_list.append(graph_embeddings) 51 | node_embeddings_list.extend(unbatched_node_embedding) 52 | pbar.update(1) 53 | graph_embeddings_list = ( 54 | torch.cat(graph_embeddings_list, dim=0).detach().cpu().numpy() 55 | ) 56 | 57 | if save: 58 | save_embeddings( 59 | config=config, 60 | dataset=dataset, 61 | graph_embeddings=graph_embeddings_list, 62 | node_embeddings=node_embeddings_list, 63 | smiles=smiles, 64 | split=split, 65 | ) 66 | 67 | pbar.close() 68 | # return dataset, graph_embeddings, node_embeddings, smiles 69 | 70 | 71 | def save_embeddings( 72 | config: ValidationConfig, 73 | dataset: Dataset, 74 | graph_embeddings: List[torch.Tensor], 75 | node_embeddings: List[List[torch.Tensor]], 76 | smiles: List[str], 77 | split: str, 78 | ): 79 | Path(config.embedding_dir).mkdir(parents=True, exist_ok=True) 80 | with open(f"{config.embedding_dir}{config.dataset}_{split}.pkl", "wb") as f: 81 | pickle.dump([graph_embeddings, node_embeddings, smiles], f) 82 | 83 | 84 | def save_model(config: TrainingConfig, model: torch.nn.Module, epoch: int) -> None: 85 | saver_dict = {"model": model.state_dict()} 86 | cfg = pretrain_config(config) 87 | 88 | # TODO: Update this for G-Contextual, 89 | path_ = f"{config.output_model_dir}/{config.pretrainer}/{config.dataset}/{cfg}" 90 | Path(path_).mkdir(parents=True, exist_ok=True) 91 | torch.save(saver_dict, f"{path_}/epoch{epoch}_model_complete.pth") 92 | 93 | 94 | def pretrain_config(config: TrainingConfig) -> None: 95 | cfg = "" 96 | if config.pretrainer == "AM": 97 | cfg = f"mask_rate-{config.mask_rate}" 98 | # TODO: run these experiments 99 | # elif config.pretrainer == "CP": 100 | # cfg = f"acsize-{config.csize}_atom_vocab_size-{config.atom_vocab_size}_contextpred_neg_samples-{config.contextpred_neg_samples}" 101 | elif config.pretrainer == "GraphCL": 102 | cfg = f"aug_mode-{config.aug_mode}_aug_strength-{config.aug_strength}_aug_prob-{config.aug_prob}" 103 | elif config.pretrainer in ["JOAO", "JOAOv2"]: 104 | cfg = f"gamma_joao-{config.gamma_joao}_gamma_joaov2-{config.gamma_joaov2}" 105 | # TODO: run these experiments 106 | elif config.pretrainer == "GraphMVP": 107 | cfg = f"alpha2-{config.GMVP_alpha2}_temper-{config.GMVP_T}" 108 | cfg = f"{cfg}_seed-{config.seed}_lr-{config.lr}" 109 | return cfg 110 | 111 | 112 | def load_checkpoint(config: ValidationConfig, device: torch.device) -> dict: 113 | print(f"\nLoad checkpoint from path {config.input_model_file}...") 114 | return torch.load(config.input_model_file, map_location=device) 115 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | # Ref: https://raw.githubusercontent.com/davda54/sam/main/example/utility/log.py 2 | 3 | import time, wandb, logging, dataclasses 4 | from pathlib import Path 5 | from datetime import datetime, timedelta 6 | from config.training_config import TrainingConfig 7 | 8 | TIME_STR = "{:%Y_%m_%d_%H_%M_%S_%f}".format(datetime.now()) 9 | 10 | 11 | @dataclasses.dataclass 12 | class EpochState: 13 | loss: float = 0.0 14 | accuracy: float = 0.0 15 | samples: int = 0 16 | 17 | def reset(self) -> None: 18 | self.loss, self.accuracy, self.samples = 0.0, 0.0, 0 19 | 20 | def add_to_loss(self, loss: float) -> None: 21 | self.loss += loss 22 | 23 | def add_to_accuracy(self, accuracy: float) -> None: 24 | self.accuracy += accuracy 25 | 26 | def add_to_samples(self, samples: int) -> None: 27 | self.samples += samples 28 | 29 | 30 | class LoadingBar: 31 | def __init__(self, length: int = 40) -> None: 32 | self.length = length 33 | self.symbols = ["┈", "░", "▒", "▓"] 34 | 35 | def __call__(self, progress: float) -> str: 36 | p = int(progress * self.length * 4 + 0.5) 37 | d, r = p // 4, p % 4 38 | return ( 39 | "┠┈" 40 | + d * "█" 41 | + ( 42 | (self.symbols[r]) + max(0, self.length - 1 - d) * "┈" 43 | if p < self.length * 4 44 | else "" 45 | ) 46 | + "┈┨" 47 | ) 48 | 49 | 50 | class LogFormatter: 51 | def __init__(self): 52 | self.start_time = time.time() 53 | 54 | def format(self, record): 55 | elapsed_seconds = round(record.created - self.start_time) 56 | 57 | prefix = "%s - %s - %s" % ( 58 | record.levelname, 59 | time.strftime("%x %X"), 60 | timedelta(seconds=elapsed_seconds), 61 | ) 62 | message = record.getMessage() 63 | message = message.replace("\n", "\n" + " " * (len(prefix) + 3)) 64 | return "%s - %s" % (prefix, message) 65 | 66 | 67 | def create_info_logger(filepath) -> logging.Logger: 68 | """ 69 | Create a logger. 70 | """ 71 | 72 | Path("log").mkdir(parents=True, exist_ok=True) 73 | # create log formatter 74 | log_formatter = LogFormatter() 75 | 76 | # create file handler and set level to debug 77 | file_handler = logging.FileHandler(filepath, "a") 78 | file_handler.setLevel(logging.DEBUG) 79 | file_handler.setFormatter(log_formatter) 80 | 81 | # TODO: If we want to print logs into the console, we can un-comment this 82 | # create console handler and set level to info 83 | # console_handler = logging.StreamHandler() 84 | # console_handler.setLevel(logging.INFO) 85 | # console_handler.setLevel(0) 86 | # console_handler.setFormatter(log_formatter) 87 | 88 | # create logger and set level to debug 89 | logger = logging.getLogger() 90 | logger.handlers = [] 91 | logger.setLevel(logging.DEBUG) 92 | logger.propagate = False 93 | logger.addHandler(file_handler) 94 | # logger.addHandler(console_handler) 95 | 96 | # reset logger elapsed time 97 | def reset_time(): 98 | log_formatter.start_time = time.time() 99 | 100 | logger.reset_time = reset_time 101 | 102 | return logger 103 | 104 | 105 | @dataclasses.dataclass 106 | class CombinedLogger: 107 | config: TrainingConfig 108 | is_train: bool = True 109 | epoch_state: EpochState = EpochState() 110 | last_steps_state: EpochState = EpochState() 111 | loading_bar: LoadingBar = LoadingBar(length=27) 112 | best_val_accuracy: float = 0.0 113 | epoch: int = -1 114 | step: int = 0 115 | info_logger: logging.Logger = None 116 | 117 | def __post_init__(self): 118 | self.info_logger = create_info_logger( 119 | f"{self.config.log_filepath}-{TIME_STR}.log" 120 | ) 121 | 122 | def log_value_dict(self, value_dict: dict): 123 | if self.config.log_to_wandb: 124 | wandb.log( 125 | { 126 | "epoch": self.epoch, 127 | **value_dict, 128 | } 129 | ) 130 | self.info_logger.info(value_dict) 131 | 132 | def train(self, num_batches: int) -> None: 133 | self.epoch += 1 134 | if self.epoch == 0: 135 | self._print_header() 136 | else: 137 | self.flush() 138 | self.is_train = True 139 | self.last_steps_state.reset() 140 | self._reset(num_batches) 141 | 142 | def eval(self, num_batches: int) -> None: 143 | self.flush() 144 | self.is_train = False 145 | self._reset(num_batches) 146 | 147 | def __call__( 148 | self, loss, accuracy, batch_size: int, learning_rate: float = None 149 | ) -> None: 150 | # TODO: in training, we don't have to record the accuracy 151 | if self.is_train: 152 | self._train_step(loss, accuracy, batch_size, learning_rate) 153 | else: 154 | self._eval_step(loss, accuracy, batch_size) 155 | 156 | def flush(self) -> None: 157 | loss = self.epoch_state.loss / self.num_batches 158 | accuracy = self.epoch_state.accuracy / self.epoch_state.samples 159 | if self.is_train: 160 | print( 161 | f"\r┃{self.epoch:12d} ┃{loss:12.4f} │{100 * accuracy:10.2f} % ┃{self.learning_rate:12.3e} │{self._time():>12} ┃", 162 | end="", 163 | flush=True, 164 | ) 165 | train_statistics = { 166 | "epoch": self.epoch, 167 | "train_accuracy": accuracy, 168 | "train_loss": loss, 169 | "lr": self.learning_rate, 170 | } 171 | if self.config.log_to_wandb: 172 | wandb.log(train_statistics) 173 | self.info_logger.info(train_statistics) 174 | 175 | else: 176 | print(f"{loss:12.4f} │{100 * accuracy:10.2f} % ┃", flush=True) 177 | 178 | if accuracy > self.best_val_accuracy: 179 | self.best_val_accuracy = accuracy 180 | validation_statistics = { 181 | "epoch": self.epoch, 182 | "val_accuracy": accuracy, 183 | "val_loss": loss, 184 | "best_val_accuracy": self.best_val_accuracy, 185 | } 186 | if self.config.log_to_wandb: 187 | wandb.log(validation_statistics) 188 | self.info_logger.info(validation_statistics) 189 | 190 | def _train_step( 191 | self, loss: float, accuracy: float, batch_size: int, learning_rate: float 192 | ) -> None: 193 | self.learning_rate = learning_rate 194 | self.last_steps_state.add_to_loss(loss) 195 | self.last_steps_state.add_to_accuracy(accuracy) 196 | self.last_steps_state.add_to_samples(batch_size) 197 | self.epoch_state.add_to_loss(loss) 198 | self.epoch_state.add_to_accuracy(accuracy) 199 | self.epoch_state.add_to_samples(batch_size) 200 | self.step += 1 201 | 202 | if self.step % self.config.log_interval == self.config.log_interval - 1: 203 | loss = self.last_steps_state.loss / self.step 204 | accuracy = self.last_steps_state.accuracy / self.last_steps_state.samples 205 | 206 | self.last_steps_state.reset() 207 | progress = self.step / self.num_batches 208 | print( 209 | f"\r┃{self.epoch:12d} ┃{loss:12.4f} │{100 * accuracy:10.2f} % ┃{learning_rate:12.3e} │{self._time():>12} {self.loading_bar(progress)}", 210 | end="", 211 | flush=True, 212 | ) 213 | 214 | def _eval_step(self, loss: float, accuracy: float, batch_size: int) -> None: 215 | self.epoch_state.add_to_loss(loss) 216 | self.epoch_state.add_to_accuracy(accuracy) 217 | self.epoch_state.add_to_samples(batch_size) 218 | 219 | def _reset(self, num_batches: int) -> None: 220 | self.start_time = time.time() 221 | self.step = 0 222 | self.num_batches = num_batches 223 | self.epoch_state.reset() 224 | 225 | def _time(self) -> str: 226 | time_seconds = int(time.time() - self.start_time) 227 | return f"{time_seconds // 60:02d}:{time_seconds % 60:02d} min" 228 | 229 | def _print_header(self) -> None: 230 | print( 231 | "┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓" 232 | ) 233 | print( 234 | "┃ ┃ ╷ ┃ ╷ ┃ ╷ ┃" 235 | ) 236 | print( 237 | "┃ epoch ┃ loss │ accuracy ┃ l.r. │ elapsed ┃ loss │ accuracy ┃" 238 | ) 239 | print( 240 | "┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨" 241 | ) 242 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.building_blocks.gnn import GNN, GNN_graphpred 2 | from .attribute_masking import AttributeMaskingModel 3 | from .building_blocks.auto_encoder import ( 4 | AutoEncoder, 5 | EnergyVariationalAutoEncoder, 6 | ImportanceWeightedAutoEncoder, 7 | NormalizingFlowVariationalAutoEncoder, 8 | VariationalAutoEncoder, 9 | ) 10 | from .context_prediction import ContextPredictionModel 11 | from .contextual import ContextualModel 12 | from .discriminator import Discriminator 13 | from .edge_prediction import EdgePredictionModel 14 | from .building_blocks.flow import ( 15 | AffineFlow, 16 | BatchNormFlow, 17 | NormalizingFlow, 18 | PlanarFlow, 19 | PReLUFlow, 20 | RadialFlow, 21 | ) 22 | from .gpt_gnn import GPTGNNModel 23 | from .graph_cl import GraphCLModel 24 | from .info_max import InfoMaxModel 25 | from .joao_v2 import JOAOv2Model 26 | from .motif import MotifModel 27 | from .graphpred import GraphPred 28 | from .graphmvp import GraphMVPModel 29 | from .rgcl import RGCLModel 30 | from .graphmae import GraphMAEModel 31 | -------------------------------------------------------------------------------- /src/models/attribute_masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Batch 3 | 4 | from config.training_config import TrainingConfig 5 | from models.building_blocks.gnn import GNN 6 | from models.pre_trainer_model import PreTrainerModel 7 | 8 | 9 | class AttributeMaskingModel(PreTrainerModel): 10 | def __init__(self, config: TrainingConfig, gnn: GNN): 11 | super().__init__(config=config) 12 | self.gnn: torch.nn.Module = gnn 13 | self.molecule_atom_masking_model = torch.nn.Linear(config.emb_dim, 119) 14 | 15 | def forward(self, batch: Batch) -> torch.Tensor: 16 | node_repr = self.gnn(batch.masked_x, batch.edge_index, batch.edge_attr) 17 | node_pred = self.molecule_atom_masking_model( 18 | node_repr[batch.masked_atom_indices] 19 | ) 20 | return node_pred 21 | -------------------------------------------------------------------------------- /src/models/building_blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/src/models/building_blocks/__init__.py -------------------------------------------------------------------------------- /src/models/building_blocks/flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as distrib 3 | import torch.distributions.transforms as transform 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Flow(transform.Transform, nn.Module): 9 | def __init__(self): 10 | transform.Transform.__init__(self) 11 | nn.Module.__init__(self) 12 | 13 | def init_parameters(self): 14 | for param in self.parameters(): 15 | param.data.uniform_(-0.01, 0.01) 16 | 17 | def __hash__(self): 18 | return nn.Module.__hash__(self) 19 | 20 | 21 | class PlanarFlow(Flow): 22 | def __init__(self, dim, h=torch.tanh, hp=(lambda x: 1 - torch.tanh(x) ** 2)): 23 | super(PlanarFlow, self).__init__() 24 | self.weight = nn.Parameter(torch.Tensor(1, dim)) 25 | self.scale = nn.Parameter(torch.Tensor(1, dim)) 26 | self.bias = nn.Parameter(torch.Tensor(1)) 27 | self.h = h 28 | self.hp = hp 29 | self.init_parameters() 30 | 31 | def _call(self, z): 32 | f_z = F.linear(z, self.weight, self.bias) 33 | return z + self.scale * self.h(f_z) 34 | 35 | def log_abs_det_jacobian(self, z): 36 | f_z = F.linear(z, self.weight, self.bias) 37 | psi = self.hp(f_z) * self.weight 38 | det_grad = 1 + torch.mm(psi, self.scale.t()) 39 | return torch.log(det_grad.abs() + 1e-9) 40 | 41 | 42 | class RadialFlow(Flow): 43 | def __init__(self, dim): 44 | super(RadialFlow, self).__init__() 45 | self.z0 = nn.Parameter(torch.Tensor(1, dim)) 46 | self.alpha = nn.Parameter(torch.Tensor(1)) 47 | self.beta = nn.Parameter(torch.Tensor(1)) 48 | self.dim = dim 49 | self.init_parameters() 50 | 51 | def _call(self, z): 52 | r = torch.norm(z - self.z0, dim=1).unsqueeze(1) 53 | h = 1 / (self.alpha + r) 54 | return z + (self.beta * h * (z - self.z0)) 55 | 56 | def log_abs_det_jacobian(self, z): 57 | r = torch.norm(z - self.z0, dim=1).unsqueeze(1) 58 | h = 1 / (self.alpha + r) 59 | hp = -1 / (self.alpha + r) ** 2 60 | bh = self.beta * h 61 | det_grad = ((1 + bh) ** self.dim - 1) * (1 + bh + self.beta * hp * r) 62 | return torch.log(det_grad.abs() + 1e-9) 63 | 64 | 65 | class PReLUFlow(Flow): 66 | def __init__(self, dim): 67 | super(PReLUFlow, self).__init__() 68 | self.alpha = nn.Parameter(torch.Tensor([1])) 69 | self.bijective = True 70 | 71 | def init_parameters(self): 72 | for param in self.parameters(): 73 | param.data.uniform_(0.01, 0.99) 74 | 75 | def _call(self, z): 76 | return torch.where(z >= 0, z, torch.abs(self.alpha) * z) 77 | 78 | def _inverse(self, z): 79 | return torch.where(z >= 0, z, torch.abs(1.0 / self.alpha) * z) 80 | 81 | def log_abs_det_jacobian(self, z): 82 | I = torch.ones_like(z) 83 | J = torch.where(z >= 0, I, self.alpha * I) 84 | log_abs_det = torch.log(torch.abs(J) + 1e-5) 85 | return torch.sum(log_abs_det, dim=1) 86 | 87 | 88 | class BatchNormFlow(Flow): 89 | def __init__(self, dim, momentum=0.95, eps=1e-5): 90 | super(BatchNormFlow, self).__init__() 91 | # Running batch statistics 92 | self.r_mean = torch.zeros(dim) 93 | self.r_var = torch.ones(dim) 94 | # Momentum 95 | self.momentum = momentum 96 | self.eps = eps 97 | # Trainable scale and shift (cf. original paper) 98 | self.gamma = nn.Parameter(torch.ones(dim)) 99 | self.beta = nn.Parameter(torch.zeros(dim)) 100 | 101 | def _call(self, z): 102 | if self.training: 103 | # Current batch stats 104 | self.b_mean = z.mean(0) 105 | self.b_var = (z - self.b_mean).pow(2).mean(0) + self.eps 106 | # Running mean and var 107 | self.r_mean = ( 108 | self.momentum * self.r_mean + (1 - self.momentum) * self.b_mean 109 | ) 110 | self.r_var = self.momentum * self.r_var + (1 - self.momentum) * self.b_var 111 | mean = self.b_mean 112 | var = self.b_var 113 | else: 114 | mean = self.r_mean 115 | var = self.r_var 116 | x_hat = (z - mean) / var.sqrt() 117 | y = self.gamma * x_hat + self.beta 118 | return y 119 | 120 | def _inverse(self, x): 121 | if self.training: 122 | mean = self.b_mean 123 | var = self.b_var 124 | else: 125 | mean = self.r_mean 126 | var = self.r_var 127 | x_hat = (z - self.beta) / self.gamma 128 | y = x_hat * var.sqrt() + mean 129 | return y 130 | 131 | def log_abs_det_jacobian(self, z): 132 | # Here we only need the variance 133 | mean = z.mean(0) 134 | var = (z - mean).pow(2).mean(0) + self.eps 135 | log_det = torch.log(self.gamma) - 0.5 * torch.log(var + self.eps) 136 | return torch.sum(log_det, -1) 137 | 138 | 139 | class AffineFlow(Flow): 140 | def __init__(self, dim): 141 | super(AffineFlow, self).__init__() 142 | self.weights = nn.Parameter(torch.Tensor(dim, dim)) 143 | nn.init.orthogonal_(self.weights) 144 | 145 | def _call(self, z): 146 | return z @ self.weights 147 | 148 | def _inverse(self, z): 149 | return z @ torch.inverse(self.weights) 150 | 151 | def log_abs_det_jacobian(self, z): 152 | return torch.slogdet(self.weights)[-1].unsqueeze(0).repeat(z.size(0), 1) 153 | 154 | 155 | class NormalizingFlow(nn.Module): 156 | def __init__(self, dim, blocks, flow_length, density): 157 | super(NormalizingFlow, self).__init__() 158 | biject = [] 159 | for f in range(flow_length): 160 | for b_flow in blocks: 161 | biject.append(b_flow(dim)) 162 | self.transforms = transform.ComposeTransform(biject) 163 | self.bijectors = nn.ModuleList(biject) 164 | self.base_density = density 165 | self.final_density = distrib.TransformedDistribution(density, self.transforms) 166 | self.log_det = [] 167 | 168 | def forward(self, z): 169 | self.log_det = [] 170 | # Applies series of flows 171 | for b in range(len(self.bijectors)): 172 | self.log_det.append(self.bijectors[b].log_abs_det_jacobian(z)) 173 | z = self.bijectors[b](z) 174 | return z, self.log_det 175 | 176 | 177 | if __name__ == "__main__": 178 | flow_model = "mlp" 179 | if flow_model == "planar": 180 | blocks = [PlanarFlow] 181 | elif flow_model == "radial": 182 | blocks = [RadialFlow] 183 | elif flow_model == "affine": 184 | blocks = [AffineFlow] 185 | elif flow_model == "mlp": 186 | blocks = [AffineFlow, BatchNormFlow, PReLUFlow] 187 | else: 188 | blocks = None 189 | 190 | flow = NormalizingFlow( 191 | dim=2, 192 | blocks=blocks, 193 | flow_length=8, 194 | density=distrib.MultivariateNormal(torch.zeros(2), torch.eye(2)), 195 | ) 196 | 197 | import numpy as np 198 | import torch.optim as optim 199 | 200 | def density_ring(z): 201 | z1, z2 = torch.chunk(z, chunks=2, dim=1) 202 | norm = torch.sqrt(z1**2 + z2**2) 203 | exp1 = torch.exp(-0.5 * ((z1 - 2) / 0.8) ** 2) 204 | exp2 = torch.exp(-0.5 * ((z1 + 2) / 0.8) ** 2) 205 | u = 0.5 * ((norm - 4) / 0.4) ** 2 - torch.log(exp1 + exp2) 206 | return torch.exp(-u) 207 | 208 | def loss(density, zk, log_jacobians): 209 | sum_of_log_jacobians = sum(log_jacobians) 210 | return (-sum_of_log_jacobians - torch.log(density(zk) + 1e-9)).mean() 211 | 212 | optimizer = optim.Adam(flow.parameters(), lr=1e-3) 213 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999) 214 | 215 | x = np.linspace(-5, 5, 1000) 216 | z = np.array(np.meshgrid(x, x)).transpose(1, 2, 0) 217 | z = np.reshape(z, [z.shape[0] * z.shape[1], -1]) 218 | 219 | ref_distrib = distrib.MultivariateNormal(torch.zeros(2), torch.eye(2)) 220 | for it in range(10001): 221 | # Draw a sample batch from Normal 222 | samples = ref_distrib.sample((512,)) 223 | # Evaluate flow of transforms 224 | zk, log_jacobians = flow(samples) 225 | # Evaluate loss and backprop 226 | optimizer.zero_grad() 227 | loss_v = loss(density_ring, zk, log_jacobians) 228 | loss_v.backward() 229 | optimizer.step() 230 | scheduler.step() 231 | if it % 1000 == 0: 232 | print("Loss (it. %i) : %f" % (it, loss_v.item())) 233 | # Draw random samples 234 | samples = ref_distrib.sample((int(1e5),)) 235 | # Evaluate flow and plot 236 | zk, _ = flow(samples) 237 | -------------------------------------------------------------------------------- /src/models/building_blocks/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | 7 | def get_activation(name: str, leaky_relu: Optional[float] = 0.5) -> nn.Module: 8 | str_to_activation = { 9 | "leaky_relu": nn.LeakyReLU(leaky_relu), 10 | "rrelu": nn.RReLU(), 11 | "relu": nn.ReLU(), 12 | "elu": nn.ELU(), 13 | "gelu": nn.GELU(), 14 | "prelu": nn.PReLU(), 15 | "selu": nn.SELU(), 16 | } 17 | 18 | return str_to_activation[name] 19 | 20 | 21 | def create_batch_norm_1d_layers(num_layers: int, dim_hidden: int): 22 | batch_norm_layers = nn.ModuleList() 23 | for i in range(num_layers - 1): 24 | batch_norm_layers.append(nn.BatchNorm1d(num_features=dim_hidden)) 25 | return batch_norm_layers 26 | 27 | 28 | def create_linear_layers( 29 | num_layers: int, dim_input: int, dim_hidden: int, dim_output: int 30 | ): 31 | linear_layers = nn.ModuleList() 32 | # Input layer 33 | linear_layers.append(nn.Linear(in_features=dim_input, out_features=dim_hidden)) 34 | # Hidden layers 35 | for i in range(1, num_layers - 1): 36 | linear_layers.append(nn.Linear(in_features=dim_hidden, out_features=dim_hidden)) 37 | # Output layer 38 | linear_layers.append(nn.Linear(dim_hidden, dim_output)) 39 | return linear_layers 40 | 41 | 42 | def init_layers(initializer_name: str, layers: nn.ModuleList): 43 | initializer = get_initializer(initializer_name) 44 | for layer in layers: 45 | initializer(layer.weight) 46 | 47 | 48 | def get_initializer(name: str = "xavier") -> Callable: 49 | str_to_init = { 50 | "orthogonal": nn.init.orthogonal_, 51 | "xavier": nn.init.xavier_uniform_, 52 | "kaiming": nn.init.kaiming_uniform_, 53 | } 54 | return str_to_init[name] 55 | 56 | 57 | class MLP(nn.Module): 58 | def __init__( 59 | self, 60 | dim_input: int, 61 | dim_hidden: int, 62 | dim_output: int, 63 | num_layers: int, 64 | batch_norm: bool, 65 | initializer: str, 66 | dropout: float, 67 | activation: str, 68 | leaky_relu: float, 69 | is_output_activation: bool, 70 | ): 71 | super().__init__() 72 | self.layers = create_linear_layers( 73 | num_layers=num_layers, 74 | dim_input=dim_input, 75 | dim_hidden=dim_hidden, 76 | dim_output=dim_output, 77 | ) 78 | init_layers(initializer_name=initializer, layers=self.layers) 79 | self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None 80 | self.batch_norm_layers = ( 81 | create_batch_norm_1d_layers(num_layers=num_layers, dim_hidden=dim_hidden) 82 | if batch_norm 83 | else None 84 | ) 85 | self.activation_function = get_activation( 86 | name=activation, leaky_relu=leaky_relu 87 | ) 88 | self.is_output_activation = is_output_activation 89 | 90 | def forward(self, x: Tensor): 91 | for i in range(len(self.layers) - 1): 92 | x = self.layers[i](x) 93 | x = self.activation_function(x) 94 | if self.batch_norm_layers: 95 | x = self.batch_norm_layers[i](x) 96 | if self.dropout: 97 | x = self.dropout(x) 98 | x = self.layers[-1](x) 99 | if self.is_output_activation: 100 | x = self.activation_function(x) 101 | return x 102 | -------------------------------------------------------------------------------- /src/models/building_blocks/schnet.py: -------------------------------------------------------------------------------- 1 | import ase 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import MessagePassing, radius_graph 6 | from torch_scatter import scatter 7 | from math import pi as PI 8 | 9 | # try: 10 | # import schnetpack as spk 11 | # except ImportError: 12 | # spk = None 13 | 14 | 15 | class SchNet(nn.Module): 16 | def __init__( 17 | self, 18 | hidden_channels=128, 19 | num_filters=128, 20 | num_interactions=6, 21 | num_gaussians=51, 22 | cutoff=10.0, 23 | readout="mean", 24 | dipole=False, 25 | mean=None, 26 | std=None, 27 | atomref=None, 28 | ): 29 | super(SchNet, self).__init__() 30 | 31 | assert readout in ["add", "sum", "mean"] 32 | 33 | self.readout = "add" if dipole else readout 34 | self.num_interactions = num_interactions 35 | self.hidden_channels = hidden_channels 36 | self.num_gaussians = num_gaussians 37 | self.num_filters = num_filters 38 | self.cutoff = cutoff 39 | self.dipole = dipole 40 | self.scale = None 41 | self.mean = mean 42 | self.std = std 43 | 44 | atomic_mass = torch.from_numpy(ase.data.atomic_masses) 45 | self.register_buffer("atomic_mass", atomic_mass) 46 | 47 | # self.embedding = Embedding(100, hidden_channels) 48 | self.embedding = nn.Embedding(119, hidden_channels) 49 | self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians) 50 | 51 | self.interactions = nn.ModuleList() 52 | for _ in range(num_interactions): 53 | block = InteractionBlock( 54 | hidden_channels, num_gaussians, num_filters, cutoff 55 | ) 56 | self.interactions.append(block) 57 | 58 | self.lin1 = nn.Linear(hidden_channels, hidden_channels) 59 | self.act = ShiftedSoftplus() 60 | self.lin2 = nn.Linear(hidden_channels, hidden_channels) 61 | 62 | self.register_buffer("initial_atomref", atomref) 63 | self.atomref = None 64 | if atomref is not None: 65 | self.atomref = nn.Embedding(100, 1) 66 | self.atomref.weight.data.copy_(atomref) 67 | 68 | self.reset_parameters() 69 | 70 | def reset_parameters(self): 71 | self.embedding.reset_parameters() 72 | for interaction in self.interactions: 73 | interaction.reset_parameters() 74 | torch.nn.init.xavier_uniform_(self.lin1.weight) 75 | self.lin1.bias.data.fill_(0) 76 | torch.nn.init.xavier_uniform_(self.lin2.weight) 77 | self.lin2.bias.data.fill_(0) 78 | if self.atomref is not None: 79 | self.atomref.weight.data.copy_(self.initial_atomref) 80 | 81 | def forward(self, z, pos, batch=None): 82 | assert z.dim() == 1 and z.dtype == torch.long 83 | batch = torch.zeros_like(z) if batch is None else batch 84 | 85 | h = self.embedding(z) 86 | 87 | edge_index = radius_graph(pos, r=self.cutoff, batch=batch) 88 | row, col = edge_index 89 | edge_weight = (pos[row] - pos[col]).norm(dim=-1) 90 | edge_attr = self.distance_expansion(edge_weight) 91 | 92 | for interaction in self.interactions: 93 | h = h + interaction(h, edge_index, edge_weight, edge_attr) 94 | 95 | h = self.lin1(h) 96 | h = self.act(h) 97 | h = self.lin2(h) 98 | 99 | if self.dipole: 100 | # Get center of mass. 101 | mass = self.atomic_mass[z].view(-1, 1) 102 | c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) 103 | h = h * (pos - c[batch]) 104 | 105 | if not self.dipole and self.mean is not None and self.std is not None: 106 | h = h * self.std + self.mean 107 | 108 | if not self.dipole and self.atomref is not None: 109 | h = h + self.atomref(z) 110 | 111 | out = scatter(h, batch, dim=0, reduce=self.readout) 112 | 113 | if self.dipole: 114 | out = torch.norm(out, dim=-1, keepdim=True) 115 | 116 | if self.scale is not None: 117 | out = self.scale * out 118 | 119 | return out 120 | 121 | def __repr__(self): 122 | return ( 123 | f"{self.__class__.__name__}(" 124 | f"hidden_channels={self.hidden_channels}, " 125 | f"num_filters={self.num_filters}, " 126 | f"num_interactions={self.num_interactions}, " 127 | f"num_gaussians={self.num_gaussians}, " 128 | f"cutoff={self.cutoff})" 129 | ) 130 | 131 | 132 | class InteractionBlock(torch.nn.Module): 133 | def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff): 134 | super(InteractionBlock, self).__init__() 135 | self.mlp = nn.Sequential( 136 | nn.Linear(num_gaussians, num_filters), 137 | ShiftedSoftplus(), 138 | nn.Linear(num_filters, num_filters), 139 | ) 140 | self.conv = CFConv( 141 | hidden_channels, hidden_channels, num_filters, self.mlp, cutoff 142 | ) 143 | self.act = ShiftedSoftplus() 144 | self.lin = nn.Linear(hidden_channels, hidden_channels) 145 | 146 | self.reset_parameters() 147 | 148 | def reset_parameters(self): 149 | torch.nn.init.xavier_uniform_(self.mlp[0].weight) 150 | self.mlp[0].bias.data.fill_(0) 151 | torch.nn.init.xavier_uniform_(self.mlp[2].weight) 152 | self.mlp[0].bias.data.fill_(0) 153 | self.conv.reset_parameters() 154 | torch.nn.init.xavier_uniform_(self.lin.weight) 155 | self.lin.bias.data.fill_(0) 156 | 157 | def forward(self, x, edge_index, edge_weight, edge_attr): 158 | x = self.conv(x, edge_index, edge_weight, edge_attr) 159 | x = self.act(x) 160 | x = self.lin(x) 161 | return x 162 | 163 | 164 | class CFConv(MessagePassing): 165 | def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff): 166 | super(CFConv, self).__init__(aggr="add") 167 | self.lin1 = nn.Linear(in_channels, num_filters, bias=False) 168 | self.lin2 = nn.Linear(num_filters, out_channels) 169 | self.mlp = mlp 170 | self.cutoff = cutoff 171 | 172 | self.reset_parameters() 173 | 174 | def reset_parameters(self): 175 | torch.nn.init.xavier_uniform_(self.lin1.weight) 176 | torch.nn.init.xavier_uniform_(self.lin2.weight) 177 | self.lin2.bias.data.fill_(0) 178 | 179 | def forward(self, x, edge_index, edge_weight, edge_attr): 180 | C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0) 181 | W = self.mlp(edge_attr) * C.view(-1, 1) 182 | 183 | x = self.lin1(x) 184 | x = self.propagate(edge_index, x=x, W=W) 185 | x = self.lin2(x) 186 | return x 187 | 188 | def message(self, x_j, W): 189 | return x_j * W 190 | 191 | 192 | class GaussianSmearing(torch.nn.Module): 193 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50): 194 | super(GaussianSmearing, self).__init__() 195 | offset = torch.linspace(start, stop, num_gaussians) 196 | self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2 197 | self.register_buffer("offset", offset) 198 | 199 | def forward(self, dist): 200 | dist = dist.view(-1, 1) - self.offset.view(1, -1) 201 | return torch.exp(self.coeff * torch.pow(dist, 2)) 202 | 203 | 204 | class ShiftedSoftplus(torch.nn.Module): 205 | def __init__(self): 206 | super(ShiftedSoftplus, self).__init__() 207 | self.shift = torch.log(torch.tensor(2.0)).item() 208 | 209 | def forward(self, x): 210 | return F.softplus(x) - self.shift 211 | -------------------------------------------------------------------------------- /src/models/context_prediction.py: -------------------------------------------------------------------------------- 1 | """ GRAPH SSL Pre-Training via Context Prediction (CP) 2 | i.e., maps nodes in similar structural contexts to closer embeddings 3 | Ref Paper: Sec. 3.1.1 and Appendix G of 4 | https://arxiv.org/abs/1905.12265 ; 5 | Ref Code: ${GitHub_Repo}/chem/pretrain_contextpred.py """ 6 | 7 | from typing import Callable 8 | 9 | import torch.nn as nn 10 | from torch import Tensor 11 | from torch_geometric.nn import global_mean_pool 12 | 13 | from config.training_config import TrainingConfig 14 | from models import GNN 15 | from models.pre_trainer_model import PreTrainerModel 16 | 17 | 18 | class ContextPredictionModel(PreTrainerModel): 19 | def __init__(self, gnn: GNN, config: TrainingConfig): 20 | super(ContextPredictionModel, self).__init__(config) 21 | self.gnn: nn.Module = gnn 22 | self.pool: Callable = global_mean_pool 23 | 24 | l1 = self.config.num_layer - 1 25 | l2 = l1 + self.config.csize 26 | num_layer = self.config.num_layer 27 | self.config.num_layer = l2 - l1 28 | self.context_model: nn.Module = GNN(self.config) 29 | self.config.num_layer = num_layer 30 | 31 | def forward_substruct_model( 32 | self, x: Tensor, edge_index: Tensor, edge_attr: Tensor 33 | ) -> Tensor: 34 | return self.gnn(x, edge_index, edge_attr) 35 | 36 | def forward_context_model( 37 | self, x: Tensor, edge_index: Tensor, edge_attr: Tensor 38 | ) -> Tensor: 39 | return self.context_model(x, edge_index, edge_attr) 40 | 41 | def forward_context_repr( 42 | self, 43 | x: Tensor, 44 | edge_index: Tensor, 45 | edge_attr: Tensor, 46 | overlap_context_substruct_idx: Tensor, 47 | batch_overlapped_context: Tensor, 48 | ) -> Tensor: 49 | # creating context representations 50 | overlapped_node_repr = self.context_model(x, edge_index, edge_attr)[ 51 | overlap_context_substruct_idx 52 | ] 53 | 54 | # positive context representation 55 | # readout -> global_mean_pool by default 56 | context_repr = self.pool(overlapped_node_repr, batch_overlapped_context) 57 | return context_repr 58 | 59 | 60 | if __name__ == "__main__": 61 | pass 62 | -------------------------------------------------------------------------------- /src/models/contextual.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | 4 | from config.training_config import TrainingConfig 5 | from models.building_blocks.gnn import GNN 6 | from models.pre_trainer_model import PreTrainerModel 7 | 8 | 9 | class ContextualModel(PreTrainerModel): 10 | def __init__(self, gnn: GNN, config: TrainingConfig): 11 | super(ContextualModel, self).__init__(config) 12 | self.gnn: nn.Module = gnn 13 | self.emb_dim = config.emb_dim 14 | self.criterion = nn.CrossEntropyLoss() 15 | self.atom_vocab_size = config.atom_vocab_size 16 | self.atom_vocab_model: nn.Module = nn.Sequential( 17 | nn.Linear(self.emb_dim, self.atom_vocab_size) 18 | ) 19 | 20 | def forward_cl( 21 | self, 22 | x: Tensor, 23 | edge_index: Tensor, 24 | edge_attr: Tensor, 25 | batch_assignments: Tensor, 26 | ) -> Tensor: 27 | x = self.gnn(x, edge_index, edge_attr) 28 | x = self.atom_vocab_model(x) 29 | return x 30 | 31 | def loss_cl(self, y_pred: Tensor, y_actual: Tensor) -> Tensor: 32 | loss = self.criterion(y_pred, y_actual) 33 | return loss 34 | -------------------------------------------------------------------------------- /src/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn.inits import uniform 4 | 5 | 6 | class Discriminator(nn.Module): 7 | def __init__(self, hidden_dim): 8 | super(Discriminator, self).__init__() 9 | self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim)) 10 | self.reset_parameters() 11 | 12 | def reset_parameters(self): 13 | size = self.weight.size(0) 14 | uniform(size, self.weight) 15 | 16 | def forward(self, x, summary): 17 | h = torch.matmul(summary, self.weight) 18 | return torch.sum(x * h, dim=1) 19 | -------------------------------------------------------------------------------- /src/models/edge_prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Batch 3 | 4 | from config.training_config import TrainingConfig 5 | from models.building_blocks.gnn import GNN 6 | from models.pre_trainer_model import PreTrainerModel 7 | 8 | 9 | class EdgePredictionModel(PreTrainerModel): 10 | def __init__(self, config: TrainingConfig, gnn: GNN): 11 | super().__init__(config=config) 12 | self.gnn: torch.nn.Module = gnn 13 | 14 | def forward(self, batch: Batch) -> torch.Tensor: 15 | node_repr = self.gnn(batch.x, batch.edge_index, batch.edge_attr) 16 | return node_repr 17 | -------------------------------------------------------------------------------- /src/models/gpt_gnn.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch.nn as nn 4 | from torch_geometric.data import Batch 5 | from torch_geometric.nn import global_mean_pool 6 | 7 | from config.training_config import TrainingConfig 8 | from models.building_blocks.gnn import GNN 9 | from models.pre_trainer_model import PreTrainerModel 10 | 11 | 12 | class GPTGNNModel(PreTrainerModel): 13 | """This is the atom prediction head for GPT-GNN. 14 | Will add edge in the future.""" 15 | 16 | def __init__(self, config: TrainingConfig, gnn: GNN): 17 | super().__init__(config=config) 18 | self.gnn: nn.Module = gnn 19 | self.atom_pred: nn.Module = nn.Linear(config.emb_dim, 119) 20 | self.molecule_readout_func: Callable = global_mean_pool 21 | 22 | def forward(self, batch: Batch): 23 | node_repr = self.gnn(batch.x, batch.edge_index, batch.edge_attr) 24 | graph_repr = self.molecule_readout_func(node_repr, batch.batch) 25 | next_node_pred = self.atom_pred(graph_repr) 26 | return next_node_pred 27 | -------------------------------------------------------------------------------- /src/models/graph_cl.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torch_geometric.nn import global_mean_pool 7 | 8 | from config.training_config import TrainingConfig 9 | from models.building_blocks.gnn import GNN 10 | from models.pre_trainer_model import PreTrainerModel 11 | 12 | 13 | class GraphCLModel(PreTrainerModel): 14 | def __init__(self, config: TrainingConfig, gnn: GNN): 15 | super(GraphCLModel, self).__init__(config) 16 | self.gnn: nn.Module = gnn 17 | self.emb_dim = config.emb_dim 18 | self.pool: Callable = global_mean_pool 19 | self.projection_head: nn.Module = nn.Sequential( 20 | nn.Linear(self.emb_dim, self.emb_dim), 21 | nn.ReLU(inplace=True), 22 | nn.Linear(self.emb_dim, self.emb_dim), 23 | ) 24 | 25 | def forward_cl( 26 | self, 27 | x: Tensor, 28 | edge_index: Tensor, 29 | edge_attr: Tensor, 30 | batch_assignments: Tensor, 31 | ) -> Tensor: 32 | x = self.gnn(x, edge_index, edge_attr) 33 | x = self.pool(x, batch_assignments) 34 | x = self.projection_head(x) 35 | return x 36 | 37 | def loss_cl(self, x1: Tensor, x2: Tensor) -> Tensor: 38 | T = 0.1 39 | batch, _ = x1.size() 40 | x1_abs = x1.norm(dim=1) 41 | x2_abs = x2.norm(dim=1) 42 | 43 | sim_matrix = torch.einsum("ik,jk->ij", x1, x2) / torch.einsum( 44 | "i,j->ij", x1_abs, x2_abs 45 | ) 46 | sim_matrix = torch.exp(sim_matrix / T) 47 | pos_sim = sim_matrix[range(batch), range(batch)] 48 | # loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 49 | loss = pos_sim / (sim_matrix.sum(dim=1)) 50 | loss = -torch.log(loss).mean() 51 | return loss 52 | -------------------------------------------------------------------------------- /src/models/graphmae.py: -------------------------------------------------------------------------------- 1 | # Ref: https://github.com/THUDM/GraphMAE/blob/main/chem/pretraining.py#L50 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch import Tensor 7 | from functools import partial 8 | from config.training_config import TrainingConfig 9 | from models.pre_trainer_model import PreTrainerModel 10 | from models.building_blocks.gnn import GNN, GNNDecoder 11 | 12 | NUM_NODE_ATTR = 119 13 | 14 | 15 | class GraphMAEModel(PreTrainerModel): 16 | def __init__(self, config: TrainingConfig, gnn: GNN): 17 | super(GraphMAEModel, self).__init__(config) 18 | self.gnn: nn.Module = gnn 19 | # self.emb_dim = config.emb_dim 20 | # self.pool: Callable = global_mean_pool 21 | self.atom_pred_decoder = GNNDecoder( 22 | config.emb_dim, NUM_NODE_ATTR, JK=config.JK, gnn_type="gin" 23 | ) 24 | # ref: https://github.com/THUDM/GraphMAE/blob/6d2636e942f6597d70f438e66ce876f80f9ca9e0/chem/pretraining.py#L137 25 | self.bond_pred_decoder = None 26 | 27 | def forward(self, batch) -> Tensor: 28 | node_rep = self.gnn(batch.x, batch.edge_index, batch.edge_attr) 29 | pred_node = self.atom_pred_decoder( 30 | node_rep, batch.edge_index, batch.edge_attr, batch.masked_atom_indices 31 | ) 32 | return pred_node 33 | 34 | def loss(self, pred_node, batch, alpha_l=1.0, loss_fn="sce") -> Tensor: 35 | def sce_loss(x, y, alpha=1): 36 | x = F.normalize(x, p=2, dim=-1) 37 | y = F.normalize(y, p=2, dim=-1) 38 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) 39 | return loss.mean() 40 | 41 | node_attr_label = batch.node_attr_label 42 | # node_attr_label = batch.mask_node_label 43 | masked_node_indices = batch.masked_atom_indices 44 | 45 | if loss_fn == "sce": 46 | criterion = partial(sce_loss, alpha=alpha_l) 47 | loss = criterion(node_attr_label, pred_node[masked_node_indices]) 48 | else: 49 | criterion = nn.CrossEntropyLoss() 50 | loss = criterion( 51 | pred_node.double()[masked_node_indices], batch.mask_node_label[:, 0] 52 | ) 53 | 54 | return loss 55 | -------------------------------------------------------------------------------- /src/models/graphmvp.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union, List 2 | 3 | # import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torch_geometric.data.batch import Batch 7 | from torch_geometric.nn import global_mean_pool 8 | from config.training_config import TrainingConfig 9 | from models.building_blocks.gnn import GNN 10 | from models.building_blocks.schnet import SchNet 11 | from models.building_blocks.auto_encoder import AutoEncoder, VariationalAutoEncoder 12 | from models.pre_trainer_model import PreTrainerModel 13 | from util import dual_CL 14 | 15 | 16 | class GraphMVPModel(PreTrainerModel): 17 | def __init__( 18 | self, 19 | config: TrainingConfig, 20 | gnn_2d: GNN, 21 | gnn_3d: SchNet, 22 | ae2d3d: Union[AutoEncoder, VariationalAutoEncoder], 23 | ae3d2d: Union[AutoEncoder, VariationalAutoEncoder], 24 | ): 25 | super(GraphMVPModel, self).__init__(config) 26 | self.gnn: nn.Module = gnn_2d 27 | self.gnn_3d: nn.Module = gnn_3d 28 | self.ae2d3d: nn.Module = ae2d3d 29 | self.ae3d2d: nn.Module = ae3d2d 30 | self.config: TrainingConfig = config 31 | self.pool: Callable = global_mean_pool 32 | 33 | def forward(self, batch: Batch) -> List[Tensor]: 34 | # x, edge_index, edge_attr, positions, batch_assignments = batch.x, batch.edg 35 | repr_node = self.gnn(batch.x, batch.edge_index, batch.edge_attr) 36 | repr_2d = self.pool(repr_node, batch.batch) 37 | repr_3d = self.gnn_3d(batch.x[:, 0], batch.positions, batch.batch) 38 | 39 | return repr_2d, repr_3d 40 | 41 | def loss(self, repr_2d: Tensor, repr_3d: Tensor) -> Tensor: 42 | CL_loss, _ = dual_CL(repr_2d, repr_3d, self.config) 43 | AE_loss_1 = self.ae2d3d(repr_2d, repr_3d) 44 | AE_loss_2 = self.ae3d2d(repr_3d, repr_2d) 45 | AE_loss = (AE_loss_1 + AE_loss_2) / 2 46 | 47 | loss = CL_loss * self.config.GMVP_alpha1 + AE_loss * self.config.GMVP_alpha2 48 | # loss = -torch.log(loss).mean() 49 | return loss 50 | -------------------------------------------------------------------------------- /src/models/graphpred.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torch_geometric.nn import global_mean_pool 7 | 8 | from config.training_config import TrainingConfig 9 | from models.building_blocks.gnn import GNN 10 | from models.pre_trainer_model import PreTrainerModel 11 | 12 | 13 | class GraphPred(nn.Module): 14 | def __init__(self, config: TrainingConfig, gnn: GNN, num_tasks: int): 15 | super(GraphPred, self).__init__() 16 | self.gnn: nn.Module = gnn 17 | self.emb_dim = config.emb_dim 18 | self.pool: Callable = global_mean_pool 19 | # self.downstream_head: nn.Module = nn.Sequential( 20 | # nn.Linear(self.emb_dim, self.emb_dim), 21 | # nn.ReLU(inplace=True), 22 | # nn.Linear(self.emb_dim, self.emb_dim), 23 | # ) 24 | if config.JK == "concat": 25 | self.downstream_head = nn.Linear( 26 | (self.num_layer + 1) * self.emb_dim, num_tasks 27 | ) 28 | else: 29 | self.downstream_head = nn.Linear(self.emb_dim, num_tasks) 30 | 31 | def forward( 32 | self, 33 | x: Tensor, 34 | edge_index: Tensor, 35 | edge_attr: Tensor, 36 | batch_assignments: Tensor, 37 | ) -> Tensor: 38 | x = self.gnn(x, edge_index, edge_attr) 39 | x = self.pool(x, batch_assignments) 40 | x = self.downstream_head(x) 41 | return x 42 | 43 | # def from_pretrained(self, model_file): 44 | # self.gnn.load_state_dict(model_file) 45 | # return 46 | -------------------------------------------------------------------------------- /src/models/info_max.py: -------------------------------------------------------------------------------- 1 | """ InfoMax + Graph 2 | Ref Paper 1: https://arxiv.org/abs/1908.01000 3 | Ref Paper 2: https://arxiv.org/abs/1908.01000 """ 4 | from typing import Callable, Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch_geometric.data import Batch 9 | from torch_geometric.nn import global_mean_pool 10 | 11 | from config.training_config import TrainingConfig 12 | from models import GNN, Discriminator 13 | from models.pre_trainer_model import PreTrainerModel 14 | 15 | 16 | class InfoMaxModel(PreTrainerModel): 17 | def __init__(self, config: TrainingConfig, gnn: GNN): 18 | super(InfoMaxModel, self).__init__(config) 19 | self.gnn: nn.Module = gnn 20 | self.pool: Callable = global_mean_pool 21 | self.infograph_discriminator_SSL_model = Discriminator(config.emb_dim) 22 | 23 | def forward_embedding(self, batch: Batch) -> Tuple[torch.Tensor]: 24 | node_emb = self.gnn(batch.x, batch.edge_index, batch.edge_attr) 25 | pooled_emb = self.pool(node_emb, batch.batch) 26 | return node_emb, pooled_emb 27 | -------------------------------------------------------------------------------- /src/models/joao_v2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from config.training_config import TrainingConfig 4 | from models.building_blocks.gnn import GNN 5 | from models.graph_cl import GraphCLModel 6 | 7 | 8 | class JOAOv2Model(GraphCLModel): 9 | def __init__(self, config: TrainingConfig, gnn: GNN): 10 | super(JOAOv2Model, self).__init__(config=config, gnn=gnn) 11 | self.gnn: nn.Module = gnn 12 | self.projection_head: nn.ModuleList = nn.ModuleList( 13 | [ 14 | nn.Sequential( 15 | nn.Linear(self.emb_dim, self.emb_dim), 16 | nn.ReLU(inplace=True), 17 | nn.Linear(self.emb_dim, self.emb_dim), 18 | ) 19 | for _ in range(5) 20 | ] 21 | ) 22 | 23 | def forward_cl(self, x, edge_index, edge_attr, batch, n_aug=0): 24 | x = self.gnn(x, edge_index, edge_attr) 25 | x = self.pool(x, batch) 26 | x = self.projection_head[n_aug](x) 27 | return x 28 | -------------------------------------------------------------------------------- /src/models/motif.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch.nn as nn 4 | from torch import Tensor 5 | from torch_geometric.nn import global_mean_pool 6 | 7 | from config.training_config import TrainingConfig 8 | from datasets import RDKIT_PROPS 9 | from models.building_blocks.gnn import GNN 10 | from models.pre_trainer_model import PreTrainerModel 11 | 12 | 13 | class MotifModel(PreTrainerModel): 14 | def __init__(self, config: TrainingConfig, gnn: GNN): 15 | super(MotifModel, self).__init__(config) 16 | self.gnn: nn.Module = gnn 17 | self.emb_dim = config.emb_dim 18 | self.num_tasks = len(RDKIT_PROPS) 19 | self.pool: Callable = global_mean_pool 20 | self.criterion = nn.BCEWithLogitsLoss() 21 | self.prediction_model: nn.Module = nn.Sequential( 22 | nn.Linear(self.emb_dim, self.num_tasks) 23 | ) 24 | 25 | def forward_cl( 26 | self, 27 | x: Tensor, 28 | edge_index: Tensor, 29 | edge_attr: Tensor, 30 | batch_assignments: Tensor, 31 | ) -> Tensor: 32 | x = self.gnn(x, edge_index, edge_attr) 33 | x = self.pool(x, batch_assignments) 34 | x = self.prediction_model(x) 35 | return x 36 | 37 | def loss_cl(self, y_pred: Tensor, y_actual: Tensor) -> Tensor: 38 | loss = self.criterion(y_pred, y_actual) 39 | return loss 40 | -------------------------------------------------------------------------------- /src/models/pre_trainer_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from config.training_config import TrainingConfig 3 | 4 | 5 | class PreTrainerModel(torch.nn.Module): 6 | def __init__(self, config: TrainingConfig): 7 | super().__init__() 8 | self.config = config 9 | 10 | def get_embeddings(self, *argv): 11 | return self.gnn.get_embeddings(*argv) 12 | -------------------------------------------------------------------------------- /src/models/rgcl.py: -------------------------------------------------------------------------------- 1 | # Ref: https://github.com/lsh0520/RGCL/blob/main/transferLearning/chem/pretrain_rgcl.py 2 | 3 | from typing import Callable 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | from torch_scatter import scatter_max 9 | from torch_geometric.nn import global_mean_pool 10 | from config.training_config import TrainingConfig 11 | from models.pre_trainer_model import PreTrainerModel 12 | from models.building_blocks.gnn import GNN, GNN_IMP_Estimator 13 | 14 | 15 | class RGCLModel(PreTrainerModel): 16 | def __init__(self, config: TrainingConfig, gnn: GNN): 17 | super(RGCLModel, self).__init__(config) 18 | self.gnn: nn.Module = gnn 19 | self.emb_dim = config.emb_dim 20 | self.pool: Callable = global_mean_pool 21 | self.node_imp_estimator = GNN_IMP_Estimator() 22 | self.projection_head: nn.Module = nn.Sequential( 23 | nn.Linear(self.emb_dim, self.emb_dim), 24 | nn.ReLU(inplace=True), 25 | nn.Linear(self.emb_dim, self.emb_dim), 26 | ) 27 | 28 | def forward_cl( 29 | self, 30 | x: Tensor, 31 | edge_index: Tensor, 32 | edge_attr: Tensor, 33 | batch_assignments: Tensor, 34 | ) -> Tensor: 35 | node_imp = self.node_imp_estimator(x, edge_index, edge_attr, batch_assignments) 36 | x = self.gnn(x, edge_index, edge_attr) 37 | 38 | out, _ = scatter_max(torch.reshape(node_imp, (1, -1)), batch_assignments) 39 | out = out.reshape(-1, 1) 40 | out = out[batch_assignments] 41 | node_imp /= out * 10 42 | node_imp += 0.9 43 | node_imp = node_imp.expand(-1, 300) 44 | 45 | x = torch.mul(x, node_imp) 46 | x = self.pool(x, batch_assignments) 47 | x = self.projection_head(x) 48 | 49 | return x 50 | 51 | def loss_cl(self, x1: Tensor, x2: Tensor, temp=0.1) -> Tensor: 52 | T = temp 53 | batch, _ = x1.size() 54 | x1_abs = x1.norm(dim=1) 55 | x2_abs = x2.norm(dim=1) 56 | 57 | sim_matrix = torch.einsum("ik,jk->ij", x1, x2) / torch.einsum( 58 | "i,j->ij", x1_abs, x2_abs 59 | ) 60 | sim_matrix = torch.exp(sim_matrix / T) 61 | pos_sim = sim_matrix[range(batch), range(batch)] 62 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 63 | loss = -torch.log(loss).mean() 64 | return loss 65 | 66 | def loss_infonce(self, x1, x2, temp=0.1): 67 | T = temp 68 | batch_size, _ = x1.size() 69 | x1_abs = x1.norm(dim=1) 70 | x2_abs = x2.norm(dim=1) 71 | 72 | sim_matrix = torch.einsum("ik,jk->ij", x1, x2) / torch.einsum( 73 | "i,j->ij", x1_abs, x2_abs 74 | ) 75 | sim_matrix = torch.exp(sim_matrix / T) 76 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 77 | loss = pos_sim / sim_matrix.sum(dim=1) 78 | loss = -torch.log(loss).mean() 79 | return loss 80 | 81 | def loss_ra(self, x1, x2, x3, temp=0.1, lamda=0.1): 82 | batch_size, _ = x1.size() 83 | x1_abs = x1.norm(dim=1) 84 | x2_abs = x2.norm(dim=1) 85 | x3_abs = x3.norm(dim=1) 86 | 87 | cp_sim_matrix = torch.einsum("ik,jk->ij", x1, x3) / torch.einsum( 88 | "i,j->ij", x1_abs, x3_abs 89 | ) 90 | cp_sim_matrix = torch.exp(cp_sim_matrix / temp) 91 | 92 | sim_matrix = torch.einsum("ik,jk->ij", x1, x2) / torch.einsum( 93 | "i,j->ij", x1_abs, x2_abs 94 | ) 95 | sim_matrix = torch.exp(sim_matrix / temp) 96 | 97 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 98 | 99 | ra_loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 100 | ra_loss = -torch.log(ra_loss).mean() 101 | 102 | cp_loss = pos_sim / (cp_sim_matrix.sum(dim=1) + pos_sim) 103 | cp_loss = -torch.log(cp_loss).mean() 104 | 105 | loss = ra_loss + lamda * cp_loss 106 | 107 | return ra_loss, cp_loss, loss 108 | -------------------------------------------------------------------------------- /src/pretrainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .attribute_masking import AttributeMaskingPreTrainer 2 | from .context_prediction import ContextPredictionPreTrainer 3 | from .contextual import ContextualPreTrainer 4 | from .edge_prediction import EdgePredictionPreTrainer 5 | from .gpt_gnn import GPTGNNPreTrainer 6 | from .graph_cl import GraphCLPreTrainer 7 | from .info_max import InfoMaxPreTrainer 8 | from .joao import JOAOPreTrainer 9 | from .joao_v2 import JOAOv2PreTrainer 10 | from .motif import MotifPreTrainer 11 | from .graphmvp import GraphMVPPreTrainer 12 | from .rgcl import RGCLPreTrainer 13 | from .graphmae import GraphMAEPreTrainer 14 | from .pretrainer import PreTrainer 15 | -------------------------------------------------------------------------------- /src/pretrainers/attribute_masking.py: -------------------------------------------------------------------------------- 1 | """ GRAPH SSL Pre-Training via Attribute Masking 2 | i.e., maps nodes in similar structural contexts to closer embeddings 3 | Ref Paper: Sec. 3.1.2 and Appendix G of 4 | https://arxiv.org/abs/1905.12265 ; 5 | Ref Code: ${GitHub_Repo}/chem/pretrain_contextpred.py """ 6 | 7 | from typing import Tuple 8 | 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from torch_geometric.data import Batch 12 | 13 | from config.training_config import TrainingConfig 14 | from logger import CombinedLogger 15 | from models.attribute_masking import AttributeMaskingModel 16 | from pretrainers.pretrainer import PreTrainer 17 | from util import get_lr 18 | 19 | 20 | class AttributeMaskingPreTrainer(PreTrainer): 21 | def __init__( 22 | self, 23 | config: TrainingConfig, 24 | model: AttributeMaskingModel, 25 | optimizer: torch.optim.Optimizer, 26 | device: torch.device, 27 | logger: CombinedLogger, 28 | ) -> None: 29 | super(AttributeMaskingPreTrainer, self).__init__( 30 | config=config, 31 | model=model, 32 | optimizer=optimizer, 33 | device=device, 34 | logger=logger, 35 | ) 36 | 37 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 38 | self.model.train() 39 | self.logger.train(num_batches=len(train_data_loader)) 40 | 41 | attributemask_loss_accum = 0.0 42 | 43 | for step, batch in enumerate(train_data_loader): 44 | batch = batch.to(self.device) 45 | 46 | loss, acc = do_AttrMasking( 47 | batch=batch, 48 | criterion=torch.nn.CrossEntropyLoss(), 49 | model=self.model, 50 | ) 51 | loss_float = loss.detach().cpu().item() 52 | attributemask_loss_accum += loss_float 53 | self.optimizer.zero_grad() 54 | loss.backward() 55 | self.optimizer.step() 56 | self.logger(loss_float, acc, batch.num_graphs, get_lr(self.optimizer)) 57 | 58 | return attributemask_loss_accum / (step + 1) 59 | 60 | # TODO 61 | def validate_model(self, val_data_loader: DataLoader) -> float: 62 | # self.logger.eval(num_batches=len(val_data_loader)) 63 | self.logger.eval(num_batches=1) 64 | self.logger(0.0, 0.0, 1) 65 | return 0.0 66 | 67 | 68 | def compute_accuracy(pred: torch.Tensor, target: torch.Tensor) -> float: 69 | pred_cls = torch.max(pred.detach(), dim=1)[1] 70 | return float(torch.sum(pred_cls == target).cpu().item()) / len(pred) 71 | 72 | 73 | def do_AttrMasking( 74 | batch: Batch, criterion: torch.nn.Module, model: AttributeMaskingModel 75 | ) -> Tuple[torch.Tensor, float]: 76 | target = batch.mask_node_label[:, 0] 77 | node_pred = model.forward(batch) 78 | attributemask_loss = criterion(node_pred.double(), target) 79 | attributemask_acc = compute_accuracy(node_pred, target) 80 | return attributemask_loss, attributemask_acc 81 | -------------------------------------------------------------------------------- /src/pretrainers/context_prediction.py: -------------------------------------------------------------------------------- 1 | """ GRAPH SSL Pre-Training via Context Prediction (CP) 2 | i.e., maps nodes in similar structural contexts to closer embeddings 3 | Ref Paper: Sec. 3.1.1 and Appendix G of 4 | https://arxiv.org/abs/1905.12265 ; 5 | Ref Code: ${GitHub_Repo}/chem/pretrain_contextpred.py """ 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | from torch_geometric.data import Batch 12 | 13 | from config.training_config import TrainingConfig 14 | from logger import CombinedLogger 15 | from models.context_prediction import ContextPredictionModel 16 | from pretrainers.pretrainer import PreTrainer 17 | from util import cycle_idx, get_lr 18 | 19 | 20 | class ContextPredictionPreTrainer(PreTrainer): 21 | def __init__( 22 | self, 23 | config: TrainingConfig, 24 | model: ContextPredictionModel, 25 | optimizer: torch.optim.Optimizer, 26 | device: torch.device, 27 | logger: CombinedLogger, 28 | ): 29 | super(ContextPredictionPreTrainer, self).__init__( 30 | config=config, 31 | model=model, 32 | optimizer=optimizer, 33 | device=device, 34 | logger=logger, 35 | ) 36 | self.criterion = nn.BCEWithLogitsLoss() 37 | 38 | def validate_model(self, val_data_loader: DataLoader): 39 | pass 40 | 41 | def train_for_one_epoch(self, train_data_loader: DataLoader): 42 | self.model.train() 43 | contextpred_loss_accum, contextpred_acc_accum = 0, 0 44 | self.logger.train(num_batches=len(train_data_loader)) 45 | 46 | for step, batch in enumerate(train_data_loader): 47 | batch = batch.to(self.device) 48 | contextpred_loss, contextpred_acc = self.do_ContextPred(batch=batch) 49 | loss_float = contextpred_loss.detach().cpu().item() 50 | ssl_loss = contextpred_loss 51 | self.optimizer.zero_grad() 52 | ssl_loss.backward() 53 | self.optimizer.step() 54 | self.logger( 55 | ssl_loss.cpu().item(), 56 | contextpred_acc, 57 | batch.num_graphs, 58 | get_lr(self.optimizer), 59 | ) 60 | contextpred_loss_accum += loss_float 61 | contextpred_acc_accum += contextpred_acc 62 | 63 | return contextpred_loss_accum / len(train_data_loader) 64 | # return contextpred_loss_accum / len(train_data_loader), \ 65 | # contextpred_acc_accum / len(train_data_loader) 66 | 67 | def do_ContextPred(self, batch: Batch): 68 | # creating substructure representation 69 | substruct_repr = self.model.forward_substruct_model( 70 | x=batch.x_substruct, 71 | edge_index=batch.edge_index_substruct, 72 | edge_attr=batch.edge_attr_substruct, 73 | )[batch.center_substruct_idx] 74 | 75 | # substruct_repr = molecule_substruct_model( 76 | # batch.x_substruct, batch.edge_index_substruct, 77 | # batch.edge_attr_substruct)[batch.center_substruct_idx] 78 | 79 | # create positive context representation 80 | # readout -> global_mean_pool by default 81 | 82 | context_repr = self.model.forward_context_repr( 83 | x=batch.x_context, 84 | edge_index=batch.edge_index_context, 85 | edge_attr=batch.edge_attr_context, 86 | overlap_context_substruct_idx=batch.overlap_context_substruct_idx, 87 | batch_overlapped_context=batch.batch_overlapped_context, 88 | ) 89 | 90 | # negative contexts are obtained by shifting 91 | # the indices of context embeddings 92 | neg_context_repr = torch.cat( 93 | [ 94 | context_repr[cycle_idx(len(context_repr), i + 1)] 95 | for i in range(self.config.contextpred_neg_samples) 96 | ], 97 | dim=0, 98 | ) 99 | 100 | num_neg = self.config.contextpred_neg_samples 101 | 102 | pred_pos = torch.sum(substruct_repr * context_repr, dim=1) 103 | pred_neg = torch.sum( 104 | substruct_repr.repeat((num_neg, 1)) * neg_context_repr, dim=1 105 | ) 106 | 107 | loss_pos = self.criterion( 108 | pred_pos.double(), torch.ones(len(pred_pos)).to(pred_pos.device).double() 109 | ) 110 | loss_neg = self.criterion( 111 | pred_neg.double(), torch.zeros(len(pred_neg)).to(pred_neg.device).double() 112 | ) 113 | 114 | contextpred_loss = loss_pos + num_neg * loss_neg 115 | 116 | num_pred = len(pred_pos) + len(pred_neg) 117 | contextpred_acc = ( 118 | torch.sum(pred_pos > 0).float() + torch.sum(pred_neg < 0).float() 119 | ) / num_pred 120 | contextpred_acc = contextpred_acc.detach().cpu().item() 121 | 122 | return contextpred_loss, contextpred_acc 123 | 124 | 125 | if __name__ == "__main__": 126 | pass 127 | -------------------------------------------------------------------------------- /src/pretrainers/contextual.py: -------------------------------------------------------------------------------- 1 | """ GROVER, Contextual Property Prediction 2 | Ref Paper: https://arxiv.org/abs/2007.02835 3 | Ref Code: https://github.com/tencent-ailab/grover """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | 9 | from config.training_config import TrainingConfig 10 | from logger import CombinedLogger 11 | from models.contextual import ContextualModel 12 | from pretrainers.pretrainer import PreTrainer 13 | from util import get_lr 14 | 15 | 16 | class ContextualPreTrainer(PreTrainer): 17 | def __init__( 18 | self, 19 | config: TrainingConfig, 20 | model: ContextualModel, 21 | optimizer: torch.optim.Optimizer, 22 | device: torch.device, 23 | logger: CombinedLogger, 24 | ): 25 | super(ContextualPreTrainer, self).__init__( 26 | config=config, 27 | model=model, 28 | optimizer=optimizer, 29 | device=device, 30 | logger=logger, 31 | ) 32 | 33 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 34 | self.model.train() 35 | train_loss_accum = 0.0 36 | self.logger.train(num_batches=len(train_data_loader)) 37 | 38 | for step, batch in enumerate(train_data_loader): 39 | batch = batch.to(self.device) 40 | node_pred = self.model.forward_cl( 41 | batch.x, batch.edge_index, batch.edge_attr, batch.batch 42 | ) 43 | node_target = batch.atom_vocab_label 44 | loss = self.model.loss_cl(node_pred, node_target) 45 | 46 | self.optimizer.zero_grad() 47 | loss.backward() 48 | self.optimizer.step() 49 | loss_float = float(loss.detach().cpu().item()) 50 | train_loss_accum += loss_float 51 | # loss, accuracy, batch_size, learning_rate 52 | self.logger(loss_float, 0.0, batch.num_graphs, get_lr(self.optimizer)) 53 | 54 | return train_loss_accum / (step + 1) 55 | 56 | # TODO 57 | def validate_model(self, val_data_loader: DataLoader) -> float: 58 | # self.logger.eval(num_batches=len(val_data_loader)) 59 | self.logger.eval(num_batches=1) 60 | self.logger(0.0, 0.0, 1) 61 | return 0.0 62 | -------------------------------------------------------------------------------- /src/pretrainers/edge_prediction.py: -------------------------------------------------------------------------------- 1 | """ GRAPH SSL Pre-Training via Edge Prediction 2 | Ref Paper: Sec. 5.2 and Appendix G of 3 | https://arxiv.org/abs/1905.12265 ; 4 | which is adapted from 5 | https://arxiv.org/abs/1706.02216 ; 6 | 7 | Ref Code: ${GitHub_Repo}/chem/pretrain_edgepred.py """ 8 | 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader 13 | 14 | from config.training_config import TrainingConfig 15 | from logger import CombinedLogger 16 | from models.edge_prediction import EdgePredictionModel 17 | from pretrainers.pretrainer import PreTrainer 18 | from util import get_lr 19 | 20 | 21 | def do_EdgePred(node_repr, batch, criterion): 22 | # positive/negative scores -> inner product of node features 23 | positive_score = torch.sum( 24 | node_repr[batch.edge_index[0, ::2]] * node_repr[batch.edge_index[1, ::2]], dim=1 25 | ) 26 | negative_score = torch.sum( 27 | node_repr[batch.negative_edge_index[0]] 28 | * node_repr[batch.negative_edge_index[1]], 29 | dim=1, 30 | ) 31 | 32 | edgepred_loss = criterion( 33 | positive_score, torch.ones_like(positive_score) 34 | ) + criterion(negative_score, torch.zeros_like(negative_score)) 35 | edgepred_acc = (torch.sum(positive_score > 0) + torch.sum(negative_score < 0)).to( 36 | torch.float32 37 | ) / float(2 * len(positive_score)) 38 | edgepred_acc = edgepred_acc.detach().cpu().item() 39 | 40 | return edgepred_loss, edgepred_acc 41 | 42 | 43 | class EdgePredictionPreTrainer(PreTrainer): 44 | def __init__( 45 | self, 46 | config: TrainingConfig, 47 | model: EdgePredictionModel, 48 | optimizer: torch.optim.Optimizer, 49 | device: torch.device, 50 | logger: CombinedLogger, 51 | ) -> None: 52 | super(EdgePredictionPreTrainer, self).__init__( 53 | config=config, 54 | model=model, 55 | optimizer=optimizer, 56 | device=device, 57 | logger=logger, 58 | ) 59 | 60 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 61 | self.model.train() 62 | self.logger.train(num_batches=len(train_data_loader)) 63 | 64 | loss_accum = 0.0 65 | 66 | for step, batch in enumerate(train_data_loader): 67 | batch = batch.to(self.device) 68 | node_repr = self.model(batch) 69 | edgepred_loss, edgepred_acc = do_EdgePred( 70 | node_repr=node_repr, batch=batch, criterion=nn.BCEWithLogitsLoss() 71 | ) 72 | loss_float = edgepred_loss.detach().cpu().item() 73 | loss_accum += loss_float 74 | self.optimizer.zero_grad() 75 | edgepred_loss.backward() 76 | self.optimizer.step() 77 | self.logger( 78 | loss_float, edgepred_acc, batch.num_graphs, get_lr(self.optimizer) 79 | ) 80 | 81 | return loss_accum / (step + 1) 82 | 83 | # TODO 84 | def validate_model(self, val_data_loader: DataLoader) -> float: 85 | # self.logger.eval(num_batches=len(val_data_loader)) 86 | self.logger.eval(num_batches=1) 87 | self.logger(0.0, 0.0, 1) 88 | return 0.0 89 | 90 | 91 | if __name__ == "__main__": 92 | pass 93 | -------------------------------------------------------------------------------- /src/pretrainers/gpt_gnn.py: -------------------------------------------------------------------------------- 1 | """ GPT-GNN pre-training 2 | Current version only supports the node reconstruction for molecular data. 3 | Because the current GIN model only supports node representation. 4 | In the future, we will add edge reconstruction and more general graph data.""" 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from config.training_config import TrainingConfig 10 | from logger import CombinedLogger 11 | from models import GPTGNNModel 12 | from pretrainers.pretrainer import PreTrainer 13 | from util import get_lr 14 | 15 | 16 | class GPTGNNPreTrainer(PreTrainer): 17 | def __init__( 18 | self, 19 | config: TrainingConfig, 20 | model: GPTGNNModel, 21 | optimizer: torch.optim.Optimizer, 22 | device: torch.device, 23 | logger: CombinedLogger, 24 | ) -> None: 25 | super(GPTGNNPreTrainer, self).__init__( 26 | config=config, 27 | model=model, 28 | optimizer=optimizer, 29 | device=device, 30 | logger=logger, 31 | ) 32 | 33 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 34 | self.model.train() 35 | self.logger.train(num_batches=len(train_data_loader)) 36 | gpt_loss_accum = 0.0 37 | criterion = torch.nn.CrossEntropyLoss() 38 | 39 | for step, batch in enumerate(train_data_loader): 40 | batch = batch.to(self.device) 41 | 42 | node_pred = self.model(batch) 43 | target = batch.next_x 44 | 45 | loss = criterion(node_pred.double(), target) 46 | acc = compute_accuracy(node_pred, target) 47 | 48 | loss_float = loss.detach().cpu().item() 49 | gpt_loss_accum += loss.detach().cpu().item() 50 | self.optimizer.zero_grad() 51 | loss.backward() 52 | self.optimizer.step() 53 | self.logger(loss_float, acc, batch.num_graphs, get_lr(self.optimizer)) 54 | 55 | return gpt_loss_accum / (step + 1) 56 | 57 | # TODO 58 | def validate_model(self, val_data_loader: DataLoader) -> float: 59 | # self.logger.eval(num_batches=len(val_data_loader)) 60 | self.logger.eval(num_batches=1) 61 | self.logger(0.0, 0.0, 1) 62 | return 0.0 63 | 64 | 65 | def compute_accuracy(pred, target): 66 | return float( 67 | torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item() 68 | ) / len(pred) 69 | -------------------------------------------------------------------------------- /src/pretrainers/graph_cl.py: -------------------------------------------------------------------------------- 1 | # Ref: {GitHub}/transferLearning_MoleculeNet_PPI 2 | # Ref: https://arxiv.org/abs/2010.13902 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from config.training_config import TrainingConfig 7 | from logger import CombinedLogger 8 | from models.graph_cl import GraphCLModel 9 | from pretrainers.pretrainer import PreTrainer 10 | from util import get_lr 11 | 12 | 13 | class GraphCLPreTrainer(PreTrainer): 14 | def __init__( 15 | self, 16 | config: TrainingConfig, 17 | model: GraphCLModel, 18 | optimizer: torch.optim.Optimizer, 19 | device: torch.device, 20 | logger: CombinedLogger, 21 | ): 22 | super(GraphCLPreTrainer, self).__init__( 23 | config=config, 24 | model=model, 25 | optimizer=optimizer, 26 | device=device, 27 | logger=logger, 28 | ) 29 | 30 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 31 | self.model.train() 32 | train_loss_accum = 0.0 33 | self.logger.train(num_batches=len(train_data_loader)) 34 | 35 | for step, (_, batch1, batch2) in enumerate(train_data_loader): 36 | batch1 = batch1.to(self.device) 37 | batch2 = batch2.to(self.device) 38 | 39 | x1 = self.model.forward_cl( 40 | batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch 41 | ) 42 | x2 = self.model.forward_cl( 43 | batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch 44 | ) 45 | loss = self.model.loss_cl(x1, x2) 46 | 47 | self.optimizer.zero_grad() 48 | loss.backward() 49 | self.optimizer.step() 50 | loss_float = float(loss.detach().cpu().item()) 51 | train_loss_accum += loss_float 52 | self.logger(loss_float, 0.0, batch1.num_graphs, get_lr(self.optimizer)) 53 | 54 | return train_loss_accum / (step + 1) 55 | 56 | # TODO 57 | def validate_model(self, val_data_loader: DataLoader) -> float: 58 | # self.logger.eval(num_batches=len(val_data_loader)) 59 | self.logger.eval(num_batches=1) 60 | self.logger(0.0, 0.0, 1) 61 | return 0.0 62 | -------------------------------------------------------------------------------- /src/pretrainers/graphmae.py: -------------------------------------------------------------------------------- 1 | # Ref: {GitHub}/lsh0520/RGCL/transferLearning/chem/pretrain_rgcl.py 2 | # Ref: https://arxiv.org/abs/2010.13902 3 | # TODO: TO UPDATE 4 | """ GRAPH SSL Pre-Training via InfoGraph [InfoGraph] 5 | i.e., maps nodes in similar structural contexts to closer embeddings 6 | Ref Paper: Sec. 5.2 and Appendix G of 7 | https://arxiv.org/abs/1905.12265 ; 8 | which is adapted from 9 | https://arxiv.org/abs/1809.10341 ; 10 | Ref Code: ${GitHub_Repo}/chem/pretrain_deepgraphinfomax.py """ 11 | 12 | import torch 13 | 14 | from config.training_config import TrainingConfig 15 | from pretrainers.pretrainer import PreTrainer 16 | from models.graphmae import GraphMAEModel 17 | from torch.utils.data import DataLoader 18 | from logger import CombinedLogger 19 | from util import get_lr 20 | from typing import List 21 | 22 | 23 | class GraphMAEPreTrainer(PreTrainer): 24 | def __init__( 25 | self, 26 | config: TrainingConfig, 27 | model: GraphMAEModel, 28 | optimizer: List, 29 | device: torch.device, 30 | logger: CombinedLogger, 31 | ): 32 | super(GraphMAEPreTrainer, self).__init__( 33 | config=config, 34 | model=model, 35 | optimizer=optimizer, 36 | device=device, 37 | logger=logger, 38 | ) 39 | assert len(self.optimizer) == 2 40 | 41 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 42 | train_loss_accum = 0 43 | self.model.train() 44 | self.logger.train(num_batches=len(train_data_loader)) 45 | 46 | for step, batch in enumerate(train_data_loader): 47 | batch = batch.to(self.device) 48 | self.optimizer[0].zero_grad() 49 | self.optimizer[1].zero_grad() 50 | pred_node = self.model.forward(batch) 51 | 52 | loss = self.model.loss(pred_node, batch) 53 | loss.backward() 54 | self.optimizer[0].step() 55 | self.optimizer[1].step() 56 | 57 | loss_float = float(loss.detach().cpu().item()) 58 | train_loss_accum += loss_float 59 | self.logger(loss_float, 0.0, batch.num_graphs, get_lr(self.optimizer[0])) 60 | 61 | return train_loss_accum / (step + 1) 62 | 63 | def validate_model(self, val_data_loader) -> float: 64 | # self.logger.eval(num_batches=len(val_data_loader)) 65 | self.logger.eval(num_batches=1) 66 | self.logger(0.0, 0.0, 1) 67 | return 0.0 68 | -------------------------------------------------------------------------------- /src/pretrainers/graphmvp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | from config.training_config import TrainingConfig 5 | from logger import CombinedLogger 6 | from models.graphmvp import GraphMVPModel 7 | from pretrainers.pretrainer import PreTrainer 8 | from util import get_lr 9 | 10 | 11 | class GraphMVPPreTrainer(PreTrainer): 12 | def __init__( 13 | self, 14 | config: TrainingConfig, 15 | model: GraphMVPModel, 16 | optimizer: torch.optim.Optimizer, 17 | device: torch.device, 18 | logger: CombinedLogger, 19 | ): 20 | super(GraphMVPPreTrainer, self).__init__( 21 | config=config, 22 | model=model, 23 | optimizer=optimizer, 24 | device=device, 25 | logger=logger, 26 | ) 27 | 28 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 29 | self.model.train() 30 | train_loss_accum = 0.0 31 | self.logger.train(num_batches=len(train_data_loader)) 32 | 33 | for step, batch in enumerate(train_data_loader): 34 | batch = batch.to(self.device) 35 | repr_2d, repr_3d = self.model(batch) 36 | loss = self.model.loss(repr_2d, repr_3d) 37 | 38 | self.optimizer.zero_grad() 39 | loss.backward() 40 | self.optimizer.step() 41 | 42 | loss_float = float(loss.detach().cpu().item()) 43 | train_loss_accum += loss_float 44 | self.logger(loss_float, 0.0, batch.num_graphs, get_lr(self.optimizer)) 45 | 46 | return train_loss_accum / (step + 1) 47 | 48 | def validate_model(self, val_data_loader: DataLoader) -> float: 49 | return 0 50 | -------------------------------------------------------------------------------- /src/pretrainers/info_max.py: -------------------------------------------------------------------------------- 1 | """ GRAPH SSL Pre-Training via InfoGraph [InfoGraph] 2 | i.e., maps nodes in similar structural contexts to closer embeddings 3 | Ref Paper: Sec. 5.2 and Appendix G of 4 | https://arxiv.org/abs/1905.12265 ; 5 | which is adapted from 6 | https://arxiv.org/abs/1809.10341 ; 7 | Ref Code: ${GitHub_Repo}/chem/pretrain_deepgraphinfomax.py """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch_geometric.data import Batch, DataLoader 12 | from torch_geometric.nn import global_mean_pool 13 | 14 | from config.training_config import TrainingConfig 15 | from logger import CombinedLogger 16 | from models.info_max import InfoMaxModel 17 | from pretrainers.pretrainer import PreTrainer 18 | from util import cycle_idx, get_lr 19 | 20 | 21 | class InfoMaxPreTrainer(PreTrainer): 22 | def __init__( 23 | self, 24 | config: TrainingConfig, 25 | model: InfoMaxModel, 26 | optimizer: torch.optim.Optimizer, 27 | device: torch.device, 28 | logger: CombinedLogger, 29 | ): 30 | super(InfoMaxPreTrainer, self).__init__( 31 | config=config, 32 | model=model, 33 | optimizer=optimizer, 34 | device=device, 35 | logger=logger, 36 | ) 37 | # TODO: perhaps we can add the below attributes as arguments? 38 | self.molecule_readout_func = global_mean_pool 39 | self.criterion = nn.BCEWithLogitsLoss() 40 | 41 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 42 | self.model.train() 43 | infograph_loss_accum, infograph_acc_accum = 0, 0 44 | self.logger.train(num_batches=len(train_data_loader)) 45 | 46 | for step, batch in enumerate(train_data_loader): 47 | batch = batch.to(self.device) 48 | node_emb, pooled_emb = self.model.forward_embedding(batch) 49 | infograph_loss, infograph_acc = do_InfoGraph( 50 | node_repr=node_emb, 51 | batch=batch, 52 | molecule_repr=pooled_emb, 53 | criterion=self.criterion, 54 | model=self.model, 55 | ) 56 | 57 | self.optimizer.zero_grad() 58 | infograph_loss.backward() 59 | self.optimizer.step() 60 | infograph_loss = infograph_loss.detach().cpu().item() 61 | infograph_loss_accum += infograph_loss 62 | infograph_acc_accum += infograph_acc 63 | self.logger( 64 | infograph_loss, infograph_acc, batch.num_graphs, get_lr(self.optimizer) 65 | ) 66 | 67 | return infograph_loss_accum / (step + 1) 68 | 69 | # TODO 70 | def validate_model(self, val_data_loader: DataLoader) -> float: 71 | # self.logger.eval(num_batches=len(val_data_loader)) 72 | self.logger.eval(num_batches=1) 73 | self.logger(0.0, 0.0, 1) 74 | return 0.0 75 | 76 | 77 | def do_InfoGraph( 78 | node_repr: torch.Tensor, 79 | molecule_repr: torch.Tensor, 80 | batch: Batch, 81 | criterion: torch.nn.Module, 82 | model: InfoMaxModel, 83 | ): 84 | summary_repr = torch.sigmoid(molecule_repr) 85 | positive_expanded_summary_repr = summary_repr[batch.batch] 86 | shifted_summary_repr = summary_repr[cycle_idx(len(summary_repr), 1)] 87 | negative_expanded_summary_repr = shifted_summary_repr[batch.batch] 88 | 89 | positive_score = model.infograph_discriminator_SSL_model( 90 | node_repr, positive_expanded_summary_repr 91 | ) 92 | negative_score = model.infograph_discriminator_SSL_model( 93 | node_repr, negative_expanded_summary_repr 94 | ) 95 | infograph_loss = criterion( 96 | positive_score, torch.ones_like(positive_score) 97 | ) + criterion(negative_score, torch.zeros_like(negative_score)) 98 | 99 | num_sample = float(2 * len(positive_score)) 100 | infograph_acc = (torch.sum(positive_score > 0) + torch.sum(negative_score < 0)).to( 101 | torch.float32 102 | ) / num_sample 103 | infograph_acc = infograph_acc.detach().cpu().item() 104 | 105 | return infograph_loss, infograph_acc 106 | -------------------------------------------------------------------------------- /src/pretrainers/joao.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | 5 | from config.training_config import TrainingConfig 6 | from logger import CombinedLogger 7 | from models.graph_cl import GraphCLModel 8 | from pretrainers.pretrainer import PreTrainer 9 | from util import get_lr 10 | 11 | 12 | class JOAOPreTrainer(PreTrainer): 13 | def __init__( 14 | self, 15 | config: TrainingConfig, 16 | model: GraphCLModel, 17 | optimizer: torch.optim.Optimizer, 18 | device: torch.device, 19 | logger: CombinedLogger, 20 | ): 21 | super(JOAOPreTrainer, self).__init__( 22 | config=config, 23 | model=model, 24 | optimizer=optimizer, 25 | device=device, 26 | logger=logger, 27 | ) 28 | 29 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 30 | self.model.train() 31 | train_loss_accum = 0.0 32 | self.logger.train(num_batches=len(train_data_loader)) 33 | 34 | for step, (_, batch1, batch2) in enumerate(train_data_loader): 35 | batch1 = batch1.to(self.device) 36 | batch2 = batch2.to(self.device) 37 | 38 | x1 = self.model.forward_cl( 39 | batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch 40 | ) 41 | x2 = self.model.forward_cl( 42 | batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch 43 | ) 44 | loss = self.model.loss_cl(x1, x2) 45 | 46 | self.optimizer.zero_grad() 47 | loss.backward() 48 | self.optimizer.step() 49 | 50 | loss_float = float(loss.detach().cpu().item()) 51 | train_loss_accum += loss_float 52 | self.logger(loss_float, 0.0, batch1.num_graphs, get_lr(self.optimizer)) 53 | 54 | # joao 55 | aug_prob = train_data_loader.dataset.aug_prob 56 | # TODO: Can we replace the below constants (25, 10, etc.) with arguments in the config? 57 | loss_aug = np.zeros(25) 58 | for n in range(25): 59 | _aug_prob = np.zeros(25) 60 | _aug_prob[n] = 1 61 | train_data_loader.dataset.set_augProb(_aug_prob) 62 | # for efficiency, we only use around 10% of data to estimate the loss 63 | count, count_stop = ( 64 | 0, 65 | len(train_data_loader.dataset) // (train_data_loader.batch_size * 10) 66 | + 1, 67 | ) 68 | 69 | with torch.no_grad(): 70 | for step, (_, batch1, batch2) in enumerate(train_data_loader): 71 | batch1 = batch1.to(self.device) 72 | batch2 = batch2.to(self.device) 73 | 74 | x1 = self.model.forward_cl( 75 | batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch 76 | ) 77 | x2 = self.model.forward_cl( 78 | batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch 79 | ) 80 | loss = self.model.loss_cl(x1, x2) 81 | loss_aug[n] += loss.item() 82 | count += 1 83 | if count == count_stop: 84 | break 85 | loss_aug[n] /= count 86 | 87 | # view selection, projected gradient descent, 88 | # reference: https://arxiv.org/abs/1906.03563 89 | beta = 1 90 | gamma = self.config.gamma_joao 91 | 92 | b = aug_prob + beta * (loss_aug - gamma * (aug_prob - 1 / 25)) 93 | mu_min, mu_max = b.min() - 1 / 25, b.max() - 1 / 25 94 | mu = (mu_min + mu_max) / 2 95 | 96 | # bisection method 97 | while abs(np.maximum(b - mu, 0).sum() - 1) > 1e-2: 98 | if np.maximum(b - mu, 0).sum() > 1: 99 | mu_min = mu 100 | else: 101 | mu_max = mu 102 | mu = (mu_min + mu_max) / 2 103 | 104 | aug_prob = np.maximum(b - mu, 0) 105 | aug_prob /= aug_prob.sum() 106 | train_data_loader.dataset.set_augProb(aug_prob=aug_prob) 107 | self.logger.log_value_dict({"aug_prob": aug_prob}) 108 | return train_loss_accum / (step + 1) 109 | 110 | # TODO 111 | def validate_model(self, val_data_loader: DataLoader) -> float: 112 | # self.logger.eval(num_batches=len(val_data_loader)) 113 | self.logger.eval(num_batches=1) 114 | self.logger(0.0, 0.0, 1) 115 | return 0.0 116 | -------------------------------------------------------------------------------- /src/pretrainers/joao_v2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | 5 | from config.training_config import TrainingConfig 6 | from logger import CombinedLogger 7 | from models.joao_v2 import JOAOv2Model 8 | from pretrainers.pretrainer import PreTrainer 9 | from util import get_lr 10 | 11 | 12 | class JOAOv2PreTrainer(PreTrainer): 13 | def __init__( 14 | self, 15 | config: TrainingConfig, 16 | model: JOAOv2Model, 17 | optimizer: torch.optim.Optimizer, 18 | device: torch.device, 19 | logger: CombinedLogger, 20 | ): 21 | super(JOAOv2PreTrainer, self).__init__( 22 | config=config, 23 | model=model, 24 | optimizer=optimizer, 25 | device=device, 26 | logger=logger, 27 | ) 28 | 29 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 30 | self.model.train() 31 | train_loss_accum = 0.0 32 | self.logger.train(num_batches=len(train_data_loader)) 33 | 34 | aug_prob = train_data_loader.dataset.aug_prob 35 | n_aug = np.random.choice(25, 1, p=aug_prob)[0] 36 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5 37 | 38 | for step, (_, batch1, batch2) in enumerate(train_data_loader): 39 | batch1 = batch1.to(self.device) 40 | batch2 = batch2.to(self.device) 41 | 42 | x1 = self.model.forward_cl( 43 | batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch, n_aug1 44 | ) 45 | x2 = self.model.forward_cl( 46 | batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch, n_aug2 47 | ) 48 | loss = self.model.loss_cl(x1, x2) 49 | 50 | self.optimizer.zero_grad() 51 | loss.backward() 52 | self.optimizer.step() 53 | 54 | loss_float = float(loss.detach().cpu().item()) 55 | train_loss_accum += loss_float 56 | self.logger(loss_float, 0.0, batch1.num_graphs, get_lr(self.optimizer)) 57 | 58 | # joaov2 59 | aug_prob = train_data_loader.dataset.aug_prob 60 | # TODO: Can we replace the below constants (25, 10, etc.) with arguments in the config? 61 | loss_aug = np.zeros(25) 62 | for n in range(25): 63 | _aug_prob = np.zeros(25) 64 | _aug_prob[n] = 1 65 | train_data_loader.dataset.set_augProb(_aug_prob) 66 | 67 | count, count_stop = ( 68 | 0, 69 | len(train_data_loader.dataset) // (train_data_loader.batch_size * 10) 70 | + 1, 71 | ) 72 | # for efficiency, we only use around 10% of data to estimate the loss 73 | n_aug1, n_aug2 = n // 5, n % 5 74 | with torch.no_grad(): 75 | for step, (_, batch1, batch2) in enumerate(train_data_loader): 76 | batch1 = batch1.to(self.device) 77 | batch2 = batch2.to(self.device) 78 | 79 | x1 = self.model.forward_cl( 80 | batch1.x, 81 | batch1.edge_index, 82 | batch1.edge_attr, 83 | batch1.batch, 84 | n_aug1, 85 | ) 86 | x2 = self.model.forward_cl( 87 | batch2.x, 88 | batch2.edge_index, 89 | batch2.edge_attr, 90 | batch2.batch, 91 | n_aug2, 92 | ) 93 | loss = self.model.loss_cl(x1, x2) 94 | loss_aug[n] += loss.item() 95 | count += 1 96 | if count == count_stop: 97 | break 98 | loss_aug[n] /= count 99 | 100 | # view selection, projected gradient descent, 101 | # reference: https://arxiv.org/abs/1906.03563 102 | beta = 1 103 | gamma = self.config.gamma_joaov2 104 | 105 | b = aug_prob + beta * (loss_aug - gamma * (aug_prob - 1 / 25)) 106 | mu_min, mu_max = b.min() - 1 / 25, b.max() - 1 / 25 107 | mu = (mu_min + mu_max) / 2 108 | 109 | # bisection method 110 | while abs(np.maximum(b - mu, 0).sum() - 1) > 1e-2: 111 | if np.maximum(b - mu, 0).sum() > 1: 112 | mu_min = mu 113 | else: 114 | mu_max = mu 115 | mu = (mu_min + mu_max) / 2 116 | 117 | aug_prob = np.maximum(b - mu, 0) 118 | aug_prob /= aug_prob.sum() 119 | train_data_loader.dataset.set_augProb(aug_prob=aug_prob) 120 | self.logger.log_value_dict({"aug_prob": aug_prob}) 121 | return train_loss_accum / (step + 1) 122 | 123 | # TODO 124 | def validate_model(self, val_data_loader: DataLoader) -> float: 125 | # self.logger.eval(num_batches=len(val_data_loader)) 126 | self.logger.eval(num_batches=1) 127 | self.logger(0.0, 0.0, 1) 128 | return 0.0 129 | -------------------------------------------------------------------------------- /src/pretrainers/motif.py: -------------------------------------------------------------------------------- 1 | """ GROVER, Graph-Level Motif Prediction 2 | Ref Paper: https://arxiv.org/abs/2007.02835 3 | Ref Code: https://github.com/tencent-ailab/grover """ 4 | 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from config.training_config import TrainingConfig 10 | from logger import CombinedLogger 11 | from models.motif import MotifModel 12 | from pretrainers.pretrainer import PreTrainer 13 | from util import get_lr 14 | 15 | 16 | class MotifPreTrainer(PreTrainer): 17 | def __init__( 18 | self, 19 | config: TrainingConfig, 20 | model: MotifModel, 21 | optimizer: torch.optim.Optimizer, 22 | device: torch.device, 23 | logger: CombinedLogger, 24 | ): 25 | super(MotifPreTrainer, self).__init__( 26 | config=config, 27 | model=model, 28 | optimizer=optimizer, 29 | device=device, 30 | logger=logger, 31 | ) 32 | 33 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 34 | self.model.train() 35 | train_loss_accum = 0.0 36 | self.logger.train(num_batches=len(train_data_loader)) 37 | 38 | for step, batch in enumerate(train_data_loader): 39 | batch = batch.to(self.device) 40 | pred = self.model.forward_cl( 41 | batch.x, batch.edge_index, batch.edge_attr, batch.batch 42 | ) 43 | y = batch.y.view(pred.shape) 44 | 45 | loss = self.model.loss_cl(pred.double(), y.double()) 46 | 47 | self.optimizer.zero_grad() 48 | loss.backward() 49 | self.optimizer.step() 50 | loss_float = float(loss.detach().cpu().item()) 51 | train_loss_accum += loss_float 52 | self.logger(loss_float, 0.0, batch.num_graphs, get_lr(self.optimizer)) 53 | 54 | return train_loss_accum / (step + 1) 55 | 56 | # TODO 57 | def validate_model(self, val_data_loader: DataLoader) -> float: 58 | # self.logger.eval(num_batches=len(val_data_loader)) 59 | self.logger.eval(num_batches=1) 60 | self.logger(0.0, 0.0, 1) 61 | return 0.0 62 | -------------------------------------------------------------------------------- /src/pretrainers/pretrainer.py: -------------------------------------------------------------------------------- 1 | import abc, torch, dataclasses 2 | 3 | from typing import Union, List 4 | from logger import CombinedLogger 5 | from torch.utils.data import DataLoader, Dataset 6 | from config.training_config import TrainingConfig 7 | 8 | 9 | @dataclasses.dataclass 10 | class PreTrainer(abc.ABC): 11 | config: TrainingConfig 12 | model: torch.nn.Module 13 | # optimizer: torch.optim.Optimizer 14 | optimizer: Union[torch.optim.Optimizer, List] 15 | device: torch.device 16 | logger: CombinedLogger 17 | 18 | @abc.abstractmethod 19 | def train_for_one_epoch(self, train_data_loader: Union[DataLoader, Dataset]): 20 | pass 21 | 22 | @abc.abstractmethod 23 | def validate_model(self, val_data_loader: Union[DataLoader, Dataset]): 24 | pass 25 | -------------------------------------------------------------------------------- /src/pretrainers/rgcl.py: -------------------------------------------------------------------------------- 1 | # Ref: {GitHub}/lsh0520/RGCL/transferLearning/chem/pretrain_rgcl.py 2 | # Ref: https://arxiv.org/abs/2010.13902 3 | # TODO: TO UPDATE 4 | """ GRAPH SSL Pre-Training via InfoGraph [InfoGraph] 5 | i.e., maps nodes in similar structural contexts to closer embeddings 6 | Ref Paper: Sec. 5.2 and Appendix G of 7 | https://arxiv.org/abs/1905.12265 ; 8 | which is adapted from 9 | https://arxiv.org/abs/1809.10341 ; 10 | Ref Code: ${GitHub_Repo}/chem/pretrain_deepgraphinfomax.py """ 11 | 12 | import gc, torch 13 | from config.training_config import TrainingConfig 14 | from pretrainers.pretrainer import PreTrainer 15 | from torch_geometric.loader import DataLoader 16 | 17 | # from torch.utils.data import DataLoader 18 | from logger import CombinedLogger 19 | from models.rgcl import RGCLModel 20 | from copy import deepcopy 21 | from util import get_lr 22 | 23 | 24 | class RGCLPreTrainer(PreTrainer): 25 | def __init__( 26 | self, 27 | config: TrainingConfig, 28 | model: RGCLModel, 29 | optimizer: torch.optim.Optimizer, 30 | device: torch.device, 31 | logger: CombinedLogger, 32 | ): 33 | super(RGCLPreTrainer, self).__init__( 34 | config=config, 35 | model=model, 36 | optimizer=optimizer, 37 | device=device, 38 | logger=logger, 39 | ) 40 | 41 | # def train_for_one_epoch(self, train_data_loader: DataLoader) -> float: 42 | def train_for_one_epoch(self, dataset) -> float: 43 | dataset.aug = "none" 44 | loader = DataLoader(dataset, batch_size=2048, num_workers=4, shuffle=False) 45 | self.model.eval() 46 | torch.set_grad_enabled(False) 47 | 48 | for step, batch in enumerate(loader): 49 | node_index_start = step * 2048 50 | node_index_end = min(node_index_start + 2048 - 1, len(dataset) - 1) 51 | batch = batch.to(self.device) 52 | node_imp = self.model.node_imp_estimator( 53 | batch.x, batch.edge_index, batch.edge_attr, batch.batch 54 | ).detach() 55 | dataset.node_score[ 56 | dataset.slices["x"][node_index_start] : dataset.slices["x"][ 57 | node_index_end + 1 58 | ] 59 | ] = torch.squeeze(node_imp.half()) 60 | 61 | dataset1 = deepcopy(dataset) 62 | dataset1 = dataset1.shuffle() 63 | dataset2 = deepcopy(dataset1) 64 | dataset3 = deepcopy(dataset1) 65 | 66 | dataset1.aug, dataset1.aug_ratio = "dropN", 0.2 67 | dataset2.aug, dataset2.aug_ratio = "dropN", 0.2 68 | dataset3.aug, dataset3.aug_ratio = "dropN" + "_cp", 0.2 69 | 70 | loader1 = DataLoader( 71 | dataset1, 72 | batch_size=self.config.batch_size, 73 | num_workers=self.config.num_workers, 74 | shuffle=False, 75 | ) 76 | loader2 = DataLoader( 77 | dataset2, 78 | batch_size=self.config.batch_size, 79 | num_workers=self.config.num_workers, 80 | shuffle=False, 81 | ) 82 | loader3 = DataLoader( 83 | dataset3, 84 | batch_size=self.config.batch_size, 85 | num_workers=self.config.num_workers, 86 | shuffle=False, 87 | ) 88 | train_loss_accum = 0 89 | ra_loss_accum = 0 90 | cp_loss_accum = 0 91 | 92 | torch.set_grad_enabled(True) 93 | self.model.train() 94 | self.logger.train(num_batches=len(loader1)) 95 | 96 | for step, batch in enumerate(zip(loader1, loader2, loader3)): 97 | batch1, batch2, batch3 = batch 98 | batch1 = batch1.to(self.device) 99 | batch2 = batch2.to(self.device) 100 | batch3 = batch3.to(self.device) 101 | 102 | self.optimizer.zero_grad() 103 | 104 | x1 = self.model.forward_cl( 105 | batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch 106 | ) 107 | x2 = self.model.forward_cl( 108 | batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch 109 | ) 110 | x3 = self.model.forward_cl( 111 | batch3.x, batch3.edge_index, batch3.edge_attr, batch3.batch 112 | ) 113 | 114 | ra_loss, cp_loss, loss = self.model.loss_ra(x1, x2, x3) 115 | 116 | loss.backward() 117 | self.optimizer.step() 118 | loss_float = float(loss.detach().cpu().item()) 119 | train_loss_accum += loss_float 120 | ra_loss_accum += float(ra_loss.detach().cpu().item()) 121 | cp_loss_accum += float(cp_loss.detach().cpu().item()) 122 | self.logger(loss_float, 0.0, batch1.num_graphs, get_lr(self.optimizer)) 123 | # del dataset1, dataset2, dataset3 124 | # gc.collect() 125 | 126 | # return train_loss_accum/(step+1), ra_loss_accum/(step+1), cp_loss_accum/(step+1) 127 | return train_loss_accum / (step + 1) 128 | 129 | def validate_model(self, val_data_loader) -> float: 130 | # self.logger.eval(num_batches=len(val_data_loader)) 131 | self.logger.eval(num_batches=1) 132 | self.logger(0.0, 0.0, 1) 133 | return 0.0 134 | -------------------------------------------------------------------------------- /src/run_embedding_extraction.py: -------------------------------------------------------------------------------- 1 | from load_save import infer_and_save_embeddings, load_checkpoint 2 | from config.validation_config import parse_config 3 | from init import ( 4 | get_data_loader_val, 5 | get_dataset_extraction, 6 | get_dataset_split, 7 | get_device, 8 | get_model, 9 | get_smiles_list, 10 | init, 11 | ) 12 | 13 | 14 | def main() -> None: 15 | """=== Load the PreTrained Weights ===""" 16 | config = parse_config() 17 | init(config=config) 18 | device = get_device(config=config) 19 | 20 | """ === Set the Datasets and Loader === """ 21 | dataset = get_dataset_extraction(config=config) 22 | smiles_list = get_smiles_list(config=config) 23 | dataset_splits = get_dataset_split( 24 | config=config, dataset=dataset, smiles_list=smiles_list 25 | ) 26 | datasets_list = [dataset for dataset in dataset_splits.values()] 27 | # a list of MoleculeDataset class: ['train', 'val', 'test', 'smiles'] 28 | loaders = [ 29 | get_data_loader_val(config=config, dataset=dataset, shuffle=False) 30 | for dataset in datasets_list 31 | ] 32 | # a tuple of list of smiles: ('train', 'val', 'test') 33 | smile_splits = dataset_splits["smiles"] 34 | 35 | pre_trained_model = get_model(config=config).to(device) 36 | checkpoint = load_checkpoint(config=config, device=device) 37 | pre_trained_model.load_state_dict(checkpoint["model"]) 38 | 39 | """ === Save Node & Graph Embeddings === """ 40 | infer_and_save_embeddings( 41 | config=config, 42 | model=pre_trained_model, 43 | device=device, 44 | datasets=datasets_list, 45 | loaders=loaders, 46 | smile_splits=smile_splits, 47 | ) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /src/run_pretraining.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from config.training_config import TrainingConfig, parse_config 4 | from datasets import MoleculeDataset 5 | from logger import CombinedLogger 6 | from load_save import save_model 7 | from init import ( 8 | get_data_loader, 9 | get_dataset, 10 | get_device, 11 | get_model, 12 | get_optimizer, 13 | get_pretrainer, 14 | init, 15 | ) 16 | 17 | 18 | def pretrain_model( 19 | config: TrainingConfig, 20 | dataset: MoleculeDataset, 21 | model: torch.nn.Module, 22 | device: torch.device, 23 | ) -> None: 24 | """=== Generic Pre-Training Wrapper ===""" 25 | logger = CombinedLogger(config=config) 26 | optimizer = get_optimizer(config=config, model=model) 27 | train_data_loader = get_data_loader(config=config, dataset=dataset) 28 | pre_trainer = get_pretrainer( 29 | config=config, 30 | model=model, 31 | optimizer=optimizer, 32 | device=device, 33 | logger=logger, 34 | ) 35 | 36 | for epoch in range(config.epochs): 37 | # the epoch_0 is random initliased weights 38 | if epoch % config.epochs_save == 0 and config.save_model: 39 | save_model(config=config, model=pre_trainer.model, epoch=epoch) 40 | if config.pretrainer == "RGCL": 41 | pre_trainer.train_for_one_epoch(dataset) 42 | else: 43 | pre_trainer.train_for_one_epoch(train_data_loader) 44 | 45 | if config.save_model: 46 | save_model(config=config, model=pre_trainer.model, epoch=epoch) 47 | 48 | 49 | def main() -> None: 50 | config = parse_config() 51 | init(config=config) 52 | device = get_device(config=config) 53 | dataset = get_dataset(config=config) 54 | model = get_model(config=config).to(device) 55 | pretrain_model(config=config, dataset=dataset, model=model, device=device) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /src/run_validation.py: -------------------------------------------------------------------------------- 1 | from config.validation_config import parse_config 2 | from init import get_task, init 3 | 4 | 5 | def main() -> None: 6 | config = parse_config() 7 | init(config=config) 8 | task = get_task(config=config) 9 | task.run() 10 | 11 | 12 | if __name__ == "__main__": 13 | main() 14 | -------------------------------------------------------------------------------- /src/validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/src/validation/__init__.py -------------------------------------------------------------------------------- /src/validation/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class ProberDataset(Dataset): 5 | """A dataset class that holds (representation, label) pairs.""" 6 | 7 | def __init__(self, representations, labels): 8 | assert len(representations) == len(labels) 9 | self.input_dim = len(representations[0]) # 300 10 | self.representations = representations 11 | self.labels = labels 12 | 13 | def __len__(self): 14 | return len(self.labels) 15 | 16 | def __getitem__(self, idx): 17 | return { 18 | "representation": self.representations[idx], 19 | "label": self.labels[idx], 20 | } 21 | -------------------------------------------------------------------------------- /src/validation/task/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import Task, TrainValTestTask 2 | -------------------------------------------------------------------------------- /src/validation/task/finetune_task.py: -------------------------------------------------------------------------------- 1 | # import pdb 2 | from typing import Dict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | 9 | from config.validation_config import ValidationConfig 10 | 11 | # from logger import CombinedLogger 12 | # from util import get_lr 13 | from validation.dataset import ProberDataset 14 | from validation.task import TrainValTestTask 15 | 16 | 17 | class ProberTask(TrainValTestTask): 18 | def __init__( 19 | self, 20 | config: ValidationConfig, 21 | model: nn.Module, 22 | device: torch.device, 23 | optimizer: torch.optim.Optimizer, 24 | train_dataset: ProberDataset, 25 | val_dataset: ProberDataset, 26 | test_dataset: ProberDataset, 27 | criterion_type: str = "mse", 28 | ): 29 | super(ProberTask, self).__init__( 30 | config=config, 31 | model=model, 32 | device=device, 33 | optimizer=optimizer, 34 | train_dataset=train_dataset, 35 | val_dataset=val_dataset, 36 | test_dataset=test_dataset, 37 | criterion_type=config.criterion_type, 38 | ) 39 | self.train_loader = DataLoader( 40 | self.train_dataset, batch_size=config.batch_size, shuffle=True 41 | ) 42 | self.val_loader = DataLoader( 43 | self.val_dataset, batch_size=config.batch_size, shuffle=False 44 | ) 45 | self.test_loader = DataLoader( 46 | self.test_dataset, batch_size=config.batch_size, shuffle=False 47 | ) 48 | 49 | if criterion_type == "mse": 50 | self.criterion = nn.MSELoss(reduction="mean") 51 | elif criterion_type == "bce": 52 | self.criterion = nn.BCEWithLogitsLoss(reduction="mean") 53 | elif criterion_type == "ce": 54 | self.criterion = nn.CrossEntropyLoss(reduction="mean") 55 | else: 56 | raise Exception("Unknown criterion {}".format(criterion_type)) 57 | 58 | self.train_loss = [] 59 | self.train_score, self.eval_score, self.test_score = [], [], [] 60 | 61 | def run(self, model: nn.Module = None, device: torch.device = None) -> Dict: 62 | # final_train_loss = self.eval_train_dataset() 63 | final_train_loss = self.train() 64 | final_val_loss = self.eval_val_dataset() 65 | final_test_loss = self.eval_test_dataset() 66 | results_dict = { 67 | "train_loss": final_train_loss, 68 | "val_loss": final_val_loss, 69 | "test_loss": final_test_loss, 70 | } 71 | print(results_dict) 72 | return results_dict 73 | 74 | def train(self) -> float: 75 | for _ in tqdm(range(self.config.epochs)): 76 | train_loss = self.train_step() 77 | # val_loss = self.eval_val_dataset() 78 | return train_loss 79 | 80 | def train_step(self) -> float: 81 | self.model.train() 82 | total_loss = 0 83 | # self.logger.train(num_batches=len(self.train_loader)) 84 | 85 | for _, batch in enumerate(self.train_loader): 86 | # pdb.set_trace() 87 | pred = self.model(batch["representation"].to(self.device)).squeeze() 88 | y = batch["label"].to(torch.float32).to(self.device) 89 | loss = self.criterion(pred, y) 90 | self.optimizer.zero_grad() 91 | loss.backward() 92 | self.optimizer.step() 93 | # loss_float = loss.detach().item() 94 | total_loss += loss.detach().item() 95 | 96 | return total_loss / len(self.train_loader) 97 | 98 | def _eval(self, loader: DataLoader) -> float: 99 | self.model.eval() 100 | total_val_loss = 0.0 101 | for _, batch in enumerate(loader): 102 | with torch.no_grad(): 103 | inputs = batch["representation"].to(self.device) 104 | pred = self.model(inputs).squeeze() 105 | true_target = batch["label"].to(torch.float32).to(self.device) 106 | val_loss = self.criterion(pred, true_target).item() 107 | total_val_loss += val_loss 108 | return total_val_loss 109 | 110 | def _eval_roc(self, loader: DataLoader) -> float: 111 | self.model.eval() 112 | y_true, y_pred = [], [] 113 | 114 | for _, batch in enumerate(loader): 115 | with torch.no_grad(): 116 | if self.config.val_task == "prober": 117 | inputs = batch["representation"].to(self.device) 118 | pred = self.model(inputs) 119 | true = batch["label"].to(torch.float32).to(self.device) 120 | elif self.config.val_task == "finetune": 121 | batch = batch.to(self.device) 122 | pred = self.model( 123 | batch.x, batch.edge_index, batch.edge_attr, batch.batch 124 | ) 125 | true = batch.y.view(pred.shape) 126 | 127 | y_true.append(true) 128 | y_pred.append(pred) 129 | # y_pred.append(pred.view(true.shape)) 130 | 131 | y_true = torch.cat(y_true, dim=0).cpu().numpy() 132 | y_pred = torch.cat(y_pred, dim=0).cpu().numpy() 133 | 134 | roc_list = [] 135 | for i in range(y_true.shape[1]): 136 | # AUC is only defined when there is at least one positive data. 137 | # if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: 138 | # roc_list.append(roc_auc_score(y_true[:, i], y_pred[:, i])) 139 | 140 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0: 141 | is_valid = y_true[:, i] ** 2 > 0 142 | roc_list.append( 143 | roc_auc_score((y_true[is_valid, i] + 1) / 2, y_pred[is_valid, i]) 144 | ) 145 | 146 | if len(roc_list) < y_true.shape[1]: 147 | print("Some target is missing!") 148 | print("Missing ratio: %f" % (1 - float(len(roc_list)) / y_true.shape[1])) 149 | 150 | return sum(roc_list) / len(roc_list) 151 | 152 | def eval_train_dataset(self) -> float: 153 | return self._eval(loader=self.train_loader) 154 | 155 | def eval_val_dataset(self) -> float: 156 | return self._eval(loader=self.val_loader) 157 | 158 | def eval_test_dataset(self) -> float: 159 | return self._eval(loader=self.test_loader) 160 | -------------------------------------------------------------------------------- /src/validation/task/graph_edit_distance.py: -------------------------------------------------------------------------------- 1 | import gmatch4py as gm 2 | import networkx as nx 3 | 4 | from validation.task.metrics import GraphPairMetric 5 | 6 | 7 | class GraphEditDistanceDataset(GraphPairMetric): 8 | """Compare the distance (minimal edits) between two graphs""" 9 | 10 | def __init__(self, representation_path, config): 11 | super(GraphEditDistanceDataset, self).__init__(representation_path, config) 12 | 13 | def graph_pair_metric( 14 | self, graph_nx_one: nx.Graph, graph_nx_two: nx.Graph 15 | ) -> float: 16 | ged = gm.GraphEditDistance(1, 1, 1, 1) # all edit costs are equal to 1 17 | result = ged.compare([graph_nx_one, graph_nx_two], None) 18 | return result[0][1] / 100 19 | -------------------------------------------------------------------------------- /src/validation/task/graph_level.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | from networkx.algorithms.distance_measures import diameter 4 | from rdkit.Chem import Fragments 5 | 6 | from datasets import graph_data_obj_to_mol_simple, graph_data_obj_to_nx_simple 7 | from validation.task.metrics import GraphLevelMetric 8 | 9 | 10 | class GraphDiameterDataset(GraphLevelMetric): 11 | """Calculate 2D graph diameter, i.e. longest path.""" 12 | 13 | def __init__(self, representation_path, config, split, **kwargs): 14 | super(GraphDiameterDataset, self).__init__( 15 | representation_path, config, split, **kwargs 16 | ) 17 | 18 | @staticmethod 19 | def graph_level_metric(graph, method="largest") -> int: 20 | graph_nx = graph_data_obj_to_nx_simple(graph) 21 | if nx.is_connected(graph_nx): 22 | return diameter(graph_nx) 23 | elif method == "largest": 24 | graphs = [ 25 | graph_nx.subgraph(c).copy() for c in nx.connected_components(graph_nx) 26 | ] 27 | return max([diameter(g) for g in graphs]) 28 | return None 29 | 30 | 31 | class CycleBasisDataset(GraphLevelMetric): 32 | def __init__(self, representation_path, config, split, **kwargs): 33 | super(CycleBasisDataset, self).__init__( 34 | representation_path, config, split, **kwargs 35 | ) 36 | 37 | @staticmethod 38 | def graph_level_metric(graph): 39 | """Calculate number of cycles.""" 40 | graph_nx = graph_data_obj_to_nx_simple(graph) 41 | return len(nx.cycle_basis(graph_nx)) 42 | 43 | 44 | class AssortativityCoefficientDataset(GraphLevelMetric): 45 | """Compute degree assortativity of graph. 46 | 47 | Assortativity measures the similarity of connections 48 | in the graph with respect to the node degree.""" 49 | 50 | def __init__(self, representation_path, config, split, **kwargs): 51 | super(AssortativityCoefficientDataset, self).__init__( 52 | representation_path, config, split, **kwargs 53 | ) 54 | 55 | @staticmethod 56 | def graph_level_metric(graph) -> float: 57 | graph_nx = graph_data_obj_to_nx_simple(graph) 58 | ret = nx.algorithms.assortativity.degree_assortativity_coefficient(G=graph_nx) 59 | if np.isnan(ret): 60 | return None 61 | return ret 62 | 63 | 64 | class AverageClusteringCoefficientDataset(GraphLevelMetric): 65 | """Estimates the average clustering coefficient of a graph. 66 | 67 | The local clustering of each node in `G` is the fraction of triangles 68 | that actually exist over all possible triangles in its neighborhood. 69 | The average clustering coefficient of a graph `G` is the mean of 70 | local clusters. 71 | 72 | This function finds an approximate average clustering coefficient 73 | for G by repeating `n` times (defined in `trials`) the following 74 | experiment: choose a node at random, choose two of its neighbors 75 | at random, and check if they are connected. The approximate 76 | coefficient is the fraction of triangles found over the number 77 | of trials [1]_.""" 78 | 79 | def __init__(self, representation_path, config, split, **kwargs): 80 | super(AverageClusteringCoefficientDataset, self).__init__( 81 | representation_path, config, split, **kwargs 82 | ) 83 | 84 | @staticmethod 85 | def graph_level_metric(graph) -> float: 86 | graph_nx = graph_data_obj_to_nx_simple(graph) 87 | return nx.algorithms.approximation.average_clustering(G=graph_nx) 88 | 89 | 90 | class NodeConnectivityDataset(GraphLevelMetric): 91 | """Returns an approximation for node connectivity for a graph or digraph G. 92 | 93 | Node connectivity is equal to the minimum number of nodes that 94 | must be removed to disconnect G or render it trivial. By Menger's theorem, 95 | this is equal to the number of node independent paths (paths that 96 | share no nodes other than source and target). 97 | 98 | This algorithm is based on a fast approximation that gives an strict lower 99 | bound on the actual number of node independent paths between two nodes [1]_. 100 | It works for both directed and undirected graphs. 101 | 102 | References 103 | ---------- 104 | .. [1] White, Douglas R., and Mark Newman. 2001 A Fast Algorithm for 105 | Node-Independent Paths. Santa Fe Institute Working Paper #01-07-035 106 | http://eclectic.ss.uci.edu/~drwhite/working.pdf""" 107 | 108 | def __init__(self, representation_path, config, split, **kwargs): 109 | super(NodeConnectivityDataset, self).__init__( 110 | representation_path, config, split, **kwargs 111 | ) 112 | 113 | @staticmethod 114 | def graph_level_metric(graph) -> int: 115 | graph_nx = graph_data_obj_to_nx_simple(graph) 116 | return nx.algorithms.approximation.node_connectivity(G=graph_nx) 117 | 118 | 119 | # class BenzeneRingDataset(GraphLevelMetric): 120 | # """Calculate the number of Benzene rings.""" 121 | # 122 | # def __init__(self, representation_path, config, split, **kwargs): 123 | # super(BenzeneRingDataset, self).__init__( 124 | # representation_path, config, split, **kwargs) 125 | # 126 | # @staticmethod 127 | # def graph_level_metric(graph) -> int: 128 | # mol = graph_data_obj_to_mol_simple( 129 | # graph.x, graph.edge_index, graph.edge_attr) 130 | # return Fragments.fr_benzene(mol) 131 | 132 | 133 | RDKIT_fragments_valid = [ 134 | "fr_epoxide", 135 | "fr_lactam", 136 | "fr_morpholine", 137 | "fr_oxazole", 138 | "fr_tetrazole", 139 | "fr_N_O", 140 | "fr_ether", 141 | "fr_furan", 142 | "fr_guanido", 143 | "fr_halogen", 144 | "fr_morpholine", 145 | "fr_piperdine", 146 | "fr_thiazole", 147 | "fr_thiophene", 148 | "fr_urea", 149 | "fr_allylic_oxid", 150 | "fr_amide", 151 | "fr_amidine", 152 | "fr_azo", 153 | "fr_benzene", 154 | "fr_imidazole", 155 | "fr_imide", 156 | "fr_piperzine", 157 | "fr_pyridine", 158 | ] 159 | 160 | 161 | class RDKiTFragmentDataset(GraphLevelMetric): 162 | """Calculate the number of Benzene rings.""" 163 | 164 | def __init__(self, representation_path, config, split, des): 165 | super(RDKiTFragmentDataset, self).__init__( 166 | representation_path, config, split, des=des 167 | ) 168 | 169 | # @staticmethod 170 | def graph_level_metric(self, graph): 171 | assert self.des in RDKIT_fragments_valid 172 | mol = graph_data_obj_to_mol_simple(graph.x, graph.edge_index, graph.edge_attr) 173 | cmd = compile("Fragments.%s(mol)" % self.des, "", "eval") 174 | try: 175 | return eval(cmd) 176 | except: 177 | return 178 | 179 | 180 | class DownstreamDataset(GraphLevelMetric): 181 | def __init__(self, representation_path, config, split, des): 182 | super(DownstreamDataset, self).__init__( 183 | representation_path, config, split, des=des 184 | ) 185 | 186 | @staticmethod 187 | def graph_level_metric(graph): 188 | labels = graph.y.cpu().numpy().ravel() 189 | return labels 190 | # if ((labels ** 2) != 1).any(): 191 | # return None 192 | # return (labels + 1)/2 193 | -------------------------------------------------------------------------------- /src/validation/task/metrics.py: -------------------------------------------------------------------------------- 1 | # import pdb 2 | import pickle as pkl 3 | import random 4 | from abc import ABC, abstractmethod 5 | 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from datasets import graph_data_obj_to_nx_simple 10 | from validation.dataset import ProberDataset 11 | from validation.utils import get_dataset_extraction, get_dataset_split, get_smiles_list 12 | 13 | 14 | class NodeLevelMetric(ABC): 15 | def __init__(self, representation_path, config, split, **kwargs): 16 | self.representation_path = representation_path 17 | self.__dict__.update(kwargs) 18 | self.config = config 19 | with open(representation_path, "rb") as f: 20 | ( 21 | _, 22 | self.node_repr_list, 23 | self.smiles, 24 | ) = pkl.load(f) 25 | self.num_graph = len(self.smiles) 26 | 27 | dataset = get_dataset_extraction(config=config) 28 | smiles_list = get_smiles_list(config=config) 29 | dataset_splits = get_dataset_split( 30 | config=config, dataset=dataset, smiles_list=smiles_list 31 | ) 32 | # "train", "val", "test" 33 | self.dataset = dataset_splits[split] 34 | 35 | def create_datasets(self): 36 | labels = [] 37 | representations = [] 38 | for graph_idx in range(self.num_graph): 39 | graph = self.dataset[graph_idx] 40 | node_repr = self.node_repr_list[graph_idx] 41 | metrics = self.node_level_metric(graph) 42 | assert len(metrics) == len(node_repr) 43 | for node_idx in range(len(metrics)): 44 | labels.append(metrics[node_idx]) 45 | representations.append(node_repr[node_idx]) 46 | return ProberDataset(representations, labels) 47 | 48 | @abstractmethod 49 | def node_level_metric(self): 50 | pass 51 | 52 | 53 | class GraphLevelMetric(ABC): 54 | def __init__(self, representation_path, config, split, **kwargs): 55 | self.representation_path = representation_path 56 | self.__dict__.update(kwargs) 57 | self.config = config 58 | with open(representation_path, "rb") as f: 59 | (self.graph_repr_list, _, self.smiles) = pkl.load(f) 60 | self.num_graph = len(self.smiles) 61 | dataset = get_dataset_extraction(config=config) 62 | smiles_list = get_smiles_list(config=config) 63 | dataset_splits = get_dataset_split( 64 | config=config, dataset=dataset, smiles_list=smiles_list 65 | ) 66 | # "train", "validation", "test" 67 | self.dataset = dataset_splits[split] 68 | 69 | def create_datasets(self): 70 | labels = [] 71 | representations = [] 72 | for graph_idx in range(self.num_graph): 73 | graph = self.dataset[graph_idx] 74 | graph_repr = self.graph_repr_list[graph_idx] 75 | metrics = self.graph_level_metric(graph) 76 | if metrics is not None: 77 | labels.append(metrics) 78 | representations.append(graph_repr) 79 | return ProberDataset(representations, labels) 80 | 81 | @abstractmethod 82 | def graph_level_metric(self): 83 | pass 84 | 85 | 86 | class NodePairMetric(ABC): 87 | def __init__(self, representation_path, config, split, **kwargs): 88 | self.representation_path = representation_path 89 | self.__dict__.update(kwargs) 90 | self.config = config 91 | with open(representation_path, "rb") as f: 92 | ( 93 | _, 94 | self.node_repr_list, 95 | self.smiles, 96 | ) = pkl.load(f) 97 | self.num_graph = len(self.smiles) 98 | 99 | dataset = get_dataset_extraction(config=config) 100 | smiles_list = get_smiles_list(config=config) 101 | dataset_splits = get_dataset_split( 102 | config=config, dataset=dataset, smiles_list=smiles_list 103 | ) 104 | # "train", "validation", "test" 105 | self.dataset = dataset_splits[split] 106 | 107 | def create_datasets(self, num_pairs=30): 108 | labels = [] 109 | representations = [] 110 | for graph_idx in range(self.num_graph): 111 | graph = self.dataset[graph_idx] 112 | node_repr = self.node_repr_list[graph_idx] 113 | num_nodes = len(node_repr) 114 | graph_nx = graph_data_obj_to_nx_simple(graph) 115 | for _ in range(num_pairs): 116 | start_node = random.choice(list(range(num_nodes))) 117 | end_node = random.choice(list(range(num_nodes))) 118 | if start_node != end_node: 119 | representations.append( 120 | np.concatenate( 121 | [ 122 | node_repr[start_node], 123 | node_repr[end_node], 124 | np.multiply(node_repr[start_node], node_repr[end_node]), 125 | ] 126 | ) 127 | ) 128 | labels.append(self.node_pair_metric(graph_nx, start_node, end_node)) 129 | return ProberDataset(representations, labels) 130 | 131 | @abstractmethod 132 | def node_pair_metric(self, graph_nx, start_node, end_node): 133 | pass 134 | 135 | 136 | class GraphPairMetric(ABC): 137 | def __init__(self, representation_path, config): 138 | self.representation_path = representation_path 139 | self.config = config 140 | with open(representation_path, "rb") as f: 141 | ( 142 | self.dataset, 143 | self.graph_repr_list, 144 | self.node_repr_list, 145 | self.smiles, 146 | ) = pkl.load(f) 147 | self.num_graph = len(self.smiles) 148 | 149 | def create_datasets(self, num_pairs=10000): 150 | def take_out_graph(dataset, idx): 151 | graph = dataset[idx] 152 | if self.config.pretrainer == "GraphCL" and isinstance(graph, tuple): 153 | graph = graph[0] # remove the contrastive augmented data. 154 | graph_nx = graph_data_obj_to_nx_simple(graph) 155 | return graph_nx 156 | 157 | labels = [] 158 | representations = [] 159 | for _ in tqdm(range(num_pairs)): 160 | graph_one_idx = random.choice(list(range(self.num_graph))) 161 | graph_two_idx = random.choice(list(range(self.num_graph))) 162 | if graph_one_idx != graph_two_idx: 163 | graph_nx_one = take_out_graph(self.dataset, graph_one_idx) 164 | graph_nx_two = take_out_graph(self.dataset, graph_two_idx) 165 | graph_repr_one = self.graph_repr_list[graph_one_idx] 166 | graph_repr_two = self.graph_repr_list[graph_two_idx] 167 | 168 | representations.append( 169 | np.concatenate( 170 | [ 171 | graph_repr_one, 172 | graph_repr_two, 173 | np.multiply(graph_repr_one, graph_repr_two), 174 | ] 175 | ) 176 | ) 177 | labels.append(self.graph_pair_metric(graph_nx_one, graph_nx_two)) 178 | return ProberDataset(representations, labels) 179 | 180 | @abstractmethod 181 | def graph_pair_metric(self, graph_nx_one, graph_nx_two): 182 | pass 183 | 184 | 185 | class FineTune_Metric(ABC): 186 | def __init__(self, representation_path, config, split, **kwargs): 187 | self.representation_path = representation_path 188 | self.__dict__.update(kwargs) 189 | self.config = config 190 | with open(representation_path, "rb") as f: 191 | _, _, self.smiles = pkl.load(f) 192 | self.num_graph = len(self.smiles) 193 | 194 | dataset = get_dataset_extraction(config=config) 195 | smiles_list = get_smiles_list(config=config) 196 | dataset_splits = get_dataset_split( 197 | config=config, dataset=dataset, smiles_list=smiles_list 198 | ) 199 | # "train", "val", "test" 200 | self.dataset = dataset_splits[split] 201 | 202 | def create_datasets(self): 203 | labels = [] 204 | representations = [] 205 | for graph_idx in range(self.num_graph): 206 | graph = self.dataset[graph_idx] 207 | # node_repr = self.node_repr_list[graph_idx] 208 | metrics = self.node_level_metric(graph) 209 | # assert len(metrics) == len(node_repr) 210 | for node_idx in range(len(metrics)): 211 | labels.append(metrics[node_idx]) 212 | representations.append(graph) 213 | return ProberDataset(representations, labels) 214 | 215 | @abstractmethod 216 | def node_level_metric(self): 217 | pass 218 | -------------------------------------------------------------------------------- /src/validation/task/node_level.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import networkx as nx 4 | 5 | from datasets import graph_data_obj_to_mol_simple, graph_data_obj_to_nx_simple 6 | from validation.task.metrics import NodeLevelMetric 7 | 8 | 9 | class NodeCentralityDataset(NodeLevelMetric): 10 | """Calculate node centrality for each node in a graph.""" 11 | 12 | def __init__(self, representation_path, config, split, **kwargs): 13 | super(NodeCentralityDataset, self).__init__( 14 | representation_path, config, split, **kwargs 15 | ) 16 | 17 | @staticmethod 18 | def node_level_metric(graph): 19 | graph_nx = graph_data_obj_to_nx_simple(graph) 20 | try: 21 | centrality = nx.eigenvector_centrality(graph_nx, max_iter=1500) 22 | except nx.exception.PowerIterationFailedConvergence: 23 | print("PowerIterationFailedConvergence") 24 | return [0] * len(graph_nx.nodes()) 25 | 26 | # centrality = nx.eigenvector_centrality_numpy(graph_nx) 27 | node_centrality = sorted(list(centrality.items())) 28 | 29 | return [c for v, c in node_centrality] 30 | 31 | 32 | class NodeDegreeDataset(NodeLevelMetric): 33 | """Calculate node degrees for each node in a graph.""" 34 | 35 | def __init__(self, representation_path, config, split, **kwargs): 36 | super(NodeDegreeDataset, self).__init__( 37 | representation_path, config, split, **kwargs 38 | ) 39 | 40 | @staticmethod 41 | def node_level_metric(graph) -> List[int]: 42 | metrics = [] 43 | mol = graph_data_obj_to_mol_simple(graph.x, graph.edge_index, graph.edge_attr) 44 | for _, atom in enumerate(mol.GetAtoms()): 45 | metrics.append(atom.GetDegree()) 46 | # if atom.GetDegree() > 4: 47 | # print(atom.GetDegree()) 48 | return metrics 49 | 50 | 51 | class NodeClusteringDataset(NodeLevelMetric): 52 | """Calculate clustering coefficient for each node in a graph.""" 53 | 54 | def __init__(self, representation_path, config, split, **kwargs): 55 | super(NodeClusteringDataset, self).__init__( 56 | representation_path, config, split, **kwargs 57 | ) 58 | 59 | @staticmethod 60 | def node_level_metric(graph) -> List[float]: 61 | graph_nx = graph_data_obj_to_nx_simple(graph) 62 | clustering_coefficient = nx.clustering(graph_nx) 63 | node_clustering = sorted(list(clustering_coefficient.items())) 64 | return [c for v, c in node_clustering] 65 | -------------------------------------------------------------------------------- /src/validation/task/pair_level.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | 4 | from validation.task.metrics import NodePairMetric 5 | 6 | 7 | class LinkPredictionDataset(NodePairMetric): 8 | """To predict whether two nodes are connected.""" 9 | 10 | def __init__(self, representation_path, config, split, **kwargs): 11 | super(LinkPredictionDataset, self).__init__( 12 | representation_path, config, split, **kwargs 13 | ) 14 | 15 | def node_pair_metric(self, graph_nx, start_node, end_node): 16 | return int(graph_nx.has_edge(start_node, end_node)) 17 | 18 | 19 | class KatzIndexDataset(NodePairMetric): 20 | """The most basic global overlap statistic. 21 | To compute the Katz index we simply count the number of paths 22 | of all lengths between a pair of nodes.""" 23 | 24 | def __init__(self, representation_path, config, split, **kwargs): 25 | super(KatzIndexDataset, self).__init__( 26 | representation_path, config, split, **kwargs 27 | ) 28 | 29 | def node_pair_metric( 30 | self, graph_nx: nx.Graph, start_node: int, end_node: int 31 | ) -> float: 32 | alpha = 0.3 33 | I = np.identity(len(graph_nx.nodes)) 34 | katz_matrix = np.linalg.inv(I - nx.to_numpy_array(graph_nx) * alpha) - I 35 | return katz_matrix[start_node][end_node] 36 | 37 | 38 | class JaccardCoefficientDataset(NodePairMetric): 39 | r"""Compute the Jaccard coefficient of all node pairs. 40 | 41 | Jaccard coefficient of nodes `u` and `v` is defined as 42 | .. math:: 43 | \frac{|\Gamma(u) \cap \Gamma(v)|}{|\Gamma(u) \cup \Gamma(v)|} 44 | where $\Gamma(u)$ denotes the set of neighbors of $u$. 45 | 46 | References 47 | ---------- 48 | .. [1] D. Liben-Nowell, J. Kleinberg. 49 | The Link Prediction Problem for Social Networks (2004). 50 | http://www.cs.cornell.edu/home/kleinber/link-pred.pdf 51 | """ 52 | 53 | def __init__(self, representation_path, config, split, **kwargs): 54 | super(JaccardCoefficientDataset, self).__init__( 55 | representation_path, config, split, **kwargs 56 | ) 57 | 58 | def node_pair_metric(self, graph_nx, start_node, end_node): 59 | preds = nx.jaccard_coefficient(graph_nx, [(start_node, end_node)]) 60 | for u, v, p in preds: 61 | return p 62 | 63 | 64 | # class GraphEditDistanceDataset(GraphPairMetric): 65 | # """ Compare the distance (minimal edits) between two graphs""" 66 | # def __init__(self, representation_path, config): 67 | # super(GraphEditDistanceDataset, self).__init__( 68 | # representation_path, config) 69 | # 70 | # def graph_pair_metric( 71 | # self, graph_nx_one: nx.Graph, graph_nx_two: nx.Graph) -> float: 72 | # ged = gm.GraphEditDistance(1, 1, 1, 1) # all edit costs equal to 1 73 | # result = ged.compare([graph_nx_one, graph_nx_two], None) 74 | # return result[0][1] / 100 75 | -------------------------------------------------------------------------------- /src/validation/task/prober_task.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | import wandb 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from sklearn.metrics import roc_auc_score 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from config.validation_config import ValidationConfig 11 | 12 | # from logger import CombinedLogger 13 | from util import get_lr 14 | from datasets import MoleculeDataset 15 | from validation.dataset import ProberDataset 16 | from validation.task import TrainValTestTask 17 | 18 | 19 | class ProberTask(TrainValTestTask): 20 | def __init__( 21 | self, 22 | config: ValidationConfig, 23 | model: nn.Module, 24 | device: torch.device, 25 | optimizer: torch.optim.Optimizer, 26 | # logger: CombinedLogger, 27 | train_dataset: Union[MoleculeDataset, ProberDataset], 28 | val_dataset: Union[MoleculeDataset, ProberDataset], 29 | test_dataset: Union[MoleculeDataset, ProberDataset], 30 | criterion_type: str = "mse", 31 | ): 32 | super(ProberTask, self).__init__( 33 | config=config, 34 | model=model, 35 | device=device, 36 | optimizer=optimizer, 37 | # logger=logger, 38 | train_dataset=train_dataset, 39 | val_dataset=val_dataset, 40 | test_dataset=test_dataset, 41 | criterion_type=config.criterion_type, 42 | ) 43 | if config.val_task == "finetune": 44 | from torch_geometric.loader import DataLoader 45 | else: 46 | from torch.utils.data import DataLoader 47 | self.train_loader = DataLoader( 48 | self.train_dataset, batch_size=config.batch_size, shuffle=True 49 | ) 50 | self.val_loader = DataLoader( 51 | self.val_dataset, batch_size=config.batch_size, shuffle=False 52 | ) 53 | self.test_loader = DataLoader( 54 | self.test_dataset, batch_size=config.batch_size, shuffle=False 55 | ) 56 | 57 | if criterion_type == "mse": 58 | self.criterion = nn.MSELoss(reduction="mean") 59 | elif criterion_type == "bce": 60 | self.criterion = nn.BCEWithLogitsLoss(reduction="mean") 61 | elif criterion_type == "ce": 62 | self.criterion = nn.CrossEntropyLoss(reduction="mean") 63 | else: 64 | raise Exception("Unknown criterion {}".format(criterion_type)) 65 | 66 | def train(self) -> float: 67 | pass 68 | 69 | def run(self): 70 | for _ in tqdm(range(self.config.epochs)): 71 | if self.config.val_task == "finetune": 72 | train_loss = self.train_finetune_step() 73 | else: 74 | train_loss = self.train_step() 75 | 76 | if self.config.probe_task == "downstream": 77 | # train_score = self._eval_roc(loader=self.train_loader) 78 | val_score = self._eval_roc(loader=self.val_loader) 79 | test_score = self._eval_roc(loader=self.test_loader) 80 | else: 81 | val_score = self.eval_val_dataset() 82 | test_score = self.eval_test_dataset() 83 | 84 | results_dict = { 85 | "train_loss": train_loss, 86 | "val_score": val_score, 87 | "test_score": test_score, 88 | } 89 | wandb.log(results_dict) 90 | 91 | print(results_dict) 92 | print("\n\n\n") 93 | return 94 | 95 | def train_step(self) -> float: 96 | self.model.train() 97 | total_loss = 0 98 | 99 | for _, batch in enumerate(self.train_loader): 100 | if self.config.probe_task == "downstream": 101 | pred = self.model(batch["representation"].to(self.device)) 102 | y = batch["label"].to(torch.float32).to(self.device) 103 | 104 | # Whether y is non-null or not. 105 | is_valid = y**2 > 0 106 | # Loss matrix 107 | loss_mat = self.criterion(pred.double(), (y + 1) / 2) 108 | # loss matrix after removing null target 109 | loss_mat = torch.where( 110 | is_valid, 111 | loss_mat, 112 | torch.zeros(loss_mat.shape).to(self.device).to(loss_mat.dtype), 113 | ) 114 | 115 | self.optimizer.zero_grad() 116 | loss = torch.sum(loss_mat) / torch.sum(is_valid) 117 | loss.backward() 118 | self.optimizer.step() 119 | total_loss += loss.detach().item() 120 | 121 | else: 122 | pred = self.model(batch["representation"].to(self.device)).squeeze() 123 | y = batch["label"].to(torch.float32).to(self.device) 124 | loss = self.criterion(pred, y) 125 | self.optimizer.zero_grad() 126 | loss.backward() 127 | self.optimizer.step() 128 | total_loss += loss.detach().item() 129 | 130 | return total_loss / len(self.train_loader) 131 | 132 | def train_finetune_step(self) -> float: 133 | self.model.train() 134 | total_loss = 0 135 | 136 | if self.config.probe_task != "downstream": 137 | raise NotImplementedError 138 | 139 | for _, batch in enumerate(self.train_loader): 140 | batch = batch.to(self.device) 141 | pred = self.model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) 142 | y = batch.y.view(pred.shape).to(torch.float64) 143 | 144 | is_valid = y**2 > 0 145 | loss_mat = self.criterion(pred.double(), (y + 1) / 2) 146 | loss_mat = torch.where( 147 | is_valid, 148 | loss_mat, 149 | torch.zeros(loss_mat.shape).to(self.device).to(loss_mat.dtype), 150 | ) 151 | 152 | self.optimizer.zero_grad() 153 | loss = torch.sum(loss_mat) / torch.sum(is_valid) 154 | loss.backward() 155 | self.optimizer.step() 156 | total_loss += loss.detach().item() 157 | 158 | return total_loss / len(self.train_loader) 159 | 160 | def _eval(self, loader: DataLoader) -> float: 161 | self.model.eval() 162 | y_true, y_pred = [], [] 163 | for _, batch in enumerate(loader): 164 | inputs = batch["representation"].to(self.device) 165 | with torch.no_grad(): 166 | pred = self.model(inputs).squeeze() 167 | y_true.append(batch["label"]) 168 | y_pred.append(pred.detach().cpu()) 169 | y_true, y_pred = torch.cat(y_true), torch.cat(y_pred) 170 | if self.criterion is nn.BCELoss: 171 | y_pred = nn.Sigmoid(y_pred) 172 | return self.criterion(y_pred, y_true).item() 173 | 174 | def _eval_roc(self, loader: DataLoader) -> float: 175 | self.model.eval() 176 | y_true, y_pred = [], [] 177 | 178 | for _, batch in enumerate(loader): 179 | with torch.no_grad(): 180 | if self.config.val_task == "prober": 181 | inputs = batch["representation"].to(self.device) 182 | pred = self.model(inputs) 183 | true = batch["label"].to(torch.float32).to(self.device) 184 | elif self.config.val_task == "finetune": 185 | batch = batch.to(self.device) 186 | pred = self.model( 187 | batch.x, batch.edge_index, batch.edge_attr, batch.batch 188 | ) 189 | true = batch.y.view(pred.shape) 190 | 191 | y_true.append(true) 192 | y_pred.append(pred) 193 | # y_pred.append(pred.view(true.shape)) 194 | 195 | y_true = torch.cat(y_true, dim=0).cpu().numpy() 196 | y_pred = torch.cat(y_pred, dim=0).cpu().numpy() 197 | 198 | roc_list = [] 199 | for i in range(y_true.shape[1]): 200 | # AUC is only defined when there is at least one positive data. 201 | # if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: 202 | # roc_list.append(roc_auc_score(y_true[:, i], y_pred[:, i])) 203 | 204 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0: 205 | is_valid = y_true[:, i] ** 2 > 0 206 | roc_list.append( 207 | roc_auc_score((y_true[is_valid, i] + 1) / 2, y_pred[is_valid, i]) 208 | ) 209 | 210 | if len(roc_list) < y_true.shape[1]: 211 | print("Some target is missing!") 212 | print("Missing ratio: %f" % (1 - float(len(roc_list)) / y_true.shape[1])) 213 | 214 | return sum(roc_list) / len(roc_list) 215 | 216 | def eval_train_dataset(self) -> float: 217 | return self._eval(loader=self.train_loader) 218 | 219 | def eval_val_dataset(self) -> float: 220 | return self._eval(loader=self.val_loader) 221 | 222 | def eval_test_dataset(self) -> float: 223 | return self._eval(loader=self.test_loader) 224 | -------------------------------------------------------------------------------- /src/validation/task/task.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import dataclasses 3 | from typing import Dict 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset 6 | from config.validation_config import ValidationConfig 7 | from logger import CombinedLogger 8 | 9 | 10 | @dataclasses.dataclass 11 | class Task(abc.ABC): 12 | config: ValidationConfig 13 | 14 | @abc.abstractmethod 15 | def run(self, model: torch.nn.Module, device: torch.device) -> Dict: 16 | pass 17 | 18 | 19 | @dataclasses.dataclass 20 | class TrainValTestTask(Task): 21 | config: ValidationConfig 22 | model: torch.nn.Module 23 | device: torch.device 24 | optimizer: torch.optim.Optimizer 25 | train_dataset: Dataset 26 | val_dataset: Dataset 27 | test_dataset: Dataset 28 | criterion_type: str 29 | # logger: CombinedLogger 30 | 31 | @abc.abstractmethod 32 | def train(self) -> float: 33 | pass 34 | 35 | @abc.abstractmethod 36 | def _eval(self, loader: DataLoader) -> float: 37 | pass 38 | 39 | @abc.abstractmethod 40 | def eval_val_dataset(self) -> float: 41 | pass 42 | 43 | @abc.abstractmethod 44 | def eval_test_dataset(self) -> float: 45 | pass 46 | -------------------------------------------------------------------------------- /src/validation/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import pandas as pd 4 | 5 | from config import Config 6 | from config.training_config import TrainingConfig 7 | from config.validation_config import ValidationConfig 8 | from datasets import MoleculeDataset 9 | from splitters import random_scaffold_split, random_split, scaffold_split 10 | 11 | 12 | def get_dataset_extraction( 13 | config: Union[TrainingConfig, ValidationConfig] 14 | ) -> MoleculeDataset: 15 | root: str = f"data/molecule_datasets/{config.dataset}/" 16 | return MoleculeDataset(root=root, dataset=config.dataset) 17 | 18 | 19 | def get_dataset_split( 20 | config: Union[TrainingConfig, ValidationConfig], 21 | dataset: MoleculeDataset, 22 | smiles_list: Optional[List[str]], 23 | ) -> Dict[str, MoleculeDataset]: 24 | if config.split == "scaffold": 25 | train_, val_, test_, smile_ = scaffold_split( 26 | dataset, 27 | smiles_list, 28 | null_value=0, 29 | frac_train=0.8, 30 | frac_valid=0.1, 31 | frac_test=0.1, 32 | return_smiles=True, 33 | ) 34 | elif config.split == "random": 35 | train_, val_, test_, smile_ = random_split( 36 | dataset, 37 | null_value=0, 38 | frac_train=0.8, 39 | frac_valid=0.1, 40 | frac_test=0.1, 41 | seed=config.seed, 42 | smiles_list=smiles_list, 43 | ) 44 | elif config.split == "random_scaffold": 45 | train_, val_, test_, smile_ = random_scaffold_split( 46 | dataset, 47 | smiles_list, 48 | null_value=0, 49 | frac_train=0.8, 50 | frac_valid=0.1, 51 | frac_test=0.1, 52 | seed=config.seed, 53 | return_smiles=True, 54 | ) 55 | else: 56 | raise ValueError("Invalid split option.") 57 | return { 58 | "train": train_, 59 | "val": val_, 60 | "test": test_, 61 | "smiles": smile_, 62 | } 63 | 64 | 65 | def get_smiles_list(config: Config) -> List[str]: 66 | smiles_list = pd.read_csv( 67 | f"data/molecule_datasets/{config.dataset}/processed/smiles.csv", 68 | header=None, 69 | ) 70 | return smiles_list[0].tolist() 71 | --------------------------------------------------------------------------------