├── LICENSE
├── README.md
├── datasets
├── README.md
└── dti_datasets
│ ├── davis
│ ├── preprocess.out
│ └── preprocess.py
│ └── kiba
│ ├── preprocess.out
│ └── preprocess.py
├── fig
├── baselines.png
└── pipeline.png
├── legacy
└── src_classification
│ ├── pretrain_3D_hybrid_02.py
│ ├── pretrain_3D_hybrid_03.py
│ ├── run_pretrain_3D_hybrid_02.sh
│ └── run_pretrain_3D_hybrid_03.sh
├── scripts_classification
├── run_fine_tuning_model.sh
├── submit_fine_tuning.sh
├── submit_pre_training_GraphMVP.sh
├── submit_pre_training_GraphMVP_hybrid.sh
└── submit_pre_training_baselines.sh
├── scripts_regression
├── run_fine_tuning_model.sh
├── run_fine_tuning_model_DTI.sh
├── submit_fine_tuning.sh
├── submit_pre_training_GraphMVP.sh
├── submit_pre_training_GraphMVP_hybrid.sh
└── submit_pre_training_baselines.sh
├── src_classification
├── GEOM_dataset_preparation.py
├── batch.py
├── config.py
├── dataloader.py
├── datasets
│ ├── __init__.py
│ ├── datasets_GPT.py
│ ├── molecule_3D_dataset.py
│ ├── molecule_3D_masking_dataset.py
│ ├── molecule_contextual_datasets.py
│ ├── molecule_contextual_datasets_utils.py
│ ├── molecule_datasets.py
│ ├── molecule_graphcl_dataset.py
│ ├── molecule_graphcl_masking_dataset.py
│ └── molecule_motif_datasets.py
├── models
│ ├── __init__.py
│ ├── auto_encoder.py
│ ├── dti_model.py
│ ├── molecule_gnn_model.py
│ └── schnet.py
├── molecule_finetune.py
├── pretrain_AM.py
├── pretrain_CP.py
├── pretrain_Contextual.py
├── pretrain_EP.py
├── pretrain_GPT_GNN.py
├── pretrain_GraphCL.py
├── pretrain_GraphLoG.py
├── pretrain_GraphMVP.py
├── pretrain_GraphMVP_hybrid.py
├── pretrain_IG.py
├── pretrain_JOAO.py
├── pretrain_JOAOv2.py
├── pretrain_Motif.py
├── run_molecule_finetune.sh
├── run_pretrain_AM.sh
├── run_pretrain_CP.sh
├── run_pretrain_Contextual.sh
├── run_pretrain_EP.sh
├── run_pretrain_GPT_GNN.sh
├── run_pretrain_GraphCL.sh
├── run_pretrain_GraphLoG.sh
├── run_pretrain_GraphMVP.sh
├── run_pretrain_GraphMVP_hybrid.sh
├── run_pretrain_IG.sh
├── run_pretrain_JOAO.sh
├── run_pretrain_JOAOv2.sh
├── run_pretrain_Motif.sh
├── splitters.py
└── util.py
└── src_regression
├── GEOM_dataset_preparation.py
├── datasets_complete_feature
├── __init__.py
├── dti_datasets.py
├── molecule_datasets.py
└── molecule_graphcl_dataset.py
├── dti_finetune.py
├── models_complete_feature
├── __init__.py
└── molecule_gnn_model.py
├── molecule_finetune_regression.py
├── pretrain_AM.py
├── pretrain_CP.py
├── pretrain_GraphCL.py
├── pretrain_GraphMVP.py
├── pretrain_GraphMVP_hybrid.py
├── pretrain_JOAO.py
├── pretrain_JOAOv2.py
├── run_dti_finetune.sh
├── run_molecule_finetune_regression.sh
├── run_pretrain_AM.sh
├── run_pretrain_CP.sh
├── run_pretrain_GraphCL.sh
├── run_pretrain_GraphMVP.sh
├── run_pretrain_GraphMVP_hybrid.sh
├── run_pretrain_JOAO.sh
├── run_pretrain_JOAOv2.sh
└── util_complete_feature.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Shengchao Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pre-training Molecular Graph Representation with 3D Geometry
2 |
3 | **ICLR 2022**
4 |
5 | Authors: Shengchao Liu, Hanchen Wang, Weiyang Liu, Joan Lasenby, Hongyu Guo, Jian Tang
6 |
7 | [[Project Page](https://chao1224.github.io/GraphMVP)]
8 | [[Paper](https://openreview.net/forum?id=xQUe1pOKPam)]
9 | [[ArXiv](https://arxiv.org/abs/2110.07728)]
10 | [[Slides](https://drive.google.com/file/d/1-lDWtdgeEgTO009YVPzHK8f7yYbvQ1oY/view?usp=sharing)]
11 | [[Poster](https://drive.google.com/file/d/1L_XrlgfmCmycfGf47Dt6nnaKpZtiqiN-/view?usp=sharing)]
12 |
13 | [[NeurIPS SSL Workshop 2021](https://sslneurips21.github.io/)]
14 | [[ICLR GTRL Workshop 2022 (Spotlight)](https://gt-rl.github.io/)]
15 |
16 | This repository provides the source code for the ICLR'22 paper **Pre-training Molecular Graph Representation with 3D Geometry**, with the following task:
17 | - During pre-training, we consider both the 2D topology and 3D geometry.
18 | - During downstream, we consider tasks with 2D topology only.
19 |
20 | In the future, we will merge it into the [TorchDrug](https://github.com/DeepGraphLearning/torchdrug) package.
21 |
22 |
23 |
24 |
25 |
26 | ## Baselines
27 | For implementation, this repository also provides the following graph SSL baselines:
28 | - Generative Graph SSL:
29 | - [Edge Prediction (EdgePred)](https://proceedings.neurips.cc/paper/2017/file/5dd9db5e033da9c6fb5ba83c7a7ebea9-Paper.pdf)
30 | - [AttributeMasking (AttrMask)](https://openreview.net/forum?id=HJlWWJSFDH)
31 | - [GPT-GNN](https://arxiv.org/abs/2006.15437)
32 | - Contrastive Graph SSL:
33 | - [InfoGraph](https://openreview.net/pdf?id=r1lfF2NYvH)
34 | - [Context Prediction (ContextPred)](https://openreview.net/forum?id=HJlWWJSFDH)
35 | - [GraphLoG](http://proceedings.mlr.press/v139/xu21g/xu21g.pdf)
36 | - [Grover-Contextual](https://papers.nips.cc/paper/2020/hash/94aef38441efa3380a3bed3faf1f9d5d-Abstract.html)
37 | - [GraphCL](https://papers.nips.cc/paper/2020/file/3fe230348e9a12c13120749e3f9fa4cd-Paper.pdf)
38 | - [JOAO](https://arxiv.org/abs/2106.07594)
39 | - Predictive Graph SSL:
40 | - [Grover-Motif](https://papers.nips.cc/paper/2020/hash/94aef38441efa3380a3bed3faf1f9d5d-Abstract.html)
41 |
42 |
43 |
44 |
45 |
46 | ## Environments
47 | Install packages under conda env
48 | ```bash
49 | conda create -n GraphMVP python=3.7
50 | conda activate GraphMVP
51 |
52 | conda install -y -c rdkit rdkit
53 | conda install -y -c pytorch pytorch=1.9.1
54 | conda install -y numpy networkx scikit-learn
55 | pip install ase
56 | pip install git+https://github.com/bp-kelley/descriptastorus
57 | pip install ogb
58 | export TORCH=1.9.0
59 | export CUDA=cu102 # cu102, cu110
60 |
61 | wget https://data.pyg.org/whl/torch-${TORCH}%2B${CUDA}/torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl
62 | pip install torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl
63 | wget https://data.pyg.org/whl/torch-${TORCH}%2B${CUDA}/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl
64 | pip install torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl
65 | wget https://data.pyg.org/whl/torch-${TORCH}%2B${CUDA}/torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl
66 | pip install torch_sparse-0.6.12-cp37-cp37m-linux_x86_64.whl
67 | pip install torch-geometric==1.7.2
68 | ```
69 |
70 | ## Dataset Preprocessing
71 |
72 | For dataset download, please follow the instruction [here](https://github.com/chao1224/GraphMVP/tree/main/datasets).
73 |
74 | For data preprocessing (GEOM), please use the following commands:
75 | ```
76 | cd src_classification
77 | python GEOM_dataset_preparation.py --n_mol 50000 --n_conf 5 --n_upper 1000 --data_folder $SLURM_TMPDIR
78 | cd ..
79 |
80 | cd src_regression
81 | python GEOM_dataset_preparation.py --n_mol 50000 --n_conf 5 --n_upper 1000 --data_folder $SLURM_TMPDIR
82 | cd ..
83 |
84 | mv $SLURM_TMPDIR/GEOM datasets
85 | ```
86 |
87 | **Featurization**. We employ two sets of featurization methods on atoms.
88 | 1. For classification tasks, in order to follow the main molecular graph SSL research line, we use the same atom featurization methods (consider the atom types and chirality).
89 | 2. For regression tasks, results with the above two atom-level features are too bad. Thus, we consider more comprehensive features from OGB.
90 |
91 | ## Experiments
92 |
93 | ### Terminology specification
94 |
95 | In the latest scripts, we use `GraphMVP` for the trivial GraphMVP (Eq. 7 in the paper), and `GraphMVP_hybrid` includes two variants adding extra 2D SSL pretext tasks (Eq 8. in the paper).
96 | In the previous scripts, we call these two terms as `3D_hybrid_02_masking` and `3D_hybrid_03_masking` respectively.
97 | This could show up in some pre-trained log files [here](https://drive.google.com/drive/folders/1uPsBiQF3bfeCAXSDd4JfyXiTh-qxYfu6?usp=sharing).
98 |
99 | | GraphMVP | Latest scripts | Previous scripts |
100 | | :--: | :--: | :--: |
101 | | Eq. 7 | `GraphMVP` | `3D_hybrid_02_masking` |
102 | | Eq. 8 | `GraphMVP_hybrid` | `3D_hybrid_03_masking` |
103 |
104 | ### For GraphMVP pre-training
105 |
106 | Check the following scripts:
107 | - `scripts_classification/submit_pre_training_GraphMVP.sh`
108 | - `scripts_classification/submit_pre_training_GraphMVP_hybrid.sh`
109 | - `scripts_regression/submit_pre_training_GraphMVP.sh`
110 | - `scripts_regression/submit_pre_training_GraphMVP_hybrid.sh`
111 |
112 | The pre-trained model weights, training logs, and prediction files can be found [here](https://drive.google.com/drive/folders/1uPsBiQF3bfeCAXSDd4JfyXiTh-qxYfu6?usp=sharing).
113 |
114 | ### For Other SSL pre-training baselines
115 |
116 | Check the following scripts:
117 | - `scripts_classification/submit_pre_training_baselines.sh`
118 | - `scripts_regression/submit_pre_training_baselines.sh`
119 |
120 | ### For Downstream tasks
121 |
122 | Check the following scripts:
123 | - `scripts_classification/submit_fine_tuning.sh`
124 | - `scripts_regression/submit_fine_tuning.sh`
125 |
126 | ## Cite Us
127 |
128 | Feel free to cite this work if you find it useful to you!
129 |
130 | ```
131 | @inproceedings{liu2022pretraining,
132 | title={Pre-training Molecular Graph Representation with 3D Geometry},
133 | author={Shengchao Liu and Hanchen Wang and Weiyang Liu and Joan Lasenby and Hongyu Guo and Jian Tang},
134 | booktitle={International Conference on Learning Representations},
135 | year={2022},
136 | url={https://openreview.net/forum?id=xQUe1pOKPam}
137 | }
138 | ```
139 |
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 | ## Geometric Ensemble Of Molecules (GEOM)
4 |
5 | ```bash
6 | mkdir -p GEOM/raw
7 | mkdir -p GEOM/processed
8 | ```
9 |
10 | + GEOM: [Paper](https://arxiv.org/pdf/2006.05531v3.pdf), [GitHub](https://github.com/learningmatter-mit/geom)
11 | + Data Download:
12 | + [Not Used] [Drug Crude](https://dataverse.harvard.edu/api/access/datafile/4360331),
13 | [Drug Featurized](https://dataverse.harvard.edu/api/access/datafile/4327295),
14 | [QM9 Crude](https://dataverse.harvard.edu/api/access/datafile/4327190),
15 | [QM9 Featurized](https://dataverse.harvard.edu/api/access/datafile/4327191)
16 |
17 | + [Mainly Used] [RdKit Folder](https://dataverse.harvard.edu/api/access/datafile/4327252)
18 | ```bash
19 | wget https://dataverse.harvard.edu/api/access/datafile/4327252
20 | mv 4327252 rdkit_folder.tar.gz
21 | tar -xvf rdkit_folder.tar.gz
22 | ```
23 | or do the following if you are using slurm system
24 | ```
25 | cp rdkit_folder.tar.gz $SLURM_TMPDIR
26 | cd $SLURM_TMPDIR
27 | tar -xvf rdkit_folder.tar.gz
28 | ```
29 | + over 33m conformations
30 | + over 430k molecules
31 | + 304,466 species contain experimental data for the inhibition of various pathogens
32 | + 133,258 are species from the QM9
33 |
34 | ## Chem Dataset
35 |
36 | ```bash
37 | wget http://snap.stanford.edu/gnn-pretrain/data/chem_dataset.zip
38 | unzip chem_dataset.zip
39 | mv dataset molecule_datasets
40 | ```
41 |
42 | ## Other Chem Datasets
43 |
44 | - delaney/esol (already included)
45 | - lipophilicity (already included)
46 | - malaria
47 | - cep
48 |
49 | ```
50 | wget -O malaria-processed.csv https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-03-malaria/malaria-processed.csv
51 | mkdir -p ./molecule_datasets/malaria/raw
52 | mv malaria-processed.csv ./molecule_datasets/malaria/raw/malaria.csv
53 |
54 | wget -O cep-processed.csv https://raw.githubusercontent.com/HIPS/neural-fingerprint/master/data/2015-06-02-cep-pce/cep-processed.csv
55 | mkdir -p ./molecule_datasets/cep/raw
56 | mv cep-processed.csv ./molecule_datasets/cep/raw/cep.csv
57 | ```
58 |
59 | Then we copy them for the regression (more atom features).
60 | ```
61 | mkdir -p ./molecule_datasets_regression/esol
62 | cp -r ./molecule_datasets/esol/raw ./molecule_datasets_regression/esol/
63 |
64 | mkdir -p ./molecule_datasets_regression/lipophilicity
65 | cp -r ./molecule_datasets/lipophilicity/raw ./molecule_datasets_regression/lipophilicity/
66 |
67 | mkdir -p ./molecule_datasets_regression/malaria
68 | cp -r ./molecule_datasets/malaria/raw ./molecule_datasets_regression/malaria/
69 |
70 | mkdir -p ./molecule_datasets_regression/cep
71 | cp -r ./molecule_datasets/cep/raw ./molecule_datasets_regression/cep/
72 | ```
73 |
74 | ## Drug-Target Interaction
75 |
76 | - Davis
77 | - Kiba
78 |
79 | ```
80 | mkdir -p dti_datasets
81 | cd dti_datasets
82 | ```
83 |
84 | Then we can follow [DeepDTA](https://github.com/hkmztrk/DeepDTA).
85 | ```
86 | git clone git@github.com:hkmztrk/DeepDTA.git
87 | cp -r DeepDTA/data/davis davis/
88 | cp -r DeepDTA/data/kiba kiba/
89 |
90 | cd davis
91 | python preprocess.py > preprocess.out
92 | cd ..
93 |
94 | cd kiba
95 | python preprocess.py > preprocess.out
96 | cd ..
97 | ```
98 |
--------------------------------------------------------------------------------
/datasets/dti_datasets/davis/preprocess.out:
--------------------------------------------------------------------------------
1 | convert data from DeepDTA for davis
2 | 68 drugs 68 unique drugs
3 |
4 | dataset: davis
5 | train_fold: 25046
6 | test_fold: 5010
7 | len(set(drugs)),len(set(prots)): 68 379
8 |
--------------------------------------------------------------------------------
/datasets/dti_datasets/davis/preprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import OrderedDict
3 | import pickle
4 | import numpy as np
5 | from rdkit import Chem
6 | from rdkit.Chem import MolFromSmiles
7 |
8 |
9 |
10 | if __name__ == '__main__':
11 | # from DeepDTA data
12 | dataset = 'davis'
13 |
14 | print('convert data from DeepDTA for ', dataset)
15 | fpath = dataset + '/'
16 | train_fold = json.load(open(fpath + "folds/train_fold_setting1.txt"))
17 | train_fold = [ee for e in train_fold for ee in e]
18 | valid_fold = json.load(open(fpath + "folds/test_fold_setting1.txt"))
19 | ligands = json.load(open(fpath + "ligands_can.txt"), object_pairs_hook=OrderedDict)
20 | proteins = json.load(open(fpath + "proteins.txt"), object_pairs_hook=OrderedDict)
21 | affinity = pickle.load(open(fpath + "Y", "rb"), encoding='latin1')
22 | drugs = []
23 | prots = []
24 |
25 | for d in ligands.keys():
26 | lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]), isomericSmiles=True)
27 | drugs.append(lg)
28 | for t in proteins.keys():
29 | prots.append(proteins[t])
30 | if dataset == 'davis':
31 | affinity = [-np.log10(y / 1e9) for y in affinity]
32 | affinity = np.asarray(affinity)
33 | opts = ['train', 'test']
34 |
35 | print('{} drugs\t{} unique drugs'.format(len(drugs), len(set(drugs))))
36 |
37 | with open('smiles.csv', 'w') as f:
38 | print('smiles', file=f)
39 | for drug in drugs:
40 | print(drug, file=f)
41 | with open('protein.csv', 'w') as f:
42 | print('protein', file=f)
43 | for prot in prots:
44 | print(prot, file=f)
45 |
46 | for opt in opts:
47 | rows, cols = np.where(np.isnan(affinity) == False)
48 | if opt == 'train':
49 | rows, cols = rows[train_fold], cols[train_fold]
50 | elif opt == 'test':
51 | rows, cols = rows[valid_fold], cols[valid_fold]
52 | # with open(opt + '.csv', 'w') as f:
53 | # f.write('compound_iso_smiles,target_sequence,affinity\n')
54 | # for pair_ind in range(len(rows)):
55 | # ls = []
56 | # ls += [drugs[rows[pair_ind]]]
57 | # ls += [prots[cols[pair_ind]]]
58 | # ls += [affinity[rows[pair_ind], cols[pair_ind]]]
59 | # f.write(','.join(map(str, ls)) + '\n')
60 | with open(opt + '.csv', 'w') as f:
61 | f.write('smiles_id,target_id,affinity\n')
62 | for pair_ind in range(len(rows)):
63 | ls = []
64 | ls += [rows[pair_ind]]
65 | ls += [cols[pair_ind]]
66 | ls += [affinity[rows[pair_ind], cols[pair_ind]]]
67 | f.write(','.join(map(str, ls)) + '\n')
68 |
69 | print('\ndataset:', dataset)
70 | print('train_fold:', len(train_fold))
71 | print('test_fold:', len(valid_fold))
72 | print('len(set(drugs)),len(set(prots)):', len(set(drugs)), len(set(prots)))
73 |
--------------------------------------------------------------------------------
/datasets/dti_datasets/kiba/preprocess.out:
--------------------------------------------------------------------------------
1 | convert data from DeepDTA for kiba
2 | 2111 drugs 2068 unique drugs
3 |
4 | dataset: kiba
5 | train_fold: 98545
6 | test_fold: 19709
7 | len(set(drugs)),len(set(prots)): 2068 229
8 |
--------------------------------------------------------------------------------
/datasets/dti_datasets/kiba/preprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import OrderedDict
3 | import pickle
4 | import numpy as np
5 | from rdkit import Chem
6 | from rdkit.Chem import MolFromSmiles
7 |
8 |
9 |
10 | if __name__ == '__main__':
11 | # from DeepDTA data
12 | dataset = 'kiba'
13 |
14 | print('convert data from DeepDTA for ', dataset)
15 | fpath = dataset + '/'
16 | train_fold = json.load(open(fpath + "folds/train_fold_setting1.txt"))
17 | train_fold = [ee for e in train_fold for ee in e]
18 | valid_fold = json.load(open(fpath + "folds/test_fold_setting1.txt"))
19 | ligands = json.load(open(fpath + "ligands_can.txt"), object_pairs_hook=OrderedDict)
20 | proteins = json.load(open(fpath + "proteins.txt"), object_pairs_hook=OrderedDict)
21 | affinity = pickle.load(open(fpath + "Y", "rb"), encoding='latin1')
22 | drugs = []
23 | prots = []
24 |
25 | for d in ligands.keys():
26 | lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]), isomericSmiles=True)
27 | drugs.append(lg)
28 | for t in proteins.keys():
29 | prots.append(proteins[t])
30 | if dataset == 'davis':
31 | affinity = [-np.log10(y / 1e9) for y in affinity]
32 | affinity = np.asarray(affinity)
33 | opts = ['train', 'test']
34 |
35 | print('{} drugs\t{} unique drugs'.format(len(drugs), len(set(drugs))))
36 |
37 | with open('smiles.csv', 'w') as f:
38 | print('smiles', file=f)
39 | for drug in drugs:
40 | print(drug, file=f)
41 | with open('protein.csv', 'w') as f:
42 | print('protein', file=f)
43 | for prot in prots:
44 | print(prot, file=f)
45 |
46 | for opt in opts:
47 | rows, cols = np.where(np.isnan(affinity) == False)
48 | if opt == 'train':
49 | rows, cols = rows[train_fold], cols[train_fold]
50 | elif opt == 'test':
51 | rows, cols = rows[valid_fold], cols[valid_fold]
52 | # with open(opt + '.csv', 'w') as f:
53 | # f.write('compound_iso_smiles,target_sequence,affinity\n')
54 | # for pair_ind in range(len(rows)):
55 | # ls = []
56 | # ls += [drugs[rows[pair_ind]]]
57 | # ls += [prots[cols[pair_ind]]]
58 | # ls += [affinity[rows[pair_ind], cols[pair_ind]]]
59 | # f.write(','.join(map(str, ls)) + '\n')
60 | with open(opt + '.csv', 'w') as f:
61 | f.write('smiles_id,target_id,affinity\n')
62 | for pair_ind in range(len(rows)):
63 | ls = []
64 | ls += [rows[pair_ind]]
65 | ls += [cols[pair_ind]]
66 | ls += [affinity[rows[pair_ind], cols[pair_ind]]]
67 | f.write(','.join(map(str, ls)) + '\n')
68 |
69 | print('\ndataset:', dataset)
70 | print('train_fold:', len(train_fold))
71 | print('test_fold:', len(valid_fold))
72 | print('len(set(drugs)),len(set(prots)):', len(set(drugs)), len(set(prots)))
73 |
--------------------------------------------------------------------------------
/fig/baselines.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chao1224/GraphMVP/80c42aff21ece462c951cfda110f8eaf866aea0c/fig/baselines.png
--------------------------------------------------------------------------------
/fig/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chao1224/GraphMVP/80c42aff21ece462c951cfda110f8eaf866aea0c/fig/pipeline.png
--------------------------------------------------------------------------------
/legacy/src_classification/run_pretrain_3D_hybrid_02.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate PC3D
5 | conda deactivate
6 | conda activate PC3D
7 |
8 | #cp -r ../datasets/GEOM_* $SLURM_TMPDIR
9 | #
10 | #echo $@
11 | #date
12 | #echo "start"
13 | #python pretrain_3D_hybrid_02.py --input_data_dir="$SLURM_TMPDIR" $@
14 | #echo "end"
15 | #date
16 |
17 |
18 | echo $@
19 | date
20 | echo "start"
21 | python pretrain_3D_hybrid_02.py $@
22 | echo "end"
23 | date
24 |
--------------------------------------------------------------------------------
/legacy/src_classification/run_pretrain_3D_hybrid_03.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate PC3D
5 | conda deactivate
6 | conda activate PC3D
7 |
8 | echo $@
9 | date
10 | echo "start"
11 | python pretrain_3D_hybrid_03.py $@
12 | echo "end"
13 | date
14 |
--------------------------------------------------------------------------------
/scripts_classification/run_fine_tuning_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --account=rrg-bengioy-ad
4 | #SBATCH --cpus-per-task=8
5 | #SBATCH --gres=gpu:v100l:1
6 | #SBATCH --mem=32G
7 | #SBATCH --time=2:59:00
8 | #SBATCH --ntasks=1
9 | #SBATCH --array=0-2%3
10 | #SBATCH --output=logs/%j.out
11 |
12 |
13 | ###############SBATCH --gres=gpu:v100l:1
14 |
15 | cd src
16 |
17 | echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID
18 | export dataset_list=(tox21 toxcast clintox bbbp sider muv hiv bace)
19 | export seed_list=(0 1 2 3 4 5 6 7 8 9)
20 | export batch_size=256
21 | export mode=$1
22 | export seed=${seed_list[$SLURM_ARRAY_TASK_ID]}
23 |
24 |
25 |
26 |
27 |
28 | if [ "$mode" == "random" ]; then
29 |
30 | for dataset in "${dataset_list[@]}"; do
31 | export folder="$mode"/"$seed"
32 | mkdir -p ../output/"$folder"
33 | mkdir -p ../output/"$folder"/"$dataset"
34 |
35 | export output_path=../output/"$folder"/"$dataset".out
36 | export output_model_dir=../output/"$folder"/"$dataset"
37 |
38 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path"
39 | echo `date` >> "$output_path"
40 |
41 | bash ./run_molecule_finetune.sh \
42 | --dataset="$dataset" --runseed="$seed" --eval_train --batch_size="$batch_size" \
43 | --dropout_ratio=0.5 \
44 | --output_model_dir="$output_model_dir" \
45 | >> "$output_path"
46 |
47 | echo `date` >> "$output_path"
48 | done
49 |
50 |
51 |
52 |
53 | else
54 |
55 | for dataset in "${dataset_list[@]}"; do
56 | export folder="$mode"/"$seed"
57 | mkdir -p ../output/"$folder"
58 | mkdir -p ../output/"$folder"/"$dataset"
59 |
60 | export output_path=../output/"$folder"/"$dataset".out
61 | # export output_model_dir=../output/"$folder"/"$dataset"
62 | export input_model_file=../output/"$mode"/pretraining_model.pth
63 |
64 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path"
65 | echo `date` >> "$output_path"
66 |
67 | bash ./run_molecule_finetune.sh \
68 | --dataset="$dataset" --runseed="$seed" --eval_train --batch_size="$batch_size" \
69 | --dropout_ratio=0.5 \
70 | --input_model_file="$input_model_file" \
71 | >> "$output_path"
72 | # --input_model_file="$input_model_file" --output_model_dir="$output_model_dir" \
73 | # >> "$output_path"
74 |
75 | echo `date` >> "$output_path"
76 | done
77 |
78 | fi
79 |
80 |
--------------------------------------------------------------------------------
/scripts_classification/submit_fine_tuning.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../src_classification
4 |
5 | mode_list=(
6 | random
7 | EP/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0
8 | AM/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0
9 | IG/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0
10 | CP/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0
11 | GraphLoG/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0
12 | Motif/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0
13 | Contextual/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0
14 | GraphCL/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0
15 | JOAO/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0
16 | JOAOv2/GEOM_2D_nmol50000_nconf1_nupper1000/epochs_100_0
17 |
18 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0
19 | GraphMVP_hybrid/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1_AM_1/6_51_10_0.1/0.15_EBM_dot_prod_0.05_normalize_l2_detach_target_2_100_0
20 | GraphMVP_hybrid/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1_CP_0.1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0
21 | )
22 |
23 | for mode in "${mode_list[@]}"; do
24 | echo "$mode"
25 | ls output/"$mode"
26 |
27 | sbatch run_fine_tuning_model.sh "$mode"
28 |
29 | echo
30 |
31 | done
32 |
33 |
34 |
35 | # Below is for ablation study
36 | mode_list=(
37 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0
38 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_1/6_51_10_0.1/0.15_InfoNCE_dot_prod_0.2_normalize_l2_detach_target_2_100_0
39 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_0/6_51_10_0.1/0.15_InfoNCE_dot_prod_0.2_normalize_l2_detach_target_2_100_0
40 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_VAE_0/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0
41 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_0_VAE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0
42 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_AE_1/6_51_10_0.1/0.15_InfoNCE_dot_prod_0.2_normalize_l2_detach_target_2_100_0
43 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_1_AE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0
44 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000/CL_0_AE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0
45 | )
46 |
--------------------------------------------------------------------------------
/scripts_classification/submit_pre_training_GraphMVP.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | cd ../src_classification
4 |
5 |
6 | export mode=GraphMVP
7 | export dataset_list=(GEOM_3D_nmol50000_nconf5_nupper1000)
8 | export epochs=100
9 | export time=9
10 |
11 |
12 | # For SchNet and GNN
13 | export schnet_lr_scale_list=(0.1)
14 | export num_interactions=6
15 | export num_gaussians=51
16 | export cutoff=10
17 | export dropout_ratio_list=(0)
18 | export SSL_masking_ratio_list=(0.15 0.3)
19 |
20 |
21 |
22 | # For CL
23 | # export CL_similarity_metric_list=(InfoNCE_dot_prod EBM_dot_prod)
24 | export CL_similarity_metric_list=(EBM_dot_prod)
25 | export T_list=(0.1 0.2 0.5 1 2)
26 | export normalize_list=(normalize)
27 |
28 |
29 |
30 | # For VAE
31 | export AE_model=VAE
32 | # export AE_loss_list=(l1 l2 cosine)
33 | export AE_loss_list=(l2)
34 | # export detach_list=(detach_target no_detach_target)
35 | export detach_list=(detach_target)
36 | # export beta_list=(0.1 1 2)
37 | export beta_list=(1 2)
38 |
39 |
40 |
41 |
42 | # For CL + VAE
43 | export alpha_1_list=(1)
44 | export alpha_2_list=(0.1 1)
45 |
46 |
47 |
48 | export SSL_masking_ratio_list=(0)
49 | export CL_similarity_metric_list=(EBM_dot_prod)
50 | export T_list=(0.1 0.2)
51 |
52 |
53 |
54 |
55 | for dataset in "${dataset_list[@]}"; do
56 | for SSL_masking_ratio in "${SSL_masking_ratio_list[@]}"; do
57 |
58 | for alpha_1 in "${alpha_1_list[@]}"; do
59 | for alpha_2 in "${alpha_2_list[@]}"; do
60 | for CL_similarity_metric in "${CL_similarity_metric_list[@]}"; do
61 | for normalize in "${normalize_list[@]}"; do
62 | for T in "${T_list[@]}"; do
63 | for AE_loss in "${AE_loss_list[@]}"; do
64 | for detach in "${detach_list[@]}"; do
65 | for beta in "${beta_list[@]}"; do
66 |
67 |
68 | for schnet_lr_scale in "${schnet_lr_scale_list[@]}"; do
69 | for dropout_ratio in "${dropout_ratio_list[@]}"; do
70 | export folder="$mode"/"$dataset"/CL_"$alpha_1"_"$AE_model"_"$alpha_2"/"$num_interactions"_"$num_gaussians"_"$cutoff"_"$schnet_lr_scale"/"$SSL_masking_ratio"_"$CL_similarity_metric"_"$T"_"$normalize"_"$AE_loss"_"$detach"_"$beta"_"$epochs"_"$dropout_ratio"
71 |
72 | echo "$folder"
73 | mkdir -p ../output/"$folder"
74 | ls ../output/"$folder"
75 |
76 | export output_file=../output/"$folder"/pretraining.out
77 | export output_model_dir=../output/"$folder"/pretraining
78 |
79 |
80 | echo "$output_file" undone
81 |
82 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":00:00 --account=rrg-bengioy-ad --qos=high --job-name=CL_VAE_"$time" \
83 | --output="$output_file" \
84 | ./run_pretrain_"$mode".sh \
85 | --epochs="$epochs" \
86 | --dataset="$dataset" \
87 | --batch_size=256 \
88 | --SSL_masking_ratio="$SSL_masking_ratio" \
89 | --CL_similarity_metric="$CL_similarity_metric" --T="$T" --"$normalize" \
90 | --AE_model="$AE_model" --AE_loss="$AE_loss" --"$detach" --beta="$beta" \
91 | --alpha_1="$alpha_1" --alpha_2="$alpha_2" \
92 | --num_interactions="$num_interactions" --num_gaussians="$num_gaussians" --cutoff="$cutoff" --schnet_lr_scale="$schnet_lr_scale" \
93 | --dropout_ratio="$dropout_ratio" --num_workers=8 \
94 | --output_model_dir="$output_model_dir"
95 |
96 | done
97 | done
98 |
99 | done
100 | done
101 | done
102 | done
103 | done
104 | done
105 | done
106 | done
107 | done
108 | done
109 |
--------------------------------------------------------------------------------
/scripts_classification/submit_pre_training_GraphMVP_hybrid.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | cd ../src_classification
4 |
5 |
6 | export mode=GraphMVP_hybrid
7 | export dataset_list=(GEOM_3D_nmol50000_nconf5_nupper1000)
8 | export epochs=100
9 |
10 |
11 |
12 |
13 |
14 | # For SchNet and GNN
15 | export schnet_lr_scale_list=(0.1)
16 | export num_interactions=6
17 | export num_gaussians=51
18 | export cutoff=10
19 | export dropout_ratio_list=(0)
20 | export SSL_masking_ratio_list=(0.15 0.3)
21 |
22 |
23 |
24 |
25 | # For CL
26 | # export CL_similarity_metric_list=(InfoNCE_dot_prod EBM_dot_prod)
27 | export CL_similarity_metric_list=(EBM_dot_prod)
28 | # export T_list=(0.05 0.1 0.2 0.5 1 2)
29 | export T_list=(0.1 0.2)
30 | export normalize_list=(normalize)
31 |
32 |
33 |
34 | # For VAE
35 | export AE_model=VAE
36 | # export AE_loss_list=(l1 l2 cosine)
37 | export AE_loss_list=(l2)
38 | # export detach_list=(detach_target no_detach_target)
39 | export detach_list=(detach_target)
40 | # export beta_list=(0.1 1 2)
41 | export beta_list=(1 2)
42 |
43 |
44 |
45 |
46 |
47 | # For CL + VAE
48 | export alpha_1_list=(1)
49 | export alpha_2_list=(0.1 1)
50 | export alpha_3_list=(0.1 1)
51 | export SSL_2D_mode_list=( CP AM IG JOAOv2 JOAO GraphCL)
52 | export time_list=( 3 3 3 6 6 3)
53 | export time_list=( 9 9 6 12 12 6)
54 |
55 |
56 |
57 |
58 | for dataset in "${dataset_list[@]}"; do
59 | for SSL_masking_ratio in "${SSL_masking_ratio_list[@]}"; do
60 |
61 | for i in {0..1}; do
62 | SSL_2D_mode=${SSL_2D_mode_list[$i]}
63 | time=${time_list[$i]}
64 |
65 | for alpha_3 in "${alpha_3_list[@]}"; do
66 |
67 | for alpha_1 in "${alpha_1_list[@]}"; do
68 | for alpha_2 in "${alpha_2_list[@]}"; do
69 | for CL_similarity_metric in "${CL_similarity_metric_list[@]}"; do
70 | for normalize in "${normalize_list[@]}"; do
71 | for T in "${T_list[@]}"; do
72 | for AE_loss in "${AE_loss_list[@]}"; do
73 | for detach in "${detach_list[@]}"; do
74 | for beta in "${beta_list[@]}"; do
75 |
76 |
77 | for schnet_lr_scale in "${schnet_lr_scale_list[@]}"; do
78 | for dropout_ratio in "${dropout_ratio_list[@]}"; do
79 | export folder="$mode"/"$dataset"/CL_"$alpha_1"_"$AE_model"_"$alpha_2"_"$SSL_2D_mode"_"$alpha_3"/"$num_interactions"_"$num_gaussians"_"$cutoff"_"$schnet_lr_scale"/"$SSL_masking_ratio"_"$CL_similarity_metric"_"$T"_"$normalize"_"$AE_loss"_"$detach"_"$beta"_"$epochs"_"$dropout_ratio"
80 |
81 | echo "$folder"
82 | mkdir -p ../output/"$folder"
83 | ls ../output/"$folder"
84 |
85 | export output_file=../output/"$folder"/pretraining.out
86 | export output_model_dir=../output/"$folder"/pretraining
87 |
88 | echo "$output_model_dir"_model_final.pth undone
89 | ls "$output_model_dir"*
90 | echo "$output_file"
91 | ls ../output/"$folder"
92 | rm ../output/"$folder"/*
93 |
94 |
95 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":00:00 --account=rrg-bengioy-ad --qos=high --job-name=CL_VAE_"$SSL_2D_mode"_"$time" \
96 | --output="$output_file" \
97 | ./run_pretrain_"$mode".sh \
98 | --epochs="$epochs" \
99 | --dataset="$dataset" \
100 | --batch_size=256 \
101 | --SSL_masking_ratio="$SSL_masking_ratio" \
102 | --CL_similarity_metric="$CL_similarity_metric" --T="$T" --"$normalize" \
103 | --AE_model="$AE_model" --AE_loss="$AE_loss" --"$detach" --beta="$beta" \
104 | --alpha_1="$alpha_1" --alpha_2="$alpha_2" \
105 | --SSL_2D_mode="$SSL_2D_mode" --alpha_3="$alpha_3" \
106 | --num_interactions="$num_interactions" --num_gaussians="$num_gaussians" --cutoff="$cutoff" --schnet_lr_scale="$schnet_lr_scale" \
107 | --dropout_ratio="$dropout_ratio" --num_workers=8 \
108 | --output_model_dir="$output_model_dir"
109 |
110 | echo
111 |
112 | done
113 | done
114 |
115 | done
116 | done
117 | done
118 | done
119 | done
120 | done
121 | done
122 | done
123 | done
124 |
125 | done
126 | done
127 | done
128 |
--------------------------------------------------------------------------------
/scripts_classification/submit_pre_training_baselines.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | cd ../src_classification
4 |
5 | export dataset=GEOM_2D_nmol50000_nconf1_nupper1000
6 | export dropout_ratio=0
7 | export epochs=100
8 |
9 |
10 | export time=3
11 | export mode_list=(EP IG AM CP GraphLoG Motif Contextual GraphCL JOAO JOAOv2)
12 |
13 | export time=12
14 | export mode_list=(GPT_GNN)
15 |
16 |
17 | for mode in "${mode_list[@]}"; do
18 | export folder="$mode"/"$dataset"/epochs_"$epochs"_"$dropout_ratio"
19 | echo "$folder"
20 |
21 | mkdir -p ../output/"$folder"
22 |
23 | export output_file=../output/"$folder"/pretraining.out
24 | export output_model_dir=../output/"$folder"/pretraining
25 |
26 |
27 | if [[ ! -f "$output_file" ]]; then
28 | echo "$folder" undone
29 |
30 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":00:00 --account=rrg-bengioy-ad --qos=high --job-name=baselines \
31 | --output="$output_file" \
32 | ./run_pretrain_"$mode".sh \
33 | --epochs="$epochs" \
34 | --dataset="$dataset" \
35 | --batch_size=256 \
36 | --dropout_ratio="$dropout_ratio" --num_workers=8 \
37 | --output_model_dir="$output_model_dir"
38 | fi
39 | done
40 |
--------------------------------------------------------------------------------
/scripts_regression/run_fine_tuning_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --account=rrg-bengioy-ad
4 | #SBATCH --cpus-per-task=8
5 | #SBATCH --gres=gpu:v100l:1
6 | #SBATCH --mem=32G
7 | #SBATCH --time=2:59:00
8 | #SBATCH --ntasks=1
9 | #SBATCH --array=0-2%3
10 | #SBATCH --output=logs/%j.out
11 | #SBATCH --job-name=reg
12 |
13 |
14 | ###############SBATCH --gres=gpu:v100l:1
15 |
16 | echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID
17 | export dataset_list=(esol freesolv lipophilicity malaria cep)
18 | export split_list=(scaffold random)
19 | export seed_list=(0 1 2 3 4 5 6 7 8 9)
20 | export batch_size=256
21 | export mode=$1
22 | export seed=${seed_list[$SLURM_ARRAY_TASK_ID]}
23 |
24 |
25 |
26 |
27 |
28 | if [ "$mode" == "random" ]; then
29 |
30 | for dataset in "${dataset_list[@]}"; do
31 | for split in "${split_list[@]}"; do
32 | export folder="$mode"/"$seed"/"$split"
33 | mkdir -p ./output/"$folder"/"$dataset"
34 |
35 | export output_path=./output/"$folder"/"$dataset".out
36 | export output_model_dir=./output/"$folder"/"$dataset"
37 |
38 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path"
39 | echo `date` >> "$output_path"
40 |
41 | bash ./run_molecule_finetune_regression.sh \
42 | --dataset="$dataset" --runseed="$seed" --eval_train --batch_size="$batch_size" \
43 | --dropout_ratio=0.2 --split="$split" \
44 | --output_model_dir="$output_model_dir" \
45 | >> "$output_path"
46 |
47 | echo `date` >> "$output_path"
48 | done
49 | done
50 |
51 |
52 |
53 |
54 | else
55 |
56 | for dataset in "${dataset_list[@]}"; do
57 | for split in "${split_list[@]}"; do
58 | export folder="$mode"/"$seed"/"$split"
59 | mkdir -p ./output/"$folder"
60 | mkdir -p ./output/"$folder"/"$dataset"
61 |
62 | export output_path=./output/"$folder"/"$dataset".out
63 | export output_model_dir=./output/"$folder"/"$dataset"
64 | export input_model_file=./output/"$mode"/pretraining_model.pth
65 |
66 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path"
67 | echo `date` >> "$output_path"
68 |
69 | bash ./run_molecule_finetune_regression.sh \
70 | --dataset="$dataset" --runseed="$seed" --eval_train --batch_size="$batch_size" \
71 | --dropout_ratio=0.2 --split="$split" \
72 | --input_model_file="$input_model_file" \
73 | >> "$output_path"
74 |
75 | echo `date` >> "$output_path"
76 | done
77 | done
78 |
79 | fi
80 |
81 |
--------------------------------------------------------------------------------
/scripts_regression/run_fine_tuning_model_DTI.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --account=rrg-bengioy-ad
4 | #SBATCH --cpus-per-task=8
5 | #SBATCH --gres=gpu:v100l:1
6 | #SBATCH --mem=32G
7 | #SBATCH --time=5:55:00
8 | #SBATCH --ntasks=1
9 | #SBATCH --array=0-2%3
10 | #SBATCH --output=logs/%j.out
11 | #SBATCH --job-name=repurpose_complete
12 |
13 |
14 | ###############SBATCH --gres=gpu:v100l:1
15 |
16 | echo "My SLURM_ARRAY_TASK_ID: " $SLURM_ARRAY_TASK_ID
17 | export dataset_list=(davis kiba)
18 | export seed_list=(0 1 2 3 4 5 6 7 8 9)
19 | export batch_size=256
20 | export mode=$1
21 | export seed=${seed_list[$SLURM_ARRAY_TASK_ID]}
22 |
23 |
24 |
25 |
26 |
27 | if [ "$mode" == "random" ]; then
28 |
29 | for dataset in "${dataset_list[@]}"; do
30 | export folder="$mode"/"$seed"
31 | mkdir -p ./output/"$folder"
32 | mkdir -p ./output/"$folder"/"$dataset"
33 |
34 | export output_path=./output/"$folder"/"$dataset".out
35 | export output_model_dir=./output/"$folder"/"$dataset"
36 |
37 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path"
38 | echo `date` >> "$output_path"
39 |
40 | bash ./run_dti_finetune.sh \
41 | --dataset="$dataset" --runseed="$seed" --batch_size="$batch_size" \
42 | >> "$output_path"
43 |
44 | echo `date` >> "$output_path"
45 | done
46 |
47 |
48 |
49 |
50 | else
51 |
52 | for dataset in "${dataset_list[@]}"; do
53 | export folder="$mode"/"$seed"
54 | mkdir -p ./output/"$folder"
55 | mkdir -p ./output/"$folder"/"$dataset"
56 |
57 | export output_path=./output/"$folder"/"$dataset".out
58 | export output_model_dir=./output/"$folder"/"$dataset"
59 | export input_model_file=./output/"$mode"/pretraining_model.pth
60 |
61 | echo "$SLURM_JOB_ID"_"$SLURM_ARRAY_TASK_ID" > "$output_path"
62 | echo `date` >> "$output_path"
63 |
64 | bash ./run_dti_finetune.sh \
65 | --dataset="$dataset" --runseed="$seed" --batch_size="$batch_size" \
66 | --input_model_file="$input_model_file" \
67 | >> "$output_path"
68 |
69 | echo `date` >> "$output_path"
70 | done
71 |
72 | fi
73 |
74 |
--------------------------------------------------------------------------------
/scripts_regression/submit_fine_tuning.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../src_classification
4 |
5 | mode_list=(
6 | random
7 | EP/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0
8 | AM/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0
9 | IG/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0
10 | CP/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0
11 | GraphLoG/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0
12 | Motif/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0
13 | Contextual/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0
14 | GraphCL/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0
15 | JOAO/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0
16 | JOAOv2/GEOM_2D_nmol50000_nconf1_nupper1000_morefeat/epochs_100_0
17 |
18 | GraphMVP/GEOM_3D_nmol50000_nconf5_nupper1000_morefeat/CL_1_VAE_1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0
19 | GraphMVP_hybrid/GEOM_3D_nmol50000_nconf5_nupper1000_morefeat/CL_1_VAE_1_AM_1/6_51_10_0.1/0.15_EBM_dot_prod_0.05_normalize_l2_detach_target_2_100_0
20 | GraphMVP_hybrid/GEOM_3D_nmol50000_nconf5_nupper1000_morefeat/CL_1_VAE_1_CP_0.1/6_51_10_0.1/0.15_EBM_dot_prod_0.2_normalize_l2_detach_target_2_100_0
21 | )
22 |
23 | for mode in "${mode_list[@]}"; do
24 | echo "$mode"
25 | ls output/"$mode"
26 |
27 | sbatch run_fine_tuning_model.sh "$mode"
28 |
29 | echo
30 |
31 | done
32 |
--------------------------------------------------------------------------------
/scripts_regression/submit_pre_training_GraphMVP.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | cd ../src_classification
4 |
5 |
6 | export mode=GraphMVP
7 | export dataset_list=(GEOM_3D_nmol50000_nconf5_nupper1000_morefeat)
8 | export epochs=100
9 | export time=9
10 |
11 |
12 | # For SchNet and GNN
13 | export schnet_lr_scale_list=(0.1)
14 | export num_interactions=6
15 | export num_gaussians=51
16 | export cutoff=10
17 | export dropout_ratio_list=(0)
18 | export SSL_masking_ratio_list=(0.15 0.3)
19 |
20 |
21 |
22 | # For CL
23 | # export CL_similarity_metric_list=(InfoNCE_dot_prod EBM_dot_prod)
24 | export CL_similarity_metric_list=(EBM_dot_prod)
25 | export T_list=(0.1 0.2 0.5 1 2)
26 | export normalize_list=(normalize)
27 |
28 |
29 |
30 | # For VAE
31 | export AE_model=VAE
32 | # export AE_loss_list=(l1 l2 cosine)
33 | export AE_loss_list=(l2)
34 | # export detach_list=(detach_target no_detach_target)
35 | export detach_list=(detach_target)
36 | # export beta_list=(0.1 1 2)
37 | export beta_list=(1 2)
38 |
39 |
40 |
41 |
42 | # For CL + VAE
43 | export alpha_1_list=(1)
44 | export alpha_2_list=(0.1 1)
45 |
46 |
47 |
48 | export SSL_masking_ratio_list=(0)
49 | export CL_similarity_metric_list=(EBM_dot_prod)
50 | export T_list=(0.1 0.2)
51 |
52 |
53 |
54 |
55 | for dataset in "${dataset_list[@]}"; do
56 | for SSL_masking_ratio in "${SSL_masking_ratio_list[@]}"; do
57 |
58 | for alpha_1 in "${alpha_1_list[@]}"; do
59 | for alpha_2 in "${alpha_2_list[@]}"; do
60 | for CL_similarity_metric in "${CL_similarity_metric_list[@]}"; do
61 | for normalize in "${normalize_list[@]}"; do
62 | for T in "${T_list[@]}"; do
63 | for AE_loss in "${AE_loss_list[@]}"; do
64 | for detach in "${detach_list[@]}"; do
65 | for beta in "${beta_list[@]}"; do
66 |
67 |
68 | for schnet_lr_scale in "${schnet_lr_scale_list[@]}"; do
69 | for dropout_ratio in "${dropout_ratio_list[@]}"; do
70 | export folder="$mode"/"$dataset"/CL_"$alpha_1"_"$AE_model"_"$alpha_2"/"$num_interactions"_"$num_gaussians"_"$cutoff"_"$schnet_lr_scale"/"$SSL_masking_ratio"_"$CL_similarity_metric"_"$T"_"$normalize"_"$AE_loss"_"$detach"_"$beta"_"$epochs"_"$dropout_ratio"
71 |
72 | echo "$folder"
73 | mkdir -p ../output/"$folder"
74 | ls ../output/"$folder"
75 |
76 | export output_file=../output/"$folder"/pretraining.out
77 | export output_model_dir=../output/"$folder"/pretraining
78 |
79 |
80 | echo "$output_file" undone
81 |
82 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":00:00 --account=rrg-bengioy-ad --qos=high --job-name=CL_VAE_"$time" \
83 | --output="$output_file" \
84 | ./run_pretrain_"$mode".sh \
85 | --epochs="$epochs" \
86 | --dataset="$dataset" \
87 | --batch_size=256 \
88 | --SSL_masking_ratio="$SSL_masking_ratio" \
89 | --CL_similarity_metric="$CL_similarity_metric" --T="$T" --"$normalize" \
90 | --AE_model="$AE_model" --AE_loss="$AE_loss" --"$detach" --beta="$beta" \
91 | --alpha_1="$alpha_1" --alpha_2="$alpha_2" \
92 | --num_interactions="$num_interactions" --num_gaussians="$num_gaussians" --cutoff="$cutoff" --schnet_lr_scale="$schnet_lr_scale" \
93 | --dropout_ratio="$dropout_ratio" --num_workers=8 \
94 | --output_model_dir="$output_model_dir"
95 |
96 | done
97 | done
98 |
99 | done
100 | done
101 | done
102 | done
103 | done
104 | done
105 | done
106 | done
107 | done
108 | done
109 |
--------------------------------------------------------------------------------
/scripts_regression/submit_pre_training_GraphMVP_hybrid.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | cd ../src_classification
4 |
5 |
6 | export mode=GraphMVP_hybrid
7 | export dataset_list=(GEOM_3D_nmol50000_nconf5_nupper1000_morefeat)
8 | export epochs=100
9 |
10 |
11 |
12 |
13 |
14 | # For SchNet and GNN
15 | export schnet_lr_scale_list=(0.1)
16 | export num_interactions=6
17 | export num_gaussians=51
18 | export cutoff=10
19 | export dropout_ratio_list=(0)
20 | export SSL_masking_ratio_list=(0.15 0.3)
21 |
22 |
23 |
24 |
25 | # For CL
26 | # export CL_similarity_metric_list=(InfoNCE_dot_prod EBM_dot_prod)
27 | export CL_similarity_metric_list=(EBM_dot_prod)
28 | # export T_list=(0.05 0.1 0.2 0.5 1 2)
29 | export T_list=(0.1 0.2)
30 | export normalize_list=(normalize)
31 |
32 |
33 |
34 | # For VAE
35 | export AE_model=VAE
36 | # export AE_loss_list=(l1 l2 cosine)
37 | export AE_loss_list=(l2)
38 | # export detach_list=(detach_target no_detach_target)
39 | export detach_list=(detach_target)
40 | # export beta_list=(0.1 1 2)
41 | export beta_list=(1 2)
42 |
43 |
44 |
45 |
46 |
47 | # For CL + VAE
48 | export alpha_1_list=(1)
49 | export alpha_2_list=(0.1 1)
50 | export alpha_3_list=(0.1 1)
51 | export SSL_2D_mode_list=( CP AM IG JOAOv2 JOAO GraphCL)
52 | export time_list=( 3 3 3 6 6 3)
53 | export time_list=( 9 9 6 12 12 6)
54 |
55 |
56 |
57 |
58 | for dataset in "${dataset_list[@]}"; do
59 | for SSL_masking_ratio in "${SSL_masking_ratio_list[@]}"; do
60 |
61 | for i in {0..1}; do
62 | SSL_2D_mode=${SSL_2D_mode_list[$i]}
63 | time=${time_list[$i]}
64 |
65 | for alpha_3 in "${alpha_3_list[@]}"; do
66 |
67 | for alpha_1 in "${alpha_1_list[@]}"; do
68 | for alpha_2 in "${alpha_2_list[@]}"; do
69 | for CL_similarity_metric in "${CL_similarity_metric_list[@]}"; do
70 | for normalize in "${normalize_list[@]}"; do
71 | for T in "${T_list[@]}"; do
72 | for AE_loss in "${AE_loss_list[@]}"; do
73 | for detach in "${detach_list[@]}"; do
74 | for beta in "${beta_list[@]}"; do
75 |
76 |
77 | for schnet_lr_scale in "${schnet_lr_scale_list[@]}"; do
78 | for dropout_ratio in "${dropout_ratio_list[@]}"; do
79 | export folder="$mode"/"$dataset"/CL_"$alpha_1"_"$AE_model"_"$alpha_2"_"$SSL_2D_mode"_"$alpha_3"/"$num_interactions"_"$num_gaussians"_"$cutoff"_"$schnet_lr_scale"/"$SSL_masking_ratio"_"$CL_similarity_metric"_"$T"_"$normalize"_"$AE_loss"_"$detach"_"$beta"_"$epochs"_"$dropout_ratio"
80 |
81 | echo "$folder"
82 | mkdir -p ../output/"$folder"
83 | ls ../output/"$folder"
84 |
85 | export output_file=../output/"$folder"/pretraining.out
86 | export output_model_dir=../output/"$folder"/pretraining
87 |
88 | echo "$output_model_dir"_model_final.pth undone
89 | ls "$output_model_dir"*
90 | echo "$output_file"
91 | ls ../output/"$folder"
92 | rm ../output/"$folder"/*
93 |
94 |
95 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":00:00 --account=rrg-bengioy-ad --qos=high --job-name=CL_VAE_"$SSL_2D_mode"_"$time" \
96 | --output="$output_file" \
97 | ./run_pretrain_"$mode".sh \
98 | --epochs="$epochs" \
99 | --dataset="$dataset" \
100 | --batch_size=256 \
101 | --SSL_masking_ratio="$SSL_masking_ratio" \
102 | --CL_similarity_metric="$CL_similarity_metric" --T="$T" --"$normalize" \
103 | --AE_model="$AE_model" --AE_loss="$AE_loss" --"$detach" --beta="$beta" \
104 | --alpha_1="$alpha_1" --alpha_2="$alpha_2" \
105 | --SSL_2D_mode="$SSL_2D_mode" --alpha_3="$alpha_3" \
106 | --num_interactions="$num_interactions" --num_gaussians="$num_gaussians" --cutoff="$cutoff" --schnet_lr_scale="$schnet_lr_scale" \
107 | --dropout_ratio="$dropout_ratio" --num_workers=8 \
108 | --output_model_dir="$output_model_dir"
109 |
110 | echo
111 |
112 | done
113 | done
114 |
115 | done
116 | done
117 | done
118 | done
119 | done
120 | done
121 | done
122 | done
123 | done
124 |
125 | done
126 | done
127 | done
128 |
--------------------------------------------------------------------------------
/scripts_regression/submit_pre_training_baselines.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | cd ../src_regression
4 |
5 | epochs=100
6 | time=2
7 | mode_list=(AM CP GraphCL JOAO JOAOv2)
8 | dropout_ratio=0
9 | dataset=GEOM_2D_nmol50000_nconf1_nupper1000_morefeat
10 |
11 |
12 | for mode in "${mode_list[@]}"; do
13 | export folder="$mode"/"$dataset"/epochs_"$epochs"_"$dropout_ratio"
14 | echo "$folder"
15 |
16 | mkdir -p ./output/"$folder"
17 | ls ./output/"$folder"
18 |
19 | export output_file=./output/"$folder"/pretraining.out
20 | export output_model_dir=./output/"$folder"/pretraining
21 |
22 | sbatch --gres=gpu:v100l:1 -c 8 --mem=32G -t "$time":59:00 --account=rrg-bengioy-ad --qos=high --job-name=baselines \
23 | --output="$output_file" \
24 | ./run_pretrain_"$mode".sh \
25 | --epochs="$epochs" \
26 | --dataset="$dataset" \
27 | --batch_size=256 \
28 | --dropout_ratio="$dropout_ratio" --num_workers=8 \
29 | --output_model_dir="$output_model_dir"
30 | done
31 |
--------------------------------------------------------------------------------
/src_classification/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 |
5 | # about seed and basic info
6 | parser.add_argument('--seed', type=int, default=42)
7 | parser.add_argument('--runseed', type=int, default=0)
8 | parser.add_argument('--device', type=int, default=0)
9 |
10 | # about dataset and dataloader
11 | parser.add_argument('--input_data_dir', type=str, default='')
12 | parser.add_argument('--dataset', type=str, default='bace')
13 | parser.add_argument('--num_workers', type=int, default=8)
14 |
15 | # about training strategies
16 | parser.add_argument('--split', type=str, default='scaffold')
17 | parser.add_argument('--batch_size', type=int, default=256)
18 | parser.add_argument('--epochs', type=int, default=100)
19 | parser.add_argument('--lr', type=float, default=0.001)
20 | parser.add_argument('--lr_scale', type=float, default=1)
21 | parser.add_argument('--decay', type=float, default=0)
22 |
23 | # about molecule GNN
24 | parser.add_argument('--gnn_type', type=str, default='gin')
25 | parser.add_argument('--num_layer', type=int, default=5)
26 | parser.add_argument('--emb_dim', type=int, default=300)
27 | parser.add_argument('--dropout_ratio', type=float, default=0.5)
28 | parser.add_argument('--graph_pooling', type=str, default='mean')
29 | parser.add_argument('--JK', type=str, default='last')
30 | parser.add_argument('--gnn_lr_scale', type=float, default=1)
31 | parser.add_argument('--model_3d', type=str, default='schnet', choices=['schnet'])
32 |
33 | # for AttributeMask
34 | parser.add_argument('--mask_rate', type=float, default=0.15)
35 | parser.add_argument('--mask_edge', type=int, default=0)
36 |
37 | # for ContextPred
38 | parser.add_argument('--csize', type=int, default=3)
39 | parser.add_argument('--contextpred_neg_samples', type=int, default=1)
40 |
41 | # for SchNet
42 | parser.add_argument('--num_filters', type=int, default=128)
43 | parser.add_argument('--num_interactions', type=int, default=6)
44 | parser.add_argument('--num_gaussians', type=int, default=51)
45 | parser.add_argument('--cutoff', type=float, default=10)
46 | parser.add_argument('--readout', type=str, default='mean', choices=['mean', 'add'])
47 | parser.add_argument('--schnet_lr_scale', type=float, default=1)
48 |
49 | # for 2D-3D Contrastive CL
50 | parser.add_argument('--CL_neg_samples', type=int, default=1)
51 | parser.add_argument('--CL_similarity_metric', type=str, default='InfoNCE_dot_prod',
52 | choices=['InfoNCE_dot_prod', 'EBM_dot_prod'])
53 | parser.add_argument('--T', type=float, default=0.1)
54 | parser.add_argument('--normalize', dest='normalize', action='store_true')
55 | parser.add_argument('--no_normalize', dest='normalize', action='store_false')
56 | parser.add_argument('--SSL_masking_ratio', type=float, default=0)
57 | # This is for generative SSL.
58 | parser.add_argument('--AE_model', type=str, default='AE', choices=['AE', 'VAE'])
59 | parser.set_defaults(AE_model='AE')
60 |
61 | # for 2D-3D AutoEncoder
62 | parser.add_argument('--AE_loss', type=str, default='l2', choices=['l1', 'l2', 'cosine'])
63 | parser.add_argument('--detach_target', dest='detach_target', action='store_true')
64 | parser.add_argument('--no_detach_target', dest='detach_target', action='store_false')
65 | parser.set_defaults(detach_target=True)
66 |
67 | # for 2D-3D Variational AutoEncoder
68 | parser.add_argument('--beta', type=float, default=1)
69 |
70 | # for 2D-3D Contrastive CL and AE/VAE
71 | parser.add_argument('--alpha_1', type=float, default=1)
72 | parser.add_argument('--alpha_2', type=float, default=1)
73 |
74 | # for 2D SSL and 3D-2D SSL
75 | parser.add_argument('--SSL_2D_mode', type=str, default='AM')
76 | parser.add_argument('--alpha_3', type=float, default=0.1)
77 | parser.add_argument('--gamma_joao', type=float, default=0.1)
78 | parser.add_argument('--gamma_joaov2', type=float, default=0.1)
79 |
80 | # about if we would print out eval metric for training data
81 | parser.add_argument('--eval_train', dest='eval_train', action='store_true')
82 | parser.add_argument('--no_eval_train', dest='eval_train', action='store_false')
83 | parser.set_defaults(eval_train=True)
84 |
85 | # about loading and saving
86 | parser.add_argument('--input_model_file', type=str, default='')
87 | parser.add_argument('--output_model_dir', type=str, default='')
88 |
89 | # verbosity
90 | parser.add_argument('--verbose', dest='verbose', action='store_true')
91 | parser.add_argument('--no_verbose', dest='verbose', action='store_false')
92 | parser.set_defaults(verbose=False)
93 |
94 | args = parser.parse_args()
95 | print('arguments\t', args)
96 |
--------------------------------------------------------------------------------
/src_classification/dataloader.py:
--------------------------------------------------------------------------------
1 | from batch import (BatchAE, BatchMasking, BatchSubstructContext,
2 | BatchSubstructContext3D)
3 | from torch.utils.data import DataLoader
4 |
5 |
6 | class DataLoaderSubstructContext(DataLoader):
7 | """Data loader which merges data objects from a
8 | :class:`torch_geometric.data.dataset` to a mini-batch.
9 | Args:
10 | dataset (Dataset): The dataset from which to load the data.
11 | batch_size (int, optional): How may samples per batch to load.
12 | (default: :obj:`1`)
13 | shuffle (bool, optional): If set to :obj:`True`, the data will be
14 | reshuffled at every epoch (default: :obj:`True`) """
15 |
16 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
17 | super(DataLoaderSubstructContext, self).__init__(
18 | dataset,
19 | batch_size,
20 | shuffle,
21 | collate_fn=lambda data_list: BatchSubstructContext.from_data_list(data_list),
22 | **kwargs)
23 |
24 |
25 | class DataLoaderSubstructContext3D(DataLoader):
26 | """Data loader which merges data objects from a
27 | :class:`torch_geometric.data.dataset` to a mini-batch.
28 | Args:
29 | dataset (Dataset): The dataset from which to load the data.
30 | batch_size (int, optional): How may samples per batch to load.
31 | (default: :obj:`1`)
32 | shuffle (bool, optional): If set to :obj:`True`, the data will be
33 | reshuffled at every epoch (default: :obj:`True`) """
34 |
35 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
36 | super(DataLoaderSubstructContext3D, self).__init__(
37 | dataset,
38 | batch_size,
39 | shuffle,
40 | collate_fn=lambda data_list: BatchSubstructContext3D.from_data_list(data_list),
41 | **kwargs)
42 |
43 |
44 | class DataLoaderMasking(DataLoader):
45 | """Data loader which merges data objects from a
46 | :class:`torch_geometric.data.dataset` to a mini-batch.
47 | Args:
48 | dataset (Dataset): The dataset from which to load the data.
49 | batch_size (int, optional): How may samples per batch to load.
50 | (default: :obj:`1`)
51 | shuffle (bool, optional): If set to :obj:`True`, the data will be
52 | reshuffled at every epoch (default: :obj:`True`) """
53 |
54 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
55 | super(DataLoaderMasking, self).__init__(
56 | dataset,
57 | batch_size,
58 | shuffle,
59 | collate_fn=lambda data_list: BatchMasking.from_data_list(data_list),
60 | **kwargs)
61 |
62 |
63 | class DataLoaderAE(DataLoader):
64 | """ Data loader which merges data objects from a
65 | :class:`torch_geometric.data.dataset` to a mini-batch.
66 | Args:
67 | dataset (Dataset): The dataset from which to load the data.
68 | batch_size (int, optional): How may samples per batch to load.
69 | (default: :obj:`1`)
70 | shuffle (bool, optional): If set to :obj:`True`, the data will be
71 | reshuffled at every epoch (default: :obj:`True`) """
72 |
73 | def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
74 | super(DataLoaderAE, self).__init__(
75 | dataset,
76 | batch_size,
77 | shuffle,
78 | collate_fn=lambda data_list: BatchAE.from_data_list(data_list),
79 | **kwargs)
80 |
--------------------------------------------------------------------------------
/src_classification/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets_GPT import MoleculeDatasetGPT
2 | from .molecule_3D_dataset import Molecule3DDataset
3 | from .molecule_3D_masking_dataset import Molecule3DMaskingDataset
4 | from .molecule_contextual_datasets import MoleculeContextualDataset
5 | from .molecule_datasets import (MoleculeDataset, allowable_features,
6 | graph_data_obj_to_nx_simple,
7 | nx_to_graph_data_obj_simple)
8 | from .molecule_graphcl_dataset import MoleculeDataset_graphcl
9 | from .molecule_graphcl_masking_dataset import MoleculeGraphCLMaskingDataset
10 | from .molecule_motif_datasets import RDKIT_PROPS, MoleculeMotifDataset
11 |
--------------------------------------------------------------------------------
/src_classification/datasets/datasets_GPT.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 | from torch_geometric.data import Data, InMemoryDataset
6 | from torch_geometric.utils import subgraph
7 | from tqdm import tqdm
8 |
9 |
10 | def search_graph(graph):
11 | num_node = len(graph.x)
12 | edge_set = set()
13 |
14 | u_list, v_list = graph.edge_index[0].numpy(), graph.edge_index[1].numpy()
15 | for u,v in zip(u_list, v_list):
16 | edge_set.add((u,v))
17 | edge_set.add((v,u))
18 |
19 | visited_list = []
20 | unvisited_set = set([i for i in range(num_node)])
21 |
22 | while len(unvisited_set) > 0:
23 | u = random.sample(unvisited_set, 1)[0]
24 | queue = [u]
25 | while len(queue):
26 | u = queue.pop(0)
27 | if u in visited_list:
28 | continue
29 | visited_list.append(u)
30 | unvisited_set.remove(u)
31 |
32 | for v in range(num_node):
33 | if (v not in visited_list) and ((u,v) in edge_set):
34 | queue.append(v)
35 | assert len(visited_list) == num_node
36 | return visited_list
37 |
38 |
39 | class MoleculeDatasetGPT(InMemoryDataset):
40 | def __init__(self, molecule_dataset, transform=None, pre_transform=None):
41 | self.molecule_dataset = molecule_dataset
42 | self.root = molecule_dataset.root + '_GPT'
43 | super(MoleculeDatasetGPT, self).__init__(self.root, transform=transform, pre_transform=pre_transform)
44 |
45 | self.data, self.slices = torch.load(self.processed_paths[0])
46 |
47 | return
48 |
49 | def process(self):
50 | num_molecule = len(self.molecule_dataset)
51 | data_list = []
52 | for i in tqdm(range(num_molecule)):
53 | graph = self.molecule_dataset.get(i)
54 |
55 | num_node = len(graph.x)
56 | # TODO: will replace this with DFS/BFS searching
57 | node_list = search_graph(graph)
58 |
59 | for idx in range(num_node-1):
60 | # print('sub_node_list: {}\nnext_node: {}'.format(sub_node_list, next_node))
61 | # [0..idx] -> [idx+1]
62 | sub_node_list = node_list[:idx+1]
63 | next_node = node_list[idx+1]
64 |
65 | edge_index, edge_attr = subgraph(
66 | subset=sub_node_list, edge_index=graph.edge_index, edge_attr=graph.edge_attr,
67 | relabel_nodes=True, num_nodes=num_node)
68 |
69 | # Take the subgraph and predict on the next node (atom type only)
70 | sub_graph = Data(x=graph.x[sub_node_list], edge_index=edge_index, edge_attr=edge_attr, next_x=graph.x[next_node, :1])
71 | data_list.append(sub_graph)
72 |
73 | print('len of data\t', len(data_list))
74 | data, slices = self.collate(data_list)
75 | print('Saving...')
76 | torch.save((data, slices), self.processed_paths[0])
77 | return
78 |
79 | @property
80 | def processed_file_names(self):
81 | return 'geometric_data_processed.pt'
82 |
--------------------------------------------------------------------------------
/src_classification/datasets/molecule_3D_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import repeat
3 |
4 | import torch
5 | from torch_geometric.data import Data, InMemoryDataset
6 |
7 |
8 | class Molecule3DDataset(InMemoryDataset):
9 | def __init__(self, root, dataset='zinc250k',
10 | transform=None, pre_transform=None, pre_filter=None, empty=False):
11 | self.dataset = dataset
12 | self.root = root
13 |
14 | super(Molecule3DDataset, self).__init__(root, transform, pre_transform, pre_filter)
15 | self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
16 |
17 | if not empty:
18 | self.data, self.slices = torch.load(self.processed_paths[0])
19 | print('Dataset: {}\nData: {}'.format(self.dataset, self.data))
20 |
21 | def get(self, idx):
22 | data = Data()
23 | for key in self.data.keys:
24 | item, slices = self.data[key], self.slices[key]
25 | s = list(repeat(slice(None), item.dim()))
26 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
27 | data[key] = item[s]
28 | return data
29 |
30 | @property
31 | def raw_file_names(self):
32 | return os.listdir(self.raw_dir)
33 |
34 | @property
35 | def processed_file_names(self):
36 | return 'geometric_data_processed.pt'
37 |
38 | def download(self):
39 | return
40 |
41 | def process(self):
42 | return
43 |
--------------------------------------------------------------------------------
/src_classification/datasets/molecule_3D_masking_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import repeat
3 |
4 | import numpy as np
5 | import torch
6 | from torch_geometric.data import Data, InMemoryDataset
7 | from torch_geometric.utils import subgraph, to_networkx
8 |
9 |
10 | class Molecule3DMaskingDataset(InMemoryDataset):
11 | def __init__(self, root, dataset, mask_ratio,
12 | transform=None, pre_transform=None, pre_filter=None, empty=False):
13 | self.root = root
14 | self.dataset = dataset
15 | self.mask_ratio = mask_ratio
16 |
17 | super(Molecule3DMaskingDataset, self).__init__(root, transform, pre_transform, pre_filter)
18 | self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
19 |
20 | if not empty:
21 | self.data, self.slices = torch.load(self.processed_paths[0])
22 | print('Dataset: {}\nData: {}'.format(self.dataset, self.data))
23 |
24 | def subgraph(self, data):
25 | G = to_networkx(data)
26 | node_num, _ = data.x.size()
27 | sub_num = int(node_num * (1 - self.mask_ratio))
28 |
29 | idx_sub = [np.random.randint(node_num, size=1)[0]]
30 | idx_neigh = set([n for n in G.neighbors(idx_sub[-1])])
31 |
32 | # BFS
33 | while len(idx_sub) <= sub_num:
34 | if len(idx_neigh) == 0:
35 | idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub)))
36 | idx_neigh = set([np.random.choice(idx_unsub)])
37 | sample_node = np.random.choice(list(idx_neigh))
38 |
39 | idx_sub.append(sample_node)
40 | idx_neigh = idx_neigh.union(
41 | set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub))
42 |
43 | idx_nondrop = idx_sub
44 | idx_nondrop.sort()
45 |
46 | edge_idx, edge_attr = subgraph(subset=idx_nondrop,
47 | edge_index=data.edge_index,
48 | edge_attr=data.edge_attr,
49 | relabel_nodes=True,
50 | num_nodes=node_num)
51 |
52 | data.edge_index = edge_idx
53 | data.edge_attr = edge_attr
54 | data.x = data.x[idx_nondrop]
55 | data.positions = data.positions[idx_nondrop]
56 | data.__num_nodes__, _ = data.x.shape
57 | return data
58 |
59 | def get(self, idx):
60 | data = Data()
61 | for key in self.data.keys:
62 | item, slices = self.data[key], self.slices[key]
63 | s = list(repeat(slice(None), item.dim()))
64 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
65 | data[key] = item[s]
66 |
67 | if self.mask_ratio > 0:
68 | data = self.subgraph(data)
69 | return data
70 |
71 | @property
72 | def raw_file_names(self):
73 | return os.listdir(self.raw_dir)
74 |
75 | @property
76 | def processed_file_names(self):
77 | return 'geometric_data_processed.pt'
78 |
79 | def download(self):
80 | return
81 |
82 | def process(self):
83 | return
84 |
--------------------------------------------------------------------------------
/src_classification/datasets/molecule_contextual_datasets_utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from collections import Counter
3 | from multiprocessing import Pool
4 |
5 | import tqdm
6 |
7 | BOND_FEATURES = ['BondType', 'BondDir']
8 |
9 |
10 | def atom_to_vocab(mol, atom):
11 | """
12 | Convert atom to vocabulary. The convention is based on atom type and bond type.
13 | :param mol: the molecular.
14 | :param atom: the target atom.
15 | :return: the generated atom vocabulary with its contexts.
16 | """
17 | nei = Counter()
18 | for a in atom.GetNeighbors():
19 | bond = mol.GetBondBetweenAtoms(atom.GetIdx(), a.GetIdx())
20 | nei[str(a.GetSymbol()) + "-" + str(bond.GetBondType())] += 1
21 | keys = nei.keys()
22 | keys = list(keys)
23 | keys.sort()
24 | output = atom.GetSymbol()
25 | for k in keys:
26 | output = "%s_%s%d" % (output, k, nei[k])
27 |
28 | # The generated atom_vocab is too long?
29 | return output
30 |
31 |
32 | def bond_to_vocab(mol, bond):
33 | """
34 | Convert bond to vocabulary. The convention is based on atom type and bond type.
35 | Considering one-hop neighbor atoms
36 | :param mol: the molecular.
37 | :param atom: the target atom.
38 | :return: the generated bond vocabulary with its contexts.
39 | """
40 | nei = Counter()
41 | two_neighbors = (bond.GetBeginAtom(), bond.GetEndAtom())
42 | two_indices = [a.GetIdx() for a in two_neighbors]
43 | for nei_atom in two_neighbors:
44 | for a in nei_atom.GetNeighbors():
45 | a_idx = a.GetIdx()
46 | if a_idx in two_indices:
47 | continue
48 | tmp_bond = mol.GetBondBetweenAtoms(nei_atom.GetIdx(), a_idx)
49 | nei[str(nei_atom.GetSymbol()) + '-' + get_bond_feature_name(tmp_bond)] += 1
50 | keys = list(nei.keys())
51 | keys.sort()
52 | output = get_bond_feature_name(bond)
53 | for k in keys:
54 | output = "%s_%s%d" % (output, k, nei[k])
55 | return output
56 |
57 |
58 | def get_bond_feature_name(bond):
59 | """
60 | Return the string format of bond features.
61 | Bond features are surrounded with ()
62 | """
63 | ret = []
64 | for bond_feature in BOND_FEATURES:
65 | fea = eval(f"bond.Get{bond_feature}")()
66 | ret.append(str(fea))
67 |
68 | return '(' + '-'.join(ret) + ')'
69 |
70 |
71 | class TorchVocab(object):
72 | def __init__(self, counter, max_size=None, min_freq=1, specials=('', ''), vocab_type='atom'):
73 | """
74 | :param counter:
75 | :param max_size:
76 | :param min_freq:
77 | :param specials:
78 | :param vocab_type: 'atom': atom atom_vocab; 'bond': bond atom_vocab.
79 | """
80 | self.freqs = counter
81 | counter = counter.copy()
82 | min_freq = max(min_freq, 1)
83 | if vocab_type in ('atom', 'bond'):
84 | self.vocab_type = vocab_type
85 | else:
86 | raise ValueError('Wrong input for vocab_type!')
87 | self.itos = list(specials)
88 |
89 | max_size = None if max_size is None else max_size + len(self.itos)
90 | # sort by frequency, then alphabetically
91 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
92 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
93 |
94 | for word, freq in words_and_frequencies:
95 | if freq < min_freq or len(self.itos) == max_size:
96 | break
97 | self.itos.append(word)
98 | # stoi is simply a reverse dict for itos
99 | self.stoi = {tok: i for i, tok in enumerate(self.itos)}
100 | self.other_index = 1
101 | self.pad_index = 0
102 |
103 | def __eq__(self, other):
104 | if self.freqs != other.freqs:
105 | return False
106 | if self.stoi != other.stoi:
107 | return False
108 | if self.itos != other.itos:
109 | return False
110 | return True
111 |
112 | def __len__(self):
113 | return len(self.itos)
114 |
115 | def vocab_rerank(self):
116 | self.stoi = {word: i for i, word in enumerate(self.itos)}
117 |
118 | def extend(self, v, sort=False):
119 | words = sorted(v.itos) if sort else v.itos
120 | for w in words:
121 | if w not in self.stoi:
122 | self.itos.append(w)
123 | self.stoi[w] = len(self.itos) - 1
124 | self.freqs[w] = 0
125 | self.freqs[w] += v.freqs[w]
126 |
127 | def save_vocab(self, vocab_path):
128 | with open(vocab_path, "wb") as f:
129 | pickle.dump(self, f)
130 |
131 |
132 | class MolVocab(TorchVocab):
133 | def __init__(self, molecule_list, max_size=None, min_freq=1, num_workers=1, total_lines=None, vocab_type='atom'):
134 | if vocab_type in ('atom', 'bond'):
135 | self.vocab_type = vocab_type
136 | else:
137 | raise ValueError('Wrong input for vocab_type!')
138 | print("Building {} vocab from molecule-list".format((self.vocab_type)))
139 |
140 | from rdkit import RDLogger
141 | lg = RDLogger.logger()
142 | lg.setLevel(RDLogger.CRITICAL)
143 |
144 | if total_lines is None:
145 | total_lines = len(molecule_list)
146 |
147 | counter = Counter()
148 | pbar = tqdm.tqdm(total=total_lines)
149 | pool = Pool(num_workers)
150 | res = []
151 | batch = 50000
152 | callback = lambda a: pbar.update(batch)
153 | for i in range(int(total_lines / batch + 1)):
154 | start = int(batch * i)
155 | end = min(total_lines, batch * (i + 1))
156 | res.append(pool.apply_async(MolVocab.read_counter_from_molecules,
157 | args=(molecule_list, start, end, vocab_type,),
158 | callback=callback))
159 | pool.close()
160 | pool.join()
161 | for r in res:
162 | sub_counter = r.get()
163 | for k in sub_counter:
164 | if k not in counter:
165 | counter[k] = 0
166 | counter[k] += sub_counter[k]
167 | super().__init__(counter, max_size=max_size, min_freq=min_freq, vocab_type=vocab_type)
168 |
169 | @staticmethod
170 | def read_counter_from_molecules(molecule_list, start, end, vocab_type):
171 | sub_counter = Counter()
172 | for i, mol in enumerate(molecule_list):
173 | if i < start:
174 | continue
175 | if i >= end:
176 | break
177 | if vocab_type == 'atom':
178 | for atom in mol.GetAtoms():
179 | v = atom_to_vocab(mol, atom)
180 | sub_counter[v] += 1
181 | else:
182 | for bond in mol.GetBonds():
183 | v = bond_to_vocab(mol, bond)
184 | sub_counter[v] += 1
185 | # print("end")
186 | return sub_counter
187 |
188 | @staticmethod
189 | def load_vocab(vocab_path: str) -> 'MolVocab':
190 | with open(vocab_path, "rb") as f:
191 | return pickle.load(f)
192 |
--------------------------------------------------------------------------------
/src_classification/datasets/molecule_graphcl_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from itertools import repeat
4 |
5 | import numpy as np
6 | import torch
7 | from torch_geometric.data import Data
8 | from torch_geometric.utils import subgraph, to_networkx
9 |
10 | from .molecule_datasets import MoleculeDataset
11 |
12 |
13 | class MoleculeDataset_graphcl(MoleculeDataset):
14 |
15 | def __init__(self,
16 | root,
17 | transform=None,
18 | pre_transform=None,
19 | pre_filter=None,
20 | dataset=None,
21 | empty=False):
22 |
23 | self.aug_prob = None
24 | self.aug_mode = 'no_aug'
25 | self.aug_strength = 0.2
26 | self.augmentations = [self.node_drop, self.subgraph,
27 | self.edge_pert, self.attr_mask, lambda x: x]
28 | super(MoleculeDataset_graphcl, self).__init__(
29 | root, transform, pre_transform, pre_filter, dataset, empty)
30 |
31 | def set_augMode(self, aug_mode):
32 | self.aug_mode = aug_mode
33 |
34 | def set_augStrength(self, aug_strength):
35 | self.aug_strength = aug_strength
36 |
37 | def set_augProb(self, aug_prob):
38 | self.aug_prob = aug_prob
39 |
40 | def node_drop(self, data):
41 |
42 | node_num, _ = data.x.size()
43 | _, edge_num = data.edge_index.size()
44 | drop_num = int(node_num * self.aug_strength)
45 |
46 | idx_perm = np.random.permutation(node_num)
47 | idx_nodrop = idx_perm[drop_num:].tolist()
48 | idx_nodrop.sort()
49 |
50 | edge_idx, edge_attr = subgraph(subset=idx_nodrop,
51 | edge_index=data.edge_index,
52 | edge_attr=data.edge_attr,
53 | relabel_nodes=True,
54 | num_nodes=node_num)
55 |
56 | data.edge_index = edge_idx
57 | data.edge_attr = edge_attr
58 | data.x = data.x[idx_nodrop]
59 | data.__num_nodes__, _ = data.x.shape
60 | return data
61 |
62 | def edge_pert(self, data):
63 | node_num, _ = data.x.size()
64 | _, edge_num = data.edge_index.size()
65 | pert_num = int(edge_num * self.aug_strength)
66 |
67 | # delete edges
68 | idx_drop = np.random.choice(edge_num, (edge_num - pert_num),
69 | replace=False)
70 | edge_index = data.edge_index[:, idx_drop]
71 | edge_attr = data.edge_attr[idx_drop]
72 |
73 | # add edges
74 | adj = torch.ones((node_num, node_num))
75 | adj[edge_index[0], edge_index[1]] = 0
76 | # edge_index_nonexist = adj.nonzero(as_tuple=False).t()
77 | edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t()
78 | idx_add = np.random.choice(edge_index_nonexist.shape[1],
79 | pert_num, replace=False)
80 | edge_index_add = edge_index_nonexist[:, idx_add]
81 | # random 4-class & 3-class edge_attr for 1st & 2nd dimension
82 | edge_attr_add_1 = torch.tensor(np.random.randint(
83 | 4, size=(edge_index_add.shape[1], 1)))
84 | edge_attr_add_2 = torch.tensor(np.random.randint(
85 | 3, size=(edge_index_add.shape[1], 1)))
86 | edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2), dim=1)
87 | edge_index = torch.cat((edge_index, edge_index_add), dim=1)
88 | edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0)
89 |
90 | data.edge_index = edge_index
91 | data.edge_attr = edge_attr
92 | return data
93 |
94 | def attr_mask(self, data):
95 |
96 | _x = data.x.clone()
97 | node_num, _ = data.x.size()
98 | mask_num = int(node_num * self.aug_strength)
99 |
100 | token = data.x.float().mean(dim=0).long()
101 | idx_mask = np.random.choice(
102 | node_num, mask_num, replace=False)
103 |
104 | _x[idx_mask] = token
105 | data.x = _x
106 | return data
107 |
108 | def subgraph(self, data):
109 |
110 | G = to_networkx(data)
111 | node_num, _ = data.x.size()
112 | _, edge_num = data.edge_index.size()
113 | sub_num = int(node_num * (1 - self.aug_strength))
114 |
115 | idx_sub = [np.random.randint(node_num, size=1)[0]]
116 | idx_neigh = set([n for n in G.neighbors(idx_sub[-1])])
117 |
118 | while len(idx_sub) <= sub_num:
119 | if len(idx_neigh) == 0:
120 | idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub)))
121 | idx_neigh = set([np.random.choice(idx_unsub)])
122 | sample_node = np.random.choice(list(idx_neigh))
123 |
124 | idx_sub.append(sample_node)
125 | idx_neigh = idx_neigh.union(
126 | set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub))
127 |
128 | idx_nondrop = idx_sub
129 | idx_nondrop.sort()
130 |
131 | edge_idx, edge_attr = subgraph(subset=idx_nondrop,
132 | edge_index=data.edge_index,
133 | edge_attr=data.edge_attr,
134 | relabel_nodes=True,
135 | num_nodes=node_num)
136 |
137 | data.edge_index = edge_idx
138 | data.edge_attr = edge_attr
139 | data.x = data.x[idx_nondrop]
140 | data.__num_nodes__, _ = data.x.shape
141 | return data
142 |
143 | def get(self, idx):
144 | data, data1, data2 = Data(), Data(), Data()
145 | keys_for_2D = ['x', 'edge_index', 'edge_attr']
146 | for key in self.data.keys:
147 | item, slices = self.data[key], self.slices[key]
148 | s = list(repeat(slice(None), item.dim()))
149 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
150 | if key in keys_for_2D:
151 | data[key], data1[key], data2[key] = item[s], item[s], item[s]
152 | else:
153 | data[key] = item[s]
154 |
155 | if self.aug_mode == 'no_aug':
156 | n_aug1, n_aug2 = 4, 4
157 | data1 = self.augmentations[n_aug1](data1)
158 | data2 = self.augmentations[n_aug2](data2)
159 | elif self.aug_mode == 'uniform':
160 | n_aug = np.random.choice(25, 1)[0]
161 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5
162 | data1 = self.augmentations[n_aug1](data1)
163 | data2 = self.augmentations[n_aug2](data2)
164 | elif self.aug_mode == 'sample':
165 | n_aug = np.random.choice(25, 1, p=self.aug_prob)[0]
166 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5
167 | data1 = self.augmentations[n_aug1](data1)
168 | data2 = self.augmentations[n_aug2](data2)
169 | else:
170 | raise ValueError
171 | return data, data1, data2
172 |
--------------------------------------------------------------------------------
/src_classification/datasets/molecule_motif_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import repeat
3 |
4 | import numpy as np
5 | import torch
6 | from descriptastorus.descriptors import rdDescriptors
7 | from torch_geometric.data import Data, InMemoryDataset
8 | from tqdm import tqdm
9 |
10 |
11 | RDKIT_PROPS = ['fr_Al_COO', 'fr_Al_OH', 'fr_Al_OH_noTert', 'fr_ArN',
12 | 'fr_Ar_COO', 'fr_Ar_N', 'fr_Ar_NH', 'fr_Ar_OH', 'fr_COO', 'fr_COO2',
13 | 'fr_C_O', 'fr_C_O_noCOO', 'fr_C_S', 'fr_HOCCN', 'fr_Imine', 'fr_NH0',
14 | 'fr_NH1', 'fr_NH2', 'fr_N_O', 'fr_Ndealkylation1', 'fr_Ndealkylation2',
15 | 'fr_Nhpyrrole', 'fr_SH', 'fr_aldehyde', 'fr_alkyl_carbamate', 'fr_alkyl_halide',
16 | 'fr_allylic_oxid', 'fr_amide', 'fr_amidine', 'fr_aniline', 'fr_aryl_methyl',
17 | 'fr_azide', 'fr_azo', 'fr_barbitur', 'fr_benzene', 'fr_benzodiazepine',
18 | 'fr_bicyclic', 'fr_diazo', 'fr_dihydropyridine', 'fr_epoxide', 'fr_ester',
19 | 'fr_ether', 'fr_furan', 'fr_guanido', 'fr_halogen', 'fr_hdrzine', 'fr_hdrzone',
20 | 'fr_imidazole', 'fr_imide', 'fr_isocyan', 'fr_isothiocyan', 'fr_ketone',
21 | 'fr_ketone_Topliss', 'fr_lactam', 'fr_lactone', 'fr_methoxy', 'fr_morpholine',
22 | 'fr_nitrile', 'fr_nitro', 'fr_nitro_arom', 'fr_nitro_arom_nonortho',
23 | 'fr_nitroso', 'fr_oxazole', 'fr_oxime', 'fr_para_hydroxylation', 'fr_phenol',
24 | 'fr_phenol_noOrthoHbond', 'fr_phos_acid', 'fr_phos_ester', 'fr_piperdine',
25 | 'fr_piperzine', 'fr_priamide', 'fr_prisulfonamd', 'fr_pyridine', 'fr_quatN',
26 | 'fr_sulfide', 'fr_sulfonamd', 'fr_sulfone', 'fr_term_acetylene', 'fr_tetrazole',
27 | 'fr_thiazole', 'fr_thiocyan', 'fr_thiophene', 'fr_unbrch_alkane', 'fr_urea']
28 |
29 |
30 | def rdkit_functional_group_label_features_generator(smiles):
31 | """
32 | Generates functional group label for a molecule using RDKit.
33 | :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
34 | :return: A 1D numpy array containing the RDKit 2D features.
35 | """
36 | # smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
37 | # if type(mol) != str else mol
38 | generator = rdDescriptors.RDKit2D(RDKIT_PROPS)
39 | features = generator.process(smiles)[1:]
40 | features = np.array(features)
41 | features[features != 0] = 1
42 | return features
43 |
44 |
45 | class MoleculeMotifDataset(InMemoryDataset):
46 | def __init__(self, root, dataset,
47 | transform=None, pre_transform=None, pre_filter=None, empty=False):
48 | self.dataset = dataset
49 | self.root = root
50 |
51 | super(MoleculeMotifDataset, self).__init__(root, transform, pre_transform, pre_filter)
52 | self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
53 |
54 | if not empty:
55 | self.data, self.slices = torch.load(self.processed_paths[0])
56 |
57 | self.motif_file = os.path.join(root, 'processed', 'motif.pt')
58 | self.process_motif_file()
59 | self.motif_label_list = torch.load(self.motif_file)
60 |
61 | print('Dataset: {}\nData: {}\nMotif: {}'.format(self.dataset, self.data, self.motif_label_list.size()))
62 |
63 | def process_motif_file(self):
64 | if not os.path.exists(self.motif_file):
65 | smiles_file = os.path.join(self.root, 'processed', 'smiles.csv')
66 | data_smiles_list = []
67 | with open(smiles_file, 'r') as f:
68 | lines = f.readlines()
69 | for smiles in lines:
70 | data_smiles_list.append(smiles.strip())
71 |
72 | motif_label_list = []
73 | for smiles in tqdm(data_smiles_list):
74 | label = rdkit_functional_group_label_features_generator(smiles)
75 | motif_label_list.append(label)
76 |
77 | self.motif_label_list = torch.LongTensor(motif_label_list)
78 | torch.save(self.motif_label_list, self.motif_file)
79 | return
80 |
81 | def get(self, idx):
82 | data = Data()
83 | for key in self.data.keys:
84 | item, slices = self.data[key], self.slices[key]
85 | s = list(repeat(slice(None), item.dim()))
86 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
87 | data[key] = item[s]
88 | data.y = self.motif_label_list[idx]
89 | return data
90 |
91 | @property
92 | def raw_file_names(self):
93 | return os.listdir(self.raw_dir)
94 |
95 | @property
96 | def processed_file_names(self):
97 | return 'geometric_data_processed.pt'
98 |
99 | def download(self):
100 | return
101 |
102 | def process(self):
103 | return
104 |
--------------------------------------------------------------------------------
/src_classification/models/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_geometric.nn.inits import uniform
4 |
5 | from .auto_encoder import AutoEncoder, VariationalAutoEncoder
6 | from .molecule_gnn_model import GNN, GNN_graphpred
7 | from .schnet import SchNet
8 | from .dti_model import ProteinModel, MoleculeProteinModel
9 |
10 |
11 | class Discriminator(nn.Module):
12 | def __init__(self, hidden_dim):
13 | super(Discriminator, self).__init__()
14 | self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
15 | self.reset_parameters()
16 |
17 | def reset_parameters(self):
18 | size = self.weight.size(0)
19 | uniform(size, self.weight)
20 |
21 | def forward(self, x, summary):
22 | h = torch.matmul(summary, self.weight)
23 | return torch.sum(x*h, dim=1)
24 |
--------------------------------------------------------------------------------
/src_classification/models/auto_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | def cosine_similarity(p, z, average=True):
7 | p = F.normalize(p, p=2, dim=1)
8 | z = F.normalize(z, p=2, dim=1)
9 | loss = -(p * z).sum(dim=1)
10 | if average:
11 | loss = loss.mean()
12 | return loss
13 |
14 |
15 | class AutoEncoder(torch.nn.Module):
16 |
17 | def __init__(self, emb_dim, loss, detach_target):
18 | super(AutoEncoder, self).__init__()
19 | self.loss = loss
20 | self.emb_dim = emb_dim
21 | self.detach_target = detach_target
22 |
23 | self.criterion = None
24 | if loss == 'l1':
25 | self.criterion = nn.L1Loss()
26 | elif loss == 'l2':
27 | self.criterion = nn.MSELoss()
28 | elif loss == 'cosine':
29 | self.criterion = cosine_similarity
30 |
31 | self.fc_layers = nn.Sequential(
32 | nn.Linear(self.emb_dim, self.emb_dim),
33 | nn.BatchNorm1d(self.emb_dim),
34 | nn.ReLU(),
35 | nn.Linear(self.emb_dim, self.emb_dim),
36 | )
37 | return
38 |
39 | def forward(self, x, y):
40 | if self.detach_target:
41 | y = y.detach()
42 | x = self.fc_layers(x)
43 | loss = self.criterion(x, y)
44 |
45 | return loss
46 |
47 |
48 | class VariationalAutoEncoder(torch.nn.Module):
49 | def __init__(self, emb_dim, loss, detach_target, beta=1):
50 | super(VariationalAutoEncoder, self).__init__()
51 | self.emb_dim = emb_dim
52 | self.loss = loss
53 | self.detach_target = detach_target
54 | self.beta = beta
55 |
56 | self.criterion = None
57 | if loss == 'l1':
58 | self.criterion = nn.L1Loss()
59 | elif loss == 'l2':
60 | self.criterion = nn.MSELoss()
61 | elif loss == 'cosine':
62 | self.criterion = cosine_similarity
63 |
64 | self.fc_mu = nn.Linear(self.emb_dim, self.emb_dim)
65 | self.fc_var = nn.Linear(self.emb_dim, self.emb_dim)
66 |
67 | self.decoder = nn.Sequential(
68 | nn.Linear(self.emb_dim, self.emb_dim),
69 | nn.BatchNorm1d(self.emb_dim),
70 | nn.ReLU(),
71 | nn.Linear(self.emb_dim, self.emb_dim),
72 | )
73 | return
74 |
75 | def encode(self, x):
76 | mu = self.fc_mu(x)
77 | log_var = self.fc_var(x)
78 | return mu, log_var
79 |
80 | def reparameterize(self, mu, log_var):
81 | std = torch.exp(0.5 * log_var)
82 | eps = torch.randn_like(std)
83 | return mu + eps * std
84 |
85 | def forward(self, x, y):
86 | if self.detach_target:
87 | y = y.detach()
88 |
89 | mu, log_var = self.encode(x)
90 | z = self.reparameterize(mu, log_var)
91 | y_hat = self.decoder(z)
92 |
93 | reconstruction_loss = self.criterion(y_hat, y)
94 | kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
95 |
96 | loss = reconstruction_loss + self.beta * kl_loss
97 |
98 | return loss
99 |
--------------------------------------------------------------------------------
/src_classification/models/dti_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch_geometric.nn import global_mean_pool
4 |
5 |
6 | class ProteinModel(nn.Module):
7 | def __init__(self, emb_dim=128, num_features=25, output_dim=128, n_filters=32, kernel_size=8):
8 | super(ProteinModel, self).__init__()
9 | self.n_filters = n_filters
10 | self.kernel_size = kernel_size
11 | self.intermediate_dim = emb_dim - kernel_size + 1
12 |
13 | self.embedding = nn.Embedding(num_features+1, emb_dim)
14 | self.n_filters = n_filters
15 | self.conv1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=kernel_size)
16 | self.fc = nn.Linear(n_filters*self.intermediate_dim, output_dim)
17 |
18 | def forward(self, x):
19 | x = self.embedding(x)
20 | x = self.conv1(x)
21 | x = x.view(-1, self.n_filters*self.intermediate_dim)
22 | x = self.fc(x)
23 | return x
24 |
25 |
26 | class MoleculeProteinModel(nn.Module):
27 | def __init__(self, molecule_model, protein_model, molecule_emb_dim, protein_emb_dim, output_dim=1, dropout=0.2):
28 | super(MoleculeProteinModel, self).__init__()
29 | self.fc1 = nn.Linear(molecule_emb_dim+protein_emb_dim, 1024)
30 | self.fc2 = nn.Linear(1024, 512)
31 | self.out = nn.Linear(512, output_dim)
32 | self.molecule_model = molecule_model
33 | self.protein_model = protein_model
34 | self.pool = global_mean_pool
35 | self.relu = nn.ReLU()
36 | self.dropout = nn.Dropout(dropout)
37 |
38 | def forward(self, molecule, protein):
39 | molecule_node_representation = self.molecule_model(molecule)
40 | molecule_representation = self.pool(molecule_node_representation, molecule.batch)
41 | protein_representation = self.protein_model(protein)
42 |
43 | x = torch.cat([molecule_representation, protein_representation], dim=1)
44 |
45 | x = self.fc1(x)
46 | x = self.relu(x)
47 | x = self.dropout(x)
48 | x = self.fc2(x)
49 | x = self.relu(x)
50 | x = self.dropout(x)
51 | x = self.out(x)
52 |
53 | return x
54 |
--------------------------------------------------------------------------------
/src_classification/models/schnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import warnings
4 | from math import pi as PI
5 |
6 | import ase
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.nn import Embedding, Linear, ModuleList, Sequential
11 | from torch_geometric.data.makedirs import makedirs
12 | from torch_geometric.nn import MessagePassing, radius_graph
13 | from torch_scatter import scatter
14 |
15 | try:
16 | import schnetpack as spk
17 | except ImportError:
18 | spk = None
19 |
20 |
21 | class SchNet(torch.nn.Module):
22 |
23 | def __init__(self, hidden_channels=128, num_filters=128,
24 | num_interactions=6, num_gaussians=50, cutoff=10.0,
25 | readout='mean', dipole=False, mean=None, std=None, atomref=None):
26 | super(SchNet, self).__init__()
27 |
28 | assert readout in ['add', 'sum', 'mean']
29 |
30 | self.readout = 'add' if dipole else readout
31 | self.num_interactions = num_interactions
32 | self.hidden_channels = hidden_channels
33 | self.num_gaussians = num_gaussians
34 | self.num_filters = num_filters
35 | # self.readout = readout
36 | self.cutoff = cutoff
37 | self.dipole = dipole
38 | self.scale = None
39 | self.mean = mean
40 | self.std = std
41 |
42 | atomic_mass = torch.from_numpy(ase.data.atomic_masses)
43 | self.register_buffer('atomic_mass', atomic_mass)
44 |
45 | # self.embedding = Embedding(100, hidden_channels)
46 | self.embedding = Embedding(119, hidden_channels)
47 | self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)
48 |
49 | self.interactions = ModuleList()
50 | for _ in range(num_interactions):
51 | block = InteractionBlock(hidden_channels, num_gaussians,
52 | num_filters, cutoff)
53 | self.interactions.append(block)
54 |
55 | # TODO: double-check hidden size
56 | self.lin1 = Linear(hidden_channels, hidden_channels)
57 | self.act = ShiftedSoftplus()
58 | self.lin2 = Linear(hidden_channels, hidden_channels)
59 |
60 | self.register_buffer('initial_atomref', atomref)
61 | self.atomref = None
62 | if atomref is not None:
63 | self.atomref = Embedding(100, 1)
64 | self.atomref.weight.data.copy_(atomref)
65 |
66 | self.reset_parameters()
67 |
68 | def reset_parameters(self):
69 | self.embedding.reset_parameters()
70 | for interaction in self.interactions:
71 | interaction.reset_parameters()
72 | torch.nn.init.xavier_uniform_(self.lin1.weight)
73 | self.lin1.bias.data.fill_(0)
74 | torch.nn.init.xavier_uniform_(self.lin2.weight)
75 | self.lin2.bias.data.fill_(0)
76 | if self.atomref is not None:
77 | self.atomref.weight.data.copy_(self.initial_atomref)
78 |
79 | def forward(self, z, pos, batch=None):
80 | assert z.dim() == 1 and z.dtype == torch.long
81 | batch = torch.zeros_like(z) if batch is None else batch
82 |
83 | h = self.embedding(z)
84 |
85 | edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
86 | row, col = edge_index
87 | edge_weight = (pos[row] - pos[col]).norm(dim=-1)
88 | edge_attr = self.distance_expansion(edge_weight)
89 |
90 | for interaction in self.interactions:
91 | h = h + interaction(h, edge_index, edge_weight, edge_attr)
92 |
93 | h = self.lin1(h)
94 | h = self.act(h)
95 | h = self.lin2(h)
96 |
97 | if self.dipole:
98 | # Get center of mass.
99 | mass = self.atomic_mass[z].view(-1, 1)
100 | c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0)
101 | h = h * (pos - c[batch])
102 |
103 | if not self.dipole and self.mean is not None and self.std is not None:
104 | h = h * self.std + self.mean
105 |
106 | if not self.dipole and self.atomref is not None:
107 | h = h + self.atomref(z)
108 |
109 | out = scatter(h, batch, dim=0, reduce=self.readout)
110 |
111 | if self.dipole:
112 | out = torch.norm(out, dim=-1, keepdim=True)
113 |
114 | if self.scale is not None:
115 | out = self.scale * out
116 |
117 | return out
118 |
119 | def __repr__(self):
120 | return (f'{self.__class__.__name__}('
121 | f'hidden_channels={self.hidden_channels}, '
122 | f'num_filters={self.num_filters}, '
123 | f'num_interactions={self.num_interactions}, '
124 | f'num_gaussians={self.num_gaussians}, '
125 | f'cutoff={self.cutoff})')
126 |
127 |
128 | class InteractionBlock(torch.nn.Module):
129 | def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff):
130 | super(InteractionBlock, self).__init__()
131 | self.mlp = Sequential(
132 | Linear(num_gaussians, num_filters),
133 | ShiftedSoftplus(),
134 | Linear(num_filters, num_filters),
135 | )
136 | self.conv = CFConv(hidden_channels, hidden_channels,
137 | num_filters, self.mlp, cutoff)
138 | self.act = ShiftedSoftplus()
139 | self.lin = Linear(hidden_channels, hidden_channels)
140 |
141 | self.reset_parameters()
142 |
143 | def reset_parameters(self):
144 | torch.nn.init.xavier_uniform_(self.mlp[0].weight)
145 | self.mlp[0].bias.data.fill_(0)
146 | torch.nn.init.xavier_uniform_(self.mlp[2].weight)
147 | self.mlp[0].bias.data.fill_(0)
148 | self.conv.reset_parameters()
149 | torch.nn.init.xavier_uniform_(self.lin.weight)
150 | self.lin.bias.data.fill_(0)
151 |
152 | def forward(self, x, edge_index, edge_weight, edge_attr):
153 | x = self.conv(x, edge_index, edge_weight, edge_attr)
154 | x = self.act(x)
155 | x = self.lin(x)
156 | return x
157 |
158 |
159 | class CFConv(MessagePassing):
160 | def __init__(self, in_channels, out_channels, num_filters, nn, cutoff):
161 | super(CFConv, self).__init__(aggr='add')
162 | self.lin1 = Linear(in_channels, num_filters, bias=False)
163 | self.lin2 = Linear(num_filters, out_channels)
164 | self.nn = nn
165 | self.cutoff = cutoff
166 |
167 | self.reset_parameters()
168 |
169 | def reset_parameters(self):
170 | torch.nn.init.xavier_uniform_(self.lin1.weight)
171 | torch.nn.init.xavier_uniform_(self.lin2.weight)
172 | self.lin2.bias.data.fill_(0)
173 |
174 | def forward(self, x, edge_index, edge_weight, edge_attr):
175 | C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0)
176 | W = self.nn(edge_attr) * C.view(-1, 1)
177 |
178 | x = self.lin1(x)
179 | x = self.propagate(edge_index, x=x, W=W)
180 | x = self.lin2(x)
181 | return x
182 |
183 | def message(self, x_j, W):
184 | return x_j * W
185 |
186 |
187 | class GaussianSmearing(torch.nn.Module):
188 |
189 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
190 | super(GaussianSmearing, self).__init__()
191 | offset = torch.linspace(start, stop, num_gaussians)
192 | self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
193 | self.register_buffer('offset', offset)
194 |
195 | def forward(self, dist):
196 | dist = dist.view(-1, 1) - self.offset.view(1, -1)
197 | return torch.exp(self.coeff * torch.pow(dist, 2))
198 |
199 |
200 | class ShiftedSoftplus(torch.nn.Module):
201 |
202 | def __init__(self):
203 | super(ShiftedSoftplus, self).__init__()
204 | self.shift = torch.log(torch.tensor(2.0)).item()
205 |
206 | def forward(self, x):
207 | return F.softplus(x) - self.shift
208 |
--------------------------------------------------------------------------------
/src_classification/pretrain_AM.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from config import args
8 | from dataloader import DataLoaderMasking
9 | from models import GNN
10 | from torch_geometric.nn import global_mean_pool
11 | from util import MaskAtom
12 |
13 | from datasets import MoleculeDataset
14 |
15 |
16 | def compute_accuracy(pred, target):
17 | return float(torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item())/len(pred)
18 |
19 |
20 | def do_AttrMasking(batch, criterion, node_repr, molecule_atom_masking_model):
21 | target = batch.mask_node_label[:, 0]
22 | node_pred = molecule_atom_masking_model(node_repr[batch.masked_atom_indices])
23 | attributemask_loss = criterion(node_pred.double(), target)
24 | attributemask_acc = compute_accuracy(node_pred, target)
25 | return attributemask_loss, attributemask_acc
26 |
27 |
28 | def train(device, loader, optimizer):
29 |
30 | start = time.time()
31 | molecule_model.train()
32 | molecule_atom_masking_model.train()
33 | attributemask_loss_accum, attributemask_acc_accum = 0, 0
34 |
35 | for step, batch in enumerate(loader):
36 | batch = batch.to(device)
37 | node_repr = molecule_model(batch.masked_x, batch.edge_index, batch.edge_attr)
38 |
39 | attributemask_loss, attributemask_acc = do_AttrMasking(
40 | batch=batch, criterion=criterion, node_repr=node_repr,
41 | molecule_atom_masking_model=molecule_atom_masking_model)
42 |
43 | attributemask_loss_accum += attributemask_loss.detach().cpu().item()
44 | attributemask_acc_accum += attributemask_acc
45 | loss = attributemask_loss
46 |
47 | optimizer.zero_grad()
48 | loss.backward()
49 | optimizer.step()
50 |
51 | print('AM Loss: {:.5f}\tAM Acc: {:.5f}\tTime: {:.5f}'.format(
52 | attributemask_loss_accum / len(loader),
53 | attributemask_acc_accum / len(loader),
54 | time.time() - start))
55 | return
56 |
57 |
58 | if __name__ == '__main__':
59 |
60 | np.random.seed(0)
61 | torch.manual_seed(0)
62 | device = torch.device('cuda:' + str(args.device)) \
63 | if torch.cuda.is_available() else torch.device('cpu')
64 | if torch.cuda.is_available():
65 | torch.cuda.manual_seed_all(0)
66 | torch.cuda.set_device(args.device)
67 |
68 | if 'GEOM' in args.dataset:
69 | dataset = MoleculeDataset(
70 | '../datasets/{}/'.format(args.dataset), dataset=args.dataset,
71 | transform=MaskAtom(num_atom_type=119, num_edge_type=5,
72 | mask_rate=args.mask_rate, mask_edge=args.mask_edge))
73 | loader = DataLoaderMasking(dataset, batch_size=args.batch_size,
74 | shuffle=True, num_workers=args.num_workers)
75 |
76 | # set up model
77 | molecule_model = GNN(args.num_layer, args.emb_dim,
78 | JK=args.JK, drop_ratio=args.dropout_ratio,
79 | gnn_type=args.gnn_type).to(device)
80 | molecule_readout_func = global_mean_pool
81 |
82 | molecule_atom_masking_model = torch.nn.Linear(args.emb_dim, 119).to(device)
83 |
84 | model_param_group = [{'params': molecule_model.parameters(), 'lr': args.lr},
85 | {'params': molecule_atom_masking_model.parameters(), 'lr': args.lr}]
86 |
87 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
88 | criterion = nn.CrossEntropyLoss()
89 |
90 | for epoch in range(1, args.epochs + 1):
91 | print('epoch: {}'.format(epoch))
92 | train(device, loader, optimizer)
93 |
94 | if not args.output_model_dir == '':
95 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth')
96 |
97 | saver_dict = {'model': molecule_model.state_dict(),
98 | 'molecule_atom_masking_model': molecule_atom_masking_model.state_dict()}
99 |
100 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
101 |
--------------------------------------------------------------------------------
/src_classification/pretrain_CP.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from config import args
8 | from dataloader import DataLoaderSubstructContext
9 | from models import GNN
10 | from torch_geometric.nn import global_mean_pool
11 | from util import ExtractSubstructureContextPair, cycle_index
12 |
13 | from datasets import MoleculeDataset
14 |
15 |
16 | def do_ContextPred(batch, criterion, args, molecule_substruct_model,
17 | molecule_context_model, molecule_readout_func):
18 |
19 | # creating substructure representation
20 | substruct_repr = molecule_substruct_model(
21 | batch.x_substruct, batch.edge_index_substruct,
22 | batch.edge_attr_substruct)[batch.center_substruct_idx]
23 |
24 | # creating context representations
25 | overlapped_node_repr = molecule_context_model(
26 | batch.x_context, batch.edge_index_context,
27 | batch.edge_attr_context)[batch.overlap_context_substruct_idx]
28 |
29 | # positive context representation
30 | # readout -> global_mean_pool by default
31 | context_repr = molecule_readout_func(overlapped_node_repr,
32 | batch.batch_overlapped_context)
33 |
34 | # negative contexts are obtained by shifting
35 | # the indices of context embeddings
36 | neg_context_repr = torch.cat(
37 | [context_repr[cycle_index(len(context_repr), i + 1)]
38 | for i in range(args.contextpred_neg_samples)], dim=0)
39 |
40 | num_neg = args.contextpred_neg_samples
41 | pred_pos = torch.sum(substruct_repr * context_repr, dim=1)
42 | pred_neg = torch.sum(substruct_repr.repeat((num_neg, 1)) * neg_context_repr, dim=1)
43 |
44 | loss_pos = criterion(pred_pos.double(),
45 | torch.ones(len(pred_pos)).to(pred_pos.device).double())
46 | loss_neg = criterion(pred_neg.double(),
47 | torch.zeros(len(pred_neg)).to(pred_neg.device).double())
48 |
49 | contextpred_loss = loss_pos + num_neg * loss_neg
50 |
51 | num_pred = len(pred_pos) + len(pred_neg)
52 | contextpred_acc = (torch.sum(pred_pos > 0).float() +
53 | torch.sum(pred_neg < 0).float()) / num_pred
54 | contextpred_acc = contextpred_acc.detach().cpu().item()
55 |
56 | return contextpred_loss, contextpred_acc
57 |
58 |
59 | def train(args, device, loader, optimizer):
60 |
61 | start_time = time.time()
62 | molecule_context_model.train()
63 | molecule_substruct_model.train()
64 | contextpred_loss_accum, contextpred_acc_accum = 0, 0
65 |
66 | for step, batch in enumerate(loader):
67 |
68 | batch = batch.to(device)
69 | contextpred_loss, contextpred_acc = do_ContextPred(
70 | batch=batch, criterion=criterion, args=args,
71 | molecule_substruct_model=molecule_substruct_model,
72 | molecule_context_model=molecule_context_model,
73 | molecule_readout_func=molecule_readout_func)
74 |
75 | contextpred_loss_accum += contextpred_loss.detach().cpu().item()
76 | contextpred_acc_accum += contextpred_acc
77 | ssl_loss = contextpred_loss
78 | optimizer.zero_grad()
79 | ssl_loss.backward()
80 | optimizer.step()
81 |
82 | print('CP Loss: {:.5f}\tCP Acc: {:.5f}\tTime: {:.3f}'.format(
83 | contextpred_loss_accum / len(loader),
84 | contextpred_acc_accum / len(loader),
85 | time.time() - start_time))
86 |
87 | return
88 |
89 |
90 | if __name__ == '__main__':
91 |
92 | np.random.seed(0)
93 | torch.manual_seed(0)
94 | device = torch.device('cuda:' + str(args.device)) \
95 | if torch.cuda.is_available() else torch.device('cpu')
96 | if torch.cuda.is_available():
97 | torch.cuda.manual_seed_all(0)
98 | torch.cuda.set_device(args.device)
99 |
100 | l1 = args.num_layer - 1
101 | l2 = l1 + args.csize
102 | print('num layer: %d l1: %d l2: %d' % (args.num_layer, l1, l2))
103 |
104 | if 'GEOM' in args.dataset:
105 | dataset = MoleculeDataset(
106 | '../datasets/{}/'.format(args.dataset), dataset=args.dataset,
107 | transform=ExtractSubstructureContextPair(args.num_layer, l1, l2))
108 | loader = DataLoaderSubstructContext(dataset, batch_size=args.batch_size,
109 | shuffle=True, num_workers=args.num_workers)
110 |
111 | ''' === set up model, mainly used in do_ContextPred() === '''
112 | molecule_substruct_model = GNN(
113 | args.num_layer, args.emb_dim, JK=args.JK,
114 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device)
115 | molecule_context_model = GNN(
116 | int(l2 - l1), args.emb_dim, JK=args.JK,
117 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device)
118 |
119 | ''' === set up loss and optimiser === '''
120 | criterion = nn.BCEWithLogitsLoss()
121 | molecule_readout_func = global_mean_pool
122 | model_param_group = [{'params': molecule_substruct_model.parameters(), 'lr': args.lr},
123 | {'params': molecule_context_model.parameters(), 'lr': args.lr}]
124 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
125 |
126 | for epoch in range(1, args.epochs + 1):
127 | print('epoch: {}'.format(epoch))
128 | train(args, device, loader, optimizer)
129 |
130 | if not args.output_model_dir == '':
131 | torch.save(molecule_substruct_model.state_dict(),
132 | args.output_model_dir + '_model.pth')
133 |
134 | saver_dict = {
135 | 'molecule_substruct_model': molecule_substruct_model.state_dict(),
136 | 'molecule_context_model': molecule_context_model.state_dict()}
137 |
138 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
139 |
--------------------------------------------------------------------------------
/src_classification/pretrain_Contextual.py:
--------------------------------------------------------------------------------
1 | import time
2 | from os.path import join
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from config import args
10 | from models import GNN, GNN_graphpred
11 | from sklearn.metrics import roc_auc_score
12 | from torch_geometric.data import DataLoader
13 | from util import get_num_task
14 |
15 | from datasets import MoleculeContextualDataset
16 |
17 |
18 | def compute_accuracy(pred, target):
19 | return float(torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item())/len(pred)
20 |
21 |
22 | def do_Contextual(batch, criterion, node_repr, atom_vocab_model):
23 | target = batch.atom_vocab_label
24 | node_pred = atom_vocab_model(node_repr)
25 | loss = criterion(node_pred, target)
26 | acc = compute_accuracy(node_pred, target)
27 | return loss, acc
28 |
29 |
30 | def train(device, loader, optimizer):
31 | start = time.time()
32 | molecule_model.train()
33 | atom_vocab_model.train()
34 |
35 | contextual_loss_accum, contextual_acc_accum = 0, 0
36 | for step, batch in enumerate(loader):
37 | batch = batch.to(device)
38 | node_repr = molecule_model(batch.x, batch.edge_index, batch.edge_attr)
39 |
40 | contextual_loss, contextual_acc = do_Contextual(batch, criterion, node_repr, atom_vocab_model)
41 | contextual_loss_accum += contextual_loss.detach().cpu().item()
42 | contextual_acc_accum += contextual_acc
43 | loss = contextual_loss
44 |
45 | optimizer.zero_grad()
46 | loss.backward()
47 | optimizer.step()
48 |
49 | print('Contextual Loss: {:.5f}\tContextual Acc: {:.5f}\tTime: {:.5f}'.format(
50 | contextual_loss_accum / len(loader),
51 | contextual_acc_accum / len(loader),
52 | time.time() - start))
53 | return
54 |
55 |
56 | if __name__ == '__main__':
57 | torch.manual_seed(args.runseed)
58 | np.random.seed(args.runseed)
59 | device = torch.device('cuda:' + str(args.device)) \
60 | if torch.cuda.is_available() else torch.device('cpu')
61 | if torch.cuda.is_available():
62 | torch.cuda.manual_seed_all(args.runseed)
63 |
64 | # Bunch of classification tasks
65 | assert 'GEOM' in args.dataset
66 | dataset_folder = '../datasets/'
67 | dataset = MoleculeContextualDataset(dataset_folder + args.dataset, dataset=args.dataset)
68 | print(dataset)
69 |
70 | atom_vocab = dataset.atom_vocab
71 | atom_vocab_size = len(atom_vocab)
72 | print('atom_vocab\t', len(atom_vocab), atom_vocab_size)
73 |
74 | loader = DataLoader(dataset, batch_size=args.batch_size,
75 | shuffle=True, num_workers=args.num_workers)
76 |
77 | # set up model
78 | molecule_model = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim,
79 | JK=args.JK, drop_ratio=args.dropout_ratio,
80 | gnn_type=args.gnn_type).to(device)
81 | atom_vocab_model = nn.Linear(args.emb_dim, atom_vocab_size).to(device)
82 |
83 | # set up optimizer
84 | # different learning rates for different parts of GNN
85 | model_param_group = [{'params': molecule_model.parameters()},
86 | {'params': atom_vocab_model.parameters(), 'lr': args.lr * args.lr_scale}]
87 | optimizer = optim.Adam(model_param_group, lr=args.lr,
88 | weight_decay=args.decay)
89 | criterion = nn.CrossEntropyLoss()
90 | train_roc_list, val_roc_list, test_roc_list = [], [], []
91 | best_val_roc, best_val_idx = -1, 0
92 |
93 | print('\nStart pre-training Contextual')
94 | for epoch in range(1, args.epochs + 1):
95 | print('epoch: {}'.format(epoch))
96 | train(device, loader, optimizer)
97 |
98 | if args.output_model_dir is not '':
99 | print('saving to {}'.format(args.output_model_dir + '_model.pth'))
100 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth')
101 | saved_model_dict = {
102 | 'molecule_model': molecule_model.state_dict(),
103 | 'atom_vocab_model': atom_vocab_model.state_dict(),
104 | }
105 | torch.save(saved_model_dict, args.output_model_dir + '_model_complete.pth')
106 |
--------------------------------------------------------------------------------
/src_classification/pretrain_EP.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from config import args
8 | from dataloader import DataLoaderAE
9 | from models import GNN
10 | from torch_geometric.nn import global_mean_pool
11 | from util import NegativeEdge
12 |
13 | from datasets import MoleculeDataset
14 |
15 |
16 | def do_EdgePred(node_repr, batch, criterion=nn.BCEWithLogitsLoss()):
17 |
18 | # positive/negative scores -> inner product of node features
19 | positive_score = torch.sum(node_repr[batch.edge_index[0, ::2]] *
20 | node_repr[batch.edge_index[1, ::2]], dim=1)
21 | negative_score = torch.sum(node_repr[batch.negative_edge_index[0]] *
22 | node_repr[batch.negative_edge_index[1]], dim=1)
23 |
24 | edgepred_loss = criterion(positive_score, torch.ones_like(positive_score)) + \
25 | criterion(negative_score, torch.zeros_like(negative_score))
26 | edgepred_acc = (torch.sum(positive_score > 0) +
27 | torch.sum(negative_score < 0)).to(torch.float32) / \
28 | float(2 * len(positive_score))
29 | edgepred_acc = edgepred_acc.detach().cpu().item()
30 |
31 | return edgepred_loss, edgepred_acc
32 |
33 |
34 | def train(molecule_model, device, loader, optimizer,
35 | criterion=nn.BCEWithLogitsLoss()):
36 |
37 | # Train for one epoch
38 | molecule_model.train()
39 | start_time = time.time()
40 | edgepred_loss_accum, edgepred_acc_accum = 0, 0
41 |
42 | for step, batch in enumerate(loader):
43 |
44 | batch = batch.to(device)
45 |
46 | node_repr = molecule_model(batch.x, batch.edge_index, batch.edge_attr)
47 | edgepred_loss, edgepred_acc = do_EdgePred(
48 | node_repr=node_repr, batch=batch, criterion=criterion)
49 | edgepred_loss_accum += edgepred_loss.detach().cpu().item()
50 | edgepred_acc_accum += edgepred_acc
51 | ssl_loss = edgepred_loss
52 |
53 | optimizer.zero_grad()
54 | ssl_loss.backward()
55 | optimizer.step()
56 |
57 | print('EP Loss: {:.5f}\tEP Acc: {:.5f}\tTime: {:.5f}'.format(
58 | edgepred_loss_accum / len(loader),
59 | edgepred_acc_accum / len(loader),
60 | time.time() - start_time))
61 | return
62 |
63 |
64 | if __name__ == '__main__':
65 | torch.manual_seed(0)
66 | np.random.seed(0)
67 | device = torch.device('cuda:' + str(args.device)) if torch.cuda.is_available() else torch.device('cpu')
68 | if torch.cuda.is_available():
69 | torch.cuda.manual_seed_all(0)
70 | torch.cuda.set_device(args.device)
71 |
72 | if 'GEOM' in args.dataset:
73 | dataset = MoleculeDataset('../datasets/{}/'.format(args.dataset), dataset=args.dataset, transform=NegativeEdge())
74 | loader = DataLoaderAE(dataset, batch_size=args.batch_size,
75 | shuffle=True, num_workers=args.num_workers)
76 |
77 | # set up model
78 | molecule_model = GNN(args.num_layer, args.emb_dim,
79 | JK=args.JK, drop_ratio=args.dropout_ratio,
80 | gnn_type=args.gnn_type).to(device)
81 | molecule_readout_func = global_mean_pool
82 |
83 | model_param_group = [{'params': molecule_model.parameters(), 'lr': args.lr}]
84 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
85 | criterion = nn.BCEWithLogitsLoss()
86 |
87 | for epoch in range(1, args.epochs + 1):
88 | print('epoch: {}'.format(epoch))
89 | train(molecule_model, device, loader, optimizer)
90 |
91 | if not args.output_model_dir == '':
92 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth')
93 | saver_dict = {'model': molecule_model.state_dict()}
94 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
95 |
--------------------------------------------------------------------------------
/src_classification/pretrain_GPT_GNN.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from config import args
8 | from models import GNN
9 | from torch_geometric.data import DataLoader
10 | from torch_geometric.nn import global_mean_pool
11 |
12 | from datasets import MoleculeDataset, MoleculeDatasetGPT
13 |
14 |
15 | def compute_accuracy(pred, target):
16 | return float(torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item())/len(pred)
17 |
18 |
19 | def train(device, loader, optimizer):
20 | start = time.time()
21 | molecule_model.train()
22 | node_pred_model.train()
23 | gpt_loss_accum, gpt_acc_accum = 0, 0
24 |
25 | for step, batch in enumerate(loader):
26 | batch = batch.to(device)
27 | node_repr = molecule_model(batch.x, batch.edge_index, batch.edge_attr)
28 | graph_repr = molecule_readout_func(node_repr, batch.batch)
29 | node_pred = node_pred_model(graph_repr)
30 | target = batch.next_x
31 |
32 | gpt_loss = criterion(node_pred.double(), target)
33 | gpt_acc = compute_accuracy(node_pred, target)
34 |
35 | gpt_loss_accum += gpt_loss.detach().cpu().item()
36 | gpt_acc_accum += gpt_acc
37 | loss = gpt_loss
38 |
39 | optimizer.zero_grad()
40 | loss.backward()
41 | optimizer.step()
42 |
43 | print('GPT Loss: {:.5f}\tGPT Acc: {:.5f}\tTime: {:.5f}'.format(
44 | gpt_loss_accum / len(loader), gpt_acc_accum / len(loader), time.time() - start))
45 | return
46 |
47 |
48 | if __name__ == '__main__':
49 | torch.manual_seed(0)
50 | np.random.seed(0)
51 | device = torch.device('cuda:' + str(args.device)) \
52 | if torch.cuda.is_available() else torch.device('cpu')
53 | if torch.cuda.is_available():
54 | torch.cuda.manual_seed_all(0)
55 | torch.cuda.set_device(args.device)
56 |
57 | if 'GEOM' in args.dataset:
58 | molecule_dataset = MoleculeDataset('../datasets/{}/'.format(args.dataset), dataset=args.dataset)
59 | molecule_gpt_dataset = MoleculeDatasetGPT(molecule_dataset)
60 | loader = DataLoader(molecule_gpt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
61 |
62 | # set up model
63 | molecule_model = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim,
64 | JK=args.JK, drop_ratio=args.dropout_ratio,
65 | gnn_type=args.gnn_type).to(device)
66 | node_pred_model = nn.Linear(args.emb_dim, 120).to(device)
67 |
68 | model_param_group = [
69 | {'params': molecule_model.parameters(), 'lr': args.lr},
70 | {'params': node_pred_model.parameters(), 'lr': args.lr},
71 | ]
72 |
73 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
74 | molecule_readout_func = global_mean_pool
75 | criterion = nn.CrossEntropyLoss()
76 |
77 | for epoch in range(1, args.epochs + 1):
78 | print('epoch: {}'.format(epoch))
79 | train(device, loader, optimizer)
80 |
81 | if not args.output_model_dir == '':
82 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth')
83 | saver_dict = {
84 | 'model': molecule_model.state_dict(),
85 | 'node_pred_model': node_pred_model.state_dict(),
86 | }
87 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
88 |
--------------------------------------------------------------------------------
/src_classification/pretrain_GraphCL.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 | import torch.optim as optim
7 | from models import GNN
8 | from pretrain_JOAO import graphcl
9 | from torch_geometric.data import DataLoader
10 |
11 | from datasets import MoleculeDataset_graphcl
12 |
13 |
14 | def train(loader, model, optimizer, device):
15 |
16 | model.train()
17 | train_loss_accum = 0
18 |
19 | for step, (_, batch1, batch2) in enumerate(loader):
20 | # _, batch1, batch2 = batch
21 | batch1 = batch1.to(device)
22 | batch2 = batch2.to(device)
23 |
24 | x1 = model.forward_cl(batch1.x, batch1.edge_index,
25 | batch1.edge_attr, batch1.batch)
26 | x2 = model.forward_cl(batch2.x, batch2.edge_index,
27 | batch2.edge_attr, batch2.batch)
28 | loss = model.loss_cl(x1, x2)
29 |
30 | optimizer.zero_grad()
31 | loss.backward()
32 | optimizer.step()
33 |
34 | train_loss_accum += float(loss.detach().cpu().item())
35 |
36 | return train_loss_accum / (step + 1)
37 |
38 |
39 | if __name__ == "__main__":
40 | # Training settings
41 | parser = argparse.ArgumentParser(description='GraphCL')
42 | parser.add_argument('--device', type=int, default=0, help='gpu')
43 | parser.add_argument('--batch_size', type=int, default=256, help='batch')
44 | parser.add_argument('--decay', type=float, default=0, help='weight decay')
45 | parser.add_argument('--epochs', type=int, default=100, help='train epochs')
46 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
47 | parser.add_argument('--JK', type=str, default="last",
48 | choices=['last', 'sum', 'max', 'concat'],
49 | help='how the node features across layers are combined.')
50 | parser.add_argument('--gnn_type', type=str, default="gin", help='gnn model type')
51 | parser.add_argument('--dropout_ratio', type=float, default=0, help='dropout ratio')
52 | parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions')
53 | parser.add_argument('--dataset', type=str, default=None, help='root dir of dataset')
54 | parser.add_argument('--num_layer', type=int, default=5, help='message passing layers')
55 | # parser.add_argument('--seed', type=int, default=0, help="Seed for splitting dataset")
56 | parser.add_argument('--output_model_file', type=str, default='', help='model save path')
57 | parser.add_argument('--num_workers', type=int, default=8, help='workers for dataset loading')
58 |
59 | parser.add_argument('--aug_mode', type=str, default='sample')
60 | parser.add_argument('--aug_strength', type=float, default=0.2)
61 |
62 | # parser.add_argument('--gamma', type=float, default=0.1)
63 | parser.add_argument('--output_model_dir', type=str, default='')
64 | args = parser.parse_args()
65 |
66 | torch.manual_seed(0)
67 | np.random.seed(0)
68 | device = torch.device("cuda:" + str(args.device)) \
69 | if torch.cuda.is_available() else torch.device("cpu")
70 | if torch.cuda.is_available():
71 | torch.cuda.manual_seed_all(0)
72 |
73 | # set up dataset
74 | if 'GEOM' in args.dataset:
75 | dataset = MoleculeDataset_graphcl('../datasets/{}/'.format(args.dataset),
76 | dataset=args.dataset)
77 | dataset.set_augMode(args.aug_mode)
78 | dataset.set_augStrength(args.aug_strength)
79 | print(dataset)
80 |
81 | loader = DataLoader(dataset, batch_size=args.batch_size,
82 | num_workers=args.num_workers, shuffle=True)
83 |
84 | # set up model
85 | gnn = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim, JK=args.JK,
86 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type)
87 |
88 | model = graphcl(gnn)
89 | model.to(device)
90 |
91 | # set up optimizer
92 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
93 | print(optimizer)
94 |
95 | aug_prob = np.ones(25) / 25
96 | dataset.set_augProb(aug_prob)
97 | for epoch in range(1, args.epochs + 1):
98 | start_time = time.time()
99 | pretrain_loss = train(loader, model, optimizer, device)
100 |
101 | print('Epoch: {:3d}\tLoss:{:.3f}\tTime: {:.3f}:'.format(
102 | epoch, pretrain_loss, time.time() - start_time))
103 |
104 | if not args.output_model_dir == '':
105 | saver_dict = {'model': model.state_dict()}
106 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
107 | torch.save(model.gnn.state_dict(), args.output_model_dir + '_model.pth')
108 |
--------------------------------------------------------------------------------
/src_classification/pretrain_GraphMVP.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | from config import args
9 | from models import GNN, AutoEncoder, SchNet, VariationalAutoEncoder
10 | from torch_geometric.data import DataLoader
11 | from torch_geometric.nn import global_mean_pool
12 | from tqdm import tqdm
13 | from util import dual_CL
14 |
15 | from datasets import Molecule3DMaskingDataset
16 |
17 |
18 | def save_model(save_best):
19 | if not args.output_model_dir == '':
20 | if save_best:
21 | global optimal_loss
22 | print('save model with loss: {:.5f}'.format(optimal_loss))
23 | torch.save(molecule_model_2D.state_dict(), args.output_model_dir + '_model.pth')
24 | saver_dict = {
25 | 'model': molecule_model_2D.state_dict(),
26 | 'model_3D': molecule_model_3D.state_dict(),
27 | 'AE_2D_3D_model': AE_2D_3D_model.state_dict(),
28 | 'AE_3D_2D_model': AE_3D_2D_model.state_dict(),
29 | }
30 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
31 |
32 | else:
33 | torch.save(molecule_model_2D.state_dict(), args.output_model_dir + '_model_final.pth')
34 | saver_dict = {
35 | 'model': molecule_model_2D.state_dict(),
36 | 'model_3D': molecule_model_3D.state_dict(),
37 | 'AE_2D_3D_model': AE_2D_3D_model.state_dict(),
38 | 'AE_3D_2D_model': AE_3D_2D_model.state_dict(),
39 | }
40 | torch.save(saver_dict, args.output_model_dir + '_model_complete_final.pth')
41 | return
42 |
43 |
44 | def train(args, molecule_model_2D, device, loader, optimizer):
45 | start_time = time.time()
46 |
47 | molecule_model_2D.train()
48 | molecule_model_3D.train()
49 | if molecule_projection_layer is not None:
50 | molecule_projection_layer.train()
51 |
52 | AE_loss_accum, AE_acc_accum = 0, 0
53 | CL_loss_accum, CL_acc_accum = 0, 0
54 |
55 | if args.verbose:
56 | l = tqdm(loader)
57 | else:
58 | l = loader
59 | for step, batch in enumerate(l):
60 | batch = batch.to(device)
61 |
62 | node_repr = molecule_model_2D(batch.x, batch.edge_index, batch.edge_attr)
63 | molecule_2D_repr = molecule_readout_func(node_repr, batch.batch)
64 |
65 | if args.model_3d == 'schnet':
66 | molecule_3D_repr = molecule_model_3D(batch.x[:, 0], batch.positions, batch.batch)
67 |
68 |
69 | CL_loss, CL_acc = dual_CL(molecule_2D_repr, molecule_3D_repr, args)
70 | AE_loss_1 = AE_2D_3D_model(molecule_2D_repr, molecule_3D_repr)
71 | AE_loss_2 = AE_3D_2D_model(molecule_3D_repr, molecule_2D_repr)
72 | AE_acc_1 = AE_acc_2 = 0
73 | AE_loss = (AE_loss_1 + AE_loss_2) / 2
74 |
75 | CL_loss_accum += CL_loss.detach().cpu().item()
76 | CL_acc_accum += CL_acc
77 | AE_loss_accum += AE_loss.detach().cpu().item()
78 | AE_acc_accum += (AE_acc_1 + AE_acc_2) / 2
79 |
80 | loss = 0
81 | if args.alpha_1 > 0:
82 | loss += CL_loss * args.alpha_1
83 | if args.alpha_2 > 0:
84 | loss += AE_loss * args.alpha_2
85 |
86 | optimizer.zero_grad()
87 | loss.backward()
88 | optimizer.step()
89 |
90 | global optimal_loss
91 | CL_loss_accum /= len(loader)
92 | CL_acc_accum /= len(loader)
93 | AE_loss_accum /= len(loader)
94 | AE_acc_accum /= len(loader)
95 | temp_loss = args.alpha_1 * CL_loss_accum + args.alpha_2 * AE_loss_accum
96 | if temp_loss < optimal_loss:
97 | optimal_loss = temp_loss
98 | save_model(save_best=True)
99 | print('CL Loss: {:.5f}\tCL Acc: {:.5f}\t\tAE Loss: {:.5f}\tAE Acc: {:.5f}\tTime: {:.5f}'.format(
100 | CL_loss_accum, CL_acc_accum, AE_loss_accum, AE_acc_accum, time.time() - start_time))
101 | return
102 |
103 |
104 | if __name__ == '__main__':
105 | torch.manual_seed(0)
106 | np.random.seed(0)
107 | device = torch.device('cuda:' + str(args.device)) if torch.cuda.is_available() else torch.device('cpu')
108 | if torch.cuda.is_available():
109 | torch.cuda.manual_seed_all(0)
110 | torch.cuda.set_device(args.device)
111 |
112 | if 'GEOM' in args.dataset:
113 | data_root = '../datasets/{}/'.format(args.dataset) if args.input_data_dir == '' else '{}/{}/'.format(args.input_data_dir, args.dataset)
114 | dataset = Molecule3DMaskingDataset(data_root, dataset=args.dataset, mask_ratio=args.SSL_masking_ratio)
115 | else:
116 | raise Exception
117 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
118 |
119 | # set up model
120 | molecule_model_2D = GNN(args.num_layer, args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device)
121 | molecule_readout_func = global_mean_pool
122 |
123 | print('Using 3d model\t', args.model_3d)
124 | molecule_projection_layer = None
125 | if args.model_3d == 'schnet':
126 | molecule_model_3D = SchNet(
127 | hidden_channels=args.emb_dim, num_filters=args.num_filters, num_interactions=args.num_interactions,
128 | num_gaussians=args.num_gaussians, cutoff=args.cutoff, atomref=None, readout=args.readout).to(device)
129 | else:
130 | raise NotImplementedError('Model {} not included.'.format(args.model_3d))
131 |
132 | if args.AE_model == 'AE':
133 | AE_2D_3D_model = AutoEncoder(
134 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target).to(device)
135 | AE_3D_2D_model = AutoEncoder(
136 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target).to(device)
137 | elif args.AE_model == 'VAE':
138 | AE_2D_3D_model = VariationalAutoEncoder(
139 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target, beta=args.beta).to(device)
140 | AE_3D_2D_model = VariationalAutoEncoder(
141 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target, beta=args.beta).to(device)
142 | else:
143 | raise Exception
144 |
145 | model_param_group = []
146 | model_param_group.append({'params': molecule_model_2D.parameters(), 'lr': args.lr * args.gnn_lr_scale})
147 | model_param_group.append({'params': molecule_model_3D.parameters(), 'lr': args.lr * args.schnet_lr_scale})
148 | model_param_group.append({'params': AE_2D_3D_model.parameters(), 'lr': args.lr * args.gnn_lr_scale})
149 | model_param_group.append({'params': AE_3D_2D_model.parameters(), 'lr': args.lr * args.schnet_lr_scale})
150 | if molecule_projection_layer is not None:
151 | model_param_group.append({'params': molecule_projection_layer.parameters(), 'lr': args.lr * args.schnet_lr_scale})
152 |
153 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
154 | optimal_loss = 1e10
155 |
156 | for epoch in range(1, args.epochs + 1):
157 | print('epoch: {}'.format(epoch))
158 | train(args, molecule_model_2D, device, loader, optimizer)
159 |
160 | save_model(save_best=False)
161 |
--------------------------------------------------------------------------------
/src_classification/pretrain_IG.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from config import args
8 | from models import GNN, Discriminator
9 | from torch_geometric.data import DataLoader
10 | from torch_geometric.nn import global_mean_pool
11 | from util import cycle_index
12 |
13 | from datasets import MoleculeDataset
14 |
15 |
16 | def do_InfoGraph(node_repr, molecule_repr, batch,
17 | criterion, infograph_discriminator_SSL_model):
18 |
19 | summary_repr = torch.sigmoid(molecule_repr)
20 | positive_expanded_summary_repr = summary_repr[batch.batch]
21 | shifted_summary_repr = summary_repr[cycle_index(len(summary_repr), 1)]
22 | negative_expanded_summary_repr = shifted_summary_repr[batch.batch]
23 |
24 | positive_score = infograph_discriminator_SSL_model(
25 | node_repr, positive_expanded_summary_repr)
26 | negative_score = infograph_discriminator_SSL_model(
27 | node_repr, negative_expanded_summary_repr)
28 | infograph_loss = criterion(positive_score, torch.ones_like(positive_score)) + \
29 | criterion(negative_score, torch.zeros_like(negative_score))
30 |
31 | num_sample = float(2 * len(positive_score))
32 | infograph_acc = (torch.sum(positive_score > 0) +
33 | torch.sum(negative_score < 0)).to(torch.float32) / num_sample
34 | infograph_acc = infograph_acc.detach().cpu().item()
35 |
36 | return infograph_loss, infograph_acc
37 |
38 |
39 | def train(molecule_model, device, loader, optimizer):
40 |
41 | start = time.time()
42 | molecule_model.train()
43 | infograph_loss_accum, infograph_acc_accum = 0, 0
44 |
45 | for step, batch in enumerate(loader):
46 |
47 | batch = batch.to(device)
48 | node_repr = molecule_model(batch.x, batch.edge_index, batch.edge_attr)
49 | molecule_repr = molecule_readout_func(node_repr, batch.batch)
50 |
51 | infograph_loss, infograph_acc = do_InfoGraph(
52 | node_repr=node_repr, batch=batch,
53 | molecule_repr=molecule_repr, criterion=criterion,
54 | infograph_discriminator_SSL_model=infograph_discriminator_SSL_model)
55 |
56 | infograph_loss_accum += infograph_loss.detach().cpu().item()
57 | infograph_acc_accum += infograph_acc
58 | ssl_loss = infograph_loss
59 | optimizer.zero_grad()
60 | ssl_loss.backward()
61 | optimizer.step()
62 |
63 | print('IG Loss: {:.5f}\tIG Acc: {:.5f}\tTime: {:.3f}'.format(
64 | infograph_loss_accum / len(loader),
65 | infograph_acc_accum / len(loader),
66 | time.time() - start))
67 | return
68 |
69 |
70 | if __name__ == '__main__':
71 |
72 | torch.manual_seed(0)
73 | np.random.seed(0)
74 | device = torch.device('cuda:' + str(args.device)) \
75 | if torch.cuda.is_available() else torch.device('cpu')
76 | if torch.cuda.is_available():
77 | torch.cuda.manual_seed_all(0)
78 | torch.cuda.set_device(args.device)
79 |
80 | if 'GEOM' in args.dataset:
81 | dataset = MoleculeDataset('../datasets/{}/'.format(args.dataset), dataset=args.dataset)
82 | loader = DataLoader(dataset, batch_size=args.batch_size,
83 | shuffle=True, num_workers=args.num_workers)
84 |
85 | # set up model
86 | molecule_model = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim,
87 | JK=args.JK, drop_ratio=args.dropout_ratio,
88 | gnn_type=args.gnn_type).to(device)
89 | infograph_discriminator_SSL_model = Discriminator(args.emb_dim).to(device)
90 |
91 | model_param_group = [{'params': molecule_model.parameters(), 'lr': args.lr},
92 | {'params': infograph_discriminator_SSL_model.parameters(),
93 | 'lr': args.lr}]
94 |
95 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
96 | molecule_readout_func = global_mean_pool
97 | criterion = nn.BCEWithLogitsLoss()
98 |
99 | for epoch in range(1, args.epochs + 1):
100 | print('epoch: {}'.format(epoch))
101 | train(molecule_model, device, loader, optimizer)
102 |
103 | if not args.output_model_dir == '':
104 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth')
105 |
106 | saver_dict = {'model': molecule_model.state_dict(),
107 | 'infograph_discriminator_SSL_model': infograph_discriminator_SSL_model.state_dict()}
108 |
109 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
110 |
--------------------------------------------------------------------------------
/src_classification/pretrain_JOAO.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from itertools import repeat
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from models import GNN
10 | from torch_geometric.data import DataLoader
11 | from torch_geometric.nn import global_mean_pool
12 |
13 | from datasets import MoleculeDataset_graphcl
14 |
15 |
16 | class graphcl(nn.Module):
17 |
18 | def __init__(self, gnn):
19 | super(graphcl, self).__init__()
20 | self.gnn = gnn
21 | self.pool = global_mean_pool
22 | self.projection_head = nn.Sequential(
23 | nn.Linear(300, 300),
24 | nn.ReLU(inplace=True),
25 | nn.Linear(300, 300))
26 |
27 | def forward_cl(self, x, edge_index, edge_attr, batch):
28 | x = self.gnn(x, edge_index, edge_attr)
29 | x = self.pool(x, batch)
30 | x = self.projection_head(x)
31 | return x
32 |
33 | def loss_cl(self, x1, x2):
34 | T = 0.1
35 | batch, _ = x1.size()
36 | x1_abs = x1.norm(dim=1)
37 | x2_abs = x2.norm(dim=1)
38 |
39 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / \
40 | torch.einsum('i,j->ij', x1_abs, x2_abs)
41 | sim_matrix = torch.exp(sim_matrix / T)
42 | pos_sim = sim_matrix[range(batch), range(batch)]
43 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
44 | loss = - torch.log(loss).mean()
45 | return loss
46 |
47 |
48 | def train(loader, model, optimizer, device, gamma_joao):
49 |
50 | model.train()
51 | train_loss_accum = 0
52 |
53 | for step, (_, batch1, batch2) in enumerate(loader):
54 | # _, batch1, batch2 = batch
55 | batch1 = batch1.to(device)
56 | batch2 = batch2.to(device)
57 |
58 | # pdb.set_trace()
59 | x1 = model.forward_cl(batch1.x, batch1.edge_index,
60 | batch1.edge_attr, batch1.batch)
61 | x2 = model.forward_cl(batch2.x, batch2.edge_index,
62 | batch2.edge_attr, batch2.batch)
63 | loss = model.loss_cl(x1, x2)
64 |
65 | optimizer.zero_grad()
66 | loss.backward()
67 | optimizer.step()
68 |
69 | train_loss_accum += float(loss.detach().cpu().item())
70 |
71 | # joao
72 | aug_prob = loader.dataset.aug_prob
73 | loss_aug = np.zeros(25)
74 | for n in range(25):
75 | _aug_prob = np.zeros(25)
76 | _aug_prob[n] = 1
77 | loader.dataset.set_augProb(_aug_prob)
78 | # for efficiency, we only use around 10% of data to estimate the loss
79 | count, count_stop = 0, len(loader.dataset) // (loader.batch_size * 10) + 1
80 |
81 | with torch.no_grad():
82 | for step, (_, batch1, batch2) in enumerate(loader):
83 | # _, batch1, batch2 = batch
84 | batch1 = batch1.to(device)
85 | batch2 = batch2.to(device)
86 |
87 | x1 = model.forward_cl(batch1.x, batch1.edge_index,
88 | batch1.edge_attr, batch1.batch)
89 | x2 = model.forward_cl(batch2.x, batch2.edge_index,
90 | batch2.edge_attr, batch2.batch)
91 | loss = model.loss_cl(x1, x2)
92 | loss_aug[n] += loss.item()
93 | count += 1
94 | if count == count_stop:
95 | break
96 | loss_aug[n] /= count
97 |
98 | # view selection, projected gradient descent,
99 | # reference: https://arxiv.org/abs/1906.03563
100 | beta = 1
101 | gamma = gamma_joao
102 |
103 | b = aug_prob + beta * (loss_aug - gamma * (aug_prob - 1 / 25))
104 | mu_min, mu_max = b.min() - 1 / 25, b.max() - 1 / 25
105 | mu = (mu_min + mu_max) / 2
106 |
107 | # bisection method
108 | while abs(np.maximum(b - mu, 0).sum() - 1) > 1e-2:
109 | if np.maximum(b - mu, 0).sum() > 1:
110 | mu_min = mu
111 | else:
112 | mu_max = mu
113 | mu = (mu_min + mu_max) / 2
114 |
115 | aug_prob = np.maximum(b - mu, 0)
116 | aug_prob /= aug_prob.sum()
117 |
118 | return train_loss_accum / (step + 1), aug_prob
119 |
120 |
121 | if __name__ == "__main__":
122 | # Training settings
123 | parser = argparse.ArgumentParser(description='JOAO')
124 | parser.add_argument('--device', type=int, default=0, help='gpu')
125 | parser.add_argument('--batch_size', type=int, default=256, help='batch')
126 | parser.add_argument('--decay', type=float, default=0, help='weight decay')
127 | parser.add_argument('--epochs', type=int, default=100, help='train epochs')
128 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
129 | parser.add_argument('--JK', type=str, default="last",
130 | choices=['last', 'sum', 'max', 'concat'],
131 | help='how the node features across layers are combined.')
132 | parser.add_argument('--gnn_type', type=str, default="gin", help='gnn model type')
133 | parser.add_argument('--dropout_ratio', type=float, default=0, help='dropout ratio')
134 | parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions')
135 | parser.add_argument('--dataset', type=str, default=None, help='root dir of dataset')
136 | parser.add_argument('--num_layer', type=int, default=5, help='message passing layers')
137 | # parser.add_argument('--seed', type=int, default=0, help="Seed for splitting dataset")
138 | parser.add_argument('--output_model_file', type=str, default='', help='model save path')
139 | parser.add_argument('--num_workers', type=int, default=8, help='workers for dataset loading')
140 |
141 | parser.add_argument('--aug_mode', type=str, default='sample')
142 | parser.add_argument('--aug_strength', type=float, default=0.2)
143 |
144 | parser.add_argument('--gamma', type=float, default=0.1)
145 | parser.add_argument('--output_model_dir', type=str, default='')
146 | args = parser.parse_args()
147 |
148 | torch.manual_seed(0)
149 | np.random.seed(0)
150 | device = torch.device("cuda:" + str(args.device)) \
151 | if torch.cuda.is_available() else torch.device("cpu")
152 | if torch.cuda.is_available():
153 | torch.cuda.manual_seed_all(0)
154 |
155 | if 'GEOM' in args.dataset:
156 | dataset = MoleculeDataset_graphcl('../datasets/{}/'.format(args.dataset), dataset=args.dataset)
157 | dataset.set_augMode(args.aug_mode)
158 | dataset.set_augStrength(args.aug_strength)
159 | print(dataset)
160 |
161 | loader = DataLoader(dataset, batch_size=args.batch_size,
162 | num_workers=args.num_workers, shuffle=True)
163 |
164 | # set up model
165 | gnn = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim, JK=args.JK,
166 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type)
167 |
168 | model = graphcl(gnn)
169 | model.to(device)
170 |
171 | # set up optimizer
172 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
173 | # print(optimizer)
174 |
175 | # pdb.set_trace()
176 | aug_prob = np.ones(25) / 25
177 | np.set_printoptions(precision=3, floatmode='fixed')
178 |
179 | for epoch in range(1, args.epochs + 1):
180 | print('\n\n')
181 | start_time = time.time()
182 | dataset.set_augProb(aug_prob)
183 | pretrain_loss, aug_prob = train(loader, model, optimizer, device, args.gamma)
184 |
185 | print('Epoch: {:3d}\tLoss:{:.3f}\tTime: {:.3f}\tAugmentation Probability:'.format(
186 | epoch, pretrain_loss, time.time() - start_time))
187 | print(aug_prob)
188 |
189 | if not args.output_model_dir == '':
190 | saver_dict = {'model': model.state_dict()}
191 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
192 | torch.save(model.gnn.state_dict(), args.output_model_dir + '_model.pth')
193 |
--------------------------------------------------------------------------------
/src_classification/pretrain_Motif.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from config import args
7 | from models import GNN, GNN_graphpred
8 | from sklearn.metrics import roc_auc_score
9 | from torch_geometric.data import DataLoader
10 |
11 | from datasets import RDKIT_PROPS, MoleculeMotifDataset
12 |
13 |
14 | def train(model, device, loader, optimizer):
15 | model.train()
16 | total_loss = 0
17 |
18 | for step, batch in enumerate(loader):
19 | batch = batch.to(device)
20 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
21 | y = batch.y.view(pred.shape).double()
22 |
23 | loss = criterion(pred.double(), y)
24 |
25 | optimizer.zero_grad()
26 | loss.backward()
27 | optimizer.step()
28 | total_loss += loss.detach().item()
29 | return total_loss / len(loader)
30 |
31 |
32 | def eval(model, device, loader):
33 | model.eval()
34 | y_true, y_scores = [], []
35 |
36 | for step, batch in enumerate(loader):
37 | batch = batch.to(device)
38 | with torch.no_grad():
39 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
40 | true = batch.y.view(pred.shape)
41 | y_true.append(true)
42 | y_scores.append(pred)
43 |
44 | y_true = torch.cat(y_true, dim=0).cpu().numpy()
45 | y_scores = torch.cat(y_scores, dim=0).cpu().numpy()
46 |
47 | roc_list = []
48 | for i in range(y_true.shape[1]):
49 | if np.sum(y_true[:, i] == 1) > 0:
50 | roc_list.append(roc_auc_score(y_true[:, i], y_scores[:, i]))
51 | return sum(roc_list) / len(roc_list), y_true, y_scores
52 |
53 |
54 | if __name__ == '__main__':
55 | torch.manual_seed(args.runseed)
56 | np.random.seed(args.runseed)
57 | device = torch.device('cuda:' + str(args.device)) \
58 | if torch.cuda.is_available() else torch.device('cpu')
59 | if torch.cuda.is_available():
60 | torch.cuda.manual_seed_all(args.runseed)
61 |
62 | # Bunch of classification tasks
63 | num_tasks = len(RDKIT_PROPS)
64 | assert 'GEOM' in args.dataset
65 | dataset_folder = '../datasets/'
66 | dataset = MoleculeMotifDataset(dataset_folder + args.dataset, dataset=args.dataset)
67 | print(dataset)
68 |
69 | loader = DataLoader(dataset, batch_size=args.batch_size,
70 | shuffle=True, num_workers=args.num_workers)
71 |
72 | # set up model
73 | molecule_model = GNN(num_layer=args.num_layer, emb_dim=args.emb_dim,
74 | JK=args.JK, drop_ratio=args.dropout_ratio,
75 | gnn_type=args.gnn_type)
76 | model = GNN_graphpred(args=args, num_tasks=num_tasks, molecule_model=molecule_model)
77 | if not args.input_model_file == '':
78 | model.from_pretrained(args.input_model_file)
79 | model.to(device)
80 |
81 | # set up optimizer
82 | # different learning rates for different parts of GNN
83 | model_param_group = [{'params': model.molecule_model.parameters()},
84 | {'params': model.graph_pred_linear.parameters(),
85 | 'lr': args.lr * args.lr_scale}]
86 | optimizer = optim.Adam(model_param_group, lr=args.lr,
87 | weight_decay=args.decay)
88 | criterion = nn.BCEWithLogitsLoss()
89 | train_roc_list, val_roc_list, test_roc_list = [], [], []
90 | best_val_roc, best_val_idx = -1, 0
91 |
92 | print('\nStart pre-training Motif')
93 | for epoch in range(1, args.epochs + 1):
94 | loss_acc = train(model, device, loader, optimizer)
95 | print('Epoch: {}\nLoss: {}'.format(epoch, loss_acc))
96 |
97 | if args.eval_train:
98 | train_roc, train_target, train_pred = eval(model, device, loader)
99 | else:
100 | train_roc = 0
101 |
102 | train_roc_list.append(train_roc)
103 | print('train: {:.6f}\n'.format(train_roc))
104 |
105 | if args.output_model_dir is not '':
106 | print('saving to {}'.format(args.output_model_dir + '_model.pth'))
107 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth')
108 | saved_model_dict = {
109 | 'molecule_model': molecule_model.state_dict(),
110 | 'model': model.state_dict(),
111 | }
112 | torch.save(saved_model_dict, args.output_model_dir + '_model_complete.pth')
113 |
--------------------------------------------------------------------------------
/src_classification/run_molecule_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python molecule_finetune.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_classification/run_pretrain_AM.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_AM.py $@
11 | echo "end"
12 | date
--------------------------------------------------------------------------------
/src_classification/run_pretrain_CP.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_CP.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_classification/run_pretrain_Contextual.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_Contextual.py $@
11 | echo "end"
12 | date
--------------------------------------------------------------------------------
/src_classification/run_pretrain_EP.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_EP.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_classification/run_pretrain_GPT_GNN.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_GPT_GNN.py $@
11 | echo "end"
12 | date
--------------------------------------------------------------------------------
/src_classification/run_pretrain_GraphCL.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_GraphCL.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_classification/run_pretrain_GraphLoG.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_GraphLoG.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_classification/run_pretrain_GraphMVP.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 |
7 | echo $@
8 | date
9 | echo "start"
10 | python pretrain_GraphMVP.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_classification/run_pretrain_GraphMVP_hybrid.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 | echo "start"
9 | python pretrain_GraphMVP_hybrid.py $@
10 | echo "end"
11 | date
12 |
--------------------------------------------------------------------------------
/src_classification/run_pretrain_IG.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_IG.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_classification/run_pretrain_JOAO.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_JOAO.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_classification/run_pretrain_JOAOv2.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_JOAOv2.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_classification/run_pretrain_Motif.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_Motif.py $@
11 | echo "end"
12 | date
--------------------------------------------------------------------------------
/src_regression/datasets_complete_feature/__init__.py:
--------------------------------------------------------------------------------
1 | from .dti_datasets import MoleculeProteinDataset
2 | from .molecule_datasets import (MoleculeDatasetComplete,
3 | graph_data_obj_to_nx_simple,
4 | nx_to_graph_data_obj_simple)
5 | from .molecule_graphcl_dataset import MoleculeDataset_graphcl_complete
6 |
--------------------------------------------------------------------------------
/src_regression/datasets_complete_feature/dti_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | from rdkit.Chem import AllChem
7 | from torch_geometric.data import InMemoryDataset
8 |
9 | from .molecule_datasets import mol_to_graph_data_obj_simple
10 |
11 | seq_voc = "ABCDEFGHIKLMNOPQRSTUVWXYZ"
12 | seq_dict = {v:(i+1) for i,v in enumerate(seq_voc)}
13 | seq_dict_len = len(seq_dict)
14 | max_seq_len = 1000
15 |
16 |
17 | def seq_cat(prot):
18 | x = np.zeros(max_seq_len)
19 | for i, ch in enumerate(prot[:max_seq_len]):
20 | x[i] = seq_dict[ch]
21 | return x
22 |
23 |
24 | class MoleculeProteinDataset(InMemoryDataset):
25 | def __init__(self, root, dataset, mode):
26 | super(InMemoryDataset, self).__init__()
27 | self.root = root
28 | self.dataset = dataset
29 | datapath = os.path.join(self.root, self.dataset, '{}.csv'.format(mode))
30 | print('datapath\t', datapath)
31 |
32 | self.process_molecule()
33 | self.process_protein()
34 |
35 | df = pd.read_csv(datapath)
36 | self.molecule_index_list = df['smiles_id'].tolist()
37 | self.protein_index_list = df['target_id'].tolist()
38 | self.label_list = df['affinity'].tolist()
39 | self.label_list = torch.FloatTensor(self.label_list)
40 |
41 | return
42 |
43 | def process_molecule(self):
44 | input_path = os.path.join(self.root, self.dataset, 'smiles.csv')
45 | input_df = pd.read_csv(input_path, sep=',')
46 | smiles_list = input_df['smiles']
47 |
48 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
49 | preprocessed_rdkit_mol_objs_list = [m if m != None else None for m in rdkit_mol_objs_list]
50 | preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m != None else None for m in preprocessed_rdkit_mol_objs_list]
51 | assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list)
52 | assert len(smiles_list) == len(preprocessed_smiles_list)
53 |
54 | smiles_list, rdkit_mol_objs = preprocessed_smiles_list, preprocessed_rdkit_mol_objs_list
55 |
56 | data_list = []
57 | for i in range(len(smiles_list)):
58 | rdkit_mol = rdkit_mol_objs[i]
59 | if rdkit_mol != None:
60 | data = mol_to_graph_data_obj_simple(rdkit_mol)
61 | data.id = torch.tensor([i])
62 | data_list.append(data)
63 |
64 | self.molecule_list = data_list
65 | return
66 |
67 | def process_protein(self):
68 | datapath = os.path.join(self.root, self.dataset, 'protein.csv')
69 |
70 | input_df = pd.read_csv(datapath, sep=',')
71 | protein_list = input_df['protein'].tolist()
72 |
73 | self.protein_list = [seq_cat(t) for t in protein_list]
74 | self.protein_list = torch.LongTensor(self.protein_list)
75 | return
76 |
77 | def __getitem__(self, idx):
78 | molecule = self.molecule_list[self.molecule_index_list[idx]]
79 | protein = self.protein_list[self.protein_index_list[idx]]
80 | label = self.label_list[idx]
81 | return molecule, protein, label
82 |
83 | def __len__(self):
84 | return len(self.label_list)
85 |
--------------------------------------------------------------------------------
/src_regression/datasets_complete_feature/molecule_graphcl_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from itertools import repeat
4 |
5 | import numpy as np
6 | import torch
7 | from torch_geometric.data import Data
8 | from torch_geometric.utils import subgraph, to_networkx
9 |
10 | from .molecule_datasets import MoleculeDatasetComplete
11 |
12 |
13 | class MoleculeDataset_graphcl_complete(MoleculeDatasetComplete):
14 |
15 | def __init__(self,
16 | root,
17 | transform=None,
18 | pre_transform=None,
19 | pre_filter=None,
20 | dataset=None,
21 | empty=False):
22 |
23 | self.aug_prob = None
24 | self.aug_mode = 'no_aug'
25 | self.aug_strength = 0.2
26 | self.augmentations = [self.node_drop, self.subgraph,
27 | self.edge_pert, self.attr_mask, lambda x: x]
28 | super(MoleculeDataset_graphcl_complete, self).__init__(
29 | root, transform, pre_transform, pre_filter, dataset, empty)
30 |
31 | def set_augMode(self, aug_mode):
32 | self.aug_mode = aug_mode
33 |
34 | def set_augStrength(self, aug_strength):
35 | self.aug_strength = aug_strength
36 |
37 | def set_augProb(self, aug_prob):
38 | self.aug_prob = aug_prob
39 |
40 | def node_drop(self, data):
41 |
42 | node_num, _ = data.x.size()
43 | _, edge_num = data.edge_index.size()
44 | drop_num = int(node_num * self.aug_strength)
45 |
46 | idx_perm = np.random.permutation(node_num)
47 | idx_nodrop = idx_perm[drop_num:].tolist()
48 | idx_nodrop.sort()
49 |
50 | edge_idx, edge_attr = subgraph(subset=idx_nodrop,
51 | edge_index=data.edge_index,
52 | edge_attr=data.edge_attr,
53 | relabel_nodes=True,
54 | num_nodes=node_num)
55 |
56 | data.edge_index = edge_idx
57 | data.edge_attr = edge_attr
58 | data.x = data.x[idx_nodrop]
59 | data.__num_nodes__, _ = data.x.shape
60 | return data
61 |
62 | def edge_pert(self, data):
63 | node_num, _ = data.x.size()
64 | _, edge_num = data.edge_index.size()
65 | pert_num = int(edge_num * self.aug_strength)
66 |
67 | # delete edges
68 | idx_drop = np.random.choice(edge_num, (edge_num - pert_num),
69 | replace=False)
70 | edge_index = data.edge_index[:, idx_drop]
71 | edge_attr = data.edge_attr[idx_drop]
72 |
73 | # add edges
74 | adj = torch.ones((node_num, node_num))
75 | adj[edge_index[0], edge_index[1]] = 0
76 | # edge_index_nonexist = adj.nonzero(as_tuple=False).t()
77 | edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t()
78 | idx_add = np.random.choice(edge_index_nonexist.shape[1],
79 | pert_num, replace=False)
80 | edge_index_add = edge_index_nonexist[:, idx_add]
81 | # check here: https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py#L2
82 | edge_attr_add_1 = torch.tensor(np.random.randint(
83 | 5, size=(edge_index_add.shape[1], 1)))
84 | edge_attr_add_2 = torch.tensor(np.random.randint(
85 | 6, size=(edge_index_add.shape[1], 1)))
86 | edge_attr_add_3 = torch.tensor(np.random.randint(
87 | 2, size=(edge_index_add.shape[1], 1)))
88 | edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2, edge_attr_add_3), dim=1)
89 | edge_index = torch.cat((edge_index, edge_index_add), dim=1)
90 | edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0)
91 |
92 | data.edge_index = edge_index
93 | data.edge_attr = edge_attr
94 | return data
95 |
96 | def attr_mask(self, data):
97 |
98 | _x = data.x.clone()
99 | node_num, _ = data.x.size()
100 | mask_num = int(node_num * self.aug_strength)
101 |
102 | token = data.x.float().mean(dim=0).long()
103 | idx_mask = np.random.choice(
104 | node_num, mask_num, replace=False)
105 |
106 | _x[idx_mask] = token
107 | data.x = _x
108 | return data
109 |
110 | def subgraph(self, data):
111 |
112 | G = to_networkx(data)
113 | node_num, _ = data.x.size()
114 | _, edge_num = data.edge_index.size()
115 | sub_num = int(node_num * (1 - self.aug_strength))
116 |
117 | idx_sub = [np.random.randint(node_num, size=1)[0]]
118 | idx_neigh = set([n for n in G.neighbors(idx_sub[-1])])
119 |
120 | while len(idx_sub) <= sub_num:
121 | if len(idx_neigh) == 0:
122 | idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub)))
123 | idx_neigh = set([np.random.choice(idx_unsub)])
124 | sample_node = np.random.choice(list(idx_neigh))
125 |
126 | idx_sub.append(sample_node)
127 | idx_neigh = idx_neigh.union(
128 | set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub))
129 |
130 | idx_nondrop = idx_sub
131 | idx_nondrop.sort()
132 |
133 | edge_idx, edge_attr = subgraph(subset=idx_nondrop,
134 | edge_index=data.edge_index,
135 | edge_attr=data.edge_attr,
136 | relabel_nodes=True,
137 | num_nodes=node_num)
138 |
139 | data.edge_index = edge_idx
140 | data.edge_attr = edge_attr
141 | data.x = data.x[idx_nondrop]
142 | data.__num_nodes__, _ = data.x.shape
143 | return data
144 |
145 | def get(self, idx):
146 | data, data1, data2 = Data(), Data(), Data()
147 | keys_for_2D = ['x', 'edge_index', 'edge_attr']
148 | for key in self.data.keys:
149 | item, slices = self.data[key], self.slices[key]
150 | s = list(repeat(slice(None), item.dim()))
151 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
152 | if key in keys_for_2D:
153 | data[key], data1[key], data2[key] = item[s], item[s], item[s]
154 | else:
155 | data[key] = item[s]
156 |
157 | if self.aug_mode == 'no_aug':
158 | n_aug1, n_aug2 = 4, 4
159 | data1 = self.augmentations[n_aug1](data1)
160 | data2 = self.augmentations[n_aug2](data2)
161 | elif self.aug_mode == 'uniform':
162 | n_aug = np.random.choice(25, 1)[0]
163 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5
164 | data1 = self.augmentations[n_aug1](data1)
165 | data2 = self.augmentations[n_aug2](data2)
166 | elif self.aug_mode == 'sample':
167 | n_aug = np.random.choice(25, 1, p=self.aug_prob)[0]
168 | n_aug1, n_aug2 = n_aug // 5, n_aug % 5
169 | data1 = self.augmentations[n_aug1](data1)
170 | data2 = self.augmentations[n_aug2](data2)
171 | else:
172 | raise ValueError
173 | return data, data1, data2
174 |
--------------------------------------------------------------------------------
/src_regression/dti_finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import os
4 | import sys
5 | import time
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.optim as optim
13 | from torch_geometric.data import DataLoader
14 | from tqdm import tqdm
15 |
16 | sys.path.insert(0, '../src_classification')
17 | from datasets_complete_feature import MoleculeProteinDataset
18 | from models import MoleculeProteinModel, ProteinModel
19 | from models_complete_feature import GNNComplete
20 | from util import ci, mse, pearson, rmse, spearman
21 |
22 |
23 | def train(repurpose_model, device, dataloader, optimizer):
24 | repurpose_model.train()
25 | loss_accum = 0
26 | for step_idx, batch in enumerate(dataloader):
27 | molecule, protein, label = batch
28 | molecule = molecule.to(device)
29 | protein = protein.to(device)
30 | label = label.to(device)
31 |
32 | pred = repurpose_model(molecule, protein).squeeze()
33 |
34 | optimizer.zero_grad()
35 | loss = criterion(pred, label)
36 | loss.backward()
37 | optimizer.step()
38 | loss_accum += loss.detach().item()
39 | print('Loss:\t{}'.format(loss_accum / len(dataloader)))
40 |
41 |
42 | def predicting(repurpose_model, device, dataloader):
43 | repurpose_model.eval()
44 | total_preds = []
45 | total_labels = []
46 | with torch.no_grad():
47 | for batch in dataloader:
48 | molecule, protein, label = batch
49 | molecule = molecule.to(device)
50 | protein = protein.to(device)
51 | label = label.to(device)
52 | pred = repurpose_model(molecule, protein).squeeze()
53 |
54 | total_preds.append(pred.detach().cpu())
55 | total_labels.append(label.detach().cpu())
56 | total_preds = torch.cat(total_preds, dim=0)
57 | total_labels = torch.cat(total_labels, dim=0)
58 | return total_labels.numpy(), total_preds.numpy()
59 |
60 |
61 | if __name__ == '__main__':
62 | parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
63 | parser.add_argument('--device', type=int, default=0)
64 | parser.add_argument('--num_layer', type=int, default=5)
65 | parser.add_argument('--emb_dim', type=int, default=300)
66 | parser.add_argument('--dropout_ratio', type=float, default=0.)
67 | parser.add_argument('--graph_pooling', type=str, default='mean')
68 | parser.add_argument('--JK', type=str, default='last')
69 | parser.add_argument('--dataset', type=str, default='davis', choices=['davis', 'kiba'])
70 | parser.add_argument('--gnn_type', type=str, default='gin')
71 | parser.add_argument('--seed', type=int, default=42)
72 | parser.add_argument('--runseed', type=int, default=0)
73 | parser.add_argument('--batch_size', type=int, default=512)
74 | parser.add_argument('--learning_rate', type=float, default=0.0005)
75 | parser.add_argument('--epochs', type=int, default=200)
76 | parser.add_argument('--input_model_file', type=str, default='')
77 | parser.add_argument('--output_model_file', type=str, default='')
78 | ########## For protein embedding ##########
79 | parser.add_argument('--protein_emb_dim', type=int, default=300)
80 | parser.add_argument('--protein_hidden_dim', type=int, default=300)
81 | parser.add_argument('--num_features', type=int, default=25)
82 | args = parser.parse_args()
83 |
84 | torch.manual_seed(args.runseed)
85 | np.random.seed(args.runseed)
86 | device = torch.device('cuda:' + str(args.device)) if torch.cuda.is_available() else torch.device('cpu')
87 |
88 | ########## Set up dataset and dataloader ##########
89 | root = '../datasets/dti_datasets'
90 | train_val_dataset = MoleculeProteinDataset(root=root, dataset=args.dataset, mode='train')
91 | train_size = int(0.8 * len(train_val_dataset))
92 | valid_size = len(train_val_dataset) - train_size
93 | train_dataset, valid_dataset = torch.utils.data.random_split(train_val_dataset, [train_size, valid_size])
94 | test_dataset = MoleculeProteinDataset(root=root, dataset=args.dataset, mode='test')
95 | print('size of train: {}\tval: {}\ttest: {}'.format(len(train_dataset), len(valid_dataset), len(test_dataset)))
96 |
97 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
98 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False)
99 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
100 |
101 | ########## Set up model ##########
102 | molecule_model = GNNComplete(args.num_layer, args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type)
103 | ########## Load pre-trained model ##########
104 | if not args.input_model_file == '':
105 | print('========= Loading from {}'.format(args.input_model_file))
106 | molecule_model.load_state_dict(torch.load(args.input_model_file))
107 | protein_model = ProteinModel(
108 | emb_dim=args.protein_emb_dim, num_features=args.num_features, output_dim=args.protein_hidden_dim)
109 | repurpose_model = MoleculeProteinModel(
110 | molecule_model, protein_model,
111 | molecule_emb_dim=args.emb_dim, protein_emb_dim=args.protein_hidden_dim).to(device)
112 | print('repurpose model\n', repurpose_model)
113 |
114 | criterion = nn.MSELoss()
115 | optimizer = torch.optim.Adam(repurpose_model.parameters(), lr=args.learning_rate)
116 |
117 | best_repurpose_model = None
118 | best_mse = 1000
119 | best_epoch = 0
120 |
121 | for epoch in range(1, 1+args.epochs):
122 | start_time = time.time()
123 | print('Start training at epoch: {}'.format(epoch))
124 | train(repurpose_model, device, train_dataloader, optimizer)
125 |
126 | G, P = predicting(repurpose_model, device, valid_dataloader)
127 | current_mse = mse(G, P)
128 | print('MSE:\t{}'.format(current_mse))
129 | if current_mse < best_mse:
130 | best_repurpose_model = copy.deepcopy(repurpose_model)
131 | best_mse = current_mse
132 | best_epoch = epoch
133 | print('MSE improved at epoch {}\tbest MSE: {}'.format(best_epoch, best_mse))
134 | else:
135 | print('No improvement since epoch {}\tbest MSE: {}'.format(best_epoch, best_mse))
136 | print('Took {:.5f}s.'.format(time.time() - start_time))
137 | print()
138 |
139 | start_time = time.time()
140 | print('Last epoch: {}'.format(args.epochs))
141 | G, P = predicting(repurpose_model, device, test_dataloader)
142 | ret = [rmse(G, P), mse(G, P), pearson(G, P), spearman(G, P), ci(G, P)]
143 | print('RMSE: {}\tMSE: {}\tPearson: {}\tSpearman: {}\tCI: {}'.format(ret[0], ret[1], ret[2], ret[3], ret[4]))
144 | print('Took {:.5f}s.'.format(time.time() - start_time))
145 |
146 | start_time = time.time()
147 | print('Best epoch: {}'.format(best_epoch))
148 | G, P = predicting(best_repurpose_model, device, test_dataloader)
149 | ret = [rmse(G, P), mse(G, P), pearson(G, P), spearman(G, P), ci(G, P)]
150 | print('RMSE: {}\tMSE: {}\tPearson: {}\tSpearman: {}\tCI: {}'.format(ret[0], ret[1], ret[2], ret[3], ret[4]))
151 | print('Took {:.5f}s.'.format(time.time() - start_time))
152 |
153 | if not args.output_model_file == '':
154 | torch.save({
155 | 'repurpose_model': repurpose_model.state_dict(),
156 | 'best_repurpose_model': best_repurpose_model.state_dict()
157 | }, args.output_model_file + '.pth')
158 |
--------------------------------------------------------------------------------
/src_regression/models_complete_feature/__init__.py:
--------------------------------------------------------------------------------
1 | from .molecule_gnn_model import GNN_graphpredComplete, GNNComplete
2 |
--------------------------------------------------------------------------------
/src_regression/models_complete_feature/molecule_gnn_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
5 | from torch_geometric.nn import (MessagePassing, global_add_pool,
6 | global_max_pool, global_mean_pool)
7 | from torch_geometric.nn.inits import glorot, zeros
8 | from torch_geometric.utils import add_self_loops, softmax
9 | from torch_scatter import scatter_add
10 |
11 |
12 | class GINConv(MessagePassing):
13 | def __init__(self, emb_dim, aggr="add"):
14 | super(GINConv, self).__init__(aggr=aggr)
15 |
16 | self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
17 | self.eps = torch.nn.Parameter(torch.Tensor([0]))
18 |
19 | self.bond_encoder = BondEncoder(emb_dim = emb_dim)
20 |
21 | def forward(self, x, edge_index, edge_attr):
22 | edge_embedding = self.bond_encoder(edge_attr)
23 | out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
24 | return out
25 |
26 | def message(self, x_j, edge_attr):
27 | return F.relu(x_j + edge_attr)
28 |
29 | def update(self, aggr_out):
30 | return aggr_out
31 |
32 |
33 | class GCNConv(MessagePassing):
34 | def __init__(self, emb_dim, aggr="add"):
35 | super(GCNConv, self).__init__(aggr=aggr)
36 |
37 | self.linear = torch.nn.Linear(emb_dim, emb_dim)
38 | self.root_emb = torch.nn.Embedding(1, emb_dim)
39 | self.bond_encoder = BondEncoder(emb_dim = emb_dim)
40 |
41 | def forward(self, x, edge_index, edge_attr):
42 | x = self.linear(x)
43 | edge_embedding = self.bond_encoder(edge_attr)
44 |
45 | row, col = edge_index
46 |
47 | deg = degree(row, x.size(0), dtype = x.dtype) + 1
48 | deg_inv_sqrt = deg.pow(-0.5)
49 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
50 |
51 | norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
52 |
53 | return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)
54 |
55 | def message(self, x_j, edge_attr, norm):
56 | return norm.view(-1, 1) * F.relu(x_j + edge_attr)
57 |
58 | def update(self, aggr_out):
59 | return aggr_out
60 |
61 |
62 | class GNNComplete(nn.Module):
63 | def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0., gnn_type="gin"):
64 |
65 | if num_layer < 2:
66 | raise ValueError("Number of GNN layers must be greater than 1.")
67 |
68 | super(GNNComplete, self).__init__()
69 | self.drop_ratio = drop_ratio
70 | self.num_layer = num_layer
71 | self.JK = JK
72 |
73 | self.atom_encoder = AtomEncoder(emb_dim)
74 |
75 | ###List of MLPs
76 | self.gnns = nn.ModuleList()
77 | for layer in range(num_layer):
78 | if gnn_type == "gin":
79 | self.gnns.append(GINConv(emb_dim, aggr="add"))
80 | elif gnn_type == "gcn":
81 | self.gnns.append(GCNConv(emb_dim, aggr="add"))
82 |
83 | ###List of batchnorms
84 | self.batch_norms = nn.ModuleList()
85 | for layer in range(num_layer):
86 | self.batch_norms.append(nn.BatchNorm1d(emb_dim))
87 |
88 | # def forward(self, x, edge_index, edge_attr):
89 | def forward(self, *argv):
90 | if len(argv) == 3:
91 | x, edge_index, edge_attr = argv[0], argv[1], argv[2]
92 | elif len(argv) == 1:
93 | data = argv[0]
94 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
95 | else:
96 | raise ValueError("unmatched number of arguments.")
97 |
98 | x = self.atom_encoder(x)
99 |
100 | h_list = [x]
101 | for layer in range(self.num_layer):
102 | h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
103 | h = self.batch_norms[layer](h)
104 | # h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
105 | if layer == self.num_layer - 1:
106 | # remove relu for the last layer
107 | h = F.dropout(h, self.drop_ratio, training=self.training)
108 | else:
109 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
110 | h_list.append(h)
111 |
112 | ### Different implementations of Jk-concat
113 | if self.JK == "concat":
114 | node_representation = torch.cat(h_list, dim=1)
115 | elif self.JK == "last":
116 | node_representation = h_list[-1]
117 | elif self.JK == "max":
118 | h_list = [h.unsqueeze_(0) for h in h_list]
119 | node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0]
120 | elif self.JK == "sum":
121 | h_list = [h.unsqueeze_(0) for h in h_list]
122 | node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]
123 | else:
124 | raise ValueError("not implemented.")
125 | return node_representation
126 |
127 |
128 | class GNN_graphpredComplete(nn.Module):
129 | def __init__(self, args, num_tasks, molecule_model=None):
130 | super(GNN_graphpredComplete, self).__init__()
131 |
132 | if args.num_layer < 2:
133 | raise ValueError("# layers must > 1.")
134 |
135 | self.molecule_model = molecule_model
136 | self.num_layer = args.num_layer
137 | self.emb_dim = args.emb_dim
138 | self.num_tasks = num_tasks
139 | self.JK = args.JK
140 |
141 | # Different kind of graph pooling
142 | if args.graph_pooling == "sum":
143 | self.pool = global_add_pool
144 | elif args.graph_pooling == "mean":
145 | self.pool = global_mean_pool
146 | elif args.graph_pooling == "max":
147 | self.pool = global_max_pool
148 | else:
149 | raise ValueError("Invalid graph pooling type.")
150 |
151 | # For graph-level binary classification
152 | self.mult = 1
153 |
154 | if self.JK == "concat":
155 | self.graph_pred_linear = nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim,
156 | self.num_tasks)
157 | else:
158 | self.graph_pred_linear = nn.Linear(self.mult * self.emb_dim, self.num_tasks)
159 | return
160 |
161 | def from_pretrained(self, model_file):
162 | self.molecule_model.load_state_dict(torch.load(model_file))
163 | return
164 |
165 | def get_graph_representation(self, *argv):
166 | if len(argv) == 4:
167 | x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
168 | elif len(argv) == 1:
169 | data = argv[0]
170 | x, edge_index, edge_attr, batch = data.x, data.edge_index, \
171 | data.edge_attr, data.batch
172 | else:
173 | raise ValueError("unmatched number of arguments.")
174 |
175 | node_representation = self.molecule_model(x, edge_index, edge_attr)
176 | graph_representation = self.pool(node_representation, batch)
177 | pred = self.graph_pred_linear(graph_representation)
178 |
179 | return graph_representation, pred
180 |
181 | def forward(self, *argv):
182 | if len(argv) == 4:
183 | x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
184 | elif len(argv) == 1:
185 | data = argv[0]
186 | x, edge_index, edge_attr, batch = data.x, data.edge_index, \
187 | data.edge_attr, data.batch
188 | else:
189 | raise ValueError("unmatched number of arguments.")
190 |
191 | node_representation = self.molecule_model(x, edge_index, edge_attr)
192 | graph_representation = self.pool(node_representation, batch)
193 | output = self.graph_pred_linear(graph_representation)
194 |
195 | return output
196 |
--------------------------------------------------------------------------------
/src_regression/molecule_finetune_regression.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 |
10 | sys.path.insert(0, '../src_classification')
11 | from os.path import join
12 |
13 | from config import args
14 | from datasets_complete_feature import MoleculeDatasetComplete
15 | from models_complete_feature import GNN_graphpredComplete, GNNComplete
16 | from sklearn.metrics import mean_absolute_error, mean_squared_error
17 | from splitters import random_scaffold_split, random_split, scaffold_split
18 | from torch_geometric.data import DataLoader
19 |
20 |
21 | def train(model, device, loader, optimizer):
22 | model.train()
23 | total_loss = 0
24 |
25 | for step, batch in enumerate(loader):
26 | batch = batch.to(device)
27 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch).squeeze()
28 | y = batch.y.squeeze()
29 |
30 | loss = reg_criterion(pred, y)
31 |
32 | optimizer.zero_grad()
33 | loss.backward()
34 | optimizer.step()
35 | total_loss += loss.detach().item()
36 |
37 | return total_loss / len(loader)
38 |
39 |
40 | def eval(model, device, loader):
41 | model.eval()
42 | y_true, y_pred = [], []
43 |
44 | for step, batch in enumerate(loader):
45 | batch = batch.to(device)
46 | with torch.no_grad():
47 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch).squeeze(1)
48 |
49 | true = batch.y.view(pred.shape)
50 | y_true.append(true)
51 | y_pred.append(pred)
52 |
53 | y_true = torch.cat(y_true, dim=0).cpu().numpy()
54 | y_pred = torch.cat(y_pred, dim=0).cpu().numpy()
55 | rmse = mean_squared_error(y_true, y_pred, squared=False)
56 | mae = mean_absolute_error(y_true, y_pred)
57 | return {'RMSE': rmse, 'MAE': mae}, y_true, y_pred
58 |
59 |
60 | if __name__ == '__main__':
61 | torch.manual_seed(args.runseed)
62 | np.random.seed(args.runseed)
63 | device = torch.device('cuda:' + str(args.device)) \
64 | if torch.cuda.is_available() else torch.device('cpu')
65 | if torch.cuda.is_available():
66 | torch.cuda.manual_seed_all(args.runseed)
67 |
68 | num_tasks = 1
69 | dataset_folder = '../datasets/molecule_datasets_regression/'
70 | dataset_folder = os.path.join(dataset_folder, args.dataset)
71 | dataset = MoleculeDatasetComplete(dataset_folder, dataset=args.dataset)
72 | print('dataset_folder:', dataset_folder)
73 | print(dataset)
74 |
75 | if args.split == 'scaffold':
76 | smiles_list = pd.read_csv(dataset_folder + '/processed/smiles.csv', header=None)[0].tolist()
77 | train_dataset, valid_dataset, test_dataset = scaffold_split(
78 | dataset, smiles_list, null_value=0, frac_train=0.8,
79 | frac_valid=0.1, frac_test=0.1)
80 | print('split via scaffold')
81 | elif args.split == 'random':
82 | train_dataset, valid_dataset, test_dataset = random_split(
83 | dataset, null_value=0, frac_train=0.8, frac_valid=0.1,
84 | frac_test=0.1, seed=args.seed)
85 | print('randomly split')
86 | elif args.split == 'random_scaffold':
87 | smiles_list = pd.read_csv(dataset_folder + '/processed/smiles.csv', header=None)[0].tolist()
88 | train_dataset, valid_dataset, test_dataset = random_scaffold_split(
89 | dataset, smiles_list, null_value=0, frac_train=0.8,
90 | frac_valid=0.1, frac_test=0.1, seed=args.seed)
91 | print('random scaffold')
92 | else:
93 | raise ValueError('Invalid split option.')
94 |
95 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
96 | val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
97 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
98 |
99 | # set up model
100 | molecule_model = GNNComplete(num_layer=args.num_layer, emb_dim=args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type)
101 | model = GNN_graphpredComplete(args=args, num_tasks=num_tasks, molecule_model=molecule_model)
102 | if not args.input_model_file == '':
103 | model.from_pretrained(args.input_model_file)
104 | model.to(device)
105 | print(model)
106 |
107 | model_param_group = [
108 | {'params': model.molecule_model.parameters()},
109 | {'params': model.graph_pred_linear.parameters(), 'lr': args.lr * args.lr_scale}
110 | ]
111 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
112 | # reg_criterion = torch.nn.L1Loss()
113 | reg_criterion = torch.nn.MSELoss()
114 |
115 | train_result_list, val_result_list, test_result_list = [], [], []
116 | # metric_list = ['RMSE', 'MAE', 'R2']
117 | metric_list = ['RMSE', 'MAE']
118 | best_val_rmse, best_val_idx = 1e10, 0
119 |
120 | for epoch in range(1, args.epochs + 1):
121 | loss_acc = train(model, device, train_loader, optimizer)
122 | print('Epoch: {}\nLoss: {}'.format(epoch, loss_acc))
123 |
124 | if args.eval_train:
125 | train_result, train_target, train_pred = eval(model, device, train_loader)
126 | else:
127 | train_result = {'RMSE': 0, 'MAE': 0, 'R2': 0}
128 | val_result, val_target, val_pred = eval(model, device, val_loader)
129 | test_result, test_target, test_pred = eval(model, device, test_loader)
130 |
131 | train_result_list.append(train_result)
132 | val_result_list.append(val_result)
133 | test_result_list.append(test_result)
134 |
135 | for metric in metric_list:
136 | print('{} train: {:.6f}\tval: {:.6f}\ttest: {:.6f}'.format(metric, train_result[metric], val_result[metric], test_result[metric]))
137 | print()
138 |
139 | if val_result['RMSE'] < best_val_rmse:
140 | best_val_rmse = val_result['RMSE']
141 | best_val_idx = epoch - 1
142 | if not args.output_model_dir == '':
143 | output_model_path = join(args.output_model_dir, 'model_best.pth')
144 | saved_model_dict = {
145 | 'molecule_model': molecule_model.state_dict(),
146 | 'model': model.state_dict()
147 | }
148 | torch.save(saved_model_dict, output_model_path)
149 |
150 | filename = join(args.output_model_dir, 'evaluation_best.pth')
151 | np.savez(filename, val_target=val_target, val_pred=val_pred,
152 | test_target=test_target, test_pred=test_pred)
153 |
154 | for metric in metric_list:
155 | print('Best (RMSE), {} train: {:.6f}\tval: {:.6f}\ttest: {:.6f}'.format(
156 | metric, train_result_list[best_val_idx][metric], val_result_list[best_val_idx][metric], test_result_list[best_val_idx][metric]))
157 |
158 | if args.output_model_dir is not '':
159 | output_model_path = join(args.output_model_dir, 'model_final.pth')
160 | saved_model_dict = {
161 | 'molecule_model': molecule_model.state_dict(),
162 | 'model': model.state_dict()
163 | }
164 | torch.save(saved_model_dict, output_model_path)
165 |
--------------------------------------------------------------------------------
/src_regression/pretrain_AM.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 |
9 | sys.path.insert(0, '../src_classification')
10 | from config import args
11 | from dataloader import DataLoaderMasking
12 | from datasets_complete_feature import MoleculeDatasetComplete
13 | from models_complete_feature import GNNComplete
14 | from torch_geometric.nn import global_mean_pool
15 | from util_complete_feature import MaskAtom
16 |
17 |
18 | def compute_accuracy(pred, target):
19 | return float(torch.sum(torch.max(pred.detach(), dim=1)[1] == target).cpu().item())/len(pred)
20 |
21 |
22 | def do_AttrMasking(batch, criterion, node_repr, molecule_atom_masking_model):
23 | target = batch.mask_node_label[:, 0]
24 | node_pred = molecule_atom_masking_model(node_repr[batch.masked_atom_indices])
25 | attributemask_loss = criterion(node_pred.double(), target)
26 | attributemask_acc = compute_accuracy(node_pred, target)
27 | return attributemask_loss, attributemask_acc
28 |
29 |
30 | def train(device, loader, optimizer):
31 |
32 | start = time.time()
33 | molecule_model.train()
34 | molecule_atom_masking_model.train()
35 | attributemask_loss_accum, attributemask_acc_accum = 0, 0
36 |
37 | for step, batch in enumerate(loader):
38 | batch = batch.to(device)
39 | node_repr = molecule_model(batch.masked_x, batch.edge_index, batch.edge_attr)
40 |
41 | attributemask_loss, attributemask_acc = do_AttrMasking(
42 | batch=batch, criterion=criterion, node_repr=node_repr,
43 | molecule_atom_masking_model=molecule_atom_masking_model)
44 |
45 | attributemask_loss_accum += attributemask_loss.detach().cpu().item()
46 | attributemask_acc_accum += attributemask_acc
47 | loss = attributemask_loss
48 |
49 | optimizer.zero_grad()
50 | loss.backward()
51 | optimizer.step()
52 |
53 | print('AM Loss: {:.5f}\tAM Acc: {:.5f}\tTime: {:.5f}'.format(
54 | attributemask_loss_accum / len(loader),
55 | attributemask_acc_accum / len(loader),
56 | time.time() - start))
57 | return
58 |
59 |
60 | if __name__ == '__main__':
61 |
62 | np.random.seed(0)
63 | torch.manual_seed(0)
64 | device = torch.device('cuda:' + str(args.device)) \
65 | if torch.cuda.is_available() else torch.device('cpu')
66 | if torch.cuda.is_available():
67 | torch.cuda.manual_seed_all(0)
68 | torch.cuda.set_device(args.device)
69 |
70 | if 'GEOM' in args.dataset:
71 | dataset = MoleculeDatasetComplete(
72 | '../datasets/{}/'.format(args.dataset), dataset=args.dataset,
73 | transform=MaskAtom(num_atom_type=119, num_edge_type=5,
74 | mask_rate=args.mask_rate, mask_edge=args.mask_edge))
75 | loader = DataLoaderMasking(dataset, batch_size=args.batch_size,
76 | shuffle=True, num_workers=args.num_workers)
77 |
78 | # set up model
79 | molecule_model = GNNComplete(args.num_layer, args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device)
80 | molecule_readout_func = global_mean_pool
81 |
82 | molecule_atom_masking_model = torch.nn.Linear(args.emb_dim, 119).to(device)
83 |
84 | model_param_group = [{'params': molecule_model.parameters(), 'lr': args.lr},
85 | {'params': molecule_atom_masking_model.parameters(), 'lr': args.lr}]
86 |
87 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
88 | criterion = nn.CrossEntropyLoss()
89 |
90 | for epoch in range(1, args.epochs + 1):
91 | print('epoch: {}'.format(epoch))
92 | train(device, loader, optimizer)
93 |
94 | if not args.output_model_dir == '':
95 | torch.save(molecule_model.state_dict(), args.output_model_dir + '_model.pth')
96 |
97 | saver_dict = {'model': molecule_model.state_dict(),
98 | 'molecule_atom_masking_model': molecule_atom_masking_model.state_dict()}
99 |
100 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
101 |
--------------------------------------------------------------------------------
/src_regression/pretrain_CP.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 |
9 | sys.path.insert(0, '../src_classification')
10 |
11 | from config import args
12 | from dataloader import DataLoaderSubstructContext
13 | from datasets_complete_feature import MoleculeDatasetComplete
14 | from models_complete_feature import GNNComplete
15 | from torch_geometric.nn import global_mean_pool
16 | from util import cycle_index
17 | from util_complete_feature import ExtractSubstructureContextPair
18 |
19 |
20 | def do_ContextPred(batch, criterion, args, molecule_substruct_model,
21 | molecule_context_model, molecule_readout_func):
22 |
23 | # creating substructure representation
24 | substruct_repr = molecule_substruct_model(
25 | batch.x_substruct, batch.edge_index_substruct,
26 | batch.edge_attr_substruct)[batch.center_substruct_idx]
27 |
28 | # creating context representations
29 | overlapped_node_repr = molecule_context_model(
30 | batch.x_context, batch.edge_index_context,
31 | batch.edge_attr_context)[batch.overlap_context_substruct_idx]
32 |
33 | # positive context representation
34 | # readout -> global_mean_pool by default
35 | context_repr = molecule_readout_func(overlapped_node_repr,
36 | batch.batch_overlapped_context)
37 |
38 | # negative contexts are obtained by shifting
39 | # the indices of context embeddings
40 | neg_context_repr = torch.cat(
41 | [context_repr[cycle_index(len(context_repr), i + 1)]
42 | for i in range(args.contextpred_neg_samples)], dim=0)
43 |
44 | num_neg = args.contextpred_neg_samples
45 | pred_pos = torch.sum(substruct_repr * context_repr, dim=1)
46 | pred_neg = torch.sum(substruct_repr.repeat((num_neg, 1)) * neg_context_repr, dim=1)
47 |
48 | loss_pos = criterion(pred_pos.double(),
49 | torch.ones(len(pred_pos)).to(pred_pos.device).double())
50 | loss_neg = criterion(pred_neg.double(),
51 | torch.zeros(len(pred_neg)).to(pred_neg.device).double())
52 |
53 | contextpred_loss = loss_pos + num_neg * loss_neg
54 |
55 | num_pred = len(pred_pos) + len(pred_neg)
56 | contextpred_acc = (torch.sum(pred_pos > 0).float() +
57 | torch.sum(pred_neg < 0).float()) / num_pred
58 | contextpred_acc = contextpred_acc.detach().cpu().item()
59 |
60 | return contextpred_loss, contextpred_acc
61 |
62 |
63 | def train(args, device, loader, optimizer):
64 |
65 | start_time = time.time()
66 | molecule_context_model.train()
67 | molecule_substruct_model.train()
68 | contextpred_loss_accum, contextpred_acc_accum = 0, 0
69 |
70 | for step, batch in enumerate(loader):
71 |
72 | batch = batch.to(device)
73 | contextpred_loss, contextpred_acc = do_ContextPred(
74 | batch=batch, criterion=criterion, args=args,
75 | molecule_substruct_model=molecule_substruct_model,
76 | molecule_context_model=molecule_context_model,
77 | molecule_readout_func=molecule_readout_func)
78 |
79 | contextpred_loss_accum += contextpred_loss.detach().cpu().item()
80 | contextpred_acc_accum += contextpred_acc
81 | ssl_loss = contextpred_loss
82 | optimizer.zero_grad()
83 | ssl_loss.backward()
84 | optimizer.step()
85 |
86 | print('CP Loss: {:.5f}\tCP Acc: {:.5f}\tTime: {:.3f}'.format(
87 | contextpred_loss_accum / len(loader),
88 | contextpred_acc_accum / len(loader),
89 | time.time() - start_time))
90 |
91 | return
92 |
93 |
94 | if __name__ == '__main__':
95 |
96 | np.random.seed(0)
97 | torch.manual_seed(0)
98 | device = torch.device('cuda:' + str(args.device)) \
99 | if torch.cuda.is_available() else torch.device('cpu')
100 | if torch.cuda.is_available():
101 | torch.cuda.manual_seed_all(0)
102 | torch.cuda.set_device(args.device)
103 |
104 | l1 = args.num_layer - 1
105 | l2 = l1 + args.csize
106 | print('num layer: %d l1: %d l2: %d' % (args.num_layer, l1, l2))
107 |
108 | if 'GEOM' in args.dataset:
109 | dataset = MoleculeDatasetComplete(
110 | '../datasets/{}/'.format(args.dataset), dataset=args.dataset,
111 | transform=ExtractSubstructureContextPair(args.num_layer, l1, l2))
112 | loader = DataLoaderSubstructContext(dataset, batch_size=args.batch_size,
113 | shuffle=True, num_workers=args.num_workers)
114 |
115 | ''' === set up model, mainly used in do_ContextPred() === '''
116 | molecule_substruct_model = GNNComplete(
117 | args.num_layer, args.emb_dim, JK=args.JK,
118 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device)
119 | molecule_context_model = GNNComplete(
120 | int(l2 - l1), args.emb_dim, JK=args.JK,
121 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device)
122 |
123 | ''' === set up loss and optimiser === '''
124 | criterion = nn.BCEWithLogitsLoss()
125 | molecule_readout_func = global_mean_pool
126 | model_param_group = [{'params': molecule_substruct_model.parameters(), 'lr': args.lr},
127 | {'params': molecule_context_model.parameters(), 'lr': args.lr}]
128 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
129 |
130 | for epoch in range(1, args.epochs + 1):
131 | print('epoch: {}'.format(epoch))
132 | train(args, device, loader, optimizer)
133 |
134 | if not args.output_model_dir == '':
135 | torch.save(molecule_substruct_model.state_dict(),
136 | args.output_model_dir + '_model.pth')
137 |
138 | saver_dict = {
139 | 'molecule_substruct_model': molecule_substruct_model.state_dict(),
140 | 'molecule_context_model': molecule_context_model.state_dict()}
141 |
142 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
143 |
--------------------------------------------------------------------------------
/src_regression/pretrain_GraphCL.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | import torch.optim as optim
8 |
9 | sys.path.insert(0, '../src_classification')
10 |
11 | from datasets_complete_feature import MoleculeDataset_graphcl_complete
12 | from models_complete_feature import GNNComplete
13 | from pretrain_JOAO import graphcl
14 | from torch_geometric.data import DataLoader
15 |
16 |
17 | def train(loader, model, optimizer, device):
18 |
19 | model.train()
20 | train_loss_accum = 0
21 |
22 | for step, (_, batch1, batch2) in enumerate(loader):
23 | # _, batch1, batch2 = batch
24 | batch1 = batch1.to(device)
25 | batch2 = batch2.to(device)
26 |
27 | x1 = model.forward_cl(batch1.x, batch1.edge_index,
28 | batch1.edge_attr, batch1.batch)
29 | x2 = model.forward_cl(batch2.x, batch2.edge_index,
30 | batch2.edge_attr, batch2.batch)
31 | loss = model.loss_cl(x1, x2)
32 |
33 | optimizer.zero_grad()
34 | loss.backward()
35 | optimizer.step()
36 |
37 | train_loss_accum += float(loss.detach().cpu().item())
38 |
39 | return train_loss_accum / (step + 1)
40 |
41 |
42 | if __name__ == "__main__":
43 | # Training settings
44 | parser = argparse.ArgumentParser(description='GraphCL')
45 | parser.add_argument('--device', type=int, default=0, help='gpu')
46 | parser.add_argument('--batch_size', type=int, default=256, help='batch')
47 | parser.add_argument('--decay', type=float, default=0, help='weight decay')
48 | parser.add_argument('--epochs', type=int, default=100, help='train epochs')
49 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
50 | parser.add_argument('--JK', type=str, default="last",
51 | choices=['last', 'sum', 'max', 'concat'],
52 | help='how the node features across layers are combined.')
53 | parser.add_argument('--gnn_type', type=str, default="gin", help='gnn model type')
54 | parser.add_argument('--dropout_ratio', type=float, default=0, help='dropout ratio')
55 | parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions')
56 | parser.add_argument('--dataset', type=str, default=None, help='root dir of dataset')
57 | parser.add_argument('--num_layer', type=int, default=5, help='message passing layers')
58 | parser.add_argument('--output_model_file', type=str, default='', help='model save path')
59 | parser.add_argument('--num_workers', type=int, default=8, help='workers for dataset loading')
60 |
61 | parser.add_argument('--aug_mode', type=str, default='sample')
62 | parser.add_argument('--aug_strength', type=float, default=0.2)
63 |
64 | # parser.add_argument('--gamma', type=float, default=0.1)
65 | parser.add_argument('--output_model_dir', type=str, default='')
66 | args = parser.parse_args()
67 |
68 | torch.manual_seed(0)
69 | np.random.seed(0)
70 | device = torch.device("cuda:" + str(args.device)) \
71 | if torch.cuda.is_available() else torch.device("cpu")
72 | if torch.cuda.is_available():
73 | torch.cuda.manual_seed_all(0)
74 |
75 | # set up dataset
76 | if 'GEOM' in args.dataset:
77 | dataset = MoleculeDataset_graphcl_complete('../datasets/{}/'.format(args.dataset), dataset=args.dataset)
78 | dataset.set_augMode(args.aug_mode)
79 | dataset.set_augStrength(args.aug_strength)
80 | print(dataset)
81 |
82 | loader = DataLoader(dataset, batch_size=args.batch_size,
83 | num_workers=args.num_workers, shuffle=True)
84 |
85 | # set up model
86 | gnn = GNNComplete(num_layer=args.num_layer, emb_dim=args.emb_dim, JK=args.JK,
87 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type)
88 |
89 | model = graphcl(gnn)
90 | model.to(device)
91 |
92 | # set up optimizer
93 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
94 | print(optimizer)
95 |
96 | aug_prob = np.ones(25) / 25
97 | dataset.set_augProb(aug_prob)
98 | for epoch in range(1, args.epochs + 1):
99 | start_time = time.time()
100 | pretrain_loss = train(loader, model, optimizer, device)
101 |
102 | print('Epoch: {:3d}\tLoss:{:.3f}\tTime: {:.3f}'.format(
103 | epoch, pretrain_loss, time.time() - start_time))
104 |
105 | if not args.output_model_dir == '':
106 | saver_dict = {'model': model.state_dict()}
107 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
108 | torch.save(model.gnn.state_dict(), args.output_model_dir + '_model.pth')
109 |
--------------------------------------------------------------------------------
/src_regression/pretrain_GraphMVP.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 |
9 | sys.path.insert(0, '../src_classification')
10 | from config import args
11 | from models import AutoEncoder, SchNet, VariationalAutoEncoder
12 | from models_complete_feature import GNNComplete
13 | from torch_geometric.data import DataLoader
14 | from torch_geometric.nn import global_mean_pool
15 | from tqdm import tqdm
16 | from util import dual_CL
17 |
18 | from datasets import Molecule3DMaskingDataset
19 |
20 |
21 | def save_model(save_best):
22 | if not args.output_model_dir == '':
23 | if save_best:
24 | global optimal_loss
25 | print('save model with loss: {:.5f}'.format(optimal_loss))
26 | torch.save(molecule_model_2D.state_dict(), args.output_model_dir + '_model.pth')
27 | saver_dict = {
28 | 'model': molecule_model_2D.state_dict(),
29 | 'model_3D': molecule_model_3D.state_dict(),
30 | 'AE_2D_3D_model': AE_2D_3D_model.state_dict(),
31 | 'AE_3D_2D_model': AE_3D_2D_model.state_dict(),
32 | }
33 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
34 |
35 | else:
36 | torch.save(molecule_model_2D.state_dict(), args.output_model_dir + '_model_final.pth')
37 | saver_dict = {
38 | 'model': molecule_model_2D.state_dict(),
39 | 'model_3D': molecule_model_3D.state_dict(),
40 | 'AE_2D_3D_model': AE_2D_3D_model.state_dict(),
41 | 'AE_3D_2D_model': AE_3D_2D_model.state_dict(),
42 | }
43 | torch.save(saver_dict, args.output_model_dir + '_model_complete_final.pth')
44 | return
45 |
46 |
47 | def train(args, molecule_model_2D, device, loader, optimizer):
48 | start_time = time.time()
49 |
50 | molecule_model_2D.train()
51 | molecule_model_3D.train()
52 | if molecule_projection_layer is not None:
53 | molecule_projection_layer.train()
54 |
55 | AE_loss_accum, AE_acc_accum = 0, 0
56 | CL_loss_accum, CL_acc_accum = 0, 0
57 |
58 | if args.verbose:
59 | l = tqdm(loader)
60 | else:
61 | l = loader
62 | for step, batch in enumerate(l):
63 | batch = batch.to(device)
64 |
65 | node_repr = molecule_model_2D(batch.x, batch.edge_index, batch.edge_attr)
66 | molecule_2D_repr = molecule_readout_func(node_repr, batch.batch)
67 |
68 | if args.model_3d == 'schnet':
69 | molecule_3D_repr = molecule_model_3D(batch.x[:, 0], batch.positions, batch.batch)
70 |
71 | CL_loss, CL_acc = dual_CL(molecule_2D_repr, molecule_3D_repr, args)
72 |
73 | AE_loss_1 = AE_2D_3D_model(molecule_2D_repr, molecule_3D_repr)
74 | AE_loss_2 = AE_3D_2D_model(molecule_3D_repr, molecule_2D_repr)
75 | AE_acc_1 = AE_acc_2 = 0
76 | AE_loss = (AE_loss_1 + AE_loss_2) / 2
77 |
78 | CL_loss_accum += CL_loss.detach().cpu().item()
79 | CL_acc_accum += CL_acc
80 | AE_loss_accum += AE_loss.detach().cpu().item()
81 | AE_acc_accum += (AE_acc_1 + AE_acc_2) / 2
82 |
83 | loss = 0
84 | if args.alpha_1 > 0:
85 | loss += CL_loss * args.alpha_1
86 | if args.alpha_2 > 0:
87 | loss += AE_loss * args.alpha_2
88 |
89 | optimizer.zero_grad()
90 | loss.backward()
91 | optimizer.step()
92 |
93 | global optimal_loss
94 | CL_loss_accum /= len(loader)
95 | CL_acc_accum /= len(loader)
96 | AE_loss_accum /= len(loader)
97 | AE_acc_accum /= len(loader)
98 | temp_loss = args.alpha_1 * CL_loss_accum + args.alpha_2 * AE_loss_accum
99 | if temp_loss < optimal_loss:
100 | optimal_loss = temp_loss
101 | save_model(save_best=True)
102 | print('CL Loss: {:.5f}\tCL Acc: {:.5f}\t\tAE Loss: {:.5f}\tAE Acc: {:.5f}\tTime: {:.5f}'.format(
103 | CL_loss_accum, CL_acc_accum, AE_loss_accum, AE_acc_accum, time.time() - start_time))
104 | return
105 |
106 |
107 | if __name__ == '__main__':
108 | torch.manual_seed(0)
109 | np.random.seed(0)
110 | device = torch.device('cuda:' + str(args.device)) if torch.cuda.is_available() else torch.device('cpu')
111 | if torch.cuda.is_available():
112 | torch.cuda.manual_seed_all(0)
113 | torch.cuda.set_device(args.device)
114 |
115 | if 'GEOM' in args.dataset:
116 | data_root = '../datasets/{}/'.format(args.dataset) if args.input_data_dir == '' else '{}/{}/'.format(args.input_data_dir, args.dataset)
117 | dataset = Molecule3DMaskingDataset(data_root, dataset=args.dataset, mask_ratio=args.SSL_masking_ratio)
118 | else:
119 | raise Exception
120 | loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
121 |
122 | # set up model
123 | molecule_model_2D = GNNComplete(args.num_layer, args.emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type).to(device)
124 | molecule_readout_func = global_mean_pool
125 |
126 | print('Using 3d model\t', args.model_3d)
127 | molecule_projection_layer = None
128 | if args.model_3d == 'schnet':
129 | molecule_model_3D = SchNet(
130 | hidden_channels=args.emb_dim, num_filters=args.num_filters, num_interactions=args.num_interactions,
131 | num_gaussians=args.num_gaussians, cutoff=args.cutoff, atomref=None, readout=args.readout).to(device)
132 | else:
133 | raise NotImplementedError('Model {} not included.'.format(args.model_3d))
134 |
135 | if args.AE_model == 'AE':
136 | AE_2D_3D_model = AutoEncoder(
137 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target).to(device)
138 | AE_3D_2D_model = AutoEncoder(
139 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target).to(device)
140 | elif args.AE_model == 'VAE':
141 | AE_2D_3D_model = VariationalAutoEncoder(
142 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target, beta=args.beta).to(device)
143 | AE_3D_2D_model = VariationalAutoEncoder(
144 | emb_dim=args.emb_dim, loss=args.AE_loss, detach_target=args.detach_target, beta=args.beta).to(device)
145 | else:
146 | raise Exception
147 |
148 | model_param_group = []
149 | model_param_group.append({'params': molecule_model_2D.parameters(), 'lr': args.lr * args.gnn_lr_scale})
150 | model_param_group.append({'params': molecule_model_3D.parameters(), 'lr': args.lr * args.schnet_lr_scale})
151 | model_param_group.append({'params': AE_2D_3D_model.parameters(), 'lr': args.lr * args.gnn_lr_scale})
152 | model_param_group.append({'params': AE_3D_2D_model.parameters(), 'lr': args.lr * args.schnet_lr_scale})
153 | if molecule_projection_layer is not None:
154 | model_param_group.append({'params': molecule_projection_layer.parameters(), 'lr': args.lr * args.schnet_lr_scale})
155 |
156 | optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
157 | optimal_loss = 1e10
158 |
159 | for epoch in range(1, args.epochs + 1):
160 | print('epoch: {}'.format(epoch))
161 | train(args, molecule_model_2D, device, loader, optimizer)
162 |
163 | save_model(save_best=False)
164 |
--------------------------------------------------------------------------------
/src_regression/pretrain_JOAO.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from datasets_complete_feature import MoleculeDataset_graphcl_complete
9 | from models_complete_feature import GNNComplete
10 | from torch_geometric.data import DataLoader
11 | from torch_geometric.nn import global_mean_pool
12 |
13 |
14 | class graphcl(nn.Module):
15 |
16 | def __init__(self, gnn):
17 | super(graphcl, self).__init__()
18 | self.gnn = gnn
19 | self.pool = global_mean_pool
20 | self.projection_head = nn.Sequential(
21 | nn.Linear(300, 300),
22 | nn.ReLU(inplace=True),
23 | nn.Linear(300, 300))
24 |
25 | def forward_cl(self, x, edge_index, edge_attr, batch):
26 | x = self.gnn(x, edge_index, edge_attr)
27 | x = self.pool(x, batch)
28 | x = self.projection_head(x)
29 | return x
30 |
31 | def loss_cl(self, x1, x2):
32 | T = 0.1
33 | batch, _ = x1.size()
34 | x1_abs = x1.norm(dim=1)
35 | x2_abs = x2.norm(dim=1)
36 |
37 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / \
38 | torch.einsum('i,j->ij', x1_abs, x2_abs)
39 | sim_matrix = torch.exp(sim_matrix / T)
40 | pos_sim = sim_matrix[range(batch), range(batch)]
41 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
42 | loss = - torch.log(loss).mean()
43 | return loss
44 |
45 |
46 | def train(loader, model, optimizer, device, gamma_joao):
47 |
48 | model.train()
49 | train_loss_accum = 0
50 |
51 | for step, (_, batch1, batch2) in enumerate(loader):
52 | # _, batch1, batch2 = batch
53 | batch1 = batch1.to(device)
54 | batch2 = batch2.to(device)
55 |
56 | # pdb.set_trace()
57 | x1 = model.forward_cl(batch1.x, batch1.edge_index,
58 | batch1.edge_attr, batch1.batch)
59 | x2 = model.forward_cl(batch2.x, batch2.edge_index,
60 | batch2.edge_attr, batch2.batch)
61 | loss = model.loss_cl(x1, x2)
62 |
63 | optimizer.zero_grad()
64 | loss.backward()
65 | optimizer.step()
66 |
67 | train_loss_accum += float(loss.detach().cpu().item())
68 |
69 | # joao
70 | aug_prob = loader.dataset.aug_prob
71 | loss_aug = np.zeros(25)
72 | for n in range(25):
73 | _aug_prob = np.zeros(25)
74 | _aug_prob[n] = 1
75 | loader.dataset.set_augProb(_aug_prob)
76 | # for efficiency, we only use around 10% of data to estimate the loss
77 | count, count_stop = 0, len(loader.dataset) // (loader.batch_size * 10) + 1
78 |
79 | with torch.no_grad():
80 | for step, (_, batch1, batch2) in enumerate(loader):
81 | # _, batch1, batch2 = batch
82 | batch1 = batch1.to(device)
83 | batch2 = batch2.to(device)
84 |
85 | x1 = model.forward_cl(batch1.x, batch1.edge_index,
86 | batch1.edge_attr, batch1.batch)
87 | x2 = model.forward_cl(batch2.x, batch2.edge_index,
88 | batch2.edge_attr, batch2.batch)
89 | loss = model.loss_cl(x1, x2)
90 | loss_aug[n] += loss.item()
91 | count += 1
92 | if count == count_stop:
93 | break
94 | loss_aug[n] /= count
95 |
96 | # view selection, projected gradient descent,
97 | # reference: https://arxiv.org/abs/1906.03563
98 | beta = 1
99 | gamma = gamma_joao
100 |
101 | b = aug_prob + beta * (loss_aug - gamma * (aug_prob - 1 / 25))
102 | mu_min, mu_max = b.min() - 1 / 25, b.max() - 1 / 25
103 | mu = (mu_min + mu_max) / 2
104 |
105 | # bisection method
106 | while abs(np.maximum(b - mu, 0).sum() - 1) > 1e-2:
107 | if np.maximum(b - mu, 0).sum() > 1:
108 | mu_min = mu
109 | else:
110 | mu_max = mu
111 | mu = (mu_min + mu_max) / 2
112 |
113 | aug_prob = np.maximum(b - mu, 0)
114 | aug_prob /= aug_prob.sum()
115 |
116 | return train_loss_accum / (step + 1), aug_prob
117 |
118 |
119 | if __name__ == "__main__":
120 | # Training settings
121 | parser = argparse.ArgumentParser(description='JOAO')
122 | parser.add_argument('--device', type=int, default=0, help='gpu')
123 | parser.add_argument('--batch_size', type=int, default=256, help='batch')
124 | parser.add_argument('--decay', type=float, default=0, help='weight decay')
125 | parser.add_argument('--epochs', type=int, default=100, help='train epochs')
126 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
127 | parser.add_argument('--JK', type=str, default="last",
128 | choices=['last', 'sum', 'max', 'concat'],
129 | help='how the node features across layers are combined.')
130 | parser.add_argument('--gnn_type', type=str, default="gin", help='gnn model type')
131 | parser.add_argument('--dropout_ratio', type=float, default=0, help='dropout ratio')
132 | parser.add_argument('--emb_dim', type=int, default=300, help='embedding dimensions')
133 | parser.add_argument('--dataset', type=str, default=None, help='root dir of dataset')
134 | parser.add_argument('--num_layer', type=int, default=5, help='message passing layers')
135 | parser.add_argument('--output_model_file', type=str, default='', help='model save path')
136 | parser.add_argument('--num_workers', type=int, default=8, help='workers for dataset loading')
137 |
138 | parser.add_argument('--aug_mode', type=str, default='sample')
139 | parser.add_argument('--aug_strength', type=float, default=0.2)
140 |
141 | parser.add_argument('--gamma', type=float, default=0.1)
142 | parser.add_argument('--output_model_dir', type=str, default='')
143 | args = parser.parse_args()
144 |
145 | torch.manual_seed(0)
146 | np.random.seed(0)
147 | device = torch.device("cuda:" + str(args.device)) \
148 | if torch.cuda.is_available() else torch.device("cpu")
149 | if torch.cuda.is_available():
150 | torch.cuda.manual_seed_all(0)
151 |
152 | if 'GEOM' in args.dataset:
153 | dataset = MoleculeDataset_graphcl_complete('../datasets/{}/'.format(args.dataset), dataset=args.dataset)
154 | dataset.set_augMode(args.aug_mode)
155 | dataset.set_augStrength(args.aug_strength)
156 | print(dataset)
157 |
158 | loader = DataLoader(dataset, batch_size=args.batch_size,
159 | num_workers=args.num_workers, shuffle=True)
160 |
161 | # set up model
162 | gnn = GNNComplete(num_layer=args.num_layer, emb_dim=args.emb_dim, JK=args.JK,
163 | drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type)
164 |
165 | model = graphcl(gnn)
166 | model.to(device)
167 |
168 | # set up optimizer
169 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.decay)
170 | # print(optimizer)
171 |
172 | # pdb.set_trace()
173 | aug_prob = np.ones(25) / 25
174 | np.set_printoptions(precision=3, floatmode='fixed')
175 |
176 | for epoch in range(1, args.epochs + 1):
177 | print('\n\n')
178 | start_time = time.time()
179 | dataset.set_augProb(aug_prob)
180 | pretrain_loss, aug_prob = train(loader, model, optimizer, device, args.gamma)
181 |
182 | print('Epoch: {:3d}\tLoss:{:.3f}\tTime: {:.3f}\tAugmentation Probability:'.format(
183 | epoch, pretrain_loss, time.time() - start_time))
184 | print(aug_prob)
185 |
186 | if not args.output_model_dir == '':
187 | saver_dict = {'model': model.state_dict()}
188 | torch.save(saver_dict, args.output_model_dir + '_model_complete.pth')
189 | torch.save(model.gnn.state_dict(), args.output_model_dir + '_model.pth')
190 |
--------------------------------------------------------------------------------
/src_regression/run_dti_finetune.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python dti_finetune.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_regression/run_molecule_finetune_regression.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python molecule_finetune_regression.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_regression/run_pretrain_AM.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 | echo "start"
9 | python pretrain_AM.py $@
10 | echo "end"
11 | date
12 |
--------------------------------------------------------------------------------
/src_regression/run_pretrain_CP.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 | echo "start"
9 | python pretrain_CP.py $@
10 | echo "end"
11 | date
12 |
--------------------------------------------------------------------------------
/src_regression/run_pretrain_GraphCL.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 | echo "start"
9 | python pretrain_GraphCL.py $@
10 | echo "end"
11 | date
12 |
--------------------------------------------------------------------------------
/src_regression/run_pretrain_GraphMVP.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 | echo "start"
9 | python pretrain_GraphMVP.py $@
10 | echo "end"
11 | date
12 |
--------------------------------------------------------------------------------
/src_regression/run_pretrain_GraphMVP_hybrid.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 | echo "start"
9 | python pretrain_GraphMVP_hybrid.py $@
10 | echo "end"
11 | date
12 |
--------------------------------------------------------------------------------
/src_regression/run_pretrain_JOAO.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_JOAO.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_regression/run_pretrain_JOAOv2.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source $HOME/.bashrc
4 | conda activate GraphMVP
5 |
6 | echo $@
7 | date
8 |
9 | echo "start"
10 | python pretrain_JOAOv2.py $@
11 | echo "end"
12 | date
13 |
--------------------------------------------------------------------------------
/src_regression/util_complete_feature.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import networkx as nx
4 | import torch
5 | import torch.nn.functional as F
6 | from datasets_complete_feature import (graph_data_obj_to_nx_simple,
7 | nx_to_graph_data_obj_simple)
8 |
9 |
10 | class ExtractSubstructureContextPair:
11 |
12 | def __init__(self, k, l1, l2):
13 | """
14 | Randomly selects a node from the data object, and adds attributes
15 | that contain the substructure that corresponds to k hop neighbours
16 | rooted at the node, and the context substructures that corresponds to
17 | the subgraph that is between l1 and l2 hops away from the root node. """
18 | self.k = k
19 | self.l1 = l1
20 | self.l2 = l2
21 |
22 | # for the special case of 0, addresses the quirk with
23 | # single_source_shortest_path_length
24 | if self.k == 0:
25 | self.k = -1
26 | if self.l1 == 0:
27 | self.l1 = -1
28 | if self.l2 == 0:
29 | self.l2 = -1
30 |
31 | def __call__(self, data, root_idx=None):
32 | """
33 | :param data: pytorch geometric data object
34 | :param root_idx: If None, then randomly samples an atom idx.
35 | Otherwise sets atom idx of root (for debugging only)
36 | :return: None. Creates new attributes in original data object:
37 | data.center_substruct_idx
38 | data.x_substruct
39 | data.edge_attr_substruct
40 | data.edge_index_substruct
41 | data.x_context
42 | data.edge_attr_context
43 | data.edge_index_context
44 | data.overlap_context_substruct_idx """
45 | num_atoms = data.x.size()[0]
46 | if root_idx is None:
47 | root_idx = random.sample(range(num_atoms), 1)[0]
48 |
49 | G = graph_data_obj_to_nx_simple(data) # same ordering as input data obj
50 |
51 | # Get k-hop subgraph rooted at specified atom idx
52 | substruct_node_idxes = nx.single_source_shortest_path_length(G, root_idx, self.k).keys()
53 | if len(substruct_node_idxes) > 0:
54 | substruct_G = G.subgraph(substruct_node_idxes)
55 | substruct_G, substruct_node_map = reset_idxes(substruct_G) # need
56 | # to reset node idx to 0 -> num_nodes - 1, otherwise data obj does not
57 | # make sense, since the node indices in data obj must start at 0
58 | substruct_data = nx_to_graph_data_obj_simple(substruct_G)
59 | data.x_substruct = substruct_data.x
60 | data.edge_attr_substruct = substruct_data.edge_attr
61 | data.edge_index_substruct = substruct_data.edge_index
62 | data.center_substruct_idx = torch.tensor([substruct_node_map[root_idx]]) # need
63 | # to convert center idx from original graph node ordering to the
64 | # new substruct node ordering
65 |
66 | # Get subgraphs that is between l1 and l2 hops away from the root node
67 | l1_node_idxes = nx.single_source_shortest_path_length(G, root_idx, self.l1).keys()
68 | l2_node_idxes = nx.single_source_shortest_path_length(G, root_idx, self.l2).keys()
69 | context_node_idxes = set(l1_node_idxes).symmetric_difference(set(l2_node_idxes))
70 | if len(context_node_idxes) > 0:
71 | context_G = G.subgraph(context_node_idxes)
72 | context_G, context_node_map = reset_idxes(context_G) # need to
73 | # reset node idx to 0 -> num_nodes - 1, otherwise data obj does not
74 | # make sense, since the node indices in data obj must start at 0
75 | context_data = nx_to_graph_data_obj_simple(context_G)
76 | data.x_context = context_data.x
77 | data.edge_attr_context = context_data.edge_attr
78 | data.edge_index_context = context_data.edge_index
79 |
80 | # Get indices of overlapping nodes between substruct and context,
81 | # WRT context ordering
82 | context_substruct_overlap_idxes = list(set(context_node_idxes).intersection(
83 | set(substruct_node_idxes)))
84 | if len(context_substruct_overlap_idxes) > 0:
85 | context_substruct_overlap_idxes_reorder = [
86 | context_node_map[old_idx]
87 | for old_idx in context_substruct_overlap_idxes]
88 | # need to convert the overlap node idxes, which is from the
89 | # original graph node ordering to the new context node ordering
90 | data.overlap_context_substruct_idx = \
91 | torch.tensor(context_substruct_overlap_idxes_reorder)
92 |
93 | return data
94 |
95 | def __repr__(self):
96 | return '{}(k={},l1={}, l2={})'.format(
97 | self.__class__.__name__, self.k, self.l1, self.l2)
98 |
99 |
100 | def reset_idxes(G):
101 | """ Resets node indices such that they are numbered from 0 to num_nodes - 1
102 | :return: copy of G with relabelled node indices, mapping """
103 | mapping = {}
104 | for new_idx, old_idx in enumerate(G.nodes()):
105 | mapping[old_idx] = new_idx
106 | new_G = nx.relabel_nodes(G, mapping, copy=True)
107 | return new_G, mapping
108 |
109 |
110 | class MaskAtom:
111 | def __init__(self, num_atom_type, num_edge_type, mask_rate, mask_edge=True):
112 | """
113 | Randomly masks an atom, and optionally masks edges connecting to it.
114 | The mask atom type index is num_possible_atom_type
115 | The mask edge type index in num_possible_edge_type
116 | :param num_atom_type:
117 | :param num_edge_type:
118 | :param mask_rate: % of atoms to be masked
119 | :param mask_edge: If True, also mask the edges that connect to the
120 | masked atoms """
121 | self.num_atom_type = num_atom_type
122 | self.num_edge_type = num_edge_type
123 | self.mask_rate = mask_rate
124 | self.mask_edge = mask_edge
125 |
126 | def __call__(self, data, masked_atom_indices=None):
127 | """
128 | :param data: pytorch geometric data object. Assume that the edge
129 | ordering is the default pytorch geometric ordering, where the two
130 | directions of a single edge occur in pairs.
131 | Eg. data.edge_index = tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])
132 | :param masked_atom_indices: If None, then randomly samples num_atoms
133 | * mask rate number of atom indices
134 | Otherwise a list of atom idx that sets the atoms to be masked (for
135 | debugging only)
136 | :return: None, Creates new attributes in original data object:
137 | data.mask_node_idx
138 | data.mask_node_label
139 | data.mask_edge_idx
140 | data.mask_edge_label """
141 |
142 | if masked_atom_indices is None:
143 | # sample x distinct atoms to be masked, based on mask rate. But
144 | # will sample at least 1 atom
145 | num_atoms = data.x.size()[0]
146 | sample_size = int(num_atoms * self.mask_rate + 1)
147 | masked_atom_indices = random.sample(range(num_atoms), sample_size)
148 |
149 | # create mask node label by copying atom feature of mask atom
150 | mask_node_labels_list = []
151 | for atom_idx in masked_atom_indices:
152 | mask_node_labels_list.append(data.x[atom_idx].view(1, -1))
153 | data.mask_node_label = torch.cat(mask_node_labels_list, dim=0)
154 | data.masked_atom_indices = torch.tensor(masked_atom_indices)
155 |
156 | # modify the original node feature of the masked node
157 | data.masked_x = data.x.clone()
158 | for atom_idx in masked_atom_indices:
159 | data.masked_x[atom_idx] = torch.tensor([self.num_atom_type-1, 0, 0, 0, 0, 0, 0, 0, 0])
160 |
161 | return data
162 |
163 | def __repr__(self):
164 | return '{}(num_atom_type={}, num_edge_type={}, mask_rate={}, mask_edge={})'.format(
165 | self.__class__.__name__, self.num_atom_type, self.num_edge_type,
166 | self.mask_rate, self.mask_edge)
167 |
--------------------------------------------------------------------------------