├── .gitignore ├── LICENSE.md ├── README.md ├── fragnet ├── assets │ ├── README.md │ ├── app.png │ ├── fragnet.png │ └── weights_main.png ├── data_create │ ├── create_dataset.py │ ├── create_finetune_datasets.py │ └── create_pretrain_datasets.py ├── dataset │ ├── cdrp.py │ ├── custom_dataset.py │ ├── data.py │ ├── dataset.py │ ├── dta.py │ ├── ext_data_utils │ │ ├── Step1_getData.py │ │ ├── __init__.py │ │ └── deepttc.py │ ├── feature_utils.py │ ├── features.py │ ├── features0.py │ ├── features2.py │ ├── features_check.py │ ├── features_exp2.py │ ├── features_exp_safe.py │ ├── fix_pt_data.py │ ├── fragments.py │ ├── general.py │ ├── loader_molebert.py │ ├── moleculenet.py │ ├── scaffold_split_from_df.py │ ├── simsgt.py │ ├── splitters.py │ ├── splitters_molebert.py │ └── utils.py ├── exps │ ├── README.md │ ├── ft │ │ ├── esol │ │ │ └── e1pt4.yaml │ │ ├── lipo │ │ │ ├── download.sh │ │ │ └── fragnet_hpdl_exp1s_pt4_30 │ │ │ │ ├── config_exp100.yaml │ │ │ │ └── ft_100.pt │ │ └── pnnl_full │ │ │ └── fragnet_hpdl_exp1s_h4pt4_10 │ │ │ ├── config_exp100.yaml │ │ │ ├── ft_100.pt │ │ │ └── ft_100.pt.data │ └── pt │ │ └── unimol_exp1s4 │ │ ├── config.yaml │ │ ├── pt.pt │ │ └── pt.pt.data ├── hp │ ├── hp.py │ ├── hp2.py │ ├── hp_cdrp.py │ ├── hp_clf.py │ ├── hp_dta.py │ ├── hpft.py │ ├── hpoptuna.py │ └── hpray.py ├── model │ ├── cdrp │ │ ├── __init__.py │ │ └── model.py │ ├── dta │ │ ├── __init__.py │ │ ├── drug_encoder.py │ │ └── model.py │ ├── gat │ │ ├── __init__.py │ │ ├── extra_optimizers.py │ │ ├── gat.py │ │ ├── gat2.py │ │ ├── gat2_cv.py │ │ ├── gat2_edge.py │ │ ├── gat2_lite.py │ │ ├── gat2_pl.py │ │ ├── gat2_pretrain.py │ │ └── pretrain_heads.py │ └── gcn │ │ ├── gcn.py │ │ ├── gcn2.py │ │ ├── gcn3.py │ │ └── gcn_pl.py ├── train │ ├── finetune │ │ ├── __init__.py │ │ ├── finetune_cdrp.py │ │ ├── finetune_dta.py │ │ ├── finetune_gat.py │ │ ├── finetune_gat2.py │ │ ├── finetune_gat2_pl.py │ │ ├── finetune_norm.py │ │ ├── gat2_cv_frag.py │ │ ├── trainer_cdrp.py │ │ └── trainer_dta.py │ ├── pretrain │ │ ├── __init__.py │ │ ├── pretrain_data_pnnl.py │ │ ├── pretrain_gat2.py │ │ ├── pretrain_gat_mol.py │ │ ├── pretrain_gat_str.py │ │ ├── pretrain_gcn.py │ │ ├── pretrain_heads.py │ │ └── pretrain_utils.py │ ├── utils.py │ └── utils_pl.py └── vizualize │ ├── app.py │ ├── config.py │ ├── model.py │ ├── model_attr.py │ ├── property.py │ └── viz.py ├── install_cpu.sh ├── install_gpu.sh ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | */.ipynb_checkpoints/* 3 | *__pycache__ 4 | __pycache__/ 5 | bin/ 6 | build/ 7 | develop-eggs/ 8 | dist/ 9 | eggs/ 10 | lib/ 11 | lib64/ 12 | parts/ 13 | sdist/ 14 | var/ 15 | *.egg-info/ 16 | .installed.cfg 17 | *.egg 18 | .tox/ 19 | .cache 20 | .project 21 | .pydevproject 22 | 23 | 24 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | FragNet 2 | 3 | Copyright © 2024, Battelle Memorial Institute 4 | All rights reserved. 5 | 6 | 1. Battelle Memorial Institute (hereinafter Battelle) hereby grants permission to any person or entity lawfully obtaining a copy of this software and associated documentation files (hereinafter “the Software”) to redistribute and use the Software in source and binary forms, with or without modification. Such person or entity may use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and may permit others to do so, subject to the following conditions: 7 | 8 | 0. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimers. 9 | 10 | 1. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | 2. Other than as used herein, neither the name Battelle Memorial Institute or Battelle may be used in any form whatsoever without the express written consent of Battelle. 13 | 14 | 2. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL BATTELLE OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FragNet 2 | 3 | FragNet is a Graph Neural Network designed for molecular property prediction, that can offer insights into how different substructures influence the predictions. More details of FragNet can be found in our paper, 4 | [FragNet: A Graph Neural Network for Molecular Property Prediction with Four Layers of Interpretability](https://arxiv.org/abs/2410.12156). 5 | 6 | 7 | 8 | drawing 9 | 10 | Figure 1: FragNet’s architecture and data representation. (a) Atom and Fragment graphs’ 11 | edge features are learned from Bond and Fragment connection graphs respectively. b) Initial 12 | fragment features for the fragment graph are the summation of the updated atom features 13 | that compose the fragment. (c) Illustration of FragNet’s message passing taking place be- 14 | tween two non-covalently bonded substructures. Fragment-Fragment connections are also 15 | present between adjacent fragments in each non-covalently bonded structure of the com- 16 | pound. 17 | 18 | drawing 19 | 20 | Figure 2: Different types of attention weights and contribution values available in FragNet visualized for CC[NH+](CCCl)CCOc1cccc2ccccc12.[Cl-] with atom, bond, and fragment at- 21 | tention weights shown in (a),(b), and (c) and fragment contribution values shown in (d). 22 | The top table provides the atom to fragment mapping and the bottom table provides the 23 | fragment connection attention weights. Atom and bond attention weights are scaled to val- 24 | ues between 0 and 1. The fragment and fragment connection weights are not scaled. The 25 | numbers in blue boxes in (d) correspond to Fragment IDs in ‘Atoms in Fragments’ table. 26 | # Usage 27 | 28 | ### Installation 29 | 30 | The installation has been tested with python 3.11 and cuda 12.1 31 | 32 | #### For CPU 33 | 34 | 1. Create a python 3.11 virtual environment and install the required packages using the command `pip install -r requirements.txt` 35 | 2. Install torch-scatter using `pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html` 36 | 3. Next install FragNet. In the directory where `setup.py` is, run the command `pip install .` 37 | 38 | Alternatively and more conveniently, you can run `bash install_cpu.sh` which will install FragNet and create pretraining and finetuning data for ESOL dataset. 39 | 40 | #### For GPU 41 | 42 | 1. Create a python 3.11 virtual environment and install the required packages using the command `pip instal -r requirements.txt` 43 | 2. Install torch-scatter using `pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html` 44 | 3. Next install FragNet. In the directory where `setup.py` is, run the command `pip install .` 45 | 46 | Alternatively do `bash install_gpu.sh`. 47 | 48 | ------- 49 | 50 | ### Creating pretraining data 51 | 52 | FragNet was pretrained using part of the data used by [UniMol](https://github.com/deepmodeling/Uni-Mol/tree/main/unimol). 53 | 54 | Here, we use ESOL dataset to demonstrate the data creation. The following commands should be run at the `FragNet/fragnet` directory. 55 | 56 | First, create a directory to save data. 57 | 58 | ```mkdir -p finetune_data/moleculenet/esol/raw/``` 59 | 60 | Next, download ESOL dataset. 61 | 62 | ``` 63 | wget -O finetune_data/moleculenet/esol/raw/delaney-processed.csv https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv 64 | ``` 65 | 66 | Next, run the following command to create pretraining data. 67 | 68 | ``` 69 | python data_create/create_pretrain_datasets.py --save_path pretrain_data/esol --data_type exp1s --maxiters 500 --raw_data_path finetune_data/moleculenet/esol/raw/delaney-processed.csv 70 | ``` 71 | 72 | 73 | - save_path: where the datasets should be saved 74 | - data_type: use exp1s for all the calculations 75 | - maxiters: maximum number of iterations for 3D coordinate generation 76 | - raw_data_path: location of the smiles dataset 77 | 78 | ------ 79 | 80 | ### Creating finetuning data 81 | 82 | Creating data for finetuning for MoleculeNet datasets can be done as follows, 83 | 84 | 85 | `python data_create/create_finetune_datasets.py --dataset_name moleculenet --dataset_subset esol --use_molebert True --output_dir finetune_data/moleculenet_exp1s --data_dir finetune_data/moleculenet --data_type exp1s` 86 | 87 | 88 | - dataset_name: dataset type 89 | - dataset_subset: dataset sub-type 90 | - use_molebert: whether to use the dataset splitting method used by MoleBert model 91 | 92 | ------ 93 | 94 | ### Pretrain 95 | 96 | To pretrain run the following command. All the input parameters have to be given in a config file. 97 | 98 | ``` 99 | python train/pretrain/pretrain_gat2.py --config exps/pt/unimol_exp1s4/config.yaml 100 | ``` 101 | 102 | ------ 103 | 104 | ### Finetune 105 | ``` 106 | python train/finetune/finetune_gat2.py --config exps/ft/esol/e1pt4.yaml 107 | ``` 108 | 109 | 110 | ------ 111 | 112 | ## Interactive Web Application 113 | 114 | To run this application, run the command `streamlit run fragnet/vizualize/app.py` from the root directory 115 | 116 | drawing 117 | 118 | ------ 119 | 120 | ## Optional 121 | ### Hyperparameter tuning 122 | ``` 123 | python hp/hpoptuna.py --config exps/ft/esol/e1pt4.yaml --n_trials 10 \ 124 | --chkpt hpruns/pt.pt --seed 10 --ft_epochs 10 --prune 1 125 | ``` 126 | 127 | - config: initial parameters 128 | - n_trials: number of hp optimization trails 129 | - chkpt: this is where the checkoint during hp optimization will be saved. Note that you will have to create an output directory for this (in this case hpruns). Otherwise the output directory is assumed to be the current working directory. 130 | - seed: random seed 131 | - ft_epochs: number of training epochs 132 | - prune: For Optuna runs. Whether to prune an optimization. 133 | 134 | 135 | 136 | ## Citation 137 | If you use our work, please cite it as, 138 | 139 | ``` 140 | @misc{panapitiya2024fragnetgraphneuralnetwork, 141 | title={FragNet: A Graph Neural Network for Molecular Property Prediction with Four Layers of Interpretability}, 142 | author={Gihan Panapitiya and Peiyuan Gao and C Mark Maupin and Emily G Saldanha}, 143 | year={2024}, 144 | eprint={2410.12156}, 145 | archivePrefix={arXiv}, 146 | primaryClass={cs.LG}, 147 | url={https://arxiv.org/abs/2410.12156}, 148 | } 149 | ``` 150 | 151 | 152 |
153 |

Disclaimer

154 | 155 | This material was prepared as an account of work sponsored by an agency of the United States Government. Neither the United States Government nor the United States Department of Energy, nor Battelle, nor any of their employees, nor any jurisdiction or organization that has cooperated in the development of these materials, makes any warranty, express or implied, or assumes any legal liability or responsibility for the accuracy, completeness, or usefulness or any information, apparatus, product, software, or process disclosed, or represents that its use would not infringe privately owned rights. 156 | Reference herein to any specific commercial product, process, or service by trade name, trademark, manufacturer, or otherwise does not necessarily constitute or imply its endorsement, recommendation, or favoring by the United States Government or any agency thereof, or Battelle Memorial Institute. The views and opinions of authors expressed herein do not necessarily state or reflect those of the United States Government or any agency thereof. 157 | PACIFIC NORTHWEST NATIONAL LABORATORY 158 | operated by 159 | BATTELLE 160 | for the 161 | UNITED STATES DEPARTMENT OF ENERGY 162 | under Contract DE-AC05-76RL01830 163 | 164 | 165 | 166 | -------------------------------------------------------------------------------- /fragnet/assets/README.md: -------------------------------------------------------------------------------- 1 | download cell_line_data.csv 2 | -------------------------------------------------------------------------------- /fragnet/assets/app.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/assets/app.png -------------------------------------------------------------------------------- /fragnet/assets/fragnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/assets/fragnet.png -------------------------------------------------------------------------------- /fragnet/assets/weights_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/assets/weights_main.png -------------------------------------------------------------------------------- /fragnet/data_create/create_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from dataset import save_dataset_parts, save_dataset 3 | import argparse 4 | 5 | 6 | if __name__ == "__main__": 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--data_path", 11 | help="dataframe or csv file", 12 | type=str, 13 | required=False, 14 | default=None, 15 | ) 16 | parser.add_argument( 17 | "--train_path", 18 | help="path to train input data", 19 | type=str, 20 | required=False, 21 | default=None, 22 | ) 23 | parser.add_argument( 24 | "--val_path", 25 | help="path to val input data", 26 | type=str, 27 | required=False, 28 | default=None, 29 | ) 30 | parser.add_argument( 31 | "--test_path", 32 | help="path to test input data", 33 | type=str, 34 | required=False, 35 | default=None, 36 | ) 37 | parser.add_argument("--start_row", help="", type=int, required=False, default=None) 38 | parser.add_argument("--end_row", help="", type=int, required=False, default=None) 39 | parser.add_argument("--start_id", help="", type=int, required=False, default=0) 40 | parser.add_argument( 41 | "--rows_per_part", help="", type=int, required=False, default=1000 42 | ) 43 | parser.add_argument( 44 | "--create_bond_graph_data", help="", type=bool, required=False, default=True 45 | ) 46 | parser.add_argument( 47 | "--add_dhangles", help="", type=bool, required=False, default=False 48 | ) 49 | parser.add_argument( 50 | "--feature_type", 51 | help="one_hot or embed", 52 | type=str, 53 | required=False, 54 | default="one_hot", 55 | ) 56 | parser.add_argument( 57 | "--target", 58 | help="target property", 59 | type=str, 60 | required=False, 61 | nargs="+", 62 | default="log_sol", 63 | ) 64 | parser.add_argument( 65 | "--save_path", 66 | help="folder where the data is saved", 67 | type=str, 68 | required=False, 69 | default=None, 70 | ) 71 | parser.add_argument( 72 | "--save_name", help="saving file name", type=str, required=False, default=None 73 | ) 74 | args = parser.parse_args() 75 | 76 | dataset_args = { 77 | "create_bond_graph_data": args.create_bond_graph_data, 78 | "add_dhangles": args.add_dhangles, 79 | "feature_type": args.feature_type, 80 | "target": args.target, 81 | "save_path": args.save_path, 82 | "save_name": args.save_name, 83 | "start_id": args.start_id, 84 | } 85 | 86 | if args.train_path: 87 | train = pd.read_csv(args.train_path) 88 | train.to_csv(f"{args.save_path}/train.csv", index=False) 89 | save_dataset( 90 | df=train, 91 | save_path=f"{args.save_path}", 92 | save_name=f"train", 93 | target=args.target, 94 | feature_type=args.feature_type, 95 | create_bond_graph_data=args.create_bond_graph_data, 96 | ) 97 | 98 | elif args.val_path: 99 | val = pd.read_csv(args.val_path) 100 | val.to_csv(f"{args.save_path}/val.csv", index=False) 101 | save_dataset( 102 | df=val, 103 | save_path=f"{args.save_path}", 104 | save_name=f"val", 105 | target=args.target, 106 | feature_type=args.feature_type, 107 | create_bond_graph_data=args.create_bond_graph_data, 108 | ) 109 | 110 | elif args.test_path: 111 | test = pd.read_csv(args.test_path) 112 | test.to_csv(f"{args.save_path}/test.csv", index=False) 113 | save_dataset( 114 | df=test, 115 | save_path=f"{args.save_path}", 116 | save_name=f"test", 117 | target=args.target, 118 | feature_type=args.feature_type, 119 | create_bond_graph_data=args.create_bond_graph_data, 120 | ) 121 | 122 | else: 123 | save_dataset_parts( 124 | args.data_path, 125 | start_row=args.start_row, 126 | end_row=args.end_row, 127 | rows_per_part=args.rows_per_part, 128 | dataset_args=dataset_args, 129 | ) 130 | -------------------------------------------------------------------------------- /fragnet/data_create/create_finetune_datasets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from fragnet.dataset.moleculenet import create_moleculenet_dataset 4 | from fragnet.dataset.general import create_general_dataset 5 | 6 | # from fragnet.dataset.unimol import create_moleculenet_dataset_from_unimol_data 7 | from fragnet.dataset.simsgt import create_moleculenet_dataset_simsgt 8 | from fragnet.dataset.dta import create_dta_dataset 9 | from fragnet.dataset.cdrp import create_cdrp_dataset 10 | from fragnet.dataset.scaffold_split_from_df import create_scaffold_split_data_from_df 11 | 12 | if __name__ == "__main__": 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--dataset_name", 17 | help="moleculenet, moleculenet-custom", 18 | type=str, 19 | required=False, 20 | default="moleculenet", 21 | ) 22 | parser.add_argument( 23 | "--dataset_subset", 24 | help="esol, freesolv", 25 | type=str, 26 | required=False, 27 | default=None, 28 | ) 29 | parser.add_argument( 30 | "--output_dir", help="", type=str, required=False, default="finetune_data" 31 | ) 32 | parser.add_argument( 33 | "--data_dir", help="", type=str, required=False, default="finetune_data" 34 | ) 35 | parser.add_argument( 36 | "--frag_type", help="", type=str, required=False, default="brics" 37 | ) 38 | parser.add_argument( 39 | "--use_molebert", help="", type=bool, required=False, default=False 40 | ) 41 | parser.add_argument("--train_path", help="", type=str, required=False, default=None) 42 | parser.add_argument("--save_parts", help="", type=int, required=False, default=0) 43 | parser.add_argument("--val_path", help="", type=str, required=False, default=None) 44 | parser.add_argument("--test_path", help="", type=str, required=False, default=None) 45 | parser.add_argument("--target_name", help="", type=str, required=False, default="y") 46 | parser.add_argument("--data_type", help="", type=str, required=False, default="exp") 47 | parser.add_argument( 48 | "--multi_conf_data", 49 | help="create multiple conformers for a smiles", 50 | type=int, 51 | required=False, 52 | default=0, 53 | ) 54 | parser.add_argument( 55 | "--use_genes", 56 | help="whether a gene subset is used for cdrp data", 57 | type=int, 58 | required=False, 59 | default=0, 60 | ) 61 | 62 | args = parser.parse_args() 63 | 64 | print("args.use_genes: ", args.use_genes) 65 | print("args.multi_conf_data: ", args.multi_conf_data) 66 | 67 | if "gen" in args.dataset_name: 68 | 69 | create_general_dataset(args) 70 | elif args.dataset_name == "moleculenet": 71 | 72 | create_moleculenet_dataset("MoleculeNet", args.dataset_subset.lower(), args) 73 | elif args.dataset_name == "moleculedataset": 74 | create_moleculenet_dataset("MoleculeDataset", args.dataset_subset.lower(), args) 75 | 76 | elif args.dataset_name == "unimol": 77 | create_moleculenet_dataset_from_unimol_data(args) 78 | 79 | elif args.dataset_name == "simsgt": 80 | create_moleculenet_dataset_simsgt(args.dataset_subset.lower(), args) 81 | 82 | elif args.dataset_name in ["davis", "kiba"]: 83 | create_dta_dataset(args) 84 | 85 | elif args.dataset_name in ["cep", "malaria"]: 86 | create_scaffold_split_data_from_df(args) 87 | 88 | elif args.dataset_name in ["gdsc", "gdsc_full", "ccle"]: 89 | print("args.use_genes0: ", args.use_genes) 90 | create_cdrp_dataset(args) 91 | -------------------------------------------------------------------------------- /fragnet/data_create/create_pretrain_datasets.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from fragnet.dataset.dataset import get_pt_dataset 3 | from fragnet.dataset.utils import extract_data, save_datasets 4 | import argparse 5 | import os 6 | import logging 7 | import logging.config 8 | 9 | 10 | def continuous_creation(args): 11 | 12 | l_limit = args.low 13 | h_limit = args.high 14 | n_rows = args.n_rows 15 | 16 | for i in range(l_limit, h_limit, n_rows): 17 | start_id = i 18 | end_id = i + n_rows - 1 19 | dfi = df.loc[start_id:end_id] 20 | 21 | ds = get_pt_dataset(dfi, args.data_type, maxiters=args.maxiters) 22 | ds = extract_data(ds) 23 | save_path = f"{args.save_path}/pt_{start_id}_{end_id}" 24 | save_datasets(ds, save_path) 25 | 26 | print(dfi.shape) 27 | 28 | 29 | def create_from_ids(args): 30 | 31 | curr = pd.read_pickle("pretrain_data/unimol_exp1s/curr_tmp.pkl") 32 | full = pd.read_pickle("../fragnet1.2/pretrain_data/unimol/file_list.pkl") 33 | new = set(full).difference(curr) 34 | new = list(new) 35 | new = new[args.low : args.high] 36 | start_ids = [int(i.split("_")[1]) for i in new] 37 | 38 | for start_id in start_ids: 39 | end_id = start_id + n_rows - 1 40 | dfi = df.loc[start_id:end_id] 41 | 42 | ds = get_pt_dataset(dfi, args.data_type) 43 | ds = extract_data(ds) 44 | save_path = f"{args.save_path}/pt_{start_id}_{end_id}" 45 | logger.info(save_path) 46 | save_datasets(ds, save_path) 47 | 48 | print(dfi.shape) 49 | 50 | 51 | def continuous_creation_add_new(args): 52 | 53 | l_limit = args.low 54 | h_limit = args.high 55 | n_rows = args.n_rows 56 | 57 | for i in range(l_limit, h_limit, n_rows): 58 | start_id = i 59 | end_id = i + n_rows - 1 60 | dfi = df.loc[start_id:end_id] 61 | 62 | save_path = f"{args.save_path}/pt_{start_id}_{end_id}" 63 | 64 | curr_data = pd.read_pickle(save_path + ".pkl") 65 | curr_smiles = [d.smiles for d in curr_data] 66 | 67 | dfi_rem = dfi[~dfi.smiles.isin(curr_smiles)] 68 | 69 | ds = get_pt_dataset( 70 | dfi_rem, args.data_type, maxiters=args.maxiters, frag_type=args.frag_type 71 | ) 72 | ds = extract_data(ds) 73 | 74 | ds_updated = ds + curr_data 75 | save_datasets(ds_updated, save_path) 76 | print(dfi.shape) 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument( 83 | "--save_path", 84 | help="", 85 | type=str, 86 | required=False, 87 | default="pretrain_data/unimol/ds", 88 | ) 89 | parser.add_argument("--low", help="", type=int, required=False, default=None) 90 | parser.add_argument("--high", help="", type=int, required=False, default=None) 91 | parser.add_argument("--data_type", help="", type=str, required=False, default="exp") 92 | parser.add_argument( 93 | "--raw_data_path", help="", type=str, required=False, default=None 94 | ) 95 | parser.add_argument( 96 | "--calc_type", help="", type=str, required=False, default="scratch" 97 | ) 98 | parser.add_argument("--maxiters", help="", type=int, required=False, default=200) 99 | parser.add_argument( 100 | "--frag_type", help="", type=str, required=False, default="brics" 101 | ) 102 | 103 | args = parser.parse_args() 104 | 105 | if not os.path.exists(args.save_path): 106 | os.makedirs(args.save_path, exist_ok=True) 107 | 108 | if args.raw_data_path == None: 109 | df = pd.read_csv("pretrain_data/input/train_no_modulus.csv") 110 | else: 111 | df = pd.read_csv(args.raw_data_path) 112 | 113 | n_rows = 1000 114 | if not args.low: 115 | args.low = 0 116 | if not args.high: 117 | args.high = len(df) 118 | args.n_rows = n_rows 119 | 120 | if args.calc_type == "scratch": 121 | continuous_creation(args) 122 | elif args.calc_type == "add": 123 | continuous_creation_add_new(args) 124 | -------------------------------------------------------------------------------- /fragnet/dataset/cdrp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .dataset import FinetuneDataCDRP 4 | from .ext_data_utils.deepttc import DataEncoding 5 | from sklearn.model_selection import train_test_split 6 | from .utils import extract_data, save_datasets, save_ds_parts 7 | 8 | 9 | def create_cdrp_dataset(args): 10 | 11 | if not os.path.exists(args.output_dir): 12 | os.makedirs(args.output_dir, exist_ok=True) 13 | 14 | print("args.use_genes1: ", args.use_genes) 15 | 16 | obj = DataEncoding(args.data_dir) 17 | 18 | print("args.use_genes: ", args.use_genes) 19 | 20 | dataset = FinetuneDataCDRP( 21 | target_name=args.target_name, 22 | data_type=args.data_type, 23 | args=args, 24 | use_genes=args.use_genes, 25 | ) 26 | 27 | traindata, testdata = obj.Getdata.ByCancer(random_seed=1, test_size=0.05) 28 | traindata, valdata = train_test_split(traindata, test_size=0.1) 29 | 30 | traindata, valdata, testdata = obj.encode2( 31 | traindata=traindata, valdata=valdata, testdata=testdata 32 | ) 33 | 34 | traindata.to_csv(f"{args.output_dir}/train.csv", index=False) 35 | valdata.to_csv(f"{args.output_dir}/val.csv", index=False) 36 | testdata.to_csv(f"{args.output_dir}/test.csv", index=False) 37 | 38 | if args.save_parts == 0: 39 | 40 | ds = dataset.get_ft_dataset(traindata) 41 | ds = extract_data(ds) 42 | print("ds: ", ds[0]) 43 | save_path = f"{args.output_dir}/train" 44 | save_datasets(ds, save_path) 45 | 46 | ds = dataset.get_ft_dataset(valdata) 47 | ds = extract_data(ds) 48 | save_path = f"{args.output_dir}/val" 49 | save_datasets(ds, save_path) 50 | 51 | ds = dataset.get_ft_dataset(testdata) 52 | ds = extract_data(ds) 53 | save_path = f"{args.output_dir}/test" 54 | save_datasets(ds, save_path) 55 | 56 | else: 57 | save_ds_parts( 58 | data_creater=dataset, ds=traindata, output_dir=args.output_dir, fold="train" 59 | ) 60 | save_ds_parts( 61 | data_creater=dataset, ds=valdata, output_dir=args.output_dir, fold="val" 62 | ) 63 | save_ds_parts( 64 | data_creater=dataset, ds=testdata, output_dir=args.output_dir, fold="test" 65 | ) 66 | -------------------------------------------------------------------------------- /fragnet/dataset/custom_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from torch_geometric.datasets import MoleculeNet 3 | from .utils import remove_non_mols 4 | from torch_geometric.data import Data 5 | 6 | 7 | class MoleculeDataset: 8 | def __init__(self, name, data_dir): 9 | self.name = name 10 | self.data_dir = data_dir 11 | 12 | def get_data(self): 13 | if self.name == "tox21": 14 | return self.get_tox21() 15 | elif self.name == "toxcast": 16 | return self.get_toxcast() 17 | elif self.name == "clintox": 18 | return self.get_clintox() 19 | elif self.name == "sider": 20 | return self.get_sider() 21 | elif self.name == "bbbp": 22 | return self.get_bbbp() 23 | elif self.name == "hiv": 24 | return self.get_hiv_dataset() 25 | elif self.name == "muv": 26 | return self.get_muv() 27 | 28 | def get_muv(self): 29 | from loader_molebert import _load_muv_dataset 30 | 31 | raw_path = f"{self.data_dir}/muv/raw/muv.csv" 32 | smiles_list, rdkit_mol_objs, labels = _load_muv_dataset(raw_path) 33 | 34 | assert len(smiles_list) == len(labels) 35 | data = [] 36 | for i in range(len(smiles_list)): 37 | y = [list(labels[i])] 38 | smiles = smiles_list[i] 39 | if smiles != None: 40 | data.append(Data(smiles=smiles, y=y)) 41 | 42 | return data 43 | 44 | def get_tox21(self): 45 | from loader_molebert import _load_tox21_dataset 46 | 47 | raw_path = f"{self.data_dir}/tox21/raw/tox21.csv" 48 | smiles_list, rdkit_mol_objs, labels = _load_tox21_dataset(raw_path) 49 | 50 | assert len(smiles_list) == len(labels) 51 | data = [] 52 | for i in range(len(smiles_list)): 53 | y = [list(labels[i])] 54 | smiles = smiles_list[i] 55 | if smiles != None: 56 | data.append(Data(smiles=smiles, y=y)) 57 | 58 | return data 59 | 60 | def get_hiv_dataset(self): 61 | from loader_molebert import _load_hiv_dataset 62 | 63 | raw_path = f"{self.data_dir}/hiv/raw/HIV.csv" 64 | 65 | smiles_list, rdkit_mol_objs, labels = _load_hiv_dataset(raw_path) 66 | 67 | assert len(smiles_list) == len(labels) 68 | data = [] 69 | for i in range(len(smiles_list)): 70 | 71 | y = [labels[i]] 72 | smiles = smiles_list[i] 73 | if smiles != None: 74 | data.append(Data(smiles=smiles, y=y)) 75 | 76 | return data 77 | 78 | def get_toxcast(self): 79 | from loader_molebert import _load_toxcast_dataset 80 | 81 | raw_path = f"{self.data_dir}/toxcast/raw/toxcast_data.csv" 82 | 83 | smiles_list, rdkit_mol_objs, labels = _load_toxcast_dataset(raw_path) 84 | 85 | assert len(smiles_list) == len(labels) 86 | data = [] 87 | for i in range(len(smiles_list)): 88 | y = [list(labels[i])] 89 | smiles = smiles_list[i] 90 | if smiles != None: 91 | data.append(Data(smiles=smiles, y=y)) 92 | 93 | return data 94 | 95 | def get_clintox(self): 96 | from loader_molebert import _load_clintox_dataset 97 | 98 | raw_path = f"{self.data_dir}/clintox/raw/clintox.csv" 99 | 100 | smiles_list, rdkit_mol_objs, labels = _load_clintox_dataset(raw_path) 101 | 102 | assert len(smiles_list) == len(labels) 103 | data = [] 104 | for i in range(len(smiles_list)): 105 | y = [list(labels[i])] 106 | smiles = smiles_list[i] 107 | if smiles != None: 108 | data.append(Data(smiles=smiles, y=y)) 109 | 110 | return data 111 | 112 | def get_bbbp(self): 113 | from loader_molebert import _load_bbbp_dataset 114 | 115 | raw_path = f"{self.data_dir}/bbbp/raw/BBBP.csv" 116 | 117 | smiles_list, rdkit_mol_objs, labels = _load_bbbp_dataset(raw_path) 118 | 119 | assert len(smiles_list) == len(labels) 120 | data = [] 121 | for i in range(len(smiles_list)): 122 | y = [labels[i]] 123 | smiles = smiles_list[i] 124 | if smiles != None: 125 | data.append(Data(smiles=smiles, y=y)) 126 | 127 | return data 128 | 129 | def get_sider(self): 130 | from loader_molebert import _load_sider_dataset 131 | 132 | dataset = MoleculeNet(self.data_dir, name=self.name) 133 | raw_path = f"{self.data_dir}/sider/raw/sider.csv" 134 | smiles_list, rdkit_mol_objs, labels = _load_sider_dataset(raw_path) 135 | 136 | assert len(smiles_list) == len(labels) 137 | data = [] 138 | for i in range(len(smiles_list)): 139 | y = [list(labels[i, :])] 140 | smiles = smiles_list[i] 141 | 142 | if smiles != None: 143 | data.append(Data(smiles=smiles, y=y)) 144 | 145 | return data 146 | 147 | def get_pcba(self): 148 | 149 | dataset = MoleculeNet(self.data_dir, name=self.name) 150 | df = pd.read_csv(f"{self.data_dir}/sider/raw/pcba.csv") 151 | df = remove_non_mols(df) 152 | df = df.fillna(-1) 153 | df.reset_index(drop=True, inplace=True) 154 | 155 | data = [] 156 | for i in df.index: 157 | y = [df.iloc[i, :128].values.tolist()] 158 | smiles = df.loc[i, "smiles"] 159 | data.append(Data(smiles=smiles, y=y)) 160 | 161 | return data 162 | -------------------------------------------------------------------------------- /fragnet/dataset/dta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from .dataset import FinetuneDataDTA 4 | from .utils import extract_data, save_datasets, save_ds_parts 5 | 6 | 7 | def create_dta_dataset(args): 8 | 9 | if not os.path.exists(args.output_dir): 10 | os.makedirs(args.output_dir, exist_ok=True) 11 | 12 | dataset = FinetuneDataDTA(target_name=args.target_name, data_type=args.data_type) 13 | 14 | if args.save_parts == 0: 15 | 16 | if args.train_path: 17 | train = pd.read_csv(args.train_path) 18 | train.to_csv(f"{args.output_dir}/train.csv", index=False) 19 | ds = dataset.get_ft_dataset(train) 20 | ds = extract_data(ds) 21 | save_path = f"{args.output_dir}/train" 22 | save_datasets(ds, save_path) 23 | 24 | if args.val_path: 25 | val = pd.read_csv(args.val_path) 26 | val.to_csv(f"{args.output_dir}/val.csv", index=False) 27 | ds = dataset.get_ft_dataset(val) 28 | ds = extract_data(ds) 29 | save_path = f"{args.output_dir}/val" 30 | save_datasets(ds, save_path) 31 | 32 | if args.test_path: 33 | test = pd.read_csv(args.test_path) 34 | test.to_csv(f"{args.output_dir}/test.csv", index=False) 35 | ds = dataset.get_ft_dataset(test) 36 | ds = extract_data(ds) 37 | save_path = f"{args.output_dir}/test" 38 | save_datasets(ds, save_path) 39 | 40 | else: 41 | save_ds_parts( 42 | data_creater=dataset, ds=train, output_dir=args.output_dir, fold="train" 43 | ) 44 | save_ds_parts( 45 | data_creater=dataset, ds=val, output_dir=args.output_dir, fold="val" 46 | ) 47 | save_ds_parts( 48 | data_creater=dataset, ds=test, output_dir=args.output_dir, fold="test" 49 | ) 50 | -------------------------------------------------------------------------------- /fragnet/dataset/ext_data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/dataset/ext_data_utils/__init__.py -------------------------------------------------------------------------------- /fragnet/dataset/ext_data_utils/deepttc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from .Step1_getData import GetData 4 | 5 | class DataEncoding: 6 | def __init__(self, data_dir): 7 | self.Getdata = GetData(data_dir) 8 | 9 | def encode2(self,traindata, valdata, testdata): 10 | drug_smiles = self.Getdata.getDrug() 11 | drugid2smile = dict(zip(drug_smiles['drug_id'],drug_smiles['smiles'])) 12 | 13 | traindata['smiles'] = [drugid2smile[i] for i in traindata['DRUG_ID']] 14 | valdata['smiles'] = [drugid2smile[i] for i in valdata['DRUG_ID']] 15 | testdata['smiles'] = [drugid2smile[i] for i in testdata['DRUG_ID']] 16 | 17 | 18 | traindata = traindata.reset_index() 19 | traindata['Label'] = traindata['LN_IC50'] 20 | 21 | valdata = valdata.reset_index() 22 | valdata['Label'] = valdata['LN_IC50'] 23 | 24 | 25 | testdata = testdata.reset_index() 26 | testdata['Label'] = testdata['LN_IC50'] 27 | 28 | 29 | return traindata, valdata, testdata 30 | -------------------------------------------------------------------------------- /fragnet/dataset/features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | from .feature_utils import one_of_k_encoding, one_of_k_encoding_unk 4 | from .feature_utils import get_bond_pair 5 | 6 | 7 | class FeaturesEXP: 8 | 9 | """ 10 | Class for creating initial atom, bond and connection features. 11 | Used by CreateData in dataset/data.py 12 | """ 13 | 14 | def __init__(self, add_connection_chrl=False): 15 | self.atom_list_one_hot = list(range(1, 119)) 16 | self.use_bond_chirality = True 17 | self.add_connection_chrl = add_connection_chrl 18 | 19 | def get_atom_and_bond_features_atom_graph_one_hot(self, mol, use_chirality): 20 | 21 | add_self_loops = False 22 | 23 | atoms = mol.GetAtoms() 24 | bonds = mol.GetBonds() 25 | edge_index = get_bond_pair(mol, add_self_loops) 26 | 27 | node_f = [self.atom_features_one_hot(atom) for atom in atoms] 28 | edge_attr = [] 29 | for bond in bonds: 30 | edge_attr.append( 31 | self.bond_features_one_hot(bond, use_chirality=use_chirality) 32 | ) 33 | edge_attr.append( 34 | self.bond_features_one_hot(bond, use_chirality=use_chirality) 35 | ) 36 | 37 | return node_f, edge_index, edge_attr 38 | 39 | def atom_features_one_hot( 40 | self, atom, explicit_H=False, use_chirality=False, angle_f=False 41 | ): 42 | 43 | atom_type = one_of_k_encoding_unk(atom.GetAtomicNum(), self.atom_list_one_hot) 44 | degree = one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) 45 | 46 | valence = one_of_k_encoding_unk( 47 | atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6] 48 | ) 49 | 50 | charge = one_of_k_encoding_unk( 51 | atom.GetFormalCharge(), [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5] 52 | ) 53 | rad_elec = one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), [0, 1, 2, 3, 4]) 54 | 55 | hyb = one_of_k_encoding_unk( 56 | atom.GetHybridization(), 57 | [ 58 | Chem.rdchem.HybridizationType.S, 59 | Chem.rdchem.HybridizationType.SP, 60 | Chem.rdchem.HybridizationType.SP2, 61 | Chem.rdchem.HybridizationType.SP3, 62 | Chem.rdchem.HybridizationType.SP3D, 63 | Chem.rdchem.HybridizationType.SP3D2, 64 | Chem.rdchem.HybridizationType.UNSPECIFIED, 65 | ], 66 | ) 67 | 68 | arom = one_of_k_encoding(atom.GetIsAromatic(), [False, True]) 69 | atom_ring = one_of_k_encoding(atom.IsInRing(), [False, True]) 70 | numhs = [atom.GetTotalNumHs()] 71 | 72 | chiral = one_of_k_encoding_unk( 73 | atom.GetChiralTag(), 74 | [ 75 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, 76 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, 77 | Chem.rdchem.ChiralType.CHI_UNSPECIFIED, 78 | ], 79 | ) 80 | results = ( 81 | atom_type 82 | + degree 83 | + valence 84 | + charge 85 | + rad_elec 86 | + hyb 87 | + arom 88 | + atom_ring 89 | + chiral 90 | + numhs 91 | ) 92 | 93 | return np.array(results) 94 | 95 | def bond_features_one_hot(self, bond, use_chirality=True): 96 | bt = bond.GetBondType() 97 | 98 | bond_feats = [ 99 | bt == Chem.rdchem.BondType.SINGLE, 100 | bt == Chem.rdchem.BondType.DOUBLE, 101 | bt == Chem.rdchem.BondType.TRIPLE, 102 | bt == Chem.rdchem.BondType.AROMATIC, 103 | ] 104 | 105 | conj = one_of_k_encoding(bond.GetIsConjugated(), [False, True]) 106 | inring = one_of_k_encoding(bond.IsInRing(), [False, True]) 107 | 108 | bond_feats = bond_feats + conj + inring 109 | 110 | if use_chirality: 111 | bond_feats = bond_feats + one_of_k_encoding_unk( 112 | str(bond.GetStereo()), ["STEREOANY", "STEREOZ", "STEREOE", "STEREONONE"] 113 | ) 114 | 115 | bond_feats = bond_feats + one_of_k_encoding_unk( 116 | bond.GetBondDir(), 117 | [ 118 | Chem.rdchem.BondDir.BEGINWEDGE, 119 | Chem.rdchem.BondDir.BEGINDASH, 120 | Chem.rdchem.BondDir.ENDDOWNRIGHT, 121 | Chem.rdchem.BondDir.ENDUPRIGHT, 122 | Chem.rdchem.BondDir.NONE, 123 | ], 124 | ) 125 | return list(bond_feats) 126 | 127 | def connection_features_one_hot(self, connection): 128 | 129 | bond = connection.bond 130 | bt = connection.bond_type 131 | 132 | bond_feats = [ 133 | bt == Chem.rdchem.BondType.SINGLE, 134 | bt == Chem.rdchem.BondType.DOUBLE, 135 | bt == Chem.rdchem.BondType.TRIPLE, 136 | bt == Chem.rdchem.BondType.AROMATIC, 137 | bt == "self_cn", 138 | bt == "iso_cn3", 139 | ] 140 | 141 | if self.add_connection_chrl: 142 | 143 | conj = one_of_k_encoding(bond.GetIsConjugated(), [False, True]) 144 | inring = one_of_k_encoding(bond.IsInRing(), [False, True]) 145 | 146 | bond_feats = bond_feats + conj + inring 147 | bond_feats = bond_feats + one_of_k_encoding_unk( 148 | str(bond.GetStereo()), ["STEREOANY", "STEREOZ", "STEREOE", "STEREONONE"] 149 | ) 150 | 151 | bond_feats = bond_feats + one_of_k_encoding_unk( 152 | bond.GetBondDir(), 153 | [ 154 | Chem.rdchem.BondDir.BEGINWEDGE, 155 | Chem.rdchem.BondDir.BEGINDASH, 156 | Chem.rdchem.BondDir.ENDDOWNRIGHT, 157 | Chem.rdchem.BondDir.ENDUPRIGHT, 158 | Chem.rdchem.BondDir.NONE, 159 | ], 160 | ) 161 | 162 | return list(bond_feats) 163 | -------------------------------------------------------------------------------- /fragnet/dataset/features0.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | from feature_utils import one_of_k_encoding, one_of_k_encoding_unk 4 | from feature_utils import get_bond_pair 5 | 6 | 7 | class FeaturesEXP: 8 | 9 | def __init__(self): 10 | self.atom_list_one_hot = [ 11 | "Br", 12 | "C", 13 | "Cl", 14 | "F", 15 | "H", 16 | "I", 17 | "K", 18 | "N", 19 | "Na", 20 | "O", 21 | "P", 22 | "S", 23 | "Unknown", 24 | ] 25 | self.use_bond_chirality = False 26 | 27 | def get_atom_and_bond_features_atom_graph_one_hot(self, mol, use_chirality): 28 | 29 | add_self_loops = False 30 | 31 | atoms = mol.GetAtoms() 32 | bonds = mol.GetBonds() 33 | edge_index = get_bond_pair(mol, add_self_loops) 34 | 35 | node_f = [self.atom_features_one_hot(atom) for atom in atoms] 36 | edge_attr = [] 37 | for bond in bonds: 38 | edge_attr.append( 39 | self.bond_features_one_hot(bond, use_chirality=use_chirality) 40 | ) 41 | edge_attr.append( 42 | self.bond_features_one_hot(bond, use_chirality=use_chirality) 43 | ) 44 | 45 | if add_self_loops: 46 | self_loop_attr = np.zeros((mol.GetNumAtoms(), 12)).tolist() 47 | 48 | edge_attr = edge_attr + self_loop_attr 49 | 50 | return node_f, edge_index, edge_attr 51 | 52 | def atom_features_one_hot( 53 | self, 54 | atom, 55 | bool_id_feat=False, 56 | explicit_H=False, 57 | use_chirality=False, 58 | angle_f=False, 59 | ): 60 | 61 | atom_type = one_of_k_encoding_unk(atom.GetSymbol(), self.atom_list_one_hot) 62 | degree = one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6]) 63 | valence = one_of_k_encoding_unk( 64 | atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6] 65 | ) 66 | charge = [atom.GetFormalCharge()] 67 | rad_elec = [atom.GetNumRadicalElectrons()] 68 | hyb = one_of_k_encoding_unk( 69 | atom.GetHybridization(), 70 | [ 71 | Chem.rdchem.HybridizationType.SP, 72 | Chem.rdchem.HybridizationType.SP2, 73 | Chem.rdchem.HybridizationType.SP3, 74 | Chem.rdchem.HybridizationType.SP3D, 75 | Chem.rdchem.HybridizationType.SP3D2, 76 | Chem.rdchem.HybridizationType.UNSPECIFIED, 77 | ], 78 | ) 79 | arom = [atom.GetIsAromatic()] 80 | atom_ring = [atom.IsInRing()] 81 | numhs = [atom.GetTotalNumHs()] 82 | 83 | chiral = one_of_k_encoding_unk( 84 | atom.GetChiralTag(), 85 | [ 86 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, 87 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, 88 | Chem.rdchem.ChiralType.CHI_UNSPECIFIED, 89 | ], 90 | ) 91 | 92 | results = ( 93 | atom_type 94 | + degree 95 | + valence 96 | + charge 97 | + rad_elec 98 | + hyb 99 | + arom 100 | + atom_ring 101 | + numhs 102 | ) 103 | 104 | if use_chirality: 105 | try: 106 | results = ( 107 | results 108 | + one_of_k_encoding_unk(atom.GetProp("_CIPCode"), ["R", "S"]) 109 | + [atom.HasProp("_ChiralityPossible")] 110 | ) 111 | except: 112 | results = ( 113 | results + [False, False] + [atom.HasProp("_ChiralityPossible")] 114 | ) 115 | 116 | return np.array(results) 117 | 118 | def bond_features_one_hot(self, bond, use_chirality=True): 119 | bt = bond.GetBondType() 120 | 121 | bond_feats = [ 122 | bt == Chem.rdchem.BondType.SINGLE, 123 | bt == Chem.rdchem.BondType.DOUBLE, 124 | bt == Chem.rdchem.BondType.TRIPLE, 125 | bt == Chem.rdchem.BondType.AROMATIC, 126 | bond.GetIsConjugated(), 127 | bond.IsInRing(), 128 | ] 129 | 130 | if use_chirality: 131 | bond_feats = bond_feats + one_of_k_encoding_unk( 132 | str(bond.GetStereo()), ["STEREOANY", "STEREOZ", "STEREOE", "STEREONONE"] 133 | ) 134 | 135 | bond_feats = bond_feats + one_of_k_encoding_unk( 136 | bond.GetBondDir(), 137 | [ 138 | Chem.rdchem.BondDir.BEGINWEDGE, 139 | Chem.rdchem.BondDir.BEGINDASH, 140 | Chem.rdchem.BondDir.ENDDOWNRIGHT, 141 | Chem.rdchem.BondDir.ENDUPRIGHT, 142 | Chem.rdchem.BondDir.NONE, 143 | ], 144 | ) 145 | 146 | return list(bond_feats) 147 | 148 | def connection_features_one_hot(self, connection): 149 | 150 | bt = connection.bond_type 151 | 152 | bond_feats = [ 153 | bt == Chem.rdchem.BondType.SINGLE, 154 | bt == Chem.rdchem.BondType.DOUBLE, 155 | bt == Chem.rdchem.BondType.TRIPLE, 156 | bt == Chem.rdchem.BondType.AROMATIC, 157 | bt == "self_cn", 158 | bt == "iso_cn3", 159 | ] 160 | return list(bond_feats) 161 | -------------------------------------------------------------------------------- /fragnet/dataset/features2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import Chem 3 | import numpy as np 4 | from rdkit import Chem 5 | from .feature_utils import one_of_k_encoding, one_of_k_encoding_unk 6 | from .feature_utils import get_bond_pair 7 | 8 | 9 | class FeaturesEXP: 10 | 11 | def __init__(self): 12 | self.atom_list_one_hot = list(range(1, 119)) 13 | self.use_bond_chirality = True 14 | 15 | def get_atom_and_bond_features_atom_graph_one_hot(self, mol, use_chirality): 16 | 17 | add_self_loops = False 18 | 19 | atoms = mol.GetAtoms() 20 | bonds = mol.GetBonds() 21 | edge_index = get_bond_pair(mol, add_self_loops) 22 | 23 | node_f = [self.atom_features_one_hot(atom) for atom in atoms] 24 | edge_attr = [] 25 | for bond in bonds: 26 | edge_attr.append( 27 | self.bond_features_one_hot(bond, use_chirality=use_chirality) 28 | ) 29 | edge_attr.append( 30 | self.bond_features_one_hot(bond, use_chirality=use_chirality) 31 | ) 32 | 33 | return node_f, edge_index, edge_attr 34 | 35 | def atom_features_one_hot( 36 | self, atom, explicit_H=False, use_chirality=False, angle_f=False 37 | ): 38 | 39 | atom_type = one_of_k_encoding_unk(atom.GetAtomicNum(), self.atom_list_one_hot) 40 | degree = one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) 41 | 42 | valence = one_of_k_encoding_unk( 43 | atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6] 44 | ) 45 | 46 | charge = one_of_k_encoding_unk( 47 | atom.GetFormalCharge(), [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5] 48 | ) 49 | rad_elec = one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), [0, 1, 2, 3, 4]) 50 | 51 | hyb = one_of_k_encoding_unk( 52 | atom.GetHybridization(), 53 | [ 54 | Chem.rdchem.HybridizationType.S, 55 | Chem.rdchem.HybridizationType.SP, 56 | Chem.rdchem.HybridizationType.SP2, 57 | Chem.rdchem.HybridizationType.SP3, 58 | Chem.rdchem.HybridizationType.SP3D, 59 | Chem.rdchem.HybridizationType.SP3D2, 60 | Chem.rdchem.HybridizationType.UNSPECIFIED, 61 | ], 62 | ) 63 | arom = one_of_k_encoding(atom.GetIsAromatic(), [False, True]) 64 | atom_ring = one_of_k_encoding(atom.IsInRing(), [False, True]) 65 | numhs = [atom.GetTotalNumHs()] 66 | 67 | chiral = one_of_k_encoding_unk( 68 | atom.GetChiralTag(), 69 | [ 70 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, 71 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, 72 | Chem.rdchem.ChiralType.CHI_UNSPECIFIED, 73 | ], 74 | ) 75 | 76 | results = ( 77 | atom_type 78 | + degree 79 | + valence 80 | + charge 81 | + rad_elec 82 | + hyb 83 | + arom 84 | + atom_ring 85 | + chiral 86 | + numhs 87 | ) 88 | 89 | return np.array(results) 90 | 91 | def bond_features_one_hot(self, bond, use_chirality=True): 92 | bt = bond.GetBondType() 93 | 94 | bond_feats = [ 95 | bt == Chem.rdchem.BondType.SINGLE, 96 | bt == Chem.rdchem.BondType.DOUBLE, 97 | bt == Chem.rdchem.BondType.TRIPLE, 98 | bt == Chem.rdchem.BondType.AROMATIC, 99 | ] 100 | 101 | conj = one_of_k_encoding(bond.GetIsConjugated(), [False, True]) 102 | inring = one_of_k_encoding(bond.IsInRing(), [False, True]) 103 | 104 | bond_feats = bond_feats + conj + inring 105 | 106 | if use_chirality: 107 | bond_feats = bond_feats + one_of_k_encoding_unk( 108 | str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREONONE"] 109 | ) 110 | 111 | bond_feats = bond_feats + one_of_k_encoding_unk( 112 | bond.GetBondDir(), 113 | [ 114 | Chem.rdchem.BondDir.ENDDOWNRIGHT, 115 | Chem.rdchem.BondDir.ENDUPRIGHT, 116 | Chem.rdchem.BondDir.NONE, 117 | ], 118 | ) 119 | return list(bond_feats) 120 | 121 | def connection_features_one_hot(self, connection): 122 | 123 | bt = connection.bond_type 124 | 125 | bond_feats = [ 126 | bt == Chem.rdchem.BondType.SINGLE, 127 | bt == Chem.rdchem.BondType.DOUBLE, 128 | bt == Chem.rdchem.BondType.TRIPLE, 129 | bt == Chem.rdchem.BondType.AROMATIC, 130 | bt == "self_cn", 131 | bt == "iso_cn3", 132 | ] 133 | return list(bond_feats) 134 | -------------------------------------------------------------------------------- /fragnet/dataset/fix_pt_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pandas as pd 3 | import torch 4 | from models import FragNetPreTrain 5 | from data import collate_fn_pt as collate_fn 6 | from torch.utils.data import DataLoader 7 | import os 8 | import numpy as np 9 | 10 | 11 | files = os.listdir("pretrain_data/unimol_exp_/ds/") 12 | data = [] 13 | lens = 0 14 | for f in files: 15 | 16 | df1 = pd.read_pickle(f"pretrain_data/unimol_exp_/ds/{f}") 17 | df2 = pd.read_pickle(f"pretrain_data/unimol_exp/ds/{f}") 18 | 19 | for i in range(len(df1)): 20 | 21 | data.append(df1[i].y.item() == df2[i].y.item()) 22 | 23 | # break 24 | lens += len(df2) 25 | 26 | print("non_zero: ", np.count_nonzero(np.array(data)) == len(data) == lens) 27 | -------------------------------------------------------------------------------- /fragnet/dataset/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch_geometric.datasets import MoleculeNet 4 | 5 | from .utils import save_datasets 6 | from .dataset import FinetuneData, FinetuneMultiConfData 7 | from .utils import extract_data 8 | from .splitters import ScaffoldSplitter 9 | from .custom_dataset import MoleculeDataset 10 | from .utils import save_ds_parts 11 | import pandas as pd 12 | 13 | 14 | def create_general_dataset(args): 15 | 16 | if args.multi_conf_data == 1: 17 | print("generating multi-conf data") 18 | dataset = FinetuneMultiConfData( 19 | target_name=args.target_name, 20 | data_type=args.data_type, 21 | frag_type=args.frag_type, 22 | ) 23 | else: 24 | print("generating single-conf data") 25 | dataset = FinetuneData( 26 | target_name=args.target_name, 27 | data_type=args.data_type, 28 | frag_type=args.frag_type, 29 | ) 30 | 31 | if not os.path.exists(args.output_dir): 32 | os.makedirs(args.output_dir, exist_ok=True) 33 | 34 | if args.train_path: 35 | train = pd.read_csv(args.train_path) 36 | train.to_csv(f"{args.output_dir}/train.csv", index=False) 37 | ds = dataset.get_ft_dataset(train) 38 | ds = extract_data(ds) 39 | save_path = f"{args.output_dir}/train" 40 | save_datasets(ds, save_path) 41 | 42 | if args.val_path: 43 | val = pd.read_csv(args.val_path) 44 | val.to_csv(f"{args.output_dir}/val.csv", index=False) 45 | ds = dataset.get_ft_dataset(val) 46 | ds = extract_data(ds) 47 | save_path = f"{args.output_dir}/val" 48 | save_datasets(ds, save_path) 49 | 50 | if args.test_path: 51 | test = pd.read_csv(args.test_path) 52 | test.to_csv(f"{args.output_dir}/test.csv", index=False) 53 | ds = dataset.get_ft_dataset(test) 54 | ds = extract_data(ds) 55 | save_path = f"{args.output_dir}/test" 56 | save_datasets(ds, save_path) 57 | -------------------------------------------------------------------------------- /fragnet/dataset/moleculenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch_geometric.datasets import MoleculeNet 4 | 5 | from .utils import save_datasets 6 | from .dataset import FinetuneData, FinetuneMultiConfData 7 | from .utils import extract_data 8 | from .splitters import ScaffoldSplitter 9 | from .custom_dataset import MoleculeDataset 10 | from .utils import save_ds_parts 11 | 12 | 13 | def create_moleculenet_dataset(dstype, name, args): 14 | 15 | if not os.path.exists(args.output_dir + "/" + name): 16 | os.makedirs(args.output_dir + "/" + name, exist_ok=True) 17 | 18 | if dstype == "MoleculeNet": 19 | ds = MoleculeNet(f"{args.data_dir}", name=name) 20 | elif dstype == "MoleculeDataset": 21 | _ = MoleculeNet(f"{args.data_dir}", name=name) 22 | dataset = MoleculeDataset(name, args.data_dir) 23 | ds = dataset.get_data() 24 | 25 | if not args.use_molebert: 26 | scaffold_split = ScaffoldSplitter() 27 | train, val, test = scaffold_split.split(dataset=dataset, include_chirality=True) 28 | elif args.use_molebert: 29 | from .splitters_molebert import scaffold_split 30 | 31 | smiles = [i.smiles for i in ds] 32 | train, val, test, (train_smiles, valid_smiles, test_smiles) = scaffold_split( 33 | ds, smiles, return_smiles=True 34 | ) 35 | 36 | torch.save(train, f"{args.output_dir}/{name}/train.pt") 37 | torch.save(val, f"{args.output_dir}/{name}/val.pt") 38 | torch.save(test, f"{args.output_dir}/{name}/test.pt") 39 | 40 | if args.multi_conf_data: 41 | dataset = FinetuneMultiConfData(args.target_name, args.data_type) 42 | else: 43 | dataset = FinetuneData( 44 | args.target_name, args.data_type, frag_type=args.frag_type 45 | ) 46 | 47 | if not args.save_parts: 48 | 49 | ds = dataset.get_ft_dataset(train) 50 | ds = extract_data(ds) 51 | save_path = f"{args.output_dir}/{name}/train" 52 | save_datasets(ds, save_path) 53 | 54 | ds = dataset.get_ft_dataset(val) 55 | ds = extract_data(ds) 56 | save_path = f"{args.output_dir}/{name}/val" 57 | save_datasets(ds, save_path) 58 | 59 | ds = dataset.get_ft_dataset(test) 60 | ds = extract_data(ds) 61 | save_path = f"{args.output_dir}/{name}/test" 62 | save_datasets(ds, save_path) 63 | 64 | elif args.save_parts: 65 | save_ds_parts( 66 | data_creater=dataset, 67 | ds=train, 68 | output_dir=args.output_dir, 69 | name=name, 70 | fold="train", 71 | ) 72 | save_ds_parts( 73 | data_creater=dataset, 74 | ds=val, 75 | output_dir=args.output_dir, 76 | name=name, 77 | fold="val", 78 | ) 79 | save_ds_parts( 80 | data_creater=dataset, 81 | ds=test, 82 | output_dir=args.output_dir, 83 | name=name, 84 | fold="test", 85 | ) 86 | -------------------------------------------------------------------------------- /fragnet/dataset/scaffold_split_from_df.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from .dataset import FinetuneData 4 | from .utils import extract_data 5 | from .utils import save_datasets 6 | 7 | 8 | def create_scaffold_split_data_from_df(args): 9 | 10 | if not os.path.exists(args.output_dir): 11 | os.makedirs(args.output_dir, exist_ok=True) 12 | 13 | ds = pd.read_csv(args.data_dir) 14 | ds.reset_index(drop=True, inplace=True) 15 | 16 | from splitters_molebert import scaffold_split 17 | 18 | smiles = ds.smiles.values.tolist() 19 | 20 | train, val, test, (train_smiles, valid_smiles, test_smiles) = scaffold_split( 21 | ds, smiles, return_smiles=True 22 | ) 23 | 24 | train.reset_index(drop=True, inplace=True) 25 | val.reset_index(drop=True, inplace=True) 26 | test.reset_index(drop=True, inplace=True) 27 | 28 | dataset = FinetuneData(target_name=args.target_name, data_type=args.data_type) 29 | 30 | if not args.save_parts: 31 | 32 | train.to_csv(f"{args.output_dir}/train.csv", index=False) 33 | ds = dataset.get_ft_dataset(train) 34 | ds = extract_data(ds) 35 | save_path = f"{args.output_dir}/train" 36 | save_datasets(ds, save_path) 37 | 38 | val.to_csv(f"{args.output_dir}/val.csv", index=False) 39 | ds = dataset.get_ft_dataset(val) 40 | ds = extract_data(ds) 41 | save_path = f"{args.output_dir}/val" 42 | save_datasets(ds, save_path) 43 | 44 | test.to_csv(f"{args.output_dir}/test.csv", index=False) 45 | ds = dataset.get_ft_dataset(test) 46 | ds = extract_data(ds) 47 | save_path = f"{args.output_dir}/test" 48 | save_datasets(ds, save_path) 49 | -------------------------------------------------------------------------------- /fragnet/dataset/simsgt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.datasets import MoleculeNet 3 | from .utils import save_datasets 4 | import pandas as pd 5 | from .dataset import FinetuneData 6 | from .utils import extract_data 7 | 8 | 9 | def create_moleculenet_dataset_simsgt(name, args): 10 | from splitters_simsgt import scaffold_split 11 | from loader_simsgt import MoleculeDataset 12 | 13 | d = MoleculeNet(f"{args.output_dir}/simsgt", name=name) 14 | dataset = MoleculeDataset(f"{args.output_dir}/simsgt/" + name, dataset=name) 15 | 16 | smiles_list = pd.read_csv( 17 | f"{args.output_dir}/simsgt/" + name + "/processed/smiles.csv", header=None 18 | )[0].tolist() 19 | ( 20 | train_dataset, 21 | valid_dataset, 22 | test_dataset, 23 | (train_smiles, valid_smiles, test_smiles), 24 | ) = scaffold_split( 25 | dataset, 26 | smiles_list, 27 | null_value=0, 28 | frac_train=0.8, 29 | frac_valid=0.1, 30 | frac_test=0.1, 31 | return_smiles=True, 32 | ) 33 | 34 | torch.save(train_dataset, f"{args.output_dir}/simsgt/{name}/train.pt") 35 | torch.save(valid_dataset, f"{args.output_dir}/simsgt/{name}/val.pt") 36 | torch.save(test_dataset, f"{args.output_dir}/simsgt/{name}/test.pt") 37 | 38 | dataset = FinetuneData(args.target_name) 39 | 40 | if not args.save_parts: 41 | 42 | ds = dataset.get_ft_dataset(train_dataset) 43 | ds = extract_data(ds) 44 | save_path = f"{args.output_dir}/simsgt/{name}/train" 45 | save_datasets(ds, save_path) 46 | 47 | ds = dataset.get_ft_dataset(valid_dataset) 48 | ds = extract_data(ds) 49 | save_path = f"{args.output_dir}/simsgt/{name}/val" 50 | save_datasets(ds, save_path) 51 | 52 | ds = dataset.get_ft_dataset(test_dataset) 53 | ds = extract_data(ds) 54 | save_path = f"{args.output_dir}/simsgt/{name}/test" 55 | save_datasets(ds, save_path) 56 | -------------------------------------------------------------------------------- /fragnet/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import pandas as pd 4 | import os 5 | import lmdb 6 | from rdkit import Chem 7 | from tqdm import tqdm 8 | 9 | 10 | def remove_non_mols(ds): 11 | 12 | keep_smiles = [] 13 | keep_ids = [] 14 | if isinstance(ds, pd.DataFrame): 15 | smiles_list = ds.smiles.values 16 | else: 17 | smiles_list = [i.smiles for i in ds] 18 | 19 | for i, smiles in enumerate(smiles_list): 20 | mol = Chem.MolFromSmiles(smiles) 21 | if mol: 22 | keep_ids.append(i) 23 | 24 | if isinstance(ds, pd.DataFrame): 25 | df2 = ds.loc[keep_ids, :] 26 | df2.reset_index(drop=True, inplace=True) 27 | else: 28 | df2 = ds.index_select(keep_ids) 29 | 30 | return df2 31 | 32 | 33 | def extract_data(ds): 34 | res = [] 35 | for i, data in tqdm(enumerate(ds)): 36 | if data: 37 | res.append(data) 38 | return res 39 | 40 | 41 | def save_datasets(ds, save_path): 42 | with open(f"{save_path}.pkl", "wb") as f: 43 | pickle.dump(ds, f) 44 | 45 | 46 | def remove_frags(): 47 | 48 | frags = [len(Chem.GetMolFrags(Chem.MolFromSmiles(sm))) for sm in train.smiles] 49 | train["mfrags"] = frags 50 | 51 | frags = [len(Chem.GetMolFrags(Chem.MolFromSmiles(sm))) for sm in val.smiles] 52 | val["mfrags"] = frags 53 | 54 | train = train[train.mfrags == 1] 55 | val = val[val.mfrags == 1] 56 | 57 | # train.shape 58 | 59 | train.reset_index(drop=True, inplace=True) 60 | val.reset_index(drop=True, inplace=True) 61 | 62 | 63 | def mol_with_atom_index(mol): 64 | for atom in mol.GetAtoms(): 65 | atom.SetAtomMapNum(atom.GetIdx()) 66 | return mol 67 | 68 | 69 | # TODO: test this function 70 | def remove_bond(rwmol, idx1, idx2): 71 | rwmol.RemoveBond(idx1, idx2) 72 | for idx in [idx1, idx2]: 73 | atom = rwmol.GetAtomWithIdx(idx) 74 | if atom.GetSymbol() == "N" and atom.GetIsAromatic() is True: 75 | atom.SetNumExplicitHs(1) 76 | 77 | 78 | def get_data(lmdb_path, name=None): 79 | 80 | # lmdb_path='ligands/train.lmdb' 81 | 82 | env = lmdb.open( 83 | lmdb_path, 84 | subdir=False, 85 | readonly=True, 86 | lock=False, 87 | readahead=False, 88 | meminit=False, 89 | max_readers=256, 90 | ) 91 | txn = env.begin() 92 | keys = list(txn.cursor().iternext(values=False)) 93 | 94 | smiles_data = [] 95 | for idx in keys: 96 | datapoint_pickled = txn.get(idx) 97 | data = pickle.loads(datapoint_pickled) 98 | smiles_data.append({"smiles": data["smi"], "target": data["target"]}) 99 | 100 | if name in ["clintox", "tox21", "toxcast", "sider", "pcba", "muv"]: 101 | for i in range(len(smiles_data)): 102 | smiles_data[i]["target"] = [list(smiles_data[i]["target"])] 103 | 104 | return smiles_data 105 | 106 | 107 | def collect_and_save(path, fold): 108 | 109 | f = os.listdir(path) 110 | t = [i for i in f if fold + "_p" in i and i.endswith("pkl")] 111 | 112 | data = [] 113 | for i in t: 114 | data.extend(pd.read_pickle(path + "/" + i)) 115 | 116 | with open(f"{path}/{fold}.pkl", "wb") as f: 117 | pickle.dump(data, f) 118 | # return data 119 | 120 | 121 | def save_ds_parts(data_creater=None, ds=None, output_dir=None, name=None, fold=None): 122 | 123 | if isinstance(ds, pd.DataFrame): 124 | ds.reset_index(drop=True, inplace=True) 125 | 126 | n = len(ds) // 1000 127 | parts = np.array_split(ds, n) 128 | 129 | for ipart, part in enumerate(parts): 130 | # ds_tmp = [ds[i] for i in ids] 131 | 132 | ds_tmp = data_creater.get_ft_dataset(part) 133 | ds_tmp = extract_data(ds_tmp) 134 | save_path = f"{output_dir}/{name}/{fold}_p_{ipart}" 135 | save_datasets(ds_tmp, save_path) 136 | if name: 137 | collect_and_save(f"{output_dir}/{name}", fold) 138 | else: 139 | collect_and_save(f"{output_dir}", fold) 140 | 141 | else: 142 | 143 | n = len(ds) // 1000 144 | id_list = np.array_split(np.arange(len(ds)), n) 145 | 146 | for ipart, ids in enumerate(id_list): 147 | ds_tmp = [ds[i] for i in ids] 148 | 149 | ds_tmp = data_creater.get_ft_dataset(ds_tmp) 150 | ds_tmp = extract_data(ds_tmp) 151 | save_path = f"{output_dir}/{name}/{fold}_p_{ipart}" 152 | save_datasets(ds_tmp, save_path) 153 | if name: 154 | collect_and_save(f"{output_dir}/{name}", fold) 155 | else: 156 | collect_and_save(f"{output_dir}", fold) 157 | -------------------------------------------------------------------------------- /fragnet/exps/README.md: -------------------------------------------------------------------------------- 1 | If you want to view the contents of the weights at `pt/unimol_exp1s4/pt.pt`, you can run the following code, 2 | 3 | ``` 4 | import torch 5 | pt = torch.load('fragnet/exps/pt/unimol_exp1s4/pt.pt') 6 | print(pt) 7 | ``` 8 | I have given the content of pt in `pt/unimol_exp1s4/pt.pt.data`. 9 | 10 | 11 | Similarly for the file at, `ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10` 12 | 13 | ``` 14 | import torch 15 | ft = torch.load('fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt') 16 | print(pt) 17 | ``` 18 | 19 | The content of ft is in `ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt.data` -------------------------------------------------------------------------------- /fragnet/exps/ft/esol/e1pt4.yaml: -------------------------------------------------------------------------------- 1 | seed: 123 2 | # seed: 4 3 | # seed: 456 4 | data_seed: None 5 | exp_dir: exps/ft/esol 6 | model_version: gat2 7 | device: gpu 8 | atom_features: 167 9 | frag_features: 167 10 | edge_features: 17 11 | fedge_in: 6 12 | fbond_edge_in: 6 13 | # atom_features: 49 14 | # frag_features: 49 15 | # edge_features: 12 16 | #model_version: gat2_transformer 17 | # model_version: gat2_transformer2 18 | pretrain: 19 | model_version: gat2 20 | num_layer: 4 21 | drop_ratio: 0.2 22 | num_heads: 4 23 | emb_dim: 128 24 | chkpoint_name: exps/pt/unimol_exp1s4/pt.pt 25 | loss: mse 26 | batch_size: 128 27 | es_patience: 500 28 | lr: 1e-4 29 | n_epochs: 20000 30 | n_classes: 1 31 | # molebert splitting 32 | 33 | finetune: 34 | n_multi_task_heads: 0 35 | # batch_size: 24 # best 36 | batch_size: 16 37 | lr: 1e-4 38 | model: 39 | n_classes: 1 40 | num_layer: 4 41 | drop_ratio: 0.1 42 | num_heads: 4 43 | emb_dim: 128 44 | h1: 128 #128 45 | h2: 1024 46 | h3: 1024 47 | h4: 512 48 | act: relu 49 | fthead: FTHead3 50 | 51 | n_epochs: 10000 52 | target_type: regr 53 | loss: mse 54 | use_schedular: False 55 | es_patience: 100 56 | chkpoint_name: ${exp_dir}/ft.pt 57 | train: 58 | path: finetune_data/moleculenet_exp1s/esol/train.pkl # 20 exp node features 59 | val: 60 | path: finetune_data/moleculenet_exp1s/esol/val.pkl # 20 exp node features 61 | test: 62 | path: finetune_data/moleculenet_exp1s/esol/test.pkl # 20 exp node features 63 | -------------------------------------------------------------------------------- /fragnet/exps/ft/lipo/download.sh: -------------------------------------------------------------------------------- 1 | scp marianas:/people/pana982/solubility/models/fragnet/fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/ft_100.pt fragnet_hpdl_exp1s_pt4_30/ 2 | 3 | scp marianas:/people/pana982/solubility/models/fragnet/fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/config_exp100.yaml fragnet_hpdl_exp1s_pt4_30/ 4 | -------------------------------------------------------------------------------- /fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/config_exp100.yaml: -------------------------------------------------------------------------------- 1 | atom_features: 167 2 | data_seed: None 3 | device: gpu 4 | edge_features: 17 5 | exp_dir: exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30 6 | finetune: 7 | batch_size: 16 8 | chkpoint_name: exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/ft_100.pt 9 | es_patience: 100 10 | loss: mse 11 | lr: 1e-4 12 | model: 13 | act: relu 14 | drop_ratio: 0.0 15 | emb_dim: 128 16 | fthead: FTHead3 17 | h1: 640 18 | h2: 1408 19 | h3: 1472 20 | h4: 1792 21 | n_classes: 1 22 | num_heads: 4 23 | num_layer: 4 24 | n_epochs: 10000 25 | n_multi_task_heads: 0 26 | target_type: regr 27 | test: 28 | path: finetune_data/moleculenet_exp1s/lipo/test.pkl 29 | train: 30 | path: finetune_data/moleculenet_exp1s/lipo/train.pkl 31 | use_schedular: false 32 | val: 33 | path: finetune_data/moleculenet_exp1s/lipo/val.pkl 34 | frag_features: 167 35 | model_version: gat2 36 | pretrain: 37 | batch_size: 128 38 | chkpoint_name: exps/pt/unimol_exp1s4/pt.pt 39 | drop_ratio: 0.2 40 | emb_dim: 128 41 | es_patience: 500 42 | loss: mse 43 | lr: 1e-4 44 | n_epochs: 20000 45 | num_heads: 4 46 | num_layer: 4 47 | seed: 100 48 | -------------------------------------------------------------------------------- /fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/ft_100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/ft_100.pt -------------------------------------------------------------------------------- /fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/config_exp100.yaml: -------------------------------------------------------------------------------- 1 | atom_features: 167 2 | data_seed: None 3 | device: gpu 4 | edge_features: 17 5 | exp_dir: exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10 6 | finetune: 7 | batch_size: 16 8 | chkpoint_name: exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt 9 | es_patience: 100 10 | loss: mse 11 | lr: 1e-4 12 | model: 13 | act: selu 14 | drop_ratio: 0.1 15 | emb_dim: 128 16 | fthead: FTHead4 17 | h1: 1472 18 | h2: 1024 19 | h3: 1024 20 | h4: 512 21 | n_classes: 1 22 | num_heads: 4 23 | num_layer: 4 24 | n_epochs: 10000 25 | n_multi_task_heads: 0 26 | target_type: regr 27 | test: 28 | path: finetune_data/pnnl_full/test.pkl 29 | train: 30 | path: finetune_data/pnnl_full/train.pkl 31 | use_schedular: false 32 | val: 33 | path: finetune_data/pnnl_full/val.pkl 34 | frag_features: 167 35 | model_version: gat2 36 | pretrain: 37 | batch_size: 128 38 | chkpoint_name: exps/pt/unimol_exp1s4/pt.pt 39 | drop_ratio: 0.2 40 | emb_dim: 128 41 | es_patience: 500 42 | loss: mse 43 | lr: 1e-4 44 | n_classes: 1 45 | n_epochs: 20000 46 | num_heads: 4 47 | num_layer: 4 48 | seed: 100 49 | -------------------------------------------------------------------------------- /fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt -------------------------------------------------------------------------------- /fragnet/exps/pt/unimol_exp1s4/config.yaml: -------------------------------------------------------------------------------- 1 | seed: 123 2 | exp_dir: exps/pt/unimol_exp1s4 3 | atom_features: 167 4 | frag_features: 167 5 | edge_features: 17 6 | model_version: gat2 7 | device: cpu 8 | fedge_in: 6 9 | fbond_edge_in: 6 10 | pretrain: 11 | num_layer: 4 12 | drop_ratio: 0.2 13 | num_heads: 4 14 | emb_dim: 128 15 | chkpoint_name: ${exp_dir}/pt.pt 16 | saved_checkpoint: null 17 | loss: mse 18 | batch_size: 512 19 | es_patience: 200 20 | lr: 1e-4 21 | n_epochs: 200 22 | valdiate_every: 5 23 | data: 24 | - pretrain_data/esol/ 25 | train_smiles: null 26 | val_smiles: null -------------------------------------------------------------------------------- /fragnet/exps/pt/unimol_exp1s4/pt.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/exps/pt/unimol_exp1s4/pt.pt -------------------------------------------------------------------------------- /fragnet/hp/hp.py: -------------------------------------------------------------------------------- 1 | 2 | from torch_geometric.data import DataLoader 3 | import random 4 | import hyperopt 5 | from hyperopt import fmin, hp, Trials, STATUS_OK 6 | import time 7 | from dataset import load_pickle_dataset 8 | from torch.utils.data import DataLoader 9 | from data import collate_fn 10 | from gat import FragNet 11 | import torch.nn as nn 12 | from gat import FragNetPreTrain 13 | from utils import EarlyStopping 14 | import numpy as np 15 | from utils import test_fn 16 | from dataset import load_data_parts 17 | import os 18 | import torch 19 | from torch_scatter import scatter_add 20 | 21 | def get_optimizer(model, freeze_pt_weights=False, lr=1e-4): 22 | 23 | if freeze_pt_weights: 24 | print('freezing pretrain weights') 25 | for name, param in model.named_parameters(): 26 | if param.requires_grad and 'pretrain' in name: 27 | param.requires_grad = False 28 | 29 | non_frozen_parameters = [p for p in model.parameters() if p.requires_grad] 30 | optimizer = torch.optim.Adam(non_frozen_parameters, lr = lr) 31 | else: 32 | print('no freezing of the weights') 33 | optimizer = torch.optim.Adam(model.parameters(), lr = lr ) 34 | 35 | return model, optimizer 36 | 37 | 38 | def set_seed(seed): 39 | torch.manual_seed(seed) 40 | torch.cuda.manual_seed_all(seed) 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | os.environ['PYTHONHASHSEED'] = str(seed) 46 | 47 | 48 | 49 | def trainModel(params): 50 | 51 | 52 | class FragNetFineTune(nn.Module): 53 | 54 | def __init__(self): 55 | super(FragNetFineTune, self).__init__() 56 | 57 | self.pretrain = FragNet(num_layer=6, drop_ratio= params['d1'] ) 58 | self.lin1 = nn.Linear(128*2, int(params['f2'])) 59 | self.lin2 = nn.Linear(int(params['f2']), int(params['f3'])) 60 | self.out = nn.Linear(int(params['f3']), 1) 61 | self.dropout = nn.Dropout(p= params['d2'] ) 62 | self.activation = nn.ReLU() 63 | 64 | 65 | def forward(self, batch): 66 | 67 | x_atoms, x_frags = self.pretrain(batch) 68 | 69 | x_frags_pooled = scatter_add(src=x_frags, index=batch['frag_batch'], dim=0) 70 | x_atoms_pooled = scatter_add(src=x_atoms, index=batch['batch'], dim=0) 71 | 72 | cat = torch.cat((x_atoms_pooled, x_frags_pooled), 1) 73 | x = self.dropout(cat) 74 | 75 | x = self.lin1(x) 76 | x = self.activation(x) 77 | x = self.dropout(x) 78 | 79 | x = self.lin2(x) 80 | x = self.activation(x) 81 | x = self.dropout(x) 82 | 83 | x = self.out(x) 84 | 85 | return x 86 | 87 | 88 | 89 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=params['bs'], shuffle=True, drop_last=True) 90 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=128, shuffle=False, drop_last=False) 91 | 92 | model_pretrain = FragNetPreTrain(); 93 | model_pretrain.to(device); 94 | model_pretrain.load_state_dict(torch.load('pt.pt')) 95 | model = FragNetFineTune() 96 | model.to(device); 97 | loss_fn = nn.MSELoss() 98 | chkpoint_name = 'hp2.pt' 99 | 100 | 101 | model, optimizer = get_optimizer(model, freeze_pt_weights=freeze_pt_weights, lr=1e-5) 102 | early_stopping = EarlyStopping(patience=20, verbose=True, chkpoint_name=chkpoint_name) 103 | 104 | for epoch in range(n_epochs): 105 | 106 | res = [] 107 | model.train() 108 | total_loss = 0 109 | for batch in train_loader: 110 | for k,v in batch.items(): 111 | batch[k] = batch[k].to(device) 112 | optimizer.zero_grad() 113 | out = model(batch).view(-1,) 114 | loss = loss_fn(out, batch['y']) 115 | loss.backward() 116 | total_loss += loss.item() 117 | optimizer.step() 118 | 119 | 120 | try: 121 | val_loss, _, _ = test_fn(val_loader, model, device) 122 | res.append(val_loss) 123 | print("val mse: ", val_loss) 124 | 125 | early_stopping(val_loss, model) 126 | 127 | if early_stopping.early_stop: 128 | print("Early stopping") 129 | break 130 | except: 131 | print('val loss cannot be calculated') 132 | pass 133 | 134 | try: 135 | model.load_state_dict(torch.load(chkpoint_name)) 136 | test_mse, test_t, test_p = test_fn(val_loader, model, device) 137 | return {'loss':test_mse, 'status':STATUS_OK} 138 | 139 | except: 140 | return {'loss':1000, 'status':STATUS_OK} 141 | 142 | if __name__ == '__main__': 143 | 144 | dataset_name='moleculenet' 145 | dataset_subset = 'esol' 146 | freeze_pt_weights=False 147 | # add_pt_weights=True 148 | seed = None 149 | n_epochs = 100 150 | 151 | if dataset_name == 'moleculenet': 152 | train_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'train_{str(seed)}') 153 | val_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'val_{str(seed)}') 154 | test_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'test_{str(seed)}') 155 | elif 'pnnl' in dataset_name: 156 | train_dataset = load_data_parts(path=dataset_name, name='train') 157 | val_dataset = load_data_parts(path=dataset_name, name='val') 158 | test_dataset = load_data_parts(path=dataset_name, name='test') 159 | 160 | 161 | space = { 162 | 'f1': hp.quniform('f1', 32, 320, 32), 163 | 'f2': hp.quniform('f2', 32, 320, 32), 164 | 'f3': hp.quniform('f3', 32, 320, 32), 165 | 'd1': hp.uniform('d1', 0,1), 166 | 'd2': hp.uniform('d2', 0,1), 167 | 'bs' : hp.choice('bs', [16, 32, 128]), 168 | } 169 | 170 | 171 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 172 | 173 | 174 | trials = Trials() 175 | st_time = time.time() 176 | best = fmin(trainModel, space, algo=hyperopt.tpe.suggest, max_evals=100, trials=trials) 177 | 178 | end_time = time.time() 179 | fo = open("res2.txt", "w") 180 | fo.write(repr(best)) 181 | fo.close() 182 | 183 | print( (end_time-st_time)/3600 ) -------------------------------------------------------------------------------- /fragnet/hp/hp_clf.py: -------------------------------------------------------------------------------- 1 | 2 | from torch_geometric.data import DataLoader 3 | import random 4 | import hyperopt 5 | from hyperopt import fmin, hp, Trials, STATUS_OK 6 | import time 7 | from dataset import load_pickle_dataset 8 | from torch.utils.data import DataLoader 9 | from data import collate_fn 10 | import torch.nn as nn 11 | from utils import EarlyStopping 12 | import numpy as np 13 | import os 14 | import torch 15 | from omegaconf import OmegaConf 16 | import argparse 17 | from utils import TrainerFineTune as Trainer 18 | import torch.optim.lr_scheduler as lr_scheduler 19 | 20 | def get_optimizer(model, freeze_pt_weights=False, lr=1e-4): 21 | 22 | if freeze_pt_weights: 23 | print('freezing pretrain weights') 24 | for name, param in model.named_parameters(): 25 | if param.requires_grad and 'pretrain' in name: 26 | param.requires_grad = False 27 | 28 | non_frozen_parameters = [p for p in model.parameters() if p.requires_grad] 29 | optimizer = torch.optim.Adam(non_frozen_parameters, lr = lr) 30 | else: 31 | print('no freezing of the weights') 32 | optimizer = torch.optim.Adam(model.parameters(), lr = lr ) 33 | 34 | return model, optimizer 35 | 36 | 37 | def set_seed(seed): 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | np.random.seed(seed) 43 | random.seed(seed) 44 | os.environ['PYTHONHASHSEED'] = str(seed) 45 | 46 | 47 | 48 | def trainModel(params): 49 | 50 | 51 | exp_dir = args.exp_dir 52 | n_classes_pt = args.pretrain.n_classes 53 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 54 | 55 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings' 56 | 57 | from gat2 import FragNetFineTune 58 | model = FragNetFineTune(n_classes=args.finetune.model.n_classes, 59 | num_layer=params['num_layer'], 60 | drop_ratio=params['drop_ratio']) 61 | 62 | 63 | if pt_chkpoint_name: 64 | model_pretrain = FragNetFineTune(n_classes_pt) 65 | model_pretrain.load_state_dict(torch.load(pt_chkpoint_name)) 66 | state_dict_to_load={} 67 | for k,v in model.state_dict().items(): 68 | 69 | if v.size() == model_pretrain.state_dict()[k].size(): 70 | state_dict_to_load[k] = model_pretrain.state_dict()[k] 71 | else: 72 | state_dict_to_load[k] = v 73 | 74 | model.load_state_dict(state_dict_to_load) 75 | 76 | train_dataset2 = load_pickle_dataset(args.finetune.train.path, args.finetune.train.name) 77 | val_dataset2 = load_pickle_dataset(args.finetune.val.path, args.finetune.val.name) 78 | test_dataset2 = load_pickle_dataset(args.finetune.test.path, args.finetune.test.name) #'finetune_data/pnnl_exp' 79 | 80 | 81 | train_loader = DataLoader(train_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=True, drop_last=True) 82 | val_loader = DataLoader(val_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=False, drop_last=False) 83 | test_loader = DataLoader(test_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=False, drop_last=False) 84 | 85 | trainer = Trainer(target_type=args.finetune.target_type) 86 | 87 | 88 | model.to(device); 89 | ft_chk_point = f'{args.exp_dir}/fthp.pt' 90 | early_stopping = EarlyStopping(patience=100, verbose=True, chkpoint_name=ft_chk_point) 91 | 92 | 93 | if args.finetune.loss == 'mse': 94 | loss_fn = nn.MSELoss() 95 | elif args.finetune.loss == 'cel': 96 | loss_fn = nn.CrossEntropyLoss() 97 | 98 | optimizer = torch.optim.Adam(model.parameters(), lr = params['lr'] ) # before 1e-4 99 | if args.finetune.use_schedular: 100 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99) 101 | else: 102 | scheduler=None 103 | 104 | for epoch in range(args.finetune.n_epochs): 105 | 106 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, scheduler=scheduler, loss_fn=loss_fn, device=device) 107 | val_loss, _, _ = trainer.test(model=model, loader=val_loader, device=device) 108 | print(train_loss, val_loss) 109 | early_stopping(val_loss, model) 110 | 111 | if early_stopping.early_stop: 112 | print("Early stopping") 113 | break 114 | 115 | try: 116 | model.load_state_dict(torch.load(ft_chk_point)) 117 | mse, true, pred = trainer.test(model=model, loader=val_loader, device=device) 118 | 119 | return {'loss':-mse, 'status':STATUS_OK} 120 | 121 | except: 122 | return {'loss':-100000, 'status':STATUS_OK} 123 | 124 | 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') 127 | args = parser.parse_args() 128 | 129 | 130 | if args.config: # args priority is higher than yaml 131 | opt = OmegaConf.load(args.config) 132 | OmegaConf.resolve(opt) 133 | args=opt 134 | 135 | 136 | space = { 137 | 'num_layer': hp.choice('num_layer', [3, 4, 5,6,7,8]), 138 | 'lr': hp.choice('lr', [1e-3, 1e-4, 1e-6]), 139 | 'drop_ratio': hp.choice('drop_ratio', [0.15, 0.2, 0.3, 0.5]), 140 | 'batch_size' : hp.choice('batch_size', [8, 16, 32, 64, 128]), 141 | } 142 | 143 | 144 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 145 | 146 | trials = Trials() 147 | st_time = time.time() 148 | best = fmin(trainModel, space, algo=hyperopt.rand.suggest, max_evals=25, trials=trials) 149 | 150 | end_time = time.time() 151 | fo = open(f"{args.exp_dir}/res2.txt", "w") 152 | fo.write(repr(best)) 153 | fo.close() 154 | 155 | print( (end_time-st_time)/3600 ) 156 | -------------------------------------------------------------------------------- /fragnet/hp/hpoptuna.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import optuna 3 | from torch_geometric.data import DataLoader 4 | import random 5 | import time 6 | from fragnet.dataset.dataset import load_pickle_dataset 7 | from torch.utils.data import DataLoader 8 | from fragnet.dataset.data import collate_fn 9 | import torch.nn as nn 10 | from fragnet.model.gat.gat2 import FragNet 11 | from fragnet.train.utils import EarlyStopping 12 | import numpy as np 13 | from fragnet.train.utils import test_fn 14 | from fragnet.dataset.dataset import load_data_parts 15 | import os 16 | import torch 17 | from torch_scatter import scatter_add 18 | from omegaconf import OmegaConf 19 | import argparse 20 | from fragnet.train.utils import TrainerFineTune as Trainer 21 | import torch.optim.lr_scheduler as lr_scheduler 22 | from fragnet.dataset.data import collate_fn 23 | from fragnet.dataset.data import collate_fn 24 | import matplotlib.pyplot as plt 25 | from fragnet.model.gat.gat2 import FragNetFineTune 26 | import pytorch_lightning as pl 27 | 28 | def seed_everything(seed: int): 29 | import random, os 30 | 31 | random.seed(seed) 32 | os.environ['PYTHONHASHSEED'] = str(seed) 33 | np.random.seed(seed) 34 | torch.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) 36 | torch.backends.cudnn.deterministic = True 37 | 38 | 39 | def trainModel(trial): 40 | 41 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 42 | 43 | num_heads = args.finetune.model.num_heads 44 | num_layer = args.finetune.model.num_layer 45 | 46 | drop_ratio = trial.suggest_categorical('drop_ratio', [0.0, 0.1, 0.2, 0.3]) 47 | 48 | fthead = args.finetune.model.fthead 49 | 50 | if fthead == 'FTHead3': 51 | h1 = trial.suggest_int('h1', 64, 2048, step=64) 52 | h2=trial.suggest_int('h2', 64, 2048, step=64) 53 | h3=trial.suggest_int('h3', 64, 2048, step=64) 54 | h4=trial.suggest_int('h4', 64, 2048, step=64) 55 | 56 | elif fthead == 'FTHead4': 57 | h1 = trial.suggest_int('h1', 64, 2048, step=64) 58 | h2, h3, h4 = None, None, None 59 | 60 | 61 | act = trial.suggest_categorical("act", ['relu','silu','gelu','celu','selu','rrelu','relu6','prelu','leakyrelu']) 62 | batch_size = trial.suggest_categorical('batch_size', [16,32,64,128]) 63 | lr = args.finetune.lr 64 | 65 | 66 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings' 67 | if args.model_version == 'gat2': 68 | from fragnet.model.gat.gat2 import FragNetFineTune 69 | elif args.model_version == 'gat2_lite': 70 | from fragnet.model.gat.gat2_lite import FragNetFineTune 71 | 72 | model = FragNetFineTune(n_classes=args.finetune.model.n_classes, 73 | atom_features=args.atom_features, 74 | frag_features=args.frag_features, 75 | edge_features=args.edge_features, 76 | num_heads = num_heads, 77 | num_layer= num_layer, 78 | drop_ratio= drop_ratio, 79 | h1=h1, 80 | h2=h2, 81 | h3=h3, 82 | h4=h4, 83 | act=act, 84 | fthead=fthead 85 | ) 86 | 87 | 88 | 89 | if pt_chkpoint_name: 90 | 91 | from fragnet.model.gat.gat2_pretrain import FragNetPreTrain 92 | modelpt = FragNetPreTrain( 93 | atom_features=args.atom_features, 94 | frag_features=args.frag_features, 95 | edge_features=args.edge_features, 96 | num_layer=args.pretrain.num_layer, 97 | drop_ratio=args.pretrain.drop_ratio, 98 | num_heads=args.pretrain.num_heads, 99 | emb_dim=args.pretrain.emb_dim) 100 | modelpt.load_state_dict(torch.load(pt_chkpoint_name, map_location=torch.device(device))) 101 | 102 | 103 | print('loading pretrained weights') 104 | model.pretrain.load_state_dict(modelpt.pretrain.state_dict()) 105 | print('weights loaded') 106 | else: 107 | print('no pretrained weights') 108 | 109 | 110 | trainer = Trainer(target_type=args.finetune.target_type) 111 | 112 | train_dataset2 = load_pickle_dataset(args.finetune.train.path) 113 | val_dataset2 = load_pickle_dataset(args.finetune.val.path) 114 | 115 | 116 | train_loader = DataLoader(train_dataset2, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, drop_last=True) 117 | val_loader = DataLoader(val_dataset2, collate_fn=collate_fn, batch_size=128, shuffle=False, drop_last=False) 118 | 119 | model.to(device); 120 | 121 | optimizer = torch.optim.Adam(model.parameters(), lr = lr ) # before 1e-4 122 | if args.finetune.use_schedular: 123 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, cooldown=0) 124 | else: 125 | scheduler=None 126 | 127 | scheduler=None 128 | 129 | ft_chk_point = args.chkpt 130 | print("checkpoint name: ", ft_chk_point) 131 | early_stopping = EarlyStopping(patience=args.finetune.es_patience, verbose=True, chkpoint_name=ft_chk_point) 132 | 133 | for epoch in range(args.finetune.n_epochs): 134 | 135 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, scheduler=scheduler, device=device, val_loader=val_loader) 136 | val_loss, _, _ = trainer.test(model=model, loader=val_loader, device=device) 137 | print(train_loss, val_loss) 138 | 139 | 140 | trial.report(val_loss, epoch) 141 | if args.prune == 1: 142 | if trial.should_prune(): 143 | raise optuna.TrialPruned() 144 | 145 | early_stopping(val_loss, model) 146 | 147 | if early_stopping.early_stop: 148 | print("Early stopping") 149 | break 150 | 151 | 152 | try: 153 | model.load_state_dict(torch.load(ft_chk_point)) 154 | mse, true, pred = trainer.test(model=model, loader=val_loader, device=device) 155 | 156 | return mse 157 | 158 | except: 159 | return 1000.0 160 | 161 | 162 | if __name__ == '__main__': 163 | 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='exps/ft/esol2/config.yaml') 166 | parser.add_argument('--chkpt', help="checkpoint name", type=str, required=False, default='opt.pt') 167 | parser.add_argument('--direction', help="", type=str, required=False, default='minimize') 168 | parser.add_argument('--n_trials', help="", type=int, required=False, default=100) 169 | parser.add_argument('--embed_max', help="", type=int, required=False, default=512) 170 | parser.add_argument('--seed', help="", type=int, required=False, default=1) 171 | parser.add_argument('--ft_epochs', help="", type=int, required=False, default=100) 172 | parser.add_argument('--choose_random', help="", type=bool, required=False, default=False) 173 | parser.add_argument('--prune', help="", type=int, required=False, default=1) 174 | args = parser.parse_args() 175 | 176 | 177 | if args.config: # args priority is higher than yaml 178 | opt = OmegaConf.load(args.config) 179 | OmegaConf.resolve(opt) 180 | 181 | opt.update(vars(args)) 182 | args = opt 183 | 184 | seed_everything(args.seed) 185 | print('seed: ', args.seed) 186 | print('choose_random: ', args.choose_random) 187 | args.finetune.n_epochs = args.ft_epochs 188 | 189 | # for resuming 190 | study_name = args.chkpt.replace('.pt', '').split('/')[-1] 191 | storage = args.chkpt.replace('.pt', '.db') 192 | study = optuna.create_study(direction = args.direction, study_name=study_name, storage=f'sqlite:///{storage}', load_if_exists=True) 193 | # for resuming 194 | 195 | study.optimize(trainModel, n_trials=args.n_trials, gc_after_trial=True) 196 | df_study = study.trials_dataframe(attrs=('number', 'value', 'params', 'state')) 197 | df_name = args.chkpt.replace('.pt', '.csv') 198 | df_study.to_csv(df_name) 199 | 200 | print("best params:") 201 | print(study.best_params) 202 | 203 | -------------------------------------------------------------------------------- /fragnet/hp/hpray.py: -------------------------------------------------------------------------------- 1 | 2 | from torch_geometric.data import DataLoader 3 | from hyperopt import STATUS_OK 4 | from dataset import load_pickle_dataset 5 | from torch.utils.data import DataLoader 6 | from data import collate_fn 7 | import torch.nn as nn 8 | import os 9 | import torch 10 | from omegaconf import OmegaConf 11 | import argparse 12 | from utils import TrainerFineTune as Trainer 13 | import torch.optim.lr_scheduler as lr_scheduler 14 | from ray import tune 15 | 16 | RESULTS_PATH = './' 17 | 18 | def trainModel(params): 19 | 20 | exp_dir = args.exp_dir 21 | n_classes_pt = args.pretrain.n_classes 22 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 23 | 24 | 25 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings' 26 | from gat2 import FragNetFineTune 27 | 28 | model = FragNetFineTune(n_classes=args.finetune.model.n_classes, 29 | num_layer=params['num_layer'], 30 | drop_ratio=params['drop_ratio']) 31 | trainer = Trainer(target_type=args.finetune.target_type) 32 | 33 | 34 | if pt_chkpoint_name: 35 | model_pretrain = FragNetFineTune(n_classes_pt) 36 | model_pretrain.load_state_dict(torch.load(pt_chkpoint_name)) 37 | state_dict_to_load={} 38 | for k,v in model.state_dict().items(): 39 | 40 | if v.size() == model_pretrain.state_dict()[k].size(): 41 | state_dict_to_load[k] = model_pretrain.state_dict()[k] 42 | else: 43 | state_dict_to_load[k] = v 44 | 45 | model.load_state_dict(state_dict_to_load) 46 | 47 | 48 | args.finetune.train.path = os.path.join(RESULTS_PATH, args.finetune.train.path) 49 | args.finetune.val.path = os.path.join(RESULTS_PATH, args.finetune.val.path) 50 | args.finetune.test.path = os.path.join(RESULTS_PATH, args.finetune.test.path) 51 | 52 | train_dataset2 = load_pickle_dataset(args.finetune.train.path, args.finetune.train.name) 53 | val_dataset2 = load_pickle_dataset(args.finetune.val.path, args.finetune.val.name) 54 | test_dataset2 = load_pickle_dataset(args.finetune.test.path, args.finetune.test.name) #'finetune_data/pnnl_exp' 55 | 56 | train_loader = DataLoader(train_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=True, drop_last=True) 57 | val_loader = DataLoader(val_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=False, drop_last=False) 58 | test_loader = DataLoader(test_dataset2, collate_fn=collate_fn, batch_size=params['batch_size'], shuffle=False, drop_last=False) 59 | 60 | model.to(device); 61 | 62 | 63 | 64 | if args.finetune.loss == 'mse': 65 | loss_fn = nn.MSELoss() 66 | elif args.finetune.loss == 'cel': 67 | loss_fn = nn.CrossEntropyLoss() 68 | 69 | optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 ) # before 1e-4 70 | if args.finetune.use_schedular: 71 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99) 72 | else: 73 | scheduler=None 74 | 75 | for epoch in range(args.finetune.n_epochs): 76 | 77 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, scheduler=scheduler, loss_fn=loss_fn, device=device) 78 | val_loss, _, _ = trainer.test(model=model, loader=val_loader, device=device) 79 | print(train_loss, val_loss) 80 | 81 | try: 82 | mse, true, pred = trainer.test(model=model, loader=val_loader, device=device) 83 | return {'score':mse, 'status':STATUS_OK} 84 | 85 | except: 86 | return {'score':1000, 'status':STATUS_OK} 87 | 88 | 89 | 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='exps/ft/esol2/None/no_pt/config.yaml') 92 | args = parser.parse_args('') 93 | 94 | 95 | if args.config: # args priority is higher than yaml 96 | opt = OmegaConf.load(args.config) 97 | OmegaConf.resolve(opt) 98 | args=opt 99 | 100 | 101 | args.finetune.n_epochs = 20 102 | search_space = { 103 | "num_layer": tune.choice([3,4,5,6,7,8]), 104 | "drop_ratio": tune.choice([0.1, 0.15, 0.2, 0.25, .3]), 105 | "batch_size": tune.choice([8, 16, 32, 64]), 106 | } 107 | 108 | tuner = tune.Tuner(trainModel, param_space=search_space, 109 | tune_config=tune.TuneConfig(num_samples=100)) 110 | results = tuner.fit() 111 | print(results.get_best_result(metric="score", mode="min").config) -------------------------------------------------------------------------------- /fragnet/model/cdrp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/model/cdrp/__init__.py -------------------------------------------------------------------------------- /fragnet/model/cdrp/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLP(nn.Sequential): 7 | def __init__(self, gene_dim=903, device='cuda'): 8 | self.device = device 9 | input_dim_gene = gene_dim 10 | hidden_dim_gene = 256 11 | mlp_hidden_dims_gene = [1024, 256, 64] 12 | super(MLP, self).__init__() 13 | layer_size = len(mlp_hidden_dims_gene) + 1 14 | dims = [input_dim_gene] + mlp_hidden_dims_gene + [hidden_dim_gene] 15 | self.predictor = nn.ModuleList([nn.Linear(dims[i], dims[i + 1]) for i in range(layer_size)]) 16 | 17 | def forward(self, v): 18 | # predict 19 | v = v.float().to(self.device) 20 | for i, l in enumerate(self.predictor): 21 | v = F.relu(l(v)) 22 | return v 23 | 24 | 25 | class CDRPModel(nn.Module): 26 | def __init__(self, drug_model, gene_dim, device): 27 | super().__init__() 28 | self.drug_model = drug_model # FragNetFineTune() 29 | self.fc1 = nn.Linear(256+256, 128) 30 | self.fc2 = nn.Linear(128, 1) 31 | self.cell_model = MLP(gene_dim, device) 32 | 33 | 34 | def forward(self, batch): 35 | 36 | drug_enc = self.drug_model(batch) 37 | gene_expr = batch['gene_expr'] 38 | cell_enc = self.cell_model(gene_expr) 39 | 40 | cat = torch.cat((drug_enc, cell_enc), 1) 41 | out = self.fc1(cat) 42 | out = self.fc2(out) 43 | return out 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /fragnet/model/dta/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/model/dta/__init__.py -------------------------------------------------------------------------------- /fragnet/model/dta/drug_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script was copied from 3 | https://github.com/jianglikun/DeepTTC/blob/main/model_helper.py 4 | and modified 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | import copy 12 | import math 13 | 14 | torch.manual_seed(1) 15 | np.random.seed(1) 16 | 17 | class LayerNorm(nn.Module): 18 | def __init__(self, hidden_size, variance_epsilon=1e-12): 19 | 20 | super(LayerNorm, self).__init__() 21 | self.gamma = nn.Parameter(torch.ones(hidden_size)) 22 | self.beta = nn.Parameter(torch.zeros(hidden_size)) 23 | self.variance_epsilon = variance_epsilon 24 | 25 | def forward(self, x): 26 | u = x.mean(-1, keepdim=True) 27 | s = (x - u).pow(2).mean(-1, keepdim=True) 28 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 29 | return self.gamma * x + self.beta 30 | 31 | 32 | class Embeddings(nn.Module): 33 | """Construct the embeddings from protein/target, position embeddings. 34 | """ 35 | def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate): 36 | super(Embeddings, self).__init__() 37 | self.word_embeddings = nn.Embedding(vocab_size, hidden_size) 38 | self.position_embeddings = nn.Embedding(max_position_size, hidden_size) 39 | 40 | self.LayerNorm = LayerNorm(hidden_size) 41 | self.dropout = nn.Dropout(dropout_rate) 42 | 43 | def forward(self, input_ids): 44 | seq_length = input_ids.size(1) 45 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 46 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 47 | 48 | words_embeddings = self.word_embeddings(input_ids) 49 | position_embeddings = self.position_embeddings(position_ids) 50 | 51 | embeddings = words_embeddings + position_embeddings 52 | embeddings = self.LayerNorm(embeddings) 53 | embeddings = self.dropout(embeddings) 54 | return embeddings 55 | 56 | 57 | class SelfAttention(nn.Module): 58 | def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): 59 | super(SelfAttention, self).__init__() 60 | if hidden_size % num_attention_heads != 0: 61 | raise ValueError( 62 | "The hidden size (%d) is not a multiple of the number of attention " 63 | "heads (%d)" % (hidden_size, num_attention_heads)) 64 | self.num_attention_heads = num_attention_heads 65 | self.attention_head_size = int(hidden_size / num_attention_heads) 66 | self.all_head_size = self.num_attention_heads * self.attention_head_size 67 | 68 | self.query = nn.Linear(hidden_size, self.all_head_size) 69 | self.key = nn.Linear(hidden_size, self.all_head_size) 70 | self.value = nn.Linear(hidden_size, self.all_head_size) 71 | 72 | self.dropout = nn.Dropout(attention_probs_dropout_prob) 73 | 74 | def transpose_for_scores(self, x): 75 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 76 | x = x.view(*new_x_shape) 77 | return x.permute(0, 2, 1, 3) 78 | 79 | def forward(self, hidden_states, attention_mask): 80 | mixed_query_layer = self.query(hidden_states) 81 | mixed_key_layer = self.key(hidden_states) 82 | mixed_value_layer = self.value(hidden_states) 83 | 84 | query_layer = self.transpose_for_scores(mixed_query_layer) 85 | key_layer = self.transpose_for_scores(mixed_key_layer) 86 | value_layer = self.transpose_for_scores(mixed_value_layer) 87 | 88 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 89 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 90 | 91 | attention_scores = attention_scores + attention_mask 92 | 93 | # Normalize the attention scores to probabilities. 94 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 95 | attention_probs = self.dropout(attention_probs) 96 | 97 | context_layer = torch.matmul(attention_probs, value_layer) 98 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 99 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 100 | context_layer = context_layer.view(*new_context_layer_shape) 101 | return context_layer 102 | 103 | 104 | class SelfOutput(nn.Module): 105 | def __init__(self, hidden_size, hidden_dropout_prob): 106 | super(SelfOutput, self).__init__() 107 | self.dense = nn.Linear(hidden_size, hidden_size) 108 | self.LayerNorm = LayerNorm(hidden_size) 109 | self.dropout = nn.Dropout(hidden_dropout_prob) 110 | 111 | def forward(self, hidden_states, input_tensor): 112 | hidden_states = self.dense(hidden_states) 113 | hidden_states = self.dropout(hidden_states) 114 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 115 | return hidden_states 116 | 117 | 118 | class Attention(nn.Module): 119 | def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): 120 | super(Attention, self).__init__() 121 | self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) 122 | self.output = SelfOutput(hidden_size, hidden_dropout_prob) 123 | 124 | def forward(self, input_tensor, attention_mask): 125 | self_output = self.self(input_tensor, attention_mask) 126 | attention_output = self.output(self_output, input_tensor) 127 | return attention_output 128 | 129 | class Intermediate(nn.Module): 130 | def __init__(self, hidden_size, intermediate_size): 131 | super(Intermediate, self).__init__() 132 | self.dense = nn.Linear(hidden_size, intermediate_size) 133 | 134 | def forward(self, hidden_states): 135 | hidden_states = self.dense(hidden_states) 136 | hidden_states = F.relu(hidden_states) 137 | return hidden_states 138 | 139 | class Output(nn.Module): 140 | def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob): 141 | super(Output, self).__init__() 142 | self.dense = nn.Linear(intermediate_size, hidden_size) 143 | self.LayerNorm = LayerNorm(hidden_size) 144 | self.dropout = nn.Dropout(hidden_dropout_prob) 145 | 146 | def forward(self, hidden_states, input_tensor): 147 | hidden_states = self.dense(hidden_states) 148 | hidden_states = self.dropout(hidden_states) 149 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 150 | return hidden_states 151 | 152 | class Encoder(nn.Module): 153 | def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): 154 | super(Encoder, self).__init__() 155 | self.attention = Attention(hidden_size, num_attention_heads, 156 | attention_probs_dropout_prob, hidden_dropout_prob) 157 | self.intermediate = Intermediate(hidden_size, intermediate_size) 158 | self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob) 159 | 160 | def forward(self, hidden_states, attention_mask): 161 | attention_output = self.attention(hidden_states, attention_mask) 162 | intermediate_output = self.intermediate(attention_output) 163 | layer_output = self.output(intermediate_output, attention_output) 164 | return layer_output 165 | 166 | 167 | class Encoder_MultipleLayers(nn.Module): 168 | def __init__(self, n_layer, hidden_size, intermediate_size, 169 | num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): 170 | super(Encoder_MultipleLayers, self).__init__() 171 | layer = Encoder(hidden_size, intermediate_size, num_attention_heads, 172 | attention_probs_dropout_prob, hidden_dropout_prob) 173 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)]) 174 | 175 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): 176 | all_encoder_layers = [] 177 | for layer_module in self.layer: 178 | hidden_states = layer_module(hidden_states, attention_mask) 179 | return hidden_states 180 | -------------------------------------------------------------------------------- /fragnet/model/dta/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from .drug_encoder import Encoder_MultipleLayers 5 | 6 | torch.manual_seed(1) 7 | np.random.seed(1) 8 | 9 | class LayerNorm(nn.Module): 10 | def __init__(self, hidden_size, variance_epsilon=1e-12): 11 | 12 | super(LayerNorm, self).__init__() 13 | self.gamma = nn.Parameter(torch.ones(hidden_size)) 14 | self.beta = nn.Parameter(torch.zeros(hidden_size)) 15 | self.variance_epsilon = variance_epsilon 16 | 17 | def forward(self, x): 18 | u = x.mean(-1, keepdim=True) 19 | s = (x - u).pow(2).mean(-1, keepdim=True) 20 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 21 | return self.gamma * x + self.beta 22 | 23 | 24 | class Embeddings(nn.Module): 25 | """Construct the embeddings from protein/target, position embeddings. 26 | """ 27 | def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate): 28 | super(Embeddings, self).__init__() 29 | self.word_embeddings = nn.Embedding(vocab_size, hidden_size) 30 | self.position_embeddings = nn.Embedding(max_position_size, hidden_size) 31 | 32 | self.LayerNorm = LayerNorm(hidden_size) 33 | self.dropout = nn.Dropout(dropout_rate) 34 | 35 | def forward(self, input_ids): 36 | seq_length = input_ids.size(1) 37 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 38 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 39 | 40 | words_embeddings = self.word_embeddings(input_ids) 41 | position_embeddings = self.position_embeddings(position_ids) 42 | 43 | embeddings = words_embeddings + position_embeddings 44 | embeddings = self.LayerNorm(embeddings) 45 | embeddings = self.dropout(embeddings) 46 | return embeddings 47 | 48 | 49 | 50 | class transformer(nn.Sequential): 51 | def __init__(self): 52 | super(transformer, self).__init__() 53 | input_dim_drug = 25 54 | transformer_emb_size_drug = 128 55 | transformer_dropout_rate = 0.1 56 | transformer_n_layer_drug = 8 57 | transformer_intermediate_size_drug = 512 58 | transformer_num_attention_heads_drug = 8 59 | transformer_attention_probs_dropout = 0.1 60 | transformer_hidden_dropout_rate = 0.1 61 | max_position_size = 1000 62 | self.emb = Embeddings(input_dim_drug, 63 | transformer_emb_size_drug, 64 | max_position_size, 65 | transformer_dropout_rate) 66 | 67 | self.encoder = Encoder_MultipleLayers(transformer_n_layer_drug, 68 | transformer_emb_size_drug, 69 | transformer_intermediate_size_drug, 70 | transformer_num_attention_heads_drug, 71 | transformer_attention_probs_dropout, 72 | transformer_hidden_dropout_rate) 73 | def forward(self, e, e_mask): 74 | ex_e_mask = e_mask.unsqueeze(1).unsqueeze(2) 75 | ex_e_mask = (1.0 - ex_e_mask) * -10000.0 76 | 77 | emb = self.emb(e) 78 | encoded_layers = self.encoder(emb.float(), ex_e_mask.float()) 79 | return encoded_layers[:, 0] 80 | 81 | 82 | 83 | class DTAModel(nn.Module): 84 | def __init__(self, drug_model): 85 | super().__init__() 86 | self.drug_model = drug_model # FragNetFineTune() 87 | self.target_model = transformer() 88 | self.fc1 = nn.Linear(256+128, 128) 89 | self.fc2 = nn.Linear(128, 1) 90 | 91 | def forward(self, batch): 92 | 93 | drug_enc = self.drug_model(batch) 94 | 95 | tokens = batch['protein'] 96 | padding_mask = ~tokens.eq(0)*1 97 | 98 | target_enc = self.target_model(tokens, padding_mask) 99 | cat = torch.cat((drug_enc, target_enc), 1) 100 | 101 | out = self.fc1(cat) 102 | out = self.fc2(out) 103 | 104 | return out 105 | 106 | 107 | class DTAModel2(nn.Module): 108 | def __init__(self, drug_model): 109 | super().__init__() 110 | self.drug_model = drug_model # FragNetFineTune() 111 | self.fc1 = nn.Linear(256+300, 128) 112 | self.fc2 = nn.Linear(128, 1) 113 | 114 | num_features = 25 115 | prot_emb_dim = 300 116 | self.in_channels = 1000 117 | n_filters = 32 118 | kernel_size = 8 119 | prot_output_dim=300 120 | 121 | self.embedding_xt = nn.Embedding(num_features + 1, prot_emb_dim) 122 | self.conv_xt_1 = nn.Conv1d(in_channels=self.in_channels, out_channels=n_filters, kernel_size=kernel_size) 123 | intermediate_dim = prot_emb_dim - kernel_size + 1 124 | self.fc1_xt_dim = n_filters*intermediate_dim 125 | self.fc1_xt = nn.Linear(self.fc1_xt_dim, prot_output_dim) 126 | 127 | def forward(self, batch): 128 | 129 | drug_enc = self.drug_model(batch) 130 | 131 | tokens = batch['protein'] 132 | tokens = tokens.reshape(-1, self.in_channels) 133 | 134 | embedded_xt = self.embedding_xt(tokens) 135 | conv_xt = self.conv_xt_1(embedded_xt) 136 | # flatten 137 | xt = conv_xt.view(-1, self.fc1_xt_dim) 138 | xt = self.fc1_xt(xt) 139 | 140 | 141 | cat = torch.cat((drug_enc, xt), 1) 142 | 143 | out = self.fc1(cat) 144 | out = self.fc2(out) 145 | 146 | return out 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /fragnet/model/gat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/model/gat/__init__.py -------------------------------------------------------------------------------- /fragnet/model/gat/extra_optimizers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code was copied from https://github.com/LARS-research/3D-PGT/blob/main/graphgps/optimizer/extra_optimizers.py 3 | and modified 4 | """ 5 | 6 | import math 7 | from typing import Iterator 8 | import torch.optim as optim 9 | from torch.nn import Parameter 10 | from torch.optim import Adagrad, AdamW, Optimizer 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | 13 | def adagrad_optimizer(params: Iterator[Parameter], base_lr: float, 14 | weight_decay: float) -> Adagrad: 15 | return Adagrad(params, lr=base_lr, weight_decay=weight_decay) 16 | 17 | def adamW_optimizer(params: Iterator[Parameter], base_lr: float, 18 | weight_decay: float) -> AdamW: 19 | return AdamW(params, lr=base_lr, weight_decay=weight_decay) 20 | 21 | 22 | def plateau_scheduler(optimizer: Optimizer, patience: int, 23 | lr_decay: float) -> ReduceLROnPlateau: 24 | return ReduceLROnPlateau(optimizer, patience=patience, factor=lr_decay) 25 | 26 | 27 | def scheduler_reduce_on_plateau(optimizer: Optimizer, reduce_factor: float, 28 | schedule_patience: int, min_lr: float, 29 | train_mode: str, eval_period: int): 30 | 31 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 32 | optimizer=optimizer, 33 | mode='min', 34 | factor=reduce_factor, 35 | patience=schedule_patience, 36 | min_lr=min_lr, 37 | verbose=True 38 | ) 39 | if not hasattr(scheduler, 'get_last_lr'): 40 | # ReduceLROnPlateau doesn't have `get_last_lr` method as of current 41 | # pytorch1.10; we add it here for consistency with other schedulers. 42 | def get_last_lr(self): 43 | """ Return last computed learning rate by current scheduler. 44 | """ 45 | return self._last_lr 46 | 47 | scheduler.get_last_lr = get_last_lr.__get__(scheduler) 48 | scheduler._last_lr = [group['lr'] 49 | for group in scheduler.optimizer.param_groups] 50 | 51 | def modified_state_dict(ref): 52 | """Returns the state of the scheduler as a :class:`dict`. 53 | Additionally modified to ignore 'get_last_lr', 'state_dict'. 54 | Including these entries in the state dict would cause issues when 55 | loading a partially trained / pretrained model from a checkpoint. 56 | """ 57 | return {key: value for key, value in ref.__dict__.items() 58 | if key not in ['sparsifier', 'get_last_lr', 'state_dict']} 59 | 60 | scheduler.state_dict = modified_state_dict.__get__(scheduler) 61 | 62 | return scheduler 63 | 64 | 65 | # @register.register_scheduler('linear_with_warmup') 66 | def linear_with_warmup_scheduler(optimizer: Optimizer, 67 | num_warmup_epochs: int, max_epoch: int): 68 | scheduler = get_linear_schedule_with_warmup( 69 | optimizer=optimizer, 70 | num_warmup_steps=num_warmup_epochs, 71 | num_training_steps=max_epoch 72 | ) 73 | return scheduler 74 | 75 | 76 | # @register.register_scheduler('cosine_with_warmup') 77 | def cosine_with_warmup_scheduler(optimizer: Optimizer, 78 | num_warmup_epochs: int, max_epoch: int): 79 | scheduler = get_cosine_schedule_with_warmup( 80 | optimizer=optimizer, 81 | num_warmup_steps=num_warmup_epochs, 82 | num_training_steps=max_epoch 83 | ) 84 | return scheduler 85 | 86 | 87 | def get_linear_schedule_with_warmup( 88 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 89 | last_epoch: int = -1): 90 | """ 91 | Implementation by Huggingface: 92 | https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py 93 | 94 | Create a schedule with a learning rate that decreases linearly from the 95 | initial lr set in the optimizer to 0, after a warmup period during which it 96 | increases linearly from 0 to the initial lr set in the optimizer. 97 | Args: 98 | optimizer ([`~torch.optim.Optimizer`]): 99 | The optimizer for which to schedule the learning rate. 100 | num_warmup_steps (`int`): 101 | The number of steps for the warmup phase. 102 | num_training_steps (`int`): 103 | The total number of training steps. 104 | last_epoch (`int`, *optional*, defaults to -1): 105 | The index of the last epoch when resuming training. 106 | Return: 107 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 108 | """ 109 | 110 | def lr_lambda(current_step: int): 111 | if current_step < num_warmup_steps: 112 | return max(1e-6, float(current_step) / float(max(1, num_warmup_steps))) 113 | return max( 114 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 115 | ) 116 | 117 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 118 | 119 | 120 | def get_cosine_schedule_with_warmup( 121 | optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, 122 | num_cycles: float = 0.5, last_epoch: int = -1): 123 | """ 124 | Implementation by Huggingface: 125 | https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py 126 | 127 | Create a schedule with a learning rate that decreases following the values 128 | of the cosine function between the initial lr set in the optimizer to 0, 129 | after a warmup period during which it increases linearly between 0 and the 130 | initial lr set in the optimizer. 131 | Args: 132 | optimizer ([`~torch.optim.Optimizer`]): 133 | The optimizer for which to schedule the learning rate. 134 | num_warmup_steps (`int`): 135 | The number of steps for the warmup phase. 136 | num_training_steps (`int`): 137 | The total number of training steps. 138 | num_cycles (`float`, *optional*, defaults to 0.5): 139 | The number of waves in the cosine schedule (the defaults is to just 140 | decrease from the max value to 0 following a half-cosine). 141 | last_epoch (`int`, *optional*, defaults to -1): 142 | The index of the last epoch when resuming training. 143 | Return: 144 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 145 | """ 146 | 147 | def lr_lambda(current_step): 148 | if current_step < num_warmup_steps: 149 | return max(1e-6, float(current_step) / float(max(1, num_warmup_steps))) 150 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 151 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 152 | 153 | return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) -------------------------------------------------------------------------------- /fragnet/model/gat/gat2_cv.py: -------------------------------------------------------------------------------- 1 | # from gat import FragNetFineTune 2 | import torch 3 | from dataset import load_pickle_dataset 4 | import torch.nn as nn 5 | from utils import EarlyStopping 6 | import torch 7 | from data import collate_fn 8 | from torch.utils.data import DataLoader 9 | import argparse 10 | from utils import TrainerFineTune as Trainer 11 | import numpy as np 12 | from omegaconf import OmegaConf 13 | import pickle 14 | import torch.optim.lr_scheduler as lr_scheduler 15 | from sklearn.model_selection import KFold 16 | from gat2 import FragNetFineTune 17 | import os 18 | 19 | 20 | def seed_everything(seed: int): 21 | import random, os 22 | 23 | random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = True 30 | 31 | seed_everything(26) 32 | 33 | def get_predictions(trainer, loader, model, device): 34 | mse, true, pred = trainer.test(model=model, loader=loader, device=device) 35 | smiles = [i.smiles for i in loader.dataset] 36 | res = {'smiles': smiles, 'true': true, 'pred': pred} 37 | return res 38 | 39 | def save_predictions(exp_dir, save_name, res): 40 | 41 | with open(f"{exp_dir}/{save_name}.pkl", 'wb') as f: 42 | pickle.dump(res,f ) 43 | 44 | 45 | 46 | if __name__ == "__main__": 47 | 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') 50 | args = parser.parse_args() 51 | 52 | if args.config: # args priority is higher than yaml 53 | opt = OmegaConf.load(args.config) 54 | OmegaConf.resolve(opt) 55 | args=opt 56 | 57 | seed_everything(args.seed) 58 | exp_dir = args['exp_dir'] 59 | os.makedirs(exp_dir, exist_ok=True) 60 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 61 | 62 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings' 63 | 64 | if args.model_version == 'gat': 65 | from gat import FragNetFineTune 66 | print('loaded from gat') 67 | elif args.model_version=='gat2': 68 | from gat2 import FragNetFineTune 69 | print('loaded from gat2') 70 | 71 | 72 | model = FragNetFineTune(n_classes=args.finetune.model.n_classes, 73 | atom_features=args.atom_features, 74 | frag_features=args.frag_features, 75 | edge_features=args.edge_features, 76 | num_layer=args.finetune.model.num_layer, 77 | drop_ratio=args.finetune.model.drop_ratio, 78 | num_heads=args.finetune.model.num_heads, 79 | emb_dim=args.finetune.model.emb_dim, 80 | h1=args.finetune.model.h1, 81 | h2=args.finetune.model.h2, 82 | h3=args.finetune.model.h3, 83 | h4=args.finetune.model.h4, 84 | act=args.finetune.model.act, 85 | fthead=args.finetune.model.fthead 86 | ) 87 | trainer = Trainer(target_type=args.finetune.target_type) 88 | 89 | 90 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings' 91 | if pt_chkpoint_name: 92 | 93 | from models import FragNetPreTrain 94 | modelpt = FragNetPreTrain(num_layer=args.finetune.model.num_layer, 95 | drop_ratio=args.finetune.model.drop_ratio, 96 | num_heads=args.finetune.model.num_heads, 97 | emb_dim=args.finetune.model.emb_dim, 98 | atom_features=args.atom_features, frag_features=args.frag_features, edge_features=args.edge_features) 99 | 100 | 101 | modelpt.load_state_dict(torch.load(pt_chkpoint_name, map_location=torch.device(device))) 102 | 103 | print('loading pretrained weights') 104 | model.pretrain.load_state_dict(modelpt.pretrain.state_dict()) 105 | print('weights loaded') 106 | else: 107 | print('no pretrained weights') 108 | 109 | 110 | train_dataset2 = load_pickle_dataset(args.finetune.train.path) 111 | val_dataset2 = load_pickle_dataset(args.finetune.val.path) 112 | 113 | kf = KFold(n_splits=5) 114 | train_val = train_dataset2 + val_dataset2 115 | 116 | for icv, (train_index, test_index) in enumerate(kf.split(train_val)): 117 | train_ds = [train_val[i] for i in train_index] 118 | val_ds = [train_val[i] for i in test_index] 119 | 120 | train_loader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=True, drop_last=True) 121 | val_loader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=False, drop_last=False) 122 | trainer = Trainer(target_type=args.finetune.target_type) 123 | 124 | 125 | model.to(device); 126 | ft_chk_point = os.path.join(args.exp_dir, f'cv_{icv}.pt') 127 | early_stopping = EarlyStopping(patience=args.finetune.es_patience, verbose=True, chkpoint_name=ft_chk_point) 128 | 129 | 130 | if args.finetune.loss == 'mse': 131 | loss_fn = nn.MSELoss() 132 | elif args.finetune.loss == 'cel': 133 | loss_fn = nn.CrossEntropyLoss() 134 | 135 | optimizer = torch.optim.Adam(model.parameters(), lr = args.finetune.lr ) # before 1e-4 136 | if args.finetune.use_schedular: 137 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99) 138 | else: 139 | scheduler=None 140 | 141 | for epoch in range(args.finetune.n_epochs): 142 | 143 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, scheduler=scheduler, device=device) 144 | val_loss, _, _ = trainer.test(model=model, loader=val_loader, device=device) 145 | 146 | print(train_loss, val_loss) 147 | early_stopping(val_loss, model) 148 | 149 | if early_stopping.early_stop: 150 | print("Early stopping") 151 | break 152 | 153 | 154 | model.load_state_dict(torch.load(ft_chk_point)) 155 | res = get_predictions(trainer, val_loader, model, device) 156 | 157 | with open(f'{exp_dir}/cv_{icv}.pkl', 'wb') as f: 158 | pickle.dump(res, f) 159 | -------------------------------------------------------------------------------- /fragnet/model/gat/gat2_pl.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import LearningRateMonitor 2 | import pytorch_lightning as pl 3 | from pytorch_lightning import loggers 4 | import torch 5 | import torch.nn.functional as F 6 | from gat2 import FragNetFineTune 7 | from pytorch_lightning import LightningModule, Trainer 8 | from dataset import load_pickle_dataset 9 | from torch.utils.data import DataLoader 10 | import argparse 11 | from omegaconf import OmegaConf 12 | from torch_geometric.nn.norm import BatchNorm 13 | from data import collate_fn 14 | import math 15 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 16 | from pytorch_lightning.callbacks import ModelCheckpoint 17 | 18 | def get_cosine_schedule_with_warmup( 19 | optimizer, num_warmup_steps, num_training_steps, 20 | num_cycles = 0.5, last_epoch = -1): 21 | """ 22 | Implementation by Huggingface: 23 | https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py 24 | 25 | Create a schedule with a learning rate that decreases following the values 26 | of the cosine function between the initial lr set in the optimizer to 0, 27 | after a warmup period during which it increases linearly between 0 and the 28 | initial lr set in the optimizer. 29 | Args: 30 | optimizer ([`~torch.optim.Optimizer`]): 31 | The optimizer for which to schedule the learning rate. 32 | num_warmup_steps (`int`): 33 | The number of steps for the warmup phase. 34 | num_training_steps (`int`): 35 | The total number of training steps. 36 | num_cycles (`float`, *optional*, defaults to 0.5): 37 | The number of waves in the cosine schedule (the defaults is to just 38 | decrease from the max value to 0 following a half-cosine). 39 | last_epoch (`int`, *optional*, defaults to -1): 40 | The index of the last epoch when resuming training. 41 | Return: 42 | `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 43 | """ 44 | 45 | def lr_lambda(current_step): 46 | if current_step < num_warmup_steps: 47 | return max(1e-6, float(current_step) / float(max(1, num_warmup_steps))) 48 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 49 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 50 | 51 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) 52 | 53 | class FragNetPL(pl.LightningModule): 54 | def __init__(self, args): 55 | super().__init__() 56 | self.save_hyperparameters() 57 | 58 | self.model = FragNetFineTune(n_classes=args.finetune.model.n_classes, 59 | atom_features=args.atom_features, 60 | frag_features=args.frag_features, 61 | edge_features=args.edge_features, 62 | num_layer=args.finetune.model.num_layer, 63 | drop_ratio=args.finetune.model.drop_ratio, 64 | num_heads=args.finetune.model.num_heads, 65 | emb_dim=args.finetune.model.emb_dim, 66 | h1=args.finetune.model.h1, 67 | h2=args.finetune.model.h2, 68 | h3=args.finetune.model.h3, 69 | h4=args.finetune.model.h4, 70 | act=args.finetune.model.act, 71 | fthead=args.finetune.model.fthead 72 | ) 73 | 74 | self.args = args 75 | 76 | def forward(self, batch): 77 | return self.model(batch) 78 | 79 | def common_step(self, batch): 80 | 81 | y = self(batch) 82 | y_pred = batch['y'] 83 | return y.reshape_as(y_pred), y_pred 84 | 85 | 86 | def training_step(self, batch, batch_idx): 87 | y, y_pred = self.common_step(batch) 88 | 89 | loss = F.mse_loss(y_pred, y) 90 | self.log('train_loss', loss, batch_size=args.finetune.batch_size) 91 | return loss 92 | 93 | def validation_step(self, batch, batch_idx): 94 | y, y_pred = self.common_step(batch) 95 | val_loss = F.mse_loss(y_pred, y) 96 | 97 | self.log('val_loss', val_loss, batch_size=args.finetune.batch_size) 98 | 99 | def test_step(self, batch, batch_idx): 100 | y, y_pred = self.common_step(batch) 101 | 102 | test_loss = F.mse_loss(y_pred, y) 103 | self.log('test_mse', test_loss**.5, batch_size=args.finetune.batch_size) 104 | 105 | 106 | def configure_optimizers(self): 107 | optimizer = torch.optim.Adam(self.parameters(), self.args.finetune.lr) 108 | return [optimizer] 109 | 110 | 111 | 112 | if __name__=="__main__": 113 | 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config_exp1.yaml') 116 | args = parser.parse_args() 117 | 118 | 119 | if args.config: # args priority is higher than yaml 120 | opt = OmegaConf.load(args.config) 121 | OmegaConf.resolve(opt) 122 | args=opt 123 | 124 | 125 | pl.seed_everything(args.seed) 126 | model = FragNetPL(args) 127 | 128 | 129 | early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=100, verbose=False, mode="min") 130 | checkpoint = ModelCheckpoint(monitor='val_loss', save_top_k=1, mode='min') 131 | 132 | trainer = Trainer(accelerator='auto', max_epochs=args.finetune.n_epochs, gradient_clip_val=1, devices=-1, 133 | log_every_n_steps=1, callbacks=[early_stop_callback, checkpoint]) 134 | 135 | 136 | train_dataset2 = load_pickle_dataset(args.finetune.train.path) 137 | val_dataset2 = load_pickle_dataset(args.finetune.val.path) 138 | test_dataset2 = load_pickle_dataset(args.finetune.test.path) #'finetune_data/pnnl_exp' 139 | 140 | train_loader = DataLoader(train_dataset2, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=True, drop_last=True) 141 | val_loader = DataLoader(val_dataset2, collate_fn=collate_fn, batch_size=64, shuffle=False, drop_last=False) 142 | test_loader = DataLoader(test_dataset2, collate_fn=collate_fn, batch_size=64, shuffle=False, drop_last=False) 143 | 144 | trainer.fit(model, train_loader, val_loader) 145 | trainer.test(model, test_loader) -------------------------------------------------------------------------------- /fragnet/model/gat/gat2_pretrain.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .pretrain_heads import PretrainTask 4 | from .gat2 import FragNet 5 | import random 6 | 7 | class FragNetPreTrain(nn.Module): 8 | 9 | def __init__(self, num_layer=4, drop_ratio=0.15, num_heads=4, emb_dim=128, 10 | atom_features=167, frag_features=167, edge_features=16, 11 | fedge_in=6, 12 | fbond_edge_in=6): 13 | super(FragNetPreTrain, self).__init__() 14 | 15 | self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio, num_heads=num_heads, emb_dim=emb_dim, 16 | atom_features=atom_features, frag_features=frag_features, edge_features=edge_features, 17 | fedge_in=fedge_in, 18 | fbond_edge_in=fbond_edge_in,) 19 | self.head = PretrainTask(128, 1) 20 | 21 | 22 | def forward(self, batch): 23 | 24 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch) 25 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = self.head(x_atoms, x_frags, e_edge, batch) 26 | 27 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep 28 | 29 | 30 | class FragNetPreTrainMasked(nn.Module): 31 | 32 | def __init__(self, num_layer=4, drop_ratio=0.15, num_heads=4, emb_dim=128, 33 | atom_features=167, frag_features=167, edge_features=16, 34 | fedge_in=6, fbond_edge_in=6): 35 | super(FragNetPreTrainMasked, self).__init__() 36 | 37 | self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio, num_heads=num_heads, emb_dim=emb_dim, 38 | atom_features=atom_features, frag_features=frag_features, edge_features=edge_features, 39 | fedge_in=fedge_in, fbond_edge_in=fbond_edge_in) 40 | self.head = PretrainTask(128, 1) 41 | 42 | 43 | def forward(self, batch): 44 | 45 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch) 46 | 47 | with torch.no_grad(): 48 | n_atoms = x_atoms.shape[0] 49 | unmask_atoms = random.sample(list(range(n_atoms)), int(n_atoms*.85) ) 50 | x_atoms_masked = torch.zeros(x_atoms.size()) + 0.0 51 | x_atoms_masked[unmask_atoms] = x_atoms[unmask_atoms] 52 | 53 | 54 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = self.head(x_atoms_masked, x_frags, e_edge, batch) 55 | 56 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep -------------------------------------------------------------------------------- /fragnet/model/gat/pretrain_heads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | from torch_scatter import scatter_add 5 | from fragnet.model.gat.gat2 import FragNet 6 | 7 | 8 | class PretrainTask(nn.Module): 9 | """ 10 | This function was copied from 11 | https://github.com/LARS-research/3D-PGT/blob/main/graphgps/head/pretrain_task.py 12 | and modified 13 | 14 | SAN prediction head for graph prediction tasks. 15 | 16 | Args: 17 | dim_in (int): Input dimension. 18 | dim_out (int): Output dimension. For binary prediction, dim_out=1. 19 | L (int): Number of hidden layers. 20 | """ 21 | 22 | def __init__(self, dim_in=128, dim_out=1, L=2): 23 | super().__init__() 24 | 25 | # bond_length 26 | self.bl_reduce_layer = nn.Linear(dim_in * 3, dim_in) 27 | list_bl_layers = [ 28 | nn.Linear(dim_in // 2**l, dim_in // 2 ** (l + 1), bias=True) 29 | for l in range(L) 30 | ] 31 | list_bl_layers.append(nn.Linear(dim_in // 2**L, dim_out, bias=True)) 32 | self.bl_layers = nn.ModuleList(list_bl_layers) 33 | 34 | # bond_angle 35 | list_ba_layers = [ 36 | nn.Linear(dim_in // 2**l, dim_in // 2 ** (l + 1), bias=True) 37 | for l in range(L) 38 | ] 39 | list_ba_layers.append(nn.Linear(dim_in // 2**L, dim_out, bias=True)) 40 | self.ba_layers = nn.ModuleList(list_ba_layers) 41 | 42 | # dihedral_angle 43 | list_da_layers = [ 44 | nn.Linear(dim_in // 2**l, dim_in // 2 ** (l + 1), bias=True) 45 | for l in range(L) 46 | ] 47 | list_da_layers.append(nn.Linear(dim_in // 2**L, dim_out, bias=True)) 48 | self.da_layers = nn.ModuleList(list_da_layers) 49 | 50 | # graph-level prediction (energy) 51 | list_FC_layers = [ 52 | nn.Linear(dim_in * 2 // 2**l, dim_in * 2 // 2 ** (l + 1), bias=True) 53 | for l in range(L) 54 | ] 55 | list_FC_layers.append(nn.Linear(dim_in * 2 // 2**L, dim_out, bias=True)) 56 | self.FC_layers = nn.ModuleList(list_FC_layers) 57 | 58 | self.L = L 59 | self.activation = nn.ReLU() 60 | 61 | def _apply_index(self, batch): 62 | return batch.bond_length, batch.distance 63 | 64 | def forward(self, x_atoms, x_frags, edge_attr, batch): 65 | edge_index = batch["edge_index"] 66 | 67 | bond_length_pred = torch.concat( 68 | (x_atoms[edge_index.T][:, 0, :], x_atoms[edge_index.T][:, 1, :], edge_attr), 69 | axis=1, 70 | ) 71 | bond_length_pred = self.bl_reduce_layer(bond_length_pred) 72 | for l in range(self.L + 1): 73 | bond_length_pred = self.activation(bond_length_pred) 74 | bond_length_pred = self.bl_layers[l](bond_length_pred) 75 | 76 | # bond_angle 77 | bond_angle_pred = x_atoms 78 | for l in range(self.L): 79 | bond_angle_pred = self.ba_layers[l](bond_angle_pred) 80 | bond_angle_pred = self.activation(bond_angle_pred) 81 | bond_angle_pred = self.ba_layers[self.L](bond_angle_pred) 82 | 83 | # dihedral_angle 84 | dihedral_angle_pred = edge_attr 85 | for l in range(self.L): 86 | dihedral_angle_pred = self.da_layers[l](dihedral_angle_pred) 87 | dihedral_angle_pred = self.activation(dihedral_angle_pred) 88 | dihedral_angle_pred = self.da_layers[self.L](dihedral_angle_pred) 89 | 90 | # total energy 91 | # graph_rep = self.pooling_fun(batch.x, batch.batch) 92 | 93 | x_frags_pooled = scatter_add(src=x_frags, index=batch["frag_batch"], dim=0) 94 | x_atoms_pooled = scatter_add(src=x_atoms, index=batch["batch"], dim=0) 95 | 96 | graph_rep = torch.cat((x_atoms_pooled, x_frags_pooled), 1) 97 | for l in range(self.L): 98 | graph_rep = self.FC_layers[l](graph_rep) 99 | graph_rep = self.activation(graph_rep) 100 | graph_rep = self.FC_layers[self.L](graph_rep) 101 | 102 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep 103 | 104 | 105 | class FragNetPreTrain(nn.Module): 106 | 107 | def __init__( 108 | self, 109 | num_layer=4, 110 | drop_ratio=0.15, 111 | num_heads=4, 112 | emb_dim=128, 113 | atom_features=167, 114 | frag_features=167, 115 | edge_features=16, 116 | fedge_in=6, 117 | fbond_edge_in=6, 118 | ): 119 | super(FragNetPreTrain, self).__init__() 120 | 121 | self.pretrain = FragNet( 122 | num_layer=num_layer, 123 | drop_ratio=drop_ratio, 124 | num_heads=num_heads, 125 | emb_dim=emb_dim, 126 | atom_features=atom_features, 127 | frag_features=frag_features, 128 | edge_features=edge_features, 129 | fedge_in=fedge_in, 130 | fbond_edge_in=fbond_edge_in, 131 | ) 132 | self.head = PretrainTask(128, 1) 133 | 134 | def forward(self, batch): 135 | 136 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch) 137 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = self.head( 138 | x_atoms, x_frags, e_edge, batch 139 | ) 140 | 141 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep 142 | 143 | 144 | class FragNetPreTrainMasked(nn.Module): 145 | 146 | def __init__( 147 | self, 148 | num_layer=4, 149 | drop_ratio=0.15, 150 | num_heads=4, 151 | emb_dim=128, 152 | atom_features=167, 153 | frag_features=167, 154 | edge_features=16, 155 | fedge_in=6, 156 | fbond_edge_in=6, 157 | ): 158 | super(FragNetPreTrainMasked, self).__init__() 159 | 160 | # self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio) 161 | self.pretrain = FragNet( 162 | num_layer=num_layer, 163 | drop_ratio=drop_ratio, 164 | num_heads=num_heads, 165 | emb_dim=emb_dim, 166 | atom_features=atom_features, 167 | frag_features=frag_features, 168 | edge_features=edge_features, 169 | fedge_in=fedge_in, 170 | fbond_edge_in=fbond_edge_in, 171 | ) 172 | self.head = PretrainTask(128, 1) 173 | 174 | def forward(self, batch): 175 | 176 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch) 177 | 178 | with torch.no_grad(): 179 | n_atoms = x_atoms.shape[0] 180 | 181 | unmask_atoms = random.sample(list(range(n_atoms)), int(n_atoms * 0.85)) 182 | x_atoms_masked = torch.zeros(x_atoms.size()) + 0.0 183 | x_atoms_masked = x_atoms_masked.to(x_atoms.device) 184 | x_atoms_masked[unmask_atoms] = x_atoms[unmask_atoms] 185 | 186 | 187 | class FragNetPreTrainMasked2(nn.Module): 188 | 189 | def __init__( 190 | self, 191 | num_layer=4, 192 | drop_ratio=0.15, 193 | num_heads=4, 194 | emb_dim=128, 195 | atom_features=167, 196 | frag_features=167, 197 | edge_features=16, 198 | fedge_in=6, 199 | fbond_edge_in=6, 200 | ): 201 | super(FragNetPreTrainMasked2, self).__init__() 202 | 203 | # self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio) 204 | self.pretrain = FragNet( 205 | num_layer=num_layer, 206 | drop_ratio=drop_ratio, 207 | num_heads=num_heads, 208 | emb_dim=emb_dim, 209 | atom_features=atom_features, 210 | frag_features=frag_features, 211 | edge_features=edge_features, 212 | fedge_in=fedge_in, 213 | fbond_edge_in=fbond_edge_in, 214 | ) 215 | self.head = PretrainTask(128, 1) 216 | 217 | def forward(self, batch): 218 | 219 | with torch.no_grad(): 220 | x_atoms = batch["x_atoms"] 221 | n_atoms = x_atoms.shape[0] 222 | 223 | unmask_atoms = random.sample(list(range(n_atoms)), int(n_atoms * 0.85)) 224 | x_atoms_masked = torch.zeros(x_atoms.size()) + 0.0 225 | x_atoms_masked = x_atoms_masked.to(x_atoms.device) 226 | x_atoms_masked[unmask_atoms] = x_atoms[unmask_atoms] 227 | 228 | batch["x_atoms"] = x_atoms_masked 229 | 230 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch) 231 | 232 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = self.head( 233 | x_atoms, x_frags, e_edge, batch 234 | ) 235 | 236 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep 237 | -------------------------------------------------------------------------------- /fragnet/model/gcn/gcn.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Parameter 2 | from torch_geometric.utils import add_self_loops, degree 3 | import torch 4 | import torch.nn as nn 5 | from torch_scatter import scatter_add, scatter_softmax 6 | from torch_geometric.utils import add_self_loops, degree 7 | from torch_geometric.utils import add_self_loops 8 | from torch_scatter import scatter_add 9 | 10 | 11 | class FragNetLayer(nn.Module): 12 | def __init__(self, atom_in=128, atom_out=128, frag_in=128, frag_out=128, 13 | edge_in=128, edge_out=128): 14 | super(FragNetLayer, self).__init__() 15 | 16 | 17 | self.atom_embed = nn.Linear(atom_in, atom_out, bias=True) 18 | self.frag_embed = nn.Linear(frag_in, frag_out) 19 | self.edge_embed = nn.Linear(edge_in, edge_out) 20 | 21 | self.frag_message_mlp = nn.Linear(atom_out*2, atom_out) 22 | self.atom_mlp = torch.nn.Sequential(torch.nn.Linear(atom_out, 2*atom_out), 23 | torch.nn.ReLU(), 24 | torch.nn.Linear(2*atom_out, atom_out)) 25 | 26 | self.frag_mlp = torch.nn.Sequential(torch.nn.Linear(atom_out, 2*atom_out), 27 | torch.nn.ReLU(), 28 | torch.nn.Linear(2*atom_out, atom_out)) 29 | 30 | def forward(self, x_atoms, 31 | edge_index, 32 | edge_attr, 33 | frag_index, 34 | x_frags, 35 | atom_to_frag_ids): 36 | 37 | 38 | edge_index, _ = add_self_loops(edge_index=edge_index) 39 | 40 | self_loop_attr = torch.zeros(x_atoms.size(0), 12, dtype=torch.long) 41 | self_loop_attr[:,0] = 0 # bond type for self-loop edge 42 | edge_attr = torch.cat((edge_attr, self_loop_attr.to(edge_attr)), dim=0) 43 | 44 | 45 | x_atoms = self.atom_embed(x_atoms) 46 | 47 | edge_attr = self.edge_embed(edge_attr) 48 | 49 | source, target = edge_index 50 | 51 | source_features = torch.index_select(input=x_atoms, index=source, dim=0) 52 | 53 | 54 | deg = degree(source, x_atoms.size(0), dtype=x_atoms.dtype) 55 | deg_inv_sqrt = deg.pow(-0.5) 56 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 57 | norm = deg_inv_sqrt[source] * deg_inv_sqrt[target] 58 | 59 | message = source_features*norm.view(-1,1) 60 | 61 | x_atoms_new = scatter_add(src=message, index=target, dim=0) 62 | 63 | x_frags = scatter_add(src=x_atoms_new, index=atom_to_frag_ids, dim=0) 64 | 65 | source, target = frag_index 66 | source_features = torch.index_select(input=x_frags, index=source, dim=0) 67 | frag_message = source_features 68 | frag_feats_sum = scatter_add(src=frag_message, index=target, dim=0) 69 | frag_feats_sum = self.frag_mlp(frag_feats_sum) 70 | 71 | x_frags_new = frag_feats_sum 72 | return x_atoms_new, x_frags_new 73 | 74 | 75 | 76 | 77 | class FragNet(nn.Module): 78 | 79 | def __init__(self, num_layer, drop_ratio = 0, emb_dim=128, 80 | atom_features=45, frag_features=45, edge_features=12): 81 | super(FragNet, self).__init__() 82 | self.num_layer = num_layer 83 | self.dropout = nn.Dropout(p=drop_ratio) 84 | self.act = nn.ReLU() 85 | 86 | self.layer1 = FragNetLayer(atom_in=atom_features, atom_out=emb_dim, frag_in=frag_features, 87 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim) 88 | self.layer2 = FragNetLayer(atom_in=emb_dim, atom_out=emb_dim, frag_in=emb_dim, 89 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim) 90 | self.layer3 = FragNetLayer(atom_in=emb_dim, atom_out=emb_dim, frag_in=emb_dim, 91 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim) 92 | self.layer4 = FragNetLayer(atom_in=emb_dim, atom_out=emb_dim, frag_in=emb_dim, 93 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim) 94 | 95 | ###List of batchnorms 96 | self.batch_norms = torch.nn.ModuleList() 97 | for layer in range(num_layer): 98 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 99 | 100 | def forward(self, batch): 101 | 102 | 103 | x_atoms = batch['x_atoms'] 104 | edge_index = batch['edge_index'] 105 | frag_index = batch['frag_index'] 106 | 107 | x_frags = batch['x_frags'] 108 | edge_attr = batch['edge_attr'] 109 | atom_batch = batch['batch'] 110 | frag_batch = batch['frag_batch'] 111 | atom_to_frag_ids = batch['atom_to_frag_ids'] 112 | 113 | node_feautures_bond_graph=batch['node_features_bonds'] 114 | edge_index_bonds_graph=batch['edge_index_bonds_graph'] 115 | edge_attr_bond_graph=batch['edge_attr_bonds'] 116 | 117 | 118 | x_atoms = self.dropout(x_atoms) 119 | x_frags = self.dropout(x_frags) 120 | 121 | x_atoms, x_frags = self.layer1(x_atoms, edge_index, edge_attr, 122 | frag_index, x_frags, atom_to_frag_ids) 123 | x_atoms, x_frags = self.act(x_atoms), self.act(x_frags) 124 | 125 | x_atoms, x_frags = self.layer2(x_atoms, edge_index, edge_attr, 126 | frag_index, x_frags, atom_to_frag_ids) 127 | x_atoms, x_frags = self.act(x_atoms), self.act(x_frags) 128 | 129 | x_atoms, x_frags = self.layer3(x_atoms, edge_index, edge_attr, 130 | frag_index, x_frags, atom_to_frag_ids) 131 | x_atoms, x_frags = self.act(x_atoms), self.act(x_frags) 132 | 133 | x_atoms, x_frags = self.layer4(x_atoms, edge_index, edge_attr, 134 | frag_index, x_frags, atom_to_frag_ids) 135 | x_atoms, x_frags = self.act(x_atoms), self.act(x_frags) 136 | 137 | 138 | return x_atoms, x_frags 139 | 140 | 141 | class FragNetPreTrain(nn.Module): 142 | 143 | def __init__(self): 144 | super(FragNetPreTrain, self).__init__() 145 | 146 | self.pretrain = FragNet(num_layer=6, drop_ratio=0.15) 147 | self.lin1 = nn.Linear(128, 128) 148 | self.out = nn.Linear(128, 13) 149 | self.dropout = nn.Dropout(p=0.15) 150 | self.activation = nn.ReLU() 151 | 152 | def forward(self, batch): 153 | 154 | x_atoms, x_frags = self.pretrain(batch) 155 | 156 | x = self.dropout(x_atoms) 157 | x = self.lin1(x) 158 | x = self.activation(x) 159 | x = self.dropout(x) 160 | x = self.out(x) 161 | 162 | 163 | return x 164 | 165 | class FragNetFineTune(nn.Module): 166 | 167 | def __init__(self, n_classes=1): 168 | super(FragNetFineTune, self).__init__() 169 | 170 | self.pretrain = FragNet(num_layer=6, drop_ratio=0.15) 171 | self.lin1 = nn.Linear(128*2, 128*2) 172 | 173 | self.dropout = nn.Dropout(p=0.15) 174 | self.activation = nn.ReLU() 175 | self.out = nn.Linear(128*2, n_classes) 176 | 177 | 178 | def forward(self, batch): 179 | 180 | x_atoms, x_frags = self.pretrain(batch) 181 | 182 | x_frags_pooled = scatter_add(src=x_frags, index=batch['frag_batch'], dim=0) 183 | x_atoms_pooled = scatter_add(src=x_atoms, index=batch['batch'], dim=0) 184 | 185 | cat = torch.cat((x_atoms_pooled, x_frags_pooled), 1) 186 | x = self.dropout(cat) 187 | x = self.lin1(x) 188 | x = self.activation(x) 189 | x = self.dropout(x) 190 | x = self.out(x) 191 | 192 | 193 | return x -------------------------------------------------------------------------------- /fragnet/model/gcn/gcn2.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Parameter 2 | from torch_geometric.utils import add_self_loops, degree 3 | import torch 4 | import torch.nn as nn 5 | from torch_scatter import scatter_add, scatter_softmax 6 | from torch_geometric.utils import add_self_loops, degree 7 | from torch_geometric.utils import add_self_loops 8 | from torch_scatter import scatter_add 9 | from gat2 import FTHead3, FTHead4 10 | 11 | class FragNetLayer(nn.Module): 12 | def __init__(self, atom_in=128, atom_out=128, frag_in=128, frag_out=128, 13 | edge_in=128, edge_out=128): 14 | super(FragNetLayer, self).__init__() 15 | 16 | 17 | self.atom_embed = nn.Linear(atom_in, atom_out, bias=True) 18 | self.frag_embed = nn.Linear(frag_in, frag_out) 19 | self.edge_embed = nn.Linear(edge_in, edge_out) 20 | 21 | self.frag_message_mlp = nn.Linear(atom_out*2, atom_out) 22 | self.atom_mlp = torch.nn.Sequential(torch.nn.Linear(atom_out, 2*atom_out), 23 | torch.nn.ReLU(), 24 | torch.nn.Linear(2*atom_out, atom_out)) 25 | 26 | self.frag_mlp = torch.nn.Sequential(torch.nn.Linear(atom_out, 2*atom_out), 27 | torch.nn.ReLU(), 28 | torch.nn.Linear(2*atom_out, atom_out)) 29 | 30 | 31 | def forward(self, x_atoms, 32 | edge_index, 33 | edge_attr, 34 | frag_index, 35 | x_frags, 36 | atom_to_frag_ids): 37 | 38 | 39 | edge_index, _ = add_self_loops(edge_index=edge_index) 40 | 41 | self_loop_attr = torch.zeros(x_atoms.size(0), edge_attr.shape[1], dtype=torch.long) 42 | edge_attr = torch.cat((edge_attr, self_loop_attr.to(edge_attr)), dim=0) 43 | 44 | 45 | x_atoms = self.atom_embed(x_atoms) 46 | edge_attr = self.edge_embed(edge_attr) 47 | 48 | source, target = edge_index 49 | source_features = torch.index_select(input=x_atoms, index=source, dim=0) 50 | 51 | deg = degree(source, x_atoms.size(0), dtype=x_atoms.dtype) 52 | deg_inv_sqrt = deg.pow(-0.5) 53 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 54 | norm = deg_inv_sqrt[source] * deg_inv_sqrt[target] 55 | 56 | message = source_features*norm.view(-1,1) 57 | 58 | x_atoms_new = scatter_add(src=message, index=target, dim=0) 59 | x_frags = scatter_add(src=x_atoms_new, index=atom_to_frag_ids, dim=0) 60 | 61 | source, target = frag_index 62 | source_features = torch.index_select(input=x_frags, index=source, dim=0) 63 | 64 | frag_message = source_features 65 | frag_feats_sum = scatter_add(src=frag_message, index=target, dim=0) 66 | frag_feats_sum = self.frag_mlp(frag_feats_sum) 67 | 68 | x_frags_new = frag_feats_sum 69 | 70 | 71 | return x_atoms_new, x_frags_new 72 | 73 | 74 | 75 | 76 | class FragNet(nn.Module): 77 | 78 | def __init__(self, num_layer, drop_ratio = 0, emb_dim=128, 79 | atom_features=45, frag_features=45, edge_features=12): 80 | super(FragNet, self).__init__() 81 | self.num_layer = num_layer 82 | self.dropout = nn.Dropout(p=drop_ratio) 83 | self.act = nn.ReLU() 84 | 85 | 86 | self.layers = torch.nn.ModuleList() 87 | self.layers.append(FragNetLayer(atom_in=atom_features, atom_out=emb_dim, frag_in=frag_features, 88 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim)) 89 | 90 | for i in range(num_layer-1): 91 | self.layers.append(FragNetLayer(atom_in=emb_dim, atom_out=emb_dim, frag_in=emb_dim, 92 | frag_out=emb_dim, edge_in=edge_features, edge_out=emb_dim)) 93 | 94 | ###List of batchnorms 95 | self.batch_norms = torch.nn.ModuleList() 96 | for layer in range(num_layer): 97 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 98 | 99 | def forward(self, batch): 100 | 101 | 102 | x_atoms = batch['x_atoms'] 103 | edge_index = batch['edge_index'] 104 | frag_index = batch['frag_index'] 105 | 106 | x_frags = batch['x_frags'] 107 | edge_attr = batch['edge_attr'] 108 | atom_batch = batch['batch'] 109 | frag_batch = batch['frag_batch'] 110 | atom_to_frag_ids = batch['atom_to_frag_ids'] 111 | 112 | node_feautures_bond_graph=batch['node_features_bonds'] 113 | edge_index_bonds_graph=batch['edge_index_bonds_graph'] 114 | edge_attr_bond_graph=batch['edge_attr_bonds'] 115 | 116 | 117 | x_atoms = self.dropout(x_atoms) 118 | x_frags = self.dropout(x_frags) 119 | 120 | 121 | x_atoms, x_frags = self.layers[0](x_atoms, edge_index, edge_attr, 122 | frag_index, x_frags, atom_to_frag_ids) 123 | x_atoms, x_frags = self.act(self.dropout(x_atoms)), self.act(self.dropout(x_frags)) 124 | 125 | for layer in self.layers[1:]: 126 | x_atoms, x_frags = layer(x_atoms, edge_index, edge_attr, 127 | frag_index, x_frags, atom_to_frag_ids) 128 | x_atoms, x_frags = self.act(self.dropout(x_atoms)), self.act(self.dropout(x_frags)) 129 | 130 | return x_atoms, x_frags 131 | 132 | 133 | class FragNetPreTrain(nn.Module): 134 | 135 | def __init__(self): 136 | super(FragNetPreTrain, self).__init__() 137 | 138 | self.pretrain = FragNet(num_layer=6, drop_ratio=0.15) 139 | self.lin1 = nn.Linear(128, 128) 140 | self.out = nn.Linear(128, 13) 141 | self.dropout = nn.Dropout(p=0.15) 142 | self.activation = nn.ReLU() 143 | 144 | def forward(self, batch): 145 | 146 | x_atoms, x_frags = self.pretrain(batch) 147 | 148 | x = self.dropout(x_atoms) 149 | x = self.lin1(x) 150 | x = self.activation(x) 151 | x = self.dropout(x) 152 | x = self.out(x) 153 | 154 | 155 | return x 156 | 157 | 158 | 159 | class FragNetFineTune(nn.Module): 160 | 161 | def __init__(self, n_classes=1, atom_features=167, frag_features=167, edge_features=16, 162 | num_layer=4, drop_ratio=.15, emb_dim=128, 163 | h1=256, h2=256, h3=256, h4=256, act='celu',fthead='FTHead3'): 164 | super().__init__() 165 | 166 | self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio, emb_dim=emb_dim, 167 | atom_features=atom_features, frag_features=frag_features, 168 | edge_features=edge_features) 169 | self.lin1 = nn.Linear(emb_dim*2, emb_dim*2) 170 | 171 | self.dropout = nn.Dropout(p=0.15) 172 | self.activation = nn.ReLU() 173 | if fthead == 'FTHead3': 174 | self.fthead = FTHead3(n_classes=n_classes, 175 | h1=h1, h2=h2, h3=h3, h4=h4, 176 | drop_ratio=drop_ratio, act=act) 177 | elif fthead == 'FTHead4': 178 | print('using FTHead4' ) 179 | self.fthead = FTHead4(n_classes=n_classes, 180 | h1=h1, drop_ratio=drop_ratio, act=act) 181 | 182 | 183 | def forward(self, batch): 184 | 185 | x_atoms, x_frags = self.pretrain(batch) 186 | 187 | x_frags_pooled = scatter_add(src=x_frags, index=batch['frag_batch'], dim=0) 188 | x_atoms_pooled = scatter_add(src=x_atoms, index=batch['batch'], dim=0) 189 | 190 | cat = torch.cat((x_atoms_pooled, x_frags_pooled), 1) 191 | x = self.fthead(cat) 192 | 193 | 194 | return x 195 | -------------------------------------------------------------------------------- /fragnet/model/gcn/gcn_pl.py: -------------------------------------------------------------------------------- 1 | from gcn import FragNetPreTrain 2 | from dataset import load_data_parts 3 | from data import mask_atom_features 4 | import torch.nn as nn 5 | from utils import EarlyStopping 6 | import torch 7 | from data import collate_fn 8 | from torch.utils.data import DataLoader 9 | from features import atom_list_one_hot 10 | from sklearn.model_selection import train_test_split 11 | import os 12 | import pickle 13 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 14 | import pytorch_lightning as pl 15 | from pytorch_lightning import loggers 16 | import torch.optim as optim 17 | 18 | 19 | def load_ids(fn): 20 | 21 | if not os.path.exists('gcn_output/train_ids.pkl'): 22 | 23 | train_ids, test_ids = train_test_split(fn, test_size=.2) 24 | test_ids, val_ids = train_test_split(test_ids, test_size=.5) 25 | 26 | with open('gcn_output/train_ids.pkl', 'wb') as f: 27 | pickle.dump(train_ids, f) 28 | with open('gcn_output/val_ids.pkl', 'wb') as f: 29 | pickle.dump(val_ids, f) 30 | with open('gcn_output/test_ids.pkl', 'wb') as f: 31 | pickle.dump(test_ids, f) 32 | 33 | else: 34 | with open('gcn_output/train_ids.pkl', 'rb') as f: 35 | train_ids = pickle.load(f) 36 | with open('gcn_output/val_ids.pkl', 'rb') as f: 37 | val_ids = pickle.load(f) 38 | with open('gcn_output/test_ids.pkl', 'rb') as f: 39 | test_ids = pickle.load(f) 40 | 41 | return train_ids, val_ids, test_ids 42 | 43 | 44 | class FragNetModule(pl.LightningModule): 45 | 46 | def __init__(self, model): 47 | """ 48 | Inputs:` 49 | model_name - Name of the model/CNN to run. Used for creating the model (see function below) 50 | model_hparams - Hyperparameters for the model, as dictionary. 51 | optimizer_name - Name of the optimizer to use. Currently supported: Adam, SGD 52 | optimizer_hparams - Hyperparameters for the optimizer, as dictionary. This includes learning rate, weight decay, etc. 53 | """ 54 | super().__init__() 55 | # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace 56 | self.save_hyperparameters() 57 | # Create model 58 | self.model = model 59 | 60 | 61 | def forward(self, batch): 62 | # Forward function that is run when visualizing the graph 63 | 64 | out = self.model(batch) 65 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1) 66 | preds = out.argmax(1) 67 | return preds, labels 68 | 69 | def configure_optimizers(self): 70 | 71 | optimizer = optim.AdamW(self.parameters()) 72 | 73 | # We will reduce the learning rate by 0.1 after 100 and 150 epochs 74 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1) 75 | 76 | 77 | return [optimizer], [scheduler] 78 | 79 | def training_step(self, batch, batch_idx): 80 | # "batch" is the output of the training data loader. 81 | out = self.model(batch) 82 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1) 83 | 84 | loss = loss_fn(out, labels) 85 | 86 | self.log('train_loss', loss) 87 | return loss # Return tensor to call ".backward" on 88 | 89 | def validation_step(self, batch, batch_idx): 90 | # x, y, p, scaffold = batch 91 | out = self.model(batch) 92 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1) 93 | loss = loss_fn(out, labels) 94 | 95 | self.log('val_loss', loss) 96 | 97 | def test_step(self, batch, batch_idx): 98 | out = self.model(batch) 99 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1) 100 | loss = loss_fn(out, labels) 101 | 102 | self.log('test_loss', loss) 103 | 104 | 105 | 106 | if __name__ == "__main__": 107 | 108 | 109 | files = os.listdir('pretrain_data/') 110 | fn = sorted([ int(i.split('.pkl')[0].strip('train')) for i in files if i.endswith('.pkl')]) 111 | 112 | train_ids, val_ids,test_ids = load_ids(fn) 113 | 114 | train_dataset = load_data_parts('pretrain_data', 'train', include=train_ids) 115 | val_dataset = load_data_parts('pretrain_data', 'train', include=val_ids) 116 | test_dataset = load_data_parts('pretrain_data', 'train', include=test_ids) 117 | 118 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True, drop_last=True) 119 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=256, shuffle=False, drop_last=False) 120 | test_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=256, shuffle=False, drop_last=False) 121 | 122 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 123 | 124 | 125 | model_pretrain = FragNetPreTrain() 126 | loss_fn = nn.CrossEntropyLoss() 127 | 128 | 129 | checkpoint_callback = ModelCheckpoint( 130 | save_weights_only=True, mode="min", monitor="val_loss", 131 | save_top_k=1, 132 | verbose=True, 133 | # dirpath=f'{args.exp_dir}/ckpt', 134 | dirpath=None, 135 | filename='model_best') 136 | # if args.trainer_version: 137 | trainer_version='tmp' 138 | 139 | logger = loggers.TensorBoardLogger(save_dir='./', version=trainer_version, name="lightning_logs") 140 | trainer = pl.Trainer(default_root_dir='./', # Where to save models 141 | accelerator="gpu" if str(device).startswith("cuda") else "cpu", 142 | devices=1, 143 | max_epochs=100, 144 | callbacks=[checkpoint_callback, LearningRateMonitor("epoch")], 145 | enable_progress_bar=True, 146 | gradient_clip_val=10, 147 | precision=16, 148 | limit_train_batches=None, 149 | limit_val_batches=None, 150 | logger=logger 151 | ) 152 | trainer.logger._log_graph = True 153 | trainer.logger._default_hp_metric = None 154 | 155 | 156 | model = FragNetModule(model_pretrain) 157 | trainer.fit(model, train_loader, val_loader) -------------------------------------------------------------------------------- /fragnet/train/finetune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/train/finetune/__init__.py -------------------------------------------------------------------------------- /fragnet/train/finetune/finetune_gat.py: -------------------------------------------------------------------------------- 1 | from dataset import load_pickle_dataset 2 | from torch.utils.data import DataLoader 3 | from data import collate_fn 4 | from gat import FragNetFineTune 5 | import torch.nn as nn 6 | from gat import FragNetPreTrain 7 | import torch 8 | from utils import EarlyStopping 9 | import numpy as np 10 | from utils import test_fn 11 | import argparse 12 | import os 13 | import pandas as pd 14 | 15 | 16 | def seed_everything(seed: int): 17 | import random, os 18 | 19 | random.seed(seed) 20 | os.environ['PYTHONHASHSEED'] = str(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = True 26 | 27 | seed_everything(26) 28 | 29 | def save_preds(test_dataset, true, pred, exp_dir): 30 | 31 | smiles = [i.smiles for i in test_dataset] 32 | res= pd.DataFrame(np.column_stack([smiles, true, pred]), columns=['smiles', 'true', 'pred']) 33 | res.to_csv(os.path.join(exp_dir, 'test_predictions.csv'), index=False) 34 | 35 | def train(train_loader, model, device, optimizer): 36 | model.train() 37 | total_loss = 0 38 | for batch in train_loader: 39 | for k,v in batch.items(): 40 | batch[k] = batch[k].to(device) 41 | optimizer.zero_grad() 42 | out = model(batch).view(-1,) 43 | loss = loss_fn(out, batch['y']) 44 | loss.backward() 45 | total_loss += loss.item() 46 | optimizer.step() 47 | return total_loss / len(train_loader.dataset) 48 | 49 | def get_optimizer(model, freeze_pt_weights=False, lr=1e-4): 50 | 51 | if freeze_pt_weights: 52 | print('freezing pretrain weights') 53 | for name, param in model.named_parameters(): 54 | if param.requires_grad and 'pretrain' in name: 55 | param.requires_grad = False 56 | 57 | non_frozen_parameters = [p for p in model.parameters() if p.requires_grad] 58 | optimizer = torch.optim.Adam(non_frozen_parameters, lr = lr) 59 | else: 60 | print('no freezing of the weights') 61 | optimizer = torch.optim.Adam(model.parameters(), lr = lr ) 62 | 63 | return model, optimizer 64 | 65 | 66 | if __name__ == "__main__": 67 | 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--dataset_name', help="", type=str,required=False, default='moleculenet') 70 | parser.add_argument('--dataset_subset', help="", type=str,required=False, default='esol') 71 | parser.add_argument('--seed', help="", type=int,required=False, default=None) 72 | parser.add_argument('--batch_size', help="saving file name", type=int,required=False, default=32) 73 | parser.add_argument('--checkpoint', help="checkpoint name", type=str,required=False, default='gnn.pt') 74 | parser.add_argument('--add_pt_weights', help="checkpoint name", type=bool,required=False, default=False) 75 | parser.add_argument('--freeze_pt_weights', help="checkpoint name", type=bool,required=False, default=False) 76 | parser.add_argument('--exp_dir', help="", type=str,required=False, default='exps/pnnl') 77 | args = parser.parse_args() 78 | 79 | dataset_name = args.dataset_name # 'esol' 80 | seed = args.seed # 36 81 | dataset_subset = args.dataset_subset.lower()# 'moleculenet' 82 | batch_size = args.batch_size 83 | exp_dir = args.exp_dir 84 | if not os.path.exists(exp_dir): 85 | os.makedirs(exp_dir) 86 | print('dataset: ', dataset_name, 'subset: ', dataset_subset) 87 | 88 | 89 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 90 | 91 | if dataset_name == 'moleculenet': 92 | train_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'train_{seed}') 93 | val_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'val_{seed}') 94 | test_dataset = load_pickle_dataset(f'{dataset_name}/{dataset_subset}', f'test_{seed}') 95 | elif 'pnnl' in dataset_name: 96 | train_dataset = load_pickle_dataset(path=dataset_name, name='train') 97 | val_dataset = load_pickle_dataset(path=dataset_name, name='val') 98 | test_dataset = load_pickle_dataset(path=dataset_name, name='test') 99 | 100 | 101 | 102 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, drop_last=True) 103 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, drop_last=False) 104 | test_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, drop_last=False) 105 | 106 | if args.add_pt_weights: 107 | model_pretrain = FragNetPreTrain(); 108 | model_pretrain.to(device); 109 | model_pretrain.load_state_dict(torch.load('pt.pt')) 110 | model = FragNetFineTune() 111 | model.to(device); 112 | loss_fn = nn.MSELoss() 113 | 114 | 115 | if args.add_pt_weights: 116 | print("adding pretrain weights to the model") 117 | model.pretrain.load_state_dict(model_pretrain.pretrain.state_dict()) 118 | 119 | model, optimizer = get_optimizer(model, freeze_pt_weights=args.freeze_pt_weights, lr=1e-4) 120 | 121 | chkpoint_name= os.path.join(exp_dir, args.checkpoint) 122 | early_stopping = EarlyStopping(patience=200, verbose=True, chkpoint_name=chkpoint_name) 123 | 124 | 125 | for epoch in range(2000): 126 | 127 | res = [] 128 | model.train() 129 | total_loss = 0 130 | for batch in train_loader: 131 | for k,v in batch.items(): 132 | batch[k] = batch[k].to(device) 133 | optimizer.zero_grad() 134 | out = model(batch).view(-1,) 135 | loss = loss_fn(out, batch['y']) 136 | loss.backward() 137 | total_loss += loss.item() 138 | optimizer.step() 139 | 140 | 141 | 142 | val_loss, _, _ = test_fn(val_loader, model, device) 143 | res.append(val_loss) 144 | print("val mse: ", val_loss) 145 | 146 | early_stopping(val_loss, model) 147 | 148 | if early_stopping.early_stop: 149 | print("Early stopping") 150 | break 151 | 152 | 153 | model.load_state_dict(torch.load(chkpoint_name)) 154 | 155 | mse, true, pred = test_fn(test_loader, model, device) 156 | save_preds(test_dataset, true, pred, exp_dir) 157 | print("rmse: ", mse**.5) 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /fragnet/train/finetune/gat2_cv_frag.py: -------------------------------------------------------------------------------- 1 | # from gat import FragNetFineTune 2 | import torch 3 | from fragnet.dataset.dataset import load_pickle_dataset 4 | import torch.nn as nn 5 | from fragnet.train.utils import EarlyStopping 6 | import torch 7 | from fragnet.dataset.data import collate_fn 8 | from torch.utils.data import DataLoader 9 | import pandas as pd 10 | import argparse 11 | from fragnet.train.utils import TrainerFineTune as Trainer 12 | import numpy as np 13 | from omegaconf import OmegaConf 14 | import pickle 15 | import torch.optim.lr_scheduler as lr_scheduler 16 | from sklearn.model_selection import KFold 17 | import os 18 | from sklearn.model_selection import train_test_split 19 | 20 | def seed_everything(seed: int): 21 | import random, os 22 | 23 | random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = True 30 | 31 | seed_everything(26) 32 | 33 | def get_predictions(trainer, loader, model, device): 34 | mse, true, pred = trainer.test(model=model, loader=loader, device=device) 35 | smiles = [i.smiles for i in loader.dataset] 36 | res = {'smiles': smiles, 'true': true, 'pred': pred} 37 | return res 38 | 39 | def save_predictions(exp_dir, save_name, res): 40 | 41 | with open(f"{exp_dir}/{save_name}.pkl", 'wb') as f: 42 | pickle.dump(res,f ) 43 | 44 | 45 | if __name__ == "__main__": 46 | 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') 49 | args = parser.parse_args() 50 | 51 | if args.config: 52 | opt = OmegaConf.load(args.config) 53 | OmegaConf.resolve(opt) 54 | args=opt 55 | 56 | seed_everything(args.seed) 57 | exp_dir = args['exp_dir'] 58 | os.makedirs(exp_dir, exist_ok=True) 59 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 60 | 61 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings' 62 | 63 | if args.model_version == 'gat': 64 | from fragnet.model.gat.gat import FragNetFineTune 65 | print('loaded from gat') 66 | elif args.model_version=='gat2': 67 | from fragnet.model.gat.gat2 import FragNetFineTune 68 | print('loaded from gat2') 69 | 70 | 71 | model = FragNetFineTune(n_classes=args.finetune.model.n_classes, 72 | atom_features=args.atom_features, 73 | frag_features=args.frag_features, 74 | edge_features=args.edge_features, 75 | num_layer=args.finetune.model.num_layer, 76 | drop_ratio=args.finetune.model.drop_ratio, 77 | num_heads=args.finetune.model.num_heads, 78 | emb_dim=args.finetune.model.emb_dim, 79 | h1=args.finetune.model.h1, 80 | h2=args.finetune.model.h2, 81 | h3=args.finetune.model.h3, 82 | h4=args.finetune.model.h4, 83 | act=args.finetune.model.act, 84 | fthead=args.finetune.model.fthead 85 | ) 86 | trainer = Trainer(target_type=args.finetune.target_type) 87 | 88 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings' 89 | if pt_chkpoint_name: 90 | 91 | from fragnet.model.gat.gat2_pretrain import FragNetPreTrain 92 | modelpt = FragNetPreTrain(num_layer=args.finetune.model.num_layer, 93 | drop_ratio=args.finetune.model.drop_ratio, 94 | num_heads=args.finetune.model.num_heads, 95 | emb_dim=args.finetune.model.emb_dim, 96 | atom_features=args.atom_features, frag_features=args.frag_features, edge_features=args.edge_features) 97 | 98 | 99 | modelpt.load_state_dict(torch.load(pt_chkpoint_name, map_location=torch.device(device))) 100 | 101 | print('loading pretrained weights') 102 | model.pretrain.load_state_dict(modelpt.pretrain.state_dict()) 103 | print('weights loaded') 104 | else: 105 | print('no pretrained weights') 106 | 107 | 108 | train_dataset2 = load_pickle_dataset(args.finetune.train.path) 109 | 110 | kf = KFold(n_splits=5) 111 | train_val = train_dataset2 112 | 113 | for icv, (train_index, test_index) in enumerate(kf.split(train_val)): 114 | train_index, val_index = train_test_split(train_index, test_size=.1) 115 | 116 | train_ds = [train_val[i] for i in train_index] 117 | val_ds = [train_val[i] for i in val_index] 118 | test_ds = [train_val[i] for i in test_index] 119 | 120 | train_loader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=True, drop_last=True) 121 | val_loader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=False, drop_last=False) 122 | test_loader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=args.finetune.batch_size, shuffle=False, drop_last=False) 123 | 124 | 125 | trainer = Trainer(target_type=args.finetune.target_type) 126 | 127 | 128 | model.to(device); 129 | ft_chk_point = os.path.join(args.exp_dir, f'cv_{icv}.pt') 130 | early_stopping = EarlyStopping(patience=args.finetune.es_patience, verbose=True, chkpoint_name=ft_chk_point) 131 | 132 | 133 | if args.finetune.loss == 'mse': 134 | loss_fn = nn.MSELoss() 135 | elif args.finetune.loss == 'cel': 136 | loss_fn = nn.CrossEntropyLoss() 137 | 138 | optimizer = torch.optim.Adam(model.parameters(), lr = args.finetune.lr ) # before 1e-4 139 | if args.finetune.use_schedular: 140 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99) 141 | else: 142 | scheduler=None 143 | 144 | for epoch in range(args.finetune.n_epochs): 145 | 146 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, 147 | scheduler=scheduler, device=device, val_loader=val_loader) 148 | val_loss, _, _ = trainer.test(model=model, loader=val_loader, device=device) 149 | 150 | 151 | print(train_loss, val_loss) 152 | 153 | 154 | early_stopping(val_loss, model) 155 | 156 | if early_stopping.early_stop: 157 | print("Early stopping") 158 | break 159 | 160 | 161 | model.load_state_dict(torch.load(ft_chk_point)) 162 | res = get_predictions(trainer, test_loader, model, device) 163 | 164 | with open(f'{exp_dir}/cv_{icv}.pkl', 'wb') as f: 165 | pickle.dump(res, f) 166 | -------------------------------------------------------------------------------- /fragnet/train/finetune/trainer_cdrp.py: -------------------------------------------------------------------------------- 1 | from utils import compute_bce_loss 2 | import torch 3 | import numpy as np 4 | from sklearn.metrics import mean_squared_error, roc_auc_score 5 | import torch.nn as nn 6 | 7 | 8 | class TrainerFineTune: 9 | def __init__(self, target_pos=None, target_type='regr', 10 | n_multi_task_heads=0): 11 | self.target_pos = target_pos 12 | if target_type=='clsf': 13 | 14 | self.train = self.train_clsf_bce 15 | self.validate = self.validate_clsf_bce 16 | self.test = self.test_clsf_bce 17 | self.loss_fn = compute_bce_loss 18 | 19 | elif target_type=='regr': 20 | self.train = self.train_regr 21 | self.validate = self.validate_regr 22 | self.test = self.test_regr 23 | self.loss_fn = nn.MSELoss() 24 | 25 | 26 | elif target_type=='clsf_ms': 27 | self.train = self.train_clsf_multi_task 28 | self.validate = self.validate_clsf_multi_task 29 | self.test = self.test_clsf_multi_task 30 | self.n_multi_task_heads = n_multi_task_heads 31 | 32 | 33 | def train_regr(self, model, loader, optimizer, scheduler, device, val_loader, label_mean, label_sdev): 34 | model.train() 35 | total_loss = 0 36 | for batch in loader: 37 | for k,v in batch.items(): 38 | batch[k] = batch[k].to(device) 39 | optimizer.zero_grad() 40 | out = model(batch).view(-1,) 41 | labels = batch['y'] 42 | 43 | loss = self.loss_fn(out, labels) 44 | loss.backward() 45 | total_loss += loss.item() 46 | optimizer.step() 47 | 48 | if scheduler: 49 | val_loss = self.validate(model, val_loader, device) 50 | scheduler.step() 51 | 52 | return total_loss / len(loader.dataset) 53 | 54 | def validate_regr(self, model,loader, device, label_mean, label_sdev): 55 | model.eval() 56 | total_loss = 0 57 | with torch.no_grad(): 58 | 59 | for batch in loader: 60 | for k,v in batch.items(): 61 | batch[k] = batch[k].to(device) 62 | 63 | out = model(batch).view(-1,) 64 | labels = batch['y'] 65 | loss = self.loss_fn(out, labels) 66 | total_loss += loss.item() 67 | return total_loss / len(loader.dataset) 68 | 69 | 70 | def test_regr(self, model, loader, device, label_mean, label_sdev): 71 | 72 | model.eval() 73 | with torch.no_grad(): 74 | target, predicted = [], [] 75 | for data in loader: 76 | for k,v in data.items(): 77 | data[k] = data[k].to(device) 78 | 79 | output = model(data).view(-1,) 80 | pred = output 81 | 82 | target += list(data['y'].cpu().detach().numpy().ravel() ) 83 | predicted += list(pred.cpu().detach().numpy().ravel() ) 84 | mse = mean_squared_error(target, predicted) 85 | 86 | return mse, np.array(target), np.array(predicted) 87 | 88 | -------------------------------------------------------------------------------- /fragnet/train/pretrain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pnnl/FragNet/54f99099a08ab991a5ba6845f7338492d03d4f1c/fragnet/train/pretrain/__init__.py -------------------------------------------------------------------------------- /fragnet/train/pretrain/pretrain_data_pnnl.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from dataset import get_pt_dataset, extract_data, save_datasets 3 | 4 | tr = pd.read_csv('/people/pana982/solubility/data/full_dataset/isomers/wang/set2/train_new.csv') 5 | vl = pd.read_csv('/people/pana982/solubility/data/full_dataset/isomers/wang/set2/val_new.csv') 6 | ts = pd.read_csv('/people/pana982/solubility/data/full_dataset/isomers/wang/set2/test_new.csv') 7 | 8 | df = pd.concat([tr, vl, ts], axis=0) 9 | df.reset_index(drop=True, inplace=True) 10 | 11 | ds = get_pt_dataset(df) 12 | ds = extract_data(ds) 13 | save_path = f'pretrain_data/pnnl_iso/ds/train' 14 | save_datasets(ds, save_path) 15 | -------------------------------------------------------------------------------- /fragnet/train/pretrain/pretrain_gat2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fragnet.train.utils import EarlyStopping 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from fragnet.dataset.dataset import load_data_parts 7 | import argparse 8 | import numpy as np 9 | from omegaconf import OmegaConf 10 | import pickle 11 | from torch.utils.tensorboard import SummaryWriter 12 | from fragnet.model.gat.pretrain_heads import FragNetPreTrain, FragNetPreTrainMasked, FragNetPreTrainMasked2 13 | from fragnet.train.pretrain.pretrain_utils import Trainer 14 | from sklearn.model_selection import train_test_split 15 | from fragnet.dataset.data import collate_fn_pt as collate_fn 16 | import pandas as pd 17 | from fragnet.model.gat.gat2 import FragNet 18 | from torch_scatter import scatter_add 19 | 20 | def seed_everything(seed: int): 21 | import random, os 22 | 23 | random.seed(seed) 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = True 30 | 31 | 32 | def save_ds_smiles(ds,name, exp_dir): 33 | smiles = [i.smiles for i in ds] 34 | with open(exp_dir+f'/{name}_smiles.pkl', 'wb') as f: 35 | pickle.dump(smiles, f) 36 | 37 | def load_train_val_dataset(args, ds): 38 | 39 | train_smiles = pd.read_pickle(args.pretrain.train_smiles) 40 | 41 | 42 | train_dataset, val_dataset = [], [] 43 | for data in ds: 44 | if data.smiles in train_smiles: 45 | train_dataset.append(data) 46 | else: 47 | val_dataset.append(data) 48 | 49 | 50 | return train_dataset, val_dataset 51 | 52 | def remove_duplicates_and_add(ds, path): 53 | t = load_data_parts(path) 54 | if len(ds) != 0: 55 | curr_smiles = [i.smiles for i in ds] 56 | new_ds = [i for i in t if i.smiles not in curr_smiles] 57 | else: 58 | new_ds=t 59 | ds += new_ds 60 | return ds 61 | 62 | def save_predictions(trainer, loader, model, exp_dir, device, save_name='test_res', loss_type='mse'): 63 | score, true, pred = trainer.test(model=model, loader=loader, device=device) 64 | smiles = [i.smiles for i in loader.dataset] 65 | 66 | if loss_type=='mse': 67 | print(f'{save_name} rmse: ', score**.5) 68 | res = {'acc': score**.5, 'true': true, 'pred': pred, 'smiles': smiles} 69 | elif loss_type=='cel': 70 | # score = roc_auc_score(true, pred[:,1]) 71 | print(f'{save_name} auc: ', score) 72 | res = {'acc': score, 'true': true, 'pred': pred, 'smiles': smiles} 73 | 74 | with open(f"{exp_dir}/{save_name}.pkl", 'wb') as f: 75 | pickle.dump(res,f ) 76 | 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yaml') 83 | args = parser.parse_args() 84 | 85 | if args.config: 86 | opt = OmegaConf.load(args.config) 87 | OmegaConf.resolve(opt) 88 | args=opt 89 | 90 | seed_everything(args.seed) 91 | exp_dir = args['exp_dir'] 92 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 93 | writer = SummaryWriter(exp_dir+'/runs') 94 | 95 | pt_chkpoint_name = args.pretrain.chkpoint_name #'exps/pt_rings' 96 | 97 | if args.model_version == 'gat2': 98 | model = FragNetPreTrain(num_layer=args.pretrain.num_layer, 99 | drop_ratio=args.pretrain.drop_ratio, 100 | num_heads=args.pretrain.num_heads, 101 | emb_dim=args.pretrain.emb_dim, 102 | atom_features=args.atom_features, 103 | frag_features=args.frag_features, 104 | edge_features=args.edge_features, 105 | fedge_in=args.fedge_in, 106 | fbond_edge_in=args.fbond_edge_in 107 | ) 108 | elif args.model_version == 'gat2_masked': 109 | model = FragNetPreTrainMasked(num_layer=args.pretrain.num_layer, 110 | drop_ratio=args.pretrain.drop_ratio, 111 | num_heads=args.pretrain.num_heads, 112 | emb_dim=args.pretrain.emb_dim, 113 | atom_features=args.atom_features, 114 | frag_features=args.frag_features, 115 | edge_features=args.edge_features, 116 | fedge_in=args.fedge_in, 117 | fbond_edge_in=args.fbond_edge_in) 118 | 119 | elif args.model_version == 'gat2_masked2': 120 | model = FragNetPreTrainMasked2(num_layer=args.pretrain.num_layer, 121 | drop_ratio=args.pretrain.drop_ratio, 122 | num_heads=args.pretrain.num_heads, 123 | emb_dim=args.pretrain.emb_dim, 124 | atom_features=args.atom_features, 125 | frag_features=args.frag_features, 126 | edge_features=args.edge_features, 127 | fedge_in=args.fedge_in, 128 | fbond_edge_in=args.fbond_edge_in) 129 | 130 | if args.pretrain.saved_checkpoint: 131 | model.load_state_dict(torch.load(args.pretrain.saved_checkpoint)) 132 | 133 | ds=[] 134 | for path in args.pretrain.data: 135 | ds = remove_duplicates_and_add(ds, path) 136 | 137 | 138 | if args.pretrain.train_smiles: 139 | train_dataset, val_dataset = load_train_val_dataset(args, ds) 140 | else: 141 | train_dataset, val_dataset = train_test_split(ds, test_size=.1, random_state=42) 142 | 143 | # NOTE: Start new runs in a new directory 144 | save_ds_smiles(train_dataset, 'train', args.exp_dir) 145 | save_ds_smiles(val_dataset, 'val', args.exp_dir) 146 | 147 | print('number of data points: ', len(ds)) 148 | writer.add_scalar('number of data points: ', len(ds)) 149 | 150 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=args.pretrain.batch_size, shuffle=True, drop_last=True) 151 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=args.pretrain.batch_size, shuffle=False, drop_last=False) 152 | 153 | 154 | model.to(device); 155 | early_stopping = EarlyStopping(patience=args.pretrain.es_patience, verbose=True, chkpoint_name=pt_chkpoint_name) 156 | 157 | 158 | if args.pretrain.loss == 'mse': 159 | loss_fn = nn.MSELoss() 160 | elif args.pretrain.loss == 'cel': 161 | loss_fn = nn.CrossEntropyLoss() 162 | 163 | trainer = Trainer(loss_fn=loss_fn) 164 | 165 | optimizer = torch.optim.Adam(model.parameters(), lr = args.pretrain.lr ) # before 1e-4 166 | 167 | for epoch in range(args.pretrain.n_epochs): 168 | 169 | train_loss = trainer.train(model=model, loader=train_loader, optimizer=optimizer, device=device) 170 | 171 | writer.add_scalar('Loss/train', train_loss, epoch) 172 | 173 | 174 | if epoch%5==0: 175 | val_loss = trainer.validate(model=model, loader=val_loader, device=device) 176 | print(train_loss, val_loss) 177 | writer.add_scalar('Loss/val', val_loss, epoch) 178 | 179 | early_stopping(val_loss, model) 180 | 181 | if early_stopping.early_stop: 182 | print("Early stopping") 183 | break 184 | -------------------------------------------------------------------------------- /fragnet/train/pretrain/pretrain_gat_mol.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils import EarlyStopping 3 | import torch 4 | from data import collate_fn 5 | from torch.utils.data import DataLoader 6 | import argparse 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | from utils import Trainer 10 | from torch.utils.tensorboard import SummaryWriter 11 | from dataset import LoadDataSets 12 | from pretrain_utils import load_prop_data, add_props_to_ds 13 | 14 | 15 | def seed_everything(seed: int): 16 | import random, os 17 | 18 | random.seed(seed) 19 | os.environ['PYTHONHASHSEED'] = str(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = True 25 | 26 | seed_everything(26) 27 | 28 | """ 29 | This is to pretrain on molecular properties 30 | """ 31 | 32 | 33 | if __name__=="__main__": 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--config', help="configuration file *.yml", type=str, required=False, default='config.yml') 37 | args = parser.parse_args() 38 | 39 | if args.config: # args priority is higher than yaml 40 | opt = OmegaConf.load(args.config) 41 | OmegaConf.resolve(opt) 42 | args=opt 43 | 44 | writer = SummaryWriter(args.exp_dir) 45 | if args.model_version == 'gat': 46 | from gat import FragNetFineTune 47 | elif args.model_version == 'gat2': 48 | from gat2 import FragNetFineTune 49 | 50 | 51 | ds = LoadDataSets() 52 | 53 | train_dataset, val_dataset, test_dataset = ds.load_datasets(args) 54 | 55 | prop_dict, _ = load_prop_data(args) 56 | add_props_to_ds(train_dataset, prop_dict) 57 | add_props_to_ds(val_dataset, prop_dict) 58 | 59 | 60 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True, drop_last=True) 61 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=512, shuffle=False, drop_last=False) 62 | 63 | n_classes = args.pretrain.n_classes # 31 for nRings 64 | target_pos = args.pretrain.target_pos 65 | target_type = args.pretrain.target_type 66 | 67 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 68 | model = FragNetFineTune(n_classes=args.pretrain.n_classes, 69 | num_layer=args.pretrain.num_layer, 70 | drop_ratio=args.pretrain.drop_ratio) 71 | 72 | model.to(device) 73 | 74 | optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4 ) 75 | chkpoint_name = args.pretrain.chkpoint_name #'pt_gat_nring.pt' 76 | early_stopping = EarlyStopping(patience=args.pretrain.n_epochs, verbose=True, chkpoint_name=chkpoint_name) 77 | 78 | if args.pretrain.loss == 'mse': 79 | loss_fn = nn.MSELoss() 80 | elif args.pretrain.loss == 'cel': 81 | loss_fn = nn.CrossEntropyLoss() 82 | 83 | pretrainer = Trainer(target_pos=target_pos, target_type=target_type, loss_fn=loss_fn) 84 | 85 | 86 | for epoch in range(args.pretrain.n_epochs): 87 | 88 | train_loss = pretrainer.train(model, train_loader, optimizer, device) 89 | val_loss = pretrainer.validate(val_loader, model, device) 90 | writer.add_scalar('Loss/train', train_loss, epoch) 91 | writer.add_scalar('Loss/val', val_loss, epoch) 92 | 93 | early_stopping(val_loss, model) 94 | 95 | if early_stopping.early_stop: 96 | print("Early stopping") 97 | break 98 | -------------------------------------------------------------------------------- /fragnet/train/pretrain/pretrain_gat_str.py: -------------------------------------------------------------------------------- 1 | from gat import FragNetPreTrain 2 | from dataset import load_data_parts 3 | from data import mask_atom_features 4 | import torch.nn as nn 5 | from utils import EarlyStopping 6 | import torch 7 | from data import collate_fn 8 | from torch.utils.data import DataLoader 9 | from features import atom_list_one_hot 10 | 11 | 12 | """ 13 | this is to pretrain on molecular fragments 14 | """ 15 | def train(model, train_loader, optimizer, device): 16 | model.train() 17 | total_loss = 0 18 | for batch in train_loader: 19 | mask_atom_features(batch) 20 | for k,v in batch.items(): 21 | batch[k] = batch[k].to(device) 22 | optimizer.zero_grad() 23 | out = model(batch) 24 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1) 25 | 26 | loss = loss_fn(out, labels) 27 | loss.backward() 28 | total_loss += loss.item() 29 | optimizer.step() 30 | return total_loss / len(train_loader.dataset) 31 | 32 | def validate(loader, model, device): 33 | model.eval() 34 | total_loss = 0 35 | with torch.no_grad(): 36 | 37 | for batch in loader: 38 | for k,v in batch.items(): 39 | batch[k] = batch[k].to(device) 40 | 41 | out = model(batch) 42 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1) 43 | 44 | loss = loss_fn(out, labels) 45 | total_loss += loss.item() 46 | return total_loss / len(loader.dataset) 47 | 48 | 49 | 50 | if __name__ == '__main__': 51 | 52 | train_dataset = load_data_parts('pretrain_data', 'train', include=range(0,479)) 53 | val_dataset = load_data_parts('pretrain_data', 'train', include=range(479,499)) 54 | 55 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True, drop_last=True) 56 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=256, shuffle=False, drop_last=False) 57 | 58 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 59 | 60 | model_pretrain = FragNetPreTrain() 61 | model_pretrain.to(device) 62 | 63 | loss_fn = nn.CrossEntropyLoss() 64 | 65 | optimizer = torch.optim.Adam(model_pretrain.parameters(), lr = 1e-4 ) 66 | chkpoint_name='pt_gat.pt' 67 | early_stopping = EarlyStopping(patience=100, verbose=True, chkpoint_name=chkpoint_name) 68 | 69 | 70 | 71 | for epoch in range(500): 72 | 73 | train_loss = train(model_pretrain, train_loader, optimizer, device) 74 | val_loss = validate(val_loader, model_pretrain, device) 75 | print(train_loss, val_loss) 76 | 77 | 78 | early_stopping(val_loss, model_pretrain) 79 | 80 | if early_stopping.early_stop: 81 | print("Early stopping") 82 | break 83 | -------------------------------------------------------------------------------- /fragnet/train/pretrain/pretrain_gcn.py: -------------------------------------------------------------------------------- 1 | from gcn import FragNetPreTrain 2 | from dataset import load_data_parts 3 | from data import mask_atom_features 4 | import torch.nn as nn 5 | from utils import EarlyStopping 6 | import torch 7 | from data import collate_fn 8 | from torch.utils.data import DataLoader 9 | from features import atom_list_one_hot 10 | from tqdm import tqdm 11 | import pickle 12 | import os 13 | from sklearn.model_selection import train_test_split 14 | 15 | def load_ids(fn): 16 | 17 | if not os.path.exists('gcn_output/train_ids.pkl'): 18 | 19 | train_ids, test_ids = train_test_split(fn, test_size=.2) 20 | test_ids, val_ids = train_test_split(test_ids, test_size=.5) 21 | 22 | with open('gcn_output/train_ids.pkl', 'wb') as f: 23 | pickle.dump(train_ids, f) 24 | with open('gcn_output/val_ids.pkl', 'wb') as f: 25 | pickle.dump(val_ids, f) 26 | with open('gcn_output/test_ids.pkl', 'wb') as f: 27 | pickle.dump(test_ids, f) 28 | 29 | else: 30 | with open('gcn_output/train_ids.pkl', 'rb') as f: 31 | train_ids = pickle.load(f) 32 | with open('gcn_output/val_ids.pkl', 'rb') as f: 33 | val_ids = pickle.load(f) 34 | with open('gcn_output/test_ids.pkl', 'rb') as f: 35 | test_ids = pickle.load(f) 36 | 37 | return train_ids, val_ids, test_ids 38 | 39 | 40 | def train(model, train_loader, optimizer, device): 41 | model.train() 42 | total_loss = 0 43 | for batch in train_loader: 44 | # batch = data.to(device) 45 | mask_atom_features(batch) 46 | for k,v in batch.items(): 47 | batch[k] = batch[k].to(device) 48 | optimizer.zero_grad() 49 | out = model(batch) 50 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1) 51 | 52 | loss = loss_fn(out, labels) 53 | loss.backward() 54 | total_loss += loss.item() 55 | optimizer.step() 56 | return total_loss / len(train_loader.dataset) 57 | 58 | def validate(loader, model, device): 59 | model.eval() 60 | total_loss = 0 61 | with torch.no_grad(): 62 | 63 | for batch in loader: 64 | for k,v in batch.items(): 65 | batch[k] = batch[k].to(device) 66 | 67 | out = model(batch) 68 | labels = batch['x_atoms'][:, :len(atom_list_one_hot)].argmax(1) 69 | 70 | loss = loss_fn(out, labels) 71 | total_loss += loss.item() 72 | return total_loss / len(loader.dataset) 73 | 74 | 75 | 76 | if __name__ == '__main__': 77 | 78 | files = os.listdir('pretrain_data/') 79 | fn = sorted([ int(i.split('.pkl')[0].strip('train')) for i in files if i.endswith('.pkl')]) 80 | train_ids, val_ids,test_ids = load_ids(fn) 81 | train_dataset = load_data_parts('pretrain_data', 'train', include=train_ids) 82 | val_dataset = load_data_parts('pretrain_data', 'train', include=val_ids) 83 | 84 | train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=512, shuffle=True, drop_last=True) 85 | val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=256, shuffle=False, drop_last=False) 86 | 87 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 88 | 89 | model_pretrain = FragNetPreTrain() 90 | model_pretrain.to(device) 91 | 92 | loss_fn = nn.CrossEntropyLoss() 93 | 94 | optimizer = torch.optim.Adam(model_pretrain.parameters(), lr = 1e-4 ) 95 | chkpoint_name='pt_gcn.pt' 96 | early_stopping = EarlyStopping(patience=500, verbose=True, chkpoint_name=chkpoint_name) 97 | 98 | 99 | 100 | for epoch in tqdm(range(500)): 101 | 102 | train_loss = train(model_pretrain, train_loader, optimizer, device) 103 | val_loss = validate(val_loader, model_pretrain, device) 104 | print(train_loss, val_loss) 105 | 106 | 107 | early_stopping(val_loss, model_pretrain) 108 | 109 | if early_stopping.early_stop: 110 | print("Early stopping") 111 | break 112 | -------------------------------------------------------------------------------- /fragnet/train/pretrain/pretrain_heads.py: -------------------------------------------------------------------------------- 1 | # from torch.nn import Parameter 2 | # from torch_geometric.utils import add_self_loops, degree 3 | import torch 4 | import torch.nn as nn 5 | # from torch_scatter import scatter_add, scatter_softmax 6 | # from torch_geometric.utils import add_self_loops 7 | from torch_scatter import scatter_add 8 | # import numpy as np 9 | # import torch.nn.functional as F 10 | # from torch_geometric.nn.norm import BatchNorm 11 | from fragnet.model.gat.gat2 import FragNet 12 | 13 | 14 | 15 | # get their activation functions 16 | class PretrainTask(nn.Module): 17 | """ 18 | SAN prediction head for graph prediction tasks. 19 | 20 | Args: 21 | dim_in (int): Input dimension. 22 | dim_out (int): Output dimension. For binary prediction, dim_out=1. 23 | L (int): Number of hidden layers. 24 | """ 25 | 26 | def __init__(self, dim_in=128, dim_out=1, L=2): 27 | super().__init__() 28 | # self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling] 29 | 30 | # bond_length 31 | self.bl_reduce_layer = nn.Linear(dim_in * 3, dim_in) 32 | list_bl_layers = [ 33 | nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) 34 | for l in range(L)] 35 | list_bl_layers.append( 36 | nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) 37 | self.bl_layers = nn.ModuleList(list_bl_layers) 38 | 39 | # bond_angle 40 | list_ba_layers = [ 41 | nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) 42 | for l in range(L)] 43 | list_ba_layers.append( 44 | nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) 45 | self.ba_layers = nn.ModuleList(list_ba_layers) 46 | 47 | # dihedral_angle 48 | list_da_layers = [ 49 | nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) 50 | for l in range(L)] 51 | list_da_layers.append( 52 | nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) 53 | self.da_layers = nn.ModuleList(list_da_layers) 54 | 55 | # graph-level prediction (energy) 56 | list_FC_layers = [ 57 | nn.Linear(dim_in*2 // 2 ** l, dim_in*2 // 2 ** (l + 1), bias=True) 58 | for l in range(L)] 59 | list_FC_layers.append( 60 | nn.Linear(dim_in*2 // 2 ** L, dim_out, bias=True)) 61 | self.FC_layers = nn.ModuleList(list_FC_layers) 62 | 63 | self.L = L 64 | self.activation = nn.ReLU() 65 | 66 | def _apply_index(self, batch): 67 | return batch.bond_length, batch.distance 68 | 69 | def forward(self, x_atoms, x_frags, edge_attr, batch): 70 | edge_index = batch['edge_index'] 71 | # bond_length 72 | # bond_length_pair = batch.positions[batch.edge_index.T] 73 | # bond_length_true = torch.sum((bond_length_pair[:, 0, :] - bond_length_pair[:, 1, :]) ** 2, axis=1) 74 | bond_length_pred = torch.concat((x_atoms[edge_index.T][:,0,:], x_atoms[edge_index.T][:,1,:], edge_attr),axis=1) 75 | bond_length_pred = self.bl_reduce_layer(bond_length_pred) 76 | for l in range(self.L + 1): 77 | bond_length_pred = self.activation(bond_length_pred) 78 | bond_length_pred = self.bl_layers[l](bond_length_pred) 79 | 80 | # bond_angle 81 | bond_angle_pred = x_atoms 82 | for l in range(self.L): 83 | bond_angle_pred = self.ba_layers[l](bond_angle_pred) 84 | bond_angle_pred = self.activation(bond_angle_pred) 85 | bond_angle_pred = self.ba_layers[self.L](bond_angle_pred) 86 | 87 | # dihedral_angle 88 | dihedral_angle_pred = edge_attr 89 | for l in range(self.L): 90 | dihedral_angle_pred = self.da_layers[l](dihedral_angle_pred) 91 | dihedral_angle_pred = self.activation(dihedral_angle_pred) 92 | dihedral_angle_pred = self.da_layers[self.L](dihedral_angle_pred) 93 | 94 | # total energy 95 | # graph_rep = self.pooling_fun(batch.x, batch.batch) 96 | 97 | x_frags_pooled = scatter_add(src=x_frags, index=batch['frag_batch'], dim=0) 98 | x_atoms_pooled = scatter_add(src=x_atoms, index=batch['batch'], dim=0) 99 | 100 | graph_rep = torch.cat((x_atoms_pooled, x_frags_pooled), 1) 101 | for l in range(self.L): 102 | graph_rep = self.FC_layers[l](graph_rep) 103 | graph_rep = self.activation(graph_rep) 104 | graph_rep = self.FC_layers[self.L](graph_rep) 105 | 106 | 107 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep 108 | 109 | 110 | 111 | 112 | class FragNetPreTrain(nn.Module): 113 | 114 | def __init__(self, num_layer=4, drop_ratio=0.15, num_heads=4, emb_dim=128, 115 | atom_features=167, frag_features=167, edge_features=16): 116 | super(FragNetPreTrain, self).__init__() 117 | 118 | # self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio) 119 | self.pretrain = FragNet(num_layer=num_layer, drop_ratio=drop_ratio, num_heads=num_heads, emb_dim=emb_dim, 120 | atom_features=atom_features, frag_features=frag_features, edge_features=edge_features) 121 | self.head = PretrainTask(128, 1) 122 | 123 | 124 | def forward(self, batch): 125 | 126 | x_atoms, x_frags, e_edge, e_fedge = self.pretrain(batch) 127 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = self.head(x_atoms, x_frags, e_edge, batch) 128 | 129 | return bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep -------------------------------------------------------------------------------- /fragnet/train/pretrain/pretrain_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | 4 | class Trainer: 5 | def __init__(self, loss_fn=None): 6 | self.loss_fn = loss_fn 7 | 8 | # single output regression 9 | def train(self, model, loader, optimizer, device): 10 | model.train() 11 | total_loss = 0 12 | for batch in loader: 13 | for k,v in batch.items(): 14 | batch[k] = batch[k].to(device) 15 | optimizer.zero_grad() 16 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = model(batch) 17 | bond_length_true = batch['bnd_lngth'] 18 | bond_angle_true = batch['bnd_angl'] 19 | dihedral_angle_true = batch['dh_angl'] 20 | E = batch['y'] 21 | 22 | loss_lngth = self.loss_fn(bond_length_pred, bond_length_true) 23 | loss_angle = self.loss_fn(bond_angle_pred, bond_angle_true) 24 | loss_lngth = self.loss_fn(dihedral_angle_pred, dihedral_angle_true) 25 | loss_E = self.loss_fn(graph_rep.view(-1), E) 26 | loss = loss_lngth + loss_angle + loss_lngth + loss_E 27 | 28 | loss.backward() 29 | total_loss += loss.item() 30 | optimizer.step() 31 | return total_loss / len(loader.dataset) 32 | 33 | def validate(self, loader, model, device): 34 | model.eval() 35 | total_loss = 0 36 | with torch.no_grad(): 37 | 38 | for batch in loader: 39 | for k,v in batch.items(): 40 | batch[k] = batch[k].to(device) 41 | 42 | 43 | bond_length_pred, bond_angle_pred, dihedral_angle_pred, graph_rep = model(batch) 44 | bond_length_true = batch['bnd_lngth'] 45 | bond_angle_true = batch['bnd_angl'] 46 | dihedral_angle_true = batch['dh_angl'] 47 | E = batch['y'] 48 | 49 | loss_lngth = self.loss_fn(bond_length_pred, bond_length_true) 50 | loss_angle = self.loss_fn(bond_angle_pred, bond_angle_true) 51 | loss_lngth = self.loss_fn(dihedral_angle_pred, dihedral_angle_true) 52 | loss_E = self.loss_fn(graph_rep.view(-1), E) 53 | loss = loss_lngth + loss_angle + loss_lngth + loss_E 54 | 55 | total_loss += loss.item() 56 | return total_loss / len(loader.dataset) 57 | 58 | 59 | def load_prop_data(args): 60 | 61 | fs = [] 62 | for f in args.pretrain.prop_files: 63 | fs.append(pd.read_csv(f)) 64 | fs = pd.concat(fs, axis=0) 65 | fs.reset_index(drop=True, inplace=True) 66 | 67 | fs = fs.loc[:, ['smiles'] + args.pretrain.props] 68 | 69 | prop_dict = dict(zip(fs.smiles, fs.iloc[:, 1:].values.tolist())) 70 | 71 | return prop_dict, fs 72 | 73 | def add_props_to_ds(ds, prop_dict): 74 | for d in ds: 75 | smiles = d.smiles 76 | props = prop_dict[smiles] 77 | d.y = torch.tensor([props], dtype=torch.float) -------------------------------------------------------------------------------- /fragnet/vizualize/app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | from pathlib import Path 4 | import base64 5 | from fragnet.vizualize.viz import FragNetVizApp 6 | from fragnet.vizualize.model import FragNetPreTrainViz 7 | from streamlit_ketcher import st_ketcher 8 | from fragnet.vizualize.model_attr import get_attr_image 9 | # Initial page config 10 | 11 | st.set_page_config( 12 | page_title='FragNet Vizualize', 13 | layout="wide", 14 | initial_sidebar_state="expanded", 15 | ) 16 | 17 | # Thanks to streamlitopedia for the following code snippet 18 | 19 | def img_to_bytes(img_path): 20 | img_bytes = Path(img_path).read_bytes() 21 | encoded = base64.b64encode(img_bytes).decode() 22 | return encoded 23 | 24 | # sidebar 25 | def input_callback(): 26 | st.session_state.input = st.session_state.my_input 27 | # def cs_sidebar(): 28 | 29 | def predict_cdrp(smiles, cell_line, cell_line_df): 30 | gene_expr = cell_line_df.loc[cell_line,:].values 31 | viz.calc_weights_cdrp(smiles, gene_expr) 32 | prop_prediction = -1 33 | return viz, prop_prediction 34 | 35 | 36 | 37 | def resolve_prop_model(prop_type): 38 | 39 | if prop_type == 'Solubility': 40 | model_config = './fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/config_exp100.yaml' 41 | chkpt_path = './fragnet/exps/ft/pnnl_full/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt' 42 | # model_config = './fragnet/exps/ft/pnnl_set2/fragnet_hpdl_exp1s_h4pt4_10/config_exp100.yaml' 43 | # chkpt_path = '../fragnet/exps/ft/pnnl_set2/fragnet_hpdl_exp1s_h4pt4_10/ft_100.pt' 44 | 45 | viz = FragNetVizApp(model_config, chkpt_path) 46 | 47 | prop_prediction = viz.calc_weights(selected) 48 | 49 | 50 | elif prop_type == 'Lipophilicity': 51 | model_config = './fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/config_exp100.yaml' 52 | chkpt_path = './fragnet/exps/ft/lipo/fragnet_hpdl_exp1s_pt4_30/ft_100.pt' 53 | viz = FragNetVizApp(model_config, chkpt_path) 54 | 55 | prop_prediction = viz.calc_weights(selected) 56 | 57 | elif prop_type == 'Energy': 58 | model_config = '../fragnet/fragnet/exps/pt/unimol_exp1s4/config.yaml' 59 | chkpt_path = '../fragnet/fragnet/exps/pt/unimol_exp1s4/pt.pt' 60 | viz = FragNetVizApp(model_config, chkpt_path, 'energy') 61 | prop_prediction = viz.calc_weights(selected) 62 | 63 | return viz, prop_prediction, model_config, chkpt_path 64 | 65 | def resolve_DRP(smiles, cell_line, cell_line_df): 66 | 67 | model_config = '../fragnet/fragnet/exps/ft/gdsc/fragnet_hpdl_exp1s_pt4_30/config_exp100.yaml' 68 | chkpt_path = '../fragnet/fragnet/exps/ft/gdsc/fragnet_hpdl_exp1s_pt4_30/ft_100.pt' 69 | viz = FragNetVizApp(model_config, chkpt_path,'cdrp') 70 | 71 | # viz, prop_prediction = predict_cdrp(smiles=selected, cell_line=cell_line, cell_line_df=cell_line_df) 72 | gene_expr = cell_line_df.loc[cell_line,:].values 73 | viz.calc_weights_cdrp(smiles, gene_expr) 74 | prop_prediction = -1 75 | 76 | return viz, prop_prediction, model_config, chkpt_path 77 | 78 | 79 | # st.sidebar.markdown('''[](https://streamlit.io/)'''.format(img_to_bytes("logomark_website.png")), unsafe_allow_html=True) 80 | st.sidebar.header('FragNet Vizualize') 81 | 82 | prop_type = st.sidebar.radio( 83 | "Select the Property type", 84 | # ["Solubility", "Lipophilicity", "Energy", "DRP"], 85 | ["Solubility", "Lipophilicity"], 86 | # captions = ["In logS units", "Lipophilicity", "Energy", "Drug Response Prediction"] 87 | captions = ["In logS units", "Lipophilicity"] 88 | ) 89 | 90 | # def input_callback(): 91 | # st.session_state.input = st.session_state.my_input 92 | # selected = st.text_input("Input Your Own SMILES :", key="my_input",on_change=input_callback,args=None) 93 | 94 | selected = st.sidebar.text_input("Input SMILES :", key="my_input",on_change=input_callback,args=None, 95 | value="CC1(C)CC(O)CC(C)(C)N1[O]") 96 | selected = st_ketcher( selected ) 97 | 98 | # if prop_type=="DRP": 99 | 100 | # cell_line = st.sidebar.selectbox( 101 | # 'Select the cell line identifier', 102 | # ['DATA.906826', 103 | # 'DATA.687983', 104 | # 'DATA.910927', 105 | # 'DATA.1240138', 106 | # 'DATA.1240139', 107 | # 'DATA.906792', 108 | # 'DATA.910688', 109 | # 'DATA.1240135', 110 | # 'DATA.1290812', 111 | # 'DATA.907045', 112 | # 'DATA.906861', 113 | # 'DATA.906830', 114 | # 'DATA.909750', 115 | # 'DATA.1240137', 116 | # 'DATA.753552', 117 | # 'DATA.907065', 118 | # 'DATA.925338', 119 | # 'DATA.1290809', 120 | # 'DATA.949158', 121 | # 'DATA.924110']) 122 | # cell_line='DATA.924110' 123 | # cell_line_df = pd.read_csv('../fragnet/fragnet/assets/cell_line_data.csv', index_col=0) 124 | 125 | # st.sidebar.write(f'selected cell line: {cell_line}') 126 | 127 | col1, col2, col3 = st.columns(3) 128 | 129 | if prop_type in ["Solubility", "Lipophilicity", "Energy"]: 130 | viz, prop_prediction, model_config, chkpt_path = resolve_prop_model(prop_type) 131 | # elif prop_type == "DRP": 132 | # viz, prop_prediction, model_config, chkpt_path = resolve_DRP(selected, cell_line, cell_line_df) 133 | 134 | 135 | hide_bond_weights = st.sidebar.checkbox("Hide bond weights") 136 | hide_atom_weights = st.sidebar.checkbox("Hide atom weights") 137 | 138 | with col1: 139 | # root='/Users/pana982/projects/esmi/models/fragnet/fragnet/' 140 | 141 | png, atom_weights = viz.vizualize_atom_weights(hide_bond_weights, hide_atom_weights) 142 | col1.image(png, caption='Atom Weights') 143 | 144 | # png_attr = get_attr_image(selected) 145 | # col1.image(png_attr, caption='Fragment Attributions') 146 | 147 | attn_atoms = pd.DataFrame(atom_weights) 148 | attn_atoms.index.rename('Atom Index', inplace=True) 149 | attn_atoms.columns = ['Atom Weights'] 150 | col1.dataframe(attn_atoms) 151 | 152 | 153 | 154 | 155 | 156 | if prop_type == "Solubility": 157 | st.sidebar.write(f"Predicted Solubility (logS): {prop_prediction:.4f}") 158 | if prop_type == "Lipophilicity": 159 | st.sidebar.write(f"Predicted Lipophilicity: {prop_prediction:.4f}") 160 | if prop_type == "Energy": 161 | st.sidebar.write(f"Predicted Energy: {prop_prediction:.4f}") 162 | 163 | hide_bond_weights=False 164 | png_frag_attn, png_frag_highlight, frag_w, connection_w, atoms_in_frags = viz.frag_weight_highlight() 165 | 166 | with col2: 167 | 168 | png_attr = get_attr_image(selected, model_config, chkpt_path) 169 | col2.image(png_attr, caption='Fragment Attributions') 170 | 171 | 172 | col2.image(png_frag_highlight, caption='Fragments') 173 | st.write("Atoms in each Fragment") 174 | df_atoms_in_frags = pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in atoms_in_frags.items() ])).T 175 | df_atoms_in_frags.index.rename('Fragment', inplace=True) 176 | st.dataframe(df_atoms_in_frags) 177 | 178 | with col3: 179 | 180 | 181 | st.image(png_frag_attn, caption='Fragment Weights') 182 | 183 | 184 | st.write('Fragment Weight Values') 185 | st.dataframe(frag_w) 186 | 187 | st.write('Fragment Connection Weight Values') 188 | st.dataframe(connection_w) 189 | 190 | st.sidebar.write(f"Current smiles: {selected}") -------------------------------------------------------------------------------- /fragnet/vizualize/config.py: -------------------------------------------------------------------------------- 1 | # MODEL_CONFIG_PATH = '../fragnet_edge/exps/ft/esol/fragnet_hpdl_exp1s_pt4_30' 2 | # MODEL_CONFIG = f'{MODEL_CONFIG_PATH}/config_exp100.yaml' 3 | # MODEL_PATH = f'{MODEL_CONFIG_PATH}/ft_100.pt' 4 | 5 | 6 | MODEL_CONFIG_PATH = '../fragnet_edge/exps/ft/pnnl_set2/' 7 | MODEL_CONFIG = f'{MODEL_CONFIG_PATH}/exp1s_h4pt4.yaml' 8 | MODEL_PATH = f'{MODEL_CONFIG_PATH}/h4/ft.pt' 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /install_cpu.sh: -------------------------------------------------------------------------------- 1 | python3.11 -m venv ~/.env/fragnet 2 | source ~/.env/fragnet/bin/activate 3 | pip install --upgrade pip 4 | pip install -r requirements.txt 5 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html 6 | pip install . 7 | 8 | mkdir -p fragnet/finetune_data/moleculenet/esol/raw/ 9 | 10 | wget -O fragnet/finetune_data/moleculenet/esol/raw/delaney-processed.csv https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv 11 | 12 | python fragnet/data_create/create_pretrain_datasets.py --save_path fragnet/pretrain_data/esol --data_type exp1s --maxiters 500 --raw_data_path fragnet/finetune_data/moleculenet/esol/raw/delaney-processed.csv 13 | 14 | python fragnet/data_create/create_finetune_datasets.py --dataset_name moleculenet --dataset_subset esol --use_molebert True --output_dir fragnet/finetune_data/moleculenet_exp1s --data_dir fragnet/finetune_data/moleculenet --data_type exp1s 15 | -------------------------------------------------------------------------------- /install_gpu.sh: -------------------------------------------------------------------------------- 1 | python -m venv ~/.env/fragnet 2 | source ~/.env/fragnet/bin/activate 3 | pip install --upgrade pip 4 | pip install -r requirements.txt 5 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu121.html 6 | pip install . 7 | 8 | mkdir -p fragnet/finetune_data/moleculenet/esol/raw/ 9 | 10 | wget -O fragnet/finetune_data/moleculenet/esol/raw/delaney-processed.csv https://deepchemdata.s3-us-west-1.amazonaws.com/d 11 | atasets/delaney-processed.csv 12 | 13 | python fragnet/data_create/create_pretrain_datasets.py --save_path fragnet/pretrain_data/esol --data_type exp1s --maxiters 14 | 500 --raw_data_path fragnet/finetune_data/moleculenet/esol/raw/delaney-processed.csv 15 | 16 | python fragnet/data_create/create_finetune_datasets.py --dataset_name moleculenet --dataset_subset esol --use_molebert Tru 17 | e --output_dir fragnet/finetune_data/moleculenet_exp1s --data_dir fragnet/finetune_data/moleculenet --data_type exp1s 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lmdb==1.4.1 2 | ipython 3 | matplotlib==3.9.2 4 | networkx==2.8.8 5 | numpy==1.24.1 6 | ogb==1.2.0 7 | omegaconf==2.3.0 8 | optuna==3.5.0 9 | pandas==2.2.3 10 | parmap==1.7.0 11 | Pillow==10.4.0 12 | PyYAML==6.0.2 13 | rdkit==2023.9.6 14 | scikit_learn==1.3.2 15 | scipy==1.14.1 16 | streamlit==1.38.0 17 | streamlit_ketcher==0.0.1 18 | torch==2.4.0 19 | pytorch_lightning 20 | torch_geometric==2.6.1 21 | tqdm==4.66.1 22 | tensorboard 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | # Application name: 5 | name="fragnet", 6 | 7 | # Version number (initial): 8 | version="0.1.0", 9 | 10 | # Application author details: 11 | author="Gihan Panapitiya", 12 | author_email="gihan.panapitiya@pnnl.gov", 13 | 14 | # Packages 15 | #packages=["mpet","mpet.utils", "mpet.models", "mpet.doa"], 16 | 17 | # Include additional files into the package 18 | include_package_data=True, 19 | 20 | # Details 21 | url="", 22 | 23 | # 24 | # license="LICENSE.txt", 25 | description="FragNet: A GNN with Four Layers of Interpretability", 26 | 27 | # long_description=open("README.txt").read(), 28 | 29 | # Dependent packages (distributions) 30 | # install_requires=["openbabel"], 31 | ) 32 | 33 | --------------------------------------------------------------------------------