├── LICENSE ├── README.md ├── datasets ├── README.md └── dti_datasets │ ├── davis │ ├── preprocess.out │ └── preprocess.py │ └── kiba │ ├── preprocess.out │ └── preprocess.py ├── fig ├── baselines.png └── pipeline.png ├── legacy └── src_classification │ ├── pretrain_3D_hybrid_02.py │ ├── pretrain_3D_hybrid_03.py │ ├── run_pretrain_3D_hybrid_02.sh │ └── run_pretrain_3D_hybrid_03.sh ├── scripts_classification ├── run_fine_tuning_model.sh ├── submit_fine_tuning.sh ├── submit_pre_training_GraphMVP.sh ├── submit_pre_training_GraphMVP_hybrid.sh └── submit_pre_training_baselines.sh ├── scripts_regression ├── run_fine_tuning_model.sh ├── run_fine_tuning_model_DTI.sh ├── submit_fine_tuning.sh ├── submit_pre_training_GraphMVP.sh ├── submit_pre_training_GraphMVP_hybrid.sh └── submit_pre_training_baselines.sh ├── src_classification ├── GEOM_dataset_preparation.py ├── batch.py ├── config.py ├── dataloader.py ├── datasets │ ├── __init__.py │ ├── datasets_GPT.py │ ├── molecule_3D_dataset.py │ ├── molecule_3D_masking_dataset.py │ ├── molecule_contextual_datasets.py │ ├── molecule_contextual_datasets_utils.py │ ├── molecule_datasets.py │ ├── molecule_graphcl_dataset.py │ ├── molecule_graphcl_masking_dataset.py │ └── molecule_motif_datasets.py ├── models │ ├── __init__.py │ ├── auto_encoder.py │ ├── dti_model.py │ ├── molecule_gnn_model.py │ └── schnet.py ├── molecule_finetune.py ├── pretrain_AM.py ├── pretrain_CP.py ├── pretrain_Contextual.py ├── pretrain_EP.py ├── pretrain_GPT_GNN.py ├── pretrain_GraphCL.py ├── pretrain_GraphLoG.py ├── pretrain_GraphMVP.py ├── pretrain_GraphMVP_hybrid.py ├── pretrain_IG.py ├── pretrain_JOAO.py ├── pretrain_JOAOv2.py ├── pretrain_Motif.py ├── run_molecule_finetune.sh ├── run_pretrain_AM.sh ├── run_pretrain_CP.sh ├── run_pretrain_Contextual.sh ├── run_pretrain_EP.sh ├── run_pretrain_GPT_GNN.sh ├── run_pretrain_GraphCL.sh ├── run_pretrain_GraphLoG.sh ├── run_pretrain_GraphMVP.sh ├── run_pretrain_GraphMVP_hybrid.sh ├── run_pretrain_IG.sh ├── run_pretrain_JOAO.sh ├── run_pretrain_JOAOv2.sh ├── run_pretrain_Motif.sh ├── splitters.py └── util.py └── src_regression ├── GEOM_dataset_preparation.py ├── datasets_complete_feature ├── __init__.py ├── dti_datasets.py ├── molecule_datasets.py └── molecule_graphcl_dataset.py ├── dti_finetune.py ├── models_complete_feature ├── __init__.py └── molecule_gnn_model.py ├── molecule_finetune_regression.py ├── pretrain_AM.py ├── pretrain_CP.py ├── pretrain_GraphCL.py ├── pretrain_GraphMVP.py ├── pretrain_GraphMVP_hybrid.py ├── pretrain_JOAO.py ├── pretrain_JOAOv2.py ├── run_dti_finetune.sh ├── run_molecule_finetune_regression.sh ├── run_pretrain_AM.sh ├── run_pretrain_CP.sh ├── run_pretrain_GraphCL.sh ├── run_pretrain_GraphMVP.sh ├── run_pretrain_GraphMVP_hybrid.sh ├── run_pretrain_JOAO.sh ├── run_pretrain_JOAOv2.sh └── util_complete_feature.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shengchao Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pre-training Molecular Graph Representation with 3D Geometry 2 | 3 | **ICLR 2022** 4 | 5 | Authors: Shengchao Liu, Hanchen Wang, Weiyang Liu, Joan Lasenby, Hongyu Guo, Jian Tang 6 | 7 | [[Project Page](https://chao1224.github.io/GraphMVP)] 8 | [[Paper](https://openreview.net/forum?id=xQUe1pOKPam)] 9 | [[ArXiv](https://arxiv.org/abs/2110.07728)] 10 | [[Slides](https://drive.google.com/file/d/1-lDWtdgeEgTO009YVPzHK8f7yYbvQ1oY/view?usp=sharing)] 11 | [[Poster](https://drive.google.com/file/d/1L_XrlgfmCmycfGf47Dt6nnaKpZtiqiN-/view?usp=sharing)] 12 |
13 | [[NeurIPS SSL Workshop 2021](https://sslneurips21.github.io/)] 14 | [[ICLR GTRL Workshop 2022 (Spotlight)](https://gt-rl.github.io/)] 15 | 16 | This repository provides the source code for the ICLR'22 paper **Pre-training Molecular Graph Representation with 3D Geometry**, with the following task: 17 | - During pre-training, we consider both the 2D topology and 3D geometry. 18 | - During downstream, we consider tasks with 2D topology only. 19 | 20 | In the future, we will merge it into the [TorchDrug](https://github.com/DeepGraphLearning/torchdrug) package. 21 | 22 |

23 | 24 |

25 | 26 | ## Baselines 27 | For implementation, this repository also provides the following graph SSL baselines: 28 | - Generative Graph SSL: 29 | - [Edge Prediction (EdgePred)](https://proceedings.neurips.cc/paper/2017/file/5dd9db5e033da9c6fb5ba83c7a7ebea9-Paper.pdf) 30 | - [AttributeMasking (AttrMask)](https://openreview.net/forum?id=HJlWWJSFDH) 31 | - [GPT-GNN](https://arxiv.org/abs/2006.15437) 32 | - Contrastive Graph SSL: 33 | - [InfoGraph](https://openreview.net/pdf?id=r1lfF2NYvH) 34 | - [Context Prediction (ContextPred)](https://openreview.net/forum?id=HJlWWJSFDH) 35 | - [GraphLoG](http://proceedings.mlr.press/v139/xu21g/xu21g.pdf) 36 | - [Grover-Contextual](https://papers.nips.cc/paper/2020/hash/94aef38441efa3380a3bed3faf1f9d5d-Abstract.html) 37 | - [GraphCL](https://papers.nips.cc/paper/2020/file/3fe230348e9a12c13120749e3f9fa4cd-Paper.pdf) 38 | - [JOAO](https://arxiv.org/abs/2106.07594) 39 | - Predictive Graph SSL: 40 | - [Grover-Motif](https://papers.nips.cc/paper/2020/hash/94aef38441efa3380a3bed3faf1f9d5d-Abstract.html) 41 | 42 |

43 | 44 |

45 | 46 | ## Environments 47 | Install packages under conda env 48 | ```bash 49 | conda create -n GraphMVP python=3.7 50 | conda activate GraphMVP 51 | 52 | conda install -y -c rdkit rdkit 53 | conda install -y -c pytorch pytorch=1.9.1 54 | conda install -y numpy networkx scikit-learn 55 | pip install ase 56 | pip install git+https://github.com/bp-kelley/descriptastorus 57 | pip install ogb 58 | export TORCH=1.9.0 59 | export CUDA=cu102 # cu102, cu110 60 | 61 | wget https://data.pyg.org/whl/torch-${TORCH}%2B${CUDA}/torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl 62 | pip install torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl 63 | wget https://data.pyg.org/whl/torch-${TORCH}%2B${CUDA}/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl 64 | pip install torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl 65 | wget https://data.pyg.org/whl/torch-${TORCH}%2B${CUDA}/torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl 66 | pip install torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl 67 | pip install torch-geometric==1.7.2 68 | ``` 69 | 70 | ## Dataset Preprocessing 71 | 72 | For dataset download, please follow the instruction [here](https://github.com/chao1224/GraphMVP/tree/main/datasets). 73 | 74 | For data preprocessing (GEOM), please use the following commands: 75 | ``` 76 | cd src_classification 77 | python GEOM_dataset_preparation.py --n_mol 50000 --n_conf 5 --n_upper 1000 --data_folder $SLURM_TMPDIR 78 | cd .. 79 | 80 | cd src_regression 81 | python GEOM_dataset_preparation.py --n_mol 50000 --n_conf 5 --n_upper 1000 --data_folder $SLURM_TMPDIR 82 | cd .. 83 | 84 | mv $SLURM_TMPDIR/GEOM datasets 85 | ``` 86 | 87 | **Featurization**. We employ two sets of featurization methods on atoms. 88 | 1. For classification tasks, in order to follow the main molecular graph SSL research line, we use the same atom featurization methods (consider the atom types and chirality). 89 | 2. For regression tasks, results with the above two atom-level features are too bad. Thus, we consider more comprehensive features from OGB. 90 | 91 | ## Experiments 92 | 93 | ### Terminology specification 94 | 95 | In the latest scripts, we use `GraphMVP` for the trivial GraphMVP (Eq. 7 in the paper), and `GraphMVP_hybrid` includes two variants adding extra 2D SSL pretext tasks (Eq 8. in the paper). 96 | In the previous scripts, we call these two terms as `3D_hybrid_02_masking` and `3D_hybrid_03_masking` respectively. 97 | This could show up in some pre-trained log files [here](https://drive.google.com/drive/folders/1uPsBiQF3bfeCAXSDd4JfyXiTh-qxYfu6?usp=sharing). 98 | 99 | | GraphMVP | Latest scripts | Previous scripts | 100 | | :--: | :--: | :--: | 101 | | Eq. 7 | `GraphMVP` | `3D_hybrid_02_masking` | 102 | | Eq. 8 | `GraphMVP_hybrid` | `3D_hybrid_03_masking` | 103 | 104 | ### For GraphMVP pre-training 105 | 106 | Check the following scripts: 107 | - `scripts_classification/submit_pre_training_GraphMVP.sh` 108 | - `scripts_classification/submit_pre_training_GraphMVP_hybrid.sh` 109 | - `scripts_regression/submit_pre_training_GraphMVP.sh` 110 | - `scripts_regression/submit_pre_training_GraphMVP_hybrid.sh` 111 | 112 | The pre-trained model weights, training logs, and prediction files can be found [here](https://drive.google.com/drive/folders/1uPsBiQF3bfeCAXSDd4JfyXiTh-qxYfu6?usp=sharing). 113 | 114 | ### For Other SSL pre-training baselines 115 | 116 | Check the following scripts: 117 | - `scripts_classification/submit_pre_training_baselines.sh` 118 | - `scripts_regression/submit_pre_training_baselines.sh` 119 | 120 | ### For Downstream tasks 121 | 122 | Check the following scripts: 123 | - `scripts_classification/submit_fine_tuning.sh` 124 | - `scripts_regression/submit_fine_tuning.sh` 125 | 126 | ## Cite Us 127 | 128 | Feel free to cite this work if you find it useful to you! 129 | 130 | ``` 131 | @inproceedings{liu2022pretraining, 132 | title={Pre-training Molecular Graph Representation with 3D Geometry}, 133 | author={Shengchao Liu and Hanchen Wang and Weiyang Liu and Joan Lasenby and Hongyu Guo and Jian Tang}, 134 | booktitle={International Conference on Learning Representations}, 135 | year={2022}, 136 | url={https://openreview.net/forum?id=xQUe1pOKPam} 137 | } 138 | ``` 139 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | ## Geometric Ensemble Of Molecules (GEOM) 4 | 5 | ```bash 6 | mkdir -p GEOM/raw 7 | mkdir -p GEOM/processed 8 | ``` 9 | 10 | + GEOM: [Paper](https://arxiv.org/pdf/2006.05531v3.pdf), [GitHub](https://github.com/learningmatter-mit/geom) 11 | + Data Download: 12 | + [Not Used] [Drug Crude](https://dataverse.harvard.edu/api/access/datafile/4360331), 13 | [Drug Featurized](https://dataverse.harvard.edu/api/access/datafile/4327295), 14 | [QM9 Crude](https://dataverse.harvard.edu/api/access/datafile/4327190), 15 | [QM9 Featurized](https://dataverse.harvard.edu/api/access/datafile/4327191) 16 | 17 | + [Mainly Used] [RdKit Folder](https://dataverse.harvard.edu/api/access/datafile/4327252) 18 | ```bash 19 | wget https://dataverse.harvard.edu/api/access/datafile/4327252 20 | mv 4327252 rdkit_folder.tar.gz 21 | tar -xvf rdkit_folder.tar.gz 22 | ``` 23 | or do the following if you are using slurm system 24 | ``` 25 | cp rdkit_folder.tar.gz $SLURM_TMPDIR 26 | cd $SLURM_TMPDIR 27 | tar -xvf rdkit_folder.tar.gz 28 | ``` 29 | + over 33m conformations 30 | + over 430k molecules 31 | + 304,466 species contain experimental data for the inhibition of various pathogens 32 | + 133,258 are species from the QM9 33 | 34 | ## Chem Dataset 35 | 36 | ```bash 37 | wget http://snap.stanford.edu/gnn-pretrain/data/chem_dataset.zip 38 | unzip chem_dataset.zip 39 | mv dataset molecule_datasets 40 | ``` 41 | 42 | ## Other Chem Datasets 43 | 44 | - delaney/esol (already included) 45 | - lipophilicity (already included) 46 | - malaria 47 | - cep 48 | 49 | ``` 50 | wget -O malaria-processed.csv https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-03-malaria/malaria-processed.csv 51 | mkdir -p ./molecule_datasets/malaria/raw 52 | mv malaria-processed.csv ./molecule_datasets/malaria/raw/malaria.csv 53 | 54 | wget -O cep-processed.csv https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-02-cep-pce/cep-processed.csv 55 | mkdir -p ./molecule_datasets/cep/raw 56 | mv cep-processed.csv ./molecule_datasets/cep/raw/cep.csv 57 | ``` 58 | 59 | Then we copy them for the regression (more atom features). 60 | ``` 61 | mkdir -p ./molecule_datasets_regression/esol 62 | cp -r ./molecule_datasets/esol/raw ./molecule_datasets_regression/esol/ 63 | 64 | mkdir -p ./molecule_datasets_regression/lipophilicity 65 | cp -r ./molecule_datasets/lipophilicity/raw ./molecule_datasets_regression/lipophilicity/ 66 | 67 | mkdir -p ./molecule_datasets_regression/malaria 68 | cp -r ./molecule_datasets/malaria/raw ./molecule_datasets_regression/malaria/ 69 | 70 | mkdir -p ./molecule_datasets_regression/cep 71 | cp -r ./molecule_datasets/cep/raw ./molecule_datasets_regression/cep/ 72 | ``` 73 | 74 | ## Drug-Target Interaction 75 | 76 | - Davis 77 | - Kiba 78 | 79 | ``` 80 | mkdir -p dti_datasets 81 | cd dti_datasets 82 | ``` 83 | 84 | Then we can follow [DeepDTA](https://github.com/hkmztrk/DeepDTA). 85 | ``` 86 | git clone git@github.com:hkmztrk/DeepDTA.git 87 | cp -r DeepDTA/data/davis davis/ 88 | cp -r DeepDTA/data/kiba kiba/ 89 | 90 | cd davis 91 | python preprocess.py > preprocess.out 92 | cd .. 93 | 94 | cd kiba 95 | python preprocess.py > preprocess.out 96 | cd .. 97 | ``` 98 | -------------------------------------------------------------------------------- /datasets/dti_datasets/davis/preprocess.out: -------------------------------------------------------------------------------- 1 | convert data from DeepDTA for davis 2 | 68 drugs 68 unique drugs 3 | 4 | dataset: davis 5 | train_fold: 25046 6 | test_fold: 5010 7 | len(set(drugs)),len(set(prots)): 68 379 8 | -------------------------------------------------------------------------------- /datasets/dti_datasets/davis/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | import pickle 4 | import numpy as np 5 | from rdkit import Chem 6 | from rdkit.Chem import MolFromSmiles 7 | 8 | 9 | 10 | if __name__ == '__main__': 11 | # from DeepDTA data 12 | dataset = 'davis' 13 | 14 | print('convert data from DeepDTA for ', dataset) 15 | fpath = dataset + '/' 16 | train_fold = json.load(open(fpath + "folds/train_fold_setting1.txt")) 17 | train_fold = [ee for e in train_fold for ee in e] 18 | valid_fold = json.load(open(fpath + "folds/test_fold_setting1.txt")) 19 | ligands = json.load(open(fpath + "ligands_can.txt"), object_pairs_hook=OrderedDict) 20 | proteins = json.load(open(fpath + "proteins.txt"), object_pairs_hook=OrderedDict) 21 | affinity = pickle.load(open(fpath + "Y", "rb"), encoding='latin1') 22 | drugs = [] 23 | prots = [] 24 | 25 | for d in ligands.keys(): 26 | lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]), isomericSmiles=True) 27 | drugs.append(lg) 28 | for t in proteins.keys(): 29 | prots.append(proteins[t]) 30 | if dataset == 'davis': 31 | affinity = [-np.log10(y / 1e9) for y in affinity] 32 | affinity = np.asarray(affinity) 33 | opts = ['train', 'test'] 34 | 35 | print('{} drugs\t{} unique drugs'.format(len(drugs), len(set(drugs)))) 36 | 37 | with open('smiles.csv', 'w') as f: 38 | print('smiles', file=f) 39 | for drug in drugs: 40 | print(drug, file=f) 41 | with open('protein.csv', 'w') as f: 42 | print('protein', file=f) 43 | for prot in prots: 44 | print(prot, file=f) 45 | 46 | for opt in opts: 47 | rows, cols = np.where(np.isnan(affinity) == False) 48 | if opt == 'train': 49 | rows, cols = rows[train_fold], cols[train_fold] 50 | elif opt == 'test': 51 | rows, cols = rows[valid_fold], cols[valid_fold] 52 | # with open(opt + '.csv', 'w') as f: 53 | # f.write('compound_iso_smiles,target_sequence,affinity\n') 54 | # for pair_ind in range(len(rows)): 55 | # ls = [] 56 | # ls += [drugs[rows[pair_ind]]] 57 | # ls += [prots[cols[pair_ind]]] 58 | # ls += [affinity[rows[pair_ind], cols[pair_ind]]] 59 | # f.write(','.join(map(str, ls)) + '\n') 60 | with open(opt + '.csv', 'w') as f: 61 | f.write('smiles_id,target_id,affinity\n') 62 | for pair_ind in range(len(rows)): 63 | ls = [] 64 | ls += [rows[pair_ind]] 65 | ls += [cols[pair_ind]] 66 | ls += [affinity[rows[pair_ind], cols[pair_ind]]] 67 | f.write(','.join(map(str, ls)) + '\n') 68 | 69 | print('\ndataset:', dataset) 70 | print('train_fold:', len(train_fold)) 71 | print('test_fold:', len(valid_fold)) 72 | print('len(set(drugs)),len(set(prots)):', len(set(drugs)), len(set(prots))) 73 | -------------------------------------------------------------------------------- /datasets/dti_datasets/kiba/preprocess.out: -------------------------------------------------------------------------------- 1 | convert data from DeepDTA for kiba 2 | 2111 drugs 2068 unique drugs 3 | 4 | dataset: kiba 5 | train_fold: 98545 6 | test_fold: 19709 7 | len(set(drugs)),len(set(prots)): 2068 229 8 | -------------------------------------------------------------------------------- /datasets/dti_datasets/kiba/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | import pickle 4 | import numpy as np 5 | from rdkit import Chem 6 | from rdkit.Chem import MolFromSmiles 7 | 8 | 9 | 10 | if __name__ == '__main__': 11 | # from DeepDTA data 12 | dataset = 'kiba' 13 | 14 | print('convert data from DeepDTA for ', dataset) 15 | fpath = dataset + '/' 16 | train_fold = json.load(open(fpath + "folds/train_fold_setting1.txt")) 17 | train_fold = [ee for e in train_fold for ee in e] 18 | valid_fold = json.load(open(fpath + "folds/test_fold_setting1.txt")) 19 | ligands = json.load(open(fpath + "ligands_can.txt"), object_pairs_hook=OrderedDict) 20 | proteins = json.load(open(fpath + "proteins.txt"), object_pairs_hook=OrderedDict) 21 | affinity = pickle.load(open(fpath + "Y", "rb"), encoding='latin1') 22 | drugs = [] 23 | prots = [] 24 | 25 | for d in ligands.keys(): 26 | lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]), isomericSmiles=True) 27 | drugs.append(lg) 28 | for t in proteins.keys(): 29 | prots.append(proteins[t]) 30 | if dataset == 'davis': 31 | affinity = [-np.log10(y / 1e9) for y in affinity] 32 | affinity = np.asarray(affinity) 33 | opts = ['train', 'test'] 34 | 35 | print('{} drugs\t{} unique drugs'.format(len(drugs), len(set(drugs)))) 36 | 37 | with open('smiles.csv', 'w') as f: 38 | print('smiles', file=f) 39 | for drug in drugs: 40 | print(drug, file=f) 41 | with open('protein.csv', 'w') as f: 42 | print('protein', file=f) 43 | for prot in prots: 44 | print(prot, file=f) 45 | 46 | for opt in opts: 47 | rows, cols = np.where(np.isnan(affinity) == False) 48 | if opt == 'train': 49 | rows, cols = rows[train_fold], cols[train_fold] 50 | elif opt == 'test': 51 | rows, cols = rows[valid_fold], cols[valid_fold] 52 | # with open(opt + '.csv', 'w') as f: 53 | # f.write('compound_iso_smiles,target_sequence,affinity\n') 54 | # for pair_ind in range(len(rows)): 55 | # ls = [] 56 | # ls += [drugs[rows[pair_ind]]] 57 | # ls += [prots[cols[pair_ind]]] 58 | # ls += [affinity[rows[pair_ind], cols[pair_ind]]] 59 | # f.write(','.join(map(str, ls)) + '\n') 60 | with open(opt + '.csv', 'w') as f: 61 | f.write('smiles_id,target_id,affinity\n') 62 | for pair_ind in range(len(rows)): 63 | ls = [] 64 | ls += [rows[pair_ind]] 65 | ls += [cols[pair_ind]] 66 | ls += [affinity[rows[pair_ind], cols[pair_ind]]] 67 | f.write(','.join(map(str, ls)) + '\n') 68 | 69 | print('\ndataset:', dataset) 70 | print('train_fold:', len(train_fold)) 71 | print('test_fold:', len(valid_fold)) 72 | print('len(set(drugs)),len(set(prots)):', len(set(drugs)), len(set(prots))) 73 | -------------------------------------------------------------------------------- /fig/baselines.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chao1224/GraphMVP/80c42aff21ece462c951cfda110f8eaf866aea0c/fig/baselines.png -------------------------------------------------------------------------------- /fig/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chao1224/GraphMVP/80c42aff21ece462c951cfda110f8eaf866aea0c/fig/pipeline.png -------------------------------------------------------------------------------- /legacy/src_classification/run_pretrain_3D_hybrid_02.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate PC3D 5 | conda deactivate 6 | conda activate PC3D 7 | 8 | #cp -r ../datasets/GEOM_* $SLURM_TMPDIR 9 | # 10 | #echo $@ 11 | #date 12 | #echo "start" 13 | #python pretrain_3D_hybrid_02.py --input_data_dir="$SLURM_TMPDIR" $@ 14 | #echo "end" 15 | #date 16 | 17 | 18 | echo $@ 19 | date 20 | echo "start" 21 | python pretrain_3D_hybrid_02.py $@ 22 | echo "end" 23 | date 24 | -------------------------------------------------------------------------------- /legacy/src_classification/run_pretrain_3D_hybrid_03.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate PC3D 5 | conda deactivate 6 | conda activate PC3D 7 | 8 | echo $@ 9 | date 10 | echo "start" 11 | python pretrain_3D_hybrid_03.py $@ 12 | echo "end" 13 | date 14 | -------------------------------------------------------------------------------- /scripts_classification/run_fine_tuning_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --account=rrg-bengioy-ad 4 | #SBATCH --cpus-per-task=8 5 | #SBATCH --gres=gpu:v100l:1 6 | #SBATCH --mem=32G 7 | #SBATCH --time=2:59:00 8 | #SBATCH --ntasks=1 9 | #SBATCH --array=0-2%3 10 | #SBATCH --output=logs/%j.out 11 | 12 | 13 | ###############SBATCH --gres=gpu:v100l:1 14 | 15 | cd src 16 | 17 | echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID 18 | export dataset_list=(tox21 toxcast clintox bbbp sider muv hiv bace) 19 | export seed_list=(0 1 2 3 4 5 6 7 8 9) 20 | export batch_size=256 21 | export mode=$1 22 | export seed=${seed_list[$SLURM_ARRAY_TASK_ID]} 23 | 24 | 25 | 26 | 27 | 28 | if [ "$mode" == "random" ]; then 29 | 30 | for dataset in "${dataset_list[@]}"; do 31 | export folder="$mode"/"$seed" 32 | mkdir -p ../output/"$folder" 33 | mkdir -p ../output/"$folder"/"$dataset" 34 | 35 | export output_path=../output/"$folder"/"$dataset".out 36 | export output_model_dir=../output/"$folder"/"$dataset" 37 | 38 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path" 39 | echo `date` >> "$output_path" 40 | 41 | bash ./run_molecule_finetune.sh \ 42 | --dataset="$dataset" --runseed="$seed" --eval_train --batch_size="$batch_size" \ 43 | --dropout_ratio=0.5 \ 44 | --output_model_dir="$output_model_dir" \ 45 | >> "$output_path" 46 | 47 | echo `date` >> "$output_path" 48 | done 49 | 50 | 51 | 52 | 53 | else 54 | 55 | for dataset in "${dataset_list[@]}"; do 56 | export folder="$mode"/"$seed" 57 | mkdir -p ../output/"$folder" 58 | mkdir -p ../output/"$folder"/"$dataset" 59 | 60 | export output_path=../output/"$folder"/"$dataset".out 61 | # export output_model_dir=../output/"$folder"/"$dataset" 62 | export input_model_file=../output/"$mode"/pretraining_model.pth 63 | 64 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path" 65 | echo `date` >> "$output_path" 66 | 67 | bash ./run_molecule_finetune.sh \ 68 | --dataset="$dataset" --runseed="$seed" --eval_train --batch_size="$batch_size" \ 69 | --dropout_ratio=0.5 \ 70 | --input_model_file="$input_model_file" \ 71 | >> "$output_path" 72 | # --input_model_file="$input_model_file" --output_model_dir="$output_model_dir" \ 73 | # >> "$output_path" 74 | 75 | echo `date` >> "$output_path" 76 | done 77 | 78 | fi 79 | 80 | -------------------------------------------------------------------------------- /scripts_classification/submit_fine_tuning.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../src_classification 4 | 5 | mode_list=( 6 | random 7 | EP/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0 8 | AM/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0 9 | IG/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0 10 | CP/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0 11 | GraphLoG/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0 12 | Motif/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0 13 | Contextual/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0 14 | GraphCL/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0 15 | JOAO/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0 16 | JOAOv2/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0 17 | 18 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0 19 | GraphMVP_hybrid/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1_AM_1/6_51_10_0.1/0.15_EBM_dot_prod_0.05_normalize_l2_detach_target_2_100_0 20 | GraphMVP_hybrid/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1_CP_0.1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0 21 | ) 22 | 23 | for mode in "${mode_list[@]}"; do 24 | echo "$mode" 25 | ls output/"$mode" 26 | 27 | sbatch run_fine_tuning_model.sh "$mode" 28 | 29 | echo 30 | 31 | done 32 | 33 | 34 | 35 | # Below is for ablation study 36 | mode_list=( 37 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0 38 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.15_InfoNCE_dot_prod_0.2_normalize_l2_detach_target_2_100_0 39 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_0/6_51_10_0.1/0.15_InfoNCE_dot_prod_0.2_normalize_l2_detach_target_2_100_0 40 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_0/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0 41 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_0_VAE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0 42 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_AE_1/6_51_10_0.1/0.15_InfoNCE_dot_prod_0.2_normalize_l2_detach_target_2_100_0 43 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_AE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0 44 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_0_AE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0 45 | ) 46 | -------------------------------------------------------------------------------- /scripts_classification/submit_pre_training_GraphMVP.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../src_classification 4 | 5 | 6 | export mode=GraphMVP 7 | export dataset_list=(GEOM_3D_nmol50000_nconf5_nupper1000) 8 | export epochs=100 9 | export time=9 10 | 11 | 12 | # For SchNet and GNN 13 | export schnet_lr_scale_list=(0.1) 14 | export num_interactions=6 15 | export num_gaussians=51 16 | export cutoff=10 17 | export dropout_ratio_list=(0) 18 | export SSL_masking_ratio_list=(0.15 0.3) 19 | 20 | 21 | 22 | # For CL 23 | # export CL_similarity_metric_list=(InfoNCE_dot_prod EBM_dot_prod) 24 | export CL_similarity_metric_list=(EBM_dot_prod) 25 | export T_list=(0.1 0.2 0.5 1 2) 26 | export normalize_list=(normalize) 27 | 28 | 29 | 30 | # For VAE 31 | export AE_model=VAE 32 | # export AE_loss_list=(l1 l2 cosine) 33 | export AE_loss_list=(l2) 34 | # export detach_list=(detach_target no_detach_target) 35 | export detach_list=(detach_target) 36 | # export beta_list=(0.1 1 2) 37 | export beta_list=(1 2) 38 | 39 | 40 | 41 | 42 | # For CL + VAE 43 | export alpha_1_list=(1) 44 | export alpha_2_list=(0.1 1) 45 | 46 | 47 | 48 | export SSL_masking_ratio_list=(0) 49 | export CL_similarity_metric_list=(EBM_dot_prod) 50 | export T_list=(0.1 0.2) 51 | 52 | 53 | 54 | 55 | for dataset in "${dataset_list[@]}"; do 56 | for SSL_masking_ratio in "${SSL_masking_ratio_list[@]}"; do 57 | 58 | for alpha_1 in "${alpha_1_list[@]}"; do 59 | for alpha_2 in "${alpha_2_list[@]}"; do 60 | for CL_similarity_metric in "${CL_similarity_metric_list[@]}"; do 61 | for normalize in "${normalize_list[@]}"; do 62 | for T in "${T_list[@]}"; do 63 | for AE_loss in "${AE_loss_list[@]}"; do 64 | for detach in "${detach_list[@]}"; do 65 | for beta in "${beta_list[@]}"; do 66 | 67 | 68 | for schnet_lr_scale in "${schnet_lr_scale_list[@]}"; do 69 | for dropout_ratio in "${dropout_ratio_list[@]}"; do 70 | export folder="$mode"/"$dataset"/CL_"$alpha_1"_"$AE_model"_"$alpha_2"/"$num_interactions"_"$num_gaussians"_"$cutoff"_"$schnet_lr_scale"/"$SSL_masking_ratio"_"$CL_similarity_metric"_"$T"_"$normalize"_"$AE_loss"_"$detach"_"$beta"_"$epochs"_"$dropout_ratio" 71 | 72 | echo "$folder" 73 | mkdir -p ../output/"$folder" 74 | ls ../output/"$folder" 75 | 76 | export output_file=../output/"$folder"/pretraining.out 77 | export output_model_dir=../output/"$folder"/pretraining 78 | 79 | 80 | echo "$output_file" undone 81 | 82 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":00:00 --account=rrg-bengioy-ad --qos=high --job-name=CL_VAE_"$time" \ 83 | --output="$output_file" \ 84 | ./run_pretrain_"$mode".sh \ 85 | --epochs="$epochs" \ 86 | --dataset="$dataset" \ 87 | --batch_size=256 \ 88 | --SSL_masking_ratio="$SSL_masking_ratio" \ 89 | --CL_similarity_metric="$CL_similarity_metric" --T="$T" --"$normalize" \ 90 | --AE_model="$AE_model" --AE_loss="$AE_loss" --"$detach" --beta="$beta" \ 91 | --alpha_1="$alpha_1" --alpha_2="$alpha_2" \ 92 | --num_interactions="$num_interactions" --num_gaussians="$num_gaussians" --cutoff="$cutoff" --schnet_lr_scale="$schnet_lr_scale" \ 93 | --dropout_ratio="$dropout_ratio" --num_workers=8 \ 94 | --output_model_dir="$output_model_dir" 95 | 96 | done 97 | done 98 | 99 | done 100 | done 101 | done 102 | done 103 | done 104 | done 105 | done 106 | done 107 | done 108 | done 109 | -------------------------------------------------------------------------------- /scripts_classification/submit_pre_training_GraphMVP_hybrid.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../src_classification 4 | 5 | 6 | export mode=GraphMVP_hybrid 7 | export dataset_list=(GEOM_3D_nmol50000_nconf5_nupper1000) 8 | export epochs=100 9 | 10 | 11 | 12 | 13 | 14 | # For SchNet and GNN 15 | export schnet_lr_scale_list=(0.1) 16 | export num_interactions=6 17 | export num_gaussians=51 18 | export cutoff=10 19 | export dropout_ratio_list=(0) 20 | export SSL_masking_ratio_list=(0.15 0.3) 21 | 22 | 23 | 24 | 25 | # For CL 26 | # export CL_similarity_metric_list=(InfoNCE_dot_prod EBM_dot_prod) 27 | export CL_similarity_metric_list=(EBM_dot_prod) 28 | # export T_list=(0.05 0.1 0.2 0.5 1 2) 29 | export T_list=(0.1 0.2) 30 | export normalize_list=(normalize) 31 | 32 | 33 | 34 | # For VAE 35 | export AE_model=VAE 36 | # export AE_loss_list=(l1 l2 cosine) 37 | export AE_loss_list=(l2) 38 | # export detach_list=(detach_target no_detach_target) 39 | export detach_list=(detach_target) 40 | # export beta_list=(0.1 1 2) 41 | export beta_list=(1 2) 42 | 43 | 44 | 45 | 46 | 47 | # For CL + VAE 48 | export alpha_1_list=(1) 49 | export alpha_2_list=(0.1 1) 50 | export alpha_3_list=(0.1 1) 51 | export SSL_2D_mode_list=( CP AM IG JOAOv2 JOAO GraphCL) 52 | export time_list=( 3 3 3 6 6 3) 53 | export time_list=( 9 9 6 12 12 6) 54 | 55 | 56 | 57 | 58 | for dataset in "${dataset_list[@]}"; do 59 | for SSL_masking_ratio in "${SSL_masking_ratio_list[@]}"; do 60 | 61 | for i in {0..1}; do 62 | SSL_2D_mode=${SSL_2D_mode_list[$i]} 63 | time=${time_list[$i]} 64 | 65 | for alpha_3 in "${alpha_3_list[@]}"; do 66 | 67 | for alpha_1 in "${alpha_1_list[@]}"; do 68 | for alpha_2 in "${alpha_2_list[@]}"; do 69 | for CL_similarity_metric in "${CL_similarity_metric_list[@]}"; do 70 | for normalize in "${normalize_list[@]}"; do 71 | for T in "${T_list[@]}"; do 72 | for AE_loss in "${AE_loss_list[@]}"; do 73 | for detach in "${detach_list[@]}"; do 74 | for beta in "${beta_list[@]}"; do 75 | 76 | 77 | for schnet_lr_scale in "${schnet_lr_scale_list[@]}"; do 78 | for dropout_ratio in "${dropout_ratio_list[@]}"; do 79 | export folder="$mode"/"$dataset"/CL_"$alpha_1"_"$AE_model"_"$alpha_2"_"$SSL_2D_mode"_"$alpha_3"/"$num_interactions"_"$num_gaussians"_"$cutoff"_"$schnet_lr_scale"/"$SSL_masking_ratio"_"$CL_similarity_metric"_"$T"_"$normalize"_"$AE_loss"_"$detach"_"$beta"_"$epochs"_"$dropout_ratio" 80 | 81 | echo "$folder" 82 | mkdir -p ../output/"$folder" 83 | ls ../output/"$folder" 84 | 85 | export output_file=../output/"$folder"/pretraining.out 86 | export output_model_dir=../output/"$folder"/pretraining 87 | 88 | echo "$output_model_dir"_model_final.pth undone 89 | ls "$output_model_dir"* 90 | echo "$output_file" 91 | ls ../output/"$folder" 92 | rm ../output/"$folder"/* 93 | 94 | 95 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":00:00 --account=rrg-bengioy-ad --qos=high --job-name=CL_VAE_"$SSL_2D_mode"_"$time" \ 96 | --output="$output_file" \ 97 | ./run_pretrain_"$mode".sh \ 98 | --epochs="$epochs" \ 99 | --dataset="$dataset" \ 100 | --batch_size=256 \ 101 | --SSL_masking_ratio="$SSL_masking_ratio" \ 102 | --CL_similarity_metric="$CL_similarity_metric" --T="$T" --"$normalize" \ 103 | --AE_model="$AE_model" --AE_loss="$AE_loss" --"$detach" --beta="$beta" \ 104 | --alpha_1="$alpha_1" --alpha_2="$alpha_2" \ 105 | --SSL_2D_mode="$SSL_2D_mode" --alpha_3="$alpha_3" \ 106 | --num_interactions="$num_interactions" --num_gaussians="$num_gaussians" --cutoff="$cutoff" --schnet_lr_scale="$schnet_lr_scale" \ 107 | --dropout_ratio="$dropout_ratio" --num_workers=8 \ 108 | --output_model_dir="$output_model_dir" 109 | 110 | echo 111 | 112 | done 113 | done 114 | 115 | done 116 | done 117 | done 118 | done 119 | done 120 | done 121 | done 122 | done 123 | done 124 | 125 | done 126 | done 127 | done 128 | -------------------------------------------------------------------------------- /scripts_classification/submit_pre_training_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../src_classification 4 | 5 | export dataset=GEOM_2D_nmol50000_nconf1_nupper1000 6 | export dropout_ratio=0 7 | export epochs=100 8 | 9 | 10 | export time=3 11 | export mode_list=(EP IG AM CP GraphLoG Motif Contextual GraphCL JOAO JOAOv2) 12 | 13 | export time=12 14 | export mode_list=(GPT_GNN) 15 | 16 | 17 | for mode in "${mode_list[@]}"; do 18 | export folder="$mode"/"$dataset"/epochs_"$epochs"_"$dropout_ratio" 19 | echo "$folder" 20 | 21 | mkdir -p ../output/"$folder" 22 | 23 | export output_file=../output/"$folder"/pretraining.out 24 | export output_model_dir=../output/"$folder"/pretraining 25 | 26 | 27 | if [[ ! -f "$output_file" ]]; then 28 | echo "$folder" undone 29 | 30 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":00:00 --account=rrg-bengioy-ad --qos=high --job-name=baselines \ 31 | --output="$output_file" \ 32 | ./run_pretrain_"$mode".sh \ 33 | --epochs="$epochs" \ 34 | --dataset="$dataset" \ 35 | --batch_size=256 \ 36 | --dropout_ratio="$dropout_ratio" --num_workers=8 \ 37 | --output_model_dir="$output_model_dir" 38 | fi 39 | done 40 | -------------------------------------------------------------------------------- /scripts_regression/run_fine_tuning_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --account=rrg-bengioy-ad 4 | #SBATCH --cpus-per-task=8 5 | #SBATCH --gres=gpu:v100l:1 6 | #SBATCH --mem=32G 7 | #SBATCH --time=2:59:00 8 | #SBATCH --ntasks=1 9 | #SBATCH --array=0-2%3 10 | #SBATCH --output=logs/%j.out 11 | #SBATCH --job-name=reg 12 | 13 | 14 | ###############SBATCH --gres=gpu:v100l:1 15 | 16 | echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID 17 | export dataset_list=(esol freesolv lipophilicity malaria cep) 18 | export split_list=(scaffold random) 19 | export seed_list=(0 1 2 3 4 5 6 7 8 9) 20 | export batch_size=256 21 | export mode=$1 22 | export seed=${seed_list[$SLURM_ARRAY_TASK_ID]} 23 | 24 | 25 | 26 | 27 | 28 | if [ "$mode" == "random" ]; then 29 | 30 | for dataset in "${dataset_list[@]}"; do 31 | for split in "${split_list[@]}"; do 32 | export folder="$mode"/"$seed"/"$split" 33 | mkdir -p ./output/"$folder"/"$dataset" 34 | 35 | export output_path=./output/"$folder"/"$dataset".out 36 | export output_model_dir=./output/"$folder"/"$dataset" 37 | 38 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path" 39 | echo `date` >> "$output_path" 40 | 41 | bash ./run_molecule_finetune_regression.sh \ 42 | --dataset="$dataset" --runseed="$seed" --eval_train --batch_size="$batch_size" \ 43 | --dropout_ratio=0.2 --split="$split" \ 44 | --output_model_dir="$output_model_dir" \ 45 | >> "$output_path" 46 | 47 | echo `date` >> "$output_path" 48 | done 49 | done 50 | 51 | 52 | 53 | 54 | else 55 | 56 | for dataset in "${dataset_list[@]}"; do 57 | for split in "${split_list[@]}"; do 58 | export folder="$mode"/"$seed"/"$split" 59 | mkdir -p ./output/"$folder" 60 | mkdir -p ./output/"$folder"/"$dataset" 61 | 62 | export output_path=./output/"$folder"/"$dataset".out 63 | export output_model_dir=./output/"$folder"/"$dataset" 64 | export input_model_file=./output/"$mode"/pretraining_model.pth 65 | 66 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path" 67 | echo `date` >> "$output_path" 68 | 69 | bash ./run_molecule_finetune_regression.sh \ 70 | --dataset="$dataset" --runseed="$seed" --eval_train --batch_size="$batch_size" \ 71 | --dropout_ratio=0.2 --split="$split" \ 72 | --input_model_file="$input_model_file" \ 73 | >> "$output_path" 74 | 75 | echo `date` >> "$output_path" 76 | done 77 | done 78 | 79 | fi 80 | 81 | -------------------------------------------------------------------------------- /scripts_regression/run_fine_tuning_model_DTI.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --account=rrg-bengioy-ad 4 | #SBATCH --cpus-per-task=8 5 | #SBATCH --gres=gpu:v100l:1 6 | #SBATCH --mem=32G 7 | #SBATCH --time=5:55:00 8 | #SBATCH --ntasks=1 9 | #SBATCH --array=0-2%3 10 | #SBATCH --output=logs/%j.out 11 | #SBATCH --job-name=repurpose_complete 12 | 13 | 14 | ###############SBATCH --gres=gpu:v100l:1 15 | 16 | echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID 17 | export dataset_list=(davis kiba) 18 | export seed_list=(0 1 2 3 4 5 6 7 8 9) 19 | export batch_size=256 20 | export mode=$1 21 | export seed=${seed_list[$SLURM_ARRAY_TASK_ID]} 22 | 23 | 24 | 25 | 26 | 27 | if [ "$mode" == "random" ]; then 28 | 29 | for dataset in "${dataset_list[@]}"; do 30 | export folder="$mode"/"$seed" 31 | mkdir -p ./output/"$folder" 32 | mkdir -p ./output/"$folder"/"$dataset" 33 | 34 | export output_path=./output/"$folder"/"$dataset".out 35 | export output_model_dir=./output/"$folder"/"$dataset" 36 | 37 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path" 38 | echo `date` >> "$output_path" 39 | 40 | bash ./run_dti_finetune.sh \ 41 | --dataset="$dataset" --runseed="$seed" --batch_size="$batch_size" \ 42 | >> "$output_path" 43 | 44 | echo `date` >> "$output_path" 45 | done 46 | 47 | 48 | 49 | 50 | else 51 | 52 | for dataset in "${dataset_list[@]}"; do 53 | export folder="$mode"/"$seed" 54 | mkdir -p ./output/"$folder" 55 | mkdir -p ./output/"$folder"/"$dataset" 56 | 57 | export output_path=./output/"$folder"/"$dataset".out 58 | export output_model_dir=./output/"$folder"/"$dataset" 59 | export input_model_file=./output/"$mode"/pretraining_model.pth 60 | 61 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path" 62 | echo `date` >> "$output_path" 63 | 64 | bash ./run_dti_finetune.sh \ 65 | --dataset="$dataset" --runseed="$seed" --batch_size="$batch_size" \ 66 | --input_model_file="$input_model_file" \ 67 | >> "$output_path" 68 | 69 | echo `date` >> "$output_path" 70 | done 71 | 72 | fi 73 | 74 | -------------------------------------------------------------------------------- /scripts_regression/submit_fine_tuning.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../src_classification 4 | 5 | mode_list=( 6 | random 7 | EP/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0 8 | AM/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0 9 | IG/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0 10 | CP/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0 11 | GraphLoG/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0 12 | Motif/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0 13 | Contextual/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0 14 | GraphCL/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0 15 | JOAO/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0 16 | JOAOv2/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0 17 | 18 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000_morefeat/CL_1_VAE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0 19 | GraphMVP_hybrid/GEOM_3D_nmol50000_nconf5_nupper1000_morefeat/CL_1_VAE_1_AM_1/6_51_10_0.1/0.15_EBM_dot_prod_0.05_normalize_l2_detach_target_2_100_0 20 | GraphMVP_hybrid/GEOM_3D_nmol50000_nconf5_nupper1000_morefeat/CL_1_VAE_1_CP_0.1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0 21 | ) 22 | 23 | for mode in "${mode_list[@]}"; do 24 | echo "$mode" 25 | ls output/"$mode" 26 | 27 | sbatch run_fine_tuning_model.sh "$mode" 28 | 29 | echo 30 | 31 | done 32 | -------------------------------------------------------------------------------- /scripts_regression/submit_pre_training_GraphMVP.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../src_classification 4 | 5 | 6 | export mode=GraphMVP 7 | export dataset_list=(GEOM_3D_nmol50000_nconf5_nupper1000_morefeat) 8 | export epochs=100 9 | export time=9 10 | 11 | 12 | # For SchNet and GNN 13 | export schnet_lr_scale_list=(0.1) 14 | export num_interactions=6 15 | export num_gaussians=51 16 | export cutoff=10 17 | export dropout_ratio_list=(0) 18 | export SSL_masking_ratio_list=(0.15 0.3) 19 | 20 | 21 | 22 | # For CL 23 | # export CL_similarity_metric_list=(InfoNCE_dot_prod EBM_dot_prod) 24 | export CL_similarity_metric_list=(EBM_dot_prod) 25 | export T_list=(0.1 0.2 0.5 1 2) 26 | export normalize_list=(normalize) 27 | 28 | 29 | 30 | # For VAE 31 | export AE_model=VAE 32 | # export AE_loss_list=(l1 l2 cosine) 33 | export AE_loss_list=(l2) 34 | # export detach_list=(detach_target no_detach_target) 35 | export detach_list=(detach_target) 36 | # export beta_list=(0.1 1 2) 37 | export beta_list=(1 2) 38 | 39 | 40 | 41 | 42 | # For CL + VAE 43 | export alpha_1_list=(1) 44 | export alpha_2_list=(0.1 1) 45 | 46 | 47 | 48 | export SSL_masking_ratio_list=(0) 49 | export CL_similarity_metric_list=(EBM_dot_prod) 50 | export T_list=(0.1 0.2) 51 | 52 | 53 | 54 | 55 | for dataset in "${dataset_list[@]}"; do 56 | for SSL_masking_ratio in "${SSL_masking_ratio_list[@]}"; do 57 | 58 | for alpha_1 in "${alpha_1_list[@]}"; do 59 | for alpha_2 in "${alpha_2_list[@]}"; do 60 | for CL_similarity_metric in "${CL_similarity_metric_list[@]}"; do 61 | for normalize in "${normalize_list[@]}"; do 62 | for T in "${T_list[@]}"; do 63 | for AE_loss in "${AE_loss_list[@]}"; do 64 | for detach in "${detach_list[@]}"; do 65 | for beta in "${beta_list[@]}"; do 66 | 67 | 68 | for schnet_lr_scale in "${schnet_lr_scale_list[@]}"; do 69 | for dropout_ratio in "${dropout_ratio_list[@]}"; do 70 | export folder="$mode"/"$dataset"/CL_"$alpha_1"_"$AE_model"_"$alpha_2"/"$num_interactions"_"$num_gaussians"_"$cutoff"_"$schnet_lr_scale"/"$SSL_masking_ratio"_"$CL_similarity_metric"_"$T"_"$normalize"_"$AE_loss"_"$detach"_"$beta"_"$epochs"_"$dropout_ratio" 71 | 72 | echo "$folder" 73 | mkdir -p ../output/"$folder" 74 | ls ../output/"$folder" 75 | 76 | export output_file=../output/"$folder"/pretraining.out 77 | export output_model_dir=../output/"$folder"/pretraining 78 | 79 | 80 | echo "$output_file" undone 81 | 82 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":00:00 --account=rrg-bengioy-ad --qos=high --job-name=CL_VAE_"$time" \ 83 | --output="$output_file" \ 84 | ./run_pretrain_"$mode".sh \ 85 | --epochs="$epochs" \ 86 | --dataset="$dataset" \ 87 | --batch_size=256 \ 88 | --SSL_masking_ratio="$SSL_masking_ratio" \ 89 | --CL_similarity_metric="$CL_similarity_metric" --T="$T" --"$normalize" \ 90 | --AE_model="$AE_model" --AE_loss="$AE_loss" --"$detach" --beta="$beta" \ 91 | --alpha_1="$alpha_1" --alpha_2="$alpha_2" \ 92 | --num_interactions="$num_interactions" --num_gaussians="$num_gaussians" --cutoff="$cutoff" --schnet_lr_scale="$schnet_lr_scale" \ 93 | --dropout_ratio="$dropout_ratio" --num_workers=8 \ 94 | --output_model_dir="$output_model_dir" 95 | 96 | done 97 | done 98 | 99 | done 100 | done 101 | done 102 | done 103 | done 104 | done 105 | done 106 | done 107 | done 108 | done 109 | -------------------------------------------------------------------------------- /scripts_regression/submit_pre_training_GraphMVP_hybrid.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../src_classification 4 | 5 | 6 | export mode=GraphMVP_hybrid 7 | export dataset_list=(GEOM_3D_nmol50000_nconf5_nupper1000_morefeat) 8 | export epochs=100 9 | 10 | 11 | 12 | 13 | 14 | # For SchNet and GNN 15 | export schnet_lr_scale_list=(0.1) 16 | export num_interactions=6 17 | export num_gaussians=51 18 | export cutoff=10 19 | export dropout_ratio_list=(0) 20 | export SSL_masking_ratio_list=(0.15 0.3) 21 | 22 | 23 | 24 | 25 | # For CL 26 | # export CL_similarity_metric_list=(InfoNCE_dot_prod EBM_dot_prod) 27 | export CL_similarity_metric_list=(EBM_dot_prod) 28 | # export T_list=(0.05 0.1 0.2 0.5 1 2) 29 | export T_list=(0.1 0.2) 30 | export normalize_list=(normalize) 31 | 32 | 33 | 34 | # For VAE 35 | export AE_model=VAE 36 | # export AE_loss_list=(l1 l2 cosine) 37 | export AE_loss_list=(l2) 38 | # export detach_list=(detach_target no_detach_target) 39 | export detach_list=(detach_target) 40 | # export beta_list=(0.1 1 2) 41 | export beta_list=(1 2) 42 | 43 | 44 | 45 | 46 | 47 | # For CL + VAE 48 | export alpha_1_list=(1) 49 | export alpha_2_list=(0.1 1) 50 | export alpha_3_list=(0.1 1) 51 | export SSL_2D_mode_list=( CP AM IG JOAOv2 JOAO GraphCL) 52 | export time_list=( 3 3 3 6 6 3) 53 | export time_list=( 9 9 6 12 12 6) 54 | 55 | 56 | 57 | 58 | for dataset in "${dataset_list[@]}"; do 59 | for SSL_masking_ratio in "${SSL_masking_ratio_list[@]}"; do 60 | 61 | for i in {0..1}; do 62 | SSL_2D_mode=${SSL_2D_mode_list[$i]} 63 | time=${time_list[$i]} 64 | 65 | for alpha_3 in "${alpha_3_list[@]}"; do 66 | 67 | for alpha_1 in "${alpha_1_list[@]}"; do 68 | for alpha_2 in "${alpha_2_list[@]}"; do 69 | for CL_similarity_metric in "${CL_similarity_metric_list[@]}"; do 70 | for normalize in "${normalize_list[@]}"; do 71 | for T in "${T_list[@]}"; do 72 | for AE_loss in "${AE_loss_list[@]}"; do 73 | for detach in "${detach_list[@]}"; do 74 | for beta in "${beta_list[@]}"; do 75 | 76 | 77 | for schnet_lr_scale in "${schnet_lr_scale_list[@]}"; do 78 | for dropout_ratio in "${dropout_ratio_list[@]}"; do 79 | export folder="$mode"/"$dataset"/CL_"$alpha_1"_"$AE_model"_"$alpha_2"_"$SSL_2D_mode"_"$alpha_3"/"$num_interactions"_"$num_gaussians"_"$cutoff"_"$schnet_lr_scale"/"$SSL_masking_ratio"_"$CL_similarity_metric"_"$T"_"$normalize"_"$AE_loss"_"$detach"_"$beta"_"$epochs"_"$dropout_ratio" 80 | 81 | echo "$folder" 82 | mkdir -p ../output/"$folder" 83 | ls ../output/"$folder" 84 | 85 | export output_file=../output/"$folder"/pretraining.out 86 | export output_model_dir=../output/"$folder"/pretraining 87 | 88 | echo "$output_model_dir"_model_final.pth undone 89 | ls "$output_model_dir"* 90 | echo "$output_file" 91 | ls ../output/"$folder" 92 | rm ../output/"$folder"/* 93 | 94 | 95 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":00:00 --account=rrg-bengioy-ad --qos=high --job-name=CL_VAE_"$SSL_2D_mode"_"$time" \ 96 | --output="$output_file" \ 97 | ./run_pretrain_"$mode".sh \ 98 | --epochs="$epochs" \ 99 | --dataset="$dataset" \ 100 | --batch_size=256 \ 101 | --SSL_masking_ratio="$SSL_masking_ratio" \ 102 | --CL_similarity_metric="$CL_similarity_metric" --T="$T" --"$normalize" \ 103 | --AE_model="$AE_model" --AE_loss="$AE_loss" --"$detach" --beta="$beta" \ 104 | --alpha_1="$alpha_1" --alpha_2="$alpha_2" \ 105 | --SSL_2D_mode="$SSL_2D_mode" --alpha_3="$alpha_3" \ 106 | --num_interactions="$num_interactions" --num_gaussians="$num_gaussians" --cutoff="$cutoff" --schnet_lr_scale="$schnet_lr_scale" \ 107 | --dropout_ratio="$dropout_ratio" --num_workers=8 \ 108 | --output_model_dir="$output_model_dir" 109 | 110 | echo 111 | 112 | done 113 | done 114 | 115 | done 116 | done 117 | done 118 | done 119 | done 120 | done 121 | done 122 | done 123 | done 124 | 125 | done 126 | done 127 | done 128 | -------------------------------------------------------------------------------- /scripts_regression/submit_pre_training_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd ../src_regression 4 | 5 | epochs=100 6 | time=2 7 | mode_list=(AM CP GraphCL JOAO JOAOv2) 8 | dropout_ratio=0 9 | dataset=GEOM_2D_nmol50000_nconf1_nupper1000_morefeat 10 | 11 | 12 | for mode in "${mode_list[@]}"; do 13 | export folder="$mode"/"$dataset"/epochs_"$epochs"_"$dropout_ratio" 14 | echo "$folder" 15 | 16 | mkdir -p ./output/"$folder" 17 | ls ./output/"$folder" 18 | 19 | export output_file=./output/"$folder"/pretraining.out 20 | export output_model_dir=./output/"$folder"/pretraining 21 | 22 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":59:00 --account=rrg-bengioy-ad --qos=high --job-name=baselines \ 23 | --output="$output_file" \ 24 | ./run_pretrain_"$mode".sh \ 25 | --epochs="$epochs" \ 26 | --dataset="$dataset" \ 27 | --batch_size=256 \ 28 | --dropout_ratio="$dropout_ratio" --num_workers=8 \ 29 | --output_model_dir="$output_model_dir" 30 | done 31 | -------------------------------------------------------------------------------- /src_classification/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | 5 | # about seed and basic info 6 | parser.add_argument('--seed', type=int, default=42) 7 | parser.add_argument('--runseed', type=int, default=0) 8 | parser.add_argument('--device', type=int, default=0) 9 | 10 | # about dataset and dataloader 11 | parser.add_argument('--input_data_dir', type=str, default='') 12 | parser.add_argument('--dataset', type=str, default='bace') 13 | parser.add_argument('--num_workers', type=int, default=8) 14 | 15 | # about training strategies 16 | parser.add_argument('--split', type=str, default='scaffold') 17 | parser.add_argument('--batch_size', type=int, default=256) 18 | parser.add_argument('--epochs', type=int, default=100) 19 | parser.add_argument('--lr', type=float, default=0.001) 20 | parser.add_argument('--lr_scale', type=float, default=1) 21 | parser.add_argument('--decay', type=float, default=0) 22 | 23 | # about molecule GNN 24 | parser.add_argument('--gnn_type', type=str, default='gin') 25 | parser.add_argument('--num_layer', type=int, default=5) 26 | parser.add_argument('--emb_dim', type=int, default=300) 27 | parser.add_argument('--dropout_ratio', type=float, default=0.5) 28 | parser.add_argument('--graph_pooling', type=str, default='mean') 29 | parser.add_argument('--JK', type=str, default='last') 30 | parser.add_argument('--gnn_lr_scale', type=float, default=1) 31 | parser.add_argument('--model_3d', type=str, default='schnet', choices=['schnet']) 32 | 33 | # for AttributeMask 34 | parser.add_argument('--mask_rate', type=float, default=0.15) 35 | parser.add_argument('--mask_edge', type=int, default=0) 36 | 37 | # for ContextPred 38 | parser.add_argument('--csize', type=int, default=3) 39 | parser.add_argument('--contextpred_neg_samples', type=int, default=1) 40 | 41 | # for SchNet 42 | parser.add_argument('--num_filters', type=int, default=128) 43 | parser.add_argument('--num_interactions', type=int, default=6) 44 | parser.add_argument('--num_gaussians', type=int, default=51) 45 | parser.add_argument('--cutoff', type=float, default=10) 46 | parser.add_argument('--readout', type=str, default='mean', choices=['mean', 'add']) 47 | parser.add_argument('--schnet_lr_scale', type=float, default=1) 48 | 49 | # for 2D-3D Contrastive CL 50 | parser.add_argument('--CL_neg_samples', type=int, default=1) 51 | parser.add_argument('--CL_similarity_metric', type=str, default='InfoNCE_dot_prod', 52 | choices=['InfoNCE_dot_prod', 'EBM_dot_prod']) 53 | parser.add_argument('--T', type=float, default=0.1) 54 | parser.add_argument('--normalize', dest='normalize', action='store_true') 55 | parser.add_argument('--no_normalize', dest='normalize', action='store_false') 56 | parser.add_argument('--SSL_masking_ratio', type=float, default=0) 57 | # This is for generative SSL. 58 | parser.add_argument('--AE_model', type=str, default='AE', choices=['AE', 'VAE']) 59 | parser.set_defaults(AE_model='AE') 60 | 61 | # for 2D-3D AutoEncoder 62 | parser.add_argument('--AE_loss', type=str, default='l2', choices=['l1', 'l2', 'cosine']) 63 | parser.add_argument('--detach_target', dest='detach_target', action='store_true') 64 | parser.add_argument('--no_detach_target', dest='detach_target', action='store_false') 65 | parser.set_defaults(detach_target=True) 66 | 67 | # for 2D-3D Variational AutoEncoder 68 | parser.add_argument('--beta', type=float, default=1) 69 | 70 | # for 2D-3D Contrastive CL and AE/VAE 71 | parser.add_argument('--alpha_1', type=float, default=1) 72 | parser.add_argument('--alpha_2', type=float, default=1) 73 | 74 | # for 2D SSL and 3D-2D SSL 75 | parser.add_argument('--SSL_2D_mode', type=str, default='AM') 76 | parser.add_argument('--alpha_3', type=float, default=0.1) 77 | parser.add_argument('--gamma_joao', type=float, default=0.1) 78 | parser.add_argument('--gamma_joaov2', type=float, default=0.1) 79 | 80 | # about if we would print out eval metric for training data 81 | parser.add_argument('--eval_train', dest='eval_train', action='store_true') 82 | parser.add_argument('--no_eval_train', dest='eval_train', action='store_false') 83 | parser.set_defaults(eval_train=True) 84 | 85 | # about loading and saving 86 | parser.add_argument('--input_model_file', type=str, default='') 87 | parser.add_argument('--output_model_dir', type=str, default='') 88 | 89 | # verbosity 90 | parser.add_argument('--verbose', dest='verbose', action='store_true') 91 | parser.add_argument('--no_verbose', dest='verbose', action='store_false') 92 | parser.set_defaults(verbose=False) 93 | 94 | args = parser.parse_args() 95 | print('arguments\t', args) 96 | -------------------------------------------------------------------------------- /src_classification/dataloader.py: -------------------------------------------------------------------------------- 1 | from batch import (BatchAE, BatchMasking, BatchSubstructContext, 2 | BatchSubstructContext3D) 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | class DataLoaderSubstructContext(DataLoader): 7 | """Data loader which merges data objects from a 8 | :class:`torch_geometric.data.dataset` to a mini-batch. 9 | Args: 10 | dataset (Dataset): The dataset from which to load the data. 11 | batch_size (int, optional): How may samples per batch to load. 12 | (default: :obj:`1`) 13 | shuffle (bool, optional): If set to :obj:`True`, the data will be 14 | reshuffled at every epoch (default: :obj:`True`) """ 15 | 16 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 17 | super(DataLoaderSubstructContext, self).__init__( 18 | dataset, 19 | batch_size, 20 | shuffle, 21 | collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list), 22 | **kwargs) 23 | 24 | 25 | class DataLoaderSubstructContext3D(DataLoader): 26 | """Data loader which merges data objects from a 27 | :class:`torch_geometric.data.dataset` to a mini-batch. 28 | Args: 29 | dataset (Dataset): The dataset from which to load the data. 30 | batch_size (int, optional): How may samples per batch to load. 31 | (default: :obj:`1`) 32 | shuffle (bool, optional): If set to :obj:`True`, the data will be 33 | reshuffled at every epoch (default: :obj:`True`) """ 34 | 35 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 36 | super(DataLoaderSubstructContext3D, self).__init__( 37 | dataset, 38 | batch_size, 39 | shuffle, 40 | collate_fn=lambda data_list: BatchSubstructContext3D.from_data_list(data_list), 41 | **kwargs) 42 | 43 | 44 | class DataLoaderMasking(DataLoader): 45 | """Data loader which merges data objects from a 46 | :class:`torch_geometric.data.dataset` to a mini-batch. 47 | Args: 48 | dataset (Dataset): The dataset from which to load the data. 49 | batch_size (int, optional): How may samples per batch to load. 50 | (default: :obj:`1`) 51 | shuffle (bool, optional): If set to :obj:`True`, the data will be 52 | reshuffled at every epoch (default: :obj:`True`) """ 53 | 54 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 55 | super(DataLoaderMasking, self).__init__( 56 | dataset, 57 | batch_size, 58 | shuffle, 59 | collate_fn=lambda data_list: BatchMasking.from_data_list(data_list), 60 | **kwargs) 61 | 62 | 63 | class DataLoaderAE(DataLoader): 64 | """ Data loader which merges data objects from a 65 | :class:`torch_geometric.data.dataset` to a mini-batch. 66 | Args: 67 | dataset (Dataset): The dataset from which to load the data. 68 | batch_size (int, optional): How may samples per batch to load. 69 | (default: :obj:`1`) 70 | shuffle (bool, optional): If set to :obj:`True`, the data will be 71 | reshuffled at every epoch (default: :obj:`True`) """ 72 | 73 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): 74 | super(DataLoaderAE, self).__init__( 75 | dataset, 76 | batch_size, 77 | shuffle, 78 | collate_fn=lambda data_list: BatchAE.from_data_list(data_list), 79 | **kwargs) 80 | -------------------------------------------------------------------------------- /src_classification/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets_GPT import MoleculeDatasetGPT 2 | from .molecule_3D_dataset import Molecule3DDataset 3 | from .molecule_3D_masking_dataset import Molecule3DMaskingDataset 4 | from .molecule_contextual_datasets import MoleculeContextualDataset 5 | from .molecule_datasets import (MoleculeDataset, allowable_features, 6 | graph_data_obj_to_nx_simple, 7 | nx_to_graph_data_obj_simple) 8 | from .molecule_graphcl_dataset import MoleculeDataset_graphcl 9 | from .molecule_graphcl_masking_dataset import MoleculeGraphCLMaskingDataset 10 | from .molecule_motif_datasets import RDKIT_PROPS, MoleculeMotifDataset 11 | -------------------------------------------------------------------------------- /src_classification/datasets/datasets_GPT.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from torch_geometric.data import Data, InMemoryDataset 6 | from torch_geometric.utils import subgraph 7 | from tqdm import tqdm 8 | 9 | 10 | def search_graph(graph): 11 | num_node = len(graph.x) 12 | edge_set = set() 13 | 14 | u_list, v_list = graph.edge_index[0].numpy(), graph.edge_index[1].numpy() 15 | for u,v in zip(u_list, v_list): 16 | edge_set.add((u,v)) 17 | edge_set.add((v,u)) 18 | 19 | visited_list = [] 20 | unvisited_set = set([i for i in range(num_node)]) 21 | 22 | while len(unvisited_set) > 0: 23 | u = random.sample(unvisited_set, 1)[0] 24 | queue = [u] 25 | while len(queue): 26 | u = queue.pop(0) 27 | if u in visited_list: 28 | continue 29 | visited_list.append(u) 30 | unvisited_set.remove(u) 31 | 32 | for v in range(num_node): 33 | if (v not in visited_list) and ((u,v) in edge_set): 34 | queue.append(v) 35 | assert len(visited_list) == num_node 36 | return visited_list 37 | 38 | 39 | class MoleculeDatasetGPT(InMemoryDataset): 40 | def __init__(self, molecule_dataset, transform=None, pre_transform=None): 41 | self.molecule_dataset = molecule_dataset 42 | self.root = molecule_dataset.root + '_GPT' 43 | super(MoleculeDatasetGPT, self).__init__(self.root, transform=transform, pre_transform=pre_transform) 44 | 45 | self.data, self.slices = torch.load(self.processed_paths[0]) 46 | 47 | return 48 | 49 | def process(self): 50 | num_molecule = len(self.molecule_dataset) 51 | data_list = [] 52 | for i in tqdm(range(num_molecule)): 53 | graph = self.molecule_dataset.get(i) 54 | 55 | num_node = len(graph.x) 56 | # TODO: will replace this with DFS/BFS searching 57 | node_list = search_graph(graph) 58 | 59 | for idx in range(num_node-1): 60 | # print('sub_node_list: {}\nnext_node: {}'.format(sub_node_list, next_node)) 61 | # [0..idx] -> [idx+1] 62 | sub_node_list = node_list[:idx+1] 63 | next_node = node_list[idx+1] 64 | 65 | edge_index, edge_attr = subgraph( 66 | subset=sub_node_list, edge_index=graph.edge_index, edge_attr=graph.edge_attr, 67 | relabel_nodes=True, num_nodes=num_node) 68 | 69 | # Take the subgraph and predict on the next node (atom type only) 70 | sub_graph = Data(x=graph.x[sub_node_list], edge_index=edge_index, edge_attr=edge_attr, next_x=graph.x[next_node, :1]) 71 | data_list.append(sub_graph) 72 | 73 | print('len of data\t', len(data_list)) 74 | data, slices = self.collate(data_list) 75 | print('Saving...') 76 | torch.save((data, slices), self.processed_paths[0]) 77 | return 78 | 79 | @property 80 | def processed_file_names(self): 81 | return 'geometric_data_processed.pt' 82 | -------------------------------------------------------------------------------- /src_classification/datasets/molecule_3D_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import repeat 3 | 4 | import torch 5 | from torch_geometric.data import Data, InMemoryDataset 6 | 7 | 8 | class Molecule3DDataset(InMemoryDataset): 9 | def __init__(self, root, dataset='zinc250k', 10 | transform=None, pre_transform=None, pre_filter=None, empty=False): 11 | self.dataset = dataset 12 | self.root = root 13 | 14 | super(Molecule3DDataset, self).__init__(root, transform, pre_transform, pre_filter) 15 | self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter 16 | 17 | if not empty: 18 | self.data, self.slices = torch.load(self.processed_paths[0]) 19 | print('Dataset: {}\nData: {}'.format(self.dataset, self.data)) 20 | 21 | def get(self, idx): 22 | data = Data() 23 | for key in self.data.keys: 24 | item, slices = self.data[key], self.slices[key] 25 | s = list(repeat(slice(None), item.dim())) 26 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 27 | data[key] = item[s] 28 | return data 29 | 30 | @property 31 | def raw_file_names(self): 32 | return os.listdir(self.raw_dir) 33 | 34 | @property 35 | def processed_file_names(self): 36 | return 'geometric_data_processed.pt' 37 | 38 | def download(self): 39 | return 40 | 41 | def process(self): 42 | return 43 | -------------------------------------------------------------------------------- /src_classification/datasets/molecule_3D_masking_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import repeat 3 | 4 | import numpy as np 5 | import torch 6 | from torch_geometric.data import Data, InMemoryDataset 7 | from torch_geometric.utils import subgraph, to_networkx 8 | 9 | 10 | class Molecule3DMaskingDataset(InMemoryDataset): 11 | def __init__(self, root, dataset, mask_ratio, 12 | transform=None, pre_transform=None, pre_filter=None, empty=False): 13 | self.root = root 14 | self.dataset = dataset 15 | self.mask_ratio = mask_ratio 16 | 17 | super(Molecule3DMaskingDataset, self).__init__(root, transform, pre_transform, pre_filter) 18 | self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter 19 | 20 | if not empty: 21 | self.data, self.slices = torch.load(self.processed_paths[0]) 22 | print('Dataset: {}\nData: {}'.format(self.dataset, self.data)) 23 | 24 | def subgraph(self, data): 25 | G = to_networkx(data) 26 | node_num, _ = data.x.size() 27 | sub_num = int(node_num * (1 - self.mask_ratio)) 28 | 29 | idx_sub = [np.random.randint(node_num, size=1)[0]] 30 | idx_neigh = set([n for n in G.neighbors(idx_sub[-1])]) 31 | 32 | # BFS 33 | while len(idx_sub) <= sub_num: 34 | if len(idx_neigh) == 0: 35 | idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub))) 36 | idx_neigh = set([np.random.choice(idx_unsub)]) 37 | sample_node = np.random.choice(list(idx_neigh)) 38 | 39 | idx_sub.append(sample_node) 40 | idx_neigh = idx_neigh.union( 41 | set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub)) 42 | 43 | idx_nondrop = idx_sub 44 | idx_nondrop.sort() 45 | 46 | edge_idx, edge_attr = subgraph(subset=idx_nondrop, 47 | edge_index=data.edge_index, 48 | edge_attr=data.edge_attr, 49 | relabel_nodes=True, 50 | num_nodes=node_num) 51 | 52 | data.edge_index = edge_idx 53 | data.edge_attr = edge_attr 54 | data.x = data.x[idx_nondrop] 55 | data.positions = data.positions[idx_nondrop] 56 | data.__num_nodes__, _ = data.x.shape 57 | return data 58 | 59 | def get(self, idx): 60 | data = Data() 61 | for key in self.data.keys: 62 | item, slices = self.data[key], self.slices[key] 63 | s = list(repeat(slice(None), item.dim())) 64 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 65 | data[key] = item[s] 66 | 67 | if self.mask_ratio > 0: 68 | data = self.subgraph(data) 69 | return data 70 | 71 | @property 72 | def raw_file_names(self): 73 | return os.listdir(self.raw_dir) 74 | 75 | @property 76 | def processed_file_names(self): 77 | return 'geometric_data_processed.pt' 78 | 79 | def download(self): 80 | return 81 | 82 | def process(self): 83 | return 84 | -------------------------------------------------------------------------------- /src_classification/datasets/molecule_contextual_datasets_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import Counter 3 | from multiprocessing import Pool 4 | 5 | import tqdm 6 | 7 | BOND_FEATURES = ['BondType', 'BondDir'] 8 | 9 | 10 | def atom_to_vocab(mol, atom): 11 | """ 12 | Convert atom to vocabulary. The convention is based on atom type and bond type. 13 | :param mol: the molecular. 14 | :param atom: the target atom. 15 | :return: the generated atom vocabulary with its contexts. 16 | """ 17 | nei = Counter() 18 | for a in atom.GetNeighbors(): 19 | bond = mol.GetBondBetweenAtoms(atom.GetIdx(), a.GetIdx()) 20 | nei[str(a.GetSymbol()) + "-" + str(bond.GetBondType())] += 1 21 | keys = nei.keys() 22 | keys = list(keys) 23 | keys.sort() 24 | output = atom.GetSymbol() 25 | for k in keys: 26 | output = "%s_%s%d" % (output, k, nei[k]) 27 | 28 | # The generated atom_vocab is too long? 29 | return output 30 | 31 | 32 | def bond_to_vocab(mol, bond): 33 | """ 34 | Convert bond to vocabulary. The convention is based on atom type and bond type. 35 | Considering one-hop neighbor atoms 36 | :param mol: the molecular. 37 | :param atom: the target atom. 38 | :return: the generated bond vocabulary with its contexts. 39 | """ 40 | nei = Counter() 41 | two_neighbors = (bond.GetBeginAtom(), bond.GetEndAtom()) 42 | two_indices = [a.GetIdx() for a in two_neighbors] 43 | for nei_atom in two_neighbors: 44 | for a in nei_atom.GetNeighbors(): 45 | a_idx = a.GetIdx() 46 | if a_idx in two_indices: 47 | continue 48 | tmp_bond = mol.GetBondBetweenAtoms(nei_atom.GetIdx(), a_idx) 49 | nei[str(nei_atom.GetSymbol()) + '-' + get_bond_feature_name(tmp_bond)] += 1 50 | keys = list(nei.keys()) 51 | keys.sort() 52 | output = get_bond_feature_name(bond) 53 | for k in keys: 54 | output = "%s_%s%d" % (output, k, nei[k]) 55 | return output 56 | 57 | 58 | def get_bond_feature_name(bond): 59 | """ 60 | Return the string format of bond features. 61 | Bond features are surrounded with () 62 | """ 63 | ret = [] 64 | for bond_feature in BOND_FEATURES: 65 | fea = eval(f"bond.Get{bond_feature}")() 66 | ret.append(str(fea)) 67 | 68 | return '(' + '-'.join(ret) + ')' 69 | 70 | 71 | class TorchVocab(object): 72 | def __init__(self, counter, max_size=None, min_freq=1, specials=('', ''), vocab_type='atom'): 73 | """ 74 | :param counter: 75 | :param max_size: 76 | :param min_freq: 77 | :param specials: 78 | :param vocab_type: 'atom': atom atom_vocab; 'bond': bond atom_vocab. 79 | """ 80 | self.freqs = counter 81 | counter = counter.copy() 82 | min_freq = max(min_freq, 1) 83 | if vocab_type in ('atom', 'bond'): 84 | self.vocab_type = vocab_type 85 | else: 86 | raise ValueError('Wrong input for vocab_type!') 87 | self.itos = list(specials) 88 | 89 | max_size = None if max_size is None else max_size + len(self.itos) 90 | # sort by frequency, then alphabetically 91 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 92 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 93 | 94 | for word, freq in words_and_frequencies: 95 | if freq < min_freq or len(self.itos) == max_size: 96 | break 97 | self.itos.append(word) 98 | # stoi is simply a reverse dict for itos 99 | self.stoi = {tok: i for i, tok in enumerate(self.itos)} 100 | self.other_index = 1 101 | self.pad_index = 0 102 | 103 | def __eq__(self, other): 104 | if self.freqs != other.freqs: 105 | return False 106 | if self.stoi != other.stoi: 107 | return False 108 | if self.itos != other.itos: 109 | return False 110 | return True 111 | 112 | def __len__(self): 113 | return len(self.itos) 114 | 115 | def vocab_rerank(self): 116 | self.stoi = {word: i for i, word in enumerate(self.itos)} 117 | 118 | def extend(self, v, sort=False): 119 | words = sorted(v.itos) if sort else v.itos 120 | for w in words: 121 | if w not in self.stoi: 122 | self.itos.append(w) 123 | self.stoi[w] = len(self.itos) - 1 124 | self.freqs[w] = 0 125 | self.freqs[w] += v.freqs[w] 126 | 127 | def save_vocab(self, vocab_path): 128 | with open(vocab_path, "wb") as f: 129 | pickle.dump(self, f) 130 | 131 | 132 | class MolVocab(TorchVocab): 133 | def __init__(self, molecule_list, max_size=None, min_freq=1, num_workers=1, total_lines=None, vocab_type='atom'): 134 | if vocab_type in ('atom', 'bond'): 135 | self.vocab_type = vocab_type 136 | else: 137 | raise ValueError('Wrong input for vocab_type!') 138 | print("Building {} vocab from molecule-list".format((self.vocab_type))) 139 | 140 | from rdkit import RDLogger 141 | lg = RDLogger.logger() 142 | lg.setLevel(RDLogger.CRITICAL) 143 | 144 | if total_lines is None: 145 | total_lines = len(molecule_list) 146 | 147 | counter = Counter() 148 | pbar = tqdm.tqdm(total=total_lines) 149 | pool = Pool(num_workers) 150 | res = [] 151 | batch = 50000 152 | callback = lambda a: pbar.update(batch) 153 | for i in range(int(total_lines / batch + 1)): 154 | start = int(batch * i) 155 | end = min(total_lines, batch * (i + 1)) 156 | res.append(pool.apply_async(MolVocab.read_counter_from_molecules, 157 | args=(molecule_list, start, end, vocab_type,), 158 | callback=callback)) 159 | pool.close() 160 | pool.join() 161 | for r in res: 162 | sub_counter = r.get() 163 | for k in sub_counter: 164 | if k not in counter: 165 | counter[k] = 0 166 | counter[k] += sub_counter[k] 167 | super().__init__(counter, max_size=max_size, min_freq=min_freq, vocab_type=vocab_type) 168 | 169 | @staticmethod 170 | def read_counter_from_molecules(molecule_list, start, end, vocab_type): 171 | sub_counter = Counter() 172 | for i, mol in enumerate(molecule_list): 173 | if i < start: 174 | continue 175 | if i >= end: 176 | break 177 | if vocab_type == 'atom': 178 | for atom in mol.GetAtoms(): 179 | v = atom_to_vocab(mol, atom) 180 | sub_counter[v] += 1 181 | else: 182 | for bond in mol.GetBonds(): 183 | v = bond_to_vocab(mol, bond) 184 | sub_counter[v] += 1 185 | # print("end") 186 | return sub_counter 187 | 188 | @staticmethod 189 | def load_vocab(vocab_path: str) -> 'MolVocab': 190 | with open(vocab_path, "rb") as f: 191 | return pickle.load(f) 192 | -------------------------------------------------------------------------------- /src_classification/datasets/molecule_graphcl_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from itertools import repeat 4 | 5 | import numpy as np 6 | import torch 7 | from torch_geometric.data import Data 8 | from torch_geometric.utils import subgraph, to_networkx 9 | 10 | from .molecule_datasets import MoleculeDataset 11 | 12 | 13 | class MoleculeDataset_graphcl(MoleculeDataset): 14 | 15 | def __init__(self, 16 | root, 17 | transform=None, 18 | pre_transform=None, 19 | pre_filter=None, 20 | dataset=None, 21 | empty=False): 22 | 23 | self.aug_prob = None 24 | self.aug_mode = 'no_aug' 25 | self.aug_strength = 0.2 26 | self.augmentations = [self.node_drop, self.subgraph, 27 | self.edge_pert, self.attr_mask, lambda x: x] 28 | super(MoleculeDataset_graphcl, self).__init__( 29 | root, transform, pre_transform, pre_filter, dataset, empty) 30 | 31 | def set_augMode(self, aug_mode): 32 | self.aug_mode = aug_mode 33 | 34 | def set_augStrength(self, aug_strength): 35 | self.aug_strength = aug_strength 36 | 37 | def set_augProb(self, aug_prob): 38 | self.aug_prob = aug_prob 39 | 40 | def node_drop(self, data): 41 | 42 | node_num, _ = data.x.size() 43 | _, edge_num = data.edge_index.size() 44 | drop_num = int(node_num * self.aug_strength) 45 | 46 | idx_perm = np.random.permutation(node_num) 47 | idx_nodrop = idx_perm[drop_num:].tolist() 48 | idx_nodrop.sort() 49 | 50 | edge_idx, edge_attr = subgraph(subset=idx_nodrop, 51 | edge_index=data.edge_index, 52 | edge_attr=data.edge_attr, 53 | relabel_nodes=True, 54 | num_nodes=node_num) 55 | 56 | data.edge_index = edge_idx 57 | data.edge_attr = edge_attr 58 | data.x = data.x[idx_nodrop] 59 | data.__num_nodes__, _ = data.x.shape 60 | return data 61 | 62 | def edge_pert(self, data): 63 | node_num, _ = data.x.size() 64 | _, edge_num = data.edge_index.size() 65 | pert_num = int(edge_num * self.aug_strength) 66 | 67 | # delete edges 68 | idx_drop = np.random.choice(edge_num, (edge_num - pert_num), 69 | replace=False) 70 | edge_index = data.edge_index[:, idx_drop] 71 | edge_attr = data.edge_attr[idx_drop] 72 | 73 | # add edges 74 | adj = torch.ones((node_num, node_num)) 75 | adj[edge_index[0], edge_index[1]] = 0 76 | # edge_index_nonexist = adj.nonzero(as_tuple=False).t() 77 | edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t() 78 | idx_add = np.random.choice(edge_index_nonexist.shape[1], 79 | pert_num, replace=False) 80 | edge_index_add = edge_index_nonexist[:, idx_add] 81 | # random 4-class & 3-class edge_attr for 1st & 2nd dimension 82 | edge_attr_add_1 = torch.tensor(np.random.randint( 83 | 4, size=(edge_index_add.shape[1], 1))) 84 | edge_attr_add_2 = torch.tensor(np.random.randint( 85 | 3, size=(edge_index_add.shape[1], 1))) 86 | edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2), dim=1) 87 | edge_index = torch.cat((edge_index, edge_index_add), dim=1) 88 | edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0) 89 | 90 | data.edge_index = edge_index 91 | data.edge_attr = edge_attr 92 | return data 93 | 94 | def attr_mask(self, data): 95 | 96 | _x = data.x.clone() 97 | node_num, _ = data.x.size() 98 | mask_num = int(node_num * self.aug_strength) 99 | 100 | token = data.x.float().mean(dim=0).long() 101 | idx_mask = np.random.choice( 102 | node_num, mask_num, replace=False) 103 | 104 | _x[idx_mask] = token 105 | data.x = _x 106 | return data 107 | 108 | def subgraph(self, data): 109 | 110 | G = to_networkx(data) 111 | node_num, _ = data.x.size() 112 | _, edge_num = data.edge_index.size() 113 | sub_num = int(node_num * (1 - self.aug_strength)) 114 | 115 | idx_sub = [np.random.randint(node_num, size=1)[0]] 116 | idx_neigh = set([n for n in G.neighbors(idx_sub[-1])]) 117 | 118 | while len(idx_sub) <= sub_num: 119 | if len(idx_neigh) == 0: 120 | idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub))) 121 | idx_neigh = set([np.random.choice(idx_unsub)]) 122 | sample_node = np.random.choice(list(idx_neigh)) 123 | 124 | idx_sub.append(sample_node) 125 | idx_neigh = idx_neigh.union( 126 | set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub)) 127 | 128 | idx_nondrop = idx_sub 129 | idx_nondrop.sort() 130 | 131 | edge_idx, edge_attr = subgraph(subset=idx_nondrop, 132 | edge_index=data.edge_index, 133 | edge_attr=data.edge_attr, 134 | relabel_nodes=True, 135 | num_nodes=node_num) 136 | 137 | data.edge_index = edge_idx 138 | data.edge_attr = edge_attr 139 | data.x = data.x[idx_nondrop] 140 | data.__num_nodes__, _ = data.x.shape 141 | return data 142 | 143 | def get(self, idx): 144 | data, data1, data2 = Data(), Data(), Data() 145 | keys_for_2D = ['x', 'edge_index', 'edge_attr'] 146 | for key in self.data.keys: 147 | item, slices = self.data[key], self.slices[key] 148 | s = list(repeat(slice(None), item.dim())) 149 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 150 | if key in keys_for_2D: 151 | data[key], data1[key], data2[key] = item[s], item[s], item[s] 152 | else: 153 | data[key] = item[s] 154 | 155 | if self.aug_mode == 'no_aug': 156 | n_aug1, n_aug2 = 4, 4 157 | data1 = self.augmentations[n_aug1](data1) 158 | data2 = self.augmentations[n_aug2](data2) 159 | elif self.aug_mode == 'uniform': 160 | n_aug = np.random.choice(25, 1)[0] 161 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5 162 | data1 = self.augmentations[n_aug1](data1) 163 | data2 = self.augmentations[n_aug2](data2) 164 | elif self.aug_mode == 'sample': 165 | n_aug = np.random.choice(25, 1, p=self.aug_prob)[0] 166 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5 167 | data1 = self.augmentations[n_aug1](data1) 168 | data2 = self.augmentations[n_aug2](data2) 169 | else: 170 | raise ValueError 171 | return data, data1, data2 172 | -------------------------------------------------------------------------------- /src_classification/datasets/molecule_motif_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import repeat 3 | 4 | import numpy as np 5 | import torch 6 | from descriptastorus.descriptors import rdDescriptors 7 | from torch_geometric.data import Data, InMemoryDataset 8 | from tqdm import tqdm 9 | 10 | 11 | RDKIT_PROPS = ['fr_Al_COO', 'fr_Al_OH', 'fr_Al_OH_noTert', 'fr_ArN', 12 | 'fr_Ar_COO', 'fr_Ar_N', 'fr_Ar_NH', 'fr_Ar_OH', 'fr_COO', 'fr_COO2', 13 | 'fr_C_O', 'fr_C_O_noCOO', 'fr_C_S', 'fr_HOCCN', 'fr_Imine', 'fr_NH0', 14 | 'fr_NH1', 'fr_NH2', 'fr_N_O', 'fr_Ndealkylation1', 'fr_Ndealkylation2', 15 | 'fr_Nhpyrrole', 'fr_SH', 'fr_aldehyde', 'fr_alkyl_carbamate', 'fr_alkyl_halide', 16 | 'fr_allylic_oxid', 'fr_amide', 'fr_amidine', 'fr_aniline', 'fr_aryl_methyl', 17 | 'fr_azide', 'fr_azo', 'fr_barbitur', 'fr_benzene', 'fr_benzodiazepine', 18 | 'fr_bicyclic', 'fr_diazo', 'fr_dihydropyridine', 'fr_epoxide', 'fr_ester', 19 | 'fr_ether', 'fr_furan', 'fr_guanido', 'fr_halogen', 'fr_hdrzine', 'fr_hdrzone', 20 | 'fr_imidazole', 'fr_imide', 'fr_isocyan', 'fr_isothiocyan', 'fr_ketone', 21 | 'fr_ketone_Topliss', 'fr_lactam', 'fr_lactone', 'fr_methoxy', 'fr_morpholine', 22 | 'fr_nitrile', 'fr_nitro', 'fr_nitro_arom', 'fr_nitro_arom_nonortho', 23 | 'fr_nitroso', 'fr_oxazole', 'fr_oxime', 'fr_para_hydroxylation', 'fr_phenol', 24 | 'fr_phenol_noOrthoHbond', 'fr_phos_acid', 'fr_phos_ester', 'fr_piperdine', 25 | 'fr_piperzine', 'fr_priamide', 'fr_prisulfonamd', 'fr_pyridine', 'fr_quatN', 26 | 'fr_sulfide', 'fr_sulfonamd', 'fr_sulfone', 'fr_term_acetylene', 'fr_tetrazole', 27 | 'fr_thiazole', 'fr_thiocyan', 'fr_thiophene', 'fr_unbrch_alkane', 'fr_urea'] 28 | 29 | 30 | def rdkit_functional_group_label_features_generator(smiles): 31 | """ 32 | Generates functional group label for a molecule using RDKit. 33 | :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule). 34 | :return: A 1D numpy array containing the RDKit 2D features. 35 | """ 36 | # smiles = Chem.MolToSmiles(mol, isomericSmiles=True) 37 | # if type(mol) != str else mol 38 | generator = rdDescriptors.RDKit2D(RDKIT_PROPS) 39 | features = generator.process(smiles)[1:] 40 | features = np.array(features) 41 | features[features != 0] = 1 42 | return features 43 | 44 | 45 | class MoleculeMotifDataset(InMemoryDataset): 46 | def __init__(self, root, dataset, 47 | transform=None, pre_transform=None, pre_filter=None, empty=False): 48 | self.dataset = dataset 49 | self.root = root 50 | 51 | super(MoleculeMotifDataset, self).__init__(root, transform, pre_transform, pre_filter) 52 | self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter 53 | 54 | if not empty: 55 | self.data, self.slices = torch.load(self.processed_paths[0]) 56 | 57 | self.motif_file = os.path.join(root, 'processed', 'motif.pt') 58 | self.process_motif_file() 59 | self.motif_label_list = torch.load(self.motif_file) 60 | 61 | print('Dataset: {}\nData: {}\nMotif: {}'.format(self.dataset, self.data, self.motif_label_list.size())) 62 | 63 | def process_motif_file(self): 64 | if not os.path.exists(self.motif_file): 65 | smiles_file = os.path.join(self.root, 'processed', 'smiles.csv') 66 | data_smiles_list = [] 67 | with open(smiles_file, 'r') as f: 68 | lines = f.readlines() 69 | for smiles in lines: 70 | data_smiles_list.append(smiles.strip()) 71 | 72 | motif_label_list = [] 73 | for smiles in tqdm(data_smiles_list): 74 | label = rdkit_functional_group_label_features_generator(smiles) 75 | motif_label_list.append(label) 76 | 77 | self.motif_label_list = torch.LongTensor(motif_label_list) 78 | torch.save(self.motif_label_list, self.motif_file) 79 | return 80 | 81 | def get(self, idx): 82 | data = Data() 83 | for key in self.data.keys: 84 | item, slices = self.data[key], self.slices[key] 85 | s = list(repeat(slice(None), item.dim())) 86 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 87 | data[key] = item[s] 88 | data.y = self.motif_label_list[idx] 89 | return data 90 | 91 | @property 92 | def raw_file_names(self): 93 | return os.listdir(self.raw_dir) 94 | 95 | @property 96 | def processed_file_names(self): 97 | return 'geometric_data_processed.pt' 98 | 99 | def download(self): 100 | return 101 | 102 | def process(self): 103 | return 104 | -------------------------------------------------------------------------------- /src_classification/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn.inits import uniform 4 | 5 | from .auto_encoder import AutoEncoder, VariationalAutoEncoder 6 | from .molecule_gnn_model import GNN, GNN_graphpred 7 | from .schnet import SchNet 8 | from .dti_model import ProteinModel, MoleculeProteinModel 9 | 10 | 11 | class Discriminator(nn.Module): 12 | def __init__(self, hidden_dim): 13 | super(Discriminator, self).__init__() 14 | self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim)) 15 | self.reset_parameters() 16 | 17 | def reset_parameters(self): 18 | size = self.weight.size(0) 19 | uniform(size, self.weight) 20 | 21 | def forward(self, x, summary): 22 | h = torch.matmul(summary, self.weight) 23 | return torch.sum(x*h, dim=1) 24 | -------------------------------------------------------------------------------- /src_classification/models/auto_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | def cosine_similarity(p, z, average=True): 7 | p = F.normalize(p, p=2, dim=1) 8 | z = F.normalize(z, p=2, dim=1) 9 | loss = -(p * z).sum(dim=1) 10 | if average: 11 | loss = loss.mean() 12 | return loss 13 | 14 | 15 | class AutoEncoder(torch.nn.Module): 16 | 17 | def __init__(self, emb_dim, loss, detach_target): 18 | super(AutoEncoder, self).__init__() 19 | self.loss = loss 20 | self.emb_dim = emb_dim 21 | self.detach_target = detach_target 22 | 23 | self.criterion = None 24 | if loss == 'l1': 25 | self.criterion = nn.L1Loss() 26 | elif loss == 'l2': 27 | self.criterion = nn.MSELoss() 28 | elif loss == 'cosine': 29 | self.criterion = cosine_similarity 30 | 31 | self.fc_layers = nn.Sequential( 32 | nn.Linear(self.emb_dim, self.emb_dim), 33 | nn.BatchNorm1d(self.emb_dim), 34 | nn.ReLU(), 35 | nn.Linear(self.emb_dim, self.emb_dim), 36 | ) 37 | return 38 | 39 | def forward(self, x, y): 40 | if self.detach_target: 41 | y = y.detach() 42 | x = self.fc_layers(x) 43 | loss = self.criterion(x, y) 44 | 45 | return loss 46 | 47 | 48 | class VariationalAutoEncoder(torch.nn.Module): 49 | def __init__(self, emb_dim, loss, detach_target, beta=1): 50 | super(VariationalAutoEncoder, self).__init__() 51 | self.emb_dim = emb_dim 52 | self.loss = loss 53 | self.detach_target = detach_target 54 | self.beta = beta 55 | 56 | self.criterion = None 57 | if loss == 'l1': 58 | self.criterion = nn.L1Loss() 59 | elif loss == 'l2': 60 | self.criterion = nn.MSELoss() 61 | elif loss == 'cosine': 62 | self.criterion = cosine_similarity 63 | 64 | self.fc_mu = nn.Linear(self.emb_dim, self.emb_dim) 65 | self.fc_var = nn.Linear(self.emb_dim, self.emb_dim) 66 | 67 | self.decoder = nn.Sequential( 68 | nn.Linear(self.emb_dim, self.emb_dim), 69 | nn.BatchNorm1d(self.emb_dim), 70 | nn.ReLU(), 71 | nn.Linear(self.emb_dim, self.emb_dim), 72 | ) 73 | return 74 | 75 | def encode(self, x): 76 | mu = self.fc_mu(x) 77 | log_var = self.fc_var(x) 78 | return mu, log_var 79 | 80 | def reparameterize(self, mu, log_var): 81 | std = torch.exp(0.5 * log_var) 82 | eps = torch.randn_like(std) 83 | return mu + eps * std 84 | 85 | def forward(self, x, y): 86 | if self.detach_target: 87 | y = y.detach() 88 | 89 | mu, log_var = self.encode(x) 90 | z = self.reparameterize(mu, log_var) 91 | y_hat = self.decoder(z) 92 | 93 | reconstruction_loss = self.criterion(y_hat, y) 94 | kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) 95 | 96 | loss = reconstruction_loss + self.beta * kl_loss 97 | 98 | return loss 99 | -------------------------------------------------------------------------------- /src_classification/models/dti_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn import global_mean_pool 4 | 5 | 6 | class ProteinModel(nn.Module): 7 | def __init__(self, emb_dim=128, num_features=25, output_dim=128, n_filters=32, kernel_size=8): 8 | super(ProteinModel, self).__init__() 9 | self.n_filters = n_filters 10 | self.kernel_size = kernel_size 11 | self.intermediate_dim = emb_dim - kernel_size + 1 12 | 13 | self.embedding = nn.Embedding(num_features+1, emb_dim) 14 | self.n_filters = n_filters 15 | self.conv1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=kernel_size) 16 | self.fc = nn.Linear(n_filters*self.intermediate_dim, output_dim) 17 | 18 | def forward(self, x): 19 | x = self.embedding(x) 20 | x = self.conv1(x) 21 | x = x.view(-1, self.n_filters*self.intermediate_dim) 22 | x = self.fc(x) 23 | return x 24 | 25 | 26 | class MoleculeProteinModel(nn.Module): 27 | def __init__(self, molecule_model, protein_model, molecule_emb_dim, protein_emb_dim, output_dim=1, dropout=0.2): 28 | super(MoleculeProteinModel, self).__init__() 29 | self.fc1 = nn.Linear(molecule_emb_dim+protein_emb_dim, 1024) 30 | self.fc2 = nn.Linear(1024, 512) 31 | self.out = nn.Linear(512, output_dim) 32 | self.molecule_model = molecule_model 33 | self.protein_model = protein_model 34 | self.pool = global_mean_pool 35 | self.relu = nn.ReLU() 36 | self.dropout = nn.Dropout(dropout) 37 | 38 | def forward(self, molecule, protein): 39 | molecule_node_representation = self.molecule_model(molecule) 40 | molecule_representation = self.pool(molecule_node_representation, molecule.batch) 41 | protein_representation = self.protein_model(protein) 42 | 43 | x = torch.cat([molecule_representation, protein_representation], dim=1) 44 | 45 | x = self.fc1(x) 46 | x = self.relu(x) 47 | x = self.dropout(x) 48 | x = self.fc2(x) 49 | x = self.relu(x) 50 | x = self.dropout(x) 51 | x = self.out(x) 52 | 53 | return x 54 | -------------------------------------------------------------------------------- /src_classification/models/schnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import warnings 4 | from math import pi as PI 5 | 6 | import ase 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.nn import Embedding, Linear, ModuleList, Sequential 11 | from torch_geometric.data.makedirs import makedirs 12 | from torch_geometric.nn import MessagePassing, radius_graph 13 | from torch_scatter import scatter 14 | 15 | try: 16 | import schnetpack as spk 17 | except ImportError: 18 | spk = None 19 | 20 | 21 | class SchNet(torch.nn.Module): 22 | 23 | def __init__(self, hidden_channels=128, num_filters=128, 24 | num_interactions=6, num_gaussians=50, cutoff=10.0, 25 | readout='mean', dipole=False, mean=None, std=None, atomref=None): 26 | super(SchNet, self).__init__() 27 | 28 | assert readout in ['add', 'sum', 'mean'] 29 | 30 | self.readout = 'add' if dipole else readout 31 | self.num_interactions = num_interactions 32 | self.hidden_channels = hidden_channels 33 | self.num_gaussians = num_gaussians 34 | self.num_filters = num_filters 35 | # self.readout = readout 36 | self.cutoff = cutoff 37 | self.dipole = dipole 38 | self.scale = None 39 | self.mean = mean 40 | self.std = std 41 | 42 | atomic_mass = torch.from_numpy(ase.data.atomic_masses) 43 | self.register_buffer('atomic_mass', atomic_mass) 44 | 45 | # self.embedding = Embedding(100, hidden_channels) 46 | self.embedding = Embedding(119, hidden_channels) 47 | self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians) 48 | 49 | self.interactions = ModuleList() 50 | for _ in range(num_interactions): 51 | block = InteractionBlock(hidden_channels, num_gaussians, 52 | num_filters, cutoff) 53 | self.interactions.append(block) 54 | 55 | # TODO: double-check hidden size 56 | self.lin1 = Linear(hidden_channels, hidden_channels) 57 | self.act = ShiftedSoftplus() 58 | self.lin2 = Linear(hidden_channels, hidden_channels) 59 | 60 | self.register_buffer('initial_atomref', atomref) 61 | self.atomref = None 62 | if atomref is not None: 63 | self.atomref = Embedding(100, 1) 64 | self.atomref.weight.data.copy_(atomref) 65 | 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | self.embedding.reset_parameters() 70 | for interaction in self.interactions: 71 | interaction.reset_parameters() 72 | torch.nn.init.xavier_uniform_(self.lin1.weight) 73 | self.lin1.bias.data.fill_(0) 74 | torch.nn.init.xavier_uniform_(self.lin2.weight) 75 | self.lin2.bias.data.fill_(0) 76 | if self.atomref is not None: 77 | self.atomref.weight.data.copy_(self.initial_atomref) 78 | 79 | def forward(self, z, pos, batch=None): 80 | assert z.dim() == 1 and z.dtype == torch.long 81 | batch = torch.zeros_like(z) if batch is None else batch 82 | 83 | h = self.embedding(z) 84 | 85 | edge_index = radius_graph(pos, r=self.cutoff, batch=batch) 86 | row, col = edge_index 87 | edge_weight = (pos[row] - pos[col]).norm(dim=-1) 88 | edge_attr = self.distance_expansion(edge_weight) 89 | 90 | for interaction in self.interactions: 91 | h = h + interaction(h, edge_index, edge_weight, edge_attr) 92 | 93 | h = self.lin1(h) 94 | h = self.act(h) 95 | h = self.lin2(h) 96 | 97 | if self.dipole: 98 | # Get center of mass. 99 | mass = self.atomic_mass[z].view(-1, 1) 100 | c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0) 101 | h = h * (pos - c[batch]) 102 | 103 | if not self.dipole and self.mean is not None and self.std is not None: 104 | h = h * self.std + self.mean 105 | 106 | if not self.dipole and self.atomref is not None: 107 | h = h + self.atomref(z) 108 | 109 | out = scatter(h, batch, dim=0, reduce=self.readout) 110 | 111 | if self.dipole: 112 | out = torch.norm(out, dim=-1, keepdim=True) 113 | 114 | if self.scale is not None: 115 | out = self.scale * out 116 | 117 | return out 118 | 119 | def __repr__(self): 120 | return (f'{self.__class__.__name__}(' 121 | f'hidden_channels={self.hidden_channels}, ' 122 | f'num_filters={self.num_filters}, ' 123 | f'num_interactions={self.num_interactions}, ' 124 | f'num_gaussians={self.num_gaussians}, ' 125 | f'cutoff={self.cutoff})') 126 | 127 | 128 | class InteractionBlock(torch.nn.Module): 129 | def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff): 130 | super(InteractionBlock, self).__init__() 131 | self.mlp = Sequential( 132 | Linear(num_gaussians, num_filters), 133 | ShiftedSoftplus(), 134 | Linear(num_filters, num_filters), 135 | ) 136 | self.conv = CFConv(hidden_channels, hidden_channels, 137 | num_filters, self.mlp, cutoff) 138 | self.act = ShiftedSoftplus() 139 | self.lin = Linear(hidden_channels, hidden_channels) 140 | 141 | self.reset_parameters() 142 | 143 | def reset_parameters(self): 144 | torch.nn.init.xavier_uniform_(self.mlp[0].weight) 145 | self.mlp[0].bias.data.fill_(0) 146 | torch.nn.init.xavier_uniform_(self.mlp[2].weight) 147 | self.mlp[0].bias.data.fill_(0) 148 | self.conv.reset_parameters() 149 | torch.nn.init.xavier_uniform_(self.lin.weight) 150 | self.lin.bias.data.fill_(0) 151 | 152 | def forward(self, x, edge_index, edge_weight, edge_attr): 153 | x = self.conv(x, edge_index, edge_weight, edge_attr) 154 | x = self.act(x) 155 | x = self.lin(x) 156 | return x 157 | 158 | 159 | class CFConv(MessagePassing): 160 | def __init__(self, in_channels, out_channels, num_filters, nn, cutoff): 161 | super(CFConv, self).__init__(aggr='add') 162 | self.lin1 = Linear(in_channels, num_filters, bias=False) 163 | self.lin2 = Linear(num_filters, out_channels) 164 | self.nn = nn 165 | self.cutoff = cutoff 166 | 167 | self.reset_parameters() 168 | 169 | def reset_parameters(self): 170 | torch.nn.init.xavier_uniform_(self.lin1.weight) 171 | torch.nn.init.xavier_uniform_(self.lin2.weight) 172 | self.lin2.bias.data.fill_(0) 173 | 174 | def forward(self, x, edge_index, edge_weight, edge_attr): 175 | C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0) 176 | W = self.nn(edge_attr) * C.view(-1, 1) 177 | 178 | x = self.lin1(x) 179 | x = self.propagate(edge_index, x=x, W=W) 180 | x = self.lin2(x) 181 | return x 182 | 183 | def message(self, x_j, W): 184 | return x_j * W 185 | 186 | 187 | class GaussianSmearing(torch.nn.Module): 188 | 189 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50): 190 | super(GaussianSmearing, self).__init__() 191 | offset = torch.linspace(start, stop, num_gaussians) 192 | self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2 193 | self.register_buffer('offset', offset) 194 | 195 | def forward(self, dist): 196 | dist = dist.view(-1, 1) - self.offset.view(1, -1) 197 | return torch.exp(self.coeff * torch.pow(dist, 2)) 198 | 199 | 200 | class ShiftedSoftplus(torch.nn.Module): 201 | 202 | def __init__(self): 203 | super(ShiftedSoftplus, self).__init__() 204 | self.shift = torch.log(torch.tensor(2.0)).item() 205 | 206 | def forward(self, x): 207 | return F.softplus(x) - self.shift 208 | -------------------------------------------------------------------------------- /src_classification/pretrain_AM.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from config import args 8 | from dataloader import DataLoaderMasking 9 | from models import GNN 10 | from torch_geometric.nn import global_mean_pool 11 | from util import MaskAtom 12 | 13 | from datasets import MoleculeDataset 14 | 15 | 16 | def compute_accuracy(pred, target): 17 | return float(torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item())/len(pred) 18 | 19 | 20 | def do_AttrMasking(batch, criterion, node_repr, molecule_atom_masking_model): 21 | target = batch.mask_node_label[:, 0] 22 | node_pred = molecule_atom_masking_model(node_repr[batch.masked_atom_indices]) 23 | attributemask_loss = criterion(node_pred.double(), target) 24 | attributemask_acc = compute_accuracy(node_pred, target) 25 | return attributemask_loss, attributemask_acc 26 | 27 | 28 | def train(device, loader, optimizer): 29 | 30 | start = time.time() 31 | molecule_model.train() 32 | molecule_atom_masking_model.train() 33 | attributemask_loss_accum, attributemask_acc_accum = 0, 0 34 | 35 | for step, batch in enumerate(loader): 36 | batch = batch.to(device) 37 | node_repr = molecule_model(batch.masked_x, batch.edge_index, batch.edge_attr) 38 | 39 | attributemask_loss, attributemask_acc = do_AttrMasking( 40 | batch=batch, criterion=criterion, node_repr=node_repr, 41 | molecule_atom_masking_model=molecule_atom_masking_model) 42 | 43 | attributemask_loss_accum += attributemask_loss.detach().cpu().item() 44 | attributemask_acc_accum += attributemask_acc 45 | loss = attributemask_loss 46 | 47 | optimizer.zero_grad() 48 | loss.backward() 49 | optimizer.step() 50 | 51 | print('AM Loss: {:.5f}\tAM Acc: {:.5f}\tTime: {:.5f}'.format( 52 | attributemask_loss_accum / len(loader), 53 | attributemask_acc_accum / len(loader), 54 | time.time() - start)) 55 | return 56 | 57 | 58 | if __name__ == '__main__': 59 | 60 | np.random.seed(0) 61 | torch.manual_seed(0) 62 | device = torch.device('cuda:' + str(args.device)) \ 63 | if torch.cuda.is_available() else torch.device('cpu') 64 | if torch.cuda.is_available(): 65 | torch.cuda.manual_seed_all(0) 66 | torch.cuda.set_device(args.device) 67 | 68 | if 'GEOM' in args.dataset: 69 | dataset = MoleculeDataset( 70 | '../datasets/{}/'.format(args.dataset), dataset=args.dataset, 71 | transform=MaskAtom(num_atom_type=119, num_edge_type=5, 72 | mask_rate=args.mask_rate, mask_edge=args.mask_edge)) 73 | loader = DataLoaderMasking(dataset, batch_size=args.batch_size, 74 | shuffle=True, num_workers=args.num_workers) 75 | 76 | # set up model 77 | molecule_model = GNN(args.num_layer, args.emb_dim, 78 | JK=args.JK, drop_ratio=args.dropout_ratio, 79 | gnn_type=args.gnn_type).to(device) 80 | molecule_readout_func = global_mean_pool 81 | 82 | molecule_atom_masking_model = torch.nn.Linear(args.emb_dim, 119).to(device) 83 | 84 | model_param_group = [{'params': molecule_model.parameters(), 'lr': args.lr}, 85 | {'params': molecule_atom_masking_model.parameters(), 'lr': args.lr}] 86 | 87 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) 88 | criterion = nn.CrossEntropyLoss() 89 | 90 | for epoch in range(1, args.epochs + 1): 91 | print('epoch: {}'.format(epoch)) 92 | train(device, loader, optimizer) 93 | 94 | if not args.output_model_dir == '': 95 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth') 96 | 97 | saver_dict = {'model': molecule_model.state_dict(), 98 | 'molecule_atom_masking_model': molecule_atom_masking_model.state_dict()} 99 | 100 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 101 | -------------------------------------------------------------------------------- /src_classification/pretrain_CP.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from config import args 8 | from dataloader import DataLoaderSubstructContext 9 | from models import GNN 10 | from torch_geometric.nn import global_mean_pool 11 | from util import ExtractSubstructureContextPair, cycle_index 12 | 13 | from datasets import MoleculeDataset 14 | 15 | 16 | def do_ContextPred(batch, criterion, args, molecule_substruct_model, 17 | molecule_context_model, molecule_readout_func): 18 | 19 | # creating substructure representation 20 | substruct_repr = molecule_substruct_model( 21 | batch.x_substruct, batch.edge_index_substruct, 22 | batch.edge_attr_substruct)[batch.center_substruct_idx] 23 | 24 | # creating context representations 25 | overlapped_node_repr = molecule_context_model( 26 | batch.x_context, batch.edge_index_context, 27 | batch.edge_attr_context)[batch.overlap_context_substruct_idx] 28 | 29 | # positive context representation 30 | # readout -> global_mean_pool by default 31 | context_repr = molecule_readout_func(overlapped_node_repr, 32 | batch.batch_overlapped_context) 33 | 34 | # negative contexts are obtained by shifting 35 | # the indices of context embeddings 36 | neg_context_repr = torch.cat( 37 | [context_repr[cycle_index(len(context_repr), i + 1)] 38 | for i in range(args.contextpred_neg_samples)], dim=0) 39 | 40 | num_neg = args.contextpred_neg_samples 41 | pred_pos = torch.sum(substruct_repr * context_repr, dim=1) 42 | pred_neg = torch.sum(substruct_repr.repeat((num_neg, 1)) * neg_context_repr, dim=1) 43 | 44 | loss_pos = criterion(pred_pos.double(), 45 | torch.ones(len(pred_pos)).to(pred_pos.device).double()) 46 | loss_neg = criterion(pred_neg.double(), 47 | torch.zeros(len(pred_neg)).to(pred_neg.device).double()) 48 | 49 | contextpred_loss = loss_pos + num_neg * loss_neg 50 | 51 | num_pred = len(pred_pos) + len(pred_neg) 52 | contextpred_acc = (torch.sum(pred_pos > 0).float() + 53 | torch.sum(pred_neg < 0).float()) / num_pred 54 | contextpred_acc = contextpred_acc.detach().cpu().item() 55 | 56 | return contextpred_loss, contextpred_acc 57 | 58 | 59 | def train(args, device, loader, optimizer): 60 | 61 | start_time = time.time() 62 | molecule_context_model.train() 63 | molecule_substruct_model.train() 64 | contextpred_loss_accum, contextpred_acc_accum = 0, 0 65 | 66 | for step, batch in enumerate(loader): 67 | 68 | batch = batch.to(device) 69 | contextpred_loss, contextpred_acc = do_ContextPred( 70 | batch=batch, criterion=criterion, args=args, 71 | molecule_substruct_model=molecule_substruct_model, 72 | molecule_context_model=molecule_context_model, 73 | molecule_readout_func=molecule_readout_func) 74 | 75 | contextpred_loss_accum += contextpred_loss.detach().cpu().item() 76 | contextpred_acc_accum += contextpred_acc 77 | ssl_loss = contextpred_loss 78 | optimizer.zero_grad() 79 | ssl_loss.backward() 80 | optimizer.step() 81 | 82 | print('CP Loss: {:.5f}\tCP Acc: {:.5f}\tTime: {:.3f}'.format( 83 | contextpred_loss_accum / len(loader), 84 | contextpred_acc_accum / len(loader), 85 | time.time() - start_time)) 86 | 87 | return 88 | 89 | 90 | if __name__ == '__main__': 91 | 92 | np.random.seed(0) 93 | torch.manual_seed(0) 94 | device = torch.device('cuda:' + str(args.device)) \ 95 | if torch.cuda.is_available() else torch.device('cpu') 96 | if torch.cuda.is_available(): 97 | torch.cuda.manual_seed_all(0) 98 | torch.cuda.set_device(args.device) 99 | 100 | l1 = args.num_layer - 1 101 | l2 = l1 + args.csize 102 | print('num layer: %d l1: %d l2: %d' % (args.num_layer, l1, l2)) 103 | 104 | if 'GEOM' in args.dataset: 105 | dataset = MoleculeDataset( 106 | '../datasets/{}/'.format(args.dataset), dataset=args.dataset, 107 | transform=ExtractSubstructureContextPair(args.num_layer, l1, l2)) 108 | loader = DataLoaderSubstructContext(dataset, batch_size=args.batch_size, 109 | shuffle=True, num_workers=args.num_workers) 110 | 111 | ''' === set up model, mainly used in do_ContextPred() === ''' 112 | molecule_substruct_model = GNN( 113 | args.num_layer, args.emb_dim, JK=args.JK, 114 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device) 115 | molecule_context_model = GNN( 116 | int(l2 - l1), args.emb_dim, JK=args.JK, 117 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device) 118 | 119 | ''' === set up loss and optimiser === ''' 120 | criterion = nn.BCEWithLogitsLoss() 121 | molecule_readout_func = global_mean_pool 122 | model_param_group = [{'params': molecule_substruct_model.parameters(), 'lr': args.lr}, 123 | {'params': molecule_context_model.parameters(), 'lr': args.lr}] 124 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) 125 | 126 | for epoch in range(1, args.epochs + 1): 127 | print('epoch: {}'.format(epoch)) 128 | train(args, device, loader, optimizer) 129 | 130 | if not args.output_model_dir == '': 131 | torch.save(molecule_substruct_model.state_dict(), 132 | args.output_model_dir + '_model.pth') 133 | 134 | saver_dict = { 135 | 'molecule_substruct_model': molecule_substruct_model.state_dict(), 136 | 'molecule_context_model': molecule_context_model.state_dict()} 137 | 138 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 139 | -------------------------------------------------------------------------------- /src_classification/pretrain_Contextual.py: -------------------------------------------------------------------------------- 1 | import time 2 | from os.path import join 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from config import args 10 | from models import GNN, GNN_graphpred 11 | from sklearn.metrics import roc_auc_score 12 | from torch_geometric.data import DataLoader 13 | from util import get_num_task 14 | 15 | from datasets import MoleculeContextualDataset 16 | 17 | 18 | def compute_accuracy(pred, target): 19 | return float(torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item())/len(pred) 20 | 21 | 22 | def do_Contextual(batch, criterion, node_repr, atom_vocab_model): 23 | target = batch.atom_vocab_label 24 | node_pred = atom_vocab_model(node_repr) 25 | loss = criterion(node_pred, target) 26 | acc = compute_accuracy(node_pred, target) 27 | return loss, acc 28 | 29 | 30 | def train(device, loader, optimizer): 31 | start = time.time() 32 | molecule_model.train() 33 | atom_vocab_model.train() 34 | 35 | contextual_loss_accum, contextual_acc_accum = 0, 0 36 | for step, batch in enumerate(loader): 37 | batch = batch.to(device) 38 | node_repr = molecule_model(batch.x, batch.edge_index, batch.edge_attr) 39 | 40 | contextual_loss, contextual_acc = do_Contextual(batch, criterion, node_repr, atom_vocab_model) 41 | contextual_loss_accum += contextual_loss.detach().cpu().item() 42 | contextual_acc_accum += contextual_acc 43 | loss = contextual_loss 44 | 45 | optimizer.zero_grad() 46 | loss.backward() 47 | optimizer.step() 48 | 49 | print('Contextual Loss: {:.5f}\tContextual Acc: {:.5f}\tTime: {:.5f}'.format( 50 | contextual_loss_accum / len(loader), 51 | contextual_acc_accum / len(loader), 52 | time.time() - start)) 53 | return 54 | 55 | 56 | if __name__ == '__main__': 57 | torch.manual_seed(args.runseed) 58 | np.random.seed(args.runseed) 59 | device = torch.device('cuda:' + str(args.device)) \ 60 | if torch.cuda.is_available() else torch.device('cpu') 61 | if torch.cuda.is_available(): 62 | torch.cuda.manual_seed_all(args.runseed) 63 | 64 | # Bunch of classification tasks 65 | assert 'GEOM' in args.dataset 66 | dataset_folder = '../datasets/' 67 | dataset = MoleculeContextualDataset(dataset_folder + args.dataset, dataset=args.dataset) 68 | print(dataset) 69 | 70 | atom_vocab = dataset.atom_vocab 71 | atom_vocab_size = len(atom_vocab) 72 | print('atom_vocab\t', len(atom_vocab), atom_vocab_size) 73 | 74 | loader = DataLoader(dataset, batch_size=args.batch_size, 75 | shuffle=True, num_workers=args.num_workers) 76 | 77 | # set up model 78 | molecule_model = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim, 79 | JK=args.JK, drop_ratio=args.dropout_ratio, 80 | gnn_type=args.gnn_type).to(device) 81 | atom_vocab_model = nn.Linear(args.emb_dim, atom_vocab_size).to(device) 82 | 83 | # set up optimizer 84 | # different learning rates for different parts of GNN 85 | model_param_group = [{'params': molecule_model.parameters()}, 86 | {'params': atom_vocab_model.parameters(), 'lr': args.lr * args.lr_scale}] 87 | optimizer = optim.Adam(model_param_group, lr=args.lr, 88 | weight_decay=args.decay) 89 | criterion = nn.CrossEntropyLoss() 90 | train_roc_list, val_roc_list, test_roc_list = [], [], [] 91 | best_val_roc, best_val_idx = -1, 0 92 | 93 | print('\nStart pre-training Contextual') 94 | for epoch in range(1, args.epochs + 1): 95 | print('epoch: {}'.format(epoch)) 96 | train(device, loader, optimizer) 97 | 98 | if args.output_model_dir is not '': 99 | print('saving to {}'.format(args.output_model_dir + '_model.pth')) 100 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth') 101 | saved_model_dict = { 102 | 'molecule_model': molecule_model.state_dict(), 103 | 'atom_vocab_model': atom_vocab_model.state_dict(), 104 | } 105 | torch.save(saved_model_dict, args.output_model_dir + '_model_complete.pth') 106 | -------------------------------------------------------------------------------- /src_classification/pretrain_EP.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from config import args 8 | from dataloader import DataLoaderAE 9 | from models import GNN 10 | from torch_geometric.nn import global_mean_pool 11 | from util import NegativeEdge 12 | 13 | from datasets import MoleculeDataset 14 | 15 | 16 | def do_EdgePred(node_repr, batch, criterion=nn.BCEWithLogitsLoss()): 17 | 18 | # positive/negative scores -> inner product of node features 19 | positive_score = torch.sum(node_repr[batch.edge_index[0, ::2]] * 20 | node_repr[batch.edge_index[1, ::2]], dim=1) 21 | negative_score = torch.sum(node_repr[batch.negative_edge_index[0]] * 22 | node_repr[batch.negative_edge_index[1]], dim=1) 23 | 24 | edgepred_loss = criterion(positive_score, torch.ones_like(positive_score)) + \ 25 | criterion(negative_score, torch.zeros_like(negative_score)) 26 | edgepred_acc = (torch.sum(positive_score > 0) + 27 | torch.sum(negative_score < 0)).to(torch.float32) / \ 28 | float(2 * len(positive_score)) 29 | edgepred_acc = edgepred_acc.detach().cpu().item() 30 | 31 | return edgepred_loss, edgepred_acc 32 | 33 | 34 | def train(molecule_model, device, loader, optimizer, 35 | criterion=nn.BCEWithLogitsLoss()): 36 | 37 | # Train for one epoch 38 | molecule_model.train() 39 | start_time = time.time() 40 | edgepred_loss_accum, edgepred_acc_accum = 0, 0 41 | 42 | for step, batch in enumerate(loader): 43 | 44 | batch = batch.to(device) 45 | 46 | node_repr = molecule_model(batch.x, batch.edge_index, batch.edge_attr) 47 | edgepred_loss, edgepred_acc = do_EdgePred( 48 | node_repr=node_repr, batch=batch, criterion=criterion) 49 | edgepred_loss_accum += edgepred_loss.detach().cpu().item() 50 | edgepred_acc_accum += edgepred_acc 51 | ssl_loss = edgepred_loss 52 | 53 | optimizer.zero_grad() 54 | ssl_loss.backward() 55 | optimizer.step() 56 | 57 | print('EP Loss: {:.5f}\tEP Acc: {:.5f}\tTime: {:.5f}'.format( 58 | edgepred_loss_accum / len(loader), 59 | edgepred_acc_accum / len(loader), 60 | time.time() - start_time)) 61 | return 62 | 63 | 64 | if __name__ == '__main__': 65 | torch.manual_seed(0) 66 | np.random.seed(0) 67 | device = torch.device('cuda:' + str(args.device)) if torch.cuda.is_available() else torch.device('cpu') 68 | if torch.cuda.is_available(): 69 | torch.cuda.manual_seed_all(0) 70 | torch.cuda.set_device(args.device) 71 | 72 | if 'GEOM' in args.dataset: 73 | dataset = MoleculeDataset('../datasets/{}/'.format(args.dataset), dataset=args.dataset, transform=NegativeEdge()) 74 | loader = DataLoaderAE(dataset, batch_size=args.batch_size, 75 | shuffle=True, num_workers=args.num_workers) 76 | 77 | # set up model 78 | molecule_model = GNN(args.num_layer, args.emb_dim, 79 | JK=args.JK, drop_ratio=args.dropout_ratio, 80 | gnn_type=args.gnn_type).to(device) 81 | molecule_readout_func = global_mean_pool 82 | 83 | model_param_group = [{'params': molecule_model.parameters(), 'lr': args.lr}] 84 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) 85 | criterion = nn.BCEWithLogitsLoss() 86 | 87 | for epoch in range(1, args.epochs + 1): 88 | print('epoch: {}'.format(epoch)) 89 | train(molecule_model, device, loader, optimizer) 90 | 91 | if not args.output_model_dir == '': 92 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth') 93 | saver_dict = {'model': molecule_model.state_dict()} 94 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 95 | -------------------------------------------------------------------------------- /src_classification/pretrain_GPT_GNN.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from config import args 8 | from models import GNN 9 | from torch_geometric.data import DataLoader 10 | from torch_geometric.nn import global_mean_pool 11 | 12 | from datasets import MoleculeDataset, MoleculeDatasetGPT 13 | 14 | 15 | def compute_accuracy(pred, target): 16 | return float(torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item())/len(pred) 17 | 18 | 19 | def train(device, loader, optimizer): 20 | start = time.time() 21 | molecule_model.train() 22 | node_pred_model.train() 23 | gpt_loss_accum, gpt_acc_accum = 0, 0 24 | 25 | for step, batch in enumerate(loader): 26 | batch = batch.to(device) 27 | node_repr = molecule_model(batch.x, batch.edge_index, batch.edge_attr) 28 | graph_repr = molecule_readout_func(node_repr, batch.batch) 29 | node_pred = node_pred_model(graph_repr) 30 | target = batch.next_x 31 | 32 | gpt_loss = criterion(node_pred.double(), target) 33 | gpt_acc = compute_accuracy(node_pred, target) 34 | 35 | gpt_loss_accum += gpt_loss.detach().cpu().item() 36 | gpt_acc_accum += gpt_acc 37 | loss = gpt_loss 38 | 39 | optimizer.zero_grad() 40 | loss.backward() 41 | optimizer.step() 42 | 43 | print('GPT Loss: {:.5f}\tGPT Acc: {:.5f}\tTime: {:.5f}'.format( 44 | gpt_loss_accum / len(loader), gpt_acc_accum / len(loader), time.time() - start)) 45 | return 46 | 47 | 48 | if __name__ == '__main__': 49 | torch.manual_seed(0) 50 | np.random.seed(0) 51 | device = torch.device('cuda:' + str(args.device)) \ 52 | if torch.cuda.is_available() else torch.device('cpu') 53 | if torch.cuda.is_available(): 54 | torch.cuda.manual_seed_all(0) 55 | torch.cuda.set_device(args.device) 56 | 57 | if 'GEOM' in args.dataset: 58 | molecule_dataset = MoleculeDataset('../datasets/{}/'.format(args.dataset), dataset=args.dataset) 59 | molecule_gpt_dataset = MoleculeDatasetGPT(molecule_dataset) 60 | loader = DataLoader(molecule_gpt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 61 | 62 | # set up model 63 | molecule_model = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim, 64 | JK=args.JK, drop_ratio=args.dropout_ratio, 65 | gnn_type=args.gnn_type).to(device) 66 | node_pred_model = nn.Linear(args.emb_dim, 120).to(device) 67 | 68 | model_param_group = [ 69 | {'params': molecule_model.parameters(), 'lr': args.lr}, 70 | {'params': node_pred_model.parameters(), 'lr': args.lr}, 71 | ] 72 | 73 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) 74 | molecule_readout_func = global_mean_pool 75 | criterion = nn.CrossEntropyLoss() 76 | 77 | for epoch in range(1, args.epochs + 1): 78 | print('epoch: {}'.format(epoch)) 79 | train(device, loader, optimizer) 80 | 81 | if not args.output_model_dir == '': 82 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth') 83 | saver_dict = { 84 | 'model': molecule_model.state_dict(), 85 | 'node_pred_model': node_pred_model.state_dict(), 86 | } 87 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 88 | -------------------------------------------------------------------------------- /src_classification/pretrain_GraphCL.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | from models import GNN 8 | from pretrain_JOAO import graphcl 9 | from torch_geometric.data import DataLoader 10 | 11 | from datasets import MoleculeDataset_graphcl 12 | 13 | 14 | def train(loader, model, optimizer, device): 15 | 16 | model.train() 17 | train_loss_accum = 0 18 | 19 | for step, (_, batch1, batch2) in enumerate(loader): 20 | # _, batch1, batch2 = batch 21 | batch1 = batch1.to(device) 22 | batch2 = batch2.to(device) 23 | 24 | x1 = model.forward_cl(batch1.x, batch1.edge_index, 25 | batch1.edge_attr, batch1.batch) 26 | x2 = model.forward_cl(batch2.x, batch2.edge_index, 27 | batch2.edge_attr, batch2.batch) 28 | loss = model.loss_cl(x1, x2) 29 | 30 | optimizer.zero_grad() 31 | loss.backward() 32 | optimizer.step() 33 | 34 | train_loss_accum += float(loss.detach().cpu().item()) 35 | 36 | return train_loss_accum / (step + 1) 37 | 38 | 39 | if __name__ == "__main__": 40 | # Training settings 41 | parser = argparse.ArgumentParser(description='GraphCL') 42 | parser.add_argument('--device', type=int, default=0, help='gpu') 43 | parser.add_argument('--batch_size', type=int, default=256, help='batch') 44 | parser.add_argument('--decay', type=float, default=0, help='weight decay') 45 | parser.add_argument('--epochs', type=int, default=100, help='train epochs') 46 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 47 | parser.add_argument('--JK', type=str, default="last", 48 | choices=['last', 'sum', 'max', 'concat'], 49 | help='how the node features across layers are combined.') 50 | parser.add_argument('--gnn_type', type=str, default="gin", help='gnn model type') 51 | parser.add_argument('--dropout_ratio', type=float, default=0, help='dropout ratio') 52 | parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions') 53 | parser.add_argument('--dataset', type=str, default=None, help='root dir of dataset') 54 | parser.add_argument('--num_layer', type=int, default=5, help='message passing layers') 55 | # parser.add_argument('--seed', type=int, default=0, help="Seed for splitting dataset") 56 | parser.add_argument('--output_model_file', type=str, default='', help='model save path') 57 | parser.add_argument('--num_workers', type=int, default=8, help='workers for dataset loading') 58 | 59 | parser.add_argument('--aug_mode', type=str, default='sample') 60 | parser.add_argument('--aug_strength', type=float, default=0.2) 61 | 62 | # parser.add_argument('--gamma', type=float, default=0.1) 63 | parser.add_argument('--output_model_dir', type=str, default='') 64 | args = parser.parse_args() 65 | 66 | torch.manual_seed(0) 67 | np.random.seed(0) 68 | device = torch.device("cuda:" + str(args.device)) \ 69 | if torch.cuda.is_available() else torch.device("cpu") 70 | if torch.cuda.is_available(): 71 | torch.cuda.manual_seed_all(0) 72 | 73 | # set up dataset 74 | if 'GEOM' in args.dataset: 75 | dataset = MoleculeDataset_graphcl('../datasets/{}/'.format(args.dataset), 76 | dataset=args.dataset) 77 | dataset.set_augMode(args.aug_mode) 78 | dataset.set_augStrength(args.aug_strength) 79 | print(dataset) 80 | 81 | loader = DataLoader(dataset, batch_size=args.batch_size, 82 | num_workers=args.num_workers, shuffle=True) 83 | 84 | # set up model 85 | gnn = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim, JK=args.JK, 86 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type) 87 | 88 | model = graphcl(gnn) 89 | model.to(device) 90 | 91 | # set up optimizer 92 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 93 | print(optimizer) 94 | 95 | aug_prob = np.ones(25) / 25 96 | dataset.set_augProb(aug_prob) 97 | for epoch in range(1, args.epochs + 1): 98 | start_time = time.time() 99 | pretrain_loss = train(loader, model, optimizer, device) 100 | 101 | print('Epoch: {:3d}\tLoss:{:.3f}\tTime: {:.3f}:'.format( 102 | epoch, pretrain_loss, time.time() - start_time)) 103 | 104 | if not args.output_model_dir == '': 105 | saver_dict = {'model': model.state_dict()} 106 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 107 | torch.save(model.gnn.state_dict(), args.output_model_dir + '_model.pth') 108 | -------------------------------------------------------------------------------- /src_classification/pretrain_GraphMVP.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from config import args 9 | from models import GNN, AutoEncoder, SchNet, VariationalAutoEncoder 10 | from torch_geometric.data import DataLoader 11 | from torch_geometric.nn import global_mean_pool 12 | from tqdm import tqdm 13 | from util import dual_CL 14 | 15 | from datasets import Molecule3DMaskingDataset 16 | 17 | 18 | def save_model(save_best): 19 | if not args.output_model_dir == '': 20 | if save_best: 21 | global optimal_loss 22 | print('save model with loss: {:.5f}'.format(optimal_loss)) 23 | torch.save(molecule_model_2D.state_dict(), args.output_model_dir + '_model.pth') 24 | saver_dict = { 25 | 'model': molecule_model_2D.state_dict(), 26 | 'model_3D': molecule_model_3D.state_dict(), 27 | 'AE_2D_3D_model': AE_2D_3D_model.state_dict(), 28 | 'AE_3D_2D_model': AE_3D_2D_model.state_dict(), 29 | } 30 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 31 | 32 | else: 33 | torch.save(molecule_model_2D.state_dict(), args.output_model_dir + '_model_final.pth') 34 | saver_dict = { 35 | 'model': molecule_model_2D.state_dict(), 36 | 'model_3D': molecule_model_3D.state_dict(), 37 | 'AE_2D_3D_model': AE_2D_3D_model.state_dict(), 38 | 'AE_3D_2D_model': AE_3D_2D_model.state_dict(), 39 | } 40 | torch.save(saver_dict, args.output_model_dir + '_model_complete_final.pth') 41 | return 42 | 43 | 44 | def train(args, molecule_model_2D, device, loader, optimizer): 45 | start_time = time.time() 46 | 47 | molecule_model_2D.train() 48 | molecule_model_3D.train() 49 | if molecule_projection_layer is not None: 50 | molecule_projection_layer.train() 51 | 52 | AE_loss_accum, AE_acc_accum = 0, 0 53 | CL_loss_accum, CL_acc_accum = 0, 0 54 | 55 | if args.verbose: 56 | l = tqdm(loader) 57 | else: 58 | l = loader 59 | for step, batch in enumerate(l): 60 | batch = batch.to(device) 61 | 62 | node_repr = molecule_model_2D(batch.x, batch.edge_index, batch.edge_attr) 63 | molecule_2D_repr = molecule_readout_func(node_repr, batch.batch) 64 | 65 | if args.model_3d == 'schnet': 66 | molecule_3D_repr = molecule_model_3D(batch.x[:, 0], batch.positions, batch.batch) 67 | 68 | 69 | CL_loss, CL_acc = dual_CL(molecule_2D_repr, molecule_3D_repr, args) 70 | AE_loss_1 = AE_2D_3D_model(molecule_2D_repr, molecule_3D_repr) 71 | AE_loss_2 = AE_3D_2D_model(molecule_3D_repr, molecule_2D_repr) 72 | AE_acc_1 = AE_acc_2 = 0 73 | AE_loss = (AE_loss_1 + AE_loss_2) / 2 74 | 75 | CL_loss_accum += CL_loss.detach().cpu().item() 76 | CL_acc_accum += CL_acc 77 | AE_loss_accum += AE_loss.detach().cpu().item() 78 | AE_acc_accum += (AE_acc_1 + AE_acc_2) / 2 79 | 80 | loss = 0 81 | if args.alpha_1 > 0: 82 | loss += CL_loss * args.alpha_1 83 | if args.alpha_2 > 0: 84 | loss += AE_loss * args.alpha_2 85 | 86 | optimizer.zero_grad() 87 | loss.backward() 88 | optimizer.step() 89 | 90 | global optimal_loss 91 | CL_loss_accum /= len(loader) 92 | CL_acc_accum /= len(loader) 93 | AE_loss_accum /= len(loader) 94 | AE_acc_accum /= len(loader) 95 | temp_loss = args.alpha_1 * CL_loss_accum + args.alpha_2 * AE_loss_accum 96 | if temp_loss < optimal_loss: 97 | optimal_loss = temp_loss 98 | save_model(save_best=True) 99 | print('CL Loss: {:.5f}\tCL Acc: {:.5f}\t\tAE Loss: {:.5f}\tAE Acc: {:.5f}\tTime: {:.5f}'.format( 100 | CL_loss_accum, CL_acc_accum, AE_loss_accum, AE_acc_accum, time.time() - start_time)) 101 | return 102 | 103 | 104 | if __name__ == '__main__': 105 | torch.manual_seed(0) 106 | np.random.seed(0) 107 | device = torch.device('cuda:' + str(args.device)) if torch.cuda.is_available() else torch.device('cpu') 108 | if torch.cuda.is_available(): 109 | torch.cuda.manual_seed_all(0) 110 | torch.cuda.set_device(args.device) 111 | 112 | if 'GEOM' in args.dataset: 113 | data_root = '../datasets/{}/'.format(args.dataset) if args.input_data_dir == '' else '{}/{}/'.format(args.input_data_dir, args.dataset) 114 | dataset = Molecule3DMaskingDataset(data_root, dataset=args.dataset, mask_ratio=args.SSL_masking_ratio) 115 | else: 116 | raise Exception 117 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 118 | 119 | # set up model 120 | molecule_model_2D = GNN(args.num_layer, args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device) 121 | molecule_readout_func = global_mean_pool 122 | 123 | print('Using 3d model\t', args.model_3d) 124 | molecule_projection_layer = None 125 | if args.model_3d == 'schnet': 126 | molecule_model_3D = SchNet( 127 | hidden_channels=args.emb_dim, num_filters=args.num_filters, num_interactions=args.num_interactions, 128 | num_gaussians=args.num_gaussians, cutoff=args.cutoff, atomref=None, readout=args.readout).to(device) 129 | else: 130 | raise NotImplementedError('Model {} not included.'.format(args.model_3d)) 131 | 132 | if args.AE_model == 'AE': 133 | AE_2D_3D_model = AutoEncoder( 134 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target).to(device) 135 | AE_3D_2D_model = AutoEncoder( 136 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target).to(device) 137 | elif args.AE_model == 'VAE': 138 | AE_2D_3D_model = VariationalAutoEncoder( 139 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target, beta=args.beta).to(device) 140 | AE_3D_2D_model = VariationalAutoEncoder( 141 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target, beta=args.beta).to(device) 142 | else: 143 | raise Exception 144 | 145 | model_param_group = [] 146 | model_param_group.append({'params': molecule_model_2D.parameters(), 'lr': args.lr * args.gnn_lr_scale}) 147 | model_param_group.append({'params': molecule_model_3D.parameters(), 'lr': args.lr * args.schnet_lr_scale}) 148 | model_param_group.append({'params': AE_2D_3D_model.parameters(), 'lr': args.lr * args.gnn_lr_scale}) 149 | model_param_group.append({'params': AE_3D_2D_model.parameters(), 'lr': args.lr * args.schnet_lr_scale}) 150 | if molecule_projection_layer is not None: 151 | model_param_group.append({'params': molecule_projection_layer.parameters(), 'lr': args.lr * args.schnet_lr_scale}) 152 | 153 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) 154 | optimal_loss = 1e10 155 | 156 | for epoch in range(1, args.epochs + 1): 157 | print('epoch: {}'.format(epoch)) 158 | train(args, molecule_model_2D, device, loader, optimizer) 159 | 160 | save_model(save_best=False) 161 | -------------------------------------------------------------------------------- /src_classification/pretrain_IG.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from config import args 8 | from models import GNN, Discriminator 9 | from torch_geometric.data import DataLoader 10 | from torch_geometric.nn import global_mean_pool 11 | from util import cycle_index 12 | 13 | from datasets import MoleculeDataset 14 | 15 | 16 | def do_InfoGraph(node_repr, molecule_repr, batch, 17 | criterion, infograph_discriminator_SSL_model): 18 | 19 | summary_repr = torch.sigmoid(molecule_repr) 20 | positive_expanded_summary_repr = summary_repr[batch.batch] 21 | shifted_summary_repr = summary_repr[cycle_index(len(summary_repr), 1)] 22 | negative_expanded_summary_repr = shifted_summary_repr[batch.batch] 23 | 24 | positive_score = infograph_discriminator_SSL_model( 25 | node_repr, positive_expanded_summary_repr) 26 | negative_score = infograph_discriminator_SSL_model( 27 | node_repr, negative_expanded_summary_repr) 28 | infograph_loss = criterion(positive_score, torch.ones_like(positive_score)) + \ 29 | criterion(negative_score, torch.zeros_like(negative_score)) 30 | 31 | num_sample = float(2 * len(positive_score)) 32 | infograph_acc = (torch.sum(positive_score > 0) + 33 | torch.sum(negative_score < 0)).to(torch.float32) / num_sample 34 | infograph_acc = infograph_acc.detach().cpu().item() 35 | 36 | return infograph_loss, infograph_acc 37 | 38 | 39 | def train(molecule_model, device, loader, optimizer): 40 | 41 | start = time.time() 42 | molecule_model.train() 43 | infograph_loss_accum, infograph_acc_accum = 0, 0 44 | 45 | for step, batch in enumerate(loader): 46 | 47 | batch = batch.to(device) 48 | node_repr = molecule_model(batch.x, batch.edge_index, batch.edge_attr) 49 | molecule_repr = molecule_readout_func(node_repr, batch.batch) 50 | 51 | infograph_loss, infograph_acc = do_InfoGraph( 52 | node_repr=node_repr, batch=batch, 53 | molecule_repr=molecule_repr, criterion=criterion, 54 | infograph_discriminator_SSL_model=infograph_discriminator_SSL_model) 55 | 56 | infograph_loss_accum += infograph_loss.detach().cpu().item() 57 | infograph_acc_accum += infograph_acc 58 | ssl_loss = infograph_loss 59 | optimizer.zero_grad() 60 | ssl_loss.backward() 61 | optimizer.step() 62 | 63 | print('IG Loss: {:.5f}\tIG Acc: {:.5f}\tTime: {:.3f}'.format( 64 | infograph_loss_accum / len(loader), 65 | infograph_acc_accum / len(loader), 66 | time.time() - start)) 67 | return 68 | 69 | 70 | if __name__ == '__main__': 71 | 72 | torch.manual_seed(0) 73 | np.random.seed(0) 74 | device = torch.device('cuda:' + str(args.device)) \ 75 | if torch.cuda.is_available() else torch.device('cpu') 76 | if torch.cuda.is_available(): 77 | torch.cuda.manual_seed_all(0) 78 | torch.cuda.set_device(args.device) 79 | 80 | if 'GEOM' in args.dataset: 81 | dataset = MoleculeDataset('../datasets/{}/'.format(args.dataset), dataset=args.dataset) 82 | loader = DataLoader(dataset, batch_size=args.batch_size, 83 | shuffle=True, num_workers=args.num_workers) 84 | 85 | # set up model 86 | molecule_model = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim, 87 | JK=args.JK, drop_ratio=args.dropout_ratio, 88 | gnn_type=args.gnn_type).to(device) 89 | infograph_discriminator_SSL_model = Discriminator(args.emb_dim).to(device) 90 | 91 | model_param_group = [{'params': molecule_model.parameters(), 'lr': args.lr}, 92 | {'params': infograph_discriminator_SSL_model.parameters(), 93 | 'lr': args.lr}] 94 | 95 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) 96 | molecule_readout_func = global_mean_pool 97 | criterion = nn.BCEWithLogitsLoss() 98 | 99 | for epoch in range(1, args.epochs + 1): 100 | print('epoch: {}'.format(epoch)) 101 | train(molecule_model, device, loader, optimizer) 102 | 103 | if not args.output_model_dir == '': 104 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth') 105 | 106 | saver_dict = {'model': molecule_model.state_dict(), 107 | 'infograph_discriminator_SSL_model': infograph_discriminator_SSL_model.state_dict()} 108 | 109 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 110 | -------------------------------------------------------------------------------- /src_classification/pretrain_JOAO.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from itertools import repeat 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from models import GNN 10 | from torch_geometric.data import DataLoader 11 | from torch_geometric.nn import global_mean_pool 12 | 13 | from datasets import MoleculeDataset_graphcl 14 | 15 | 16 | class graphcl(nn.Module): 17 | 18 | def __init__(self, gnn): 19 | super(graphcl, self).__init__() 20 | self.gnn = gnn 21 | self.pool = global_mean_pool 22 | self.projection_head = nn.Sequential( 23 | nn.Linear(300, 300), 24 | nn.ReLU(inplace=True), 25 | nn.Linear(300, 300)) 26 | 27 | def forward_cl(self, x, edge_index, edge_attr, batch): 28 | x = self.gnn(x, edge_index, edge_attr) 29 | x = self.pool(x, batch) 30 | x = self.projection_head(x) 31 | return x 32 | 33 | def loss_cl(self, x1, x2): 34 | T = 0.1 35 | batch, _ = x1.size() 36 | x1_abs = x1.norm(dim=1) 37 | x2_abs = x2.norm(dim=1) 38 | 39 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / \ 40 | torch.einsum('i,j->ij', x1_abs, x2_abs) 41 | sim_matrix = torch.exp(sim_matrix / T) 42 | pos_sim = sim_matrix[range(batch), range(batch)] 43 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 44 | loss = - torch.log(loss).mean() 45 | return loss 46 | 47 | 48 | def train(loader, model, optimizer, device, gamma_joao): 49 | 50 | model.train() 51 | train_loss_accum = 0 52 | 53 | for step, (_, batch1, batch2) in enumerate(loader): 54 | # _, batch1, batch2 = batch 55 | batch1 = batch1.to(device) 56 | batch2 = batch2.to(device) 57 | 58 | # pdb.set_trace() 59 | x1 = model.forward_cl(batch1.x, batch1.edge_index, 60 | batch1.edge_attr, batch1.batch) 61 | x2 = model.forward_cl(batch2.x, batch2.edge_index, 62 | batch2.edge_attr, batch2.batch) 63 | loss = model.loss_cl(x1, x2) 64 | 65 | optimizer.zero_grad() 66 | loss.backward() 67 | optimizer.step() 68 | 69 | train_loss_accum += float(loss.detach().cpu().item()) 70 | 71 | # joao 72 | aug_prob = loader.dataset.aug_prob 73 | loss_aug = np.zeros(25) 74 | for n in range(25): 75 | _aug_prob = np.zeros(25) 76 | _aug_prob[n] = 1 77 | loader.dataset.set_augProb(_aug_prob) 78 | # for efficiency, we only use around 10% of data to estimate the loss 79 | count, count_stop = 0, len(loader.dataset) // (loader.batch_size * 10) + 1 80 | 81 | with torch.no_grad(): 82 | for step, (_, batch1, batch2) in enumerate(loader): 83 | # _, batch1, batch2 = batch 84 | batch1 = batch1.to(device) 85 | batch2 = batch2.to(device) 86 | 87 | x1 = model.forward_cl(batch1.x, batch1.edge_index, 88 | batch1.edge_attr, batch1.batch) 89 | x2 = model.forward_cl(batch2.x, batch2.edge_index, 90 | batch2.edge_attr, batch2.batch) 91 | loss = model.loss_cl(x1, x2) 92 | loss_aug[n] += loss.item() 93 | count += 1 94 | if count == count_stop: 95 | break 96 | loss_aug[n] /= count 97 | 98 | # view selection, projected gradient descent, 99 | # reference: https://arxiv.org/abs/1906.03563 100 | beta = 1 101 | gamma = gamma_joao 102 | 103 | b = aug_prob + beta * (loss_aug - gamma * (aug_prob - 1 / 25)) 104 | mu_min, mu_max = b.min() - 1 / 25, b.max() - 1 / 25 105 | mu = (mu_min + mu_max) / 2 106 | 107 | # bisection method 108 | while abs(np.maximum(b - mu, 0).sum() - 1) > 1e-2: 109 | if np.maximum(b - mu, 0).sum() > 1: 110 | mu_min = mu 111 | else: 112 | mu_max = mu 113 | mu = (mu_min + mu_max) / 2 114 | 115 | aug_prob = np.maximum(b - mu, 0) 116 | aug_prob /= aug_prob.sum() 117 | 118 | return train_loss_accum / (step + 1), aug_prob 119 | 120 | 121 | if __name__ == "__main__": 122 | # Training settings 123 | parser = argparse.ArgumentParser(description='JOAO') 124 | parser.add_argument('--device', type=int, default=0, help='gpu') 125 | parser.add_argument('--batch_size', type=int, default=256, help='batch') 126 | parser.add_argument('--decay', type=float, default=0, help='weight decay') 127 | parser.add_argument('--epochs', type=int, default=100, help='train epochs') 128 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 129 | parser.add_argument('--JK', type=str, default="last", 130 | choices=['last', 'sum', 'max', 'concat'], 131 | help='how the node features across layers are combined.') 132 | parser.add_argument('--gnn_type', type=str, default="gin", help='gnn model type') 133 | parser.add_argument('--dropout_ratio', type=float, default=0, help='dropout ratio') 134 | parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions') 135 | parser.add_argument('--dataset', type=str, default=None, help='root dir of dataset') 136 | parser.add_argument('--num_layer', type=int, default=5, help='message passing layers') 137 | # parser.add_argument('--seed', type=int, default=0, help="Seed for splitting dataset") 138 | parser.add_argument('--output_model_file', type=str, default='', help='model save path') 139 | parser.add_argument('--num_workers', type=int, default=8, help='workers for dataset loading') 140 | 141 | parser.add_argument('--aug_mode', type=str, default='sample') 142 | parser.add_argument('--aug_strength', type=float, default=0.2) 143 | 144 | parser.add_argument('--gamma', type=float, default=0.1) 145 | parser.add_argument('--output_model_dir', type=str, default='') 146 | args = parser.parse_args() 147 | 148 | torch.manual_seed(0) 149 | np.random.seed(0) 150 | device = torch.device("cuda:" + str(args.device)) \ 151 | if torch.cuda.is_available() else torch.device("cpu") 152 | if torch.cuda.is_available(): 153 | torch.cuda.manual_seed_all(0) 154 | 155 | if 'GEOM' in args.dataset: 156 | dataset = MoleculeDataset_graphcl('../datasets/{}/'.format(args.dataset), dataset=args.dataset) 157 | dataset.set_augMode(args.aug_mode) 158 | dataset.set_augStrength(args.aug_strength) 159 | print(dataset) 160 | 161 | loader = DataLoader(dataset, batch_size=args.batch_size, 162 | num_workers=args.num_workers, shuffle=True) 163 | 164 | # set up model 165 | gnn = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim, JK=args.JK, 166 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type) 167 | 168 | model = graphcl(gnn) 169 | model.to(device) 170 | 171 | # set up optimizer 172 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 173 | # print(optimizer) 174 | 175 | # pdb.set_trace() 176 | aug_prob = np.ones(25) / 25 177 | np.set_printoptions(precision=3, floatmode='fixed') 178 | 179 | for epoch in range(1, args.epochs + 1): 180 | print('\n\n') 181 | start_time = time.time() 182 | dataset.set_augProb(aug_prob) 183 | pretrain_loss, aug_prob = train(loader, model, optimizer, device, args.gamma) 184 | 185 | print('Epoch: {:3d}\tLoss:{:.3f}\tTime: {:.3f}\tAugmentation Probability:'.format( 186 | epoch, pretrain_loss, time.time() - start_time)) 187 | print(aug_prob) 188 | 189 | if not args.output_model_dir == '': 190 | saver_dict = {'model': model.state_dict()} 191 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 192 | torch.save(model.gnn.state_dict(), args.output_model_dir + '_model.pth') 193 | -------------------------------------------------------------------------------- /src_classification/pretrain_Motif.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from config import args 7 | from models import GNN, GNN_graphpred 8 | from sklearn.metrics import roc_auc_score 9 | from torch_geometric.data import DataLoader 10 | 11 | from datasets import RDKIT_PROPS, MoleculeMotifDataset 12 | 13 | 14 | def train(model, device, loader, optimizer): 15 | model.train() 16 | total_loss = 0 17 | 18 | for step, batch in enumerate(loader): 19 | batch = batch.to(device) 20 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) 21 | y = batch.y.view(pred.shape).double() 22 | 23 | loss = criterion(pred.double(), y) 24 | 25 | optimizer.zero_grad() 26 | loss.backward() 27 | optimizer.step() 28 | total_loss += loss.detach().item() 29 | return total_loss / len(loader) 30 | 31 | 32 | def eval(model, device, loader): 33 | model.eval() 34 | y_true, y_scores = [], [] 35 | 36 | for step, batch in enumerate(loader): 37 | batch = batch.to(device) 38 | with torch.no_grad(): 39 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) 40 | true = batch.y.view(pred.shape) 41 | y_true.append(true) 42 | y_scores.append(pred) 43 | 44 | y_true = torch.cat(y_true, dim=0).cpu().numpy() 45 | y_scores = torch.cat(y_scores, dim=0).cpu().numpy() 46 | 47 | roc_list = [] 48 | for i in range(y_true.shape[1]): 49 | if np.sum(y_true[:, i] == 1) > 0: 50 | roc_list.append(roc_auc_score(y_true[:, i], y_scores[:, i])) 51 | return sum(roc_list) / len(roc_list), y_true, y_scores 52 | 53 | 54 | if __name__ == '__main__': 55 | torch.manual_seed(args.runseed) 56 | np.random.seed(args.runseed) 57 | device = torch.device('cuda:' + str(args.device)) \ 58 | if torch.cuda.is_available() else torch.device('cpu') 59 | if torch.cuda.is_available(): 60 | torch.cuda.manual_seed_all(args.runseed) 61 | 62 | # Bunch of classification tasks 63 | num_tasks = len(RDKIT_PROPS) 64 | assert 'GEOM' in args.dataset 65 | dataset_folder = '../datasets/' 66 | dataset = MoleculeMotifDataset(dataset_folder + args.dataset, dataset=args.dataset) 67 | print(dataset) 68 | 69 | loader = DataLoader(dataset, batch_size=args.batch_size, 70 | shuffle=True, num_workers=args.num_workers) 71 | 72 | # set up model 73 | molecule_model = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim, 74 | JK=args.JK, drop_ratio=args.dropout_ratio, 75 | gnn_type=args.gnn_type) 76 | model = GNN_graphpred(args=args, num_tasks=num_tasks, molecule_model=molecule_model) 77 | if not args.input_model_file == '': 78 | model.from_pretrained(args.input_model_file) 79 | model.to(device) 80 | 81 | # set up optimizer 82 | # different learning rates for different parts of GNN 83 | model_param_group = [{'params': model.molecule_model.parameters()}, 84 | {'params': model.graph_pred_linear.parameters(), 85 | 'lr': args.lr * args.lr_scale}] 86 | optimizer = optim.Adam(model_param_group, lr=args.lr, 87 | weight_decay=args.decay) 88 | criterion = nn.BCEWithLogitsLoss() 89 | train_roc_list, val_roc_list, test_roc_list = [], [], [] 90 | best_val_roc, best_val_idx = -1, 0 91 | 92 | print('\nStart pre-training Motif') 93 | for epoch in range(1, args.epochs + 1): 94 | loss_acc = train(model, device, loader, optimizer) 95 | print('Epoch: {}\nLoss: {}'.format(epoch, loss_acc)) 96 | 97 | if args.eval_train: 98 | train_roc, train_target, train_pred = eval(model, device, loader) 99 | else: 100 | train_roc = 0 101 | 102 | train_roc_list.append(train_roc) 103 | print('train: {:.6f}\n'.format(train_roc)) 104 | 105 | if args.output_model_dir is not '': 106 | print('saving to {}'.format(args.output_model_dir + '_model.pth')) 107 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth') 108 | saved_model_dict = { 109 | 'molecule_model': molecule_model.state_dict(), 110 | 'model': model.state_dict(), 111 | } 112 | torch.save(saved_model_dict, args.output_model_dir + '_model_complete.pth') 113 | -------------------------------------------------------------------------------- /src_classification/run_molecule_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python molecule_finetune.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_classification/run_pretrain_AM.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_AM.py $@ 11 | echo "end" 12 | date -------------------------------------------------------------------------------- /src_classification/run_pretrain_CP.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_CP.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_classification/run_pretrain_Contextual.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_Contextual.py $@ 11 | echo "end" 12 | date -------------------------------------------------------------------------------- /src_classification/run_pretrain_EP.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_EP.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_classification/run_pretrain_GPT_GNN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_GPT_GNN.py $@ 11 | echo "end" 12 | date -------------------------------------------------------------------------------- /src_classification/run_pretrain_GraphCL.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_GraphCL.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_classification/run_pretrain_GraphLoG.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_GraphLoG.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_classification/run_pretrain_GraphMVP.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | 7 | echo $@ 8 | date 9 | echo "start" 10 | python pretrain_GraphMVP.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_classification/run_pretrain_GraphMVP_hybrid.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | echo "start" 9 | python pretrain_GraphMVP_hybrid.py $@ 10 | echo "end" 11 | date 12 | -------------------------------------------------------------------------------- /src_classification/run_pretrain_IG.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_IG.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_classification/run_pretrain_JOAO.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_JOAO.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_classification/run_pretrain_JOAOv2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_JOAOv2.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_classification/run_pretrain_Motif.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_Motif.py $@ 11 | echo "end" 12 | date -------------------------------------------------------------------------------- /src_regression/datasets_complete_feature/__init__.py: -------------------------------------------------------------------------------- 1 | from .dti_datasets import MoleculeProteinDataset 2 | from .molecule_datasets import (MoleculeDatasetComplete, 3 | graph_data_obj_to_nx_simple, 4 | nx_to_graph_data_obj_simple) 5 | from .molecule_graphcl_dataset import MoleculeDataset_graphcl_complete 6 | -------------------------------------------------------------------------------- /src_regression/datasets_complete_feature/dti_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from rdkit.Chem import AllChem 7 | from torch_geometric.data import InMemoryDataset 8 | 9 | from .molecule_datasets import mol_to_graph_data_obj_simple 10 | 11 | seq_voc = "ABCDEFGHIKLMNOPQRSTUVWXYZ" 12 | seq_dict = {v:(i+1) for i,v in enumerate(seq_voc)} 13 | seq_dict_len = len(seq_dict) 14 | max_seq_len = 1000 15 | 16 | 17 | def seq_cat(prot): 18 | x = np.zeros(max_seq_len) 19 | for i, ch in enumerate(prot[:max_seq_len]): 20 | x[i] = seq_dict[ch] 21 | return x 22 | 23 | 24 | class MoleculeProteinDataset(InMemoryDataset): 25 | def __init__(self, root, dataset, mode): 26 | super(InMemoryDataset, self).__init__() 27 | self.root = root 28 | self.dataset = dataset 29 | datapath = os.path.join(self.root, self.dataset, '{}.csv'.format(mode)) 30 | print('datapath\t', datapath) 31 | 32 | self.process_molecule() 33 | self.process_protein() 34 | 35 | df = pd.read_csv(datapath) 36 | self.molecule_index_list = df['smiles_id'].tolist() 37 | self.protein_index_list = df['target_id'].tolist() 38 | self.label_list = df['affinity'].tolist() 39 | self.label_list = torch.FloatTensor(self.label_list) 40 | 41 | return 42 | 43 | def process_molecule(self): 44 | input_path = os.path.join(self.root, self.dataset, 'smiles.csv') 45 | input_df = pd.read_csv(input_path, sep=',') 46 | smiles_list = input_df['smiles'] 47 | 48 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 49 | preprocessed_rdkit_mol_objs_list = [m if m != None else None for m in rdkit_mol_objs_list] 50 | preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m != None else None for m in preprocessed_rdkit_mol_objs_list] 51 | assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list) 52 | assert len(smiles_list) == len(preprocessed_smiles_list) 53 | 54 | smiles_list, rdkit_mol_objs = preprocessed_smiles_list, preprocessed_rdkit_mol_objs_list 55 | 56 | data_list = [] 57 | for i in range(len(smiles_list)): 58 | rdkit_mol = rdkit_mol_objs[i] 59 | if rdkit_mol != None: 60 | data = mol_to_graph_data_obj_simple(rdkit_mol) 61 | data.id = torch.tensor([i]) 62 | data_list.append(data) 63 | 64 | self.molecule_list = data_list 65 | return 66 | 67 | def process_protein(self): 68 | datapath = os.path.join(self.root, self.dataset, 'protein.csv') 69 | 70 | input_df = pd.read_csv(datapath, sep=',') 71 | protein_list = input_df['protein'].tolist() 72 | 73 | self.protein_list = [seq_cat(t) for t in protein_list] 74 | self.protein_list = torch.LongTensor(self.protein_list) 75 | return 76 | 77 | def __getitem__(self, idx): 78 | molecule = self.molecule_list[self.molecule_index_list[idx]] 79 | protein = self.protein_list[self.protein_index_list[idx]] 80 | label = self.label_list[idx] 81 | return molecule, protein, label 82 | 83 | def __len__(self): 84 | return len(self.label_list) 85 | -------------------------------------------------------------------------------- /src_regression/datasets_complete_feature/molecule_graphcl_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from itertools import repeat 4 | 5 | import numpy as np 6 | import torch 7 | from torch_geometric.data import Data 8 | from torch_geometric.utils import subgraph, to_networkx 9 | 10 | from .molecule_datasets import MoleculeDatasetComplete 11 | 12 | 13 | class MoleculeDataset_graphcl_complete(MoleculeDatasetComplete): 14 | 15 | def __init__(self, 16 | root, 17 | transform=None, 18 | pre_transform=None, 19 | pre_filter=None, 20 | dataset=None, 21 | empty=False): 22 | 23 | self.aug_prob = None 24 | self.aug_mode = 'no_aug' 25 | self.aug_strength = 0.2 26 | self.augmentations = [self.node_drop, self.subgraph, 27 | self.edge_pert, self.attr_mask, lambda x: x] 28 | super(MoleculeDataset_graphcl_complete, self).__init__( 29 | root, transform, pre_transform, pre_filter, dataset, empty) 30 | 31 | def set_augMode(self, aug_mode): 32 | self.aug_mode = aug_mode 33 | 34 | def set_augStrength(self, aug_strength): 35 | self.aug_strength = aug_strength 36 | 37 | def set_augProb(self, aug_prob): 38 | self.aug_prob = aug_prob 39 | 40 | def node_drop(self, data): 41 | 42 | node_num, _ = data.x.size() 43 | _, edge_num = data.edge_index.size() 44 | drop_num = int(node_num * self.aug_strength) 45 | 46 | idx_perm = np.random.permutation(node_num) 47 | idx_nodrop = idx_perm[drop_num:].tolist() 48 | idx_nodrop.sort() 49 | 50 | edge_idx, edge_attr = subgraph(subset=idx_nodrop, 51 | edge_index=data.edge_index, 52 | edge_attr=data.edge_attr, 53 | relabel_nodes=True, 54 | num_nodes=node_num) 55 | 56 | data.edge_index = edge_idx 57 | data.edge_attr = edge_attr 58 | data.x = data.x[idx_nodrop] 59 | data.__num_nodes__, _ = data.x.shape 60 | return data 61 | 62 | def edge_pert(self, data): 63 | node_num, _ = data.x.size() 64 | _, edge_num = data.edge_index.size() 65 | pert_num = int(edge_num * self.aug_strength) 66 | 67 | # delete edges 68 | idx_drop = np.random.choice(edge_num, (edge_num - pert_num), 69 | replace=False) 70 | edge_index = data.edge_index[:, idx_drop] 71 | edge_attr = data.edge_attr[idx_drop] 72 | 73 | # add edges 74 | adj = torch.ones((node_num, node_num)) 75 | adj[edge_index[0], edge_index[1]] = 0 76 | # edge_index_nonexist = adj.nonzero(as_tuple=False).t() 77 | edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t() 78 | idx_add = np.random.choice(edge_index_nonexist.shape[1], 79 | pert_num, replace=False) 80 | edge_index_add = edge_index_nonexist[:, idx_add] 81 | # check here: https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py#L2 82 | edge_attr_add_1 = torch.tensor(np.random.randint( 83 | 5, size=(edge_index_add.shape[1], 1))) 84 | edge_attr_add_2 = torch.tensor(np.random.randint( 85 | 6, size=(edge_index_add.shape[1], 1))) 86 | edge_attr_add_3 = torch.tensor(np.random.randint( 87 | 2, size=(edge_index_add.shape[1], 1))) 88 | edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2, edge_attr_add_3), dim=1) 89 | edge_index = torch.cat((edge_index, edge_index_add), dim=1) 90 | edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0) 91 | 92 | data.edge_index = edge_index 93 | data.edge_attr = edge_attr 94 | return data 95 | 96 | def attr_mask(self, data): 97 | 98 | _x = data.x.clone() 99 | node_num, _ = data.x.size() 100 | mask_num = int(node_num * self.aug_strength) 101 | 102 | token = data.x.float().mean(dim=0).long() 103 | idx_mask = np.random.choice( 104 | node_num, mask_num, replace=False) 105 | 106 | _x[idx_mask] = token 107 | data.x = _x 108 | return data 109 | 110 | def subgraph(self, data): 111 | 112 | G = to_networkx(data) 113 | node_num, _ = data.x.size() 114 | _, edge_num = data.edge_index.size() 115 | sub_num = int(node_num * (1 - self.aug_strength)) 116 | 117 | idx_sub = [np.random.randint(node_num, size=1)[0]] 118 | idx_neigh = set([n for n in G.neighbors(idx_sub[-1])]) 119 | 120 | while len(idx_sub) <= sub_num: 121 | if len(idx_neigh) == 0: 122 | idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub))) 123 | idx_neigh = set([np.random.choice(idx_unsub)]) 124 | sample_node = np.random.choice(list(idx_neigh)) 125 | 126 | idx_sub.append(sample_node) 127 | idx_neigh = idx_neigh.union( 128 | set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub)) 129 | 130 | idx_nondrop = idx_sub 131 | idx_nondrop.sort() 132 | 133 | edge_idx, edge_attr = subgraph(subset=idx_nondrop, 134 | edge_index=data.edge_index, 135 | edge_attr=data.edge_attr, 136 | relabel_nodes=True, 137 | num_nodes=node_num) 138 | 139 | data.edge_index = edge_idx 140 | data.edge_attr = edge_attr 141 | data.x = data.x[idx_nondrop] 142 | data.__num_nodes__, _ = data.x.shape 143 | return data 144 | 145 | def get(self, idx): 146 | data, data1, data2 = Data(), Data(), Data() 147 | keys_for_2D = ['x', 'edge_index', 'edge_attr'] 148 | for key in self.data.keys: 149 | item, slices = self.data[key], self.slices[key] 150 | s = list(repeat(slice(None), item.dim())) 151 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 152 | if key in keys_for_2D: 153 | data[key], data1[key], data2[key] = item[s], item[s], item[s] 154 | else: 155 | data[key] = item[s] 156 | 157 | if self.aug_mode == 'no_aug': 158 | n_aug1, n_aug2 = 4, 4 159 | data1 = self.augmentations[n_aug1](data1) 160 | data2 = self.augmentations[n_aug2](data2) 161 | elif self.aug_mode == 'uniform': 162 | n_aug = np.random.choice(25, 1)[0] 163 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5 164 | data1 = self.augmentations[n_aug1](data1) 165 | data2 = self.augmentations[n_aug2](data2) 166 | elif self.aug_mode == 'sample': 167 | n_aug = np.random.choice(25, 1, p=self.aug_prob)[0] 168 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5 169 | data1 = self.augmentations[n_aug1](data1) 170 | data2 = self.augmentations[n_aug2](data2) 171 | else: 172 | raise ValueError 173 | return data, data1, data2 174 | -------------------------------------------------------------------------------- /src_regression/dti_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import sys 5 | import time 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torch_geometric.data import DataLoader 14 | from tqdm import tqdm 15 | 16 | sys.path.insert(0, '../src_classification') 17 | from datasets_complete_feature import MoleculeProteinDataset 18 | from models import MoleculeProteinModel, ProteinModel 19 | from models_complete_feature import GNNComplete 20 | from util import ci, mse, pearson, rmse, spearman 21 | 22 | 23 | def train(repurpose_model, device, dataloader, optimizer): 24 | repurpose_model.train() 25 | loss_accum = 0 26 | for step_idx, batch in enumerate(dataloader): 27 | molecule, protein, label = batch 28 | molecule = molecule.to(device) 29 | protein = protein.to(device) 30 | label = label.to(device) 31 | 32 | pred = repurpose_model(molecule, protein).squeeze() 33 | 34 | optimizer.zero_grad() 35 | loss = criterion(pred, label) 36 | loss.backward() 37 | optimizer.step() 38 | loss_accum += loss.detach().item() 39 | print('Loss:\t{}'.format(loss_accum / len(dataloader))) 40 | 41 | 42 | def predicting(repurpose_model, device, dataloader): 43 | repurpose_model.eval() 44 | total_preds = [] 45 | total_labels = [] 46 | with torch.no_grad(): 47 | for batch in dataloader: 48 | molecule, protein, label = batch 49 | molecule = molecule.to(device) 50 | protein = protein.to(device) 51 | label = label.to(device) 52 | pred = repurpose_model(molecule, protein).squeeze() 53 | 54 | total_preds.append(pred.detach().cpu()) 55 | total_labels.append(label.detach().cpu()) 56 | total_preds = torch.cat(total_preds, dim=0) 57 | total_labels = torch.cat(total_labels, dim=0) 58 | return total_labels.numpy(), total_preds.numpy() 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks') 63 | parser.add_argument('--device', type=int, default=0) 64 | parser.add_argument('--num_layer', type=int, default=5) 65 | parser.add_argument('--emb_dim', type=int, default=300) 66 | parser.add_argument('--dropout_ratio', type=float, default=0.) 67 | parser.add_argument('--graph_pooling', type=str, default='mean') 68 | parser.add_argument('--JK', type=str, default='last') 69 | parser.add_argument('--dataset', type=str, default='davis', choices=['davis', 'kiba']) 70 | parser.add_argument('--gnn_type', type=str, default='gin') 71 | parser.add_argument('--seed', type=int, default=42) 72 | parser.add_argument('--runseed', type=int, default=0) 73 | parser.add_argument('--batch_size', type=int, default=512) 74 | parser.add_argument('--learning_rate', type=float, default=0.0005) 75 | parser.add_argument('--epochs', type=int, default=200) 76 | parser.add_argument('--input_model_file', type=str, default='') 77 | parser.add_argument('--output_model_file', type=str, default='') 78 | ########## For protein embedding ########## 79 | parser.add_argument('--protein_emb_dim', type=int, default=300) 80 | parser.add_argument('--protein_hidden_dim', type=int, default=300) 81 | parser.add_argument('--num_features', type=int, default=25) 82 | args = parser.parse_args() 83 | 84 | torch.manual_seed(args.runseed) 85 | np.random.seed(args.runseed) 86 | device = torch.device('cuda:' + str(args.device)) if torch.cuda.is_available() else torch.device('cpu') 87 | 88 | ########## Set up dataset and dataloader ########## 89 | root = '../datasets/dti_datasets' 90 | train_val_dataset = MoleculeProteinDataset(root=root, dataset=args.dataset, mode='train') 91 | train_size = int(0.8 * len(train_val_dataset)) 92 | valid_size = len(train_val_dataset) - train_size 93 | train_dataset, valid_dataset = torch.utils.data.random_split(train_val_dataset, [train_size, valid_size]) 94 | test_dataset = MoleculeProteinDataset(root=root, dataset=args.dataset, mode='test') 95 | print('size of train: {}\tval: {}\ttest: {}'.format(len(train_dataset), len(valid_dataset), len(test_dataset))) 96 | 97 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 98 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False) 99 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 100 | 101 | ########## Set up model ########## 102 | molecule_model = GNNComplete(args.num_layer, args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type) 103 | ########## Load pre-trained model ########## 104 | if not args.input_model_file == '': 105 | print('========= Loading from {}'.format(args.input_model_file)) 106 | molecule_model.load_state_dict(torch.load(args.input_model_file)) 107 | protein_model = ProteinModel( 108 | emb_dim=args.protein_emb_dim, num_features=args.num_features, output_dim=args.protein_hidden_dim) 109 | repurpose_model = MoleculeProteinModel( 110 | molecule_model, protein_model, 111 | molecule_emb_dim=args.emb_dim, protein_emb_dim=args.protein_hidden_dim).to(device) 112 | print('repurpose model\n', repurpose_model) 113 | 114 | criterion = nn.MSELoss() 115 | optimizer = torch.optim.Adam(repurpose_model.parameters(), lr=args.learning_rate) 116 | 117 | best_repurpose_model = None 118 | best_mse = 1000 119 | best_epoch = 0 120 | 121 | for epoch in range(1, 1+args.epochs): 122 | start_time = time.time() 123 | print('Start training at epoch: {}'.format(epoch)) 124 | train(repurpose_model, device, train_dataloader, optimizer) 125 | 126 | G, P = predicting(repurpose_model, device, valid_dataloader) 127 | current_mse = mse(G, P) 128 | print('MSE:\t{}'.format(current_mse)) 129 | if current_mse < best_mse: 130 | best_repurpose_model = copy.deepcopy(repurpose_model) 131 | best_mse = current_mse 132 | best_epoch = epoch 133 | print('MSE improved at epoch {}\tbest MSE: {}'.format(best_epoch, best_mse)) 134 | else: 135 | print('No improvement since epoch {}\tbest MSE: {}'.format(best_epoch, best_mse)) 136 | print('Took {:.5f}s.'.format(time.time() - start_time)) 137 | print() 138 | 139 | start_time = time.time() 140 | print('Last epoch: {}'.format(args.epochs)) 141 | G, P = predicting(repurpose_model, device, test_dataloader) 142 | ret = [rmse(G, P), mse(G, P), pearson(G, P), spearman(G, P), ci(G, P)] 143 | print('RMSE: {}\tMSE: {}\tPearson: {}\tSpearman: {}\tCI: {}'.format(ret[0], ret[1], ret[2], ret[3], ret[4])) 144 | print('Took {:.5f}s.'.format(time.time() - start_time)) 145 | 146 | start_time = time.time() 147 | print('Best epoch: {}'.format(best_epoch)) 148 | G, P = predicting(best_repurpose_model, device, test_dataloader) 149 | ret = [rmse(G, P), mse(G, P), pearson(G, P), spearman(G, P), ci(G, P)] 150 | print('RMSE: {}\tMSE: {}\tPearson: {}\tSpearman: {}\tCI: {}'.format(ret[0], ret[1], ret[2], ret[3], ret[4])) 151 | print('Took {:.5f}s.'.format(time.time() - start_time)) 152 | 153 | if not args.output_model_file == '': 154 | torch.save({ 155 | 'repurpose_model': repurpose_model.state_dict(), 156 | 'best_repurpose_model': best_repurpose_model.state_dict() 157 | }, args.output_model_file + '.pth') 158 | -------------------------------------------------------------------------------- /src_regression/models_complete_feature/__init__.py: -------------------------------------------------------------------------------- 1 | from .molecule_gnn_model import GNN_graphpredComplete, GNNComplete 2 | -------------------------------------------------------------------------------- /src_regression/models_complete_feature/molecule_gnn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder 5 | from torch_geometric.nn import (MessagePassing, global_add_pool, 6 | global_max_pool, global_mean_pool) 7 | from torch_geometric.nn.inits import glorot, zeros 8 | from torch_geometric.utils import add_self_loops, softmax 9 | from torch_scatter import scatter_add 10 | 11 | 12 | class GINConv(MessagePassing): 13 | def __init__(self, emb_dim, aggr="add"): 14 | super(GINConv, self).__init__(aggr=aggr) 15 | 16 | self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) 17 | self.eps = torch.nn.Parameter(torch.Tensor([0])) 18 | 19 | self.bond_encoder = BondEncoder(emb_dim = emb_dim) 20 | 21 | def forward(self, x, edge_index, edge_attr): 22 | edge_embedding = self.bond_encoder(edge_attr) 23 | out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) 24 | return out 25 | 26 | def message(self, x_j, edge_attr): 27 | return F.relu(x_j + edge_attr) 28 | 29 | def update(self, aggr_out): 30 | return aggr_out 31 | 32 | 33 | class GCNConv(MessagePassing): 34 | def __init__(self, emb_dim, aggr="add"): 35 | super(GCNConv, self).__init__(aggr=aggr) 36 | 37 | self.linear = torch.nn.Linear(emb_dim, emb_dim) 38 | self.root_emb = torch.nn.Embedding(1, emb_dim) 39 | self.bond_encoder = BondEncoder(emb_dim = emb_dim) 40 | 41 | def forward(self, x, edge_index, edge_attr): 42 | x = self.linear(x) 43 | edge_embedding = self.bond_encoder(edge_attr) 44 | 45 | row, col = edge_index 46 | 47 | deg = degree(row, x.size(0), dtype = x.dtype) + 1 48 | deg_inv_sqrt = deg.pow(-0.5) 49 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 50 | 51 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] 52 | 53 | return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) 54 | 55 | def message(self, x_j, edge_attr, norm): 56 | return norm.view(-1, 1) * F.relu(x_j + edge_attr) 57 | 58 | def update(self, aggr_out): 59 | return aggr_out 60 | 61 | 62 | class GNNComplete(nn.Module): 63 | def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0., gnn_type="gin"): 64 | 65 | if num_layer < 2: 66 | raise ValueError("Number of GNN layers must be greater than 1.") 67 | 68 | super(GNNComplete, self).__init__() 69 | self.drop_ratio = drop_ratio 70 | self.num_layer = num_layer 71 | self.JK = JK 72 | 73 | self.atom_encoder = AtomEncoder(emb_dim) 74 | 75 | ###List of MLPs 76 | self.gnns = nn.ModuleList() 77 | for layer in range(num_layer): 78 | if gnn_type == "gin": 79 | self.gnns.append(GINConv(emb_dim, aggr="add")) 80 | elif gnn_type == "gcn": 81 | self.gnns.append(GCNConv(emb_dim, aggr="add")) 82 | 83 | ###List of batchnorms 84 | self.batch_norms = nn.ModuleList() 85 | for layer in range(num_layer): 86 | self.batch_norms.append(nn.BatchNorm1d(emb_dim)) 87 | 88 | # def forward(self, x, edge_index, edge_attr): 89 | def forward(self, *argv): 90 | if len(argv) == 3: 91 | x, edge_index, edge_attr = argv[0], argv[1], argv[2] 92 | elif len(argv) == 1: 93 | data = argv[0] 94 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 95 | else: 96 | raise ValueError("unmatched number of arguments.") 97 | 98 | x = self.atom_encoder(x) 99 | 100 | h_list = [x] 101 | for layer in range(self.num_layer): 102 | h = self.gnns[layer](h_list[layer], edge_index, edge_attr) 103 | h = self.batch_norms[layer](h) 104 | # h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) 105 | if layer == self.num_layer - 1: 106 | # remove relu for the last layer 107 | h = F.dropout(h, self.drop_ratio, training=self.training) 108 | else: 109 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) 110 | h_list.append(h) 111 | 112 | ### Different implementations of Jk-concat 113 | if self.JK == "concat": 114 | node_representation = torch.cat(h_list, dim=1) 115 | elif self.JK == "last": 116 | node_representation = h_list[-1] 117 | elif self.JK == "max": 118 | h_list = [h.unsqueeze_(0) for h in h_list] 119 | node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0] 120 | elif self.JK == "sum": 121 | h_list = [h.unsqueeze_(0) for h in h_list] 122 | node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0] 123 | else: 124 | raise ValueError("not implemented.") 125 | return node_representation 126 | 127 | 128 | class GNN_graphpredComplete(nn.Module): 129 | def __init__(self, args, num_tasks, molecule_model=None): 130 | super(GNN_graphpredComplete, self).__init__() 131 | 132 | if args.num_layer < 2: 133 | raise ValueError("# layers must > 1.") 134 | 135 | self.molecule_model = molecule_model 136 | self.num_layer = args.num_layer 137 | self.emb_dim = args.emb_dim 138 | self.num_tasks = num_tasks 139 | self.JK = args.JK 140 | 141 | # Different kind of graph pooling 142 | if args.graph_pooling == "sum": 143 | self.pool = global_add_pool 144 | elif args.graph_pooling == "mean": 145 | self.pool = global_mean_pool 146 | elif args.graph_pooling == "max": 147 | self.pool = global_max_pool 148 | else: 149 | raise ValueError("Invalid graph pooling type.") 150 | 151 | # For graph-level binary classification 152 | self.mult = 1 153 | 154 | if self.JK == "concat": 155 | self.graph_pred_linear = nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, 156 | self.num_tasks) 157 | else: 158 | self.graph_pred_linear = nn.Linear(self.mult * self.emb_dim, self.num_tasks) 159 | return 160 | 161 | def from_pretrained(self, model_file): 162 | self.molecule_model.load_state_dict(torch.load(model_file)) 163 | return 164 | 165 | def get_graph_representation(self, *argv): 166 | if len(argv) == 4: 167 | x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3] 168 | elif len(argv) == 1: 169 | data = argv[0] 170 | x, edge_index, edge_attr, batch = data.x, data.edge_index, \ 171 | data.edge_attr, data.batch 172 | else: 173 | raise ValueError("unmatched number of arguments.") 174 | 175 | node_representation = self.molecule_model(x, edge_index, edge_attr) 176 | graph_representation = self.pool(node_representation, batch) 177 | pred = self.graph_pred_linear(graph_representation) 178 | 179 | return graph_representation, pred 180 | 181 | def forward(self, *argv): 182 | if len(argv) == 4: 183 | x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3] 184 | elif len(argv) == 1: 185 | data = argv[0] 186 | x, edge_index, edge_attr, batch = data.x, data.edge_index, \ 187 | data.edge_attr, data.batch 188 | else: 189 | raise ValueError("unmatched number of arguments.") 190 | 191 | node_representation = self.molecule_model(x, edge_index, edge_attr) 192 | graph_representation = self.pool(node_representation, batch) 193 | output = self.graph_pred_linear(graph_representation) 194 | 195 | return output 196 | -------------------------------------------------------------------------------- /src_regression/molecule_finetune_regression.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | sys.path.insert(0, '../src_classification') 11 | from os.path import join 12 | 13 | from config import args 14 | from datasets_complete_feature import MoleculeDatasetComplete 15 | from models_complete_feature import GNN_graphpredComplete, GNNComplete 16 | from sklearn.metrics import mean_absolute_error, mean_squared_error 17 | from splitters import random_scaffold_split, random_split, scaffold_split 18 | from torch_geometric.data import DataLoader 19 | 20 | 21 | def train(model, device, loader, optimizer): 22 | model.train() 23 | total_loss = 0 24 | 25 | for step, batch in enumerate(loader): 26 | batch = batch.to(device) 27 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch).squeeze() 28 | y = batch.y.squeeze() 29 | 30 | loss = reg_criterion(pred, y) 31 | 32 | optimizer.zero_grad() 33 | loss.backward() 34 | optimizer.step() 35 | total_loss += loss.detach().item() 36 | 37 | return total_loss / len(loader) 38 | 39 | 40 | def eval(model, device, loader): 41 | model.eval() 42 | y_true, y_pred = [], [] 43 | 44 | for step, batch in enumerate(loader): 45 | batch = batch.to(device) 46 | with torch.no_grad(): 47 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch).squeeze(1) 48 | 49 | true = batch.y.view(pred.shape) 50 | y_true.append(true) 51 | y_pred.append(pred) 52 | 53 | y_true = torch.cat(y_true, dim=0).cpu().numpy() 54 | y_pred = torch.cat(y_pred, dim=0).cpu().numpy() 55 | rmse = mean_squared_error(y_true, y_pred, squared=False) 56 | mae = mean_absolute_error(y_true, y_pred) 57 | return {'RMSE': rmse, 'MAE': mae}, y_true, y_pred 58 | 59 | 60 | if __name__ == '__main__': 61 | torch.manual_seed(args.runseed) 62 | np.random.seed(args.runseed) 63 | device = torch.device('cuda:' + str(args.device)) \ 64 | if torch.cuda.is_available() else torch.device('cpu') 65 | if torch.cuda.is_available(): 66 | torch.cuda.manual_seed_all(args.runseed) 67 | 68 | num_tasks = 1 69 | dataset_folder = '../datasets/molecule_datasets_regression/' 70 | dataset_folder = os.path.join(dataset_folder, args.dataset) 71 | dataset = MoleculeDatasetComplete(dataset_folder, dataset=args.dataset) 72 | print('dataset_folder:', dataset_folder) 73 | print(dataset) 74 | 75 | if args.split == 'scaffold': 76 | smiles_list = pd.read_csv(dataset_folder + '/processed/smiles.csv', header=None)[0].tolist() 77 | train_dataset, valid_dataset, test_dataset = scaffold_split( 78 | dataset, smiles_list, null_value=0, frac_train=0.8, 79 | frac_valid=0.1, frac_test=0.1) 80 | print('split via scaffold') 81 | elif args.split == 'random': 82 | train_dataset, valid_dataset, test_dataset = random_split( 83 | dataset, null_value=0, frac_train=0.8, frac_valid=0.1, 84 | frac_test=0.1, seed=args.seed) 85 | print('randomly split') 86 | elif args.split == 'random_scaffold': 87 | smiles_list = pd.read_csv(dataset_folder + '/processed/smiles.csv', header=None)[0].tolist() 88 | train_dataset, valid_dataset, test_dataset = random_scaffold_split( 89 | dataset, smiles_list, null_value=0, frac_train=0.8, 90 | frac_valid=0.1, frac_test=0.1, seed=args.seed) 91 | print('random scaffold') 92 | else: 93 | raise ValueError('Invalid split option.') 94 | 95 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 96 | val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 97 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 98 | 99 | # set up model 100 | molecule_model = GNNComplete(num_layer=args.num_layer, emb_dim=args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type) 101 | model = GNN_graphpredComplete(args=args, num_tasks=num_tasks, molecule_model=molecule_model) 102 | if not args.input_model_file == '': 103 | model.from_pretrained(args.input_model_file) 104 | model.to(device) 105 | print(model) 106 | 107 | model_param_group = [ 108 | {'params': model.molecule_model.parameters()}, 109 | {'params': model.graph_pred_linear.parameters(), 'lr': args.lr * args.lr_scale} 110 | ] 111 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) 112 | # reg_criterion = torch.nn.L1Loss() 113 | reg_criterion = torch.nn.MSELoss() 114 | 115 | train_result_list, val_result_list, test_result_list = [], [], [] 116 | # metric_list = ['RMSE', 'MAE', 'R2'] 117 | metric_list = ['RMSE', 'MAE'] 118 | best_val_rmse, best_val_idx = 1e10, 0 119 | 120 | for epoch in range(1, args.epochs + 1): 121 | loss_acc = train(model, device, train_loader, optimizer) 122 | print('Epoch: {}\nLoss: {}'.format(epoch, loss_acc)) 123 | 124 | if args.eval_train: 125 | train_result, train_target, train_pred = eval(model, device, train_loader) 126 | else: 127 | train_result = {'RMSE': 0, 'MAE': 0, 'R2': 0} 128 | val_result, val_target, val_pred = eval(model, device, val_loader) 129 | test_result, test_target, test_pred = eval(model, device, test_loader) 130 | 131 | train_result_list.append(train_result) 132 | val_result_list.append(val_result) 133 | test_result_list.append(test_result) 134 | 135 | for metric in metric_list: 136 | print('{} train: {:.6f}\tval: {:.6f}\ttest: {:.6f}'.format(metric, train_result[metric], val_result[metric], test_result[metric])) 137 | print() 138 | 139 | if val_result['RMSE'] < best_val_rmse: 140 | best_val_rmse = val_result['RMSE'] 141 | best_val_idx = epoch - 1 142 | if not args.output_model_dir == '': 143 | output_model_path = join(args.output_model_dir, 'model_best.pth') 144 | saved_model_dict = { 145 | 'molecule_model': molecule_model.state_dict(), 146 | 'model': model.state_dict() 147 | } 148 | torch.save(saved_model_dict, output_model_path) 149 | 150 | filename = join(args.output_model_dir, 'evaluation_best.pth') 151 | np.savez(filename, val_target=val_target, val_pred=val_pred, 152 | test_target=test_target, test_pred=test_pred) 153 | 154 | for metric in metric_list: 155 | print('Best (RMSE), {} train: {:.6f}\tval: {:.6f}\ttest: {:.6f}'.format( 156 | metric, train_result_list[best_val_idx][metric], val_result_list[best_val_idx][metric], test_result_list[best_val_idx][metric])) 157 | 158 | if args.output_model_dir is not '': 159 | output_model_path = join(args.output_model_dir, 'model_final.pth') 160 | saved_model_dict = { 161 | 'molecule_model': molecule_model.state_dict(), 162 | 'model': model.state_dict() 163 | } 164 | torch.save(saved_model_dict, output_model_path) 165 | -------------------------------------------------------------------------------- /src_regression/pretrain_AM.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | sys.path.insert(0, '../src_classification') 10 | from config import args 11 | from dataloader import DataLoaderMasking 12 | from datasets_complete_feature import MoleculeDatasetComplete 13 | from models_complete_feature import GNNComplete 14 | from torch_geometric.nn import global_mean_pool 15 | from util_complete_feature import MaskAtom 16 | 17 | 18 | def compute_accuracy(pred, target): 19 | return float(torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item())/len(pred) 20 | 21 | 22 | def do_AttrMasking(batch, criterion, node_repr, molecule_atom_masking_model): 23 | target = batch.mask_node_label[:, 0] 24 | node_pred = molecule_atom_masking_model(node_repr[batch.masked_atom_indices]) 25 | attributemask_loss = criterion(node_pred.double(), target) 26 | attributemask_acc = compute_accuracy(node_pred, target) 27 | return attributemask_loss, attributemask_acc 28 | 29 | 30 | def train(device, loader, optimizer): 31 | 32 | start = time.time() 33 | molecule_model.train() 34 | molecule_atom_masking_model.train() 35 | attributemask_loss_accum, attributemask_acc_accum = 0, 0 36 | 37 | for step, batch in enumerate(loader): 38 | batch = batch.to(device) 39 | node_repr = molecule_model(batch.masked_x, batch.edge_index, batch.edge_attr) 40 | 41 | attributemask_loss, attributemask_acc = do_AttrMasking( 42 | batch=batch, criterion=criterion, node_repr=node_repr, 43 | molecule_atom_masking_model=molecule_atom_masking_model) 44 | 45 | attributemask_loss_accum += attributemask_loss.detach().cpu().item() 46 | attributemask_acc_accum += attributemask_acc 47 | loss = attributemask_loss 48 | 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | 53 | print('AM Loss: {:.5f}\tAM Acc: {:.5f}\tTime: {:.5f}'.format( 54 | attributemask_loss_accum / len(loader), 55 | attributemask_acc_accum / len(loader), 56 | time.time() - start)) 57 | return 58 | 59 | 60 | if __name__ == '__main__': 61 | 62 | np.random.seed(0) 63 | torch.manual_seed(0) 64 | device = torch.device('cuda:' + str(args.device)) \ 65 | if torch.cuda.is_available() else torch.device('cpu') 66 | if torch.cuda.is_available(): 67 | torch.cuda.manual_seed_all(0) 68 | torch.cuda.set_device(args.device) 69 | 70 | if 'GEOM' in args.dataset: 71 | dataset = MoleculeDatasetComplete( 72 | '../datasets/{}/'.format(args.dataset), dataset=args.dataset, 73 | transform=MaskAtom(num_atom_type=119, num_edge_type=5, 74 | mask_rate=args.mask_rate, mask_edge=args.mask_edge)) 75 | loader = DataLoaderMasking(dataset, batch_size=args.batch_size, 76 | shuffle=True, num_workers=args.num_workers) 77 | 78 | # set up model 79 | molecule_model = GNNComplete(args.num_layer, args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device) 80 | molecule_readout_func = global_mean_pool 81 | 82 | molecule_atom_masking_model = torch.nn.Linear(args.emb_dim, 119).to(device) 83 | 84 | model_param_group = [{'params': molecule_model.parameters(), 'lr': args.lr}, 85 | {'params': molecule_atom_masking_model.parameters(), 'lr': args.lr}] 86 | 87 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) 88 | criterion = nn.CrossEntropyLoss() 89 | 90 | for epoch in range(1, args.epochs + 1): 91 | print('epoch: {}'.format(epoch)) 92 | train(device, loader, optimizer) 93 | 94 | if not args.output_model_dir == '': 95 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth') 96 | 97 | saver_dict = {'model': molecule_model.state_dict(), 98 | 'molecule_atom_masking_model': molecule_atom_masking_model.state_dict()} 99 | 100 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 101 | -------------------------------------------------------------------------------- /src_regression/pretrain_CP.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | sys.path.insert(0, '../src_classification') 10 | 11 | from config import args 12 | from dataloader import DataLoaderSubstructContext 13 | from datasets_complete_feature import MoleculeDatasetComplete 14 | from models_complete_feature import GNNComplete 15 | from torch_geometric.nn import global_mean_pool 16 | from util import cycle_index 17 | from util_complete_feature import ExtractSubstructureContextPair 18 | 19 | 20 | def do_ContextPred(batch, criterion, args, molecule_substruct_model, 21 | molecule_context_model, molecule_readout_func): 22 | 23 | # creating substructure representation 24 | substruct_repr = molecule_substruct_model( 25 | batch.x_substruct, batch.edge_index_substruct, 26 | batch.edge_attr_substruct)[batch.center_substruct_idx] 27 | 28 | # creating context representations 29 | overlapped_node_repr = molecule_context_model( 30 | batch.x_context, batch.edge_index_context, 31 | batch.edge_attr_context)[batch.overlap_context_substruct_idx] 32 | 33 | # positive context representation 34 | # readout -> global_mean_pool by default 35 | context_repr = molecule_readout_func(overlapped_node_repr, 36 | batch.batch_overlapped_context) 37 | 38 | # negative contexts are obtained by shifting 39 | # the indices of context embeddings 40 | neg_context_repr = torch.cat( 41 | [context_repr[cycle_index(len(context_repr), i + 1)] 42 | for i in range(args.contextpred_neg_samples)], dim=0) 43 | 44 | num_neg = args.contextpred_neg_samples 45 | pred_pos = torch.sum(substruct_repr * context_repr, dim=1) 46 | pred_neg = torch.sum(substruct_repr.repeat((num_neg, 1)) * neg_context_repr, dim=1) 47 | 48 | loss_pos = criterion(pred_pos.double(), 49 | torch.ones(len(pred_pos)).to(pred_pos.device).double()) 50 | loss_neg = criterion(pred_neg.double(), 51 | torch.zeros(len(pred_neg)).to(pred_neg.device).double()) 52 | 53 | contextpred_loss = loss_pos + num_neg * loss_neg 54 | 55 | num_pred = len(pred_pos) + len(pred_neg) 56 | contextpred_acc = (torch.sum(pred_pos > 0).float() + 57 | torch.sum(pred_neg < 0).float()) / num_pred 58 | contextpred_acc = contextpred_acc.detach().cpu().item() 59 | 60 | return contextpred_loss, contextpred_acc 61 | 62 | 63 | def train(args, device, loader, optimizer): 64 | 65 | start_time = time.time() 66 | molecule_context_model.train() 67 | molecule_substruct_model.train() 68 | contextpred_loss_accum, contextpred_acc_accum = 0, 0 69 | 70 | for step, batch in enumerate(loader): 71 | 72 | batch = batch.to(device) 73 | contextpred_loss, contextpred_acc = do_ContextPred( 74 | batch=batch, criterion=criterion, args=args, 75 | molecule_substruct_model=molecule_substruct_model, 76 | molecule_context_model=molecule_context_model, 77 | molecule_readout_func=molecule_readout_func) 78 | 79 | contextpred_loss_accum += contextpred_loss.detach().cpu().item() 80 | contextpred_acc_accum += contextpred_acc 81 | ssl_loss = contextpred_loss 82 | optimizer.zero_grad() 83 | ssl_loss.backward() 84 | optimizer.step() 85 | 86 | print('CP Loss: {:.5f}\tCP Acc: {:.5f}\tTime: {:.3f}'.format( 87 | contextpred_loss_accum / len(loader), 88 | contextpred_acc_accum / len(loader), 89 | time.time() - start_time)) 90 | 91 | return 92 | 93 | 94 | if __name__ == '__main__': 95 | 96 | np.random.seed(0) 97 | torch.manual_seed(0) 98 | device = torch.device('cuda:' + str(args.device)) \ 99 | if torch.cuda.is_available() else torch.device('cpu') 100 | if torch.cuda.is_available(): 101 | torch.cuda.manual_seed_all(0) 102 | torch.cuda.set_device(args.device) 103 | 104 | l1 = args.num_layer - 1 105 | l2 = l1 + args.csize 106 | print('num layer: %d l1: %d l2: %d' % (args.num_layer, l1, l2)) 107 | 108 | if 'GEOM' in args.dataset: 109 | dataset = MoleculeDatasetComplete( 110 | '../datasets/{}/'.format(args.dataset), dataset=args.dataset, 111 | transform=ExtractSubstructureContextPair(args.num_layer, l1, l2)) 112 | loader = DataLoaderSubstructContext(dataset, batch_size=args.batch_size, 113 | shuffle=True, num_workers=args.num_workers) 114 | 115 | ''' === set up model, mainly used in do_ContextPred() === ''' 116 | molecule_substruct_model = GNNComplete( 117 | args.num_layer, args.emb_dim, JK=args.JK, 118 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device) 119 | molecule_context_model = GNNComplete( 120 | int(l2 - l1), args.emb_dim, JK=args.JK, 121 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device) 122 | 123 | ''' === set up loss and optimiser === ''' 124 | criterion = nn.BCEWithLogitsLoss() 125 | molecule_readout_func = global_mean_pool 126 | model_param_group = [{'params': molecule_substruct_model.parameters(), 'lr': args.lr}, 127 | {'params': molecule_context_model.parameters(), 'lr': args.lr}] 128 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) 129 | 130 | for epoch in range(1, args.epochs + 1): 131 | print('epoch: {}'.format(epoch)) 132 | train(args, device, loader, optimizer) 133 | 134 | if not args.output_model_dir == '': 135 | torch.save(molecule_substruct_model.state_dict(), 136 | args.output_model_dir + '_model.pth') 137 | 138 | saver_dict = { 139 | 'molecule_substruct_model': molecule_substruct_model.state_dict(), 140 | 'molecule_context_model': molecule_context_model.state_dict()} 141 | 142 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 143 | -------------------------------------------------------------------------------- /src_regression/pretrain_GraphCL.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.optim as optim 8 | 9 | sys.path.insert(0, '../src_classification') 10 | 11 | from datasets_complete_feature import MoleculeDataset_graphcl_complete 12 | from models_complete_feature import GNNComplete 13 | from pretrain_JOAO import graphcl 14 | from torch_geometric.data import DataLoader 15 | 16 | 17 | def train(loader, model, optimizer, device): 18 | 19 | model.train() 20 | train_loss_accum = 0 21 | 22 | for step, (_, batch1, batch2) in enumerate(loader): 23 | # _, batch1, batch2 = batch 24 | batch1 = batch1.to(device) 25 | batch2 = batch2.to(device) 26 | 27 | x1 = model.forward_cl(batch1.x, batch1.edge_index, 28 | batch1.edge_attr, batch1.batch) 29 | x2 = model.forward_cl(batch2.x, batch2.edge_index, 30 | batch2.edge_attr, batch2.batch) 31 | loss = model.loss_cl(x1, x2) 32 | 33 | optimizer.zero_grad() 34 | loss.backward() 35 | optimizer.step() 36 | 37 | train_loss_accum += float(loss.detach().cpu().item()) 38 | 39 | return train_loss_accum / (step + 1) 40 | 41 | 42 | if __name__ == "__main__": 43 | # Training settings 44 | parser = argparse.ArgumentParser(description='GraphCL') 45 | parser.add_argument('--device', type=int, default=0, help='gpu') 46 | parser.add_argument('--batch_size', type=int, default=256, help='batch') 47 | parser.add_argument('--decay', type=float, default=0, help='weight decay') 48 | parser.add_argument('--epochs', type=int, default=100, help='train epochs') 49 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 50 | parser.add_argument('--JK', type=str, default="last", 51 | choices=['last', 'sum', 'max', 'concat'], 52 | help='how the node features across layers are combined.') 53 | parser.add_argument('--gnn_type', type=str, default="gin", help='gnn model type') 54 | parser.add_argument('--dropout_ratio', type=float, default=0, help='dropout ratio') 55 | parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions') 56 | parser.add_argument('--dataset', type=str, default=None, help='root dir of dataset') 57 | parser.add_argument('--num_layer', type=int, default=5, help='message passing layers') 58 | parser.add_argument('--output_model_file', type=str, default='', help='model save path') 59 | parser.add_argument('--num_workers', type=int, default=8, help='workers for dataset loading') 60 | 61 | parser.add_argument('--aug_mode', type=str, default='sample') 62 | parser.add_argument('--aug_strength', type=float, default=0.2) 63 | 64 | # parser.add_argument('--gamma', type=float, default=0.1) 65 | parser.add_argument('--output_model_dir', type=str, default='') 66 | args = parser.parse_args() 67 | 68 | torch.manual_seed(0) 69 | np.random.seed(0) 70 | device = torch.device("cuda:" + str(args.device)) \ 71 | if torch.cuda.is_available() else torch.device("cpu") 72 | if torch.cuda.is_available(): 73 | torch.cuda.manual_seed_all(0) 74 | 75 | # set up dataset 76 | if 'GEOM' in args.dataset: 77 | dataset = MoleculeDataset_graphcl_complete('../datasets/{}/'.format(args.dataset), dataset=args.dataset) 78 | dataset.set_augMode(args.aug_mode) 79 | dataset.set_augStrength(args.aug_strength) 80 | print(dataset) 81 | 82 | loader = DataLoader(dataset, batch_size=args.batch_size, 83 | num_workers=args.num_workers, shuffle=True) 84 | 85 | # set up model 86 | gnn = GNNComplete(num_layer=args.num_layer, emb_dim=args.emb_dim, JK=args.JK, 87 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type) 88 | 89 | model = graphcl(gnn) 90 | model.to(device) 91 | 92 | # set up optimizer 93 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 94 | print(optimizer) 95 | 96 | aug_prob = np.ones(25) / 25 97 | dataset.set_augProb(aug_prob) 98 | for epoch in range(1, args.epochs + 1): 99 | start_time = time.time() 100 | pretrain_loss = train(loader, model, optimizer, device) 101 | 102 | print('Epoch: {:3d}\tLoss:{:.3f}\tTime: {:.3f}'.format( 103 | epoch, pretrain_loss, time.time() - start_time)) 104 | 105 | if not args.output_model_dir == '': 106 | saver_dict = {'model': model.state_dict()} 107 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 108 | torch.save(model.gnn.state_dict(), args.output_model_dir + '_model.pth') 109 | -------------------------------------------------------------------------------- /src_regression/pretrain_GraphMVP.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | 9 | sys.path.insert(0, '../src_classification') 10 | from config import args 11 | from models import AutoEncoder, SchNet, VariationalAutoEncoder 12 | from models_complete_feature import GNNComplete 13 | from torch_geometric.data import DataLoader 14 | from torch_geometric.nn import global_mean_pool 15 | from tqdm import tqdm 16 | from util import dual_CL 17 | 18 | from datasets import Molecule3DMaskingDataset 19 | 20 | 21 | def save_model(save_best): 22 | if not args.output_model_dir == '': 23 | if save_best: 24 | global optimal_loss 25 | print('save model with loss: {:.5f}'.format(optimal_loss)) 26 | torch.save(molecule_model_2D.state_dict(), args.output_model_dir + '_model.pth') 27 | saver_dict = { 28 | 'model': molecule_model_2D.state_dict(), 29 | 'model_3D': molecule_model_3D.state_dict(), 30 | 'AE_2D_3D_model': AE_2D_3D_model.state_dict(), 31 | 'AE_3D_2D_model': AE_3D_2D_model.state_dict(), 32 | } 33 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 34 | 35 | else: 36 | torch.save(molecule_model_2D.state_dict(), args.output_model_dir + '_model_final.pth') 37 | saver_dict = { 38 | 'model': molecule_model_2D.state_dict(), 39 | 'model_3D': molecule_model_3D.state_dict(), 40 | 'AE_2D_3D_model': AE_2D_3D_model.state_dict(), 41 | 'AE_3D_2D_model': AE_3D_2D_model.state_dict(), 42 | } 43 | torch.save(saver_dict, args.output_model_dir + '_model_complete_final.pth') 44 | return 45 | 46 | 47 | def train(args, molecule_model_2D, device, loader, optimizer): 48 | start_time = time.time() 49 | 50 | molecule_model_2D.train() 51 | molecule_model_3D.train() 52 | if molecule_projection_layer is not None: 53 | molecule_projection_layer.train() 54 | 55 | AE_loss_accum, AE_acc_accum = 0, 0 56 | CL_loss_accum, CL_acc_accum = 0, 0 57 | 58 | if args.verbose: 59 | l = tqdm(loader) 60 | else: 61 | l = loader 62 | for step, batch in enumerate(l): 63 | batch = batch.to(device) 64 | 65 | node_repr = molecule_model_2D(batch.x, batch.edge_index, batch.edge_attr) 66 | molecule_2D_repr = molecule_readout_func(node_repr, batch.batch) 67 | 68 | if args.model_3d == 'schnet': 69 | molecule_3D_repr = molecule_model_3D(batch.x[:, 0], batch.positions, batch.batch) 70 | 71 | CL_loss, CL_acc = dual_CL(molecule_2D_repr, molecule_3D_repr, args) 72 | 73 | AE_loss_1 = AE_2D_3D_model(molecule_2D_repr, molecule_3D_repr) 74 | AE_loss_2 = AE_3D_2D_model(molecule_3D_repr, molecule_2D_repr) 75 | AE_acc_1 = AE_acc_2 = 0 76 | AE_loss = (AE_loss_1 + AE_loss_2) / 2 77 | 78 | CL_loss_accum += CL_loss.detach().cpu().item() 79 | CL_acc_accum += CL_acc 80 | AE_loss_accum += AE_loss.detach().cpu().item() 81 | AE_acc_accum += (AE_acc_1 + AE_acc_2) / 2 82 | 83 | loss = 0 84 | if args.alpha_1 > 0: 85 | loss += CL_loss * args.alpha_1 86 | if args.alpha_2 > 0: 87 | loss += AE_loss * args.alpha_2 88 | 89 | optimizer.zero_grad() 90 | loss.backward() 91 | optimizer.step() 92 | 93 | global optimal_loss 94 | CL_loss_accum /= len(loader) 95 | CL_acc_accum /= len(loader) 96 | AE_loss_accum /= len(loader) 97 | AE_acc_accum /= len(loader) 98 | temp_loss = args.alpha_1 * CL_loss_accum + args.alpha_2 * AE_loss_accum 99 | if temp_loss < optimal_loss: 100 | optimal_loss = temp_loss 101 | save_model(save_best=True) 102 | print('CL Loss: {:.5f}\tCL Acc: {:.5f}\t\tAE Loss: {:.5f}\tAE Acc: {:.5f}\tTime: {:.5f}'.format( 103 | CL_loss_accum, CL_acc_accum, AE_loss_accum, AE_acc_accum, time.time() - start_time)) 104 | return 105 | 106 | 107 | if __name__ == '__main__': 108 | torch.manual_seed(0) 109 | np.random.seed(0) 110 | device = torch.device('cuda:' + str(args.device)) if torch.cuda.is_available() else torch.device('cpu') 111 | if torch.cuda.is_available(): 112 | torch.cuda.manual_seed_all(0) 113 | torch.cuda.set_device(args.device) 114 | 115 | if 'GEOM' in args.dataset: 116 | data_root = '../datasets/{}/'.format(args.dataset) if args.input_data_dir == '' else '{}/{}/'.format(args.input_data_dir, args.dataset) 117 | dataset = Molecule3DMaskingDataset(data_root, dataset=args.dataset, mask_ratio=args.SSL_masking_ratio) 118 | else: 119 | raise Exception 120 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 121 | 122 | # set up model 123 | molecule_model_2D = GNNComplete(args.num_layer, args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device) 124 | molecule_readout_func = global_mean_pool 125 | 126 | print('Using 3d model\t', args.model_3d) 127 | molecule_projection_layer = None 128 | if args.model_3d == 'schnet': 129 | molecule_model_3D = SchNet( 130 | hidden_channels=args.emb_dim, num_filters=args.num_filters, num_interactions=args.num_interactions, 131 | num_gaussians=args.num_gaussians, cutoff=args.cutoff, atomref=None, readout=args.readout).to(device) 132 | else: 133 | raise NotImplementedError('Model {} not included.'.format(args.model_3d)) 134 | 135 | if args.AE_model == 'AE': 136 | AE_2D_3D_model = AutoEncoder( 137 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target).to(device) 138 | AE_3D_2D_model = AutoEncoder( 139 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target).to(device) 140 | elif args.AE_model == 'VAE': 141 | AE_2D_3D_model = VariationalAutoEncoder( 142 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target, beta=args.beta).to(device) 143 | AE_3D_2D_model = VariationalAutoEncoder( 144 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target, beta=args.beta).to(device) 145 | else: 146 | raise Exception 147 | 148 | model_param_group = [] 149 | model_param_group.append({'params': molecule_model_2D.parameters(), 'lr': args.lr * args.gnn_lr_scale}) 150 | model_param_group.append({'params': molecule_model_3D.parameters(), 'lr': args.lr * args.schnet_lr_scale}) 151 | model_param_group.append({'params': AE_2D_3D_model.parameters(), 'lr': args.lr * args.gnn_lr_scale}) 152 | model_param_group.append({'params': AE_3D_2D_model.parameters(), 'lr': args.lr * args.schnet_lr_scale}) 153 | if molecule_projection_layer is not None: 154 | model_param_group.append({'params': molecule_projection_layer.parameters(), 'lr': args.lr * args.schnet_lr_scale}) 155 | 156 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) 157 | optimal_loss = 1e10 158 | 159 | for epoch in range(1, args.epochs + 1): 160 | print('epoch: {}'.format(epoch)) 161 | train(args, molecule_model_2D, device, loader, optimizer) 162 | 163 | save_model(save_best=False) 164 | -------------------------------------------------------------------------------- /src_regression/pretrain_JOAO.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from datasets_complete_feature import MoleculeDataset_graphcl_complete 9 | from models_complete_feature import GNNComplete 10 | from torch_geometric.data import DataLoader 11 | from torch_geometric.nn import global_mean_pool 12 | 13 | 14 | class graphcl(nn.Module): 15 | 16 | def __init__(self, gnn): 17 | super(graphcl, self).__init__() 18 | self.gnn = gnn 19 | self.pool = global_mean_pool 20 | self.projection_head = nn.Sequential( 21 | nn.Linear(300, 300), 22 | nn.ReLU(inplace=True), 23 | nn.Linear(300, 300)) 24 | 25 | def forward_cl(self, x, edge_index, edge_attr, batch): 26 | x = self.gnn(x, edge_index, edge_attr) 27 | x = self.pool(x, batch) 28 | x = self.projection_head(x) 29 | return x 30 | 31 | def loss_cl(self, x1, x2): 32 | T = 0.1 33 | batch, _ = x1.size() 34 | x1_abs = x1.norm(dim=1) 35 | x2_abs = x2.norm(dim=1) 36 | 37 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / \ 38 | torch.einsum('i,j->ij', x1_abs, x2_abs) 39 | sim_matrix = torch.exp(sim_matrix / T) 40 | pos_sim = sim_matrix[range(batch), range(batch)] 41 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 42 | loss = - torch.log(loss).mean() 43 | return loss 44 | 45 | 46 | def train(loader, model, optimizer, device, gamma_joao): 47 | 48 | model.train() 49 | train_loss_accum = 0 50 | 51 | for step, (_, batch1, batch2) in enumerate(loader): 52 | # _, batch1, batch2 = batch 53 | batch1 = batch1.to(device) 54 | batch2 = batch2.to(device) 55 | 56 | # pdb.set_trace() 57 | x1 = model.forward_cl(batch1.x, batch1.edge_index, 58 | batch1.edge_attr, batch1.batch) 59 | x2 = model.forward_cl(batch2.x, batch2.edge_index, 60 | batch2.edge_attr, batch2.batch) 61 | loss = model.loss_cl(x1, x2) 62 | 63 | optimizer.zero_grad() 64 | loss.backward() 65 | optimizer.step() 66 | 67 | train_loss_accum += float(loss.detach().cpu().item()) 68 | 69 | # joao 70 | aug_prob = loader.dataset.aug_prob 71 | loss_aug = np.zeros(25) 72 | for n in range(25): 73 | _aug_prob = np.zeros(25) 74 | _aug_prob[n] = 1 75 | loader.dataset.set_augProb(_aug_prob) 76 | # for efficiency, we only use around 10% of data to estimate the loss 77 | count, count_stop = 0, len(loader.dataset) // (loader.batch_size * 10) + 1 78 | 79 | with torch.no_grad(): 80 | for step, (_, batch1, batch2) in enumerate(loader): 81 | # _, batch1, batch2 = batch 82 | batch1 = batch1.to(device) 83 | batch2 = batch2.to(device) 84 | 85 | x1 = model.forward_cl(batch1.x, batch1.edge_index, 86 | batch1.edge_attr, batch1.batch) 87 | x2 = model.forward_cl(batch2.x, batch2.edge_index, 88 | batch2.edge_attr, batch2.batch) 89 | loss = model.loss_cl(x1, x2) 90 | loss_aug[n] += loss.item() 91 | count += 1 92 | if count == count_stop: 93 | break 94 | loss_aug[n] /= count 95 | 96 | # view selection, projected gradient descent, 97 | # reference: https://arxiv.org/abs/1906.03563 98 | beta = 1 99 | gamma = gamma_joao 100 | 101 | b = aug_prob + beta * (loss_aug - gamma * (aug_prob - 1 / 25)) 102 | mu_min, mu_max = b.min() - 1 / 25, b.max() - 1 / 25 103 | mu = (mu_min + mu_max) / 2 104 | 105 | # bisection method 106 | while abs(np.maximum(b - mu, 0).sum() - 1) > 1e-2: 107 | if np.maximum(b - mu, 0).sum() > 1: 108 | mu_min = mu 109 | else: 110 | mu_max = mu 111 | mu = (mu_min + mu_max) / 2 112 | 113 | aug_prob = np.maximum(b - mu, 0) 114 | aug_prob /= aug_prob.sum() 115 | 116 | return train_loss_accum / (step + 1), aug_prob 117 | 118 | 119 | if __name__ == "__main__": 120 | # Training settings 121 | parser = argparse.ArgumentParser(description='JOAO') 122 | parser.add_argument('--device', type=int, default=0, help='gpu') 123 | parser.add_argument('--batch_size', type=int, default=256, help='batch') 124 | parser.add_argument('--decay', type=float, default=0, help='weight decay') 125 | parser.add_argument('--epochs', type=int, default=100, help='train epochs') 126 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 127 | parser.add_argument('--JK', type=str, default="last", 128 | choices=['last', 'sum', 'max', 'concat'], 129 | help='how the node features across layers are combined.') 130 | parser.add_argument('--gnn_type', type=str, default="gin", help='gnn model type') 131 | parser.add_argument('--dropout_ratio', type=float, default=0, help='dropout ratio') 132 | parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions') 133 | parser.add_argument('--dataset', type=str, default=None, help='root dir of dataset') 134 | parser.add_argument('--num_layer', type=int, default=5, help='message passing layers') 135 | parser.add_argument('--output_model_file', type=str, default='', help='model save path') 136 | parser.add_argument('--num_workers', type=int, default=8, help='workers for dataset loading') 137 | 138 | parser.add_argument('--aug_mode', type=str, default='sample') 139 | parser.add_argument('--aug_strength', type=float, default=0.2) 140 | 141 | parser.add_argument('--gamma', type=float, default=0.1) 142 | parser.add_argument('--output_model_dir', type=str, default='') 143 | args = parser.parse_args() 144 | 145 | torch.manual_seed(0) 146 | np.random.seed(0) 147 | device = torch.device("cuda:" + str(args.device)) \ 148 | if torch.cuda.is_available() else torch.device("cpu") 149 | if torch.cuda.is_available(): 150 | torch.cuda.manual_seed_all(0) 151 | 152 | if 'GEOM' in args.dataset: 153 | dataset = MoleculeDataset_graphcl_complete('../datasets/{}/'.format(args.dataset), dataset=args.dataset) 154 | dataset.set_augMode(args.aug_mode) 155 | dataset.set_augStrength(args.aug_strength) 156 | print(dataset) 157 | 158 | loader = DataLoader(dataset, batch_size=args.batch_size, 159 | num_workers=args.num_workers, shuffle=True) 160 | 161 | # set up model 162 | gnn = GNNComplete(num_layer=args.num_layer, emb_dim=args.emb_dim, JK=args.JK, 163 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type) 164 | 165 | model = graphcl(gnn) 166 | model.to(device) 167 | 168 | # set up optimizer 169 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay) 170 | # print(optimizer) 171 | 172 | # pdb.set_trace() 173 | aug_prob = np.ones(25) / 25 174 | np.set_printoptions(precision=3, floatmode='fixed') 175 | 176 | for epoch in range(1, args.epochs + 1): 177 | print('\n\n') 178 | start_time = time.time() 179 | dataset.set_augProb(aug_prob) 180 | pretrain_loss, aug_prob = train(loader, model, optimizer, device, args.gamma) 181 | 182 | print('Epoch: {:3d}\tLoss:{:.3f}\tTime: {:.3f}\tAugmentation Probability:'.format( 183 | epoch, pretrain_loss, time.time() - start_time)) 184 | print(aug_prob) 185 | 186 | if not args.output_model_dir == '': 187 | saver_dict = {'model': model.state_dict()} 188 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth') 189 | torch.save(model.gnn.state_dict(), args.output_model_dir + '_model.pth') 190 | -------------------------------------------------------------------------------- /src_regression/run_dti_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python dti_finetune.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_regression/run_molecule_finetune_regression.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python molecule_finetune_regression.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_regression/run_pretrain_AM.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | echo "start" 9 | python pretrain_AM.py $@ 10 | echo "end" 11 | date 12 | -------------------------------------------------------------------------------- /src_regression/run_pretrain_CP.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | echo "start" 9 | python pretrain_CP.py $@ 10 | echo "end" 11 | date 12 | -------------------------------------------------------------------------------- /src_regression/run_pretrain_GraphCL.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | echo "start" 9 | python pretrain_GraphCL.py $@ 10 | echo "end" 11 | date 12 | -------------------------------------------------------------------------------- /src_regression/run_pretrain_GraphMVP.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | echo "start" 9 | python pretrain_GraphMVP.py $@ 10 | echo "end" 11 | date 12 | -------------------------------------------------------------------------------- /src_regression/run_pretrain_GraphMVP_hybrid.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | echo "start" 9 | python pretrain_GraphMVP_hybrid.py $@ 10 | echo "end" 11 | date 12 | -------------------------------------------------------------------------------- /src_regression/run_pretrain_JOAO.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_JOAO.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_regression/run_pretrain_JOAOv2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source $HOME/.bashrc 4 | conda activate GraphMVP 5 | 6 | echo $@ 7 | date 8 | 9 | echo "start" 10 | python pretrain_JOAOv2.py $@ 11 | echo "end" 12 | date 13 | -------------------------------------------------------------------------------- /src_regression/util_complete_feature.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import networkx as nx 4 | import torch 5 | import torch.nn.functional as F 6 | from datasets_complete_feature import (graph_data_obj_to_nx_simple, 7 | nx_to_graph_data_obj_simple) 8 | 9 | 10 | class ExtractSubstructureContextPair: 11 | 12 | def __init__(self, k, l1, l2): 13 | """ 14 | Randomly selects a node from the data object, and adds attributes 15 | that contain the substructure that corresponds to k hop neighbours 16 | rooted at the node, and the context substructures that corresponds to 17 | the subgraph that is between l1 and l2 hops away from the root node. """ 18 | self.k = k 19 | self.l1 = l1 20 | self.l2 = l2 21 | 22 | # for the special case of 0, addresses the quirk with 23 | # single_source_shortest_path_length 24 | if self.k == 0: 25 | self.k = -1 26 | if self.l1 == 0: 27 | self.l1 = -1 28 | if self.l2 == 0: 29 | self.l2 = -1 30 | 31 | def __call__(self, data, root_idx=None): 32 | """ 33 | :param data: pytorch geometric data object 34 | :param root_idx: If None, then randomly samples an atom idx. 35 | Otherwise sets atom idx of root (for debugging only) 36 | :return: None. Creates new attributes in original data object: 37 | data.center_substruct_idx 38 | data.x_substruct 39 | data.edge_attr_substruct 40 | data.edge_index_substruct 41 | data.x_context 42 | data.edge_attr_context 43 | data.edge_index_context 44 | data.overlap_context_substruct_idx """ 45 | num_atoms = data.x.size()[0] 46 | if root_idx is None: 47 | root_idx = random.sample(range(num_atoms), 1)[0] 48 | 49 | G = graph_data_obj_to_nx_simple(data) # same ordering as input data obj 50 | 51 | # Get k-hop subgraph rooted at specified atom idx 52 | substruct_node_idxes = nx.single_source_shortest_path_length(G, root_idx, self.k).keys() 53 | if len(substruct_node_idxes) > 0: 54 | substruct_G = G.subgraph(substruct_node_idxes) 55 | substruct_G, substruct_node_map = reset_idxes(substruct_G) # need 56 | # to reset node idx to 0 -> num_nodes - 1, otherwise data obj does not 57 | # make sense, since the node indices in data obj must start at 0 58 | substruct_data = nx_to_graph_data_obj_simple(substruct_G) 59 | data.x_substruct = substruct_data.x 60 | data.edge_attr_substruct = substruct_data.edge_attr 61 | data.edge_index_substruct = substruct_data.edge_index 62 | data.center_substruct_idx = torch.tensor([substruct_node_map[root_idx]]) # need 63 | # to convert center idx from original graph node ordering to the 64 | # new substruct node ordering 65 | 66 | # Get subgraphs that is between l1 and l2 hops away from the root node 67 | l1_node_idxes = nx.single_source_shortest_path_length(G, root_idx, self.l1).keys() 68 | l2_node_idxes = nx.single_source_shortest_path_length(G, root_idx, self.l2).keys() 69 | context_node_idxes = set(l1_node_idxes).symmetric_difference(set(l2_node_idxes)) 70 | if len(context_node_idxes) > 0: 71 | context_G = G.subgraph(context_node_idxes) 72 | context_G, context_node_map = reset_idxes(context_G) # need to 73 | # reset node idx to 0 -> num_nodes - 1, otherwise data obj does not 74 | # make sense, since the node indices in data obj must start at 0 75 | context_data = nx_to_graph_data_obj_simple(context_G) 76 | data.x_context = context_data.x 77 | data.edge_attr_context = context_data.edge_attr 78 | data.edge_index_context = context_data.edge_index 79 | 80 | # Get indices of overlapping nodes between substruct and context, 81 | # WRT context ordering 82 | context_substruct_overlap_idxes = list(set(context_node_idxes).intersection( 83 | set(substruct_node_idxes))) 84 | if len(context_substruct_overlap_idxes) > 0: 85 | context_substruct_overlap_idxes_reorder = [ 86 | context_node_map[old_idx] 87 | for old_idx in context_substruct_overlap_idxes] 88 | # need to convert the overlap node idxes, which is from the 89 | # original graph node ordering to the new context node ordering 90 | data.overlap_context_substruct_idx = \ 91 | torch.tensor(context_substruct_overlap_idxes_reorder) 92 | 93 | return data 94 | 95 | def __repr__(self): 96 | return '{}(k={},l1={}, l2={})'.format( 97 | self.__class__.__name__, self.k, self.l1, self.l2) 98 | 99 | 100 | def reset_idxes(G): 101 | """ Resets node indices such that they are numbered from 0 to num_nodes - 1 102 | :return: copy of G with relabelled node indices, mapping """ 103 | mapping = {} 104 | for new_idx, old_idx in enumerate(G.nodes()): 105 | mapping[old_idx] = new_idx 106 | new_G = nx.relabel_nodes(G, mapping, copy=True) 107 | return new_G, mapping 108 | 109 | 110 | class MaskAtom: 111 | def __init__(self, num_atom_type, num_edge_type, mask_rate, mask_edge=True): 112 | """ 113 | Randomly masks an atom, and optionally masks edges connecting to it. 114 | The mask atom type index is num_possible_atom_type 115 | The mask edge type index in num_possible_edge_type 116 | :param num_atom_type: 117 | :param num_edge_type: 118 | :param mask_rate: % of atoms to be masked 119 | :param mask_edge: If True, also mask the edges that connect to the 120 | masked atoms """ 121 | self.num_atom_type = num_atom_type 122 | self.num_edge_type = num_edge_type 123 | self.mask_rate = mask_rate 124 | self.mask_edge = mask_edge 125 | 126 | def __call__(self, data, masked_atom_indices=None): 127 | """ 128 | :param data: pytorch geometric data object. Assume that the edge 129 | ordering is the default pytorch geometric ordering, where the two 130 | directions of a single edge occur in pairs. 131 | Eg. data.edge_index = tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) 132 | :param masked_atom_indices: If None, then randomly samples num_atoms 133 | * mask rate number of atom indices 134 | Otherwise a list of atom idx that sets the atoms to be masked (for 135 | debugging only) 136 | :return: None, Creates new attributes in original data object: 137 | data.mask_node_idx 138 | data.mask_node_label 139 | data.mask_edge_idx 140 | data.mask_edge_label """ 141 | 142 | if masked_atom_indices is None: 143 | # sample x distinct atoms to be masked, based on mask rate. But 144 | # will sample at least 1 atom 145 | num_atoms = data.x.size()[0] 146 | sample_size = int(num_atoms * self.mask_rate + 1) 147 | masked_atom_indices = random.sample(range(num_atoms), sample_size) 148 | 149 | # create mask node label by copying atom feature of mask atom 150 | mask_node_labels_list = [] 151 | for atom_idx in masked_atom_indices: 152 | mask_node_labels_list.append(data.x[atom_idx].view(1, -1)) 153 | data.mask_node_label = torch.cat(mask_node_labels_list, dim=0) 154 | data.masked_atom_indices = torch.tensor(masked_atom_indices) 155 | 156 | # modify the original node feature of the masked node 157 | data.masked_x = data.x.clone() 158 | for atom_idx in masked_atom_indices: 159 | data.masked_x[atom_idx] = torch.tensor([self.num_atom_type-1, 0, 0, 0, 0, 0, 0, 0, 0]) 160 | 161 | return data 162 | 163 | def __repr__(self): 164 | return '{}(num_atom_type={}, num_edge_type={}, mask_rate={}, mask_edge={})'.format( 165 | self.__class__.__name__, self.num_atom_type, self.num_edge_type, 166 | self.mask_rate, self.mask_edge) 167 | --------------------------------------------------------------------------------