├── .gitignore
├── LICENSE.md
├── README.md
├── fragnet
├── assets
│ ├── README.md
│ ├── app.png
│ ├── fragnet.png
│ └── weights_main.png
├── data_create
│ ├── create_dataset.py
│ ├── create_finetune_datasets.py
│ └── create_pretrain_datasets.py
├── dataset
│ ├── cdrp.py
│ ├── custom_dataset.py
│ ├── data.py
│ ├── dataset.py
│ ├── dta.py
│ ├── ext_data_utils
│ │ ├── Step1_getData.py
│ │ ├── __init__.py
│ │ └── deepttc.py
│ ├── feature_utils.py
│ ├── features.py
│ ├── features0.py
│ ├── features2.py
│ ├── features_check.py
│ ├── features_exp2.py
│ ├── features_exp_safe.py
│ ├── fix_pt_data.py
│ ├── fragments.py
│ ├── general.py
│ ├── loader_molebert.py
│ ├── moleculenet.py
│ ├── scaffold_split_from_df.py
│ ├── simsgt.py
│ ├── splitters.py
│ ├── splitters_molebert.py
│ └── utils.py
├── exps
│ ├── README.md
│ ├── ft
│ │ ├── esol
│ │ │ └── e1pt4.yaml
│ │ ├── lipo
│ │ │ ├── download.sh
│ │ │ └── fragnet_hpdl_exp1s_pt4_30
│ │ │ │ ├── config_exp100.yaml
│ │ │ │ └── ft_100.pt
│ │ └── pnnl_full
│ │ │ └── fragnet_hpdl_exp1s_h4pt4_10
│ │ │ ├── config_exp100.yaml
│ │ │ ├── ft_100.pt
│ │ │ └── ft_100.pt.data
│ └── pt
│ │ └── unimol_exp1s4
│ │ ├── config.yaml
│ │ ├── pt.pt
│ │ └── pt.pt.data
├── hp
│ ├── hp.py
│ ├── hp2.py
│ ├── hp_cdrp.py
│ ├── hp_clf.py
│ ├── hp_dta.py
│ ├── hpft.py
│ ├── hpoptuna.py
│ └── hpray.py
├── model
│ ├── cdrp
│ │ ├── __init__.py
│ │ └── model.py
│ ├── dta
│ │ ├── __init__.py
│ │ ├── drug_encoder.py
│ │ └── model.py
│ ├── gat
│ │ ├── __init__.py
│ │ ├── extra_optimizers.py
│ │ ├── gat.py
│ │ ├── gat2.py
│ │ ├── gat2_cv.py
│ │ ├── gat2_edge.py
│ │ ├── gat2_lite.py
│ │ ├── gat2_pl.py
│ │ ├── gat2_pretrain.py
│ │ └── pretrain_heads.py
│ └── gcn
│ │ ├── gcn.py
│ │ ├── gcn2.py
│ │ ├── gcn3.py
│ │ └── gcn_pl.py
├── train
│ ├── finetune
│ │ ├── __init__.py
│ │ ├── finetune_cdrp.py
│ │ ├── finetune_dta.py
│ │ ├── finetune_gat.py
│ │ ├── finetune_gat2.py
│ │ ├── finetune_gat2_pl.py
│ │ ├── finetune_norm.py
│ │ ├── gat2_cv_frag.py
│ │ ├── trainer_cdrp.py
│ │ └── trainer_dta.py
│ ├── pretrain
│ │ ├── __init__.py
│ │ ├── pretrain_data_pnnl.py
│ │ ├── pretrain_gat2.py
│ │ ├── pretrain_gat_mol.py
│ │ ├── pretrain_gat_str.py
│ │ ├── pretrain_gcn.py
│ │ ├── pretrain_heads.py
│ │ └── pretrain_utils.py
│ ├── utils.py
│ └── utils_pl.py
└── vizualize
│ ├── app.py
│ ├── config.py
│ ├── model.py
│ ├── model_attr.py
│ ├── property.py
│ └── viz.py
├── install_cpu.sh
├── install_gpu.sh
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | */.ipynb_checkpoints/*
3 | *__pycache__
4 | __pycache__/
5 | bin/
6 | build/
7 | develop-eggs/
8 | dist/
9 | eggs/
10 | lib/
11 | lib64/
12 | parts/
13 | sdist/
14 | var/
15 | *.egg-info/
16 | .installed.cfg
17 | *.egg
18 | .tox/
19 | .cache
20 | .project
21 | .pydevproject
22 |
23 |
24 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | FragNet
2 |
3 | Copyright © 2024, Battelle Memorial Institute
4 | All rights reserved.
5 |
6 | 1. Battelle Memorial Institute (hereinafter Battelle) hereby grants permission to any person or entity lawfully obtaining a copy of this software and associated documentation files (hereinafter “the Software”) to redistribute and use the Software in source and binary forms, with or without modification. Such person or entity may use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and may permit others to do so, subject to the following conditions:
7 |
8 | 0. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimers.
9 |
10 | 1. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11 |
12 | 2. Other than as used herein, neither the name Battelle Memorial Institute or Battelle may be used in any form whatsoever without the express written consent of Battelle.
13 |
14 | 2. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL BATTELLE OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FragNet
2 |
3 | FragNet is a Graph Neural Network designed for molecular property prediction, that can offer insights into how different substructures influence the predictions. More details of FragNet can be found in our paper,
4 | [FragNet: A Graph Neural Network for Molecular Property Prediction with Four Layers of Interpretability](https://arxiv.org/abs/2410.12156).
5 |
6 |
7 |
8 |
9 |
10 | Figure 1: FragNet’s architecture and data representation. (a) Atom and Fragment graphs’
11 | edge features are learned from Bond and Fragment connection graphs respectively. b) Initial
12 | fragment features for the fragment graph are the summation of the updated atom features
13 | that compose the fragment. (c) Illustration of FragNet’s message passing taking place be-
14 | tween two non-covalently bonded substructures. Fragment-Fragment connections are also
15 | present between adjacent fragments in each non-covalently bonded structure of the com-
16 | pound.
17 |
18 |
19 |
20 | Figure 2: Different types of attention weights and contribution values available in FragNet visualized for CC[NH+](CCCl)CCOc1cccc2ccccc12.[Cl-] with atom, bond, and fragment at-
21 | tention weights shown in (a),(b), and (c) and fragment contribution values shown in (d).
22 | The top table provides the atom to fragment mapping and the bottom table provides the
23 | fragment connection attention weights. Atom and bond attention weights are scaled to val-
24 | ues between 0 and 1. The fragment and fragment connection weights are not scaled. The
25 | numbers in blue boxes in (d) correspond to Fragment IDs in ‘Atoms in Fragments’ table.
26 | # Usage
27 |
28 | ### Installation
29 |
30 | The installation has been tested with python 3.11 and cuda 12.1
31 |
32 | #### For CPU
33 |
34 | 1. Create a python 3.11 virtual environment and install the required packages using the command `pip install -r requirements.txt`
35 | 2. Install torch-scatter using `pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html`
36 | 3. Next install FragNet. In the directory where `setup.py` is, run the command `pip install .`
37 |
38 | Alternatively and more conveniently, you can run `bash install_cpu.sh` which will install FragNet and create pretraining and finetuning data for ESOL dataset.
39 |
40 | #### For GPU
41 |
42 | 1. Create a python 3.11 virtual environment and install the required packages using the command `pip instal -r requirements.txt`
43 | 2. Install torch-scatter using `pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html`
44 | 3. Next install FragNet. In the directory where `setup.py` is, run the command `pip install .`
45 |
46 | Alternatively do `bash install_gpu.sh`.
47 |
48 | -------
49 |
50 | ### Creating pretraining data
51 |
52 | FragNet was pretrained using part of the data used by [UniMol](https://github.com/deepmodeling/Uni-Mol/tree/main/unimol).
53 |
54 | Here, we use ESOL dataset to demonstrate the data creation. The following commands should be run at the `FragNet/fragnet` directory.
55 |
56 | First, create a directory to save data.
57 |
58 | ```mkdir -p finetune_data/moleculenet/esol/raw/```
59 |
60 | Next, download ESOL dataset.
61 |
62 | ```
63 | wget -O finetune_data/moleculenet/esol/raw/delaney-processed.csv https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv
64 | ```
65 |
66 | Next, run the following command to create pretraining data.
67 |
68 | ```
69 | python data_create/create_pretrain_datasets.py --save_path pretrain_data/esol --data_type exp1s --maxiters 500 --raw_data_path finetune_data/moleculenet/esol/raw/delaney-processed.csv
70 | ```
71 |
72 |
73 | - save_path: where the datasets should be saved
74 | - data_type: use exp1s for all the calculations
75 | - maxiters: maximum number of iterations for 3D coordinate generation
76 | - raw_data_path: location of the smiles dataset
77 |
78 | ------
79 |
80 | ### Creating finetuning data
81 |
82 | Creating data for finetuning for MoleculeNet datasets can be done as follows,
83 |
84 |
85 | `python data_create/create_finetune_datasets.py --dataset_name moleculenet --dataset_subset esol --use_molebert True --output_dir finetune_data/moleculenet_exp1s --data_dir finetune_data/moleculenet --data_type exp1s`
86 |
87 |
88 | - dataset_name: dataset type
89 | - dataset_subset: dataset sub-type
90 | - use_molebert: whether to use the dataset splitting method used by MoleBert model
91 |
92 | ------
93 |
94 | ### Pretrain
95 |
96 | To pretrain run the following command. All the input parameters have to be given in a config file.
97 |
98 | ```
99 | python train/pretrain/pretrain_gat2.py --config exps/pt/unimol_exp1s4/config.yaml
100 | ```
101 |
102 | ------
103 |
104 | ### Finetune
105 | ```
106 | python train/finetune/finetune_gat2.py --config exps/ft/esol/e1pt4.yaml
107 | ```
108 |
109 |
110 | ------
111 |
112 | ## Interactive Web Application
113 |
114 | To run this application, run the command `streamlit run fragnet/vizualize/app.py` from the root directory
115 |
116 |
117 |
118 | ------
119 |
120 | ## Optional
121 | ### Hyperparameter tuning
122 | ```
123 | python hp/hpoptuna.py --config exps/ft/esol/e1pt4.yaml --n_trials 10 \
124 | --chkpt hpruns/pt.pt --seed 10 --ft_epochs 10 --prune 1
125 | ```
126 |
127 | - config: initial parameters
128 | - n_trials: number of hp optimization trails
129 | - chkpt: this is where the checkoint during hp optimization will be saved. Note that you will have to create an output directory for this (in this case hpruns). Otherwise the output directory is assumed to be the current working directory.
130 | - seed: random seed
131 | - ft_epochs: number of training epochs
132 | - prune: For Optuna runs. Whether to prune an optimization.
133 |
134 |
135 |
136 | ## Citation
137 | If you use our work, please cite it as,
138 |
139 | ```
140 | @misc{panapitiya2024fragnetgraphneuralnetwork,
141 | title={FragNet: A Graph Neural Network for Molecular Property Prediction with Four Layers of Interpretability},
142 | author={Gihan Panapitiya and Peiyuan Gao and C Mark Maupin and Emily G Saldanha},
143 | year={2024},
144 | eprint={2410.12156},
145 | archivePrefix={arXiv},
146 | primaryClass={cs.LG},
147 | url={https://arxiv.org/abs/2410.12156},
148 | }
149 | ```
150 |
151 |
152 |
153 |
Disclaimer
154 |
155 | This material was prepared as an account of work sponsored by an agency of the United States Government. Neither the United States Government nor the United States Department of Energy, nor Battelle, nor any of their employees, nor any jurisdiction or organization that has cooperated in the development of these materials, makes any warranty, express or implied, or assumes any legal liability or responsibility for the accuracy, completeness, or usefulness or any information, apparatus, product, software, or process disclosed, or represents that its use would not infringe privately owned rights.
156 | Reference herein to any specific commercial product, process, or service by trade name, trademark, manufacturer, or otherwise does not necessarily constitute or imply its endorsement, recommendation, or favoring by the United States Government or any agency thereof, or Battelle Memorial Institute. The views and opinions of authors expressed herein do not necessarily state or reflect those of the United States Government or any agency thereof.
157 | PACIFIC NORTHWEST NATIONAL LABORATORY
158 | operated by
159 | BATTELLE
160 | for the
161 | UNITED STATES DEPARTMENT OF ENERGY
162 | under Contract DE-AC05-76RL01830
163 |
164 |
165 |
166 |
--------------------------------------------------------------------------------
/fragnet/assets/README.md:
--------------------------------------------------------------------------------
1 | download cell_line_data.csv
2 |
--------------------------------------------------------------------------------
/fragnet/assets/app.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/assets/app.png
--------------------------------------------------------------------------------
/fragnet/assets/fragnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/assets/fragnet.png
--------------------------------------------------------------------------------
/fragnet/assets/weights_main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/assets/weights_main.png
--------------------------------------------------------------------------------
/fragnet/data_create/create_dataset.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from dataset import save_dataset_parts, save_dataset
3 | import argparse
4 |
5 |
6 | if __name__ == "__main__":
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument(
10 | "--data_path",
11 | help="dataframe or csv file",
12 | type=str,
13 | required=False,
14 | default=None,
15 | )
16 | parser.add_argument(
17 | "--train_path",
18 | help="path to train input data",
19 | type=str,
20 | required=False,
21 | default=None,
22 | )
23 | parser.add_argument(
24 | "--val_path",
25 | help="path to val input data",
26 | type=str,
27 | required=False,
28 | default=None,
29 | )
30 | parser.add_argument(
31 | "--test_path",
32 | help="path to test input data",
33 | type=str,
34 | required=False,
35 | default=None,
36 | )
37 | parser.add_argument("--start_row", help="", type=int, required=False, default=None)
38 | parser.add_argument("--end_row", help="", type=int, required=False, default=None)
39 | parser.add_argument("--start_id", help="", type=int, required=False, default=0)
40 | parser.add_argument(
41 | "--rows_per_part", help="", type=int, required=False, default=1000
42 | )
43 | parser.add_argument(
44 | "--create_bond_graph_data", help="", type=bool, required=False, default=True
45 | )
46 | parser.add_argument(
47 | "--add_dhangles", help="", type=bool, required=False, default=False
48 | )
49 | parser.add_argument(
50 | "--feature_type",
51 | help="one_hot or embed",
52 | type=str,
53 | required=False,
54 | default="one_hot",
55 | )
56 | parser.add_argument(
57 | "--target",
58 | help="target property",
59 | type=str,
60 | required=False,
61 | nargs="+",
62 | default="log_sol",
63 | )
64 | parser.add_argument(
65 | "--save_path",
66 | help="folder where the data is saved",
67 | type=str,
68 | required=False,
69 | default=None,
70 | )
71 | parser.add_argument(
72 | "--save_name", help="saving file name", type=str, required=False, default=None
73 | )
74 | args = parser.parse_args()
75 |
76 | dataset_args = {
77 | "create_bond_graph_data": args.create_bond_graph_data,
78 | "add_dhangles": args.add_dhangles,
79 | "feature_type": args.feature_type,
80 | "target": args.target,
81 | "save_path": args.save_path,
82 | "save_name": args.save_name,
83 | "start_id": args.start_id,
84 | }
85 |
86 | if args.train_path:
87 | train = pd.read_csv(args.train_path)
88 | train.to_csv(f"{args.save_path}/train.csv", index=False)
89 | save_dataset(
90 | df=train,
91 | save_path=f"{args.save_path}",
92 | save_name=f"train",
93 | target=args.target,
94 | feature_type=args.feature_type,
95 | create_bond_graph_data=args.create_bond_graph_data,
96 | )
97 |
98 | elif args.val_path:
99 | val = pd.read_csv(args.val_path)
100 | val.to_csv(f"{args.save_path}/val.csv", index=False)
101 | save_dataset(
102 | df=val,
103 | save_path=f"{args.save_path}",
104 | save_name=f"val",
105 | target=args.target,
106 | feature_type=args.feature_type,
107 | create_bond_graph_data=args.create_bond_graph_data,
108 | )
109 |
110 | elif args.test_path:
111 | test = pd.read_csv(args.test_path)
112 | test.to_csv(f"{args.save_path}/test.csv", index=False)
113 | save_dataset(
114 | df=test,
115 | save_path=f"{args.save_path}",
116 | save_name=f"test",
117 | target=args.target,
118 | feature_type=args.feature_type,
119 | create_bond_graph_data=args.create_bond_graph_data,
120 | )
121 |
122 | else:
123 | save_dataset_parts(
124 | args.data_path,
125 | start_row=args.start_row,
126 | end_row=args.end_row,
127 | rows_per_part=args.rows_per_part,
128 | dataset_args=dataset_args,
129 | )
130 |
--------------------------------------------------------------------------------
/fragnet/data_create/create_finetune_datasets.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from fragnet.dataset.moleculenet import create_moleculenet_dataset
4 | from fragnet.dataset.general import create_general_dataset
5 |
6 | # from fragnet.dataset.unimol import create_moleculenet_dataset_from_unimol_data
7 | from fragnet.dataset.simsgt import create_moleculenet_dataset_simsgt
8 | from fragnet.dataset.dta import create_dta_dataset
9 | from fragnet.dataset.cdrp import create_cdrp_dataset
10 | from fragnet.dataset.scaffold_split_from_df import create_scaffold_split_data_from_df
11 |
12 | if __name__ == "__main__":
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "--dataset_name",
17 | help="moleculenet, moleculenet-custom",
18 | type=str,
19 | required=False,
20 | default="moleculenet",
21 | )
22 | parser.add_argument(
23 | "--dataset_subset",
24 | help="esol, freesolv",
25 | type=str,
26 | required=False,
27 | default=None,
28 | )
29 | parser.add_argument(
30 | "--output_dir", help="", type=str, required=False, default="finetune_data"
31 | )
32 | parser.add_argument(
33 | "--data_dir", help="", type=str, required=False, default="finetune_data"
34 | )
35 | parser.add_argument(
36 | "--frag_type", help="", type=str, required=False, default="brics"
37 | )
38 | parser.add_argument(
39 | "--use_molebert", help="", type=bool, required=False, default=False
40 | )
41 | parser.add_argument("--train_path", help="", type=str, required=False, default=None)
42 | parser.add_argument("--save_parts", help="", type=int, required=False, default=0)
43 | parser.add_argument("--val_path", help="", type=str, required=False, default=None)
44 | parser.add_argument("--test_path", help="", type=str, required=False, default=None)
45 | parser.add_argument("--target_name", help="", type=str, required=False, default="y")
46 | parser.add_argument("--data_type", help="", type=str, required=False, default="exp")
47 | parser.add_argument(
48 | "--multi_conf_data",
49 | help="create multiple conformers for a smiles",
50 | type=int,
51 | required=False,
52 | default=0,
53 | )
54 | parser.add_argument(
55 | "--use_genes",
56 | help="whether a gene subset is used for cdrp data",
57 | type=int,
58 | required=False,
59 | default=0,
60 | )
61 |
62 | args = parser.parse_args()
63 |
64 | print("args.use_genes: ", args.use_genes)
65 | print("args.multi_conf_data: ", args.multi_conf_data)
66 |
67 | if "gen" in args.dataset_name:
68 |
69 | create_general_dataset(args)
70 | elif args.dataset_name == "moleculenet":
71 |
72 | create_moleculenet_dataset("MoleculeNet", args.dataset_subset.lower(), args)
73 | elif args.dataset_name == "moleculedataset":
74 | create_moleculenet_dataset("MoleculeDataset", args.dataset_subset.lower(), args)
75 |
76 | elif args.dataset_name == "unimol":
77 | create_moleculenet_dataset_from_unimol_data(args)
78 |
79 | elif args.dataset_name == "simsgt":
80 | create_moleculenet_dataset_simsgt(args.dataset_subset.lower(), args)
81 |
82 | elif args.dataset_name in ["davis", "kiba"]:
83 | create_dta_dataset(args)
84 |
85 | elif args.dataset_name in ["cep", "malaria"]:
86 | create_scaffold_split_data_from_df(args)
87 |
88 | elif args.dataset_name in ["gdsc", "gdsc_full", "ccle"]:
89 | print("args.use_genes0: ", args.use_genes)
90 | create_cdrp_dataset(args)
91 |
--------------------------------------------------------------------------------
/fragnet/data_create/create_pretrain_datasets.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from fragnet.dataset.dataset import get_pt_dataset
3 | from fragnet.dataset.utils import extract_data, save_datasets
4 | import argparse
5 | import os
6 | import logging
7 | import logging.config
8 |
9 |
10 | def continuous_creation(args):
11 |
12 | l_limit = args.low
13 | h_limit = args.high
14 | n_rows = args.n_rows
15 |
16 | for i in range(l_limit, h_limit, n_rows):
17 | start_id = i
18 | end_id = i + n_rows - 1
19 | dfi = df.loc[start_id:end_id]
20 |
21 | ds = get_pt_dataset(dfi, args.data_type, maxiters=args.maxiters)
22 | ds = extract_data(ds)
23 | save_path = f"{args.save_path}/pt_{start_id}_{end_id}"
24 | save_datasets(ds, save_path)
25 |
26 | print(dfi.shape)
27 |
28 |
29 | def create_from_ids(args):
30 |
31 | curr = pd.read_pickle("pretrain_data/unimol_exp1s/curr_tmp.pkl")
32 | full = pd.read_pickle("../fragnet1.2/pretrain_data/unimol/file_list.pkl")
33 | new = set(full).difference(curr)
34 | new = list(new)
35 | new = new[args.low : args.high]
36 | start_ids = [int(i.split("_")[1]) for i in new]
37 |
38 | for start_id in start_ids:
39 | end_id = start_id + n_rows - 1
40 | dfi = df.loc[start_id:end_id]
41 |
42 | ds = get_pt_dataset(dfi, args.data_type)
43 | ds = extract_data(ds)
44 | save_path = f"{args.save_path}/pt_{start_id}_{end_id}"
45 | logger.info(save_path)
46 | save_datasets(ds, save_path)
47 |
48 | print(dfi.shape)
49 |
50 |
51 | def continuous_creation_add_new(args):
52 |
53 | l_limit = args.low
54 | h_limit = args.high
55 | n_rows = args.n_rows
56 |
57 | for i in range(l_limit, h_limit, n_rows):
58 | start_id = i
59 | end_id = i + n_rows - 1
60 | dfi = df.loc[start_id:end_id]
61 |
62 | save_path = f"{args.save_path}/pt_{start_id}_{end_id}"
63 |
64 | curr_data = pd.read_pickle(save_path + ".pkl")
65 | curr_smiles = [d.smiles for d in curr_data]
66 |
67 | dfi_rem = dfi[~dfi.smiles.isin(curr_smiles)]
68 |
69 | ds = get_pt_dataset(
70 | dfi_rem, args.data_type, maxiters=args.maxiters, frag_type=args.frag_type
71 | )
72 | ds = extract_data(ds)
73 |
74 | ds_updated = ds + curr_data
75 | save_datasets(ds_updated, save_path)
76 | print(dfi.shape)
77 |
78 |
79 | if __name__ == "__main__":
80 |
81 | parser = argparse.ArgumentParser()
82 | parser.add_argument(
83 | "--save_path",
84 | help="",
85 | type=str,
86 | required=False,
87 | default="pretrain_data/unimol/ds",
88 | )
89 | parser.add_argument("--low", help="", type=int, required=False, default=None)
90 | parser.add_argument("--high", help="", type=int, required=False, default=None)
91 | parser.add_argument("--data_type", help="", type=str, required=False, default="exp")
92 | parser.add_argument(
93 | "--raw_data_path", help="", type=str, required=False, default=None
94 | )
95 | parser.add_argument(
96 | "--calc_type", help="", type=str, required=False, default="scratch"
97 | )
98 | parser.add_argument("--maxiters", help="", type=int, required=False, default=200)
99 | parser.add_argument(
100 | "--frag_type", help="", type=str, required=False, default="brics"
101 | )
102 |
103 | args = parser.parse_args()
104 |
105 | if not os.path.exists(args.save_path):
106 | os.makedirs(args.save_path, exist_ok=True)
107 |
108 | if args.raw_data_path == None:
109 | df = pd.read_csv("pretrain_data/input/train_no_modulus.csv")
110 | else:
111 | df = pd.read_csv(args.raw_data_path)
112 |
113 | n_rows = 1000
114 | if not args.low:
115 | args.low = 0
116 | if not args.high:
117 | args.high = len(df)
118 | args.n_rows = n_rows
119 |
120 | if args.calc_type == "scratch":
121 | continuous_creation(args)
122 | elif args.calc_type == "add":
123 | continuous_creation_add_new(args)
124 |
--------------------------------------------------------------------------------
/fragnet/dataset/cdrp.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .dataset import FinetuneDataCDRP
4 | from .ext_data_utils.deepttc import DataEncoding
5 | from sklearn.model_selection import train_test_split
6 | from .utils import extract_data, save_datasets, save_ds_parts
7 |
8 |
9 | def create_cdrp_dataset(args):
10 |
11 | if not os.path.exists(args.output_dir):
12 | os.makedirs(args.output_dir, exist_ok=True)
13 |
14 | print("args.use_genes1: ", args.use_genes)
15 |
16 | obj = DataEncoding(args.data_dir)
17 |
18 | print("args.use_genes: ", args.use_genes)
19 |
20 | dataset = FinetuneDataCDRP(
21 | target_name=args.target_name,
22 | data_type=args.data_type,
23 | args=args,
24 | use_genes=args.use_genes,
25 | )
26 |
27 | traindata, testdata = obj.Getdata.ByCancer(random_seed=1, test_size=0.05)
28 | traindata, valdata = train_test_split(traindata, test_size=0.1)
29 |
30 | traindata, valdata, testdata = obj.encode2(
31 | traindata=traindata, valdata=valdata, testdata=testdata
32 | )
33 |
34 | traindata.to_csv(f"{args.output_dir}/train.csv", index=False)
35 | valdata.to_csv(f"{args.output_dir}/val.csv", index=False)
36 | testdata.to_csv(f"{args.output_dir}/test.csv", index=False)
37 |
38 | if args.save_parts == 0:
39 |
40 | ds = dataset.get_ft_dataset(traindata)
41 | ds = extract_data(ds)
42 | print("ds: ", ds[0])
43 | save_path = f"{args.output_dir}/train"
44 | save_datasets(ds, save_path)
45 |
46 | ds = dataset.get_ft_dataset(valdata)
47 | ds = extract_data(ds)
48 | save_path = f"{args.output_dir}/val"
49 | save_datasets(ds, save_path)
50 |
51 | ds = dataset.get_ft_dataset(testdata)
52 | ds = extract_data(ds)
53 | save_path = f"{args.output_dir}/test"
54 | save_datasets(ds, save_path)
55 |
56 | else:
57 | save_ds_parts(
58 | data_creater=dataset, ds=traindata, output_dir=args.output_dir, fold="train"
59 | )
60 | save_ds_parts(
61 | data_creater=dataset, ds=valdata, output_dir=args.output_dir, fold="val"
62 | )
63 | save_ds_parts(
64 | data_creater=dataset, ds=testdata, output_dir=args.output_dir, fold="test"
65 | )
66 |
--------------------------------------------------------------------------------
/fragnet/dataset/custom_dataset.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from torch_geometric.datasets import MoleculeNet
3 | from .utils import remove_non_mols
4 | from torch_geometric.data import Data
5 |
6 |
7 | class MoleculeDataset:
8 | def __init__(self, name, data_dir):
9 | self.name = name
10 | self.data_dir = data_dir
11 |
12 | def get_data(self):
13 | if self.name == "tox21":
14 | return self.get_tox21()
15 | elif self.name == "toxcast":
16 | return self.get_toxcast()
17 | elif self.name == "clintox":
18 | return self.get_clintox()
19 | elif self.name == "sider":
20 | return self.get_sider()
21 | elif self.name == "bbbp":
22 | return self.get_bbbp()
23 | elif self.name == "hiv":
24 | return self.get_hiv_dataset()
25 | elif self.name == "muv":
26 | return self.get_muv()
27 |
28 | def get_muv(self):
29 | from loader_molebert import _load_muv_dataset
30 |
31 | raw_path = f"{self.data_dir}/muv/raw/muv.csv"
32 | smiles_list, rdkit_mol_objs, labels = _load_muv_dataset(raw_path)
33 |
34 | assert len(smiles_list) == len(labels)
35 | data = []
36 | for i in range(len(smiles_list)):
37 | y = [list(labels[i])]
38 | smiles = smiles_list[i]
39 | if smiles != None:
40 | data.append(Data(smiles=smiles, y=y))
41 |
42 | return data
43 |
44 | def get_tox21(self):
45 | from loader_molebert import _load_tox21_dataset
46 |
47 | raw_path = f"{self.data_dir}/tox21/raw/tox21.csv"
48 | smiles_list, rdkit_mol_objs, labels = _load_tox21_dataset(raw_path)
49 |
50 | assert len(smiles_list) == len(labels)
51 | data = []
52 | for i in range(len(smiles_list)):
53 | y = [list(labels[i])]
54 | smiles = smiles_list[i]
55 | if smiles != None:
56 | data.append(Data(smiles=smiles, y=y))
57 |
58 | return data
59 |
60 | def get_hiv_dataset(self):
61 | from loader_molebert import _load_hiv_dataset
62 |
63 | raw_path = f"{self.data_dir}/hiv/raw/HIV.csv"
64 |
65 | smiles_list, rdkit_mol_objs, labels = _load_hiv_dataset(raw_path)
66 |
67 | assert len(smiles_list) == len(labels)
68 | data = []
69 | for i in range(len(smiles_list)):
70 |
71 | y = [labels[i]]
72 | smiles = smiles_list[i]
73 | if smiles != None:
74 | data.append(Data(smiles=smiles, y=y))
75 |
76 | return data
77 |
78 | def get_toxcast(self):
79 | from loader_molebert import _load_toxcast_dataset
80 |
81 | raw_path = f"{self.data_dir}/toxcast/raw/toxcast_data.csv"
82 |
83 | smiles_list, rdkit_mol_objs, labels = _load_toxcast_dataset(raw_path)
84 |
85 | assert len(smiles_list) == len(labels)
86 | data = []
87 | for i in range(len(smiles_list)):
88 | y = [list(labels[i])]
89 | smiles = smiles_list[i]
90 | if smiles != None:
91 | data.append(Data(smiles=smiles, y=y))
92 |
93 | return data
94 |
95 | def get_clintox(self):
96 | from loader_molebert import _load_clintox_dataset
97 |
98 | raw_path = f"{self.data_dir}/clintox/raw/clintox.csv"
99 |
100 | smiles_list, rdkit_mol_objs, labels = _load_clintox_dataset(raw_path)
101 |
102 | assert len(smiles_list) == len(labels)
103 | data = []
104 | for i in range(len(smiles_list)):
105 | y = [list(labels[i])]
106 | smiles = smiles_list[i]
107 | if smiles != None:
108 | data.append(Data(smiles=smiles, y=y))
109 |
110 | return data
111 |
112 | def get_bbbp(self):
113 | from loader_molebert import _load_bbbp_dataset
114 |
115 | raw_path = f"{self.data_dir}/bbbp/raw/BBBP.csv"
116 |
117 | smiles_list, rdkit_mol_objs, labels = _load_bbbp_dataset(raw_path)
118 |
119 | assert len(smiles_list) == len(labels)
120 | data = []
121 | for i in range(len(smiles_list)):
122 | y = [labels[i]]
123 | smiles = smiles_list[i]
124 | if smiles != None:
125 | data.append(Data(smiles=smiles, y=y))
126 |
127 | return data
128 |
129 | def get_sider(self):
130 | from loader_molebert import _load_sider_dataset
131 |
132 | dataset = MoleculeNet(self.data_dir, name=self.name)
133 | raw_path = f"{self.data_dir}/sider/raw/sider.csv"
134 | smiles_list, rdkit_mol_objs, labels = _load_sider_dataset(raw_path)
135 |
136 | assert len(smiles_list) == len(labels)
137 | data = []
138 | for i in range(len(smiles_list)):
139 | y = [list(labels[i, :])]
140 | smiles = smiles_list[i]
141 |
142 | if smiles != None:
143 | data.append(Data(smiles=smiles, y=y))
144 |
145 | return data
146 |
147 | def get_pcba(self):
148 |
149 | dataset = MoleculeNet(self.data_dir, name=self.name)
150 | df = pd.read_csv(f"{self.data_dir}/sider/raw/pcba.csv")
151 | df = remove_non_mols(df)
152 | df = df.fillna(-1)
153 | df.reset_index(drop=True, inplace=True)
154 |
155 | data = []
156 | for i in df.index:
157 | y = [df.iloc[i, :128].values.tolist()]
158 | smiles = df.loc[i, "smiles"]
159 | data.append(Data(smiles=smiles, y=y))
160 |
161 | return data
162 |
--------------------------------------------------------------------------------
/fragnet/dataset/dta.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from .dataset import FinetuneDataDTA
4 | from .utils import extract_data, save_datasets, save_ds_parts
5 |
6 |
7 | def create_dta_dataset(args):
8 |
9 | if not os.path.exists(args.output_dir):
10 | os.makedirs(args.output_dir, exist_ok=True)
11 |
12 | dataset = FinetuneDataDTA(target_name=args.target_name, data_type=args.data_type)
13 |
14 | if args.save_parts == 0:
15 |
16 | if args.train_path:
17 | train = pd.read_csv(args.train_path)
18 | train.to_csv(f"{args.output_dir}/train.csv", index=False)
19 | ds = dataset.get_ft_dataset(train)
20 | ds = extract_data(ds)
21 | save_path = f"{args.output_dir}/train"
22 | save_datasets(ds, save_path)
23 |
24 | if args.val_path:
25 | val = pd.read_csv(args.val_path)
26 | val.to_csv(f"{args.output_dir}/val.csv", index=False)
27 | ds = dataset.get_ft_dataset(val)
28 | ds = extract_data(ds)
29 | save_path = f"{args.output_dir}/val"
30 | save_datasets(ds, save_path)
31 |
32 | if args.test_path:
33 | test = pd.read_csv(args.test_path)
34 | test.to_csv(f"{args.output_dir}/test.csv", index=False)
35 | ds = dataset.get_ft_dataset(test)
36 | ds = extract_data(ds)
37 | save_path = f"{args.output_dir}/test"
38 | save_datasets(ds, save_path)
39 |
40 | else:
41 | save_ds_parts(
42 | data_creater=dataset, ds=train, output_dir=args.output_dir, fold="train"
43 | )
44 | save_ds_parts(
45 | data_creater=dataset, ds=val, output_dir=args.output_dir, fold="val"
46 | )
47 | save_ds_parts(
48 | data_creater=dataset, ds=test, output_dir=args.output_dir, fold="test"
49 | )
50 |
--------------------------------------------------------------------------------
/fragnet/dataset/ext_data_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/dataset/ext_data_utils/__init__.py
--------------------------------------------------------------------------------
/fragnet/dataset/ext_data_utils/deepttc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from .Step1_getData import GetData
4 |
5 | class DataEncoding:
6 | def __init__(self, data_dir):
7 | self.Getdata = GetData(data_dir)
8 |
9 | def encode2(self,traindata, valdata, testdata):
10 | drug_smiles = self.Getdata.getDrug()
11 | drugid2smile = dict(zip(drug_smiles['drug_id'],drug_smiles['smiles']))
12 |
13 | traindata['smiles'] = [drugid2smile[i] for i in traindata['DRUG_ID']]
14 | valdata['smiles'] = [drugid2smile[i] for i in valdata['DRUG_ID']]
15 | testdata['smiles'] = [drugid2smile[i] for i in testdata['DRUG_ID']]
16 |
17 |
18 | traindata = traindata.reset_index()
19 | traindata['Label'] = traindata['LN_IC50']
20 |
21 | valdata = valdata.reset_index()
22 | valdata['Label'] = valdata['LN_IC50']
23 |
24 |
25 | testdata = testdata.reset_index()
26 | testdata['Label'] = testdata['LN_IC50']
27 |
28 |
29 | return traindata, valdata, testdata
30 |
--------------------------------------------------------------------------------
/fragnet/dataset/features.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rdkit import Chem
3 | from .feature_utils import one_of_k_encoding, one_of_k_encoding_unk
4 | from .feature_utils import get_bond_pair
5 |
6 |
7 | class FeaturesEXP:
8 |
9 | """
10 | Class for creating initial atom, bond and connection features.
11 | Used by CreateData in dataset/data.py
12 | """
13 |
14 | def __init__(self, add_connection_chrl=False):
15 | self.atom_list_one_hot = list(range(1, 119))
16 | self.use_bond_chirality = True
17 | self.add_connection_chrl = add_connection_chrl
18 |
19 | def get_atom_and_bond_features_atom_graph_one_hot(self, mol, use_chirality):
20 |
21 | add_self_loops = False
22 |
23 | atoms = mol.GetAtoms()
24 | bonds = mol.GetBonds()
25 | edge_index = get_bond_pair(mol, add_self_loops)
26 |
27 | node_f = [self.atom_features_one_hot(atom) for atom in atoms]
28 | edge_attr = []
29 | for bond in bonds:
30 | edge_attr.append(
31 | self.bond_features_one_hot(bond, use_chirality=use_chirality)
32 | )
33 | edge_attr.append(
34 | self.bond_features_one_hot(bond, use_chirality=use_chirality)
35 | )
36 |
37 | return node_f, edge_index, edge_attr
38 |
39 | def atom_features_one_hot(
40 | self, atom, explicit_H=False, use_chirality=False, angle_f=False
41 | ):
42 |
43 | atom_type = one_of_k_encoding_unk(atom.GetAtomicNum(), self.atom_list_one_hot)
44 | degree = one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
45 |
46 | valence = one_of_k_encoding_unk(
47 | atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]
48 | )
49 |
50 | charge = one_of_k_encoding_unk(
51 | atom.GetFormalCharge(), [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
52 | )
53 | rad_elec = one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), [0, 1, 2, 3, 4])
54 |
55 | hyb = one_of_k_encoding_unk(
56 | atom.GetHybridization(),
57 | [
58 | Chem.rdchem.HybridizationType.S,
59 | Chem.rdchem.HybridizationType.SP,
60 | Chem.rdchem.HybridizationType.SP2,
61 | Chem.rdchem.HybridizationType.SP3,
62 | Chem.rdchem.HybridizationType.SP3D,
63 | Chem.rdchem.HybridizationType.SP3D2,
64 | Chem.rdchem.HybridizationType.UNSPECIFIED,
65 | ],
66 | )
67 |
68 | arom = one_of_k_encoding(atom.GetIsAromatic(), [False, True])
69 | atom_ring = one_of_k_encoding(atom.IsInRing(), [False, True])
70 | numhs = [atom.GetTotalNumHs()]
71 |
72 | chiral = one_of_k_encoding_unk(
73 | atom.GetChiralTag(),
74 | [
75 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
76 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
77 | Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
78 | ],
79 | )
80 | results = (
81 | atom_type
82 | + degree
83 | + valence
84 | + charge
85 | + rad_elec
86 | + hyb
87 | + arom
88 | + atom_ring
89 | + chiral
90 | + numhs
91 | )
92 |
93 | return np.array(results)
94 |
95 | def bond_features_one_hot(self, bond, use_chirality=True):
96 | bt = bond.GetBondType()
97 |
98 | bond_feats = [
99 | bt == Chem.rdchem.BondType.SINGLE,
100 | bt == Chem.rdchem.BondType.DOUBLE,
101 | bt == Chem.rdchem.BondType.TRIPLE,
102 | bt == Chem.rdchem.BondType.AROMATIC,
103 | ]
104 |
105 | conj = one_of_k_encoding(bond.GetIsConjugated(), [False, True])
106 | inring = one_of_k_encoding(bond.IsInRing(), [False, True])
107 |
108 | bond_feats = bond_feats + conj + inring
109 |
110 | if use_chirality:
111 | bond_feats = bond_feats + one_of_k_encoding_unk(
112 | str(bond.GetStereo()), ["STEREOANY", "STEREOZ", "STEREOE", "STEREONONE"]
113 | )
114 |
115 | bond_feats = bond_feats + one_of_k_encoding_unk(
116 | bond.GetBondDir(),
117 | [
118 | Chem.rdchem.BondDir.BEGINWEDGE,
119 | Chem.rdchem.BondDir.BEGINDASH,
120 | Chem.rdchem.BondDir.ENDDOWNRIGHT,
121 | Chem.rdchem.BondDir.ENDUPRIGHT,
122 | Chem.rdchem.BondDir.NONE,
123 | ],
124 | )
125 | return list(bond_feats)
126 |
127 | def connection_features_one_hot(self, connection):
128 |
129 | bond = connection.bond
130 | bt = connection.bond_type
131 |
132 | bond_feats = [
133 | bt == Chem.rdchem.BondType.SINGLE,
134 | bt == Chem.rdchem.BondType.DOUBLE,
135 | bt == Chem.rdchem.BondType.TRIPLE,
136 | bt == Chem.rdchem.BondType.AROMATIC,
137 | bt == "self_cn",
138 | bt == "iso_cn3",
139 | ]
140 |
141 | if self.add_connection_chrl:
142 |
143 | conj = one_of_k_encoding(bond.GetIsConjugated(), [False, True])
144 | inring = one_of_k_encoding(bond.IsInRing(), [False, True])
145 |
146 | bond_feats = bond_feats + conj + inring
147 | bond_feats = bond_feats + one_of_k_encoding_unk(
148 | str(bond.GetStereo()), ["STEREOANY", "STEREOZ", "STEREOE", "STEREONONE"]
149 | )
150 |
151 | bond_feats = bond_feats + one_of_k_encoding_unk(
152 | bond.GetBondDir(),
153 | [
154 | Chem.rdchem.BondDir.BEGINWEDGE,
155 | Chem.rdchem.BondDir.BEGINDASH,
156 | Chem.rdchem.BondDir.ENDDOWNRIGHT,
157 | Chem.rdchem.BondDir.ENDUPRIGHT,
158 | Chem.rdchem.BondDir.NONE,
159 | ],
160 | )
161 |
162 | return list(bond_feats)
163 |
--------------------------------------------------------------------------------
/fragnet/dataset/features0.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rdkit import Chem
3 | from feature_utils import one_of_k_encoding, one_of_k_encoding_unk
4 | from feature_utils import get_bond_pair
5 |
6 |
7 | class FeaturesEXP:
8 |
9 | def __init__(self):
10 | self.atom_list_one_hot = [
11 | "Br",
12 | "C",
13 | "Cl",
14 | "F",
15 | "H",
16 | "I",
17 | "K",
18 | "N",
19 | "Na",
20 | "O",
21 | "P",
22 | "S",
23 | "Unknown",
24 | ]
25 | self.use_bond_chirality = False
26 |
27 | def get_atom_and_bond_features_atom_graph_one_hot(self, mol, use_chirality):
28 |
29 | add_self_loops = False
30 |
31 | atoms = mol.GetAtoms()
32 | bonds = mol.GetBonds()
33 | edge_index = get_bond_pair(mol, add_self_loops)
34 |
35 | node_f = [self.atom_features_one_hot(atom) for atom in atoms]
36 | edge_attr = []
37 | for bond in bonds:
38 | edge_attr.append(
39 | self.bond_features_one_hot(bond, use_chirality=use_chirality)
40 | )
41 | edge_attr.append(
42 | self.bond_features_one_hot(bond, use_chirality=use_chirality)
43 | )
44 |
45 | if add_self_loops:
46 | self_loop_attr = np.zeros((mol.GetNumAtoms(), 12)).tolist()
47 |
48 | edge_attr = edge_attr + self_loop_attr
49 |
50 | return node_f, edge_index, edge_attr
51 |
52 | def atom_features_one_hot(
53 | self,
54 | atom,
55 | bool_id_feat=False,
56 | explicit_H=False,
57 | use_chirality=False,
58 | angle_f=False,
59 | ):
60 |
61 | atom_type = one_of_k_encoding_unk(atom.GetSymbol(), self.atom_list_one_hot)
62 | degree = one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6])
63 | valence = one_of_k_encoding_unk(
64 | atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]
65 | )
66 | charge = [atom.GetFormalCharge()]
67 | rad_elec = [atom.GetNumRadicalElectrons()]
68 | hyb = one_of_k_encoding_unk(
69 | atom.GetHybridization(),
70 | [
71 | Chem.rdchem.HybridizationType.SP,
72 | Chem.rdchem.HybridizationType.SP2,
73 | Chem.rdchem.HybridizationType.SP3,
74 | Chem.rdchem.HybridizationType.SP3D,
75 | Chem.rdchem.HybridizationType.SP3D2,
76 | Chem.rdchem.HybridizationType.UNSPECIFIED,
77 | ],
78 | )
79 | arom = [atom.GetIsAromatic()]
80 | atom_ring = [atom.IsInRing()]
81 | numhs = [atom.GetTotalNumHs()]
82 |
83 | chiral = one_of_k_encoding_unk(
84 | atom.GetChiralTag(),
85 | [
86 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
87 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
88 | Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
89 | ],
90 | )
91 |
92 | results = (
93 | atom_type
94 | + degree
95 | + valence
96 | + charge
97 | + rad_elec
98 | + hyb
99 | + arom
100 | + atom_ring
101 | + numhs
102 | )
103 |
104 | if use_chirality:
105 | try:
106 | results = (
107 | results
108 | + one_of_k_encoding_unk(atom.GetProp("_CIPCode"), ["R", "S"])
109 | + [atom.HasProp("_ChiralityPossible")]
110 | )
111 | except:
112 | results = (
113 | results + [False, False] + [atom.HasProp("_ChiralityPossible")]
114 | )
115 |
116 | return np.array(results)
117 |
118 | def bond_features_one_hot(self, bond, use_chirality=True):
119 | bt = bond.GetBondType()
120 |
121 | bond_feats = [
122 | bt == Chem.rdchem.BondType.SINGLE,
123 | bt == Chem.rdchem.BondType.DOUBLE,
124 | bt == Chem.rdchem.BondType.TRIPLE,
125 | bt == Chem.rdchem.BondType.AROMATIC,
126 | bond.GetIsConjugated(),
127 | bond.IsInRing(),
128 | ]
129 |
130 | if use_chirality:
131 | bond_feats = bond_feats + one_of_k_encoding_unk(
132 | str(bond.GetStereo()), ["STEREOANY", "STEREOZ", "STEREOE", "STEREONONE"]
133 | )
134 |
135 | bond_feats = bond_feats + one_of_k_encoding_unk(
136 | bond.GetBondDir(),
137 | [
138 | Chem.rdchem.BondDir.BEGINWEDGE,
139 | Chem.rdchem.BondDir.BEGINDASH,
140 | Chem.rdchem.BondDir.ENDDOWNRIGHT,
141 | Chem.rdchem.BondDir.ENDUPRIGHT,
142 | Chem.rdchem.BondDir.NONE,
143 | ],
144 | )
145 |
146 | return list(bond_feats)
147 |
148 | def connection_features_one_hot(self, connection):
149 |
150 | bt = connection.bond_type
151 |
152 | bond_feats = [
153 | bt == Chem.rdchem.BondType.SINGLE,
154 | bt == Chem.rdchem.BondType.DOUBLE,
155 | bt == Chem.rdchem.BondType.TRIPLE,
156 | bt == Chem.rdchem.BondType.AROMATIC,
157 | bt == "self_cn",
158 | bt == "iso_cn3",
159 | ]
160 | return list(bond_feats)
161 |
--------------------------------------------------------------------------------
/fragnet/dataset/features2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rdkit import Chem
3 | import numpy as np
4 | from rdkit import Chem
5 | from .feature_utils import one_of_k_encoding, one_of_k_encoding_unk
6 | from .feature_utils import get_bond_pair
7 |
8 |
9 | class FeaturesEXP:
10 |
11 | def __init__(self):
12 | self.atom_list_one_hot = list(range(1, 119))
13 | self.use_bond_chirality = True
14 |
15 | def get_atom_and_bond_features_atom_graph_one_hot(self, mol, use_chirality):
16 |
17 | add_self_loops = False
18 |
19 | atoms = mol.GetAtoms()
20 | bonds = mol.GetBonds()
21 | edge_index = get_bond_pair(mol, add_self_loops)
22 |
23 | node_f = [self.atom_features_one_hot(atom) for atom in atoms]
24 | edge_attr = []
25 | for bond in bonds:
26 | edge_attr.append(
27 | self.bond_features_one_hot(bond, use_chirality=use_chirality)
28 | )
29 | edge_attr.append(
30 | self.bond_features_one_hot(bond, use_chirality=use_chirality)
31 | )
32 |
33 | return node_f, edge_index, edge_attr
34 |
35 | def atom_features_one_hot(
36 | self, atom, explicit_H=False, use_chirality=False, angle_f=False
37 | ):
38 |
39 | atom_type = one_of_k_encoding_unk(atom.GetAtomicNum(), self.atom_list_one_hot)
40 | degree = one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
41 |
42 | valence = one_of_k_encoding_unk(
43 | atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]
44 | )
45 |
46 | charge = one_of_k_encoding_unk(
47 | atom.GetFormalCharge(), [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
48 | )
49 | rad_elec = one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), [0, 1, 2, 3, 4])
50 |
51 | hyb = one_of_k_encoding_unk(
52 | atom.GetHybridization(),
53 | [
54 | Chem.rdchem.HybridizationType.S,
55 | Chem.rdchem.HybridizationType.SP,
56 | Chem.rdchem.HybridizationType.SP2,
57 | Chem.rdchem.HybridizationType.SP3,
58 | Chem.rdchem.HybridizationType.SP3D,
59 | Chem.rdchem.HybridizationType.SP3D2,
60 | Chem.rdchem.HybridizationType.UNSPECIFIED,
61 | ],
62 | )
63 | arom = one_of_k_encoding(atom.GetIsAromatic(), [False, True])
64 | atom_ring = one_of_k_encoding(atom.IsInRing(), [False, True])
65 | numhs = [atom.GetTotalNumHs()]
66 |
67 | chiral = one_of_k_encoding_unk(
68 | atom.GetChiralTag(),
69 | [
70 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
71 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
72 | Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
73 | ],
74 | )
75 |
76 | results = (
77 | atom_type
78 | + degree
79 | + valence
80 | + charge
81 | + rad_elec
82 | + hyb
83 | + arom
84 | + atom_ring
85 | + chiral
86 | + numhs
87 | )
88 |
89 | return np.array(results)
90 |
91 | def bond_features_one_hot(self, bond, use_chirality=True):
92 | bt = bond.GetBondType()
93 |
94 | bond_feats = [
95 | bt == Chem.rdchem.BondType.SINGLE,
96 | bt == Chem.rdchem.BondType.DOUBLE,
97 | bt == Chem.rdchem.BondType.TRIPLE,
98 | bt == Chem.rdchem.BondType.AROMATIC,
99 | ]
100 |
101 | conj = one_of_k_encoding(bond.GetIsConjugated(), [False, True])
102 | inring = one_of_k_encoding(bond.IsInRing(), [False, True])
103 |
104 | bond_feats = bond_feats + conj + inring
105 |
106 | if use_chirality:
107 | bond_feats = bond_feats + one_of_k_encoding_unk(
108 | str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREONONE"]
109 | )
110 |
111 | bond_feats = bond_feats + one_of_k_encoding_unk(
112 | bond.GetBondDir(),
113 | [
114 | Chem.rdchem.BondDir.ENDDOWNRIGHT,
115 | Chem.rdchem.BondDir.ENDUPRIGHT,
116 | Chem.rdchem.BondDir.NONE,
117 | ],
118 | )
119 | return list(bond_feats)
120 |
121 | def connection_features_one_hot(self, connection):
122 |
123 | bt = connection.bond_type
124 |
125 | bond_feats = [
126 | bt == Chem.rdchem.BondType.SINGLE,
127 | bt == Chem.rdchem.BondType.DOUBLE,
128 | bt == Chem.rdchem.BondType.TRIPLE,
129 | bt == Chem.rdchem.BondType.AROMATIC,
130 | bt == "self_cn",
131 | bt == "iso_cn3",
132 | ]
133 | return list(bond_feats)
134 |
--------------------------------------------------------------------------------
/fragnet/dataset/fix_pt_data.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import pandas as pd
3 | import torch
4 | from models import FragNetPreTrain
5 | from data import collate_fn_pt as collate_fn
6 | from torch.utils.data import DataLoader
7 | import os
8 | import numpy as np
9 |
10 |
11 | files = os.listdir("pretrain_data/unimol_exp_/ds/")
12 | data = []
13 | lens = 0
14 | for f in files:
15 |
16 | df1 = pd.read_pickle(f"pretrain_data/unimol_exp_/ds/{f}")
17 | df2 = pd.read_pickle(f"pretrain_data/unimol_exp/ds/{f}")
18 |
19 | for i in range(len(df1)):
20 |
21 | data.append(df1[i].y.item() == df2[i].y.item())
22 |
23 | # break
24 | lens += len(df2)
25 |
26 | print("non_zero: ", np.count_nonzero(np.array(data)) == len(data) == lens)
27 |
--------------------------------------------------------------------------------
/fragnet/dataset/general.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch_geometric.datasets import MoleculeNet
4 |
5 | from .utils import save_datasets
6 | from .dataset import FinetuneData, FinetuneMultiConfData
7 | from .utils import extract_data
8 | from .splitters import ScaffoldSplitter
9 | from .custom_dataset import MoleculeDataset
10 | from .utils import save_ds_parts
11 | import pandas as pd
12 |
13 |
14 | def create_general_dataset(args):
15 |
16 | if args.multi_conf_data == 1:
17 | print("generating multi-conf data")
18 | dataset = FinetuneMultiConfData(
19 | target_name=args.target_name,
20 | data_type=args.data_type,
21 | frag_type=args.frag_type,
22 | )
23 | else:
24 | print("generating single-conf data")
25 | dataset = FinetuneData(
26 | target_name=args.target_name,
27 | data_type=args.data_type,
28 | frag_type=args.frag_type,
29 | )
30 |
31 | if not os.path.exists(args.output_dir):
32 | os.makedirs(args.output_dir, exist_ok=True)
33 |
34 | if args.train_path:
35 | train = pd.read_csv(args.train_path)
36 | train.to_csv(f"{args.output_dir}/train.csv", index=False)
37 | ds = dataset.get_ft_dataset(train)
38 | ds = extract_data(ds)
39 | save_path = f"{args.output_dir}/train"
40 | save_datasets(ds, save_path)
41 |
42 | if args.val_path:
43 | val = pd.read_csv(args.val_path)
44 | val.to_csv(f"{args.output_dir}/val.csv", index=False)
45 | ds = dataset.get_ft_dataset(val)
46 | ds = extract_data(ds)
47 | save_path = f"{args.output_dir}/val"
48 | save_datasets(ds, save_path)
49 |
50 | if args.test_path:
51 | test = pd.read_csv(args.test_path)
52 | test.to_csv(f"{args.output_dir}/test.csv", index=False)
53 | ds = dataset.get_ft_dataset(test)
54 | ds = extract_data(ds)
55 | save_path = f"{args.output_dir}/test"
56 | save_datasets(ds, save_path)
57 |
--------------------------------------------------------------------------------
/fragnet/dataset/moleculenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch_geometric.datasets import MoleculeNet
4 |
5 | from .utils import save_datasets
6 | from .dataset import FinetuneData, FinetuneMultiConfData
7 | from .utils import extract_data
8 | from .splitters import ScaffoldSplitter
9 | from .custom_dataset import MoleculeDataset
10 | from .utils import save_ds_parts
11 |
12 |
13 | def create_moleculenet_dataset(dstype, name, args):
14 |
15 | if not os.path.exists(args.output_dir + "/" + name):
16 | os.makedirs(args.output_dir + "/" + name, exist_ok=True)
17 |
18 | if dstype == "MoleculeNet":
19 | ds = MoleculeNet(f"{args.data_dir}", name=name)
20 | elif dstype == "MoleculeDataset":
21 | _ = MoleculeNet(f"{args.data_dir}", name=name)
22 | dataset = MoleculeDataset(name, args.data_dir)
23 | ds = dataset.get_data()
24 |
25 | if not args.use_molebert:
26 | scaffold_split = ScaffoldSplitter()
27 | train, val, test = scaffold_split.split(dataset=dataset, include_chirality=True)
28 | elif args.use_molebert:
29 | from .splitters_molebert import scaffold_split
30 |
31 | smiles = [i.smiles for i in ds]
32 | train, val, test, (train_smiles, valid_smiles, test_smiles) = scaffold_split(
33 | ds, smiles, return_smiles=True
34 | )
35 |
36 | torch.save(train, f"{args.output_dir}/{name}/train.pt")
37 | torch.save(val, f"{args.output_dir}/{name}/val.pt")
38 | torch.save(test, f"{args.output_dir}/{name}/test.pt")
39 |
40 | if args.multi_conf_data:
41 | dataset = FinetuneMultiConfData(args.target_name, args.data_type)
42 | else:
43 | dataset = FinetuneData(
44 | args.target_name, args.data_type, frag_type=args.frag_type
45 | )
46 |
47 | if not args.save_parts:
48 |
49 | ds = dataset.get_ft_dataset(train)
50 | ds = extract_data(ds)
51 | save_path = f"{args.output_dir}/{name}/train"
52 | save_datasets(ds, save_path)
53 |
54 | ds = dataset.get_ft_dataset(val)
55 | ds = extract_data(ds)
56 | save_path = f"{args.output_dir}/{name}/val"
57 | save_datasets(ds, save_path)
58 |
59 | ds = dataset.get_ft_dataset(test)
60 | ds = extract_data(ds)
61 | save_path = f"{args.output_dir}/{name}/test"
62 | save_datasets(ds, save_path)
63 |
64 | elif args.save_parts:
65 | save_ds_parts(
66 | data_creater=dataset,
67 | ds=train,
68 | output_dir=args.output_dir,
69 | name=name,
70 | fold="train",
71 | )
72 | save_ds_parts(
73 | data_creater=dataset,
74 | ds=val,
75 | output_dir=args.output_dir,
76 | name=name,
77 | fold="val",
78 | )
79 | save_ds_parts(
80 | data_creater=dataset,
81 | ds=test,
82 | output_dir=args.output_dir,
83 | name=name,
84 | fold="test",
85 | )
86 |
--------------------------------------------------------------------------------
/fragnet/dataset/scaffold_split_from_df.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from .dataset import FinetuneData
4 | from .utils import extract_data
5 | from .utils import save_datasets
6 |
7 |
8 | def create_scaffold_split_data_from_df(args):
9 |
10 | if not os.path.exists(args.output_dir):
11 | os.makedirs(args.output_dir, exist_ok=True)
12 |
13 | ds = pd.read_csv(args.data_dir)
14 | ds.reset_index(drop=True, inplace=True)
15 |
16 | from splitters_molebert import scaffold_split
17 |
18 | smiles = ds.smiles.values.tolist()
19 |
20 | train, val, test, (train_smiles, valid_smiles, test_smiles) = scaffold_split(
21 | ds, smiles, return_smiles=True
22 | )
23 |
24 | train.reset_index(drop=True, inplace=True)
25 | val.reset_index(drop=True, inplace=True)
26 | test.reset_index(drop=True, inplace=True)
27 |
28 | dataset = FinetuneData(target_name=args.target_name, data_type=args.data_type)
29 |
30 | if not args.save_parts:
31 |
32 | train.to_csv(f"{args.output_dir}/train.csv", index=False)
33 | ds = dataset.get_ft_dataset(train)
34 | ds = extract_data(ds)
35 | save_path = f"{args.output_dir}/train"
36 | save_datasets(ds, save_path)
37 |
38 | val.to_csv(f"{args.output_dir}/val.csv", index=False)
39 | ds = dataset.get_ft_dataset(val)
40 | ds = extract_data(ds)
41 | save_path = f"{args.output_dir}/val"
42 | save_datasets(ds, save_path)
43 |
44 | test.to_csv(f"{args.output_dir}/test.csv", index=False)
45 | ds = dataset.get_ft_dataset(test)
46 | ds = extract_data(ds)
47 | save_path = f"{args.output_dir}/test"
48 | save_datasets(ds, save_path)
49 |
--------------------------------------------------------------------------------
/fragnet/dataset/simsgt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.datasets import MoleculeNet
3 | from .utils import save_datasets
4 | import pandas as pd
5 | from .dataset import FinetuneData
6 | from .utils import extract_data
7 |
8 |
9 | def create_moleculenet_dataset_simsgt(name, args):
10 | from splitters_simsgt import scaffold_split
11 | from loader_simsgt import MoleculeDataset
12 |
13 | d = MoleculeNet(f"{args.output_dir}/simsgt", name=name)
14 | dataset = MoleculeDataset(f"{args.output_dir}/simsgt/" + name, dataset=name)
15 |
16 | smiles_list = pd.read_csv(
17 | f"{args.output_dir}/simsgt/" + name + "/processed/smiles.csv", header=None
18 | )[0].tolist()
19 | (
20 | train_dataset,
21 | valid_dataset,
22 | test_dataset,
23 | (train_smiles, valid_smiles, test_smiles),
24 | ) = scaffold_split(
25 | dataset,
26 | smiles_list,
27 | null_value=0,
28 | frac_train=0.8,
29 | frac_valid=0.1,
30 | frac_test=0.1,
31 | return_smiles=True,
32 | )
33 |
34 | torch.save(train_dataset, f"{args.output_dir}/simsgt/{name}/train.pt")
35 | torch.save(valid_dataset, f"{args.output_dir}/simsgt/{name}/val.pt")
36 | torch.save(test_dataset, f"{args.output_dir}/simsgt/{name}/test.pt")
37 |
38 | dataset = FinetuneData(args.target_name)
39 |
40 | if not args.save_parts:
41 |
42 | ds = dataset.get_ft_dataset(train_dataset)
43 | ds = extract_data(ds)
44 | save_path = f"{args.output_dir}/simsgt/{name}/train"
45 | save_datasets(ds, save_path)
46 |
47 | ds = dataset.get_ft_dataset(valid_dataset)
48 | ds = extract_data(ds)
49 | save_path = f"{args.output_dir}/simsgt/{name}/val"
50 | save_datasets(ds, save_path)
51 |
52 | ds = dataset.get_ft_dataset(test_dataset)
53 | ds = extract_data(ds)
54 | save_path = f"{args.output_dir}/simsgt/{name}/test"
55 | save_datasets(ds, save_path)
56 |
--------------------------------------------------------------------------------
/fragnet/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import pandas as pd
4 | import os
5 | import lmdb
6 | from rdkit import Chem
7 | from tqdm import tqdm
8 |
9 |
10 | def remove_non_mols(ds):
11 |
12 | keep_smiles = []
13 | keep_ids = []
14 | if isinstance(ds, pd.DataFrame):
15 | smiles_list = ds.smiles.values
16 | else:
17 | smiles_list = [i.smiles for i in ds]
18 |
19 | for i, smiles in enumerate(smiles_list):
20 | mol = Chem.MolFromSmiles(smiles)
21 | if mol:
22 | keep_ids.append(i)
23 |
24 | if isinstance(ds, pd.DataFrame):
25 | df2 = ds.loc[keep_ids, :]
26 | df2.reset_index(drop=True, inplace=True)
27 | else:
28 | df2 = ds.index_select(keep_ids)
29 |
30 | return df2
31 |
32 |
33 | def extract_data(ds):
34 | res = []
35 | for i, data in tqdm(enumerate(ds)):
36 | if data:
37 | res.append(data)
38 | return res
39 |
40 |
41 | def save_datasets(ds, save_path):
42 | with open(f"{save_path}.pkl", "wb") as f:
43 | pickle.dump(ds, f)
44 |
45 |
46 | def remove_frags():
47 |
48 | frags = [len(Chem.GetMolFrags(Chem.MolFromSmiles(sm))) for sm in train.smiles]
49 | train["mfrags"] = frags
50 |
51 | frags = [len(Chem.GetMolFrags(Chem.MolFromSmiles(sm))) for sm in val.smiles]
52 | val["mfrags"] = frags
53 |
54 | train = train[train.mfrags == 1]
55 | val = val[val.mfrags == 1]
56 |
57 | # train.shape
58 |
59 | train.reset_index(drop=True, inplace=True)
60 | val.reset_index(drop=True, inplace=True)
61 |
62 |
63 | def mol_with_atom_index(mol):
64 | for atom in mol.GetAtoms():
65 | atom.SetAtomMapNum(atom.GetIdx())
66 | return mol
67 |
68 |
69 | # TODO: test this function
70 | def remove_bond(rwmol, idx1, idx2):
71 | rwmol.RemoveBond(idx1, idx2)
72 | for idx in [idx1, idx2]:
73 | atom = rwmol.GetAtomWithIdx(idx)
74 | if atom.GetSymbol() == "N" and atom.GetIsAromatic() is True:
75 | atom.SetNumExplicitHs(1)
76 |
77 |
78 | def get_data(lmdb_path, name=None):
79 |
80 | # lmdb_path='ligands/train.lmdb'
81 |
82 | env = lmdb.open(
83 | lmdb_path,
84 | subdir=False,
85 | readonly=True,
86 | lock=False,
87 | readahead=False,
88 | meminit=False,
89 | max_readers=256,
90 | )
91 | txn = env.begin()
92 | keys = list(txn.cursor().iternext(values=False))
93 |
94 | smiles_data = []
95 | for idx in keys:
96 | datapoint_pickled = txn.get(idx)
97 | data = pickle.loads(datapoint_pickled)
98 | smiles_data.append({"smiles": data["smi"], "target": data["target"]})
99 |
100 | if name in ["clintox", "tox21", "toxcast", "sider", "pcba", "muv"]:
101 | for i in range(len(smiles_data)):
102 | smiles_data[i]["target"] = [list(smiles_data[i]["target"])]
103 |
104 | return smiles_data
105 |
106 |
107 | def collect_and_save(path, fold):
108 |
109 | f = os.listdir(path)
110 | t = [i for i in f if fold + "_p" in i and i.endswith("pkl")]
111 |
112 | data = []
113 | for i in t:
114 | data.extend(pd.read_pickle(path + "/" + i))
115 |
116 | with open(f"{path}/{fold}.pkl", "wb") as f:
117 | pickle.dump(data, f)
118 | # return data
119 |
120 |
121 | def save_ds_parts(data_creater=None, ds=None, output_dir=None, name=None, fold=None):
122 |
123 | if isinstance(ds, pd.DataFrame):
124 | ds.reset_index(drop=True, inplace=True)
125 |
126 | n = len(ds) // 1000
127 | parts = np.array_split(ds, n)
128 |
129 | for ipart, part in enumerate(parts):
130 | # ds_tmp = [ds[i] for i in ids]
131 |
132 | ds_tmp = data_creater.get_ft_dataset(part)
133 | ds_tmp = extract_data(ds_tmp)
134 | save_path = f"{output_dir}/{name}/{fold}_p_{ipart}"
135 | save_datasets(ds_tmp, save_path)
136 | if name:
137 | collect_and_save(f"{output_dir}/{name}", fold)
138 | else:
139 | collect_and_save(f"{output_dir}", fold)
140 |
141 | else:
142 |
143 | n = len(ds) // 1000
144 | id_list = np.array_split(np.arange(len(ds)), n)
145 |
146 | for ipart, ids in enumerate(id_list):
147 | ds_tmp = [ds[i] for i in ids]
148 |
149 | ds_tmp = data_creater.get_ft_dataset(ds_tmp)
150 | ds_tmp = extract_data(ds_tmp)
151 | save_path = f"{output_dir}/{name}/{fold}_p_{ipart}"
152 | save_datasets(ds_tmp, save_path)
153 | if name:
154 | collect_and_save(f"{output_dir}/{name}", fold)
155 | else:
156 | collect_and_save(f"{output_dir}", fold)
157 |
--------------------------------------------------------------------------------
/fragnet/exps/README.md:
--------------------------------------------------------------------------------
1 | If you want to view the contents of the weights at `pt/unimol_exp1s4/pt.pt`, you can run the following code,
2 |
3 | ```
4 | import torch
5 | pt = torch.load('fragnet/exps/pt/unimol_exp1s4/pt.pt')
6 | print(pt)
7 | ```
8 | I have given the content of pt in `pt/unimol_exp1s4/pt.pt.data`.
9 |
10 |
11 | Similarly for the file at, `ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10`
12 |
13 | ```
14 | import torch
15 | ft = torch.load('fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt')
16 | print(pt)
17 | ```
18 |
19 | The content of ft is in `ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt.data`
--------------------------------------------------------------------------------
/fragnet/exps/ft/esol/e1pt4.yaml:
--------------------------------------------------------------------------------
1 | seed: 123
2 | # seed: 4
3 | # seed: 456
4 | data_seed: None
5 | exp_dir: exps/ft/esol
6 | model_version: gat2
7 | device: gpu
8 | atom_features: 167
9 | frag_features: 167
10 | edge_features: 17
11 | fedge_in: 6
12 | fbond_edge_in: 6
13 | # atom_features: 49
14 | # frag_features: 49
15 | # edge_features: 12
16 | #model_version: gat2_transformer
17 | # model_version: gat2_transformer2
18 | pretrain:
19 | model_version: gat2
20 | num_layer: 4
21 | drop_ratio: 0.2
22 | num_heads: 4
23 | emb_dim: 128
24 | chkpoint_name: exps/pt/unimol_exp1s4/pt.pt
25 | loss: mse
26 | batch_size: 128
27 | es_patience: 500
28 | lr: 1e-4
29 | n_epochs: 20000
30 | n_classes: 1
31 | # molebert splitting
32 |
33 | finetune:
34 | n_multi_task_heads: 0
35 | # batch_size: 24 # best
36 | batch_size: 16
37 | lr: 1e-4
38 | model:
39 | n_classes: 1
40 | num_layer: 4
41 | drop_ratio: 0.1
42 | num_heads: 4
43 | emb_dim: 128
44 | h1: 128 #128
45 | h2: 1024
46 | h3: 1024
47 | h4: 512
48 | act: relu
49 | fthead: FTHead3
50 |
51 | n_epochs: 10000
52 | target_type: regr
53 | loss: mse
54 | use_schedular: False
55 | es_patience: 100
56 | chkpoint_name: ${exp_dir}/ft.pt
57 | train:
58 | path: finetune_data/moleculenet_exp1s/esol/train.pkl # 20 exp node features
59 | val:
60 | path: finetune_data/moleculenet_exp1s/esol/val.pkl # 20 exp node features
61 | test:
62 | path: finetune_data/moleculenet_exp1s/esol/test.pkl # 20 exp node features
63 |
--------------------------------------------------------------------------------
/fragnet/exps/ft/lipo/download.sh:
--------------------------------------------------------------------------------
1 | scp marianas:/people/pana982/solubility/models/fragnet/fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/ft_100.pt fragnet_hpdl_exp1s_pt4_30/
2 |
3 | scp marianas:/people/pana982/solubility/models/fragnet/fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/config_exp100.yaml fragnet_hpdl_exp1s_pt4_30/
4 |
--------------------------------------------------------------------------------
/fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/config_exp100.yaml:
--------------------------------------------------------------------------------
1 | atom_features: 167
2 | data_seed: None
3 | device: gpu
4 | edge_features: 17
5 | exp_dir: exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30
6 | finetune:
7 | batch_size: 16
8 | chkpoint_name: exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/ft_100.pt
9 | es_patience: 100
10 | loss: mse
11 | lr: 1e-4
12 | model:
13 | act: relu
14 | drop_ratio: 0.0
15 | emb_dim: 128
16 | fthead: FTHead3
17 | h1: 640
18 | h2: 1408
19 | h3: 1472
20 | h4: 1792
21 | n_classes: 1
22 | num_heads: 4
23 | num_layer: 4
24 | n_epochs: 10000
25 | n_multi_task_heads: 0
26 | target_type: regr
27 | test:
28 | path: finetune_data/moleculenet_exp1s/lipo/test.pkl
29 | train:
30 | path: finetune_data/moleculenet_exp1s/lipo/train.pkl
31 | use_schedular: false
32 | val:
33 | path: finetune_data/moleculenet_exp1s/lipo/val.pkl
34 | frag_features: 167
35 | model_version: gat2
36 | pretrain:
37 | batch_size: 128
38 | chkpoint_name: exps/pt/unimol_exp1s4/pt.pt
39 | drop_ratio: 0.2
40 | emb_dim: 128
41 | es_patience: 500
42 | loss: mse
43 | lr: 1e-4
44 | n_epochs: 20000
45 | num_heads: 4
46 | num_layer: 4
47 | seed: 100
48 |
--------------------------------------------------------------------------------
/fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/ft_100.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/ft_100.pt
--------------------------------------------------------------------------------
/fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/config_exp100.yaml:
--------------------------------------------------------------------------------
1 | atom_features: 167
2 | data_seed: None
3 | device: gpu
4 | edge_features: 17
5 | exp_dir: exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10
6 | finetune:
7 | batch_size: 16
8 | chkpoint_name: exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt
9 | es_patience: 100
10 | loss: mse
11 | lr: 1e-4
12 | model:
13 | act: selu
14 | drop_ratio: 0.1
15 | emb_dim: 128
16 | fthead: FTHead4
17 | h1: 1472
18 | h2: 1024
19 | h3: 1024
20 | h4: 512
21 | n_classes: 1
22 | num_heads: 4
23 | num_layer: 4
24 | n_epochs: 10000
25 | n_multi_task_heads: 0
26 | target_type: regr
27 | test:
28 | path: finetune_data/pnnl_full/test.pkl
29 | train:
30 | path: finetune_data/pnnl_full/train.pkl
31 | use_schedular: false
32 | val:
33 | path: finetune_data/pnnl_full/val.pkl
34 | frag_features: 167
35 | model_version: gat2
36 | pretrain:
37 | batch_size: 128
38 | chkpoint_name: exps/pt/unimol_exp1s4/pt.pt
39 | drop_ratio: 0.2
40 | emb_dim: 128
41 | es_patience: 500
42 | loss: mse
43 | lr: 1e-4
44 | n_classes: 1
45 | n_epochs: 20000
46 | num_heads: 4
47 | num_layer: 4
48 | seed: 100
49 |
--------------------------------------------------------------------------------
/fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt
--------------------------------------------------------------------------------
/fragnet/exps/pt/unimol_exp1s4/config.yaml:
--------------------------------------------------------------------------------
1 | seed: 123
2 | exp_dir: exps/pt/unimol_exp1s4
3 | atom_features: 167
4 | frag_features: 167
5 | edge_features: 17
6 | model_version: gat2
7 | device: cpu
8 | fedge_in: 6
9 | fbond_edge_in: 6
10 | pretrain:
11 | num_layer: 4
12 | drop_ratio: 0.2
13 | num_heads: 4
14 | emb_dim: 128
15 | chkpoint_name: ${exp_dir}/pt.pt
16 | saved_checkpoint: null
17 | loss: mse
18 | batch_size: 512
19 | es_patience: 200
20 | lr: 1e-4
21 | n_epochs: 200
22 | valdiate_every: 5
23 | data:
24 | - pretrain_data/esol/
25 | train_smiles: null
26 | val_smiles: null
--------------------------------------------------------------------------------
/fragnet/exps/pt/unimol_exp1s4/pt.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/exps/pt/unimol_exp1s4/pt.pt
--------------------------------------------------------------------------------
/fragnet/hp/hp.py:
--------------------------------------------------------------------------------
1 |
2 | from torch_geometric.data import DataLoader
3 | import random
4 | import hyperopt
5 | from hyperopt import fmin, hp, Trials, STATUS_OK
6 | import time
7 | from dataset import load_pickle_dataset
8 | from torch.utils.data import DataLoader
9 | from data import collate_fn
10 | from gat import FragNet
11 | import torch.nn as nn
12 | from gat import FragNetPreTrain
13 | from utils import EarlyStopping
14 | import numpy as np
15 | from utils import test_fn
16 | from dataset import load_data_parts
17 | import os
18 | import torch
19 | from torch_scatter import scatter_add
20 |
21 | def get_optimizer(model, freeze_pt_weights=False, lr=1e-4):
22 |
23 | if freeze_pt_weights:
24 | print('freezing pretrain weights')
25 | for name, param in model.named_parameters():
26 | if param.requires_grad and 'pretrain' in name:
27 | param.requires_grad = False
28 |
29 | non_frozen_parameters = [p for p in model.parameters() if p.requires_grad]
30 | optimizer = torch.optim.Adam(non_frozen_parameters, lr = lr)
31 | else:
32 | print('no freezing of the weights')
33 | optimizer = torch.optim.Adam(model.parameters(), lr = lr )
34 |
35 | return model, optimizer
36 |
37 |
38 | def set_seed(seed):
39 | torch.manual_seed(seed)
40 | torch.cuda.manual_seed_all(seed)
41 | torch.backends.cudnn.deterministic = True
42 | torch.backends.cudnn.benchmark = False
43 | np.random.seed(seed)
44 | random.seed(seed)
45 | os.environ['PYTHONHASHSEED'] = str(seed)
46 |
47 |
48 |
49 | def trainModel(params):
50 |
51 |
52 | class FragNetFineTune(nn.Module):
53 |
54 | def __init__(self):
55 | super(FragNetFineTune, self).__init__()
56 |
57 | self.pretrain = FragNet(num_layer=6, drop_ratio= params['d1'] )
58 | self.lin1 = nn.Linear(128*2, int(params['f2']))
59 | self.lin2 = nn.Linear(int(params['f2']), int(params['f3']))
60 | self.out = nn.Linear(int(params['f3']), 1)
61 | self.dropout = nn.Dropout(p= params['d2'] )
62 | self.activation = nn.ReLU()
63 |
64 |
65 | def forward(self, batch):
66 |
67 | x_atoms, x_frags = self.pretrain(batch)
68 |
69 | x_frags_pooled = scatter_add(src=x_frags, index=batch['frag_batch'], dim=0)
70 | x_atoms_pooled = scatter_add(src=x_atoms, index=batch['batch'], dim=0)
71 |
72 | cat = torch.cat((x_atoms_pooled, x_frags_pooled), 1)
73 | x = self.dropout(cat)
74 |
75 | x = self.lin1(x)
76 | x = self.activation(x)
77 | x = self.dropout(x)
78 |
79 | x = self.lin2(x)
80 | x = self.activation(x)
81 | x = self.dropout(x)
82 |
83 | x = self.out(x)
84 |
85 | return x
86 |
87 |
88 |
89 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=params['bs'], shuffle=True, drop_last=True)
90 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=128, shuffle=False, drop_last=False)
91 |
92 | model_pretrain = FragNetPreTrain();
93 | model_pretrain.to(device);
94 | model_pretrain.load_state_dict(torch.load('pt.pt'))
95 | model = FragNetFineTune()
96 | model.to(device);
97 | loss_fn = nn.MSELoss()
98 | chkpoint_name = 'hp2.pt'
99 |
100 |
101 | model, optimizer = get_optimizer(model, freeze_pt_weights=freeze_pt_weights, lr=1e-5)
102 | early_stopping = EarlyStopping(patience=20, verbose=True, chkpoint_name=chkpoint_name)
103 |
104 | for epoch in range(n_epochs):
105 |
106 | res = []
107 | model.train()
108 | total_loss = 0
109 | for batch in train_loader:
110 | for k,v in batch.items():
111 | batch[k] = batch[k].to(device)
112 | optimizer.zero_grad()
113 | out = model(batch).view(-1,)
114 | loss = loss_fn(out, batch['y'])
115 | loss.backward()
116 | total_loss += loss.item()
117 | optimizer.step()
118 |
119 |
120 | try:
121 | val_loss, _, _ = test_fn(val_loader, model, device)
122 | res.append(val_loss)
123 | print("val mse: ", val_loss)
124 |
125 | early_stopping(val_loss, model)
126 |
127 | if early_stopping.early_stop:
128 | print("Early stopping")
129 | break
130 | except:
131 | print('val loss cannot be calculated')
132 | pass
133 |
134 | try:
135 | model.load_state_dict(torch.load(chkpoint_name))
136 | test_mse, test_t, test_p = test_fn(val_loader, model, device)
137 | return {'loss':test_mse, 'status':STATUS_OK}
138 |
139 | except:
140 | return {'loss':1000, 'status':STATUS_OK}
141 |
142 | if __name__ == '__main__':
143 |
144 | dataset_name='moleculenet'
145 | dataset_subset = 'esol'
146 | freeze_pt_weights=False
147 | # add_pt_weights=True
148 | seed = None
149 | n_epochs = 100
150 |
151 | if dataset_name == 'moleculenet':
152 | train_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'train_{str(seed)}')
153 | val_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'val_{str(seed)}')
154 | test_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'test_{str(seed)}')
155 | elif 'pnnl' in dataset_name:
156 | train_dataset = load_data_parts(path=dataset_name, name='train')
157 | val_dataset = load_data_parts(path=dataset_name, name='val')
158 | test_dataset = load_data_parts(path=dataset_name, name='test')
159 |
160 |
161 | space = {
162 | 'f1': hp.quniform('f1', 32, 320, 32),
163 | 'f2': hp.quniform('f2', 32, 320, 32),
164 | 'f3': hp.quniform('f3', 32, 320, 32),
165 | 'd1': hp.uniform('d1', 0,1),
166 | 'd2': hp.uniform('d2', 0,1),
167 | 'bs' : hp.choice('bs', [16, 32, 128]),
168 | }
169 |
170 |
171 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
172 |
173 |
174 | trials = Trials()
175 | st_time = time.time()
176 | best = fmin(trainModel, space, algo=hyperopt.tpe.suggest, max_evals=100, trials=trials)
177 |
178 | end_time = time.time()
179 | fo = open("res2.txt", "w")
180 | fo.write(repr(best))
181 | fo.close()
182 |
183 | print( (end_time-st_time)/3600 )
--------------------------------------------------------------------------------
/fragnet/hp/hp_clf.py:
--------------------------------------------------------------------------------
1 |
2 | from torch_geometric.data import DataLoader
3 | import random
4 | import hyperopt
5 | from hyperopt import fmin, hp, Trials, STATUS_OK
6 | import time
7 | from dataset import load_pickle_dataset
8 | from torch.utils.data import DataLoader
9 | from data import collate_fn
10 | import torch.nn as nn
11 | from utils import EarlyStopping
12 | import numpy as np
13 | import os
14 | import torch
15 | from omegaconf import OmegaConf
16 | import argparse
17 | from utils import TrainerFineTune as Trainer
18 | import torch.optim.lr_scheduler as lr_scheduler
19 |
20 | def get_optimizer(model, freeze_pt_weights=False, lr=1e-4):
21 |
22 | if freeze_pt_weights:
23 | print('freezing pretrain weights')
24 | for name, param in model.named_parameters():
25 | if param.requires_grad and 'pretrain' in name:
26 | param.requires_grad = False
27 |
28 | non_frozen_parameters = [p for p in model.parameters() if p.requires_grad]
29 | optimizer = torch.optim.Adam(non_frozen_parameters, lr = lr)
30 | else:
31 | print('no freezing of the weights')
32 | optimizer = torch.optim.Adam(model.parameters(), lr = lr )
33 |
34 | return model, optimizer
35 |
36 |
37 | def set_seed(seed):
38 | torch.manual_seed(seed)
39 | torch.cuda.manual_seed_all(seed)
40 | torch.backends.cudnn.deterministic = True
41 | torch.backends.cudnn.benchmark = False
42 | np.random.seed(seed)
43 | random.seed(seed)
44 | os.environ['PYTHONHASHSEED'] = str(seed)
45 |
46 |
47 |
48 | def trainModel(params):
49 |
50 |
51 | exp_dir = args.exp_dir
52 | n_classes_pt = args.pretrain.n_classes
53 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
54 |
55 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings'
56 |
57 | from gat2 import FragNetFineTune
58 | model = FragNetFineTune(n_classes=args.finetune.model.n_classes,
59 | num_layer=params['num_layer'],
60 | drop_ratio=params['drop_ratio'])
61 |
62 |
63 | if pt_chkpoint_name:
64 | model_pretrain = FragNetFineTune(n_classes_pt)
65 | model_pretrain.load_state_dict(torch.load(pt_chkpoint_name))
66 | state_dict_to_load={}
67 | for k,v in model.state_dict().items():
68 |
69 | if v.size() == model_pretrain.state_dict()[k].size():
70 | state_dict_to_load[k] = model_pretrain.state_dict()[k]
71 | else:
72 | state_dict_to_load[k] = v
73 |
74 | model.load_state_dict(state_dict_to_load)
75 |
76 | train_dataset2 = load_pickle_dataset(args.finetune.train.path, args.finetune.train.name)
77 | val_dataset2 = load_pickle_dataset(args.finetune.val.path, args.finetune.val.name)
78 | test_dataset2 = load_pickle_dataset(args.finetune.test.path, args.finetune.test.name) #'finetune_data/pnnl_exp'
79 |
80 |
81 | train_loader = DataLoader(train_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=True, drop_last=True)
82 | val_loader = DataLoader(val_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=False, drop_last=False)
83 | test_loader = DataLoader(test_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=False, drop_last=False)
84 |
85 | trainer = Trainer(target_type=args.finetune.target_type)
86 |
87 |
88 | model.to(device);
89 | ft_chk_point = f'{args.exp_dir}/fthp.pt'
90 | early_stopping = EarlyStopping(patience=100, verbose=True, chkpoint_name=ft_chk_point)
91 |
92 |
93 | if args.finetune.loss == 'mse':
94 | loss_fn = nn.MSELoss()
95 | elif args.finetune.loss == 'cel':
96 | loss_fn = nn.CrossEntropyLoss()
97 |
98 | optimizer = torch.optim.Adam(model.parameters(), lr = params['lr'] ) # before 1e-4
99 | if args.finetune.use_schedular:
100 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
101 | else:
102 | scheduler=None
103 |
104 | for epoch in range(args.finetune.n_epochs):
105 |
106 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, scheduler=scheduler, loss_fn=loss_fn, device=device)
107 | val_loss, _, _ = trainer.test(model=model, loader=val_loader, device=device)
108 | print(train_loss, val_loss)
109 | early_stopping(val_loss, model)
110 |
111 | if early_stopping.early_stop:
112 | print("Early stopping")
113 | break
114 |
115 | try:
116 | model.load_state_dict(torch.load(ft_chk_point))
117 | mse, true, pred = trainer.test(model=model, loader=val_loader, device=device)
118 |
119 | return {'loss':-mse, 'status':STATUS_OK}
120 |
121 | except:
122 | return {'loss':-100000, 'status':STATUS_OK}
123 |
124 |
125 | parser = argparse.ArgumentParser()
126 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml')
127 | args = parser.parse_args()
128 |
129 |
130 | if args.config: # args priority is higher than yaml
131 | opt = OmegaConf.load(args.config)
132 | OmegaConf.resolve(opt)
133 | args=opt
134 |
135 |
136 | space = {
137 | 'num_layer': hp.choice('num_layer', [3, 4, 5,6,7,8]),
138 | 'lr': hp.choice('lr', [1e-3, 1e-4, 1e-6]),
139 | 'drop_ratio': hp.choice('drop_ratio', [0.15, 0.2, 0.3, 0.5]),
140 | 'batch_size' : hp.choice('batch_size', [8, 16, 32, 64, 128]),
141 | }
142 |
143 |
144 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
145 |
146 | trials = Trials()
147 | st_time = time.time()
148 | best = fmin(trainModel, space, algo=hyperopt.rand.suggest, max_evals=25, trials=trials)
149 |
150 | end_time = time.time()
151 | fo = open(f"{args.exp_dir}/res2.txt", "w")
152 | fo.write(repr(best))
153 | fo.close()
154 |
155 | print( (end_time-st_time)/3600 )
156 |
--------------------------------------------------------------------------------
/fragnet/hp/hpoptuna.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import optuna
3 | from torch_geometric.data import DataLoader
4 | import random
5 | import time
6 | from fragnet.dataset.dataset import load_pickle_dataset
7 | from torch.utils.data import DataLoader
8 | from fragnet.dataset.data import collate_fn
9 | import torch.nn as nn
10 | from fragnet.model.gat.gat2 import FragNet
11 | from fragnet.train.utils import EarlyStopping
12 | import numpy as np
13 | from fragnet.train.utils import test_fn
14 | from fragnet.dataset.dataset import load_data_parts
15 | import os
16 | import torch
17 | from torch_scatter import scatter_add
18 | from omegaconf import OmegaConf
19 | import argparse
20 | from fragnet.train.utils import TrainerFineTune as Trainer
21 | import torch.optim.lr_scheduler as lr_scheduler
22 | from fragnet.dataset.data import collate_fn
23 | from fragnet.dataset.data import collate_fn
24 | import matplotlib.pyplot as plt
25 | from fragnet.model.gat.gat2 import FragNetFineTune
26 | import pytorch_lightning as pl
27 |
28 | def seed_everything(seed: int):
29 | import random, os
30 |
31 | random.seed(seed)
32 | os.environ['PYTHONHASHSEED'] = str(seed)
33 | np.random.seed(seed)
34 | torch.manual_seed(seed)
35 | torch.cuda.manual_seed_all(seed)
36 | torch.backends.cudnn.deterministic = True
37 |
38 |
39 | def trainModel(trial):
40 |
41 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
42 |
43 | num_heads = args.finetune.model.num_heads
44 | num_layer = args.finetune.model.num_layer
45 |
46 | drop_ratio = trial.suggest_categorical('drop_ratio', [0.0, 0.1, 0.2, 0.3])
47 |
48 | fthead = args.finetune.model.fthead
49 |
50 | if fthead == 'FTHead3':
51 | h1 = trial.suggest_int('h1', 64, 2048, step=64)
52 | h2=trial.suggest_int('h2', 64, 2048, step=64)
53 | h3=trial.suggest_int('h3', 64, 2048, step=64)
54 | h4=trial.suggest_int('h4', 64, 2048, step=64)
55 |
56 | elif fthead == 'FTHead4':
57 | h1 = trial.suggest_int('h1', 64, 2048, step=64)
58 | h2, h3, h4 = None, None, None
59 |
60 |
61 | act = trial.suggest_categorical("act", ['relu','silu','gelu','celu','selu','rrelu','relu6','prelu','leakyrelu'])
62 | batch_size = trial.suggest_categorical('batch_size', [16,32,64,128])
63 | lr = args.finetune.lr
64 |
65 |
66 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings'
67 | if args.model_version == 'gat2':
68 | from fragnet.model.gat.gat2 import FragNetFineTune
69 | elif args.model_version == 'gat2_lite':
70 | from fragnet.model.gat.gat2_lite import FragNetFineTune
71 |
72 | model = FragNetFineTune(n_classes=args.finetune.model.n_classes,
73 | atom_features=args.atom_features,
74 | frag_features=args.frag_features,
75 | edge_features=args.edge_features,
76 | num_heads = num_heads,
77 | num_layer= num_layer,
78 | drop_ratio= drop_ratio,
79 | h1=h1,
80 | h2=h2,
81 | h3=h3,
82 | h4=h4,
83 | act=act,
84 | fthead=fthead
85 | )
86 |
87 |
88 |
89 | if pt_chkpoint_name:
90 |
91 | from fragnet.model.gat.gat2_pretrain import FragNetPreTrain
92 | modelpt = FragNetPreTrain(
93 | atom_features=args.atom_features,
94 | frag_features=args.frag_features,
95 | edge_features=args.edge_features,
96 | num_layer=args.pretrain.num_layer,
97 | drop_ratio=args.pretrain.drop_ratio,
98 | num_heads=args.pretrain.num_heads,
99 | emb_dim=args.pretrain.emb_dim)
100 | modelpt.load_state_dict(torch.load(pt_chkpoint_name, map_location=torch.device(device)))
101 |
102 |
103 | print('loading pretrained weights')
104 | model.pretrain.load_state_dict(modelpt.pretrain.state_dict())
105 | print('weights loaded')
106 | else:
107 | print('no pretrained weights')
108 |
109 |
110 | trainer = Trainer(target_type=args.finetune.target_type)
111 |
112 | train_dataset2 = load_pickle_dataset(args.finetune.train.path)
113 | val_dataset2 = load_pickle_dataset(args.finetune.val.path)
114 |
115 |
116 | train_loader = DataLoader(train_dataset2, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, drop_last=True)
117 | val_loader = DataLoader(val_dataset2, collate_fn=collate_fn, batch_size=128, shuffle=False, drop_last=False)
118 |
119 | model.to(device);
120 |
121 | optimizer = torch.optim.Adam(model.parameters(), lr = lr ) # before 1e-4
122 | if args.finetune.use_schedular:
123 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, cooldown=0)
124 | else:
125 | scheduler=None
126 |
127 | scheduler=None
128 |
129 | ft_chk_point = args.chkpt
130 | print("checkpoint name: ", ft_chk_point)
131 | early_stopping = EarlyStopping(patience=args.finetune.es_patience, verbose=True, chkpoint_name=ft_chk_point)
132 |
133 | for epoch in range(args.finetune.n_epochs):
134 |
135 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, scheduler=scheduler, device=device, val_loader=val_loader)
136 | val_loss, _, _ = trainer.test(model=model, loader=val_loader, device=device)
137 | print(train_loss, val_loss)
138 |
139 |
140 | trial.report(val_loss, epoch)
141 | if args.prune == 1:
142 | if trial.should_prune():
143 | raise optuna.TrialPruned()
144 |
145 | early_stopping(val_loss, model)
146 |
147 | if early_stopping.early_stop:
148 | print("Early stopping")
149 | break
150 |
151 |
152 | try:
153 | model.load_state_dict(torch.load(ft_chk_point))
154 | mse, true, pred = trainer.test(model=model, loader=val_loader, device=device)
155 |
156 | return mse
157 |
158 | except:
159 | return 1000.0
160 |
161 |
162 | if __name__ == '__main__':
163 |
164 | parser = argparse.ArgumentParser()
165 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='exps/ft/esol2/config.yaml')
166 | parser.add_argument('--chkpt', help="checkpoint name", type=str, required=False, default='opt.pt')
167 | parser.add_argument('--direction', help="", type=str, required=False, default='minimize')
168 | parser.add_argument('--n_trials', help="", type=int, required=False, default=100)
169 | parser.add_argument('--embed_max', help="", type=int, required=False, default=512)
170 | parser.add_argument('--seed', help="", type=int, required=False, default=1)
171 | parser.add_argument('--ft_epochs', help="", type=int, required=False, default=100)
172 | parser.add_argument('--choose_random', help="", type=bool, required=False, default=False)
173 | parser.add_argument('--prune', help="", type=int, required=False, default=1)
174 | args = parser.parse_args()
175 |
176 |
177 | if args.config: # args priority is higher than yaml
178 | opt = OmegaConf.load(args.config)
179 | OmegaConf.resolve(opt)
180 |
181 | opt.update(vars(args))
182 | args = opt
183 |
184 | seed_everything(args.seed)
185 | print('seed: ', args.seed)
186 | print('choose_random: ', args.choose_random)
187 | args.finetune.n_epochs = args.ft_epochs
188 |
189 | # for resuming
190 | study_name = args.chkpt.replace('.pt', '').split('/')[-1]
191 | storage = args.chkpt.replace('.pt', '.db')
192 | study = optuna.create_study(direction = args.direction, study_name=study_name, storage=f'sqlite:///{storage}', load_if_exists=True)
193 | # for resuming
194 |
195 | study.optimize(trainModel, n_trials=args.n_trials, gc_after_trial=True)
196 | df_study = study.trials_dataframe(attrs=('number', 'value', 'params', 'state'))
197 | df_name = args.chkpt.replace('.pt', '.csv')
198 | df_study.to_csv(df_name)
199 |
200 | print("best params:")
201 | print(study.best_params)
202 |
203 |
--------------------------------------------------------------------------------
/fragnet/hp/hpray.py:
--------------------------------------------------------------------------------
1 |
2 | from torch_geometric.data import DataLoader
3 | from hyperopt import STATUS_OK
4 | from dataset import load_pickle_dataset
5 | from torch.utils.data import DataLoader
6 | from data import collate_fn
7 | import torch.nn as nn
8 | import os
9 | import torch
10 | from omegaconf import OmegaConf
11 | import argparse
12 | from utils import TrainerFineTune as Trainer
13 | import torch.optim.lr_scheduler as lr_scheduler
14 | from ray import tune
15 |
16 | RESULTS_PATH = './'
17 |
18 | def trainModel(params):
19 |
20 | exp_dir = args.exp_dir
21 | n_classes_pt = args.pretrain.n_classes
22 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
23 |
24 |
25 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings'
26 | from gat2 import FragNetFineTune
27 |
28 | model = FragNetFineTune(n_classes=args.finetune.model.n_classes,
29 | num_layer=params['num_layer'],
30 | drop_ratio=params['drop_ratio'])
31 | trainer = Trainer(target_type=args.finetune.target_type)
32 |
33 |
34 | if pt_chkpoint_name:
35 | model_pretrain = FragNetFineTune(n_classes_pt)
36 | model_pretrain.load_state_dict(torch.load(pt_chkpoint_name))
37 | state_dict_to_load={}
38 | for k,v in model.state_dict().items():
39 |
40 | if v.size() == model_pretrain.state_dict()[k].size():
41 | state_dict_to_load[k] = model_pretrain.state_dict()[k]
42 | else:
43 | state_dict_to_load[k] = v
44 |
45 | model.load_state_dict(state_dict_to_load)
46 |
47 |
48 | args.finetune.train.path = os.path.join(RESULTS_PATH, args.finetune.train.path)
49 | args.finetune.val.path = os.path.join(RESULTS_PATH, args.finetune.val.path)
50 | args.finetune.test.path = os.path.join(RESULTS_PATH, args.finetune.test.path)
51 |
52 | train_dataset2 = load_pickle_dataset(args.finetune.train.path, args.finetune.train.name)
53 | val_dataset2 = load_pickle_dataset(args.finetune.val.path, args.finetune.val.name)
54 | test_dataset2 = load_pickle_dataset(args.finetune.test.path, args.finetune.test.name) #'finetune_data/pnnl_exp'
55 |
56 | train_loader = DataLoader(train_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=True, drop_last=True)
57 | val_loader = DataLoader(val_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=False, drop_last=False)
58 | test_loader = DataLoader(test_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=False, drop_last=False)
59 |
60 | model.to(device);
61 |
62 |
63 |
64 | if args.finetune.loss == 'mse':
65 | loss_fn = nn.MSELoss()
66 | elif args.finetune.loss == 'cel':
67 | loss_fn = nn.CrossEntropyLoss()
68 |
69 | optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 ) # before 1e-4
70 | if args.finetune.use_schedular:
71 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
72 | else:
73 | scheduler=None
74 |
75 | for epoch in range(args.finetune.n_epochs):
76 |
77 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, scheduler=scheduler, loss_fn=loss_fn, device=device)
78 | val_loss, _, _ = trainer.test(model=model, loader=val_loader, device=device)
79 | print(train_loss, val_loss)
80 |
81 | try:
82 | mse, true, pred = trainer.test(model=model, loader=val_loader, device=device)
83 | return {'score':mse, 'status':STATUS_OK}
84 |
85 | except:
86 | return {'score':1000, 'status':STATUS_OK}
87 |
88 |
89 |
90 | parser = argparse.ArgumentParser()
91 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='exps/ft/esol2/None/no_pt/config.yaml')
92 | args = parser.parse_args('')
93 |
94 |
95 | if args.config: # args priority is higher than yaml
96 | opt = OmegaConf.load(args.config)
97 | OmegaConf.resolve(opt)
98 | args=opt
99 |
100 |
101 | args.finetune.n_epochs = 20
102 | search_space = {
103 | "num_layer": tune.choice([3,4,5,6,7,8]),
104 | "drop_ratio": tune.choice([0.1, 0.15, 0.2, 0.25, .3]),
105 | "batch_size": tune.choice([8, 16, 32, 64]),
106 | }
107 |
108 | tuner = tune.Tuner(trainModel, param_space=search_space,
109 | tune_config=tune.TuneConfig(num_samples=100))
110 | results = tuner.fit()
111 | print(results.get_best_result(metric="score", mode="min").config)
--------------------------------------------------------------------------------
/fragnet/model/cdrp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/model/cdrp/__init__.py
--------------------------------------------------------------------------------
/fragnet/model/cdrp/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MLP(nn.Sequential):
7 | def __init__(self, gene_dim=903, device='cuda'):
8 | self.device = device
9 | input_dim_gene = gene_dim
10 | hidden_dim_gene = 256
11 | mlp_hidden_dims_gene = [1024, 256, 64]
12 | super(MLP, self).__init__()
13 | layer_size = len(mlp_hidden_dims_gene) + 1
14 | dims = [input_dim_gene] + mlp_hidden_dims_gene + [hidden_dim_gene]
15 | self.predictor = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(layer_size)])
16 |
17 | def forward(self, v):
18 | # predict
19 | v = v.float().to(self.device)
20 | for i, l in enumerate(self.predictor):
21 | v = F.relu(l(v))
22 | return v
23 |
24 |
25 | class CDRPModel(nn.Module):
26 | def __init__(self, drug_model, gene_dim, device):
27 | super().__init__()
28 | self.drug_model = drug_model # FragNetFineTune()
29 | self.fc1 = nn.Linear(256+256, 128)
30 | self.fc2 = nn.Linear(128, 1)
31 | self.cell_model = MLP(gene_dim, device)
32 |
33 |
34 | def forward(self, batch):
35 |
36 | drug_enc = self.drug_model(batch)
37 | gene_expr = batch['gene_expr']
38 | cell_enc = self.cell_model(gene_expr)
39 |
40 | cat = torch.cat((drug_enc, cell_enc), 1)
41 | out = self.fc1(cat)
42 | out = self.fc2(out)
43 | return out
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/fragnet/model/dta/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/model/dta/__init__.py
--------------------------------------------------------------------------------
/fragnet/model/dta/drug_encoder.py:
--------------------------------------------------------------------------------
1 | """
2 | This script was copied from
3 | https://github.com/jianglikun/DeepTTC/blob/main/model_helper.py
4 | and modified
5 | """
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 | import numpy as np
11 | import copy
12 | import math
13 |
14 | torch.manual_seed(1)
15 | np.random.seed(1)
16 |
17 | class LayerNorm(nn.Module):
18 | def __init__(self, hidden_size, variance_epsilon=1e-12):
19 |
20 | super(LayerNorm, self).__init__()
21 | self.gamma = nn.Parameter(torch.ones(hidden_size))
22 | self.beta = nn.Parameter(torch.zeros(hidden_size))
23 | self.variance_epsilon = variance_epsilon
24 |
25 | def forward(self, x):
26 | u = x.mean(-1, keepdim=True)
27 | s = (x - u).pow(2).mean(-1, keepdim=True)
28 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
29 | return self.gamma * x + self.beta
30 |
31 |
32 | class Embeddings(nn.Module):
33 | """Construct the embeddings from protein/target, position embeddings.
34 | """
35 | def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate):
36 | super(Embeddings, self).__init__()
37 | self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
38 | self.position_embeddings = nn.Embedding(max_position_size, hidden_size)
39 |
40 | self.LayerNorm = LayerNorm(hidden_size)
41 | self.dropout = nn.Dropout(dropout_rate)
42 |
43 | def forward(self, input_ids):
44 | seq_length = input_ids.size(1)
45 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
46 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
47 |
48 | words_embeddings = self.word_embeddings(input_ids)
49 | position_embeddings = self.position_embeddings(position_ids)
50 |
51 | embeddings = words_embeddings + position_embeddings
52 | embeddings = self.LayerNorm(embeddings)
53 | embeddings = self.dropout(embeddings)
54 | return embeddings
55 |
56 |
57 | class SelfAttention(nn.Module):
58 | def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
59 | super(SelfAttention, self).__init__()
60 | if hidden_size % num_attention_heads != 0:
61 | raise ValueError(
62 | "The hidden size (%d) is not a multiple of the number of attention "
63 | "heads (%d)" % (hidden_size, num_attention_heads))
64 | self.num_attention_heads = num_attention_heads
65 | self.attention_head_size = int(hidden_size / num_attention_heads)
66 | self.all_head_size = self.num_attention_heads * self.attention_head_size
67 |
68 | self.query = nn.Linear(hidden_size, self.all_head_size)
69 | self.key = nn.Linear(hidden_size, self.all_head_size)
70 | self.value = nn.Linear(hidden_size, self.all_head_size)
71 |
72 | self.dropout = nn.Dropout(attention_probs_dropout_prob)
73 |
74 | def transpose_for_scores(self, x):
75 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
76 | x = x.view(*new_x_shape)
77 | return x.permute(0, 2, 1, 3)
78 |
79 | def forward(self, hidden_states, attention_mask):
80 | mixed_query_layer = self.query(hidden_states)
81 | mixed_key_layer = self.key(hidden_states)
82 | mixed_value_layer = self.value(hidden_states)
83 |
84 | query_layer = self.transpose_for_scores(mixed_query_layer)
85 | key_layer = self.transpose_for_scores(mixed_key_layer)
86 | value_layer = self.transpose_for_scores(mixed_value_layer)
87 |
88 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
89 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
90 |
91 | attention_scores = attention_scores + attention_mask
92 |
93 | # Normalize the attention scores to probabilities.
94 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
95 | attention_probs = self.dropout(attention_probs)
96 |
97 | context_layer = torch.matmul(attention_probs, value_layer)
98 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
99 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
100 | context_layer = context_layer.view(*new_context_layer_shape)
101 | return context_layer
102 |
103 |
104 | class SelfOutput(nn.Module):
105 | def __init__(self, hidden_size, hidden_dropout_prob):
106 | super(SelfOutput, self).__init__()
107 | self.dense = nn.Linear(hidden_size, hidden_size)
108 | self.LayerNorm = LayerNorm(hidden_size)
109 | self.dropout = nn.Dropout(hidden_dropout_prob)
110 |
111 | def forward(self, hidden_states, input_tensor):
112 | hidden_states = self.dense(hidden_states)
113 | hidden_states = self.dropout(hidden_states)
114 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
115 | return hidden_states
116 |
117 |
118 | class Attention(nn.Module):
119 | def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
120 | super(Attention, self).__init__()
121 | self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
122 | self.output = SelfOutput(hidden_size, hidden_dropout_prob)
123 |
124 | def forward(self, input_tensor, attention_mask):
125 | self_output = self.self(input_tensor, attention_mask)
126 | attention_output = self.output(self_output, input_tensor)
127 | return attention_output
128 |
129 | class Intermediate(nn.Module):
130 | def __init__(self, hidden_size, intermediate_size):
131 | super(Intermediate, self).__init__()
132 | self.dense = nn.Linear(hidden_size, intermediate_size)
133 |
134 | def forward(self, hidden_states):
135 | hidden_states = self.dense(hidden_states)
136 | hidden_states = F.relu(hidden_states)
137 | return hidden_states
138 |
139 | class Output(nn.Module):
140 | def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob):
141 | super(Output, self).__init__()
142 | self.dense = nn.Linear(intermediate_size, hidden_size)
143 | self.LayerNorm = LayerNorm(hidden_size)
144 | self.dropout = nn.Dropout(hidden_dropout_prob)
145 |
146 | def forward(self, hidden_states, input_tensor):
147 | hidden_states = self.dense(hidden_states)
148 | hidden_states = self.dropout(hidden_states)
149 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
150 | return hidden_states
151 |
152 | class Encoder(nn.Module):
153 | def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
154 | super(Encoder, self).__init__()
155 | self.attention = Attention(hidden_size, num_attention_heads,
156 | attention_probs_dropout_prob, hidden_dropout_prob)
157 | self.intermediate = Intermediate(hidden_size, intermediate_size)
158 | self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob)
159 |
160 | def forward(self, hidden_states, attention_mask):
161 | attention_output = self.attention(hidden_states, attention_mask)
162 | intermediate_output = self.intermediate(attention_output)
163 | layer_output = self.output(intermediate_output, attention_output)
164 | return layer_output
165 |
166 |
167 | class Encoder_MultipleLayers(nn.Module):
168 | def __init__(self, n_layer, hidden_size, intermediate_size,
169 | num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
170 | super(Encoder_MultipleLayers, self).__init__()
171 | layer = Encoder(hidden_size, intermediate_size, num_attention_heads,
172 | attention_probs_dropout_prob, hidden_dropout_prob)
173 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)])
174 |
175 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
176 | all_encoder_layers = []
177 | for layer_module in self.layer:
178 | hidden_states = layer_module(hidden_states, attention_mask)
179 | return hidden_states
180 |
--------------------------------------------------------------------------------
/fragnet/model/dta/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | from .drug_encoder import Encoder_MultipleLayers
5 |
6 | torch.manual_seed(1)
7 | np.random.seed(1)
8 |
9 | class LayerNorm(nn.Module):
10 | def __init__(self, hidden_size, variance_epsilon=1e-12):
11 |
12 | super(LayerNorm, self).__init__()
13 | self.gamma = nn.Parameter(torch.ones(hidden_size))
14 | self.beta = nn.Parameter(torch.zeros(hidden_size))
15 | self.variance_epsilon = variance_epsilon
16 |
17 | def forward(self, x):
18 | u = x.mean(-1, keepdim=True)
19 | s = (x - u).pow(2).mean(-1, keepdim=True)
20 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
21 | return self.gamma * x + self.beta
22 |
23 |
24 | class Embeddings(nn.Module):
25 | """Construct the embeddings from protein/target, position embeddings.
26 | """
27 | def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate):
28 | super(Embeddings, self).__init__()
29 | self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
30 | self.position_embeddings = nn.Embedding(max_position_size, hidden_size)
31 |
32 | self.LayerNorm = LayerNorm(hidden_size)
33 | self.dropout = nn.Dropout(dropout_rate)
34 |
35 | def forward(self, input_ids):
36 | seq_length = input_ids.size(1)
37 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
38 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
39 |
40 | words_embeddings = self.word_embeddings(input_ids)
41 | position_embeddings = self.position_embeddings(position_ids)
42 |
43 | embeddings = words_embeddings + position_embeddings
44 | embeddings = self.LayerNorm(embeddings)
45 | embeddings = self.dropout(embeddings)
46 | return embeddings
47 |
48 |
49 |
50 | class transformer(nn.Sequential):
51 | def __init__(self):
52 | super(transformer, self).__init__()
53 | input_dim_drug = 25
54 | transformer_emb_size_drug = 128
55 | transformer_dropout_rate = 0.1
56 | transformer_n_layer_drug = 8
57 | transformer_intermediate_size_drug = 512
58 | transformer_num_attention_heads_drug = 8
59 | transformer_attention_probs_dropout = 0.1
60 | transformer_hidden_dropout_rate = 0.1
61 | max_position_size = 1000
62 | self.emb = Embeddings(input_dim_drug,
63 | transformer_emb_size_drug,
64 | max_position_size,
65 | transformer_dropout_rate)
66 |
67 | self.encoder = Encoder_MultipleLayers(transformer_n_layer_drug,
68 | transformer_emb_size_drug,
69 | transformer_intermediate_size_drug,
70 | transformer_num_attention_heads_drug,
71 | transformer_attention_probs_dropout,
72 | transformer_hidden_dropout_rate)
73 | def forward(self, e, e_mask):
74 | ex_e_mask = e_mask.unsqueeze(1).unsqueeze(2)
75 | ex_e_mask = (1.0 - ex_e_mask) * -10000.0
76 |
77 | emb = self.emb(e)
78 | encoded_layers = self.encoder(emb.float(), ex_e_mask.float())
79 | return encoded_layers[:, 0]
80 |
81 |
82 |
83 | class DTAModel(nn.Module):
84 | def __init__(self, drug_model):
85 | super().__init__()
86 | self.drug_model = drug_model # FragNetFineTune()
87 | self.target_model = transformer()
88 | self.fc1 = nn.Linear(256+128, 128)
89 | self.fc2 = nn.Linear(128, 1)
90 |
91 | def forward(self, batch):
92 |
93 | drug_enc = self.drug_model(batch)
94 |
95 | tokens = batch['protein']
96 | padding_mask = ~tokens.eq(0)*1
97 |
98 | target_enc = self.target_model(tokens, padding_mask)
99 | cat = torch.cat((drug_enc, target_enc), 1)
100 |
101 | out = self.fc1(cat)
102 | out = self.fc2(out)
103 |
104 | return out
105 |
106 |
107 | class DTAModel2(nn.Module):
108 | def __init__(self, drug_model):
109 | super().__init__()
110 | self.drug_model = drug_model # FragNetFineTune()
111 | self.fc1 = nn.Linear(256+300, 128)
112 | self.fc2 = nn.Linear(128, 1)
113 |
114 | num_features = 25
115 | prot_emb_dim = 300
116 | self.in_channels = 1000
117 | n_filters = 32
118 | kernel_size = 8
119 | prot_output_dim=300
120 |
121 | self.embedding_xt = nn.Embedding(num_features + 1, prot_emb_dim)
122 | self.conv_xt_1 = nn.Conv1d(in_channels=self.in_channels, out_channels=n_filters, kernel_size=kernel_size)
123 | intermediate_dim = prot_emb_dim - kernel_size + 1
124 | self.fc1_xt_dim = n_filters*intermediate_dim
125 | self.fc1_xt = nn.Linear(self.fc1_xt_dim, prot_output_dim)
126 |
127 | def forward(self, batch):
128 |
129 | drug_enc = self.drug_model(batch)
130 |
131 | tokens = batch['protein']
132 | tokens = tokens.reshape(-1, self.in_channels)
133 |
134 | embedded_xt = self.embedding_xt(tokens)
135 | conv_xt = self.conv_xt_1(embedded_xt)
136 | # flatten
137 | xt = conv_xt.view(-1, self.fc1_xt_dim)
138 | xt = self.fc1_xt(xt)
139 |
140 |
141 | cat = torch.cat((drug_enc, xt), 1)
142 |
143 | out = self.fc1(cat)
144 | out = self.fc2(out)
145 |
146 | return out
147 |
148 |
149 |
150 |
--------------------------------------------------------------------------------
/fragnet/model/gat/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/model/gat/__init__.py
--------------------------------------------------------------------------------
/fragnet/model/gat/extra_optimizers.py:
--------------------------------------------------------------------------------
1 | """
2 | This code was copied from https://github.com/LARS-research/3D-PGT/blob/main/graphgps/optimizer/extra_optimizers.py
3 | and modified
4 | """
5 |
6 | import math
7 | from typing import Iterator
8 | import torch.optim as optim
9 | from torch.nn import Parameter
10 | from torch.optim import Adagrad, AdamW, Optimizer
11 | from torch.optim.lr_scheduler import ReduceLROnPlateau
12 |
13 | def adagrad_optimizer(params: Iterator[Parameter], base_lr: float,
14 | weight_decay: float) -> Adagrad:
15 | return Adagrad(params, lr=base_lr, weight_decay=weight_decay)
16 |
17 | def adamW_optimizer(params: Iterator[Parameter], base_lr: float,
18 | weight_decay: float) -> AdamW:
19 | return AdamW(params, lr=base_lr, weight_decay=weight_decay)
20 |
21 |
22 | def plateau_scheduler(optimizer: Optimizer, patience: int,
23 | lr_decay: float) -> ReduceLROnPlateau:
24 | return ReduceLROnPlateau(optimizer, patience=patience, factor=lr_decay)
25 |
26 |
27 | def scheduler_reduce_on_plateau(optimizer: Optimizer, reduce_factor: float,
28 | schedule_patience: int, min_lr: float,
29 | train_mode: str, eval_period: int):
30 |
31 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
32 | optimizer=optimizer,
33 | mode='min',
34 | factor=reduce_factor,
35 | patience=schedule_patience,
36 | min_lr=min_lr,
37 | verbose=True
38 | )
39 | if not hasattr(scheduler, 'get_last_lr'):
40 | # ReduceLROnPlateau doesn't have `get_last_lr` method as of current
41 | # pytorch1.10; we add it here for consistency with other schedulers.
42 | def get_last_lr(self):
43 | """ Return last computed learning rate by current scheduler.
44 | """
45 | return self._last_lr
46 |
47 | scheduler.get_last_lr = get_last_lr.__get__(scheduler)
48 | scheduler._last_lr = [group['lr']
49 | for group in scheduler.optimizer.param_groups]
50 |
51 | def modified_state_dict(ref):
52 | """Returns the state of the scheduler as a :class:`dict`.
53 | Additionally modified to ignore 'get_last_lr', 'state_dict'.
54 | Including these entries in the state dict would cause issues when
55 | loading a partially trained / pretrained model from a checkpoint.
56 | """
57 | return {key: value for key, value in ref.__dict__.items()
58 | if key not in ['sparsifier', 'get_last_lr', 'state_dict']}
59 |
60 | scheduler.state_dict = modified_state_dict.__get__(scheduler)
61 |
62 | return scheduler
63 |
64 |
65 | # @register.register_scheduler('linear_with_warmup')
66 | def linear_with_warmup_scheduler(optimizer: Optimizer,
67 | num_warmup_epochs: int, max_epoch: int):
68 | scheduler = get_linear_schedule_with_warmup(
69 | optimizer=optimizer,
70 | num_warmup_steps=num_warmup_epochs,
71 | num_training_steps=max_epoch
72 | )
73 | return scheduler
74 |
75 |
76 | # @register.register_scheduler('cosine_with_warmup')
77 | def cosine_with_warmup_scheduler(optimizer: Optimizer,
78 | num_warmup_epochs: int, max_epoch: int):
79 | scheduler = get_cosine_schedule_with_warmup(
80 | optimizer=optimizer,
81 | num_warmup_steps=num_warmup_epochs,
82 | num_training_steps=max_epoch
83 | )
84 | return scheduler
85 |
86 |
87 | def get_linear_schedule_with_warmup(
88 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
89 | last_epoch: int = -1):
90 | """
91 | Implementation by Huggingface:
92 | https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py
93 |
94 | Create a schedule with a learning rate that decreases linearly from the
95 | initial lr set in the optimizer to 0, after a warmup period during which it
96 | increases linearly from 0 to the initial lr set in the optimizer.
97 | Args:
98 | optimizer ([`~torch.optim.Optimizer`]):
99 | The optimizer for which to schedule the learning rate.
100 | num_warmup_steps (`int`):
101 | The number of steps for the warmup phase.
102 | num_training_steps (`int`):
103 | The total number of training steps.
104 | last_epoch (`int`, *optional*, defaults to -1):
105 | The index of the last epoch when resuming training.
106 | Return:
107 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
108 | """
109 |
110 | def lr_lambda(current_step: int):
111 | if current_step < num_warmup_steps:
112 | return max(1e-6, float(current_step) / float(max(1, num_warmup_steps)))
113 | return max(
114 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
115 | )
116 |
117 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
118 |
119 |
120 | def get_cosine_schedule_with_warmup(
121 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
122 | num_cycles: float = 0.5, last_epoch: int = -1):
123 | """
124 | Implementation by Huggingface:
125 | https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py
126 |
127 | Create a schedule with a learning rate that decreases following the values
128 | of the cosine function between the initial lr set in the optimizer to 0,
129 | after a warmup period during which it increases linearly between 0 and the
130 | initial lr set in the optimizer.
131 | Args:
132 | optimizer ([`~torch.optim.Optimizer`]):
133 | The optimizer for which to schedule the learning rate.
134 | num_warmup_steps (`int`):
135 | The number of steps for the warmup phase.
136 | num_training_steps (`int`):
137 | The total number of training steps.
138 | num_cycles (`float`, *optional*, defaults to 0.5):
139 | The number of waves in the cosine schedule (the defaults is to just
140 | decrease from the max value to 0 following a half-cosine).
141 | last_epoch (`int`, *optional*, defaults to -1):
142 | The index of the last epoch when resuming training.
143 | Return:
144 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
145 | """
146 |
147 | def lr_lambda(current_step):
148 | if current_step < num_warmup_steps:
149 | return max(1e-6, float(current_step) / float(max(1, num_warmup_steps)))
150 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
151 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
152 |
153 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
--------------------------------------------------------------------------------
/fragnet/model/gat/gat2_cv.py:
--------------------------------------------------------------------------------
1 | # from gat import FragNetFineTune
2 | import torch
3 | from dataset import load_pickle_dataset
4 | import torch.nn as nn
5 | from utils import EarlyStopping
6 | import torch
7 | from data import collate_fn
8 | from torch.utils.data import DataLoader
9 | import argparse
10 | from utils import TrainerFineTune as Trainer
11 | import numpy as np
12 | from omegaconf import OmegaConf
13 | import pickle
14 | import torch.optim.lr_scheduler as lr_scheduler
15 | from sklearn.model_selection import KFold
16 | from gat2 import FragNetFineTune
17 | import os
18 |
19 |
20 | def seed_everything(seed: int):
21 | import random, os
22 |
23 | random.seed(seed)
24 | os.environ['PYTHONHASHSEED'] = str(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = True
30 |
31 | seed_everything(26)
32 |
33 | def get_predictions(trainer, loader, model, device):
34 | mse, true, pred = trainer.test(model=model, loader=loader, device=device)
35 | smiles = [i.smiles for i in loader.dataset]
36 | res = {'smiles': smiles, 'true': true, 'pred': pred}
37 | return res
38 |
39 | def save_predictions(exp_dir, save_name, res):
40 |
41 | with open(f"{exp_dir}/{save_name}.pkl", 'wb') as f:
42 | pickle.dump(res,f )
43 |
44 |
45 |
46 | if __name__ == "__main__":
47 |
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml')
50 | args = parser.parse_args()
51 |
52 | if args.config: # args priority is higher than yaml
53 | opt = OmegaConf.load(args.config)
54 | OmegaConf.resolve(opt)
55 | args=opt
56 |
57 | seed_everything(args.seed)
58 | exp_dir = args['exp_dir']
59 | os.makedirs(exp_dir, exist_ok=True)
60 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
61 |
62 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings'
63 |
64 | if args.model_version == 'gat':
65 | from gat import FragNetFineTune
66 | print('loaded from gat')
67 | elif args.model_version=='gat2':
68 | from gat2 import FragNetFineTune
69 | print('loaded from gat2')
70 |
71 |
72 | model = FragNetFineTune(n_classes=args.finetune.model.n_classes,
73 | atom_features=args.atom_features,
74 | frag_features=args.frag_features,
75 | edge_features=args.edge_features,
76 | num_layer=args.finetune.model.num_layer,
77 | drop_ratio=args.finetune.model.drop_ratio,
78 | num_heads=args.finetune.model.num_heads,
79 | emb_dim=args.finetune.model.emb_dim,
80 | h1=args.finetune.model.h1,
81 | h2=args.finetune.model.h2,
82 | h3=args.finetune.model.h3,
83 | h4=args.finetune.model.h4,
84 | act=args.finetune.model.act,
85 | fthead=args.finetune.model.fthead
86 | )
87 | trainer = Trainer(target_type=args.finetune.target_type)
88 |
89 |
90 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings'
91 | if pt_chkpoint_name:
92 |
93 | from models import FragNetPreTrain
94 | modelpt = FragNetPreTrain(num_layer=args.finetune.model.num_layer,
95 | drop_ratio=args.finetune.model.drop_ratio,
96 | num_heads=args.finetune.model.num_heads,
97 | emb_dim=args.finetune.model.emb_dim,
98 | atom_features=args.atom_features, frag_features=args.frag_features, edge_features=args.edge_features)
99 |
100 |
101 | modelpt.load_state_dict(torch.load(pt_chkpoint_name, map_location=torch.device(device)))
102 |
103 | print('loading pretrained weights')
104 | model.pretrain.load_state_dict(modelpt.pretrain.state_dict())
105 | print('weights loaded')
106 | else:
107 | print('no pretrained weights')
108 |
109 |
110 | train_dataset2 = load_pickle_dataset(args.finetune.train.path)
111 | val_dataset2 = load_pickle_dataset(args.finetune.val.path)
112 |
113 | kf = KFold(n_splits=5)
114 | train_val = train_dataset2 + val_dataset2
115 |
116 | for icv, (train_index, test_index) in enumerate(kf.split(train_val)):
117 | train_ds = [train_val[i] for i in train_index]
118 | val_ds = [train_val[i] for i in test_index]
119 |
120 | train_loader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=True, drop_last=True)
121 | val_loader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=False, drop_last=False)
122 | trainer = Trainer(target_type=args.finetune.target_type)
123 |
124 |
125 | model.to(device);
126 | ft_chk_point = os.path.join(args.exp_dir, f'cv_{icv}.pt')
127 | early_stopping = EarlyStopping(patience=args.finetune.es_patience, verbose=True, chkpoint_name=ft_chk_point)
128 |
129 |
130 | if args.finetune.loss == 'mse':
131 | loss_fn = nn.MSELoss()
132 | elif args.finetune.loss == 'cel':
133 | loss_fn = nn.CrossEntropyLoss()
134 |
135 | optimizer = torch.optim.Adam(model.parameters(), lr = args.finetune.lr ) # before 1e-4
136 | if args.finetune.use_schedular:
137 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
138 | else:
139 | scheduler=None
140 |
141 | for epoch in range(args.finetune.n_epochs):
142 |
143 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, scheduler=scheduler, device=device)
144 | val_loss, _, _ = trainer.test(model=model, loader=val_loader, device=device)
145 |
146 | print(train_loss, val_loss)
147 | early_stopping(val_loss, model)
148 |
149 | if early_stopping.early_stop:
150 | print("Early stopping")
151 | break
152 |
153 |
154 | model.load_state_dict(torch.load(ft_chk_point))
155 | res = get_predictions(trainer, val_loader, model, device)
156 |
157 | with open(f'{exp_dir}/cv_{icv}.pkl', 'wb') as f:
158 | pickle.dump(res, f)
159 |
--------------------------------------------------------------------------------
/fragnet/model/gat/gat2_pl.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.callbacks import LearningRateMonitor
2 | import pytorch_lightning as pl
3 | from pytorch_lightning import loggers
4 | import torch
5 | import torch.nn.functional as F
6 | from gat2 import FragNetFineTune
7 | from pytorch_lightning import LightningModule, Trainer
8 | from dataset import load_pickle_dataset
9 | from torch.utils.data import DataLoader
10 | import argparse
11 | from omegaconf import OmegaConf
12 | from torch_geometric.nn.norm import BatchNorm
13 | from data import collate_fn
14 | import math
15 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping
16 | from pytorch_lightning.callbacks import ModelCheckpoint
17 |
18 | def get_cosine_schedule_with_warmup(
19 | optimizer, num_warmup_steps, num_training_steps,
20 | num_cycles = 0.5, last_epoch = -1):
21 | """
22 | Implementation by Huggingface:
23 | https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py
24 |
25 | Create a schedule with a learning rate that decreases following the values
26 | of the cosine function between the initial lr set in the optimizer to 0,
27 | after a warmup period during which it increases linearly between 0 and the
28 | initial lr set in the optimizer.
29 | Args:
30 | optimizer ([`~torch.optim.Optimizer`]):
31 | The optimizer for which to schedule the learning rate.
32 | num_warmup_steps (`int`):
33 | The number of steps for the warmup phase.
34 | num_training_steps (`int`):
35 | The total number of training steps.
36 | num_cycles (`float`, *optional*, defaults to 0.5):
37 | The number of waves in the cosine schedule (the defaults is to just
38 | decrease from the max value to 0 following a half-cosine).
39 | last_epoch (`int`, *optional*, defaults to -1):
40 | The index of the last epoch when resuming training.
41 | Return:
42 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
43 | """
44 |
45 | def lr_lambda(current_step):
46 | if current_step < num_warmup_steps:
47 | return max(1e-6, float(current_step) / float(max(1, num_warmup_steps)))
48 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
49 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
50 |
51 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
52 |
53 | class FragNetPL(pl.LightningModule):
54 | def __init__(self, args):
55 | super().__init__()
56 | self.save_hyperparameters()
57 |
58 | self.model = FragNetFineTune(n_classes=args.finetune.model.n_classes,
59 | atom_features=args.atom_features,
60 | frag_features=args.frag_features,
61 | edge_features=args.edge_features,
62 | num_layer=args.finetune.model.num_layer,
63 | drop_ratio=args.finetune.model.drop_ratio,
64 | num_heads=args.finetune.model.num_heads,
65 | emb_dim=args.finetune.model.emb_dim,
66 | h1=args.finetune.model.h1,
67 | h2=args.finetune.model.h2,
68 | h3=args.finetune.model.h3,
69 | h4=args.finetune.model.h4,
70 | act=args.finetune.model.act,
71 | fthead=args.finetune.model.fthead
72 | )
73 |
74 | self.args = args
75 |
76 | def forward(self, batch):
77 | return self.model(batch)
78 |
79 | def common_step(self, batch):
80 |
81 | y = self(batch)
82 | y_pred = batch['y']
83 | return y.reshape_as(y_pred), y_pred
84 |
85 |
86 | def training_step(self, batch, batch_idx):
87 | y, y_pred = self.common_step(batch)
88 |
89 | loss = F.mse_loss(y_pred, y)
90 | self.log('train_loss', loss, batch_size=args.finetune.batch_size)
91 | return loss
92 |
93 | def validation_step(self, batch, batch_idx):
94 | y, y_pred = self.common_step(batch)
95 | val_loss = F.mse_loss(y_pred, y)
96 |
97 | self.log('val_loss', val_loss, batch_size=args.finetune.batch_size)
98 |
99 | def test_step(self, batch, batch_idx):
100 | y, y_pred = self.common_step(batch)
101 |
102 | test_loss = F.mse_loss(y_pred, y)
103 | self.log('test_mse', test_loss**.5, batch_size=args.finetune.batch_size)
104 |
105 |
106 | def configure_optimizers(self):
107 | optimizer = torch.optim.Adam(self.parameters(), self.args.finetune.lr)
108 | return [optimizer]
109 |
110 |
111 |
112 | if __name__=="__main__":
113 |
114 | parser = argparse.ArgumentParser()
115 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config_exp1.yaml')
116 | args = parser.parse_args()
117 |
118 |
119 | if args.config: # args priority is higher than yaml
120 | opt = OmegaConf.load(args.config)
121 | OmegaConf.resolve(opt)
122 | args=opt
123 |
124 |
125 | pl.seed_everything(args.seed)
126 | model = FragNetPL(args)
127 |
128 |
129 | early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=100, verbose=False, mode="min")
130 | checkpoint = ModelCheckpoint(monitor='val_loss', save_top_k=1, mode='min')
131 |
132 | trainer = Trainer(accelerator='auto', max_epochs=args.finetune.n_epochs, gradient_clip_val=1, devices=-1,
133 | log_every_n_steps=1, callbacks=[early_stop_callback, checkpoint])
134 |
135 |
136 | train_dataset2 = load_pickle_dataset(args.finetune.train.path)
137 | val_dataset2 = load_pickle_dataset(args.finetune.val.path)
138 | test_dataset2 = load_pickle_dataset(args.finetune.test.path) #'finetune_data/pnnl_exp'
139 |
140 | train_loader = DataLoader(train_dataset2, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=True, drop_last=True)
141 | val_loader = DataLoader(val_dataset2, collate_fn=collate_fn, batch_size=64, shuffle=False, drop_last=False)
142 | test_loader = DataLoader(test_dataset2, collate_fn=collate_fn, batch_size=64, shuffle=False, drop_last=False)
143 |
144 | trainer.fit(model, train_loader, val_loader)
145 | trainer.test(model, test_loader)
--------------------------------------------------------------------------------
/fragnet/model/gat/gat2_pretrain.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from .pretrain_heads import PretrainTask
4 | from .gat2 import FragNet
5 | import random
6 |
7 | class FragNetPreTrain(nn.Module):
8 |
9 | def __init__(self, num_layer=4, drop_ratio=0.15, num_heads=4, emb_dim=128,
10 | atom_features=167, frag_features=167, edge_features=16,
11 | fedge_in=6,
12 | fbond_edge_in=6):
13 | super(FragNetPreTrain, self).__init__()
14 |
15 | self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio, num_heads=num_heads, emb_dim=emb_dim,
16 | atom_features=atom_features, frag_features=frag_features, edge_features=edge_features,
17 | fedge_in=fedge_in,
18 | fbond_edge_in=fbond_edge_in,)
19 | self.head = PretrainTask(128, 1)
20 |
21 |
22 | def forward(self, batch):
23 |
24 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch)
25 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = self.head(x_atoms, x_frags, e_edge, batch)
26 |
27 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep
28 |
29 |
30 | class FragNetPreTrainMasked(nn.Module):
31 |
32 | def __init__(self, num_layer=4, drop_ratio=0.15, num_heads=4, emb_dim=128,
33 | atom_features=167, frag_features=167, edge_features=16,
34 | fedge_in=6, fbond_edge_in=6):
35 | super(FragNetPreTrainMasked, self).__init__()
36 |
37 | self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio, num_heads=num_heads, emb_dim=emb_dim,
38 | atom_features=atom_features, frag_features=frag_features, edge_features=edge_features,
39 | fedge_in=fedge_in, fbond_edge_in=fbond_edge_in)
40 | self.head = PretrainTask(128, 1)
41 |
42 |
43 | def forward(self, batch):
44 |
45 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch)
46 |
47 | with torch.no_grad():
48 | n_atoms = x_atoms.shape[0]
49 | unmask_atoms = random.sample(list(range(n_atoms)), int(n_atoms*.85) )
50 | x_atoms_masked = torch.zeros(x_atoms.size()) + 0.0
51 | x_atoms_masked[unmask_atoms] = x_atoms[unmask_atoms]
52 |
53 |
54 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = self.head(x_atoms_masked, x_frags, e_edge, batch)
55 |
56 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep
--------------------------------------------------------------------------------
/fragnet/model/gat/pretrain_heads.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 | from torch_scatter import scatter_add
5 | from fragnet.model.gat.gat2 import FragNet
6 |
7 |
8 | class PretrainTask(nn.Module):
9 | """
10 | This function was copied from
11 | https://github.com/LARS-research/3D-PGT/blob/main/graphgps/head/pretrain_task.py
12 | and modified
13 |
14 | SAN prediction head for graph prediction tasks.
15 |
16 | Args:
17 | dim_in (int): Input dimension.
18 | dim_out (int): Output dimension. For binary prediction, dim_out=1.
19 | L (int): Number of hidden layers.
20 | """
21 |
22 | def __init__(self, dim_in=128, dim_out=1, L=2):
23 | super().__init__()
24 |
25 | # bond_length
26 | self.bl_reduce_layer = nn.Linear(dim_in * 3, dim_in)
27 | list_bl_layers = [
28 | nn.Linear(dim_in // 2**l, dim_in // 2 ** (l + 1), bias=True)
29 | for l in range(L)
30 | ]
31 | list_bl_layers.append(nn.Linear(dim_in // 2**L, dim_out, bias=True))
32 | self.bl_layers = nn.ModuleList(list_bl_layers)
33 |
34 | # bond_angle
35 | list_ba_layers = [
36 | nn.Linear(dim_in // 2**l, dim_in // 2 ** (l + 1), bias=True)
37 | for l in range(L)
38 | ]
39 | list_ba_layers.append(nn.Linear(dim_in // 2**L, dim_out, bias=True))
40 | self.ba_layers = nn.ModuleList(list_ba_layers)
41 |
42 | # dihedral_angle
43 | list_da_layers = [
44 | nn.Linear(dim_in // 2**l, dim_in // 2 ** (l + 1), bias=True)
45 | for l in range(L)
46 | ]
47 | list_da_layers.append(nn.Linear(dim_in // 2**L, dim_out, bias=True))
48 | self.da_layers = nn.ModuleList(list_da_layers)
49 |
50 | # graph-level prediction (energy)
51 | list_FC_layers = [
52 | nn.Linear(dim_in * 2 // 2**l, dim_in * 2 // 2 ** (l + 1), bias=True)
53 | for l in range(L)
54 | ]
55 | list_FC_layers.append(nn.Linear(dim_in * 2 // 2**L, dim_out, bias=True))
56 | self.FC_layers = nn.ModuleList(list_FC_layers)
57 |
58 | self.L = L
59 | self.activation = nn.ReLU()
60 |
61 | def _apply_index(self, batch):
62 | return batch.bond_length, batch.distance
63 |
64 | def forward(self, x_atoms, x_frags, edge_attr, batch):
65 | edge_index = batch["edge_index"]
66 |
67 | bond_length_pred = torch.concat(
68 | (x_atoms[edge_index.T][:, 0, :], x_atoms[edge_index.T][:, 1, :], edge_attr),
69 | axis=1,
70 | )
71 | bond_length_pred = self.bl_reduce_layer(bond_length_pred)
72 | for l in range(self.L + 1):
73 | bond_length_pred = self.activation(bond_length_pred)
74 | bond_length_pred = self.bl_layers[l](bond_length_pred)
75 |
76 | # bond_angle
77 | bond_angle_pred = x_atoms
78 | for l in range(self.L):
79 | bond_angle_pred = self.ba_layers[l](bond_angle_pred)
80 | bond_angle_pred = self.activation(bond_angle_pred)
81 | bond_angle_pred = self.ba_layers[self.L](bond_angle_pred)
82 |
83 | # dihedral_angle
84 | dihedral_angle_pred = edge_attr
85 | for l in range(self.L):
86 | dihedral_angle_pred = self.da_layers[l](dihedral_angle_pred)
87 | dihedral_angle_pred = self.activation(dihedral_angle_pred)
88 | dihedral_angle_pred = self.da_layers[self.L](dihedral_angle_pred)
89 |
90 | # total energy
91 | # graph_rep = self.pooling_fun(batch.x, batch.batch)
92 |
93 | x_frags_pooled = scatter_add(src=x_frags, index=batch["frag_batch"], dim=0)
94 | x_atoms_pooled = scatter_add(src=x_atoms, index=batch["batch"], dim=0)
95 |
96 | graph_rep = torch.cat((x_atoms_pooled, x_frags_pooled), 1)
97 | for l in range(self.L):
98 | graph_rep = self.FC_layers[l](graph_rep)
99 | graph_rep = self.activation(graph_rep)
100 | graph_rep = self.FC_layers[self.L](graph_rep)
101 |
102 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep
103 |
104 |
105 | class FragNetPreTrain(nn.Module):
106 |
107 | def __init__(
108 | self,
109 | num_layer=4,
110 | drop_ratio=0.15,
111 | num_heads=4,
112 | emb_dim=128,
113 | atom_features=167,
114 | frag_features=167,
115 | edge_features=16,
116 | fedge_in=6,
117 | fbond_edge_in=6,
118 | ):
119 | super(FragNetPreTrain, self).__init__()
120 |
121 | self.pretrain = FragNet(
122 | num_layer=num_layer,
123 | drop_ratio=drop_ratio,
124 | num_heads=num_heads,
125 | emb_dim=emb_dim,
126 | atom_features=atom_features,
127 | frag_features=frag_features,
128 | edge_features=edge_features,
129 | fedge_in=fedge_in,
130 | fbond_edge_in=fbond_edge_in,
131 | )
132 | self.head = PretrainTask(128, 1)
133 |
134 | def forward(self, batch):
135 |
136 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch)
137 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = self.head(
138 | x_atoms, x_frags, e_edge, batch
139 | )
140 |
141 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep
142 |
143 |
144 | class FragNetPreTrainMasked(nn.Module):
145 |
146 | def __init__(
147 | self,
148 | num_layer=4,
149 | drop_ratio=0.15,
150 | num_heads=4,
151 | emb_dim=128,
152 | atom_features=167,
153 | frag_features=167,
154 | edge_features=16,
155 | fedge_in=6,
156 | fbond_edge_in=6,
157 | ):
158 | super(FragNetPreTrainMasked, self).__init__()
159 |
160 | # self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio)
161 | self.pretrain = FragNet(
162 | num_layer=num_layer,
163 | drop_ratio=drop_ratio,
164 | num_heads=num_heads,
165 | emb_dim=emb_dim,
166 | atom_features=atom_features,
167 | frag_features=frag_features,
168 | edge_features=edge_features,
169 | fedge_in=fedge_in,
170 | fbond_edge_in=fbond_edge_in,
171 | )
172 | self.head = PretrainTask(128, 1)
173 |
174 | def forward(self, batch):
175 |
176 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch)
177 |
178 | with torch.no_grad():
179 | n_atoms = x_atoms.shape[0]
180 |
181 | unmask_atoms = random.sample(list(range(n_atoms)), int(n_atoms * 0.85))
182 | x_atoms_masked = torch.zeros(x_atoms.size()) + 0.0
183 | x_atoms_masked = x_atoms_masked.to(x_atoms.device)
184 | x_atoms_masked[unmask_atoms] = x_atoms[unmask_atoms]
185 |
186 |
187 | class FragNetPreTrainMasked2(nn.Module):
188 |
189 | def __init__(
190 | self,
191 | num_layer=4,
192 | drop_ratio=0.15,
193 | num_heads=4,
194 | emb_dim=128,
195 | atom_features=167,
196 | frag_features=167,
197 | edge_features=16,
198 | fedge_in=6,
199 | fbond_edge_in=6,
200 | ):
201 | super(FragNetPreTrainMasked2, self).__init__()
202 |
203 | # self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio)
204 | self.pretrain = FragNet(
205 | num_layer=num_layer,
206 | drop_ratio=drop_ratio,
207 | num_heads=num_heads,
208 | emb_dim=emb_dim,
209 | atom_features=atom_features,
210 | frag_features=frag_features,
211 | edge_features=edge_features,
212 | fedge_in=fedge_in,
213 | fbond_edge_in=fbond_edge_in,
214 | )
215 | self.head = PretrainTask(128, 1)
216 |
217 | def forward(self, batch):
218 |
219 | with torch.no_grad():
220 | x_atoms = batch["x_atoms"]
221 | n_atoms = x_atoms.shape[0]
222 |
223 | unmask_atoms = random.sample(list(range(n_atoms)), int(n_atoms * 0.85))
224 | x_atoms_masked = torch.zeros(x_atoms.size()) + 0.0
225 | x_atoms_masked = x_atoms_masked.to(x_atoms.device)
226 | x_atoms_masked[unmask_atoms] = x_atoms[unmask_atoms]
227 |
228 | batch["x_atoms"] = x_atoms_masked
229 |
230 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch)
231 |
232 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = self.head(
233 | x_atoms, x_frags, e_edge, batch
234 | )
235 |
236 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep
237 |
--------------------------------------------------------------------------------
/fragnet/model/gcn/gcn.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Parameter
2 | from torch_geometric.utils import add_self_loops, degree
3 | import torch
4 | import torch.nn as nn
5 | from torch_scatter import scatter_add, scatter_softmax
6 | from torch_geometric.utils import add_self_loops, degree
7 | from torch_geometric.utils import add_self_loops
8 | from torch_scatter import scatter_add
9 |
10 |
11 | class FragNetLayer(nn.Module):
12 | def __init__(self, atom_in=128, atom_out=128, frag_in=128, frag_out=128,
13 | edge_in=128, edge_out=128):
14 | super(FragNetLayer, self).__init__()
15 |
16 |
17 | self.atom_embed = nn.Linear(atom_in, atom_out, bias=True)
18 | self.frag_embed = nn.Linear(frag_in, frag_out)
19 | self.edge_embed = nn.Linear(edge_in, edge_out)
20 |
21 | self.frag_message_mlp = nn.Linear(atom_out*2, atom_out)
22 | self.atom_mlp = torch.nn.Sequential(torch.nn.Linear(atom_out, 2*atom_out),
23 | torch.nn.ReLU(),
24 | torch.nn.Linear(2*atom_out, atom_out))
25 |
26 | self.frag_mlp = torch.nn.Sequential(torch.nn.Linear(atom_out, 2*atom_out),
27 | torch.nn.ReLU(),
28 | torch.nn.Linear(2*atom_out, atom_out))
29 |
30 | def forward(self, x_atoms,
31 | edge_index,
32 | edge_attr,
33 | frag_index,
34 | x_frags,
35 | atom_to_frag_ids):
36 |
37 |
38 | edge_index, _ = add_self_loops(edge_index=edge_index)
39 |
40 | self_loop_attr = torch.zeros(x_atoms.size(0), 12, dtype=torch.long)
41 | self_loop_attr[:,0] = 0 # bond type for self-loop edge
42 | edge_attr = torch.cat((edge_attr, self_loop_attr.to(edge_attr)), dim=0)
43 |
44 |
45 | x_atoms = self.atom_embed(x_atoms)
46 |
47 | edge_attr = self.edge_embed(edge_attr)
48 |
49 | source, target = edge_index
50 |
51 | source_features = torch.index_select(input=x_atoms, index=source, dim=0)
52 |
53 |
54 | deg = degree(source, x_atoms.size(0), dtype=x_atoms.dtype)
55 | deg_inv_sqrt = deg.pow(-0.5)
56 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
57 | norm = deg_inv_sqrt[source] * deg_inv_sqrt[target]
58 |
59 | message = source_features*norm.view(-1,1)
60 |
61 | x_atoms_new = scatter_add(src=message, index=target, dim=0)
62 |
63 | x_frags = scatter_add(src=x_atoms_new, index=atom_to_frag_ids, dim=0)
64 |
65 | source, target = frag_index
66 | source_features = torch.index_select(input=x_frags, index=source, dim=0)
67 | frag_message = source_features
68 | frag_feats_sum = scatter_add(src=frag_message, index=target, dim=0)
69 | frag_feats_sum = self.frag_mlp(frag_feats_sum)
70 |
71 | x_frags_new = frag_feats_sum
72 | return x_atoms_new, x_frags_new
73 |
74 |
75 |
76 |
77 | class FragNet(nn.Module):
78 |
79 | def __init__(self, num_layer, drop_ratio = 0, emb_dim=128,
80 | atom_features=45, frag_features=45, edge_features=12):
81 | super(FragNet, self).__init__()
82 | self.num_layer = num_layer
83 | self.dropout = nn.Dropout(p=drop_ratio)
84 | self.act = nn.ReLU()
85 |
86 | self.layer1 = FragNetLayer(atom_in=atom_features, atom_out=emb_dim, frag_in=frag_features,
87 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim)
88 | self.layer2 = FragNetLayer(atom_in=emb_dim, atom_out=emb_dim, frag_in=emb_dim,
89 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim)
90 | self.layer3 = FragNetLayer(atom_in=emb_dim, atom_out=emb_dim, frag_in=emb_dim,
91 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim)
92 | self.layer4 = FragNetLayer(atom_in=emb_dim, atom_out=emb_dim, frag_in=emb_dim,
93 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim)
94 |
95 | ###List of batchnorms
96 | self.batch_norms = torch.nn.ModuleList()
97 | for layer in range(num_layer):
98 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
99 |
100 | def forward(self, batch):
101 |
102 |
103 | x_atoms = batch['x_atoms']
104 | edge_index = batch['edge_index']
105 | frag_index = batch['frag_index']
106 |
107 | x_frags = batch['x_frags']
108 | edge_attr = batch['edge_attr']
109 | atom_batch = batch['batch']
110 | frag_batch = batch['frag_batch']
111 | atom_to_frag_ids = batch['atom_to_frag_ids']
112 |
113 | node_feautures_bond_graph=batch['node_features_bonds']
114 | edge_index_bonds_graph=batch['edge_index_bonds_graph']
115 | edge_attr_bond_graph=batch['edge_attr_bonds']
116 |
117 |
118 | x_atoms = self.dropout(x_atoms)
119 | x_frags = self.dropout(x_frags)
120 |
121 | x_atoms, x_frags = self.layer1(x_atoms, edge_index, edge_attr,
122 | frag_index, x_frags, atom_to_frag_ids)
123 | x_atoms, x_frags = self.act(x_atoms), self.act(x_frags)
124 |
125 | x_atoms, x_frags = self.layer2(x_atoms, edge_index, edge_attr,
126 | frag_index, x_frags, atom_to_frag_ids)
127 | x_atoms, x_frags = self.act(x_atoms), self.act(x_frags)
128 |
129 | x_atoms, x_frags = self.layer3(x_atoms, edge_index, edge_attr,
130 | frag_index, x_frags, atom_to_frag_ids)
131 | x_atoms, x_frags = self.act(x_atoms), self.act(x_frags)
132 |
133 | x_atoms, x_frags = self.layer4(x_atoms, edge_index, edge_attr,
134 | frag_index, x_frags, atom_to_frag_ids)
135 | x_atoms, x_frags = self.act(x_atoms), self.act(x_frags)
136 |
137 |
138 | return x_atoms, x_frags
139 |
140 |
141 | class FragNetPreTrain(nn.Module):
142 |
143 | def __init__(self):
144 | super(FragNetPreTrain, self).__init__()
145 |
146 | self.pretrain = FragNet(num_layer=6, drop_ratio=0.15)
147 | self.lin1 = nn.Linear(128, 128)
148 | self.out = nn.Linear(128, 13)
149 | self.dropout = nn.Dropout(p=0.15)
150 | self.activation = nn.ReLU()
151 |
152 | def forward(self, batch):
153 |
154 | x_atoms, x_frags = self.pretrain(batch)
155 |
156 | x = self.dropout(x_atoms)
157 | x = self.lin1(x)
158 | x = self.activation(x)
159 | x = self.dropout(x)
160 | x = self.out(x)
161 |
162 |
163 | return x
164 |
165 | class FragNetFineTune(nn.Module):
166 |
167 | def __init__(self, n_classes=1):
168 | super(FragNetFineTune, self).__init__()
169 |
170 | self.pretrain = FragNet(num_layer=6, drop_ratio=0.15)
171 | self.lin1 = nn.Linear(128*2, 128*2)
172 |
173 | self.dropout = nn.Dropout(p=0.15)
174 | self.activation = nn.ReLU()
175 | self.out = nn.Linear(128*2, n_classes)
176 |
177 |
178 | def forward(self, batch):
179 |
180 | x_atoms, x_frags = self.pretrain(batch)
181 |
182 | x_frags_pooled = scatter_add(src=x_frags, index=batch['frag_batch'], dim=0)
183 | x_atoms_pooled = scatter_add(src=x_atoms, index=batch['batch'], dim=0)
184 |
185 | cat = torch.cat((x_atoms_pooled, x_frags_pooled), 1)
186 | x = self.dropout(cat)
187 | x = self.lin1(x)
188 | x = self.activation(x)
189 | x = self.dropout(x)
190 | x = self.out(x)
191 |
192 |
193 | return x
--------------------------------------------------------------------------------
/fragnet/model/gcn/gcn2.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Parameter
2 | from torch_geometric.utils import add_self_loops, degree
3 | import torch
4 | import torch.nn as nn
5 | from torch_scatter import scatter_add, scatter_softmax
6 | from torch_geometric.utils import add_self_loops, degree
7 | from torch_geometric.utils import add_self_loops
8 | from torch_scatter import scatter_add
9 | from gat2 import FTHead3, FTHead4
10 |
11 | class FragNetLayer(nn.Module):
12 | def __init__(self, atom_in=128, atom_out=128, frag_in=128, frag_out=128,
13 | edge_in=128, edge_out=128):
14 | super(FragNetLayer, self).__init__()
15 |
16 |
17 | self.atom_embed = nn.Linear(atom_in, atom_out, bias=True)
18 | self.frag_embed = nn.Linear(frag_in, frag_out)
19 | self.edge_embed = nn.Linear(edge_in, edge_out)
20 |
21 | self.frag_message_mlp = nn.Linear(atom_out*2, atom_out)
22 | self.atom_mlp = torch.nn.Sequential(torch.nn.Linear(atom_out, 2*atom_out),
23 | torch.nn.ReLU(),
24 | torch.nn.Linear(2*atom_out, atom_out))
25 |
26 | self.frag_mlp = torch.nn.Sequential(torch.nn.Linear(atom_out, 2*atom_out),
27 | torch.nn.ReLU(),
28 | torch.nn.Linear(2*atom_out, atom_out))
29 |
30 |
31 | def forward(self, x_atoms,
32 | edge_index,
33 | edge_attr,
34 | frag_index,
35 | x_frags,
36 | atom_to_frag_ids):
37 |
38 |
39 | edge_index, _ = add_self_loops(edge_index=edge_index)
40 |
41 | self_loop_attr = torch.zeros(x_atoms.size(0), edge_attr.shape[1], dtype=torch.long)
42 | edge_attr = torch.cat((edge_attr, self_loop_attr.to(edge_attr)), dim=0)
43 |
44 |
45 | x_atoms = self.atom_embed(x_atoms)
46 | edge_attr = self.edge_embed(edge_attr)
47 |
48 | source, target = edge_index
49 | source_features = torch.index_select(input=x_atoms, index=source, dim=0)
50 |
51 | deg = degree(source, x_atoms.size(0), dtype=x_atoms.dtype)
52 | deg_inv_sqrt = deg.pow(-0.5)
53 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
54 | norm = deg_inv_sqrt[source] * deg_inv_sqrt[target]
55 |
56 | message = source_features*norm.view(-1,1)
57 |
58 | x_atoms_new = scatter_add(src=message, index=target, dim=0)
59 | x_frags = scatter_add(src=x_atoms_new, index=atom_to_frag_ids, dim=0)
60 |
61 | source, target = frag_index
62 | source_features = torch.index_select(input=x_frags, index=source, dim=0)
63 |
64 | frag_message = source_features
65 | frag_feats_sum = scatter_add(src=frag_message, index=target, dim=0)
66 | frag_feats_sum = self.frag_mlp(frag_feats_sum)
67 |
68 | x_frags_new = frag_feats_sum
69 |
70 |
71 | return x_atoms_new, x_frags_new
72 |
73 |
74 |
75 |
76 | class FragNet(nn.Module):
77 |
78 | def __init__(self, num_layer, drop_ratio = 0, emb_dim=128,
79 | atom_features=45, frag_features=45, edge_features=12):
80 | super(FragNet, self).__init__()
81 | self.num_layer = num_layer
82 | self.dropout = nn.Dropout(p=drop_ratio)
83 | self.act = nn.ReLU()
84 |
85 |
86 | self.layers = torch.nn.ModuleList()
87 | self.layers.append(FragNetLayer(atom_in=atom_features, atom_out=emb_dim, frag_in=frag_features,
88 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim))
89 |
90 | for i in range(num_layer-1):
91 | self.layers.append(FragNetLayer(atom_in=emb_dim, atom_out=emb_dim, frag_in=emb_dim,
92 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim))
93 |
94 | ###List of batchnorms
95 | self.batch_norms = torch.nn.ModuleList()
96 | for layer in range(num_layer):
97 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
98 |
99 | def forward(self, batch):
100 |
101 |
102 | x_atoms = batch['x_atoms']
103 | edge_index = batch['edge_index']
104 | frag_index = batch['frag_index']
105 |
106 | x_frags = batch['x_frags']
107 | edge_attr = batch['edge_attr']
108 | atom_batch = batch['batch']
109 | frag_batch = batch['frag_batch']
110 | atom_to_frag_ids = batch['atom_to_frag_ids']
111 |
112 | node_feautures_bond_graph=batch['node_features_bonds']
113 | edge_index_bonds_graph=batch['edge_index_bonds_graph']
114 | edge_attr_bond_graph=batch['edge_attr_bonds']
115 |
116 |
117 | x_atoms = self.dropout(x_atoms)
118 | x_frags = self.dropout(x_frags)
119 |
120 |
121 | x_atoms, x_frags = self.layers[0](x_atoms, edge_index, edge_attr,
122 | frag_index, x_frags, atom_to_frag_ids)
123 | x_atoms, x_frags = self.act(self.dropout(x_atoms)), self.act(self.dropout(x_frags))
124 |
125 | for layer in self.layers[1:]:
126 | x_atoms, x_frags = layer(x_atoms, edge_index, edge_attr,
127 | frag_index, x_frags, atom_to_frag_ids)
128 | x_atoms, x_frags = self.act(self.dropout(x_atoms)), self.act(self.dropout(x_frags))
129 |
130 | return x_atoms, x_frags
131 |
132 |
133 | class FragNetPreTrain(nn.Module):
134 |
135 | def __init__(self):
136 | super(FragNetPreTrain, self).__init__()
137 |
138 | self.pretrain = FragNet(num_layer=6, drop_ratio=0.15)
139 | self.lin1 = nn.Linear(128, 128)
140 | self.out = nn.Linear(128, 13)
141 | self.dropout = nn.Dropout(p=0.15)
142 | self.activation = nn.ReLU()
143 |
144 | def forward(self, batch):
145 |
146 | x_atoms, x_frags = self.pretrain(batch)
147 |
148 | x = self.dropout(x_atoms)
149 | x = self.lin1(x)
150 | x = self.activation(x)
151 | x = self.dropout(x)
152 | x = self.out(x)
153 |
154 |
155 | return x
156 |
157 |
158 |
159 | class FragNetFineTune(nn.Module):
160 |
161 | def __init__(self, n_classes=1, atom_features=167, frag_features=167, edge_features=16,
162 | num_layer=4, drop_ratio=.15, emb_dim=128,
163 | h1=256, h2=256, h3=256, h4=256, act='celu',fthead='FTHead3'):
164 | super().__init__()
165 |
166 | self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio, emb_dim=emb_dim,
167 | atom_features=atom_features, frag_features=frag_features,
168 | edge_features=edge_features)
169 | self.lin1 = nn.Linear(emb_dim*2, emb_dim*2)
170 |
171 | self.dropout = nn.Dropout(p=0.15)
172 | self.activation = nn.ReLU()
173 | if fthead == 'FTHead3':
174 | self.fthead = FTHead3(n_classes=n_classes,
175 | h1=h1, h2=h2, h3=h3, h4=h4,
176 | drop_ratio=drop_ratio, act=act)
177 | elif fthead == 'FTHead4':
178 | print('using FTHead4' )
179 | self.fthead = FTHead4(n_classes=n_classes,
180 | h1=h1, drop_ratio=drop_ratio, act=act)
181 |
182 |
183 | def forward(self, batch):
184 |
185 | x_atoms, x_frags = self.pretrain(batch)
186 |
187 | x_frags_pooled = scatter_add(src=x_frags, index=batch['frag_batch'], dim=0)
188 | x_atoms_pooled = scatter_add(src=x_atoms, index=batch['batch'], dim=0)
189 |
190 | cat = torch.cat((x_atoms_pooled, x_frags_pooled), 1)
191 | x = self.fthead(cat)
192 |
193 |
194 | return x
195 |
--------------------------------------------------------------------------------
/fragnet/model/gcn/gcn_pl.py:
--------------------------------------------------------------------------------
1 | from gcn import FragNetPreTrain
2 | from dataset import load_data_parts
3 | from data import mask_atom_features
4 | import torch.nn as nn
5 | from utils import EarlyStopping
6 | import torch
7 | from data import collate_fn
8 | from torch.utils.data import DataLoader
9 | from features import atom_list_one_hot
10 | from sklearn.model_selection import train_test_split
11 | import os
12 | import pickle
13 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
14 | import pytorch_lightning as pl
15 | from pytorch_lightning import loggers
16 | import torch.optim as optim
17 |
18 |
19 | def load_ids(fn):
20 |
21 | if not os.path.exists('gcn_output/train_ids.pkl'):
22 |
23 | train_ids, test_ids = train_test_split(fn, test_size=.2)
24 | test_ids, val_ids = train_test_split(test_ids, test_size=.5)
25 |
26 | with open('gcn_output/train_ids.pkl', 'wb') as f:
27 | pickle.dump(train_ids, f)
28 | with open('gcn_output/val_ids.pkl', 'wb') as f:
29 | pickle.dump(val_ids, f)
30 | with open('gcn_output/test_ids.pkl', 'wb') as f:
31 | pickle.dump(test_ids, f)
32 |
33 | else:
34 | with open('gcn_output/train_ids.pkl', 'rb') as f:
35 | train_ids = pickle.load(f)
36 | with open('gcn_output/val_ids.pkl', 'rb') as f:
37 | val_ids = pickle.load(f)
38 | with open('gcn_output/test_ids.pkl', 'rb') as f:
39 | test_ids = pickle.load(f)
40 |
41 | return train_ids, val_ids, test_ids
42 |
43 |
44 | class FragNetModule(pl.LightningModule):
45 |
46 | def __init__(self, model):
47 | """
48 | Inputs:`
49 | model_name - Name of the model/CNN to run. Used for creating the model (see function below)
50 | model_hparams - Hyperparameters for the model, as dictionary.
51 | optimizer_name - Name of the optimizer to use. Currently supported: Adam, SGD
52 | optimizer_hparams - Hyperparameters for the optimizer, as dictionary. This includes learning rate, weight decay, etc.
53 | """
54 | super().__init__()
55 | # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
56 | self.save_hyperparameters()
57 | # Create model
58 | self.model = model
59 |
60 |
61 | def forward(self, batch):
62 | # Forward function that is run when visualizing the graph
63 |
64 | out = self.model(batch)
65 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1)
66 | preds = out.argmax(1)
67 | return preds, labels
68 |
69 | def configure_optimizers(self):
70 |
71 | optimizer = optim.AdamW(self.parameters())
72 |
73 | # We will reduce the learning rate by 0.1 after 100 and 150 epochs
74 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
75 |
76 |
77 | return [optimizer], [scheduler]
78 |
79 | def training_step(self, batch, batch_idx):
80 | # "batch" is the output of the training data loader.
81 | out = self.model(batch)
82 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1)
83 |
84 | loss = loss_fn(out, labels)
85 |
86 | self.log('train_loss', loss)
87 | return loss # Return tensor to call ".backward" on
88 |
89 | def validation_step(self, batch, batch_idx):
90 | # x, y, p, scaffold = batch
91 | out = self.model(batch)
92 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1)
93 | loss = loss_fn(out, labels)
94 |
95 | self.log('val_loss', loss)
96 |
97 | def test_step(self, batch, batch_idx):
98 | out = self.model(batch)
99 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1)
100 | loss = loss_fn(out, labels)
101 |
102 | self.log('test_loss', loss)
103 |
104 |
105 |
106 | if __name__ == "__main__":
107 |
108 |
109 | files = os.listdir('pretrain_data/')
110 | fn = sorted([ int(i.split('.pkl')[0].strip('train')) for i in files if i.endswith('.pkl')])
111 |
112 | train_ids, val_ids,test_ids = load_ids(fn)
113 |
114 | train_dataset = load_data_parts('pretrain_data', 'train', include=train_ids)
115 | val_dataset = load_data_parts('pretrain_data', 'train', include=val_ids)
116 | test_dataset = load_data_parts('pretrain_data', 'train', include=test_ids)
117 |
118 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True, drop_last=True)
119 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=256, shuffle=False, drop_last=False)
120 | test_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=256, shuffle=False, drop_last=False)
121 |
122 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
123 |
124 |
125 | model_pretrain = FragNetPreTrain()
126 | loss_fn = nn.CrossEntropyLoss()
127 |
128 |
129 | checkpoint_callback = ModelCheckpoint(
130 | save_weights_only=True, mode="min", monitor="val_loss",
131 | save_top_k=1,
132 | verbose=True,
133 | # dirpath=f'{args.exp_dir}/ckpt',
134 | dirpath=None,
135 | filename='model_best')
136 | # if args.trainer_version:
137 | trainer_version='tmp'
138 |
139 | logger = loggers.TensorBoardLogger(save_dir='./', version=trainer_version, name="lightning_logs")
140 | trainer = pl.Trainer(default_root_dir='./', # Where to save models
141 | accelerator="gpu" if str(device).startswith("cuda") else "cpu",
142 | devices=1,
143 | max_epochs=100,
144 | callbacks=[checkpoint_callback, LearningRateMonitor("epoch")],
145 | enable_progress_bar=True,
146 | gradient_clip_val=10,
147 | precision=16,
148 | limit_train_batches=None,
149 | limit_val_batches=None,
150 | logger=logger
151 | )
152 | trainer.logger._log_graph = True
153 | trainer.logger._default_hp_metric = None
154 |
155 |
156 | model = FragNetModule(model_pretrain)
157 | trainer.fit(model, train_loader, val_loader)
--------------------------------------------------------------------------------
/fragnet/train/finetune/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/train/finetune/__init__.py
--------------------------------------------------------------------------------
/fragnet/train/finetune/finetune_gat.py:
--------------------------------------------------------------------------------
1 | from dataset import load_pickle_dataset
2 | from torch.utils.data import DataLoader
3 | from data import collate_fn
4 | from gat import FragNetFineTune
5 | import torch.nn as nn
6 | from gat import FragNetPreTrain
7 | import torch
8 | from utils import EarlyStopping
9 | import numpy as np
10 | from utils import test_fn
11 | import argparse
12 | import os
13 | import pandas as pd
14 |
15 |
16 | def seed_everything(seed: int):
17 | import random, os
18 |
19 | random.seed(seed)
20 | os.environ['PYTHONHASHSEED'] = str(seed)
21 | np.random.seed(seed)
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = True
26 |
27 | seed_everything(26)
28 |
29 | def save_preds(test_dataset, true, pred, exp_dir):
30 |
31 | smiles = [i.smiles for i in test_dataset]
32 | res= pd.DataFrame(np.column_stack([smiles, true, pred]), columns=['smiles', 'true', 'pred'])
33 | res.to_csv(os.path.join(exp_dir, 'test_predictions.csv'), index=False)
34 |
35 | def train(train_loader, model, device, optimizer):
36 | model.train()
37 | total_loss = 0
38 | for batch in train_loader:
39 | for k,v in batch.items():
40 | batch[k] = batch[k].to(device)
41 | optimizer.zero_grad()
42 | out = model(batch).view(-1,)
43 | loss = loss_fn(out, batch['y'])
44 | loss.backward()
45 | total_loss += loss.item()
46 | optimizer.step()
47 | return total_loss / len(train_loader.dataset)
48 |
49 | def get_optimizer(model, freeze_pt_weights=False, lr=1e-4):
50 |
51 | if freeze_pt_weights:
52 | print('freezing pretrain weights')
53 | for name, param in model.named_parameters():
54 | if param.requires_grad and 'pretrain' in name:
55 | param.requires_grad = False
56 |
57 | non_frozen_parameters = [p for p in model.parameters() if p.requires_grad]
58 | optimizer = torch.optim.Adam(non_frozen_parameters, lr = lr)
59 | else:
60 | print('no freezing of the weights')
61 | optimizer = torch.optim.Adam(model.parameters(), lr = lr )
62 |
63 | return model, optimizer
64 |
65 |
66 | if __name__ == "__main__":
67 |
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument('--dataset_name', help="", type=str,required=False, default='moleculenet')
70 | parser.add_argument('--dataset_subset', help="", type=str,required=False, default='esol')
71 | parser.add_argument('--seed', help="", type=int,required=False, default=None)
72 | parser.add_argument('--batch_size', help="saving file name", type=int,required=False, default=32)
73 | parser.add_argument('--checkpoint', help="checkpoint name", type=str,required=False, default='gnn.pt')
74 | parser.add_argument('--add_pt_weights', help="checkpoint name", type=bool,required=False, default=False)
75 | parser.add_argument('--freeze_pt_weights', help="checkpoint name", type=bool,required=False, default=False)
76 | parser.add_argument('--exp_dir', help="", type=str,required=False, default='exps/pnnl')
77 | args = parser.parse_args()
78 |
79 | dataset_name = args.dataset_name # 'esol'
80 | seed = args.seed # 36
81 | dataset_subset = args.dataset_subset.lower()# 'moleculenet'
82 | batch_size = args.batch_size
83 | exp_dir = args.exp_dir
84 | if not os.path.exists(exp_dir):
85 | os.makedirs(exp_dir)
86 | print('dataset: ', dataset_name, 'subset: ', dataset_subset)
87 |
88 |
89 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
90 |
91 | if dataset_name == 'moleculenet':
92 | train_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'train_{seed}')
93 | val_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'val_{seed}')
94 | test_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'test_{seed}')
95 | elif 'pnnl' in dataset_name:
96 | train_dataset = load_pickle_dataset(path=dataset_name, name='train')
97 | val_dataset = load_pickle_dataset(path=dataset_name, name='val')
98 | test_dataset = load_pickle_dataset(path=dataset_name, name='test')
99 |
100 |
101 |
102 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, drop_last=True)
103 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, drop_last=False)
104 | test_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, drop_last=False)
105 |
106 | if args.add_pt_weights:
107 | model_pretrain = FragNetPreTrain();
108 | model_pretrain.to(device);
109 | model_pretrain.load_state_dict(torch.load('pt.pt'))
110 | model = FragNetFineTune()
111 | model.to(device);
112 | loss_fn = nn.MSELoss()
113 |
114 |
115 | if args.add_pt_weights:
116 | print("adding pretrain weights to the model")
117 | model.pretrain.load_state_dict(model_pretrain.pretrain.state_dict())
118 |
119 | model, optimizer = get_optimizer(model, freeze_pt_weights=args.freeze_pt_weights, lr=1e-4)
120 |
121 | chkpoint_name= os.path.join(exp_dir, args.checkpoint)
122 | early_stopping = EarlyStopping(patience=200, verbose=True, chkpoint_name=chkpoint_name)
123 |
124 |
125 | for epoch in range(2000):
126 |
127 | res = []
128 | model.train()
129 | total_loss = 0
130 | for batch in train_loader:
131 | for k,v in batch.items():
132 | batch[k] = batch[k].to(device)
133 | optimizer.zero_grad()
134 | out = model(batch).view(-1,)
135 | loss = loss_fn(out, batch['y'])
136 | loss.backward()
137 | total_loss += loss.item()
138 | optimizer.step()
139 |
140 |
141 |
142 | val_loss, _, _ = test_fn(val_loader, model, device)
143 | res.append(val_loss)
144 | print("val mse: ", val_loss)
145 |
146 | early_stopping(val_loss, model)
147 |
148 | if early_stopping.early_stop:
149 | print("Early stopping")
150 | break
151 |
152 |
153 | model.load_state_dict(torch.load(chkpoint_name))
154 |
155 | mse, true, pred = test_fn(test_loader, model, device)
156 | save_preds(test_dataset, true, pred, exp_dir)
157 | print("rmse: ", mse**.5)
158 |
159 |
160 |
161 |
--------------------------------------------------------------------------------
/fragnet/train/finetune/gat2_cv_frag.py:
--------------------------------------------------------------------------------
1 | # from gat import FragNetFineTune
2 | import torch
3 | from fragnet.dataset.dataset import load_pickle_dataset
4 | import torch.nn as nn
5 | from fragnet.train.utils import EarlyStopping
6 | import torch
7 | from fragnet.dataset.data import collate_fn
8 | from torch.utils.data import DataLoader
9 | import pandas as pd
10 | import argparse
11 | from fragnet.train.utils import TrainerFineTune as Trainer
12 | import numpy as np
13 | from omegaconf import OmegaConf
14 | import pickle
15 | import torch.optim.lr_scheduler as lr_scheduler
16 | from sklearn.model_selection import KFold
17 | import os
18 | from sklearn.model_selection import train_test_split
19 |
20 | def seed_everything(seed: int):
21 | import random, os
22 |
23 | random.seed(seed)
24 | os.environ['PYTHONHASHSEED'] = str(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = True
30 |
31 | seed_everything(26)
32 |
33 | def get_predictions(trainer, loader, model, device):
34 | mse, true, pred = trainer.test(model=model, loader=loader, device=device)
35 | smiles = [i.smiles for i in loader.dataset]
36 | res = {'smiles': smiles, 'true': true, 'pred': pred}
37 | return res
38 |
39 | def save_predictions(exp_dir, save_name, res):
40 |
41 | with open(f"{exp_dir}/{save_name}.pkl", 'wb') as f:
42 | pickle.dump(res,f )
43 |
44 |
45 | if __name__ == "__main__":
46 |
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml')
49 | args = parser.parse_args()
50 |
51 | if args.config:
52 | opt = OmegaConf.load(args.config)
53 | OmegaConf.resolve(opt)
54 | args=opt
55 |
56 | seed_everything(args.seed)
57 | exp_dir = args['exp_dir']
58 | os.makedirs(exp_dir, exist_ok=True)
59 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
60 |
61 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings'
62 |
63 | if args.model_version == 'gat':
64 | from fragnet.model.gat.gat import FragNetFineTune
65 | print('loaded from gat')
66 | elif args.model_version=='gat2':
67 | from fragnet.model.gat.gat2 import FragNetFineTune
68 | print('loaded from gat2')
69 |
70 |
71 | model = FragNetFineTune(n_classes=args.finetune.model.n_classes,
72 | atom_features=args.atom_features,
73 | frag_features=args.frag_features,
74 | edge_features=args.edge_features,
75 | num_layer=args.finetune.model.num_layer,
76 | drop_ratio=args.finetune.model.drop_ratio,
77 | num_heads=args.finetune.model.num_heads,
78 | emb_dim=args.finetune.model.emb_dim,
79 | h1=args.finetune.model.h1,
80 | h2=args.finetune.model.h2,
81 | h3=args.finetune.model.h3,
82 | h4=args.finetune.model.h4,
83 | act=args.finetune.model.act,
84 | fthead=args.finetune.model.fthead
85 | )
86 | trainer = Trainer(target_type=args.finetune.target_type)
87 |
88 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings'
89 | if pt_chkpoint_name:
90 |
91 | from fragnet.model.gat.gat2_pretrain import FragNetPreTrain
92 | modelpt = FragNetPreTrain(num_layer=args.finetune.model.num_layer,
93 | drop_ratio=args.finetune.model.drop_ratio,
94 | num_heads=args.finetune.model.num_heads,
95 | emb_dim=args.finetune.model.emb_dim,
96 | atom_features=args.atom_features, frag_features=args.frag_features, edge_features=args.edge_features)
97 |
98 |
99 | modelpt.load_state_dict(torch.load(pt_chkpoint_name, map_location=torch.device(device)))
100 |
101 | print('loading pretrained weights')
102 | model.pretrain.load_state_dict(modelpt.pretrain.state_dict())
103 | print('weights loaded')
104 | else:
105 | print('no pretrained weights')
106 |
107 |
108 | train_dataset2 = load_pickle_dataset(args.finetune.train.path)
109 |
110 | kf = KFold(n_splits=5)
111 | train_val = train_dataset2
112 |
113 | for icv, (train_index, test_index) in enumerate(kf.split(train_val)):
114 | train_index, val_index = train_test_split(train_index, test_size=.1)
115 |
116 | train_ds = [train_val[i] for i in train_index]
117 | val_ds = [train_val[i] for i in val_index]
118 | test_ds = [train_val[i] for i in test_index]
119 |
120 | train_loader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=True, drop_last=True)
121 | val_loader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=False, drop_last=False)
122 | test_loader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=False, drop_last=False)
123 |
124 |
125 | trainer = Trainer(target_type=args.finetune.target_type)
126 |
127 |
128 | model.to(device);
129 | ft_chk_point = os.path.join(args.exp_dir, f'cv_{icv}.pt')
130 | early_stopping = EarlyStopping(patience=args.finetune.es_patience, verbose=True, chkpoint_name=ft_chk_point)
131 |
132 |
133 | if args.finetune.loss == 'mse':
134 | loss_fn = nn.MSELoss()
135 | elif args.finetune.loss == 'cel':
136 | loss_fn = nn.CrossEntropyLoss()
137 |
138 | optimizer = torch.optim.Adam(model.parameters(), lr = args.finetune.lr ) # before 1e-4
139 | if args.finetune.use_schedular:
140 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
141 | else:
142 | scheduler=None
143 |
144 | for epoch in range(args.finetune.n_epochs):
145 |
146 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer,
147 | scheduler=scheduler, device=device, val_loader=val_loader)
148 | val_loss, _, _ = trainer.test(model=model, loader=val_loader, device=device)
149 |
150 |
151 | print(train_loss, val_loss)
152 |
153 |
154 | early_stopping(val_loss, model)
155 |
156 | if early_stopping.early_stop:
157 | print("Early stopping")
158 | break
159 |
160 |
161 | model.load_state_dict(torch.load(ft_chk_point))
162 | res = get_predictions(trainer, test_loader, model, device)
163 |
164 | with open(f'{exp_dir}/cv_{icv}.pkl', 'wb') as f:
165 | pickle.dump(res, f)
166 |
--------------------------------------------------------------------------------
/fragnet/train/finetune/trainer_cdrp.py:
--------------------------------------------------------------------------------
1 | from utils import compute_bce_loss
2 | import torch
3 | import numpy as np
4 | from sklearn.metrics import mean_squared_error, roc_auc_score
5 | import torch.nn as nn
6 |
7 |
8 | class TrainerFineTune:
9 | def __init__(self, target_pos=None, target_type='regr',
10 | n_multi_task_heads=0):
11 | self.target_pos = target_pos
12 | if target_type=='clsf':
13 |
14 | self.train = self.train_clsf_bce
15 | self.validate = self.validate_clsf_bce
16 | self.test = self.test_clsf_bce
17 | self.loss_fn = compute_bce_loss
18 |
19 | elif target_type=='regr':
20 | self.train = self.train_regr
21 | self.validate = self.validate_regr
22 | self.test = self.test_regr
23 | self.loss_fn = nn.MSELoss()
24 |
25 |
26 | elif target_type=='clsf_ms':
27 | self.train = self.train_clsf_multi_task
28 | self.validate = self.validate_clsf_multi_task
29 | self.test = self.test_clsf_multi_task
30 | self.n_multi_task_heads = n_multi_task_heads
31 |
32 |
33 | def train_regr(self, model, loader, optimizer, scheduler, device, val_loader, label_mean, label_sdev):
34 | model.train()
35 | total_loss = 0
36 | for batch in loader:
37 | for k,v in batch.items():
38 | batch[k] = batch[k].to(device)
39 | optimizer.zero_grad()
40 | out = model(batch).view(-1,)
41 | labels = batch['y']
42 |
43 | loss = self.loss_fn(out, labels)
44 | loss.backward()
45 | total_loss += loss.item()
46 | optimizer.step()
47 |
48 | if scheduler:
49 | val_loss = self.validate(model, val_loader, device)
50 | scheduler.step()
51 |
52 | return total_loss / len(loader.dataset)
53 |
54 | def validate_regr(self, model,loader, device, label_mean, label_sdev):
55 | model.eval()
56 | total_loss = 0
57 | with torch.no_grad():
58 |
59 | for batch in loader:
60 | for k,v in batch.items():
61 | batch[k] = batch[k].to(device)
62 |
63 | out = model(batch).view(-1,)
64 | labels = batch['y']
65 | loss = self.loss_fn(out, labels)
66 | total_loss += loss.item()
67 | return total_loss / len(loader.dataset)
68 |
69 |
70 | def test_regr(self, model, loader, device, label_mean, label_sdev):
71 |
72 | model.eval()
73 | with torch.no_grad():
74 | target, predicted = [], []
75 | for data in loader:
76 | for k,v in data.items():
77 | data[k] = data[k].to(device)
78 |
79 | output = model(data).view(-1,)
80 | pred = output
81 |
82 | target += list(data['y'].cpu().detach().numpy().ravel() )
83 | predicted += list(pred.cpu().detach().numpy().ravel() )
84 | mse = mean_squared_error(target, predicted)
85 |
86 | return mse, np.array(target), np.array(predicted)
87 |
88 |
--------------------------------------------------------------------------------
/fragnet/train/pretrain/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/train/pretrain/__init__.py
--------------------------------------------------------------------------------
/fragnet/train/pretrain/pretrain_data_pnnl.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from dataset import get_pt_dataset, extract_data, save_datasets
3 |
4 | tr = pd.read_csv('/people/pana982/solubility/data/full_dataset/isomers/wang/set2/train_new.csv')
5 | vl = pd.read_csv('/people/pana982/solubility/data/full_dataset/isomers/wang/set2/val_new.csv')
6 | ts = pd.read_csv('/people/pana982/solubility/data/full_dataset/isomers/wang/set2/test_new.csv')
7 |
8 | df = pd.concat([tr, vl, ts], axis=0)
9 | df.reset_index(drop=True, inplace=True)
10 |
11 | ds = get_pt_dataset(df)
12 | ds = extract_data(ds)
13 | save_path = f'pretrain_data/pnnl_iso/ds/train'
14 | save_datasets(ds, save_path)
15 |
--------------------------------------------------------------------------------
/fragnet/train/pretrain/pretrain_gat2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from fragnet.train.utils import EarlyStopping
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from fragnet.dataset.dataset import load_data_parts
7 | import argparse
8 | import numpy as np
9 | from omegaconf import OmegaConf
10 | import pickle
11 | from torch.utils.tensorboard import SummaryWriter
12 | from fragnet.model.gat.pretrain_heads import FragNetPreTrain, FragNetPreTrainMasked, FragNetPreTrainMasked2
13 | from fragnet.train.pretrain.pretrain_utils import Trainer
14 | from sklearn.model_selection import train_test_split
15 | from fragnet.dataset.data import collate_fn_pt as collate_fn
16 | import pandas as pd
17 | from fragnet.model.gat.gat2 import FragNet
18 | from torch_scatter import scatter_add
19 |
20 | def seed_everything(seed: int):
21 | import random, os
22 |
23 | random.seed(seed)
24 | os.environ['PYTHONHASHSEED'] = str(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = True
30 |
31 |
32 | def save_ds_smiles(ds,name, exp_dir):
33 | smiles = [i.smiles for i in ds]
34 | with open(exp_dir+f'/{name}_smiles.pkl', 'wb') as f:
35 | pickle.dump(smiles, f)
36 |
37 | def load_train_val_dataset(args, ds):
38 |
39 | train_smiles = pd.read_pickle(args.pretrain.train_smiles)
40 |
41 |
42 | train_dataset, val_dataset = [], []
43 | for data in ds:
44 | if data.smiles in train_smiles:
45 | train_dataset.append(data)
46 | else:
47 | val_dataset.append(data)
48 |
49 |
50 | return train_dataset, val_dataset
51 |
52 | def remove_duplicates_and_add(ds, path):
53 | t = load_data_parts(path)
54 | if len(ds) != 0:
55 | curr_smiles = [i.smiles for i in ds]
56 | new_ds = [i for i in t if i.smiles not in curr_smiles]
57 | else:
58 | new_ds=t
59 | ds += new_ds
60 | return ds
61 |
62 | def save_predictions(trainer, loader, model, exp_dir, device, save_name='test_res', loss_type='mse'):
63 | score, true, pred = trainer.test(model=model, loader=loader, device=device)
64 | smiles = [i.smiles for i in loader.dataset]
65 |
66 | if loss_type=='mse':
67 | print(f'{save_name} rmse: ', score**.5)
68 | res = {'acc': score**.5, 'true': true, 'pred': pred, 'smiles': smiles}
69 | elif loss_type=='cel':
70 | # score = roc_auc_score(true, pred[:,1])
71 | print(f'{save_name} auc: ', score)
72 | res = {'acc': score, 'true': true, 'pred': pred, 'smiles': smiles}
73 |
74 | with open(f"{exp_dir}/{save_name}.pkl", 'wb') as f:
75 | pickle.dump(res,f )
76 |
77 |
78 |
79 | if __name__ == "__main__":
80 |
81 | parser = argparse.ArgumentParser()
82 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml')
83 | args = parser.parse_args()
84 |
85 | if args.config:
86 | opt = OmegaConf.load(args.config)
87 | OmegaConf.resolve(opt)
88 | args=opt
89 |
90 | seed_everything(args.seed)
91 | exp_dir = args['exp_dir']
92 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
93 | writer = SummaryWriter(exp_dir+'/runs')
94 |
95 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings'
96 |
97 | if args.model_version == 'gat2':
98 | model = FragNetPreTrain(num_layer=args.pretrain.num_layer,
99 | drop_ratio=args.pretrain.drop_ratio,
100 | num_heads=args.pretrain.num_heads,
101 | emb_dim=args.pretrain.emb_dim,
102 | atom_features=args.atom_features,
103 | frag_features=args.frag_features,
104 | edge_features=args.edge_features,
105 | fedge_in=args.fedge_in,
106 | fbond_edge_in=args.fbond_edge_in
107 | )
108 | elif args.model_version == 'gat2_masked':
109 | model = FragNetPreTrainMasked(num_layer=args.pretrain.num_layer,
110 | drop_ratio=args.pretrain.drop_ratio,
111 | num_heads=args.pretrain.num_heads,
112 | emb_dim=args.pretrain.emb_dim,
113 | atom_features=args.atom_features,
114 | frag_features=args.frag_features,
115 | edge_features=args.edge_features,
116 | fedge_in=args.fedge_in,
117 | fbond_edge_in=args.fbond_edge_in)
118 |
119 | elif args.model_version == 'gat2_masked2':
120 | model = FragNetPreTrainMasked2(num_layer=args.pretrain.num_layer,
121 | drop_ratio=args.pretrain.drop_ratio,
122 | num_heads=args.pretrain.num_heads,
123 | emb_dim=args.pretrain.emb_dim,
124 | atom_features=args.atom_features,
125 | frag_features=args.frag_features,
126 | edge_features=args.edge_features,
127 | fedge_in=args.fedge_in,
128 | fbond_edge_in=args.fbond_edge_in)
129 |
130 | if args.pretrain.saved_checkpoint:
131 | model.load_state_dict(torch.load(args.pretrain.saved_checkpoint))
132 |
133 | ds=[]
134 | for path in args.pretrain.data:
135 | ds = remove_duplicates_and_add(ds, path)
136 |
137 |
138 | if args.pretrain.train_smiles:
139 | train_dataset, val_dataset = load_train_val_dataset(args, ds)
140 | else:
141 | train_dataset, val_dataset = train_test_split(ds, test_size=.1, random_state=42)
142 |
143 | # NOTE: Start new runs in a new directory
144 | save_ds_smiles(train_dataset, 'train', args.exp_dir)
145 | save_ds_smiles(val_dataset, 'val', args.exp_dir)
146 |
147 | print('number of data points: ', len(ds))
148 | writer.add_scalar('number of data points: ', len(ds))
149 |
150 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=args.pretrain.batch_size, shuffle=True, drop_last=True)
151 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=args.pretrain.batch_size, shuffle=False, drop_last=False)
152 |
153 |
154 | model.to(device);
155 | early_stopping = EarlyStopping(patience=args.pretrain.es_patience, verbose=True, chkpoint_name=pt_chkpoint_name)
156 |
157 |
158 | if args.pretrain.loss == 'mse':
159 | loss_fn = nn.MSELoss()
160 | elif args.pretrain.loss == 'cel':
161 | loss_fn = nn.CrossEntropyLoss()
162 |
163 | trainer = Trainer(loss_fn=loss_fn)
164 |
165 | optimizer = torch.optim.Adam(model.parameters(), lr = args.pretrain.lr ) # before 1e-4
166 |
167 | for epoch in range(args.pretrain.n_epochs):
168 |
169 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, device=device)
170 |
171 | writer.add_scalar('Loss/train', train_loss, epoch)
172 |
173 |
174 | if epoch%5==0:
175 | val_loss = trainer.validate(model=model, loader=val_loader, device=device)
176 | print(train_loss, val_loss)
177 | writer.add_scalar('Loss/val', val_loss, epoch)
178 |
179 | early_stopping(val_loss, model)
180 |
181 | if early_stopping.early_stop:
182 | print("Early stopping")
183 | break
184 |
--------------------------------------------------------------------------------
/fragnet/train/pretrain/pretrain_gat_mol.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from utils import EarlyStopping
3 | import torch
4 | from data import collate_fn
5 | from torch.utils.data import DataLoader
6 | import argparse
7 | import numpy as np
8 | from omegaconf import OmegaConf
9 | from utils import Trainer
10 | from torch.utils.tensorboard import SummaryWriter
11 | from dataset import LoadDataSets
12 | from pretrain_utils import load_prop_data, add_props_to_ds
13 |
14 |
15 | def seed_everything(seed: int):
16 | import random, os
17 |
18 | random.seed(seed)
19 | os.environ['PYTHONHASHSEED'] = str(seed)
20 | np.random.seed(seed)
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed(seed)
23 | torch.backends.cudnn.deterministic = True
24 | torch.backends.cudnn.benchmark = True
25 |
26 | seed_everything(26)
27 |
28 | """
29 | This is to pretrain on molecular properties
30 | """
31 |
32 |
33 | if __name__=="__main__":
34 |
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yml')
37 | args = parser.parse_args()
38 |
39 | if args.config: # args priority is higher than yaml
40 | opt = OmegaConf.load(args.config)
41 | OmegaConf.resolve(opt)
42 | args=opt
43 |
44 | writer = SummaryWriter(args.exp_dir)
45 | if args.model_version == 'gat':
46 | from gat import FragNetFineTune
47 | elif args.model_version == 'gat2':
48 | from gat2 import FragNetFineTune
49 |
50 |
51 | ds = LoadDataSets()
52 |
53 | train_dataset, val_dataset, test_dataset = ds.load_datasets(args)
54 |
55 | prop_dict, _ = load_prop_data(args)
56 | add_props_to_ds(train_dataset, prop_dict)
57 | add_props_to_ds(val_dataset, prop_dict)
58 |
59 |
60 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True, drop_last=True)
61 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=512, shuffle=False, drop_last=False)
62 |
63 | n_classes = args.pretrain.n_classes # 31 for nRings
64 | target_pos = args.pretrain.target_pos
65 | target_type = args.pretrain.target_type
66 |
67 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
68 | model = FragNetFineTune(n_classes=args.pretrain.n_classes,
69 | num_layer=args.pretrain.num_layer,
70 | drop_ratio=args.pretrain.drop_ratio)
71 |
72 | model.to(device)
73 |
74 | optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 )
75 | chkpoint_name = args.pretrain.chkpoint_name #'pt_gat_nring.pt'
76 | early_stopping = EarlyStopping(patience=args.pretrain.n_epochs, verbose=True, chkpoint_name=chkpoint_name)
77 |
78 | if args.pretrain.loss == 'mse':
79 | loss_fn = nn.MSELoss()
80 | elif args.pretrain.loss == 'cel':
81 | loss_fn = nn.CrossEntropyLoss()
82 |
83 | pretrainer = Trainer(target_pos=target_pos, target_type=target_type, loss_fn=loss_fn)
84 |
85 |
86 | for epoch in range(args.pretrain.n_epochs):
87 |
88 | train_loss = pretrainer.train(model, train_loader, optimizer, device)
89 | val_loss = pretrainer.validate(val_loader, model, device)
90 | writer.add_scalar('Loss/train', train_loss, epoch)
91 | writer.add_scalar('Loss/val', val_loss, epoch)
92 |
93 | early_stopping(val_loss, model)
94 |
95 | if early_stopping.early_stop:
96 | print("Early stopping")
97 | break
98 |
--------------------------------------------------------------------------------
/fragnet/train/pretrain/pretrain_gat_str.py:
--------------------------------------------------------------------------------
1 | from gat import FragNetPreTrain
2 | from dataset import load_data_parts
3 | from data import mask_atom_features
4 | import torch.nn as nn
5 | from utils import EarlyStopping
6 | import torch
7 | from data import collate_fn
8 | from torch.utils.data import DataLoader
9 | from features import atom_list_one_hot
10 |
11 |
12 | """
13 | this is to pretrain on molecular fragments
14 | """
15 | def train(model, train_loader, optimizer, device):
16 | model.train()
17 | total_loss = 0
18 | for batch in train_loader:
19 | mask_atom_features(batch)
20 | for k,v in batch.items():
21 | batch[k] = batch[k].to(device)
22 | optimizer.zero_grad()
23 | out = model(batch)
24 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1)
25 |
26 | loss = loss_fn(out, labels)
27 | loss.backward()
28 | total_loss += loss.item()
29 | optimizer.step()
30 | return total_loss / len(train_loader.dataset)
31 |
32 | def validate(loader, model, device):
33 | model.eval()
34 | total_loss = 0
35 | with torch.no_grad():
36 |
37 | for batch in loader:
38 | for k,v in batch.items():
39 | batch[k] = batch[k].to(device)
40 |
41 | out = model(batch)
42 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1)
43 |
44 | loss = loss_fn(out, labels)
45 | total_loss += loss.item()
46 | return total_loss / len(loader.dataset)
47 |
48 |
49 |
50 | if __name__ == '__main__':
51 |
52 | train_dataset = load_data_parts('pretrain_data', 'train', include=range(0,479))
53 | val_dataset = load_data_parts('pretrain_data', 'train', include=range(479,499))
54 |
55 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True, drop_last=True)
56 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=256, shuffle=False, drop_last=False)
57 |
58 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
59 |
60 | model_pretrain = FragNetPreTrain()
61 | model_pretrain.to(device)
62 |
63 | loss_fn = nn.CrossEntropyLoss()
64 |
65 | optimizer = torch.optim.Adam(model_pretrain.parameters(), lr = 1e-4 )
66 | chkpoint_name='pt_gat.pt'
67 | early_stopping = EarlyStopping(patience=100, verbose=True, chkpoint_name=chkpoint_name)
68 |
69 |
70 |
71 | for epoch in range(500):
72 |
73 | train_loss = train(model_pretrain, train_loader, optimizer, device)
74 | val_loss = validate(val_loader, model_pretrain, device)
75 | print(train_loss, val_loss)
76 |
77 |
78 | early_stopping(val_loss, model_pretrain)
79 |
80 | if early_stopping.early_stop:
81 | print("Early stopping")
82 | break
83 |
--------------------------------------------------------------------------------
/fragnet/train/pretrain/pretrain_gcn.py:
--------------------------------------------------------------------------------
1 | from gcn import FragNetPreTrain
2 | from dataset import load_data_parts
3 | from data import mask_atom_features
4 | import torch.nn as nn
5 | from utils import EarlyStopping
6 | import torch
7 | from data import collate_fn
8 | from torch.utils.data import DataLoader
9 | from features import atom_list_one_hot
10 | from tqdm import tqdm
11 | import pickle
12 | import os
13 | from sklearn.model_selection import train_test_split
14 |
15 | def load_ids(fn):
16 |
17 | if not os.path.exists('gcn_output/train_ids.pkl'):
18 |
19 | train_ids, test_ids = train_test_split(fn, test_size=.2)
20 | test_ids, val_ids = train_test_split(test_ids, test_size=.5)
21 |
22 | with open('gcn_output/train_ids.pkl', 'wb') as f:
23 | pickle.dump(train_ids, f)
24 | with open('gcn_output/val_ids.pkl', 'wb') as f:
25 | pickle.dump(val_ids, f)
26 | with open('gcn_output/test_ids.pkl', 'wb') as f:
27 | pickle.dump(test_ids, f)
28 |
29 | else:
30 | with open('gcn_output/train_ids.pkl', 'rb') as f:
31 | train_ids = pickle.load(f)
32 | with open('gcn_output/val_ids.pkl', 'rb') as f:
33 | val_ids = pickle.load(f)
34 | with open('gcn_output/test_ids.pkl', 'rb') as f:
35 | test_ids = pickle.load(f)
36 |
37 | return train_ids, val_ids, test_ids
38 |
39 |
40 | def train(model, train_loader, optimizer, device):
41 | model.train()
42 | total_loss = 0
43 | for batch in train_loader:
44 | # batch = data.to(device)
45 | mask_atom_features(batch)
46 | for k,v in batch.items():
47 | batch[k] = batch[k].to(device)
48 | optimizer.zero_grad()
49 | out = model(batch)
50 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1)
51 |
52 | loss = loss_fn(out, labels)
53 | loss.backward()
54 | total_loss += loss.item()
55 | optimizer.step()
56 | return total_loss / len(train_loader.dataset)
57 |
58 | def validate(loader, model, device):
59 | model.eval()
60 | total_loss = 0
61 | with torch.no_grad():
62 |
63 | for batch in loader:
64 | for k,v in batch.items():
65 | batch[k] = batch[k].to(device)
66 |
67 | out = model(batch)
68 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1)
69 |
70 | loss = loss_fn(out, labels)
71 | total_loss += loss.item()
72 | return total_loss / len(loader.dataset)
73 |
74 |
75 |
76 | if __name__ == '__main__':
77 |
78 | files = os.listdir('pretrain_data/')
79 | fn = sorted([ int(i.split('.pkl')[0].strip('train')) for i in files if i.endswith('.pkl')])
80 | train_ids, val_ids,test_ids = load_ids(fn)
81 | train_dataset = load_data_parts('pretrain_data', 'train', include=train_ids)
82 | val_dataset = load_data_parts('pretrain_data', 'train', include=val_ids)
83 |
84 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True, drop_last=True)
85 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=256, shuffle=False, drop_last=False)
86 |
87 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
88 |
89 | model_pretrain = FragNetPreTrain()
90 | model_pretrain.to(device)
91 |
92 | loss_fn = nn.CrossEntropyLoss()
93 |
94 | optimizer = torch.optim.Adam(model_pretrain.parameters(), lr = 1e-4 )
95 | chkpoint_name='pt_gcn.pt'
96 | early_stopping = EarlyStopping(patience=500, verbose=True, chkpoint_name=chkpoint_name)
97 |
98 |
99 |
100 | for epoch in tqdm(range(500)):
101 |
102 | train_loss = train(model_pretrain, train_loader, optimizer, device)
103 | val_loss = validate(val_loader, model_pretrain, device)
104 | print(train_loss, val_loss)
105 |
106 |
107 | early_stopping(val_loss, model_pretrain)
108 |
109 | if early_stopping.early_stop:
110 | print("Early stopping")
111 | break
112 |
--------------------------------------------------------------------------------
/fragnet/train/pretrain/pretrain_heads.py:
--------------------------------------------------------------------------------
1 | # from torch.nn import Parameter
2 | # from torch_geometric.utils import add_self_loops, degree
3 | import torch
4 | import torch.nn as nn
5 | # from torch_scatter import scatter_add, scatter_softmax
6 | # from torch_geometric.utils import add_self_loops
7 | from torch_scatter import scatter_add
8 | # import numpy as np
9 | # import torch.nn.functional as F
10 | # from torch_geometric.nn.norm import BatchNorm
11 | from fragnet.model.gat.gat2 import FragNet
12 |
13 |
14 |
15 | # get their activation functions
16 | class PretrainTask(nn.Module):
17 | """
18 | SAN prediction head for graph prediction tasks.
19 |
20 | Args:
21 | dim_in (int): Input dimension.
22 | dim_out (int): Output dimension. For binary prediction, dim_out=1.
23 | L (int): Number of hidden layers.
24 | """
25 |
26 | def __init__(self, dim_in=128, dim_out=1, L=2):
27 | super().__init__()
28 | # self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling]
29 |
30 | # bond_length
31 | self.bl_reduce_layer = nn.Linear(dim_in * 3, dim_in)
32 | list_bl_layers = [
33 | nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True)
34 | for l in range(L)]
35 | list_bl_layers.append(
36 | nn.Linear(dim_in // 2 ** L, dim_out, bias=True))
37 | self.bl_layers = nn.ModuleList(list_bl_layers)
38 |
39 | # bond_angle
40 | list_ba_layers = [
41 | nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True)
42 | for l in range(L)]
43 | list_ba_layers.append(
44 | nn.Linear(dim_in // 2 ** L, dim_out, bias=True))
45 | self.ba_layers = nn.ModuleList(list_ba_layers)
46 |
47 | # dihedral_angle
48 | list_da_layers = [
49 | nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True)
50 | for l in range(L)]
51 | list_da_layers.append(
52 | nn.Linear(dim_in // 2 ** L, dim_out, bias=True))
53 | self.da_layers = nn.ModuleList(list_da_layers)
54 |
55 | # graph-level prediction (energy)
56 | list_FC_layers = [
57 | nn.Linear(dim_in*2 // 2 ** l, dim_in*2 // 2 ** (l + 1), bias=True)
58 | for l in range(L)]
59 | list_FC_layers.append(
60 | nn.Linear(dim_in*2 // 2 ** L, dim_out, bias=True))
61 | self.FC_layers = nn.ModuleList(list_FC_layers)
62 |
63 | self.L = L
64 | self.activation = nn.ReLU()
65 |
66 | def _apply_index(self, batch):
67 | return batch.bond_length, batch.distance
68 |
69 | def forward(self, x_atoms, x_frags, edge_attr, batch):
70 | edge_index = batch['edge_index']
71 | # bond_length
72 | # bond_length_pair = batch.positions[batch.edge_index.T]
73 | # bond_length_true = torch.sum((bond_length_pair[:, 0, :] - bond_length_pair[:, 1, :]) ** 2, axis=1)
74 | bond_length_pred = torch.concat((x_atoms[edge_index.T][:,0,:], x_atoms[edge_index.T][:,1,:], edge_attr),axis=1)
75 | bond_length_pred = self.bl_reduce_layer(bond_length_pred)
76 | for l in range(self.L + 1):
77 | bond_length_pred = self.activation(bond_length_pred)
78 | bond_length_pred = self.bl_layers[l](bond_length_pred)
79 |
80 | # bond_angle
81 | bond_angle_pred = x_atoms
82 | for l in range(self.L):
83 | bond_angle_pred = self.ba_layers[l](bond_angle_pred)
84 | bond_angle_pred = self.activation(bond_angle_pred)
85 | bond_angle_pred = self.ba_layers[self.L](bond_angle_pred)
86 |
87 | # dihedral_angle
88 | dihedral_angle_pred = edge_attr
89 | for l in range(self.L):
90 | dihedral_angle_pred = self.da_layers[l](dihedral_angle_pred)
91 | dihedral_angle_pred = self.activation(dihedral_angle_pred)
92 | dihedral_angle_pred = self.da_layers[self.L](dihedral_angle_pred)
93 |
94 | # total energy
95 | # graph_rep = self.pooling_fun(batch.x, batch.batch)
96 |
97 | x_frags_pooled = scatter_add(src=x_frags, index=batch['frag_batch'], dim=0)
98 | x_atoms_pooled = scatter_add(src=x_atoms, index=batch['batch'], dim=0)
99 |
100 | graph_rep = torch.cat((x_atoms_pooled, x_frags_pooled), 1)
101 | for l in range(self.L):
102 | graph_rep = self.FC_layers[l](graph_rep)
103 | graph_rep = self.activation(graph_rep)
104 | graph_rep = self.FC_layers[self.L](graph_rep)
105 |
106 |
107 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep
108 |
109 |
110 |
111 |
112 | class FragNetPreTrain(nn.Module):
113 |
114 | def __init__(self, num_layer=4, drop_ratio=0.15, num_heads=4, emb_dim=128,
115 | atom_features=167, frag_features=167, edge_features=16):
116 | super(FragNetPreTrain, self).__init__()
117 |
118 | # self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio)
119 | self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio, num_heads=num_heads, emb_dim=emb_dim,
120 | atom_features=atom_features, frag_features=frag_features, edge_features=edge_features)
121 | self.head = PretrainTask(128, 1)
122 |
123 |
124 | def forward(self, batch):
125 |
126 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch)
127 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = self.head(x_atoms, x_frags, e_edge, batch)
128 |
129 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep
--------------------------------------------------------------------------------
/fragnet/train/pretrain/pretrain_utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 |
4 | class Trainer:
5 | def __init__(self, loss_fn=None):
6 | self.loss_fn = loss_fn
7 |
8 | # single output regression
9 | def train(self, model, loader, optimizer, device):
10 | model.train()
11 | total_loss = 0
12 | for batch in loader:
13 | for k,v in batch.items():
14 | batch[k] = batch[k].to(device)
15 | optimizer.zero_grad()
16 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = model(batch)
17 | bond_length_true = batch['bnd_lngth']
18 | bond_angle_true = batch['bnd_angl']
19 | dihedral_angle_true = batch['dh_angl']
20 | E = batch['y']
21 |
22 | loss_lngth = self.loss_fn(bond_length_pred, bond_length_true)
23 | loss_angle = self.loss_fn(bond_angle_pred, bond_angle_true)
24 | loss_lngth = self.loss_fn(dihedral_angle_pred, dihedral_angle_true)
25 | loss_E = self.loss_fn(graph_rep.view(-1), E)
26 | loss = loss_lngth + loss_angle + loss_lngth + loss_E
27 |
28 | loss.backward()
29 | total_loss += loss.item()
30 | optimizer.step()
31 | return total_loss / len(loader.dataset)
32 |
33 | def validate(self, loader, model, device):
34 | model.eval()
35 | total_loss = 0
36 | with torch.no_grad():
37 |
38 | for batch in loader:
39 | for k,v in batch.items():
40 | batch[k] = batch[k].to(device)
41 |
42 |
43 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = model(batch)
44 | bond_length_true = batch['bnd_lngth']
45 | bond_angle_true = batch['bnd_angl']
46 | dihedral_angle_true = batch['dh_angl']
47 | E = batch['y']
48 |
49 | loss_lngth = self.loss_fn(bond_length_pred, bond_length_true)
50 | loss_angle = self.loss_fn(bond_angle_pred, bond_angle_true)
51 | loss_lngth = self.loss_fn(dihedral_angle_pred, dihedral_angle_true)
52 | loss_E = self.loss_fn(graph_rep.view(-1), E)
53 | loss = loss_lngth + loss_angle + loss_lngth + loss_E
54 |
55 | total_loss += loss.item()
56 | return total_loss / len(loader.dataset)
57 |
58 |
59 | def load_prop_data(args):
60 |
61 | fs = []
62 | for f in args.pretrain.prop_files:
63 | fs.append(pd.read_csv(f))
64 | fs = pd.concat(fs, axis=0)
65 | fs.reset_index(drop=True, inplace=True)
66 |
67 | fs = fs.loc[:, ['smiles'] + args.pretrain.props]
68 |
69 | prop_dict = dict(zip(fs.smiles, fs.iloc[:, 1:].values.tolist()))
70 |
71 | return prop_dict, fs
72 |
73 | def add_props_to_ds(ds, prop_dict):
74 | for d in ds:
75 | smiles = d.smiles
76 | props = prop_dict[smiles]
77 | d.y = torch.tensor([props], dtype=torch.float)
--------------------------------------------------------------------------------
/fragnet/vizualize/app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import pandas as pd
3 | from pathlib import Path
4 | import base64
5 | from fragnet.vizualize.viz import FragNetVizApp
6 | from fragnet.vizualize.model import FragNetPreTrainViz
7 | from streamlit_ketcher import st_ketcher
8 | from fragnet.vizualize.model_attr import get_attr_image
9 | # Initial page config
10 |
11 | st.set_page_config(
12 | page_title='FragNet Vizualize',
13 | layout="wide",
14 | initial_sidebar_state="expanded",
15 | )
16 |
17 | # Thanks to streamlitopedia for the following code snippet
18 |
19 | def img_to_bytes(img_path):
20 | img_bytes = Path(img_path).read_bytes()
21 | encoded = base64.b64encode(img_bytes).decode()
22 | return encoded
23 |
24 | # sidebar
25 | def input_callback():
26 | st.session_state.input = st.session_state.my_input
27 | # def cs_sidebar():
28 |
29 | def predict_cdrp(smiles, cell_line, cell_line_df):
30 | gene_expr = cell_line_df.loc[cell_line,:].values
31 | viz.calc_weights_cdrp(smiles, gene_expr)
32 | prop_prediction = -1
33 | return viz, prop_prediction
34 |
35 |
36 |
37 | def resolve_prop_model(prop_type):
38 |
39 | if prop_type == 'Solubility':
40 | model_config = './fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/config_exp100.yaml'
41 | chkpt_path = './fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt'
42 | # model_config = './fragnet/exps/ft/pnnl_set2/fragnet_hpdl_exp1s_h4pt4_10/config_exp100.yaml'
43 | # chkpt_path = '../fragnet/exps/ft/pnnl_set2/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt'
44 |
45 | viz = FragNetVizApp(model_config, chkpt_path)
46 |
47 | prop_prediction = viz.calc_weights(selected)
48 |
49 |
50 | elif prop_type == 'Lipophilicity':
51 | model_config = './fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/config_exp100.yaml'
52 | chkpt_path = './fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/ft_100.pt'
53 | viz = FragNetVizApp(model_config, chkpt_path)
54 |
55 | prop_prediction = viz.calc_weights(selected)
56 |
57 | elif prop_type == 'Energy':
58 | model_config = '../fragnet/fragnet/exps/pt/unimol_exp1s4/config.yaml'
59 | chkpt_path = '../fragnet/fragnet/exps/pt/unimol_exp1s4/pt.pt'
60 | viz = FragNetVizApp(model_config, chkpt_path, 'energy')
61 | prop_prediction = viz.calc_weights(selected)
62 |
63 | return viz, prop_prediction, model_config, chkpt_path
64 |
65 | def resolve_DRP(smiles, cell_line, cell_line_df):
66 |
67 | model_config = '../fragnet/fragnet/exps/ft/gdsc/fragnet_hpdl_exp1s_pt4_30/config_exp100.yaml'
68 | chkpt_path = '../fragnet/fragnet/exps/ft/gdsc/fragnet_hpdl_exp1s_pt4_30/ft_100.pt'
69 | viz = FragNetVizApp(model_config, chkpt_path,'cdrp')
70 |
71 | # viz, prop_prediction = predict_cdrp(smiles=selected, cell_line=cell_line, cell_line_df=cell_line_df)
72 | gene_expr = cell_line_df.loc[cell_line,:].values
73 | viz.calc_weights_cdrp(smiles, gene_expr)
74 | prop_prediction = -1
75 |
76 | return viz, prop_prediction, model_config, chkpt_path
77 |
78 |
79 | # st.sidebar.markdown('''[
](https://streamlit.io/)'''.format(img_to_bytes("logomark_website.png")), unsafe_allow_html=True)
80 | st.sidebar.header('FragNet Vizualize')
81 |
82 | prop_type = st.sidebar.radio(
83 | "Select the Property type",
84 | # ["Solubility", "Lipophilicity", "Energy", "DRP"],
85 | ["Solubility", "Lipophilicity"],
86 | # captions = ["In logS units", "Lipophilicity", "Energy", "Drug Response Prediction"]
87 | captions = ["In logS units", "Lipophilicity"]
88 | )
89 |
90 | # def input_callback():
91 | # st.session_state.input = st.session_state.my_input
92 | # selected = st.text_input("Input Your Own SMILES :", key="my_input",on_change=input_callback,args=None)
93 |
94 | selected = st.sidebar.text_input("Input SMILES :", key="my_input",on_change=input_callback,args=None,
95 | value="CC1(C)CC(O)CC(C)(C)N1[O]")
96 | selected = st_ketcher( selected )
97 |
98 | # if prop_type=="DRP":
99 |
100 | # cell_line = st.sidebar.selectbox(
101 | # 'Select the cell line identifier',
102 | # ['DATA.906826',
103 | # 'DATA.687983',
104 | # 'DATA.910927',
105 | # 'DATA.1240138',
106 | # 'DATA.1240139',
107 | # 'DATA.906792',
108 | # 'DATA.910688',
109 | # 'DATA.1240135',
110 | # 'DATA.1290812',
111 | # 'DATA.907045',
112 | # 'DATA.906861',
113 | # 'DATA.906830',
114 | # 'DATA.909750',
115 | # 'DATA.1240137',
116 | # 'DATA.753552',
117 | # 'DATA.907065',
118 | # 'DATA.925338',
119 | # 'DATA.1290809',
120 | # 'DATA.949158',
121 | # 'DATA.924110'])
122 | # cell_line='DATA.924110'
123 | # cell_line_df = pd.read_csv('../fragnet/fragnet/assets/cell_line_data.csv', index_col=0)
124 |
125 | # st.sidebar.write(f'selected cell line: {cell_line}')
126 |
127 | col1, col2, col3 = st.columns(3)
128 |
129 | if prop_type in ["Solubility", "Lipophilicity", "Energy"]:
130 | viz, prop_prediction, model_config, chkpt_path = resolve_prop_model(prop_type)
131 | # elif prop_type == "DRP":
132 | # viz, prop_prediction, model_config, chkpt_path = resolve_DRP(selected, cell_line, cell_line_df)
133 |
134 |
135 | hide_bond_weights = st.sidebar.checkbox("Hide bond weights")
136 | hide_atom_weights = st.sidebar.checkbox("Hide atom weights")
137 |
138 | with col1:
139 | # root='/Users/pana982/projects/esmi/models/fragnet/fragnet/'
140 |
141 | png, atom_weights = viz.vizualize_atom_weights(hide_bond_weights, hide_atom_weights)
142 | col1.image(png, caption='Atom Weights')
143 |
144 | # png_attr = get_attr_image(selected)
145 | # col1.image(png_attr, caption='Fragment Attributions')
146 |
147 | attn_atoms = pd.DataFrame(atom_weights)
148 | attn_atoms.index.rename('Atom Index', inplace=True)
149 | attn_atoms.columns = ['Atom Weights']
150 | col1.dataframe(attn_atoms)
151 |
152 |
153 |
154 |
155 |
156 | if prop_type == "Solubility":
157 | st.sidebar.write(f"Predicted Solubility (logS): {prop_prediction:.4f}")
158 | if prop_type == "Lipophilicity":
159 | st.sidebar.write(f"Predicted Lipophilicity: {prop_prediction:.4f}")
160 | if prop_type == "Energy":
161 | st.sidebar.write(f"Predicted Energy: {prop_prediction:.4f}")
162 |
163 | hide_bond_weights=False
164 | png_frag_attn, png_frag_highlight, frag_w, connection_w, atoms_in_frags = viz.frag_weight_highlight()
165 |
166 | with col2:
167 |
168 | png_attr = get_attr_image(selected, model_config, chkpt_path)
169 | col2.image(png_attr, caption='Fragment Attributions')
170 |
171 |
172 | col2.image(png_frag_highlight, caption='Fragments')
173 | st.write("Atoms in each Fragment")
174 | df_atoms_in_frags = pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in atoms_in_frags.items() ])).T
175 | df_atoms_in_frags.index.rename('Fragment', inplace=True)
176 | st.dataframe(df_atoms_in_frags)
177 |
178 | with col3:
179 |
180 |
181 | st.image(png_frag_attn, caption='Fragment Weights')
182 |
183 |
184 | st.write('Fragment Weight Values')
185 | st.dataframe(frag_w)
186 |
187 | st.write('Fragment Connection Weight Values')
188 | st.dataframe(connection_w)
189 |
190 | st.sidebar.write(f"Current smiles: {selected}")
--------------------------------------------------------------------------------
/fragnet/vizualize/config.py:
--------------------------------------------------------------------------------
1 | # MODEL_CONFIG_PATH = '../fragnet_edge/exps/ft/esol/fragnet_hpdl_exp1s_pt4_30'
2 | # MODEL_CONFIG = f'{MODEL_CONFIG_PATH}/config_exp100.yaml'
3 | # MODEL_PATH = f'{MODEL_CONFIG_PATH}/ft_100.pt'
4 |
5 |
6 | MODEL_CONFIG_PATH = '../fragnet_edge/exps/ft/pnnl_set2/'
7 | MODEL_CONFIG = f'{MODEL_CONFIG_PATH}/exp1s_h4pt4.yaml'
8 | MODEL_PATH = f'{MODEL_CONFIG_PATH}/h4/ft.pt'
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/install_cpu.sh:
--------------------------------------------------------------------------------
1 | python3.11 -m venv ~/.env/fragnet
2 | source ~/.env/fragnet/bin/activate
3 | pip install --upgrade pip
4 | pip install -r requirements.txt
5 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
6 | pip install .
7 |
8 | mkdir -p fragnet/finetune_data/moleculenet/esol/raw/
9 |
10 | wget -O fragnet/finetune_data/moleculenet/esol/raw/delaney-processed.csv https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv
11 |
12 | python fragnet/data_create/create_pretrain_datasets.py --save_path fragnet/pretrain_data/esol --data_type exp1s --maxiters 500 --raw_data_path fragnet/finetune_data/moleculenet/esol/raw/delaney-processed.csv
13 |
14 | python fragnet/data_create/create_finetune_datasets.py --dataset_name moleculenet --dataset_subset esol --use_molebert True --output_dir fragnet/finetune_data/moleculenet_exp1s --data_dir fragnet/finetune_data/moleculenet --data_type exp1s
15 |
--------------------------------------------------------------------------------
/install_gpu.sh:
--------------------------------------------------------------------------------
1 | python -m venv ~/.env/fragnet
2 | source ~/.env/fragnet/bin/activate
3 | pip install --upgrade pip
4 | pip install -r requirements.txt
5 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
6 | pip install .
7 |
8 | mkdir -p fragnet/finetune_data/moleculenet/esol/raw/
9 |
10 | wget -O fragnet/finetune_data/moleculenet/esol/raw/delaney-processed.csv https://deepchemdata.s3-us-west-1.amazonaws.com/d
11 | atasets/delaney-processed.csv
12 |
13 | python fragnet/data_create/create_pretrain_datasets.py --save_path fragnet/pretrain_data/esol --data_type exp1s --maxiters
14 | 500 --raw_data_path fragnet/finetune_data/moleculenet/esol/raw/delaney-processed.csv
15 |
16 | python fragnet/data_create/create_finetune_datasets.py --dataset_name moleculenet --dataset_subset esol --use_molebert Tru
17 | e --output_dir fragnet/finetune_data/moleculenet_exp1s --data_dir fragnet/finetune_data/moleculenet --data_type exp1s
18 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | lmdb==1.4.1
2 | ipython
3 | matplotlib==3.9.2
4 | networkx==2.8.8
5 | numpy==1.24.1
6 | ogb==1.2.0
7 | omegaconf==2.3.0
8 | optuna==3.5.0
9 | pandas==2.2.3
10 | parmap==1.7.0
11 | Pillow==10.4.0
12 | PyYAML==6.0.2
13 | rdkit==2023.9.6
14 | scikit_learn==1.3.2
15 | scipy==1.14.1
16 | streamlit==1.38.0
17 | streamlit_ketcher==0.0.1
18 | torch==2.4.0
19 | pytorch_lightning
20 | torch_geometric==2.6.1
21 | tqdm==4.66.1
22 | tensorboard
23 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | setup(
4 | # Application name:
5 | name="fragnet",
6 |
7 | # Version number (initial):
8 | version="0.1.0",
9 |
10 | # Application author details:
11 | author="Gihan Panapitiya",
12 | author_email="gihan.panapitiya@pnnl.gov",
13 |
14 | # Packages
15 | #packages=["mpet","mpet.utils", "mpet.models", "mpet.doa"],
16 |
17 | # Include additional files into the package
18 | include_package_data=True,
19 |
20 | # Details
21 | url="",
22 |
23 | #
24 | # license="LICENSE.txt",
25 | description="FragNet: A GNN with Four Layers of Interpretability",
26 |
27 | # long_description=open("README.txt").read(),
28 |
29 | # Dependent packages (distributions)
30 | # install_requires=["openbabel"],
31 | )
32 |
33 |
--------------------------------------------------------------------------------