├── .gitignore
├── LICENSE
├── README.md
├── data
├── GEOM
│ ├── __init__.py
│ ├── analyze_GEOM.py
│ └── pre_process.py
├── README.md
└── download_data.sh
├── env.yaml
├── figures
├── Diagram.png
└── flow.svg
├── log
└── readme.md
├── script
├── FineTuning.sh
├── PreTraining.sh
├── Probing.sh
└── README.md
└── src
├── __init__.py
├── batch.py
├── config
├── __init__.py
├── aug_whitening.py
├── sweeps
│ ├── __init__.py
│ └── mlp.yaml
├── training_config.py
└── validation_config.py
├── dataloader.py
├── datasets
├── __init__.py
├── molecule_contextual.py
├── molecule_datasets.py
├── molecule_gpt_gnn.py
├── molecule_graphcl.py
├── molecule_graphmvp.py
├── molecule_motif.py
├── molecule_rgcl.py
└── utils.py
├── init.py
├── load_save.py
├── logger.py
├── models
├── __init__.py
├── attribute_masking.py
├── building_blocks
│ ├── __init__.py
│ ├── auto_encoder.py
│ ├── flow.py
│ ├── gnn.py
│ ├── mlp.py
│ └── schnet.py
├── context_prediction.py
├── contextual.py
├── discriminator.py
├── edge_prediction.py
├── gpt_gnn.py
├── graph_cl.py
├── graphmae.py
├── graphmvp.py
├── graphpred.py
├── info_max.py
├── joao_v2.py
├── motif.py
├── pre_trainer_model.py
└── rgcl.py
├── pretrainers
├── __init__.py
├── attribute_masking.py
├── context_prediction.py
├── contextual.py
├── edge_prediction.py
├── gpt_gnn.py
├── graph_cl.py
├── graphmae.py
├── graphmvp.py
├── info_max.py
├── joao.py
├── joao_v2.py
├── motif.py
├── pretrainer.py
└── rgcl.py
├── run_embedding_extraction.py
├── run_pretraining.py
├── run_validation.py
├── splitters.py
├── util.py
└── validation
├── __init__.py
├── dataset.py
├── task
├── __init__.py
├── finetune_task.py
├── graph_edit_distance.py
├── graph_level.py
├── metrics.py
├── node_level.py
├── pair_level.py
├── prober_task.py
├── task.py
└── weisfeiler_lehman.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | wandb/
3 | analysis/
4 | pretrain_models/
5 | .vscode/
6 | embedding_dir/
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | pip-wheel-metadata/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .venv
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 | .DS_Store
138 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Hanchen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## MolGraphEval
2 |
3 | This repository is the official implementation of paper: "**Evaluating Self-supervised Learning for Molecular Graph Embeddings**”, NeurIPS 2023, Datasets and Benchmarks Track.
4 |
5 | 
6 |
7 | ### Citation
8 | ```bibtex
9 | @inproceedings{GraphEval,
10 | title = {Evaluating Self-supervised Learning for Molecular Graph Embeddings},
11 | author = {Hanchen Wang* and Jean Kaddour* and Shengchao Liu and Jian Tang and Joan Lasenby and Qi Liu},
12 | booktitle = {NeurIPS 2023, Datasets and Benchmarks Track},
13 | year = 2023
14 | }
15 | ```
16 |
17 | ### Usage
18 | We include scripts for pre-training, probing and fine-tuning for GraphSSL on molecules, see script folder. We use conda to set up the environment:
19 | ```bash
20 | conda env create -f env.yaml
21 | ```
--------------------------------------------------------------------------------
/data/GEOM/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/data/GEOM/__init__.py
--------------------------------------------------------------------------------
/data/GEOM/analyze_GEOM.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021. Shengchao & Hanchen
2 | # liusheng@mila.quebec & hw501@cam.ac.uk
3 |
4 | import hashlib
5 | import json
6 | import os
7 | import pickle
8 | import random
9 | from os.path import join
10 |
11 | import msgpack
12 | from rdkit import Chem
13 |
14 | # from rdkit.Chem import AllChem
15 | # from rdkit.Chem.Draw import MolsToGridImage
16 | # from rdkit.Chem.rdmolfiles import MolToPDBFile
17 | # from rdkit.Chem.rdMolDescriptors import CalcWHIM
18 |
19 | zip2md5 = {
20 | "drugs_crude.msgpack.tar.gz": "7778e84c50b7cde755cca670d1f75091",
21 | "drugs_featurized.msgpack.tar.gz": "2fb86edc50e3ab3b96f78fe01082965b",
22 | "qm9_crude.msgpack.tar.gz": "aad0081ed5d9b8c93c2bd0235987573b",
23 | "qm9_featurized.msgpack.tar.gz": "09655f470f438e3a7a0dfd20f40f6f22",
24 | "rdkit_folder.tar.gz": "e8f2168b7050652db22c976be25c450e",
25 | }
26 |
27 |
28 | def compute_md5(file_name, chunk_size=65536):
29 | md5 = hashlib.md5()
30 | with open(file_name, "rb") as fin:
31 | chunk = fin.read(chunk_size)
32 | while chunk:
33 | md5.update(chunk)
34 | chunk = fin.read(chunk_size)
35 | return md5.hexdigest()
36 |
37 |
38 | def analyze_crude_file(data):
39 | drugs_file = "{}_crude.msgpack".format(data)
40 | unpacker = msgpack.Unpacker(open(drugs_file, "rb"))
41 | print(compute_md5("{}.tar.gz".format(drugs_file)))
42 | print(zip2md5["{}.tar.gz".format(drugs_file)])
43 |
44 | total_smiles_list = []
45 | for idx, drug_batch in enumerate(unpacker):
46 | smiles_list = list(drug_batch.keys())
47 | print(idx, "\t", len(smiles_list))
48 | total_smiles_list.extend(smiles_list)
49 |
50 | for smiles in smiles_list:
51 | print(smiles)
52 | if smiles == "CCOCC[C@@H](O)C=O":
53 | print(drug_batch[smiles])
54 | conformer_list = drug_batch[smiles]["conformers"]
55 | print(len(conformer_list))
56 | # break
57 | break
58 | print("total smiles list {}".format(len(total_smiles_list)))
59 | return
60 |
61 |
62 | def analyze_featurized_file(data):
63 | drugs_file = "{}_featurized.msgpack".format(data)
64 | unpacker = msgpack.Unpacker(open(drugs_file, "rb"))
65 | print(compute_md5("{}.tar.gz".format(drugs_file)))
66 | print(zip2md5["{}.tar.gz".format(drugs_file)])
67 |
68 | for idx, drug_batch in enumerate(unpacker):
69 | smiles_list = list(drug_batch.keys())
70 | for smiles in smiles_list:
71 | print(smiles)
72 | print(len(drug_batch[smiles]))
73 | # print(drug_batch[smiles])
74 | break
75 | break
76 |
77 | return
78 |
79 |
80 | def analyze_rdkit_file(data):
81 | dir_name = "rdkit_folder"
82 | # dir_zip_name = '{}.tar.gz'.format(dir_name)
83 | # assert compute_md5(dir_zip_name) == zip2md5[dir_zip_name]
84 |
85 | drugs_file = "{}/summary_{}.json".format(dir_name, data)
86 | with open(drugs_file, "r") as f:
87 | drugs_summary = json.load(f)
88 |
89 | smiles_list = list(drugs_summary.keys())
90 | print("# SMILES: {}".format(len(smiles_list))) # 304,466
91 | example_smiles = smiles_list[0]
92 | print(drugs_summary[example_smiles])
93 |
94 | # Now let's find active molecules and their pickle paths:
95 | active_mol_paths = []
96 | active_smiles_list = []
97 | for smiles, sub_dic in drugs_summary.items():
98 | if sub_dic.get("sars_cov_one_cl_protease_active") == 1:
99 | pickle_path = join(dir_name, sub_dic.get("pickle_path", ""))
100 | if os.path.isfile(pickle_path):
101 | active_mol_paths.append(pickle_path)
102 | print("# active mols on CoV 3CL: {}\n".format(len(active_mol_paths)))
103 |
104 | # Now randomly sample inactive molecules and their pickle paths:
105 | random_smiles = list(drugs_summary.keys())
106 | random.shuffle(random_smiles)
107 | random_smiles = random_smiles[:1000]
108 | inactive_mol_paths = []
109 | for smiles in random_smiles:
110 | sub_dic = drugs_summary[smiles]
111 | if sub_dic.get("sars_cov_one_cl_protease_active") == 0:
112 | pickle_path = join(dir_name, sub_dic.get("pickle_path", ""))
113 | if os.path.isfile(pickle_path):
114 | inactive_mol_paths.append(pickle_path)
115 | print("# inactive mols on CoV 3CL: {}\n".format(len(inactive_mol_paths)))
116 |
117 | sample_dic = {}
118 | sample_smiles = active_smiles_list
119 | sample_smiles.extend(random_smiles)
120 | for mol_path in [*active_mol_paths, *inactive_mol_paths]:
121 | with open(mol_path, "rb") as f:
122 | dic = pickle.load(f)
123 | sample_dic.update({dic["smiles"]: dic})
124 | print("# all mols on CoV 3CL: {}\n".format(len(sample_dic)))
125 |
126 | idx = 0
127 | for k, v in sample_dic.items():
128 | conf_list = v["conformers"]
129 | # print(k)
130 | # print(len(conf_list))
131 | for conf in conf_list:
132 | mol = conf["rd_mol"]
133 | # print(conf)
134 | # print(mol)
135 | print(Chem.MolToSmiles(mol))
136 | smiles = Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(mol)))
137 | print(smiles)
138 | print("\n")
139 |
140 | idx += 1
141 | if idx > 9:
142 | break
143 |
144 | return
145 |
146 |
147 | if __name__ == "__main__":
148 | data = "drugs"
149 | # analyze_crude_file(data)
150 | # analyze_featurized_file(data)
151 | analyze_rdkit_file(data)
--------------------------------------------------------------------------------
/data/GEOM/pre_process.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | def clean_rdkit_file(data):
5 | """Write smiles in a csv file"""
6 |
7 | dir_name = "rdkit_folder"
8 | drugs_file = "{}/summary_{}.json".format(dir_name, data)
9 | with open(drugs_file, "r") as f:
10 | drugs_summary = json.load(f)
11 |
12 | # 304,466 molecules in total
13 | smiles_list = list(drugs_summary.keys())
14 | print("Number of total items (SMILES): {}".format(len(smiles_list)))
15 |
16 | drug_file = "{}.csv".format(data)
17 | with open(drug_file, "w") as f:
18 | f.write("smiles\n")
19 | for smiles in smiles_list:
20 | f.write("{}\n".format(smiles))
21 |
22 | return
23 |
24 |
25 | if __name__ == "__main__":
26 | data = "drugs"
27 | clean_rdkit_file(data)
28 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 |
4 |
5 | #### Geometric Ensemble Of Molecules (GEOM)
6 |
7 | ```bash
8 | mkdir -p GEOM/raw
9 | mkdir -p GEOM/processed
10 | ```
11 |
12 | ```bash
13 | wget https://dataverse.harvard.edu/api/access/datafile/4327252
14 | mv 4327252 rdkit_folder.tar.gz
15 | tar -xvf rdkit_folder.tar.gz
16 | ```
17 |
18 |
19 |
20 | #### Chem Dataset
21 |
22 | ```bash
23 | bash download_data.sh
24 | ```
25 |
--------------------------------------------------------------------------------
/data/download_data.sh:
--------------------------------------------------------------------------------
1 |
2 | # MoleculeNet + ZINC
3 | wget http://snap.stanford.edu/gnn-pretrain/data/chem_dataset.zip
4 | unzip chem_dataset.zip
5 | mv dataset molecule_datasets
6 | rm chem_dataset.zip
7 | rm -r molecule_datasets/*/processed
8 |
9 | # for d in molecule_datasets/*/
10 | # do
11 | # echo "$d"
12 | # ln -s $d ./
13 | # done
14 |
15 | # GEOM
16 | wget https://dataverse.harvard.edu/api/access/datafile/4327252
17 | mv 4327252 rdkit_folder.tar.gz
18 | tar -xvf rdkit_folder.tar.gz
19 | mv rdkit_folder GEOM/
20 | rm rdkit_folder.tar.gz
21 |
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: GraphEval
2 | channels:
3 | - pyg
4 | - pytorch
5 | - nvidia
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=main
10 | - _openmp_mutex=5.1=1_gnu
11 | - black=23.3.0=py39h06a4308_0
12 | - blas=1.0=mkl
13 | - brotlipy=0.7.0=py39h27cfd23_1003
14 | - bzip2=1.0.8=h7b6447c_0
15 | - ca-certificates=2023.05.30=h06a4308_0
16 | - certifi=2023.7.22=py39h06a4308_0
17 | - cffi=1.15.1=py39h5eee18b_3
18 | - cryptography=41.0.2=py39h22a60cf_0
19 | - cuda-cudart=11.7.99=0
20 | - cuda-cupti=11.7.101=0
21 | - cuda-libraries=11.7.1=0
22 | - cuda-nvrtc=11.7.99=0
23 | - cuda-nvtx=11.7.91=0
24 | - cuda-runtime=11.7.1=0
25 | - cudatoolkit=10.2.89=hfd86e86_1
26 | - ffmpeg=4.3=hf484d3e_0
27 | - filelock=3.9.0=py39h06a4308_0
28 | - freetype=2.12.1=h4a9f257_0
29 | - giflib=5.2.1=h5eee18b_3
30 | - gmp=6.2.1=h295c915_3
31 | - gmpy2=2.1.2=py39heeb90bb_0
32 | - gnutls=3.6.15=he1e5248_0
33 | - idna=3.4=py39h06a4308_0
34 | - intel-openmp=2023.1.0=hdb19cb5_46305
35 | - jinja2=3.1.2=py39h06a4308_0
36 | - joblib=1.2.0=py39h06a4308_0
37 | - jpeg=9e=h5eee18b_1
38 | - lame=3.100=h7b6447c_0
39 | - lcms2=2.12=h3be6417_0
40 | - ld_impl_linux-64=2.38=h1181459_1
41 | - lerc=3.0=h295c915_0
42 | - libcublas=11.10.3.66=0
43 | - libcufft=10.7.2.124=h4fbf590_0
44 | - libcufile=1.7.1.12=0
45 | - libcurand=10.3.3.129=0
46 | - libcusolver=11.4.0.1=0
47 | - libcusparse=11.7.4.91=0
48 | - libdeflate=1.17=h5eee18b_0
49 | - libffi=3.4.4=h6a678d5_0
50 | - libgcc-ng=11.2.0=h1234567_1
51 | - libgfortran-ng=11.2.0=h00389a5_1
52 | - libgfortran5=11.2.0=h1234567_1
53 | - libgomp=11.2.0=h1234567_1
54 | - libiconv=1.16=h7f8727e_2
55 | - libidn2=2.3.4=h5eee18b_0
56 | - libnpp=11.7.4.75=0
57 | - libnvjpeg=11.8.0.2=0
58 | - libpng=1.6.39=h5eee18b_0
59 | - libstdcxx-ng=11.2.0=h1234567_1
60 | - libtasn1=4.19.0=h5eee18b_0
61 | - libtiff=4.5.0=h6a678d5_2
62 | - libunistring=0.9.10=h27cfd23_0
63 | - libuv=1.44.2=h5eee18b_0
64 | - libwebp=1.2.4=h11a3e52_1
65 | - libwebp-base=1.2.4=h5eee18b_1
66 | - lz4-c=1.9.4=h6a678d5_0
67 | - markupsafe=2.1.1=py39h7f8727e_0
68 | - mkl=2023.1.0=h213fc3f_46343
69 | - mkl-service=2.4.0=py39h5eee18b_1
70 | - mkl_fft=1.3.6=py39h417a72b_1
71 | - mkl_random=1.2.2=py39h417a72b_1
72 | - mpc=1.1.0=h10f8cd9_1
73 | - mpfr=4.0.2=hb69a4c5_1
74 | - mpmath=1.3.0=py39h06a4308_0
75 | - mypy_extensions=0.4.3=py39h06a4308_1
76 | - ncurses=6.4=h6a678d5_0
77 | - nettle=3.7.3=hbbd107a_1
78 | - networkx=3.1=py39h06a4308_0
79 | - ninja=1.10.2=h06a4308_5
80 | - ninja-base=1.10.2=hd09550d_5
81 | - numpy=1.25.2=py39h5f9d8c6_0
82 | - numpy-base=1.25.2=py39hb5e798b_0
83 | - openh264=2.1.1=h4ff587b_0
84 | - openssl=3.0.10=h7f8727e_0
85 | - pathspec=0.10.3=py39h06a4308_0
86 | - pip=23.2.1=py39h06a4308_0
87 | - platformdirs=2.5.2=py39h06a4308_0
88 | - psutil=5.9.0=py39h5eee18b_0
89 | - pycparser=2.21=pyhd3eb1b0_0
90 | - pyg=2.3.1=py39_torch_2.0.0_cu117
91 | - pyopenssl=23.2.0=py39h06a4308_0
92 | - pyparsing=3.0.9=py39h06a4308_0
93 | - pysocks=1.7.1=py39h06a4308_0
94 | - python=3.9.17=h955ad1f_0
95 | - pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0
96 | - pytorch-cluster=1.6.1=py39_torch_2.0.0_cu117
97 | - pytorch-cuda=11.7=h778d358_5
98 | - pytorch-mutex=1.0=cuda
99 | - pytorch-scatter=2.1.1=py39_torch_2.0.0_cu117
100 | - readline=8.2=h5eee18b_0
101 | - requests=2.31.0=py39h06a4308_0
102 | - scikit-learn=1.3.0=py39h1128e8f_0
103 | - scipy=1.11.1=py39h5f9d8c6_0
104 | - setuptools=68.0.0=py39h06a4308_0
105 | - sqlite=3.41.2=h5eee18b_0
106 | - sympy=1.11.1=py39h06a4308_0
107 | - tbb=2021.8.0=hdb19cb5_0
108 | - threadpoolctl=2.2.0=pyh0d69192_0
109 | - tk=8.6.12=h1ccaba5_0
110 | - tomli=2.0.1=py39h06a4308_0
111 | - torchaudio=2.0.2=py39_cu117
112 | - torchtriton=2.0.0=py39
113 | - torchvision=0.15.2=py39_cu117
114 | - typing_extensions=4.7.1=py39h06a4308_0
115 | - wheel=0.38.4=py39h06a4308_0
116 | - xz=5.4.2=h5eee18b_0
117 | - zlib=1.2.13=h5eee18b_0
118 | - zstd=1.5.5=hc292b87_0
119 | - pip:
120 | - appdirs==1.4.4
121 | - ase==3.22.1
122 | - calmsize==0.1.3
123 | - chardet==5.2.0
124 | - charset-normalizer==3.2.0
125 | - click==8.1.6
126 | - contourpy==1.1.0
127 | - cycler==0.11.0
128 | - cython==3.0.0
129 | - descriptastorus==2.5.0.23
130 | - docker-pycreds==0.4.0
131 | - fonttools==4.42.0
132 | - fvcore==0.1.5.post20221221
133 | - gensim==4.3.1
134 | - gitdb==4.0.10
135 | - gitpython==3.1.32
136 | - gmatch4py==0.2.5b0
137 | - importlib-resources==6.0.1
138 | - iopath==0.1.10
139 | - kiwisolver==1.4.4
140 | - littleutils==0.2.2
141 | - matplotlib==3.7.2
142 | - msgpack==1.0.5
143 | - ogb==1.3.6
144 | - outdated==0.2.2
145 | - packaging==23.1
146 | - pandas==2.0.3
147 | - pandas-flavor==0.6.0
148 | - pathtools==0.1.2
149 | - pillow==10.0.0
150 | - portalocker==2.7.0
151 | - protobuf==4.24.0
152 | - python-dateutil==2.8.2
153 | - pytorch-memlab==0.3.0
154 | - pytz==2023.3
155 | - pyyaml==6.0.1
156 | - rdkit==2023.3.2
157 | - sentry-sdk==1.29.2
158 | - setproctitle==1.3.2
159 | - six==1.16.0
160 | - smart-open==6.3.0
161 | - smmap==5.0.0
162 | - tabulate==0.9.0
163 | - termcolor==2.3.0
164 | - tqdm==4.66.1
165 | - tzdata==2023.3
166 | - urllib3==2.0.4
167 | - wandb==0.15.8
168 | - xarray==2023.7.0
169 | - yacs==0.1.8
170 | - zipp==3.16.2
171 | prefix: /home/wangh256/anaconda3/envs/GraphEval
172 |
--------------------------------------------------------------------------------
/figures/Diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/figures/Diagram.png
--------------------------------------------------------------------------------
/figures/flow.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
98 |
--------------------------------------------------------------------------------
/log/readme.md:
--------------------------------------------------------------------------------
1 | ### Path to save log files
2 |
--------------------------------------------------------------------------------
/script/FineTuning.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd ../
3 |
4 | export PreTrainer=AM # or others
5 | export PreTrainData=zinc_standard_agent # or others
6 |
7 | export ProbeTask=downstream
8 | export FineTuneData_List=(bbbp tox21 toxcast sider clintox muv hiv bace)
9 | export Checkpoint="./pretrain_models/$PreTrainer/$PreTrainData/epoch99_model_complete.pth"
10 |
11 | for FineTuneData in "${FineTuneData_List[@]}"; do
12 | python src/run_validation.py \
13 | --val_task="finetune" \
14 | --pretrainer="FineTune" \
15 | --dataset=$FineTuneData \
16 | --probe_task=$ProbeTask \
17 | --input_model_file=$Checkpoint ;
18 | done >> FineTune_${PreTrainer}_${PreTrainData}.log
19 |
--------------------------------------------------------------------------------
/script/PreTraining.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd ../
3 |
4 | export PreTrainData=zinc_standard_agent # or geom2d_nmol500000_nconf1
5 | export PreTrainer=AM # or IM EP CP GPT_GNN JOAO JOAOv2 GraphCL Motif Contextual GraphMAE RGCL
6 |
7 | # If the PreTrainer is GraphMVP, then the PreTrainData is geom3d_nmol500000_nconf1
8 |
9 | # === Pre-Training ===
10 | python src/run_pretraining.py \
11 | --dataset="$PreTrainData" \
12 | --pretrainer="$PreTrainer" \
13 | --output_model_dir=./pretrain_models/;
14 |
--------------------------------------------------------------------------------
/script/Probing.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd ../
3 |
4 | export PreTrainer=AM # or others
5 | export PreTrainData=zinc_standard_agent # or others
6 | export EmbeddingDir="./embedding_dir/$PreTrainer/$PreTrainData/"
7 | export Checkpoint="./pretrain_models/$PreTrainer/$PreTrainData/epoch99_model_complete.pth"
8 | # for AM, the default path of checkpoint is "./pretrain_models/$PreTrainer/$PreTrainData/mask_rate-0.15_seed-42_lr-0.001/epoch0_model_complete.pth"
9 | # for the random baseline, just set the checkpoint as epoch0_model_complete.pth
10 |
11 | export FineTuneData_List=(bbbp tox21 toxcast sider clintox muv hiv bace)
12 |
13 | # === Embedding Extractions, from Fixed Pre-Trained GNNs ===
14 | for FineTuneData in "${FineTuneData_List[@]}"; do
15 | python src/run_embedding_extraction.py \
16 | --dataset="$FineTuneData" \
17 | --pretrainer="$PreTrainer" \
18 | --embedding_dir="$EmbeddingDir" \
19 | --input_model_file="$Checkpoint" ;
20 | done
21 |
22 | # === Probing on Downstream Tasks ===
23 | export ProbeTask=downstream
24 | for FineTuneData in "${FineTuneData_List[@]}"; do
25 | python run_validation.py \
26 | --dataset="$FineTuneData" \
27 | --probe_task="$ProbeTask" \
28 | --pretrainer="$PreTrainer" \
29 | --embedding_dir="$EmbeddingDir";
30 | done >> Downstream_Probe_${PreTrainer}_${PreTrainData}.log
31 |
32 |
33 | # === Probing on Topological Metrics ===
34 | export ProbeTask=node_degree # or node_centrality node_clustering link_prediction jaccard_coefficient katz_index graph_diameter node_connectivity cycle_basis assortativity_coefficient average_clustering_coefficient
35 | for FineTuneData in "${FineTuneData_List[@]}"; do
36 | python src/run_validation.py \
37 | --probe_task="$ProbeTask" \
38 | --dataset="$FineTuneData" \
39 | --pretrainer="$PreTrainer" \
40 | --embedding_dir="$EmbeddingDir" ;
41 | done >> Topological_Probe_${PreTrainer}_${PreTrainData}.log
42 |
43 |
44 | # === Probing on Substructures ===
45 | export Substructure_List=(fr_epoxide fr_lactam fr_morpholine fr_oxazole \
46 | fr_tetrazole fr_N_O fr_ether fr_furan fr_guanido fr_halogen fr_morpholine \
47 | fr_piperdine fr_thiazole fr_thiophene fr_urea fr_allylic_oxid fr_amide \
48 | fr_amidine fr_azo fr_benzene fr_imidazole fr_imide fr_piperzine fr_pyridine)
49 |
50 | for Substructure in "${Substructure_List[@]}"; do
51 | for FineTuneData in "${FineTuneData_List[@]}"; do
52 | python src/run_validation.py \
53 | --dataset="$FineTuneData" \
54 | --pretrainer="$PreTrainer" \
55 | --embedding_dir="$EmbeddingDir" \
56 | --probe_task="RDKiTFragment_$Substructure" ;
57 | done
58 | done >> Substructure_Probe_${PreTrainer}_${PreTrainData}.log
59 |
--------------------------------------------------------------------------------
/script/README.md:
--------------------------------------------------------------------------------
1 | ## Scripts
2 |
3 | ### see PreTraining.sh, Probing.sh, FineTuning.sh for details.
4 |
5 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/src/__init__.py
--------------------------------------------------------------------------------
/src/config/__init__.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from abc import ABC
3 |
4 |
5 | def str2bool(v: str):
6 | return v.lower() == "true"
7 |
8 |
9 | @dataclasses.dataclass
10 | class Config(ABC):
11 | # about seed and basic info
12 | seed: int
13 | runseed: int
14 | device: int
15 | no_cuda: bool
16 | dataset: str
17 | # about model and pre-trainer
18 | model: str
19 | pretrainer: str
20 |
--------------------------------------------------------------------------------
/src/config/aug_whitening.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from .training_config import TrainingConfig
4 | from .validation_config import ValidationConfig
5 |
6 |
7 | def aug_whitening(
8 | config: Union[TrainingConfig, ValidationConfig]
9 | ) -> Union[TrainingConfig, ValidationConfig]:
10 | """Remove the augmentations in dataloader."""
11 |
12 | """ AM PreTrainer """
13 | config.mask_rate = 0.0
14 | config.mask_edge = 0
15 |
16 | """ GraphCL/JOAO/JOAOv2 PreTrainer """
17 | config.aug_mode = "no_aug"
18 | config.aug_strength = 0.0
19 | config.aug_prob = 0.0
20 |
21 | # unresolved Methods:
22 | # GraphGPTGNN, Contextual
23 |
24 | # unchanged Methods:
25 | # Motif, IM
26 | return config
27 |
--------------------------------------------------------------------------------
/src/config/sweeps/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/src/config/sweeps/__init__.py
--------------------------------------------------------------------------------
/src/config/sweeps/mlp.yaml:
--------------------------------------------------------------------------------
1 | # see https://docs.wandb.ai/guides/sweeps/configuration for more infos
2 | method: grid
3 | metric:
4 | name: val_loss
5 | goal: minimize
6 | parameters:
7 | mlp_dim_hidden:
8 | values:
9 | - 128
10 | - 256
11 | - 512
12 | - 1024
13 | mlp_num_layers:
14 | values:
15 | - 1
16 | - 2
17 | - 4
18 | program: run_validation.py
19 |
--------------------------------------------------------------------------------
/src/config/training_config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dataclasses
3 |
4 | from config import Config, str2bool
5 |
6 |
7 | @dataclasses.dataclass
8 | class TrainingConfig(Config):
9 | model: str
10 | pretrainer: str
11 | log_filepath: str
12 | log_to_wandb: bool
13 | log_interval: int
14 | project_name: str
15 | run_name: str
16 | val_interval: int
17 | eval_train: bool
18 | input_data_dir: str
19 | save_model: bool
20 | input_model_file: str
21 | output_model_dir: str
22 | verbose: bool
23 | num_workers: str
24 | optimizer_name: str
25 | split: str
26 | batch_size: int
27 | epochs: int
28 | epochs_save: int
29 | lr: float
30 | # lr_scale: float
31 | weight_decay: float
32 | gnn_type: str
33 | num_layer: int
34 | emb_dim: int
35 | dropout_ratio: float
36 | graph_pooling: str
37 | JK: str
38 | # gnn_lr_scale: float
39 | aggr: str
40 |
41 | # === GraphCL ===
42 | aug_mode: str
43 | aug_strength: float
44 | aug_prob: float
45 | # === AttrMask ===
46 | mask_rate: float
47 | mask_edge: int
48 | num_atom_type: int
49 | num_edge_type: int
50 | # === ContextPred ===
51 | csize: int
52 | contextpred_neg_samples: int
53 | atom_vocab_size: int
54 | # === JOAO ===
55 | gamma_joao: float
56 | gamma_joaov2: float
57 | # === GraphMVP ===
58 | GMVP_alpha1: float
59 | GMVP_alpha2: float
60 | GMVP_T: float
61 | GMVP_normalize: bool
62 | GMVP_CL_similarity_metric: str
63 | GMVP_CL_Neg_Samples: int
64 | GMVP_Masking_Ratio: float
65 |
66 |
67 | def parse_config() -> TrainingConfig:
68 | parser = argparse.ArgumentParser()
69 |
70 | # Seed and Basic Info
71 | parser.add_argument("--seed", type=int, default=42)
72 | parser.add_argument("--device", type=int, default=0)
73 | parser.add_argument("--runseed", type=int, default=0)
74 | parser.add_argument("--no_cuda", type=str2bool, default=False, help="Disable CUDA")
75 | parser.add_argument(
76 | "--model", type=str, default="gnn", choices=["gnn", "schnet", "egnn"]
77 | )
78 | parser.add_argument(
79 | "--pretrainer",
80 | type=str,
81 | default="GraphCL",
82 | choices=[
83 | "Motif",
84 | "Contextual",
85 | "GPT_GNN",
86 | "GraphCL",
87 | "JOAO",
88 | "JOAOv2",
89 | "AM",
90 | "IM",
91 | "CP",
92 | "EP",
93 | "GraphMVP",
94 | "RGCL",
95 | "GraphMAE",
96 | ],
97 | )
98 |
99 | # Logging
100 | parser.add_argument("--log_filepath", type=str, default="./log/")
101 | parser.add_argument("--log_to_wandb", type=str2bool, default=True)
102 | parser.add_argument(
103 | "--log_interval", default=10, type=int, help="Log every n steps"
104 | )
105 | parser.add_argument(
106 | "--val_interval",
107 | default=1,
108 | type=int,
109 | help="Evaluate validation push_loss every n steps",
110 | )
111 | parser.add_argument(
112 | "--project_name", default="GraphEval", type=str, help="project name in wandb"
113 | )
114 | parser.add_argument("--run_name", type=str, help="run name in wandb")
115 | parser.add_argument("--verbose", type=str2bool, default=False)
116 |
117 | # about if we would print out eval metric for training data
118 | parser.add_argument("--eval_train", type=str2bool, default=True)
119 | parser.add_argument("--input_data_dir", type=str, default="")
120 |
121 | # Loading and saving model checkpoints
122 | parser.add_argument("--save_model", type=str2bool, default=True)
123 | parser.add_argument("--input_model_file", type=str, default="")
124 | parser.add_argument("--output_model_dir", type=str, default="./saved_models/")
125 |
126 | # about dataset and dataloader
127 | parser.add_argument("--dataset", type=str, default="bace")
128 | parser.add_argument("--num_workers", type=int, default=8)
129 |
130 | # Training strategies (shared by PreTraining and FineTuning)
131 | parser.add_argument("--optimizer_name", type=str, default="adam")
132 | parser.add_argument("--split", type=str, default="scaffold")
133 | parser.add_argument("--batch_size", type=int, default=256)
134 | parser.add_argument("--epochs", type=int, default=100)
135 | parser.add_argument("--epochs_save", type=int, default=20)
136 | parser.add_argument("--lr", type=float, default=0.001)
137 | # parser.add_argument("--lr_scale", type=float, default=1)
138 | parser.add_argument("--weight_decay", type=float, default=0)
139 |
140 | # Molecule GNN
141 | parser.add_argument("--gnn_type", type=str, default="gin")
142 | parser.add_argument("--num_layer", type=int, default=5)
143 | parser.add_argument("--emb_dim", type=int, default=300)
144 | parser.add_argument("--dropout_ratio", type=float, default=0.5)
145 | parser.add_argument("--graph_pooling", type=str, default="mean")
146 | parser.add_argument(
147 | "--JK",
148 | type=str,
149 | default="last",
150 | choices=["last", "sum", "max", "concat"],
151 | help="how the node features across layers are combined.",
152 | )
153 | # parser.add_argument("--gnn_lr_scale", type=float, default=1)
154 | parser.add_argument("--aggr", type=str, default="add")
155 |
156 | # PreTrainer: GraphCL, JOAO, JOAOv2
157 | parser.add_argument("--aug_mode", type=str, default="sample")
158 | parser.add_argument("--aug_strength", type=float, default=0.2)
159 | parser.add_argument("--aug_prob", type=float, default=0.1)
160 |
161 | # PreTrainer: AttrMask
162 | parser.add_argument("--mask_rate", type=float, default=0.15)
163 | parser.add_argument("--mask_edge", type=int, default=0)
164 | parser.add_argument("--num_atom_type", type=int, default=119)
165 | parser.add_argument("--num_edge_type", type=int, default=5)
166 |
167 | # PreTrainer: G-Cont, will automatically adjust based on pre-training data
168 | parser.add_argument("--atom_vocab_size", type=int, default=1)
169 |
170 | # PreTrainer: ContextPred
171 | parser.add_argument("--csize", type=int, default=3)
172 | parser.add_argument("--contextpred_neg_samples", type=int, default=1)
173 |
174 | # PreTrainer: JOAO and JOAOv2
175 | parser.add_argument("--gamma_joao", type=float, default=0.1)
176 | parser.add_argument("--gamma_joaov2", type=float, default=0.1)
177 |
178 | # PreTrainer: GraphMVP
179 | # Ref: https://github.com/chao1224/GraphMVP/blob/main/scripts_classification/submit_pre_training_GraphMVP_hybrid.sh#L14-L50
180 | parser.add_argument("--GMVP_alpha1", type=float, default=1)
181 | parser.add_argument("--GMVP_alpha2", type=float, default=1) # 0.1, 1, 10
182 | parser.add_argument("--GMVP_T", type=float, default=0.1) # 0.1, 0.2, 0.5, 1, 2
183 | parser.add_argument("--GMVP_normalize", type=bool, default=True)
184 | parser.add_argument(
185 | "--GMVP_CL_similarity_metric",
186 | type=str,
187 | default="EBM_dot_prod",
188 | choices=["InfoNCE_dot_prod", "EBM_dot_prod"],
189 | )
190 | parser.add_argument("--GMVP_CL_Neg_Samples", type=int, default=5)
191 | parser.add_argument("--GMVP_Masking_Ratio", type=float, default=0.0)
192 |
193 | args = parser.parse_args()
194 | return TrainingConfig(**vars(args))
195 |
--------------------------------------------------------------------------------
/src/config/validation_config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dataclasses
3 |
4 | from config import Config, str2bool
5 |
6 | # TODO (Jean): There are lots of redundancies between TrainingConfig and ValidationConfig
7 | # I am sure there is a better way; maybe making ValidationConfig a subclass of TrainingConfig
8 |
9 |
10 | @dataclasses.dataclass
11 | class ValidationConfig(Config):
12 | # validation task
13 | val_task: str
14 | probe_task: str
15 | # logging
16 | log_filepath: str
17 | log_to_wandb: bool
18 | log_interval: int
19 | val_interval: int
20 | project_name: str
21 | run_name: str
22 | # about if we would print out eval metric for training data
23 | eval_train: bool
24 | input_data_dir: str
25 | # about loading and saving
26 | save_model: bool
27 | input_model_file: str
28 | output_model_dir: str
29 | embedding_dir: str
30 | verbose: bool
31 | # about dataset and dataloader
32 | batch_size: int
33 | num_workers: str
34 | # about molecule GNN
35 | gnn_type: str
36 | num_layer: int
37 | emb_dim: int
38 | dropout_ratio: float
39 | graph_pooling: str
40 | JK: str
41 | # gnn_lr_scale: float
42 | aggr: str
43 |
44 | # for ProberTaskMLP
45 | mlp_dim_hidden: int
46 | mlp_dim_out: int
47 | mlp_num_layers: int
48 | mlp_batch_norm: bool
49 | mlp_initializer: str
50 | mlp_dropout: float
51 | mlp_activation: str
52 | mlp_leaky_relu: float
53 |
54 | # for GraphCL
55 | aug_mode: str
56 | aug_strength: float
57 | aug_prob: float
58 | # for AttributeMask
59 | mask_rate: float
60 | mask_edge: int
61 | num_atom_type: int
62 | num_edge_type: int
63 | # for ContextPred
64 | csize: int
65 | atom_vocab_size: int
66 | contextpred_neg_samples: int
67 | # for JOAO
68 | gamma_joao: float
69 | gamma_joaov2: float
70 |
71 | # Validation metric arguments
72 |
73 | # ProberTask
74 | optimizer_name: str
75 | split: str
76 | batch_size: int
77 | epochs: int
78 | lr: float
79 | # lr_scale: float
80 | weight_decay: float
81 | criterion_type: float
82 |
83 |
84 | def parse_config(parser: argparse.ArgumentParser = None) -> ValidationConfig:
85 | parser = argparse.ArgumentParser() if parser is None else parser
86 | # val and probe task
87 | parser.add_argument(
88 | "--val_task",
89 | type=str,
90 | default="prober",
91 | choices=["prober", "smoothing_metric", "finetune"],
92 | )
93 | parser.add_argument("--probe_task", type=str, default="downstream")
94 |
95 | # seed and basic info
96 | parser.add_argument("--seed", type=int, default=42)
97 | parser.add_argument("--runseed", type=int, default=0)
98 | parser.add_argument("--no_cuda", type=str2bool, default=False)
99 | parser.add_argument("--device", type=int, default=0)
100 | parser.add_argument(
101 | "--model", type=str, default="gnn", choices=["gnn", "schnet", "egnn"]
102 | )
103 | parser.add_argument(
104 | "--pretrainer",
105 | type=str,
106 | default="AM",
107 | choices=[
108 | "Motif",
109 | "Contextual",
110 | "GPT_GNN",
111 | "GraphCL",
112 | "JOAO",
113 | "JOAOv2",
114 | "AM",
115 | "IM",
116 | "GraphMVP",
117 | "CP",
118 | "EP",
119 | "GraphMAE",
120 | "RGCL",
121 | "FineTune",
122 | ],
123 | )
124 |
125 | # logging
126 | parser.add_argument("--log_filepath", type=str, default="./log/")
127 | parser.add_argument("--log_to_wandb", type=str2bool, default=True)
128 | parser.add_argument("--log_interval", default=10, type=int, help="Log steps")
129 | parser.add_argument(
130 | "--val_interval",
131 | default=1,
132 | type=int,
133 | help="Evaluate validation push_loss every n steps",
134 | )
135 | parser.add_argument(
136 | "--project_name", default="GraphEval", type=str, help="project name in wandb"
137 | )
138 | parser.add_argument("--run_name", type=str, help="run name in wandb")
139 |
140 | # about loading and saving
141 | parser.add_argument("--save_model", type=str2bool, default=True)
142 | parser.add_argument("--input_model_file", type=str, default="")
143 | parser.add_argument("--output_model_dir", type=str, default="")
144 | parser.add_argument("--embedding_dir", type=str, default="")
145 | # parser.add_argument("--embedding_dir", type=str,
146 | # default="./embedding_dir_x/Contextual/geom2d_nmol50000_nconf1_nupper1000/")
147 | parser.add_argument("--verbose", type=str2bool, default=False)
148 |
149 | # about dataset and dataloader
150 | parser.add_argument("--dataset", type=str, default="tox21")
151 | parser.add_argument("--num_workers", type=int, default=8)
152 | parser.add_argument("--batch_size", type=int, default=256)
153 |
154 | # ProberTask
155 | # about training strategies
156 | parser.add_argument("--optimizer_name", type=str, default="adam")
157 | parser.add_argument("--split", type=str, default="scaffold")
158 | parser.add_argument("--epochs", type=int, default=100)
159 | parser.add_argument("--lr", type=float, default=0.001)
160 | # parser.add_argument("--lr_scale", type=float, default=1)
161 | parser.add_argument("--weight_decay", type=float, default=0)
162 | parser.add_argument("--criterion_type", type=str, default="mse")
163 |
164 | # about molecule GNN
165 | parser.add_argument("--gnn_type", type=str, default="gin")
166 | parser.add_argument("--num_layer", type=int, default=5)
167 | parser.add_argument("--emb_dim", type=int, default=300)
168 | parser.add_argument("--dropout_ratio", type=float, default=0.5)
169 | parser.add_argument("--graph_pooling", type=str, default="mean")
170 | parser.add_argument(
171 | "--JK",
172 | type=str,
173 | default="last",
174 | choices=["last", "sum", "max", "concat"],
175 | help="how the node features across layers are combined.",
176 | )
177 | # parser.add_argument("--gnn_lr_scale", type=float, default=1)
178 | parser.add_argument("--aggr", type=str, default="add")
179 |
180 | # for ProberTaskMLP
181 | parser.add_argument("--mlp_dim_hidden", type=int, default=600)
182 | parser.add_argument("--mlp_dim_out", type=int, default=1)
183 | parser.add_argument("--mlp_num_layers", type=int, default=2)
184 | parser.add_argument("--mlp_batch_norm", type=str2bool, default=False)
185 | parser.add_argument("--mlp_initializer", type=str, default="xavier")
186 | parser.add_argument("--mlp_dropout", type=float, default=0.0)
187 | parser.add_argument(
188 | "--mlp_activation",
189 | type=str,
190 | default="relu",
191 | choices=["leaky_relu", "rrelu", "relu", "elu", "gelu", "prelu", "selu"],
192 | )
193 | parser.add_argument("--mlp_leaky_relu", type=float, default=0.5)
194 |
195 | # for GraphCL
196 | parser.add_argument("--aug_mode", type=str, default="sample")
197 | parser.add_argument("--aug_strength", type=float, default=0.2)
198 | parser.add_argument("--aug_prob", type=float, default=0.1)
199 |
200 | # for AttributeMask
201 | parser.add_argument("--mask_rate", type=float, default=0.15)
202 | parser.add_argument("--mask_edge", type=int, default=0)
203 | parser.add_argument("--num_atom_type", type=int, default=119)
204 | parser.add_argument("--num_edge_type", type=int, default=5)
205 |
206 | # PreTrainer: G-Cont, will automatically adjust based on pre-training data
207 | parser.add_argument("--atom_vocab_size", type=int, default=1)
208 |
209 | # for ContextPred
210 | parser.add_argument("--csize", type=int, default=3)
211 | parser.add_argument("--contextpred_neg_samples", type=int, default=1)
212 |
213 | # for JOAO
214 | parser.add_argument("--gamma_joao", type=float, default=0.1)
215 | parser.add_argument("--gamma_joaov2", type=float, default=0.1)
216 |
217 | # about if we would print out eval metric for training data
218 | parser.add_argument("--eval_train", type=str2bool, default=True)
219 | parser.add_argument("--input_data_dir", type=str, default="")
220 |
221 | args = parser.parse_args()
222 | return ValidationConfig(**vars(args))
223 |
--------------------------------------------------------------------------------
/src/dataloader.py:
--------------------------------------------------------------------------------
1 | # Ref: snap-stanford/pretrain-gnns/blob/master/bio/dataloader.py
2 | # Ref: snap-stanford/pretrain-gnns/blob/master/chem/dataloader.py
3 | from util import MaskAtom
4 | from torch.utils.data import DataLoader
5 | from batch import BatchAE, BatchMasking, BatchSubstructContext
6 |
7 |
8 | class DataLoaderSubstructContext(DataLoader):
9 | """Data loader which merges data objects from a
10 | :class:`torch_geometric.data.dataset` to a mini-batch.
11 | Args:
12 | dataset (Dataset): The dataset from which to load the data.
13 | batch_size (int, optional): How may samples per batch to load.
14 | (default: :obj:`1`)
15 | shuffle (bool, optional): If set to :obj:`True`, the data will be
16 | reshuffled at every epoch (default: :obj:`True`)"""
17 |
18 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
19 | super(DataLoaderSubstructContext, self).__init__(
20 | dataset,
21 | batch_size,
22 | shuffle,
23 | collate_fn=lambda l: BatchSubstructContext.from_data_list(l),
24 | **kwargs
25 | )
26 |
27 |
28 | class DataLoaderMasking(DataLoader):
29 | """Data loader which merges data objects from a
30 | :class:`torch_geometric.data.dataset` to a mini-batch.
31 | Args:
32 | dataset (Dataset): The dataset from which to load the data.
33 | batch_size (int, optional): How may samples per batch to load.
34 | (default: :obj:`1`)
35 | shuffle (bool, optional): If set to :obj:`True`, the data will be
36 | reshuffled at every epoch (default: :obj:`True`)"""
37 |
38 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
39 | super(DataLoaderMasking, self).__init__(
40 | dataset,
41 | batch_size,
42 | shuffle,
43 | collate_fn=lambda l: BatchMasking.from_data_list(l),
44 | **kwargs
45 | )
46 |
47 |
48 | class DataLoaderAE(DataLoader):
49 | """Data loader which merges data objects from a
50 | :class:`torch_geometric.data.dataset` to a mini-batch.
51 | Args:
52 | dataset (Dataset): The dataset from which to load the data.
53 | batch_size (int, optional): How may samples per batch to load.
54 | (default: :obj:`1`)
55 | shuffle (bool, optional): If set to :obj:`True`, the data will be
56 | reshuffled at every epoch (default: :obj:`True`)"""
57 |
58 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
59 | super(DataLoaderAE, self).__init__(
60 | dataset,
61 | batch_size,
62 | shuffle,
63 | collate_fn=lambda l: BatchAE.from_data_list(l),
64 | **kwargs
65 | )
66 |
67 |
68 | class DataLoaderMaskingPred(DataLoader):
69 | r"""Data loader which merges data objects from a
70 | :class:`torch_geometric.data.dataset` to a mini-batch.
71 | Args:
72 | dataset (Dataset): The dataset from which to load the data.
73 | batch_size (int, optional): How may samples per batch to load.
74 | (default: :obj:`1`)
75 | shuffle (bool, optional): If set to :obj:`True`, the data will be
76 | reshuffled at every epoch (default: :obj:`True`)
77 | """
78 |
79 | def __init__(
80 | self,
81 | dataset,
82 | batch_size=1,
83 | shuffle=True,
84 | mask_rate=0.25,
85 | mask_edge=0.0,
86 | **kwargs
87 | ):
88 | self._transform = MaskAtom(
89 | num_atom_type=119, num_edge_type=5, mask_rate=mask_rate, mask_edge=mask_edge
90 | )
91 | super(DataLoaderMaskingPred, self).__init__(
92 | dataset, batch_size, shuffle, collate_fn=self.collate_fn, **kwargs
93 | )
94 |
95 | def collate_fn(self, batches):
96 | batchs = [self._transform(x) for x in batches]
97 | return BatchMasking.from_data_list(batchs)
98 |
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .molecule_contextual import MoleculeDataset_Contextual
2 | from .molecule_datasets import MoleculeDataset
3 | from .molecule_gpt_gnn import MoleculeDataset_GPTGNN
4 | from .molecule_graphcl import MoleculeDataset_GraphCL
5 | from .molecule_motif import RDKIT_PROPS, MoleculeDataset_Motif
6 | from .molecule_rgcl import MoleculeDataset_RGCL
7 | from .molecule_graphmvp import Molecule3DDataset, Molecule3DMaskingDataset
8 | from .utils import (
9 | allowable_features,
10 | graph_data_obj_to_mol_simple,
11 | graph_data_obj_to_nx_simple,
12 | nx_to_graph_data_obj_simple,
13 | )
14 |
--------------------------------------------------------------------------------
/src/datasets/molecule_gpt_gnn.py:
--------------------------------------------------------------------------------
1 | import torch, random
2 | from tqdm import tqdm
3 | from torch_geometric.utils import subgraph
4 | from torch_geometric.data import Data, InMemoryDataset
5 |
6 |
7 | def search_graph(graph):
8 | num_node, edge_set = len(graph.x), set()
9 | u_list, v_list = graph.edge_index[0].numpy(), graph.edge_index[1].numpy()
10 | for u, v in zip(u_list, v_list):
11 | edge_set.add((u, v))
12 | edge_set.add((v, u))
13 |
14 | visited_list = list()
15 | unvisited_set = set([i for i in range(num_node)])
16 |
17 | while len(unvisited_set) > 0:
18 | u = random.sample(unvisited_set, 1)[0]
19 | queue = [u]
20 | while len(queue):
21 | u = queue.pop(0)
22 | if u in visited_list:
23 | continue
24 | visited_list.append(u)
25 | unvisited_set.remove(u)
26 |
27 | for v in range(num_node):
28 | if (v not in visited_list) and ((u, v) in edge_set):
29 | queue.append(v)
30 | assert len(visited_list) == num_node
31 | return visited_list
32 |
33 |
34 | class MoleculeDataset_GPTGNN(InMemoryDataset):
35 | def __init__(self, molecule_dataset, transform=None, pre_transform=None):
36 | self.molecule_dataset = molecule_dataset
37 | self.root = molecule_dataset.root + "_GPT"
38 | super(MoleculeDataset_GPTGNN, self).__init__(
39 | self.root, transform=transform, pre_transform=pre_transform
40 | )
41 |
42 | self.data, self.slices = torch.load(self.processed_paths[0])
43 |
44 | return
45 |
46 | def process(self):
47 | num_molecule, data_list = len(self.molecule_dataset), list()
48 | for i in tqdm(range(num_molecule)):
49 | graph = self.molecule_dataset.get(i)
50 |
51 | num_node = len(graph.x)
52 | # TODO: will replace this with DFS/BFS searching
53 | node_list = search_graph(graph)
54 |
55 | for idx in range(num_node - 1):
56 | # [0..idx] -> [idx+1]
57 | sub_node_list = node_list[: idx + 1]
58 | next_node = node_list[idx + 1]
59 |
60 | edge_index, edge_attr = subgraph(
61 | subset=sub_node_list,
62 | edge_index=graph.edge_index,
63 | edge_attr=graph.edge_attr,
64 | relabel_nodes=True,
65 | num_nodes=num_node,
66 | )
67 |
68 | # Take the subgraph and predict the next node (atom type only)
69 | sub_graph = Data(
70 | x=graph.x[sub_node_list],
71 | edge_index=edge_index,
72 | edge_attr=edge_attr,
73 | next_x=graph.x[next_node, :1],
74 | )
75 | data_list.append(sub_graph)
76 |
77 | print("len of data\t", len(data_list))
78 | data, slices = self.collate(data_list)
79 | print("Saving...")
80 | torch.save((data, slices), self.processed_paths[0])
81 | return
82 |
83 | @property
84 | def processed_file_names(self):
85 | return "geometric_data_processed.pt"
86 |
--------------------------------------------------------------------------------
/src/datasets/molecule_graphcl.py:
--------------------------------------------------------------------------------
1 | import torch, numpy as np
2 | from itertools import repeat
3 | from torch_geometric.data import Data
4 | from torch_geometric.utils import subgraph, to_networkx
5 |
6 | from .molecule_datasets import MoleculeDataset
7 |
8 |
9 | class MoleculeDataset_GraphCL(MoleculeDataset):
10 | # used in GraphCL, JOAOv1, JOAOv2
11 | def __init__(
12 | self,
13 | root,
14 | transform=None,
15 | pre_transform=None,
16 | pre_filter=None,
17 | dataset=None,
18 | empty=False,
19 | ):
20 | self.aug_prob = None
21 | self.aug_mode = "no_aug"
22 | self.aug_strength = 0.2
23 | self.augmentations = [
24 | self.node_drop,
25 | self.subgraph,
26 | self.edge_pert,
27 | self.attr_mask,
28 | lambda x: x,
29 | ]
30 | super(MoleculeDataset_GraphCL, self).__init__(
31 | root, transform, pre_transform, pre_filter, dataset, empty
32 | )
33 |
34 | def set_augMode(self, aug_mode):
35 | self.aug_mode = aug_mode
36 |
37 | def set_augStrength(self, aug_strength):
38 | self.aug_strength = aug_strength
39 |
40 | def set_augProb(self, aug_prob):
41 | self.aug_prob = aug_prob
42 |
43 | def node_drop(self, data):
44 | node_num, _ = data.x.size()
45 | _, edge_num = data.edge_index.size()
46 | drop_num = int(node_num * self.aug_strength)
47 |
48 | idx_perm = np.random.permutation(node_num)
49 | idx_nodrop = idx_perm[drop_num:].tolist()
50 | idx_nodrop.sort()
51 |
52 | edge_idx, edge_attr = subgraph(
53 | subset=idx_nodrop,
54 | edge_index=data.edge_index,
55 | edge_attr=data.edge_attr,
56 | relabel_nodes=True,
57 | num_nodes=node_num,
58 | )
59 |
60 | data.edge_index = edge_idx
61 | data.edge_attr = edge_attr
62 | data.x = data.x[idx_nodrop]
63 | data.__num_nodes__, _ = data.x.shape
64 | return data
65 |
66 | def edge_pert(self, data):
67 | node_num, _ = data.x.size()
68 | _, edge_num = data.edge_index.size()
69 | pert_num = int(edge_num * self.aug_strength)
70 |
71 | # del edges
72 | idx_drop = np.random.choice(edge_num, (edge_num - pert_num), replace=False)
73 | edge_index = data.edge_index[:, idx_drop]
74 | edge_attr = data.edge_attr[idx_drop]
75 |
76 | # add edges
77 | adj = torch.ones((node_num, node_num))
78 | adj[edge_index[0], edge_index[1]] = 0
79 | edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t()
80 | idx_add = np.random.choice(
81 | edge_index_nonexist.shape[1], pert_num, replace=False
82 | )
83 | edge_index_add = edge_index_nonexist[:, idx_add]
84 |
85 | # random 4-class & 3-class edge_attr for 1st & 2nd dimension
86 | edge_attr_add_1 = torch.tensor(
87 | np.random.randint(4, size=(edge_index_add.shape[1], 1))
88 | )
89 | edge_attr_add_2 = torch.tensor(
90 | np.random.randint(3, size=(edge_index_add.shape[1], 1))
91 | )
92 | edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2), dim=1)
93 | edge_index = torch.cat((edge_index, edge_index_add), dim=1)
94 | edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0)
95 |
96 | data.edge_index = edge_index
97 | data.edge_attr = edge_attr
98 | return data
99 |
100 | def attr_mask(self, data):
101 | _x = data.x.clone()
102 | node_num, _ = data.x.size()
103 | mask_num = int(node_num * self.aug_strength)
104 |
105 | token = data.x.float().mean(dim=0).long()
106 | idx_mask = np.random.choice(node_num, mask_num, replace=False)
107 |
108 | _x[idx_mask] = token
109 | data.x = _x
110 | return data
111 |
112 | def subgraph(self, data):
113 | G = to_networkx(data)
114 | node_num, _ = data.x.size()
115 | _, edge_num = data.edge_index.size()
116 | sub_num = int(node_num * (1 - self.aug_strength))
117 |
118 | idx_sub = [np.random.randint(node_num, size=1)[0]]
119 | idx_neigh = set([n for n in G.neighbors(idx_sub[-1])])
120 |
121 | while len(idx_sub) <= sub_num:
122 | if len(idx_neigh) == 0:
123 | idx_unsub = list(
124 | set([n for n in range(node_num)]).difference(set(idx_sub))
125 | )
126 | idx_neigh = set([np.random.choice(idx_unsub)])
127 | sample_node = np.random.choice(list(idx_neigh))
128 |
129 | idx_sub.append(sample_node)
130 | idx_neigh = idx_neigh.union(
131 | set([n for n in G.neighbors(idx_sub[-1])])
132 | ).difference(set(idx_sub))
133 |
134 | idx_nondrop = idx_sub
135 | idx_nondrop.sort()
136 |
137 | edge_idx, edge_attr = subgraph(
138 | subset=idx_nondrop,
139 | edge_index=data.edge_index,
140 | edge_attr=data.edge_attr,
141 | relabel_nodes=True,
142 | num_nodes=node_num,
143 | )
144 |
145 | data.edge_index = edge_idx
146 | data.edge_attr = edge_attr
147 | data.x = data.x[idx_nondrop]
148 | data.__num_nodes__, _ = data.x.shape
149 | return data
150 |
151 | def get(self, idx):
152 | data, data1, data2 = Data(), Data(), Data()
153 | keys_for_2D = ["x", "edge_index", "edge_attr"]
154 | for key in self.data.keys:
155 | item, slices = self.data[key], self.slices[key]
156 | s = list(repeat(slice(None), item.dim()))
157 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
158 | if key in keys_for_2D:
159 | data[key], data1[key], data2[key] = item[s], item[s], item[s]
160 | else:
161 | data[key] = item[s]
162 |
163 | if self.aug_mode == "no_aug":
164 | n_aug1, n_aug2 = 4, 4
165 | data1 = self.augmentations[n_aug1](data1)
166 | data2 = self.augmentations[n_aug2](data2)
167 | elif self.aug_mode == "uniform":
168 | n_aug = np.random.choice(25, 1)[0]
169 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5
170 | data1 = self.augmentations[n_aug1](data1)
171 | data2 = self.augmentations[n_aug2](data2)
172 | elif self.aug_mode == "sample":
173 | n_aug = np.random.choice(25, 1, p=self.aug_prob)[0]
174 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5
175 | data1 = self.augmentations[n_aug1](data1)
176 | data2 = self.augmentations[n_aug2](data2)
177 | else:
178 | raise NotImplementedError("aug_mode not implemented")
179 | return data, data1, data2
180 |
--------------------------------------------------------------------------------
/src/datasets/molecule_graphmvp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from os.path import join
5 | from itertools import repeat
6 | from torch_geometric.loader import DataLoader
7 | from torch_geometric.data import Data, InMemoryDataset
8 | from torch_geometric.utils import subgraph, to_networkx
9 |
10 |
11 | class Molecule3DDataset(InMemoryDataset):
12 | def __init__(
13 | self,
14 | root,
15 | dataset,
16 | transform=None,
17 | pre_transform=None,
18 | pre_filter=None,
19 | empty=False,
20 | ):
21 | self.root = root
22 | self.dataset = dataset
23 |
24 | super(Molecule3DDataset, self).__init__(
25 | root, transform, pre_transform, pre_filter
26 | )
27 | self.transform, self.pre_transform, self.pre_filter = (
28 | transform,
29 | pre_transform,
30 | pre_filter,
31 | )
32 |
33 | if not empty:
34 | self.data, self.slices = torch.load(self.processed_paths[0])
35 | print("Dataset: {}\nData: {}".format(self.dataset, self.data))
36 |
37 | def get(self, idx):
38 | data = Data()
39 | for key in self.data.keys:
40 | item, slices = self.data[key], self.slices[key]
41 | s = list(repeat(slice(None), item.dim()))
42 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
43 | data[key] = item[s]
44 | return data
45 |
46 | @property
47 | def raw_file_names(self):
48 | return os.listdir(self.raw_dir)
49 |
50 | @property
51 | def processed_file_names(self):
52 | return "geometric_data_processed.pt"
53 |
54 | def process(self):
55 | return
56 |
57 |
58 | class Molecule3DMaskingDataset(InMemoryDataset):
59 | def __init__(
60 | self,
61 | root,
62 | dataset,
63 | mask_ratio,
64 | transform=None,
65 | pre_transform=None,
66 | pre_filter=None,
67 | empty=False,
68 | ):
69 | self.root = root
70 | self.dataset = dataset
71 | self.mask_ratio = mask_ratio
72 |
73 | super(Molecule3DMaskingDataset, self).__init__(
74 | root, transform, pre_transform, pre_filter
75 | )
76 | self.transform, self.pre_transform, self.pre_filter = (
77 | transform,
78 | pre_transform,
79 | pre_filter,
80 | )
81 |
82 | if not empty:
83 | self.data, self.slices = torch.load(self.processed_paths[0])
84 | print("Dataset: {}\nData: {}".format(self.dataset, self.data))
85 |
86 | def subgraph(self, data):
87 | G = to_networkx(data)
88 | node_num, _ = data.x.size()
89 | sub_num = int(node_num * (1 - self.mask_ratio))
90 |
91 | idx_sub = [np.random.randint(node_num, size=1)[0]]
92 | idx_neigh = set([n for n in G.neighbors(idx_sub[-1])])
93 |
94 | # BFS
95 | while len(idx_sub) <= sub_num:
96 | if len(idx_neigh) == 0:
97 | idx_unsub = list(
98 | set([n for n in range(node_num)]).difference(set(idx_sub))
99 | )
100 | idx_neigh = set([np.random.choice(idx_unsub)])
101 | sample_node = np.random.choice(list(idx_neigh))
102 |
103 | idx_sub.append(sample_node)
104 | idx_neigh = idx_neigh.union(
105 | set([n for n in G.neighbors(idx_sub[-1])])
106 | ).difference(set(idx_sub))
107 |
108 | idx_nondrop = idx_sub
109 | idx_nondrop.sort()
110 |
111 | edge_idx, edge_attr = subgraph(
112 | subset=idx_nondrop,
113 | edge_index=data.edge_index,
114 | edge_attr=data.edge_attr,
115 | relabel_nodes=True,
116 | num_nodes=node_num,
117 | )
118 |
119 | data.edge_index = edge_idx
120 | data.edge_attr = edge_attr
121 | data.x = data.x[idx_nondrop]
122 | data.positions = data.positions[idx_nondrop]
123 | data.__num_nodes__, _ = data.x.shape
124 | return data
125 |
126 | def get(self, idx):
127 | data = Data()
128 | for key in self.data.keys:
129 | item, slices = self.data[key], self.slices[key]
130 | s = list(repeat(slice(None), item.dim()))
131 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
132 | data[key] = item[s]
133 |
134 | if self.mask_ratio > 0:
135 | data = self.subgraph(data)
136 | return data
137 |
138 | @property
139 | def raw_file_names(self):
140 | return os.listdir(self.raw_dir)
141 |
142 | @property
143 | def processed_file_names(self):
144 | return "geometric_data_processed.pt"
145 |
146 | def process(self):
147 | return
148 |
149 |
150 | if __name__ == "__main__":
151 | dataset = "geom3d_nmol500000_nconf1"
152 | root = join("~/GraphEval_dev/data/GEOM", dataset)
153 | dataset = Molecule3DDataset(root, dataset=dataset)
154 | dataset = Molecule3DMaskingDataset(root, dataset=dataset, mask_ratio=0.1)
155 |
156 | loader = DataLoader(dataset[:64], batch_size=16, shuffle=True, num_workers=4)
157 | for _, batch in enumerate(loader):
158 | print(
159 | "\n",
160 | "batch.x.shape: ",
161 | batch.x.shape,
162 | "\n",
163 | "batch.edge_index.shape: ",
164 | batch.edge_index.shape,
165 | "\n",
166 | "batch.edge_attr.shape: ",
167 | batch.edge_attr.shape,
168 | "\n",
169 | "batch.positions.shape: ",
170 | batch.positions.shape,
171 | "\n",
172 | "batch.batch.shape: ",
173 | batch.batch.shape,
174 | "\n",
175 | "=" * 7,
176 | )
177 |
--------------------------------------------------------------------------------
/src/datasets/molecule_motif.py:
--------------------------------------------------------------------------------
1 | import os, torch, numpy as np
2 | from tqdm import tqdm
3 | from itertools import repeat
4 | from descriptastorus.descriptors import rdDescriptors
5 | from torch_geometric.data import Data, InMemoryDataset
6 |
7 |
8 | RDKIT_PROPS = [
9 | "fr_Al_COO",
10 | "fr_Al_OH",
11 | "fr_Al_OH_noTert",
12 | "fr_ArN",
13 | "fr_Ar_COO",
14 | "fr_Ar_N",
15 | "fr_Ar_NH",
16 | "fr_Ar_OH",
17 | "fr_COO",
18 | "fr_COO2",
19 | "fr_C_O",
20 | "fr_C_O_noCOO",
21 | "fr_C_S",
22 | "fr_HOCCN",
23 | "fr_Imine",
24 | "fr_NH0",
25 | "fr_NH1",
26 | "fr_NH2",
27 | "fr_N_O",
28 | "fr_Ndealkylation1",
29 | "fr_Ndealkylation2",
30 | "fr_Nhpyrrole",
31 | "fr_SH",
32 | "fr_aldehyde",
33 | "fr_alkyl_carbamate",
34 | "fr_alkyl_halide",
35 | "fr_allylic_oxid",
36 | "fr_amide",
37 | "fr_amidine",
38 | "fr_aniline",
39 | "fr_aryl_methyl",
40 | "fr_azide",
41 | "fr_azo",
42 | "fr_barbitur",
43 | "fr_benzene",
44 | "fr_benzodiazepine",
45 | "fr_bicyclic",
46 | "fr_diazo",
47 | "fr_dihydropyridine",
48 | "fr_epoxide",
49 | "fr_ester",
50 | "fr_ether",
51 | "fr_furan",
52 | "fr_guanido",
53 | "fr_halogen",
54 | "fr_hdrzine",
55 | "fr_hdrzone",
56 | "fr_imidazole",
57 | "fr_imide",
58 | "fr_isocyan",
59 | "fr_isothiocyan",
60 | "fr_ketone",
61 | "fr_ketone_Topliss",
62 | "fr_lactam",
63 | "fr_lactone",
64 | "fr_methoxy",
65 | "fr_morpholine",
66 | "fr_nitrile",
67 | "fr_nitro",
68 | "fr_nitro_arom",
69 | "fr_nitro_arom_nonortho",
70 | "fr_nitroso",
71 | "fr_oxazole",
72 | "fr_oxime",
73 | "fr_para_hydroxylation",
74 | "fr_phenol",
75 | "fr_phenol_noOrthoHbond",
76 | "fr_phos_acid",
77 | "fr_phos_ester",
78 | "fr_piperdine",
79 | "fr_piperzine",
80 | "fr_priamide",
81 | "fr_prisulfonamd",
82 | "fr_pyridine",
83 | "fr_quatN",
84 | "fr_sulfide",
85 | "fr_sulfonamd",
86 | "fr_sulfone",
87 | "fr_term_acetylene",
88 | "fr_tetrazole",
89 | "fr_thiazole",
90 | "fr_thiocyan",
91 | "fr_thiophene",
92 | "fr_unbrch_alkane",
93 | "fr_urea",
94 | ]
95 |
96 |
97 | def rdkit_functional_group_label_features_generator(smiles):
98 | """
99 | Generates functional group label for a molecule using RDKit.
100 | :param smiles: A molecule (i.e. either a SMILES string or an RDKit obj).
101 | :return: A 1D numpy array containing the RDKit 2D features."""
102 | # smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
103 | # if type(mol) != str else mol
104 | generator = rdDescriptors.RDKit2D(RDKIT_PROPS)
105 | features = generator.process(smiles)[1:]
106 | features = np.array(features)
107 | features[features != 0] = 1
108 | return features
109 |
110 |
111 | class MoleculeDataset_Motif(InMemoryDataset):
112 | def __init__(
113 | self,
114 | root,
115 | dataset,
116 | transform=None,
117 | pre_transform=None,
118 | pre_filter=None,
119 | empty=False,
120 | ):
121 | self.dataset = dataset
122 | self.root = root
123 |
124 | super(MoleculeDataset_Motif, self).__init__(
125 | root, transform, pre_transform, pre_filter
126 | )
127 | self.transform, self.pre_transform, self.pre_filter = (
128 | transform,
129 | pre_transform,
130 | pre_filter,
131 | )
132 |
133 | if not empty:
134 | self.data, self.slices = torch.load(self.processed_paths[0])
135 |
136 | self.motif_file = os.path.join(root, "processed", "motif.pt")
137 | self.process_motif_file()
138 | self.motif_label_list = torch.load(self.motif_file)
139 |
140 | print(
141 | "Dataset: {}\nData: {}\nMotif: {}".format(
142 | self.dataset, self.data, self.motif_label_list.size()
143 | )
144 | )
145 |
146 | def process_motif_file(self):
147 | if not os.path.exists(self.motif_file):
148 | smiles_file = os.path.join(self.root, "processed", "smiles.csv")
149 | data_smiles_list = []
150 | with open(smiles_file, "r") as f:
151 | lines = f.readlines()
152 | for smiles in lines:
153 | data_smiles_list.append(smiles.strip())
154 |
155 | motif_label_list = []
156 | for smiles in tqdm(data_smiles_list):
157 | label = rdkit_functional_group_label_features_generator(smiles)
158 | motif_label_list.append(label)
159 |
160 | self.motif_label_list = torch.LongTensor(motif_label_list)
161 | torch.save(self.motif_label_list, self.motif_file)
162 | return
163 |
164 | def get(self, idx):
165 | data = Data()
166 | for key in self.data.keys:
167 | item, slices = self.data[key], self.slices[key]
168 | s = list(repeat(slice(None), item.dim()))
169 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
170 | data[key] = item[s]
171 | data.y = self.motif_label_list[idx]
172 | return data
173 |
174 | @property
175 | def raw_file_names(self):
176 | return os.listdir(self.raw_dir)
177 |
178 | @property
179 | def processed_file_names(self):
180 | return "geometric_data_processed.pt"
181 |
182 | def download(self):
183 | return
184 |
185 | def process(self):
186 | return
187 |
--------------------------------------------------------------------------------
/src/load_save.py:
--------------------------------------------------------------------------------
1 | import torch, pickle
2 | from tqdm import tqdm
3 | from typing import List
4 | from pathlib import Path
5 | from torch_geometric.data.dataset import Dataset
6 | from config.training_config import TrainingConfig
7 | from config.validation_config import ValidationConfig
8 | from models.pre_trainer_model import PreTrainerModel
9 | from torch_geometric.loader.dataloader import DataLoader
10 |
11 |
12 | def infer_and_save_embeddings(
13 | config: ValidationConfig,
14 | model: PreTrainerModel,
15 | device: torch.device,
16 | datasets: List[Dataset],
17 | loaders: List[DataLoader],
18 | smile_splits: list,
19 | save: bool = True,
20 | ) -> None:
21 | """Save graph and node representations for analysis.
22 | :param config: configurations
23 | :param model: pretrained or randomly-initialized model.
24 | :param device: device in use.
25 | :param datasets: train, valid and test data.
26 | :param loaders: train, valid and test data loaders.
27 | :param smile_splits: train, valid and test smiles.
28 | :param save: whether to save the pickle file.
29 | :return: None"""
30 | model.eval()
31 | for dataset, loader, smiles, split in zip(
32 | datasets, loaders, smile_splits, ["train", "valid", "test"]
33 | ):
34 | pbar = tqdm(total=len(loader))
35 | pbar.set_description(f"{split} embeddings extracted: ")
36 | graph_embeddings_list, node_embeddings_list = [], []
37 | for batch in loader:
38 | # if config.pretrainer == 'GraphCL' and isinstance(batch, list):
39 | # batch = batch[0] # remove the contrastive augmented data.
40 | batch = batch.to(device)
41 | with torch.no_grad():
42 | node_embeddings, graph_embeddings = model.get_embeddings(
43 | batch.x, batch.edge_index, batch.edge_attr, batch.batch
44 | )
45 | unbatched_node_embedding = [[] for _ in range(batch.batch.max() + 1)]
46 | for embedding, graph_id in zip(node_embeddings, batch.batch):
47 | unbatched_node_embedding[graph_id].append(
48 | embedding.detach().cpu().numpy()
49 | )
50 | graph_embeddings_list.append(graph_embeddings)
51 | node_embeddings_list.extend(unbatched_node_embedding)
52 | pbar.update(1)
53 | graph_embeddings_list = (
54 | torch.cat(graph_embeddings_list, dim=0).detach().cpu().numpy()
55 | )
56 |
57 | if save:
58 | save_embeddings(
59 | config=config,
60 | dataset=dataset,
61 | graph_embeddings=graph_embeddings_list,
62 | node_embeddings=node_embeddings_list,
63 | smiles=smiles,
64 | split=split,
65 | )
66 |
67 | pbar.close()
68 | # return dataset, graph_embeddings, node_embeddings, smiles
69 |
70 |
71 | def save_embeddings(
72 | config: ValidationConfig,
73 | dataset: Dataset,
74 | graph_embeddings: List[torch.Tensor],
75 | node_embeddings: List[List[torch.Tensor]],
76 | smiles: List[str],
77 | split: str,
78 | ):
79 | Path(config.embedding_dir).mkdir(parents=True, exist_ok=True)
80 | with open(f"{config.embedding_dir}{config.dataset}_{split}.pkl", "wb") as f:
81 | pickle.dump([graph_embeddings, node_embeddings, smiles], f)
82 |
83 |
84 | def save_model(config: TrainingConfig, model: torch.nn.Module, epoch: int) -> None:
85 | saver_dict = {"model": model.state_dict()}
86 | cfg = pretrain_config(config)
87 |
88 | # TODO: Update this for G-Contextual,
89 | path_ = f"{config.output_model_dir}/{config.pretrainer}/{config.dataset}/{cfg}"
90 | Path(path_).mkdir(parents=True, exist_ok=True)
91 | torch.save(saver_dict, f"{path_}/epoch{epoch}_model_complete.pth")
92 |
93 |
94 | def pretrain_config(config: TrainingConfig) -> None:
95 | cfg = ""
96 | if config.pretrainer == "AM":
97 | cfg = f"mask_rate-{config.mask_rate}"
98 | # TODO: run these experiments
99 | # elif config.pretrainer == "CP":
100 | # cfg = f"acsize-{config.csize}_atom_vocab_size-{config.atom_vocab_size}_contextpred_neg_samples-{config.contextpred_neg_samples}"
101 | elif config.pretrainer == "GraphCL":
102 | cfg = f"aug_mode-{config.aug_mode}_aug_strength-{config.aug_strength}_aug_prob-{config.aug_prob}"
103 | elif config.pretrainer in ["JOAO", "JOAOv2"]:
104 | cfg = f"gamma_joao-{config.gamma_joao}_gamma_joaov2-{config.gamma_joaov2}"
105 | # TODO: run these experiments
106 | elif config.pretrainer == "GraphMVP":
107 | cfg = f"alpha2-{config.GMVP_alpha2}_temper-{config.GMVP_T}"
108 | cfg = f"{cfg}_seed-{config.seed}_lr-{config.lr}"
109 | return cfg
110 |
111 |
112 | def load_checkpoint(config: ValidationConfig, device: torch.device) -> dict:
113 | print(f"\nLoad checkpoint from path {config.input_model_file}...")
114 | return torch.load(config.input_model_file, map_location=device)
115 |
--------------------------------------------------------------------------------
/src/logger.py:
--------------------------------------------------------------------------------
1 | # Ref: https://raw.githubusercontent.com/davda54/sam/main/example/utility/log.py
2 |
3 | import time, wandb, logging, dataclasses
4 | from pathlib import Path
5 | from datetime import datetime, timedelta
6 | from config.training_config import TrainingConfig
7 |
8 | TIME_STR = "{:%Y_%m_%d_%H_%M_%S_%f}".format(datetime.now())
9 |
10 |
11 | @dataclasses.dataclass
12 | class EpochState:
13 | loss: float = 0.0
14 | accuracy: float = 0.0
15 | samples: int = 0
16 |
17 | def reset(self) -> None:
18 | self.loss, self.accuracy, self.samples = 0.0, 0.0, 0
19 |
20 | def add_to_loss(self, loss: float) -> None:
21 | self.loss += loss
22 |
23 | def add_to_accuracy(self, accuracy: float) -> None:
24 | self.accuracy += accuracy
25 |
26 | def add_to_samples(self, samples: int) -> None:
27 | self.samples += samples
28 |
29 |
30 | class LoadingBar:
31 | def __init__(self, length: int = 40) -> None:
32 | self.length = length
33 | self.symbols = ["┈", "░", "▒", "▓"]
34 |
35 | def __call__(self, progress: float) -> str:
36 | p = int(progress * self.length * 4 + 0.5)
37 | d, r = p // 4, p % 4
38 | return (
39 | "┠┈"
40 | + d * "█"
41 | + (
42 | (self.symbols[r]) + max(0, self.length - 1 - d) * "┈"
43 | if p < self.length * 4
44 | else ""
45 | )
46 | + "┈┨"
47 | )
48 |
49 |
50 | class LogFormatter:
51 | def __init__(self):
52 | self.start_time = time.time()
53 |
54 | def format(self, record):
55 | elapsed_seconds = round(record.created - self.start_time)
56 |
57 | prefix = "%s - %s - %s" % (
58 | record.levelname,
59 | time.strftime("%x %X"),
60 | timedelta(seconds=elapsed_seconds),
61 | )
62 | message = record.getMessage()
63 | message = message.replace("\n", "\n" + " " * (len(prefix) + 3))
64 | return "%s - %s" % (prefix, message)
65 |
66 |
67 | def create_info_logger(filepath) -> logging.Logger:
68 | """
69 | Create a logger.
70 | """
71 |
72 | Path("log").mkdir(parents=True, exist_ok=True)
73 | # create log formatter
74 | log_formatter = LogFormatter()
75 |
76 | # create file handler and set level to debug
77 | file_handler = logging.FileHandler(filepath, "a")
78 | file_handler.setLevel(logging.DEBUG)
79 | file_handler.setFormatter(log_formatter)
80 |
81 | # TODO: If we want to print logs into the console, we can un-comment this
82 | # create console handler and set level to info
83 | # console_handler = logging.StreamHandler()
84 | # console_handler.setLevel(logging.INFO)
85 | # console_handler.setLevel(0)
86 | # console_handler.setFormatter(log_formatter)
87 |
88 | # create logger and set level to debug
89 | logger = logging.getLogger()
90 | logger.handlers = []
91 | logger.setLevel(logging.DEBUG)
92 | logger.propagate = False
93 | logger.addHandler(file_handler)
94 | # logger.addHandler(console_handler)
95 |
96 | # reset logger elapsed time
97 | def reset_time():
98 | log_formatter.start_time = time.time()
99 |
100 | logger.reset_time = reset_time
101 |
102 | return logger
103 |
104 |
105 | @dataclasses.dataclass
106 | class CombinedLogger:
107 | config: TrainingConfig
108 | is_train: bool = True
109 | epoch_state: EpochState = EpochState()
110 | last_steps_state: EpochState = EpochState()
111 | loading_bar: LoadingBar = LoadingBar(length=27)
112 | best_val_accuracy: float = 0.0
113 | epoch: int = -1
114 | step: int = 0
115 | info_logger: logging.Logger = None
116 |
117 | def __post_init__(self):
118 | self.info_logger = create_info_logger(
119 | f"{self.config.log_filepath}-{TIME_STR}.log"
120 | )
121 |
122 | def log_value_dict(self, value_dict: dict):
123 | if self.config.log_to_wandb:
124 | wandb.log(
125 | {
126 | "epoch": self.epoch,
127 | **value_dict,
128 | }
129 | )
130 | self.info_logger.info(value_dict)
131 |
132 | def train(self, num_batches: int) -> None:
133 | self.epoch += 1
134 | if self.epoch == 0:
135 | self._print_header()
136 | else:
137 | self.flush()
138 | self.is_train = True
139 | self.last_steps_state.reset()
140 | self._reset(num_batches)
141 |
142 | def eval(self, num_batches: int) -> None:
143 | self.flush()
144 | self.is_train = False
145 | self._reset(num_batches)
146 |
147 | def __call__(
148 | self, loss, accuracy, batch_size: int, learning_rate: float = None
149 | ) -> None:
150 | # TODO: in training, we don't have to record the accuracy
151 | if self.is_train:
152 | self._train_step(loss, accuracy, batch_size, learning_rate)
153 | else:
154 | self._eval_step(loss, accuracy, batch_size)
155 |
156 | def flush(self) -> None:
157 | loss = self.epoch_state.loss / self.num_batches
158 | accuracy = self.epoch_state.accuracy / self.epoch_state.samples
159 | if self.is_train:
160 | print(
161 | f"\r┃{self.epoch:12d} ┃{loss:12.4f} │{100 * accuracy:10.2f} % ┃{self.learning_rate:12.3e} │{self._time():>12} ┃",
162 | end="",
163 | flush=True,
164 | )
165 | train_statistics = {
166 | "epoch": self.epoch,
167 | "train_accuracy": accuracy,
168 | "train_loss": loss,
169 | "lr": self.learning_rate,
170 | }
171 | if self.config.log_to_wandb:
172 | wandb.log(train_statistics)
173 | self.info_logger.info(train_statistics)
174 |
175 | else:
176 | print(f"{loss:12.4f} │{100 * accuracy:10.2f} % ┃", flush=True)
177 |
178 | if accuracy > self.best_val_accuracy:
179 | self.best_val_accuracy = accuracy
180 | validation_statistics = {
181 | "epoch": self.epoch,
182 | "val_accuracy": accuracy,
183 | "val_loss": loss,
184 | "best_val_accuracy": self.best_val_accuracy,
185 | }
186 | if self.config.log_to_wandb:
187 | wandb.log(validation_statistics)
188 | self.info_logger.info(validation_statistics)
189 |
190 | def _train_step(
191 | self, loss: float, accuracy: float, batch_size: int, learning_rate: float
192 | ) -> None:
193 | self.learning_rate = learning_rate
194 | self.last_steps_state.add_to_loss(loss)
195 | self.last_steps_state.add_to_accuracy(accuracy)
196 | self.last_steps_state.add_to_samples(batch_size)
197 | self.epoch_state.add_to_loss(loss)
198 | self.epoch_state.add_to_accuracy(accuracy)
199 | self.epoch_state.add_to_samples(batch_size)
200 | self.step += 1
201 |
202 | if self.step % self.config.log_interval == self.config.log_interval - 1:
203 | loss = self.last_steps_state.loss / self.step
204 | accuracy = self.last_steps_state.accuracy / self.last_steps_state.samples
205 |
206 | self.last_steps_state.reset()
207 | progress = self.step / self.num_batches
208 | print(
209 | f"\r┃{self.epoch:12d} ┃{loss:12.4f} │{100 * accuracy:10.2f} % ┃{learning_rate:12.3e} │{self._time():>12} {self.loading_bar(progress)}",
210 | end="",
211 | flush=True,
212 | )
213 |
214 | def _eval_step(self, loss: float, accuracy: float, batch_size: int) -> None:
215 | self.epoch_state.add_to_loss(loss)
216 | self.epoch_state.add_to_accuracy(accuracy)
217 | self.epoch_state.add_to_samples(batch_size)
218 |
219 | def _reset(self, num_batches: int) -> None:
220 | self.start_time = time.time()
221 | self.step = 0
222 | self.num_batches = num_batches
223 | self.epoch_state.reset()
224 |
225 | def _time(self) -> str:
226 | time_seconds = int(time.time() - self.start_time)
227 | return f"{time_seconds // 60:02d}:{time_seconds % 60:02d} min"
228 |
229 | def _print_header(self) -> None:
230 | print(
231 | "┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓"
232 | )
233 | print(
234 | "┃ ┃ ╷ ┃ ╷ ┃ ╷ ┃"
235 | )
236 | print(
237 | "┃ epoch ┃ loss │ accuracy ┃ l.r. │ elapsed ┃ loss │ accuracy ┃"
238 | )
239 | print(
240 | "┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨"
241 | )
242 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.building_blocks.gnn import GNN, GNN_graphpred
2 | from .attribute_masking import AttributeMaskingModel
3 | from .building_blocks.auto_encoder import (
4 | AutoEncoder,
5 | EnergyVariationalAutoEncoder,
6 | ImportanceWeightedAutoEncoder,
7 | NormalizingFlowVariationalAutoEncoder,
8 | VariationalAutoEncoder,
9 | )
10 | from .context_prediction import ContextPredictionModel
11 | from .contextual import ContextualModel
12 | from .discriminator import Discriminator
13 | from .edge_prediction import EdgePredictionModel
14 | from .building_blocks.flow import (
15 | AffineFlow,
16 | BatchNormFlow,
17 | NormalizingFlow,
18 | PlanarFlow,
19 | PReLUFlow,
20 | RadialFlow,
21 | )
22 | from .gpt_gnn import GPTGNNModel
23 | from .graph_cl import GraphCLModel
24 | from .info_max import InfoMaxModel
25 | from .joao_v2 import JOAOv2Model
26 | from .motif import MotifModel
27 | from .graphpred import GraphPred
28 | from .graphmvp import GraphMVPModel
29 | from .rgcl import RGCLModel
30 | from .graphmae import GraphMAEModel
31 |
--------------------------------------------------------------------------------
/src/models/attribute_masking.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Batch
3 |
4 | from config.training_config import TrainingConfig
5 | from models.building_blocks.gnn import GNN
6 | from models.pre_trainer_model import PreTrainerModel
7 |
8 |
9 | class AttributeMaskingModel(PreTrainerModel):
10 | def __init__(self, config: TrainingConfig, gnn: GNN):
11 | super().__init__(config=config)
12 | self.gnn: torch.nn.Module = gnn
13 | self.molecule_atom_masking_model = torch.nn.Linear(config.emb_dim, 119)
14 |
15 | def forward(self, batch: Batch) -> torch.Tensor:
16 | node_repr = self.gnn(batch.masked_x, batch.edge_index, batch.edge_attr)
17 | node_pred = self.molecule_atom_masking_model(
18 | node_repr[batch.masked_atom_indices]
19 | )
20 | return node_pred
21 |
--------------------------------------------------------------------------------
/src/models/building_blocks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/src/models/building_blocks/__init__.py
--------------------------------------------------------------------------------
/src/models/building_blocks/flow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributions as distrib
3 | import torch.distributions.transforms as transform
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class Flow(transform.Transform, nn.Module):
9 | def __init__(self):
10 | transform.Transform.__init__(self)
11 | nn.Module.__init__(self)
12 |
13 | def init_parameters(self):
14 | for param in self.parameters():
15 | param.data.uniform_(-0.01, 0.01)
16 |
17 | def __hash__(self):
18 | return nn.Module.__hash__(self)
19 |
20 |
21 | class PlanarFlow(Flow):
22 | def __init__(self, dim, h=torch.tanh, hp=(lambda x: 1 - torch.tanh(x) ** 2)):
23 | super(PlanarFlow, self).__init__()
24 | self.weight = nn.Parameter(torch.Tensor(1, dim))
25 | self.scale = nn.Parameter(torch.Tensor(1, dim))
26 | self.bias = nn.Parameter(torch.Tensor(1))
27 | self.h = h
28 | self.hp = hp
29 | self.init_parameters()
30 |
31 | def _call(self, z):
32 | f_z = F.linear(z, self.weight, self.bias)
33 | return z + self.scale * self.h(f_z)
34 |
35 | def log_abs_det_jacobian(self, z):
36 | f_z = F.linear(z, self.weight, self.bias)
37 | psi = self.hp(f_z) * self.weight
38 | det_grad = 1 + torch.mm(psi, self.scale.t())
39 | return torch.log(det_grad.abs() + 1e-9)
40 |
41 |
42 | class RadialFlow(Flow):
43 | def __init__(self, dim):
44 | super(RadialFlow, self).__init__()
45 | self.z0 = nn.Parameter(torch.Tensor(1, dim))
46 | self.alpha = nn.Parameter(torch.Tensor(1))
47 | self.beta = nn.Parameter(torch.Tensor(1))
48 | self.dim = dim
49 | self.init_parameters()
50 |
51 | def _call(self, z):
52 | r = torch.norm(z - self.z0, dim=1).unsqueeze(1)
53 | h = 1 / (self.alpha + r)
54 | return z + (self.beta * h * (z - self.z0))
55 |
56 | def log_abs_det_jacobian(self, z):
57 | r = torch.norm(z - self.z0, dim=1).unsqueeze(1)
58 | h = 1 / (self.alpha + r)
59 | hp = -1 / (self.alpha + r) ** 2
60 | bh = self.beta * h
61 | det_grad = ((1 + bh) ** self.dim - 1) * (1 + bh + self.beta * hp * r)
62 | return torch.log(det_grad.abs() + 1e-9)
63 |
64 |
65 | class PReLUFlow(Flow):
66 | def __init__(self, dim):
67 | super(PReLUFlow, self).__init__()
68 | self.alpha = nn.Parameter(torch.Tensor([1]))
69 | self.bijective = True
70 |
71 | def init_parameters(self):
72 | for param in self.parameters():
73 | param.data.uniform_(0.01, 0.99)
74 |
75 | def _call(self, z):
76 | return torch.where(z >= 0, z, torch.abs(self.alpha) * z)
77 |
78 | def _inverse(self, z):
79 | return torch.where(z >= 0, z, torch.abs(1.0 / self.alpha) * z)
80 |
81 | def log_abs_det_jacobian(self, z):
82 | I = torch.ones_like(z)
83 | J = torch.where(z >= 0, I, self.alpha * I)
84 | log_abs_det = torch.log(torch.abs(J) + 1e-5)
85 | return torch.sum(log_abs_det, dim=1)
86 |
87 |
88 | class BatchNormFlow(Flow):
89 | def __init__(self, dim, momentum=0.95, eps=1e-5):
90 | super(BatchNormFlow, self).__init__()
91 | # Running batch statistics
92 | self.r_mean = torch.zeros(dim)
93 | self.r_var = torch.ones(dim)
94 | # Momentum
95 | self.momentum = momentum
96 | self.eps = eps
97 | # Trainable scale and shift (cf. original paper)
98 | self.gamma = nn.Parameter(torch.ones(dim))
99 | self.beta = nn.Parameter(torch.zeros(dim))
100 |
101 | def _call(self, z):
102 | if self.training:
103 | # Current batch stats
104 | self.b_mean = z.mean(0)
105 | self.b_var = (z - self.b_mean).pow(2).mean(0) + self.eps
106 | # Running mean and var
107 | self.r_mean = (
108 | self.momentum * self.r_mean + (1 - self.momentum) * self.b_mean
109 | )
110 | self.r_var = self.momentum * self.r_var + (1 - self.momentum) * self.b_var
111 | mean = self.b_mean
112 | var = self.b_var
113 | else:
114 | mean = self.r_mean
115 | var = self.r_var
116 | x_hat = (z - mean) / var.sqrt()
117 | y = self.gamma * x_hat + self.beta
118 | return y
119 |
120 | def _inverse(self, x):
121 | if self.training:
122 | mean = self.b_mean
123 | var = self.b_var
124 | else:
125 | mean = self.r_mean
126 | var = self.r_var
127 | x_hat = (z - self.beta) / self.gamma
128 | y = x_hat * var.sqrt() + mean
129 | return y
130 |
131 | def log_abs_det_jacobian(self, z):
132 | # Here we only need the variance
133 | mean = z.mean(0)
134 | var = (z - mean).pow(2).mean(0) + self.eps
135 | log_det = torch.log(self.gamma) - 0.5 * torch.log(var + self.eps)
136 | return torch.sum(log_det, -1)
137 |
138 |
139 | class AffineFlow(Flow):
140 | def __init__(self, dim):
141 | super(AffineFlow, self).__init__()
142 | self.weights = nn.Parameter(torch.Tensor(dim, dim))
143 | nn.init.orthogonal_(self.weights)
144 |
145 | def _call(self, z):
146 | return z @ self.weights
147 |
148 | def _inverse(self, z):
149 | return z @ torch.inverse(self.weights)
150 |
151 | def log_abs_det_jacobian(self, z):
152 | return torch.slogdet(self.weights)[-1].unsqueeze(0).repeat(z.size(0), 1)
153 |
154 |
155 | class NormalizingFlow(nn.Module):
156 | def __init__(self, dim, blocks, flow_length, density):
157 | super(NormalizingFlow, self).__init__()
158 | biject = []
159 | for f in range(flow_length):
160 | for b_flow in blocks:
161 | biject.append(b_flow(dim))
162 | self.transforms = transform.ComposeTransform(biject)
163 | self.bijectors = nn.ModuleList(biject)
164 | self.base_density = density
165 | self.final_density = distrib.TransformedDistribution(density, self.transforms)
166 | self.log_det = []
167 |
168 | def forward(self, z):
169 | self.log_det = []
170 | # Applies series of flows
171 | for b in range(len(self.bijectors)):
172 | self.log_det.append(self.bijectors[b].log_abs_det_jacobian(z))
173 | z = self.bijectors[b](z)
174 | return z, self.log_det
175 |
176 |
177 | if __name__ == "__main__":
178 | flow_model = "mlp"
179 | if flow_model == "planar":
180 | blocks = [PlanarFlow]
181 | elif flow_model == "radial":
182 | blocks = [RadialFlow]
183 | elif flow_model == "affine":
184 | blocks = [AffineFlow]
185 | elif flow_model == "mlp":
186 | blocks = [AffineFlow, BatchNormFlow, PReLUFlow]
187 | else:
188 | blocks = None
189 |
190 | flow = NormalizingFlow(
191 | dim=2,
192 | blocks=blocks,
193 | flow_length=8,
194 | density=distrib.MultivariateNormal(torch.zeros(2), torch.eye(2)),
195 | )
196 |
197 | import numpy as np
198 | import torch.optim as optim
199 |
200 | def density_ring(z):
201 | z1, z2 = torch.chunk(z, chunks=2, dim=1)
202 | norm = torch.sqrt(z1**2 + z2**2)
203 | exp1 = torch.exp(-0.5 * ((z1 - 2) / 0.8) ** 2)
204 | exp2 = torch.exp(-0.5 * ((z1 + 2) / 0.8) ** 2)
205 | u = 0.5 * ((norm - 4) / 0.4) ** 2 - torch.log(exp1 + exp2)
206 | return torch.exp(-u)
207 |
208 | def loss(density, zk, log_jacobians):
209 | sum_of_log_jacobians = sum(log_jacobians)
210 | return (-sum_of_log_jacobians - torch.log(density(zk) + 1e-9)).mean()
211 |
212 | optimizer = optim.Adam(flow.parameters(), lr=1e-3)
213 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)
214 |
215 | x = np.linspace(-5, 5, 1000)
216 | z = np.array(np.meshgrid(x, x)).transpose(1, 2, 0)
217 | z = np.reshape(z, [z.shape[0] * z.shape[1], -1])
218 |
219 | ref_distrib = distrib.MultivariateNormal(torch.zeros(2), torch.eye(2))
220 | for it in range(10001):
221 | # Draw a sample batch from Normal
222 | samples = ref_distrib.sample((512,))
223 | # Evaluate flow of transforms
224 | zk, log_jacobians = flow(samples)
225 | # Evaluate loss and backprop
226 | optimizer.zero_grad()
227 | loss_v = loss(density_ring, zk, log_jacobians)
228 | loss_v.backward()
229 | optimizer.step()
230 | scheduler.step()
231 | if it % 1000 == 0:
232 | print("Loss (it. %i) : %f" % (it, loss_v.item()))
233 | # Draw random samples
234 | samples = ref_distrib.sample((int(1e5),))
235 | # Evaluate flow and plot
236 | zk, _ = flow(samples)
237 |
--------------------------------------------------------------------------------
/src/models/building_blocks/mlp.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional
2 |
3 | import torch.nn as nn
4 | from torch import Tensor
5 |
6 |
7 | def get_activation(name: str, leaky_relu: Optional[float] = 0.5) -> nn.Module:
8 | str_to_activation = {
9 | "leaky_relu": nn.LeakyReLU(leaky_relu),
10 | "rrelu": nn.RReLU(),
11 | "relu": nn.ReLU(),
12 | "elu": nn.ELU(),
13 | "gelu": nn.GELU(),
14 | "prelu": nn.PReLU(),
15 | "selu": nn.SELU(),
16 | }
17 |
18 | return str_to_activation[name]
19 |
20 |
21 | def create_batch_norm_1d_layers(num_layers: int, dim_hidden: int):
22 | batch_norm_layers = nn.ModuleList()
23 | for i in range(num_layers - 1):
24 | batch_norm_layers.append(nn.BatchNorm1d(num_features=dim_hidden))
25 | return batch_norm_layers
26 |
27 |
28 | def create_linear_layers(
29 | num_layers: int, dim_input: int, dim_hidden: int, dim_output: int
30 | ):
31 | linear_layers = nn.ModuleList()
32 | # Input layer
33 | linear_layers.append(nn.Linear(in_features=dim_input, out_features=dim_hidden))
34 | # Hidden layers
35 | for i in range(1, num_layers - 1):
36 | linear_layers.append(nn.Linear(in_features=dim_hidden, out_features=dim_hidden))
37 | # Output layer
38 | linear_layers.append(nn.Linear(dim_hidden, dim_output))
39 | return linear_layers
40 |
41 |
42 | def init_layers(initializer_name: str, layers: nn.ModuleList):
43 | initializer = get_initializer(initializer_name)
44 | for layer in layers:
45 | initializer(layer.weight)
46 |
47 |
48 | def get_initializer(name: str = "xavier") -> Callable:
49 | str_to_init = {
50 | "orthogonal": nn.init.orthogonal_,
51 | "xavier": nn.init.xavier_uniform_,
52 | "kaiming": nn.init.kaiming_uniform_,
53 | }
54 | return str_to_init[name]
55 |
56 |
57 | class MLP(nn.Module):
58 | def __init__(
59 | self,
60 | dim_input: int,
61 | dim_hidden: int,
62 | dim_output: int,
63 | num_layers: int,
64 | batch_norm: bool,
65 | initializer: str,
66 | dropout: float,
67 | activation: str,
68 | leaky_relu: float,
69 | is_output_activation: bool,
70 | ):
71 | super().__init__()
72 | self.layers = create_linear_layers(
73 | num_layers=num_layers,
74 | dim_input=dim_input,
75 | dim_hidden=dim_hidden,
76 | dim_output=dim_output,
77 | )
78 | init_layers(initializer_name=initializer, layers=self.layers)
79 | self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None
80 | self.batch_norm_layers = (
81 | create_batch_norm_1d_layers(num_layers=num_layers, dim_hidden=dim_hidden)
82 | if batch_norm
83 | else None
84 | )
85 | self.activation_function = get_activation(
86 | name=activation, leaky_relu=leaky_relu
87 | )
88 | self.is_output_activation = is_output_activation
89 |
90 | def forward(self, x: Tensor):
91 | for i in range(len(self.layers) - 1):
92 | x = self.layers[i](x)
93 | x = self.activation_function(x)
94 | if self.batch_norm_layers:
95 | x = self.batch_norm_layers[i](x)
96 | if self.dropout:
97 | x = self.dropout(x)
98 | x = self.layers[-1](x)
99 | if self.is_output_activation:
100 | x = self.activation_function(x)
101 | return x
102 |
--------------------------------------------------------------------------------
/src/models/building_blocks/schnet.py:
--------------------------------------------------------------------------------
1 | import ase
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch_geometric.nn import MessagePassing, radius_graph
6 | from torch_scatter import scatter
7 | from math import pi as PI
8 |
9 | # try:
10 | # import schnetpack as spk
11 | # except ImportError:
12 | # spk = None
13 |
14 |
15 | class SchNet(nn.Module):
16 | def __init__(
17 | self,
18 | hidden_channels=128,
19 | num_filters=128,
20 | num_interactions=6,
21 | num_gaussians=51,
22 | cutoff=10.0,
23 | readout="mean",
24 | dipole=False,
25 | mean=None,
26 | std=None,
27 | atomref=None,
28 | ):
29 | super(SchNet, self).__init__()
30 |
31 | assert readout in ["add", "sum", "mean"]
32 |
33 | self.readout = "add" if dipole else readout
34 | self.num_interactions = num_interactions
35 | self.hidden_channels = hidden_channels
36 | self.num_gaussians = num_gaussians
37 | self.num_filters = num_filters
38 | self.cutoff = cutoff
39 | self.dipole = dipole
40 | self.scale = None
41 | self.mean = mean
42 | self.std = std
43 |
44 | atomic_mass = torch.from_numpy(ase.data.atomic_masses)
45 | self.register_buffer("atomic_mass", atomic_mass)
46 |
47 | # self.embedding = Embedding(100, hidden_channels)
48 | self.embedding = nn.Embedding(119, hidden_channels)
49 | self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)
50 |
51 | self.interactions = nn.ModuleList()
52 | for _ in range(num_interactions):
53 | block = InteractionBlock(
54 | hidden_channels, num_gaussians, num_filters, cutoff
55 | )
56 | self.interactions.append(block)
57 |
58 | self.lin1 = nn.Linear(hidden_channels, hidden_channels)
59 | self.act = ShiftedSoftplus()
60 | self.lin2 = nn.Linear(hidden_channels, hidden_channels)
61 |
62 | self.register_buffer("initial_atomref", atomref)
63 | self.atomref = None
64 | if atomref is not None:
65 | self.atomref = nn.Embedding(100, 1)
66 | self.atomref.weight.data.copy_(atomref)
67 |
68 | self.reset_parameters()
69 |
70 | def reset_parameters(self):
71 | self.embedding.reset_parameters()
72 | for interaction in self.interactions:
73 | interaction.reset_parameters()
74 | torch.nn.init.xavier_uniform_(self.lin1.weight)
75 | self.lin1.bias.data.fill_(0)
76 | torch.nn.init.xavier_uniform_(self.lin2.weight)
77 | self.lin2.bias.data.fill_(0)
78 | if self.atomref is not None:
79 | self.atomref.weight.data.copy_(self.initial_atomref)
80 |
81 | def forward(self, z, pos, batch=None):
82 | assert z.dim() == 1 and z.dtype == torch.long
83 | batch = torch.zeros_like(z) if batch is None else batch
84 |
85 | h = self.embedding(z)
86 |
87 | edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
88 | row, col = edge_index
89 | edge_weight = (pos[row] - pos[col]).norm(dim=-1)
90 | edge_attr = self.distance_expansion(edge_weight)
91 |
92 | for interaction in self.interactions:
93 | h = h + interaction(h, edge_index, edge_weight, edge_attr)
94 |
95 | h = self.lin1(h)
96 | h = self.act(h)
97 | h = self.lin2(h)
98 |
99 | if self.dipole:
100 | # Get center of mass.
101 | mass = self.atomic_mass[z].view(-1, 1)
102 | c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0)
103 | h = h * (pos - c[batch])
104 |
105 | if not self.dipole and self.mean is not None and self.std is not None:
106 | h = h * self.std + self.mean
107 |
108 | if not self.dipole and self.atomref is not None:
109 | h = h + self.atomref(z)
110 |
111 | out = scatter(h, batch, dim=0, reduce=self.readout)
112 |
113 | if self.dipole:
114 | out = torch.norm(out, dim=-1, keepdim=True)
115 |
116 | if self.scale is not None:
117 | out = self.scale * out
118 |
119 | return out
120 |
121 | def __repr__(self):
122 | return (
123 | f"{self.__class__.__name__}("
124 | f"hidden_channels={self.hidden_channels}, "
125 | f"num_filters={self.num_filters}, "
126 | f"num_interactions={self.num_interactions}, "
127 | f"num_gaussians={self.num_gaussians}, "
128 | f"cutoff={self.cutoff})"
129 | )
130 |
131 |
132 | class InteractionBlock(torch.nn.Module):
133 | def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff):
134 | super(InteractionBlock, self).__init__()
135 | self.mlp = nn.Sequential(
136 | nn.Linear(num_gaussians, num_filters),
137 | ShiftedSoftplus(),
138 | nn.Linear(num_filters, num_filters),
139 | )
140 | self.conv = CFConv(
141 | hidden_channels, hidden_channels, num_filters, self.mlp, cutoff
142 | )
143 | self.act = ShiftedSoftplus()
144 | self.lin = nn.Linear(hidden_channels, hidden_channels)
145 |
146 | self.reset_parameters()
147 |
148 | def reset_parameters(self):
149 | torch.nn.init.xavier_uniform_(self.mlp[0].weight)
150 | self.mlp[0].bias.data.fill_(0)
151 | torch.nn.init.xavier_uniform_(self.mlp[2].weight)
152 | self.mlp[0].bias.data.fill_(0)
153 | self.conv.reset_parameters()
154 | torch.nn.init.xavier_uniform_(self.lin.weight)
155 | self.lin.bias.data.fill_(0)
156 |
157 | def forward(self, x, edge_index, edge_weight, edge_attr):
158 | x = self.conv(x, edge_index, edge_weight, edge_attr)
159 | x = self.act(x)
160 | x = self.lin(x)
161 | return x
162 |
163 |
164 | class CFConv(MessagePassing):
165 | def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff):
166 | super(CFConv, self).__init__(aggr="add")
167 | self.lin1 = nn.Linear(in_channels, num_filters, bias=False)
168 | self.lin2 = nn.Linear(num_filters, out_channels)
169 | self.mlp = mlp
170 | self.cutoff = cutoff
171 |
172 | self.reset_parameters()
173 |
174 | def reset_parameters(self):
175 | torch.nn.init.xavier_uniform_(self.lin1.weight)
176 | torch.nn.init.xavier_uniform_(self.lin2.weight)
177 | self.lin2.bias.data.fill_(0)
178 |
179 | def forward(self, x, edge_index, edge_weight, edge_attr):
180 | C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0)
181 | W = self.mlp(edge_attr) * C.view(-1, 1)
182 |
183 | x = self.lin1(x)
184 | x = self.propagate(edge_index, x=x, W=W)
185 | x = self.lin2(x)
186 | return x
187 |
188 | def message(self, x_j, W):
189 | return x_j * W
190 |
191 |
192 | class GaussianSmearing(torch.nn.Module):
193 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
194 | super(GaussianSmearing, self).__init__()
195 | offset = torch.linspace(start, stop, num_gaussians)
196 | self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
197 | self.register_buffer("offset", offset)
198 |
199 | def forward(self, dist):
200 | dist = dist.view(-1, 1) - self.offset.view(1, -1)
201 | return torch.exp(self.coeff * torch.pow(dist, 2))
202 |
203 |
204 | class ShiftedSoftplus(torch.nn.Module):
205 | def __init__(self):
206 | super(ShiftedSoftplus, self).__init__()
207 | self.shift = torch.log(torch.tensor(2.0)).item()
208 |
209 | def forward(self, x):
210 | return F.softplus(x) - self.shift
211 |
--------------------------------------------------------------------------------
/src/models/context_prediction.py:
--------------------------------------------------------------------------------
1 | """ GRAPH SSL Pre-Training via Context Prediction (CP)
2 | i.e., maps nodes in similar structural contexts to closer embeddings
3 | Ref Paper: Sec. 3.1.1 and Appendix G of
4 | https://arxiv.org/abs/1905.12265 ;
5 | Ref Code: ${GitHub_Repo}/chem/pretrain_contextpred.py """
6 |
7 | from typing import Callable
8 |
9 | import torch.nn as nn
10 | from torch import Tensor
11 | from torch_geometric.nn import global_mean_pool
12 |
13 | from config.training_config import TrainingConfig
14 | from models import GNN
15 | from models.pre_trainer_model import PreTrainerModel
16 |
17 |
18 | class ContextPredictionModel(PreTrainerModel):
19 | def __init__(self, gnn: GNN, config: TrainingConfig):
20 | super(ContextPredictionModel, self).__init__(config)
21 | self.gnn: nn.Module = gnn
22 | self.pool: Callable = global_mean_pool
23 |
24 | l1 = self.config.num_layer - 1
25 | l2 = l1 + self.config.csize
26 | num_layer = self.config.num_layer
27 | self.config.num_layer = l2 - l1
28 | self.context_model: nn.Module = GNN(self.config)
29 | self.config.num_layer = num_layer
30 |
31 | def forward_substruct_model(
32 | self, x: Tensor, edge_index: Tensor, edge_attr: Tensor
33 | ) -> Tensor:
34 | return self.gnn(x, edge_index, edge_attr)
35 |
36 | def forward_context_model(
37 | self, x: Tensor, edge_index: Tensor, edge_attr: Tensor
38 | ) -> Tensor:
39 | return self.context_model(x, edge_index, edge_attr)
40 |
41 | def forward_context_repr(
42 | self,
43 | x: Tensor,
44 | edge_index: Tensor,
45 | edge_attr: Tensor,
46 | overlap_context_substruct_idx: Tensor,
47 | batch_overlapped_context: Tensor,
48 | ) -> Tensor:
49 | # creating context representations
50 | overlapped_node_repr = self.context_model(x, edge_index, edge_attr)[
51 | overlap_context_substruct_idx
52 | ]
53 |
54 | # positive context representation
55 | # readout -> global_mean_pool by default
56 | context_repr = self.pool(overlapped_node_repr, batch_overlapped_context)
57 | return context_repr
58 |
59 |
60 | if __name__ == "__main__":
61 | pass
62 |
--------------------------------------------------------------------------------
/src/models/contextual.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch import Tensor
3 |
4 | from config.training_config import TrainingConfig
5 | from models.building_blocks.gnn import GNN
6 | from models.pre_trainer_model import PreTrainerModel
7 |
8 |
9 | class ContextualModel(PreTrainerModel):
10 | def __init__(self, gnn: GNN, config: TrainingConfig):
11 | super(ContextualModel, self).__init__(config)
12 | self.gnn: nn.Module = gnn
13 | self.emb_dim = config.emb_dim
14 | self.criterion = nn.CrossEntropyLoss()
15 | self.atom_vocab_size = config.atom_vocab_size
16 | self.atom_vocab_model: nn.Module = nn.Sequential(
17 | nn.Linear(self.emb_dim, self.atom_vocab_size)
18 | )
19 |
20 | def forward_cl(
21 | self,
22 | x: Tensor,
23 | edge_index: Tensor,
24 | edge_attr: Tensor,
25 | batch_assignments: Tensor,
26 | ) -> Tensor:
27 | x = self.gnn(x, edge_index, edge_attr)
28 | x = self.atom_vocab_model(x)
29 | return x
30 |
31 | def loss_cl(self, y_pred: Tensor, y_actual: Tensor) -> Tensor:
32 | loss = self.criterion(y_pred, y_actual)
33 | return loss
34 |
--------------------------------------------------------------------------------
/src/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_geometric.nn.inits import uniform
4 |
5 |
6 | class Discriminator(nn.Module):
7 | def __init__(self, hidden_dim):
8 | super(Discriminator, self).__init__()
9 | self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
10 | self.reset_parameters()
11 |
12 | def reset_parameters(self):
13 | size = self.weight.size(0)
14 | uniform(size, self.weight)
15 |
16 | def forward(self, x, summary):
17 | h = torch.matmul(summary, self.weight)
18 | return torch.sum(x * h, dim=1)
19 |
--------------------------------------------------------------------------------
/src/models/edge_prediction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.data import Batch
3 |
4 | from config.training_config import TrainingConfig
5 | from models.building_blocks.gnn import GNN
6 | from models.pre_trainer_model import PreTrainerModel
7 |
8 |
9 | class EdgePredictionModel(PreTrainerModel):
10 | def __init__(self, config: TrainingConfig, gnn: GNN):
11 | super().__init__(config=config)
12 | self.gnn: torch.nn.Module = gnn
13 |
14 | def forward(self, batch: Batch) -> torch.Tensor:
15 | node_repr = self.gnn(batch.x, batch.edge_index, batch.edge_attr)
16 | return node_repr
17 |
--------------------------------------------------------------------------------
/src/models/gpt_gnn.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch.nn as nn
4 | from torch_geometric.data import Batch
5 | from torch_geometric.nn import global_mean_pool
6 |
7 | from config.training_config import TrainingConfig
8 | from models.building_blocks.gnn import GNN
9 | from models.pre_trainer_model import PreTrainerModel
10 |
11 |
12 | class GPTGNNModel(PreTrainerModel):
13 | """This is the atom prediction head for GPT-GNN.
14 | Will add edge in the future."""
15 |
16 | def __init__(self, config: TrainingConfig, gnn: GNN):
17 | super().__init__(config=config)
18 | self.gnn: nn.Module = gnn
19 | self.atom_pred: nn.Module = nn.Linear(config.emb_dim, 119)
20 | self.molecule_readout_func: Callable = global_mean_pool
21 |
22 | def forward(self, batch: Batch):
23 | node_repr = self.gnn(batch.x, batch.edge_index, batch.edge_attr)
24 | graph_repr = self.molecule_readout_func(node_repr, batch.batch)
25 | next_node_pred = self.atom_pred(graph_repr)
26 | return next_node_pred
27 |
--------------------------------------------------------------------------------
/src/models/graph_cl.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 | from torch_geometric.nn import global_mean_pool
7 |
8 | from config.training_config import TrainingConfig
9 | from models.building_blocks.gnn import GNN
10 | from models.pre_trainer_model import PreTrainerModel
11 |
12 |
13 | class GraphCLModel(PreTrainerModel):
14 | def __init__(self, config: TrainingConfig, gnn: GNN):
15 | super(GraphCLModel, self).__init__(config)
16 | self.gnn: nn.Module = gnn
17 | self.emb_dim = config.emb_dim
18 | self.pool: Callable = global_mean_pool
19 | self.projection_head: nn.Module = nn.Sequential(
20 | nn.Linear(self.emb_dim, self.emb_dim),
21 | nn.ReLU(inplace=True),
22 | nn.Linear(self.emb_dim, self.emb_dim),
23 | )
24 |
25 | def forward_cl(
26 | self,
27 | x: Tensor,
28 | edge_index: Tensor,
29 | edge_attr: Tensor,
30 | batch_assignments: Tensor,
31 | ) -> Tensor:
32 | x = self.gnn(x, edge_index, edge_attr)
33 | x = self.pool(x, batch_assignments)
34 | x = self.projection_head(x)
35 | return x
36 |
37 | def loss_cl(self, x1: Tensor, x2: Tensor) -> Tensor:
38 | T = 0.1
39 | batch, _ = x1.size()
40 | x1_abs = x1.norm(dim=1)
41 | x2_abs = x2.norm(dim=1)
42 |
43 | sim_matrix = torch.einsum("ik,jk->ij", x1, x2) / torch.einsum(
44 | "i,j->ij", x1_abs, x2_abs
45 | )
46 | sim_matrix = torch.exp(sim_matrix / T)
47 | pos_sim = sim_matrix[range(batch), range(batch)]
48 | # loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
49 | loss = pos_sim / (sim_matrix.sum(dim=1))
50 | loss = -torch.log(loss).mean()
51 | return loss
52 |
--------------------------------------------------------------------------------
/src/models/graphmae.py:
--------------------------------------------------------------------------------
1 | # Ref: https://github.com/THUDM/GraphMAE/blob/main/chem/pretraining.py#L50
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch import Tensor
7 | from functools import partial
8 | from config.training_config import TrainingConfig
9 | from models.pre_trainer_model import PreTrainerModel
10 | from models.building_blocks.gnn import GNN, GNNDecoder
11 |
12 | NUM_NODE_ATTR = 119
13 |
14 |
15 | class GraphMAEModel(PreTrainerModel):
16 | def __init__(self, config: TrainingConfig, gnn: GNN):
17 | super(GraphMAEModel, self).__init__(config)
18 | self.gnn: nn.Module = gnn
19 | # self.emb_dim = config.emb_dim
20 | # self.pool: Callable = global_mean_pool
21 | self.atom_pred_decoder = GNNDecoder(
22 | config.emb_dim, NUM_NODE_ATTR, JK=config.JK, gnn_type="gin"
23 | )
24 | # ref: https://github.com/THUDM/GraphMAE/blob/6d2636e942f6597d70f438e66ce876f80f9ca9e0/chem/pretraining.py#L137
25 | self.bond_pred_decoder = None
26 |
27 | def forward(self, batch) -> Tensor:
28 | node_rep = self.gnn(batch.x, batch.edge_index, batch.edge_attr)
29 | pred_node = self.atom_pred_decoder(
30 | node_rep, batch.edge_index, batch.edge_attr, batch.masked_atom_indices
31 | )
32 | return pred_node
33 |
34 | def loss(self, pred_node, batch, alpha_l=1.0, loss_fn="sce") -> Tensor:
35 | def sce_loss(x, y, alpha=1):
36 | x = F.normalize(x, p=2, dim=-1)
37 | y = F.normalize(y, p=2, dim=-1)
38 | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
39 | return loss.mean()
40 |
41 | node_attr_label = batch.node_attr_label
42 | # node_attr_label = batch.mask_node_label
43 | masked_node_indices = batch.masked_atom_indices
44 |
45 | if loss_fn == "sce":
46 | criterion = partial(sce_loss, alpha=alpha_l)
47 | loss = criterion(node_attr_label, pred_node[masked_node_indices])
48 | else:
49 | criterion = nn.CrossEntropyLoss()
50 | loss = criterion(
51 | pred_node.double()[masked_node_indices], batch.mask_node_label[:, 0]
52 | )
53 |
54 | return loss
55 |
--------------------------------------------------------------------------------
/src/models/graphmvp.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Union, List
2 |
3 | # import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 | from torch_geometric.data.batch import Batch
7 | from torch_geometric.nn import global_mean_pool
8 | from config.training_config import TrainingConfig
9 | from models.building_blocks.gnn import GNN
10 | from models.building_blocks.schnet import SchNet
11 | from models.building_blocks.auto_encoder import AutoEncoder, VariationalAutoEncoder
12 | from models.pre_trainer_model import PreTrainerModel
13 | from util import dual_CL
14 |
15 |
16 | class GraphMVPModel(PreTrainerModel):
17 | def __init__(
18 | self,
19 | config: TrainingConfig,
20 | gnn_2d: GNN,
21 | gnn_3d: SchNet,
22 | ae2d3d: Union[AutoEncoder, VariationalAutoEncoder],
23 | ae3d2d: Union[AutoEncoder, VariationalAutoEncoder],
24 | ):
25 | super(GraphMVPModel, self).__init__(config)
26 | self.gnn: nn.Module = gnn_2d
27 | self.gnn_3d: nn.Module = gnn_3d
28 | self.ae2d3d: nn.Module = ae2d3d
29 | self.ae3d2d: nn.Module = ae3d2d
30 | self.config: TrainingConfig = config
31 | self.pool: Callable = global_mean_pool
32 |
33 | def forward(self, batch: Batch) -> List[Tensor]:
34 | # x, edge_index, edge_attr, positions, batch_assignments = batch.x, batch.edg
35 | repr_node = self.gnn(batch.x, batch.edge_index, batch.edge_attr)
36 | repr_2d = self.pool(repr_node, batch.batch)
37 | repr_3d = self.gnn_3d(batch.x[:, 0], batch.positions, batch.batch)
38 |
39 | return repr_2d, repr_3d
40 |
41 | def loss(self, repr_2d: Tensor, repr_3d: Tensor) -> Tensor:
42 | CL_loss, _ = dual_CL(repr_2d, repr_3d, self.config)
43 | AE_loss_1 = self.ae2d3d(repr_2d, repr_3d)
44 | AE_loss_2 = self.ae3d2d(repr_3d, repr_2d)
45 | AE_loss = (AE_loss_1 + AE_loss_2) / 2
46 |
47 | loss = CL_loss * self.config.GMVP_alpha1 + AE_loss * self.config.GMVP_alpha2
48 | # loss = -torch.log(loss).mean()
49 | return loss
50 |
--------------------------------------------------------------------------------
/src/models/graphpred.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 | from torch_geometric.nn import global_mean_pool
7 |
8 | from config.training_config import TrainingConfig
9 | from models.building_blocks.gnn import GNN
10 | from models.pre_trainer_model import PreTrainerModel
11 |
12 |
13 | class GraphPred(nn.Module):
14 | def __init__(self, config: TrainingConfig, gnn: GNN, num_tasks: int):
15 | super(GraphPred, self).__init__()
16 | self.gnn: nn.Module = gnn
17 | self.emb_dim = config.emb_dim
18 | self.pool: Callable = global_mean_pool
19 | # self.downstream_head: nn.Module = nn.Sequential(
20 | # nn.Linear(self.emb_dim, self.emb_dim),
21 | # nn.ReLU(inplace=True),
22 | # nn.Linear(self.emb_dim, self.emb_dim),
23 | # )
24 | if config.JK == "concat":
25 | self.downstream_head = nn.Linear(
26 | (self.num_layer + 1) * self.emb_dim, num_tasks
27 | )
28 | else:
29 | self.downstream_head = nn.Linear(self.emb_dim, num_tasks)
30 |
31 | def forward(
32 | self,
33 | x: Tensor,
34 | edge_index: Tensor,
35 | edge_attr: Tensor,
36 | batch_assignments: Tensor,
37 | ) -> Tensor:
38 | x = self.gnn(x, edge_index, edge_attr)
39 | x = self.pool(x, batch_assignments)
40 | x = self.downstream_head(x)
41 | return x
42 |
43 | # def from_pretrained(self, model_file):
44 | # self.gnn.load_state_dict(model_file)
45 | # return
46 |
--------------------------------------------------------------------------------
/src/models/info_max.py:
--------------------------------------------------------------------------------
1 | """ InfoMax + Graph
2 | Ref Paper 1: https://arxiv.org/abs/1908.01000
3 | Ref Paper 2: https://arxiv.org/abs/1908.01000 """
4 | from typing import Callable, Tuple
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch_geometric.data import Batch
9 | from torch_geometric.nn import global_mean_pool
10 |
11 | from config.training_config import TrainingConfig
12 | from models import GNN, Discriminator
13 | from models.pre_trainer_model import PreTrainerModel
14 |
15 |
16 | class InfoMaxModel(PreTrainerModel):
17 | def __init__(self, config: TrainingConfig, gnn: GNN):
18 | super(InfoMaxModel, self).__init__(config)
19 | self.gnn: nn.Module = gnn
20 | self.pool: Callable = global_mean_pool
21 | self.infograph_discriminator_SSL_model = Discriminator(config.emb_dim)
22 |
23 | def forward_embedding(self, batch: Batch) -> Tuple[torch.Tensor]:
24 | node_emb = self.gnn(batch.x, batch.edge_index, batch.edge_attr)
25 | pooled_emb = self.pool(node_emb, batch.batch)
26 | return node_emb, pooled_emb
27 |
--------------------------------------------------------------------------------
/src/models/joao_v2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from config.training_config import TrainingConfig
4 | from models.building_blocks.gnn import GNN
5 | from models.graph_cl import GraphCLModel
6 |
7 |
8 | class JOAOv2Model(GraphCLModel):
9 | def __init__(self, config: TrainingConfig, gnn: GNN):
10 | super(JOAOv2Model, self).__init__(config=config, gnn=gnn)
11 | self.gnn: nn.Module = gnn
12 | self.projection_head: nn.ModuleList = nn.ModuleList(
13 | [
14 | nn.Sequential(
15 | nn.Linear(self.emb_dim, self.emb_dim),
16 | nn.ReLU(inplace=True),
17 | nn.Linear(self.emb_dim, self.emb_dim),
18 | )
19 | for _ in range(5)
20 | ]
21 | )
22 |
23 | def forward_cl(self, x, edge_index, edge_attr, batch, n_aug=0):
24 | x = self.gnn(x, edge_index, edge_attr)
25 | x = self.pool(x, batch)
26 | x = self.projection_head[n_aug](x)
27 | return x
28 |
--------------------------------------------------------------------------------
/src/models/motif.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch.nn as nn
4 | from torch import Tensor
5 | from torch_geometric.nn import global_mean_pool
6 |
7 | from config.training_config import TrainingConfig
8 | from datasets import RDKIT_PROPS
9 | from models.building_blocks.gnn import GNN
10 | from models.pre_trainer_model import PreTrainerModel
11 |
12 |
13 | class MotifModel(PreTrainerModel):
14 | def __init__(self, config: TrainingConfig, gnn: GNN):
15 | super(MotifModel, self).__init__(config)
16 | self.gnn: nn.Module = gnn
17 | self.emb_dim = config.emb_dim
18 | self.num_tasks = len(RDKIT_PROPS)
19 | self.pool: Callable = global_mean_pool
20 | self.criterion = nn.BCEWithLogitsLoss()
21 | self.prediction_model: nn.Module = nn.Sequential(
22 | nn.Linear(self.emb_dim, self.num_tasks)
23 | )
24 |
25 | def forward_cl(
26 | self,
27 | x: Tensor,
28 | edge_index: Tensor,
29 | edge_attr: Tensor,
30 | batch_assignments: Tensor,
31 | ) -> Tensor:
32 | x = self.gnn(x, edge_index, edge_attr)
33 | x = self.pool(x, batch_assignments)
34 | x = self.prediction_model(x)
35 | return x
36 |
37 | def loss_cl(self, y_pred: Tensor, y_actual: Tensor) -> Tensor:
38 | loss = self.criterion(y_pred, y_actual)
39 | return loss
40 |
--------------------------------------------------------------------------------
/src/models/pre_trainer_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from config.training_config import TrainingConfig
3 |
4 |
5 | class PreTrainerModel(torch.nn.Module):
6 | def __init__(self, config: TrainingConfig):
7 | super().__init__()
8 | self.config = config
9 |
10 | def get_embeddings(self, *argv):
11 | return self.gnn.get_embeddings(*argv)
12 |
--------------------------------------------------------------------------------
/src/models/rgcl.py:
--------------------------------------------------------------------------------
1 | # Ref: https://github.com/lsh0520/RGCL/blob/main/transferLearning/chem/pretrain_rgcl.py
2 |
3 | from typing import Callable
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch import Tensor
8 | from torch_scatter import scatter_max
9 | from torch_geometric.nn import global_mean_pool
10 | from config.training_config import TrainingConfig
11 | from models.pre_trainer_model import PreTrainerModel
12 | from models.building_blocks.gnn import GNN, GNN_IMP_Estimator
13 |
14 |
15 | class RGCLModel(PreTrainerModel):
16 | def __init__(self, config: TrainingConfig, gnn: GNN):
17 | super(RGCLModel, self).__init__(config)
18 | self.gnn: nn.Module = gnn
19 | self.emb_dim = config.emb_dim
20 | self.pool: Callable = global_mean_pool
21 | self.node_imp_estimator = GNN_IMP_Estimator()
22 | self.projection_head: nn.Module = nn.Sequential(
23 | nn.Linear(self.emb_dim, self.emb_dim),
24 | nn.ReLU(inplace=True),
25 | nn.Linear(self.emb_dim, self.emb_dim),
26 | )
27 |
28 | def forward_cl(
29 | self,
30 | x: Tensor,
31 | edge_index: Tensor,
32 | edge_attr: Tensor,
33 | batch_assignments: Tensor,
34 | ) -> Tensor:
35 | node_imp = self.node_imp_estimator(x, edge_index, edge_attr, batch_assignments)
36 | x = self.gnn(x, edge_index, edge_attr)
37 |
38 | out, _ = scatter_max(torch.reshape(node_imp, (1, -1)), batch_assignments)
39 | out = out.reshape(-1, 1)
40 | out = out[batch_assignments]
41 | node_imp /= out * 10
42 | node_imp += 0.9
43 | node_imp = node_imp.expand(-1, 300)
44 |
45 | x = torch.mul(x, node_imp)
46 | x = self.pool(x, batch_assignments)
47 | x = self.projection_head(x)
48 |
49 | return x
50 |
51 | def loss_cl(self, x1: Tensor, x2: Tensor, temp=0.1) -> Tensor:
52 | T = temp
53 | batch, _ = x1.size()
54 | x1_abs = x1.norm(dim=1)
55 | x2_abs = x2.norm(dim=1)
56 |
57 | sim_matrix = torch.einsum("ik,jk->ij", x1, x2) / torch.einsum(
58 | "i,j->ij", x1_abs, x2_abs
59 | )
60 | sim_matrix = torch.exp(sim_matrix / T)
61 | pos_sim = sim_matrix[range(batch), range(batch)]
62 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
63 | loss = -torch.log(loss).mean()
64 | return loss
65 |
66 | def loss_infonce(self, x1, x2, temp=0.1):
67 | T = temp
68 | batch_size, _ = x1.size()
69 | x1_abs = x1.norm(dim=1)
70 | x2_abs = x2.norm(dim=1)
71 |
72 | sim_matrix = torch.einsum("ik,jk->ij", x1, x2) / torch.einsum(
73 | "i,j->ij", x1_abs, x2_abs
74 | )
75 | sim_matrix = torch.exp(sim_matrix / T)
76 | pos_sim = sim_matrix[range(batch_size), range(batch_size)]
77 | loss = pos_sim / sim_matrix.sum(dim=1)
78 | loss = -torch.log(loss).mean()
79 | return loss
80 |
81 | def loss_ra(self, x1, x2, x3, temp=0.1, lamda=0.1):
82 | batch_size, _ = x1.size()
83 | x1_abs = x1.norm(dim=1)
84 | x2_abs = x2.norm(dim=1)
85 | x3_abs = x3.norm(dim=1)
86 |
87 | cp_sim_matrix = torch.einsum("ik,jk->ij", x1, x3) / torch.einsum(
88 | "i,j->ij", x1_abs, x3_abs
89 | )
90 | cp_sim_matrix = torch.exp(cp_sim_matrix / temp)
91 |
92 | sim_matrix = torch.einsum("ik,jk->ij", x1, x2) / torch.einsum(
93 | "i,j->ij", x1_abs, x2_abs
94 | )
95 | sim_matrix = torch.exp(sim_matrix / temp)
96 |
97 | pos_sim = sim_matrix[range(batch_size), range(batch_size)]
98 |
99 | ra_loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
100 | ra_loss = -torch.log(ra_loss).mean()
101 |
102 | cp_loss = pos_sim / (cp_sim_matrix.sum(dim=1) + pos_sim)
103 | cp_loss = -torch.log(cp_loss).mean()
104 |
105 | loss = ra_loss + lamda * cp_loss
106 |
107 | return ra_loss, cp_loss, loss
108 |
--------------------------------------------------------------------------------
/src/pretrainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .attribute_masking import AttributeMaskingPreTrainer
2 | from .context_prediction import ContextPredictionPreTrainer
3 | from .contextual import ContextualPreTrainer
4 | from .edge_prediction import EdgePredictionPreTrainer
5 | from .gpt_gnn import GPTGNNPreTrainer
6 | from .graph_cl import GraphCLPreTrainer
7 | from .info_max import InfoMaxPreTrainer
8 | from .joao import JOAOPreTrainer
9 | from .joao_v2 import JOAOv2PreTrainer
10 | from .motif import MotifPreTrainer
11 | from .graphmvp import GraphMVPPreTrainer
12 | from .rgcl import RGCLPreTrainer
13 | from .graphmae import GraphMAEPreTrainer
14 | from .pretrainer import PreTrainer
15 |
--------------------------------------------------------------------------------
/src/pretrainers/attribute_masking.py:
--------------------------------------------------------------------------------
1 | """ GRAPH SSL Pre-Training via Attribute Masking
2 | i.e., maps nodes in similar structural contexts to closer embeddings
3 | Ref Paper: Sec. 3.1.2 and Appendix G of
4 | https://arxiv.org/abs/1905.12265 ;
5 | Ref Code: ${GitHub_Repo}/chem/pretrain_contextpred.py """
6 |
7 | from typing import Tuple
8 |
9 | import torch
10 | from torch.utils.data import DataLoader
11 | from torch_geometric.data import Batch
12 |
13 | from config.training_config import TrainingConfig
14 | from logger import CombinedLogger
15 | from models.attribute_masking import AttributeMaskingModel
16 | from pretrainers.pretrainer import PreTrainer
17 | from util import get_lr
18 |
19 |
20 | class AttributeMaskingPreTrainer(PreTrainer):
21 | def __init__(
22 | self,
23 | config: TrainingConfig,
24 | model: AttributeMaskingModel,
25 | optimizer: torch.optim.Optimizer,
26 | device: torch.device,
27 | logger: CombinedLogger,
28 | ) -> None:
29 | super(AttributeMaskingPreTrainer, self).__init__(
30 | config=config,
31 | model=model,
32 | optimizer=optimizer,
33 | device=device,
34 | logger=logger,
35 | )
36 |
37 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
38 | self.model.train()
39 | self.logger.train(num_batches=len(train_data_loader))
40 |
41 | attributemask_loss_accum = 0.0
42 |
43 | for step, batch in enumerate(train_data_loader):
44 | batch = batch.to(self.device)
45 |
46 | loss, acc = do_AttrMasking(
47 | batch=batch,
48 | criterion=torch.nn.CrossEntropyLoss(),
49 | model=self.model,
50 | )
51 | loss_float = loss.detach().cpu().item()
52 | attributemask_loss_accum += loss_float
53 | self.optimizer.zero_grad()
54 | loss.backward()
55 | self.optimizer.step()
56 | self.logger(loss_float, acc, batch.num_graphs, get_lr(self.optimizer))
57 |
58 | return attributemask_loss_accum / (step + 1)
59 |
60 | # TODO
61 | def validate_model(self, val_data_loader: DataLoader) -> float:
62 | # self.logger.eval(num_batches=len(val_data_loader))
63 | self.logger.eval(num_batches=1)
64 | self.logger(0.0, 0.0, 1)
65 | return 0.0
66 |
67 |
68 | def compute_accuracy(pred: torch.Tensor, target: torch.Tensor) -> float:
69 | pred_cls = torch.max(pred.detach(), dim=1)[1]
70 | return float(torch.sum(pred_cls == target).cpu().item()) / len(pred)
71 |
72 |
73 | def do_AttrMasking(
74 | batch: Batch, criterion: torch.nn.Module, model: AttributeMaskingModel
75 | ) -> Tuple[torch.Tensor, float]:
76 | target = batch.mask_node_label[:, 0]
77 | node_pred = model.forward(batch)
78 | attributemask_loss = criterion(node_pred.double(), target)
79 | attributemask_acc = compute_accuracy(node_pred, target)
80 | return attributemask_loss, attributemask_acc
81 |
--------------------------------------------------------------------------------
/src/pretrainers/context_prediction.py:
--------------------------------------------------------------------------------
1 | """ GRAPH SSL Pre-Training via Context Prediction (CP)
2 | i.e., maps nodes in similar structural contexts to closer embeddings
3 | Ref Paper: Sec. 3.1.1 and Appendix G of
4 | https://arxiv.org/abs/1905.12265 ;
5 | Ref Code: ${GitHub_Repo}/chem/pretrain_contextpred.py """
6 |
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import DataLoader
11 | from torch_geometric.data import Batch
12 |
13 | from config.training_config import TrainingConfig
14 | from logger import CombinedLogger
15 | from models.context_prediction import ContextPredictionModel
16 | from pretrainers.pretrainer import PreTrainer
17 | from util import cycle_idx, get_lr
18 |
19 |
20 | class ContextPredictionPreTrainer(PreTrainer):
21 | def __init__(
22 | self,
23 | config: TrainingConfig,
24 | model: ContextPredictionModel,
25 | optimizer: torch.optim.Optimizer,
26 | device: torch.device,
27 | logger: CombinedLogger,
28 | ):
29 | super(ContextPredictionPreTrainer, self).__init__(
30 | config=config,
31 | model=model,
32 | optimizer=optimizer,
33 | device=device,
34 | logger=logger,
35 | )
36 | self.criterion = nn.BCEWithLogitsLoss()
37 |
38 | def validate_model(self, val_data_loader: DataLoader):
39 | pass
40 |
41 | def train_for_one_epoch(self, train_data_loader: DataLoader):
42 | self.model.train()
43 | contextpred_loss_accum, contextpred_acc_accum = 0, 0
44 | self.logger.train(num_batches=len(train_data_loader))
45 |
46 | for step, batch in enumerate(train_data_loader):
47 | batch = batch.to(self.device)
48 | contextpred_loss, contextpred_acc = self.do_ContextPred(batch=batch)
49 | loss_float = contextpred_loss.detach().cpu().item()
50 | ssl_loss = contextpred_loss
51 | self.optimizer.zero_grad()
52 | ssl_loss.backward()
53 | self.optimizer.step()
54 | self.logger(
55 | ssl_loss.cpu().item(),
56 | contextpred_acc,
57 | batch.num_graphs,
58 | get_lr(self.optimizer),
59 | )
60 | contextpred_loss_accum += loss_float
61 | contextpred_acc_accum += contextpred_acc
62 |
63 | return contextpred_loss_accum / len(train_data_loader)
64 | # return contextpred_loss_accum / len(train_data_loader), \
65 | # contextpred_acc_accum / len(train_data_loader)
66 |
67 | def do_ContextPred(self, batch: Batch):
68 | # creating substructure representation
69 | substruct_repr = self.model.forward_substruct_model(
70 | x=batch.x_substruct,
71 | edge_index=batch.edge_index_substruct,
72 | edge_attr=batch.edge_attr_substruct,
73 | )[batch.center_substruct_idx]
74 |
75 | # substruct_repr = molecule_substruct_model(
76 | # batch.x_substruct, batch.edge_index_substruct,
77 | # batch.edge_attr_substruct)[batch.center_substruct_idx]
78 |
79 | # create positive context representation
80 | # readout -> global_mean_pool by default
81 |
82 | context_repr = self.model.forward_context_repr(
83 | x=batch.x_context,
84 | edge_index=batch.edge_index_context,
85 | edge_attr=batch.edge_attr_context,
86 | overlap_context_substruct_idx=batch.overlap_context_substruct_idx,
87 | batch_overlapped_context=batch.batch_overlapped_context,
88 | )
89 |
90 | # negative contexts are obtained by shifting
91 | # the indices of context embeddings
92 | neg_context_repr = torch.cat(
93 | [
94 | context_repr[cycle_idx(len(context_repr), i + 1)]
95 | for i in range(self.config.contextpred_neg_samples)
96 | ],
97 | dim=0,
98 | )
99 |
100 | num_neg = self.config.contextpred_neg_samples
101 |
102 | pred_pos = torch.sum(substruct_repr * context_repr, dim=1)
103 | pred_neg = torch.sum(
104 | substruct_repr.repeat((num_neg, 1)) * neg_context_repr, dim=1
105 | )
106 |
107 | loss_pos = self.criterion(
108 | pred_pos.double(), torch.ones(len(pred_pos)).to(pred_pos.device).double()
109 | )
110 | loss_neg = self.criterion(
111 | pred_neg.double(), torch.zeros(len(pred_neg)).to(pred_neg.device).double()
112 | )
113 |
114 | contextpred_loss = loss_pos + num_neg * loss_neg
115 |
116 | num_pred = len(pred_pos) + len(pred_neg)
117 | contextpred_acc = (
118 | torch.sum(pred_pos > 0).float() + torch.sum(pred_neg < 0).float()
119 | ) / num_pred
120 | contextpred_acc = contextpred_acc.detach().cpu().item()
121 |
122 | return contextpred_loss, contextpred_acc
123 |
124 |
125 | if __name__ == "__main__":
126 | pass
127 |
--------------------------------------------------------------------------------
/src/pretrainers/contextual.py:
--------------------------------------------------------------------------------
1 | """ GROVER, Contextual Property Prediction
2 | Ref Paper: https://arxiv.org/abs/2007.02835
3 | Ref Code: https://github.com/tencent-ailab/grover """
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.utils.data import DataLoader
8 |
9 | from config.training_config import TrainingConfig
10 | from logger import CombinedLogger
11 | from models.contextual import ContextualModel
12 | from pretrainers.pretrainer import PreTrainer
13 | from util import get_lr
14 |
15 |
16 | class ContextualPreTrainer(PreTrainer):
17 | def __init__(
18 | self,
19 | config: TrainingConfig,
20 | model: ContextualModel,
21 | optimizer: torch.optim.Optimizer,
22 | device: torch.device,
23 | logger: CombinedLogger,
24 | ):
25 | super(ContextualPreTrainer, self).__init__(
26 | config=config,
27 | model=model,
28 | optimizer=optimizer,
29 | device=device,
30 | logger=logger,
31 | )
32 |
33 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
34 | self.model.train()
35 | train_loss_accum = 0.0
36 | self.logger.train(num_batches=len(train_data_loader))
37 |
38 | for step, batch in enumerate(train_data_loader):
39 | batch = batch.to(self.device)
40 | node_pred = self.model.forward_cl(
41 | batch.x, batch.edge_index, batch.edge_attr, batch.batch
42 | )
43 | node_target = batch.atom_vocab_label
44 | loss = self.model.loss_cl(node_pred, node_target)
45 |
46 | self.optimizer.zero_grad()
47 | loss.backward()
48 | self.optimizer.step()
49 | loss_float = float(loss.detach().cpu().item())
50 | train_loss_accum += loss_float
51 | # loss, accuracy, batch_size, learning_rate
52 | self.logger(loss_float, 0.0, batch.num_graphs, get_lr(self.optimizer))
53 |
54 | return train_loss_accum / (step + 1)
55 |
56 | # TODO
57 | def validate_model(self, val_data_loader: DataLoader) -> float:
58 | # self.logger.eval(num_batches=len(val_data_loader))
59 | self.logger.eval(num_batches=1)
60 | self.logger(0.0, 0.0, 1)
61 | return 0.0
62 |
--------------------------------------------------------------------------------
/src/pretrainers/edge_prediction.py:
--------------------------------------------------------------------------------
1 | """ GRAPH SSL Pre-Training via Edge Prediction
2 | Ref Paper: Sec. 5.2 and Appendix G of
3 | https://arxiv.org/abs/1905.12265 ;
4 | which is adapted from
5 | https://arxiv.org/abs/1706.02216 ;
6 |
7 | Ref Code: ${GitHub_Repo}/chem/pretrain_edgepred.py """
8 |
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.utils.data import DataLoader
13 |
14 | from config.training_config import TrainingConfig
15 | from logger import CombinedLogger
16 | from models.edge_prediction import EdgePredictionModel
17 | from pretrainers.pretrainer import PreTrainer
18 | from util import get_lr
19 |
20 |
21 | def do_EdgePred(node_repr, batch, criterion):
22 | # positive/negative scores -> inner product of node features
23 | positive_score = torch.sum(
24 | node_repr[batch.edge_index[0, ::2]] * node_repr[batch.edge_index[1, ::2]], dim=1
25 | )
26 | negative_score = torch.sum(
27 | node_repr[batch.negative_edge_index[0]]
28 | * node_repr[batch.negative_edge_index[1]],
29 | dim=1,
30 | )
31 |
32 | edgepred_loss = criterion(
33 | positive_score, torch.ones_like(positive_score)
34 | ) + criterion(negative_score, torch.zeros_like(negative_score))
35 | edgepred_acc = (torch.sum(positive_score > 0) + torch.sum(negative_score < 0)).to(
36 | torch.float32
37 | ) / float(2 * len(positive_score))
38 | edgepred_acc = edgepred_acc.detach().cpu().item()
39 |
40 | return edgepred_loss, edgepred_acc
41 |
42 |
43 | class EdgePredictionPreTrainer(PreTrainer):
44 | def __init__(
45 | self,
46 | config: TrainingConfig,
47 | model: EdgePredictionModel,
48 | optimizer: torch.optim.Optimizer,
49 | device: torch.device,
50 | logger: CombinedLogger,
51 | ) -> None:
52 | super(EdgePredictionPreTrainer, self).__init__(
53 | config=config,
54 | model=model,
55 | optimizer=optimizer,
56 | device=device,
57 | logger=logger,
58 | )
59 |
60 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
61 | self.model.train()
62 | self.logger.train(num_batches=len(train_data_loader))
63 |
64 | loss_accum = 0.0
65 |
66 | for step, batch in enumerate(train_data_loader):
67 | batch = batch.to(self.device)
68 | node_repr = self.model(batch)
69 | edgepred_loss, edgepred_acc = do_EdgePred(
70 | node_repr=node_repr, batch=batch, criterion=nn.BCEWithLogitsLoss()
71 | )
72 | loss_float = edgepred_loss.detach().cpu().item()
73 | loss_accum += loss_float
74 | self.optimizer.zero_grad()
75 | edgepred_loss.backward()
76 | self.optimizer.step()
77 | self.logger(
78 | loss_float, edgepred_acc, batch.num_graphs, get_lr(self.optimizer)
79 | )
80 |
81 | return loss_accum / (step + 1)
82 |
83 | # TODO
84 | def validate_model(self, val_data_loader: DataLoader) -> float:
85 | # self.logger.eval(num_batches=len(val_data_loader))
86 | self.logger.eval(num_batches=1)
87 | self.logger(0.0, 0.0, 1)
88 | return 0.0
89 |
90 |
91 | if __name__ == "__main__":
92 | pass
93 |
--------------------------------------------------------------------------------
/src/pretrainers/gpt_gnn.py:
--------------------------------------------------------------------------------
1 | """ GPT-GNN pre-training
2 | Current version only supports the node reconstruction for molecular data.
3 | Because the current GIN model only supports node representation.
4 | In the future, we will add edge reconstruction and more general graph data."""
5 |
6 | import torch
7 | from torch.utils.data import DataLoader
8 |
9 | from config.training_config import TrainingConfig
10 | from logger import CombinedLogger
11 | from models import GPTGNNModel
12 | from pretrainers.pretrainer import PreTrainer
13 | from util import get_lr
14 |
15 |
16 | class GPTGNNPreTrainer(PreTrainer):
17 | def __init__(
18 | self,
19 | config: TrainingConfig,
20 | model: GPTGNNModel,
21 | optimizer: torch.optim.Optimizer,
22 | device: torch.device,
23 | logger: CombinedLogger,
24 | ) -> None:
25 | super(GPTGNNPreTrainer, self).__init__(
26 | config=config,
27 | model=model,
28 | optimizer=optimizer,
29 | device=device,
30 | logger=logger,
31 | )
32 |
33 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
34 | self.model.train()
35 | self.logger.train(num_batches=len(train_data_loader))
36 | gpt_loss_accum = 0.0
37 | criterion = torch.nn.CrossEntropyLoss()
38 |
39 | for step, batch in enumerate(train_data_loader):
40 | batch = batch.to(self.device)
41 |
42 | node_pred = self.model(batch)
43 | target = batch.next_x
44 |
45 | loss = criterion(node_pred.double(), target)
46 | acc = compute_accuracy(node_pred, target)
47 |
48 | loss_float = loss.detach().cpu().item()
49 | gpt_loss_accum += loss.detach().cpu().item()
50 | self.optimizer.zero_grad()
51 | loss.backward()
52 | self.optimizer.step()
53 | self.logger(loss_float, acc, batch.num_graphs, get_lr(self.optimizer))
54 |
55 | return gpt_loss_accum / (step + 1)
56 |
57 | # TODO
58 | def validate_model(self, val_data_loader: DataLoader) -> float:
59 | # self.logger.eval(num_batches=len(val_data_loader))
60 | self.logger.eval(num_batches=1)
61 | self.logger(0.0, 0.0, 1)
62 | return 0.0
63 |
64 |
65 | def compute_accuracy(pred, target):
66 | return float(
67 | torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item()
68 | ) / len(pred)
69 |
--------------------------------------------------------------------------------
/src/pretrainers/graph_cl.py:
--------------------------------------------------------------------------------
1 | # Ref: {GitHub}/transferLearning_MoleculeNet_PPI
2 | # Ref: https://arxiv.org/abs/2010.13902
3 |
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from config.training_config import TrainingConfig
7 | from logger import CombinedLogger
8 | from models.graph_cl import GraphCLModel
9 | from pretrainers.pretrainer import PreTrainer
10 | from util import get_lr
11 |
12 |
13 | class GraphCLPreTrainer(PreTrainer):
14 | def __init__(
15 | self,
16 | config: TrainingConfig,
17 | model: GraphCLModel,
18 | optimizer: torch.optim.Optimizer,
19 | device: torch.device,
20 | logger: CombinedLogger,
21 | ):
22 | super(GraphCLPreTrainer, self).__init__(
23 | config=config,
24 | model=model,
25 | optimizer=optimizer,
26 | device=device,
27 | logger=logger,
28 | )
29 |
30 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
31 | self.model.train()
32 | train_loss_accum = 0.0
33 | self.logger.train(num_batches=len(train_data_loader))
34 |
35 | for step, (_, batch1, batch2) in enumerate(train_data_loader):
36 | batch1 = batch1.to(self.device)
37 | batch2 = batch2.to(self.device)
38 |
39 | x1 = self.model.forward_cl(
40 | batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch
41 | )
42 | x2 = self.model.forward_cl(
43 | batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch
44 | )
45 | loss = self.model.loss_cl(x1, x2)
46 |
47 | self.optimizer.zero_grad()
48 | loss.backward()
49 | self.optimizer.step()
50 | loss_float = float(loss.detach().cpu().item())
51 | train_loss_accum += loss_float
52 | self.logger(loss_float, 0.0, batch1.num_graphs, get_lr(self.optimizer))
53 |
54 | return train_loss_accum / (step + 1)
55 |
56 | # TODO
57 | def validate_model(self, val_data_loader: DataLoader) -> float:
58 | # self.logger.eval(num_batches=len(val_data_loader))
59 | self.logger.eval(num_batches=1)
60 | self.logger(0.0, 0.0, 1)
61 | return 0.0
62 |
--------------------------------------------------------------------------------
/src/pretrainers/graphmae.py:
--------------------------------------------------------------------------------
1 | # Ref: {GitHub}/lsh0520/RGCL/transferLearning/chem/pretrain_rgcl.py
2 | # Ref: https://arxiv.org/abs/2010.13902
3 | # TODO: TO UPDATE
4 | """ GRAPH SSL Pre-Training via InfoGraph [InfoGraph]
5 | i.e., maps nodes in similar structural contexts to closer embeddings
6 | Ref Paper: Sec. 5.2 and Appendix G of
7 | https://arxiv.org/abs/1905.12265 ;
8 | which is adapted from
9 | https://arxiv.org/abs/1809.10341 ;
10 | Ref Code: ${GitHub_Repo}/chem/pretrain_deepgraphinfomax.py """
11 |
12 | import torch
13 |
14 | from config.training_config import TrainingConfig
15 | from pretrainers.pretrainer import PreTrainer
16 | from models.graphmae import GraphMAEModel
17 | from torch.utils.data import DataLoader
18 | from logger import CombinedLogger
19 | from util import get_lr
20 | from typing import List
21 |
22 |
23 | class GraphMAEPreTrainer(PreTrainer):
24 | def __init__(
25 | self,
26 | config: TrainingConfig,
27 | model: GraphMAEModel,
28 | optimizer: List,
29 | device: torch.device,
30 | logger: CombinedLogger,
31 | ):
32 | super(GraphMAEPreTrainer, self).__init__(
33 | config=config,
34 | model=model,
35 | optimizer=optimizer,
36 | device=device,
37 | logger=logger,
38 | )
39 | assert len(self.optimizer) == 2
40 |
41 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
42 | train_loss_accum = 0
43 | self.model.train()
44 | self.logger.train(num_batches=len(train_data_loader))
45 |
46 | for step, batch in enumerate(train_data_loader):
47 | batch = batch.to(self.device)
48 | self.optimizer[0].zero_grad()
49 | self.optimizer[1].zero_grad()
50 | pred_node = self.model.forward(batch)
51 |
52 | loss = self.model.loss(pred_node, batch)
53 | loss.backward()
54 | self.optimizer[0].step()
55 | self.optimizer[1].step()
56 |
57 | loss_float = float(loss.detach().cpu().item())
58 | train_loss_accum += loss_float
59 | self.logger(loss_float, 0.0, batch.num_graphs, get_lr(self.optimizer[0]))
60 |
61 | return train_loss_accum / (step + 1)
62 |
63 | def validate_model(self, val_data_loader) -> float:
64 | # self.logger.eval(num_batches=len(val_data_loader))
65 | self.logger.eval(num_batches=1)
66 | self.logger(0.0, 0.0, 1)
67 | return 0.0
68 |
--------------------------------------------------------------------------------
/src/pretrainers/graphmvp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 |
4 | from config.training_config import TrainingConfig
5 | from logger import CombinedLogger
6 | from models.graphmvp import GraphMVPModel
7 | from pretrainers.pretrainer import PreTrainer
8 | from util import get_lr
9 |
10 |
11 | class GraphMVPPreTrainer(PreTrainer):
12 | def __init__(
13 | self,
14 | config: TrainingConfig,
15 | model: GraphMVPModel,
16 | optimizer: torch.optim.Optimizer,
17 | device: torch.device,
18 | logger: CombinedLogger,
19 | ):
20 | super(GraphMVPPreTrainer, self).__init__(
21 | config=config,
22 | model=model,
23 | optimizer=optimizer,
24 | device=device,
25 | logger=logger,
26 | )
27 |
28 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
29 | self.model.train()
30 | train_loss_accum = 0.0
31 | self.logger.train(num_batches=len(train_data_loader))
32 |
33 | for step, batch in enumerate(train_data_loader):
34 | batch = batch.to(self.device)
35 | repr_2d, repr_3d = self.model(batch)
36 | loss = self.model.loss(repr_2d, repr_3d)
37 |
38 | self.optimizer.zero_grad()
39 | loss.backward()
40 | self.optimizer.step()
41 |
42 | loss_float = float(loss.detach().cpu().item())
43 | train_loss_accum += loss_float
44 | self.logger(loss_float, 0.0, batch.num_graphs, get_lr(self.optimizer))
45 |
46 | return train_loss_accum / (step + 1)
47 |
48 | def validate_model(self, val_data_loader: DataLoader) -> float:
49 | return 0
50 |
--------------------------------------------------------------------------------
/src/pretrainers/info_max.py:
--------------------------------------------------------------------------------
1 | """ GRAPH SSL Pre-Training via InfoGraph [InfoGraph]
2 | i.e., maps nodes in similar structural contexts to closer embeddings
3 | Ref Paper: Sec. 5.2 and Appendix G of
4 | https://arxiv.org/abs/1905.12265 ;
5 | which is adapted from
6 | https://arxiv.org/abs/1809.10341 ;
7 | Ref Code: ${GitHub_Repo}/chem/pretrain_deepgraphinfomax.py """
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch_geometric.data import Batch, DataLoader
12 | from torch_geometric.nn import global_mean_pool
13 |
14 | from config.training_config import TrainingConfig
15 | from logger import CombinedLogger
16 | from models.info_max import InfoMaxModel
17 | from pretrainers.pretrainer import PreTrainer
18 | from util import cycle_idx, get_lr
19 |
20 |
21 | class InfoMaxPreTrainer(PreTrainer):
22 | def __init__(
23 | self,
24 | config: TrainingConfig,
25 | model: InfoMaxModel,
26 | optimizer: torch.optim.Optimizer,
27 | device: torch.device,
28 | logger: CombinedLogger,
29 | ):
30 | super(InfoMaxPreTrainer, self).__init__(
31 | config=config,
32 | model=model,
33 | optimizer=optimizer,
34 | device=device,
35 | logger=logger,
36 | )
37 | # TODO: perhaps we can add the below attributes as arguments?
38 | self.molecule_readout_func = global_mean_pool
39 | self.criterion = nn.BCEWithLogitsLoss()
40 |
41 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
42 | self.model.train()
43 | infograph_loss_accum, infograph_acc_accum = 0, 0
44 | self.logger.train(num_batches=len(train_data_loader))
45 |
46 | for step, batch in enumerate(train_data_loader):
47 | batch = batch.to(self.device)
48 | node_emb, pooled_emb = self.model.forward_embedding(batch)
49 | infograph_loss, infograph_acc = do_InfoGraph(
50 | node_repr=node_emb,
51 | batch=batch,
52 | molecule_repr=pooled_emb,
53 | criterion=self.criterion,
54 | model=self.model,
55 | )
56 |
57 | self.optimizer.zero_grad()
58 | infograph_loss.backward()
59 | self.optimizer.step()
60 | infograph_loss = infograph_loss.detach().cpu().item()
61 | infograph_loss_accum += infograph_loss
62 | infograph_acc_accum += infograph_acc
63 | self.logger(
64 | infograph_loss, infograph_acc, batch.num_graphs, get_lr(self.optimizer)
65 | )
66 |
67 | return infograph_loss_accum / (step + 1)
68 |
69 | # TODO
70 | def validate_model(self, val_data_loader: DataLoader) -> float:
71 | # self.logger.eval(num_batches=len(val_data_loader))
72 | self.logger.eval(num_batches=1)
73 | self.logger(0.0, 0.0, 1)
74 | return 0.0
75 |
76 |
77 | def do_InfoGraph(
78 | node_repr: torch.Tensor,
79 | molecule_repr: torch.Tensor,
80 | batch: Batch,
81 | criterion: torch.nn.Module,
82 | model: InfoMaxModel,
83 | ):
84 | summary_repr = torch.sigmoid(molecule_repr)
85 | positive_expanded_summary_repr = summary_repr[batch.batch]
86 | shifted_summary_repr = summary_repr[cycle_idx(len(summary_repr), 1)]
87 | negative_expanded_summary_repr = shifted_summary_repr[batch.batch]
88 |
89 | positive_score = model.infograph_discriminator_SSL_model(
90 | node_repr, positive_expanded_summary_repr
91 | )
92 | negative_score = model.infograph_discriminator_SSL_model(
93 | node_repr, negative_expanded_summary_repr
94 | )
95 | infograph_loss = criterion(
96 | positive_score, torch.ones_like(positive_score)
97 | ) + criterion(negative_score, torch.zeros_like(negative_score))
98 |
99 | num_sample = float(2 * len(positive_score))
100 | infograph_acc = (torch.sum(positive_score > 0) + torch.sum(negative_score < 0)).to(
101 | torch.float32
102 | ) / num_sample
103 | infograph_acc = infograph_acc.detach().cpu().item()
104 |
105 | return infograph_loss, infograph_acc
106 |
--------------------------------------------------------------------------------
/src/pretrainers/joao.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import DataLoader
4 |
5 | from config.training_config import TrainingConfig
6 | from logger import CombinedLogger
7 | from models.graph_cl import GraphCLModel
8 | from pretrainers.pretrainer import PreTrainer
9 | from util import get_lr
10 |
11 |
12 | class JOAOPreTrainer(PreTrainer):
13 | def __init__(
14 | self,
15 | config: TrainingConfig,
16 | model: GraphCLModel,
17 | optimizer: torch.optim.Optimizer,
18 | device: torch.device,
19 | logger: CombinedLogger,
20 | ):
21 | super(JOAOPreTrainer, self).__init__(
22 | config=config,
23 | model=model,
24 | optimizer=optimizer,
25 | device=device,
26 | logger=logger,
27 | )
28 |
29 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
30 | self.model.train()
31 | train_loss_accum = 0.0
32 | self.logger.train(num_batches=len(train_data_loader))
33 |
34 | for step, (_, batch1, batch2) in enumerate(train_data_loader):
35 | batch1 = batch1.to(self.device)
36 | batch2 = batch2.to(self.device)
37 |
38 | x1 = self.model.forward_cl(
39 | batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch
40 | )
41 | x2 = self.model.forward_cl(
42 | batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch
43 | )
44 | loss = self.model.loss_cl(x1, x2)
45 |
46 | self.optimizer.zero_grad()
47 | loss.backward()
48 | self.optimizer.step()
49 |
50 | loss_float = float(loss.detach().cpu().item())
51 | train_loss_accum += loss_float
52 | self.logger(loss_float, 0.0, batch1.num_graphs, get_lr(self.optimizer))
53 |
54 | # joao
55 | aug_prob = train_data_loader.dataset.aug_prob
56 | # TODO: Can we replace the below constants (25, 10, etc.) with arguments in the config?
57 | loss_aug = np.zeros(25)
58 | for n in range(25):
59 | _aug_prob = np.zeros(25)
60 | _aug_prob[n] = 1
61 | train_data_loader.dataset.set_augProb(_aug_prob)
62 | # for efficiency, we only use around 10% of data to estimate the loss
63 | count, count_stop = (
64 | 0,
65 | len(train_data_loader.dataset) // (train_data_loader.batch_size * 10)
66 | + 1,
67 | )
68 |
69 | with torch.no_grad():
70 | for step, (_, batch1, batch2) in enumerate(train_data_loader):
71 | batch1 = batch1.to(self.device)
72 | batch2 = batch2.to(self.device)
73 |
74 | x1 = self.model.forward_cl(
75 | batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch
76 | )
77 | x2 = self.model.forward_cl(
78 | batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch
79 | )
80 | loss = self.model.loss_cl(x1, x2)
81 | loss_aug[n] += loss.item()
82 | count += 1
83 | if count == count_stop:
84 | break
85 | loss_aug[n] /= count
86 |
87 | # view selection, projected gradient descent,
88 | # reference: https://arxiv.org/abs/1906.03563
89 | beta = 1
90 | gamma = self.config.gamma_joao
91 |
92 | b = aug_prob + beta * (loss_aug - gamma * (aug_prob - 1 / 25))
93 | mu_min, mu_max = b.min() - 1 / 25, b.max() - 1 / 25
94 | mu = (mu_min + mu_max) / 2
95 |
96 | # bisection method
97 | while abs(np.maximum(b - mu, 0).sum() - 1) > 1e-2:
98 | if np.maximum(b - mu, 0).sum() > 1:
99 | mu_min = mu
100 | else:
101 | mu_max = mu
102 | mu = (mu_min + mu_max) / 2
103 |
104 | aug_prob = np.maximum(b - mu, 0)
105 | aug_prob /= aug_prob.sum()
106 | train_data_loader.dataset.set_augProb(aug_prob=aug_prob)
107 | self.logger.log_value_dict({"aug_prob": aug_prob})
108 | return train_loss_accum / (step + 1)
109 |
110 | # TODO
111 | def validate_model(self, val_data_loader: DataLoader) -> float:
112 | # self.logger.eval(num_batches=len(val_data_loader))
113 | self.logger.eval(num_batches=1)
114 | self.logger(0.0, 0.0, 1)
115 | return 0.0
116 |
--------------------------------------------------------------------------------
/src/pretrainers/joao_v2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import DataLoader
4 |
5 | from config.training_config import TrainingConfig
6 | from logger import CombinedLogger
7 | from models.joao_v2 import JOAOv2Model
8 | from pretrainers.pretrainer import PreTrainer
9 | from util import get_lr
10 |
11 |
12 | class JOAOv2PreTrainer(PreTrainer):
13 | def __init__(
14 | self,
15 | config: TrainingConfig,
16 | model: JOAOv2Model,
17 | optimizer: torch.optim.Optimizer,
18 | device: torch.device,
19 | logger: CombinedLogger,
20 | ):
21 | super(JOAOv2PreTrainer, self).__init__(
22 | config=config,
23 | model=model,
24 | optimizer=optimizer,
25 | device=device,
26 | logger=logger,
27 | )
28 |
29 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
30 | self.model.train()
31 | train_loss_accum = 0.0
32 | self.logger.train(num_batches=len(train_data_loader))
33 |
34 | aug_prob = train_data_loader.dataset.aug_prob
35 | n_aug = np.random.choice(25, 1, p=aug_prob)[0]
36 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5
37 |
38 | for step, (_, batch1, batch2) in enumerate(train_data_loader):
39 | batch1 = batch1.to(self.device)
40 | batch2 = batch2.to(self.device)
41 |
42 | x1 = self.model.forward_cl(
43 | batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch, n_aug1
44 | )
45 | x2 = self.model.forward_cl(
46 | batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch, n_aug2
47 | )
48 | loss = self.model.loss_cl(x1, x2)
49 |
50 | self.optimizer.zero_grad()
51 | loss.backward()
52 | self.optimizer.step()
53 |
54 | loss_float = float(loss.detach().cpu().item())
55 | train_loss_accum += loss_float
56 | self.logger(loss_float, 0.0, batch1.num_graphs, get_lr(self.optimizer))
57 |
58 | # joaov2
59 | aug_prob = train_data_loader.dataset.aug_prob
60 | # TODO: Can we replace the below constants (25, 10, etc.) with arguments in the config?
61 | loss_aug = np.zeros(25)
62 | for n in range(25):
63 | _aug_prob = np.zeros(25)
64 | _aug_prob[n] = 1
65 | train_data_loader.dataset.set_augProb(_aug_prob)
66 |
67 | count, count_stop = (
68 | 0,
69 | len(train_data_loader.dataset) // (train_data_loader.batch_size * 10)
70 | + 1,
71 | )
72 | # for efficiency, we only use around 10% of data to estimate the loss
73 | n_aug1, n_aug2 = n // 5, n % 5
74 | with torch.no_grad():
75 | for step, (_, batch1, batch2) in enumerate(train_data_loader):
76 | batch1 = batch1.to(self.device)
77 | batch2 = batch2.to(self.device)
78 |
79 | x1 = self.model.forward_cl(
80 | batch1.x,
81 | batch1.edge_index,
82 | batch1.edge_attr,
83 | batch1.batch,
84 | n_aug1,
85 | )
86 | x2 = self.model.forward_cl(
87 | batch2.x,
88 | batch2.edge_index,
89 | batch2.edge_attr,
90 | batch2.batch,
91 | n_aug2,
92 | )
93 | loss = self.model.loss_cl(x1, x2)
94 | loss_aug[n] += loss.item()
95 | count += 1
96 | if count == count_stop:
97 | break
98 | loss_aug[n] /= count
99 |
100 | # view selection, projected gradient descent,
101 | # reference: https://arxiv.org/abs/1906.03563
102 | beta = 1
103 | gamma = self.config.gamma_joaov2
104 |
105 | b = aug_prob + beta * (loss_aug - gamma * (aug_prob - 1 / 25))
106 | mu_min, mu_max = b.min() - 1 / 25, b.max() - 1 / 25
107 | mu = (mu_min + mu_max) / 2
108 |
109 | # bisection method
110 | while abs(np.maximum(b - mu, 0).sum() - 1) > 1e-2:
111 | if np.maximum(b - mu, 0).sum() > 1:
112 | mu_min = mu
113 | else:
114 | mu_max = mu
115 | mu = (mu_min + mu_max) / 2
116 |
117 | aug_prob = np.maximum(b - mu, 0)
118 | aug_prob /= aug_prob.sum()
119 | train_data_loader.dataset.set_augProb(aug_prob=aug_prob)
120 | self.logger.log_value_dict({"aug_prob": aug_prob})
121 | return train_loss_accum / (step + 1)
122 |
123 | # TODO
124 | def validate_model(self, val_data_loader: DataLoader) -> float:
125 | # self.logger.eval(num_batches=len(val_data_loader))
126 | self.logger.eval(num_batches=1)
127 | self.logger(0.0, 0.0, 1)
128 | return 0.0
129 |
--------------------------------------------------------------------------------
/src/pretrainers/motif.py:
--------------------------------------------------------------------------------
1 | """ GROVER, Graph-Level Motif Prediction
2 | Ref Paper: https://arxiv.org/abs/2007.02835
3 | Ref Code: https://github.com/tencent-ailab/grover """
4 |
5 |
6 | import torch
7 | from torch.utils.data import DataLoader
8 |
9 | from config.training_config import TrainingConfig
10 | from logger import CombinedLogger
11 | from models.motif import MotifModel
12 | from pretrainers.pretrainer import PreTrainer
13 | from util import get_lr
14 |
15 |
16 | class MotifPreTrainer(PreTrainer):
17 | def __init__(
18 | self,
19 | config: TrainingConfig,
20 | model: MotifModel,
21 | optimizer: torch.optim.Optimizer,
22 | device: torch.device,
23 | logger: CombinedLogger,
24 | ):
25 | super(MotifPreTrainer, self).__init__(
26 | config=config,
27 | model=model,
28 | optimizer=optimizer,
29 | device=device,
30 | logger=logger,
31 | )
32 |
33 | def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
34 | self.model.train()
35 | train_loss_accum = 0.0
36 | self.logger.train(num_batches=len(train_data_loader))
37 |
38 | for step, batch in enumerate(train_data_loader):
39 | batch = batch.to(self.device)
40 | pred = self.model.forward_cl(
41 | batch.x, batch.edge_index, batch.edge_attr, batch.batch
42 | )
43 | y = batch.y.view(pred.shape)
44 |
45 | loss = self.model.loss_cl(pred.double(), y.double())
46 |
47 | self.optimizer.zero_grad()
48 | loss.backward()
49 | self.optimizer.step()
50 | loss_float = float(loss.detach().cpu().item())
51 | train_loss_accum += loss_float
52 | self.logger(loss_float, 0.0, batch.num_graphs, get_lr(self.optimizer))
53 |
54 | return train_loss_accum / (step + 1)
55 |
56 | # TODO
57 | def validate_model(self, val_data_loader: DataLoader) -> float:
58 | # self.logger.eval(num_batches=len(val_data_loader))
59 | self.logger.eval(num_batches=1)
60 | self.logger(0.0, 0.0, 1)
61 | return 0.0
62 |
--------------------------------------------------------------------------------
/src/pretrainers/pretrainer.py:
--------------------------------------------------------------------------------
1 | import abc, torch, dataclasses
2 |
3 | from typing import Union, List
4 | from logger import CombinedLogger
5 | from torch.utils.data import DataLoader, Dataset
6 | from config.training_config import TrainingConfig
7 |
8 |
9 | @dataclasses.dataclass
10 | class PreTrainer(abc.ABC):
11 | config: TrainingConfig
12 | model: torch.nn.Module
13 | # optimizer: torch.optim.Optimizer
14 | optimizer: Union[torch.optim.Optimizer, List]
15 | device: torch.device
16 | logger: CombinedLogger
17 |
18 | @abc.abstractmethod
19 | def train_for_one_epoch(self, train_data_loader: Union[DataLoader, Dataset]):
20 | pass
21 |
22 | @abc.abstractmethod
23 | def validate_model(self, val_data_loader: Union[DataLoader, Dataset]):
24 | pass
25 |
--------------------------------------------------------------------------------
/src/pretrainers/rgcl.py:
--------------------------------------------------------------------------------
1 | # Ref: {GitHub}/lsh0520/RGCL/transferLearning/chem/pretrain_rgcl.py
2 | # Ref: https://arxiv.org/abs/2010.13902
3 | # TODO: TO UPDATE
4 | """ GRAPH SSL Pre-Training via InfoGraph [InfoGraph]
5 | i.e., maps nodes in similar structural contexts to closer embeddings
6 | Ref Paper: Sec. 5.2 and Appendix G of
7 | https://arxiv.org/abs/1905.12265 ;
8 | which is adapted from
9 | https://arxiv.org/abs/1809.10341 ;
10 | Ref Code: ${GitHub_Repo}/chem/pretrain_deepgraphinfomax.py """
11 |
12 | import gc, torch
13 | from config.training_config import TrainingConfig
14 | from pretrainers.pretrainer import PreTrainer
15 | from torch_geometric.loader import DataLoader
16 |
17 | # from torch.utils.data import DataLoader
18 | from logger import CombinedLogger
19 | from models.rgcl import RGCLModel
20 | from copy import deepcopy
21 | from util import get_lr
22 |
23 |
24 | class RGCLPreTrainer(PreTrainer):
25 | def __init__(
26 | self,
27 | config: TrainingConfig,
28 | model: RGCLModel,
29 | optimizer: torch.optim.Optimizer,
30 | device: torch.device,
31 | logger: CombinedLogger,
32 | ):
33 | super(RGCLPreTrainer, self).__init__(
34 | config=config,
35 | model=model,
36 | optimizer=optimizer,
37 | device=device,
38 | logger=logger,
39 | )
40 |
41 | # def train_for_one_epoch(self, train_data_loader: DataLoader) -> float:
42 | def train_for_one_epoch(self, dataset) -> float:
43 | dataset.aug = "none"
44 | loader = DataLoader(dataset, batch_size=2048, num_workers=4, shuffle=False)
45 | self.model.eval()
46 | torch.set_grad_enabled(False)
47 |
48 | for step, batch in enumerate(loader):
49 | node_index_start = step * 2048
50 | node_index_end = min(node_index_start + 2048 - 1, len(dataset) - 1)
51 | batch = batch.to(self.device)
52 | node_imp = self.model.node_imp_estimator(
53 | batch.x, batch.edge_index, batch.edge_attr, batch.batch
54 | ).detach()
55 | dataset.node_score[
56 | dataset.slices["x"][node_index_start] : dataset.slices["x"][
57 | node_index_end + 1
58 | ]
59 | ] = torch.squeeze(node_imp.half())
60 |
61 | dataset1 = deepcopy(dataset)
62 | dataset1 = dataset1.shuffle()
63 | dataset2 = deepcopy(dataset1)
64 | dataset3 = deepcopy(dataset1)
65 |
66 | dataset1.aug, dataset1.aug_ratio = "dropN", 0.2
67 | dataset2.aug, dataset2.aug_ratio = "dropN", 0.2
68 | dataset3.aug, dataset3.aug_ratio = "dropN" + "_cp", 0.2
69 |
70 | loader1 = DataLoader(
71 | dataset1,
72 | batch_size=self.config.batch_size,
73 | num_workers=self.config.num_workers,
74 | shuffle=False,
75 | )
76 | loader2 = DataLoader(
77 | dataset2,
78 | batch_size=self.config.batch_size,
79 | num_workers=self.config.num_workers,
80 | shuffle=False,
81 | )
82 | loader3 = DataLoader(
83 | dataset3,
84 | batch_size=self.config.batch_size,
85 | num_workers=self.config.num_workers,
86 | shuffle=False,
87 | )
88 | train_loss_accum = 0
89 | ra_loss_accum = 0
90 | cp_loss_accum = 0
91 |
92 | torch.set_grad_enabled(True)
93 | self.model.train()
94 | self.logger.train(num_batches=len(loader1))
95 |
96 | for step, batch in enumerate(zip(loader1, loader2, loader3)):
97 | batch1, batch2, batch3 = batch
98 | batch1 = batch1.to(self.device)
99 | batch2 = batch2.to(self.device)
100 | batch3 = batch3.to(self.device)
101 |
102 | self.optimizer.zero_grad()
103 |
104 | x1 = self.model.forward_cl(
105 | batch1.x, batch1.edge_index, batch1.edge_attr, batch1.batch
106 | )
107 | x2 = self.model.forward_cl(
108 | batch2.x, batch2.edge_index, batch2.edge_attr, batch2.batch
109 | )
110 | x3 = self.model.forward_cl(
111 | batch3.x, batch3.edge_index, batch3.edge_attr, batch3.batch
112 | )
113 |
114 | ra_loss, cp_loss, loss = self.model.loss_ra(x1, x2, x3)
115 |
116 | loss.backward()
117 | self.optimizer.step()
118 | loss_float = float(loss.detach().cpu().item())
119 | train_loss_accum += loss_float
120 | ra_loss_accum += float(ra_loss.detach().cpu().item())
121 | cp_loss_accum += float(cp_loss.detach().cpu().item())
122 | self.logger(loss_float, 0.0, batch1.num_graphs, get_lr(self.optimizer))
123 | # del dataset1, dataset2, dataset3
124 | # gc.collect()
125 |
126 | # return train_loss_accum/(step+1), ra_loss_accum/(step+1), cp_loss_accum/(step+1)
127 | return train_loss_accum / (step + 1)
128 |
129 | def validate_model(self, val_data_loader) -> float:
130 | # self.logger.eval(num_batches=len(val_data_loader))
131 | self.logger.eval(num_batches=1)
132 | self.logger(0.0, 0.0, 1)
133 | return 0.0
134 |
--------------------------------------------------------------------------------
/src/run_embedding_extraction.py:
--------------------------------------------------------------------------------
1 | from load_save import infer_and_save_embeddings, load_checkpoint
2 | from config.validation_config import parse_config
3 | from init import (
4 | get_data_loader_val,
5 | get_dataset_extraction,
6 | get_dataset_split,
7 | get_device,
8 | get_model,
9 | get_smiles_list,
10 | init,
11 | )
12 |
13 |
14 | def main() -> None:
15 | """=== Load the PreTrained Weights ==="""
16 | config = parse_config()
17 | init(config=config)
18 | device = get_device(config=config)
19 |
20 | """ === Set the Datasets and Loader === """
21 | dataset = get_dataset_extraction(config=config)
22 | smiles_list = get_smiles_list(config=config)
23 | dataset_splits = get_dataset_split(
24 | config=config, dataset=dataset, smiles_list=smiles_list
25 | )
26 | datasets_list = [dataset for dataset in dataset_splits.values()]
27 | # a list of MoleculeDataset class: ['train', 'val', 'test', 'smiles']
28 | loaders = [
29 | get_data_loader_val(config=config, dataset=dataset, shuffle=False)
30 | for dataset in datasets_list
31 | ]
32 | # a tuple of list of smiles: ('train', 'val', 'test')
33 | smile_splits = dataset_splits["smiles"]
34 |
35 | pre_trained_model = get_model(config=config).to(device)
36 | checkpoint = load_checkpoint(config=config, device=device)
37 | pre_trained_model.load_state_dict(checkpoint["model"])
38 |
39 | """ === Save Node & Graph Embeddings === """
40 | infer_and_save_embeddings(
41 | config=config,
42 | model=pre_trained_model,
43 | device=device,
44 | datasets=datasets_list,
45 | loaders=loaders,
46 | smile_splits=smile_splits,
47 | )
48 |
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/src/run_pretraining.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from config.training_config import TrainingConfig, parse_config
4 | from datasets import MoleculeDataset
5 | from logger import CombinedLogger
6 | from load_save import save_model
7 | from init import (
8 | get_data_loader,
9 | get_dataset,
10 | get_device,
11 | get_model,
12 | get_optimizer,
13 | get_pretrainer,
14 | init,
15 | )
16 |
17 |
18 | def pretrain_model(
19 | config: TrainingConfig,
20 | dataset: MoleculeDataset,
21 | model: torch.nn.Module,
22 | device: torch.device,
23 | ) -> None:
24 | """=== Generic Pre-Training Wrapper ==="""
25 | logger = CombinedLogger(config=config)
26 | optimizer = get_optimizer(config=config, model=model)
27 | train_data_loader = get_data_loader(config=config, dataset=dataset)
28 | pre_trainer = get_pretrainer(
29 | config=config,
30 | model=model,
31 | optimizer=optimizer,
32 | device=device,
33 | logger=logger,
34 | )
35 |
36 | for epoch in range(config.epochs):
37 | # the epoch_0 is random initliased weights
38 | if epoch % config.epochs_save == 0 and config.save_model:
39 | save_model(config=config, model=pre_trainer.model, epoch=epoch)
40 | if config.pretrainer == "RGCL":
41 | pre_trainer.train_for_one_epoch(dataset)
42 | else:
43 | pre_trainer.train_for_one_epoch(train_data_loader)
44 |
45 | if config.save_model:
46 | save_model(config=config, model=pre_trainer.model, epoch=epoch)
47 |
48 |
49 | def main() -> None:
50 | config = parse_config()
51 | init(config=config)
52 | device = get_device(config=config)
53 | dataset = get_dataset(config=config)
54 | model = get_model(config=config).to(device)
55 | pretrain_model(config=config, dataset=dataset, model=model, device=device)
56 |
57 |
58 | if __name__ == "__main__":
59 | main()
60 |
--------------------------------------------------------------------------------
/src/run_validation.py:
--------------------------------------------------------------------------------
1 | from config.validation_config import parse_config
2 | from init import get_task, init
3 |
4 |
5 | def main() -> None:
6 | config = parse_config()
7 | init(config=config)
8 | task = get_task(config=config)
9 | task.run()
10 |
11 |
12 | if __name__ == "__main__":
13 | main()
14 |
--------------------------------------------------------------------------------
/src/validation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansen7/MolGraphEval/2f83950b622d98937500c2efd8091a5af88c8cca/src/validation/__init__.py
--------------------------------------------------------------------------------
/src/validation/dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 |
3 |
4 | class ProberDataset(Dataset):
5 | """A dataset class that holds (representation, label) pairs."""
6 |
7 | def __init__(self, representations, labels):
8 | assert len(representations) == len(labels)
9 | self.input_dim = len(representations[0]) # 300
10 | self.representations = representations
11 | self.labels = labels
12 |
13 | def __len__(self):
14 | return len(self.labels)
15 |
16 | def __getitem__(self, idx):
17 | return {
18 | "representation": self.representations[idx],
19 | "label": self.labels[idx],
20 | }
21 |
--------------------------------------------------------------------------------
/src/validation/task/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import Task, TrainValTestTask
2 |
--------------------------------------------------------------------------------
/src/validation/task/finetune_task.py:
--------------------------------------------------------------------------------
1 | # import pdb
2 | from typing import Dict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.utils.data import DataLoader
7 | from tqdm import tqdm
8 |
9 | from config.validation_config import ValidationConfig
10 |
11 | # from logger import CombinedLogger
12 | # from util import get_lr
13 | from validation.dataset import ProberDataset
14 | from validation.task import TrainValTestTask
15 |
16 |
17 | class ProberTask(TrainValTestTask):
18 | def __init__(
19 | self,
20 | config: ValidationConfig,
21 | model: nn.Module,
22 | device: torch.device,
23 | optimizer: torch.optim.Optimizer,
24 | train_dataset: ProberDataset,
25 | val_dataset: ProberDataset,
26 | test_dataset: ProberDataset,
27 | criterion_type: str = "mse",
28 | ):
29 | super(ProberTask, self).__init__(
30 | config=config,
31 | model=model,
32 | device=device,
33 | optimizer=optimizer,
34 | train_dataset=train_dataset,
35 | val_dataset=val_dataset,
36 | test_dataset=test_dataset,
37 | criterion_type=config.criterion_type,
38 | )
39 | self.train_loader = DataLoader(
40 | self.train_dataset, batch_size=config.batch_size, shuffle=True
41 | )
42 | self.val_loader = DataLoader(
43 | self.val_dataset, batch_size=config.batch_size, shuffle=False
44 | )
45 | self.test_loader = DataLoader(
46 | self.test_dataset, batch_size=config.batch_size, shuffle=False
47 | )
48 |
49 | if criterion_type == "mse":
50 | self.criterion = nn.MSELoss(reduction="mean")
51 | elif criterion_type == "bce":
52 | self.criterion = nn.BCEWithLogitsLoss(reduction="mean")
53 | elif criterion_type == "ce":
54 | self.criterion = nn.CrossEntropyLoss(reduction="mean")
55 | else:
56 | raise Exception("Unknown criterion {}".format(criterion_type))
57 |
58 | self.train_loss = []
59 | self.train_score, self.eval_score, self.test_score = [], [], []
60 |
61 | def run(self, model: nn.Module = None, device: torch.device = None) -> Dict:
62 | # final_train_loss = self.eval_train_dataset()
63 | final_train_loss = self.train()
64 | final_val_loss = self.eval_val_dataset()
65 | final_test_loss = self.eval_test_dataset()
66 | results_dict = {
67 | "train_loss": final_train_loss,
68 | "val_loss": final_val_loss,
69 | "test_loss": final_test_loss,
70 | }
71 | print(results_dict)
72 | return results_dict
73 |
74 | def train(self) -> float:
75 | for _ in tqdm(range(self.config.epochs)):
76 | train_loss = self.train_step()
77 | # val_loss = self.eval_val_dataset()
78 | return train_loss
79 |
80 | def train_step(self) -> float:
81 | self.model.train()
82 | total_loss = 0
83 | # self.logger.train(num_batches=len(self.train_loader))
84 |
85 | for _, batch in enumerate(self.train_loader):
86 | # pdb.set_trace()
87 | pred = self.model(batch["representation"].to(self.device)).squeeze()
88 | y = batch["label"].to(torch.float32).to(self.device)
89 | loss = self.criterion(pred, y)
90 | self.optimizer.zero_grad()
91 | loss.backward()
92 | self.optimizer.step()
93 | # loss_float = loss.detach().item()
94 | total_loss += loss.detach().item()
95 |
96 | return total_loss / len(self.train_loader)
97 |
98 | def _eval(self, loader: DataLoader) -> float:
99 | self.model.eval()
100 | total_val_loss = 0.0
101 | for _, batch in enumerate(loader):
102 | with torch.no_grad():
103 | inputs = batch["representation"].to(self.device)
104 | pred = self.model(inputs).squeeze()
105 | true_target = batch["label"].to(torch.float32).to(self.device)
106 | val_loss = self.criterion(pred, true_target).item()
107 | total_val_loss += val_loss
108 | return total_val_loss
109 |
110 | def _eval_roc(self, loader: DataLoader) -> float:
111 | self.model.eval()
112 | y_true, y_pred = [], []
113 |
114 | for _, batch in enumerate(loader):
115 | with torch.no_grad():
116 | if self.config.val_task == "prober":
117 | inputs = batch["representation"].to(self.device)
118 | pred = self.model(inputs)
119 | true = batch["label"].to(torch.float32).to(self.device)
120 | elif self.config.val_task == "finetune":
121 | batch = batch.to(self.device)
122 | pred = self.model(
123 | batch.x, batch.edge_index, batch.edge_attr, batch.batch
124 | )
125 | true = batch.y.view(pred.shape)
126 |
127 | y_true.append(true)
128 | y_pred.append(pred)
129 | # y_pred.append(pred.view(true.shape))
130 |
131 | y_true = torch.cat(y_true, dim=0).cpu().numpy()
132 | y_pred = torch.cat(y_pred, dim=0).cpu().numpy()
133 |
134 | roc_list = []
135 | for i in range(y_true.shape[1]):
136 | # AUC is only defined when there is at least one positive data.
137 | # if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
138 | # roc_list.append(roc_auc_score(y_true[:, i], y_pred[:, i]))
139 |
140 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
141 | is_valid = y_true[:, i] ** 2 > 0
142 | roc_list.append(
143 | roc_auc_score((y_true[is_valid, i] + 1) / 2, y_pred[is_valid, i])
144 | )
145 |
146 | if len(roc_list) < y_true.shape[1]:
147 | print("Some target is missing!")
148 | print("Missing ratio: %f" % (1 - float(len(roc_list)) / y_true.shape[1]))
149 |
150 | return sum(roc_list) / len(roc_list)
151 |
152 | def eval_train_dataset(self) -> float:
153 | return self._eval(loader=self.train_loader)
154 |
155 | def eval_val_dataset(self) -> float:
156 | return self._eval(loader=self.val_loader)
157 |
158 | def eval_test_dataset(self) -> float:
159 | return self._eval(loader=self.test_loader)
160 |
--------------------------------------------------------------------------------
/src/validation/task/graph_edit_distance.py:
--------------------------------------------------------------------------------
1 | import gmatch4py as gm
2 | import networkx as nx
3 |
4 | from validation.task.metrics import GraphPairMetric
5 |
6 |
7 | class GraphEditDistanceDataset(GraphPairMetric):
8 | """Compare the distance (minimal edits) between two graphs"""
9 |
10 | def __init__(self, representation_path, config):
11 | super(GraphEditDistanceDataset, self).__init__(representation_path, config)
12 |
13 | def graph_pair_metric(
14 | self, graph_nx_one: nx.Graph, graph_nx_two: nx.Graph
15 | ) -> float:
16 | ged = gm.GraphEditDistance(1, 1, 1, 1) # all edit costs are equal to 1
17 | result = ged.compare([graph_nx_one, graph_nx_two], None)
18 | return result[0][1] / 100
19 |
--------------------------------------------------------------------------------
/src/validation/task/graph_level.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import numpy as np
3 | from networkx.algorithms.distance_measures import diameter
4 | from rdkit.Chem import Fragments
5 |
6 | from datasets import graph_data_obj_to_mol_simple, graph_data_obj_to_nx_simple
7 | from validation.task.metrics import GraphLevelMetric
8 |
9 |
10 | class GraphDiameterDataset(GraphLevelMetric):
11 | """Calculate 2D graph diameter, i.e. longest path."""
12 |
13 | def __init__(self, representation_path, config, split, **kwargs):
14 | super(GraphDiameterDataset, self).__init__(
15 | representation_path, config, split, **kwargs
16 | )
17 |
18 | @staticmethod
19 | def graph_level_metric(graph, method="largest") -> int:
20 | graph_nx = graph_data_obj_to_nx_simple(graph)
21 | if nx.is_connected(graph_nx):
22 | return diameter(graph_nx)
23 | elif method == "largest":
24 | graphs = [
25 | graph_nx.subgraph(c).copy() for c in nx.connected_components(graph_nx)
26 | ]
27 | return max([diameter(g) for g in graphs])
28 | return None
29 |
30 |
31 | class CycleBasisDataset(GraphLevelMetric):
32 | def __init__(self, representation_path, config, split, **kwargs):
33 | super(CycleBasisDataset, self).__init__(
34 | representation_path, config, split, **kwargs
35 | )
36 |
37 | @staticmethod
38 | def graph_level_metric(graph):
39 | """Calculate number of cycles."""
40 | graph_nx = graph_data_obj_to_nx_simple(graph)
41 | return len(nx.cycle_basis(graph_nx))
42 |
43 |
44 | class AssortativityCoefficientDataset(GraphLevelMetric):
45 | """Compute degree assortativity of graph.
46 |
47 | Assortativity measures the similarity of connections
48 | in the graph with respect to the node degree."""
49 |
50 | def __init__(self, representation_path, config, split, **kwargs):
51 | super(AssortativityCoefficientDataset, self).__init__(
52 | representation_path, config, split, **kwargs
53 | )
54 |
55 | @staticmethod
56 | def graph_level_metric(graph) -> float:
57 | graph_nx = graph_data_obj_to_nx_simple(graph)
58 | ret = nx.algorithms.assortativity.degree_assortativity_coefficient(G=graph_nx)
59 | if np.isnan(ret):
60 | return None
61 | return ret
62 |
63 |
64 | class AverageClusteringCoefficientDataset(GraphLevelMetric):
65 | """Estimates the average clustering coefficient of a graph.
66 |
67 | The local clustering of each node in `G` is the fraction of triangles
68 | that actually exist over all possible triangles in its neighborhood.
69 | The average clustering coefficient of a graph `G` is the mean of
70 | local clusters.
71 |
72 | This function finds an approximate average clustering coefficient
73 | for G by repeating `n` times (defined in `trials`) the following
74 | experiment: choose a node at random, choose two of its neighbors
75 | at random, and check if they are connected. The approximate
76 | coefficient is the fraction of triangles found over the number
77 | of trials [1]_."""
78 |
79 | def __init__(self, representation_path, config, split, **kwargs):
80 | super(AverageClusteringCoefficientDataset, self).__init__(
81 | representation_path, config, split, **kwargs
82 | )
83 |
84 | @staticmethod
85 | def graph_level_metric(graph) -> float:
86 | graph_nx = graph_data_obj_to_nx_simple(graph)
87 | return nx.algorithms.approximation.average_clustering(G=graph_nx)
88 |
89 |
90 | class NodeConnectivityDataset(GraphLevelMetric):
91 | """Returns an approximation for node connectivity for a graph or digraph G.
92 |
93 | Node connectivity is equal to the minimum number of nodes that
94 | must be removed to disconnect G or render it trivial. By Menger's theorem,
95 | this is equal to the number of node independent paths (paths that
96 | share no nodes other than source and target).
97 |
98 | This algorithm is based on a fast approximation that gives an strict lower
99 | bound on the actual number of node independent paths between two nodes [1]_.
100 | It works for both directed and undirected graphs.
101 |
102 | References
103 | ----------
104 | .. [1] White, Douglas R., and Mark Newman. 2001 A Fast Algorithm for
105 | Node-Independent Paths. Santa Fe Institute Working Paper #01-07-035
106 | http://eclectic.ss.uci.edu/~drwhite/working.pdf"""
107 |
108 | def __init__(self, representation_path, config, split, **kwargs):
109 | super(NodeConnectivityDataset, self).__init__(
110 | representation_path, config, split, **kwargs
111 | )
112 |
113 | @staticmethod
114 | def graph_level_metric(graph) -> int:
115 | graph_nx = graph_data_obj_to_nx_simple(graph)
116 | return nx.algorithms.approximation.node_connectivity(G=graph_nx)
117 |
118 |
119 | # class BenzeneRingDataset(GraphLevelMetric):
120 | # """Calculate the number of Benzene rings."""
121 | #
122 | # def __init__(self, representation_path, config, split, **kwargs):
123 | # super(BenzeneRingDataset, self).__init__(
124 | # representation_path, config, split, **kwargs)
125 | #
126 | # @staticmethod
127 | # def graph_level_metric(graph) -> int:
128 | # mol = graph_data_obj_to_mol_simple(
129 | # graph.x, graph.edge_index, graph.edge_attr)
130 | # return Fragments.fr_benzene(mol)
131 |
132 |
133 | RDKIT_fragments_valid = [
134 | "fr_epoxide",
135 | "fr_lactam",
136 | "fr_morpholine",
137 | "fr_oxazole",
138 | "fr_tetrazole",
139 | "fr_N_O",
140 | "fr_ether",
141 | "fr_furan",
142 | "fr_guanido",
143 | "fr_halogen",
144 | "fr_morpholine",
145 | "fr_piperdine",
146 | "fr_thiazole",
147 | "fr_thiophene",
148 | "fr_urea",
149 | "fr_allylic_oxid",
150 | "fr_amide",
151 | "fr_amidine",
152 | "fr_azo",
153 | "fr_benzene",
154 | "fr_imidazole",
155 | "fr_imide",
156 | "fr_piperzine",
157 | "fr_pyridine",
158 | ]
159 |
160 |
161 | class RDKiTFragmentDataset(GraphLevelMetric):
162 | """Calculate the number of Benzene rings."""
163 |
164 | def __init__(self, representation_path, config, split, des):
165 | super(RDKiTFragmentDataset, self).__init__(
166 | representation_path, config, split, des=des
167 | )
168 |
169 | # @staticmethod
170 | def graph_level_metric(self, graph):
171 | assert self.des in RDKIT_fragments_valid
172 | mol = graph_data_obj_to_mol_simple(graph.x, graph.edge_index, graph.edge_attr)
173 | cmd = compile("Fragments.%s(mol)" % self.des, "", "eval")
174 | try:
175 | return eval(cmd)
176 | except:
177 | return
178 |
179 |
180 | class DownstreamDataset(GraphLevelMetric):
181 | def __init__(self, representation_path, config, split, des):
182 | super(DownstreamDataset, self).__init__(
183 | representation_path, config, split, des=des
184 | )
185 |
186 | @staticmethod
187 | def graph_level_metric(graph):
188 | labels = graph.y.cpu().numpy().ravel()
189 | return labels
190 | # if ((labels ** 2) != 1).any():
191 | # return None
192 | # return (labels + 1)/2
193 |
--------------------------------------------------------------------------------
/src/validation/task/metrics.py:
--------------------------------------------------------------------------------
1 | # import pdb
2 | import pickle as pkl
3 | import random
4 | from abc import ABC, abstractmethod
5 |
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from datasets import graph_data_obj_to_nx_simple
10 | from validation.dataset import ProberDataset
11 | from validation.utils import get_dataset_extraction, get_dataset_split, get_smiles_list
12 |
13 |
14 | class NodeLevelMetric(ABC):
15 | def __init__(self, representation_path, config, split, **kwargs):
16 | self.representation_path = representation_path
17 | self.__dict__.update(kwargs)
18 | self.config = config
19 | with open(representation_path, "rb") as f:
20 | (
21 | _,
22 | self.node_repr_list,
23 | self.smiles,
24 | ) = pkl.load(f)
25 | self.num_graph = len(self.smiles)
26 |
27 | dataset = get_dataset_extraction(config=config)
28 | smiles_list = get_smiles_list(config=config)
29 | dataset_splits = get_dataset_split(
30 | config=config, dataset=dataset, smiles_list=smiles_list
31 | )
32 | # "train", "val", "test"
33 | self.dataset = dataset_splits[split]
34 |
35 | def create_datasets(self):
36 | labels = []
37 | representations = []
38 | for graph_idx in range(self.num_graph):
39 | graph = self.dataset[graph_idx]
40 | node_repr = self.node_repr_list[graph_idx]
41 | metrics = self.node_level_metric(graph)
42 | assert len(metrics) == len(node_repr)
43 | for node_idx in range(len(metrics)):
44 | labels.append(metrics[node_idx])
45 | representations.append(node_repr[node_idx])
46 | return ProberDataset(representations, labels)
47 |
48 | @abstractmethod
49 | def node_level_metric(self):
50 | pass
51 |
52 |
53 | class GraphLevelMetric(ABC):
54 | def __init__(self, representation_path, config, split, **kwargs):
55 | self.representation_path = representation_path
56 | self.__dict__.update(kwargs)
57 | self.config = config
58 | with open(representation_path, "rb") as f:
59 | (self.graph_repr_list, _, self.smiles) = pkl.load(f)
60 | self.num_graph = len(self.smiles)
61 | dataset = get_dataset_extraction(config=config)
62 | smiles_list = get_smiles_list(config=config)
63 | dataset_splits = get_dataset_split(
64 | config=config, dataset=dataset, smiles_list=smiles_list
65 | )
66 | # "train", "validation", "test"
67 | self.dataset = dataset_splits[split]
68 |
69 | def create_datasets(self):
70 | labels = []
71 | representations = []
72 | for graph_idx in range(self.num_graph):
73 | graph = self.dataset[graph_idx]
74 | graph_repr = self.graph_repr_list[graph_idx]
75 | metrics = self.graph_level_metric(graph)
76 | if metrics is not None:
77 | labels.append(metrics)
78 | representations.append(graph_repr)
79 | return ProberDataset(representations, labels)
80 |
81 | @abstractmethod
82 | def graph_level_metric(self):
83 | pass
84 |
85 |
86 | class NodePairMetric(ABC):
87 | def __init__(self, representation_path, config, split, **kwargs):
88 | self.representation_path = representation_path
89 | self.__dict__.update(kwargs)
90 | self.config = config
91 | with open(representation_path, "rb") as f:
92 | (
93 | _,
94 | self.node_repr_list,
95 | self.smiles,
96 | ) = pkl.load(f)
97 | self.num_graph = len(self.smiles)
98 |
99 | dataset = get_dataset_extraction(config=config)
100 | smiles_list = get_smiles_list(config=config)
101 | dataset_splits = get_dataset_split(
102 | config=config, dataset=dataset, smiles_list=smiles_list
103 | )
104 | # "train", "validation", "test"
105 | self.dataset = dataset_splits[split]
106 |
107 | def create_datasets(self, num_pairs=30):
108 | labels = []
109 | representations = []
110 | for graph_idx in range(self.num_graph):
111 | graph = self.dataset[graph_idx]
112 | node_repr = self.node_repr_list[graph_idx]
113 | num_nodes = len(node_repr)
114 | graph_nx = graph_data_obj_to_nx_simple(graph)
115 | for _ in range(num_pairs):
116 | start_node = random.choice(list(range(num_nodes)))
117 | end_node = random.choice(list(range(num_nodes)))
118 | if start_node != end_node:
119 | representations.append(
120 | np.concatenate(
121 | [
122 | node_repr[start_node],
123 | node_repr[end_node],
124 | np.multiply(node_repr[start_node], node_repr[end_node]),
125 | ]
126 | )
127 | )
128 | labels.append(self.node_pair_metric(graph_nx, start_node, end_node))
129 | return ProberDataset(representations, labels)
130 |
131 | @abstractmethod
132 | def node_pair_metric(self, graph_nx, start_node, end_node):
133 | pass
134 |
135 |
136 | class GraphPairMetric(ABC):
137 | def __init__(self, representation_path, config):
138 | self.representation_path = representation_path
139 | self.config = config
140 | with open(representation_path, "rb") as f:
141 | (
142 | self.dataset,
143 | self.graph_repr_list,
144 | self.node_repr_list,
145 | self.smiles,
146 | ) = pkl.load(f)
147 | self.num_graph = len(self.smiles)
148 |
149 | def create_datasets(self, num_pairs=10000):
150 | def take_out_graph(dataset, idx):
151 | graph = dataset[idx]
152 | if self.config.pretrainer == "GraphCL" and isinstance(graph, tuple):
153 | graph = graph[0] # remove the contrastive augmented data.
154 | graph_nx = graph_data_obj_to_nx_simple(graph)
155 | return graph_nx
156 |
157 | labels = []
158 | representations = []
159 | for _ in tqdm(range(num_pairs)):
160 | graph_one_idx = random.choice(list(range(self.num_graph)))
161 | graph_two_idx = random.choice(list(range(self.num_graph)))
162 | if graph_one_idx != graph_two_idx:
163 | graph_nx_one = take_out_graph(self.dataset, graph_one_idx)
164 | graph_nx_two = take_out_graph(self.dataset, graph_two_idx)
165 | graph_repr_one = self.graph_repr_list[graph_one_idx]
166 | graph_repr_two = self.graph_repr_list[graph_two_idx]
167 |
168 | representations.append(
169 | np.concatenate(
170 | [
171 | graph_repr_one,
172 | graph_repr_two,
173 | np.multiply(graph_repr_one, graph_repr_two),
174 | ]
175 | )
176 | )
177 | labels.append(self.graph_pair_metric(graph_nx_one, graph_nx_two))
178 | return ProberDataset(representations, labels)
179 |
180 | @abstractmethod
181 | def graph_pair_metric(self, graph_nx_one, graph_nx_two):
182 | pass
183 |
184 |
185 | class FineTune_Metric(ABC):
186 | def __init__(self, representation_path, config, split, **kwargs):
187 | self.representation_path = representation_path
188 | self.__dict__.update(kwargs)
189 | self.config = config
190 | with open(representation_path, "rb") as f:
191 | _, _, self.smiles = pkl.load(f)
192 | self.num_graph = len(self.smiles)
193 |
194 | dataset = get_dataset_extraction(config=config)
195 | smiles_list = get_smiles_list(config=config)
196 | dataset_splits = get_dataset_split(
197 | config=config, dataset=dataset, smiles_list=smiles_list
198 | )
199 | # "train", "val", "test"
200 | self.dataset = dataset_splits[split]
201 |
202 | def create_datasets(self):
203 | labels = []
204 | representations = []
205 | for graph_idx in range(self.num_graph):
206 | graph = self.dataset[graph_idx]
207 | # node_repr = self.node_repr_list[graph_idx]
208 | metrics = self.node_level_metric(graph)
209 | # assert len(metrics) == len(node_repr)
210 | for node_idx in range(len(metrics)):
211 | labels.append(metrics[node_idx])
212 | representations.append(graph)
213 | return ProberDataset(representations, labels)
214 |
215 | @abstractmethod
216 | def node_level_metric(self):
217 | pass
218 |
--------------------------------------------------------------------------------
/src/validation/task/node_level.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import networkx as nx
4 |
5 | from datasets import graph_data_obj_to_mol_simple, graph_data_obj_to_nx_simple
6 | from validation.task.metrics import NodeLevelMetric
7 |
8 |
9 | class NodeCentralityDataset(NodeLevelMetric):
10 | """Calculate node centrality for each node in a graph."""
11 |
12 | def __init__(self, representation_path, config, split, **kwargs):
13 | super(NodeCentralityDataset, self).__init__(
14 | representation_path, config, split, **kwargs
15 | )
16 |
17 | @staticmethod
18 | def node_level_metric(graph):
19 | graph_nx = graph_data_obj_to_nx_simple(graph)
20 | try:
21 | centrality = nx.eigenvector_centrality(graph_nx, max_iter=1500)
22 | except nx.exception.PowerIterationFailedConvergence:
23 | print("PowerIterationFailedConvergence")
24 | return [0] * len(graph_nx.nodes())
25 |
26 | # centrality = nx.eigenvector_centrality_numpy(graph_nx)
27 | node_centrality = sorted(list(centrality.items()))
28 |
29 | return [c for v, c in node_centrality]
30 |
31 |
32 | class NodeDegreeDataset(NodeLevelMetric):
33 | """Calculate node degrees for each node in a graph."""
34 |
35 | def __init__(self, representation_path, config, split, **kwargs):
36 | super(NodeDegreeDataset, self).__init__(
37 | representation_path, config, split, **kwargs
38 | )
39 |
40 | @staticmethod
41 | def node_level_metric(graph) -> List[int]:
42 | metrics = []
43 | mol = graph_data_obj_to_mol_simple(graph.x, graph.edge_index, graph.edge_attr)
44 | for _, atom in enumerate(mol.GetAtoms()):
45 | metrics.append(atom.GetDegree())
46 | # if atom.GetDegree() > 4:
47 | # print(atom.GetDegree())
48 | return metrics
49 |
50 |
51 | class NodeClusteringDataset(NodeLevelMetric):
52 | """Calculate clustering coefficient for each node in a graph."""
53 |
54 | def __init__(self, representation_path, config, split, **kwargs):
55 | super(NodeClusteringDataset, self).__init__(
56 | representation_path, config, split, **kwargs
57 | )
58 |
59 | @staticmethod
60 | def node_level_metric(graph) -> List[float]:
61 | graph_nx = graph_data_obj_to_nx_simple(graph)
62 | clustering_coefficient = nx.clustering(graph_nx)
63 | node_clustering = sorted(list(clustering_coefficient.items()))
64 | return [c for v, c in node_clustering]
65 |
--------------------------------------------------------------------------------
/src/validation/task/pair_level.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import numpy as np
3 |
4 | from validation.task.metrics import NodePairMetric
5 |
6 |
7 | class LinkPredictionDataset(NodePairMetric):
8 | """To predict whether two nodes are connected."""
9 |
10 | def __init__(self, representation_path, config, split, **kwargs):
11 | super(LinkPredictionDataset, self).__init__(
12 | representation_path, config, split, **kwargs
13 | )
14 |
15 | def node_pair_metric(self, graph_nx, start_node, end_node):
16 | return int(graph_nx.has_edge(start_node, end_node))
17 |
18 |
19 | class KatzIndexDataset(NodePairMetric):
20 | """The most basic global overlap statistic.
21 | To compute the Katz index we simply count the number of paths
22 | of all lengths between a pair of nodes."""
23 |
24 | def __init__(self, representation_path, config, split, **kwargs):
25 | super(KatzIndexDataset, self).__init__(
26 | representation_path, config, split, **kwargs
27 | )
28 |
29 | def node_pair_metric(
30 | self, graph_nx: nx.Graph, start_node: int, end_node: int
31 | ) -> float:
32 | alpha = 0.3
33 | I = np.identity(len(graph_nx.nodes))
34 | katz_matrix = np.linalg.inv(I - nx.to_numpy_array(graph_nx) * alpha) - I
35 | return katz_matrix[start_node][end_node]
36 |
37 |
38 | class JaccardCoefficientDataset(NodePairMetric):
39 | r"""Compute the Jaccard coefficient of all node pairs.
40 |
41 | Jaccard coefficient of nodes `u` and `v` is defined as
42 | .. math::
43 | \frac{|\Gamma(u) \cap \Gamma(v)|}{|\Gamma(u) \cup \Gamma(v)|}
44 | where $\Gamma(u)$ denotes the set of neighbors of $u$.
45 |
46 | References
47 | ----------
48 | .. [1] D. Liben-Nowell, J. Kleinberg.
49 | The Link Prediction Problem for Social Networks (2004).
50 | http://www.cs.cornell.edu/home/kleinber/link-pred.pdf
51 | """
52 |
53 | def __init__(self, representation_path, config, split, **kwargs):
54 | super(JaccardCoefficientDataset, self).__init__(
55 | representation_path, config, split, **kwargs
56 | )
57 |
58 | def node_pair_metric(self, graph_nx, start_node, end_node):
59 | preds = nx.jaccard_coefficient(graph_nx, [(start_node, end_node)])
60 | for u, v, p in preds:
61 | return p
62 |
63 |
64 | # class GraphEditDistanceDataset(GraphPairMetric):
65 | # """ Compare the distance (minimal edits) between two graphs"""
66 | # def __init__(self, representation_path, config):
67 | # super(GraphEditDistanceDataset, self).__init__(
68 | # representation_path, config)
69 | #
70 | # def graph_pair_metric(
71 | # self, graph_nx_one: nx.Graph, graph_nx_two: nx.Graph) -> float:
72 | # ged = gm.GraphEditDistance(1, 1, 1, 1) # all edit costs equal to 1
73 | # result = ged.compare([graph_nx_one, graph_nx_two], None)
74 | # return result[0][1] / 100
75 |
--------------------------------------------------------------------------------
/src/validation/task/prober_task.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 | import wandb
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from sklearn.metrics import roc_auc_score
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 |
10 | from config.validation_config import ValidationConfig
11 |
12 | # from logger import CombinedLogger
13 | from util import get_lr
14 | from datasets import MoleculeDataset
15 | from validation.dataset import ProberDataset
16 | from validation.task import TrainValTestTask
17 |
18 |
19 | class ProberTask(TrainValTestTask):
20 | def __init__(
21 | self,
22 | config: ValidationConfig,
23 | model: nn.Module,
24 | device: torch.device,
25 | optimizer: torch.optim.Optimizer,
26 | # logger: CombinedLogger,
27 | train_dataset: Union[MoleculeDataset, ProberDataset],
28 | val_dataset: Union[MoleculeDataset, ProberDataset],
29 | test_dataset: Union[MoleculeDataset, ProberDataset],
30 | criterion_type: str = "mse",
31 | ):
32 | super(ProberTask, self).__init__(
33 | config=config,
34 | model=model,
35 | device=device,
36 | optimizer=optimizer,
37 | # logger=logger,
38 | train_dataset=train_dataset,
39 | val_dataset=val_dataset,
40 | test_dataset=test_dataset,
41 | criterion_type=config.criterion_type,
42 | )
43 | if config.val_task == "finetune":
44 | from torch_geometric.loader import DataLoader
45 | else:
46 | from torch.utils.data import DataLoader
47 | self.train_loader = DataLoader(
48 | self.train_dataset, batch_size=config.batch_size, shuffle=True
49 | )
50 | self.val_loader = DataLoader(
51 | self.val_dataset, batch_size=config.batch_size, shuffle=False
52 | )
53 | self.test_loader = DataLoader(
54 | self.test_dataset, batch_size=config.batch_size, shuffle=False
55 | )
56 |
57 | if criterion_type == "mse":
58 | self.criterion = nn.MSELoss(reduction="mean")
59 | elif criterion_type == "bce":
60 | self.criterion = nn.BCEWithLogitsLoss(reduction="mean")
61 | elif criterion_type == "ce":
62 | self.criterion = nn.CrossEntropyLoss(reduction="mean")
63 | else:
64 | raise Exception("Unknown criterion {}".format(criterion_type))
65 |
66 | def train(self) -> float:
67 | pass
68 |
69 | def run(self):
70 | for _ in tqdm(range(self.config.epochs)):
71 | if self.config.val_task == "finetune":
72 | train_loss = self.train_finetune_step()
73 | else:
74 | train_loss = self.train_step()
75 |
76 | if self.config.probe_task == "downstream":
77 | # train_score = self._eval_roc(loader=self.train_loader)
78 | val_score = self._eval_roc(loader=self.val_loader)
79 | test_score = self._eval_roc(loader=self.test_loader)
80 | else:
81 | val_score = self.eval_val_dataset()
82 | test_score = self.eval_test_dataset()
83 |
84 | results_dict = {
85 | "train_loss": train_loss,
86 | "val_score": val_score,
87 | "test_score": test_score,
88 | }
89 | wandb.log(results_dict)
90 |
91 | print(results_dict)
92 | print("\n\n\n")
93 | return
94 |
95 | def train_step(self) -> float:
96 | self.model.train()
97 | total_loss = 0
98 |
99 | for _, batch in enumerate(self.train_loader):
100 | if self.config.probe_task == "downstream":
101 | pred = self.model(batch["representation"].to(self.device))
102 | y = batch["label"].to(torch.float32).to(self.device)
103 |
104 | # Whether y is non-null or not.
105 | is_valid = y**2 > 0
106 | # Loss matrix
107 | loss_mat = self.criterion(pred.double(), (y + 1) / 2)
108 | # loss matrix after removing null target
109 | loss_mat = torch.where(
110 | is_valid,
111 | loss_mat,
112 | torch.zeros(loss_mat.shape).to(self.device).to(loss_mat.dtype),
113 | )
114 |
115 | self.optimizer.zero_grad()
116 | loss = torch.sum(loss_mat) / torch.sum(is_valid)
117 | loss.backward()
118 | self.optimizer.step()
119 | total_loss += loss.detach().item()
120 |
121 | else:
122 | pred = self.model(batch["representation"].to(self.device)).squeeze()
123 | y = batch["label"].to(torch.float32).to(self.device)
124 | loss = self.criterion(pred, y)
125 | self.optimizer.zero_grad()
126 | loss.backward()
127 | self.optimizer.step()
128 | total_loss += loss.detach().item()
129 |
130 | return total_loss / len(self.train_loader)
131 |
132 | def train_finetune_step(self) -> float:
133 | self.model.train()
134 | total_loss = 0
135 |
136 | if self.config.probe_task != "downstream":
137 | raise NotImplementedError
138 |
139 | for _, batch in enumerate(self.train_loader):
140 | batch = batch.to(self.device)
141 | pred = self.model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
142 | y = batch.y.view(pred.shape).to(torch.float64)
143 |
144 | is_valid = y**2 > 0
145 | loss_mat = self.criterion(pred.double(), (y + 1) / 2)
146 | loss_mat = torch.where(
147 | is_valid,
148 | loss_mat,
149 | torch.zeros(loss_mat.shape).to(self.device).to(loss_mat.dtype),
150 | )
151 |
152 | self.optimizer.zero_grad()
153 | loss = torch.sum(loss_mat) / torch.sum(is_valid)
154 | loss.backward()
155 | self.optimizer.step()
156 | total_loss += loss.detach().item()
157 |
158 | return total_loss / len(self.train_loader)
159 |
160 | def _eval(self, loader: DataLoader) -> float:
161 | self.model.eval()
162 | y_true, y_pred = [], []
163 | for _, batch in enumerate(loader):
164 | inputs = batch["representation"].to(self.device)
165 | with torch.no_grad():
166 | pred = self.model(inputs).squeeze()
167 | y_true.append(batch["label"])
168 | y_pred.append(pred.detach().cpu())
169 | y_true, y_pred = torch.cat(y_true), torch.cat(y_pred)
170 | if self.criterion is nn.BCELoss:
171 | y_pred = nn.Sigmoid(y_pred)
172 | return self.criterion(y_pred, y_true).item()
173 |
174 | def _eval_roc(self, loader: DataLoader) -> float:
175 | self.model.eval()
176 | y_true, y_pred = [], []
177 |
178 | for _, batch in enumerate(loader):
179 | with torch.no_grad():
180 | if self.config.val_task == "prober":
181 | inputs = batch["representation"].to(self.device)
182 | pred = self.model(inputs)
183 | true = batch["label"].to(torch.float32).to(self.device)
184 | elif self.config.val_task == "finetune":
185 | batch = batch.to(self.device)
186 | pred = self.model(
187 | batch.x, batch.edge_index, batch.edge_attr, batch.batch
188 | )
189 | true = batch.y.view(pred.shape)
190 |
191 | y_true.append(true)
192 | y_pred.append(pred)
193 | # y_pred.append(pred.view(true.shape))
194 |
195 | y_true = torch.cat(y_true, dim=0).cpu().numpy()
196 | y_pred = torch.cat(y_pred, dim=0).cpu().numpy()
197 |
198 | roc_list = []
199 | for i in range(y_true.shape[1]):
200 | # AUC is only defined when there is at least one positive data.
201 | # if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
202 | # roc_list.append(roc_auc_score(y_true[:, i], y_pred[:, i]))
203 |
204 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
205 | is_valid = y_true[:, i] ** 2 > 0
206 | roc_list.append(
207 | roc_auc_score((y_true[is_valid, i] + 1) / 2, y_pred[is_valid, i])
208 | )
209 |
210 | if len(roc_list) < y_true.shape[1]:
211 | print("Some target is missing!")
212 | print("Missing ratio: %f" % (1 - float(len(roc_list)) / y_true.shape[1]))
213 |
214 | return sum(roc_list) / len(roc_list)
215 |
216 | def eval_train_dataset(self) -> float:
217 | return self._eval(loader=self.train_loader)
218 |
219 | def eval_val_dataset(self) -> float:
220 | return self._eval(loader=self.val_loader)
221 |
222 | def eval_test_dataset(self) -> float:
223 | return self._eval(loader=self.test_loader)
224 |
--------------------------------------------------------------------------------
/src/validation/task/task.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import dataclasses
3 | from typing import Dict
4 | import torch
5 | from torch.utils.data import DataLoader, Dataset
6 | from config.validation_config import ValidationConfig
7 | from logger import CombinedLogger
8 |
9 |
10 | @dataclasses.dataclass
11 | class Task(abc.ABC):
12 | config: ValidationConfig
13 |
14 | @abc.abstractmethod
15 | def run(self, model: torch.nn.Module, device: torch.device) -> Dict:
16 | pass
17 |
18 |
19 | @dataclasses.dataclass
20 | class TrainValTestTask(Task):
21 | config: ValidationConfig
22 | model: torch.nn.Module
23 | device: torch.device
24 | optimizer: torch.optim.Optimizer
25 | train_dataset: Dataset
26 | val_dataset: Dataset
27 | test_dataset: Dataset
28 | criterion_type: str
29 | # logger: CombinedLogger
30 |
31 | @abc.abstractmethod
32 | def train(self) -> float:
33 | pass
34 |
35 | @abc.abstractmethod
36 | def _eval(self, loader: DataLoader) -> float:
37 | pass
38 |
39 | @abc.abstractmethod
40 | def eval_val_dataset(self) -> float:
41 | pass
42 |
43 | @abc.abstractmethod
44 | def eval_test_dataset(self) -> float:
45 | pass
46 |
--------------------------------------------------------------------------------
/src/validation/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Union
2 |
3 | import pandas as pd
4 |
5 | from config import Config
6 | from config.training_config import TrainingConfig
7 | from config.validation_config import ValidationConfig
8 | from datasets import MoleculeDataset
9 | from splitters import random_scaffold_split, random_split, scaffold_split
10 |
11 |
12 | def get_dataset_extraction(
13 | config: Union[TrainingConfig, ValidationConfig]
14 | ) -> MoleculeDataset:
15 | root: str = f"data/molecule_datasets/{config.dataset}/"
16 | return MoleculeDataset(root=root, dataset=config.dataset)
17 |
18 |
19 | def get_dataset_split(
20 | config: Union[TrainingConfig, ValidationConfig],
21 | dataset: MoleculeDataset,
22 | smiles_list: Optional[List[str]],
23 | ) -> Dict[str, MoleculeDataset]:
24 | if config.split == "scaffold":
25 | train_, val_, test_, smile_ = scaffold_split(
26 | dataset,
27 | smiles_list,
28 | null_value=0,
29 | frac_train=0.8,
30 | frac_valid=0.1,
31 | frac_test=0.1,
32 | return_smiles=True,
33 | )
34 | elif config.split == "random":
35 | train_, val_, test_, smile_ = random_split(
36 | dataset,
37 | null_value=0,
38 | frac_train=0.8,
39 | frac_valid=0.1,
40 | frac_test=0.1,
41 | seed=config.seed,
42 | smiles_list=smiles_list,
43 | )
44 | elif config.split == "random_scaffold":
45 | train_, val_, test_, smile_ = random_scaffold_split(
46 | dataset,
47 | smiles_list,
48 | null_value=0,
49 | frac_train=0.8,
50 | frac_valid=0.1,
51 | frac_test=0.1,
52 | seed=config.seed,
53 | return_smiles=True,
54 | )
55 | else:
56 | raise ValueError("Invalid split option.")
57 | return {
58 | "train": train_,
59 | "val": val_,
60 | "test": test_,
61 | "smiles": smile_,
62 | }
63 |
64 |
65 | def get_smiles_list(config: Config) -> List[str]:
66 | smiles_list = pd.read_csv(
67 | f"data/molecule_datasets/{config.dataset}/processed/smiles.csv",
68 | header=None,
69 | )
70 | return smiles_list[0].tolist()
71 |
--------------------------------------------------------------------------------