├── README.md ├── configs └── property_info.json ├── data ├── __init__.py ├── collate_fn.py ├── examples │ └── zinc_example.txt ├── tokenizer.py └── zinc250k │ ├── all_250k_rndm_zinc_drugs_clean_3.txt │ └── valid_250k_rndm_zinc_drugs_clean_3.txt ├── data_finetune └── README.md ├── entrypoints ├── __init__.py ├── generation │ ├── __init__.py │ ├── conditional │ │ ├── __init__.py │ │ ├── generate.py │ │ └── visualize.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── evaluate_moses_few_shot_sampling.py │ │ └── evaluate_zinc250k_few_shot_sampling.py │ └── unconditional │ │ ├── __init__.py │ │ ├── generate.py │ │ └── visualize.py └── representation │ ├── __init__.py │ └── finetune.py ├── graphsgpt.svg ├── jupyter_notebooks ├── analysis │ ├── README-Clustering.md │ ├── clustering.ipynb │ ├── hybridization.ipynb │ └── interpolation.ipynb └── example_pipeline.ipynb ├── models ├── __init__.py ├── graphsgpt │ ├── __init__.py │ ├── configuration_graphsgpt.py │ ├── generation_utils.py │ ├── modeling_graphsgpt.py │ └── orf.py └── graphsgpt_cond_gen │ ├── __init__.py │ ├── configuration_graphsgpt_cond_gen.py │ └── modelling_graphsgpt_cond_gen.py ├── moses ├── NP_Score │ ├── README │ ├── __init__.py │ ├── npscorer.py │ └── publicnp.model.gz ├── README.md ├── SA_Score │ ├── README │ ├── __init__.py │ ├── fpscores.pkl.gz │ └── sascorer.py ├── __init__.py ├── data │ ├── test.csv.gz │ ├── test_scaffolds.csv.gz │ ├── test_scaffolds_stats.npz │ ├── test_stats.npz │ └── train.csv.gz ├── dataset.py ├── mcf.csv ├── metric_utils.py ├── metrics.py ├── utils.py └── wehi_pains.csv ├── requirements.txt ├── scripts ├── generation │ ├── conditional │ │ ├── README-Generation-Cond.md │ │ ├── examples │ │ │ ├── generate_scaffold_logp.sh │ │ │ ├── generate_scaffold_qed.sh │ │ │ └── generate_scaffold_sa.sh │ │ └── visualize.sh │ ├── evaluation │ │ ├── moses.sh │ │ └── zinc250k.sh │ └── unconditional │ │ ├── README-Generation-Uncond.md │ │ ├── examples │ │ ├── generate_flexible.sh │ │ └── generate_strict.sh │ │ └── visualize.sh └── representation │ └── finetune.sh ├── setup.py └── utils ├── __init__.py ├── accuracy.py ├── algorithms.py ├── io.py ├── molecule.py ├── operations ├── __init__.py ├── operation_dataframe.py ├── operation_dict.py ├── operation_list.py ├── operation_number.py ├── operation_string.py └── operation_tensor.py ├── property_scores ├── __init__.py ├── fpscores.pkl.gz ├── nspdk.py ├── sascorer.py ├── scoring_func.py └── tanimoto_similarity.py └── representation ├── __init__.py ├── data_interface.py ├── dict.txt ├── dictionary.py ├── fingerprint_interface.py ├── graphsgpt_finetune_model.py ├── lmdb_dataset.py ├── load_dataset.py ├── logger.py └── model_interface.py /README.md: -------------------------------------------------------------------------------- 1 | # [GraphsGPT] A Graph is Worth $K$ Words:
Euclideanizing Graph using Pure Transformer (ICML2024) 2 | 3 | **[Zhangyang Gao](https://scholar.google.com/citations?user=4SclT-QAAAAJ)\*, [Daize Dong](https://daizedong.github.io/)\*, [Cheng Tan](https://chengtan9907.github.io/), [Jun Xia](https://junxia97.github.io/), [Bozhen Hu](https://scholar.google.com/citations?user=6FZh9C8AAAAJ), [Stan Z. Li](https://scholar.google.com/citations?user=Y-nyLGIAAAAJ)** 4 | 5 | Published on *The 41st International Conference on Machine Learning (ICML 2024)*. 6 | 7 | [![arXiv](https://img.shields.io/badge/arXiv-2402.02464-b31b1b.svg?style=plastic)](https://arxiv.org/abs/2402.02464) 8 | 9 | ## Introduction 10 | 11 | Can we model Non-Euclidean graphs as pure language or even Euclidean vectors while retaining their inherent information? The Non-Euclidean property have posed a long term challenge in graph modeling. Despite recent graph neural networks and graph transformers efforts encoding graphs as Euclidean vectors, recovering the original graph from vectors remains a challenge. 12 | In this paper, we introduce GraphsGPT, featuring an Graph2Seq encoder that transforms Non-Euclidean graphs into learnable GraphWords in the Euclidean space, along with a GraphGPT decoder that reconstructs the original graph from GraphWords to ensure information equivalence. We pretrain GraphsGPT on 100M molecules and yield some interesting findings: 13 | 14 | - The pretrained Graph2Seq excels in graph representation learning, achieving state-of-the-art results on $8/9$ graph classification and regression tasks. 15 | - The pretrained GraphGPT serves as a strong graph generator, demonstrated by its strong ability to perform both few-shot and conditional graph generation. 16 | - Graph2Seq+GraphGPT enables effective graph mixup in the Euclidean space, overcoming previously known Non-Euclidean challenges. 17 | - The edge-centric pretraining framework GraphsGPT demonstrates its efficacy in graph domain tasks, excelling in both representation and generation. 18 | 19 | ![graphsgpt.svg](graphsgpt.svg) 20 | 21 | 22 | 23 | ## Installation 24 | 25 | To get started with GraphsGPT, please run the following commands to install the environments. 26 | 27 | ```bash 28 | git clone git@github.com:A4Bio/GraphsGPT.git --depth=1 29 | cd GraphsGPT 30 | conda create --name graphsgpt python=3.12 31 | conda activate graphsgpt 32 | pip install -e .[dev] 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | 37 | 38 | ## Quickstart 39 | 40 | We provide some Jupyter Notebooks in `./jupyter_notebooks`, and their corresponding online Google Colaboratory Notebooks. You can run them for a quick start. 41 | 42 | | | Jupyter Notebook | Google Colaboratory | 43 | | :--------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 44 | | *GraphsGPT Pipeline* | [example_pipeline.ipynb](jupyter_notebooks%2Fexample_pipeline.ipynb) | Open In Colab | 45 | | Clustering Analysis | [clustering.ipynb](jupyter_notebooks%2Fanalysis%2Fclustering.ipynb) | Open In Colab | 46 | | Hybridization Analysis | [hybridization.ipynb](jupyter_notebooks%2Fanalysis%2Fhybridization.ipynb) | Open In Colab | 47 | | Interpolation Analysis | [interpolation.ipynb](jupyter_notebooks%2Fanalysis%2Finterpolation.ipynb) | Open In Colab | 48 | 49 | 50 | 51 | ## Checkpoints 52 | 53 | The model [checkpoints](https://huggingface.co/collections/DaizeDong/graphsgpt-65efe70c326a1a5bd35c2fcc) can be downloaded from 🤗 Transformers. We provide both the foundational pretrained models with different number of Graph Words $\mathcal{W}$ (GraphsGPT-nW), and the conditional version with one Graph Word (GraphsGPT-1W-C). 54 | 55 | | Model Name | Model Type | Model Checkpoint | 56 | | :------------: | :--------------: | :----------------------------------------------------------: | 57 | | GraphsGPT-1W | Foundation Model | | 58 | | GraphsGPT-2W | Foundation Model | | 59 | | GraphsGPT-4W | Foundation Model | | 60 | | GraphsGPT-8W | Foundation Model | | 61 | | GraphsGPT-1W-C | Finetuned Model | | 62 | 63 | 64 | 65 | ## Representation Experiments 66 | 67 | You should first [download](https://github.com/A4Bio/GraphsGPT/releases/tag/data) the configurations and data for finetuning, and put them in `./data_finetune`. (We also include the finetuned checkpoints in the `model_zoom.zip` file for a quick test.) 68 | 69 | To evaluate the representation performance of the Graph2Seq Encoder, please run: 70 | 71 | ```bash 72 | bash ./scripts/representation/finetune.sh 73 | ``` 74 | 75 | You can also toggle the `--mixup_strategy` for graph mixup using Graph2Seq. 76 | 77 | 78 | 79 | ## Generation Experiments 80 | 81 | For the unconditional generation with GraphGPT Decoder, please refer to [README-Generation-Uncond.md](scripts%2Fgeneration%2Funconditional%2FREADME-Generation-Uncond.md). 82 | 83 | For the conditional generation with GraphGPT-C Decoder, please refer to [README-Generation-Cond.md](scripts%2Fgeneration%2Fconditional%2FREADME-Generation-Cond.md). 84 | 85 | To evaluate the few-shots generation performance of GraphGPT Decoder, please run: 86 | 87 | ```bash 88 | bash ./scripts/generation/evaluation/moses.sh 89 | bash ./scripts/generation/evaluation/zinc250k.sh 90 | ``` 91 | 92 | 93 | 94 | ## Citation 95 | 96 | ```latex 97 | @article{gao2024graph, 98 | title={A Graph is Worth $K$ Words: Euclideanizing Graph using Pure Transformer}, 99 | author={Gao, Zhangyang and Dong, Daize and Tan, Cheng and Xia, Jun and Hu, Bozhen and Li, Stan Z}, 100 | journal={arXiv preprint arXiv:2402.02464}, 101 | year={2024} 102 | } 103 | ``` 104 | 105 | ## Contact Us 106 | If you have any questions, please contact: 107 | 108 | - Zhangyang Gao: gaozhangyang@westlake.edu.cn 109 | 110 | - Daize Dong: dzdong2019@gmail.com 111 | -------------------------------------------------------------------------------- /configs/property_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "qed": { 3 | "mean": 0.7195562004332452, 4 | "std": 0.13004599771146944 5 | }, 6 | "sa": { 7 | "mean": 0.7612138562611798, 8 | "std": 0.06663695435624298 9 | }, 10 | "logp": { 11 | "mean": 2.570799410827321, 12 | "std": 1.3405790580975692 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/data/__init__.py -------------------------------------------------------------------------------- /data/collate_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.utils.rnn as rnn_utils 3 | from typing import List 4 | 5 | 6 | def identity_collator(examples): 7 | return examples 8 | 9 | 10 | def tensor_stack_collator(examples): 11 | """ 12 | examples: list of tensors. 13 | input: [tensor1, tensor2, ..., tensorN] 14 | output: stacked_tensor 15 | """ 16 | return torch.stack(examples, dim=0) 17 | 18 | 19 | class tensor_stack_padding_collater: 20 | """ 21 | examples: list of tensors. 22 | input: [tensor1, tensor2, ..., tensorN] 23 | output: padded_tensor 24 | """ 25 | 26 | def __init__(self, padding_id, padding_position="right", return_padding_mask=True): 27 | assert padding_position in ("left", "right") 28 | self.padding_id = padding_id 29 | self.padding_position = padding_position 30 | self.return_padding_mask = return_padding_mask 31 | 32 | def __call__(self, examples): 33 | dtype = examples[0].dtype 34 | if self.padding_position == "right": 35 | padded_examples = rnn_utils.pad_sequence(examples, batch_first=True, padding_value=self.padding_id) 36 | elif self.padding_position == "left": # This will take about twice the time compared to right padding 37 | flipped_examples = [torch.flip(tensor, dims=[0]) for tensor in examples] 38 | padded_examples_flip = rnn_utils.pad_sequence(flipped_examples, batch_first=True, padding_value=self.padding_id) 39 | padded_examples = torch.flip(padded_examples_flip, dims=[1]) 40 | else: 41 | raise NotImplementedError 42 | padded_examples = padded_examples.to(dtype) 43 | 44 | if self.return_padding_mask: 45 | padding_mask = (padded_examples != self.padding_id) 46 | return padded_examples, padding_mask 47 | else: 48 | return padded_examples 49 | 50 | 51 | def tensor_lists_stack_collator(examples): 52 | """ 53 | examples: list of tensor lists. 54 | input: 55 | [ 56 | [tensor1, tensor1, ..., tensor1], 57 | [tensor2, tensor2, ..., tensor2], 58 | ... 59 | [tensorN, tensorN, ..., tensorN], 60 | ] 61 | output: 62 | [ 63 | stacked_tensor1, 64 | stacked_tensor2, 65 | ... 66 | stacked_tensorN, 67 | ] 68 | """ 69 | return [torch.stack([tensor_list[i] for tensor_list in examples], dim=0) for i in range(len(examples[0]))] 70 | 71 | 72 | class tensor_lists_stack_padding_collater: 73 | def __init__(self, padding_id, padding_position="right", return_padding_mask=True, tensor_ids_to_create_mask: List = None): 74 | assert padding_position in ("left", "right") 75 | self.padding_id = padding_id 76 | self.padding_position = padding_position 77 | self.return_padding_mask = return_padding_mask 78 | 79 | # set indices of tensors in list to create "padding_mask" 80 | # if set to "None", then "padding_mask" will be created for all keys in dict 81 | self.tensor_ids_to_create_mask = tensor_ids_to_create_mask 82 | 83 | def __call__(self, examples): 84 | """ 85 | examples: list of tensor lists. 86 | input: 87 | [ 88 | [tensor1, tensor1, ..., tensor1], 89 | [tensor2, tensor2, ..., tensor2], 90 | ... 91 | [tensorN, tensorN, ..., tensorN], 92 | ] 93 | output: 94 | [ 95 | padded_tensor1, 96 | padded_tensor2, 97 | ... 98 | padded_tensorN, 99 | ] 100 | """ 101 | num_tensors = len(examples[0]) 102 | padded_tensors = [] 103 | padding_masks = [] 104 | 105 | for i in range(num_tensors): 106 | dtype = examples[0][0].dtype 107 | if self.padding_position == "right": 108 | tensors = [tensor_list[i] for tensor_list in examples] 109 | padded_tensor = rnn_utils.pad_sequence(tensors, batch_first=True, padding_value=self.padding_id) 110 | elif self.padding_position == "left": # This will take about twice the time compared to right padding 111 | flipped_tensors = [torch.flip(tensor_list[i], dims=[0]) for tensor_list in examples] 112 | flipped_padded_tensors = rnn_utils.pad_sequence(flipped_tensors, batch_first=True, padding_value=self.padding_id) 113 | padded_tensor = torch.flip(flipped_padded_tensors, dims=[1]) 114 | else: 115 | raise NotImplementedError 116 | padded_tensor = padded_tensor.to(dtype) 117 | 118 | padded_tensors.append(padded_tensor) 119 | 120 | if self.return_padding_mask: 121 | if self.tensor_ids_to_create_mask is None or i in self.tensor_ids_to_create_mask: 122 | padding_masks.append(padded_tensors[i] != self.padding_id) 123 | else: 124 | padding_masks.append(None) 125 | 126 | if self.return_padding_mask: 127 | return padded_tensors, padding_masks 128 | else: 129 | return padded_tensors 130 | 131 | 132 | def tensor_dicts_stack_collator(examples): 133 | """ 134 | examples: list of tensor dicts. 135 | input: 136 | [ 137 | { 138 | "key1": tensor1, 139 | "key2": tensor2, 140 | ... 141 | "keyN": tensorN, 142 | }, 143 | { 144 | "key1": tensor1, 145 | "key2": tensor2, 146 | ... 147 | "keyN": tensorN, 148 | }, 149 | ... 150 | { 151 | "key1": tensor1, 152 | "key2": tensor2, 153 | ... 154 | "keyN": tensorN, 155 | } 156 | ] 157 | output: 158 | { 159 | "key1": stacked_tensor1, 160 | "key2": stacked_tensor2, 161 | ... 162 | "keyN": stacked_tensorN, 163 | } 164 | """ 165 | return {key: torch.stack([tensor_dict[key] for tensor_dict in examples], dim=0) for key in examples[0].keys()} 166 | 167 | 168 | class tensor_dict_stack_padding_collater: 169 | def __init__(self, padding_id, padding_position="right", return_padding_mask=True, tensor_keys_to_create_mask: List = None): 170 | assert padding_position in ("left", "right") 171 | self.padding_id = padding_id 172 | self.padding_position = padding_position 173 | self.return_padding_mask = return_padding_mask 174 | 175 | # set keys of tensors in dict to create "padding_mask" 176 | # if set to "None", then "padding_mask" will be created for all keys in dict 177 | self.tensor_keys_to_create_mask = tensor_keys_to_create_mask 178 | 179 | def __call__(self, examples): 180 | """ 181 | examples: list of tensor (or other types) dicts. 182 | input: 183 | [ 184 | { 185 | "key0": int, 186 | "key1": tensor1, 187 | "key2": tensor2, 188 | ... 189 | "keyN": tensorN, 190 | }, 191 | { 192 | "key0": int, 193 | "key1": tensor1, 194 | "key2": tensor2, 195 | ... 196 | "keyN": tensorN, 197 | }, 198 | ... 199 | { 200 | "key0": int, 201 | "key1": tensor1, 202 | "key2": tensor2, 203 | ... 204 | "keyN": tensorN, 205 | } 206 | ] 207 | output: 208 | { 209 | "key0": [int, int, ..., int], 210 | "key1": padded_tensor1, 211 | "key2": padded_tensor2, 212 | ... 213 | "keyN": padded_tensorN, 214 | } 215 | """ 216 | keys = examples[0].keys() 217 | padded_tensors = {} 218 | padding_masks = {} 219 | 220 | for key in keys: 221 | if isinstance(examples[0][key], torch.Tensor): 222 | if self.padding_position == "right": 223 | tensors = [tensor_dict[key] for tensor_dict in examples] 224 | padded_tensor = rnn_utils.pad_sequence(tensors, batch_first=True, padding_value=self.padding_id) 225 | elif self.padding_position == "left": # This will take about twice the time compared to right padding 226 | flipped_tensors = [torch.flip(tensor_dict[key], dims=[0]) for tensor_dict in examples] 227 | flipped_padded_tensors = rnn_utils.pad_sequence(flipped_tensors, batch_first=True, padding_value=self.padding_id) 228 | padded_tensor = torch.flip(flipped_padded_tensors, dims=[1]) 229 | else: 230 | raise NotImplementedError 231 | else: # not tensor type, return as a list 232 | padded_tensor = [tensor_dict[key] for tensor_dict in examples] 233 | 234 | padded_tensors[key] = padded_tensor 235 | 236 | if self.return_padding_mask and isinstance(examples[0][key], torch.Tensor): 237 | if self.tensor_keys_to_create_mask is None or key in self.tensor_keys_to_create_mask: 238 | padding_masks[key] = (padded_tensors[key] != self.padding_id) 239 | else: 240 | padding_masks[key] = None 241 | 242 | if self.return_padding_mask: 243 | return padded_tensors, padding_masks 244 | else: 245 | return padded_tensors 246 | -------------------------------------------------------------------------------- /data_finetune/README.md: -------------------------------------------------------------------------------- 1 | Please [download](https://github.com/A4Bio/GraphsGPT/releases/tag/data) the configurations and data for finetuning and put them in this folder. 2 | 3 | Finally, the structure of this folder should be like: 4 | 5 | ```` 6 | -- data_finetune 7 | -- model_zoom (checkpoints and finetuning configs) 8 | -- bace 9 | -- bbbp 10 | -- hiv 11 | ... 12 | -- molecular_property_prediction (finetuning data) 13 | -- bace 14 | -- bbbp 15 | -- esol 16 | ... 17 | ```` -------------------------------------------------------------------------------- /entrypoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/__init__.py -------------------------------------------------------------------------------- /entrypoints/generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/generation/__init__.py -------------------------------------------------------------------------------- /entrypoints/generation/conditional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/generation/conditional/__init__.py -------------------------------------------------------------------------------- /entrypoints/generation/conditional/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from argparse import ArgumentParser 4 | from rdkit import RDLogger 5 | from tqdm import tqdm 6 | 7 | from data.tokenizer import GraphsGPTTokenizer 8 | from utils.io import delete_file_or_dir, create_dir, save_mol_png, save_empty_png, summary_property_from_file, summary_property_from_all_files 9 | from utils.molecule import get_molecule_standard_scaffold 10 | from utils.operations.operation_string import str2bool 11 | from utils.property_scores.scoring_func import get_qed, get_sa, get_logp, get_is_valid 12 | 13 | RDLogger.DisableLog('rdApp.*') 14 | 15 | 16 | def main(args): 17 | tokenizer = GraphsGPTTokenizer.from_pretrained(args.model_name_or_path) 18 | 19 | file_list = os.listdir(args.generation_results_dir) 20 | file_list = sorted(file_list) 21 | file_list = file_list[args.file_begin_index:args.file_end_index] 22 | 23 | all_smiles = [] 24 | 25 | delete_file_or_dir(args.save_dir) 26 | 27 | """visualize""" 28 | for file_name in tqdm(file_list): 29 | success_cnt = 0 30 | invalid_cnt = 0 31 | fail_cnt = 0 32 | 33 | valid_list = [] 34 | qed_list = [] 35 | sa_list = [] 36 | logp_list = [] 37 | smiles_list = [] 38 | scaffold_smiles_list = [] 39 | 40 | save_dir = os.path.join(args.save_dir, f"{file_name.split('.')[0]}") 41 | create_dir(save_dir) 42 | 43 | """read file""" 44 | file_path = os.path.join(args.generation_results_dir, file_name) 45 | with open(file_path, 'rb') as f: 46 | result_list = pickle.load(f) 47 | if not isinstance(result_list, (list, tuple)): 48 | result_list = [result_list] 49 | 50 | """save png""" 51 | for i, result in enumerate(result_list): 52 | print(i) 53 | save_img_file = os.path.join(save_dir, f"{i}.png") 54 | 55 | if result is not None: 56 | mol, smiles = tokenizer.decode(result, kekulize=True) # kekulize the molecule for score calculation 57 | 58 | if mol is None: 59 | valid_list.append(None) 60 | qed_list.append(None) 61 | sa_list.append(None) 62 | logp_list.append(None) 63 | smiles_list.append(None) 64 | scaffold_smiles_list.append(None) 65 | if args.save_images: 66 | save_empty_png(save_img_file) 67 | invalid_cnt += 1 68 | else: 69 | valid = get_is_valid(mol) 70 | qed = get_qed(mol) 71 | sa = get_sa(mol) 72 | logp = get_logp(mol) 73 | 74 | scaffold = get_molecule_standard_scaffold(mol, normalizer=tokenizer.normalizer) 75 | scaffold_smiles = tokenizer._convert_molecule_to_standard_smiles(scaffold) 76 | 77 | valid_list.append(valid) 78 | qed_list.append(qed) 79 | sa_list.append(sa) 80 | logp_list.append(logp) 81 | smiles_list.append(smiles) 82 | scaffold_smiles_list.append(scaffold_smiles) 83 | 84 | if args.save_images: 85 | save_mol_png(mol, save_img_file) 86 | success_cnt += 1 87 | else: 88 | valid_list.append(None) 89 | qed_list.append(None) 90 | sa_list.append(None) 91 | logp_list.append(None) 92 | smiles_list.append(None) 93 | scaffold_smiles_list.append(None) 94 | if args.save_images: 95 | save_empty_png(save_img_file) 96 | fail_cnt += 1 97 | 98 | all_smiles.extend(smiles_list) 99 | 100 | """save statistics""" 101 | with open(os.path.join(save_dir, "count.txt"), 'a') as f: 102 | f.write(f"Success count: {success_cnt}\n") 103 | f.write(f"Invalid count: {invalid_cnt}\n") 104 | f.write(f"Fail count: {fail_cnt}\n") 105 | 106 | with open(os.path.join(save_dir, "valid.txt"), 'a') as f: 107 | for valid in valid_list: 108 | f.write(f"{valid}\n") 109 | 110 | with open(os.path.join(save_dir, "qed.txt"), 'a') as f: 111 | for qed in qed_list: 112 | f.write(f"{qed}\n") 113 | 114 | with open(os.path.join(save_dir, "sa.txt"), 'a') as f: 115 | for sa in sa_list: 116 | f.write(f"{sa}\n") 117 | 118 | with open(os.path.join(save_dir, "logp.txt"), 'a') as f: 119 | for logp in logp_list: 120 | f.write(f"{logp}\n") 121 | 122 | with open(os.path.join(save_dir, "smiles.txt"), 'a') as f: 123 | for smiles in smiles_list: 124 | f.write(f"{smiles}\n") 125 | 126 | with open(os.path.join(save_dir, "scaffold_smiles.txt"), 'a') as f: 127 | for scaffold_smiles in scaffold_smiles_list: 128 | f.write(f"{scaffold_smiles}\n") 129 | 130 | mean_qed, std_qed, num_qed = summary_property_from_file(os.path.join(save_dir, "qed.txt")) 131 | mean_sa, std_sa, num_sa = summary_property_from_file(os.path.join(save_dir, "sa.txt")) 132 | mean_logp, std_logp, num_logp = summary_property_from_file(os.path.join(save_dir, "logp.txt")) 133 | 134 | valid_num = sum([1 for valid in valid_list if valid]) 135 | total_num = len([smiles for smiles in smiles_list if smiles is not None]) 136 | 137 | with open(os.path.join(save_dir, "summary.txt"), 'w') as f: 138 | f.write(f"Summary QED: mean={format(mean_qed, '.3f')}, std={format(std_qed, '.3f')}, total_cnt={num_qed}\n") 139 | f.write(f"Summary SA: mean={format(mean_sa, '.3f')}, std={format(std_sa, '.3f')}, total_cnt={num_sa}\n") 140 | f.write(f"Summary logP: mean={format(mean_logp, '.3f')}, std={format(std_logp, '.3f')}, total_cnt={num_logp}\n") 141 | f.write(f"Summary validity: {format(valid_num / total_num * 100, '.2f')}% ({valid_num}/{total_num})\n") 142 | 143 | """summarize the results""" 144 | mean_qed, std_qed, num_qed = summary_property_from_all_files(args.save_dir, "qed.txt", value_type="float") 145 | mean_sa, std_sa, num_sa = summary_property_from_all_files(args.save_dir, "sa.txt", value_type="float") 146 | mean_logp, std_logp, num_logp = summary_property_from_all_files(args.save_dir, "logp.txt", value_type="float") 147 | mean_validity, std_validity, num_validity = summary_property_from_all_files(args.save_dir, "valid.txt", value_type="bool") 148 | 149 | all_total_num = len([smiles for smiles in all_smiles if smiles is not None]) 150 | 151 | with open(os.path.join(args.save_dir, "final_summary.txt"), 'w') as f: 152 | f.write(f"Summary QED: mean={format(mean_qed, '.3f')}, std={format(std_qed, '.3f')}, total_cnt={num_qed}\n") 153 | f.write(f"Summary SA: mean={format(mean_sa, '.3f')}, std={format(std_sa, '.3f')}, total_cnt={num_sa}\n") 154 | f.write(f"Summary logP: mean={format(mean_logp, '.3f')}, std={format(std_logp, '.3f')}, total_cnt={num_logp}\n") 155 | f.write(f"Summary validity: {format(mean_validity * 100, '.2f')}% ({round(mean_validity * all_total_num)}/{all_total_num})\n") 156 | 157 | 158 | if __name__ == '__main__': 159 | parser = ArgumentParser() 160 | parser.add_argument("--model_name_or_path", default="DaizeDong/GraphsGPT-1W-C", type=str) 161 | parser.add_argument('--generation_results_dir', default=None, type=str) 162 | parser.add_argument('--save_dir', default=None, type=str) 163 | parser.add_argument('--save_images', default="True", type=str) 164 | 165 | parser.add_argument('--file_begin_index', default=0, type=int) 166 | parser.add_argument('--file_end_index', default=2, type=int) 167 | args = parser.parse_args() 168 | args.save_images = str2bool(args.save_images) 169 | print(args) 170 | main(args) 171 | -------------------------------------------------------------------------------- /entrypoints/generation/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/generation/evaluation/__init__.py -------------------------------------------------------------------------------- /entrypoints/generation/evaluation/evaluate_moses_few_shot_sampling.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import argparse 4 | import random 5 | import torch 6 | from tqdm import tqdm 7 | 8 | import moses 9 | from data.collate_fn import tensor_dict_stack_padding_collater 10 | from data.tokenizer import GraphsGPTTokenizer 11 | from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM 12 | from utils.io import create_dir, save_json, delete_file_or_dir 13 | from utils.operations.operation_list import split_list 14 | from utils.operations.operation_tensor import move_tensors_to_device 15 | 16 | 17 | def main(args): 18 | random.seed(0) 19 | torch.manual_seed(0) 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | """load data & model""" 23 | train_smiles = moses.get_dataset("train") 24 | # test_smiles = moses.get_dataset("test") 25 | # test_scaffolds = moses.get_dataset("test_scaffolds") 26 | 27 | pad_collator = tensor_dict_stack_padding_collater(0, tensor_keys_to_create_mask=["input_ids"]) 28 | 29 | tokenizer = GraphsGPTTokenizer.from_pretrained(args.model_name_or_path) 30 | model = GraphsGPTForCausalLM.from_pretrained(args.model_name_or_path) 31 | model.to(device) 32 | 33 | """generate according to reference fingerprints""" 34 | all_smiles = [] 35 | 36 | # initialize the variable for batch-optimization of few-shot generation 37 | max_shots_in_a_batch = args.batch_size_valid // args.num_samples_each_shot 38 | next_shot_id = 0 39 | finished_shots = set() # shots that have done the generation 40 | candidate_shots = {} # shots for the next generation forward 41 | 42 | with torch.inference_mode(): 43 | while len(finished_shots) < args.num_shots: 44 | """prepare initial fingerprint tokens""" 45 | extra_inputs_num = max_shots_in_a_batch - len(candidate_shots) 46 | 47 | if extra_inputs_num > 0: # need some new inputs to fill up the batch 48 | # construct inputs 49 | inputs = [] 50 | for i in range(next_shot_id, next_shot_id + extra_inputs_num): 51 | this_shot_inputs = tokenizer.encode(train_smiles[i], return_tensors="pt") 52 | if this_shot_inputs is None: # the encoding may fail 53 | extra_inputs_num -= 1 54 | continue 55 | move_tensors_to_device(this_shot_inputs, device) 56 | inputs.append(this_shot_inputs) 57 | inputs, mask = pad_collator(inputs) 58 | inputs["attention_mask"] = mask["input_ids"] 59 | 60 | # repeat fingerprint tokens 61 | fingerprint_tokens = model.encode_to_fingerprints(**inputs) 62 | fingerprint_tokens = fingerprint_tokens.repeat_interleave(args.num_samples_each_shot, dim=0) # (extra_inputs_num * num_samples_each_shot, num_fingerprints, hidden_size) 63 | each_shot_fingerprint_tokens = fingerprint_tokens.split(args.num_samples_each_shot, dim=0) # each is (num_samples_each_shot, num_fingerprints, hidden_size) 64 | 65 | # add to shots 66 | for shot_id, this_shot_fingerprint_tokens in enumerate(each_shot_fingerprint_tokens, start=next_shot_id): 67 | candidate_shots[shot_id] = { 68 | "sample_times": 0, 69 | "fingerprint": this_shot_fingerprint_tokens, 70 | "generation_results": [], 71 | } 72 | print(f"Total {len(candidate_shots)} candidate shots in this iteration!") 73 | 74 | # update next_shot_id 75 | next_shot_id += extra_inputs_num 76 | print(f"Encoded {extra_inputs_num} new fingerprints!") 77 | 78 | """aggregate fingerprint tokens for candidate shots""" 79 | shot_ids_in_batch = [] 80 | aggregated_fingerprint_tokens = [] 81 | for shot_id, candidate_shot in candidate_shots.items(): 82 | shot_ids_in_batch.append(shot_id) 83 | aggregated_fingerprint_tokens.append(candidate_shot["fingerprint"]) 84 | aggregated_fingerprint_tokens = torch.cat(aggregated_fingerprint_tokens, dim=0) # (max_shots_in_a_batch * num_samples_each_shot, num_fingerprints, hidden_size) 85 | 86 | """random sampling & generate""" 87 | # For each shot, we try to randomly sample fingerprints for at most (max_sample_times * num_samples_each_shot) times. 88 | # If the sampling trys exceed the limit, we stop sampling this shot. 89 | generate_fingerprint_tokens = torch.normal(mean=aggregated_fingerprint_tokens, std=args.sample_std) 90 | generated_results: list = model.generate_from_fingerprints( 91 | fingerprint_tokens=generate_fingerprint_tokens, 92 | bond_dict=tokenizer.bond_dict, 93 | input_ids_list=None, 94 | graph_position_ids_1_list=None, 95 | graph_position_ids_2_list=None, 96 | identifier_ids_list=None, 97 | strict_generation=True, 98 | do_sample=False, 99 | topk=1, 100 | temperature=1.0, 101 | max_atoms=None, 102 | similarity_threshold=0.5, 103 | check_first_node=True, 104 | check_atom_valence=True, 105 | fix_aromatic_bond=True, 106 | use_cache=False, 107 | save_failed=False, 108 | show_progress=True, 109 | verbose=False, 110 | ) 111 | each_shot_generate_results = split_list(generated_results, args.num_samples_each_shot) 112 | 113 | """add results to corresponding shots""" 114 | for shot_id, this_shot_generate_results in zip(shot_ids_in_batch, each_shot_generate_results): 115 | this_shot_generate_results = [move_tensors_to_device(result, "cpu") for result in this_shot_generate_results if result is not None] # remove None results 116 | candidate_shots[shot_id]["generation_results"].extend(this_shot_generate_results) 117 | candidate_shots[shot_id]["sample_times"] += 1 118 | print(f"Added {len(this_shot_generate_results)} generated results for shot {shot_id}. Now: {len(candidate_shots[shot_id]['generation_results'])}") 119 | 120 | """gather results & remove finished shots""" 121 | this_iter_finished_shot_ids = [] 122 | this_iter_finished_results = [] 123 | for shot_id, candidate_shot in tqdm(candidate_shots.items(), desc="Aggregating results"): 124 | if candidate_shot["sample_times"] >= args.max_sample_times or len(candidate_shots[shot_id]["generation_results"]) >= args.num_samples_each_shot: # exceed max trys / results are enough 125 | this_iter_finished_shot_ids.append(shot_id) 126 | this_iter_finished_results.extend(candidate_shots[shot_id]["generation_results"][:args.num_samples_each_shot]) # add at most "num_samples_each_shot" results 127 | if len(finished_shots) + len(this_iter_finished_shot_ids) >= args.num_shots: 128 | break 129 | 130 | for shot_id in this_iter_finished_shot_ids: 131 | print(f"Finished shot {shot_id}. Final samples: {len(candidate_shots[shot_id]['generation_results'])}") 132 | candidate_shots.pop(shot_id) # remove finished shots 133 | finished_shots.add(shot_id) 134 | print(f"Now total finished shots: {len(finished_shots)}") 135 | 136 | print("Decoding to SMILES...") 137 | if len(this_iter_finished_results) > 0: 138 | this_iter_finished_results_batched, mask = pad_collator(this_iter_finished_results) 139 | this_iter_finished_results_batched["attention_mask"] = mask["input_ids"] 140 | decoded_mol_list, decoded_smiles_list = tokenizer.decode(this_iter_finished_results_batched, kekulize=True, nprocs=None) 141 | # You can use the following line to accelerate the decoding. However, it may occasionally raise errors. 142 | # decoded_mol_list, decoded_smiles_list = tokenizer.decode(this_iter_finished_results_batched, kekulize=True, nprocs=args.num_processes) 143 | all_smiles.extend(decoded_smiles_list) 144 | print(f"Now total generated SMILES: {len(all_smiles)}") 145 | 146 | """get metrics""" 147 | print("Getting metrics...") 148 | metrics = moses.get_all_metrics( 149 | all_smiles, 150 | n_jobs=args.num_processes, 151 | device=device, 152 | batch_size=args.batch_size_valid, 153 | ) 154 | print(metrics) 155 | 156 | delete_file_or_dir(args.save_path) 157 | create_dir(args.save_path) 158 | save_metric_file = os.path.join(args.save_path, "metrics.json") 159 | save_json(metrics, save_metric_file) 160 | print(f"Metrics saved to {save_metric_file}") 161 | 162 | """save results""" 163 | save_all_smiles_file = os.path.join(args.save_path, f"all_smiles.txt") 164 | with open(save_all_smiles_file, "w") as f: 165 | for smiles in all_smiles: 166 | f.write(smiles + "\n") 167 | print(f"Generated SMILES saved to {save_all_smiles_file}") 168 | 169 | 170 | if __name__ == "__main__": 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument("--model_name_or_path", default="DaizeDong/GraphsGPT-1W", type=str, help="Path to the GraphsGPT hugging face model.") 173 | parser.add_argument("--save_path", default="./results/unconditional/moses", type=str, help="Path to save the evaluation results.") 174 | 175 | parser.add_argument('--batch_size_valid', default=8192, type=int, help="Number of samples per batch.") 176 | parser.add_argument("--sample_std", default=1.0, type=float, help="The standard deviation for sampling.") 177 | parser.add_argument('--max_sample_times', default=10, type=int, help="The maximum number of attempts to sample for each shot. Shots with insufficient successful generated results exceeding this number of attempts will be discarded.") 178 | parser.add_argument("--num_shots", default=100000, type=int, help="The number of shots for reference.") 179 | parser.add_argument("--num_samples_each_shot", default=1, type=int, help="The number of generated samples for each shot.") 180 | 181 | parser.add_argument("--num_processes", default=32, type=int, help="Number of parallel processes for decoding & metric calculation.") 182 | args = parser.parse_args() 183 | 184 | print(args) 185 | main(args) 186 | -------------------------------------------------------------------------------- /entrypoints/generation/evaluation/evaluate_zinc250k_few_shot_sampling.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import argparse 4 | import random 5 | import torch 6 | from tqdm import tqdm 7 | 8 | import moses 9 | from data.collate_fn import tensor_dict_stack_padding_collater 10 | from data.tokenizer import GraphsGPTTokenizer 11 | from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM 12 | from utils.io import create_dir, save_json, delete_file_or_dir 13 | from utils.operations.operation_list import split_list 14 | from utils.operations.operation_tensor import move_tensors_to_device 15 | from utils.property_scores.nspdk import get_npsdk 16 | 17 | 18 | def main(args): 19 | random.seed(0) 20 | torch.manual_seed(0) 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | """load data & model""" 24 | train_file = "data/zinc250k/all_250k_rndm_zinc_drugs_clean_3.txt" 25 | valid_file = "data/zinc250k/valid_250k_rndm_zinc_drugs_clean_3.txt" 26 | 27 | with open(train_file, "r") as file: 28 | train_smiles = [data.replace("\n", "") for data in file.readlines()] 29 | random.shuffle(train_smiles) 30 | with open(valid_file, "r") as file: 31 | valid_smiles = [data.replace("\n", "") for data in file.readlines()] 32 | 33 | pad_collator = tensor_dict_stack_padding_collater(0, tensor_keys_to_create_mask=["input_ids"]) 34 | 35 | tokenizer = GraphsGPTTokenizer.from_pretrained(args.model_name_or_path) 36 | model = GraphsGPTForCausalLM.from_pretrained(args.model_name_or_path) 37 | model.to(device) 38 | 39 | """generate according to reference fingerprints""" 40 | all_smiles = [] 41 | 42 | # initialize the variable for batch-optimization of few-shot generation 43 | max_shots_in_a_batch = args.batch_size_valid // args.num_samples_each_shot 44 | next_shot_id = 0 45 | finished_shots = set() # shots that have done the generation 46 | candidate_shots = {} # shots for the next generation forward 47 | 48 | with torch.inference_mode(): 49 | while len(finished_shots) < args.num_shots: 50 | """prepare initial fingerprint tokens""" 51 | extra_inputs_num = max_shots_in_a_batch - len(candidate_shots) 52 | 53 | if extra_inputs_num > 0: # need some new inputs to fill up the batch 54 | # construct inputs 55 | inputs = [] 56 | for i in range(next_shot_id, next_shot_id + extra_inputs_num): 57 | this_shot_inputs = tokenizer.encode(train_smiles[i], return_tensors="pt") 58 | if this_shot_inputs is None: # the encoding may fail 59 | extra_inputs_num -= 1 60 | continue 61 | move_tensors_to_device(this_shot_inputs, device) 62 | inputs.append(this_shot_inputs) 63 | inputs, mask = pad_collator(inputs) 64 | inputs["attention_mask"] = mask["input_ids"] 65 | 66 | # repeat fingerprint tokens 67 | fingerprint_tokens = model.encode_to_fingerprints(**inputs) 68 | fingerprint_tokens = fingerprint_tokens.repeat_interleave(args.num_samples_each_shot, dim=0) # (extra_inputs_num * num_samples_each_shot, num_fingerprints, hidden_size) 69 | each_shot_fingerprint_tokens = fingerprint_tokens.split(args.num_samples_each_shot, dim=0) # each is (num_samples_each_shot, num_fingerprints, hidden_size) 70 | 71 | # add to shots 72 | for shot_id, this_shot_fingerprint_tokens in enumerate(each_shot_fingerprint_tokens, start=next_shot_id): 73 | candidate_shots[shot_id] = { 74 | "sample_times": 0, 75 | "fingerprint": this_shot_fingerprint_tokens, 76 | "generation_results": [], 77 | } 78 | print(f"Total {len(candidate_shots)} candidate shots in this iteration!") 79 | 80 | # update next_shot_id 81 | next_shot_id += extra_inputs_num 82 | print(f"Encoded {extra_inputs_num} new fingerprints!") 83 | 84 | """aggregate fingerprint tokens for candidate shots""" 85 | shot_ids_in_batch = [] 86 | aggregated_fingerprint_tokens = [] 87 | for shot_id, candidate_shot in candidate_shots.items(): 88 | shot_ids_in_batch.append(shot_id) 89 | aggregated_fingerprint_tokens.append(candidate_shot["fingerprint"]) 90 | aggregated_fingerprint_tokens = torch.cat(aggregated_fingerprint_tokens, dim=0) # (max_shots_in_a_batch * num_samples_each_shot, num_fingerprints, hidden_size) 91 | 92 | """random sampling & generate""" 93 | # For each shot, we try to randomly sample fingerprints for at most (max_sample_times * num_samples_each_shot) times. 94 | # If the sampling trys exceed the limit, we stop sampling this shot. 95 | generate_fingerprint_tokens = torch.normal(mean=aggregated_fingerprint_tokens, std=args.sample_std) 96 | generated_results: list = model.generate_from_fingerprints( 97 | fingerprint_tokens=generate_fingerprint_tokens, 98 | bond_dict=tokenizer.bond_dict, 99 | input_ids_list=None, 100 | graph_position_ids_1_list=None, 101 | graph_position_ids_2_list=None, 102 | identifier_ids_list=None, 103 | strict_generation=True, 104 | do_sample=False, 105 | topk=1, 106 | temperature=1.0, 107 | max_atoms=None, 108 | similarity_threshold=0.5, 109 | check_first_node=True, 110 | check_atom_valence=True, 111 | fix_aromatic_bond=True, 112 | use_cache=False, 113 | save_failed=False, 114 | show_progress=True, 115 | verbose=False, 116 | ) 117 | each_shot_generate_results = split_list(generated_results, args.num_samples_each_shot) 118 | 119 | """add results to corresponding shots""" 120 | for shot_id, this_shot_generate_results in zip(shot_ids_in_batch, each_shot_generate_results): 121 | this_shot_generate_results = [move_tensors_to_device(result, "cpu") for result in this_shot_generate_results if result is not None] # remove None results 122 | candidate_shots[shot_id]["generation_results"].extend(this_shot_generate_results) 123 | candidate_shots[shot_id]["sample_times"] += 1 124 | print(f"Added {len(this_shot_generate_results)} generated results for shot {shot_id}. Now: {len(candidate_shots[shot_id]['generation_results'])}") 125 | 126 | """gather results & remove finished shots""" 127 | this_iter_finished_shot_ids = [] 128 | this_iter_finished_results = [] 129 | for shot_id, candidate_shot in tqdm(candidate_shots.items(), desc="Aggregating results"): 130 | if candidate_shot["sample_times"] >= args.max_sample_times or len(candidate_shots[shot_id]["generation_results"]) >= args.num_samples_each_shot: # exceed max trys / results are enough 131 | this_iter_finished_shot_ids.append(shot_id) 132 | this_iter_finished_results.extend(candidate_shots[shot_id]["generation_results"][:args.num_samples_each_shot]) # add a maximum of "num_samples_each_shot" results 133 | if len(finished_shots) + len(this_iter_finished_shot_ids) >= args.num_shots: 134 | break 135 | 136 | for shot_id in this_iter_finished_shot_ids: 137 | print(f"Finished shot {shot_id}. Final samples: {len(candidate_shots[shot_id]['generation_results'])}") 138 | candidate_shots.pop(shot_id) # remove finished shots 139 | finished_shots.add(shot_id) 140 | print(f"Now total finished shots: {len(finished_shots)}") 141 | 142 | print("Decoding to SMILES...") 143 | if len(this_iter_finished_results) > 0: 144 | this_iter_finished_results_batched, mask = pad_collator(this_iter_finished_results) 145 | this_iter_finished_results_batched["attention_mask"] = mask["input_ids"] 146 | decoded_mol_list, decoded_smiles_list = tokenizer.decode(this_iter_finished_results_batched, kekulize=True, nprocs=None) 147 | # You can use the following line to accelerate the decoding. However, it may occasionally raise errors. 148 | # decoded_mol_list, decoded_smiles_list = tokenizer.decode(this_iter_finished_results_batched, kekulize=True, nprocs=args.num_processes) 149 | all_smiles.extend(decoded_smiles_list) 150 | print(f"Now total generated SMILES: {len(all_smiles)}") 151 | 152 | """get metrics""" 153 | print("Getting metrics...") 154 | metrics = moses.get_all_metrics( 155 | all_smiles, 156 | n_jobs=args.num_processes, 157 | device=device, 158 | batch_size=args.batch_size_valid, 159 | test=valid_smiles, 160 | test_scaffolds=valid_smiles, 161 | ) 162 | 163 | # 🔍 NPSDK 164 | npsdk_mmd = get_npsdk(all_smiles, valid_smiles, n_jobs=args.num_processes) 165 | metrics["NPSDK"] = npsdk_mmd 166 | 167 | print(metrics) 168 | 169 | delete_file_or_dir(args.save_path) 170 | create_dir(args.save_path) 171 | save_metric_file = os.path.join(args.save_path, "metrics.json") 172 | save_json(metrics, save_metric_file) 173 | print(f"Metrics saved to {save_metric_file}") 174 | 175 | """save results""" 176 | save_all_smiles_file = os.path.join(args.save_path, f"all_smiles.txt") 177 | with open(save_all_smiles_file, "w") as f: 178 | for smiles in all_smiles: 179 | f.write(smiles + "\n") 180 | print(f"Generated SMILES saved to {save_all_smiles_file}") 181 | 182 | 183 | if __name__ == "__main__": 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument("--model_name_or_path", default="DaizeDong/GraphsGPT-1W", type=str, help="Path to the GraphsGPT hugging face model.") 186 | parser.add_argument("--save_path", default="./results/unconditional/zinc250k", type=str, help="Path to save the evaluation results.") 187 | 188 | parser.add_argument('--batch_size_valid', default=8192, type=int, help="Number of samples per batch.") 189 | parser.add_argument("--sample_std", default=1.0, type=float, help="The standard deviation for sampling.") 190 | parser.add_argument('--max_sample_times', default=10, type=int, help="The maximum number of attempts to sample for each shot. Shots with insufficient successful generated results exceeding this number of attempts will be discarded.") 191 | parser.add_argument("--num_shots", default=100000, type=int, help="The number of shots for reference.") 192 | parser.add_argument("--num_samples_each_shot", default=1, type=int, help="The number of generated samples for each shot.") 193 | 194 | parser.add_argument("--num_processes", default=32, type=int, help="Number of parallel processes for decoding & metric calculation.") 195 | args = parser.parse_args() 196 | 197 | print(args) 198 | main(args) 199 | -------------------------------------------------------------------------------- /entrypoints/generation/unconditional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/generation/unconditional/__init__.py -------------------------------------------------------------------------------- /entrypoints/generation/unconditional/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from argparse import ArgumentParser 4 | from tqdm import tqdm 5 | 6 | from data.tokenizer import GraphsGPTTokenizer 7 | from utils.io import delete_file_or_dir, create_dir, save_mol_png, save_empty_png 8 | from utils.operations.operation_string import str2bool 9 | 10 | 11 | def main(args): 12 | tokenizer = GraphsGPTTokenizer.from_pretrained(args.model_name_or_path) 13 | 14 | file_list = os.listdir(args.generation_results_dir) 15 | file_list = sorted(file_list) 16 | file_list = file_list[args.file_begin_index:args.file_end_index] 17 | 18 | """visualize""" 19 | for file_name in tqdm(file_list): 20 | success_cnt = 0 21 | invalid_cnt = 0 22 | fail_cnt = 0 23 | smiles_list = [] 24 | 25 | save_dir = os.path.join(args.save_dir, f"{file_name.split('.')[0]}") 26 | delete_file_or_dir(save_dir) 27 | create_dir(save_dir) 28 | 29 | """read file""" 30 | file_path = os.path.join(args.generation_results_dir, file_name) 31 | with open(file_path, 'rb') as f: 32 | result_list = pickle.load(f) 33 | if not isinstance(result_list, (list, tuple)): 34 | result_list = [result_list] 35 | 36 | """save png""" 37 | for i, result in enumerate(result_list): 38 | print(i) 39 | save_img_file = os.path.join(save_dir, f"{i}.png") 40 | 41 | if result is not None: 42 | mol, smiles = tokenizer.decode(result, kekulize=True) # kekulize the molecule for score calculation 43 | 44 | if mol is None: 45 | smiles_list.append(None) 46 | if args.save_images: 47 | save_empty_png(save_img_file) 48 | invalid_cnt += 1 49 | else: 50 | smiles_list.append(smiles) 51 | if args.save_images: 52 | save_mol_png(mol, save_img_file) 53 | # print(f"Molecule '{smiles}' saved to '{save_img_file}'.") 54 | success_cnt += 1 55 | else: 56 | smiles_list.append(None) 57 | if args.save_images: 58 | save_empty_png(save_img_file) 59 | fail_cnt += 1 60 | 61 | """save statistics""" 62 | with open(os.path.join(save_dir, "count.txt"), 'a') as f: 63 | f.write(f"Success count: {success_cnt}\n") 64 | f.write(f"Invalid count: {invalid_cnt}\n") 65 | f.write(f"Fail count: {fail_cnt}\n") 66 | 67 | with open(os.path.join(save_dir, "smiles.txt"), 'a') as f: 68 | for smiles in smiles_list: 69 | f.write(f"{smiles}\n") 70 | 71 | 72 | if __name__ == '__main__': 73 | parser = ArgumentParser() 74 | parser.add_argument("--model_name_or_path", default="DaizeDong/GraphsGPT-1W", type=str) 75 | parser.add_argument('--generation_results_dir', default=None, type=str) 76 | parser.add_argument('--save_dir', default=None, type=str) 77 | parser.add_argument('--save_images', default="True", type=str) 78 | 79 | parser.add_argument('--file_begin_index', default=0, type=int) 80 | parser.add_argument('--file_end_index', default=2, type=int) 81 | args = parser.parse_args() 82 | args.save_images = str2bool(args.save_images) 83 | print(args) 84 | main(args) 85 | -------------------------------------------------------------------------------- /entrypoints/representation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/representation/__init__.py -------------------------------------------------------------------------------- /entrypoints/representation/finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import math 4 | import os 5 | import pytorch_lightning as pl 6 | import pytorch_lightning.callbacks as plc 7 | import pytorch_lightning.loggers as plog 8 | import sys 9 | import torch 10 | import warnings 11 | from omegaconf import OmegaConf 12 | from pytorch_lightning.trainer import Trainer 13 | 14 | sys.path.append(os.getcwd()) 15 | warnings.filterwarnings("ignore") 16 | 17 | from utils.representation.logger import SetupCallback 18 | from utils.representation.data_interface import DInterface 19 | from utils.representation.model_interface import MInterface 20 | 21 | model_zoom = { 22 | 'tox21': ['./data_finetune/model_zoom/tox21/tox21.yaml', 23 | './data_finetune/model_zoom/tox21/tox21.pth'], 24 | 'toxcast': ['./data_finetune/model_zoom/toxcast/toxcast.yaml', 25 | './data_finetune/model_zoom/toxcast/toxcast.pth'], 26 | 'bbbp': ['./data_finetune/model_zoom/bbbp/bbbp.yaml', 27 | './data_finetune/model_zoom/bbbp/bbbp.pth'], 28 | 'sider': ['./data_finetune/model_zoom/sider/sider.yaml', 29 | './data_finetune/model_zoom/sider/sider.pth'], 30 | 'hiv': ['./data_finetune/model_zoom/hiv/hiv.yaml', 31 | './data_finetune/model_zoom/hiv/hiv.pth'], 32 | 'bace': ['./data_finetune/model_zoom/bace/bace.yaml', 33 | './data_finetune/model_zoom/bace/bace.pth'] 34 | } 35 | 36 | 37 | def create_parser(): 38 | parser = argparse.ArgumentParser() 39 | # Set-up parameters 40 | parser.add_argument('--res_dir', default="./results/representation", type=str) 41 | parser.add_argument('--ex_name', default='debug', type=str) 42 | parser.add_argument('--check_val_every_n_epoch', default=1, type=int) 43 | 44 | parser.add_argument('--task_name', default='bace', choices=['bace', 'bbbp', 'clintox', 'sider', 'tox21', 'toxcast', 'hiv', 'esol', 'freesolv', 'lipo']) 45 | parser.add_argument('--mixup_strategy', default='no_mix_pretrain', choices=['mix_embed', 'mix_graph', 'no_mix_pretrain', 'no_mix_vanilla']) 46 | parser.add_argument('--lr_scheduler', default='cosine') 47 | parser.add_argument('--offline', default=0, type=int) 48 | 49 | # dataset parameters 50 | parser.add_argument('--num_workers', default=8, type=int) 51 | parser.add_argument('--seed', default=1, type=int) 52 | parser.add_argument('--count', default=0, type=int) 53 | parser.add_argument('--multiple_conformation', default=1, type=int) 54 | parser.add_argument('--only_polar', default=-1, type=int) 55 | parser.add_argument('--remove_polar_hydrogen', default=False, type=bool) 56 | parser.add_argument('--remove_hydrogen', default=False, type=bool) 57 | parser.add_argument('--fingerprint', default='graphsgpt', type=str) 58 | parser.add_argument('--data', default='./data_finetune/molecular_property_prediction', type=str) 59 | parser.add_argument('--conf_size', default=11, type=int) 60 | parser.add_argument('--max_atoms', default=256, type=int) 61 | parser.add_argument('--self_prob', default=0.1, type=float) 62 | parser.add_argument('--no_shuffle', default=False, type=bool) 63 | parser.add_argument('--mix_times', default=1, type=int) 64 | 65 | # Training parameters 66 | parser.add_argument('--batch_size', default=128, type=int) 67 | parser.add_argument('--epoch', default=50, type=int, help='end epoch') 68 | parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate') 69 | parser.add_argument('--warmup_ratio', default=0.06, type=float, help='warmup rate') 70 | parser.add_argument('--ema_decay', default=0.999, type=float, help='warmup rate') 71 | parser.add_argument('--pos_weight', default=99, type=float, help='warmup rate') 72 | 73 | # Model parameters 74 | parser.add_argument('--encoder_layers', default=15, type=int) 75 | parser.add_argument('--embed_dim', default=512, type=int) 76 | parser.add_argument('--ffn_embed_dim', default=2048, type=int) 77 | parser.add_argument('--attention_heads', default=64, type=int) 78 | parser.add_argument('--emb_dropout', default=0.1, type=float) 79 | parser.add_argument('--dropout', default=0.1, type=float) 80 | parser.add_argument('--attention_dropout', default=0.1, type=float) 81 | parser.add_argument('--activation_dropout', default=0.0, type=float) 82 | parser.add_argument('--max_seq_len', default=512, type=int) 83 | parser.add_argument('--activation_fn', default='gelu', type=str) 84 | parser.add_argument('--post_ln', default=False, type=bool) 85 | parser.add_argument('--no_final_head_layer_norm', default=True, type=bool) 86 | parser.add_argument('--mixup_alpha', default=1.0, type=float) 87 | parser.add_argument('--beta', default=1.0, type=float) 88 | parser.add_argument('--encoder_embed_dim', default=512, type=int) 89 | parser.add_argument('--pooler_dropout', default=0.0, type=float) 90 | parser.add_argument('--num_classes', default=2, type=int) # need to be changed 91 | parser.add_argument('--loss_type', default='mixup_multi_task_BCE', type=str) # need to be changed 92 | parser.add_argument('--checkpoint_metric', default='test_auc', type=str) # need to be changed 93 | args = parser.parse_args() 94 | return args 95 | 96 | 97 | def load_callbacks(args): 98 | callbacks = [] 99 | 100 | logdir = str(os.path.join(args.res_dir, args.ex_name)) 101 | 102 | ckptdir = os.path.join(logdir, "checkpoints") 103 | 104 | modes = {'test_auc': 'max', 'test_rmse': 'min', 'test_mae': 'min'} 105 | 106 | metric = args.checkpoint_metric 107 | sv_filename = 'best-{epoch:02d}-{' + metric + ':.3f}' 108 | callbacks.append(plc.ModelCheckpoint( 109 | monitor=metric, 110 | filename=sv_filename, 111 | save_top_k=5, 112 | mode=modes[metric], 113 | save_last=True, 114 | dirpath=ckptdir, 115 | verbose=True, 116 | every_n_epochs=args.check_val_every_n_epoch, 117 | )) 118 | 119 | now = datetime.datetime.now().strftime("%m-%dT%H-%M-%S") 120 | cfgdir = os.path.join(logdir, "configs") 121 | callbacks.append( 122 | SetupCallback( 123 | now=now, 124 | logdir=logdir, 125 | ckptdir=ckptdir, 126 | cfgdir=cfgdir, 127 | config=args.__dict__, 128 | argv_content=sys.argv + ["gpus: {}".format(torch.cuda.device_count())], ) 129 | ) 130 | 131 | if args.lr_scheduler: 132 | callbacks.append(plc.LearningRateMonitor( 133 | logging_interval=None)) 134 | return callbacks 135 | 136 | 137 | def main(): 138 | args = create_parser() 139 | params = OmegaConf.load(model_zoom[args.task_name][0]) 140 | config = args.__dict__ 141 | config.update(params) 142 | logger = plog.WandbLogger( 143 | project='graphsgpt_mixup2', 144 | name=args.ex_name, 145 | save_dir=str(os.path.join(args.res_dir, args.ex_name)), 146 | offline=True, 147 | id="_".join(args.ex_name.split("/")), 148 | entity="gaozhangyang" 149 | ) 150 | 151 | pl.seed_everything(args.seed) 152 | 153 | data_module = DInterface(**vars(args)) 154 | data_module.setup() 155 | 156 | gpu_count = torch.cuda.device_count() 157 | args.steps_per_epoch = math.ceil(len(data_module.trainset) / args.batch_size / gpu_count) 158 | print(f"steps_per_epoch {args.steps_per_epoch}, gpu_count {gpu_count}, batch_size{args.batch_size}") 159 | 160 | model = MInterface(**vars(args)) 161 | params = torch.load(model_zoom[args.task_name][1]) 162 | params = {k.replace('_forward_module.', ''): v for k, v in params.items()} 163 | model.load_state_dict(params) 164 | 165 | trainer_config = { 166 | 'gpus': -1, # Use all available GPUs 167 | 'max_epochs': args.epoch, # Maximum number of epochs to train for 168 | 'num_nodes': 1, # Number of nodes to use for distributed training 169 | "strategy": 'deepspeed_stage_2', # 'ddp', 'deepspeed_stage_2 170 | "precision": '32', # "bf16", 16 171 | 'accelerator': 'gpu', # Use distributed data parallel 172 | 'callbacks': load_callbacks(args), 173 | 'logger': logger, 174 | 'gradient_clip_val': 1.0 175 | } 176 | 177 | trainer_opt = argparse.Namespace(**trainer_config) 178 | trainer = Trainer.from_argparse_args(trainer_opt) 179 | 180 | trainer.test(model, data_module) 181 | print(trainer_config) 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /jupyter_notebooks/analysis/README-Clustering.md: -------------------------------------------------------------------------------- 1 | # Hyper-Parameters for Clustering 2 | 3 | Total number of samples should be set to **32768**. 4 | 5 | 6 | 7 | ## UMAP 8 | 9 | ### For Clustering 10 | 11 | | Model | total_vis_sample_num | n_neighbors | min_dist | n_components | 12 | |--------------|----------------------|-------------|----------|--------------| 13 | | GraphsGPT-1W | 32768 | 40 | 0.05 | 2 | 14 | | GraphsGPT-2W | 32768 | 40 | 0.05 | 2 | 15 | | GraphsGPT-4W | 32768 | 100 | 0.05 | 2 | 16 | | GraphsGPT-8W | 32768 | 100 | 0.05 | 2 | 17 | 18 | ### For Visualization 19 | 20 | | Model | total_vis_sample_num | n_neighbors | min_dist | n_components | 21 | |--------------|----------------------|-------------|----------|--------------| 22 | | GraphsGPT-1W | 32768 | 40 | 0.8 | 2 | 23 | | GraphsGPT-2W | 32768 | 40 | 0.8 | 2 | 24 | | GraphsGPT-4W | 32768 | 40 | 0.7 | 2 | 25 | | GraphsGPT-8W | 32768 | 40 | 0.7 | 2 | 26 | 27 | 28 | 29 | ## HDBSCAN 30 | 31 | | Model | min_cluster_size | min_samples | cluster_selection_epsilon | alpha | 32 | |--------------|------------------|-------------|---------------------------|-------| 33 | | GraphsGPT-1W | 48 | 64 | 0.25 | 1.0 | 34 | | GraphsGPT-2W | 48 | 64 | 0.25 | 1.0 | 35 | | GraphsGPT-4W | 32 | 48 | 0.2 | 1.0 | 36 | | GraphsGPT-8W | 32 | 48 | 0.2 | 1.0 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /jupyter_notebooks/example_pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": "# Pipeline of GraphsGPT with Hugging Face Transformers", 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "id": "3e188af549c37dc3" 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "source": "### Configurations", 14 | "metadata": { 15 | "collapsed": false 16 | }, 17 | "id": "e5ecf87e00701167" 18 | }, 19 | { 20 | "cell_type": "code", 21 | "source": [ 22 | "import torch\n", 23 | "\n", 24 | "model_name_or_path = \"DaizeDong/GraphsGPT-8W\"\n", 25 | "smiles_file = \"../data/examples/zinc_example.txt\"\n", 26 | "\n", 27 | "batch_size = 1024\n", 28 | "max_batches = 4\n", 29 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" 30 | ], 31 | "metadata": { 32 | "collapsed": false 33 | }, 34 | "id": "ef4b44c94c086bd1", 35 | "outputs": [], 36 | "execution_count": null 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "source": [ 41 | "### Load SMILES" 42 | ], 43 | "metadata": { 44 | "collapsed": false 45 | }, 46 | "id": "84357216224aed01" 47 | }, 48 | { 49 | "cell_type": "code", 50 | "source": [ 51 | "with open(smiles_file, \"r\", encoding=\"utf-8\") as f:\n", 52 | " smiles_list = f.readlines()\n", 53 | "smiles_list = [smiles.removesuffix(\"\\n\") for smiles in smiles_list]\n", 54 | "\n", 55 | "print(f\"Total SMILES loaded: {len(smiles_list)}\")\n", 56 | "for i in range(10):\n", 57 | " print(f\"Example SMILES {i}: {smiles_list[i]}\")" 58 | ], 59 | "metadata": { 60 | "collapsed": false 61 | }, 62 | "id": "1c903203360b16", 63 | "outputs": [], 64 | "execution_count": null 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "source": [ 69 | "### Load Model & Tokenizer" 70 | ], 71 | "metadata": { 72 | "collapsed": false 73 | }, 74 | "id": "4263d29cbbe85ea1" 75 | }, 76 | { 77 | "cell_type": "code", 78 | "source": [ 79 | "from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM\n", 80 | "from data.tokenizer import GraphsGPTTokenizer\n", 81 | "\n", 82 | "model = GraphsGPTForCausalLM.from_pretrained(model_name_or_path)\n", 83 | "tokenizer = GraphsGPTTokenizer.from_pretrained(model_name_or_path)\n", 84 | "\n", 85 | "print(model.state_dict().keys())\n", 86 | "print(f\"Total paramerters: {sum(x.numel() for x in model.parameters())}\")" 87 | ], 88 | "metadata": { 89 | "collapsed": false 90 | }, 91 | "id": "eed2630d3a5bf75a", 92 | "outputs": [], 93 | "execution_count": null 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "source": "### Encode SMILES into Fingerprint Embeddings (Graph Words)", 98 | "metadata": { 99 | "collapsed": false 100 | }, 101 | "id": "6a43509f85c9ef0c" 102 | }, 103 | { 104 | "cell_type": "code", 105 | "source": [ 106 | "from utils.operations.operation_tensor import move_tensors_to_device\n", 107 | "from utils.operations.operation_list import split_list_with_yield\n", 108 | "\n", 109 | "batch_count = 0\n", 110 | "fingerprints_lists = []\n", 111 | "\n", 112 | "model.to(device)\n", 113 | "model.eval()\n", 114 | "with torch.no_grad():\n", 115 | " for batched_smiles in split_list_with_yield(smiles_list, batch_size):\n", 116 | " inputs = tokenizer.batch_encode(batched_smiles, return_tensors=\"pt\")\n", 117 | " move_tensors_to_device(inputs, device)\n", 118 | "\n", 119 | " fingerprint_tokens = model.encode_to_fingerprints(**inputs) # (batch_size, num_fingerprints, hidden_dim)\n", 120 | " fingerprints_lists.append(fingerprint_tokens)\n", 121 | "\n", 122 | " batch_count += 1\n", 123 | " if batch_count >= max_batches:\n", 124 | " break\n", 125 | "\n", 126 | "print(f\"Encoded total {batch_count * batch_size} molecules\")" 127 | ], 128 | "metadata": { 129 | "collapsed": false 130 | }, 131 | "id": "7a04d04186c8a9c", 132 | "outputs": [], 133 | "execution_count": null 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "source": "### Recover Molecule Sequences through Generation", 138 | "metadata": { 139 | "collapsed": false 140 | }, 141 | "id": "3ec3683b0fa19abb" 142 | }, 143 | { 144 | "cell_type": "code", 145 | "source": [ 146 | "all_results = []\n", 147 | "\n", 148 | "for fingerprints in fingerprints_lists:\n", 149 | " generation_result = model.generate_from_fingerprints(\n", 150 | " fingerprint_tokens=fingerprints,\n", 151 | " bond_dict=tokenizer.bond_dict,\n", 152 | " strict_generation=True,\n", 153 | " max_atoms=None,\n", 154 | " similarity_threshold=0.5,\n", 155 | " check_first_node=True,\n", 156 | " check_atom_valence=False,\n", 157 | " fix_aromatic_bond=False,\n", 158 | " use_cache=False,\n", 159 | " save_failed=True, # save the generated partial result even the full generation failed\n", 160 | " show_progress=True,\n", 161 | " verbose=True,\n", 162 | " )\n", 163 | " all_results.extend(generation_result)\n", 164 | "\n", 165 | "print(\"Done.\")\n", 166 | "print(f\"#### Generated {len(all_results)} molecules\")" 167 | ], 168 | "metadata": { 169 | "collapsed": false 170 | }, 171 | "id": "8b5324f8eb8e35de", 172 | "outputs": [], 173 | "execution_count": null 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "source": [ 178 | "### Decode Sequences back to SMILES" 179 | ], 180 | "metadata": { 181 | "collapsed": false 182 | }, 183 | "id": "921df1d54d1bb705" 184 | }, 185 | { 186 | "cell_type": "code", 187 | "source": [ 188 | "from rdkit.Chem import Draw\n", 189 | "\n", 190 | "\n", 191 | "def show_mol_png(mol, size=(512, 512)):\n", 192 | " img = Draw.MolToImage(mol, size=size)\n", 193 | " img.show()\n", 194 | " img.close()\n", 195 | "\n", 196 | "\n", 197 | "decoded_mols = []\n", 198 | "decoded_smiles = []\n", 199 | "\n", 200 | "for result in all_results:\n", 201 | " if result is not None:\n", 202 | " mol, smiles = tokenizer.decode(result)\n", 203 | " decoded_mols.append(mol)\n", 204 | " decoded_smiles.append(smiles)\n", 205 | " else:\n", 206 | " decoded_mols.append(None)\n", 207 | " decoded_smiles.append(None)\n", 208 | "\n", 209 | "# visualize the first 10 results\n", 210 | "for i in range(10):\n", 211 | " print(f\"Original SMILES {i}: {smiles_list[i]}\")\n", 212 | " print(f\"Decoded SMILES {i}: {decoded_smiles[i]}\")\n", 213 | " show_mol_png(decoded_mols[i])" 214 | ], 215 | "metadata": { 216 | "collapsed": false 217 | }, 218 | "id": "ef67a4e812de11e3", 219 | "outputs": [], 220 | "execution_count": null 221 | }, 222 | { 223 | "cell_type": "code", 224 | "source": [], 225 | "metadata": { 226 | "collapsed": false 227 | }, 228 | "id": "c86d3fce01319503", 229 | "outputs": [], 230 | "execution_count": null 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 2 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython2", 249 | "version": "2.7.6" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 5 254 | } 255 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/models/__init__.py -------------------------------------------------------------------------------- /models/graphsgpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/models/graphsgpt/__init__.py -------------------------------------------------------------------------------- /models/graphsgpt/configuration_graphsgpt.py: -------------------------------------------------------------------------------- 1 | """GraphsGPT Model Configuration""" 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.utils import logging 5 | 6 | logger = logging.get_logger(__name__) 7 | 8 | 9 | class GraphsGPTConfig(PretrainedConfig): 10 | model_type = "graphs_gpt" 11 | 12 | def __init__( 13 | self, 14 | atom_vocab_size=118, # number of atoms 15 | bond_vocab_size=92, # number of bonds 16 | pad_token_id=0, 17 | share_embeddings=False, 18 | # --------------------- # 19 | node_loss_weight=1.0, 20 | connection_loss_weight=1.0, 21 | connection_loss_type="contrastive", # classification contrastive 22 | adaptive_position_length=False, 23 | # --------------------- # 24 | num_fingerprints=8, 25 | position_feature_size=128, 26 | hidden_size=512, 27 | intermediate_size=2048, 28 | num_hidden_layers=8, 29 | num_attention_heads=8, 30 | hidden_act="silu", 31 | # --------------------- # 32 | initializer_method="hidden", # manual hidden hidden-layer 33 | initializer_range=0.02, # useful only when "initializer_method" is "manual" 34 | rms_norm_eps=1e-6, 35 | gradient_checkpointing=False, 36 | **kwargs, 37 | ): 38 | self.atom_vocab_size = atom_vocab_size 39 | self.bond_vocab_size = bond_vocab_size 40 | self.share_embeddings = share_embeddings 41 | 42 | self.node_loss_weight = node_loss_weight 43 | self.connection_loss_weight = connection_loss_weight 44 | self.connection_loss_type = connection_loss_type 45 | self.adaptive_position_length = adaptive_position_length 46 | 47 | self.num_fingerprints = num_fingerprints 48 | self.position_feature_size = position_feature_size 49 | self.hidden_size = hidden_size 50 | self.intermediate_size = intermediate_size 51 | self.num_hidden_layers = num_hidden_layers 52 | self.num_attention_heads = num_attention_heads 53 | self.hidden_act = hidden_act 54 | 55 | self.initializer_method = initializer_method 56 | self.initializer_range = initializer_range 57 | self.rms_norm_eps = rms_norm_eps 58 | self.gradient_checkpointing = gradient_checkpointing 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | tie_word_embeddings=False, 63 | # **kwargs, 64 | ) 65 | -------------------------------------------------------------------------------- /models/graphsgpt/generation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Tuple 4 | 5 | from utils.algorithms import FindCycles 6 | 7 | VALENCE_LIMIT = { 8 | # ATOMIC_NUM: MAX_VALENCE 9 | 5: 3, # B 10 | 6: 4, # C 11 | 7: 3, # N 12 | 8: 2, # 0 13 | 9: 1, # F 14 | 14: 4, # Si 15 | 16: 6, # S 16 | 17: 1, # Cl 17 | 35: 1, # Br 18 | 53: 1, # I 19 | } 20 | 21 | 22 | def get_atom_ids_from_bond_id(inverse_bond_dict, bond_id: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 23 | atom_ids = inverse_bond_dict[bond_id.item()][:2] 24 | return ( 25 | torch.tensor(atom_ids[0], dtype=bond_id.dtype, device=bond_id.device).unsqueeze(0), 26 | torch.tensor(atom_ids[1], dtype=bond_id.dtype, device=bond_id.device).unsqueeze(0) 27 | ) 28 | 29 | 30 | def get_another_atom_id_from_existing_bond(inverse_bond_dict, bond_id: torch.Tensor, connected_atom_id_1: torch.Tensor) -> torch.Tensor: 31 | # bond_id: id in the bond_dict 32 | # connected_atom_id_1: atomic_num of the existing connected atom 33 | info = inverse_bond_dict[bond_id.item()] # (min_atomic_num, max_atomic_num, bond_type) 34 | if not (connected_atom_id_1.item() == info[0] or connected_atom_id_1.item() == info[1]): 35 | raise ValueError 36 | elif connected_atom_id_1.item() == info[0]: 37 | return torch.tensor(info[1], dtype=bond_id.dtype, device=bond_id.device).unsqueeze(0) 38 | else: 39 | return torch.tensor(info[0], dtype=bond_id.dtype, device=bond_id.device).unsqueeze(0) 40 | 41 | 42 | def check_bond_connectivity_begin(inverse_bond_dict, bond_id, atom_id): 43 | # bond_id: id in the bond_dict 44 | # atom_id: atomic_num 45 | info = inverse_bond_dict[bond_id] # (min_atomic_num, max_atomic_num, bond_type) 46 | if atom_id == info[0] or atom_id == info[1]: # "+1" for getting the atomic num 47 | return True 48 | else: 49 | return False 50 | 51 | 52 | def check_bond_connectivity_both_sides(inverse_bond_dict, bond_id, atom_id1, atom_id2): 53 | # bond_id: id in the bond_dict 54 | # atom_id: atomic_num 55 | if atom_id1 <= atom_id2: 56 | min_atomic_num = atom_id1 57 | max_atomic_num = atom_id2 58 | else: 59 | min_atomic_num = atom_id2 60 | max_atomic_num = atom_id1 61 | 62 | info = inverse_bond_dict[bond_id] # (min_atomic_num, max_atomic_num, bond_type) 63 | if min_atomic_num == info[0] and max_atomic_num == info[1]: 64 | return True 65 | else: 66 | return False 67 | 68 | 69 | def check_bond_in_graph(graph_position_ids_1, graph_position_ids_2, connection_1, connection_2): 70 | connection_1_in_1 = (graph_position_ids_1 == connection_1) 71 | connection_1_in_2 = (graph_position_ids_2 == connection_1) 72 | connection_2_in_1 = (graph_position_ids_1 == connection_2) 73 | connection_2_in_2 = (graph_position_ids_2 == connection_2) 74 | if torch.any(connection_1_in_1 & connection_2_in_2) or torch.any(connection_1_in_2 & connection_2_in_1): 75 | return True 76 | else: 77 | return False 78 | 79 | 80 | def get_valence(input_ids, graph_position_ids_1, graph_position_ids_2, connection_id, inverse_bond_dict, bond_mask): 81 | """ 82 | Return the total valence for an atom by checking its connected bonds 83 | For each connected aromatic bond with 1.5 valence, the real valence for the atom can be adjusted by a value of ±0.5 84 | """ 85 | bond_ids_in_dict = input_ids[bond_mask] - 1 86 | bond_connections_1 = graph_position_ids_1[bond_mask] 87 | bond_connections_2 = graph_position_ids_2[bond_mask] 88 | bond_connections_mask = (bond_connections_1 == connection_id) | (bond_connections_2 == connection_id) 89 | connected_bond_ids = bond_ids_in_dict[bond_connections_mask].tolist() 90 | 91 | total_valence = 0 92 | valence_offset = 0 93 | for bond_type in connected_bond_ids: 94 | this_valence = inverse_bond_dict[bond_type][2] 95 | total_valence += this_valence 96 | if this_valence == 1.5: # aromatic bond 97 | valence_offset += 0.5 98 | 99 | if int(valence_offset / 0.5) % 2 == 0: 100 | total_valence += int(valence_offset / 0.5) % 2 101 | valence_offset = 0 102 | else: 103 | valence_offset = 0.5 104 | 105 | return total_valence, valence_offset 106 | 107 | 108 | def fix_dissociative_aromatic_bond(input_ids, graph_position_ids_1, graph_position_ids_2, identifier_ids, inverse_bond_dict, bond_dict): 109 | """ 110 | Replace invalid aromatic bonds with corresponding single/double/triple bonds. 111 | Invalid aromatic bonds: non-ring & incomplete ring 112 | """ 113 | atom_num = torch.sum(identifier_ids).item() 114 | bond_mask = ~identifier_ids 115 | 116 | indices_all_positions = torch.arange(input_ids.shape[0], device=input_ids.device, dtype=torch.int64) 117 | indices_bonds = indices_all_positions[bond_mask] 118 | 119 | # Create connection -> bond_type mapping 120 | bond_ids_in_dict = (input_ids.clone()[bond_mask] - 1).tolist() 121 | 122 | bond_connections_1 = graph_position_ids_1[bond_mask].tolist() 123 | bond_connections_2 = graph_position_ids_2[bond_mask].tolist() 124 | bond_connections_all = [ 125 | (bond_connections_1[i], bond_connections_2[i]) if bond_connections_1[i] <= bond_connections_2[i] else (bond_connections_2[i], bond_connections_1[i]) # sort node id 126 | for i in range(len(bond_connections_1)) 127 | ] 128 | 129 | connection_bond_mapping = {} 130 | connection_index_mapping = {} # record the positions for each connection 131 | 132 | for i, connection in enumerate(bond_connections_all): 133 | connection_bond_mapping[connection] = inverse_bond_dict[bond_ids_in_dict[i]] 134 | connection_index_mapping[connection] = indices_bonds[i] 135 | 136 | # Find cycles in the molecule 137 | adjacency_matrix = np.zeros((atom_num, atom_num), dtype=np.int8) 138 | for connection in bond_connections_all: 139 | adjacency_matrix[connection[0], connection[1]] = 1 140 | adjacency_matrix[connection[1], connection[0]] = 1 141 | cycle_finder = FindCycles(adjacency_matrix) 142 | cycles = cycle_finder.find_cycles() 143 | 144 | # Check the bonds in each cycle 145 | # If there is any cycle with all aromatic bonds, then all bonds in it are marked as valid 146 | valid_aromatic_connections = set() 147 | 148 | for cycle in cycles: 149 | is_aromatic = [] 150 | not_aromatic = [] 151 | 152 | # Check the validity 153 | for node_id in range(len(cycle)): 154 | if node_id < len(cycle) - 1: 155 | connection = (cycle[node_id], cycle[node_id + 1]) if cycle[node_id] <= cycle[node_id + 1] else (cycle[node_id + 1], cycle[node_id]) 156 | else: 157 | connection = (cycle[node_id], cycle[0]) if cycle[node_id] <= cycle[0] else (cycle[0], cycle[node_id]) 158 | 159 | begin_atomic_num, end_atomic_num, bond_type = connection_bond_mapping[connection] 160 | 161 | if bond_type == 1.5: 162 | is_aromatic.append(connection) 163 | else: 164 | not_aromatic.append(connection) 165 | 166 | if len(not_aromatic) == 0: # all bonds are aromatic 167 | for connection in is_aromatic: 168 | valid_aromatic_connections.add(connection) 169 | 170 | # Change invalid aromatic bonds into single/double/triple bonds 171 | for connection in bond_connections_all: 172 | begin_atomic_num, end_atomic_num, bond_type = connection_bond_mapping[connection] 173 | if bond_type == 1.5 and connection not in valid_aromatic_connections: # invalid aromatic 174 | index = connection_index_mapping[connection] 175 | if (begin_atomic_num, end_atomic_num, 1.0) in bond_dict: # single 176 | new_bond_id = bond_dict[(begin_atomic_num, end_atomic_num, 1.0)] + 1 177 | elif (begin_atomic_num, end_atomic_num, 2.0) in bond_dict: # double 178 | new_bond_id = bond_dict[(begin_atomic_num, end_atomic_num, 2.0)] + 1 179 | elif (begin_atomic_num, end_atomic_num, 3.0) in bond_dict: # triple 180 | new_bond_id = bond_dict[(begin_atomic_num, end_atomic_num, 3.0)] + 1 181 | else: # this bond is incorrigible! 182 | continue 183 | input_ids[index] = new_bond_id 184 | 185 | return input_ids, graph_position_ids_1, graph_position_ids_2, identifier_ids 186 | -------------------------------------------------------------------------------- /models/graphsgpt/orf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | @torch.no_grad() 10 | def orthogonal_matrix_chunk(cols, device=None): 11 | unstructured_block = torch.randn((cols, cols), device=device) 12 | q, r = torch.linalg.qr(unstructured_block, mode='reduced') 13 | return q.transpose(0, 1) # [cols, cols] 14 | 15 | 16 | @torch.no_grad() 17 | def gaussian_orthogonal_random_matrix(nb_columns, nb_rows, random_shuffle=False, device=None, dtype=torch.float32): 18 | """create 2D Gaussian orthogonal matrix""" 19 | nb_full_blocks = int(nb_rows / nb_columns) 20 | 21 | block_list = [] 22 | 23 | for _ in range(nb_full_blocks): 24 | q = orthogonal_matrix_chunk(nb_columns, device=device) 25 | block_list.append(q) 26 | 27 | remaining_rows = nb_rows - nb_full_blocks * nb_columns 28 | if remaining_rows > 0: 29 | # q = orthogonal_matrix_chunk_batched(nb_samples, nb_columns, device=device) 30 | block_list.append(torch.zeros((nb_columns, remaining_rows), device=device)) 31 | 32 | final_matrix = torch.cat(block_list, dim=1).type(dtype) 33 | final_matrix = F.normalize(final_matrix, p=2, dim=1) 34 | 35 | if random_shuffle: 36 | _, indices = torch.rand((final_matrix.shape[1],), device=device).sort(dim=0) 37 | indices = indices.unsqueeze(0).expand(final_matrix.shape) 38 | final_matrix = torch.gather(final_matrix, dim=1, index=indices) 39 | 40 | return final_matrix # (nb_columns, nb_rows) 41 | 42 | 43 | @torch.no_grad() 44 | def orthogonal_matrix_chunk_batched(bsz, cols, device=None): 45 | unstructured_block = torch.randn((bsz, cols, cols), device=device) 46 | q, r = torch.linalg.qr(unstructured_block, mode='reduced') 47 | return q.transpose(1, 2) # [bsz, cols, cols] 48 | 49 | 50 | @torch.no_grad() 51 | def gaussian_orthogonal_random_matrix_batched(nb_samples, nb_columns, nb_rows, random_shuffle=False, device=None, dtype=torch.float32): 52 | """create 2D Gaussian orthogonal matrix""" 53 | nb_full_blocks = int(nb_rows / nb_columns) 54 | 55 | block_list = [] 56 | 57 | for _ in range(nb_full_blocks): 58 | q = orthogonal_matrix_chunk_batched(nb_samples, nb_columns, device=device) 59 | block_list.append(q) 60 | 61 | remaining_rows = nb_rows - nb_full_blocks * nb_columns 62 | if remaining_rows > 0: 63 | # q = orthogonal_matrix_chunk_batched(nb_samples, nb_columns, device=device) 64 | block_list.append(torch.zeros((nb_samples, nb_columns, remaining_rows), device=device)) 65 | 66 | final_matrix = torch.cat(block_list, dim=2).type(dtype) 67 | final_matrix = F.normalize(final_matrix, p=2, dim=2) 68 | 69 | if random_shuffle: 70 | _, indices = torch.rand((final_matrix.shape[0], final_matrix.shape[2]), device=device).sort(dim=1) 71 | indices = indices.unsqueeze(1).expand(final_matrix.shape) 72 | final_matrix = torch.gather(final_matrix, dim=2, index=indices) 73 | 74 | return final_matrix # (nb_samples, nb_columns, nb_rows) 75 | 76 | 77 | if __name__ == "__main__": 78 | gaussian_orthogonal_random_matrix(37, 128) 79 | gaussian_orthogonal_random_matrix_batched(256, 37, 128) 80 | -------------------------------------------------------------------------------- /models/graphsgpt_cond_gen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/models/graphsgpt_cond_gen/__init__.py -------------------------------------------------------------------------------- /models/graphsgpt_cond_gen/configuration_graphsgpt_cond_gen.py: -------------------------------------------------------------------------------- 1 | """GraphsGPT Conditioned Model Configuration""" 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.utils import logging 5 | 6 | logger = logging.get_logger(__name__) 7 | 8 | 9 | class GraphsGPTConditionedConfig(PretrainedConfig): 10 | model_type = "graphs_gpt_conditioned" 11 | 12 | def __init__( 13 | self, 14 | atom_vocab_size=118, # number of atoms 15 | bond_vocab_size=92, # number of bonds 16 | pad_token_id=0, 17 | share_embeddings=False, 18 | # --------------------- # 19 | node_loss_weight=1.0, 20 | connection_loss_weight=1.0, 21 | connection_loss_type="contrastive", # classification contrastive 22 | adaptive_position_length=False, 23 | # --------------------- # 24 | num_properties=3, # 🔍 25 | num_fingerprints=8, 26 | position_feature_size=128, 27 | hidden_size=512, 28 | intermediate_size=2048, 29 | num_hidden_layers=8, 30 | num_attention_heads=8, 31 | hidden_act="silu", 32 | # --------------------- # 33 | initializer_method="hidden", # manual hidden hidden-layer 34 | initializer_range=0.02, # useful only when "initializer_method" is "manual" 35 | rms_norm_eps=1e-6, 36 | gradient_checkpointing=False, 37 | **kwargs, 38 | ): 39 | self.atom_vocab_size = atom_vocab_size 40 | self.bond_vocab_size = bond_vocab_size 41 | self.share_embeddings = share_embeddings 42 | 43 | self.node_loss_weight = node_loss_weight 44 | self.connection_loss_weight = connection_loss_weight 45 | self.connection_loss_type = connection_loss_type 46 | self.adaptive_position_length = adaptive_position_length 47 | 48 | self.num_properties = num_properties # 🔍 49 | self.num_fingerprints = num_fingerprints 50 | self.position_feature_size = position_feature_size 51 | self.hidden_size = hidden_size 52 | self.intermediate_size = intermediate_size 53 | self.num_hidden_layers = num_hidden_layers 54 | self.num_attention_heads = num_attention_heads 55 | self.hidden_act = hidden_act 56 | 57 | self.initializer_method = initializer_method 58 | self.initializer_range = initializer_range 59 | self.rms_norm_eps = rms_norm_eps 60 | self.gradient_checkpointing = gradient_checkpointing 61 | 62 | super().__init__( 63 | pad_token_id=pad_token_id, 64 | tie_word_embeddings=False, 65 | # **kwargs, 66 | ) 67 | -------------------------------------------------------------------------------- /moses/NP_Score/README: -------------------------------------------------------------------------------- 1 | RDKit-based implementation of the method described in: 2 | 3 | Natural Product-likeness Score and Its Application for Prioritization of Compound Libraries 4 | Peter Ertl, Silvio Roggo, and Ansgar Schuffenhauer 5 | Journal of Chemical Information and Modeling, 48, 68-74 (2008) 6 | http://pubs.acs.org/doi/abs/10.1021/ci700286x 7 | 8 | Contribution from Peter Ertl 9 | 10 | -------------------------------------------------------------------------------- /moses/NP_Score/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/NP_Score/__init__.py -------------------------------------------------------------------------------- /moses/NP_Score/npscorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of natural product-likeness as described in: 3 | # 4 | # Natural Product-likeness Score and Its Application for Prioritization of 5 | # Compound Libraries 6 | # Peter Ertl, Silvio Roggo, and Ansgar Schuffenhauer 7 | # Journal of Chemical Information and Modeling, 48, 68-74 (2008) 8 | # http://pubs.acs.org/doi/abs/10.1021/ci700286x 9 | # 10 | # for the training of this model only openly available data have been used 11 | # ~50,000 natural products collected from various open databases 12 | # ~1 million drug-like molecules from ZINC as a "non-NP background" 13 | # 14 | # peter ertl, august 2015 15 | # 16 | 17 | from __future__ import print_function 18 | 19 | import os.path 20 | from collections import namedtuple 21 | 22 | import gzip 23 | import math 24 | import pickle 25 | import sys 26 | from rdkit import Chem 27 | from rdkit.Chem import rdMolDescriptors 28 | 29 | _fscores = None 30 | 31 | 32 | def readNPModel(filename=os.path.join(os.path.dirname(__file__), 33 | 'publicnp.model.gz')): 34 | """Reads and returns the scoring model, 35 | which has to be passed to the scoring functions.""" 36 | global _fscores 37 | _fscores = pickle.load(gzip.open(filename)) 38 | return _fscores 39 | 40 | 41 | def scoreMolWConfidence(mol, fscore): 42 | """Next to the NP Likeness Score, this function outputs a confidence value 43 | between 0..1 that descibes how many fragments of the tested molecule 44 | were found in the model data set (1: all fragments were found). 45 | 46 | Returns namedtuple NPLikeness(nplikeness, confidence)""" 47 | 48 | if mol is None: 49 | raise ValueError('invalid molecule') 50 | fp = rdMolDescriptors.GetMorganFingerprint(mol, 2) 51 | bits = fp.GetNonzeroElements() 52 | 53 | # calculating the score 54 | score = 0.0 55 | bits_found = 0 56 | for bit in bits: 57 | if bit in fscore: 58 | bits_found += 1 59 | score += fscore[bit] 60 | 61 | score /= float(mol.GetNumAtoms()) 62 | confidence = float(bits_found / len(bits)) 63 | 64 | # preventing score explosion for exotic molecules 65 | if score > 4: 66 | score = 4. + math.log10(score - 4. + 1.) 67 | elif score < -4: 68 | score = -4. - math.log10(-4. - score + 1.) 69 | NPLikeness = namedtuple("NPLikeness", "nplikeness,confidence") 70 | return NPLikeness(score, confidence) 71 | 72 | 73 | def scoreMol(mol, fscore=None): 74 | """Calculates the Natural Product Likeness of a molecule. 75 | 76 | Returns the score as float in the range -5..5.""" 77 | if _fscores is None: 78 | readNPModel() 79 | fscore = fscore or _fscores 80 | return scoreMolWConfidence(mol, fscore).nplikeness 81 | 82 | 83 | def processMols(fscore, suppl): 84 | print("calculating ...", file=sys.stderr) 85 | n = 0 86 | for m in suppl: 87 | if m is None: 88 | continue 89 | 90 | n += 1 91 | score = "%.3f" % scoreMol(m, fscore) 92 | 93 | smiles = Chem.MolToSmiles(m, True) 94 | name = m.GetProp('_Name') 95 | print(smiles + "\t" + name + "\t" + score) 96 | 97 | print("finished, " + str(n) + " molecules processed", file=sys.stderr) 98 | 99 | 100 | if __name__ == '__main__': 101 | fscore = readNPModel() # fills fscore 102 | 103 | suppl = Chem.SmilesMolSupplier( 104 | sys.argv[1], smilesColumn=0, nameColumn=1, titleLine=False 105 | ) 106 | processMols(fscore, suppl) 107 | 108 | # 109 | # Copyright (c) 2015, Novartis Institutes for BioMedical Research Inc. 110 | # All rights reserved. 111 | # 112 | # Redistribution and use in source and binary forms, with or without 113 | # modification, are permitted provided that the following conditions are 114 | # met: 115 | # 116 | # * Redistributions of source code must retain the above copyright 117 | # notice, this list of conditions and the following disclaimer. 118 | # * Redistributions in binary form must reproduce the above 119 | # copyright notice, this list of conditions and the following 120 | # disclaimer in the documentation and/or other materials provided 121 | # with the distribution. 122 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 123 | # nor the names of its contributors may be used to endorse or promote 124 | # products derived from this software without specific prior written 125 | # permission. 126 | # 127 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 128 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 129 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 130 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 131 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 132 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 133 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 134 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 135 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 136 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 137 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 138 | # 139 | -------------------------------------------------------------------------------- /moses/NP_Score/publicnp.model.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/NP_Score/publicnp.model.gz -------------------------------------------------------------------------------- /moses/README.md: -------------------------------------------------------------------------------- 1 | This is a modified version of [Molecular Sets (MOSES)](https://github.com/molecularsets/moses) for portability and minimal implementation. 2 | 3 | Changes include: 4 | 5 | 1. Cleaned the codes for training & evaluating built-in models, which makes it compatible with the latest versions of Python (>=3.10). 6 | 2. Fixed the [compatibility issue](https://github.com/molecularsets/moses/pull/111) with the latest versions of Pandas (>=2.2.2). 7 | 8 | -------------------------------------------------------------------------------- /moses/SA_Score/README: -------------------------------------------------------------------------------- 1 | RDKit-based implementation of the method described in: 2 | 3 | Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 4 | Peter Ertl and Ansgar Schuffenhauer 5 | Journal of Cheminformatics 1:8 (2009) 6 | http://www.jcheminf.com/content/1/1/8 7 | 8 | Contribution from Peter Ertl and Greg Landrum 9 | 10 | -------------------------------------------------------------------------------- /moses/SA_Score/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/SA_Score/__init__.py -------------------------------------------------------------------------------- /moses/SA_Score/fpscores.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/SA_Score/fpscores.pkl.gz -------------------------------------------------------------------------------- /moses/SA_Score/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on 5 | # Molecular Complexity and Fragment Contributions 6 | # Peter Ertl and Ansgar Schuffenhauer 7 | # Journal of Cheminformatics 1:8 (2009) 8 | # http://www.jcheminf.com/content/1/1/8 9 | # 10 | # several small modifications to the original paper are included 11 | # particularly slightly different formula for marocyclic penalty 12 | # and taking into account also molecule symmetry (fingerprint density) 13 | # 14 | # for a set of 10k diverse molecules the agreement between the original method 15 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 16 | # 17 | # peter ertl & greg landrum, september 2013 18 | # 19 | from __future__ import print_function 20 | 21 | import os.path as op 22 | 23 | import math 24 | import pickle 25 | from rdkit import Chem 26 | from rdkit.Chem import rdMolDescriptors 27 | from rdkit.six import iteritems 28 | 29 | _fscores = None 30 | 31 | 32 | def readFragmentScores(name='fpscores'): 33 | import gzip 34 | global _fscores 35 | # generate the full path filename: 36 | if name == "fpscores": 37 | name = op.join(op.dirname(__file__), name) 38 | _fscores = pickle.load(gzip.open('%s.pkl.gz' % name)) 39 | outDict = {} 40 | for i in _fscores: 41 | for j in range(1, len(i)): 42 | outDict[i[j]] = float(i[0]) 43 | _fscores = outDict 44 | 45 | 46 | def numBridgeheadsAndSpiro(mol, ri=None): 47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 49 | return nBridgehead, nSpiro 50 | 51 | 52 | def calculateScore(m): 53 | if _fscores is None: 54 | readFragmentScores() 55 | 56 | # fragment score 57 | fp = rdMolDescriptors.GetMorganFingerprint( 58 | m, 2 # <- 2 is the *radius* of the circular fingerprint 59 | ) 60 | fps = fp.GetNonzeroElements() 61 | score1 = 0. 62 | nf = 0 63 | for bitId, v in iteritems(fps): 64 | nf += v 65 | sfp = bitId 66 | score1 += _fscores.get(sfp, -4) * v 67 | score1 /= nf 68 | 69 | # features score 70 | nAtoms = m.GetNumAtoms() 71 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 72 | ri = m.GetRingInfo() 73 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 74 | nMacrocycles = 0 75 | for x in ri.AtomRings(): 76 | if len(x) > 8: 77 | nMacrocycles += 1 78 | 79 | sizePenalty = nAtoms ** 1.005 - nAtoms 80 | stereoPenalty = math.log10(nChiralCenters + 1) 81 | spiroPenalty = math.log10(nSpiro + 1) 82 | bridgePenalty = math.log10(nBridgeheads + 1) 83 | macrocyclePenalty = 0. 84 | # --------------------------------------- 85 | # This differs from the paper, which defines: 86 | # macrocyclePenalty = math.log10(nMacrocycles+1) 87 | # This form generates better results when 2 or more macrocycles are present 88 | if nMacrocycles > 0: 89 | macrocyclePenalty = math.log10(2) 90 | 91 | score2 = (0. - sizePenalty - stereoPenalty - 92 | spiroPenalty - bridgePenalty - macrocyclePenalty) 93 | 94 | # correction for the fingerprint density 95 | # not in the original publication, added in version 1.1 96 | # to make highly symmetrical molecules easier to synthetise 97 | score3 = 0. 98 | if nAtoms > len(fps): 99 | score3 = math.log(float(nAtoms) / len(fps)) * .5 100 | 101 | sascore = score1 + score2 + score3 102 | 103 | # need to transform "raw" value into scale between 1 and 10 104 | min = -4.0 105 | max = 2.5 106 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 107 | # smooth the 10-end 108 | if sascore > 8.: 109 | sascore = 8. + math.log(sascore + 1. - 9.) 110 | if sascore > 10.: 111 | sascore = 10.0 112 | elif sascore < 1.: 113 | sascore = 1.0 114 | 115 | return sascore 116 | 117 | 118 | def processMols(mols): 119 | print('smiles\tName\tsa_score') 120 | for m in mols: 121 | if m is None: 122 | continue 123 | 124 | s = calculateScore(m) 125 | 126 | smiles = Chem.MolToSmiles(m) 127 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 128 | 129 | 130 | if __name__ == '__main__': 131 | import sys 132 | import time 133 | 134 | t1 = time.time() 135 | readFragmentScores("fpscores") 136 | t2 = time.time() 137 | 138 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 139 | t3 = time.time() 140 | processMols(suppl) 141 | t4 = time.time() 142 | 143 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ( 144 | (t2 - t1), (t4 - t3)), 145 | file=sys.stderr) 146 | 147 | # 148 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 149 | # All rights reserved. 150 | # 151 | # Redistribution and use in source and binary forms, with or without 152 | # modification, are permitted provided that the following conditions are 153 | # met: 154 | # 155 | # * Redistributions of source code must retain the above copyright 156 | # notice, this list of conditions and the following disclaimer. 157 | # * Redistributions in binary form must reproduce the above 158 | # copyright notice, this list of conditions and the following 159 | # disclaimer in the documentation and/or other materials provided 160 | # with the distribution. 161 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 162 | # nor the names of its contributors may be used to endorse or promote 163 | # products derived from this software without specific prior written 164 | # permission. 165 | # 166 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 167 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 168 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 169 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 170 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 171 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 172 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 173 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 174 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 175 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 176 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 177 | # 178 | -------------------------------------------------------------------------------- /moses/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import get_dataset, get_statistics 2 | from .metrics import get_all_metrics 3 | 4 | __version__ = '0.3.1' 5 | __all__ = ["get_dataset", "get_statistics", "get_all_metrics"] 6 | -------------------------------------------------------------------------------- /moses/data/test.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/data/test.csv.gz -------------------------------------------------------------------------------- /moses/data/test_scaffolds.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/data/test_scaffolds.csv.gz -------------------------------------------------------------------------------- /moses/data/test_scaffolds_stats.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/data/test_scaffolds_stats.npz -------------------------------------------------------------------------------- /moses/data/test_stats.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/data/test_stats.npz -------------------------------------------------------------------------------- /moses/data/train.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/data/train.csv.gz -------------------------------------------------------------------------------- /moses/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | 5 | AVAILABLE_SPLITS = ['train', 'test', 'test_scaffolds'] 6 | 7 | BASE_PATH = os.path.dirname(__file__) 8 | 9 | 10 | # BASE_PATH = os.path.join(BASE_PATH, "..") 11 | 12 | 13 | def get_dataset(split='train'): 14 | """ 15 | Loads MOSES dataset 16 | 17 | Arguments: 18 | split (str): split to load. Must be 19 | one of: 'train', 'test', 'test_scaffolds' 20 | 21 | Returns: 22 | list with SMILES strings 23 | """ 24 | if split not in AVAILABLE_SPLITS: 25 | raise ValueError( 26 | f"Unknown split {split}. " 27 | f"Available splits: {AVAILABLE_SPLITS}" 28 | ) 29 | if split not in AVAILABLE_SPLITS: 30 | raise ValueError( 31 | f"Unknown split {split}. " 32 | f"Available splits: {AVAILABLE_SPLITS}") 33 | path = os.path.join(BASE_PATH, 'data', split + '.csv.gz') 34 | smiles = pd.read_csv(path, compression='gzip')['SMILES'].values 35 | return smiles 36 | 37 | 38 | def get_statistics(split='test'): 39 | path = os.path.join(BASE_PATH, 'data', split + '_stats.npz') 40 | return np.load(path, allow_pickle=True)['stats'].item() 41 | -------------------------------------------------------------------------------- /moses/mcf.csv: -------------------------------------------------------------------------------- 1 | names,smarts 2 | MCF1,[#6]=&!@[#6]-[#6]#[#7] 3 | MCF2,[#6]=&!@[#6]-[#16](=[#8])=[#8] 4 | MCF3,[#6]=&!@[#6&!H0]-&!@[#6](=[#8])-&!@[#7] 5 | MCF4,"[H]C([H])([#6])[F,Cl,Br,I]" 6 | MCF5,[#6]1-[#8]-[#6]-1 7 | MCF6,[#6]-[#7]=[#6]=[#8] 8 | MCF7,[#6&!H0]=[#8] 9 | MCF8,"[#6](=&!@[#7&!H0])-&!@[#6,#7,#8,#16]" 10 | MCF9,[#6]1-[#7]-[#6]-1 11 | MCF10,[#6]~&!@[#7]~&!@[#7]~&!@[#6] 12 | MCF11,[#7]=&!@[#7] 13 | MCF12,[H][#6]-1=[#6]([H])-[#6]=[#6](-*)-[#8]-1 14 | MCF13,[H][#6]-1=[#6]([H])-[#6]=[#6](-*)-[#16]-1 15 | MCF14,"[#17,#35,#53]-c(:*):[!#1!#6]:*" 16 | MCF15,[H][#7]([H])-[#6]-1=[#6]-[#6]=[#6]-[#6]=[#6]-1 17 | MCF16,[#16]~[#16] 18 | MCF17,[#7]~&!@[#7]~&!@[#7] 19 | MCF18,[#7]-&!@[#6&!H0&!H1]-&!@[#7] 20 | MCF19,[#6&!H0](-&!@[#8])-&!@[#8] 21 | MCF20,[#35].[#35].[#35] 22 | MCF21,[#17].[#17].[#17].[#17] 23 | MCF22,[#9].[#9].[#9].[#9].[#9].[#9].[#9] 24 | -------------------------------------------------------------------------------- /moses/metric_utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import numpy as np 4 | import os 5 | import pandas as pd 6 | import scipy.sparse 7 | import torch 8 | from functools import partial 9 | from rdkit import Chem 10 | from rdkit.Chem import AllChem 11 | from rdkit.Chem import Descriptors 12 | from rdkit.Chem import MACCSkeys 13 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan 14 | from rdkit.Chem.QED import qed 15 | from rdkit.Chem.Scaffolds import MurckoScaffold 16 | 17 | from .NP_Score import npscorer 18 | from .SA_Score import sascorer 19 | from .utils import mapper, get_mol 20 | 21 | _base_dir = os.path.split(__file__)[0] 22 | _mcf = pd.read_csv(os.path.join(_base_dir, 'mcf.csv')) 23 | _pains = pd.read_csv(os.path.join(_base_dir, 'wehi_pains.csv'), names=['smarts', 'names']) 24 | _filters = [Chem.MolFromSmarts(x) for x in pd.concat([_mcf, _pains], sort=True)['smarts'].values] # fixed for pandas 2.2.2 25 | 26 | 27 | def canonic_smiles(smiles_or_mol): 28 | mol = get_mol(smiles_or_mol) 29 | if mol is None: 30 | return None 31 | return Chem.MolToSmiles(mol) 32 | 33 | 34 | def logP(mol): 35 | """ 36 | Computes RDKit's logP 37 | """ 38 | return Chem.Crippen.MolLogP(mol) 39 | 40 | 41 | def SA(mol): 42 | """ 43 | Computes RDKit's Synthetic Accessibility score 44 | """ 45 | return sascorer.calculateScore(mol) 46 | 47 | 48 | def NP(mol): 49 | """ 50 | Computes RDKit's Natural Product-likeness score 51 | """ 52 | return npscorer.scoreMol(mol) 53 | 54 | 55 | def QED(mol): 56 | """ 57 | Computes RDKit's QED score 58 | """ 59 | return qed(mol) 60 | 61 | 62 | def weight(mol): 63 | """ 64 | Computes molecular weight for given molecule. 65 | Returns float, 66 | """ 67 | return Descriptors.MolWt(mol) 68 | 69 | 70 | def get_n_rings(mol): 71 | """ 72 | Computes the number of rings in a molecule 73 | """ 74 | return mol.GetRingInfo().NumRings() 75 | 76 | 77 | def fragmenter(mol): 78 | """ 79 | fragment mol using BRICS and return smiles list 80 | """ 81 | fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol)) 82 | fgs_smi = Chem.MolToSmiles(fgs).split(".") 83 | return fgs_smi 84 | 85 | 86 | def compute_fragments(mol_list, n_jobs=1): 87 | """ 88 | fragment list of mols using BRICS and return smiles list 89 | """ 90 | fragments = Counter() 91 | for mol_frag in mapper(n_jobs)(fragmenter, mol_list): 92 | fragments.update(mol_frag) 93 | return fragments 94 | 95 | 96 | def compute_scaffolds(mol_list, n_jobs=1, min_rings=2): 97 | """ 98 | Extracts a scafold from a molecule in a form of a canonic SMILES 99 | """ 100 | scaffolds = Counter() 101 | map_ = mapper(n_jobs) 102 | scaffolds = Counter( 103 | map_(partial(compute_scaffold, min_rings=min_rings), mol_list)) 104 | if None in scaffolds: 105 | scaffolds.pop(None) 106 | return scaffolds 107 | 108 | 109 | def compute_scaffold(mol, min_rings=2): 110 | mol = get_mol(mol) 111 | try: 112 | scaffold = MurckoScaffold.GetScaffoldForMol(mol) 113 | except (ValueError, RuntimeError): 114 | return None 115 | n_rings = get_n_rings(scaffold) 116 | scaffold_smiles = Chem.MolToSmiles(scaffold) 117 | if scaffold_smiles == '' or n_rings < min_rings: 118 | return None 119 | return scaffold_smiles 120 | 121 | 122 | def average_agg_tanimoto(stock_vecs, gen_vecs, 123 | batch_size=5000, agg='max', 124 | device='cpu', p=1): 125 | """ 126 | For each molecule in gen_vecs finds closest molecule in stock_vecs. 127 | Returns average tanimoto score for between these molecules 128 | 129 | Parameters: 130 | stock_vecs: numpy array 131 | gen_vecs: numpy array 132 | agg: max or mean 133 | p: power for averaging: (mean x^p)^(1/p) 134 | """ 135 | assert agg in ['max', 'mean'], "Can aggregate only max or mean" 136 | agg_tanimoto = np.zeros(len(gen_vecs)) 137 | total = np.zeros(len(gen_vecs)) 138 | for j in range(0, stock_vecs.shape[0], batch_size): 139 | x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() 140 | for i in range(0, gen_vecs.shape[0], batch_size): 141 | y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() 142 | y_gen = y_gen.transpose(0, 1) 143 | tp = torch.mm(x_stock, y_gen) 144 | jac = (tp / (x_stock.sum(1, keepdim=True) + 145 | y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() 146 | jac[np.isnan(jac)] = 1 147 | if p != 1: 148 | jac = jac ** p 149 | if agg == 'max': 150 | agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( 151 | agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) 152 | elif agg == 'mean': 153 | agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) 154 | total[i:i + y_gen.shape[1]] += jac.shape[0] 155 | if agg == 'mean': 156 | agg_tanimoto /= total 157 | if p != 1: 158 | agg_tanimoto = (agg_tanimoto) ** (1 / p) 159 | return np.mean(agg_tanimoto) 160 | 161 | 162 | def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2, 163 | morgan__n=1024, *args, **kwargs): 164 | """ 165 | Generates fingerprint for SMILES 166 | If smiles is invalid, returns None 167 | Returns numpy array of fingerprint bits 168 | 169 | Parameters: 170 | smiles: SMILES string 171 | type: type of fingerprint: [MACCS|morgan] 172 | dtype: if not None, specifies the dtype of returned array 173 | """ 174 | fp_type = fp_type.lower() 175 | molecule = get_mol(smiles_or_mol, *args, **kwargs) 176 | if molecule is None: 177 | return None 178 | if fp_type == 'maccs': 179 | keys = MACCSkeys.GenMACCSKeys(molecule) 180 | keys = np.array(keys.GetOnBits()) 181 | fingerprint = np.zeros(166, dtype='uint8') 182 | if len(keys) != 0: 183 | fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero 184 | elif fp_type == 'morgan': 185 | fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n), 186 | dtype='uint8') 187 | else: 188 | raise ValueError("Unknown fingerprint type {}".format(fp_type)) 189 | if dtype is not None: 190 | fingerprint = fingerprint.astype(dtype) 191 | return fingerprint 192 | 193 | 194 | def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args, 195 | **kwargs): 196 | ''' 197 | Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers 198 | e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10) 199 | Inserts np.NaN to rows corresponding to incorrect smiles. 200 | IMPORTANT: if there is at least one np.NaN, the dtype would be float 201 | Parameters: 202 | smiles_mols_array: list/array/pd.Series of smiles or already computed 203 | RDKit molecules 204 | n_jobs: number of parralel workers to execute 205 | already_unique: flag for performance reasons, if smiles array is big 206 | and already unique. Its value is set to True if smiles_mols_array 207 | contain RDKit molecules already. 208 | ''' 209 | if isinstance(smiles_mols_array, pd.Series): 210 | smiles_mols_array = smiles_mols_array.values 211 | else: 212 | smiles_mols_array = np.asarray(smiles_mols_array) 213 | if not isinstance(smiles_mols_array[0], str): 214 | already_unique = True 215 | 216 | if not already_unique: 217 | smiles_mols_array, inv_index = np.unique(smiles_mols_array, 218 | return_inverse=True) 219 | 220 | fps = mapper(n_jobs)( 221 | partial(fingerprint, *args, **kwargs), smiles_mols_array 222 | ) 223 | 224 | length = 1 225 | for fp in fps: 226 | if fp is not None: 227 | length = fp.shape[-1] 228 | first_fp = fp 229 | break 230 | fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :] 231 | for fp in fps] 232 | if scipy.sparse.issparse(first_fp): 233 | fps = scipy.sparse.vstack(fps).tocsr() 234 | else: 235 | fps = np.vstack(fps) 236 | if not already_unique: 237 | return fps[inv_index] 238 | return fps 239 | 240 | 241 | def mol_passes_filters(mol, 242 | allowed=None, 243 | isomericSmiles=False): 244 | """ 245 | Checks if mol 246 | * passes MCF and PAINS filters, 247 | * has only allowed atoms 248 | * is not charged 249 | """ 250 | allowed = allowed or {'C', 'N', 'S', 'O', 'F', 'Cl', 'Br', 'H'} 251 | mol = get_mol(mol) 252 | if mol is None: 253 | return False 254 | ring_info = mol.GetRingInfo() 255 | if ring_info.NumRings() != 0 and any( 256 | len(x) >= 8 for x in ring_info.AtomRings() 257 | ): 258 | return False 259 | h_mol = Chem.AddHs(mol) 260 | if any(atom.GetFormalCharge() != 0 for atom in mol.GetAtoms()): 261 | return False 262 | if any(atom.GetSymbol() not in allowed for atom in mol.GetAtoms()): 263 | return False 264 | if any(h_mol.HasSubstructMatch(smarts) for smarts in _filters): 265 | return False 266 | smiles = Chem.MolToSmiles(mol, isomericSmiles=isomericSmiles) 267 | if smiles is None or len(smiles) == 0: 268 | return False 269 | if Chem.MolFromSmiles(smiles) is None: 270 | return False 271 | return True 272 | -------------------------------------------------------------------------------- /moses/metrics.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | 3 | import numpy as np 4 | import warnings 5 | from fcd_torch import FCD as FCDMetric 6 | from scipy.spatial.distance import cosine as cos_distance 7 | from scipy.stats import wasserstein_distance 8 | 9 | from .dataset import get_dataset, get_statistics 10 | from .metric_utils import compute_fragments, average_agg_tanimoto, compute_scaffolds, fingerprints, get_mol, canonic_smiles, mol_passes_filters, logP, QED, SA, weight 11 | from .utils import mapper, disable_rdkit_log, enable_rdkit_log 12 | 13 | 14 | def get_all_metrics(gen, k=None, n_jobs=1, 15 | device='cpu', batch_size=512, pool=None, 16 | test=None, test_scaffolds=None, 17 | ptest=None, ptest_scaffolds=None, 18 | train=None): 19 | """ 20 | Computes all available metrics between test (scaffold test) 21 | and generated sets of SMILES. 22 | Parameters: 23 | gen: list of generated SMILES 24 | k: int or list with values for unique@k. Will calculate number of 25 | unique molecules in the first k molecules. Default [1000, 10000] 26 | n_jobs: number of workers for parallel processing 27 | device: 'cpu' or 'cuda:n', where n is GPU device number 28 | batch_size: batch size for FCD metric 29 | pool: optional multiprocessing pool to use for parallelization 30 | 31 | test (None or list): test SMILES. If None, will load 32 | a default test set 33 | test_scaffolds (None or list): scaffold test SMILES. If None, will 34 | load a default scaffold test set 35 | ptest (None or dict): precalculated statistics of the test set. If 36 | None, will load default test statistics. If you specified a custom 37 | test set, default test statistics will be ignored 38 | ptest_scaffolds (None or dict): precalculated statistics of the 39 | scaffold test set If None, will load default scaffold test 40 | statistics. If you specified a custom test set, default test 41 | statistics will be ignored 42 | train (None or list): train SMILES. If None, will load a default 43 | train set 44 | Available metrics: 45 | * %valid 46 | * %unique@k 47 | * Frechet ChemNet Distance (FCD) 48 | * Fragment similarity (Frag) 49 | * Scaffold similarity (Scaf) 50 | * Similarity to nearest neighbour (SNN) 51 | * Internal diversity (IntDiv) 52 | * Internal diversity 2: using square root of mean squared 53 | Tanimoto similarity (IntDiv2) 54 | * %passes filters (Filters) 55 | * Distribution difference for logP, SA, QED, weight 56 | * Novelty (molecules not present in train) 57 | """ 58 | if test is None: 59 | if ptest is not None: 60 | raise ValueError( 61 | "You cannot specify custom test " 62 | "statistics for default test set") 63 | test = get_dataset('test') 64 | ptest = get_statistics('test') 65 | 66 | if test_scaffolds is None: 67 | if ptest_scaffolds is not None: 68 | raise ValueError( 69 | "You cannot specify custom scaffold test " 70 | "statistics for default scaffold test set") 71 | test_scaffolds = get_dataset('test_scaffolds') 72 | ptest_scaffolds = get_statistics('test_scaffolds') 73 | 74 | train = train or get_dataset('train') 75 | 76 | if k is None: 77 | k = [1000, 10000] 78 | disable_rdkit_log() 79 | metrics = {} 80 | close_pool = False 81 | if pool is None: 82 | if n_jobs != 1: 83 | pool = Pool(n_jobs) 84 | close_pool = True 85 | else: 86 | pool = 1 87 | metrics['valid'] = fraction_valid(gen, n_jobs=pool) 88 | gen = remove_invalid(gen, canonize=True) 89 | if not isinstance(k, (list, tuple)): 90 | k = [k] 91 | for _k in k: 92 | metrics['unique@{}'.format(_k)] = fraction_unique(gen, _k, pool) 93 | 94 | if ptest is None: 95 | ptest = compute_intermediate_statistics(test, n_jobs=n_jobs, 96 | device=device, 97 | batch_size=batch_size, 98 | pool=pool) 99 | if test_scaffolds is not None and ptest_scaffolds is None: 100 | ptest_scaffolds = compute_intermediate_statistics( 101 | test_scaffolds, n_jobs=n_jobs, 102 | device=device, batch_size=batch_size, 103 | pool=pool 104 | ) 105 | mols = mapper(pool)(get_mol, gen) 106 | kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size} 107 | kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size} 108 | metrics['FCD/Test'] = FCDMetric(**kwargs_fcd)(gen=gen, pref=ptest['FCD']) 109 | metrics['SNN/Test'] = SNNMetric(**kwargs)(gen=mols, pref=ptest['SNN']) 110 | metrics['Frag/Test'] = FragMetric(**kwargs)(gen=mols, pref=ptest['Frag']) 111 | metrics['Scaf/Test'] = ScafMetric(**kwargs)(gen=mols, pref=ptest['Scaf']) 112 | if ptest_scaffolds is not None: 113 | metrics['FCD/TestSF'] = FCDMetric(**kwargs_fcd)( 114 | gen=gen, pref=ptest_scaffolds['FCD'] 115 | ) 116 | metrics['SNN/TestSF'] = SNNMetric(**kwargs)( 117 | gen=mols, pref=ptest_scaffolds['SNN'] 118 | ) 119 | metrics['Frag/TestSF'] = FragMetric(**kwargs)( 120 | gen=mols, pref=ptest_scaffolds['Frag'] 121 | ) 122 | metrics['Scaf/TestSF'] = ScafMetric(**kwargs)( 123 | gen=mols, pref=ptest_scaffolds['Scaf'] 124 | ) 125 | 126 | metrics['IntDiv'] = internal_diversity(mols, pool, device=device) 127 | metrics['IntDiv2'] = internal_diversity(mols, pool, device=device, p=2) 128 | metrics['Filters'] = fraction_passes_filters(mols, pool) 129 | 130 | # Properties 131 | for name, func in [('logP', logP), ('SA', SA), 132 | ('QED', QED), 133 | ('weight', weight)]: 134 | metrics[name] = WassersteinMetric(func, **kwargs)( 135 | gen=mols, pref=ptest[name]) 136 | 137 | if train is not None: 138 | metrics['Novelty'] = novelty(mols, train, pool) 139 | enable_rdkit_log() 140 | if close_pool: 141 | pool.close() 142 | pool.join() 143 | return metrics 144 | 145 | 146 | def compute_intermediate_statistics(smiles, n_jobs=1, device='cpu', 147 | batch_size=512, pool=None): 148 | """ 149 | The function precomputes statistics such as mean and variance for FCD, etc. 150 | It is useful to compute the statistics for test and scaffold test sets to 151 | speedup metrics calculation. 152 | """ 153 | close_pool = False 154 | if pool is None: 155 | if n_jobs != 1: 156 | pool = Pool(n_jobs) 157 | close_pool = True 158 | else: 159 | pool = 1 160 | statistics = {} 161 | mols = mapper(pool)(get_mol, smiles) 162 | kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size} 163 | kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size} 164 | statistics['FCD'] = FCDMetric(**kwargs_fcd).precalc(smiles) 165 | statistics['SNN'] = SNNMetric(**kwargs).precalc(mols) 166 | statistics['Frag'] = FragMetric(**kwargs).precalc(mols) 167 | statistics['Scaf'] = ScafMetric(**kwargs).precalc(mols) 168 | for name, func in [('logP', logP), ('SA', SA), 169 | ('QED', QED), 170 | ('weight', weight)]: 171 | statistics[name] = WassersteinMetric(func, **kwargs).precalc(mols) 172 | if close_pool: 173 | pool.terminate() 174 | return statistics 175 | 176 | 177 | def fraction_passes_filters(gen, n_jobs=1): 178 | """ 179 | Computes the fraction of molecules that pass filters: 180 | * MCF 181 | * PAINS 182 | * Only allowed atoms ('C','N','S','O','F','Cl','Br','H') 183 | * No charges 184 | """ 185 | passes = mapper(n_jobs)(mol_passes_filters, gen) 186 | return np.mean(passes) 187 | 188 | 189 | def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan', 190 | gen_fps=None, p=1): 191 | """ 192 | Computes internal diversity as: 193 | 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y)) 194 | """ 195 | if gen_fps is None: 196 | gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs) 197 | return 1 - (average_agg_tanimoto(gen_fps, gen_fps, 198 | agg='mean', device=device, p=p)).mean() 199 | 200 | 201 | def fraction_unique(gen, k=None, n_jobs=1, check_validity=True): 202 | """ 203 | Computes a number of unique molecules 204 | Parameters: 205 | gen: list of SMILES 206 | k: compute unique@k 207 | n_jobs: number of threads for calculation 208 | check_validity: raises ValueError if invalid molecules are present 209 | """ 210 | if k is not None: 211 | if len(gen) < k: 212 | warnings.warn( 213 | "Can't compute unique@{}.".format(k) + 214 | "gen contains only {} molecules".format(len(gen)) 215 | ) 216 | gen = gen[:k] 217 | canonic = set(mapper(n_jobs)(canonic_smiles, gen)) 218 | if None in canonic and check_validity: 219 | raise ValueError("Invalid molecule passed to unique@k") 220 | return len(canonic) / len(gen) 221 | 222 | 223 | def fraction_valid(gen, n_jobs=1): 224 | """ 225 | Computes a number of valid molecules 226 | Parameters: 227 | gen: list of SMILES 228 | n_jobs: number of threads for calculation 229 | """ 230 | gen = mapper(n_jobs)(get_mol, gen) 231 | return 1 - gen.count(None) / len(gen) 232 | 233 | 234 | def novelty(gen, train, n_jobs=1): 235 | gen_smiles = mapper(n_jobs)(canonic_smiles, gen) 236 | gen_smiles_set = set(gen_smiles) - {None} 237 | train_set = set(train) 238 | return len(gen_smiles_set - train_set) / len(gen_smiles_set) 239 | 240 | 241 | def remove_invalid(gen, canonize=True, n_jobs=1): 242 | """ 243 | Removes invalid molecules from the dataset 244 | """ 245 | if not canonize: 246 | mols = mapper(n_jobs)(get_mol, gen) 247 | return [gen_ for gen_, mol in zip(gen, mols) if mol is not None] 248 | return [x for x in mapper(n_jobs)(canonic_smiles, gen) if 249 | x is not None] 250 | 251 | 252 | class Metric: 253 | def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs): 254 | self.n_jobs = n_jobs 255 | self.device = device 256 | self.batch_size = batch_size 257 | for k, v in kwargs.values(): 258 | setattr(self, k, v) 259 | 260 | def __call__(self, ref=None, gen=None, pref=None, pgen=None): 261 | assert (ref is None) != (pref is None), "specify ref xor pref" 262 | assert (gen is None) != (pgen is None), "specify gen xor pgen" 263 | if pref is None: 264 | pref = self.precalc(ref) 265 | if pgen is None: 266 | pgen = self.precalc(gen) 267 | return self.metric(pref, pgen) 268 | 269 | def precalc(self, moleclues): 270 | raise NotImplementedError 271 | 272 | def metric(self, pref, pgen): 273 | raise NotImplementedError 274 | 275 | 276 | class SNNMetric(Metric): 277 | """ 278 | Computes average max similarities of gen SMILES to ref SMILES 279 | """ 280 | 281 | def __init__(self, fp_type='morgan', **kwargs): 282 | self.fp_type = fp_type 283 | super().__init__(**kwargs) 284 | 285 | def precalc(self, mols): 286 | return {'fps': fingerprints(mols, n_jobs=self.n_jobs, 287 | fp_type=self.fp_type)} 288 | 289 | def metric(self, pref, pgen): 290 | return average_agg_tanimoto(pref['fps'], pgen['fps'], 291 | device=self.device) 292 | 293 | 294 | def cos_similarity(ref_counts, gen_counts): 295 | """ 296 | Computes cosine similarity between 297 | dictionaries of form {name: count}. Non-present 298 | elements are considered zero: 299 | 300 | sim = / ||r|| / ||g|| 301 | """ 302 | if len(ref_counts) == 0 or len(gen_counts) == 0: 303 | return np.nan 304 | keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys())) 305 | ref_vec = np.array([ref_counts.get(k, 0) for k in keys]) 306 | gen_vec = np.array([gen_counts.get(k, 0) for k in keys]) 307 | return 1 - cos_distance(ref_vec, gen_vec) 308 | 309 | 310 | class FragMetric(Metric): 311 | def precalc(self, mols): 312 | return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)} 313 | 314 | def metric(self, pref, pgen): 315 | return cos_similarity(pref['frag'], pgen['frag']) 316 | 317 | 318 | class ScafMetric(Metric): 319 | def precalc(self, mols): 320 | return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)} 321 | 322 | def metric(self, pref, pgen): 323 | return cos_similarity(pref['scaf'], pgen['scaf']) 324 | 325 | 326 | class WassersteinMetric(Metric): 327 | def __init__(self, func=None, **kwargs): 328 | self.func = func 329 | super().__init__(**kwargs) 330 | 331 | def precalc(self, mols): 332 | if self.func is not None: 333 | values = mapper(self.n_jobs)(self.func, mols) 334 | else: 335 | values = mols 336 | return {'values': values} 337 | 338 | def metric(self, pref, pgen): 339 | return wasserstein_distance( 340 | pref['values'], pgen['values'] 341 | ) 342 | -------------------------------------------------------------------------------- /moses/utils.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | 3 | from rdkit import Chem 4 | from rdkit import rdBase 5 | 6 | 7 | # https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader 8 | 9 | 10 | def mapper(n_jobs): 11 | ''' 12 | Returns function for map call. 13 | If n_jobs == 1, will use standard map 14 | If n_jobs > 1, will use multiprocessing pool 15 | If n_jobs is a pool object, will return its map function 16 | ''' 17 | if n_jobs == 1: 18 | def _mapper(*args, **kwargs): 19 | return list(map(*args, **kwargs)) 20 | 21 | return _mapper 22 | if isinstance(n_jobs, int): 23 | pool = Pool(n_jobs) 24 | 25 | def _mapper(*args, **kwargs): 26 | try: 27 | result = pool.map(*args, **kwargs) 28 | finally: 29 | pool.terminate() 30 | return result 31 | 32 | return _mapper 33 | return n_jobs.map 34 | 35 | 36 | def disable_rdkit_log(): 37 | rdBase.DisableLog('rdApp.*') 38 | 39 | 40 | def enable_rdkit_log(): 41 | rdBase.EnableLog('rdApp.*') 42 | 43 | 44 | def get_mol(smiles_or_mol): 45 | ''' 46 | Loads SMILES/molecule into RDKit's object 47 | ''' 48 | if isinstance(smiles_or_mol, str): 49 | if len(smiles_or_mol) == 0: 50 | return None 51 | mol = Chem.MolFromSmiles(smiles_or_mol) 52 | if mol is None: 53 | return None 54 | try: 55 | Chem.SanitizeMol(mol) 56 | except ValueError: 57 | return None 58 | return mol 59 | return smiles_or_mol 60 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.0 2 | transformers==4.41.1 3 | pytorch_lightning==1.9.5 4 | deepspeed==0.14.2 5 | rdkit==2023.9.6 6 | fcd_torch==1.0.7 7 | lmdb==1.4.1 8 | omegaconf==2.3.0 9 | pandas==2.2.2 10 | umap-learn==0.5.6 11 | hdbscan==0.8.36 12 | jupyter==1.0.0 13 | wandb==0.17.0 -------------------------------------------------------------------------------- /scripts/generation/conditional/README-Generation-Cond.md: -------------------------------------------------------------------------------- 1 | # Conditional Generation with GraphGPT-C Decoder 2 | 3 | The conditional generation of GraphGPT-C is similar to the unconditional version. However, there are some extra configurations to control the properties. For other configurations you can refer to [Unconditional Generation](..%2Funconditional%2FREADME-Generation-Uncond.md). 4 | 5 | 6 | 7 | ## Extra Generation Configurations 8 | 9 | - **Conditions:** 10 | - `value_qed`: `(float) None` (The target QED value for generated molecules. The model will not condition on this property if not specified.) 11 | - `value_sa`: `(float) None` (The target SA score for generated molecules. The model will not condition on this property if not specified.) 12 | - `value_logp`: `(float) None` (The target logP value for generated molecules. The model will not condition on this property if not specified.) 13 | - `scaffold_smiles`: `(str) None` (The target scaffold SMILES. The model will not condition on this property if not specified.) 14 | 15 | 16 | 17 | ### Configurations in the Paper 18 | 19 | We use the following configuration to test the ability of GraphGPT on conditioning molecular properties: 20 | 21 | ````bash 22 | strict_generation="False" 23 | fix_aromatic_bond="True" 24 | 25 | do_sample="False" 26 | 27 | check_first_node="True" 28 | check_atom_valence="True" 29 | ```` 30 | 31 | Example scripts can be found in `scripts/generation/unconditional/examples`. 32 | 33 | 34 | 35 | ### Generate with More Diversity 36 | 37 | You can further turn on the *probabilistic sampling* for more diversity: 38 | 39 | ````bash 40 | strict_generation="False" 41 | fix_aromatic_bond="True" 42 | 43 | do_sample="True" 44 | top_k=4 45 | temperature=1.0 46 | 47 | check_first_node="True" 48 | check_atom_valence="True" 49 | ```` 50 | 51 | 52 | 53 | ## Evaluate & Visualize the Results 54 | 55 | Run `scripts/generation/conditional/visualize.sh`. 56 | 57 | The mean and variance property values of generated molecules will also be saved to the `summary.txt`. 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /scripts/generation/conditional/examples/generate_scaffold_logp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name_or_path="DaizeDong/GraphsGPT-1W-C" 4 | #model_name_or_path="/mnt/petrelfs/dongdaize.d/workspace/graphsgpt/ckpts/GraphsGPT-1W-C" 5 | save_dir="./results/conditional/generation_scaffold_logp" # change this 6 | smiles_file="./data/examples/zinc_example.txt" 7 | num_batches=10 8 | batch_size=1024 9 | seed=0 10 | 11 | property_info_file="./configs/property_info.json" 12 | value_logp=0.0 # change this 13 | scaffold_smiles="c1ccccc1" # change this 14 | 15 | strict_generation="False" 16 | fix_aromatic_bond="True" 17 | do_sample="False" 18 | check_first_node="True" 19 | check_atom_valence="True" 20 | 21 | save_results="True" 22 | save_failed="False" 23 | 24 | python entrypoints/generation/conditional/generate.py \ 25 | --model_name_or_path ${model_name_or_path} \ 26 | --save_dir ${save_dir} \ 27 | --smiles_file ${smiles_file} \ 28 | --num_batches ${num_batches} \ 29 | --batch_size ${batch_size} \ 30 | --seed ${seed} \ 31 | --property_info_file ${property_info_file} \ 32 | --value_logp ${value_logp} \ 33 | --scaffold_smiles ${scaffold_smiles} \ 34 | --strict_generation ${strict_generation} \ 35 | --do_sample ${do_sample} \ 36 | --check_first_node ${check_first_node} \ 37 | --check_atom_valence ${check_atom_valence} \ 38 | --fix_aromatic_bond ${fix_aromatic_bond} \ 39 | --save_results ${save_results} \ 40 | --save_failed ${save_failed} 41 | -------------------------------------------------------------------------------- /scripts/generation/conditional/examples/generate_scaffold_qed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name_or_path="DaizeDong/GraphsGPT-1W-C" 4 | save_dir="./results/conditional/generation_scaffold_qed" # change this 5 | smiles_file="./data/examples/zinc_example.txt" 6 | num_batches=10 7 | batch_size=1024 8 | seed=0 9 | 10 | property_info_file="./configs/property_info.json" 11 | value_qed=0.5 # change this 12 | scaffold_smiles="c1ccccc1" # change this 13 | 14 | strict_generation="False" 15 | fix_aromatic_bond="True" 16 | do_sample="False" 17 | check_first_node="True" 18 | check_atom_valence="True" 19 | 20 | save_results="True" 21 | save_failed="False" 22 | 23 | python entrypoints/generation/conditional/generate.py \ 24 | --model_name_or_path ${model_name_or_path} \ 25 | --save_dir ${save_dir} \ 26 | --smiles_file ${smiles_file} \ 27 | --num_batches ${num_batches} \ 28 | --batch_size ${batch_size} \ 29 | --seed ${seed} \ 30 | --property_info_file ${property_info_file} \ 31 | --value_qed ${value_qed} \ 32 | --scaffold_smiles ${scaffold_smiles} \ 33 | --strict_generation ${strict_generation} \ 34 | --do_sample ${do_sample} \ 35 | --check_first_node ${check_first_node} \ 36 | --check_atom_valence ${check_atom_valence} \ 37 | --fix_aromatic_bond ${fix_aromatic_bond} \ 38 | --save_results ${save_results} \ 39 | --save_failed ${save_failed} 40 | -------------------------------------------------------------------------------- /scripts/generation/conditional/examples/generate_scaffold_sa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name_or_path="DaizeDong/GraphsGPT-1W-C" 4 | save_dir="./results/conditional/generation_scaffold_sa" # change this 5 | smiles_file="./data/examples/zinc_example.txt" 6 | num_batches=10 7 | batch_size=1024 8 | seed=0 9 | 10 | property_info_file="./configs/property_info.json" 11 | value_sa=0.7 # change this 12 | scaffold_smiles="c1ccccc1" # change this 13 | 14 | strict_generation="False" 15 | fix_aromatic_bond="True" 16 | do_sample="False" 17 | check_first_node="True" 18 | check_atom_valence="True" 19 | 20 | save_results="True" 21 | save_failed="False" 22 | 23 | python entrypoints/generation/conditional/generate.py \ 24 | --model_name_or_path ${model_name_or_path} \ 25 | --save_dir ${save_dir} \ 26 | --smiles_file ${smiles_file} \ 27 | --num_batches ${num_batches} \ 28 | --batch_size ${batch_size} \ 29 | --seed ${seed} \ 30 | --property_info_file ${property_info_file} \ 31 | --value_sa ${value_sa} \ 32 | --scaffold_smiles ${scaffold_smiles} \ 33 | --strict_generation ${strict_generation} \ 34 | --do_sample ${do_sample} \ 35 | --check_first_node ${check_first_node} \ 36 | --check_atom_valence ${check_atom_valence} \ 37 | --fix_aromatic_bond ${fix_aromatic_bond} \ 38 | --save_results ${save_results} \ 39 | --save_failed ${save_failed} 40 | -------------------------------------------------------------------------------- /scripts/generation/conditional/visualize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ### Conditional ### 3 | 4 | model_name_or_path="DaizeDong/GraphsGPT-1W-C" 5 | generation_results_dir="./results/conditional/generation_scaffold_logp/generated_results" # change this 6 | save_dir="./results/conditional/visualization_scaffold_logp" # change this 7 | save_images="False" 8 | 9 | # visualize all files to gather property info 10 | file_begin_index=0 11 | file_end_index=10 12 | 13 | python entrypoints/generation/conditional/visualize.py \ 14 | --model_name_or_path ${model_name_or_path} \ 15 | --generation_results_dir ${generation_results_dir} \ 16 | --save_dir ${save_dir} \ 17 | --save_images ${save_images} \ 18 | --file_begin_index ${file_begin_index} \ 19 | --file_end_index ${file_end_index} 20 | -------------------------------------------------------------------------------- /scripts/generation/evaluation/moses.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name_or_path="DaizeDong/GraphsGPT-1W" 4 | save_path="./results/unconditional/moses" # change this 5 | 6 | batch_size_valid=8196 7 | sample_std=1.0 8 | max_sample_times=10 9 | num_shots=100000 10 | num_samples_each_shot=1 11 | 12 | num_processes=32 13 | 14 | python entrypoints/generation/evaluation/evaluate_moses_few_shot_sampling.py \ 15 | --model_name_or_path ${model_name_or_path} \ 16 | --save_path ${save_path} \ 17 | --batch_size_valid ${batch_size_valid} \ 18 | --sample_std ${sample_std} \ 19 | --max_sample_times ${max_sample_times} \ 20 | --num_shots ${num_shots} \ 21 | --num_samples_each_shot ${num_samples_each_shot} \ 22 | --num_processes ${num_processes} 23 | -------------------------------------------------------------------------------- /scripts/generation/evaluation/zinc250k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name_or_path="DaizeDong/GraphsGPT-1W" 4 | save_path="./results/unconditional/zinc250k" # change this 5 | 6 | batch_size_valid=8196 7 | sample_std=1.0 8 | max_sample_times=10 9 | num_shots=100000 10 | num_samples_each_shot=1 11 | 12 | num_processes=32 13 | 14 | python entrypoints/generation/evaluation/evaluate_zinc250k_few_shot_sampling.py \ 15 | --model_name_or_path ${model_name_or_path} \ 16 | --save_path ${save_path} \ 17 | --batch_size_valid ${batch_size_valid} \ 18 | --sample_std ${sample_std} \ 19 | --max_sample_times ${max_sample_times} \ 20 | --num_shots ${num_shots} \ 21 | --num_samples_each_shot ${num_samples_each_shot} \ 22 | --num_processes ${num_processes} 23 | -------------------------------------------------------------------------------- /scripts/generation/unconditional/README-Generation-Uncond.md: -------------------------------------------------------------------------------- 1 | # Unconditional Generation with GraphGPT Decoder 2 | 3 | The generation of GraphsGPT is controlled by multiple adjustable configurations. You can refer to the following descriptions to adjust the generation to different needs. 4 | 5 | ## Generation Configurations 6 | 7 | - **Auto-fix Toggle:** 8 | - `strict_generation`: `(bool) True` (Whether to tolerate the validity exceptions during generation. Setting to `False` will enable flexible generation that automatically fixes invalid predictions and ensure maximum effectiveness.) 9 | - `fix_aromatic_bond`: `(bool) False` (Whether to fix the dissociative aromatic bonds in the generated molecules.) 10 | 11 | 12 | - **Sampling Strategy:** 13 | - `do_sample`: `(bool) False` (Whether to use probabilistic sampling for bond predictions. Setting to `True` will enable probabilistic sampling and introduce more randomness.) 14 | - `top_k`: `(int) None` (The range of top predictions for probability sampling. Available when `do_sample` is `True`.) 15 | - `temperature`: `(float) 1.0` (Temperature to adjust the probability distribution. Available when `do_sample` is `True`.) 16 | 17 | 18 | - **Hyperparameters:** 19 | - `max_atoms`: `(int) None` (The maximum number of atoms for generation.) 20 | - `similarity_threshold`: `(float) 0.5` (Threshold for classifying whether a generated atom is new or old.) 21 | 22 | 23 | - **Other Check Terms:** 24 | - `check_first_node`: `(bool) True` (Whether to check the consistency between the predicted beginning atom and the first bond, and fix the order of the beginning two atoms.) 25 | - `check_atom_valence`: `(bool) False` (Whether to check the validity regarding the valence of the atoms connected to the predicted bonds.) 26 | 27 | For reference, we provide some example configurations to use under different circumstances. 28 | 29 | ### Validate the Pretraining Performance 30 | 31 | To validate the pretraining performance, the generation should be of no randomness, both *auto-fix* and *probabilistic sampling* should be turned off: 32 | 33 | ````bash 34 | strict_generation="True" 35 | fix_aromatic_bond="False" 36 | 37 | do_sample="False" 38 | 39 | check_first_node="True" 40 | check_atom_valence="False" 41 | ```` 42 | 43 | An example script can be found in `scripts/generation/unconditional/examples/generate_strict.sh`. 44 | 45 | ### Generate with More Effectiveness 46 | 47 | Upon generation of more effective results. You can turn on the *auto-fix*: 48 | 49 | ````bash 50 | strict_generation="False" 51 | fix_aromatic_bond="True" 52 | 53 | do_sample="False" 54 | 55 | check_first_node="True" 56 | check_atom_valence="False" 57 | ```` 58 | 59 | An example script can be found in `scripts/generation/unconditional/examples/generate_flexible.sh`. 60 | 61 | ### Generate with More Diversity (Further Finetuning Needed) 62 | 63 | To generate more diverse results. You can further turn on the *probabilistic sampling* and *valence check*. This requires the encoded Graph Words $\mathcal{W}$ to be of variable information, where an extra finetuning would be needed. 64 | 65 | ````bash 66 | strict_generation="False" 67 | fix_aromatic_bond="True" 68 | 69 | do_sample="True" 70 | top_k=4 71 | temperature=1.0 72 | 73 | check_first_node="True" 74 | check_atom_valence="True" 75 | ```` 76 | 77 | ## Visualize the Results 78 | 79 | Run `scripts/generation/unconditional/visualize.sh`. 80 | 81 | -------------------------------------------------------------------------------- /scripts/generation/unconditional/examples/generate_flexible.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name_or_path="DaizeDong/GraphsGPT-1W" 4 | save_dir="./results/unconditional/generation" # change this 5 | smiles_file="./data/examples/zinc_example.txt" 6 | num_batches=10 7 | batch_size=1024 8 | seed=0 9 | 10 | strict_generation="False" 11 | fix_aromatic_bond="True" 12 | do_sample="True" 13 | check_first_node="True" 14 | check_atom_valence="False" 15 | 16 | save_results="True" 17 | save_failed="False" 18 | 19 | python entrypoints/generation/unconditional/generate.py \ 20 | --model_name_or_path ${model_name_or_path} \ 21 | --save_dir ${save_dir} \ 22 | --smiles_file ${smiles_file} \ 23 | --num_batches ${num_batches} \ 24 | --batch_size ${batch_size} \ 25 | --seed ${seed} \ 26 | --strict_generation ${strict_generation} \ 27 | --do_sample ${do_sample} \ 28 | --check_first_node ${check_first_node} \ 29 | --check_atom_valence ${check_atom_valence} \ 30 | --fix_aromatic_bond ${fix_aromatic_bond} \ 31 | --save_results ${save_results} \ 32 | --save_failed ${save_failed} 33 | -------------------------------------------------------------------------------- /scripts/generation/unconditional/examples/generate_strict.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name_or_path="DaizeDong/GraphsGPT-1W" 4 | save_dir="./results/unconditional/generation" # change this 5 | smiles_file="./data/examples/zinc_example.txt" 6 | num_batches=10 7 | batch_size=1024 8 | seed=0 9 | 10 | strict_generation="True" 11 | fix_aromatic_bond="False" 12 | do_sample="False" 13 | check_first_node="True" 14 | check_atom_valence="False" 15 | 16 | save_results="True" 17 | save_failed="False" 18 | 19 | python entrypoints/generation/unconditional/generate.py \ 20 | --model_name_or_path ${model_name_or_path} \ 21 | --save_dir ${save_dir} \ 22 | --smiles_file ${smiles_file} \ 23 | --num_batches ${num_batches} \ 24 | --batch_size ${batch_size} \ 25 | --seed ${seed} \ 26 | --strict_generation ${strict_generation} \ 27 | --do_sample ${do_sample} \ 28 | --check_first_node ${check_first_node} \ 29 | --check_atom_valence ${check_atom_valence} \ 30 | --fix_aromatic_bond ${fix_aromatic_bond} \ 31 | --save_results ${save_results} \ 32 | --save_failed ${save_failed} 33 | -------------------------------------------------------------------------------- /scripts/generation/unconditional/visualize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ### Unconditional ### 3 | 4 | model_name_or_path="DaizeDong/GraphsGPT-1W" 5 | generation_results_dir="./results/unconditional/generation/generated_results" # change this 6 | save_dir="./results/unconditional/visualization" # change this 7 | save_images="True" 8 | 9 | # only visualize 2 files to save time & storage 10 | file_begin_index=0 11 | file_end_index=2 12 | 13 | python entrypoints/generation/unconditional/visualize.py \ 14 | --model_name_or_path ${model_name_or_path} \ 15 | --generation_results_dir ${generation_results_dir} \ 16 | --save_dir ${save_dir} \ 17 | --save_images ${save_images} \ 18 | --file_begin_index ${file_begin_index} \ 19 | --file_end_index ${file_end_index} 20 | -------------------------------------------------------------------------------- /scripts/representation/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python entrypoints/representation/finetune.py --task_name tox21 4 | python entrypoints/representation/finetune.py --task_name toxcast 5 | python entrypoints/representation/finetune.py --task_name bbbp 6 | python entrypoints/representation/finetune.py --task_name sider 7 | python entrypoints/representation/finetune.py --task_name hiv 8 | python entrypoints/representation/finetune.py --task_name bace 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='graphsgpt', 5 | version='1.0', 6 | packages=find_packages(), 7 | install_requires=[], 8 | classifiers=[ 9 | 'Programming Language :: Python :: 3', 10 | 'Programming Language :: Python :: 3.6', 11 | 'Programming Language :: Python :: 3.7', 12 | 'Programming Language :: Python :: 3.8', 13 | ], 14 | ) 15 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/utils/__init__.py -------------------------------------------------------------------------------- /utils/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import no_grad 3 | 4 | 5 | @no_grad() 6 | def classification_accuracy( 7 | logits: torch.FloatTensor, 8 | labels: torch.LongTensor, 9 | ): 10 | prediction = torch.argmax(logits, dim=1) 11 | correct = (prediction == labels).float() 12 | return torch.mean(correct) 13 | -------------------------------------------------------------------------------- /utils/algorithms.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | 5 | 6 | def find_factors_with_minimal_sum(number): 7 | if number == 1: 8 | return (1, 1) 9 | 10 | # Initialize variables to keep track of the factors with the minimal sum 11 | min_sum = float("inf") 12 | min_factors = None 13 | 14 | # Iterate through potential factors from 1 to half of the number 15 | for factor1 in range(1, number // 2 + 1): 16 | factor2 = number // factor1 17 | 18 | # Check if factor1 * factor2 is equal to the original number 19 | if factor1 * factor2 == number: 20 | current_sum = factor1 + factor2 21 | 22 | # Update the minimum sum and factors if the current sum is smaller 23 | if current_sum < min_sum: 24 | min_sum = current_sum 25 | min_factors = (factor1, factor2) 26 | 27 | return min_factors 28 | 29 | 30 | class FindCycles: 31 | """ 32 | Example: 33 | adjacency_matrix = np.array([ 34 | [0, 1, 0, 0, 1, 1, 0], 35 | [1, 0, 1, 1, 0, 0, 0], 36 | [0, 1, 0, 1, 0, 0, 0], 37 | [0, 1, 1, 0, 1, 0, 0], 38 | [1, 0, 0, 1, 0, 0, 0], 39 | [1, 0, 0, 0, 0, 0, 1], 40 | [0, 0, 0, 0, 0, 1, 0], 41 | ]) 42 | 43 | finder = FindCycles(adjacency_matrix) 44 | cycles = finder.find_cycles() 45 | """ 46 | 47 | def __init__(self, adj_matrix): 48 | self.graph = adj_matrix 49 | self.num_vertices = len(adj_matrix) 50 | self.visited = [False] * self.num_vertices 51 | self.cycles = set() # Set to store unique cycles 52 | 53 | def find_cycles(self): 54 | # Remove nodes with degree 1 55 | self.remove_degree_one_nodes() 56 | 57 | # Perform DFS for remaining nodes 58 | for vertex in range(self.num_vertices): 59 | if not self.visited[vertex]: 60 | self.dfs(vertex, vertex, []) 61 | 62 | # Deduplicate 63 | self.deduplicate_cycles() 64 | 65 | return list(self.cycles) 66 | 67 | def remove_degree_one_nodes(self): 68 | # Use a deque to efficiently process nodes with degree 1 69 | queue = deque([i for i in range(self.num_vertices) if np.sum(self.graph[i]) == 1]) 70 | 71 | while queue: 72 | node = queue.popleft() 73 | self.graph[node, node] = 0 # Mark the node as removed 74 | 75 | # Decrease the degree of its neighbor 76 | neighbor = np.argmax(self.graph[node]) 77 | self.graph[node, neighbor] = 0 78 | self.graph[neighbor, node] = 0 79 | 80 | # If the neighbor now has degree 1, add it to the queue 81 | if np.sum(self.graph[neighbor]) == 1: 82 | queue.append(neighbor) 83 | 84 | def dfs(self, start, current, path): 85 | self.visited[current] = True 86 | path.append(current) 87 | 88 | for neighbor in range(self.num_vertices): 89 | if self.graph[current, neighbor] == 1: 90 | if neighbor == start and len(path) > 2: 91 | # Found a cycle with at least 3 vertices 92 | self.cycles.add(tuple(path)) 93 | elif not self.visited[neighbor] and neighbor > start: 94 | # Continue DFS only if the neighbor has not been removed 95 | self.dfs(start, neighbor, path) 96 | 97 | path.pop() 98 | self.visited[current] = False 99 | 100 | def deduplicate_cycles(self): 101 | all_cycles = self.cycles 102 | record_set = set() 103 | 104 | self.cycles = [] 105 | for cycle in all_cycles: 106 | sorted_cycle = tuple(sorted(cycle)) 107 | if sorted_cycle not in record_set: 108 | record_set.add(sorted_cycle) 109 | self.cycles.append(cycle) 110 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import gzip 3 | import json 4 | import lzma 5 | import numpy as np 6 | import os 7 | import pickle 8 | import shutil 9 | from rdkit import Chem 10 | from rdkit.Chem import Draw 11 | from typing import Union, List, Dict 12 | 13 | from utils.operations.operation_string import extract_numbers 14 | 15 | 16 | def create_dir(dir): 17 | if not os.path.exists(dir): 18 | os.makedirs(dir) 19 | 20 | 21 | def delete_file_or_dir(dir): 22 | if os.path.isfile(dir): 23 | os.remove(dir) 24 | elif os.path.exists(dir): 25 | shutil.rmtree(dir) 26 | else: 27 | pass 28 | 29 | 30 | def save_compressed_file_7z(data, file_path): # 7z 31 | create_dir(os.path.dirname(file_path)) 32 | with lzma.open(file_path, "wb") as file: 33 | pickle.dump(data, file) 34 | 35 | 36 | def load_compressed_file_7z(file_path): # 7z 37 | with lzma.open(file_path, "rb") as file: 38 | data = pickle.load(file) 39 | return data 40 | 41 | 42 | def save_compressed_file_gz(data, file_path, compresslevel=6): # gz 43 | create_dir(os.path.dirname(file_path)) 44 | with gzip.open(file_path, "wb", compresslevel=compresslevel) as file: 45 | pickle.dump(data, file) 46 | 47 | 48 | def load_compressed_file_gz(file_path): # gz 49 | with gzip.open(file_path, "rb") as file: 50 | data = pickle.load(file) 51 | return data 52 | 53 | 54 | def read_csv(file_path, has_header=True) -> Union[List[List], List[Dict]]: 55 | """ 56 | Read a CSV file and return its content. 57 | 58 | Args: 59 | - file_path (str): Path to the CSV file. 60 | - has_header (bool): Whether the CSV file has a header. Default is True. 61 | 62 | Returns: 63 | - list of list or dict: Content of the CSV file. 64 | If has_header is True, return a list of dictionaries; 65 | if has_header is False, return a list of lists. 66 | """ 67 | data = [] 68 | with open(file_path, newline='', encoding='utf-8') as f: 69 | if has_header: 70 | csvreader = csv.DictReader(f) 71 | for row in csvreader: 72 | data.append(dict(row)) 73 | else: 74 | csvreader = csv.reader(f) 75 | for row in csvreader: 76 | data.append(row) 77 | return data 78 | 79 | 80 | def load_json(file_path): 81 | with open(file_path, "r", encoding="utf8") as f: 82 | data = json.load(f) 83 | return data 84 | 85 | 86 | def save_json(data, file_path, indent=4, **kwargs): 87 | create_dir(os.path.dirname(file_path)) 88 | with open(file_path, "w", encoding="utf8") as f: 89 | f.write(f"{json.dumps(data, ensure_ascii=False, indent=indent, **kwargs)}\n") 90 | 91 | 92 | def load_jsonl(file_path) -> list: 93 | data = [] 94 | with open(file_path, "r", encoding="utf8") as f: 95 | for line in f: 96 | try: 97 | data.append(json.loads(line)) 98 | except json.JSONDecodeError as e: 99 | print(f"Error decoding line: {line}") 100 | continue 101 | return data 102 | 103 | 104 | def save_jsonl(data, file_path, **kwargs): 105 | create_dir(os.path.dirname(file_path)) 106 | with open(file_path, "w", encoding="utf8") as f: 107 | for ins in data: 108 | f.write(f"{json.dumps(ins, ensure_ascii=False, **kwargs)}\n") 109 | 110 | 111 | def compress_png_image(image_path, print_info=False): 112 | import cv2 113 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) 114 | cv2.imwrite(image_path, img, [cv2.IMWRITE_PNG_COMPRESSION, 9]) 115 | if print_info: 116 | print(f'Done for "{image_path}".') 117 | 118 | 119 | """for this project""" 120 | 121 | 122 | def find_best_model(model_dir): 123 | best_model_file = None 124 | best_opt_steps = 0 125 | best_val_loss = float("inf") 126 | 127 | for root, dirs, files in os.walk(model_dir): 128 | for file in files: 129 | file_postfix = file.split(".")[-1] 130 | file_name = file.replace("." + file_postfix, "") 131 | 132 | if file_postfix == "ckpt" and "best" in file_name: 133 | # Example: "best-opt_steps=091999-val_loss=0.1109.ckpt" 134 | this_opt_steps, this_val_loss = extract_numbers(file_name) 135 | 136 | if this_val_loss < best_val_loss: 137 | # model with the minimal val_loss 138 | best_model_file = os.path.join(root, file) 139 | best_opt_steps = this_opt_steps 140 | best_val_loss = this_val_loss 141 | 142 | elif this_val_loss == best_val_loss and this_opt_steps > best_opt_steps: 143 | # model with the largest opt_steps 144 | best_model_file = os.path.join(root, file) 145 | best_opt_steps = this_opt_steps 146 | best_val_loss = this_val_loss 147 | 148 | return best_model_file 149 | 150 | 151 | def find_last_model(model_dir): 152 | last_model_file = None 153 | 154 | for root, dirs, files in os.walk(model_dir): 155 | for file in files: 156 | if file == "last.ckpt": 157 | last_model_file = os.path.join(root, file) 158 | break 159 | 160 | return last_model_file 161 | 162 | 163 | def get_avg_acc_from_file(save_file_match_num): 164 | acc_list = [] 165 | with open(save_file_match_num, "r") as f: 166 | lines = f.readlines() 167 | for line in lines: 168 | # example line: "Accuracy of sample 0: 100.00% (1024/1024)" 169 | # example line: "Consistency of sample 0: 100.00% (523776/523776)" 170 | numbers = extract_numbers(line) 171 | acc_list.append(numbers[1]) 172 | return sum(acc_list) / len(acc_list) 173 | 174 | 175 | def save_mol_png(mol, save_path, size=(512, 512)): 176 | img = Draw.MolToImage(mol, size=size) 177 | img.save(save_path) 178 | img.close() 179 | 180 | 181 | def save_empty_png(save_path, size=(512, 512)): 182 | img = Draw.MolToImage(Chem.Mol(), size=size) 183 | img.save(save_path) 184 | img.close() 185 | 186 | 187 | def summary_property_from_file(file_name): 188 | value_list = [] 189 | with open(file_name, "r") as f: 190 | lines = f.readlines() 191 | for line in lines: 192 | line = line.removesuffix("\n") 193 | if line != "None": 194 | number = float(line) 195 | value_list.append(number) 196 | if len(value_list) > 0: 197 | values = np.array(value_list) 198 | return values.mean(), values.std(), len(value_list) 199 | else: 200 | return -1, -1, -1 201 | 202 | 203 | def summary_property_from_all_files(search_dir, file_name, value_type="float"): 204 | value_list = [] 205 | for root, dirs, files in os.walk(search_dir): 206 | for file in files: 207 | if file == file_name: 208 | with open(os.path.join(root, file), "r") as f: 209 | lines = f.readlines() 210 | for line in lines: 211 | line = line.removesuffix("\n") 212 | if line != "None": 213 | if value_type == "float": 214 | number = float(line) 215 | elif value_type == "bool": 216 | number = (line == "True") 217 | value_list.append(number) 218 | if len(value_list) > 0: 219 | values = np.array(value_list) 220 | return values.mean(), values.std(), len(value_list) 221 | else: 222 | return -1, -1, -1 223 | -------------------------------------------------------------------------------- /utils/molecule.py: -------------------------------------------------------------------------------- 1 | from rdkit.Chem import MolStandardize 2 | from rdkit.Chem.Scaffolds import MurckoScaffold 3 | 4 | 5 | def get_molecule_standard_scaffold(mol, normalizer=None): 6 | if normalizer is None: 7 | normalizer = MolStandardize.normalize.Normalizer() 8 | 9 | if mol is not None: 10 | try: 11 | scaffold = MurckoScaffold.GetScaffoldForMol(mol) 12 | 13 | try: 14 | standardized_scaffold = normalizer.normalize(scaffold) 15 | except: 16 | standardized_scaffold = scaffold 17 | 18 | return standardized_scaffold 19 | except: 20 | return None 21 | else: 22 | return None 23 | -------------------------------------------------------------------------------- /utils/operations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/utils/operations/__init__.py -------------------------------------------------------------------------------- /utils/operations/operation_dataframe.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | def chunk_dataframe(df, num_chunks): 5 | """ 6 | Split the input DataFrame into a specified number of chunks. 7 | 8 | Args: 9 | - df (pd.DataFrame): The DataFrame to be evenly divided. 10 | - num_chunks (int): The desired number of chunks. 11 | 12 | Returns: 13 | - list of pd.DataFrame: List of chunks after even division. 14 | 15 | Example: 16 | >>> input_df = pd.DataFrame({'A': range(1, 10)}) 17 | >>> num_chunks = 5 18 | >>> result = chunk_dataframe(input_df, num_chunks) 19 | >>> for chunk in result: 20 | >>> print(chunk) # Output: DataFrame chunks printed one by one 21 | """ 22 | avg_chunk_size = len(df) // num_chunks 23 | remainder = len(df) % num_chunks 24 | 25 | chunks = [] 26 | start = 0 27 | for _ in range(num_chunks): 28 | chunk_size = avg_chunk_size + 1 if remainder > 0 else avg_chunk_size 29 | chunks.append(df.iloc[start:start + chunk_size]) 30 | start += chunk_size 31 | remainder -= 1 32 | 33 | return chunks 34 | 35 | 36 | def chunk_dataframe_with_yield(df, num_chunks): 37 | """ 38 | Split the input DataFrame into a specified number of chunks using a generator. 39 | 40 | Args: 41 | - df (pd.DataFrame): The DataFrame to be evenly divided. 42 | - num_chunks (int): The desired number of chunks. 43 | 44 | Yields: 45 | - pd.DataFrame: DataFrame chunks yielded one at a time. 46 | 47 | Example: 48 | >>> input_df = pd.DataFrame({'A': range(1, 10)}) 49 | >>> num_chunks = 5 50 | >>> for chunk in chunk_dataframe_with_yield(input_df, num_chunks): 51 | >>> print(chunk) # Output: DataFrame chunks printed one by one 52 | """ 53 | avg_chunk_size = len(df) // num_chunks 54 | remainder = len(df) % num_chunks 55 | 56 | start = 0 57 | for _ in range(num_chunks): 58 | chunk_size = avg_chunk_size + 1 if remainder > 0 else avg_chunk_size 59 | yield df.iloc[start:start + chunk_size] 60 | start += chunk_size 61 | remainder -= 1 62 | -------------------------------------------------------------------------------- /utils/operations/operation_dict.py: -------------------------------------------------------------------------------- 1 | def reverse_dict(input_dict, aggregate_same_results=True): 2 | output_dict = {} 3 | for key, value in input_dict.items(): 4 | if value not in output_dict: 5 | output_dict[value] = key 6 | else: 7 | if aggregate_same_results: 8 | if not isinstance(output_dict[value], list): 9 | output_dict[value] = [output_dict[value]] 10 | output_dict[value].append(key) 11 | else: 12 | raise ValueError("Input dictionary does not satisfy the one-to-one mapping condition.") 13 | return output_dict 14 | -------------------------------------------------------------------------------- /utils/operations/operation_list.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Union, List 3 | 4 | 5 | def chunk_list(input_list, num_chunks): 6 | """ 7 | Split the input list into a specified number of chunks. 8 | 9 | Args: 10 | - input_list (list): The list to be evenly divided. 11 | - num_chunks (int): The desired number of chunks. 12 | 13 | Returns: 14 | - list of lists: List of chunks after even division. 15 | 16 | Example: 17 | >>> input_list = [1, 2, 3, 4, 5, 6, 7, 8, 9] 18 | >>> num_chunks = 5 19 | >>> result = chunk_list(input_list, num_chunks) 20 | >>> print(result) # Output: [[1, 2], [3, 4], [5, 6], [7, 8], [9]] 21 | """ 22 | avg_chunk_size = len(input_list) // num_chunks 23 | remainder = len(input_list) % num_chunks 24 | 25 | chunks = [] 26 | start = 0 27 | for _ in range(num_chunks): 28 | chunk_size = avg_chunk_size + 1 if remainder > 0 else avg_chunk_size 29 | chunks.append(input_list[start:start + chunk_size]) 30 | start += chunk_size 31 | remainder -= 1 32 | 33 | return chunks 34 | 35 | 36 | def chunk_list_with_yield(input_list, num_chunks): 37 | """ 38 | Split the input list into a specified number of chunks using a generator. 39 | 40 | Args: 41 | - input_list (list): The list to be evenly divided. 42 | - num_chunks (int): The desired number of chunks. 43 | 44 | Yields: 45 | - list of lists: Chunks yielded one at a time. 46 | 47 | Example: 48 | >>> input_list = [1, 2, 3, 4, 5, 6, 7, 8, 9] 49 | >>> num_chunks = 5 50 | >>> for chunk in chunk_list_with_yield(input_list, num_chunks): 51 | >>> print(chunk) # Output: [1, 2] [3, 4] [5, 6] [7, 8] [9] 52 | """ 53 | avg_chunk_size = len(input_list) // num_chunks 54 | remainder = len(input_list) % num_chunks 55 | 56 | start = 0 57 | for _ in range(num_chunks): 58 | chunk_size = avg_chunk_size + 1 if remainder > 0 else avg_chunk_size 59 | yield input_list[start:start + chunk_size] 60 | start += chunk_size 61 | remainder -= 1 62 | 63 | 64 | def split_list(input_list, split_length, drop_last=False): 65 | """ 66 | Split a list into sublists, each with a length of split_length. 67 | 68 | Args: 69 | - input_list (list): The list to be split. 70 | - split_length (int): Length of each sublist. 71 | - drop_last (bool): Whether to drop the last sublist if its length is insufficient. Default is False. 72 | 73 | Returns: 74 | - list of lists: List of split sublists. 75 | 76 | Example: 77 | >>> input_list = [1, 2, 3, 4, 5, 6, 7, 8, 9] 78 | >>> split_length = 5 79 | >>> result = split_list(input_list, split_length, drop_last=False) 80 | >>> print(result) # Output: [[1, 2, 3, 4, 5], [6, 7, 8, 9]] 81 | """ 82 | if split_length <= 0: 83 | raise ValueError("split_length must be a positive integer!") 84 | 85 | num_elements = len(input_list) 86 | num_splits = num_elements // split_length 87 | 88 | sublists = [input_list[i * split_length: (i + 1) * split_length] for i in range(num_splits)] 89 | 90 | if not drop_last and num_splits * split_length < num_elements: 91 | sublists.append(input_list[num_splits * split_length:]) 92 | 93 | return sublists 94 | 95 | 96 | def split_list_with_yield(input_list, split_length, drop_last=False): 97 | """ 98 | Split a list into sublists using a generator, each with a length of split_length. 99 | 100 | Args: 101 | - input_list (list): The list to be split. 102 | - split_length (int): Length of each sublist. 103 | - drop_last (bool): Whether to drop the last sublist if its length is insufficient. Default is False. 104 | 105 | Yields: 106 | - list of lists: Sublists yielded one at a time. 107 | 108 | Example: 109 | >>> input_list = [1, 2, 3, 4, 5, 6, 7, 8, 9] 110 | >>> split_length = 5 111 | >>> result = split_list_with_yield(input_list, split_length, drop_last=False) 112 | >>> for sublist in result: 113 | >>> print(sublist) # Output: [1, 2, 3, 4, 5] [6, 7, 8, 9] 114 | """ 115 | if split_length <= 0: 116 | raise ValueError("split_length must be a positive integer!") 117 | 118 | num_elements = len(input_list) 119 | num_splits = num_elements // split_length 120 | 121 | start = 0 122 | for _ in range(num_splits): 123 | sublist = input_list[start: start + split_length] 124 | yield sublist 125 | start += split_length 126 | 127 | if not drop_last and start < num_elements: 128 | sublist = input_list[start:] 129 | yield sublist 130 | 131 | 132 | def replicate_elements(input_list, num_copies: Union[int, float, List[int]]): 133 | """ 134 | Replicate each element in the original list a fixed or variable number of times, 135 | with support for decimals indicating a proportional number of additional elements. 136 | 137 | Args: 138 | - input_list (list): The original list of elements. 139 | - num_copies: The number of times each element should be replicated. 140 | It can take multiple forms: 141 | - An integer: Each element in the list is replicated this many times. 142 | - A float: The integer part dictates the fixed number of times each element is replicated. 143 | The fractional part is used to determine a proportional number of extra elements to be 144 | randomly chosen and replicated once. For example, if num_copies is 2.5 and the list 145 | has 4 elements, each element is replicated 2 times, and additionally, 2 (50% of 4) 146 | randomly chosen elements are replicated once more. 147 | - A list of integers: Each element in the input list is replicated according to the 148 | corresponding number of times specified in this list. This allows for variable replication 149 | per element. The length of this list must match the length of the input list. 150 | 151 | Returns: 152 | - list: The new list with replicated elements. 153 | """ 154 | if isinstance(num_copies, (int, float)): 155 | # Fixed replication, integer part and proportional part handled 156 | int_part = int(num_copies) # Integer part for the definite copies 157 | num_copies_list = [int_part] * len(input_list) 158 | 159 | if isinstance(num_copies, float): 160 | frac_part = num_copies - int_part # Fractional part for the proportional extra copies 161 | extra_copies_count = round(frac_part * len(input_list)) 162 | 163 | if extra_copies_count > 0: 164 | # Choose the items to be replicated 165 | extra_num_copies_array = np.concatenate(( 166 | np.ones(extra_copies_count, dtype=int), 167 | np.zeros(len(input_list) - extra_copies_count, dtype=int) 168 | ), axis=0) 169 | np.random.shuffle(extra_num_copies_array) 170 | 171 | # Add to the "num_copies_list" 172 | num_copies_list = (np.array(num_copies_list) + extra_num_copies_array).tolist() 173 | 174 | elif isinstance(num_copies, list): 175 | # Variable replication based on the list 176 | if len(input_list) != len(num_copies): 177 | raise ValueError("Lengths of input_list and num_copies_list must be the same.") 178 | 179 | num_copies_list = num_copies 180 | 181 | else: 182 | raise ValueError("Invalid type for num_copies. It should be an int, float, or a list.") 183 | 184 | new_list = [] 185 | for item, num_copies_item in zip(input_list, num_copies_list): 186 | new_list.extend([item] * num_copies_item) 187 | 188 | return new_list 189 | 190 | 191 | def all_elements_equal(input_list): 192 | """Check if all elements in the list are equal.""" 193 | if not input_list: 194 | return True 195 | return len(set(input_list)) == 1 196 | 197 | 198 | def mean_value_of_elements(input_list): 199 | """Return the mean value of elements in the list.""" 200 | total_value = 0 201 | total_cnt = 0 202 | for element in input_list: 203 | if element is not None: 204 | total_value += element 205 | total_cnt += 1 206 | if total_cnt > 0: 207 | return total_value / total_cnt 208 | else: 209 | return 0 210 | -------------------------------------------------------------------------------- /utils/operations/operation_number.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | 4 | def normalize_value(value: Union[int, float], mean, std): 5 | if value is not None: 6 | if std != 0: 7 | return (value - mean) / std 8 | else: 9 | return value 10 | else: 11 | return None 12 | 13 | 14 | def denormalize_value(normalized_value: Union[int, float], mean, std): 15 | if normalized_value is not None: 16 | if std != 0: 17 | return (normalized_value * std) + mean 18 | else: 19 | return normalized_value 20 | else: 21 | return None 22 | -------------------------------------------------------------------------------- /utils/operations/operation_string.py: -------------------------------------------------------------------------------- 1 | import re 2 | from argparse import ArgumentTypeError 3 | 4 | 5 | def str2bool(v): 6 | if isinstance(v, bool): 7 | return v 8 | if v.lower() in ("yes", "true", "t", "y", "1"): 9 | return True 10 | elif v.lower() in ("no", "false", "f", "n", "0"): 11 | return False 12 | else: 13 | raise ArgumentTypeError("Boolean value expected.") 14 | 15 | 16 | def string2number_list(string, sep=","): 17 | if isinstance(string, list) or string is None: 18 | return string 19 | else: 20 | split_string = string.split(sep) 21 | return [float(num) if "." in num else int(num) for num in split_string] 22 | 23 | 24 | def extract_numbers(string): 25 | """Extract numbers (int, float) from a given string.""" 26 | pattern = r"[-+]?\d*\.\d+|\d+" 27 | matches = re.findall(pattern, string) 28 | numbers = [float(match) if '.' in match else int(match) for match in matches] 29 | return numbers 30 | 31 | 32 | def calculate_non_ascii_ratio(string): 33 | """Calculate the non-ASCII ratio of a given string.""" 34 | if len(string) == 0: 35 | non_ascii_ratio = 0.0 36 | else: 37 | non_ascii_count = sum(1 for char in string if ord(char) >= 128) 38 | non_ascii_ratio = non_ascii_count / len(string) 39 | return non_ascii_ratio 40 | 41 | 42 | def remove_non_ascii_code(string): 43 | """Use a regular expression to remove all non-ASCII characters""" 44 | string = re.sub(r'[^\x00-\x7F]+', '', string) 45 | return string 46 | 47 | 48 | def replace_non_ascii_code(string): 49 | """ 50 | Replace common non-ASCII characters with their ASCII counterparts in the given string. 51 | 52 | :param string: Input string with non-ASCII characters. 53 | :return: String with non-ASCII characters replaced. 54 | """ 55 | string = re.sub(r'“|”', "\"", string) 56 | string = re.sub(r'‘|’', "\'", string) 57 | string = re.sub(r'—|–', "-", string) 58 | string = re.sub(r'…', "...", string) 59 | 60 | return string 61 | -------------------------------------------------------------------------------- /utils/operations/operation_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def move_tensors_to_device(input, device): 5 | if input is None: 6 | return input 7 | 8 | elif isinstance(input, dict): 9 | for key, value in input.items(): 10 | if isinstance(value, torch.Tensor): 11 | input[key] = value.to(device) 12 | return input 13 | 14 | elif isinstance(input, list): 15 | for i in range(len(input)): 16 | if isinstance(input[i], torch.Tensor): 17 | input[i] = input[i].to(device) 18 | return input 19 | 20 | elif isinstance(input, torch.Tensor): 21 | return input.to(device) 22 | 23 | else: 24 | raise TypeError(input) 25 | 26 | 27 | def tensor2numbers(input): 28 | if input is None: 29 | return input 30 | 31 | elif isinstance(input, dict): 32 | for key, value in input.items(): 33 | if isinstance(value, torch.Tensor): 34 | input[key] = value.tolist() 35 | return input 36 | 37 | elif isinstance(input, list): 38 | for i in range(len(input)): 39 | if isinstance(input[i], torch.Tensor): 40 | input[i] = input[i].tolist() 41 | return input 42 | 43 | elif isinstance(input, torch.Tensor): 44 | return input.tolist() 45 | 46 | else: 47 | raise TypeError(input) 48 | 49 | 50 | def turn_last_true_mask_to_false(mask, true_mask_cnt=None): 51 | """Turn the last true value to false for each row in a mask matrix.""" 52 | # mask: shape(batch_size, seq_len) 53 | if true_mask_cnt is None: 54 | true_mask_cnt = torch.sum(mask, dim=1).unsqueeze(1) 55 | turn_position_indices = (mask.cumsum(dim=1) == true_mask_cnt) 56 | converted_mask = mask.clone() 57 | converted_mask[turn_position_indices] = False 58 | return converted_mask 59 | 60 | 61 | def turn_first_true_mask_to_false(mask): 62 | """Turn the first true value to false for each row in a mask matrix.""" 63 | # mask: shape(batch_size, seq_len) 64 | turn_position_indices = (mask.cumsum(dim=1) == 1) 65 | converted_mask = mask.clone() 66 | converted_mask[turn_position_indices] = False 67 | return converted_mask 68 | 69 | 70 | def last_true_position(mask): 71 | """Return the index of the last true value in each row in a mask matrix.""" 72 | # mask: shape(batch_size, seq_len) 73 | true_mask_cnt = torch.sum(mask, dim=1).unsqueeze(1) 74 | last_true_mask = (mask.cumsum(dim=1) == true_mask_cnt) & mask 75 | last_true_position = last_true_mask.nonzero()[:, 1].unsqueeze(1) 76 | return last_true_position 77 | 78 | 79 | def pass_kernel_function(tensor, criterion, allow_nan=False): 80 | if criterion == "plain": 81 | return tensor 82 | elif criterion == "sqrt": 83 | if not allow_nan and torch.any(tensor < 0): 84 | raise ValueError("Detected negative value in the tensor! This will cause the result to be \"nan\"!") 85 | return torch.sqrt(tensor) 86 | elif criterion == "l1": 87 | return torch.abs(tensor) 88 | elif criterion == "l2": 89 | return tensor * tensor 90 | else: 91 | raise NotImplementedError 92 | -------------------------------------------------------------------------------- /utils/property_scores/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/utils/property_scores/__init__.py -------------------------------------------------------------------------------- /utils/property_scores/fpscores.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/utils/property_scores/fpscores.pkl.gz -------------------------------------------------------------------------------- /utils/property_scores/nspdk.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | import networkx as nx 4 | import numpy as np 5 | from eden.graph import vectorize 6 | from rdkit import Chem 7 | from sklearn.metrics.pairwise import pairwise_kernels 8 | from tqdm import tqdm 9 | 10 | 11 | def single_mol_to_nx(mol): 12 | if mol is not None: 13 | G = nx.Graph() 14 | for atom in mol.GetAtoms(): 15 | G.add_node(atom.GetIdx(), label=atom.GetSymbol()) 16 | for bond in mol.GetBonds(): 17 | G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), label=int(bond.GetBondTypeAsDouble())) 18 | return G 19 | return None 20 | 21 | 22 | def mols_to_nx(mols, n_jobs=None): 23 | # convert with multiprocessing support 24 | if n_jobs is not None: 25 | pool = multiprocessing.Pool(processes=n_jobs) 26 | nx_graphs = pool.map(single_mol_to_nx, mols) 27 | pool.close() 28 | pool.join() 29 | return [graph for graph in tqdm(nx_graphs, desc='Converting Molecules to Graphs') if graph is not None] 30 | else: 31 | nx_graphs = [single_mol_to_nx(mol) for mol in tqdm(mols, desc='Converting Molecules to Graphs')] 32 | return [graph for graph in nx_graphs if graph is not None] 33 | 34 | 35 | def single_graph_to_vector(graph): 36 | return vectorize(graph, complexity=4, discrete=True).toarray() 37 | 38 | 39 | def compute_nspdk_mmd(samples1, samples2, metric='linear', n_jobs=None): 40 | # code adapted from https://github.com/idea-iitd/graphgen/blob/master/metrics/mmd.py 41 | # convert with multiprocessing support 42 | if n_jobs is not None: 43 | pool = multiprocessing.Pool(processes=n_jobs) 44 | vectors1 = pool.map(single_graph_to_vector, [[sample] for sample in samples1]) 45 | vectors2 = pool.map(single_graph_to_vector, [[sample] for sample in samples2]) 46 | pool.close() 47 | pool.join() 48 | vectors1 = np.concatenate([vector for vector in tqdm(vectors1, desc='Vectorization...')], axis=0) 49 | vectors2 = np.concatenate([vector for vector in tqdm(vectors2, desc='Vectorization...')], axis=0) 50 | else: 51 | print("Vectorization...") 52 | vectors1 = vectorize(samples1, complexity=4, discrete=True).toarray() 53 | vectors2 = vectorize(samples2, complexity=4, discrete=True).toarray() 54 | 55 | print("Computing X...") 56 | X = pairwise_kernels(vectors1, None, metric=metric, n_jobs=n_jobs) 57 | print(f"X={X}") 58 | 59 | print("Computing Y...") 60 | Y = pairwise_kernels(vectors2, None, metric=metric, n_jobs=n_jobs) 61 | print(f"Y={Y}") 62 | 63 | print("Computing Z...") 64 | Z = pairwise_kernels(vectors1, vectors2, metric=metric, n_jobs=n_jobs) 65 | print(f"Z={Z}") 66 | 67 | return np.average(X) + np.average(Y) - 2 * np.average(Z) 68 | 69 | 70 | def get_npsdk(smiles_list1, smiles_list2, metric='linear', n_jobs=None): 71 | nx_graphs1 = mols_to_nx([Chem.MolFromSmiles(smile) for smile in tqdm(smiles_list1, desc='Converting SMILES to Molecules')], n_jobs=n_jobs) 72 | nx_graphs2 = mols_to_nx([Chem.MolFromSmiles(smile) for smile in tqdm(smiles_list2, desc='Converting SMILES to Molecules')], n_jobs=n_jobs) 73 | nx_graphs1_remove_empty = [G for G in nx_graphs1 if not G.number_of_nodes() == 0] 74 | nx_graphs2_remove_empty = [G for G in nx_graphs2 if not G.number_of_nodes() == 0] 75 | nspdk_mmd = compute_nspdk_mmd(nx_graphs1_remove_empty, nx_graphs2_remove_empty, metric=metric, n_jobs=n_jobs) 76 | return nspdk_mmd 77 | -------------------------------------------------------------------------------- /utils/property_scores/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | from __future__ import print_function 19 | 20 | import _pickle as cPickle 21 | import os.path as op 22 | 23 | import math 24 | from rdkit import Chem 25 | from rdkit.Chem import rdMolDescriptors 26 | from rdkit.six import iteritems 27 | 28 | _fscores = None 29 | 30 | 31 | def readFragmentScores(name='fpscores'): 32 | import gzip 33 | global _fscores 34 | # generate the full path filename: 35 | if name == "fpscores": 36 | name = op.join(op.dirname(__file__), name) 37 | _fscores = cPickle.load(gzip.open('%s.pkl.gz' % name)) 38 | outDict = {} 39 | for i in _fscores: 40 | for j in range(1, len(i)): 41 | outDict[i[j]] = float(i[0]) 42 | _fscores = outDict 43 | 44 | 45 | def numBridgeheadsAndSpiro(mol, ri=None): 46 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 47 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 48 | return nBridgehead, nSpiro 49 | 50 | 51 | def calculateScore(m): 52 | if _fscores is None: 53 | readFragmentScores() 54 | 55 | # fragment score 56 | fp = rdMolDescriptors.GetMorganFingerprint(m, 57 | 2) # <- 2 is the *radius* of the circular fingerprint 58 | fps = fp.GetNonzeroElements() 59 | score1 = 0. 60 | nf = 0 61 | for bitId, v in iteritems(fps): 62 | nf += v 63 | sfp = bitId 64 | score1 += _fscores.get(sfp, -4) * v 65 | score1 /= nf 66 | 67 | # features score 68 | nAtoms = m.GetNumAtoms() 69 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 70 | ri = m.GetRingInfo() 71 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 72 | nMacrocycles = 0 73 | for x in ri.AtomRings(): 74 | if len(x) > 8: 75 | nMacrocycles += 1 76 | 77 | sizePenalty = nAtoms ** 1.005 - nAtoms 78 | stereoPenalty = math.log10(nChiralCenters + 1) 79 | spiroPenalty = math.log10(nSpiro + 1) 80 | bridgePenalty = math.log10(nBridgeheads + 1) 81 | macrocyclePenalty = 0. 82 | # --------------------------------------- 83 | # This differs from the paper, which defines: 84 | # macrocyclePenalty = math.log10(nMacrocycles+1) 85 | # This form generates better results when 2 or more macrocycles are present 86 | if nMacrocycles > 0: 87 | macrocyclePenalty = math.log10(2) 88 | 89 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 90 | 91 | # correction for the fingerprint density 92 | # not in the original publication, added in version 1.1 93 | # to make highly symmetrical molecules easier to synthetise 94 | score3 = 0. 95 | if nAtoms > len(fps): 96 | score3 = math.log(float(nAtoms) / len(fps)) * .5 97 | 98 | sascore = score1 + score2 + score3 99 | 100 | # need to transform "raw" value into scale between 1 and 10 101 | min = -4.0 102 | max = 2.5 103 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 104 | # smooth the 10-end 105 | if sascore > 8.: 106 | sascore = 8. + math.log(sascore + 1. - 9.) 107 | if sascore > 10.: 108 | sascore = 10.0 109 | elif sascore < 1.: 110 | sascore = 1.0 111 | 112 | return sascore 113 | 114 | 115 | def processMols(mols): 116 | print('smiles\tName\tsa_score') 117 | for i, m in enumerate(mols): 118 | if m is None: 119 | continue 120 | 121 | s = calculateScore(m) 122 | 123 | smiles = Chem.MolToSmiles(m) 124 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 125 | 126 | 127 | if __name__ == '__main__': 128 | import sys, time 129 | 130 | t1 = time.time() 131 | readFragmentScores("fpscores") 132 | t2 = time.time() 133 | 134 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 135 | t3 = time.time() 136 | processMols(suppl) 137 | t4 = time.time() 138 | 139 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 140 | file=sys.stderr) 141 | 142 | 143 | # 144 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 145 | # All rights reserved. 146 | # 147 | # Redistribution and use in source and binary forms, with or without 148 | # modification, are permitted provided that the following conditions are 149 | # met: 150 | # 151 | # * Redistributions of source code must retain the above copyright 152 | # notice, this list of conditions and the following disclaimer. 153 | # * Redistributions in binary form must reproduce the above 154 | # copyright notice, this list of conditions and the following 155 | # disclaimer in the documentation and/or other materials provided 156 | # with the distribution. 157 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 158 | # nor the names of its contributors may be used to endorse or promote 159 | # products derived from this software without specific prior written permission. 160 | # 161 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 162 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 163 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 164 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 165 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 166 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 167 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 168 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 169 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 170 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 171 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 172 | # 173 | 174 | def compute_sa_score(rdmol): 175 | rdmol = Chem.MolFromSmiles(Chem.MolToSmiles(rdmol)) 176 | sa = calculateScore(rdmol) 177 | sa_norm = round((10 - sa) / 9, 2) 178 | return sa_norm 179 | -------------------------------------------------------------------------------- /utils/property_scores/scoring_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | from rdkit import Chem 4 | from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski, Mol 5 | from rdkit.Chem.QED import qed 6 | 7 | from utils.property_scores.sascorer import compute_sa_score 8 | from utils.property_scores.tanimoto_similarity import MolsTanimotoSimilarity 9 | 10 | 11 | def fix_explicit_hs(mol: Mol) -> Mol: 12 | # rdkit has a problem with implicit hs. By default there are only explicit hs. 13 | # This is a hack to fix this error 14 | for a in mol.GetAtoms(): 15 | a.SetNoImplicit(False) 16 | 17 | mol = Chem.AddHs(mol, explicitOnly=True) 18 | mol = Chem.RemoveHs(mol) 19 | 20 | Chem.SanitizeMol(mol) 21 | return mol 22 | 23 | 24 | def get_basic(mol): 25 | n_atoms = len(mol.GetAtoms()) 26 | n_bonds = len(mol.GetBonds()) 27 | n_rings = len(Chem.GetSymmSSSR(mol)) 28 | weight = Descriptors.ExactMolWt(mol) 29 | return n_atoms, n_bonds, n_rings, weight 30 | 31 | 32 | def get_is_valid(mol): 33 | smiles = Chem.MolToSmiles(mol) 34 | mol = Chem.MolFromSmiles(smiles) 35 | if mol is None: 36 | return False 37 | try: 38 | Chem.SanitizeMol(mol) 39 | except ValueError: 40 | return False 41 | return True 42 | 43 | 44 | def get_qed(mol): 45 | # mol = fix_explicit_hs(mol) 46 | # return qed(mol) 47 | try: 48 | mol = deepcopy(mol) 49 | try: 50 | AllChem.Kekulize(mol, clearAromaticFlags=True) 51 | except: 52 | pass 53 | mol.UpdatePropertyCache(strict=False) 54 | return qed(mol) 55 | except Exception as e: 56 | print(e) 57 | return None 58 | 59 | 60 | def get_sa(mol): 61 | try: 62 | mol = deepcopy(mol) 63 | try: 64 | AllChem.Kekulize(mol, clearAromaticFlags=True) 65 | except: 66 | pass 67 | mol.UpdatePropertyCache(strict=False) 68 | return compute_sa_score(mol) 69 | except Exception as e: 70 | print(e) 71 | return None 72 | 73 | 74 | def get_logp(mol): 75 | try: 76 | mol = deepcopy(mol) 77 | try: 78 | AllChem.Kekulize(mol, clearAromaticFlags=True) 79 | except: 80 | pass 81 | mol.UpdatePropertyCache(strict=False) 82 | return Crippen.MolLogP(mol) 83 | except Exception as e: 84 | print(e) 85 | return None 86 | 87 | 88 | def get_lipinski(mol): 89 | try: 90 | mol = deepcopy(mol) 91 | try: 92 | Chem.SanitizeMol(mol) 93 | except: 94 | pass 95 | mol.UpdatePropertyCache(strict=False) 96 | rule_1 = Descriptors.ExactMolWt(mol) < 500 97 | rule_2 = Lipinski.NumHDonors(mol) <= 5 98 | rule_3 = Lipinski.NumHAcceptors(mol) <= 10 99 | logp = Crippen.MolLogP(mol) 100 | rule_4 = (logp >= -2) & (logp <= 5) 101 | rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10 102 | return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) 103 | except Exception as e: 104 | print(e) 105 | return None 106 | 107 | 108 | def get_hacc(mol): 109 | return Lipinski.NumHAcceptors(mol) 110 | 111 | 112 | def get_hdon(mol): 113 | return Lipinski.NumHDonors(mol) 114 | 115 | 116 | def get_rdkit_rmsd(mol, n_conf=20, random_seed=42): 117 | """ 118 | Calculate the alignment of generated mol and rdkit predicted mol 119 | Return the rmsd (max, min, median) of the `n_conf` rdkit conformers 120 | """ 121 | mol = deepcopy(mol) 122 | Chem.SanitizeMol(mol) 123 | mol3d = Chem.AddHs(mol) 124 | rmsd_list = [] 125 | # predict 3d 126 | confIds = AllChem.EmbedMultipleConfs(mol3d, n_conf, randomSeed=random_seed) 127 | for confId in confIds: 128 | AllChem.UFFOptimizeMolecule(mol3d, confId=confId) 129 | rmsd = Chem.rdMolAlign.GetBestRMS(mol, mol3d, refId=confId) 130 | rmsd_list.append(rmsd) 131 | # mol3d = Chem.RemoveHs(mol3d) 132 | rmsd_list = np.array(rmsd_list) 133 | return [np.max(rmsd_list), np.min(rmsd_list), np.median(rmsd_list)] 134 | 135 | 136 | def get_tanimoto_similarity(mol_list): 137 | tanimoto_calculator = MolsTanimotoSimilarity(mol_list) 138 | similarity_matrix = tanimoto_calculator.get_tanimoto_similarity_matrix() 139 | top_similarities = tanimoto_calculator.get_top_tanimoto_similarities() 140 | 141 | return similarity_matrix, top_similarities 142 | -------------------------------------------------------------------------------- /utils/property_scores/tanimoto_similarity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rdkit import DataStructs 3 | from rdkit.Chem import AllChem 4 | 5 | 6 | class MolsTanimotoSimilarity: 7 | def __init__(self, mols): 8 | self.mols = mols 9 | self.fingerprints = [AllChem.GetMorganFingerprintAsBitVect(mol, 2) for mol in mols] # get fingerprints 10 | self.tanimoto_similarity_matrix = np.zeros((len(mols), len(mols))) 11 | 12 | def get_tanimoto_similarity_matrix(self): 13 | # calculate similarity for each mol pair 14 | for i in range(0, len(self.mols)): 15 | for j in range(i + 1, len(self.mols)): 16 | similarity = DataStructs.TanimotoSimilarity(self.fingerprints[i], self.fingerprints[j]) 17 | self.tanimoto_similarity_matrix[i, j] = similarity 18 | 19 | # complete the similarity matrix 20 | lower_indices = np.tril_indices(len(self.mols), -1) 21 | self.tanimoto_similarity_matrix[lower_indices] = self.tanimoto_similarity_matrix.T[lower_indices] 22 | return self.tanimoto_similarity_matrix 23 | 24 | def get_top_tanimoto_similarities(self): 25 | top_similarities = np.max(self.tanimoto_similarity_matrix, axis=1) 26 | return top_similarities.tolist() 27 | -------------------------------------------------------------------------------- /utils/representation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/utils/representation/__init__.py -------------------------------------------------------------------------------- /utils/representation/data_interface.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch.utils.data import DataLoader 4 | 5 | from .load_dataset import DatasetTask 6 | 7 | 8 | class MyDataLoader(DataLoader): 9 | def __init__(self, dataset, batch_size=64, num_workers=8, data_task=None, split='train', *args, **kwargs): 10 | super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, *args, **kwargs) 11 | self.pretrain_device = 'cuda:0' 12 | self.split = split 13 | self.data_task = data_task 14 | 15 | def __iter__(self): 16 | for batch in super().__iter__(): 17 | try: 18 | self.pretrain_device = f'cuda:{torch.distributed.get_rank()}' 19 | except: 20 | self.pretrain_device = 'cuda:0' 21 | 22 | stream = torch.cuda.Stream( 23 | self.pretrain_device 24 | ) 25 | with torch.cuda.stream(stream): 26 | sample = {} 27 | for key in batch[0].keys(): 28 | sample[key] = [one[key] for one in batch] 29 | if type(sample[key][0]) == torch.Tensor: 30 | sample[key] = torch.stack(sample[key], dim=0).cuda(non_blocking=True, device=self.pretrain_device) 31 | 32 | yield sample 33 | 34 | 35 | def collate_fn(batch): 36 | return batch 37 | 38 | 39 | class DInterface_base(pl.LightningDataModule): 40 | def __init__(self, **kwargs): 41 | super().__init__() 42 | self.save_hyperparameters() 43 | self.batch_size = self.hparams.batch_size 44 | print("batch_size", self.batch_size) 45 | 46 | def setup(self, stage=None): 47 | # Assign train/val datasets for use in dataloaders 48 | if stage == 'fit' or stage is None: 49 | self.trainset = self.data_task.datasets['train'] 50 | self.valset = self.data_task.datasets['valid'] 51 | 52 | # Assign test dataset for use in dataloader(s) 53 | if stage == 'test' or stage is None: 54 | self.testset = self.data_task.datasets['test'] 55 | 56 | def train_dataloader(self): 57 | return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3) 58 | 59 | def val_dataloader(self): 60 | return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 61 | 62 | def test_dataloader(self): 63 | return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) 64 | 65 | 66 | class DInterface(DInterface_base): 67 | def __init__(self, **kwargs): 68 | super().__init__(**kwargs) 69 | self.save_hyperparameters() 70 | self.data_task = DatasetTask(self.hparams) 71 | self.data_task.load_dataset_mix_enc_dec('train') 72 | self.data_task.load_dataset_mix_enc_dec('valid') 73 | self.data_task.load_dataset_mix_enc_dec('test') 74 | 75 | def train_dataloader(self): 76 | return MyDataLoader(self.trainset, batch_size=self.batch_size, split='train', num_workers=self.hparams.num_workers, data_task=self.data_task, shuffle=True, prefetch_factor=8, pin_memory=True, collate_fn=collate_fn) 77 | 78 | def val_dataloader(self): 79 | return MyDataLoader(self.testset, batch_size=self.batch_size, split='valid', num_workers=self.hparams.num_workers, data_task=self.data_task, shuffle=False, pin_memory=True, collate_fn=collate_fn) 80 | 81 | def test_dataloader(self): 82 | return MyDataLoader(self.testset, batch_size=self.batch_size, split='test', num_workers=self.hparams.num_workers, data_task=self.data_task, shuffle=False, pin_memory=True, collate_fn=collate_fn) 83 | -------------------------------------------------------------------------------- /utils/representation/dict.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [CLS] 3 | [SEP] 4 | [UNK] 5 | C 6 | N 7 | O 8 | S 9 | H 10 | Cl 11 | F 12 | Br 13 | I 14 | Si 15 | P 16 | B 17 | Na 18 | K 19 | Al 20 | Ca 21 | Sn 22 | As 23 | Hg 24 | Fe 25 | Zn 26 | Cr 27 | Se 28 | Gd 29 | Au 30 | Li 31 | [MASK] -------------------------------------------------------------------------------- /utils/representation/dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import numpy as np 9 | 10 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 11 | 12 | 13 | class Dictionary: 14 | """A mapping from symbols to consecutive integers""" 15 | 16 | def __init__( 17 | self, 18 | *, # begin keyword-only arguments 19 | bos="[CLS]", 20 | pad="[PAD]", 21 | eos="[SEP]", 22 | unk="[UNK]", 23 | extra_special_symbols=None, 24 | ): 25 | self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos 26 | self.symbols = [] 27 | self.count = [] 28 | self.indices = {} 29 | self.specials = set() 30 | self.specials.add(bos) 31 | self.specials.add(unk) 32 | self.specials.add(pad) 33 | self.specials.add(eos) 34 | 35 | def __eq__(self, other): 36 | return self.indices == other.indices 37 | 38 | def __getitem__(self, idx): 39 | if idx < len(self.symbols): 40 | return self.symbols[idx] 41 | return self.unk_word 42 | 43 | def __len__(self): 44 | """Returns the number of symbols in the dictionary""" 45 | return len(self.symbols) 46 | 47 | def __contains__(self, sym): 48 | return sym in self.indices 49 | 50 | def vec_index(self, a): 51 | return np.vectorize(self.index)(a) 52 | 53 | def index(self, sym): 54 | """Returns the index of the specified symbol""" 55 | assert isinstance(sym, str) 56 | if sym in self.indices: 57 | return self.indices[sym] 58 | return self.indices[self.unk_word] 59 | 60 | def special_index(self): 61 | return [self.index(x) for x in self.specials] 62 | 63 | def add_symbol(self, word, n=1, overwrite=False, is_special=False): 64 | """Adds a word to the dictionary""" 65 | if is_special: 66 | self.specials.add(word) 67 | if word in self.indices and not overwrite: 68 | idx = self.indices[word] 69 | self.count[idx] = self.count[idx] + n 70 | return idx 71 | else: 72 | idx = len(self.symbols) 73 | self.indices[word] = idx 74 | self.symbols.append(word) 75 | self.count.append(n) 76 | return idx 77 | 78 | def bos(self): 79 | """Helper to get index of beginning-of-sentence symbol""" 80 | return self.index(self.bos_word) 81 | 82 | def pad(self): 83 | """Helper to get index of pad symbol""" 84 | return self.index(self.pad_word) 85 | 86 | def eos(self): 87 | """Helper to get index of end-of-sentence symbol""" 88 | return self.index(self.eos_word) 89 | 90 | def unk(self): 91 | """Helper to get index of unk symbol""" 92 | return self.index(self.unk_word) 93 | 94 | @classmethod 95 | def load(cls, f): 96 | """Loads the dictionary from a text file with the format: 97 | 98 | ``` 99 | 100 | 101 | ... 102 | ``` 103 | """ 104 | d = cls() 105 | d.add_from_file(f) 106 | return d 107 | 108 | def add_from_file(self, f): 109 | """ 110 | Loads a pre-existing dictionary from a text file and adds its symbols 111 | to this instance. 112 | """ 113 | if isinstance(f, str): 114 | try: 115 | with open(f, "r", encoding="utf-8") as fd: 116 | self.add_from_file(fd) 117 | except FileNotFoundError as fnfe: 118 | raise fnfe 119 | except UnicodeError: 120 | raise Exception( 121 | "Incorrect encoding detected in {}, please " 122 | "rebuild the dataset".format(f) 123 | ) 124 | return 125 | 126 | lines = f.readlines() 127 | 128 | for line_idx, line in enumerate(lines): 129 | try: 130 | splits = line.rstrip().rsplit(" ", 1) 131 | line = splits[0] 132 | field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx) 133 | if field == "#overwrite": 134 | overwrite = True 135 | line, field = line.rsplit(" ", 1) 136 | else: 137 | overwrite = False 138 | count = int(field) 139 | word = line 140 | if word in self and not overwrite: 141 | logger.info( 142 | "Duplicate word found when loading Dictionary: '{}', index is {}.".format(word, self.indices[word]) 143 | ) 144 | else: 145 | self.add_symbol(word, n=count, overwrite=overwrite) 146 | except ValueError: 147 | raise ValueError( 148 | "Incorrect dictionary format, expected ' [flags]'" 149 | ) 150 | -------------------------------------------------------------------------------- /utils/representation/fingerprint_interface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from rdkit import Chem 5 | from rdkit.Chem import rdFingerprintGenerator 6 | 7 | from data.tokenizer import GraphsGPTTokenizer 8 | from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM 9 | 10 | CLASS_OF_FPGEN = { 11 | "fingerprint-morgan": rdFingerprintGenerator.GetMorganGenerator, 12 | "fingerprint-rdkit": rdFingerprintGenerator.GetRDKitFPGenerator, 13 | "fingerprint-ap": rdFingerprintGenerator.GetAtomPairGenerator, 14 | "fingerprint-tt": rdFingerprintGenerator.GetTopologicalTorsionGenerator, 15 | } 16 | 17 | 18 | class FPModel(nn.Module): 19 | def __init__(self, model_name) -> None: 20 | super().__init__() 21 | self.model_name = model_name 22 | if model_name == 'graphsgpt': 23 | self.model = GraphsGPTForCausalLM.from_pretrained("DaizeDong/GraphsGPT-1W") 24 | self.tokenizer = GraphsGPTTokenizer.from_pretrained("DaizeDong/GraphsGPT-1W") 25 | self.model.cuda() 26 | self.model.eval() 27 | 28 | if model_name == 'morgan': 29 | self.radius = 2 30 | self.fpsize = 2048 31 | self.count = 0 32 | 33 | self.memory = {} 34 | 35 | def gen_fp(self, mol, fpgen, count=False): 36 | if count: 37 | g = fpgen.GetCountFingerprintAsNumPy 38 | else: 39 | g = fpgen.GetFingerprintAsNumPy 40 | return np.array(g(mol)) 41 | 42 | def forward(self, smiles_list): 43 | if all([one in self.memory for one in smiles_list]): 44 | fingerprint_tokens = torch.stack([self.memory[one] for one in smiles_list], dim=0) 45 | return fingerprint_tokens 46 | 47 | if self.model_name == 'graphsgpt': 48 | with torch.no_grad(): 49 | batch = self.tokenizer.batch_encode(smiles_list, return_tensors="pt") 50 | inputs = batch['batched_tokens'] 51 | inputs = {k: v.cuda() for k, v in inputs.items()} 52 | fingerprint_tokens = self.model.encode_to_fingerprints(**inputs) 53 | fingerprint_tokens = fingerprint_tokens[:, 0].cpu() 54 | 55 | if self.model_name == 'morgan': 56 | fps = [] 57 | fpgen = CLASS_OF_FPGEN["fingerprint-morgan"](radius=self.radius, fpSize=self.fpsize) 58 | for i in range(len(smiles_list)): 59 | mol = Chem.MolFromSmiles(smiles_list[i]) 60 | fps.append(self.gen_fp(mol, fpgen, count=self.count)) 61 | fingerprint_embeddings = np.array(fps).astype(float) 62 | 63 | fingerprint_tokens = torch.from_numpy(fingerprint_embeddings) 64 | 65 | for i, one in enumerate(smiles_list): 66 | self.memory[one] = fingerprint_tokens[i] 67 | 68 | return fingerprint_tokens 69 | -------------------------------------------------------------------------------- /utils/representation/graphsgpt_finetune_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from data.tokenizer import GraphsGPTTokenizer 5 | from models.graphsgpt.configuration_graphsgpt import GraphsGPTConfig 6 | from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM 7 | 8 | 9 | class GraphsGPT(nn.Module): 10 | def __init__(self, task_name, mixup_strategy, pooler_dropout): 11 | super().__init__() 12 | if mixup_strategy == 'no_mix_vanilla': 13 | config = GraphsGPTConfig.from_pretrained("DaizeDong/GraphsGPT-1W") 14 | self.model = GraphsGPTForCausalLM(config) 15 | else: 16 | self.model = GraphsGPTForCausalLM.from_pretrained("DaizeDong/GraphsGPT-1W") 17 | self.tokenizer = GraphsGPTTokenizer.from_pretrained("DaizeDong/GraphsGPT-1W") 18 | 19 | if task_name in ['bace', 'bbbp', 'clintox', 'hiv']: 20 | num_classes = 2 21 | 22 | if task_name in ['esol', 'freesolv', 'lipo', 'qm7dft']: 23 | num_classes = 1 24 | 25 | if task_name in ['qm8dft']: 26 | num_classes = 12 27 | 28 | if task_name in ['qm9dft']: 29 | num_classes = 3 30 | 31 | if task_name in ['sider']: 32 | num_classes = 27 33 | 34 | if task_name in ['tox21']: 35 | num_classes = 12 36 | 37 | if task_name in ['toxcast']: 38 | num_classes = 617 39 | 40 | if task_name in ['muv']: 41 | num_classes = 17 42 | 43 | self.classification_heads = ClassificationHead( 44 | input_dim=512, 45 | inner_dim=512, 46 | num_classes=num_classes, 47 | pooler_dropout=pooler_dropout, 48 | ) 49 | 50 | def forward(self, mols_list, mix_embeds=None, mixup_lam=None, mixup_index=None): 51 | batch = self.tokenizer.batch_encode(mols_list, return_tensors="pt") 52 | inputs = {k: v.cuda() for k, v in batch.items()} 53 | fingerprint_tokens = self.model.encode_to_fingerprints(**inputs) 54 | logit_output = self.classification_heads(fingerprint_tokens) 55 | return {'logit_output': logit_output} 56 | 57 | 58 | class ClassificationHead(nn.Module): 59 | """Head for sentence-level classification tasks.""" 60 | 61 | def __init__( 62 | self, 63 | input_dim, 64 | inner_dim, 65 | num_classes, 66 | pooler_dropout, 67 | ): 68 | super().__init__() 69 | self.dense = nn.Linear(input_dim, inner_dim) 70 | self.activation_fn = torch.tanh 71 | self.dropout = nn.Dropout(p=pooler_dropout) 72 | self.out_proj = nn.Linear(inner_dim, num_classes) 73 | 74 | def forward(self, features, **kwargs): 75 | x = features[:, 0, :] # take token (equiv. to [CLS]) 76 | x = self.dropout(x) 77 | x = self.dense(x) 78 | x = self.activation_fn(x) 79 | x = self.dropout(x) 80 | x = self.out_proj(x) 81 | return x 82 | -------------------------------------------------------------------------------- /utils/representation/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | 6 | import lmdb 7 | import logging 8 | import os 9 | import pickle 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class LMDBDataset: 15 | def __init__(self, db_path): 16 | self.db_path = db_path 17 | assert os.path.isfile(self.db_path), "{} not found".format(self.db_path) 18 | env = self.connect_db(self.db_path) 19 | 20 | with env.begin() as txn: 21 | self._keys = list(txn.cursor().iternext(values=False)) 22 | 23 | def connect_db(self, lmdb_path, save_to_self=False): 24 | env = lmdb.open( 25 | lmdb_path, 26 | subdir=False, 27 | readonly=True, 28 | lock=False, 29 | readahead=False, 30 | meminit=False, 31 | max_readers=256, 32 | ) 33 | if not save_to_self: 34 | return env 35 | else: 36 | self.env = env 37 | 38 | def __len__(self): 39 | return len(self._keys) 40 | 41 | def __getitem__(self, idx): 42 | env = self.connect_db(self.db_path) 43 | datapoint_pickled = env.begin().get(self._keys[idx]) 44 | 45 | # avoid open too many files error. 46 | env.close() 47 | del env 48 | 49 | data = pickle.loads(datapoint_pickled) 50 | return data 51 | -------------------------------------------------------------------------------- /utils/representation/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import OmegaConf 3 | from pytorch_lightning.callbacks import Callback 4 | 5 | 6 | class SetupCallback(Callback): 7 | def __init__(self, now, logdir, ckptdir, cfgdir, config, argv_content=None): 8 | super().__init__() 9 | self.now = now 10 | self.logdir = logdir 11 | self.ckptdir = ckptdir 12 | self.cfgdir = cfgdir 13 | self.config = config 14 | 15 | self.argv_content = argv_content 16 | 17 | def on_fit_start(self, trainer, pl_module): 18 | # Create logdirs and save configs 19 | os.makedirs(self.logdir, exist_ok=True) 20 | os.makedirs(self.ckptdir, exist_ok=True) 21 | os.makedirs(self.cfgdir, exist_ok=True) 22 | 23 | print("Project config") 24 | print(OmegaConf.to_yaml(self.config)) 25 | OmegaConf.save(self.config, 26 | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) 27 | 28 | with open(os.path.join(self.logdir, "argv_content.txt"), "w") as f: 29 | f.write(str(self.argv_content)) 30 | -------------------------------------------------------------------------------- /utils/representation/model_interface.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import numpy as np 3 | import os 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn.functional as F 7 | from sklearn.metrics import roc_auc_score 8 | from torch import nn as nn 9 | from torch.optim import lr_scheduler as lrs 10 | 11 | from .graphsgpt_finetune_model import GraphsGPT 12 | from .load_dataset import task_metainfo 13 | 14 | 15 | def softmax_cross_entropy_with_softtarget(input, target, reduction='mean'): 16 | """ 17 | :param input: (batch, *) 18 | :param target: (batch, *) same shape as input, each item must be a valid distribution: target[i, :].sum() == 1. 19 | """ 20 | logprobs = torch.nn.functional.log_softmax(input.view(input.shape[0], -1), dim=1) 21 | batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1) 22 | if reduction == 'none': 23 | return batchloss 24 | elif reduction == 'mean': 25 | return torch.mean(batchloss) 26 | elif reduction == 'sum': 27 | return torch.sum(batchloss) 28 | else: 29 | raise NotImplementedError('Unsupported reduction mode.') 30 | 31 | 32 | class MInterface_base(pl.LightningModule): 33 | def __init__(self, model_name=None, loss=None, lr=None, **kargs): 34 | super().__init__() 35 | self.save_hyperparameters() 36 | self.load_model() 37 | self.configure_loss() 38 | os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True) 39 | 40 | def forward(self, input): 41 | pass 42 | 43 | def training_step(self, batch, batch_idx, **kwargs): 44 | pass 45 | 46 | def validation_step(self, batch, batch_idx): 47 | pass 48 | 49 | def test_step(self, batch, batch_idx): 50 | # Here we just reuse the validation_step for testing 51 | return self.validation_step(batch, batch_idx) 52 | 53 | def on_validation_epoch_end(self): 54 | # Make the Progress Bar leave there 55 | self.print('') 56 | 57 | def get_schedular(self, optimizer, lr_scheduler='onecycle'): 58 | if lr_scheduler == 'step': 59 | scheduler = lrs.StepLR(optimizer, 60 | step_size=self.hparams.lr_decay_steps, 61 | gamma=self.hparams.lr_decay_rate) 62 | elif lr_scheduler == 'cosine': 63 | scheduler = lrs.CosineAnnealingLR(optimizer, 64 | T_max=max(self.hparams.epoch / 5, 1)) 65 | elif lr_scheduler == 'onecycle': 66 | scheduler = lrs.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=self.hparams.steps_per_epoch, epochs=self.hparams.epoch, three_phase=False) 67 | elif lr_scheduler == 'polynomial': 68 | scheduler = PolynomialDecayLRSchedule(optimizer, warmup_ratio=self.hparams.warmup_ratio, total_num_update=self.hparams.epoch * self.hparams.steps_per_epoch, lr=self.hparams.lr, end_learning_rate=0.0, power=1.0) 69 | else: 70 | raise ValueError('Invalid lr_scheduler type!') 71 | 72 | return scheduler 73 | 74 | def configure_optimizers(self): 75 | if hasattr(self.hparams, 'weight_decay'): 76 | weight_decay = self.hparams.weight_decay 77 | else: 78 | weight_decay = 0 79 | 80 | optimizer_g = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.lr, weight_decay=weight_decay, betas=(0.9, 0.99), eps=1e-8) 81 | 82 | schecular_g = self.get_schedular(optimizer_g, self.hparams.lr_scheduler) 83 | 84 | return [optimizer_g], [{"scheduler": schecular_g, "interval": "step"}] 85 | 86 | def lr_scheduler_step(self, *args, **kwargs): 87 | scheduler = self.lr_schedulers() 88 | scheduler.step() 89 | 90 | def configure_devices(self): 91 | self.device = torch.device(self.hparams.device) 92 | 93 | def configure_loss(self): 94 | self.loss_function = nn.CrossEntropyLoss(reduction='none') 95 | 96 | def load_model(self): 97 | self.model = None 98 | 99 | def instancialize(self, Model, **other_args): 100 | """ Instancialize a model using the corresponding parameters 101 | from self.hparams dictionary. You can also input any args 102 | to overwrite the corresponding value in self.hparams. 103 | """ 104 | class_args = inspect.getargspec(Model.__init__).args[1:] 105 | inkeys = self.hparams.keys() 106 | args1 = {} 107 | for arg in class_args: 108 | if arg in inkeys: 109 | args1[arg] = getattr(self.hparams, arg) 110 | args1.update(other_args) 111 | return Model(**args1) 112 | 113 | 114 | class MInterface(MInterface_base): 115 | def __init__(self, model_name=None, loss=None, lr=None, **kargs): 116 | super().__init__() 117 | self.save_hyperparameters() 118 | self.load_model() 119 | self.configure_loss() 120 | os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True) 121 | self.targets = [] 122 | self.preds = [] 123 | 124 | def forward(self, batch, mode='eval'): 125 | results = self.model(batch['mol_raw']) 126 | targets = batch["target_raw"] 127 | logit_output = results['logit_output'] 128 | 129 | results_mix = self.compute_loss(logit_output, targets, loss_type=self.hparams.loss_type) 130 | loss = results_mix['loss'] 131 | 132 | return {'loss': loss, 'logits': logit_output, 'targets': targets} 133 | 134 | def training_step(self, batch, batch_idx, **kwargs): 135 | results = self(batch, 'train') 136 | self.log_dict({"train_loss": results['loss']}, on_epoch=True, prog_bar=True) 137 | return results['loss'] 138 | 139 | def validation_step(self, batch, batch_idx): 140 | results = self(batch, 'eval') 141 | self.targets.append(results['targets'].float().cpu().numpy()) 142 | self.preds.append(results['logits'].float().cpu().numpy()) 143 | metrics = self.compute_metrics(results) 144 | metrics.update({'loss': results['loss']}) 145 | val_metrics = {'test_' + key: val for key, val in metrics.items()} 146 | self.log_dict(val_metrics) 147 | return metrics 148 | 149 | def test_step(self, batch, batch_idx): 150 | metrics = self.validation_step(batch, batch_idx) 151 | return metrics 152 | 153 | def training_epoch_end(self, outputs): 154 | pass 155 | 156 | def validation_epoch_end(self, outputs): 157 | metrics = self.compute_metrics(outputs) 158 | val_metrics = {'test_' + key: val for key, val in metrics.items()} 159 | self.log_dict(val_metrics) 160 | 161 | self.targets = [] 162 | self.preds = [] 163 | return self.log_dict 164 | 165 | def test_epoch_end(self, outputs): 166 | metrics = self.compute_metrics(outputs) 167 | val_metrics = {'test_' + key: val for key, val in metrics.items()} 168 | self.log_dict(val_metrics) 169 | 170 | self.targets = [] 171 | self.preds = [] 172 | return self.log_dict 173 | 174 | def load_model(self): 175 | # =========== graphsgpt model 176 | self.model = GraphsGPT(self.hparams.task_name, self.hparams.mixup_strategy, self.hparams.pooler_dropout) 177 | 178 | def compute_metrics(self, outputs): 179 | targets = np.concatenate(self.targets) 180 | preds = np.concatenate(self.preds) 181 | 182 | if self.hparams.loss_type == 'mixup_cross_entropy': 183 | targets = targets.argmax(axis=-1) 184 | if self.hparams.num_classes == 2: 185 | if np.unique(targets).shape[0] == 1: 186 | auc = 0.0 187 | else: 188 | auc = roc_auc_score(targets, preds[:, 1]) 189 | return {'auc': auc} 190 | 191 | if self.hparams.loss_type == 'mixup_mse': 192 | mean = task_metainfo[self.hparams.task_name]["mean"] 193 | std = task_metainfo[self.hparams.task_name]["std"] 194 | predicts = preds * std + mean 195 | mae = np.abs(predicts - targets).mean() 196 | mse = ((predicts - targets) ** 2).mean() 197 | return {'mae': mae, 'mse': mse, 'rmse': np.sqrt(mse)} 198 | 199 | if self.hparams.loss_type == 'mixup_smooth_mae': 200 | mean = task_metainfo[self.hparams.task_name]["mean"] 201 | std = task_metainfo[self.hparams.task_name]["std"] 202 | predicts = preds * std + mean 203 | mae = np.abs(predicts - targets).mean() 204 | mse = ((predicts - targets) ** 2).mean() 205 | return {'mae': mae, 'mse': mse, 'rmse': np.sqrt(mse)} 206 | 207 | if self.hparams.loss_type == 'mixup_multi_task_BCE': 208 | def sigmoid(z): 209 | return 1 / (1 + np.exp(-z)) 210 | 211 | probs = sigmoid(preds) 212 | agg_auc_list = [] 213 | for i in range(targets.shape[1]): 214 | if np.sum(targets[:, i] == 1) > 0 and np.sum(targets[:, i] == 0) > 0: 215 | # ignore nan values 216 | is_labeled = targets[:, i] > -0.5 217 | agg_auc_list.append( 218 | roc_auc_score(targets[is_labeled, i], probs[is_labeled, i]) 219 | ) 220 | 221 | auc = sum(agg_auc_list) / (len(agg_auc_list) + 1e-8) 222 | return {'auc': auc} 223 | 224 | def cross_entropy_loss(self, net_output, targets, reduce=True): 225 | targets = targets.view(-1) 226 | lprobs = F.log_softmax(net_output, dim=-1) 227 | lprobs = lprobs.view(-1, lprobs.size(-1)) 228 | loss = F.nll_loss( 229 | lprobs, 230 | targets, 231 | reduction="sum" if reduce else "none", 232 | ) 233 | 234 | return {'loss': loss} 235 | 236 | def multi_task_BCE(self, logit_output, targets, reduce=True): 237 | is_labeled = targets > -0.5 238 | pos_mask = targets[is_labeled].float().view(-1) 239 | loss = F.binary_cross_entropy_with_logits( 240 | logit_output[is_labeled].float().view(-1), 241 | targets[is_labeled].float().view(-1), 242 | reduction="sum" if reduce else "none", 243 | pos_weight=((1 - pos_mask) * 9 + 1) 244 | ) 245 | return {'loss': loss} 246 | 247 | def mse(self, logit_output, targets, reduce=True): 248 | predicts_normed = logit_output.view(-1, self.hparams.num_classes).float() 249 | targets = ( 250 | targets.view(-1, self.hparams.num_classes).float() 251 | ) 252 | 253 | mean = task_metainfo[self.hparams.task_name]["mean"] 254 | std = task_metainfo[self.hparams.task_name]["std"] 255 | 256 | targets_mean = torch.tensor(mean, device=targets.device) 257 | targets_std = torch.tensor(std, device=targets.device) 258 | targets_normed = (targets - targets_mean) / targets_std 259 | loss = F.mse_loss( 260 | predicts_normed, 261 | targets_normed, 262 | reduction="sum" if reduce else "none", 263 | ) 264 | return {'loss': loss} 265 | 266 | def smooth_mse(self, logit_output, targets, reduce=True): 267 | predicts_normed = logit_output.view(-1, self.hparams.num_classes).float() 268 | targets = ( 269 | targets.view(-1, self.hparams.num_classes).float() 270 | ) 271 | mean = task_metainfo[self.hparams.task_name]["mean"] 272 | std = task_metainfo[self.hparams.task_name]["std"] 273 | targets_mean = torch.tensor(mean, device=targets.device) 274 | targets_std = torch.tensor(std, device=targets.device) 275 | targets_normed = (targets - targets_mean) / targets_std 276 | loss = F.smooth_l1_loss( 277 | predicts_normed, 278 | targets_normed, 279 | reduction="sum" if reduce else "none", 280 | ) 281 | return {'loss': loss} 282 | 283 | def compute_loss(self, logit_output, targets, loss_type='mixup_cross_entropy'): 284 | if loss_type == 'mixup_cross_entropy': 285 | loss = softmax_cross_entropy_with_softtarget(logit_output, targets, reduction='mean') 286 | return {'loss': loss} 287 | 288 | if loss_type == 'mixup_mse': 289 | return self.mse(logit_output, targets, reduce=True) 290 | 291 | if loss_type == 'mixup_smooth_mae': 292 | return self.smooth_mse(logit_output, targets, reduce=True) 293 | 294 | if loss_type == 'mixup_multi_task_BCE': 295 | return self.multi_task_BCE(logit_output, targets, reduce=True) 296 | 297 | 298 | class PolynomialDecayLRSchedule(lrs.LRScheduler): 299 | def __init__(self, optimizer, warmup_ratio, total_num_update, lr, end_learning_rate, power, last_epoch=-1): 300 | self.warmup_ratio = warmup_ratio # 2532 301 | self.warmup_updates = int(self.warmup_ratio * total_num_update) 302 | self.total_num_update = total_num_update # 42200 303 | self.lr = lr 304 | self.warmup_factor = 1.0 / self.warmup_updates if self.warmup_updates > 0 else 1 305 | self.end_learning_rate = end_learning_rate 306 | self.power = power 307 | super(PolynomialDecayLRSchedule, self).__init__(optimizer, last_epoch) 308 | 309 | def get_lr(self): 310 | if self._step_count < self.warmup_updates: 311 | self.warmup_factor = self._step_count / float(self.warmup_updates) 312 | return [self.warmup_factor * self.lr] 313 | elif self.last_epoch >= self.total_num_update: 314 | return [self.end_learning_rate] 315 | else: 316 | lr_range = self.lr - self.end_learning_rate 317 | pct_remaining = 1 - (self._step_count - self.warmup_updates) / (self.total_num_update - self.warmup_updates) 318 | lr = lr_range * pct_remaining ** self.power + self.end_learning_rate 319 | return [lr] 320 | --------------------------------------------------------------------------------