├── README.md
├── configs
└── property_info.json
├── data
├── __init__.py
├── collate_fn.py
├── examples
│ └── zinc_example.txt
├── tokenizer.py
└── zinc250k
│ ├── all_250k_rndm_zinc_drugs_clean_3.txt
│ └── valid_250k_rndm_zinc_drugs_clean_3.txt
├── data_finetune
└── README.md
├── entrypoints
├── __init__.py
├── generation
│ ├── __init__.py
│ ├── conditional
│ │ ├── __init__.py
│ │ ├── generate.py
│ │ └── visualize.py
│ ├── evaluation
│ │ ├── __init__.py
│ │ ├── evaluate_moses_few_shot_sampling.py
│ │ └── evaluate_zinc250k_few_shot_sampling.py
│ └── unconditional
│ │ ├── __init__.py
│ │ ├── generate.py
│ │ └── visualize.py
└── representation
│ ├── __init__.py
│ └── finetune.py
├── graphsgpt.svg
├── jupyter_notebooks
├── analysis
│ ├── README-Clustering.md
│ ├── clustering.ipynb
│ ├── hybridization.ipynb
│ └── interpolation.ipynb
└── example_pipeline.ipynb
├── models
├── __init__.py
├── graphsgpt
│ ├── __init__.py
│ ├── configuration_graphsgpt.py
│ ├── generation_utils.py
│ ├── modeling_graphsgpt.py
│ └── orf.py
└── graphsgpt_cond_gen
│ ├── __init__.py
│ ├── configuration_graphsgpt_cond_gen.py
│ └── modelling_graphsgpt_cond_gen.py
├── moses
├── NP_Score
│ ├── README
│ ├── __init__.py
│ ├── npscorer.py
│ └── publicnp.model.gz
├── README.md
├── SA_Score
│ ├── README
│ ├── __init__.py
│ ├── fpscores.pkl.gz
│ └── sascorer.py
├── __init__.py
├── data
│ ├── test.csv.gz
│ ├── test_scaffolds.csv.gz
│ ├── test_scaffolds_stats.npz
│ ├── test_stats.npz
│ └── train.csv.gz
├── dataset.py
├── mcf.csv
├── metric_utils.py
├── metrics.py
├── utils.py
└── wehi_pains.csv
├── requirements.txt
├── scripts
├── generation
│ ├── conditional
│ │ ├── README-Generation-Cond.md
│ │ ├── examples
│ │ │ ├── generate_scaffold_logp.sh
│ │ │ ├── generate_scaffold_qed.sh
│ │ │ └── generate_scaffold_sa.sh
│ │ └── visualize.sh
│ ├── evaluation
│ │ ├── moses.sh
│ │ └── zinc250k.sh
│ └── unconditional
│ │ ├── README-Generation-Uncond.md
│ │ ├── examples
│ │ ├── generate_flexible.sh
│ │ └── generate_strict.sh
│ │ └── visualize.sh
└── representation
│ └── finetune.sh
├── setup.py
└── utils
├── __init__.py
├── accuracy.py
├── algorithms.py
├── io.py
├── molecule.py
├── operations
├── __init__.py
├── operation_dataframe.py
├── operation_dict.py
├── operation_list.py
├── operation_number.py
├── operation_string.py
└── operation_tensor.py
├── property_scores
├── __init__.py
├── fpscores.pkl.gz
├── nspdk.py
├── sascorer.py
├── scoring_func.py
└── tanimoto_similarity.py
└── representation
├── __init__.py
├── data_interface.py
├── dict.txt
├── dictionary.py
├── fingerprint_interface.py
├── graphsgpt_finetune_model.py
├── lmdb_dataset.py
├── load_dataset.py
├── logger.py
└── model_interface.py
/README.md:
--------------------------------------------------------------------------------
1 | # [GraphsGPT] A Graph is Worth $K$ Words:
Euclideanizing Graph using Pure Transformer (ICML2024)
2 |
3 | **[Zhangyang Gao](https://scholar.google.com/citations?user=4SclT-QAAAAJ)\*, [Daize Dong](https://daizedong.github.io/)\*, [Cheng Tan](https://chengtan9907.github.io/), [Jun Xia](https://junxia97.github.io/), [Bozhen Hu](https://scholar.google.com/citations?user=6FZh9C8AAAAJ), [Stan Z. Li](https://scholar.google.com/citations?user=Y-nyLGIAAAAJ)**
4 |
5 | Published on *The 41st International Conference on Machine Learning (ICML 2024)*.
6 |
7 | [](https://arxiv.org/abs/2402.02464)
8 |
9 | ## Introduction
10 |
11 | Can we model Non-Euclidean graphs as pure language or even Euclidean vectors while retaining their inherent information? The Non-Euclidean property have posed a long term challenge in graph modeling. Despite recent graph neural networks and graph transformers efforts encoding graphs as Euclidean vectors, recovering the original graph from vectors remains a challenge.
12 | In this paper, we introduce GraphsGPT, featuring an Graph2Seq encoder that transforms Non-Euclidean graphs into learnable GraphWords in the Euclidean space, along with a GraphGPT decoder that reconstructs the original graph from GraphWords to ensure information equivalence. We pretrain GraphsGPT on 100M molecules and yield some interesting findings:
13 |
14 | - The pretrained Graph2Seq excels in graph representation learning, achieving state-of-the-art results on $8/9$ graph classification and regression tasks.
15 | - The pretrained GraphGPT serves as a strong graph generator, demonstrated by its strong ability to perform both few-shot and conditional graph generation.
16 | - Graph2Seq+GraphGPT enables effective graph mixup in the Euclidean space, overcoming previously known Non-Euclidean challenges.
17 | - The edge-centric pretraining framework GraphsGPT demonstrates its efficacy in graph domain tasks, excelling in both representation and generation.
18 |
19 | 
20 |
21 |
22 |
23 | ## Installation
24 |
25 | To get started with GraphsGPT, please run the following commands to install the environments.
26 |
27 | ```bash
28 | git clone git@github.com:A4Bio/GraphsGPT.git --depth=1
29 | cd GraphsGPT
30 | conda create --name graphsgpt python=3.12
31 | conda activate graphsgpt
32 | pip install -e .[dev]
33 | pip install -r requirements.txt
34 | ```
35 |
36 |
37 |
38 | ## Quickstart
39 |
40 | We provide some Jupyter Notebooks in `./jupyter_notebooks`, and their corresponding online Google Colaboratory Notebooks. You can run them for a quick start.
41 |
42 | | | Jupyter Notebook | Google Colaboratory |
43 | | :--------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
44 | | *GraphsGPT Pipeline* | [example_pipeline.ipynb](jupyter_notebooks%2Fexample_pipeline.ipynb) |
|
45 | | Clustering Analysis | [clustering.ipynb](jupyter_notebooks%2Fanalysis%2Fclustering.ipynb) |
|
46 | | Hybridization Analysis | [hybridization.ipynb](jupyter_notebooks%2Fanalysis%2Fhybridization.ipynb) |
|
47 | | Interpolation Analysis | [interpolation.ipynb](jupyter_notebooks%2Fanalysis%2Finterpolation.ipynb) |
|
48 |
49 |
50 |
51 | ## Checkpoints
52 |
53 | The model [checkpoints](https://huggingface.co/collections/DaizeDong/graphsgpt-65efe70c326a1a5bd35c2fcc) can be downloaded from 🤗 Transformers. We provide both the foundational pretrained models with different number of Graph Words $\mathcal{W}$ (GraphsGPT-nW), and the conditional version with one Graph Word (GraphsGPT-1W-C).
54 |
55 | | Model Name | Model Type | Model Checkpoint |
56 | | :------------: | :--------------: | :----------------------------------------------------------: |
57 | | GraphsGPT-1W | Foundation Model |
|
58 | | GraphsGPT-2W | Foundation Model |
|
59 | | GraphsGPT-4W | Foundation Model |
|
60 | | GraphsGPT-8W | Foundation Model |
|
61 | | GraphsGPT-1W-C | Finetuned Model |
|
62 |
63 |
64 |
65 | ## Representation Experiments
66 |
67 | You should first [download](https://github.com/A4Bio/GraphsGPT/releases/tag/data) the configurations and data for finetuning, and put them in `./data_finetune`. (We also include the finetuned checkpoints in the `model_zoom.zip` file for a quick test.)
68 |
69 | To evaluate the representation performance of the Graph2Seq Encoder, please run:
70 |
71 | ```bash
72 | bash ./scripts/representation/finetune.sh
73 | ```
74 |
75 | You can also toggle the `--mixup_strategy` for graph mixup using Graph2Seq.
76 |
77 |
78 |
79 | ## Generation Experiments
80 |
81 | For the unconditional generation with GraphGPT Decoder, please refer to [README-Generation-Uncond.md](scripts%2Fgeneration%2Funconditional%2FREADME-Generation-Uncond.md).
82 |
83 | For the conditional generation with GraphGPT-C Decoder, please refer to [README-Generation-Cond.md](scripts%2Fgeneration%2Fconditional%2FREADME-Generation-Cond.md).
84 |
85 | To evaluate the few-shots generation performance of GraphGPT Decoder, please run:
86 |
87 | ```bash
88 | bash ./scripts/generation/evaluation/moses.sh
89 | bash ./scripts/generation/evaluation/zinc250k.sh
90 | ```
91 |
92 |
93 |
94 | ## Citation
95 |
96 | ```latex
97 | @article{gao2024graph,
98 | title={A Graph is Worth $K$ Words: Euclideanizing Graph using Pure Transformer},
99 | author={Gao, Zhangyang and Dong, Daize and Tan, Cheng and Xia, Jun and Hu, Bozhen and Li, Stan Z},
100 | journal={arXiv preprint arXiv:2402.02464},
101 | year={2024}
102 | }
103 | ```
104 |
105 | ## Contact Us
106 | If you have any questions, please contact:
107 |
108 | - Zhangyang Gao: gaozhangyang@westlake.edu.cn
109 |
110 | - Daize Dong: dzdong2019@gmail.com
111 |
--------------------------------------------------------------------------------
/configs/property_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "qed": {
3 | "mean": 0.7195562004332452,
4 | "std": 0.13004599771146944
5 | },
6 | "sa": {
7 | "mean": 0.7612138562611798,
8 | "std": 0.06663695435624298
9 | },
10 | "logp": {
11 | "mean": 2.570799410827321,
12 | "std": 1.3405790580975692
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/data/__init__.py
--------------------------------------------------------------------------------
/data/collate_fn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.utils.rnn as rnn_utils
3 | from typing import List
4 |
5 |
6 | def identity_collator(examples):
7 | return examples
8 |
9 |
10 | def tensor_stack_collator(examples):
11 | """
12 | examples: list of tensors.
13 | input: [tensor1, tensor2, ..., tensorN]
14 | output: stacked_tensor
15 | """
16 | return torch.stack(examples, dim=0)
17 |
18 |
19 | class tensor_stack_padding_collater:
20 | """
21 | examples: list of tensors.
22 | input: [tensor1, tensor2, ..., tensorN]
23 | output: padded_tensor
24 | """
25 |
26 | def __init__(self, padding_id, padding_position="right", return_padding_mask=True):
27 | assert padding_position in ("left", "right")
28 | self.padding_id = padding_id
29 | self.padding_position = padding_position
30 | self.return_padding_mask = return_padding_mask
31 |
32 | def __call__(self, examples):
33 | dtype = examples[0].dtype
34 | if self.padding_position == "right":
35 | padded_examples = rnn_utils.pad_sequence(examples, batch_first=True, padding_value=self.padding_id)
36 | elif self.padding_position == "left": # This will take about twice the time compared to right padding
37 | flipped_examples = [torch.flip(tensor, dims=[0]) for tensor in examples]
38 | padded_examples_flip = rnn_utils.pad_sequence(flipped_examples, batch_first=True, padding_value=self.padding_id)
39 | padded_examples = torch.flip(padded_examples_flip, dims=[1])
40 | else:
41 | raise NotImplementedError
42 | padded_examples = padded_examples.to(dtype)
43 |
44 | if self.return_padding_mask:
45 | padding_mask = (padded_examples != self.padding_id)
46 | return padded_examples, padding_mask
47 | else:
48 | return padded_examples
49 |
50 |
51 | def tensor_lists_stack_collator(examples):
52 | """
53 | examples: list of tensor lists.
54 | input:
55 | [
56 | [tensor1, tensor1, ..., tensor1],
57 | [tensor2, tensor2, ..., tensor2],
58 | ...
59 | [tensorN, tensorN, ..., tensorN],
60 | ]
61 | output:
62 | [
63 | stacked_tensor1,
64 | stacked_tensor2,
65 | ...
66 | stacked_tensorN,
67 | ]
68 | """
69 | return [torch.stack([tensor_list[i] for tensor_list in examples], dim=0) for i in range(len(examples[0]))]
70 |
71 |
72 | class tensor_lists_stack_padding_collater:
73 | def __init__(self, padding_id, padding_position="right", return_padding_mask=True, tensor_ids_to_create_mask: List = None):
74 | assert padding_position in ("left", "right")
75 | self.padding_id = padding_id
76 | self.padding_position = padding_position
77 | self.return_padding_mask = return_padding_mask
78 |
79 | # set indices of tensors in list to create "padding_mask"
80 | # if set to "None", then "padding_mask" will be created for all keys in dict
81 | self.tensor_ids_to_create_mask = tensor_ids_to_create_mask
82 |
83 | def __call__(self, examples):
84 | """
85 | examples: list of tensor lists.
86 | input:
87 | [
88 | [tensor1, tensor1, ..., tensor1],
89 | [tensor2, tensor2, ..., tensor2],
90 | ...
91 | [tensorN, tensorN, ..., tensorN],
92 | ]
93 | output:
94 | [
95 | padded_tensor1,
96 | padded_tensor2,
97 | ...
98 | padded_tensorN,
99 | ]
100 | """
101 | num_tensors = len(examples[0])
102 | padded_tensors = []
103 | padding_masks = []
104 |
105 | for i in range(num_tensors):
106 | dtype = examples[0][0].dtype
107 | if self.padding_position == "right":
108 | tensors = [tensor_list[i] for tensor_list in examples]
109 | padded_tensor = rnn_utils.pad_sequence(tensors, batch_first=True, padding_value=self.padding_id)
110 | elif self.padding_position == "left": # This will take about twice the time compared to right padding
111 | flipped_tensors = [torch.flip(tensor_list[i], dims=[0]) for tensor_list in examples]
112 | flipped_padded_tensors = rnn_utils.pad_sequence(flipped_tensors, batch_first=True, padding_value=self.padding_id)
113 | padded_tensor = torch.flip(flipped_padded_tensors, dims=[1])
114 | else:
115 | raise NotImplementedError
116 | padded_tensor = padded_tensor.to(dtype)
117 |
118 | padded_tensors.append(padded_tensor)
119 |
120 | if self.return_padding_mask:
121 | if self.tensor_ids_to_create_mask is None or i in self.tensor_ids_to_create_mask:
122 | padding_masks.append(padded_tensors[i] != self.padding_id)
123 | else:
124 | padding_masks.append(None)
125 |
126 | if self.return_padding_mask:
127 | return padded_tensors, padding_masks
128 | else:
129 | return padded_tensors
130 |
131 |
132 | def tensor_dicts_stack_collator(examples):
133 | """
134 | examples: list of tensor dicts.
135 | input:
136 | [
137 | {
138 | "key1": tensor1,
139 | "key2": tensor2,
140 | ...
141 | "keyN": tensorN,
142 | },
143 | {
144 | "key1": tensor1,
145 | "key2": tensor2,
146 | ...
147 | "keyN": tensorN,
148 | },
149 | ...
150 | {
151 | "key1": tensor1,
152 | "key2": tensor2,
153 | ...
154 | "keyN": tensorN,
155 | }
156 | ]
157 | output:
158 | {
159 | "key1": stacked_tensor1,
160 | "key2": stacked_tensor2,
161 | ...
162 | "keyN": stacked_tensorN,
163 | }
164 | """
165 | return {key: torch.stack([tensor_dict[key] for tensor_dict in examples], dim=0) for key in examples[0].keys()}
166 |
167 |
168 | class tensor_dict_stack_padding_collater:
169 | def __init__(self, padding_id, padding_position="right", return_padding_mask=True, tensor_keys_to_create_mask: List = None):
170 | assert padding_position in ("left", "right")
171 | self.padding_id = padding_id
172 | self.padding_position = padding_position
173 | self.return_padding_mask = return_padding_mask
174 |
175 | # set keys of tensors in dict to create "padding_mask"
176 | # if set to "None", then "padding_mask" will be created for all keys in dict
177 | self.tensor_keys_to_create_mask = tensor_keys_to_create_mask
178 |
179 | def __call__(self, examples):
180 | """
181 | examples: list of tensor (or other types) dicts.
182 | input:
183 | [
184 | {
185 | "key0": int,
186 | "key1": tensor1,
187 | "key2": tensor2,
188 | ...
189 | "keyN": tensorN,
190 | },
191 | {
192 | "key0": int,
193 | "key1": tensor1,
194 | "key2": tensor2,
195 | ...
196 | "keyN": tensorN,
197 | },
198 | ...
199 | {
200 | "key0": int,
201 | "key1": tensor1,
202 | "key2": tensor2,
203 | ...
204 | "keyN": tensorN,
205 | }
206 | ]
207 | output:
208 | {
209 | "key0": [int, int, ..., int],
210 | "key1": padded_tensor1,
211 | "key2": padded_tensor2,
212 | ...
213 | "keyN": padded_tensorN,
214 | }
215 | """
216 | keys = examples[0].keys()
217 | padded_tensors = {}
218 | padding_masks = {}
219 |
220 | for key in keys:
221 | if isinstance(examples[0][key], torch.Tensor):
222 | if self.padding_position == "right":
223 | tensors = [tensor_dict[key] for tensor_dict in examples]
224 | padded_tensor = rnn_utils.pad_sequence(tensors, batch_first=True, padding_value=self.padding_id)
225 | elif self.padding_position == "left": # This will take about twice the time compared to right padding
226 | flipped_tensors = [torch.flip(tensor_dict[key], dims=[0]) for tensor_dict in examples]
227 | flipped_padded_tensors = rnn_utils.pad_sequence(flipped_tensors, batch_first=True, padding_value=self.padding_id)
228 | padded_tensor = torch.flip(flipped_padded_tensors, dims=[1])
229 | else:
230 | raise NotImplementedError
231 | else: # not tensor type, return as a list
232 | padded_tensor = [tensor_dict[key] for tensor_dict in examples]
233 |
234 | padded_tensors[key] = padded_tensor
235 |
236 | if self.return_padding_mask and isinstance(examples[0][key], torch.Tensor):
237 | if self.tensor_keys_to_create_mask is None or key in self.tensor_keys_to_create_mask:
238 | padding_masks[key] = (padded_tensors[key] != self.padding_id)
239 | else:
240 | padding_masks[key] = None
241 |
242 | if self.return_padding_mask:
243 | return padded_tensors, padding_masks
244 | else:
245 | return padded_tensors
246 |
--------------------------------------------------------------------------------
/data_finetune/README.md:
--------------------------------------------------------------------------------
1 | Please [download](https://github.com/A4Bio/GraphsGPT/releases/tag/data) the configurations and data for finetuning and put them in this folder.
2 |
3 | Finally, the structure of this folder should be like:
4 |
5 | ````
6 | -- data_finetune
7 | -- model_zoom (checkpoints and finetuning configs)
8 | -- bace
9 | -- bbbp
10 | -- hiv
11 | ...
12 | -- molecular_property_prediction (finetuning data)
13 | -- bace
14 | -- bbbp
15 | -- esol
16 | ...
17 | ````
--------------------------------------------------------------------------------
/entrypoints/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/__init__.py
--------------------------------------------------------------------------------
/entrypoints/generation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/generation/__init__.py
--------------------------------------------------------------------------------
/entrypoints/generation/conditional/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/generation/conditional/__init__.py
--------------------------------------------------------------------------------
/entrypoints/generation/conditional/visualize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from argparse import ArgumentParser
4 | from rdkit import RDLogger
5 | from tqdm import tqdm
6 |
7 | from data.tokenizer import GraphsGPTTokenizer
8 | from utils.io import delete_file_or_dir, create_dir, save_mol_png, save_empty_png, summary_property_from_file, summary_property_from_all_files
9 | from utils.molecule import get_molecule_standard_scaffold
10 | from utils.operations.operation_string import str2bool
11 | from utils.property_scores.scoring_func import get_qed, get_sa, get_logp, get_is_valid
12 |
13 | RDLogger.DisableLog('rdApp.*')
14 |
15 |
16 | def main(args):
17 | tokenizer = GraphsGPTTokenizer.from_pretrained(args.model_name_or_path)
18 |
19 | file_list = os.listdir(args.generation_results_dir)
20 | file_list = sorted(file_list)
21 | file_list = file_list[args.file_begin_index:args.file_end_index]
22 |
23 | all_smiles = []
24 |
25 | delete_file_or_dir(args.save_dir)
26 |
27 | """visualize"""
28 | for file_name in tqdm(file_list):
29 | success_cnt = 0
30 | invalid_cnt = 0
31 | fail_cnt = 0
32 |
33 | valid_list = []
34 | qed_list = []
35 | sa_list = []
36 | logp_list = []
37 | smiles_list = []
38 | scaffold_smiles_list = []
39 |
40 | save_dir = os.path.join(args.save_dir, f"{file_name.split('.')[0]}")
41 | create_dir(save_dir)
42 |
43 | """read file"""
44 | file_path = os.path.join(args.generation_results_dir, file_name)
45 | with open(file_path, 'rb') as f:
46 | result_list = pickle.load(f)
47 | if not isinstance(result_list, (list, tuple)):
48 | result_list = [result_list]
49 |
50 | """save png"""
51 | for i, result in enumerate(result_list):
52 | print(i)
53 | save_img_file = os.path.join(save_dir, f"{i}.png")
54 |
55 | if result is not None:
56 | mol, smiles = tokenizer.decode(result, kekulize=True) # kekulize the molecule for score calculation
57 |
58 | if mol is None:
59 | valid_list.append(None)
60 | qed_list.append(None)
61 | sa_list.append(None)
62 | logp_list.append(None)
63 | smiles_list.append(None)
64 | scaffold_smiles_list.append(None)
65 | if args.save_images:
66 | save_empty_png(save_img_file)
67 | invalid_cnt += 1
68 | else:
69 | valid = get_is_valid(mol)
70 | qed = get_qed(mol)
71 | sa = get_sa(mol)
72 | logp = get_logp(mol)
73 |
74 | scaffold = get_molecule_standard_scaffold(mol, normalizer=tokenizer.normalizer)
75 | scaffold_smiles = tokenizer._convert_molecule_to_standard_smiles(scaffold)
76 |
77 | valid_list.append(valid)
78 | qed_list.append(qed)
79 | sa_list.append(sa)
80 | logp_list.append(logp)
81 | smiles_list.append(smiles)
82 | scaffold_smiles_list.append(scaffold_smiles)
83 |
84 | if args.save_images:
85 | save_mol_png(mol, save_img_file)
86 | success_cnt += 1
87 | else:
88 | valid_list.append(None)
89 | qed_list.append(None)
90 | sa_list.append(None)
91 | logp_list.append(None)
92 | smiles_list.append(None)
93 | scaffold_smiles_list.append(None)
94 | if args.save_images:
95 | save_empty_png(save_img_file)
96 | fail_cnt += 1
97 |
98 | all_smiles.extend(smiles_list)
99 |
100 | """save statistics"""
101 | with open(os.path.join(save_dir, "count.txt"), 'a') as f:
102 | f.write(f"Success count: {success_cnt}\n")
103 | f.write(f"Invalid count: {invalid_cnt}\n")
104 | f.write(f"Fail count: {fail_cnt}\n")
105 |
106 | with open(os.path.join(save_dir, "valid.txt"), 'a') as f:
107 | for valid in valid_list:
108 | f.write(f"{valid}\n")
109 |
110 | with open(os.path.join(save_dir, "qed.txt"), 'a') as f:
111 | for qed in qed_list:
112 | f.write(f"{qed}\n")
113 |
114 | with open(os.path.join(save_dir, "sa.txt"), 'a') as f:
115 | for sa in sa_list:
116 | f.write(f"{sa}\n")
117 |
118 | with open(os.path.join(save_dir, "logp.txt"), 'a') as f:
119 | for logp in logp_list:
120 | f.write(f"{logp}\n")
121 |
122 | with open(os.path.join(save_dir, "smiles.txt"), 'a') as f:
123 | for smiles in smiles_list:
124 | f.write(f"{smiles}\n")
125 |
126 | with open(os.path.join(save_dir, "scaffold_smiles.txt"), 'a') as f:
127 | for scaffold_smiles in scaffold_smiles_list:
128 | f.write(f"{scaffold_smiles}\n")
129 |
130 | mean_qed, std_qed, num_qed = summary_property_from_file(os.path.join(save_dir, "qed.txt"))
131 | mean_sa, std_sa, num_sa = summary_property_from_file(os.path.join(save_dir, "sa.txt"))
132 | mean_logp, std_logp, num_logp = summary_property_from_file(os.path.join(save_dir, "logp.txt"))
133 |
134 | valid_num = sum([1 for valid in valid_list if valid])
135 | total_num = len([smiles for smiles in smiles_list if smiles is not None])
136 |
137 | with open(os.path.join(save_dir, "summary.txt"), 'w') as f:
138 | f.write(f"Summary QED: mean={format(mean_qed, '.3f')}, std={format(std_qed, '.3f')}, total_cnt={num_qed}\n")
139 | f.write(f"Summary SA: mean={format(mean_sa, '.3f')}, std={format(std_sa, '.3f')}, total_cnt={num_sa}\n")
140 | f.write(f"Summary logP: mean={format(mean_logp, '.3f')}, std={format(std_logp, '.3f')}, total_cnt={num_logp}\n")
141 | f.write(f"Summary validity: {format(valid_num / total_num * 100, '.2f')}% ({valid_num}/{total_num})\n")
142 |
143 | """summarize the results"""
144 | mean_qed, std_qed, num_qed = summary_property_from_all_files(args.save_dir, "qed.txt", value_type="float")
145 | mean_sa, std_sa, num_sa = summary_property_from_all_files(args.save_dir, "sa.txt", value_type="float")
146 | mean_logp, std_logp, num_logp = summary_property_from_all_files(args.save_dir, "logp.txt", value_type="float")
147 | mean_validity, std_validity, num_validity = summary_property_from_all_files(args.save_dir, "valid.txt", value_type="bool")
148 |
149 | all_total_num = len([smiles for smiles in all_smiles if smiles is not None])
150 |
151 | with open(os.path.join(args.save_dir, "final_summary.txt"), 'w') as f:
152 | f.write(f"Summary QED: mean={format(mean_qed, '.3f')}, std={format(std_qed, '.3f')}, total_cnt={num_qed}\n")
153 | f.write(f"Summary SA: mean={format(mean_sa, '.3f')}, std={format(std_sa, '.3f')}, total_cnt={num_sa}\n")
154 | f.write(f"Summary logP: mean={format(mean_logp, '.3f')}, std={format(std_logp, '.3f')}, total_cnt={num_logp}\n")
155 | f.write(f"Summary validity: {format(mean_validity * 100, '.2f')}% ({round(mean_validity * all_total_num)}/{all_total_num})\n")
156 |
157 |
158 | if __name__ == '__main__':
159 | parser = ArgumentParser()
160 | parser.add_argument("--model_name_or_path", default="DaizeDong/GraphsGPT-1W-C", type=str)
161 | parser.add_argument('--generation_results_dir', default=None, type=str)
162 | parser.add_argument('--save_dir', default=None, type=str)
163 | parser.add_argument('--save_images', default="True", type=str)
164 |
165 | parser.add_argument('--file_begin_index', default=0, type=int)
166 | parser.add_argument('--file_end_index', default=2, type=int)
167 | args = parser.parse_args()
168 | args.save_images = str2bool(args.save_images)
169 | print(args)
170 | main(args)
171 |
--------------------------------------------------------------------------------
/entrypoints/generation/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/generation/evaluation/__init__.py
--------------------------------------------------------------------------------
/entrypoints/generation/evaluation/evaluate_moses_few_shot_sampling.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import argparse
4 | import random
5 | import torch
6 | from tqdm import tqdm
7 |
8 | import moses
9 | from data.collate_fn import tensor_dict_stack_padding_collater
10 | from data.tokenizer import GraphsGPTTokenizer
11 | from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM
12 | from utils.io import create_dir, save_json, delete_file_or_dir
13 | from utils.operations.operation_list import split_list
14 | from utils.operations.operation_tensor import move_tensors_to_device
15 |
16 |
17 | def main(args):
18 | random.seed(0)
19 | torch.manual_seed(0)
20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21 |
22 | """load data & model"""
23 | train_smiles = moses.get_dataset("train")
24 | # test_smiles = moses.get_dataset("test")
25 | # test_scaffolds = moses.get_dataset("test_scaffolds")
26 |
27 | pad_collator = tensor_dict_stack_padding_collater(0, tensor_keys_to_create_mask=["input_ids"])
28 |
29 | tokenizer = GraphsGPTTokenizer.from_pretrained(args.model_name_or_path)
30 | model = GraphsGPTForCausalLM.from_pretrained(args.model_name_or_path)
31 | model.to(device)
32 |
33 | """generate according to reference fingerprints"""
34 | all_smiles = []
35 |
36 | # initialize the variable for batch-optimization of few-shot generation
37 | max_shots_in_a_batch = args.batch_size_valid // args.num_samples_each_shot
38 | next_shot_id = 0
39 | finished_shots = set() # shots that have done the generation
40 | candidate_shots = {} # shots for the next generation forward
41 |
42 | with torch.inference_mode():
43 | while len(finished_shots) < args.num_shots:
44 | """prepare initial fingerprint tokens"""
45 | extra_inputs_num = max_shots_in_a_batch - len(candidate_shots)
46 |
47 | if extra_inputs_num > 0: # need some new inputs to fill up the batch
48 | # construct inputs
49 | inputs = []
50 | for i in range(next_shot_id, next_shot_id + extra_inputs_num):
51 | this_shot_inputs = tokenizer.encode(train_smiles[i], return_tensors="pt")
52 | if this_shot_inputs is None: # the encoding may fail
53 | extra_inputs_num -= 1
54 | continue
55 | move_tensors_to_device(this_shot_inputs, device)
56 | inputs.append(this_shot_inputs)
57 | inputs, mask = pad_collator(inputs)
58 | inputs["attention_mask"] = mask["input_ids"]
59 |
60 | # repeat fingerprint tokens
61 | fingerprint_tokens = model.encode_to_fingerprints(**inputs)
62 | fingerprint_tokens = fingerprint_tokens.repeat_interleave(args.num_samples_each_shot, dim=0) # (extra_inputs_num * num_samples_each_shot, num_fingerprints, hidden_size)
63 | each_shot_fingerprint_tokens = fingerprint_tokens.split(args.num_samples_each_shot, dim=0) # each is (num_samples_each_shot, num_fingerprints, hidden_size)
64 |
65 | # add to shots
66 | for shot_id, this_shot_fingerprint_tokens in enumerate(each_shot_fingerprint_tokens, start=next_shot_id):
67 | candidate_shots[shot_id] = {
68 | "sample_times": 0,
69 | "fingerprint": this_shot_fingerprint_tokens,
70 | "generation_results": [],
71 | }
72 | print(f"Total {len(candidate_shots)} candidate shots in this iteration!")
73 |
74 | # update next_shot_id
75 | next_shot_id += extra_inputs_num
76 | print(f"Encoded {extra_inputs_num} new fingerprints!")
77 |
78 | """aggregate fingerprint tokens for candidate shots"""
79 | shot_ids_in_batch = []
80 | aggregated_fingerprint_tokens = []
81 | for shot_id, candidate_shot in candidate_shots.items():
82 | shot_ids_in_batch.append(shot_id)
83 | aggregated_fingerprint_tokens.append(candidate_shot["fingerprint"])
84 | aggregated_fingerprint_tokens = torch.cat(aggregated_fingerprint_tokens, dim=0) # (max_shots_in_a_batch * num_samples_each_shot, num_fingerprints, hidden_size)
85 |
86 | """random sampling & generate"""
87 | # For each shot, we try to randomly sample fingerprints for at most (max_sample_times * num_samples_each_shot) times.
88 | # If the sampling trys exceed the limit, we stop sampling this shot.
89 | generate_fingerprint_tokens = torch.normal(mean=aggregated_fingerprint_tokens, std=args.sample_std)
90 | generated_results: list = model.generate_from_fingerprints(
91 | fingerprint_tokens=generate_fingerprint_tokens,
92 | bond_dict=tokenizer.bond_dict,
93 | input_ids_list=None,
94 | graph_position_ids_1_list=None,
95 | graph_position_ids_2_list=None,
96 | identifier_ids_list=None,
97 | strict_generation=True,
98 | do_sample=False,
99 | topk=1,
100 | temperature=1.0,
101 | max_atoms=None,
102 | similarity_threshold=0.5,
103 | check_first_node=True,
104 | check_atom_valence=True,
105 | fix_aromatic_bond=True,
106 | use_cache=False,
107 | save_failed=False,
108 | show_progress=True,
109 | verbose=False,
110 | )
111 | each_shot_generate_results = split_list(generated_results, args.num_samples_each_shot)
112 |
113 | """add results to corresponding shots"""
114 | for shot_id, this_shot_generate_results in zip(shot_ids_in_batch, each_shot_generate_results):
115 | this_shot_generate_results = [move_tensors_to_device(result, "cpu") for result in this_shot_generate_results if result is not None] # remove None results
116 | candidate_shots[shot_id]["generation_results"].extend(this_shot_generate_results)
117 | candidate_shots[shot_id]["sample_times"] += 1
118 | print(f"Added {len(this_shot_generate_results)} generated results for shot {shot_id}. Now: {len(candidate_shots[shot_id]['generation_results'])}")
119 |
120 | """gather results & remove finished shots"""
121 | this_iter_finished_shot_ids = []
122 | this_iter_finished_results = []
123 | for shot_id, candidate_shot in tqdm(candidate_shots.items(), desc="Aggregating results"):
124 | if candidate_shot["sample_times"] >= args.max_sample_times or len(candidate_shots[shot_id]["generation_results"]) >= args.num_samples_each_shot: # exceed max trys / results are enough
125 | this_iter_finished_shot_ids.append(shot_id)
126 | this_iter_finished_results.extend(candidate_shots[shot_id]["generation_results"][:args.num_samples_each_shot]) # add at most "num_samples_each_shot" results
127 | if len(finished_shots) + len(this_iter_finished_shot_ids) >= args.num_shots:
128 | break
129 |
130 | for shot_id in this_iter_finished_shot_ids:
131 | print(f"Finished shot {shot_id}. Final samples: {len(candidate_shots[shot_id]['generation_results'])}")
132 | candidate_shots.pop(shot_id) # remove finished shots
133 | finished_shots.add(shot_id)
134 | print(f"Now total finished shots: {len(finished_shots)}")
135 |
136 | print("Decoding to SMILES...")
137 | if len(this_iter_finished_results) > 0:
138 | this_iter_finished_results_batched, mask = pad_collator(this_iter_finished_results)
139 | this_iter_finished_results_batched["attention_mask"] = mask["input_ids"]
140 | decoded_mol_list, decoded_smiles_list = tokenizer.decode(this_iter_finished_results_batched, kekulize=True, nprocs=None)
141 | # You can use the following line to accelerate the decoding. However, it may occasionally raise errors.
142 | # decoded_mol_list, decoded_smiles_list = tokenizer.decode(this_iter_finished_results_batched, kekulize=True, nprocs=args.num_processes)
143 | all_smiles.extend(decoded_smiles_list)
144 | print(f"Now total generated SMILES: {len(all_smiles)}")
145 |
146 | """get metrics"""
147 | print("Getting metrics...")
148 | metrics = moses.get_all_metrics(
149 | all_smiles,
150 | n_jobs=args.num_processes,
151 | device=device,
152 | batch_size=args.batch_size_valid,
153 | )
154 | print(metrics)
155 |
156 | delete_file_or_dir(args.save_path)
157 | create_dir(args.save_path)
158 | save_metric_file = os.path.join(args.save_path, "metrics.json")
159 | save_json(metrics, save_metric_file)
160 | print(f"Metrics saved to {save_metric_file}")
161 |
162 | """save results"""
163 | save_all_smiles_file = os.path.join(args.save_path, f"all_smiles.txt")
164 | with open(save_all_smiles_file, "w") as f:
165 | for smiles in all_smiles:
166 | f.write(smiles + "\n")
167 | print(f"Generated SMILES saved to {save_all_smiles_file}")
168 |
169 |
170 | if __name__ == "__main__":
171 | parser = argparse.ArgumentParser()
172 | parser.add_argument("--model_name_or_path", default="DaizeDong/GraphsGPT-1W", type=str, help="Path to the GraphsGPT hugging face model.")
173 | parser.add_argument("--save_path", default="./results/unconditional/moses", type=str, help="Path to save the evaluation results.")
174 |
175 | parser.add_argument('--batch_size_valid', default=8192, type=int, help="Number of samples per batch.")
176 | parser.add_argument("--sample_std", default=1.0, type=float, help="The standard deviation for sampling.")
177 | parser.add_argument('--max_sample_times', default=10, type=int, help="The maximum number of attempts to sample for each shot. Shots with insufficient successful generated results exceeding this number of attempts will be discarded.")
178 | parser.add_argument("--num_shots", default=100000, type=int, help="The number of shots for reference.")
179 | parser.add_argument("--num_samples_each_shot", default=1, type=int, help="The number of generated samples for each shot.")
180 |
181 | parser.add_argument("--num_processes", default=32, type=int, help="Number of parallel processes for decoding & metric calculation.")
182 | args = parser.parse_args()
183 |
184 | print(args)
185 | main(args)
186 |
--------------------------------------------------------------------------------
/entrypoints/generation/evaluation/evaluate_zinc250k_few_shot_sampling.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import argparse
4 | import random
5 | import torch
6 | from tqdm import tqdm
7 |
8 | import moses
9 | from data.collate_fn import tensor_dict_stack_padding_collater
10 | from data.tokenizer import GraphsGPTTokenizer
11 | from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM
12 | from utils.io import create_dir, save_json, delete_file_or_dir
13 | from utils.operations.operation_list import split_list
14 | from utils.operations.operation_tensor import move_tensors_to_device
15 | from utils.property_scores.nspdk import get_npsdk
16 |
17 |
18 | def main(args):
19 | random.seed(0)
20 | torch.manual_seed(0)
21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 |
23 | """load data & model"""
24 | train_file = "data/zinc250k/all_250k_rndm_zinc_drugs_clean_3.txt"
25 | valid_file = "data/zinc250k/valid_250k_rndm_zinc_drugs_clean_3.txt"
26 |
27 | with open(train_file, "r") as file:
28 | train_smiles = [data.replace("\n", "") for data in file.readlines()]
29 | random.shuffle(train_smiles)
30 | with open(valid_file, "r") as file:
31 | valid_smiles = [data.replace("\n", "") for data in file.readlines()]
32 |
33 | pad_collator = tensor_dict_stack_padding_collater(0, tensor_keys_to_create_mask=["input_ids"])
34 |
35 | tokenizer = GraphsGPTTokenizer.from_pretrained(args.model_name_or_path)
36 | model = GraphsGPTForCausalLM.from_pretrained(args.model_name_or_path)
37 | model.to(device)
38 |
39 | """generate according to reference fingerprints"""
40 | all_smiles = []
41 |
42 | # initialize the variable for batch-optimization of few-shot generation
43 | max_shots_in_a_batch = args.batch_size_valid // args.num_samples_each_shot
44 | next_shot_id = 0
45 | finished_shots = set() # shots that have done the generation
46 | candidate_shots = {} # shots for the next generation forward
47 |
48 | with torch.inference_mode():
49 | while len(finished_shots) < args.num_shots:
50 | """prepare initial fingerprint tokens"""
51 | extra_inputs_num = max_shots_in_a_batch - len(candidate_shots)
52 |
53 | if extra_inputs_num > 0: # need some new inputs to fill up the batch
54 | # construct inputs
55 | inputs = []
56 | for i in range(next_shot_id, next_shot_id + extra_inputs_num):
57 | this_shot_inputs = tokenizer.encode(train_smiles[i], return_tensors="pt")
58 | if this_shot_inputs is None: # the encoding may fail
59 | extra_inputs_num -= 1
60 | continue
61 | move_tensors_to_device(this_shot_inputs, device)
62 | inputs.append(this_shot_inputs)
63 | inputs, mask = pad_collator(inputs)
64 | inputs["attention_mask"] = mask["input_ids"]
65 |
66 | # repeat fingerprint tokens
67 | fingerprint_tokens = model.encode_to_fingerprints(**inputs)
68 | fingerprint_tokens = fingerprint_tokens.repeat_interleave(args.num_samples_each_shot, dim=0) # (extra_inputs_num * num_samples_each_shot, num_fingerprints, hidden_size)
69 | each_shot_fingerprint_tokens = fingerprint_tokens.split(args.num_samples_each_shot, dim=0) # each is (num_samples_each_shot, num_fingerprints, hidden_size)
70 |
71 | # add to shots
72 | for shot_id, this_shot_fingerprint_tokens in enumerate(each_shot_fingerprint_tokens, start=next_shot_id):
73 | candidate_shots[shot_id] = {
74 | "sample_times": 0,
75 | "fingerprint": this_shot_fingerprint_tokens,
76 | "generation_results": [],
77 | }
78 | print(f"Total {len(candidate_shots)} candidate shots in this iteration!")
79 |
80 | # update next_shot_id
81 | next_shot_id += extra_inputs_num
82 | print(f"Encoded {extra_inputs_num} new fingerprints!")
83 |
84 | """aggregate fingerprint tokens for candidate shots"""
85 | shot_ids_in_batch = []
86 | aggregated_fingerprint_tokens = []
87 | for shot_id, candidate_shot in candidate_shots.items():
88 | shot_ids_in_batch.append(shot_id)
89 | aggregated_fingerprint_tokens.append(candidate_shot["fingerprint"])
90 | aggregated_fingerprint_tokens = torch.cat(aggregated_fingerprint_tokens, dim=0) # (max_shots_in_a_batch * num_samples_each_shot, num_fingerprints, hidden_size)
91 |
92 | """random sampling & generate"""
93 | # For each shot, we try to randomly sample fingerprints for at most (max_sample_times * num_samples_each_shot) times.
94 | # If the sampling trys exceed the limit, we stop sampling this shot.
95 | generate_fingerprint_tokens = torch.normal(mean=aggregated_fingerprint_tokens, std=args.sample_std)
96 | generated_results: list = model.generate_from_fingerprints(
97 | fingerprint_tokens=generate_fingerprint_tokens,
98 | bond_dict=tokenizer.bond_dict,
99 | input_ids_list=None,
100 | graph_position_ids_1_list=None,
101 | graph_position_ids_2_list=None,
102 | identifier_ids_list=None,
103 | strict_generation=True,
104 | do_sample=False,
105 | topk=1,
106 | temperature=1.0,
107 | max_atoms=None,
108 | similarity_threshold=0.5,
109 | check_first_node=True,
110 | check_atom_valence=True,
111 | fix_aromatic_bond=True,
112 | use_cache=False,
113 | save_failed=False,
114 | show_progress=True,
115 | verbose=False,
116 | )
117 | each_shot_generate_results = split_list(generated_results, args.num_samples_each_shot)
118 |
119 | """add results to corresponding shots"""
120 | for shot_id, this_shot_generate_results in zip(shot_ids_in_batch, each_shot_generate_results):
121 | this_shot_generate_results = [move_tensors_to_device(result, "cpu") for result in this_shot_generate_results if result is not None] # remove None results
122 | candidate_shots[shot_id]["generation_results"].extend(this_shot_generate_results)
123 | candidate_shots[shot_id]["sample_times"] += 1
124 | print(f"Added {len(this_shot_generate_results)} generated results for shot {shot_id}. Now: {len(candidate_shots[shot_id]['generation_results'])}")
125 |
126 | """gather results & remove finished shots"""
127 | this_iter_finished_shot_ids = []
128 | this_iter_finished_results = []
129 | for shot_id, candidate_shot in tqdm(candidate_shots.items(), desc="Aggregating results"):
130 | if candidate_shot["sample_times"] >= args.max_sample_times or len(candidate_shots[shot_id]["generation_results"]) >= args.num_samples_each_shot: # exceed max trys / results are enough
131 | this_iter_finished_shot_ids.append(shot_id)
132 | this_iter_finished_results.extend(candidate_shots[shot_id]["generation_results"][:args.num_samples_each_shot]) # add a maximum of "num_samples_each_shot" results
133 | if len(finished_shots) + len(this_iter_finished_shot_ids) >= args.num_shots:
134 | break
135 |
136 | for shot_id in this_iter_finished_shot_ids:
137 | print(f"Finished shot {shot_id}. Final samples: {len(candidate_shots[shot_id]['generation_results'])}")
138 | candidate_shots.pop(shot_id) # remove finished shots
139 | finished_shots.add(shot_id)
140 | print(f"Now total finished shots: {len(finished_shots)}")
141 |
142 | print("Decoding to SMILES...")
143 | if len(this_iter_finished_results) > 0:
144 | this_iter_finished_results_batched, mask = pad_collator(this_iter_finished_results)
145 | this_iter_finished_results_batched["attention_mask"] = mask["input_ids"]
146 | decoded_mol_list, decoded_smiles_list = tokenizer.decode(this_iter_finished_results_batched, kekulize=True, nprocs=None)
147 | # You can use the following line to accelerate the decoding. However, it may occasionally raise errors.
148 | # decoded_mol_list, decoded_smiles_list = tokenizer.decode(this_iter_finished_results_batched, kekulize=True, nprocs=args.num_processes)
149 | all_smiles.extend(decoded_smiles_list)
150 | print(f"Now total generated SMILES: {len(all_smiles)}")
151 |
152 | """get metrics"""
153 | print("Getting metrics...")
154 | metrics = moses.get_all_metrics(
155 | all_smiles,
156 | n_jobs=args.num_processes,
157 | device=device,
158 | batch_size=args.batch_size_valid,
159 | test=valid_smiles,
160 | test_scaffolds=valid_smiles,
161 | )
162 |
163 | # 🔍 NPSDK
164 | npsdk_mmd = get_npsdk(all_smiles, valid_smiles, n_jobs=args.num_processes)
165 | metrics["NPSDK"] = npsdk_mmd
166 |
167 | print(metrics)
168 |
169 | delete_file_or_dir(args.save_path)
170 | create_dir(args.save_path)
171 | save_metric_file = os.path.join(args.save_path, "metrics.json")
172 | save_json(metrics, save_metric_file)
173 | print(f"Metrics saved to {save_metric_file}")
174 |
175 | """save results"""
176 | save_all_smiles_file = os.path.join(args.save_path, f"all_smiles.txt")
177 | with open(save_all_smiles_file, "w") as f:
178 | for smiles in all_smiles:
179 | f.write(smiles + "\n")
180 | print(f"Generated SMILES saved to {save_all_smiles_file}")
181 |
182 |
183 | if __name__ == "__main__":
184 | parser = argparse.ArgumentParser()
185 | parser.add_argument("--model_name_or_path", default="DaizeDong/GraphsGPT-1W", type=str, help="Path to the GraphsGPT hugging face model.")
186 | parser.add_argument("--save_path", default="./results/unconditional/zinc250k", type=str, help="Path to save the evaluation results.")
187 |
188 | parser.add_argument('--batch_size_valid', default=8192, type=int, help="Number of samples per batch.")
189 | parser.add_argument("--sample_std", default=1.0, type=float, help="The standard deviation for sampling.")
190 | parser.add_argument('--max_sample_times', default=10, type=int, help="The maximum number of attempts to sample for each shot. Shots with insufficient successful generated results exceeding this number of attempts will be discarded.")
191 | parser.add_argument("--num_shots", default=100000, type=int, help="The number of shots for reference.")
192 | parser.add_argument("--num_samples_each_shot", default=1, type=int, help="The number of generated samples for each shot.")
193 |
194 | parser.add_argument("--num_processes", default=32, type=int, help="Number of parallel processes for decoding & metric calculation.")
195 | args = parser.parse_args()
196 |
197 | print(args)
198 | main(args)
199 |
--------------------------------------------------------------------------------
/entrypoints/generation/unconditional/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/generation/unconditional/__init__.py
--------------------------------------------------------------------------------
/entrypoints/generation/unconditional/visualize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from argparse import ArgumentParser
4 | from tqdm import tqdm
5 |
6 | from data.tokenizer import GraphsGPTTokenizer
7 | from utils.io import delete_file_or_dir, create_dir, save_mol_png, save_empty_png
8 | from utils.operations.operation_string import str2bool
9 |
10 |
11 | def main(args):
12 | tokenizer = GraphsGPTTokenizer.from_pretrained(args.model_name_or_path)
13 |
14 | file_list = os.listdir(args.generation_results_dir)
15 | file_list = sorted(file_list)
16 | file_list = file_list[args.file_begin_index:args.file_end_index]
17 |
18 | """visualize"""
19 | for file_name in tqdm(file_list):
20 | success_cnt = 0
21 | invalid_cnt = 0
22 | fail_cnt = 0
23 | smiles_list = []
24 |
25 | save_dir = os.path.join(args.save_dir, f"{file_name.split('.')[0]}")
26 | delete_file_or_dir(save_dir)
27 | create_dir(save_dir)
28 |
29 | """read file"""
30 | file_path = os.path.join(args.generation_results_dir, file_name)
31 | with open(file_path, 'rb') as f:
32 | result_list = pickle.load(f)
33 | if not isinstance(result_list, (list, tuple)):
34 | result_list = [result_list]
35 |
36 | """save png"""
37 | for i, result in enumerate(result_list):
38 | print(i)
39 | save_img_file = os.path.join(save_dir, f"{i}.png")
40 |
41 | if result is not None:
42 | mol, smiles = tokenizer.decode(result, kekulize=True) # kekulize the molecule for score calculation
43 |
44 | if mol is None:
45 | smiles_list.append(None)
46 | if args.save_images:
47 | save_empty_png(save_img_file)
48 | invalid_cnt += 1
49 | else:
50 | smiles_list.append(smiles)
51 | if args.save_images:
52 | save_mol_png(mol, save_img_file)
53 | # print(f"Molecule '{smiles}' saved to '{save_img_file}'.")
54 | success_cnt += 1
55 | else:
56 | smiles_list.append(None)
57 | if args.save_images:
58 | save_empty_png(save_img_file)
59 | fail_cnt += 1
60 |
61 | """save statistics"""
62 | with open(os.path.join(save_dir, "count.txt"), 'a') as f:
63 | f.write(f"Success count: {success_cnt}\n")
64 | f.write(f"Invalid count: {invalid_cnt}\n")
65 | f.write(f"Fail count: {fail_cnt}\n")
66 |
67 | with open(os.path.join(save_dir, "smiles.txt"), 'a') as f:
68 | for smiles in smiles_list:
69 | f.write(f"{smiles}\n")
70 |
71 |
72 | if __name__ == '__main__':
73 | parser = ArgumentParser()
74 | parser.add_argument("--model_name_or_path", default="DaizeDong/GraphsGPT-1W", type=str)
75 | parser.add_argument('--generation_results_dir', default=None, type=str)
76 | parser.add_argument('--save_dir', default=None, type=str)
77 | parser.add_argument('--save_images', default="True", type=str)
78 |
79 | parser.add_argument('--file_begin_index', default=0, type=int)
80 | parser.add_argument('--file_end_index', default=2, type=int)
81 | args = parser.parse_args()
82 | args.save_images = str2bool(args.save_images)
83 | print(args)
84 | main(args)
85 |
--------------------------------------------------------------------------------
/entrypoints/representation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/entrypoints/representation/__init__.py
--------------------------------------------------------------------------------
/entrypoints/representation/finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import math
4 | import os
5 | import pytorch_lightning as pl
6 | import pytorch_lightning.callbacks as plc
7 | import pytorch_lightning.loggers as plog
8 | import sys
9 | import torch
10 | import warnings
11 | from omegaconf import OmegaConf
12 | from pytorch_lightning.trainer import Trainer
13 |
14 | sys.path.append(os.getcwd())
15 | warnings.filterwarnings("ignore")
16 |
17 | from utils.representation.logger import SetupCallback
18 | from utils.representation.data_interface import DInterface
19 | from utils.representation.model_interface import MInterface
20 |
21 | model_zoom = {
22 | 'tox21': ['./data_finetune/model_zoom/tox21/tox21.yaml',
23 | './data_finetune/model_zoom/tox21/tox21.pth'],
24 | 'toxcast': ['./data_finetune/model_zoom/toxcast/toxcast.yaml',
25 | './data_finetune/model_zoom/toxcast/toxcast.pth'],
26 | 'bbbp': ['./data_finetune/model_zoom/bbbp/bbbp.yaml',
27 | './data_finetune/model_zoom/bbbp/bbbp.pth'],
28 | 'sider': ['./data_finetune/model_zoom/sider/sider.yaml',
29 | './data_finetune/model_zoom/sider/sider.pth'],
30 | 'hiv': ['./data_finetune/model_zoom/hiv/hiv.yaml',
31 | './data_finetune/model_zoom/hiv/hiv.pth'],
32 | 'bace': ['./data_finetune/model_zoom/bace/bace.yaml',
33 | './data_finetune/model_zoom/bace/bace.pth']
34 | }
35 |
36 |
37 | def create_parser():
38 | parser = argparse.ArgumentParser()
39 | # Set-up parameters
40 | parser.add_argument('--res_dir', default="./results/representation", type=str)
41 | parser.add_argument('--ex_name', default='debug', type=str)
42 | parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
43 |
44 | parser.add_argument('--task_name', default='bace', choices=['bace', 'bbbp', 'clintox', 'sider', 'tox21', 'toxcast', 'hiv', 'esol', 'freesolv', 'lipo'])
45 | parser.add_argument('--mixup_strategy', default='no_mix_pretrain', choices=['mix_embed', 'mix_graph', 'no_mix_pretrain', 'no_mix_vanilla'])
46 | parser.add_argument('--lr_scheduler', default='cosine')
47 | parser.add_argument('--offline', default=0, type=int)
48 |
49 | # dataset parameters
50 | parser.add_argument('--num_workers', default=8, type=int)
51 | parser.add_argument('--seed', default=1, type=int)
52 | parser.add_argument('--count', default=0, type=int)
53 | parser.add_argument('--multiple_conformation', default=1, type=int)
54 | parser.add_argument('--only_polar', default=-1, type=int)
55 | parser.add_argument('--remove_polar_hydrogen', default=False, type=bool)
56 | parser.add_argument('--remove_hydrogen', default=False, type=bool)
57 | parser.add_argument('--fingerprint', default='graphsgpt', type=str)
58 | parser.add_argument('--data', default='./data_finetune/molecular_property_prediction', type=str)
59 | parser.add_argument('--conf_size', default=11, type=int)
60 | parser.add_argument('--max_atoms', default=256, type=int)
61 | parser.add_argument('--self_prob', default=0.1, type=float)
62 | parser.add_argument('--no_shuffle', default=False, type=bool)
63 | parser.add_argument('--mix_times', default=1, type=int)
64 |
65 | # Training parameters
66 | parser.add_argument('--batch_size', default=128, type=int)
67 | parser.add_argument('--epoch', default=50, type=int, help='end epoch')
68 | parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate')
69 | parser.add_argument('--warmup_ratio', default=0.06, type=float, help='warmup rate')
70 | parser.add_argument('--ema_decay', default=0.999, type=float, help='warmup rate')
71 | parser.add_argument('--pos_weight', default=99, type=float, help='warmup rate')
72 |
73 | # Model parameters
74 | parser.add_argument('--encoder_layers', default=15, type=int)
75 | parser.add_argument('--embed_dim', default=512, type=int)
76 | parser.add_argument('--ffn_embed_dim', default=2048, type=int)
77 | parser.add_argument('--attention_heads', default=64, type=int)
78 | parser.add_argument('--emb_dropout', default=0.1, type=float)
79 | parser.add_argument('--dropout', default=0.1, type=float)
80 | parser.add_argument('--attention_dropout', default=0.1, type=float)
81 | parser.add_argument('--activation_dropout', default=0.0, type=float)
82 | parser.add_argument('--max_seq_len', default=512, type=int)
83 | parser.add_argument('--activation_fn', default='gelu', type=str)
84 | parser.add_argument('--post_ln', default=False, type=bool)
85 | parser.add_argument('--no_final_head_layer_norm', default=True, type=bool)
86 | parser.add_argument('--mixup_alpha', default=1.0, type=float)
87 | parser.add_argument('--beta', default=1.0, type=float)
88 | parser.add_argument('--encoder_embed_dim', default=512, type=int)
89 | parser.add_argument('--pooler_dropout', default=0.0, type=float)
90 | parser.add_argument('--num_classes', default=2, type=int) # need to be changed
91 | parser.add_argument('--loss_type', default='mixup_multi_task_BCE', type=str) # need to be changed
92 | parser.add_argument('--checkpoint_metric', default='test_auc', type=str) # need to be changed
93 | args = parser.parse_args()
94 | return args
95 |
96 |
97 | def load_callbacks(args):
98 | callbacks = []
99 |
100 | logdir = str(os.path.join(args.res_dir, args.ex_name))
101 |
102 | ckptdir = os.path.join(logdir, "checkpoints")
103 |
104 | modes = {'test_auc': 'max', 'test_rmse': 'min', 'test_mae': 'min'}
105 |
106 | metric = args.checkpoint_metric
107 | sv_filename = 'best-{epoch:02d}-{' + metric + ':.3f}'
108 | callbacks.append(plc.ModelCheckpoint(
109 | monitor=metric,
110 | filename=sv_filename,
111 | save_top_k=5,
112 | mode=modes[metric],
113 | save_last=True,
114 | dirpath=ckptdir,
115 | verbose=True,
116 | every_n_epochs=args.check_val_every_n_epoch,
117 | ))
118 |
119 | now = datetime.datetime.now().strftime("%m-%dT%H-%M-%S")
120 | cfgdir = os.path.join(logdir, "configs")
121 | callbacks.append(
122 | SetupCallback(
123 | now=now,
124 | logdir=logdir,
125 | ckptdir=ckptdir,
126 | cfgdir=cfgdir,
127 | config=args.__dict__,
128 | argv_content=sys.argv + ["gpus: {}".format(torch.cuda.device_count())], )
129 | )
130 |
131 | if args.lr_scheduler:
132 | callbacks.append(plc.LearningRateMonitor(
133 | logging_interval=None))
134 | return callbacks
135 |
136 |
137 | def main():
138 | args = create_parser()
139 | params = OmegaConf.load(model_zoom[args.task_name][0])
140 | config = args.__dict__
141 | config.update(params)
142 | logger = plog.WandbLogger(
143 | project='graphsgpt_mixup2',
144 | name=args.ex_name,
145 | save_dir=str(os.path.join(args.res_dir, args.ex_name)),
146 | offline=True,
147 | id="_".join(args.ex_name.split("/")),
148 | entity="gaozhangyang"
149 | )
150 |
151 | pl.seed_everything(args.seed)
152 |
153 | data_module = DInterface(**vars(args))
154 | data_module.setup()
155 |
156 | gpu_count = torch.cuda.device_count()
157 | args.steps_per_epoch = math.ceil(len(data_module.trainset) / args.batch_size / gpu_count)
158 | print(f"steps_per_epoch {args.steps_per_epoch}, gpu_count {gpu_count}, batch_size{args.batch_size}")
159 |
160 | model = MInterface(**vars(args))
161 | params = torch.load(model_zoom[args.task_name][1])
162 | params = {k.replace('_forward_module.', ''): v for k, v in params.items()}
163 | model.load_state_dict(params)
164 |
165 | trainer_config = {
166 | 'gpus': -1, # Use all available GPUs
167 | 'max_epochs': args.epoch, # Maximum number of epochs to train for
168 | 'num_nodes': 1, # Number of nodes to use for distributed training
169 | "strategy": 'deepspeed_stage_2', # 'ddp', 'deepspeed_stage_2
170 | "precision": '32', # "bf16", 16
171 | 'accelerator': 'gpu', # Use distributed data parallel
172 | 'callbacks': load_callbacks(args),
173 | 'logger': logger,
174 | 'gradient_clip_val': 1.0
175 | }
176 |
177 | trainer_opt = argparse.Namespace(**trainer_config)
178 | trainer = Trainer.from_argparse_args(trainer_opt)
179 |
180 | trainer.test(model, data_module)
181 | print(trainer_config)
182 |
183 |
184 | if __name__ == "__main__":
185 | main()
186 |
--------------------------------------------------------------------------------
/jupyter_notebooks/analysis/README-Clustering.md:
--------------------------------------------------------------------------------
1 | # Hyper-Parameters for Clustering
2 |
3 | Total number of samples should be set to **32768**.
4 |
5 |
6 |
7 | ## UMAP
8 |
9 | ### For Clustering
10 |
11 | | Model | total_vis_sample_num | n_neighbors | min_dist | n_components |
12 | |--------------|----------------------|-------------|----------|--------------|
13 | | GraphsGPT-1W | 32768 | 40 | 0.05 | 2 |
14 | | GraphsGPT-2W | 32768 | 40 | 0.05 | 2 |
15 | | GraphsGPT-4W | 32768 | 100 | 0.05 | 2 |
16 | | GraphsGPT-8W | 32768 | 100 | 0.05 | 2 |
17 |
18 | ### For Visualization
19 |
20 | | Model | total_vis_sample_num | n_neighbors | min_dist | n_components |
21 | |--------------|----------------------|-------------|----------|--------------|
22 | | GraphsGPT-1W | 32768 | 40 | 0.8 | 2 |
23 | | GraphsGPT-2W | 32768 | 40 | 0.8 | 2 |
24 | | GraphsGPT-4W | 32768 | 40 | 0.7 | 2 |
25 | | GraphsGPT-8W | 32768 | 40 | 0.7 | 2 |
26 |
27 |
28 |
29 | ## HDBSCAN
30 |
31 | | Model | min_cluster_size | min_samples | cluster_selection_epsilon | alpha |
32 | |--------------|------------------|-------------|---------------------------|-------|
33 | | GraphsGPT-1W | 48 | 64 | 0.25 | 1.0 |
34 | | GraphsGPT-2W | 48 | 64 | 0.25 | 1.0 |
35 | | GraphsGPT-4W | 32 | 48 | 0.2 | 1.0 |
36 | | GraphsGPT-8W | 32 | 48 | 0.2 | 1.0 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/jupyter_notebooks/example_pipeline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": "# Pipeline of GraphsGPT with Hugging Face Transformers",
6 | "metadata": {
7 | "collapsed": false
8 | },
9 | "id": "3e188af549c37dc3"
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "source": "### Configurations",
14 | "metadata": {
15 | "collapsed": false
16 | },
17 | "id": "e5ecf87e00701167"
18 | },
19 | {
20 | "cell_type": "code",
21 | "source": [
22 | "import torch\n",
23 | "\n",
24 | "model_name_or_path = \"DaizeDong/GraphsGPT-8W\"\n",
25 | "smiles_file = \"../data/examples/zinc_example.txt\"\n",
26 | "\n",
27 | "batch_size = 1024\n",
28 | "max_batches = 4\n",
29 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
30 | ],
31 | "metadata": {
32 | "collapsed": false
33 | },
34 | "id": "ef4b44c94c086bd1",
35 | "outputs": [],
36 | "execution_count": null
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "source": [
41 | "### Load SMILES"
42 | ],
43 | "metadata": {
44 | "collapsed": false
45 | },
46 | "id": "84357216224aed01"
47 | },
48 | {
49 | "cell_type": "code",
50 | "source": [
51 | "with open(smiles_file, \"r\", encoding=\"utf-8\") as f:\n",
52 | " smiles_list = f.readlines()\n",
53 | "smiles_list = [smiles.removesuffix(\"\\n\") for smiles in smiles_list]\n",
54 | "\n",
55 | "print(f\"Total SMILES loaded: {len(smiles_list)}\")\n",
56 | "for i in range(10):\n",
57 | " print(f\"Example SMILES {i}: {smiles_list[i]}\")"
58 | ],
59 | "metadata": {
60 | "collapsed": false
61 | },
62 | "id": "1c903203360b16",
63 | "outputs": [],
64 | "execution_count": null
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "source": [
69 | "### Load Model & Tokenizer"
70 | ],
71 | "metadata": {
72 | "collapsed": false
73 | },
74 | "id": "4263d29cbbe85ea1"
75 | },
76 | {
77 | "cell_type": "code",
78 | "source": [
79 | "from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM\n",
80 | "from data.tokenizer import GraphsGPTTokenizer\n",
81 | "\n",
82 | "model = GraphsGPTForCausalLM.from_pretrained(model_name_or_path)\n",
83 | "tokenizer = GraphsGPTTokenizer.from_pretrained(model_name_or_path)\n",
84 | "\n",
85 | "print(model.state_dict().keys())\n",
86 | "print(f\"Total paramerters: {sum(x.numel() for x in model.parameters())}\")"
87 | ],
88 | "metadata": {
89 | "collapsed": false
90 | },
91 | "id": "eed2630d3a5bf75a",
92 | "outputs": [],
93 | "execution_count": null
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "source": "### Encode SMILES into Fingerprint Embeddings (Graph Words)",
98 | "metadata": {
99 | "collapsed": false
100 | },
101 | "id": "6a43509f85c9ef0c"
102 | },
103 | {
104 | "cell_type": "code",
105 | "source": [
106 | "from utils.operations.operation_tensor import move_tensors_to_device\n",
107 | "from utils.operations.operation_list import split_list_with_yield\n",
108 | "\n",
109 | "batch_count = 0\n",
110 | "fingerprints_lists = []\n",
111 | "\n",
112 | "model.to(device)\n",
113 | "model.eval()\n",
114 | "with torch.no_grad():\n",
115 | " for batched_smiles in split_list_with_yield(smiles_list, batch_size):\n",
116 | " inputs = tokenizer.batch_encode(batched_smiles, return_tensors=\"pt\")\n",
117 | " move_tensors_to_device(inputs, device)\n",
118 | "\n",
119 | " fingerprint_tokens = model.encode_to_fingerprints(**inputs) # (batch_size, num_fingerprints, hidden_dim)\n",
120 | " fingerprints_lists.append(fingerprint_tokens)\n",
121 | "\n",
122 | " batch_count += 1\n",
123 | " if batch_count >= max_batches:\n",
124 | " break\n",
125 | "\n",
126 | "print(f\"Encoded total {batch_count * batch_size} molecules\")"
127 | ],
128 | "metadata": {
129 | "collapsed": false
130 | },
131 | "id": "7a04d04186c8a9c",
132 | "outputs": [],
133 | "execution_count": null
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "source": "### Recover Molecule Sequences through Generation",
138 | "metadata": {
139 | "collapsed": false
140 | },
141 | "id": "3ec3683b0fa19abb"
142 | },
143 | {
144 | "cell_type": "code",
145 | "source": [
146 | "all_results = []\n",
147 | "\n",
148 | "for fingerprints in fingerprints_lists:\n",
149 | " generation_result = model.generate_from_fingerprints(\n",
150 | " fingerprint_tokens=fingerprints,\n",
151 | " bond_dict=tokenizer.bond_dict,\n",
152 | " strict_generation=True,\n",
153 | " max_atoms=None,\n",
154 | " similarity_threshold=0.5,\n",
155 | " check_first_node=True,\n",
156 | " check_atom_valence=False,\n",
157 | " fix_aromatic_bond=False,\n",
158 | " use_cache=False,\n",
159 | " save_failed=True, # save the generated partial result even the full generation failed\n",
160 | " show_progress=True,\n",
161 | " verbose=True,\n",
162 | " )\n",
163 | " all_results.extend(generation_result)\n",
164 | "\n",
165 | "print(\"Done.\")\n",
166 | "print(f\"#### Generated {len(all_results)} molecules\")"
167 | ],
168 | "metadata": {
169 | "collapsed": false
170 | },
171 | "id": "8b5324f8eb8e35de",
172 | "outputs": [],
173 | "execution_count": null
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "source": [
178 | "### Decode Sequences back to SMILES"
179 | ],
180 | "metadata": {
181 | "collapsed": false
182 | },
183 | "id": "921df1d54d1bb705"
184 | },
185 | {
186 | "cell_type": "code",
187 | "source": [
188 | "from rdkit.Chem import Draw\n",
189 | "\n",
190 | "\n",
191 | "def show_mol_png(mol, size=(512, 512)):\n",
192 | " img = Draw.MolToImage(mol, size=size)\n",
193 | " img.show()\n",
194 | " img.close()\n",
195 | "\n",
196 | "\n",
197 | "decoded_mols = []\n",
198 | "decoded_smiles = []\n",
199 | "\n",
200 | "for result in all_results:\n",
201 | " if result is not None:\n",
202 | " mol, smiles = tokenizer.decode(result)\n",
203 | " decoded_mols.append(mol)\n",
204 | " decoded_smiles.append(smiles)\n",
205 | " else:\n",
206 | " decoded_mols.append(None)\n",
207 | " decoded_smiles.append(None)\n",
208 | "\n",
209 | "# visualize the first 10 results\n",
210 | "for i in range(10):\n",
211 | " print(f\"Original SMILES {i}: {smiles_list[i]}\")\n",
212 | " print(f\"Decoded SMILES {i}: {decoded_smiles[i]}\")\n",
213 | " show_mol_png(decoded_mols[i])"
214 | ],
215 | "metadata": {
216 | "collapsed": false
217 | },
218 | "id": "ef67a4e812de11e3",
219 | "outputs": [],
220 | "execution_count": null
221 | },
222 | {
223 | "cell_type": "code",
224 | "source": [],
225 | "metadata": {
226 | "collapsed": false
227 | },
228 | "id": "c86d3fce01319503",
229 | "outputs": [],
230 | "execution_count": null
231 | }
232 | ],
233 | "metadata": {
234 | "kernelspec": {
235 | "display_name": "Python 3",
236 | "language": "python",
237 | "name": "python3"
238 | },
239 | "language_info": {
240 | "codemirror_mode": {
241 | "name": "ipython",
242 | "version": 2
243 | },
244 | "file_extension": ".py",
245 | "mimetype": "text/x-python",
246 | "name": "python",
247 | "nbconvert_exporter": "python",
248 | "pygments_lexer": "ipython2",
249 | "version": "2.7.6"
250 | }
251 | },
252 | "nbformat": 4,
253 | "nbformat_minor": 5
254 | }
255 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/models/__init__.py
--------------------------------------------------------------------------------
/models/graphsgpt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/models/graphsgpt/__init__.py
--------------------------------------------------------------------------------
/models/graphsgpt/configuration_graphsgpt.py:
--------------------------------------------------------------------------------
1 | """GraphsGPT Model Configuration"""
2 |
3 | from transformers.configuration_utils import PretrainedConfig
4 | from transformers.utils import logging
5 |
6 | logger = logging.get_logger(__name__)
7 |
8 |
9 | class GraphsGPTConfig(PretrainedConfig):
10 | model_type = "graphs_gpt"
11 |
12 | def __init__(
13 | self,
14 | atom_vocab_size=118, # number of atoms
15 | bond_vocab_size=92, # number of bonds
16 | pad_token_id=0,
17 | share_embeddings=False,
18 | # --------------------- #
19 | node_loss_weight=1.0,
20 | connection_loss_weight=1.0,
21 | connection_loss_type="contrastive", # classification contrastive
22 | adaptive_position_length=False,
23 | # --------------------- #
24 | num_fingerprints=8,
25 | position_feature_size=128,
26 | hidden_size=512,
27 | intermediate_size=2048,
28 | num_hidden_layers=8,
29 | num_attention_heads=8,
30 | hidden_act="silu",
31 | # --------------------- #
32 | initializer_method="hidden", # manual hidden hidden-layer
33 | initializer_range=0.02, # useful only when "initializer_method" is "manual"
34 | rms_norm_eps=1e-6,
35 | gradient_checkpointing=False,
36 | **kwargs,
37 | ):
38 | self.atom_vocab_size = atom_vocab_size
39 | self.bond_vocab_size = bond_vocab_size
40 | self.share_embeddings = share_embeddings
41 |
42 | self.node_loss_weight = node_loss_weight
43 | self.connection_loss_weight = connection_loss_weight
44 | self.connection_loss_type = connection_loss_type
45 | self.adaptive_position_length = adaptive_position_length
46 |
47 | self.num_fingerprints = num_fingerprints
48 | self.position_feature_size = position_feature_size
49 | self.hidden_size = hidden_size
50 | self.intermediate_size = intermediate_size
51 | self.num_hidden_layers = num_hidden_layers
52 | self.num_attention_heads = num_attention_heads
53 | self.hidden_act = hidden_act
54 |
55 | self.initializer_method = initializer_method
56 | self.initializer_range = initializer_range
57 | self.rms_norm_eps = rms_norm_eps
58 | self.gradient_checkpointing = gradient_checkpointing
59 |
60 | super().__init__(
61 | pad_token_id=pad_token_id,
62 | tie_word_embeddings=False,
63 | # **kwargs,
64 | )
65 |
--------------------------------------------------------------------------------
/models/graphsgpt/generation_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from typing import Tuple
4 |
5 | from utils.algorithms import FindCycles
6 |
7 | VALENCE_LIMIT = {
8 | # ATOMIC_NUM: MAX_VALENCE
9 | 5: 3, # B
10 | 6: 4, # C
11 | 7: 3, # N
12 | 8: 2, # 0
13 | 9: 1, # F
14 | 14: 4, # Si
15 | 16: 6, # S
16 | 17: 1, # Cl
17 | 35: 1, # Br
18 | 53: 1, # I
19 | }
20 |
21 |
22 | def get_atom_ids_from_bond_id(inverse_bond_dict, bond_id: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
23 | atom_ids = inverse_bond_dict[bond_id.item()][:2]
24 | return (
25 | torch.tensor(atom_ids[0], dtype=bond_id.dtype, device=bond_id.device).unsqueeze(0),
26 | torch.tensor(atom_ids[1], dtype=bond_id.dtype, device=bond_id.device).unsqueeze(0)
27 | )
28 |
29 |
30 | def get_another_atom_id_from_existing_bond(inverse_bond_dict, bond_id: torch.Tensor, connected_atom_id_1: torch.Tensor) -> torch.Tensor:
31 | # bond_id: id in the bond_dict
32 | # connected_atom_id_1: atomic_num of the existing connected atom
33 | info = inverse_bond_dict[bond_id.item()] # (min_atomic_num, max_atomic_num, bond_type)
34 | if not (connected_atom_id_1.item() == info[0] or connected_atom_id_1.item() == info[1]):
35 | raise ValueError
36 | elif connected_atom_id_1.item() == info[0]:
37 | return torch.tensor(info[1], dtype=bond_id.dtype, device=bond_id.device).unsqueeze(0)
38 | else:
39 | return torch.tensor(info[0], dtype=bond_id.dtype, device=bond_id.device).unsqueeze(0)
40 |
41 |
42 | def check_bond_connectivity_begin(inverse_bond_dict, bond_id, atom_id):
43 | # bond_id: id in the bond_dict
44 | # atom_id: atomic_num
45 | info = inverse_bond_dict[bond_id] # (min_atomic_num, max_atomic_num, bond_type)
46 | if atom_id == info[0] or atom_id == info[1]: # "+1" for getting the atomic num
47 | return True
48 | else:
49 | return False
50 |
51 |
52 | def check_bond_connectivity_both_sides(inverse_bond_dict, bond_id, atom_id1, atom_id2):
53 | # bond_id: id in the bond_dict
54 | # atom_id: atomic_num
55 | if atom_id1 <= atom_id2:
56 | min_atomic_num = atom_id1
57 | max_atomic_num = atom_id2
58 | else:
59 | min_atomic_num = atom_id2
60 | max_atomic_num = atom_id1
61 |
62 | info = inverse_bond_dict[bond_id] # (min_atomic_num, max_atomic_num, bond_type)
63 | if min_atomic_num == info[0] and max_atomic_num == info[1]:
64 | return True
65 | else:
66 | return False
67 |
68 |
69 | def check_bond_in_graph(graph_position_ids_1, graph_position_ids_2, connection_1, connection_2):
70 | connection_1_in_1 = (graph_position_ids_1 == connection_1)
71 | connection_1_in_2 = (graph_position_ids_2 == connection_1)
72 | connection_2_in_1 = (graph_position_ids_1 == connection_2)
73 | connection_2_in_2 = (graph_position_ids_2 == connection_2)
74 | if torch.any(connection_1_in_1 & connection_2_in_2) or torch.any(connection_1_in_2 & connection_2_in_1):
75 | return True
76 | else:
77 | return False
78 |
79 |
80 | def get_valence(input_ids, graph_position_ids_1, graph_position_ids_2, connection_id, inverse_bond_dict, bond_mask):
81 | """
82 | Return the total valence for an atom by checking its connected bonds
83 | For each connected aromatic bond with 1.5 valence, the real valence for the atom can be adjusted by a value of ±0.5
84 | """
85 | bond_ids_in_dict = input_ids[bond_mask] - 1
86 | bond_connections_1 = graph_position_ids_1[bond_mask]
87 | bond_connections_2 = graph_position_ids_2[bond_mask]
88 | bond_connections_mask = (bond_connections_1 == connection_id) | (bond_connections_2 == connection_id)
89 | connected_bond_ids = bond_ids_in_dict[bond_connections_mask].tolist()
90 |
91 | total_valence = 0
92 | valence_offset = 0
93 | for bond_type in connected_bond_ids:
94 | this_valence = inverse_bond_dict[bond_type][2]
95 | total_valence += this_valence
96 | if this_valence == 1.5: # aromatic bond
97 | valence_offset += 0.5
98 |
99 | if int(valence_offset / 0.5) % 2 == 0:
100 | total_valence += int(valence_offset / 0.5) % 2
101 | valence_offset = 0
102 | else:
103 | valence_offset = 0.5
104 |
105 | return total_valence, valence_offset
106 |
107 |
108 | def fix_dissociative_aromatic_bond(input_ids, graph_position_ids_1, graph_position_ids_2, identifier_ids, inverse_bond_dict, bond_dict):
109 | """
110 | Replace invalid aromatic bonds with corresponding single/double/triple bonds.
111 | Invalid aromatic bonds: non-ring & incomplete ring
112 | """
113 | atom_num = torch.sum(identifier_ids).item()
114 | bond_mask = ~identifier_ids
115 |
116 | indices_all_positions = torch.arange(input_ids.shape[0], device=input_ids.device, dtype=torch.int64)
117 | indices_bonds = indices_all_positions[bond_mask]
118 |
119 | # Create connection -> bond_type mapping
120 | bond_ids_in_dict = (input_ids.clone()[bond_mask] - 1).tolist()
121 |
122 | bond_connections_1 = graph_position_ids_1[bond_mask].tolist()
123 | bond_connections_2 = graph_position_ids_2[bond_mask].tolist()
124 | bond_connections_all = [
125 | (bond_connections_1[i], bond_connections_2[i]) if bond_connections_1[i] <= bond_connections_2[i] else (bond_connections_2[i], bond_connections_1[i]) # sort node id
126 | for i in range(len(bond_connections_1))
127 | ]
128 |
129 | connection_bond_mapping = {}
130 | connection_index_mapping = {} # record the positions for each connection
131 |
132 | for i, connection in enumerate(bond_connections_all):
133 | connection_bond_mapping[connection] = inverse_bond_dict[bond_ids_in_dict[i]]
134 | connection_index_mapping[connection] = indices_bonds[i]
135 |
136 | # Find cycles in the molecule
137 | adjacency_matrix = np.zeros((atom_num, atom_num), dtype=np.int8)
138 | for connection in bond_connections_all:
139 | adjacency_matrix[connection[0], connection[1]] = 1
140 | adjacency_matrix[connection[1], connection[0]] = 1
141 | cycle_finder = FindCycles(adjacency_matrix)
142 | cycles = cycle_finder.find_cycles()
143 |
144 | # Check the bonds in each cycle
145 | # If there is any cycle with all aromatic bonds, then all bonds in it are marked as valid
146 | valid_aromatic_connections = set()
147 |
148 | for cycle in cycles:
149 | is_aromatic = []
150 | not_aromatic = []
151 |
152 | # Check the validity
153 | for node_id in range(len(cycle)):
154 | if node_id < len(cycle) - 1:
155 | connection = (cycle[node_id], cycle[node_id + 1]) if cycle[node_id] <= cycle[node_id + 1] else (cycle[node_id + 1], cycle[node_id])
156 | else:
157 | connection = (cycle[node_id], cycle[0]) if cycle[node_id] <= cycle[0] else (cycle[0], cycle[node_id])
158 |
159 | begin_atomic_num, end_atomic_num, bond_type = connection_bond_mapping[connection]
160 |
161 | if bond_type == 1.5:
162 | is_aromatic.append(connection)
163 | else:
164 | not_aromatic.append(connection)
165 |
166 | if len(not_aromatic) == 0: # all bonds are aromatic
167 | for connection in is_aromatic:
168 | valid_aromatic_connections.add(connection)
169 |
170 | # Change invalid aromatic bonds into single/double/triple bonds
171 | for connection in bond_connections_all:
172 | begin_atomic_num, end_atomic_num, bond_type = connection_bond_mapping[connection]
173 | if bond_type == 1.5 and connection not in valid_aromatic_connections: # invalid aromatic
174 | index = connection_index_mapping[connection]
175 | if (begin_atomic_num, end_atomic_num, 1.0) in bond_dict: # single
176 | new_bond_id = bond_dict[(begin_atomic_num, end_atomic_num, 1.0)] + 1
177 | elif (begin_atomic_num, end_atomic_num, 2.0) in bond_dict: # double
178 | new_bond_id = bond_dict[(begin_atomic_num, end_atomic_num, 2.0)] + 1
179 | elif (begin_atomic_num, end_atomic_num, 3.0) in bond_dict: # triple
180 | new_bond_id = bond_dict[(begin_atomic_num, end_atomic_num, 3.0)] + 1
181 | else: # this bond is incorrigible!
182 | continue
183 | input_ids[index] = new_bond_id
184 |
185 | return input_ids, graph_position_ids_1, graph_position_ids_2, identifier_ids
186 |
--------------------------------------------------------------------------------
/models/graphsgpt/orf.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py
3 | """
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | @torch.no_grad()
10 | def orthogonal_matrix_chunk(cols, device=None):
11 | unstructured_block = torch.randn((cols, cols), device=device)
12 | q, r = torch.linalg.qr(unstructured_block, mode='reduced')
13 | return q.transpose(0, 1) # [cols, cols]
14 |
15 |
16 | @torch.no_grad()
17 | def gaussian_orthogonal_random_matrix(nb_columns, nb_rows, random_shuffle=False, device=None, dtype=torch.float32):
18 | """create 2D Gaussian orthogonal matrix"""
19 | nb_full_blocks = int(nb_rows / nb_columns)
20 |
21 | block_list = []
22 |
23 | for _ in range(nb_full_blocks):
24 | q = orthogonal_matrix_chunk(nb_columns, device=device)
25 | block_list.append(q)
26 |
27 | remaining_rows = nb_rows - nb_full_blocks * nb_columns
28 | if remaining_rows > 0:
29 | # q = orthogonal_matrix_chunk_batched(nb_samples, nb_columns, device=device)
30 | block_list.append(torch.zeros((nb_columns, remaining_rows), device=device))
31 |
32 | final_matrix = torch.cat(block_list, dim=1).type(dtype)
33 | final_matrix = F.normalize(final_matrix, p=2, dim=1)
34 |
35 | if random_shuffle:
36 | _, indices = torch.rand((final_matrix.shape[1],), device=device).sort(dim=0)
37 | indices = indices.unsqueeze(0).expand(final_matrix.shape)
38 | final_matrix = torch.gather(final_matrix, dim=1, index=indices)
39 |
40 | return final_matrix # (nb_columns, nb_rows)
41 |
42 |
43 | @torch.no_grad()
44 | def orthogonal_matrix_chunk_batched(bsz, cols, device=None):
45 | unstructured_block = torch.randn((bsz, cols, cols), device=device)
46 | q, r = torch.linalg.qr(unstructured_block, mode='reduced')
47 | return q.transpose(1, 2) # [bsz, cols, cols]
48 |
49 |
50 | @torch.no_grad()
51 | def gaussian_orthogonal_random_matrix_batched(nb_samples, nb_columns, nb_rows, random_shuffle=False, device=None, dtype=torch.float32):
52 | """create 2D Gaussian orthogonal matrix"""
53 | nb_full_blocks = int(nb_rows / nb_columns)
54 |
55 | block_list = []
56 |
57 | for _ in range(nb_full_blocks):
58 | q = orthogonal_matrix_chunk_batched(nb_samples, nb_columns, device=device)
59 | block_list.append(q)
60 |
61 | remaining_rows = nb_rows - nb_full_blocks * nb_columns
62 | if remaining_rows > 0:
63 | # q = orthogonal_matrix_chunk_batched(nb_samples, nb_columns, device=device)
64 | block_list.append(torch.zeros((nb_samples, nb_columns, remaining_rows), device=device))
65 |
66 | final_matrix = torch.cat(block_list, dim=2).type(dtype)
67 | final_matrix = F.normalize(final_matrix, p=2, dim=2)
68 |
69 | if random_shuffle:
70 | _, indices = torch.rand((final_matrix.shape[0], final_matrix.shape[2]), device=device).sort(dim=1)
71 | indices = indices.unsqueeze(1).expand(final_matrix.shape)
72 | final_matrix = torch.gather(final_matrix, dim=2, index=indices)
73 |
74 | return final_matrix # (nb_samples, nb_columns, nb_rows)
75 |
76 |
77 | if __name__ == "__main__":
78 | gaussian_orthogonal_random_matrix(37, 128)
79 | gaussian_orthogonal_random_matrix_batched(256, 37, 128)
80 |
--------------------------------------------------------------------------------
/models/graphsgpt_cond_gen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/models/graphsgpt_cond_gen/__init__.py
--------------------------------------------------------------------------------
/models/graphsgpt_cond_gen/configuration_graphsgpt_cond_gen.py:
--------------------------------------------------------------------------------
1 | """GraphsGPT Conditioned Model Configuration"""
2 |
3 | from transformers.configuration_utils import PretrainedConfig
4 | from transformers.utils import logging
5 |
6 | logger = logging.get_logger(__name__)
7 |
8 |
9 | class GraphsGPTConditionedConfig(PretrainedConfig):
10 | model_type = "graphs_gpt_conditioned"
11 |
12 | def __init__(
13 | self,
14 | atom_vocab_size=118, # number of atoms
15 | bond_vocab_size=92, # number of bonds
16 | pad_token_id=0,
17 | share_embeddings=False,
18 | # --------------------- #
19 | node_loss_weight=1.0,
20 | connection_loss_weight=1.0,
21 | connection_loss_type="contrastive", # classification contrastive
22 | adaptive_position_length=False,
23 | # --------------------- #
24 | num_properties=3, # 🔍
25 | num_fingerprints=8,
26 | position_feature_size=128,
27 | hidden_size=512,
28 | intermediate_size=2048,
29 | num_hidden_layers=8,
30 | num_attention_heads=8,
31 | hidden_act="silu",
32 | # --------------------- #
33 | initializer_method="hidden", # manual hidden hidden-layer
34 | initializer_range=0.02, # useful only when "initializer_method" is "manual"
35 | rms_norm_eps=1e-6,
36 | gradient_checkpointing=False,
37 | **kwargs,
38 | ):
39 | self.atom_vocab_size = atom_vocab_size
40 | self.bond_vocab_size = bond_vocab_size
41 | self.share_embeddings = share_embeddings
42 |
43 | self.node_loss_weight = node_loss_weight
44 | self.connection_loss_weight = connection_loss_weight
45 | self.connection_loss_type = connection_loss_type
46 | self.adaptive_position_length = adaptive_position_length
47 |
48 | self.num_properties = num_properties # 🔍
49 | self.num_fingerprints = num_fingerprints
50 | self.position_feature_size = position_feature_size
51 | self.hidden_size = hidden_size
52 | self.intermediate_size = intermediate_size
53 | self.num_hidden_layers = num_hidden_layers
54 | self.num_attention_heads = num_attention_heads
55 | self.hidden_act = hidden_act
56 |
57 | self.initializer_method = initializer_method
58 | self.initializer_range = initializer_range
59 | self.rms_norm_eps = rms_norm_eps
60 | self.gradient_checkpointing = gradient_checkpointing
61 |
62 | super().__init__(
63 | pad_token_id=pad_token_id,
64 | tie_word_embeddings=False,
65 | # **kwargs,
66 | )
67 |
--------------------------------------------------------------------------------
/moses/NP_Score/README:
--------------------------------------------------------------------------------
1 | RDKit-based implementation of the method described in:
2 |
3 | Natural Product-likeness Score and Its Application for Prioritization of Compound Libraries
4 | Peter Ertl, Silvio Roggo, and Ansgar Schuffenhauer
5 | Journal of Chemical Information and Modeling, 48, 68-74 (2008)
6 | http://pubs.acs.org/doi/abs/10.1021/ci700286x
7 |
8 | Contribution from Peter Ertl
9 |
10 |
--------------------------------------------------------------------------------
/moses/NP_Score/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/NP_Score/__init__.py
--------------------------------------------------------------------------------
/moses/NP_Score/npscorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of natural product-likeness as described in:
3 | #
4 | # Natural Product-likeness Score and Its Application for Prioritization of
5 | # Compound Libraries
6 | # Peter Ertl, Silvio Roggo, and Ansgar Schuffenhauer
7 | # Journal of Chemical Information and Modeling, 48, 68-74 (2008)
8 | # http://pubs.acs.org/doi/abs/10.1021/ci700286x
9 | #
10 | # for the training of this model only openly available data have been used
11 | # ~50,000 natural products collected from various open databases
12 | # ~1 million drug-like molecules from ZINC as a "non-NP background"
13 | #
14 | # peter ertl, august 2015
15 | #
16 |
17 | from __future__ import print_function
18 |
19 | import os.path
20 | from collections import namedtuple
21 |
22 | import gzip
23 | import math
24 | import pickle
25 | import sys
26 | from rdkit import Chem
27 | from rdkit.Chem import rdMolDescriptors
28 |
29 | _fscores = None
30 |
31 |
32 | def readNPModel(filename=os.path.join(os.path.dirname(__file__),
33 | 'publicnp.model.gz')):
34 | """Reads and returns the scoring model,
35 | which has to be passed to the scoring functions."""
36 | global _fscores
37 | _fscores = pickle.load(gzip.open(filename))
38 | return _fscores
39 |
40 |
41 | def scoreMolWConfidence(mol, fscore):
42 | """Next to the NP Likeness Score, this function outputs a confidence value
43 | between 0..1 that descibes how many fragments of the tested molecule
44 | were found in the model data set (1: all fragments were found).
45 |
46 | Returns namedtuple NPLikeness(nplikeness, confidence)"""
47 |
48 | if mol is None:
49 | raise ValueError('invalid molecule')
50 | fp = rdMolDescriptors.GetMorganFingerprint(mol, 2)
51 | bits = fp.GetNonzeroElements()
52 |
53 | # calculating the score
54 | score = 0.0
55 | bits_found = 0
56 | for bit in bits:
57 | if bit in fscore:
58 | bits_found += 1
59 | score += fscore[bit]
60 |
61 | score /= float(mol.GetNumAtoms())
62 | confidence = float(bits_found / len(bits))
63 |
64 | # preventing score explosion for exotic molecules
65 | if score > 4:
66 | score = 4. + math.log10(score - 4. + 1.)
67 | elif score < -4:
68 | score = -4. - math.log10(-4. - score + 1.)
69 | NPLikeness = namedtuple("NPLikeness", "nplikeness,confidence")
70 | return NPLikeness(score, confidence)
71 |
72 |
73 | def scoreMol(mol, fscore=None):
74 | """Calculates the Natural Product Likeness of a molecule.
75 |
76 | Returns the score as float in the range -5..5."""
77 | if _fscores is None:
78 | readNPModel()
79 | fscore = fscore or _fscores
80 | return scoreMolWConfidence(mol, fscore).nplikeness
81 |
82 |
83 | def processMols(fscore, suppl):
84 | print("calculating ...", file=sys.stderr)
85 | n = 0
86 | for m in suppl:
87 | if m is None:
88 | continue
89 |
90 | n += 1
91 | score = "%.3f" % scoreMol(m, fscore)
92 |
93 | smiles = Chem.MolToSmiles(m, True)
94 | name = m.GetProp('_Name')
95 | print(smiles + "\t" + name + "\t" + score)
96 |
97 | print("finished, " + str(n) + " molecules processed", file=sys.stderr)
98 |
99 |
100 | if __name__ == '__main__':
101 | fscore = readNPModel() # fills fscore
102 |
103 | suppl = Chem.SmilesMolSupplier(
104 | sys.argv[1], smilesColumn=0, nameColumn=1, titleLine=False
105 | )
106 | processMols(fscore, suppl)
107 |
108 | #
109 | # Copyright (c) 2015, Novartis Institutes for BioMedical Research Inc.
110 | # All rights reserved.
111 | #
112 | # Redistribution and use in source and binary forms, with or without
113 | # modification, are permitted provided that the following conditions are
114 | # met:
115 | #
116 | # * Redistributions of source code must retain the above copyright
117 | # notice, this list of conditions and the following disclaimer.
118 | # * Redistributions in binary form must reproduce the above
119 | # copyright notice, this list of conditions and the following
120 | # disclaimer in the documentation and/or other materials provided
121 | # with the distribution.
122 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
123 | # nor the names of its contributors may be used to endorse or promote
124 | # products derived from this software without specific prior written
125 | # permission.
126 | #
127 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
128 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
129 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
130 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
131 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
132 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
133 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
134 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
135 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
136 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
137 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
138 | #
139 |
--------------------------------------------------------------------------------
/moses/NP_Score/publicnp.model.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/NP_Score/publicnp.model.gz
--------------------------------------------------------------------------------
/moses/README.md:
--------------------------------------------------------------------------------
1 | This is a modified version of [Molecular Sets (MOSES)](https://github.com/molecularsets/moses) for portability and minimal implementation.
2 |
3 | Changes include:
4 |
5 | 1. Cleaned the codes for training & evaluating built-in models, which makes it compatible with the latest versions of Python (>=3.10).
6 | 2. Fixed the [compatibility issue](https://github.com/molecularsets/moses/pull/111) with the latest versions of Pandas (>=2.2.2).
7 |
8 |
--------------------------------------------------------------------------------
/moses/SA_Score/README:
--------------------------------------------------------------------------------
1 | RDKit-based implementation of the method described in:
2 |
3 | Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
4 | Peter Ertl and Ansgar Schuffenhauer
5 | Journal of Cheminformatics 1:8 (2009)
6 | http://www.jcheminf.com/content/1/1/8
7 |
8 | Contribution from Peter Ertl and Greg Landrum
9 |
10 |
--------------------------------------------------------------------------------
/moses/SA_Score/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/SA_Score/__init__.py
--------------------------------------------------------------------------------
/moses/SA_Score/fpscores.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/SA_Score/fpscores.pkl.gz
--------------------------------------------------------------------------------
/moses/SA_Score/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on
5 | # Molecular Complexity and Fragment Contributions
6 | # Peter Ertl and Ansgar Schuffenhauer
7 | # Journal of Cheminformatics 1:8 (2009)
8 | # http://www.jcheminf.com/content/1/1/8
9 | #
10 | # several small modifications to the original paper are included
11 | # particularly slightly different formula for marocyclic penalty
12 | # and taking into account also molecule symmetry (fingerprint density)
13 | #
14 | # for a set of 10k diverse molecules the agreement between the original method
15 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
16 | #
17 | # peter ertl & greg landrum, september 2013
18 | #
19 | from __future__ import print_function
20 |
21 | import os.path as op
22 |
23 | import math
24 | import pickle
25 | from rdkit import Chem
26 | from rdkit.Chem import rdMolDescriptors
27 | from rdkit.six import iteritems
28 |
29 | _fscores = None
30 |
31 |
32 | def readFragmentScores(name='fpscores'):
33 | import gzip
34 | global _fscores
35 | # generate the full path filename:
36 | if name == "fpscores":
37 | name = op.join(op.dirname(__file__), name)
38 | _fscores = pickle.load(gzip.open('%s.pkl.gz' % name))
39 | outDict = {}
40 | for i in _fscores:
41 | for j in range(1, len(i)):
42 | outDict[i[j]] = float(i[0])
43 | _fscores = outDict
44 |
45 |
46 | def numBridgeheadsAndSpiro(mol, ri=None):
47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49 | return nBridgehead, nSpiro
50 |
51 |
52 | def calculateScore(m):
53 | if _fscores is None:
54 | readFragmentScores()
55 |
56 | # fragment score
57 | fp = rdMolDescriptors.GetMorganFingerprint(
58 | m, 2 # <- 2 is the *radius* of the circular fingerprint
59 | )
60 | fps = fp.GetNonzeroElements()
61 | score1 = 0.
62 | nf = 0
63 | for bitId, v in iteritems(fps):
64 | nf += v
65 | sfp = bitId
66 | score1 += _fscores.get(sfp, -4) * v
67 | score1 /= nf
68 |
69 | # features score
70 | nAtoms = m.GetNumAtoms()
71 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
72 | ri = m.GetRingInfo()
73 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
74 | nMacrocycles = 0
75 | for x in ri.AtomRings():
76 | if len(x) > 8:
77 | nMacrocycles += 1
78 |
79 | sizePenalty = nAtoms ** 1.005 - nAtoms
80 | stereoPenalty = math.log10(nChiralCenters + 1)
81 | spiroPenalty = math.log10(nSpiro + 1)
82 | bridgePenalty = math.log10(nBridgeheads + 1)
83 | macrocyclePenalty = 0.
84 | # ---------------------------------------
85 | # This differs from the paper, which defines:
86 | # macrocyclePenalty = math.log10(nMacrocycles+1)
87 | # This form generates better results when 2 or more macrocycles are present
88 | if nMacrocycles > 0:
89 | macrocyclePenalty = math.log10(2)
90 |
91 | score2 = (0. - sizePenalty - stereoPenalty -
92 | spiroPenalty - bridgePenalty - macrocyclePenalty)
93 |
94 | # correction for the fingerprint density
95 | # not in the original publication, added in version 1.1
96 | # to make highly symmetrical molecules easier to synthetise
97 | score3 = 0.
98 | if nAtoms > len(fps):
99 | score3 = math.log(float(nAtoms) / len(fps)) * .5
100 |
101 | sascore = score1 + score2 + score3
102 |
103 | # need to transform "raw" value into scale between 1 and 10
104 | min = -4.0
105 | max = 2.5
106 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
107 | # smooth the 10-end
108 | if sascore > 8.:
109 | sascore = 8. + math.log(sascore + 1. - 9.)
110 | if sascore > 10.:
111 | sascore = 10.0
112 | elif sascore < 1.:
113 | sascore = 1.0
114 |
115 | return sascore
116 |
117 |
118 | def processMols(mols):
119 | print('smiles\tName\tsa_score')
120 | for m in mols:
121 | if m is None:
122 | continue
123 |
124 | s = calculateScore(m)
125 |
126 | smiles = Chem.MolToSmiles(m)
127 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
128 |
129 |
130 | if __name__ == '__main__':
131 | import sys
132 | import time
133 |
134 | t1 = time.time()
135 | readFragmentScores("fpscores")
136 | t2 = time.time()
137 |
138 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
139 | t3 = time.time()
140 | processMols(suppl)
141 | t4 = time.time()
142 |
143 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % (
144 | (t2 - t1), (t4 - t3)),
145 | file=sys.stderr)
146 |
147 | #
148 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
149 | # All rights reserved.
150 | #
151 | # Redistribution and use in source and binary forms, with or without
152 | # modification, are permitted provided that the following conditions are
153 | # met:
154 | #
155 | # * Redistributions of source code must retain the above copyright
156 | # notice, this list of conditions and the following disclaimer.
157 | # * Redistributions in binary form must reproduce the above
158 | # copyright notice, this list of conditions and the following
159 | # disclaimer in the documentation and/or other materials provided
160 | # with the distribution.
161 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
162 | # nor the names of its contributors may be used to endorse or promote
163 | # products derived from this software without specific prior written
164 | # permission.
165 | #
166 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
167 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
168 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
169 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
170 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
171 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
172 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
173 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
174 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
175 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
176 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
177 | #
178 |
--------------------------------------------------------------------------------
/moses/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import get_dataset, get_statistics
2 | from .metrics import get_all_metrics
3 |
4 | __version__ = '0.3.1'
5 | __all__ = ["get_dataset", "get_statistics", "get_all_metrics"]
6 |
--------------------------------------------------------------------------------
/moses/data/test.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/data/test.csv.gz
--------------------------------------------------------------------------------
/moses/data/test_scaffolds.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/data/test_scaffolds.csv.gz
--------------------------------------------------------------------------------
/moses/data/test_scaffolds_stats.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/data/test_scaffolds_stats.npz
--------------------------------------------------------------------------------
/moses/data/test_stats.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/data/test_stats.npz
--------------------------------------------------------------------------------
/moses/data/train.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/moses/data/train.csv.gz
--------------------------------------------------------------------------------
/moses/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pandas as pd
4 |
5 | AVAILABLE_SPLITS = ['train', 'test', 'test_scaffolds']
6 |
7 | BASE_PATH = os.path.dirname(__file__)
8 |
9 |
10 | # BASE_PATH = os.path.join(BASE_PATH, "..")
11 |
12 |
13 | def get_dataset(split='train'):
14 | """
15 | Loads MOSES dataset
16 |
17 | Arguments:
18 | split (str): split to load. Must be
19 | one of: 'train', 'test', 'test_scaffolds'
20 |
21 | Returns:
22 | list with SMILES strings
23 | """
24 | if split not in AVAILABLE_SPLITS:
25 | raise ValueError(
26 | f"Unknown split {split}. "
27 | f"Available splits: {AVAILABLE_SPLITS}"
28 | )
29 | if split not in AVAILABLE_SPLITS:
30 | raise ValueError(
31 | f"Unknown split {split}. "
32 | f"Available splits: {AVAILABLE_SPLITS}")
33 | path = os.path.join(BASE_PATH, 'data', split + '.csv.gz')
34 | smiles = pd.read_csv(path, compression='gzip')['SMILES'].values
35 | return smiles
36 |
37 |
38 | def get_statistics(split='test'):
39 | path = os.path.join(BASE_PATH, 'data', split + '_stats.npz')
40 | return np.load(path, allow_pickle=True)['stats'].item()
41 |
--------------------------------------------------------------------------------
/moses/mcf.csv:
--------------------------------------------------------------------------------
1 | names,smarts
2 | MCF1,[#6]=&!@[#6]-[#6]#[#7]
3 | MCF2,[#6]=&!@[#6]-[#16](=[#8])=[#8]
4 | MCF3,[#6]=&!@[#6&!H0]-&!@[#6](=[#8])-&!@[#7]
5 | MCF4,"[H]C([H])([#6])[F,Cl,Br,I]"
6 | MCF5,[#6]1-[#8]-[#6]-1
7 | MCF6,[#6]-[#7]=[#6]=[#8]
8 | MCF7,[#6&!H0]=[#8]
9 | MCF8,"[#6](=&!@[#7&!H0])-&!@[#6,#7,#8,#16]"
10 | MCF9,[#6]1-[#7]-[#6]-1
11 | MCF10,[#6]~&!@[#7]~&!@[#7]~&!@[#6]
12 | MCF11,[#7]=&!@[#7]
13 | MCF12,[H][#6]-1=[#6]([H])-[#6]=[#6](-*)-[#8]-1
14 | MCF13,[H][#6]-1=[#6]([H])-[#6]=[#6](-*)-[#16]-1
15 | MCF14,"[#17,#35,#53]-c(:*):[!#1!#6]:*"
16 | MCF15,[H][#7]([H])-[#6]-1=[#6]-[#6]=[#6]-[#6]=[#6]-1
17 | MCF16,[#16]~[#16]
18 | MCF17,[#7]~&!@[#7]~&!@[#7]
19 | MCF18,[#7]-&!@[#6&!H0&!H1]-&!@[#7]
20 | MCF19,[#6&!H0](-&!@[#8])-&!@[#8]
21 | MCF20,[#35].[#35].[#35]
22 | MCF21,[#17].[#17].[#17].[#17]
23 | MCF22,[#9].[#9].[#9].[#9].[#9].[#9].[#9]
24 |
--------------------------------------------------------------------------------
/moses/metric_utils.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 |
3 | import numpy as np
4 | import os
5 | import pandas as pd
6 | import scipy.sparse
7 | import torch
8 | from functools import partial
9 | from rdkit import Chem
10 | from rdkit.Chem import AllChem
11 | from rdkit.Chem import Descriptors
12 | from rdkit.Chem import MACCSkeys
13 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect as Morgan
14 | from rdkit.Chem.QED import qed
15 | from rdkit.Chem.Scaffolds import MurckoScaffold
16 |
17 | from .NP_Score import npscorer
18 | from .SA_Score import sascorer
19 | from .utils import mapper, get_mol
20 |
21 | _base_dir = os.path.split(__file__)[0]
22 | _mcf = pd.read_csv(os.path.join(_base_dir, 'mcf.csv'))
23 | _pains = pd.read_csv(os.path.join(_base_dir, 'wehi_pains.csv'), names=['smarts', 'names'])
24 | _filters = [Chem.MolFromSmarts(x) for x in pd.concat([_mcf, _pains], sort=True)['smarts'].values] # fixed for pandas 2.2.2
25 |
26 |
27 | def canonic_smiles(smiles_or_mol):
28 | mol = get_mol(smiles_or_mol)
29 | if mol is None:
30 | return None
31 | return Chem.MolToSmiles(mol)
32 |
33 |
34 | def logP(mol):
35 | """
36 | Computes RDKit's logP
37 | """
38 | return Chem.Crippen.MolLogP(mol)
39 |
40 |
41 | def SA(mol):
42 | """
43 | Computes RDKit's Synthetic Accessibility score
44 | """
45 | return sascorer.calculateScore(mol)
46 |
47 |
48 | def NP(mol):
49 | """
50 | Computes RDKit's Natural Product-likeness score
51 | """
52 | return npscorer.scoreMol(mol)
53 |
54 |
55 | def QED(mol):
56 | """
57 | Computes RDKit's QED score
58 | """
59 | return qed(mol)
60 |
61 |
62 | def weight(mol):
63 | """
64 | Computes molecular weight for given molecule.
65 | Returns float,
66 | """
67 | return Descriptors.MolWt(mol)
68 |
69 |
70 | def get_n_rings(mol):
71 | """
72 | Computes the number of rings in a molecule
73 | """
74 | return mol.GetRingInfo().NumRings()
75 |
76 |
77 | def fragmenter(mol):
78 | """
79 | fragment mol using BRICS and return smiles list
80 | """
81 | fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
82 | fgs_smi = Chem.MolToSmiles(fgs).split(".")
83 | return fgs_smi
84 |
85 |
86 | def compute_fragments(mol_list, n_jobs=1):
87 | """
88 | fragment list of mols using BRICS and return smiles list
89 | """
90 | fragments = Counter()
91 | for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
92 | fragments.update(mol_frag)
93 | return fragments
94 |
95 |
96 | def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
97 | """
98 | Extracts a scafold from a molecule in a form of a canonic SMILES
99 | """
100 | scaffolds = Counter()
101 | map_ = mapper(n_jobs)
102 | scaffolds = Counter(
103 | map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
104 | if None in scaffolds:
105 | scaffolds.pop(None)
106 | return scaffolds
107 |
108 |
109 | def compute_scaffold(mol, min_rings=2):
110 | mol = get_mol(mol)
111 | try:
112 | scaffold = MurckoScaffold.GetScaffoldForMol(mol)
113 | except (ValueError, RuntimeError):
114 | return None
115 | n_rings = get_n_rings(scaffold)
116 | scaffold_smiles = Chem.MolToSmiles(scaffold)
117 | if scaffold_smiles == '' or n_rings < min_rings:
118 | return None
119 | return scaffold_smiles
120 |
121 |
122 | def average_agg_tanimoto(stock_vecs, gen_vecs,
123 | batch_size=5000, agg='max',
124 | device='cpu', p=1):
125 | """
126 | For each molecule in gen_vecs finds closest molecule in stock_vecs.
127 | Returns average tanimoto score for between these molecules
128 |
129 | Parameters:
130 | stock_vecs: numpy array
131 | gen_vecs: numpy array
132 | agg: max or mean
133 | p: power for averaging: (mean x^p)^(1/p)
134 | """
135 | assert agg in ['max', 'mean'], "Can aggregate only max or mean"
136 | agg_tanimoto = np.zeros(len(gen_vecs))
137 | total = np.zeros(len(gen_vecs))
138 | for j in range(0, stock_vecs.shape[0], batch_size):
139 | x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
140 | for i in range(0, gen_vecs.shape[0], batch_size):
141 | y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
142 | y_gen = y_gen.transpose(0, 1)
143 | tp = torch.mm(x_stock, y_gen)
144 | jac = (tp / (x_stock.sum(1, keepdim=True) +
145 | y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
146 | jac[np.isnan(jac)] = 1
147 | if p != 1:
148 | jac = jac ** p
149 | if agg == 'max':
150 | agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
151 | agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
152 | elif agg == 'mean':
153 | agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
154 | total[i:i + y_gen.shape[1]] += jac.shape[0]
155 | if agg == 'mean':
156 | agg_tanimoto /= total
157 | if p != 1:
158 | agg_tanimoto = (agg_tanimoto) ** (1 / p)
159 | return np.mean(agg_tanimoto)
160 |
161 |
162 | def fingerprint(smiles_or_mol, fp_type='maccs', dtype=None, morgan__r=2,
163 | morgan__n=1024, *args, **kwargs):
164 | """
165 | Generates fingerprint for SMILES
166 | If smiles is invalid, returns None
167 | Returns numpy array of fingerprint bits
168 |
169 | Parameters:
170 | smiles: SMILES string
171 | type: type of fingerprint: [MACCS|morgan]
172 | dtype: if not None, specifies the dtype of returned array
173 | """
174 | fp_type = fp_type.lower()
175 | molecule = get_mol(smiles_or_mol, *args, **kwargs)
176 | if molecule is None:
177 | return None
178 | if fp_type == 'maccs':
179 | keys = MACCSkeys.GenMACCSKeys(molecule)
180 | keys = np.array(keys.GetOnBits())
181 | fingerprint = np.zeros(166, dtype='uint8')
182 | if len(keys) != 0:
183 | fingerprint[keys - 1] = 1 # We drop 0-th key that is always zero
184 | elif fp_type == 'morgan':
185 | fingerprint = np.asarray(Morgan(molecule, morgan__r, nBits=morgan__n),
186 | dtype='uint8')
187 | else:
188 | raise ValueError("Unknown fingerprint type {}".format(fp_type))
189 | if dtype is not None:
190 | fingerprint = fingerprint.astype(dtype)
191 | return fingerprint
192 |
193 |
194 | def fingerprints(smiles_mols_array, n_jobs=1, already_unique=False, *args,
195 | **kwargs):
196 | '''
197 | Computes fingerprints of smiles np.array/list/pd.Series with n_jobs workers
198 | e.g.fingerprints(smiles_mols_array, type='morgan', n_jobs=10)
199 | Inserts np.NaN to rows corresponding to incorrect smiles.
200 | IMPORTANT: if there is at least one np.NaN, the dtype would be float
201 | Parameters:
202 | smiles_mols_array: list/array/pd.Series of smiles or already computed
203 | RDKit molecules
204 | n_jobs: number of parralel workers to execute
205 | already_unique: flag for performance reasons, if smiles array is big
206 | and already unique. Its value is set to True if smiles_mols_array
207 | contain RDKit molecules already.
208 | '''
209 | if isinstance(smiles_mols_array, pd.Series):
210 | smiles_mols_array = smiles_mols_array.values
211 | else:
212 | smiles_mols_array = np.asarray(smiles_mols_array)
213 | if not isinstance(smiles_mols_array[0], str):
214 | already_unique = True
215 |
216 | if not already_unique:
217 | smiles_mols_array, inv_index = np.unique(smiles_mols_array,
218 | return_inverse=True)
219 |
220 | fps = mapper(n_jobs)(
221 | partial(fingerprint, *args, **kwargs), smiles_mols_array
222 | )
223 |
224 | length = 1
225 | for fp in fps:
226 | if fp is not None:
227 | length = fp.shape[-1]
228 | first_fp = fp
229 | break
230 | fps = [fp if fp is not None else np.array([np.NaN]).repeat(length)[None, :]
231 | for fp in fps]
232 | if scipy.sparse.issparse(first_fp):
233 | fps = scipy.sparse.vstack(fps).tocsr()
234 | else:
235 | fps = np.vstack(fps)
236 | if not already_unique:
237 | return fps[inv_index]
238 | return fps
239 |
240 |
241 | def mol_passes_filters(mol,
242 | allowed=None,
243 | isomericSmiles=False):
244 | """
245 | Checks if mol
246 | * passes MCF and PAINS filters,
247 | * has only allowed atoms
248 | * is not charged
249 | """
250 | allowed = allowed or {'C', 'N', 'S', 'O', 'F', 'Cl', 'Br', 'H'}
251 | mol = get_mol(mol)
252 | if mol is None:
253 | return False
254 | ring_info = mol.GetRingInfo()
255 | if ring_info.NumRings() != 0 and any(
256 | len(x) >= 8 for x in ring_info.AtomRings()
257 | ):
258 | return False
259 | h_mol = Chem.AddHs(mol)
260 | if any(atom.GetFormalCharge() != 0 for atom in mol.GetAtoms()):
261 | return False
262 | if any(atom.GetSymbol() not in allowed for atom in mol.GetAtoms()):
263 | return False
264 | if any(h_mol.HasSubstructMatch(smarts) for smarts in _filters):
265 | return False
266 | smiles = Chem.MolToSmiles(mol, isomericSmiles=isomericSmiles)
267 | if smiles is None or len(smiles) == 0:
268 | return False
269 | if Chem.MolFromSmiles(smiles) is None:
270 | return False
271 | return True
272 |
--------------------------------------------------------------------------------
/moses/metrics.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 |
3 | import numpy as np
4 | import warnings
5 | from fcd_torch import FCD as FCDMetric
6 | from scipy.spatial.distance import cosine as cos_distance
7 | from scipy.stats import wasserstein_distance
8 |
9 | from .dataset import get_dataset, get_statistics
10 | from .metric_utils import compute_fragments, average_agg_tanimoto, compute_scaffolds, fingerprints, get_mol, canonic_smiles, mol_passes_filters, logP, QED, SA, weight
11 | from .utils import mapper, disable_rdkit_log, enable_rdkit_log
12 |
13 |
14 | def get_all_metrics(gen, k=None, n_jobs=1,
15 | device='cpu', batch_size=512, pool=None,
16 | test=None, test_scaffolds=None,
17 | ptest=None, ptest_scaffolds=None,
18 | train=None):
19 | """
20 | Computes all available metrics between test (scaffold test)
21 | and generated sets of SMILES.
22 | Parameters:
23 | gen: list of generated SMILES
24 | k: int or list with values for unique@k. Will calculate number of
25 | unique molecules in the first k molecules. Default [1000, 10000]
26 | n_jobs: number of workers for parallel processing
27 | device: 'cpu' or 'cuda:n', where n is GPU device number
28 | batch_size: batch size for FCD metric
29 | pool: optional multiprocessing pool to use for parallelization
30 |
31 | test (None or list): test SMILES. If None, will load
32 | a default test set
33 | test_scaffolds (None or list): scaffold test SMILES. If None, will
34 | load a default scaffold test set
35 | ptest (None or dict): precalculated statistics of the test set. If
36 | None, will load default test statistics. If you specified a custom
37 | test set, default test statistics will be ignored
38 | ptest_scaffolds (None or dict): precalculated statistics of the
39 | scaffold test set If None, will load default scaffold test
40 | statistics. If you specified a custom test set, default test
41 | statistics will be ignored
42 | train (None or list): train SMILES. If None, will load a default
43 | train set
44 | Available metrics:
45 | * %valid
46 | * %unique@k
47 | * Frechet ChemNet Distance (FCD)
48 | * Fragment similarity (Frag)
49 | * Scaffold similarity (Scaf)
50 | * Similarity to nearest neighbour (SNN)
51 | * Internal diversity (IntDiv)
52 | * Internal diversity 2: using square root of mean squared
53 | Tanimoto similarity (IntDiv2)
54 | * %passes filters (Filters)
55 | * Distribution difference for logP, SA, QED, weight
56 | * Novelty (molecules not present in train)
57 | """
58 | if test is None:
59 | if ptest is not None:
60 | raise ValueError(
61 | "You cannot specify custom test "
62 | "statistics for default test set")
63 | test = get_dataset('test')
64 | ptest = get_statistics('test')
65 |
66 | if test_scaffolds is None:
67 | if ptest_scaffolds is not None:
68 | raise ValueError(
69 | "You cannot specify custom scaffold test "
70 | "statistics for default scaffold test set")
71 | test_scaffolds = get_dataset('test_scaffolds')
72 | ptest_scaffolds = get_statistics('test_scaffolds')
73 |
74 | train = train or get_dataset('train')
75 |
76 | if k is None:
77 | k = [1000, 10000]
78 | disable_rdkit_log()
79 | metrics = {}
80 | close_pool = False
81 | if pool is None:
82 | if n_jobs != 1:
83 | pool = Pool(n_jobs)
84 | close_pool = True
85 | else:
86 | pool = 1
87 | metrics['valid'] = fraction_valid(gen, n_jobs=pool)
88 | gen = remove_invalid(gen, canonize=True)
89 | if not isinstance(k, (list, tuple)):
90 | k = [k]
91 | for _k in k:
92 | metrics['unique@{}'.format(_k)] = fraction_unique(gen, _k, pool)
93 |
94 | if ptest is None:
95 | ptest = compute_intermediate_statistics(test, n_jobs=n_jobs,
96 | device=device,
97 | batch_size=batch_size,
98 | pool=pool)
99 | if test_scaffolds is not None and ptest_scaffolds is None:
100 | ptest_scaffolds = compute_intermediate_statistics(
101 | test_scaffolds, n_jobs=n_jobs,
102 | device=device, batch_size=batch_size,
103 | pool=pool
104 | )
105 | mols = mapper(pool)(get_mol, gen)
106 | kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size}
107 | kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size}
108 | metrics['FCD/Test'] = FCDMetric(**kwargs_fcd)(gen=gen, pref=ptest['FCD'])
109 | metrics['SNN/Test'] = SNNMetric(**kwargs)(gen=mols, pref=ptest['SNN'])
110 | metrics['Frag/Test'] = FragMetric(**kwargs)(gen=mols, pref=ptest['Frag'])
111 | metrics['Scaf/Test'] = ScafMetric(**kwargs)(gen=mols, pref=ptest['Scaf'])
112 | if ptest_scaffolds is not None:
113 | metrics['FCD/TestSF'] = FCDMetric(**kwargs_fcd)(
114 | gen=gen, pref=ptest_scaffolds['FCD']
115 | )
116 | metrics['SNN/TestSF'] = SNNMetric(**kwargs)(
117 | gen=mols, pref=ptest_scaffolds['SNN']
118 | )
119 | metrics['Frag/TestSF'] = FragMetric(**kwargs)(
120 | gen=mols, pref=ptest_scaffolds['Frag']
121 | )
122 | metrics['Scaf/TestSF'] = ScafMetric(**kwargs)(
123 | gen=mols, pref=ptest_scaffolds['Scaf']
124 | )
125 |
126 | metrics['IntDiv'] = internal_diversity(mols, pool, device=device)
127 | metrics['IntDiv2'] = internal_diversity(mols, pool, device=device, p=2)
128 | metrics['Filters'] = fraction_passes_filters(mols, pool)
129 |
130 | # Properties
131 | for name, func in [('logP', logP), ('SA', SA),
132 | ('QED', QED),
133 | ('weight', weight)]:
134 | metrics[name] = WassersteinMetric(func, **kwargs)(
135 | gen=mols, pref=ptest[name])
136 |
137 | if train is not None:
138 | metrics['Novelty'] = novelty(mols, train, pool)
139 | enable_rdkit_log()
140 | if close_pool:
141 | pool.close()
142 | pool.join()
143 | return metrics
144 |
145 |
146 | def compute_intermediate_statistics(smiles, n_jobs=1, device='cpu',
147 | batch_size=512, pool=None):
148 | """
149 | The function precomputes statistics such as mean and variance for FCD, etc.
150 | It is useful to compute the statistics for test and scaffold test sets to
151 | speedup metrics calculation.
152 | """
153 | close_pool = False
154 | if pool is None:
155 | if n_jobs != 1:
156 | pool = Pool(n_jobs)
157 | close_pool = True
158 | else:
159 | pool = 1
160 | statistics = {}
161 | mols = mapper(pool)(get_mol, smiles)
162 | kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size}
163 | kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size}
164 | statistics['FCD'] = FCDMetric(**kwargs_fcd).precalc(smiles)
165 | statistics['SNN'] = SNNMetric(**kwargs).precalc(mols)
166 | statistics['Frag'] = FragMetric(**kwargs).precalc(mols)
167 | statistics['Scaf'] = ScafMetric(**kwargs).precalc(mols)
168 | for name, func in [('logP', logP), ('SA', SA),
169 | ('QED', QED),
170 | ('weight', weight)]:
171 | statistics[name] = WassersteinMetric(func, **kwargs).precalc(mols)
172 | if close_pool:
173 | pool.terminate()
174 | return statistics
175 |
176 |
177 | def fraction_passes_filters(gen, n_jobs=1):
178 | """
179 | Computes the fraction of molecules that pass filters:
180 | * MCF
181 | * PAINS
182 | * Only allowed atoms ('C','N','S','O','F','Cl','Br','H')
183 | * No charges
184 | """
185 | passes = mapper(n_jobs)(mol_passes_filters, gen)
186 | return np.mean(passes)
187 |
188 |
189 | def internal_diversity(gen, n_jobs=1, device='cpu', fp_type='morgan',
190 | gen_fps=None, p=1):
191 | """
192 | Computes internal diversity as:
193 | 1/|A|^2 sum_{x, y in AxA} (1-tanimoto(x, y))
194 | """
195 | if gen_fps is None:
196 | gen_fps = fingerprints(gen, fp_type=fp_type, n_jobs=n_jobs)
197 | return 1 - (average_agg_tanimoto(gen_fps, gen_fps,
198 | agg='mean', device=device, p=p)).mean()
199 |
200 |
201 | def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
202 | """
203 | Computes a number of unique molecules
204 | Parameters:
205 | gen: list of SMILES
206 | k: compute unique@k
207 | n_jobs: number of threads for calculation
208 | check_validity: raises ValueError if invalid molecules are present
209 | """
210 | if k is not None:
211 | if len(gen) < k:
212 | warnings.warn(
213 | "Can't compute unique@{}.".format(k) +
214 | "gen contains only {} molecules".format(len(gen))
215 | )
216 | gen = gen[:k]
217 | canonic = set(mapper(n_jobs)(canonic_smiles, gen))
218 | if None in canonic and check_validity:
219 | raise ValueError("Invalid molecule passed to unique@k")
220 | return len(canonic) / len(gen)
221 |
222 |
223 | def fraction_valid(gen, n_jobs=1):
224 | """
225 | Computes a number of valid molecules
226 | Parameters:
227 | gen: list of SMILES
228 | n_jobs: number of threads for calculation
229 | """
230 | gen = mapper(n_jobs)(get_mol, gen)
231 | return 1 - gen.count(None) / len(gen)
232 |
233 |
234 | def novelty(gen, train, n_jobs=1):
235 | gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
236 | gen_smiles_set = set(gen_smiles) - {None}
237 | train_set = set(train)
238 | return len(gen_smiles_set - train_set) / len(gen_smiles_set)
239 |
240 |
241 | def remove_invalid(gen, canonize=True, n_jobs=1):
242 | """
243 | Removes invalid molecules from the dataset
244 | """
245 | if not canonize:
246 | mols = mapper(n_jobs)(get_mol, gen)
247 | return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
248 | return [x for x in mapper(n_jobs)(canonic_smiles, gen) if
249 | x is not None]
250 |
251 |
252 | class Metric:
253 | def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
254 | self.n_jobs = n_jobs
255 | self.device = device
256 | self.batch_size = batch_size
257 | for k, v in kwargs.values():
258 | setattr(self, k, v)
259 |
260 | def __call__(self, ref=None, gen=None, pref=None, pgen=None):
261 | assert (ref is None) != (pref is None), "specify ref xor pref"
262 | assert (gen is None) != (pgen is None), "specify gen xor pgen"
263 | if pref is None:
264 | pref = self.precalc(ref)
265 | if pgen is None:
266 | pgen = self.precalc(gen)
267 | return self.metric(pref, pgen)
268 |
269 | def precalc(self, moleclues):
270 | raise NotImplementedError
271 |
272 | def metric(self, pref, pgen):
273 | raise NotImplementedError
274 |
275 |
276 | class SNNMetric(Metric):
277 | """
278 | Computes average max similarities of gen SMILES to ref SMILES
279 | """
280 |
281 | def __init__(self, fp_type='morgan', **kwargs):
282 | self.fp_type = fp_type
283 | super().__init__(**kwargs)
284 |
285 | def precalc(self, mols):
286 | return {'fps': fingerprints(mols, n_jobs=self.n_jobs,
287 | fp_type=self.fp_type)}
288 |
289 | def metric(self, pref, pgen):
290 | return average_agg_tanimoto(pref['fps'], pgen['fps'],
291 | device=self.device)
292 |
293 |
294 | def cos_similarity(ref_counts, gen_counts):
295 | """
296 | Computes cosine similarity between
297 | dictionaries of form {name: count}. Non-present
298 | elements are considered zero:
299 |
300 | sim = / ||r|| / ||g||
301 | """
302 | if len(ref_counts) == 0 or len(gen_counts) == 0:
303 | return np.nan
304 | keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
305 | ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
306 | gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
307 | return 1 - cos_distance(ref_vec, gen_vec)
308 |
309 |
310 | class FragMetric(Metric):
311 | def precalc(self, mols):
312 | return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
313 |
314 | def metric(self, pref, pgen):
315 | return cos_similarity(pref['frag'], pgen['frag'])
316 |
317 |
318 | class ScafMetric(Metric):
319 | def precalc(self, mols):
320 | return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
321 |
322 | def metric(self, pref, pgen):
323 | return cos_similarity(pref['scaf'], pgen['scaf'])
324 |
325 |
326 | class WassersteinMetric(Metric):
327 | def __init__(self, func=None, **kwargs):
328 | self.func = func
329 | super().__init__(**kwargs)
330 |
331 | def precalc(self, mols):
332 | if self.func is not None:
333 | values = mapper(self.n_jobs)(self.func, mols)
334 | else:
335 | values = mols
336 | return {'values': values}
337 |
338 | def metric(self, pref, pgen):
339 | return wasserstein_distance(
340 | pref['values'], pgen['values']
341 | )
342 |
--------------------------------------------------------------------------------
/moses/utils.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 |
3 | from rdkit import Chem
4 | from rdkit import rdBase
5 |
6 |
7 | # https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
8 |
9 |
10 | def mapper(n_jobs):
11 | '''
12 | Returns function for map call.
13 | If n_jobs == 1, will use standard map
14 | If n_jobs > 1, will use multiprocessing pool
15 | If n_jobs is a pool object, will return its map function
16 | '''
17 | if n_jobs == 1:
18 | def _mapper(*args, **kwargs):
19 | return list(map(*args, **kwargs))
20 |
21 | return _mapper
22 | if isinstance(n_jobs, int):
23 | pool = Pool(n_jobs)
24 |
25 | def _mapper(*args, **kwargs):
26 | try:
27 | result = pool.map(*args, **kwargs)
28 | finally:
29 | pool.terminate()
30 | return result
31 |
32 | return _mapper
33 | return n_jobs.map
34 |
35 |
36 | def disable_rdkit_log():
37 | rdBase.DisableLog('rdApp.*')
38 |
39 |
40 | def enable_rdkit_log():
41 | rdBase.EnableLog('rdApp.*')
42 |
43 |
44 | def get_mol(smiles_or_mol):
45 | '''
46 | Loads SMILES/molecule into RDKit's object
47 | '''
48 | if isinstance(smiles_or_mol, str):
49 | if len(smiles_or_mol) == 0:
50 | return None
51 | mol = Chem.MolFromSmiles(smiles_or_mol)
52 | if mol is None:
53 | return None
54 | try:
55 | Chem.SanitizeMol(mol)
56 | except ValueError:
57 | return None
58 | return mol
59 | return smiles_or_mol
60 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.3.0
2 | transformers==4.41.1
3 | pytorch_lightning==1.9.5
4 | deepspeed==0.14.2
5 | rdkit==2023.9.6
6 | fcd_torch==1.0.7
7 | lmdb==1.4.1
8 | omegaconf==2.3.0
9 | pandas==2.2.2
10 | umap-learn==0.5.6
11 | hdbscan==0.8.36
12 | jupyter==1.0.0
13 | wandb==0.17.0
--------------------------------------------------------------------------------
/scripts/generation/conditional/README-Generation-Cond.md:
--------------------------------------------------------------------------------
1 | # Conditional Generation with GraphGPT-C Decoder
2 |
3 | The conditional generation of GraphGPT-C is similar to the unconditional version. However, there are some extra configurations to control the properties. For other configurations you can refer to [Unconditional Generation](..%2Funconditional%2FREADME-Generation-Uncond.md).
4 |
5 |
6 |
7 | ## Extra Generation Configurations
8 |
9 | - **Conditions:**
10 | - `value_qed`: `(float) None` (The target QED value for generated molecules. The model will not condition on this property if not specified.)
11 | - `value_sa`: `(float) None` (The target SA score for generated molecules. The model will not condition on this property if not specified.)
12 | - `value_logp`: `(float) None` (The target logP value for generated molecules. The model will not condition on this property if not specified.)
13 | - `scaffold_smiles`: `(str) None` (The target scaffold SMILES. The model will not condition on this property if not specified.)
14 |
15 |
16 |
17 | ### Configurations in the Paper
18 |
19 | We use the following configuration to test the ability of GraphGPT on conditioning molecular properties:
20 |
21 | ````bash
22 | strict_generation="False"
23 | fix_aromatic_bond="True"
24 |
25 | do_sample="False"
26 |
27 | check_first_node="True"
28 | check_atom_valence="True"
29 | ````
30 |
31 | Example scripts can be found in `scripts/generation/unconditional/examples`.
32 |
33 |
34 |
35 | ### Generate with More Diversity
36 |
37 | You can further turn on the *probabilistic sampling* for more diversity:
38 |
39 | ````bash
40 | strict_generation="False"
41 | fix_aromatic_bond="True"
42 |
43 | do_sample="True"
44 | top_k=4
45 | temperature=1.0
46 |
47 | check_first_node="True"
48 | check_atom_valence="True"
49 | ````
50 |
51 |
52 |
53 | ## Evaluate & Visualize the Results
54 |
55 | Run `scripts/generation/conditional/visualize.sh`.
56 |
57 | The mean and variance property values of generated molecules will also be saved to the `summary.txt`.
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/scripts/generation/conditional/examples/generate_scaffold_logp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_name_or_path="DaizeDong/GraphsGPT-1W-C"
4 | #model_name_or_path="/mnt/petrelfs/dongdaize.d/workspace/graphsgpt/ckpts/GraphsGPT-1W-C"
5 | save_dir="./results/conditional/generation_scaffold_logp" # change this
6 | smiles_file="./data/examples/zinc_example.txt"
7 | num_batches=10
8 | batch_size=1024
9 | seed=0
10 |
11 | property_info_file="./configs/property_info.json"
12 | value_logp=0.0 # change this
13 | scaffold_smiles="c1ccccc1" # change this
14 |
15 | strict_generation="False"
16 | fix_aromatic_bond="True"
17 | do_sample="False"
18 | check_first_node="True"
19 | check_atom_valence="True"
20 |
21 | save_results="True"
22 | save_failed="False"
23 |
24 | python entrypoints/generation/conditional/generate.py \
25 | --model_name_or_path ${model_name_or_path} \
26 | --save_dir ${save_dir} \
27 | --smiles_file ${smiles_file} \
28 | --num_batches ${num_batches} \
29 | --batch_size ${batch_size} \
30 | --seed ${seed} \
31 | --property_info_file ${property_info_file} \
32 | --value_logp ${value_logp} \
33 | --scaffold_smiles ${scaffold_smiles} \
34 | --strict_generation ${strict_generation} \
35 | --do_sample ${do_sample} \
36 | --check_first_node ${check_first_node} \
37 | --check_atom_valence ${check_atom_valence} \
38 | --fix_aromatic_bond ${fix_aromatic_bond} \
39 | --save_results ${save_results} \
40 | --save_failed ${save_failed}
41 |
--------------------------------------------------------------------------------
/scripts/generation/conditional/examples/generate_scaffold_qed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_name_or_path="DaizeDong/GraphsGPT-1W-C"
4 | save_dir="./results/conditional/generation_scaffold_qed" # change this
5 | smiles_file="./data/examples/zinc_example.txt"
6 | num_batches=10
7 | batch_size=1024
8 | seed=0
9 |
10 | property_info_file="./configs/property_info.json"
11 | value_qed=0.5 # change this
12 | scaffold_smiles="c1ccccc1" # change this
13 |
14 | strict_generation="False"
15 | fix_aromatic_bond="True"
16 | do_sample="False"
17 | check_first_node="True"
18 | check_atom_valence="True"
19 |
20 | save_results="True"
21 | save_failed="False"
22 |
23 | python entrypoints/generation/conditional/generate.py \
24 | --model_name_or_path ${model_name_or_path} \
25 | --save_dir ${save_dir} \
26 | --smiles_file ${smiles_file} \
27 | --num_batches ${num_batches} \
28 | --batch_size ${batch_size} \
29 | --seed ${seed} \
30 | --property_info_file ${property_info_file} \
31 | --value_qed ${value_qed} \
32 | --scaffold_smiles ${scaffold_smiles} \
33 | --strict_generation ${strict_generation} \
34 | --do_sample ${do_sample} \
35 | --check_first_node ${check_first_node} \
36 | --check_atom_valence ${check_atom_valence} \
37 | --fix_aromatic_bond ${fix_aromatic_bond} \
38 | --save_results ${save_results} \
39 | --save_failed ${save_failed}
40 |
--------------------------------------------------------------------------------
/scripts/generation/conditional/examples/generate_scaffold_sa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_name_or_path="DaizeDong/GraphsGPT-1W-C"
4 | save_dir="./results/conditional/generation_scaffold_sa" # change this
5 | smiles_file="./data/examples/zinc_example.txt"
6 | num_batches=10
7 | batch_size=1024
8 | seed=0
9 |
10 | property_info_file="./configs/property_info.json"
11 | value_sa=0.7 # change this
12 | scaffold_smiles="c1ccccc1" # change this
13 |
14 | strict_generation="False"
15 | fix_aromatic_bond="True"
16 | do_sample="False"
17 | check_first_node="True"
18 | check_atom_valence="True"
19 |
20 | save_results="True"
21 | save_failed="False"
22 |
23 | python entrypoints/generation/conditional/generate.py \
24 | --model_name_or_path ${model_name_or_path} \
25 | --save_dir ${save_dir} \
26 | --smiles_file ${smiles_file} \
27 | --num_batches ${num_batches} \
28 | --batch_size ${batch_size} \
29 | --seed ${seed} \
30 | --property_info_file ${property_info_file} \
31 | --value_sa ${value_sa} \
32 | --scaffold_smiles ${scaffold_smiles} \
33 | --strict_generation ${strict_generation} \
34 | --do_sample ${do_sample} \
35 | --check_first_node ${check_first_node} \
36 | --check_atom_valence ${check_atom_valence} \
37 | --fix_aromatic_bond ${fix_aromatic_bond} \
38 | --save_results ${save_results} \
39 | --save_failed ${save_failed}
40 |
--------------------------------------------------------------------------------
/scripts/generation/conditional/visualize.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ### Conditional ###
3 |
4 | model_name_or_path="DaizeDong/GraphsGPT-1W-C"
5 | generation_results_dir="./results/conditional/generation_scaffold_logp/generated_results" # change this
6 | save_dir="./results/conditional/visualization_scaffold_logp" # change this
7 | save_images="False"
8 |
9 | # visualize all files to gather property info
10 | file_begin_index=0
11 | file_end_index=10
12 |
13 | python entrypoints/generation/conditional/visualize.py \
14 | --model_name_or_path ${model_name_or_path} \
15 | --generation_results_dir ${generation_results_dir} \
16 | --save_dir ${save_dir} \
17 | --save_images ${save_images} \
18 | --file_begin_index ${file_begin_index} \
19 | --file_end_index ${file_end_index}
20 |
--------------------------------------------------------------------------------
/scripts/generation/evaluation/moses.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_name_or_path="DaizeDong/GraphsGPT-1W"
4 | save_path="./results/unconditional/moses" # change this
5 |
6 | batch_size_valid=8196
7 | sample_std=1.0
8 | max_sample_times=10
9 | num_shots=100000
10 | num_samples_each_shot=1
11 |
12 | num_processes=32
13 |
14 | python entrypoints/generation/evaluation/evaluate_moses_few_shot_sampling.py \
15 | --model_name_or_path ${model_name_or_path} \
16 | --save_path ${save_path} \
17 | --batch_size_valid ${batch_size_valid} \
18 | --sample_std ${sample_std} \
19 | --max_sample_times ${max_sample_times} \
20 | --num_shots ${num_shots} \
21 | --num_samples_each_shot ${num_samples_each_shot} \
22 | --num_processes ${num_processes}
23 |
--------------------------------------------------------------------------------
/scripts/generation/evaluation/zinc250k.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_name_or_path="DaizeDong/GraphsGPT-1W"
4 | save_path="./results/unconditional/zinc250k" # change this
5 |
6 | batch_size_valid=8196
7 | sample_std=1.0
8 | max_sample_times=10
9 | num_shots=100000
10 | num_samples_each_shot=1
11 |
12 | num_processes=32
13 |
14 | python entrypoints/generation/evaluation/evaluate_zinc250k_few_shot_sampling.py \
15 | --model_name_or_path ${model_name_or_path} \
16 | --save_path ${save_path} \
17 | --batch_size_valid ${batch_size_valid} \
18 | --sample_std ${sample_std} \
19 | --max_sample_times ${max_sample_times} \
20 | --num_shots ${num_shots} \
21 | --num_samples_each_shot ${num_samples_each_shot} \
22 | --num_processes ${num_processes}
23 |
--------------------------------------------------------------------------------
/scripts/generation/unconditional/README-Generation-Uncond.md:
--------------------------------------------------------------------------------
1 | # Unconditional Generation with GraphGPT Decoder
2 |
3 | The generation of GraphsGPT is controlled by multiple adjustable configurations. You can refer to the following descriptions to adjust the generation to different needs.
4 |
5 | ## Generation Configurations
6 |
7 | - **Auto-fix Toggle:**
8 | - `strict_generation`: `(bool) True` (Whether to tolerate the validity exceptions during generation. Setting to `False` will enable flexible generation that automatically fixes invalid predictions and ensure maximum effectiveness.)
9 | - `fix_aromatic_bond`: `(bool) False` (Whether to fix the dissociative aromatic bonds in the generated molecules.)
10 |
11 |
12 | - **Sampling Strategy:**
13 | - `do_sample`: `(bool) False` (Whether to use probabilistic sampling for bond predictions. Setting to `True` will enable probabilistic sampling and introduce more randomness.)
14 | - `top_k`: `(int) None` (The range of top predictions for probability sampling. Available when `do_sample` is `True`.)
15 | - `temperature`: `(float) 1.0` (Temperature to adjust the probability distribution. Available when `do_sample` is `True`.)
16 |
17 |
18 | - **Hyperparameters:**
19 | - `max_atoms`: `(int) None` (The maximum number of atoms for generation.)
20 | - `similarity_threshold`: `(float) 0.5` (Threshold for classifying whether a generated atom is new or old.)
21 |
22 |
23 | - **Other Check Terms:**
24 | - `check_first_node`: `(bool) True` (Whether to check the consistency between the predicted beginning atom and the first bond, and fix the order of the beginning two atoms.)
25 | - `check_atom_valence`: `(bool) False` (Whether to check the validity regarding the valence of the atoms connected to the predicted bonds.)
26 |
27 | For reference, we provide some example configurations to use under different circumstances.
28 |
29 | ### Validate the Pretraining Performance
30 |
31 | To validate the pretraining performance, the generation should be of no randomness, both *auto-fix* and *probabilistic sampling* should be turned off:
32 |
33 | ````bash
34 | strict_generation="True"
35 | fix_aromatic_bond="False"
36 |
37 | do_sample="False"
38 |
39 | check_first_node="True"
40 | check_atom_valence="False"
41 | ````
42 |
43 | An example script can be found in `scripts/generation/unconditional/examples/generate_strict.sh`.
44 |
45 | ### Generate with More Effectiveness
46 |
47 | Upon generation of more effective results. You can turn on the *auto-fix*:
48 |
49 | ````bash
50 | strict_generation="False"
51 | fix_aromatic_bond="True"
52 |
53 | do_sample="False"
54 |
55 | check_first_node="True"
56 | check_atom_valence="False"
57 | ````
58 |
59 | An example script can be found in `scripts/generation/unconditional/examples/generate_flexible.sh`.
60 |
61 | ### Generate with More Diversity (Further Finetuning Needed)
62 |
63 | To generate more diverse results. You can further turn on the *probabilistic sampling* and *valence check*. This requires the encoded Graph Words $\mathcal{W}$ to be of variable information, where an extra finetuning would be needed.
64 |
65 | ````bash
66 | strict_generation="False"
67 | fix_aromatic_bond="True"
68 |
69 | do_sample="True"
70 | top_k=4
71 | temperature=1.0
72 |
73 | check_first_node="True"
74 | check_atom_valence="True"
75 | ````
76 |
77 | ## Visualize the Results
78 |
79 | Run `scripts/generation/unconditional/visualize.sh`.
80 |
81 |
--------------------------------------------------------------------------------
/scripts/generation/unconditional/examples/generate_flexible.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_name_or_path="DaizeDong/GraphsGPT-1W"
4 | save_dir="./results/unconditional/generation" # change this
5 | smiles_file="./data/examples/zinc_example.txt"
6 | num_batches=10
7 | batch_size=1024
8 | seed=0
9 |
10 | strict_generation="False"
11 | fix_aromatic_bond="True"
12 | do_sample="True"
13 | check_first_node="True"
14 | check_atom_valence="False"
15 |
16 | save_results="True"
17 | save_failed="False"
18 |
19 | python entrypoints/generation/unconditional/generate.py \
20 | --model_name_or_path ${model_name_or_path} \
21 | --save_dir ${save_dir} \
22 | --smiles_file ${smiles_file} \
23 | --num_batches ${num_batches} \
24 | --batch_size ${batch_size} \
25 | --seed ${seed} \
26 | --strict_generation ${strict_generation} \
27 | --do_sample ${do_sample} \
28 | --check_first_node ${check_first_node} \
29 | --check_atom_valence ${check_atom_valence} \
30 | --fix_aromatic_bond ${fix_aromatic_bond} \
31 | --save_results ${save_results} \
32 | --save_failed ${save_failed}
33 |
--------------------------------------------------------------------------------
/scripts/generation/unconditional/examples/generate_strict.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | model_name_or_path="DaizeDong/GraphsGPT-1W"
4 | save_dir="./results/unconditional/generation" # change this
5 | smiles_file="./data/examples/zinc_example.txt"
6 | num_batches=10
7 | batch_size=1024
8 | seed=0
9 |
10 | strict_generation="True"
11 | fix_aromatic_bond="False"
12 | do_sample="False"
13 | check_first_node="True"
14 | check_atom_valence="False"
15 |
16 | save_results="True"
17 | save_failed="False"
18 |
19 | python entrypoints/generation/unconditional/generate.py \
20 | --model_name_or_path ${model_name_or_path} \
21 | --save_dir ${save_dir} \
22 | --smiles_file ${smiles_file} \
23 | --num_batches ${num_batches} \
24 | --batch_size ${batch_size} \
25 | --seed ${seed} \
26 | --strict_generation ${strict_generation} \
27 | --do_sample ${do_sample} \
28 | --check_first_node ${check_first_node} \
29 | --check_atom_valence ${check_atom_valence} \
30 | --fix_aromatic_bond ${fix_aromatic_bond} \
31 | --save_results ${save_results} \
32 | --save_failed ${save_failed}
33 |
--------------------------------------------------------------------------------
/scripts/generation/unconditional/visualize.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ### Unconditional ###
3 |
4 | model_name_or_path="DaizeDong/GraphsGPT-1W"
5 | generation_results_dir="./results/unconditional/generation/generated_results" # change this
6 | save_dir="./results/unconditional/visualization" # change this
7 | save_images="True"
8 |
9 | # only visualize 2 files to save time & storage
10 | file_begin_index=0
11 | file_end_index=2
12 |
13 | python entrypoints/generation/unconditional/visualize.py \
14 | --model_name_or_path ${model_name_or_path} \
15 | --generation_results_dir ${generation_results_dir} \
16 | --save_dir ${save_dir} \
17 | --save_images ${save_images} \
18 | --file_begin_index ${file_begin_index} \
19 | --file_end_index ${file_end_index}
20 |
--------------------------------------------------------------------------------
/scripts/representation/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python entrypoints/representation/finetune.py --task_name tox21
4 | python entrypoints/representation/finetune.py --task_name toxcast
5 | python entrypoints/representation/finetune.py --task_name bbbp
6 | python entrypoints/representation/finetune.py --task_name sider
7 | python entrypoints/representation/finetune.py --task_name hiv
8 | python entrypoints/representation/finetune.py --task_name bace
9 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='graphsgpt',
5 | version='1.0',
6 | packages=find_packages(),
7 | install_requires=[],
8 | classifiers=[
9 | 'Programming Language :: Python :: 3',
10 | 'Programming Language :: Python :: 3.6',
11 | 'Programming Language :: Python :: 3.7',
12 | 'Programming Language :: Python :: 3.8',
13 | ],
14 | )
15 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/utils/__init__.py
--------------------------------------------------------------------------------
/utils/accuracy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import no_grad
3 |
4 |
5 | @no_grad()
6 | def classification_accuracy(
7 | logits: torch.FloatTensor,
8 | labels: torch.LongTensor,
9 | ):
10 | prediction = torch.argmax(logits, dim=1)
11 | correct = (prediction == labels).float()
12 | return torch.mean(correct)
13 |
--------------------------------------------------------------------------------
/utils/algorithms.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | import numpy as np
4 |
5 |
6 | def find_factors_with_minimal_sum(number):
7 | if number == 1:
8 | return (1, 1)
9 |
10 | # Initialize variables to keep track of the factors with the minimal sum
11 | min_sum = float("inf")
12 | min_factors = None
13 |
14 | # Iterate through potential factors from 1 to half of the number
15 | for factor1 in range(1, number // 2 + 1):
16 | factor2 = number // factor1
17 |
18 | # Check if factor1 * factor2 is equal to the original number
19 | if factor1 * factor2 == number:
20 | current_sum = factor1 + factor2
21 |
22 | # Update the minimum sum and factors if the current sum is smaller
23 | if current_sum < min_sum:
24 | min_sum = current_sum
25 | min_factors = (factor1, factor2)
26 |
27 | return min_factors
28 |
29 |
30 | class FindCycles:
31 | """
32 | Example:
33 | adjacency_matrix = np.array([
34 | [0, 1, 0, 0, 1, 1, 0],
35 | [1, 0, 1, 1, 0, 0, 0],
36 | [0, 1, 0, 1, 0, 0, 0],
37 | [0, 1, 1, 0, 1, 0, 0],
38 | [1, 0, 0, 1, 0, 0, 0],
39 | [1, 0, 0, 0, 0, 0, 1],
40 | [0, 0, 0, 0, 0, 1, 0],
41 | ])
42 |
43 | finder = FindCycles(adjacency_matrix)
44 | cycles = finder.find_cycles()
45 | """
46 |
47 | def __init__(self, adj_matrix):
48 | self.graph = adj_matrix
49 | self.num_vertices = len(adj_matrix)
50 | self.visited = [False] * self.num_vertices
51 | self.cycles = set() # Set to store unique cycles
52 |
53 | def find_cycles(self):
54 | # Remove nodes with degree 1
55 | self.remove_degree_one_nodes()
56 |
57 | # Perform DFS for remaining nodes
58 | for vertex in range(self.num_vertices):
59 | if not self.visited[vertex]:
60 | self.dfs(vertex, vertex, [])
61 |
62 | # Deduplicate
63 | self.deduplicate_cycles()
64 |
65 | return list(self.cycles)
66 |
67 | def remove_degree_one_nodes(self):
68 | # Use a deque to efficiently process nodes with degree 1
69 | queue = deque([i for i in range(self.num_vertices) if np.sum(self.graph[i]) == 1])
70 |
71 | while queue:
72 | node = queue.popleft()
73 | self.graph[node, node] = 0 # Mark the node as removed
74 |
75 | # Decrease the degree of its neighbor
76 | neighbor = np.argmax(self.graph[node])
77 | self.graph[node, neighbor] = 0
78 | self.graph[neighbor, node] = 0
79 |
80 | # If the neighbor now has degree 1, add it to the queue
81 | if np.sum(self.graph[neighbor]) == 1:
82 | queue.append(neighbor)
83 |
84 | def dfs(self, start, current, path):
85 | self.visited[current] = True
86 | path.append(current)
87 |
88 | for neighbor in range(self.num_vertices):
89 | if self.graph[current, neighbor] == 1:
90 | if neighbor == start and len(path) > 2:
91 | # Found a cycle with at least 3 vertices
92 | self.cycles.add(tuple(path))
93 | elif not self.visited[neighbor] and neighbor > start:
94 | # Continue DFS only if the neighbor has not been removed
95 | self.dfs(start, neighbor, path)
96 |
97 | path.pop()
98 | self.visited[current] = False
99 |
100 | def deduplicate_cycles(self):
101 | all_cycles = self.cycles
102 | record_set = set()
103 |
104 | self.cycles = []
105 | for cycle in all_cycles:
106 | sorted_cycle = tuple(sorted(cycle))
107 | if sorted_cycle not in record_set:
108 | record_set.add(sorted_cycle)
109 | self.cycles.append(cycle)
110 |
--------------------------------------------------------------------------------
/utils/io.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import gzip
3 | import json
4 | import lzma
5 | import numpy as np
6 | import os
7 | import pickle
8 | import shutil
9 | from rdkit import Chem
10 | from rdkit.Chem import Draw
11 | from typing import Union, List, Dict
12 |
13 | from utils.operations.operation_string import extract_numbers
14 |
15 |
16 | def create_dir(dir):
17 | if not os.path.exists(dir):
18 | os.makedirs(dir)
19 |
20 |
21 | def delete_file_or_dir(dir):
22 | if os.path.isfile(dir):
23 | os.remove(dir)
24 | elif os.path.exists(dir):
25 | shutil.rmtree(dir)
26 | else:
27 | pass
28 |
29 |
30 | def save_compressed_file_7z(data, file_path): # 7z
31 | create_dir(os.path.dirname(file_path))
32 | with lzma.open(file_path, "wb") as file:
33 | pickle.dump(data, file)
34 |
35 |
36 | def load_compressed_file_7z(file_path): # 7z
37 | with lzma.open(file_path, "rb") as file:
38 | data = pickle.load(file)
39 | return data
40 |
41 |
42 | def save_compressed_file_gz(data, file_path, compresslevel=6): # gz
43 | create_dir(os.path.dirname(file_path))
44 | with gzip.open(file_path, "wb", compresslevel=compresslevel) as file:
45 | pickle.dump(data, file)
46 |
47 |
48 | def load_compressed_file_gz(file_path): # gz
49 | with gzip.open(file_path, "rb") as file:
50 | data = pickle.load(file)
51 | return data
52 |
53 |
54 | def read_csv(file_path, has_header=True) -> Union[List[List], List[Dict]]:
55 | """
56 | Read a CSV file and return its content.
57 |
58 | Args:
59 | - file_path (str): Path to the CSV file.
60 | - has_header (bool): Whether the CSV file has a header. Default is True.
61 |
62 | Returns:
63 | - list of list or dict: Content of the CSV file.
64 | If has_header is True, return a list of dictionaries;
65 | if has_header is False, return a list of lists.
66 | """
67 | data = []
68 | with open(file_path, newline='', encoding='utf-8') as f:
69 | if has_header:
70 | csvreader = csv.DictReader(f)
71 | for row in csvreader:
72 | data.append(dict(row))
73 | else:
74 | csvreader = csv.reader(f)
75 | for row in csvreader:
76 | data.append(row)
77 | return data
78 |
79 |
80 | def load_json(file_path):
81 | with open(file_path, "r", encoding="utf8") as f:
82 | data = json.load(f)
83 | return data
84 |
85 |
86 | def save_json(data, file_path, indent=4, **kwargs):
87 | create_dir(os.path.dirname(file_path))
88 | with open(file_path, "w", encoding="utf8") as f:
89 | f.write(f"{json.dumps(data, ensure_ascii=False, indent=indent, **kwargs)}\n")
90 |
91 |
92 | def load_jsonl(file_path) -> list:
93 | data = []
94 | with open(file_path, "r", encoding="utf8") as f:
95 | for line in f:
96 | try:
97 | data.append(json.loads(line))
98 | except json.JSONDecodeError as e:
99 | print(f"Error decoding line: {line}")
100 | continue
101 | return data
102 |
103 |
104 | def save_jsonl(data, file_path, **kwargs):
105 | create_dir(os.path.dirname(file_path))
106 | with open(file_path, "w", encoding="utf8") as f:
107 | for ins in data:
108 | f.write(f"{json.dumps(ins, ensure_ascii=False, **kwargs)}\n")
109 |
110 |
111 | def compress_png_image(image_path, print_info=False):
112 | import cv2
113 | img = cv2.imread(image_path, cv2.IMREAD_COLOR)
114 | cv2.imwrite(image_path, img, [cv2.IMWRITE_PNG_COMPRESSION, 9])
115 | if print_info:
116 | print(f'Done for "{image_path}".')
117 |
118 |
119 | """for this project"""
120 |
121 |
122 | def find_best_model(model_dir):
123 | best_model_file = None
124 | best_opt_steps = 0
125 | best_val_loss = float("inf")
126 |
127 | for root, dirs, files in os.walk(model_dir):
128 | for file in files:
129 | file_postfix = file.split(".")[-1]
130 | file_name = file.replace("." + file_postfix, "")
131 |
132 | if file_postfix == "ckpt" and "best" in file_name:
133 | # Example: "best-opt_steps=091999-val_loss=0.1109.ckpt"
134 | this_opt_steps, this_val_loss = extract_numbers(file_name)
135 |
136 | if this_val_loss < best_val_loss:
137 | # model with the minimal val_loss
138 | best_model_file = os.path.join(root, file)
139 | best_opt_steps = this_opt_steps
140 | best_val_loss = this_val_loss
141 |
142 | elif this_val_loss == best_val_loss and this_opt_steps > best_opt_steps:
143 | # model with the largest opt_steps
144 | best_model_file = os.path.join(root, file)
145 | best_opt_steps = this_opt_steps
146 | best_val_loss = this_val_loss
147 |
148 | return best_model_file
149 |
150 |
151 | def find_last_model(model_dir):
152 | last_model_file = None
153 |
154 | for root, dirs, files in os.walk(model_dir):
155 | for file in files:
156 | if file == "last.ckpt":
157 | last_model_file = os.path.join(root, file)
158 | break
159 |
160 | return last_model_file
161 |
162 |
163 | def get_avg_acc_from_file(save_file_match_num):
164 | acc_list = []
165 | with open(save_file_match_num, "r") as f:
166 | lines = f.readlines()
167 | for line in lines:
168 | # example line: "Accuracy of sample 0: 100.00% (1024/1024)"
169 | # example line: "Consistency of sample 0: 100.00% (523776/523776)"
170 | numbers = extract_numbers(line)
171 | acc_list.append(numbers[1])
172 | return sum(acc_list) / len(acc_list)
173 |
174 |
175 | def save_mol_png(mol, save_path, size=(512, 512)):
176 | img = Draw.MolToImage(mol, size=size)
177 | img.save(save_path)
178 | img.close()
179 |
180 |
181 | def save_empty_png(save_path, size=(512, 512)):
182 | img = Draw.MolToImage(Chem.Mol(), size=size)
183 | img.save(save_path)
184 | img.close()
185 |
186 |
187 | def summary_property_from_file(file_name):
188 | value_list = []
189 | with open(file_name, "r") as f:
190 | lines = f.readlines()
191 | for line in lines:
192 | line = line.removesuffix("\n")
193 | if line != "None":
194 | number = float(line)
195 | value_list.append(number)
196 | if len(value_list) > 0:
197 | values = np.array(value_list)
198 | return values.mean(), values.std(), len(value_list)
199 | else:
200 | return -1, -1, -1
201 |
202 |
203 | def summary_property_from_all_files(search_dir, file_name, value_type="float"):
204 | value_list = []
205 | for root, dirs, files in os.walk(search_dir):
206 | for file in files:
207 | if file == file_name:
208 | with open(os.path.join(root, file), "r") as f:
209 | lines = f.readlines()
210 | for line in lines:
211 | line = line.removesuffix("\n")
212 | if line != "None":
213 | if value_type == "float":
214 | number = float(line)
215 | elif value_type == "bool":
216 | number = (line == "True")
217 | value_list.append(number)
218 | if len(value_list) > 0:
219 | values = np.array(value_list)
220 | return values.mean(), values.std(), len(value_list)
221 | else:
222 | return -1, -1, -1
223 |
--------------------------------------------------------------------------------
/utils/molecule.py:
--------------------------------------------------------------------------------
1 | from rdkit.Chem import MolStandardize
2 | from rdkit.Chem.Scaffolds import MurckoScaffold
3 |
4 |
5 | def get_molecule_standard_scaffold(mol, normalizer=None):
6 | if normalizer is None:
7 | normalizer = MolStandardize.normalize.Normalizer()
8 |
9 | if mol is not None:
10 | try:
11 | scaffold = MurckoScaffold.GetScaffoldForMol(mol)
12 |
13 | try:
14 | standardized_scaffold = normalizer.normalize(scaffold)
15 | except:
16 | standardized_scaffold = scaffold
17 |
18 | return standardized_scaffold
19 | except:
20 | return None
21 | else:
22 | return None
23 |
--------------------------------------------------------------------------------
/utils/operations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/utils/operations/__init__.py
--------------------------------------------------------------------------------
/utils/operations/operation_dataframe.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 |
4 | def chunk_dataframe(df, num_chunks):
5 | """
6 | Split the input DataFrame into a specified number of chunks.
7 |
8 | Args:
9 | - df (pd.DataFrame): The DataFrame to be evenly divided.
10 | - num_chunks (int): The desired number of chunks.
11 |
12 | Returns:
13 | - list of pd.DataFrame: List of chunks after even division.
14 |
15 | Example:
16 | >>> input_df = pd.DataFrame({'A': range(1, 10)})
17 | >>> num_chunks = 5
18 | >>> result = chunk_dataframe(input_df, num_chunks)
19 | >>> for chunk in result:
20 | >>> print(chunk) # Output: DataFrame chunks printed one by one
21 | """
22 | avg_chunk_size = len(df) // num_chunks
23 | remainder = len(df) % num_chunks
24 |
25 | chunks = []
26 | start = 0
27 | for _ in range(num_chunks):
28 | chunk_size = avg_chunk_size + 1 if remainder > 0 else avg_chunk_size
29 | chunks.append(df.iloc[start:start + chunk_size])
30 | start += chunk_size
31 | remainder -= 1
32 |
33 | return chunks
34 |
35 |
36 | def chunk_dataframe_with_yield(df, num_chunks):
37 | """
38 | Split the input DataFrame into a specified number of chunks using a generator.
39 |
40 | Args:
41 | - df (pd.DataFrame): The DataFrame to be evenly divided.
42 | - num_chunks (int): The desired number of chunks.
43 |
44 | Yields:
45 | - pd.DataFrame: DataFrame chunks yielded one at a time.
46 |
47 | Example:
48 | >>> input_df = pd.DataFrame({'A': range(1, 10)})
49 | >>> num_chunks = 5
50 | >>> for chunk in chunk_dataframe_with_yield(input_df, num_chunks):
51 | >>> print(chunk) # Output: DataFrame chunks printed one by one
52 | """
53 | avg_chunk_size = len(df) // num_chunks
54 | remainder = len(df) % num_chunks
55 |
56 | start = 0
57 | for _ in range(num_chunks):
58 | chunk_size = avg_chunk_size + 1 if remainder > 0 else avg_chunk_size
59 | yield df.iloc[start:start + chunk_size]
60 | start += chunk_size
61 | remainder -= 1
62 |
--------------------------------------------------------------------------------
/utils/operations/operation_dict.py:
--------------------------------------------------------------------------------
1 | def reverse_dict(input_dict, aggregate_same_results=True):
2 | output_dict = {}
3 | for key, value in input_dict.items():
4 | if value not in output_dict:
5 | output_dict[value] = key
6 | else:
7 | if aggregate_same_results:
8 | if not isinstance(output_dict[value], list):
9 | output_dict[value] = [output_dict[value]]
10 | output_dict[value].append(key)
11 | else:
12 | raise ValueError("Input dictionary does not satisfy the one-to-one mapping condition.")
13 | return output_dict
14 |
--------------------------------------------------------------------------------
/utils/operations/operation_list.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Union, List
3 |
4 |
5 | def chunk_list(input_list, num_chunks):
6 | """
7 | Split the input list into a specified number of chunks.
8 |
9 | Args:
10 | - input_list (list): The list to be evenly divided.
11 | - num_chunks (int): The desired number of chunks.
12 |
13 | Returns:
14 | - list of lists: List of chunks after even division.
15 |
16 | Example:
17 | >>> input_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
18 | >>> num_chunks = 5
19 | >>> result = chunk_list(input_list, num_chunks)
20 | >>> print(result) # Output: [[1, 2], [3, 4], [5, 6], [7, 8], [9]]
21 | """
22 | avg_chunk_size = len(input_list) // num_chunks
23 | remainder = len(input_list) % num_chunks
24 |
25 | chunks = []
26 | start = 0
27 | for _ in range(num_chunks):
28 | chunk_size = avg_chunk_size + 1 if remainder > 0 else avg_chunk_size
29 | chunks.append(input_list[start:start + chunk_size])
30 | start += chunk_size
31 | remainder -= 1
32 |
33 | return chunks
34 |
35 |
36 | def chunk_list_with_yield(input_list, num_chunks):
37 | """
38 | Split the input list into a specified number of chunks using a generator.
39 |
40 | Args:
41 | - input_list (list): The list to be evenly divided.
42 | - num_chunks (int): The desired number of chunks.
43 |
44 | Yields:
45 | - list of lists: Chunks yielded one at a time.
46 |
47 | Example:
48 | >>> input_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
49 | >>> num_chunks = 5
50 | >>> for chunk in chunk_list_with_yield(input_list, num_chunks):
51 | >>> print(chunk) # Output: [1, 2] [3, 4] [5, 6] [7, 8] [9]
52 | """
53 | avg_chunk_size = len(input_list) // num_chunks
54 | remainder = len(input_list) % num_chunks
55 |
56 | start = 0
57 | for _ in range(num_chunks):
58 | chunk_size = avg_chunk_size + 1 if remainder > 0 else avg_chunk_size
59 | yield input_list[start:start + chunk_size]
60 | start += chunk_size
61 | remainder -= 1
62 |
63 |
64 | def split_list(input_list, split_length, drop_last=False):
65 | """
66 | Split a list into sublists, each with a length of split_length.
67 |
68 | Args:
69 | - input_list (list): The list to be split.
70 | - split_length (int): Length of each sublist.
71 | - drop_last (bool): Whether to drop the last sublist if its length is insufficient. Default is False.
72 |
73 | Returns:
74 | - list of lists: List of split sublists.
75 |
76 | Example:
77 | >>> input_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
78 | >>> split_length = 5
79 | >>> result = split_list(input_list, split_length, drop_last=False)
80 | >>> print(result) # Output: [[1, 2, 3, 4, 5], [6, 7, 8, 9]]
81 | """
82 | if split_length <= 0:
83 | raise ValueError("split_length must be a positive integer!")
84 |
85 | num_elements = len(input_list)
86 | num_splits = num_elements // split_length
87 |
88 | sublists = [input_list[i * split_length: (i + 1) * split_length] for i in range(num_splits)]
89 |
90 | if not drop_last and num_splits * split_length < num_elements:
91 | sublists.append(input_list[num_splits * split_length:])
92 |
93 | return sublists
94 |
95 |
96 | def split_list_with_yield(input_list, split_length, drop_last=False):
97 | """
98 | Split a list into sublists using a generator, each with a length of split_length.
99 |
100 | Args:
101 | - input_list (list): The list to be split.
102 | - split_length (int): Length of each sublist.
103 | - drop_last (bool): Whether to drop the last sublist if its length is insufficient. Default is False.
104 |
105 | Yields:
106 | - list of lists: Sublists yielded one at a time.
107 |
108 | Example:
109 | >>> input_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
110 | >>> split_length = 5
111 | >>> result = split_list_with_yield(input_list, split_length, drop_last=False)
112 | >>> for sublist in result:
113 | >>> print(sublist) # Output: [1, 2, 3, 4, 5] [6, 7, 8, 9]
114 | """
115 | if split_length <= 0:
116 | raise ValueError("split_length must be a positive integer!")
117 |
118 | num_elements = len(input_list)
119 | num_splits = num_elements // split_length
120 |
121 | start = 0
122 | for _ in range(num_splits):
123 | sublist = input_list[start: start + split_length]
124 | yield sublist
125 | start += split_length
126 |
127 | if not drop_last and start < num_elements:
128 | sublist = input_list[start:]
129 | yield sublist
130 |
131 |
132 | def replicate_elements(input_list, num_copies: Union[int, float, List[int]]):
133 | """
134 | Replicate each element in the original list a fixed or variable number of times,
135 | with support for decimals indicating a proportional number of additional elements.
136 |
137 | Args:
138 | - input_list (list): The original list of elements.
139 | - num_copies: The number of times each element should be replicated.
140 | It can take multiple forms:
141 | - An integer: Each element in the list is replicated this many times.
142 | - A float: The integer part dictates the fixed number of times each element is replicated.
143 | The fractional part is used to determine a proportional number of extra elements to be
144 | randomly chosen and replicated once. For example, if num_copies is 2.5 and the list
145 | has 4 elements, each element is replicated 2 times, and additionally, 2 (50% of 4)
146 | randomly chosen elements are replicated once more.
147 | - A list of integers: Each element in the input list is replicated according to the
148 | corresponding number of times specified in this list. This allows for variable replication
149 | per element. The length of this list must match the length of the input list.
150 |
151 | Returns:
152 | - list: The new list with replicated elements.
153 | """
154 | if isinstance(num_copies, (int, float)):
155 | # Fixed replication, integer part and proportional part handled
156 | int_part = int(num_copies) # Integer part for the definite copies
157 | num_copies_list = [int_part] * len(input_list)
158 |
159 | if isinstance(num_copies, float):
160 | frac_part = num_copies - int_part # Fractional part for the proportional extra copies
161 | extra_copies_count = round(frac_part * len(input_list))
162 |
163 | if extra_copies_count > 0:
164 | # Choose the items to be replicated
165 | extra_num_copies_array = np.concatenate((
166 | np.ones(extra_copies_count, dtype=int),
167 | np.zeros(len(input_list) - extra_copies_count, dtype=int)
168 | ), axis=0)
169 | np.random.shuffle(extra_num_copies_array)
170 |
171 | # Add to the "num_copies_list"
172 | num_copies_list = (np.array(num_copies_list) + extra_num_copies_array).tolist()
173 |
174 | elif isinstance(num_copies, list):
175 | # Variable replication based on the list
176 | if len(input_list) != len(num_copies):
177 | raise ValueError("Lengths of input_list and num_copies_list must be the same.")
178 |
179 | num_copies_list = num_copies
180 |
181 | else:
182 | raise ValueError("Invalid type for num_copies. It should be an int, float, or a list.")
183 |
184 | new_list = []
185 | for item, num_copies_item in zip(input_list, num_copies_list):
186 | new_list.extend([item] * num_copies_item)
187 |
188 | return new_list
189 |
190 |
191 | def all_elements_equal(input_list):
192 | """Check if all elements in the list are equal."""
193 | if not input_list:
194 | return True
195 | return len(set(input_list)) == 1
196 |
197 |
198 | def mean_value_of_elements(input_list):
199 | """Return the mean value of elements in the list."""
200 | total_value = 0
201 | total_cnt = 0
202 | for element in input_list:
203 | if element is not None:
204 | total_value += element
205 | total_cnt += 1
206 | if total_cnt > 0:
207 | return total_value / total_cnt
208 | else:
209 | return 0
210 |
--------------------------------------------------------------------------------
/utils/operations/operation_number.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 |
4 | def normalize_value(value: Union[int, float], mean, std):
5 | if value is not None:
6 | if std != 0:
7 | return (value - mean) / std
8 | else:
9 | return value
10 | else:
11 | return None
12 |
13 |
14 | def denormalize_value(normalized_value: Union[int, float], mean, std):
15 | if normalized_value is not None:
16 | if std != 0:
17 | return (normalized_value * std) + mean
18 | else:
19 | return normalized_value
20 | else:
21 | return None
22 |
--------------------------------------------------------------------------------
/utils/operations/operation_string.py:
--------------------------------------------------------------------------------
1 | import re
2 | from argparse import ArgumentTypeError
3 |
4 |
5 | def str2bool(v):
6 | if isinstance(v, bool):
7 | return v
8 | if v.lower() in ("yes", "true", "t", "y", "1"):
9 | return True
10 | elif v.lower() in ("no", "false", "f", "n", "0"):
11 | return False
12 | else:
13 | raise ArgumentTypeError("Boolean value expected.")
14 |
15 |
16 | def string2number_list(string, sep=","):
17 | if isinstance(string, list) or string is None:
18 | return string
19 | else:
20 | split_string = string.split(sep)
21 | return [float(num) if "." in num else int(num) for num in split_string]
22 |
23 |
24 | def extract_numbers(string):
25 | """Extract numbers (int, float) from a given string."""
26 | pattern = r"[-+]?\d*\.\d+|\d+"
27 | matches = re.findall(pattern, string)
28 | numbers = [float(match) if '.' in match else int(match) for match in matches]
29 | return numbers
30 |
31 |
32 | def calculate_non_ascii_ratio(string):
33 | """Calculate the non-ASCII ratio of a given string."""
34 | if len(string) == 0:
35 | non_ascii_ratio = 0.0
36 | else:
37 | non_ascii_count = sum(1 for char in string if ord(char) >= 128)
38 | non_ascii_ratio = non_ascii_count / len(string)
39 | return non_ascii_ratio
40 |
41 |
42 | def remove_non_ascii_code(string):
43 | """Use a regular expression to remove all non-ASCII characters"""
44 | string = re.sub(r'[^\x00-\x7F]+', '', string)
45 | return string
46 |
47 |
48 | def replace_non_ascii_code(string):
49 | """
50 | Replace common non-ASCII characters with their ASCII counterparts in the given string.
51 |
52 | :param string: Input string with non-ASCII characters.
53 | :return: String with non-ASCII characters replaced.
54 | """
55 | string = re.sub(r'“|”', "\"", string)
56 | string = re.sub(r'‘|’', "\'", string)
57 | string = re.sub(r'—|–', "-", string)
58 | string = re.sub(r'…', "...", string)
59 |
60 | return string
61 |
--------------------------------------------------------------------------------
/utils/operations/operation_tensor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def move_tensors_to_device(input, device):
5 | if input is None:
6 | return input
7 |
8 | elif isinstance(input, dict):
9 | for key, value in input.items():
10 | if isinstance(value, torch.Tensor):
11 | input[key] = value.to(device)
12 | return input
13 |
14 | elif isinstance(input, list):
15 | for i in range(len(input)):
16 | if isinstance(input[i], torch.Tensor):
17 | input[i] = input[i].to(device)
18 | return input
19 |
20 | elif isinstance(input, torch.Tensor):
21 | return input.to(device)
22 |
23 | else:
24 | raise TypeError(input)
25 |
26 |
27 | def tensor2numbers(input):
28 | if input is None:
29 | return input
30 |
31 | elif isinstance(input, dict):
32 | for key, value in input.items():
33 | if isinstance(value, torch.Tensor):
34 | input[key] = value.tolist()
35 | return input
36 |
37 | elif isinstance(input, list):
38 | for i in range(len(input)):
39 | if isinstance(input[i], torch.Tensor):
40 | input[i] = input[i].tolist()
41 | return input
42 |
43 | elif isinstance(input, torch.Tensor):
44 | return input.tolist()
45 |
46 | else:
47 | raise TypeError(input)
48 |
49 |
50 | def turn_last_true_mask_to_false(mask, true_mask_cnt=None):
51 | """Turn the last true value to false for each row in a mask matrix."""
52 | # mask: shape(batch_size, seq_len)
53 | if true_mask_cnt is None:
54 | true_mask_cnt = torch.sum(mask, dim=1).unsqueeze(1)
55 | turn_position_indices = (mask.cumsum(dim=1) == true_mask_cnt)
56 | converted_mask = mask.clone()
57 | converted_mask[turn_position_indices] = False
58 | return converted_mask
59 |
60 |
61 | def turn_first_true_mask_to_false(mask):
62 | """Turn the first true value to false for each row in a mask matrix."""
63 | # mask: shape(batch_size, seq_len)
64 | turn_position_indices = (mask.cumsum(dim=1) == 1)
65 | converted_mask = mask.clone()
66 | converted_mask[turn_position_indices] = False
67 | return converted_mask
68 |
69 |
70 | def last_true_position(mask):
71 | """Return the index of the last true value in each row in a mask matrix."""
72 | # mask: shape(batch_size, seq_len)
73 | true_mask_cnt = torch.sum(mask, dim=1).unsqueeze(1)
74 | last_true_mask = (mask.cumsum(dim=1) == true_mask_cnt) & mask
75 | last_true_position = last_true_mask.nonzero()[:, 1].unsqueeze(1)
76 | return last_true_position
77 |
78 |
79 | def pass_kernel_function(tensor, criterion, allow_nan=False):
80 | if criterion == "plain":
81 | return tensor
82 | elif criterion == "sqrt":
83 | if not allow_nan and torch.any(tensor < 0):
84 | raise ValueError("Detected negative value in the tensor! This will cause the result to be \"nan\"!")
85 | return torch.sqrt(tensor)
86 | elif criterion == "l1":
87 | return torch.abs(tensor)
88 | elif criterion == "l2":
89 | return tensor * tensor
90 | else:
91 | raise NotImplementedError
92 |
--------------------------------------------------------------------------------
/utils/property_scores/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/utils/property_scores/__init__.py
--------------------------------------------------------------------------------
/utils/property_scores/fpscores.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/utils/property_scores/fpscores.pkl.gz
--------------------------------------------------------------------------------
/utils/property_scores/nspdk.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 |
3 | import networkx as nx
4 | import numpy as np
5 | from eden.graph import vectorize
6 | from rdkit import Chem
7 | from sklearn.metrics.pairwise import pairwise_kernels
8 | from tqdm import tqdm
9 |
10 |
11 | def single_mol_to_nx(mol):
12 | if mol is not None:
13 | G = nx.Graph()
14 | for atom in mol.GetAtoms():
15 | G.add_node(atom.GetIdx(), label=atom.GetSymbol())
16 | for bond in mol.GetBonds():
17 | G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), label=int(bond.GetBondTypeAsDouble()))
18 | return G
19 | return None
20 |
21 |
22 | def mols_to_nx(mols, n_jobs=None):
23 | # convert with multiprocessing support
24 | if n_jobs is not None:
25 | pool = multiprocessing.Pool(processes=n_jobs)
26 | nx_graphs = pool.map(single_mol_to_nx, mols)
27 | pool.close()
28 | pool.join()
29 | return [graph for graph in tqdm(nx_graphs, desc='Converting Molecules to Graphs') if graph is not None]
30 | else:
31 | nx_graphs = [single_mol_to_nx(mol) for mol in tqdm(mols, desc='Converting Molecules to Graphs')]
32 | return [graph for graph in nx_graphs if graph is not None]
33 |
34 |
35 | def single_graph_to_vector(graph):
36 | return vectorize(graph, complexity=4, discrete=True).toarray()
37 |
38 |
39 | def compute_nspdk_mmd(samples1, samples2, metric='linear', n_jobs=None):
40 | # code adapted from https://github.com/idea-iitd/graphgen/blob/master/metrics/mmd.py
41 | # convert with multiprocessing support
42 | if n_jobs is not None:
43 | pool = multiprocessing.Pool(processes=n_jobs)
44 | vectors1 = pool.map(single_graph_to_vector, [[sample] for sample in samples1])
45 | vectors2 = pool.map(single_graph_to_vector, [[sample] for sample in samples2])
46 | pool.close()
47 | pool.join()
48 | vectors1 = np.concatenate([vector for vector in tqdm(vectors1, desc='Vectorization...')], axis=0)
49 | vectors2 = np.concatenate([vector for vector in tqdm(vectors2, desc='Vectorization...')], axis=0)
50 | else:
51 | print("Vectorization...")
52 | vectors1 = vectorize(samples1, complexity=4, discrete=True).toarray()
53 | vectors2 = vectorize(samples2, complexity=4, discrete=True).toarray()
54 |
55 | print("Computing X...")
56 | X = pairwise_kernels(vectors1, None, metric=metric, n_jobs=n_jobs)
57 | print(f"X={X}")
58 |
59 | print("Computing Y...")
60 | Y = pairwise_kernels(vectors2, None, metric=metric, n_jobs=n_jobs)
61 | print(f"Y={Y}")
62 |
63 | print("Computing Z...")
64 | Z = pairwise_kernels(vectors1, vectors2, metric=metric, n_jobs=n_jobs)
65 | print(f"Z={Z}")
66 |
67 | return np.average(X) + np.average(Y) - 2 * np.average(Z)
68 |
69 |
70 | def get_npsdk(smiles_list1, smiles_list2, metric='linear', n_jobs=None):
71 | nx_graphs1 = mols_to_nx([Chem.MolFromSmiles(smile) for smile in tqdm(smiles_list1, desc='Converting SMILES to Molecules')], n_jobs=n_jobs)
72 | nx_graphs2 = mols_to_nx([Chem.MolFromSmiles(smile) for smile in tqdm(smiles_list2, desc='Converting SMILES to Molecules')], n_jobs=n_jobs)
73 | nx_graphs1_remove_empty = [G for G in nx_graphs1 if not G.number_of_nodes() == 0]
74 | nx_graphs2_remove_empty = [G for G in nx_graphs2 if not G.number_of_nodes() == 0]
75 | nspdk_mmd = compute_nspdk_mmd(nx_graphs1_remove_empty, nx_graphs2_remove_empty, metric=metric, n_jobs=n_jobs)
76 | return nspdk_mmd
77 |
--------------------------------------------------------------------------------
/utils/property_scores/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5 | # Peter Ertl and Ansgar Schuffenhauer
6 | # Journal of Cheminformatics 1:8 (2009)
7 | # http://www.jcheminf.com/content/1/1/8
8 | #
9 | # several small modifications to the original paper are included
10 | # particularly slightly different formula for marocyclic penalty
11 | # and taking into account also molecule symmetry (fingerprint density)
12 | #
13 | # for a set of 10k diverse molecules the agreement between the original method
14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
15 | #
16 | # peter ertl & greg landrum, september 2013
17 | #
18 | from __future__ import print_function
19 |
20 | import _pickle as cPickle
21 | import os.path as op
22 |
23 | import math
24 | from rdkit import Chem
25 | from rdkit.Chem import rdMolDescriptors
26 | from rdkit.six import iteritems
27 |
28 | _fscores = None
29 |
30 |
31 | def readFragmentScores(name='fpscores'):
32 | import gzip
33 | global _fscores
34 | # generate the full path filename:
35 | if name == "fpscores":
36 | name = op.join(op.dirname(__file__), name)
37 | _fscores = cPickle.load(gzip.open('%s.pkl.gz' % name))
38 | outDict = {}
39 | for i in _fscores:
40 | for j in range(1, len(i)):
41 | outDict[i[j]] = float(i[0])
42 | _fscores = outDict
43 |
44 |
45 | def numBridgeheadsAndSpiro(mol, ri=None):
46 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
47 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
48 | return nBridgehead, nSpiro
49 |
50 |
51 | def calculateScore(m):
52 | if _fscores is None:
53 | readFragmentScores()
54 |
55 | # fragment score
56 | fp = rdMolDescriptors.GetMorganFingerprint(m,
57 | 2) # <- 2 is the *radius* of the circular fingerprint
58 | fps = fp.GetNonzeroElements()
59 | score1 = 0.
60 | nf = 0
61 | for bitId, v in iteritems(fps):
62 | nf += v
63 | sfp = bitId
64 | score1 += _fscores.get(sfp, -4) * v
65 | score1 /= nf
66 |
67 | # features score
68 | nAtoms = m.GetNumAtoms()
69 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
70 | ri = m.GetRingInfo()
71 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
72 | nMacrocycles = 0
73 | for x in ri.AtomRings():
74 | if len(x) > 8:
75 | nMacrocycles += 1
76 |
77 | sizePenalty = nAtoms ** 1.005 - nAtoms
78 | stereoPenalty = math.log10(nChiralCenters + 1)
79 | spiroPenalty = math.log10(nSpiro + 1)
80 | bridgePenalty = math.log10(nBridgeheads + 1)
81 | macrocyclePenalty = 0.
82 | # ---------------------------------------
83 | # This differs from the paper, which defines:
84 | # macrocyclePenalty = math.log10(nMacrocycles+1)
85 | # This form generates better results when 2 or more macrocycles are present
86 | if nMacrocycles > 0:
87 | macrocyclePenalty = math.log10(2)
88 |
89 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
90 |
91 | # correction for the fingerprint density
92 | # not in the original publication, added in version 1.1
93 | # to make highly symmetrical molecules easier to synthetise
94 | score3 = 0.
95 | if nAtoms > len(fps):
96 | score3 = math.log(float(nAtoms) / len(fps)) * .5
97 |
98 | sascore = score1 + score2 + score3
99 |
100 | # need to transform "raw" value into scale between 1 and 10
101 | min = -4.0
102 | max = 2.5
103 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
104 | # smooth the 10-end
105 | if sascore > 8.:
106 | sascore = 8. + math.log(sascore + 1. - 9.)
107 | if sascore > 10.:
108 | sascore = 10.0
109 | elif sascore < 1.:
110 | sascore = 1.0
111 |
112 | return sascore
113 |
114 |
115 | def processMols(mols):
116 | print('smiles\tName\tsa_score')
117 | for i, m in enumerate(mols):
118 | if m is None:
119 | continue
120 |
121 | s = calculateScore(m)
122 |
123 | smiles = Chem.MolToSmiles(m)
124 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
125 |
126 |
127 | if __name__ == '__main__':
128 | import sys, time
129 |
130 | t1 = time.time()
131 | readFragmentScores("fpscores")
132 | t2 = time.time()
133 |
134 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
135 | t3 = time.time()
136 | processMols(suppl)
137 | t4 = time.time()
138 |
139 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
140 | file=sys.stderr)
141 |
142 |
143 | #
144 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
145 | # All rights reserved.
146 | #
147 | # Redistribution and use in source and binary forms, with or without
148 | # modification, are permitted provided that the following conditions are
149 | # met:
150 | #
151 | # * Redistributions of source code must retain the above copyright
152 | # notice, this list of conditions and the following disclaimer.
153 | # * Redistributions in binary form must reproduce the above
154 | # copyright notice, this list of conditions and the following
155 | # disclaimer in the documentation and/or other materials provided
156 | # with the distribution.
157 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
158 | # nor the names of its contributors may be used to endorse or promote
159 | # products derived from this software without specific prior written permission.
160 | #
161 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
162 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
163 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
164 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
165 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
166 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
167 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
168 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
169 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
170 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
171 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
172 | #
173 |
174 | def compute_sa_score(rdmol):
175 | rdmol = Chem.MolFromSmiles(Chem.MolToSmiles(rdmol))
176 | sa = calculateScore(rdmol)
177 | sa_norm = round((10 - sa) / 9, 2)
178 | return sa_norm
179 |
--------------------------------------------------------------------------------
/utils/property_scores/scoring_func.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | from rdkit import Chem
4 | from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski, Mol
5 | from rdkit.Chem.QED import qed
6 |
7 | from utils.property_scores.sascorer import compute_sa_score
8 | from utils.property_scores.tanimoto_similarity import MolsTanimotoSimilarity
9 |
10 |
11 | def fix_explicit_hs(mol: Mol) -> Mol:
12 | # rdkit has a problem with implicit hs. By default there are only explicit hs.
13 | # This is a hack to fix this error
14 | for a in mol.GetAtoms():
15 | a.SetNoImplicit(False)
16 |
17 | mol = Chem.AddHs(mol, explicitOnly=True)
18 | mol = Chem.RemoveHs(mol)
19 |
20 | Chem.SanitizeMol(mol)
21 | return mol
22 |
23 |
24 | def get_basic(mol):
25 | n_atoms = len(mol.GetAtoms())
26 | n_bonds = len(mol.GetBonds())
27 | n_rings = len(Chem.GetSymmSSSR(mol))
28 | weight = Descriptors.ExactMolWt(mol)
29 | return n_atoms, n_bonds, n_rings, weight
30 |
31 |
32 | def get_is_valid(mol):
33 | smiles = Chem.MolToSmiles(mol)
34 | mol = Chem.MolFromSmiles(smiles)
35 | if mol is None:
36 | return False
37 | try:
38 | Chem.SanitizeMol(mol)
39 | except ValueError:
40 | return False
41 | return True
42 |
43 |
44 | def get_qed(mol):
45 | # mol = fix_explicit_hs(mol)
46 | # return qed(mol)
47 | try:
48 | mol = deepcopy(mol)
49 | try:
50 | AllChem.Kekulize(mol, clearAromaticFlags=True)
51 | except:
52 | pass
53 | mol.UpdatePropertyCache(strict=False)
54 | return qed(mol)
55 | except Exception as e:
56 | print(e)
57 | return None
58 |
59 |
60 | def get_sa(mol):
61 | try:
62 | mol = deepcopy(mol)
63 | try:
64 | AllChem.Kekulize(mol, clearAromaticFlags=True)
65 | except:
66 | pass
67 | mol.UpdatePropertyCache(strict=False)
68 | return compute_sa_score(mol)
69 | except Exception as e:
70 | print(e)
71 | return None
72 |
73 |
74 | def get_logp(mol):
75 | try:
76 | mol = deepcopy(mol)
77 | try:
78 | AllChem.Kekulize(mol, clearAromaticFlags=True)
79 | except:
80 | pass
81 | mol.UpdatePropertyCache(strict=False)
82 | return Crippen.MolLogP(mol)
83 | except Exception as e:
84 | print(e)
85 | return None
86 |
87 |
88 | def get_lipinski(mol):
89 | try:
90 | mol = deepcopy(mol)
91 | try:
92 | Chem.SanitizeMol(mol)
93 | except:
94 | pass
95 | mol.UpdatePropertyCache(strict=False)
96 | rule_1 = Descriptors.ExactMolWt(mol) < 500
97 | rule_2 = Lipinski.NumHDonors(mol) <= 5
98 | rule_3 = Lipinski.NumHAcceptors(mol) <= 10
99 | logp = Crippen.MolLogP(mol)
100 | rule_4 = (logp >= -2) & (logp <= 5)
101 | rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
102 | return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
103 | except Exception as e:
104 | print(e)
105 | return None
106 |
107 |
108 | def get_hacc(mol):
109 | return Lipinski.NumHAcceptors(mol)
110 |
111 |
112 | def get_hdon(mol):
113 | return Lipinski.NumHDonors(mol)
114 |
115 |
116 | def get_rdkit_rmsd(mol, n_conf=20, random_seed=42):
117 | """
118 | Calculate the alignment of generated mol and rdkit predicted mol
119 | Return the rmsd (max, min, median) of the `n_conf` rdkit conformers
120 | """
121 | mol = deepcopy(mol)
122 | Chem.SanitizeMol(mol)
123 | mol3d = Chem.AddHs(mol)
124 | rmsd_list = []
125 | # predict 3d
126 | confIds = AllChem.EmbedMultipleConfs(mol3d, n_conf, randomSeed=random_seed)
127 | for confId in confIds:
128 | AllChem.UFFOptimizeMolecule(mol3d, confId=confId)
129 | rmsd = Chem.rdMolAlign.GetBestRMS(mol, mol3d, refId=confId)
130 | rmsd_list.append(rmsd)
131 | # mol3d = Chem.RemoveHs(mol3d)
132 | rmsd_list = np.array(rmsd_list)
133 | return [np.max(rmsd_list), np.min(rmsd_list), np.median(rmsd_list)]
134 |
135 |
136 | def get_tanimoto_similarity(mol_list):
137 | tanimoto_calculator = MolsTanimotoSimilarity(mol_list)
138 | similarity_matrix = tanimoto_calculator.get_tanimoto_similarity_matrix()
139 | top_similarities = tanimoto_calculator.get_top_tanimoto_similarities()
140 |
141 | return similarity_matrix, top_similarities
142 |
--------------------------------------------------------------------------------
/utils/property_scores/tanimoto_similarity.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rdkit import DataStructs
3 | from rdkit.Chem import AllChem
4 |
5 |
6 | class MolsTanimotoSimilarity:
7 | def __init__(self, mols):
8 | self.mols = mols
9 | self.fingerprints = [AllChem.GetMorganFingerprintAsBitVect(mol, 2) for mol in mols] # get fingerprints
10 | self.tanimoto_similarity_matrix = np.zeros((len(mols), len(mols)))
11 |
12 | def get_tanimoto_similarity_matrix(self):
13 | # calculate similarity for each mol pair
14 | for i in range(0, len(self.mols)):
15 | for j in range(i + 1, len(self.mols)):
16 | similarity = DataStructs.TanimotoSimilarity(self.fingerprints[i], self.fingerprints[j])
17 | self.tanimoto_similarity_matrix[i, j] = similarity
18 |
19 | # complete the similarity matrix
20 | lower_indices = np.tril_indices(len(self.mols), -1)
21 | self.tanimoto_similarity_matrix[lower_indices] = self.tanimoto_similarity_matrix.T[lower_indices]
22 | return self.tanimoto_similarity_matrix
23 |
24 | def get_top_tanimoto_similarities(self):
25 | top_similarities = np.max(self.tanimoto_similarity_matrix, axis=1)
26 | return top_similarities.tolist()
27 |
--------------------------------------------------------------------------------
/utils/representation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/A4Bio/GraphsGPT/e71647c5a9e19322b4166bb7922c604778a71e1e/utils/representation/__init__.py
--------------------------------------------------------------------------------
/utils/representation/data_interface.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | from torch.utils.data import DataLoader
4 |
5 | from .load_dataset import DatasetTask
6 |
7 |
8 | class MyDataLoader(DataLoader):
9 | def __init__(self, dataset, batch_size=64, num_workers=8, data_task=None, split='train', *args, **kwargs):
10 | super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, *args, **kwargs)
11 | self.pretrain_device = 'cuda:0'
12 | self.split = split
13 | self.data_task = data_task
14 |
15 | def __iter__(self):
16 | for batch in super().__iter__():
17 | try:
18 | self.pretrain_device = f'cuda:{torch.distributed.get_rank()}'
19 | except:
20 | self.pretrain_device = 'cuda:0'
21 |
22 | stream = torch.cuda.Stream(
23 | self.pretrain_device
24 | )
25 | with torch.cuda.stream(stream):
26 | sample = {}
27 | for key in batch[0].keys():
28 | sample[key] = [one[key] for one in batch]
29 | if type(sample[key][0]) == torch.Tensor:
30 | sample[key] = torch.stack(sample[key], dim=0).cuda(non_blocking=True, device=self.pretrain_device)
31 |
32 | yield sample
33 |
34 |
35 | def collate_fn(batch):
36 | return batch
37 |
38 |
39 | class DInterface_base(pl.LightningDataModule):
40 | def __init__(self, **kwargs):
41 | super().__init__()
42 | self.save_hyperparameters()
43 | self.batch_size = self.hparams.batch_size
44 | print("batch_size", self.batch_size)
45 |
46 | def setup(self, stage=None):
47 | # Assign train/val datasets for use in dataloaders
48 | if stage == 'fit' or stage is None:
49 | self.trainset = self.data_task.datasets['train']
50 | self.valset = self.data_task.datasets['valid']
51 |
52 | # Assign test dataset for use in dataloader(s)
53 | if stage == 'test' or stage is None:
54 | self.testset = self.data_task.datasets['test']
55 |
56 | def train_dataloader(self):
57 | return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3)
58 |
59 | def val_dataloader(self):
60 | return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
61 |
62 | def test_dataloader(self):
63 | return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
64 |
65 |
66 | class DInterface(DInterface_base):
67 | def __init__(self, **kwargs):
68 | super().__init__(**kwargs)
69 | self.save_hyperparameters()
70 | self.data_task = DatasetTask(self.hparams)
71 | self.data_task.load_dataset_mix_enc_dec('train')
72 | self.data_task.load_dataset_mix_enc_dec('valid')
73 | self.data_task.load_dataset_mix_enc_dec('test')
74 |
75 | def train_dataloader(self):
76 | return MyDataLoader(self.trainset, batch_size=self.batch_size, split='train', num_workers=self.hparams.num_workers, data_task=self.data_task, shuffle=True, prefetch_factor=8, pin_memory=True, collate_fn=collate_fn)
77 |
78 | def val_dataloader(self):
79 | return MyDataLoader(self.testset, batch_size=self.batch_size, split='valid', num_workers=self.hparams.num_workers, data_task=self.data_task, shuffle=False, pin_memory=True, collate_fn=collate_fn)
80 |
81 | def test_dataloader(self):
82 | return MyDataLoader(self.testset, batch_size=self.batch_size, split='test', num_workers=self.hparams.num_workers, data_task=self.data_task, shuffle=False, pin_memory=True, collate_fn=collate_fn)
83 |
--------------------------------------------------------------------------------
/utils/representation/dict.txt:
--------------------------------------------------------------------------------
1 | [PAD]
2 | [CLS]
3 | [SEP]
4 | [UNK]
5 | C
6 | N
7 | O
8 | S
9 | H
10 | Cl
11 | F
12 | Br
13 | I
14 | Si
15 | P
16 | B
17 | Na
18 | K
19 | Al
20 | Ca
21 | Sn
22 | As
23 | Hg
24 | Fe
25 | Zn
26 | Cr
27 | Se
28 | Gd
29 | Au
30 | Li
31 | [MASK]
--------------------------------------------------------------------------------
/utils/representation/dictionary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import logging
7 |
8 | import numpy as np
9 |
10 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
11 |
12 |
13 | class Dictionary:
14 | """A mapping from symbols to consecutive integers"""
15 |
16 | def __init__(
17 | self,
18 | *, # begin keyword-only arguments
19 | bos="[CLS]",
20 | pad="[PAD]",
21 | eos="[SEP]",
22 | unk="[UNK]",
23 | extra_special_symbols=None,
24 | ):
25 | self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
26 | self.symbols = []
27 | self.count = []
28 | self.indices = {}
29 | self.specials = set()
30 | self.specials.add(bos)
31 | self.specials.add(unk)
32 | self.specials.add(pad)
33 | self.specials.add(eos)
34 |
35 | def __eq__(self, other):
36 | return self.indices == other.indices
37 |
38 | def __getitem__(self, idx):
39 | if idx < len(self.symbols):
40 | return self.symbols[idx]
41 | return self.unk_word
42 |
43 | def __len__(self):
44 | """Returns the number of symbols in the dictionary"""
45 | return len(self.symbols)
46 |
47 | def __contains__(self, sym):
48 | return sym in self.indices
49 |
50 | def vec_index(self, a):
51 | return np.vectorize(self.index)(a)
52 |
53 | def index(self, sym):
54 | """Returns the index of the specified symbol"""
55 | assert isinstance(sym, str)
56 | if sym in self.indices:
57 | return self.indices[sym]
58 | return self.indices[self.unk_word]
59 |
60 | def special_index(self):
61 | return [self.index(x) for x in self.specials]
62 |
63 | def add_symbol(self, word, n=1, overwrite=False, is_special=False):
64 | """Adds a word to the dictionary"""
65 | if is_special:
66 | self.specials.add(word)
67 | if word in self.indices and not overwrite:
68 | idx = self.indices[word]
69 | self.count[idx] = self.count[idx] + n
70 | return idx
71 | else:
72 | idx = len(self.symbols)
73 | self.indices[word] = idx
74 | self.symbols.append(word)
75 | self.count.append(n)
76 | return idx
77 |
78 | def bos(self):
79 | """Helper to get index of beginning-of-sentence symbol"""
80 | return self.index(self.bos_word)
81 |
82 | def pad(self):
83 | """Helper to get index of pad symbol"""
84 | return self.index(self.pad_word)
85 |
86 | def eos(self):
87 | """Helper to get index of end-of-sentence symbol"""
88 | return self.index(self.eos_word)
89 |
90 | def unk(self):
91 | """Helper to get index of unk symbol"""
92 | return self.index(self.unk_word)
93 |
94 | @classmethod
95 | def load(cls, f):
96 | """Loads the dictionary from a text file with the format:
97 |
98 | ```
99 |
100 |
101 | ...
102 | ```
103 | """
104 | d = cls()
105 | d.add_from_file(f)
106 | return d
107 |
108 | def add_from_file(self, f):
109 | """
110 | Loads a pre-existing dictionary from a text file and adds its symbols
111 | to this instance.
112 | """
113 | if isinstance(f, str):
114 | try:
115 | with open(f, "r", encoding="utf-8") as fd:
116 | self.add_from_file(fd)
117 | except FileNotFoundError as fnfe:
118 | raise fnfe
119 | except UnicodeError:
120 | raise Exception(
121 | "Incorrect encoding detected in {}, please "
122 | "rebuild the dataset".format(f)
123 | )
124 | return
125 |
126 | lines = f.readlines()
127 |
128 | for line_idx, line in enumerate(lines):
129 | try:
130 | splits = line.rstrip().rsplit(" ", 1)
131 | line = splits[0]
132 | field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx)
133 | if field == "#overwrite":
134 | overwrite = True
135 | line, field = line.rsplit(" ", 1)
136 | else:
137 | overwrite = False
138 | count = int(field)
139 | word = line
140 | if word in self and not overwrite:
141 | logger.info(
142 | "Duplicate word found when loading Dictionary: '{}', index is {}.".format(word, self.indices[word])
143 | )
144 | else:
145 | self.add_symbol(word, n=count, overwrite=overwrite)
146 | except ValueError:
147 | raise ValueError(
148 | "Incorrect dictionary format, expected ' [flags]'"
149 | )
150 |
--------------------------------------------------------------------------------
/utils/representation/fingerprint_interface.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from rdkit import Chem
5 | from rdkit.Chem import rdFingerprintGenerator
6 |
7 | from data.tokenizer import GraphsGPTTokenizer
8 | from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM
9 |
10 | CLASS_OF_FPGEN = {
11 | "fingerprint-morgan": rdFingerprintGenerator.GetMorganGenerator,
12 | "fingerprint-rdkit": rdFingerprintGenerator.GetRDKitFPGenerator,
13 | "fingerprint-ap": rdFingerprintGenerator.GetAtomPairGenerator,
14 | "fingerprint-tt": rdFingerprintGenerator.GetTopologicalTorsionGenerator,
15 | }
16 |
17 |
18 | class FPModel(nn.Module):
19 | def __init__(self, model_name) -> None:
20 | super().__init__()
21 | self.model_name = model_name
22 | if model_name == 'graphsgpt':
23 | self.model = GraphsGPTForCausalLM.from_pretrained("DaizeDong/GraphsGPT-1W")
24 | self.tokenizer = GraphsGPTTokenizer.from_pretrained("DaizeDong/GraphsGPT-1W")
25 | self.model.cuda()
26 | self.model.eval()
27 |
28 | if model_name == 'morgan':
29 | self.radius = 2
30 | self.fpsize = 2048
31 | self.count = 0
32 |
33 | self.memory = {}
34 |
35 | def gen_fp(self, mol, fpgen, count=False):
36 | if count:
37 | g = fpgen.GetCountFingerprintAsNumPy
38 | else:
39 | g = fpgen.GetFingerprintAsNumPy
40 | return np.array(g(mol))
41 |
42 | def forward(self, smiles_list):
43 | if all([one in self.memory for one in smiles_list]):
44 | fingerprint_tokens = torch.stack([self.memory[one] for one in smiles_list], dim=0)
45 | return fingerprint_tokens
46 |
47 | if self.model_name == 'graphsgpt':
48 | with torch.no_grad():
49 | batch = self.tokenizer.batch_encode(smiles_list, return_tensors="pt")
50 | inputs = batch['batched_tokens']
51 | inputs = {k: v.cuda() for k, v in inputs.items()}
52 | fingerprint_tokens = self.model.encode_to_fingerprints(**inputs)
53 | fingerprint_tokens = fingerprint_tokens[:, 0].cpu()
54 |
55 | if self.model_name == 'morgan':
56 | fps = []
57 | fpgen = CLASS_OF_FPGEN["fingerprint-morgan"](radius=self.radius, fpSize=self.fpsize)
58 | for i in range(len(smiles_list)):
59 | mol = Chem.MolFromSmiles(smiles_list[i])
60 | fps.append(self.gen_fp(mol, fpgen, count=self.count))
61 | fingerprint_embeddings = np.array(fps).astype(float)
62 |
63 | fingerprint_tokens = torch.from_numpy(fingerprint_embeddings)
64 |
65 | for i, one in enumerate(smiles_list):
66 | self.memory[one] = fingerprint_tokens[i]
67 |
68 | return fingerprint_tokens
69 |
--------------------------------------------------------------------------------
/utils/representation/graphsgpt_finetune_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from data.tokenizer import GraphsGPTTokenizer
5 | from models.graphsgpt.configuration_graphsgpt import GraphsGPTConfig
6 | from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM
7 |
8 |
9 | class GraphsGPT(nn.Module):
10 | def __init__(self, task_name, mixup_strategy, pooler_dropout):
11 | super().__init__()
12 | if mixup_strategy == 'no_mix_vanilla':
13 | config = GraphsGPTConfig.from_pretrained("DaizeDong/GraphsGPT-1W")
14 | self.model = GraphsGPTForCausalLM(config)
15 | else:
16 | self.model = GraphsGPTForCausalLM.from_pretrained("DaizeDong/GraphsGPT-1W")
17 | self.tokenizer = GraphsGPTTokenizer.from_pretrained("DaizeDong/GraphsGPT-1W")
18 |
19 | if task_name in ['bace', 'bbbp', 'clintox', 'hiv']:
20 | num_classes = 2
21 |
22 | if task_name in ['esol', 'freesolv', 'lipo', 'qm7dft']:
23 | num_classes = 1
24 |
25 | if task_name in ['qm8dft']:
26 | num_classes = 12
27 |
28 | if task_name in ['qm9dft']:
29 | num_classes = 3
30 |
31 | if task_name in ['sider']:
32 | num_classes = 27
33 |
34 | if task_name in ['tox21']:
35 | num_classes = 12
36 |
37 | if task_name in ['toxcast']:
38 | num_classes = 617
39 |
40 | if task_name in ['muv']:
41 | num_classes = 17
42 |
43 | self.classification_heads = ClassificationHead(
44 | input_dim=512,
45 | inner_dim=512,
46 | num_classes=num_classes,
47 | pooler_dropout=pooler_dropout,
48 | )
49 |
50 | def forward(self, mols_list, mix_embeds=None, mixup_lam=None, mixup_index=None):
51 | batch = self.tokenizer.batch_encode(mols_list, return_tensors="pt")
52 | inputs = {k: v.cuda() for k, v in batch.items()}
53 | fingerprint_tokens = self.model.encode_to_fingerprints(**inputs)
54 | logit_output = self.classification_heads(fingerprint_tokens)
55 | return {'logit_output': logit_output}
56 |
57 |
58 | class ClassificationHead(nn.Module):
59 | """Head for sentence-level classification tasks."""
60 |
61 | def __init__(
62 | self,
63 | input_dim,
64 | inner_dim,
65 | num_classes,
66 | pooler_dropout,
67 | ):
68 | super().__init__()
69 | self.dense = nn.Linear(input_dim, inner_dim)
70 | self.activation_fn = torch.tanh
71 | self.dropout = nn.Dropout(p=pooler_dropout)
72 | self.out_proj = nn.Linear(inner_dim, num_classes)
73 |
74 | def forward(self, features, **kwargs):
75 | x = features[:, 0, :] # take token (equiv. to [CLS])
76 | x = self.dropout(x)
77 | x = self.dense(x)
78 | x = self.activation_fn(x)
79 | x = self.dropout(x)
80 | x = self.out_proj(x)
81 | return x
82 |
--------------------------------------------------------------------------------
/utils/representation/lmdb_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 |
6 | import lmdb
7 | import logging
8 | import os
9 | import pickle
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class LMDBDataset:
15 | def __init__(self, db_path):
16 | self.db_path = db_path
17 | assert os.path.isfile(self.db_path), "{} not found".format(self.db_path)
18 | env = self.connect_db(self.db_path)
19 |
20 | with env.begin() as txn:
21 | self._keys = list(txn.cursor().iternext(values=False))
22 |
23 | def connect_db(self, lmdb_path, save_to_self=False):
24 | env = lmdb.open(
25 | lmdb_path,
26 | subdir=False,
27 | readonly=True,
28 | lock=False,
29 | readahead=False,
30 | meminit=False,
31 | max_readers=256,
32 | )
33 | if not save_to_self:
34 | return env
35 | else:
36 | self.env = env
37 |
38 | def __len__(self):
39 | return len(self._keys)
40 |
41 | def __getitem__(self, idx):
42 | env = self.connect_db(self.db_path)
43 | datapoint_pickled = env.begin().get(self._keys[idx])
44 |
45 | # avoid open too many files error.
46 | env.close()
47 | del env
48 |
49 | data = pickle.loads(datapoint_pickled)
50 | return data
51 |
--------------------------------------------------------------------------------
/utils/representation/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | from omegaconf import OmegaConf
3 | from pytorch_lightning.callbacks import Callback
4 |
5 |
6 | class SetupCallback(Callback):
7 | def __init__(self, now, logdir, ckptdir, cfgdir, config, argv_content=None):
8 | super().__init__()
9 | self.now = now
10 | self.logdir = logdir
11 | self.ckptdir = ckptdir
12 | self.cfgdir = cfgdir
13 | self.config = config
14 |
15 | self.argv_content = argv_content
16 |
17 | def on_fit_start(self, trainer, pl_module):
18 | # Create logdirs and save configs
19 | os.makedirs(self.logdir, exist_ok=True)
20 | os.makedirs(self.ckptdir, exist_ok=True)
21 | os.makedirs(self.cfgdir, exist_ok=True)
22 |
23 | print("Project config")
24 | print(OmegaConf.to_yaml(self.config))
25 | OmegaConf.save(self.config,
26 | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
27 |
28 | with open(os.path.join(self.logdir, "argv_content.txt"), "w") as f:
29 | f.write(str(self.argv_content))
30 |
--------------------------------------------------------------------------------
/utils/representation/model_interface.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import numpy as np
3 | import os
4 | import pytorch_lightning as pl
5 | import torch
6 | import torch.nn.functional as F
7 | from sklearn.metrics import roc_auc_score
8 | from torch import nn as nn
9 | from torch.optim import lr_scheduler as lrs
10 |
11 | from .graphsgpt_finetune_model import GraphsGPT
12 | from .load_dataset import task_metainfo
13 |
14 |
15 | def softmax_cross_entropy_with_softtarget(input, target, reduction='mean'):
16 | """
17 | :param input: (batch, *)
18 | :param target: (batch, *) same shape as input, each item must be a valid distribution: target[i, :].sum() == 1.
19 | """
20 | logprobs = torch.nn.functional.log_softmax(input.view(input.shape[0], -1), dim=1)
21 | batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
22 | if reduction == 'none':
23 | return batchloss
24 | elif reduction == 'mean':
25 | return torch.mean(batchloss)
26 | elif reduction == 'sum':
27 | return torch.sum(batchloss)
28 | else:
29 | raise NotImplementedError('Unsupported reduction mode.')
30 |
31 |
32 | class MInterface_base(pl.LightningModule):
33 | def __init__(self, model_name=None, loss=None, lr=None, **kargs):
34 | super().__init__()
35 | self.save_hyperparameters()
36 | self.load_model()
37 | self.configure_loss()
38 | os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True)
39 |
40 | def forward(self, input):
41 | pass
42 |
43 | def training_step(self, batch, batch_idx, **kwargs):
44 | pass
45 |
46 | def validation_step(self, batch, batch_idx):
47 | pass
48 |
49 | def test_step(self, batch, batch_idx):
50 | # Here we just reuse the validation_step for testing
51 | return self.validation_step(batch, batch_idx)
52 |
53 | def on_validation_epoch_end(self):
54 | # Make the Progress Bar leave there
55 | self.print('')
56 |
57 | def get_schedular(self, optimizer, lr_scheduler='onecycle'):
58 | if lr_scheduler == 'step':
59 | scheduler = lrs.StepLR(optimizer,
60 | step_size=self.hparams.lr_decay_steps,
61 | gamma=self.hparams.lr_decay_rate)
62 | elif lr_scheduler == 'cosine':
63 | scheduler = lrs.CosineAnnealingLR(optimizer,
64 | T_max=max(self.hparams.epoch / 5, 1))
65 | elif lr_scheduler == 'onecycle':
66 | scheduler = lrs.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=self.hparams.steps_per_epoch, epochs=self.hparams.epoch, three_phase=False)
67 | elif lr_scheduler == 'polynomial':
68 | scheduler = PolynomialDecayLRSchedule(optimizer, warmup_ratio=self.hparams.warmup_ratio, total_num_update=self.hparams.epoch * self.hparams.steps_per_epoch, lr=self.hparams.lr, end_learning_rate=0.0, power=1.0)
69 | else:
70 | raise ValueError('Invalid lr_scheduler type!')
71 |
72 | return scheduler
73 |
74 | def configure_optimizers(self):
75 | if hasattr(self.hparams, 'weight_decay'):
76 | weight_decay = self.hparams.weight_decay
77 | else:
78 | weight_decay = 0
79 |
80 | optimizer_g = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.lr, weight_decay=weight_decay, betas=(0.9, 0.99), eps=1e-8)
81 |
82 | schecular_g = self.get_schedular(optimizer_g, self.hparams.lr_scheduler)
83 |
84 | return [optimizer_g], [{"scheduler": schecular_g, "interval": "step"}]
85 |
86 | def lr_scheduler_step(self, *args, **kwargs):
87 | scheduler = self.lr_schedulers()
88 | scheduler.step()
89 |
90 | def configure_devices(self):
91 | self.device = torch.device(self.hparams.device)
92 |
93 | def configure_loss(self):
94 | self.loss_function = nn.CrossEntropyLoss(reduction='none')
95 |
96 | def load_model(self):
97 | self.model = None
98 |
99 | def instancialize(self, Model, **other_args):
100 | """ Instancialize a model using the corresponding parameters
101 | from self.hparams dictionary. You can also input any args
102 | to overwrite the corresponding value in self.hparams.
103 | """
104 | class_args = inspect.getargspec(Model.__init__).args[1:]
105 | inkeys = self.hparams.keys()
106 | args1 = {}
107 | for arg in class_args:
108 | if arg in inkeys:
109 | args1[arg] = getattr(self.hparams, arg)
110 | args1.update(other_args)
111 | return Model(**args1)
112 |
113 |
114 | class MInterface(MInterface_base):
115 | def __init__(self, model_name=None, loss=None, lr=None, **kargs):
116 | super().__init__()
117 | self.save_hyperparameters()
118 | self.load_model()
119 | self.configure_loss()
120 | os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True)
121 | self.targets = []
122 | self.preds = []
123 |
124 | def forward(self, batch, mode='eval'):
125 | results = self.model(batch['mol_raw'])
126 | targets = batch["target_raw"]
127 | logit_output = results['logit_output']
128 |
129 | results_mix = self.compute_loss(logit_output, targets, loss_type=self.hparams.loss_type)
130 | loss = results_mix['loss']
131 |
132 | return {'loss': loss, 'logits': logit_output, 'targets': targets}
133 |
134 | def training_step(self, batch, batch_idx, **kwargs):
135 | results = self(batch, 'train')
136 | self.log_dict({"train_loss": results['loss']}, on_epoch=True, prog_bar=True)
137 | return results['loss']
138 |
139 | def validation_step(self, batch, batch_idx):
140 | results = self(batch, 'eval')
141 | self.targets.append(results['targets'].float().cpu().numpy())
142 | self.preds.append(results['logits'].float().cpu().numpy())
143 | metrics = self.compute_metrics(results)
144 | metrics.update({'loss': results['loss']})
145 | val_metrics = {'test_' + key: val for key, val in metrics.items()}
146 | self.log_dict(val_metrics)
147 | return metrics
148 |
149 | def test_step(self, batch, batch_idx):
150 | metrics = self.validation_step(batch, batch_idx)
151 | return metrics
152 |
153 | def training_epoch_end(self, outputs):
154 | pass
155 |
156 | def validation_epoch_end(self, outputs):
157 | metrics = self.compute_metrics(outputs)
158 | val_metrics = {'test_' + key: val for key, val in metrics.items()}
159 | self.log_dict(val_metrics)
160 |
161 | self.targets = []
162 | self.preds = []
163 | return self.log_dict
164 |
165 | def test_epoch_end(self, outputs):
166 | metrics = self.compute_metrics(outputs)
167 | val_metrics = {'test_' + key: val for key, val in metrics.items()}
168 | self.log_dict(val_metrics)
169 |
170 | self.targets = []
171 | self.preds = []
172 | return self.log_dict
173 |
174 | def load_model(self):
175 | # =========== graphsgpt model
176 | self.model = GraphsGPT(self.hparams.task_name, self.hparams.mixup_strategy, self.hparams.pooler_dropout)
177 |
178 | def compute_metrics(self, outputs):
179 | targets = np.concatenate(self.targets)
180 | preds = np.concatenate(self.preds)
181 |
182 | if self.hparams.loss_type == 'mixup_cross_entropy':
183 | targets = targets.argmax(axis=-1)
184 | if self.hparams.num_classes == 2:
185 | if np.unique(targets).shape[0] == 1:
186 | auc = 0.0
187 | else:
188 | auc = roc_auc_score(targets, preds[:, 1])
189 | return {'auc': auc}
190 |
191 | if self.hparams.loss_type == 'mixup_mse':
192 | mean = task_metainfo[self.hparams.task_name]["mean"]
193 | std = task_metainfo[self.hparams.task_name]["std"]
194 | predicts = preds * std + mean
195 | mae = np.abs(predicts - targets).mean()
196 | mse = ((predicts - targets) ** 2).mean()
197 | return {'mae': mae, 'mse': mse, 'rmse': np.sqrt(mse)}
198 |
199 | if self.hparams.loss_type == 'mixup_smooth_mae':
200 | mean = task_metainfo[self.hparams.task_name]["mean"]
201 | std = task_metainfo[self.hparams.task_name]["std"]
202 | predicts = preds * std + mean
203 | mae = np.abs(predicts - targets).mean()
204 | mse = ((predicts - targets) ** 2).mean()
205 | return {'mae': mae, 'mse': mse, 'rmse': np.sqrt(mse)}
206 |
207 | if self.hparams.loss_type == 'mixup_multi_task_BCE':
208 | def sigmoid(z):
209 | return 1 / (1 + np.exp(-z))
210 |
211 | probs = sigmoid(preds)
212 | agg_auc_list = []
213 | for i in range(targets.shape[1]):
214 | if np.sum(targets[:, i] == 1) > 0 and np.sum(targets[:, i] == 0) > 0:
215 | # ignore nan values
216 | is_labeled = targets[:, i] > -0.5
217 | agg_auc_list.append(
218 | roc_auc_score(targets[is_labeled, i], probs[is_labeled, i])
219 | )
220 |
221 | auc = sum(agg_auc_list) / (len(agg_auc_list) + 1e-8)
222 | return {'auc': auc}
223 |
224 | def cross_entropy_loss(self, net_output, targets, reduce=True):
225 | targets = targets.view(-1)
226 | lprobs = F.log_softmax(net_output, dim=-1)
227 | lprobs = lprobs.view(-1, lprobs.size(-1))
228 | loss = F.nll_loss(
229 | lprobs,
230 | targets,
231 | reduction="sum" if reduce else "none",
232 | )
233 |
234 | return {'loss': loss}
235 |
236 | def multi_task_BCE(self, logit_output, targets, reduce=True):
237 | is_labeled = targets > -0.5
238 | pos_mask = targets[is_labeled].float().view(-1)
239 | loss = F.binary_cross_entropy_with_logits(
240 | logit_output[is_labeled].float().view(-1),
241 | targets[is_labeled].float().view(-1),
242 | reduction="sum" if reduce else "none",
243 | pos_weight=((1 - pos_mask) * 9 + 1)
244 | )
245 | return {'loss': loss}
246 |
247 | def mse(self, logit_output, targets, reduce=True):
248 | predicts_normed = logit_output.view(-1, self.hparams.num_classes).float()
249 | targets = (
250 | targets.view(-1, self.hparams.num_classes).float()
251 | )
252 |
253 | mean = task_metainfo[self.hparams.task_name]["mean"]
254 | std = task_metainfo[self.hparams.task_name]["std"]
255 |
256 | targets_mean = torch.tensor(mean, device=targets.device)
257 | targets_std = torch.tensor(std, device=targets.device)
258 | targets_normed = (targets - targets_mean) / targets_std
259 | loss = F.mse_loss(
260 | predicts_normed,
261 | targets_normed,
262 | reduction="sum" if reduce else "none",
263 | )
264 | return {'loss': loss}
265 |
266 | def smooth_mse(self, logit_output, targets, reduce=True):
267 | predicts_normed = logit_output.view(-1, self.hparams.num_classes).float()
268 | targets = (
269 | targets.view(-1, self.hparams.num_classes).float()
270 | )
271 | mean = task_metainfo[self.hparams.task_name]["mean"]
272 | std = task_metainfo[self.hparams.task_name]["std"]
273 | targets_mean = torch.tensor(mean, device=targets.device)
274 | targets_std = torch.tensor(std, device=targets.device)
275 | targets_normed = (targets - targets_mean) / targets_std
276 | loss = F.smooth_l1_loss(
277 | predicts_normed,
278 | targets_normed,
279 | reduction="sum" if reduce else "none",
280 | )
281 | return {'loss': loss}
282 |
283 | def compute_loss(self, logit_output, targets, loss_type='mixup_cross_entropy'):
284 | if loss_type == 'mixup_cross_entropy':
285 | loss = softmax_cross_entropy_with_softtarget(logit_output, targets, reduction='mean')
286 | return {'loss': loss}
287 |
288 | if loss_type == 'mixup_mse':
289 | return self.mse(logit_output, targets, reduce=True)
290 |
291 | if loss_type == 'mixup_smooth_mae':
292 | return self.smooth_mse(logit_output, targets, reduce=True)
293 |
294 | if loss_type == 'mixup_multi_task_BCE':
295 | return self.multi_task_BCE(logit_output, targets, reduce=True)
296 |
297 |
298 | class PolynomialDecayLRSchedule(lrs.LRScheduler):
299 | def __init__(self, optimizer, warmup_ratio, total_num_update, lr, end_learning_rate, power, last_epoch=-1):
300 | self.warmup_ratio = warmup_ratio # 2532
301 | self.warmup_updates = int(self.warmup_ratio * total_num_update)
302 | self.total_num_update = total_num_update # 42200
303 | self.lr = lr
304 | self.warmup_factor = 1.0 / self.warmup_updates if self.warmup_updates > 0 else 1
305 | self.end_learning_rate = end_learning_rate
306 | self.power = power
307 | super(PolynomialDecayLRSchedule, self).__init__(optimizer, last_epoch)
308 |
309 | def get_lr(self):
310 | if self._step_count < self.warmup_updates:
311 | self.warmup_factor = self._step_count / float(self.warmup_updates)
312 | return [self.warmup_factor * self.lr]
313 | elif self.last_epoch >= self.total_num_update:
314 | return [self.end_learning_rate]
315 | else:
316 | lr_range = self.lr - self.end_learning_rate
317 | pct_remaining = 1 - (self._step_count - self.warmup_updates) / (self.total_num_update - self.warmup_updates)
318 | lr = lr_range * pct_remaining ** self.power + self.end_learning_rate
319 | return [lr]
320 |
--------------------------------------------------------------------------------