├── .gitignore
├── LICENSE
├── README.md
├── dataset
├── dwar-iBond.tsv
├── novartis_acid.tsv
├── novartis_base.tsv
├── sampl6.tsv
├── sampl7.tsv
└── sampl8.tsv
├── enumerator
├── example_out.tsv
├── main.py
├── simple_smarts_pattern.tsv
└── smarts_pattern.tsv
├── finetune_pka.sh
├── image
├── inference.png
├── overview.png
├── performance.png
└── protensemble.png
├── infer_free_energy.sh
├── infer_pka.sh
├── pretrain_pka.sh
├── scripts
├── infer_mean_ensemble.py
└── preprocess_pka.py
└── unimol
├── __init__.py
├── data
├── __init__.py
├── conformer_sample_dataset.py
├── coord_pad_dataset.py
├── cropping_dataset.py
├── data_utils.py
├── distance_dataset.py
├── key_dataset.py
├── lmdb_dataset.py
├── mask_points_dataset.py
├── normalize_dataset.py
├── pka_input_dataset.py
├── remove_hydrogen_dataset.py
└── tta_dataset.py
├── examples
├── dict.txt
└── dict_charge.txt
├── infer.py
├── losses
├── __init__.py
├── mlm_loss.py
└── reg_loss.py
├── models
├── __init__.py
├── transformer_encoder_with_pair.py
├── unimol.py
└── unimol_pka.py
└── tasks
├── __init__.py
├── unimol_free_energy.py
├── unimol_mlm.py
└── unimol_pka.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Uni-p*K*a
2 | The official implementation of the model Uni-p*K*a in the paper Bridging Machine Learning and Thermodynamics for Accurate p*K*a Prediction.
3 |
4 | Interactive demo with available model weights at https://bohrium.dp.tech/notebooks/38543442597
5 |
6 | Published paper at [[JACS Au](https://pubs.acs.org/doi/10.1021/jacsau.4c00271)] | Relevant preprint at [[ChemRxiv](https://chemrxiv.org/engage/chemrxiv/article-details/64e8da3879853bbd786ca4eb)] | Small molecule protonation state ranking demo at [[Bohrium App](https://bohrium.dp.tech/apps/uni-pka)] | Full datasets at [[AISSquare](https://www.aissquare.com/datasets/detail?pageType=datasets&name=Uni-pKa-Dataset)]
7 |
8 | This machine-learning-based p*K*a prediction model achieves the state-of-the-art accuracy on several drug-like small molecule macro-p*K*a datasets.
9 | 
10 |
11 | Two core components of Uni-p*K*a framework are
12 |
13 | - A microstate enumerator to systematically build the protonation
14 | ensemble from a single structure.
15 |
16 | - A molecular machine learning model to predict the free energy of each single structure.
17 |
18 | The model reaches the expected accuracy in the inference stage after the comprehensive data preparation by the enumerator, pretraining on the ChemBL dataset and finetuning on our Dwar-iBond dataset.
19 |
20 | 
21 |
22 | ## Microstate Enumerator
23 |
24 | ### Introduction
25 |
26 | It uses iterated template-matching algorithm to enumerate all the microstates in adjacent macrostates of a molecule's protonation ensemble from at least one microstate stored as SMILES.
27 |
28 | The protonation template `smarts_pattern.tsv` modifies and augments the one in the paper [MolGpka: A Web Server for Small Molecule pKa Prediction Using a Graph-Convolutional Neural Network](https://pubs.acs.org/doi/10.1021/acs.jcim.1c00075) and its open source implementation (MIT license) in the Github repository [MolGpKa](https://github.com/Xundrug/MolGpKa/blob/master/src/utils/smarts_pattern.tsv).
29 |
30 | ### Usage
31 |
32 | The recommended environment is
33 | ```yaml
34 | python = 3.8.13
35 | rdkit = 2021.09.5
36 | numpy = 1.20.3
37 | pandas = 1.5.2
38 | ```
39 |
40 | #### Reconstruct a plain p*K*a dataset to the Uni-p*K*a standard macro-p*K*a format with fully enumerated microstates
41 |
42 | ```shell
43 | cd enumerator
44 | python main.py reconstruct -i -o -m
45 | ```
46 |
47 | The ` ` dataset is assumed be a csv-like file with a column storing SMILES. There are two cases allowed for each entry in the dataset.
48 |
49 | 1. It contains only one SMILES. The Enumerator helps to build the protonated/deprotonated macrostate and complete the original macrostate.
50 | - When `` is "A", it will be considered as an acid (thrown into A pool).
51 | - When `` is "B", it will be considered as a base (thrown into B pool).
52 | 2. It contains a string like "A1,...,Am>>B1,...Bn", where A1,...,Am are comma-separated SMILES of microstates in the acid macrostate (all thrown into A pool), and B1,...,Bn are comma-separated SMILES of microstates in the base macrostate(all thrown into B pool). The Enumerator helps to complete the both.
53 |
54 | 
55 |
56 | The `` "A" (default) or "B" determines which pool (A/B) is the reference structures and the starting point of the enumeration.
57 |
58 | The `` dataset is then constructed after the enumeration.
59 |
60 | #### Build protonation ensembles from single molecules
61 |
62 | Example:
63 | ```shell
64 | cd enumerator
65 | python main.py ensemble -i ../dataset/sampl6.tsv -o example_out.tsv -u 2 -l -2 -t simple_smarts_pattern.tsv
66 | ```
67 |
68 | The input dataset is SAMPL6 dataset as example. Reconstructed p*K*a dataset, or just any molecular dataset with an "SMILES" column with single molecular SMILES is supported as the input. In the output file, like `example_out.tsv`, columns include the original SMILES, and macrostates of total charge between the upper bound set by `-u` (default +2) and the lower bound set by `-l` (default -2). A simpler template is prepared as `simple_smarts_pattern.tsv` here for cleaner protonation ensembles which discard some rare structure motifs in the aqueous solution.
69 |
70 | ## Machine Learning Model
71 |
72 | ### Introduction
73 |
74 | It is a [Uni-Mol](https://github.com/dptech-corp/Uni-Mol)-based neural network. By embedding the neural network into thermodynamic relationship between the free energy and p*K*a throughout the training and inference stages, the framework preserves physical consistency and adapts to multiple tasks.
75 |
76 | 
77 |
78 | ### Usage
79 |
80 | #### Dependencies
81 |
82 | The dependencies of Uni-p*K*a are the same as those of Uni-Mol.
83 |
84 | - [Uni-Core](https://github.com/dptech-corp/Uni-Core), check its [Installation Documentation](https://github.com/dptech-corp/Uni-Core#installation).
85 | - rdkit==2022.9.3, install via `pip install rdkit-pypi==2022.9.3`
86 |
87 | The recommended environment is the docker image.
88 |
89 | ```
90 | docker pull dptechnology/unimol:latest-pytorch1.11.0-cuda11.3
91 | ```
92 |
93 | See details in [Uni-Mol](https://github.com/dptech-corp/Uni-Mol/tree/main/unimol#dependencies) repository.
94 |
95 |
96 | ### Ready-to-run training workflow
97 |
98 | #### Data
99 |
100 | The raw data can be downloaded from [[AISSquare](https://www.aissquare.com/datasets/detail?pageType=datasets&name=Uni-pKa-Dataset)].
101 |
102 |
103 | #### Pretrain with ChemBL
104 |
105 | First, preprocess the ChemBL training and validation sets, and then pretrain the model:
106 |
107 | ```bash
108 | # Preprocess training set
109 | python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/chembl_train.tsv --processed-lmdb-dir chembl --task-name train
110 |
111 | # Preprocess validation set
112 | python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/chembl_valid.tsv --processed-lmdb-dir chembl --task-name valid
113 |
114 | # Copy the necessary dict file
115 | cp -r unimol/examples/* chembl
116 |
117 | # Pretrain the model
118 | bash pretrain_pka.sh
119 | ```
120 |
121 | Note: The `head_name` in the subsequent scripts must match the `task_name` in `pretrain_pka.sh`.
122 |
123 |
124 | #### Finetune with dwar-iBond
125 |
126 | Next, preprocess the dwar-iBond dataset and finetune the model:
127 |
128 | ```bash
129 | # Preprocess
130 | python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/dwar-iBond.tsv --processed-lmdb-dir dwar --task-name dwar-iBond
131 |
132 | # Copy the necessary dict file
133 | cp -r unimol/examples/* dwar
134 |
135 | # Finetune the model
136 | bash finetune_pka.sh
137 | ```
138 |
139 | #### Infer p*K*a
140 |
141 | Infer with the finetuned model, taking novartis_acid as an example:
142 |
143 | ```bash
144 | # Preprocess
145 | python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/novartis_acid.tsv --processed-lmdb-dir novartis_acid --task-name novartis_acid
146 |
147 | # Copy the necessary examples from unimol
148 | cp -r unimol/examples/* novartis_acid
149 |
150 | # Run inference
151 | bash infer_pka.sh
152 | ```
153 | To test with other external test datasets, it may be necessary to modify `data_path`, `infer_task`, and `results_path` in `infer_pka.sh`.
154 |
155 | #### Obtain the result files and calculate the metrics
156 | After inference, extract the results to CSV files and calculate the performance metrics (e.g., MAE, RMSE) on the results:
157 |
158 | ```bash
159 | python ./scripts/infer_mean_ensemble.py --task pka --nfolds 5 --results-path novartis_acid_results
160 | ```
161 |
162 | The metrics are calculated using the average of the 5-fold model predictions.
--------------------------------------------------------------------------------
/dataset/sampl6.tsv:
--------------------------------------------------------------------------------
1 | SMILES TARGET ref.
2 | 0 c1cc2c(cc1O)c3c(o2)C(=O)NCCC3>>O=C1NCCCc2c1oc1ccc([O-])cc21,O=C1[N-]CCCc2c1oc1ccc(O)cc21 9.53 SAMPL6 SM01 pKa 1
3 | 1 FC(F)(F)c1cccc([NH2+]c2ncnc3ccccc23)c1,FC(F)(F)c1cccc(Nc2nc[nH+]c3ccccc23)c1,FC(F)(F)c1cccc(Nc2[nH+]cnc3ccccc23)c1>>c1ccc2c(c1)c(ncn2)Nc3cccc(c3)C(F)(F)F 5.03 SAMPL6 SM02 pKa 1
4 | 2 c1ccc(cc1)Cc2nnc(s2)NC(=O)c3cccs3>>O=C(Nc1nnc([CH-]c2ccccc2)s1)c1cccs1,O=C([N-]c1nnc(Cc2ccccc2)s1)c1cccs1 7.02 SAMPL6 SM03 pKa 1
5 | 3 Clc1ccc(C[NH2+]c2ncnc3ccccc23)cc1,Clc1ccc(CNc2nc[nH+]c3ccccc23)cc1,Clc1ccc(CNc2[nH+]cnc3ccccc23)cc1>>c1ccc2c(c1)c(ncn2)NCc3ccc(cc3)Cl 6.02 SAMPL6 SM04 pKa 1
6 | 4 [OH+]=C(Nc1ccccc1N1CCCCC1)c1ccc(Cl)o1,O=C(Nc1ccccc1[NH+]1CCCCC1)c1ccc(Cl)o1>>c1ccc(c(c1)NC(=O)c2ccc(o2)Cl)N3CCCCC3 4.59 SAMPL6 SM05 pKa 1
7 | 5 O=C(Nc1cccc2ccc[nH+]c12)c1cncc(Br)c1,O=C(Nc1cccc2cccnc12)c1c[nH+]cc(Br)c1,[OH+]=C(Nc1cccc2cccnc12)c1cncc(Br)c1>>c1cc2cccnc2c(c1)NC(=O)c3cc(cnc3)Br 3.03 SAMPL6 SM06 pKa 1
8 | 6 c1cc2cccnc2c(c1)NC(=O)c3cc(cnc3)Br>>O=C([N-]c1cccc2cccnc12)c1cncc(Br)c1 11.74 SAMPL6 SM06 pKa 2
9 | 7 c1ccc(CNc2nc[nH+]c3ccccc23)cc1,c1ccc(C[NH2+]c2ncnc3ccccc23)cc1,c1ccc(CNc2[nH+]cnc3ccccc23)cc1>>c1ccc(cc1)CNc2c3ccccc3ncn2 6.08 SAMPL6 SM07 pKa 1
10 | 8 Cc1ccc2c(c1)c(c(c(=O)[nH]2)CC(=O)O)c3ccccc3>>Cc1ccc2[nH]c(=O)c([CH-]C(=O)O)c(-c3ccccc3)c2c1,Cc1ccc2[nH]c(=O)c(CC(=O)[O-])c(-c3ccccc3)c2c1,Cc1ccc2[n-]c(=O)c(CC(=O)O)c(-c3ccccc3)c2c1 4.22 SAMPL6 SM08 pKa 1
11 | 9 COc1cccc(Nc2nc[nH+]c3ccccc23)c1,COc1cccc([NH2+]c2ncnc3ccccc23)c1,COc1cccc(Nc2[nH+]cnc3ccccc23)c1>>COc1cccc(c1)Nc2c3ccccc3ncn2 5.37 SAMPL6 SM09 pKa 1
12 | 10 c1ccc(cc1)C(=O)NCC(=O)Nc2nc3ccccc3s2>>O=C(CNC(=O)c1ccccc1)[N-]c1nc2ccccc2s1,O=C(C[N-]C(=O)c1ccccc1)Nc1nc2ccccc2s1 9.02 SAMPL6 SM10 pKa 1
13 | 11 [NH3+]c1ncnc2c1cnn2-c1ccccc1,Nc1ncnc2c1c[nH+]n2-c1ccccc1,Nc1[nH+]cnc2c1cnn2-c1ccccc1,Nc1nc[nH+]c2c1cnn2-c1ccccc1>>c1ccc(cc1)n2c3c(cn2)c(ncn3)N 3.89 SAMPL6 SM11 pKa 1
14 | 12 Clc1cccc(Nc2nc[nH+]c3ccccc23)c1,Clc1cccc([NH2+]c2ncnc3ccccc23)c1,Clc1cccc(Nc2[nH+]cnc3ccccc23)c1>>c1ccc2c(c1)c(ncn2)Nc3cccc(c3)Cl 5.28 SAMPL6 SM12 pKa 1
15 | 13 COc1cc2nc[nH+]c(Nc3cccc(C)c3)c2cc1OC,COc1cc2ncnc([NH2+]c3cccc(C)c3)c2cc1OC,COc1cc2[nH+]cnc(Nc3cccc(C)c3)c2cc1OC>>Cc1cccc(c1)Nc2c3cc(c(cc3ncn2)OC)OC 5.77 SAMPL6 SM13 pKa 1
16 | 14 [NH3+]c1ccc2c(c1)[nH+]cn2-c1ccccc1>>[NH3+]c1ccc2c(c1)ncn2-c1ccccc1,Nc1ccc2c(c1)[nH+]cn2-c1ccccc1 2.58 SAMPL6 SM14 pKa 1
17 | 15 [NH3+]c1ccc2c(c1)ncn2-c1ccccc1,Nc1ccc2c(c1)[nH+]cn2-c1ccccc1>>c1ccc(cc1)n2cnc3c2ccc(c3)N 5.3 SAMPL6 SM14 pKa 2
18 | 16 Oc1ccc(-n2c[nH+]c3ccccc32)cc1>>c1ccc2c(c1)ncn2c3ccc(cc3)O,[O-]c1ccc(-n2c[nH+]c3ccccc32)cc1 4.7 SAMPL6 SM15 pKa 1
19 | 17 c1ccc2c(c1)ncn2c3ccc(cc3)O,[O-]c1ccc(-n2c[nH+]c3ccccc32)cc1>>[O-]c1ccc(-n2cnc3ccccc32)cc1 8.94 SAMPL6 SM15 pKa 2
20 | 18 O=C(Nc1cc[nH+]cc1)c1c(Cl)cccc1Cl,[OH+]=C(Nc1ccncc1)c1c(Cl)cccc1Cl>>c1cc(c(c(c1)Cl)C(=O)Nc2ccncc2)Cl 5.37 SAMPL6 SM16 pKa 1
21 | 19 c1cc(c(c(c1)Cl)C(=O)Nc2ccncc2)Cl>>O=C([N-]c1ccncc1)c1c(Cl)cccc1Cl 10.65 SAMPL6 SM16 pKa 2
22 | 20 c1ccc(CSc2nnc(-c3cc[nH+]cc3)o2)cc1,c1ccc(CSc2[nH+]nc(-c3ccncc3)o2)cc1,c1ccc(CSc2n[nH+]c(-c3ccncc3)o2)cc1>>c1ccc(cc1)CSc2nnc(o2)c3ccncc3 3.16 SAMPL6 SM17 pKa 1
23 | 21 O=c1[nH]c(CCC(=[OH+])Nc2ncc(Cc3ccc(F)c(F)c3)s2)nc2ccccc12,O=C(CCc1nc2ccccc2c(=[OH+])[nH]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1[nH]c(=O)c2ccccc2[nH+]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1[nH+]cc(Cc2ccc(F)c(F)c2)s1>>c1ccc2c(c1)c(=O)[nH]c(n2)CCC(=O)Nc3ncc(s3)Cc4ccc(c(c4)F)F 2.15 SAMPL6 SM18 pKa 1
24 | 22 c1ccc2c(c1)c(=O)[nH]c(n2)CCC(=O)Nc3ncc(s3)Cc4ccc(c(c4)F)F>>O=C(CCc1nc2ccccc2c(=O)[n-]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1nc2ccccc2c(=O)[nH]1)[N-]c1ncc(Cc2ccc(F)c(F)c2)s1 9.58 SAMPL6 SM18 pKa 2
25 | 23 O=C(CCc1nc2ccccc2c(=O)[n-]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1nc2ccccc2c(=O)[nH]1)[N-]c1ncc(Cc2ccc(F)c(F)c2)s1>>O=C(CCc1nc2ccccc2c(=O)[n-]1)[N-]c1ncc(Cc2ccc(F)c(F)c2)s1 11.02 SAMPL6 SM18 pKa 3
26 | 24 CCOc1ccc2c(c1)sc(n2)NC(=O)Cc3ccc(c(c3)Cl)Cl>>CCOc1ccc2nc(NC(=O)[CH-]c3ccc(Cl)c(Cl)c3)sc2c1,CCOc1ccc2nc([N-]C(=O)Cc3ccc(Cl)c(Cl)c3)sc2c1 9.56 SAMPL6 SM19 pKa 1
27 | 25 c1cc(cc(c1)OCc2ccc(cc2Cl)Cl)/C=C/3\C(=O)NC(=O)S3>>O=C1[N-]C(=O)/C(=C\c2cccc(OCc3ccc(Cl)cc3Cl)c2)S1 5.7 SAMPL6 SM20 pKa 1
28 | 26 Fc1cnc(Nc2cccc(Br)c2)nc1[NH2+]c1cccc(Br)c1,Fc1cnc([NH2+]c2cccc(Br)c2)nc1Nc1cccc(Br)c1,Fc1cnc(Nc2cccc(Br)c2)[nH+]c1Nc1cccc(Br)c1,Fc1c[nH+]c(Nc2cccc(Br)c2)nc1Nc1cccc(Br)c1>>c1cc(cc(c1)Br)Nc2c(cnc(n2)Nc3cccc(c3)Br)F 4.1 SAMPL6 SM21 pKa 1
29 | 27 Oc1c(I)cc(I)c2ccc[nH+]c12>>c1cc2c(cc(c(c2nc1)O)I)I,[O-]c1c(I)cc(I)c2ccc[nH+]c12 2.4 SAMPL6 SM22 pKa 1
30 | 28 c1cc2c(cc(c(c2nc1)O)I)I,[O-]c1c(I)cc(I)c2ccc[nH+]c12>>[O-]c1c(I)cc(I)c2cccnc12 7.43 SAMPL6 SM22 pKa 2
31 | 29 CCOC(=O)c1ccc(Nc2cc(C)nc(Nc3ccc(C(=O)OCC)cc3)[nH+]2)cc1,CCOC(=O)c1ccc(Nc2cc(C)nc([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2nc(C)cc(Nc3ccc(C(=[OH+])OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2nc(C)cc([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)[nH+]c(Nc3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)nc(Nc3ccc(C(=[OH+])OCC)cc3)n2)cc1>>CCOC(=O)c1ccc(cc1)Nc2cc(nc(n2)Nc3ccc(cc3)C(=O)OCC)C 5.45 SAMPL6 SM23 pKa 1
32 | 30 COc1ccc(-c2oc3[nH+]cnc(NCCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3ncnc(NCC[OH2+])c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3ncnc([NH2+]CCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3nc[nH+]c(NCCO)c3c2-c2ccc(OC)cc2)cc1>>COc1ccc(cc1)c2c3c(ncnc3oc2c4ccc(cc4)OC)NCCO 2.6 SAMPL6 SM24 pKa 1
33 |
--------------------------------------------------------------------------------
/dataset/sampl7.tsv:
--------------------------------------------------------------------------------
1 | SMILES TARGET ref.
2 | 0 O=C(NS(C1=CC=CC=C1)(=O)=O)CCC2=CC=CC=C2>>O=C(CCc1ccccc1)[N-]S(=O)(=O)c1ccccc1 4.49 SAMPL7 SM25 pKa
3 | 1 O=S(CCC1=CC=CC=C1)(NC(C)=O)=O>>CC(=O)[N-]S(=O)(=O)CCc1ccccc1 4.91 SAMPL7 SM26 pKa
4 | 2 O=S(CCC1=CC=CC=C1)(NC2(C)COC2)=O>>CC1([N-]S(=O)(=O)CCc2ccccc2)COC1 10.45 SAMPL7 SM27 pKa
5 | 3 CS(NC1(COC1)CCC2=CC=CC=C2)(=O)=O>>CS(=O)(=O)[N-]C1(CCc2ccccc2)COC1 10.05 SAMPL7 SM29 pKa
6 | 4 O=S(NC1(COC1)CCC2=CC=CC=C2)(C3=CC=CC=C3)=O>>O=S(=O)([N-]C1(CCc2ccccc2)COC1)c1ccccc1 10.29 SAMPL7 SM30 pKa
7 | 5 O=S(NC1(COC1)CCC2=CC=CC=C2)(N(C)C)=O>>CN(C)S(=O)(=O)[N-]C1(CCc2ccccc2)COC1 11.02 SAMPL7 SM31 pKa
8 | 6 CS(NC1(CSC1)CCC2=CC=CC=C2)(=O)=O>>CS(=O)(=O)[N-]C1(CCc2ccccc2)CSC1 10.45 SAMPL7 SM32 pKa
9 | 7 O=S(NC1(CSC1)CCC2=CC=CC=C2)(N(C)C)=O>>CN(C)S(=O)(=O)[N-]C1(CCc2ccccc2)CSC1 11.93 SAMPL7 SM34 pKa
10 | 8 CS(N[C@@]1(C[S+]([O-])C1)CCC2=CC=CC=C2)(=O)=O,CS(=O)(=O)[N-]C1(CCc2ccccc2)C[S+](O)C1>>CS(=O)(=O)[N-]C1(CCc2ccccc2)C[S+]([O-])C1 9.87 SAMPL7 SM35 pKa
11 | 9 O=S(N[C@@]1(C[S+]([O-])C1)CCC2=CC=CC=C2)(C3=CC=CC=C3)=O,O=S(=O)([N-]C1(CCc2ccccc2)C[S+](O)C1)c1ccccc1>>O=S(=O)([N-]C1(CCc2ccccc2)C[S+]([O-])C1)c1ccccc1 9.8 SAMPL7 SM36 pKa
12 | 10 O=S(N[C@@]1(C[S+]([O-])C1)CCC2=CC=CC=C2)(N(C)C)=O,CN(C)S(=O)(=O)[N-]C1(CCc2ccccc2)C[S+](O)C1>>CN(C)S(=O)(=O)[N-]C1(CCc2ccccc2)C[S+]([O-])C1 10.33 SAMPL7 SM37 pKa
13 | 11 CS(NC1(CS(C1)(=O)=O)CCC2=CC=CC=C2)(=O)=O>>CS(=O)(=O)[N-]C1(CCc2ccccc2)CS(=O)(=O)C1 9.44 SAMPL7 SM38 pKa
14 | 12 O=S(NC1(CS(C1)(=O)=O)CCC2=CC=CC=C2)(C3=CC=CC=C3)=O>>O=S1(=O)CC(CCc2ccccc2)([N-]S(=O)(=O)c2ccccc2)C1 10.22 SAMPL7 SM39 pKa
15 | 13 O=S(NC1(CS(C1)(=O)=O)CCC2=CC=CC=C2)(N(C)C)=O>>CN(C)S(=O)(=O)[N-]C1(CCc2ccccc2)CS(=O)(=O)C1 9.58 SAMPL7 SM40 pKa
16 | 14 O=S(NC1=NOC(C2=CC=CC=C2)=C1)(C)=O>>CS(=O)(=O)[N-]c1cc(-c2ccccc2)on1 5.22 SAMPL7 SM41 pKa
17 | 15 O=S(NC1=NOC(C2=CC=CC=C2)=C1)(C3=CC=CC=C3)=O>>O=S(=O)([N-]c1cc(-c2ccccc2)on1)c1ccccc1 6.62 SAMPL7 SM42 pKa
18 | 16 O=S(NC1=NOC(C2=CC=CC=C2)=C1)(N(C)C)=O>>CN(C)S(=O)(=O)[N-]c1cc(-c2ccccc2)on1 5.62 SAMPL7 SM43 pKa
19 | 17 O=S(NC(N=N1)=CN1C2=CC=CC=C2)(C)=O>>CS(=O)(=O)[N-]c1cn(-c2ccccc2)nn1 6.34 SAMPL7 SM44 pKa
20 | 18 O=S(NC(N=N1)=CN1C2=CC=CC=C2)(C3=CC=CC=C3)=O>>O=S(=O)([N-]c1cn(-c2ccccc2)nn1)c1ccccc1 5.93 SAMPL7 SM45 pKa
21 | 19 O=S(NC(N=N1)=CN1C2=CC=CC=C2)(N(C)C)=O>>CN(C)S(=O)(=O)[N-]c1cn(-c2ccccc2)nn1 6.42 SAMPL7 SM46 pKa
22 |
--------------------------------------------------------------------------------
/dataset/sampl8.tsv:
--------------------------------------------------------------------------------
1 | SMILES TARGET ref.
2 | 0 O=C(O)c1ccccc1[NH2+]c1cccc(C(F)(F)F)c1,OC(=[OH+])c1ccccc1Nc1cccc(C(F)(F)F)c1>>OC(=O)c1ccccc1Nc1cccc(c1)C(F)(F)F,O=C([O-])c1ccccc1[NH2+]c1cccc(C(F)(F)F)c1 2.54 SAMPL8 SAMPL8-1 pKa
3 | 1 OC(=O)c1ccccc1Nc1cccc(c1)C(F)(F)F,O=C([O-])c1ccccc1[NH2+]c1cccc(C(F)(F)F)c1>>O=C([O-])c1ccccc1Nc1cccc(C(F)(F)F)c1,O=C(O)c1ccccc1[N-]c1cccc(C(F)(F)F)c1 5.01 SAMPL8 SAMPL8-1 pKa
4 | 2 CS(=O)(=O)c1ccc(CCC(O)=O)cc1>>CS(=O)(=O)c1ccc(CCC(=O)[O-])cc1 4.41 SAMPL8 SAMPL8-2 pKa
5 | 3 NS(=O)(=O)c1cc(C(O)=O)c(NCc2ccco2)cc1Cl,NS(=O)(=O)c1cc(C(=O)[O-])c([NH2+]Cc2ccco2)cc1Cl,[NH3+]S(=O)(=O)c1cc(C(=O)[O-])c(NCc2ccco2)cc1Cl>>[NH-]S(=O)(=O)c1cc(C(=O)O)c(NCc2ccco2)cc1Cl,NS(=O)(=O)c1cc(C(=O)O)c([N-]Cc2ccco2)cc1Cl,NS(=O)(=O)c1cc(C(=O)[O-])c(NCc2ccco2)cc1Cl 4.0 SAMPL8 SAMPL8-3 pKa
6 | 4 Cc1sc(Nc2ccc(C#[NH+])c(Cl)c2)nc1C(=O)[O-],Cc1sc([NH2+]c2ccc(C#N)c(Cl)c2)nc1C(=O)[O-],Cc1sc(Nc2ccc(C#N)c(Cl)c2)[nH+]c1C(=O)[O-],Cc1sc(Nc2ccc(C#N)c(Cl)c2)nc1C(O)=O>>Cc1sc([N-]c2ccc(C#N)c(Cl)c2)nc1C(=O)O,Cc1sc(Nc2ccc(C#N)c(Cl)c2)nc1C(=O)[O-] 5.77 SAMPL8 SAMPL8-4 pKa
7 | 5 O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,O=C([O-])Cc1ccccc1[NH2+]c1c(Cl)cccc1Cl>>[O-]C(=O)Cc1ccccc1Nc1c(Cl)cccc1Cl,O=C(O)[CH-]c1ccccc1Nc1c(Cl)cccc1Cl 3.92 SAMPL8 SAMPL8-5 pKa
8 | 6 Cc1ccc(Cl)cc1NCc1ccc(s1)C(O)=O,Cc1ccc(Cl)cc1[NH2+]Cc1ccc(C(=O)[O-])s1>>Cc1ccc(Cl)cc1NCc1ccc(C(=O)[O-])s1 4.17 SAMPL8 SAMPL8-6 pKa
9 | 7 [NH3+]c1nc2ccc(Br)cc2n1CC1(O)CCOCC1,Nc1nc2ccc(Br)cc2n1CC1([OH2+])CCOCC1,Nc1[nH+]c2ccc(Br)cc2n1CC1(O)CCOCC1>>Nc1nc2ccc(Br)cc2n1CC1(O)CCOCC1 6.63 SAMPL8 SAMPL8-7 pKa
10 | 8 COC(=O)c1cn2cccc(C(F)(F)F)c2[nH+]1,COC(=[OH+])c1cn2cccc(C(F)(F)F)c2n1>>COC(=O)c1cn2cccc(c2n1)C(F)(F)F 2.78 SAMPL8 SAMPL8-8 pKa
11 | 9 Nc1[nH+]c2ccc(Br)cc2n1CC1(O)CCCCC1,[NH3+]c1nc2ccc(Br)cc2n1CC1(O)CCCCC1,Nc1nc2ccc(Br)cc2n1CC1([OH2+])CCCCC1>>Nc1nc2ccc(Br)cc2n1CC1(O)CCCCC1 6.08 SAMPL8 SAMPL8-9 pKa
12 | 10 CS(=O)(=[OH+])c1cccc(-c2ccc(CNCc3c(F)cccc3Cl)cc2)c1,CS(=O)(=O)c1cccc(-c2ccc(C[NH2+]Cc3c(F)cccc3Cl)cc2)c1>>CS(=O)(=O)c1cccc(c1)-c1ccc(CNCc2c(F)cccc2Cl)cc1 7.71 SAMPL8 SAMPL8-10 pKa
13 | 11 Cc1ccc(Cc2cnc([NH3+])nc2N2CCOCC2)cc1,Cc1ccc(Cc2c[nH+]c(N)nc2N2CCOCC2)cc1,Cc1ccc(Cc2cnc(N)nc2[NH+]2CCOCC2)cc1,Cc1ccc(Cc2cnc(N)[nH+]c2N2CCOCC2)cc1>>Cc1ccc(Cc2cnc(N)nc2N2CCOCC2)cc1 6.98 SAMPL8 SAMPL8-12 pKa
14 | 12 Cc1cc(Cc2cnc(N)nc2[NH3+])cc(C(C)(C)C)c1O,Cc1cc(Cc2c[nH+]c(N)[nH+]c2N)cc(C(C)(C)C)c1[O-],Cc1cc(Cc2cnc(N)[nH+]c2N)cc(C(C)(C)C)c1O,Cc1cc(Cc2cnc([NH3+])nc2[NH3+])cc(C(C)(C)C)c1[O-],Cc1cc(Cc2cnc(N)[nH+]c2[NH3+])cc(C(C)(C)C)c1[O-],Cc1cc(Cc2c[nH+]c(N)nc2N)cc(C(C)(C)C)c1O,Cc1cc(Cc2cnc([NH3+])[nH+]c2N)cc(C(C)(C)C)c1[O-],Cc1cc(Cc2c[nH+]c([NH3+])nc2N)cc(C(C)(C)C)c1[O-],Cc1cc(Cc2c[nH+]c(N)nc2[NH3+])cc(C(C)(C)C)c1[O-],Cc1cc(Cc2cnc([NH3+])nc2N)cc(C(C)(C)C)c1O>>Cc1cc(Cc2cnc(N)[nH+]c2N)cc(C(C)(C)C)c1[O-],Cc1cc(Cc2cnc([NH3+])nc2N)cc(C(C)(C)C)c1[O-],Cc1cc(Cc2cnc(N)nc2N)cc(c1O)C(C)(C)C,Cc1cc(Cc2cnc(N)nc2[NH3+])cc(C(C)(C)C)c1[O-],Cc1cc(Cc2c[nH+]c(N)nc2N)cc(C(C)(C)C)c1[O-] 7.27 SAMPL8 SAMPL8-14 pKa
15 | 13 Cc1ccccc1[NH2+]c1nc(Cl)nc2ccccc12,Cc1ccccc1Nc1nc(Cl)[nH+]c2ccccc12,Cc1ccccc1Nc1[nH+]c(Cl)nc2ccccc12>>Cc1ccccc1Nc1nc(Cl)nc2ccccc12 2.54 SAMPL8 SAMPL8-15 pKa
16 | 14 CC(C)(C)OC(=[OH+])NCc1nc2ccccc2[nH]1,CC(C)(C)OC(=O)NCc1[nH]c2ccccc2[nH+]1>>CC(C)(C)OC(=O)NCc1nc2ccccc2[nH]1 5.1 SAMPL8 SAMPL8-16 pKa
17 | 15 Nc1nc2ccc(Br)cc2n1CC1(C[OH2+])CCOCC1,Nc1[nH+]c2ccc(Br)cc2n1CC1(CO)CCOCC1,[NH3+]c1nc2ccc(Br)cc2n1CC1(CO)CCOCC1>>Nc1nc2ccc(Br)cc2n1CC1(CO)CCOCC1 6.58 SAMPL8 SAMPL8-17 pKa
18 | 16 COc1ccc(Nc2nc(Cl)[nH+]c3ccccc23)cc1OC,COc1ccc([NH2+]c2nc(Cl)nc3ccccc23)cc1OC,COc1ccc(Nc2[nH+]c(Cl)nc3ccccc23)cc1OC>>COc1ccc(Nc2nc(Cl)nc3ccccc23)cc1OC 2.72 SAMPL8 SAMPL8-18 pKa
19 | 17 COc1ncc(Cc2c[nH]c(SCc3nc4ccccc4n3C)[nH+]c2=O)cn1,COc1ncc(Cc2c[nH]c(SCc3[nH+]c4ccccc4n3C)nc2=O)cn1,COc1ncc(Cc2c[nH]c(SCc3nc4ccccc4n3C)nc2=O)c[nH+]1,COc1ncc(Cc2c[nH]c(SCc3nc4ccccc4n3C)nc2=[OH+])cn1>>COc1ncc(Cc2c[nH]c(SCc3nc4ccccc4n3C)nc2=O)cn1 4.93 SAMPL8 SAMPL8-19 pKa
20 | 17 COc1ncc(Cc2c[nH]c(SCc3nc4ccccc4n3C)nc2=O)cn1>>COc1ncc(Cc2c[n-]c(SCc3nc4ccccc4n3C)nc2=O)cn1 6.99 SAMPL8 SAMPL8-19 pKa
21 | 19 [NH3+]c1n[nH]c2nc(-c3ccccc3)c(Cl)cc12,Nc1[nH+][nH]c2nc(-c3ccccc3)c(Cl)cc12,Nc1n[nH]c2[nH+]c(-c3ccccc3)c(Cl)cc12>>Nc1n[nH]c2nc(c(Cl)cc12)-c1ccccc1 2.44 SAMPL8 SAMPL8-20 pKa
22 | 20 Nc1n[nH]c2nc(c(Cl)cc12)-c1ccccc1>>Nc1n[n-]c2nc(-c3ccccc3)c(Cl)cc12,[NH-]c1n[nH]c2nc(-c3ccccc3)c(Cl)cc12 11.44 SAMPL8 SAMPL8-20 pKa
23 | 21 COc1cc(Cc2c(OC)nc([NH3+])[nH+]c2N)cc(OC)c1[O-],COc1cc(Cc2c(OC)nc(N)[nH+]c2N)cc(OC)c1O,COc1cc(Cc2c(N)nc([NH3+])[nH+]c2OC)cc(OC)c1[O-],COc1cc(Cc2c(N)nc(N)[nH+]c2OC)cc(OC)c1O,COc1cc(Cc2c(N)[nH+]c(N)[nH+]c2OC)cc(OC)c1[O-],COc1cc(Cc2c([NH3+])nc(N)nc2OC)cc(OC)c1O,COc1cc(Cc2c(N)nc([NH3+])nc2OC)cc(OC)c1O,COc1cc(Cc2c([NH3+])nc([NH3+])nc2OC)cc(OC)c1[O-],COc1cc(Cc2c([NH3+])nc(N)[nH+]c2OC)cc(OC)c1[O-],COc1cc(Cc2c(OC)nc(N)[nH+]c2[NH3+])cc(OC)c1[O-]>>COc1cc(Cc2c(N)nc([NH3+])nc2OC)cc(OC)c1[O-],COc1cc(Cc2c(N)nc(N)nc2OC)cc(OC)c1O,COc1cc(Cc2c(N)nc(N)[nH+]c2OC)cc(OC)c1[O-],COc1cc(Cc2c(OC)nc(N)[nH+]c2N)cc(OC)c1[O-],COc1cc(Cc2c([NH3+])nc(N)nc2OC)cc(OC)c1[O-] 5.38 SAMPL8 SAMPL8-21 pKa
24 | 22 COc1cc2[nH+]c(Cl)nc(N)c2cc1OC,COc1cc2nc(Cl)nc([NH3+])c2cc1OC,COc1cc2nc(Cl)[nH+]c(N)c2cc1OC>>COc1cc2nc(Cl)nc(N)c2cc1OC 3.36 SAMPL8 SAMPL8-22 pKa
25 | 23 Cc1[nH+]c2cc(O)ccc2s1>>Cc1[nH+]c2cc([O-])ccc2s1,Cc1nc2cc(O)ccc2s1 2.65 SAMPL8 SAMPL8-23 pKa
26 | 24 Cc1[nH+]c2cc([O-])ccc2s1,Cc1nc2cc(O)ccc2s1>>Cc1nc2cc([O-])ccc2s1 9.02 SAMPL8 SAMPL8-23 pKa
27 |
--------------------------------------------------------------------------------
/enumerator/example_out.tsv:
--------------------------------------------------------------------------------
1 | -2 -1 0 1 2 SMILES
2 | 0 O=C1NCCCc2c1oc1ccc([O-])cc21 O=C1NCCCc2c1oc1ccc(O)cc21 c1cc2c(cc1O)c3c(o2)C(=O)NCCC3
3 | 1 FC(F)(F)c1cccc(Nc2ncnc3ccccc23)c1 FC(F)(F)c1cccc([NH2+]c2ncnc3ccccc23)c1,FC(F)(F)c1cccc(Nc2[nH+]cnc3ccccc23)c1,FC(F)(F)c1cccc(Nc2nc[nH+]c3ccccc23)c1 FC(F)(F)c1cccc([NH2+]c2nc[nH+]c3ccccc23)c1,FC(F)(F)c1cccc(Nc2[nH+]c[nH+]c3ccccc23)c1,FC(F)(F)c1cccc([NH2+]c2[nH+]cnc3ccccc23)c1 FC(F)(F)c1cccc([NH2+]c2ncnc3ccccc23)c1
4 | 2 O=C(Nc1nnc(Cc2ccccc2)s1)c1cccs1 O=C(Nc1n[nH+]c(Cc2ccccc2)s1)c1cccs1,O=C(Nc1[nH+]nc(Cc2ccccc2)s1)c1cccs1 c1ccc(cc1)Cc2nnc(s2)NC(=O)c3cccs3
5 | 3 Clc1ccc(CNc2ncnc3ccccc23)cc1 Clc1ccc(CNc2[nH+]cnc3ccccc23)cc1,Clc1ccc(CNc2nc[nH+]c3ccccc23)cc1,Clc1ccc(C[NH2+]c2ncnc3ccccc23)cc1 Clc1ccc(CNc2[nH+]c[nH+]c3ccccc23)cc1,Clc1ccc(C[NH2+]c2[nH+]cnc3ccccc23)cc1,Clc1ccc(C[NH2+]c2nc[nH+]c3ccccc23)cc1 Clc1ccc(C[NH2+]c2ncnc3ccccc23)cc1
6 | 4 [OH+]=C(Nc1ccccc1N1CCCCC1)c1ccc(Cl)o1 [OH+]=C(Nc1ccccc1[NH+]1CCCCC1)c1ccc(Cl)o1 [OH+]=C(Nc1ccccc1N1CCCCC1)c1ccc(Cl)o1
7 | 5 O=C(Nc1cccc2cccnc12)c1cncc(Br)c1 O=C(Nc1cccc2ccc[nH+]c12)c1cncc(Br)c1,O=C(Nc1cccc2cccnc12)c1c[nH+]cc(Br)c1 O=C(Nc1cccc2ccc[nH+]c12)c1c[nH+]cc(Br)c1 O=C(Nc1cccc2ccc[nH+]c12)c1cncc(Br)c1
8 | 6 O=C(Nc1cccc2cccnc12)c1cncc(Br)c1 O=C(Nc1cccc2cccnc12)c1c[nH+]cc(Br)c1,O=C(Nc1cccc2ccc[nH+]c12)c1cncc(Br)c1 c1cc2cccnc2c(c1)NC(=O)c3cc(cnc3)Br
9 | 7 c1ccc(CNc2ncnc3ccccc23)cc1 c1ccc(CNc2nc[nH+]c3ccccc23)cc1,c1ccc(CNc2[nH+]cnc3ccccc23)cc1,c1ccc(C[NH2+]c2ncnc3ccccc23)cc1 c1ccc(CNc2[nH+]c[nH+]c3ccccc23)cc1,c1ccc(C[NH2+]c2[nH+]cnc3ccccc23)cc1,c1ccc(C[NH2+]c2nc[nH+]c3ccccc23)cc1 c1ccc(CNc2nc[nH+]c3ccccc23)cc1
10 | 8 Cc1ccc2[nH]c(=O)c(CC(=O)[O-])c(-c3ccccc3)c2c1,Cc1ccc2[n-]c(=O)c(CC(=O)O)c(-c3ccccc3)c2c1 Cc1ccc2[nH]c(=O)c(CC(=O)O)c(-c3ccccc3)c2c1 Cc1ccc2c(c1)c(c(c(=O)[nH]2)CC(=O)O)c3ccccc3
11 | 9 COc1cccc(Nc2ncnc3ccccc23)c1 COc1cccc(Nc2[nH+]cnc3ccccc23)c1,COc1cccc(Nc2nc[nH+]c3ccccc23)c1,COc1cccc([NH2+]c2ncnc3ccccc23)c1 COc1cccc(Nc2[nH+]c[nH+]c3ccccc23)c1,COc1cccc([NH2+]c2[nH+]cnc3ccccc23)c1,COc1cccc([NH2+]c2nc[nH+]c3ccccc23)c1 COc1cccc(Nc2nc[nH+]c3ccccc23)c1
12 | 10 O=C(CNC(=O)c1ccccc1)Nc1nc2ccccc2s1 O=C(CNC(=O)c1ccccc1)Nc1[nH+]c2ccccc2s1 c1ccc(cc1)C(=O)NCC(=O)Nc2nc3ccccc3s2
13 | 11 Nc1ncnc2c1cnn2-c1ccccc1 [NH3+]c1ncnc2c1cnn2-c1ccccc1,Nc1ncnc2c1c[nH+]n2-c1ccccc1,Nc1nc[nH+]c2c1cnn2-c1ccccc1,Nc1[nH+]cnc2c1cnn2-c1ccccc1 Nc1[nH+]c[nH+]c2c1cnn2-c1ccccc1,Nc1nc[nH+]c2c1c[nH+]n2-c1ccccc1,Nc1[nH+]cnc2c1c[nH+]n2-c1ccccc1,[NH3+]c1nc[nH+]c2c1cnn2-c1ccccc1,[NH3+]c1[nH+]cnc2c1cnn2-c1ccccc1,[NH3+]c1ncnc2c1c[nH+]n2-c1ccccc1 [NH3+]c1ncnc2c1cnn2-c1ccccc1
14 | 12 Clc1cccc(Nc2ncnc3ccccc23)c1 Clc1cccc([NH2+]c2ncnc3ccccc23)c1,Clc1cccc(Nc2nc[nH+]c3ccccc23)c1,Clc1cccc(Nc2[nH+]cnc3ccccc23)c1 Clc1cccc([NH2+]c2[nH+]cnc3ccccc23)c1,Clc1cccc([NH2+]c2nc[nH+]c3ccccc23)c1,Clc1cccc(Nc2[nH+]c[nH+]c3ccccc23)c1 Clc1cccc(Nc2nc[nH+]c3ccccc23)c1
15 | 13 COc1cc2ncnc(Nc3cccc(C)c3)c2cc1OC COc1cc2nc[nH+]c(Nc3cccc(C)c3)c2cc1OC,COc1cc2[nH+]cnc(Nc3cccc(C)c3)c2cc1OC,COc1cc2ncnc([NH2+]c3cccc(C)c3)c2cc1OC COc1cc2nc[nH+]c([NH2+]c3cccc(C)c3)c2cc1OC,COc1cc2[nH+]c[nH+]c(Nc3cccc(C)c3)c2cc1OC,COc1cc2[nH+]cnc([NH2+]c3cccc(C)c3)c2cc1OC COc1cc2nc[nH+]c(Nc3cccc(C)c3)c2cc1OC
16 | 14 Nc1ccc2c(c1)ncn2-c1ccccc1 [NH3+]c1ccc2c(c1)ncn2-c1ccccc1,Nc1ccc2c(c1)[nH+]cn2-c1ccccc1 [NH3+]c1ccc2c(c1)[nH+]cn2-c1ccccc1 [NH3+]c1ccc2c(c1)[nH+]cn2-c1ccccc1
17 | 15 Nc1ccc2c(c1)ncn2-c1ccccc1 [NH3+]c1ccc2c(c1)ncn2-c1ccccc1,Nc1ccc2c(c1)[nH+]cn2-c1ccccc1 [NH3+]c1ccc2c(c1)[nH+]cn2-c1ccccc1 [NH3+]c1ccc2c(c1)ncn2-c1ccccc1
18 | 16 [O-]c1ccc(-n2cnc3ccccc32)cc1 Oc1ccc(-n2cnc3ccccc32)cc1,[O-]c1ccc(-n2c[nH+]c3ccccc32)cc1 Oc1ccc(-n2c[nH+]c3ccccc32)cc1 Oc1ccc(-n2c[nH+]c3ccccc32)cc1
19 | 17 [O-]c1ccc(-n2cnc3ccccc32)cc1 Oc1ccc(-n2cnc3ccccc32)cc1,[O-]c1ccc(-n2c[nH+]c3ccccc32)cc1 Oc1ccc(-n2c[nH+]c3ccccc32)cc1 c1ccc2c(c1)ncn2c3ccc(cc3)O
20 | 18 O=C(Nc1ccncc1)c1c(Cl)cccc1Cl O=C(Nc1cc[nH+]cc1)c1c(Cl)cccc1Cl O=C(Nc1cc[nH+]cc1)c1c(Cl)cccc1Cl
21 | 19 O=C(Nc1ccncc1)c1c(Cl)cccc1Cl O=C(Nc1cc[nH+]cc1)c1c(Cl)cccc1Cl c1cc(c(c(c1)Cl)C(=O)Nc2ccncc2)Cl
22 | 20 c1ccc(CSc2nnc(-c3ccncc3)o2)cc1 c1ccc(CSc2n[nH+]c(-c3ccncc3)o2)cc1,c1ccc(CSc2[nH+]nc(-c3ccncc3)o2)cc1,c1ccc(CSc2nnc(-c3cc[nH+]cc3)o2)cc1 c1ccc(CSc2[nH+][nH+]c(-c3ccncc3)o2)cc1,c1ccc(CSc2n[nH+]c(-c3cc[nH+]cc3)o2)cc1,c1ccc(CSc2[nH+]nc(-c3cc[nH+]cc3)o2)cc1 c1ccc(CSc2nnc(-c3cc[nH+]cc3)o2)cc1
23 | 21 O=c1[nH]c(CCC(=[OH+])Nc2ncc(Cc3ccc(F)c(F)c3)s2)nc2ccccc12 O=c1[nH]c(CCC(=[OH+])Nc2ncc(Cc3ccc(F)c(F)c3)s2)[nH+]c2ccccc12,O=c1[nH]c(CCC(=[OH+])Nc2[nH+]cc(Cc3ccc(F)c(F)c3)s2)nc2ccccc12 O=c1[nH]c(CCC(=[OH+])Nc2ncc(Cc3ccc(F)c(F)c3)s2)nc2ccccc12
24 | 22 O=C(CCc1nc2ccccc2c(=O)[n-]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1 O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1 O=C(CCc1[nH]c(=O)c2ccccc2[nH+]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1[nH+]cc(Cc2ccc(F)c(F)c2)s1 c1ccc2c(c1)c(=O)[nH]c(n2)CCC(=O)Nc3ncc(s3)Cc4ccc(c(c4)F)F
25 | 23 O=C(CCc1nc2ccccc2c(=O)[n-]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1 O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1 O=C(CCc1[nH]c(=O)c2ccccc2[nH+]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1[nH+]cc(Cc2ccc(F)c(F)c2)s1 O=C(CCc1nc2ccccc2c(=O)[n-]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1
26 | 24 CCOc1ccc2nc(NC(=O)Cc3ccc(Cl)c(Cl)c3)sc2c1 CCOc1ccc2[nH+]c(NC(=O)Cc3ccc(Cl)c(Cl)c3)sc2c1 CCOc1ccc2c(c1)sc(n2)NC(=O)Cc3ccc(c(c3)Cl)Cl
27 | 25 O=C1NC(=O)/C(=C\c2cccc(OCc3ccc(Cl)cc3Cl)c2)S1 c1cc(cc(c1)OCc2ccc(cc2Cl)Cl)/C=C/3\C(=O)NC(=O)S3
28 | 26 Fc1cnc(Nc2cccc(Br)c2)nc1Nc1cccc(Br)c1 Fc1cnc(Nc2cccc(Br)c2)nc1[NH2+]c1cccc(Br)c1,Fc1cnc([NH2+]c2cccc(Br)c2)nc1Nc1cccc(Br)c1,Fc1cnc(Nc2cccc(Br)c2)[nH+]c1Nc1cccc(Br)c1,Fc1c[nH+]c(Nc2cccc(Br)c2)nc1Nc1cccc(Br)c1 Fc1cnc([NH2+]c2cccc(Br)c2)nc1[NH2+]c1cccc(Br)c1,Fc1c[nH+]c([NH2+]c2cccc(Br)c2)nc1Nc1cccc(Br)c1,Fc1c[nH+]c(Nc2cccc(Br)c2)[nH+]c1Nc1cccc(Br)c1,Fc1c[nH+]c(Nc2cccc(Br)c2)nc1[NH2+]c1cccc(Br)c1,Fc1cnc([NH2+]c2cccc(Br)c2)[nH+]c1Nc1cccc(Br)c1,Fc1cnc(Nc2cccc(Br)c2)[nH+]c1[NH2+]c1cccc(Br)c1 Fc1cnc(Nc2cccc(Br)c2)nc1[NH2+]c1cccc(Br)c1
29 | 27 [O-]c1c(I)cc(I)c2cccnc12 Oc1c(I)cc(I)c2cccnc12,[O-]c1c(I)cc(I)c2ccc[nH+]c12 Oc1c(I)cc(I)c2ccc[nH+]c12 Oc1c(I)cc(I)c2ccc[nH+]c12
30 | 28 [O-]c1c(I)cc(I)c2cccnc12 Oc1c(I)cc(I)c2cccnc12,[O-]c1c(I)cc(I)c2ccc[nH+]c12 Oc1c(I)cc(I)c2ccc[nH+]c12 c1cc2c(cc(c(c2nc1)O)I)I
31 | 29 CCOC(=O)c1ccc(Nc2cc(C)nc(Nc3ccc(C(=O)OCC)cc3)n2)cc1 CCOC(=O)c1ccc(Nc2cc(C)[nH+]c(Nc3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)nc([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2nc(C)cc([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)nc(Nc3ccc(C(=O)OCC)cc3)[nH+]2)cc1 CCOC(=O)c1ccc([NH2+]c2cc(C)nc([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)[nH+]c(Nc3ccc(C(=O)OCC)cc3)[nH+]2)cc1,CCOC(=O)c1ccc(Nc2nc([NH2+]c3ccc(C(=O)OCC)cc3)cc(C)[nH+]2)cc1,CCOC(=O)c1ccc(Nc2cc(C)[nH+]c([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)nc([NH2+]c3ccc(C(=O)OCC)cc3)[nH+]2)cc1,CCOC(=O)c1ccc(Nc2nc(C)cc([NH2+]c3ccc(C(=O)OCC)cc3)[nH+]2)cc1 CCOC(=O)c1ccc(Nc2cc(C)nc(Nc3ccc(C(=O)OCC)cc3)[nH+]2)cc1
32 | 30 COc1ccc(-c2oc3ncnc(NCCO)c3c2-c2ccc(OC)cc2)cc1 COc1ccc(-c2oc3[nH+]cnc(NCCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3ncnc([NH2+]CCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3nc[nH+]c(NCCO)c3c2-c2ccc(OC)cc2)cc1 COc1ccc(-c2oc3nc[nH+]c([NH2+]CCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3[nH+]c[nH+]c(NCCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3[nH+]cnc([NH2+]CCO)c3c2-c2ccc(OC)cc2)cc1 COc1ccc(-c2oc3[nH+]cnc(NCCO)c3c2-c2ccc(OC)cc2)cc1
33 |
--------------------------------------------------------------------------------
/enumerator/main.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List, Dict, Callable
2 | from collections import OrderedDict
3 | from argparse import ArgumentParser
4 |
5 | import pandas as pd
6 | import numpy as np
7 | from tqdm import tqdm, trange
8 | from rdkit.Chem import MolFromSmarts, AddHs, MolFromSmiles, SanitizeMol, Mol, CanonSmiles, MolToSmiles, RemoveHs, RWMol, GetFormalCharge
9 | from rdkit.RDLogger import DisableLog
10 |
11 |
12 | # Silence!
13 | DisableLog('rdApp.*')
14 |
15 |
16 | # Unreasonable chemical structures
17 | FILTER_PATTERNS = list(map(MolFromSmarts, [
18 | "[#6X5]",
19 | "[#7X5]",
20 | "[#8X4]",
21 | "[*r]=[*r]=[*r]",
22 | "[#1]-[*+1]~[*-1]",
23 | "[#1]-[*+1]=,:[*]-,:[*-1]",
24 | "[#1]-[*+1]-,:[*]=,:[*-1]",
25 | "[*+2]",
26 | "[*-2]",
27 | "[#1]-[#8+1].[#8-1,#7-1,#6-1]",
28 | "[#1]-[#7+1,#8+1].[#7-1,#6-1]",
29 | "[#1]-[#8+1].[#8-1,#6-1]",
30 | "[#1]-[#7+1].[#8-1]-[C](-[C,#1])(-[C,#1])",
31 | # "[#6;!$([#6]-,:[*]=,:[*]);!$([#6]-,:[#7,#8,#16])]=[C](-[O,N,S]-[#1])",
32 | # "[#6]-,=[C](-[O,N,S])(-[O,N,S]-[#1])",
33 | "[OX1]=[C]-[OH2+1]",
34 | "[NX1,NX2H1,NX3H2]=[C]-[O]-[H]",
35 | "[#6-1]=[*]-[*]",
36 | "[cX2-1]",
37 | "[N+1](=O)-[O]-[H]"
38 | ]))
39 |
40 |
41 | def _read_dataset(dataset_file: str)-> pd.DataFrame:
42 | try:
43 | dataset = pd.read_csv(dataset_file, sep="\t", index_col=False)
44 | except pd.errors.ParserError:
45 | dataset = pd.read_csv(dataset_file, sep=",", index_col=False)
46 | return dataset
47 |
48 |
49 | def read_dataset(dataset_file: str, column: str="SMILES", mode=None) -> Tuple[List[List[str]], List[List[str]], pd.DataFrame]:
50 | '''
51 | Read an acid/base dataset.
52 |
53 | Params:
54 | ----
55 | `dataset_file`: The path of a csv-like dataset with columns separated by `\t`.
56 |
57 | `column`: The name of the column storing SMILES.
58 |
59 | `mode`:
60 | - `None` if every entry is acid/base pair recorded as [acid SMILES]>>[basic SMILES].
61 | - `A` if every entry stores acid structures as [acid SMILES].
62 | - `B` if every entry stores base structures as [base SMILES].
63 |
64 | Return:
65 | ----
66 | acid SMILES collections, base SMILES collections, the dataset as `pandas.Dataframe`
67 | '''
68 | dataset = _read_dataset(dataset_file)
69 | smis_A, smis_B = [], []
70 | for smi in dataset[column]:
71 | if ">>" in smi:
72 | ab_smi = smi.split(">>")
73 | smis_A.append(ab_smi[0].split(","))
74 | smis_B.append(ab_smi[1].split(","))
75 | else:
76 | if mode == "A":
77 | smis_A.append([smi]), smis_B.append([])
78 | elif mode == "B":
79 | smis_A.append([]), smis_B.append([smi])
80 | else:
81 | raise ValueError
82 | return smis_A, smis_B, dataset
83 |
84 |
85 | def read_template(template_file: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
86 | '''
87 | Read a protonation template.
88 |
89 | Params:
90 | ----
91 | `template_file`: path of `.csv`-like template, with columns of substructure names, SMARTS patterns, protonation indices and acid/base flags
92 |
93 | Return:
94 | ----
95 | `template_a2b`, `template_b2a`: acid to base and base to acid templates
96 | '''
97 | template = pd.read_csv(template_file, sep="\t")
98 | template_a2b = template[template.Acid_or_base == "A"]
99 | template_b2a = template[template.Acid_or_base == "B"]
100 | return template_a2b, template_b2a
101 |
102 |
103 | def match_template(template: pd.DataFrame, mol: Mol, verbose: bool=False) -> list:
104 | '''
105 | Find protonation site using templates
106 |
107 | Params:
108 | ----
109 | `template`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, protonation indices and acid/base flags
110 |
111 | `mol`: Molecule
112 |
113 | `verbose`: Boolean flag for printing matching results
114 |
115 | Return:
116 | ----
117 | A set of matched indices to be (de)protonated
118 | '''
119 | mol = AddHs(mol)
120 | matches = []
121 | for idx, name, smarts, index, acid_base in template.itertuples():
122 | pattern = MolFromSmarts(smarts)
123 | match = mol.GetSubstructMatches(pattern)
124 | if len(match) == 0:
125 | continue
126 | else:
127 | index = int(index)
128 | for m in match:
129 | matches.append(m[index])
130 | if verbose:
131 | print(f"find index {m[index]} in pattern {name} smarts {smarts}")
132 | return list(set(matches))
133 |
134 |
135 | def prot(mol: Mol, idx: int, mode: str) -> Mol:
136 | '''
137 | Protonate / Deprotonate a molecule at a specified site
138 |
139 | Params:
140 | ----
141 | `mol`: Molecule
142 |
143 | `idx`: Index of reaction
144 |
145 | `mode`: `a2b` means deprotonization, with a hydrogen atom or a heavy atom at `idx`; `b2a` means protonization, with a heavy atom at `idx`
146 |
147 | Return:
148 | ----
149 | `mol_prot`: (De)protonated molecule
150 | '''
151 | mw = RWMol(mol)
152 | if mode == "a2b":
153 | atom_H = mw.GetAtomWithIdx(idx)
154 | if atom_H.GetAtomicNum() == 1:
155 | atom_A = atom_H.GetNeighbors()[0]
156 | charge_A = atom_A.GetFormalCharge()
157 | atom_A.SetFormalCharge(charge_A - 1)
158 | mw.RemoveAtom(idx)
159 | mol_prot = mw.GetMol()
160 | else:
161 | charge_H = atom_H.GetFormalCharge()
162 | numH_H = atom_H.GetTotalNumHs()
163 | atom_H.SetFormalCharge(charge_H - 1)
164 | atom_H.SetNumExplicitHs(numH_H - 1)
165 | atom_H.UpdatePropertyCache()
166 | mol_prot = AddHs(mw)
167 | elif mode == "b2a":
168 | atom_B = mw.GetAtomWithIdx(idx)
169 | charge_B = atom_B.GetFormalCharge()
170 | atom_B.SetFormalCharge(charge_B + 1)
171 | numH_B = atom_B.GetNumExplicitHs()
172 | atom_B.SetNumExplicitHs(numH_B + 1)
173 | mol_prot = AddHs(mw)
174 | SanitizeMol(mol_prot)
175 | mol_prot = MolFromSmiles(MolToSmiles(mol_prot))
176 | mol_prot = AddHs(mol_prot)
177 | return mol_prot
178 |
179 |
180 | def prot_template(template: pd.DataFrame, smi: str, mode: str) -> Tuple[List[int], List[str]]:
181 | """
182 | Protonate / Deprotonate a SMILES at every found site in the template
183 |
184 | Params:
185 | ----
186 | `template`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, protonation indices and acid/base flags
187 |
188 | `smi`: The SMILES to be processed
189 |
190 | `mode`: `a2b` means deprotonization, with a hydrogen atom or a heavy atom at `idx`; `b2a` means protonization, with a heavy atom at `idx`
191 | """
192 | mol = MolFromSmiles(smi)
193 | sites = match_template(template, mol)
194 | smis = []
195 | for site in sites:
196 | smis.append(CanonSmiles(MolToSmiles(RemoveHs(prot(mol, site, mode)))))
197 | return sites, list(set(smis))
198 |
199 |
200 | def sanitize_checker(smi: str, filter_patterns: List[Mol], verbose: bool=False) -> bool:
201 | """
202 | Check if a SMILES can be sanitized and does not contain unreasonable chemical structures.
203 |
204 | Params:
205 | ----
206 | `smi`: The SMILES to be check.
207 |
208 | `filter_patterns`: Unreasonable chemical structures.
209 |
210 | `verbose`: If True, matched unreasonable chemical structures will be printed.
211 |
212 | Return:
213 | ----
214 | If the SMILES should be filtered.
215 | """
216 | mol = AddHs(MolFromSmiles(smi))
217 | for pattern in filter_patterns:
218 | match = mol.GetSubstructMatches(pattern)
219 | if match:
220 | if verbose:
221 | print(f"pattern {pattern}")
222 | return False
223 | try:
224 | SanitizeMol(mol)
225 | except:
226 | print("cannot sanitize")
227 | return False
228 | return True
229 |
230 |
231 | def sanitize_filter(smis: List[str], filter_patterns: List[Mol]=FILTER_PATTERNS) -> List[str]:
232 | """
233 | A filter for SMILES can be sanitized and does not contain unreasonable chemical structures.
234 |
235 | Params:
236 | ----
237 | `smis`: The list of SMILES.
238 |
239 | `filter_patterns`: Unreasonable chemical structures.
240 |
241 | Return:
242 | ----
243 | The list of SMILES filtered.
244 | """
245 | def _checker(smi):
246 | return sanitize_checker(smi, filter_patterns)
247 | return list(filter(_checker, smis))
248 |
249 |
250 | def cnt_stereo_atom(smi: str) -> int:
251 | """
252 | Count the stereo atoms in a SMILES
253 | """
254 | mol = MolFromSmiles(smi)
255 | return sum([str(atom.GetChiralTag()) != "CHI_UNSPECIFIED" for atom in mol.GetAtoms()])
256 |
257 |
258 | def stereo_filter(smis: List[str]) -> List[str]:
259 | """
260 | A filter against SMILES losing stereochemical information in structure processing.
261 | """
262 | filtered_smi_dict = dict()
263 | for smi in smis:
264 | nonstereo_smi = CanonSmiles(smi, useChiral=0)
265 | stereo_cnt = cnt_stereo_atom(smi)
266 | if nonstereo_smi not in filtered_smi_dict:
267 | filtered_smi_dict[nonstereo_smi] = (smi, stereo_cnt)
268 | else:
269 | if stereo_cnt > filtered_smi_dict[nonstereo_smi][1]:
270 | filtered_smi_dict[nonstereo_smi] = (smi, stereo_cnt)
271 | return [value[0] for value in filtered_smi_dict.values()]
272 |
273 |
274 | def make_filter(filter_param: OrderedDict) -> Callable:
275 | """
276 | Make a sequential SMILES filter
277 |
278 | Params:
279 | ----
280 | `filter_param`: An `collections.OrderedDict` whose keys are single filter functions and the corresponding values are their parameter dictionary.
281 |
282 | Return:
283 | ----
284 | The sequential filter function
285 | """
286 | def seq_filter(smis):
287 | for single_filter, param in filter_param.items():
288 | smis = single_filter(smis, **param)
289 | return smis
290 | return seq_filter
291 |
292 |
293 | def enumerate_template(smi: str, template_a2b: pd.DataFrame, template_b2a: pd.DataFrame, mode: str="A", maxiter: int=2, verbose: int=0, filter_patterns: List[Mol]=FILTER_PATTERNS) -> Tuple[List[str], List[str]]:
294 | """
295 | Enumerate all the (de)protonation results of one SMILES.
296 |
297 | Params:
298 | ----
299 | `smi`: The smiles to be processed.
300 |
301 | `template_a2b`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, deprotonation indices and acid flags.
302 |
303 | `template_b2a`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, protonation indices and base flags.
304 |
305 | `mode`:
306 | - "A": `smi` is an acid to be deprotonated.
307 | - "B": `smi` is a base to be protonated.
308 |
309 | `maxiter`: Max iteration number of template matching and microstate pool growth.
310 |
311 | `verbose`:
312 | - 0: Silent mode.
313 | - 1: Print the length of microstate pools in each iteration.
314 | - 2: Print the content of microstate pools in each iteration.
315 |
316 | `filter_patterns`: Unreasonable chemical structures.
317 |
318 | Return:
319 | ----
320 | A microstate pool and B microstate pool after enumeration.
321 | """
322 | if isinstance(smi, str):
323 | smis = [smi]
324 | else:
325 | smis = list(smi)
326 |
327 | enum_func = lambda x: [x] # TODO: Tautomerism enumeration
328 |
329 | if mode == "A":
330 | smis_A_pool, smis_B_pool = smis, []
331 | elif mode == "B":
332 | smis_A_pool, smis_B_pool = [], smis
333 | pool_length_A = -1
334 | pool_length_B = -1
335 | filters = make_filter({
336 | sanitize_filter: {"filter_patterns": filter_patterns},
337 | stereo_filter: {}
338 | })
339 | i = 0
340 | while (len(smis_A_pool) != pool_length_A or len(smis_B_pool) != pool_length_B) and i < maxiter:
341 | pool_length_A, pool_length_B = len(smis_A_pool), len(smis_B_pool)
342 | if verbose > 0:
343 | print(f"iter {i}: {pool_length_A} acid, {pool_length_B} base")
344 | if verbose > 1:
345 | print(f"iter {i}, acid: {smis_A_pool}, base: {smis_B_pool}")
346 | if (mode == "A" and (i + 1) % 2) or (mode == "B" and i % 2):
347 | smis_A_tmp_pool = []
348 | for smi in smis_A_pool:
349 | smis_B_pool += filters(prot_template(template_a2b, smi, "a2b")[1])
350 | smis_A_tmp_pool += filters([CanonSmiles(MolToSmiles(mol)) for mol in enum_func(MolFromSmiles(smi))])
351 | smis_A_pool += smis_A_tmp_pool
352 | elif (mode == "B" and (i + 1) % 2) or (mode == "A" and i % 2):
353 | smis_B_tmp_pool = []
354 | for smi in smis_B_pool:
355 | smis_A_pool += filters(prot_template(template_b2a, smi, "b2a")[1])
356 | smis_B_tmp_pool += filters([CanonSmiles(MolToSmiles(mol)) for mol in enum_func(MolFromSmiles(smi))])
357 | smis_B_pool += smis_B_tmp_pool
358 | smis_A_pool = filters(smis_A_pool)
359 | smis_B_pool = filters(smis_B_pool)
360 | smis_A_pool = list(set(smis_A_pool))
361 | smis_B_pool = list(set(smis_B_pool))
362 | i += 1
363 | if verbose > 0:
364 | print(f"iter {i}: {pool_length_A} acid, {pool_length_B} base")
365 | if verbose > 1:
366 | print(f"iter {i}, acid: {smis_A_pool}, base: {smis_B_pool}")
367 | smis_A_pool = list(map(CanonSmiles, smis_A_pool))
368 | smis_B_pool = list(map(CanonSmiles, smis_B_pool))
369 | return smis_A_pool, smis_B_pool
370 |
371 |
372 | def check_dataset(dataset_file: str) -> None:
373 | """
374 | Check if every entry in the dataset is valid under Uni-pKa standard format.
375 | """
376 | print(f"Checking reconstructed dataset {dataset_file}")
377 | dataset = _read_dataset(dataset_file)
378 | for i in trange(len(dataset)):
379 | try:
380 | a_smi, b_smi = dataset.iloc[i]["SMILES"].split(">>")
381 | except:
382 | print(f"missing '>>' in index {i}")
383 | continue
384 | if not a_smi:
385 | print(f"missing acid smiles in index {i}")
386 | continue
387 | if not b_smi:
388 | print(f"missing base smiles in index {i}")
389 | continue
390 | for smi in a_smi.split(",") + b_smi.split(","):
391 | if not smi:
392 | print(f"empty smiles in index {i}")
393 | else:
394 | try:
395 | mol = AddHs(MolFromSmiles(smi))
396 | assert mol is not None
397 | except:
398 | print(f"invalid smiles {smi} in index {i}")
399 |
400 |
401 | def enum_dataset(input_file: str, output_file: str, template: str, mode: str, column:str, maxiter: int) -> pd.DataFrame:
402 | """
403 | Enumerate the full macrostate and reconstruct the pairwise acid/base dataset from a molecule-wise or pairwise acid/base dataset.
404 |
405 | Params:
406 | ----
407 | `input_file`: The path of input dataset.
408 |
409 | `output_file`: The path of output dataset.
410 |
411 | `mode`:
412 | - "A": Enumeration is started from the acid.
413 | - "B": Enumeration is started from the base.
414 |
415 | `column`: The name of the column storing SMILES.
416 |
417 | `maxiter`: Max iteration number of template matching and microstate pool growth.
418 |
419 | Return:
420 | ----
421 | The reconstructed dataset.
422 | """
423 | print(f"Reconstructing {input_file} with the template {template} from {mode} microstates")
424 | smis_A, smis_B, dataset = read_dataset(input_file, column=column, mode=mode)
425 |
426 | template_a2b, template_b2a = read_template(template)
427 |
428 | if mode == "A":
429 | smis_I = smis_A
430 | elif mode == "B":
431 | smis_I = smis_B
432 |
433 | SMILES_col = []
434 | for i, smis in tqdm(enumerate(smis_I), total=len(smis_I)):
435 | try:
436 | smis_a, smis_b = enumerate_template(smis, template_a2b, template_b2a, maxiter=maxiter, mode=mode)
437 | except:
438 | print(f"failed to enumerate {smis}: enum error")
439 | raise ValueError
440 | if not smis_a:
441 | if not smis_A[i]:
442 | print(f"failed to enumerate {smis}: no A states")
443 | raise ValueError
444 | else:
445 | smis_a = smis_A[i]
446 | if not smis_b:
447 | if not smis_B[i]:
448 | print(f"failed to enumerate {smis}: no B states")
449 | raise ValueError
450 | else:
451 | smis_b = smis_B[i]
452 | SMILES_col.append(",".join(smis_a) + ">>" + ",".join(smis_b))
453 |
454 | dataset["SMILES"] = SMILES_col
455 | dataset.to_csv(output_file, sep="\t")
456 | check_dataset(output_file)
457 | return dataset
458 |
459 |
460 | def enum_ensemble(input_file: str, output_file: str, template: str, upper: str, lower: str, column: str, maxiter: int) -> pd.DataFrame:
461 | """
462 | Enumerate the full macrostate and reconstruct the pairwise acid/base dataset from a molecule-wise or pairwise acid/base dataset.
463 |
464 | Params:
465 | ----
466 | `input_file`: The path of input molecules.
467 |
468 | `output_file`: The path of output ensembles.
469 |
470 | `upper`: The maximum total charge of protonation ensembles.
471 |
472 | `lower`: The minimum total charge of protonation ensembles.
473 |
474 | `column`: The name of the column storing SMILES.
475 |
476 | `maxiter`: Max iteration number of template matching and microstate pool growth.
477 |
478 | Return:
479 | ----
480 | The enumerated protonation ensembles.
481 | """
482 | print(f"Enumerating the protonation ensemble for {input_file} with the template {template} with maximum charge {upper} and minimum charge {lower}")
483 | smis, _, _ = read_dataset(input_file, column, "A")
484 | smis = [smi[0] for smi in smis]
485 | template_a2b, template_b2a = read_template(template)
486 | ensembles = {i: [] for i in range(lower, upper + 1)}
487 |
488 | for smi in tqdm(smis):
489 |
490 | ensemble = dict()
491 | q0 = GetFormalCharge(MolFromSmiles(smi))
492 | ensemble[q0] = [smi]
493 |
494 | smis_0 = [smi]
495 |
496 | if q0 > lower:
497 | smis_0, smis_b1 = enumerate_template(smis_0, template_a2b, template_b2a, maxiter=maxiter, mode="A")
498 | if smis_b1:
499 | ensemble[q0 - 1] = smis_b1
500 | for q in range(q0 - 2, lower, -1):
501 | if q + 1 in ensemble:
502 | _, smis_b = enumerate_template(ensemble[q + 1], template_a2b, template_b2a, maxiter=maxiter, mode="A")
503 | if smis_b:
504 | ensemble[q] = smis_b
505 |
506 | if q0 < upper:
507 | smis_a1, smis_0 = enumerate_template(smis_0, template_a2b, template_b2a, maxiter=maxiter, mode="B")
508 | if smis_a1:
509 | ensemble[q0 + 1] = smis_a1
510 | for q in range(q0 + 2, upper):
511 | if q - 1 in ensemble:
512 | smis_a, _ = enumerate_template(ensemble[q - 1], template_a2b, template_b2a, maxiter=maxiter, mode="B")
513 | if smis_a:
514 | ensemble[q] = smis_a
515 |
516 | ensemble[q0] = smis_0
517 |
518 | for q in ensembles:
519 | if q in ensemble:
520 | ensembles[q].append(",".join(ensemble[q]))
521 | else:
522 | ensembles[q].append(None)
523 |
524 | ensembles[column] = smis
525 | ensembles = pd.DataFrame(ensembles)
526 | ensembles.to_csv(output_file, sep="\t")
527 | return ensembles
528 |
529 |
530 | if __name__ == "__main__":
531 | parser = ArgumentParser()
532 | subparser = parser.add_subparsers(dest="command")
533 |
534 | parser_check = subparser.add_parser("check")
535 | parser_check.add_argument("-i", "--input", type=str)
536 |
537 | parser_enum = subparser.add_parser("reconstruct")
538 | parser_enum.add_argument("-i", "--input", type=str)
539 | parser_enum.add_argument("-o", "--output", type=str)
540 | parser_enum.add_argument("-t", "--template", type=str, default="smarts_pattern.tsv")
541 | parser_enum.add_argument("-n", "--maxiter", type=int, default=10)
542 | parser_enum.add_argument("-c", "--column", type=str, default="SMILES")
543 | parser_enum.add_argument("-m", "--mode", type=str, default="A")
544 |
545 | parser_enum = subparser.add_parser("ensemble")
546 | parser_enum.add_argument("-i", "--input", type=str)
547 | parser_enum.add_argument("-o", "--output", type=str)
548 | parser_enum.add_argument("-t", "--template", type=str, default="simple_smarts_pattern.tsv")
549 | parser_enum.add_argument("-u", "--upper", type=int, default=2)
550 | parser_enum.add_argument("-l", "--lower", type=int, default=-2)
551 | parser_enum.add_argument("-n", "--maxiter", type=int, default=10)
552 | parser_enum.add_argument("-c", "--column", type=str, default="SMILES")
553 |
554 | args = parser.parse_args()
555 | if args.command == "check":
556 | check_dataset(args.input)
557 | elif args.command == "reconstruct":
558 | enum_dataset(
559 | input_file=args.input,
560 | output_file=args.output,
561 | template=args.template,
562 | mode=args.mode,
563 | column=args.column,
564 | maxiter=args.maxiter
565 | )
566 | elif args.command == "ensemble":
567 | enum_ensemble(
568 | input_file=args.input,
569 | output_file=args.output,
570 | template=args.template,
571 | upper=args.upper,
572 | lower=args.lower,
573 | column=args.column,
574 | maxiter=args.maxiter
575 | )
576 |
--------------------------------------------------------------------------------
/enumerator/simple_smarts_pattern.tsv:
--------------------------------------------------------------------------------
1 | Substructure SMARTS Index Acid_or_base
2 | Sulfate monoether [SX4:0](=[O:1])(=[O:2])(-[O:3])-[OX2:4]-[H:5] 4 A
3 | Sulfate monoether [SX4:0](=[O:1])(=[O:2])(-[O:3])-[O-1:4] 4 B
4 | Sulfonic acid [SX4:0](=[O:1])(=[O:2])(-[#6,#7:3])-[OX2:4]-[H:5] 4 A
5 | Sulfonic acid [SX4:0](=[O:1])(=[O:2])(-[#6,#7:3])-[O-1:4] 4 B
6 | Sulfinic acid [SX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A
7 | Sulfinic acid [SX3:0](=[O:1])(-[0#6,#7:2])-[O-1:3] 3 B
8 | Seleninic acid [SeX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A
9 | Seleninic acid [SeX3:0](=[O:1])(-[#6,#7:2])-[O-1:3] 3 B
10 | Selenenic acid [SeX2:0]-[OX2:1]-[H:2] 1 A
11 | Selenenic acid [SeX2:0]-[O-1:1] 1 B
12 | Arsonic acid [AsX4:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A
13 | Arsonic acid [AsX4:0](=[O:1])(-[#6,#7:2])-[O-1:3] 3 B
14 | Thiosulfuric acid [S:0]~[SX4:1](~[O:2])(~[O:3])-[O:4]-[H:5] 4 A
15 | Thiosulfuric acid [S:0]~[SX4:1](~[O:2])(~[O:3])-[O-1:4] 4 B
16 | Phosph(o/i)nic acid [PX4:0](=[O:1])(-[OX2:2]-[H:5])(-[#1,#6,#7,#8:3])(-[#1,#6,#7,#8:4]) 2 A
17 | Phosph(o/i)nic acid [PX4:0](=[O:1])(-[O-1:2])(-[#1,#6,#7,#8:3])(-[#1,#6,#7,#8:4]) 2 B
18 | Phosphate (mono/di)ether [PX4:0](=[O:1])(-[O:2])(-[O:3])-[OX2:4]-[H:5] 4 A
19 | Phosphate (mono/di)ether [PX4:0](=[O:1])(-[O:2])(-[O:3])-[O-1:4] 4 B
20 | Carboxyl acid [$([#6]=[#8,#7]),$(C#N):0]-[OX2:1]-[H:2] 1 A
21 | Carboxyl acid [$([#6]=[#8,#7]),$(C#N):0]-[O-1:1] 1 B
22 | Carboxyl acid enol [C:0]=[C:1](-[OX2:2]-[H:3])-[OX2:4]-[H:5] 4 A
23 | Carboxyl acid enol [C:0]=[C:1](-[OX2:2]-[H:3])-[O-1:4] 4 B
24 | Carbo(di)thioic acid [CX3:0](=[O,S:1])-[SX2,OX2:2]-[H:3] 2 A
25 | Carbo(di)thioic acid [CX3:0](=[O,S:1])-[S-1,O-1:2] 2 B
26 | Carboxyl acid vinylogue [O:0]=[C:1]-[C:2]=[C:3]-[OX2:4]-[H:5] 4 A
27 | Carboxyl acid vinylogue [O:0]=[C:1]-[C:2]=[C:3]-[O-1:4] 4 B
28 | Thiol/Thiophenol [#6,#7:0]-[SX2:1]-[H:2] 1 A
29 | Thiol/Thiophenol [#6,#7:0]-[S-1:1] 1 B
30 | Phenol [c,n:0]-[OX2:1]-[H:2] 1 A
31 | Phenol [c,n:0]-[O-1:1] 1 B
32 | Hydroperoxide/Hydroxyl amine [O,N:0]-[OX2:1]-[H:2] 1 A
33 | Hydroperoxide/Hydroxyl amine [O,N:0]-[O-1:1] 1 B
34 | Azole [#7:0]1(-[H:5])-,:[#7,#6:1]=,:[#7,#6:2]-,:[#7,#6:3]=,:[#7,#6:4]-,:1 0 A
35 | Azole [#7-1:0]1-,:[#7,#6:1]=,:[#7,#6:2]-,:[#7,#6:3]=,:[#7,#6:4]-,:1 0 B
36 | Aza-aromatics [n:0]-[H:1] 0 A
37 | Aza-aromatics [n-1,n+0X2:0] 0 B
38 | Oxime [$([#7]:,=[#6,#7]),$([#7]:,=[#6,#7]:,-[#6,#7]:,=[#6,#7]):0]-[OX2,NX3:1]-[H:2] 1 A
39 | Oxime [$([#7]:,=[#6,#7]),$([#7]:,=[#6,#7]:,-[#6,#7]:,=[#6,#7]):0]-[O-1,NX2-1:1] 1 B
40 | Amine [NX4+1:0](-[H:4])(-[CX4,c,#7,#8,#1,S,$(C=C),Cl:1])(-[CX4,c,#7,#8,#1,S,$(C=C):2])(-[CX4,c,#7,#8,#1,S,$(C=C):3]) 0 A
41 | Amine [NX3:0](-[CX4,c,#7,#8,#1,S,$(C=C),Cl:1])(-[CX4,c,#7,#8,#1,S,$(C=C):2])(-[CX4,c,#7,#8,#1,S,$(C=C):3]) 0 B
42 | Imine [#6,#7,P,S:0]=[NX3+1:1](-[H:2]) 1 A
43 | Imine [#6,#7,P,S:0]=[NX2:1] 1 B
44 | Amide [$([#7]=[#7,#8]),$(c:c:c:c:[#7+1]):0]-[NX3:1]-[H:2] 1 A
45 | Amide [$([#7]=[#7,#8]),$(c:c:c:c:[#7+1]):0]-[NX2-1:1] 1 B
46 | Amide imine [$([#6]-,:[O,S,#7]),N+1:0]=,:[NX2:1]-[H:2] 1 A
47 | Amide imine [$([#6]-,:[O,S,#7]),N+1:0]=,:[NX1-1:1] 1 B
48 | Sulfamide [SX4:0](=[O:1])(=[O:2])-[NX3:3]-[H:4] 3 A
49 | Sulfamide [SX4:0](=[O:1])(=[O:2])-[NX2-1:3] 3 B
50 | Phosphamide [PX4:0](=[O:1])-[NX3:2]-[H:3] 2 A
51 | Phosphamide [PX4:0](=[O:1])-[NX2-1:2] 2 B
52 | Enol [$([#6]=,:[#7,#8]),$(C#N),#7+1,$([S]=[O]),OH1:0]-[#6:1]:,=[#6:2]-[OX2:3]-[H:4] 3 A
53 | Enol [$([#6]=,:[#7,#8]),$(C#N),#7+1,$([S]=[O]),OH1:0]-[#6:1]:,=[#6:2]-[O-1:3] 3 B
54 | Hydrocyanic acid [N:0]#[C:1]-[H:2] 1 A
55 | Hydrocyanic acid [N:0]#[C-1:1] 1 B
56 | Selenol [SeX2:0]-[H:1] 0 A
57 | Selenol [SeX1-1:0] 0 B
--------------------------------------------------------------------------------
/enumerator/smarts_pattern.tsv:
--------------------------------------------------------------------------------
1 | Substructure SMARTS Index Acid_or_base
2 | Sulfate monoether [SX4:0](=[O:1])(=[O:2])(-[O:3])-[OX2:4]-[H:5] 4 A
3 | Sulfate monoether [SX4:0](=[O:1])(=[O:2])(-[O:3])-[O-1:4] 4 B
4 | Sulfonic acid [SX4:0](=[O:1])(=[O:2])(-[#6,#7:3])-[OX2:4]-[H:5] 4 A
5 | Sulfonic acid [SX4:0](=[O:1])(=[O:2])(-[#6,#7:3])-[O-1:4] 4 B
6 | Sulfinic acid [SX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A
7 | Sulfinic acid [SX3:0](=[O:1])(-[0#6,#7:2])-[O-1:3] 3 B
8 | Seleninic acid [SeX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A
9 | Seleninic acid [SeX3:0](=[O:1])(-[#6,#7:2])-[O-1:3] 3 B
10 | Selenenic acid [SeX2:0]-[OX2:1]-[H:2] 1 A
11 | Selenenic acid [SeX2:0]-[O-1:1] 1 B
12 | Arsonic acid [AsX4:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A
13 | Arsonic acid [AsX4:0](=[O:1])(-[#6,#7:2])-[O-1:3] 3 B
14 | Thiosulfuric acid [S:0]~[SX4:1](~[O:2])(~[O:3])-[O:4]-[H:5] 4 A
15 | Thiosulfuric acid [S:0]~[SX4:1](~[O:2])(~[O:3])-[O-1:4] 4 B
16 | Phosph(o/i)nic acid [PX4:0](=[O:1])(-[OX2:2]-[H:5])(-[#1,#6,#7,#8:3])(-[#1,#6,#7,#8:4]) 2 A
17 | Phosph(o/i)nic acid [PX4:0](=[O:1])(-[O-1:2])(-[#1,#6,#7,#8:3])(-[#1,#6,#7,#8:4]) 2 B
18 | Phosphate (mono/di)ether [PX4:0](=[O:1])(-[O:2])(-[O:3])-[OX2:4]-[H:5] 4 A
19 | Phosphate (mono/di)ether [PX4:0](=[O:1])(-[O:2])(-[O:3])-[O-1:4] 4 B
20 | Carboxyl acid [$([#6]=[#8,#7]),$(C#N):0]-[OX2:1]-[H:2] 1 A
21 | Carboxyl acid [$([#6]=[#8,#7]),$(C#N):0]-[O-1:1] 1 B
22 | Carboxyl acid enol [C:0]=[C:1](-[OX2:2]-[H:3])-[OX2:4]-[H:5] 4 A
23 | Carboxyl acid enol [C:0]=[C:1](-[OX2:2]-[H:3])-[O-1:4] 4 B
24 | Carbo(di)thioic acid [CX3:0](=[O,S:1])-[SX2,OX2:2]-[H:3] 2 A
25 | Carbo(di)thioic acid [CX3:0](=[O,S:1])-[S-1,O-1:2] 2 B
26 | Carboxyl acid vinylogue [O:0]=[C:1]-[C:2]=[C:3]-[OX2:4]-[H:5] 4 A
27 | Carboxyl acid vinylogue [O:0]=[C:1]-[C:2]=[C:3]-[O-1:4] 4 B
28 | Thiol/Thiophenol [#6,#7:0]-[SX2:1]-[H:2] 1 A
29 | Thiol/Thiophenol [#6,#7:0]-[S-1:1] 1 B
30 | Phenol [c,n:0]-[OX2:1]-[H:2] 1 A
31 | Phenol [c,n:0]-[O-1:1] 1 B
32 | Alcohol [$([CX4]-[$([#6]=,:[#7,#8]),$([#6]=,:[#6]-,:[#6]=,:[#7,#8]),$(C#N),O,#7+1,$(C-Cl),$(C#C),$(C-O),S:0]),$([CH2]-[#1,CH3]):0]-[OX2:1]-[H:2] 1 A
33 | Alcohol [$([CX4]-[$([#6]=,:[#7,#8]),$([#6]=,:[#6]-,:[#6]=,:[#7,#8]),$(C#N),O,#7+1,$(C-Cl),$(C#C),$(C-O),S:0]),$([CH2]-[#1,CH3]):0]-[O-1:1] 1 B
34 | Hydroxypyridine [n:0]:[c:1]-[OH2+1:2]-[H:3] 2 A
35 | Hydroxypyridine [n:0]:[c:1]-[OX2H1:2] 2 B
36 | Methylpyridine [n:0](-[C:1]=[O:2]):[c:3]:[c:4]:[c:5]-[CX4:6]-[H:7] 6 A
37 | Methylpyridine [n:0](-[C:1]=[O:2]):[c:3]:[c:4]:[c:5]-[CX3-1:6] 6 B
38 | Hydroperoxide/Hydroxyl amine [O,N:0]-[OX2:1]-[H:2] 1 A
39 | Hydroperoxide/Hydroxyl amine [O,N:0]-[O-1:1] 1 B
40 | Azole [#7:0]1(-[H:5])-,:[#7,#6:1]=,:[#7,#6:2]-,:[#7,#6:3]=,:[#7,#6:4]-,:1 0 A
41 | Azole [#7-1:0]1-,:[#7,#6:1]=,:[#7,#6:2]-,:[#7,#6:3]=,:[#7,#6:4]-,:1 0 B
42 | Aza-aromatics [n:0]-[H:1] 0 A
43 | Aza-aromatics [n-1,n+0X2:0] 0 B
44 | N-substitute aza-aromatics [n+1:0]-[CX4:1]-[H:2] 1 A
45 | N-substitute aza-aromatics [n+1:0]-[CX3-1:1] 1 B
46 | Oxime [$([#7]:,=[#6,#7]),$([#7]:,=[#6,#7]:,-[#6,#7]:,=[#6,#7]):0]-[OX2,NX3:1]-[H:2] 1 A
47 | Oxime [$([#7]:,=[#6,#7]),$([#7]:,=[#6,#7]:,-[#6,#7]:,=[#6,#7]):0]-[O-1,NX2-1:1] 1 B
48 | Amine [NX4+1:0](-[H:4])(-[CX4,c,#7,#8,#1,S,$(C=C),Cl:1])(-[CX4,c,#7,#8,#1,S,$(C=C):2])(-[CX4,c,#7,#8,#1,S,$(C=C):3]) 0 A
49 | Amine [NX3:0](-[CX4,c,#7,#8,#1,S,$(C=C),Cl:1])(-[CX4,c,#7,#8,#1,S,$(C=C):2])(-[CX4,c,#7,#8,#1,S,$(C=C):3]) 0 B
50 | Imine [#6,#7,P,S:0]=[NX3+1:1](-[H:2]) 1 A
51 | Imine [#6,#7,P,S:0]=[NX2:1] 1 B
52 | Amide [$([#6]=,:[O,S,#7:0]),$([#7]=[#7,#8]),$([#6]:,=[#6]:,-[#7]=[#7,#8]),$(c:c:c:c:[#7+1]):0]-[NX3:1]-[H:2] 1 A
53 | Amide [$([#6]=,:[O,S,#7:0]),$([#7]=[#7,#8]),$([#6]:,=[#6]:,-[#7]=[#7,#8]),$(c:c:c:c:[#7+1]):0]-[NX2-1:1] 1 B
54 | Amide imine [$([#6]-,:[O,S,#7]),N+1:0]=,:[NX2:1]-[H:2] 1 A
55 | Amide imine [$([#6]-,:[O,S,#7]),N+1:0]=,:[NX1-1:1] 1 B
56 | Sulfamide [SX4:0](=[O:1])(=[O:2])-[NX3:3]-[H:4] 3 A
57 | Sulfamide [SX4:0](=[O:1])(=[O:2])-[NX2-1:3] 3 B
58 | Phosphamide [PX4:0](=[O:1])-[NX3:2]-[H:3] 2 A
59 | Phosphamide [PX4:0](=[O:1])-[NX2-1:2] 2 B
60 | Amide vinylogue [NX3:0](-[H:5])-,:[#6:1]=,:[#6:2]-,:[$([#6]=,:[#7,#8]),$(C#N):3] 0 A
61 | Amide vinylogue [NX2-1:0]-,:[#6:1]=,:[#6:2]-,:[$([#6]=,:[#7,#8]),$(C#N):3] 0 B
62 | Di Carbonyl βH [$([#6,#7]=,:[#7,#8]),$(C#N),$([#6]=,:[#6]-,:[$([#6]=,:[#7,#8]),$(C#N)]),$(P=O),$(S=O),S+1,Cl,F,O,c:0]-,:[#6X4:1](-[H:3])-,:[$([#6]=,:[#7,#8]),$(C#N),$(P=O),$(S=O):2] 1 A
63 | Di Carbonyl βH [$([#6,#7]=,:[#7,#8]),$(C#N),$([#6]=,:[#6]-,:[$([#6]=,:[#7,#8]),$(C#N)]),$(P=O),$(S=O),S+1,Cl,F,O,c:0]-,:[#6X3-1:1]-,:[$([#6]=,:[#7,#8]),$(C#N),$(P=O),$(S=O):2] 1 B
64 | Carbonyl βH [$([#6](=O)(-,:[#7+1,#6,#1])(-,:[#6,#1])),$([N](=O)(-O)),$(P=O),$(C=C-C(=O)-[#6,#1]):0]-,:[#6X4:1]-[H:2] 1 A
65 | Carbonyl βH [$([#6](=O)(-,:[#7+1,#6,#1])(-,:[#6,#1])),$([N](=O)(-O)),$(P=O),$(C=C-C(=O)-[#6,#1]):0]-,:[#6X3-1:1] 1 B
66 | Carbonyl allene [O:0]=[C:1]-[C:2]=[C:3]=[CX3:4]-[H:5] 4 A
67 | Carbonyl allene [O:0]=[C:1]-[C:2]=[C:3]=[CX2-1:4] 4 B
68 | Enol [$([#6]=,:[#7,#8]),$(C#N),#7+1,$([S]=[O]),c,$(C=C),OH1:0]-[#6:1]:,=[#6:2]-[OX2:3]-[H:4] 3 A
69 | Enol [$([#6]=,:[#7,#8]),$(C#N),#7+1,$([S]=[O]),c,$(C=C),OH1:0]-[#6:1]:,=[#6:2]-[O-1:3] 3 B
70 | Enol [#6:0]=[#6:1](-[$(C=O),$(C(=C)-[OH1]):2])-[OX2:3]-[H:4] 3 A
71 | Enol [#6:0]=[#6:1](-[$(C=O),$(C(=C)-[OH1]):2])-[O-1:3] 3 B
72 | Acyl group [#6:0](-[O,N:1])=[OX2+1:2]-[H:3] 2 A
73 | Acyl group [#6:0](-[O,N:1])=[OX1:2] 2 B
74 | Sulfoxide [S+1:0](-[OX2:1]-[H:4])(-[#6:2])(-[#6:3]) 1 A
75 | Sulfoxide [S+1:0](-[O-1:1])(-[#6:2])(-[#6:3]) 1 B
76 | Sulfoxide [S:0](=[OX2+1:1]-[H:4])(-[#6:2])(-[#6:3]) 1 A
77 | Sulfoxide [S:0](=[OX1:1])(-[#6:2])(-[#6:3]) 1 B
78 | Sulfoxide [S:0](=[OX2+1:1]-[H:3])(=[#6:2]) 1 A
79 | Sulfoxide [S:0](=[OX1:1])(=[#6:2]) 1 B
80 | Hydrocyanic acid [N:0]#[C:1]-[H:2] 1 A
81 | Hydrocyanic acid [N:0]#[C-1:1] 1 B
82 | Phosphoryl group [PX4:0]=[OX2+1:1]-[H:3] 1 A
83 | Phosphoryl group [PX4:0]=[OX1:1] 1 B
84 | Selenonyl group [Se:0]=[OX2+1:1]-[H:3] 1 A
85 | Selenonyl group [Se:0]=[OX1:1] 1 B
86 | Arsenyl group [AsX4:0]=[OX2+1:1]-[H:2] 1 A
87 | Arsenyl group [AsX4:0]=[OX1:1] 1 B
88 | Carboxyl group [#6X3:0](:,-[O,#7,S:1])=[OX2+1,SX2+1:2]-[H:3] 2 A
89 | Carboxyl group [#6X3:0](:,-[O,#7,S:1])=[OX1,SX1:2] 2 B
90 | Carboxyl group vinylogue [#6X3:0](:,-[#6:1]:,=[#6:2]:,-[O,#7,S:3])=[OX2+1,SX2+1:4]-[H:5] 4 A
91 | Carboxyl group vinylogue [#6X3:0](:,-[#6:1]:,=[#6:2]:,-[O,#7,S:3])=[OX1,SX1:4] 4 B
92 | Carbonyl group [#6X3:0](:,-[#1,#6:1])(:,-[#1,#6:2])=[OX2+1:3]-[H:4] 3 A
93 | Carbonyl group [#6X3:0](:,-[#1,#6:1])(:,-[#1,#6:2])=[OX1:3] 3 B
94 | Cyano group [C:0]#[N:1]-[H:2] 1 A
95 | Cyano group [C:0]#[NX1:1] 1 B
96 | Hydroxyl group [CX4:0](-[#6,#1:1])(-[#6,#1:2])(-[#6,#1:3])-[OH2+1:4] 4 A
97 | Hydroxyl group [CX4:0](-[#6,#1:1])(-[#6,#1:2])(-[#6,#1:3])-[OX2:4]-[H:5] 4 B
98 | Selenol [SeX2:0]-[H:1] 0 A
99 | Selenol [SeX1-1:0] 0 B
100 | Borate [BX3:0]-[OX2:1]-[H:2] 1 A
101 | Borate [BX3:0]-[O-1:1] 1 B
102 | Bromomethane [Br:0]-[CH3:1]-[H:2] 1 A
103 | Bromomethane [Br:0]-[CH2-1:1] 1 B
104 | Cyclopentadiene [#6X4:0](-[#1:5])1-,:[#6:1]=,:[#6:2]-,:[#6:3]=,:[#6:4]-,:1 0 A
105 | Cyclopentadiene [#6X3-1:0]1-,:[#6:1]=,:[#6:2]-,:[#6:3]=,:[#6:4]-,:1 0 B
106 | Tin alkyl [N+:0]-[CX4:1](-[H:3])-[SnX4:2] 1 A
107 | Tin alkyl [N+:0]-[CX3-1:1]-[SnX4:2] 1 B
--------------------------------------------------------------------------------
/finetune_pka.sh:
--------------------------------------------------------------------------------
1 | data_path='dwar'
2 | MASTER_PORT=10090
3 | task_name="dwar-iBond"
4 | head_name='chembl'
5 | weight_path='pretrain_save/checkpoint_best.pt'
6 | n_gpu=1
7 | save_dir='finetune_save'
8 |
9 | # train params
10 | seed=0
11 | nfolds=5
12 | cv_seed=42
13 | task_num=1
14 | loss_func="finetune_mse"
15 | dict_name='dict.txt'
16 | charge_dict_name='dict_charge.txt'
17 | only_polar=-1
18 | conf_size=11
19 | local_batch_size=16
20 | lr=3e-4
21 | bs=32
22 | epoch=20
23 | dropout=0.1
24 | warmup=0.06
25 |
26 | for ((fold=0;fold<$nfolds;fold++))
27 | do
28 | export NCCL_ASYNC_ERROR_HANDLING=1
29 | export OMP_NUM_THREADS=1
30 | echo "params setting lr: $lr, bs: $bs, epoch: $epoch, dropout: $dropout, warmup: $warmup, cv_seed: $cv_seed, fold: $fold"
31 | update_freq=`expr $bs / $local_batch_size`
32 | fold_save_dir="fold_${fold}"
33 | model_dir="${save_dir}/${fold_save_dir}"
34 | python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT \
35 | $(which unicore-train) $data_path --task-name $task_name --user-dir ./unimol --train-subset train --valid-subset valid \
36 | --conf-size $conf_size --nfolds $nfolds --fold $fold --cv-seed $cv_seed\
37 | --num-workers 8 --ddp-backend=c10d \
38 | --dict-name $dict_name --charge-dict-name $charge_dict_name \
39 | --task mol_pka --loss $loss_func --arch unimol_pka \
40 | --classification-head-name $head_name --num-classes $task_num \
41 | --optimizer adam --adam-betas '(0.9, 0.99)' --adam-eps 1e-6 --clip-norm 1.0 \
42 | --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $local_batch_size --pooler-dropout $dropout\
43 | --update-freq $update_freq --seed $seed \
44 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \
45 | --log-interval 100 --log-format simple \
46 | --finetune-from-model $weight_path \
47 | --validate-interval 1 --keep-last-epochs 1 \
48 | --all-gather-list-size 102400 \
49 | --save-dir $model_dir \
50 | --best-checkpoint-metric valid_rmse --patience 2000 \
51 | --only-polar $only_polar --split-mode cross_valid
52 | done
--------------------------------------------------------------------------------
/image/inference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dptech-corp/Uni-pKa/7cfbf6532d7e14f427abaed5859bd1001b4b6377/image/inference.png
--------------------------------------------------------------------------------
/image/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dptech-corp/Uni-pKa/7cfbf6532d7e14f427abaed5859bd1001b4b6377/image/overview.png
--------------------------------------------------------------------------------
/image/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dptech-corp/Uni-pKa/7cfbf6532d7e14f427abaed5859bd1001b4b6377/image/performance.png
--------------------------------------------------------------------------------
/image/protensemble.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dptech-corp/Uni-pKa/7cfbf6532d7e14f427abaed5859bd1001b4b6377/image/protensemble.png
--------------------------------------------------------------------------------
/infer_free_energy.sh:
--------------------------------------------------------------------------------
1 | data_path='./unimol/examples'
2 | infer_task='fe_example'
3 | results_path='fe_results'
4 | head_name='chembl_small'
5 | conf_size=11
6 | dict_name='dict.txt'
7 | charge_dict_name='dict_charge.txt'
8 | task_num=1
9 | batch_size=16
10 | model_path='dwar_finetune'
11 | loss_func="infer_free_energy"
12 | nfolds=5
13 | only_polar=-1
14 |
15 | for ((fold=0;fold<$nfolds;fold++))
16 | do
17 | python ./unimol/infer.py --user-dir ./unimol ${data_path} --task-name $infer_task --valid-subset $infer_task \
18 | --results-path $results_path/fold_${fold} \
19 | --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \
20 | --task mol_free_energy --loss $loss_func --arch unimol \
21 | --classification-head-name $head_name --num-classes $task_num \
22 | --dict-name $dict_name --charge-dict-name $charge_dict_name --conf-size $conf_size \
23 | --only-polar $only_polar \
24 | --path $model_path/fold_$fold/checkpoint_best.pt \
25 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \
26 | --log-interval 50 --log-format simple --required-batch-size-multiple 1
27 | done
--------------------------------------------------------------------------------
/infer_pka.sh:
--------------------------------------------------------------------------------
1 | data_path='novartis_acid'
2 | infer_task='novartis_acid'
3 | results_path='novartis_acid_results'
4 | model_path='finetune_save'
5 | head_name='chembl'
6 | conf_size=11
7 | dict_name='dict.txt'
8 | charge_dict_name='dict_charge.txt'
9 | task_num=1
10 | batch_size=16
11 | loss_func="finetune_mse"
12 | nfolds=5
13 | only_polar=-1
14 |
15 | for ((fold=0;fold<$nfolds;fold++))
16 | do
17 | python ./unimol/infer.py --user-dir ./unimol ${data_path} --task-name $infer_task --valid-subset $infer_task \
18 | --results-path $results_path/fold_${fold} \
19 | --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \
20 | --task mol_pka --loss $loss_func --arch unimol_pka \
21 | --classification-head-name $head_name --num-classes $task_num \
22 | --dict-name $dict_name --charge-dict-name $charge_dict_name --conf-size $conf_size \
23 | --only-polar $only_polar \
24 | --path $model_path/fold_$fold/checkpoint_best.pt \
25 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \
26 | --log-interval 50 --log-format simple --required-batch-size-multiple 1
27 | done
--------------------------------------------------------------------------------
/pretrain_pka.sh:
--------------------------------------------------------------------------------
1 | data_path="chembl"
2 | task_name="chembl"
3 | n_gpu=8
4 | save_dir="pretrain_save"
5 | tmp_save_dir="tmp_save"
6 | task_num=1
7 | loss_func="pretrain_mlm"
8 | dict_name="dict.txt"
9 | charge_dict_name="dict_charge.txt"
10 | only_polar=-1
11 | conf_size=11
12 | split_mode="predefine"
13 | MASTER_PORT=10090
14 |
15 | # train params
16 | local_batch_size=16
17 | batch_size=16
18 | lr=1e-4
19 | epoch=100
20 | dropout=0.1
21 | warmup=0.06
22 | seed=0
23 | mask_prob=0.05
24 | update_freq=`expr $batch_size / $local_batch_size`
25 | global_batch_size=`expr $batch_size \* $n_gpu \* $update_freq`
26 | echo "params setting lr: $lr, bs: $global_batch_size, epoch: $epoch, dropout: $dropout, warmup: $warmup, seed: $seed"
27 |
28 | # loss
29 | masked_token_loss=1
30 | masked_charge_loss=2
31 | masked_coord_loss=2
32 | masked_dist_loss=1
33 | x_norm_loss=0.01
34 | delta_pair_repr_norm_loss=0.01
35 |
36 |
37 | export NCCL_ASYNC_ERROR_HANDLING=1
38 | export OMP_NUM_THREADS=1
39 | python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT \
40 | $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \
41 | --conf-size $conf_size \
42 | --num-workers 8 --ddp-backend=c10d \
43 | --dict-name $dict_name --charge-dict-name $charge_dict_name \
44 | --task mol_pka_mlm --loss $loss_func --arch unimol_pka \
45 | --classification-head-name $task_name --num-classes $task_num \
46 | --optimizer adam --adam-betas "(0.9, 0.99)" --adam-eps 1e-6 --clip-norm 1.0 \
47 | --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $local_batch_size --pooler-dropout $dropout\
48 | --update-freq $update_freq --seed $seed \
49 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \
50 | --keep-last-epochs 1 \
51 | --log-interval 100 --log-format simple \
52 | --validate-interval 1 \
53 | --save-dir $save_dir --tmp-save-dir $tmp_save_dir --tensorboard-logdir $save_dir/tsb \
54 | --best-checkpoint-metric valid_rmse --patience 2000 \
55 | --only-polar $only_polar --mask-prob $mask_prob \
56 | --masked-token-loss $masked_token_loss --masked-coord-loss $masked_coord_loss --masked-dist-loss $masked_dist_loss \
57 | --masked-charge-loss $masked_charge_loss --x-norm-loss $x_norm_loss --delta-pair-repr-norm-loss $delta_pair_repr_norm_loss
58 |
59 |
--------------------------------------------------------------------------------
/scripts/infer_mean_ensemble.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import pandas as pd
6 | import os
7 | import argparse
8 | import numpy as np
9 | import glob
10 |
11 |
12 | def cal_metrics(df):
13 | mae = np.abs(df["predict"] - df["target"]).mean()
14 | mse = ((df["predict"] - df["target"]) ** 2).mean()
15 | rmse = np.sqrt(mse)
16 | return mae, rmse
17 |
18 |
19 | def get_csv_results(results_path, nfolds, task):
20 |
21 | all_smi_list, all_predict_list, all_target_list = [], [], []
22 |
23 | for fold_idx in range(nfolds):
24 | print(f"Processing fold {fold_idx}...")
25 | fold_path = os.path.join(results_path, f'fold_{fold_idx}')
26 | pkl_files = glob.glob(f"{fold_path}/*.pkl")
27 | fold_data = pd.read_pickle(pkl_files[0])
28 |
29 | smi_list, predict_list, target_list = [], [], []
30 | for batch in fold_data:
31 | sz = batch["bsz"]
32 | for i in range(sz):
33 | smi_list.append(batch["smi_name"][i])
34 | predict_list.append(batch["predict"][i].cpu().item())
35 | target_list.append(batch["target"][i].cpu().item())
36 | fold_df = pd.DataFrame({"smiles": smi_list, "predict": predict_list, "target": target_list})
37 | fold_df.to_csv(f'{fold_path}/fold_{fold_idx}.csv',index=False, sep='\t')
38 |
39 | # for final combined results
40 | all_smi_list.extend(smi_list)
41 | all_predict_list.extend(predict_list)
42 | all_target_list.extend(target_list)
43 |
44 | print(f"Combining results from {nfolds} folds into a single file...")
45 | combined_df = pd.DataFrame({"smiles": all_smi_list, "predict": all_predict_list, "target": all_target_list})
46 | combined_df.to_csv(f'{results_path}/all_results.csv', index=False, sep='\t')
47 |
48 | print(f"Calculating mean results for each SMILES...")
49 | mean_results = combined_df.groupby('smiles', as_index=False).agg({
50 | 'predict': 'mean',
51 | 'target': 'mean'
52 | })
53 | mean_results.to_csv(f'{results_path}/mean_results.csv', index=False, sep='\t')
54 | if task == 'pka':
55 | print(f"MAE and RMSE for this task...")
56 | mae, rmse = cal_metrics(mean_results)
57 | print(f'MAE: {round(mae, 4)}, RMSE: {round(rmse, 4)}')
58 | print(f"Done!")
59 |
60 |
61 | def main():
62 | parser = argparse.ArgumentParser(description='Model infer result mean ensemble')
63 | parser.add_argument(
64 | '--results-path',
65 | type=str,
66 | default='results',
67 | help='path to save infer results'
68 | )
69 | parser.add_argument(
70 | "--nfolds",
71 | default=5,
72 | type=int,
73 | help="cross validation split folds"
74 | )
75 | parser.add_argument(
76 | "--task",
77 | default='pka',
78 | type=str,
79 | choices=['pka', 'free_energy']
80 | )
81 | args = parser.parse_args()
82 | get_csv_results(args.results_path, args.nfolds, args.task)
83 |
84 |
85 | if __name__ == "__main__":
86 | main()
--------------------------------------------------------------------------------
/scripts/preprocess_pka.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import os
6 | import pickle
7 | import lmdb
8 | import pandas as pd
9 | import numpy as np
10 | from rdkit import Chem
11 | from tqdm import tqdm
12 | from rdkit.Chem import AllChem
13 | from rdkit.Chem.Scaffolds import MurckoScaffold
14 | from rdkit import RDLogger
15 | RDLogger.DisableLog('rdApp.*')
16 | import warnings
17 | warnings.filterwarnings(action='ignore')
18 | from multiprocessing import Pool
19 | import argparse
20 |
21 |
22 | def smi2scaffold(smi):
23 | try:
24 | return MurckoScaffold.MurckoScaffoldSmiles(
25 | smiles=smi, includeChirality=True)
26 | except:
27 | print("failed to generate scaffold with smiles: {}".format(smi))
28 | return smi
29 |
30 |
31 | def smi2_2Dcoords(smi):
32 | mol = Chem.MolFromSmiles(smi)
33 | mol = AllChem.AddHs(mol)
34 | AllChem.Compute2DCoords(mol)
35 | coordinates = mol.GetConformer().GetPositions().astype(np.float32)
36 | len(mol.GetAtoms()) == len(coordinates), "2D coordinates shape is not align with {}".format(smi)
37 | return coordinates
38 |
39 |
40 | def smi2_3Dcoords(smi,cnt, gen_mode='mmff'):
41 | assert gen_mode in ['mmff', 'no_mmff']
42 | mol = Chem.MolFromSmiles(smi)
43 | mol = AllChem.AddHs(mol)
44 | coordinate_list=[]
45 | for seed in range(cnt):
46 | try:
47 | res = AllChem.EmbedMolecule(mol, randomSeed=seed)
48 | if res == 0:
49 | try:
50 | if gen_mode == 'mmff':
51 | AllChem.MMFFOptimizeMolecule(mol) # some conformer can not use MMFF optimize
52 | coordinates = mol.GetConformer().GetPositions()
53 | except:
54 | print("Failed to generate 3D, replace with 2D")
55 | coordinates = smi2_2Dcoords(smi)
56 |
57 | elif res == -1:
58 | mol_tmp = Chem.MolFromSmiles(smi)
59 | AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)
60 | mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)
61 | try:
62 | if gen_mode == 'mmff':
63 | AllChem.MMFFOptimizeMolecule(mol_tmp) # some conformer can not use MMFF optimize
64 | coordinates = mol_tmp.GetConformer().GetPositions()
65 | except:
66 | print("Failed to generate 3D, replace with 2D")
67 | coordinates = smi2_2Dcoords(smi)
68 | except:
69 | print("Failed to generate 3D, replace with 2D")
70 | coordinates = smi2_2Dcoords(smi)
71 |
72 | assert len(mol.GetAtoms()) == len(coordinates), "3D coordinates shape is not align with {}".format(smi)
73 | coordinate_list.append(coordinates.astype(np.float32))
74 | return coordinate_list
75 |
76 |
77 | def smi2metadata(smi, cnt, gen_mode): # input: single smi; output: molecule metadata (atom, charge, mol. smi, coords...)
78 | scaffold = smi2scaffold(smi)
79 | mol = Chem.MolFromSmiles(smi)
80 |
81 | if len(mol.GetAtoms()) > 400:
82 | coordinate_list = [smi2_2Dcoords(smi)] * (cnt+1)
83 | print("atom num >400,use 2D coords",smi)
84 | else:
85 | # gen cnt num 3D conf
86 | coordinate_list = smi2_3Dcoords(smi,cnt, gen_mode)
87 | # gen 1 2D conf
88 | coordinate_list.append(smi2_2Dcoords(smi).astype(np.float32))
89 | mol = AllChem.AddHs(mol)
90 | atoms, charges = [], []
91 | for atom in mol.GetAtoms():
92 | atoms.append(atom.GetSymbol())
93 | charges.append(atom.GetFormalCharge())
94 |
95 | return {'atoms': atoms,'charges': charges, 'coordinates': coordinate_list, 'mol':mol,'smi': smi, 'scaffold': scaffold}
96 |
97 |
98 | def inner_smi2coords_pka(content):
99 | smi_all, target = content
100 | cnt = 10 # conformer num,all==11, 10 3d + 1 2d
101 | gen_mode = 'mmff' # 'mmff', 'no_mmff'
102 |
103 | # get single smi from original SMILES
104 | smi_list_a, smi_list_b = smi_all.split('>>')
105 | smi_list_a = smi_list_a.split(',')
106 | smi_list_b = smi_list_b.split(',')
107 |
108 | # get whole datapoint metadata
109 | metadata_a, metadata_b = [], []
110 | for i in range(len(smi_list_a)):
111 | metadata_a.append(smi2metadata(smi_list_a[i], cnt, gen_mode))
112 | for i in range(len(smi_list_b)):
113 | metadata_b.append(smi2metadata(smi_list_b[i], cnt, gen_mode))
114 |
115 | return pickle.dumps({'ori_smi': smi_all, 'metadata_a': metadata_a, 'metadata_b': metadata_b, 'target': target}, protocol=-1)
116 |
117 |
118 | def smi2coords(content):
119 | try:
120 | return inner_smi2coords_pka(content)
121 | except:
122 | print("failed smiles: {}".format(content[0]))
123 | return None
124 |
125 |
126 | def load_rawdata_pka(input_csv):
127 |
128 | # read tsv file
129 | df = pd.read_csv(input_csv, sep='\t')
130 | smi_col = 'SMILES'
131 | target_col = 'TARGET'
132 | if target_col not in df.columns:
133 | # If not exist, add "-1.0" as a placeholder.
134 | df["TARGET"] = -1.0
135 | col_list = [smi_col, target_col]
136 | df = df[col_list]
137 | print(f'raw_data size: {df.shape[0]}')
138 | return df
139 |
140 | def write_lmdb(task_name, input_csv, output_dir='.', nthreads=16):
141 |
142 | df = load_rawdata_pka(input_csv)
143 | content_list = zip(*[df[c].values.tolist() for c in df])
144 | lmdb_name = '{}.lmdb'.format(task_name)
145 | os.makedirs(output_dir, exist_ok=True)
146 | output_name = os.path.join(output_dir, lmdb_name)
147 | try:
148 | os.remove(output_name)
149 | except:
150 | pass
151 | env_new = lmdb.open(
152 | output_name,
153 | subdir=False,
154 | readonly=False,
155 | lock=False,
156 | readahead=False,
157 | meminit=False,
158 | max_readers=1,
159 | map_size=int(100e9),
160 | )
161 | txn_write = env_new.begin(write=True)
162 | with Pool(nthreads) as pool:
163 | i = 0
164 | for inner_output in tqdm(pool.imap(smi2coords, content_list), total=len(df)):
165 | if inner_output is not None:
166 | txn_write.put(f'{i}'.encode("ascii"), inner_output)
167 | i += 1
168 | print('{} process {} lines'.format(lmdb_name, i))
169 | txn_write.commit()
170 | env_new.close()
171 |
172 | def main():
173 | parser = argparse.ArgumentParser(
174 | description="use rdkit to generate conformers"
175 | )
176 | parser.add_argument(
177 | "--raw-csv-file",
178 | type=str,
179 | default="Datasets/tsv/chembl_train.tsv",
180 | help="the original data csv file path",
181 | )
182 | parser.add_argument(
183 | "--processed-lmdb-dir",
184 | type=str,
185 | default="chembl",
186 | help="dir of the processed lmdb data",
187 | )
188 | parser.add_argument("--nthreads", type=int, default=22, help="num of threads")
189 | parser.add_argument(
190 | "--task-name",
191 | type=str,
192 | default="train",
193 | help="name of the lmdb file; train and valid for chembl",
194 | choices=['train', 'valid', 'dwar-iBond', 'novartis_acid', 'novartis_base', 'sampl6', 'sampl7', 'sampl8']
195 | )
196 | args = parser.parse_args()
197 | write_lmdb(task_name = args.task_name, input_csv=args.raw_csv_file, output_dir=args.processed_lmdb_dir, nthreads = args.nthreads)
198 |
199 |
200 | if __name__ == '__main__':
201 | main()
202 |
--------------------------------------------------------------------------------
/unimol/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import unimol.tasks
3 | import unimol.data
4 | import unimol.models
5 | import unimol.losses
6 |
--------------------------------------------------------------------------------
/unimol/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .key_dataset import KeyDataset
2 | from .normalize_dataset import (
3 | NormalizeDataset,
4 | )
5 | from .remove_hydrogen_dataset import (
6 | RemoveHydrogenDataset,
7 | )
8 | from .tta_dataset import (
9 | TTADataset,
10 | TTAPKADataset,
11 | )
12 | from .cropping_dataset import CroppingDataset
13 |
14 | from .distance_dataset import (
15 | DistanceDataset,
16 | EdgeTypeDataset,
17 | )
18 | from .conformer_sample_dataset import (
19 | ConformerSamplePKADataset,
20 | )
21 | from .coord_pad_dataset import RightPadDatasetCoord
22 | from .lmdb_dataset import (
23 | FoldLMDBDataset,
24 | StackedLMDBDataset,
25 | SplitLMDBDataset,
26 | )
27 | from .pka_input_dataset import (
28 | PKAInputDataset,
29 | PKAMLMInputDataset,
30 | )
31 | from .mask_points_dataset import MaskPointsDataset
32 |
33 | __all__ = []
--------------------------------------------------------------------------------
/unimol/data/conformer_sample_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import numpy as np
6 | from functools import lru_cache
7 | from unicore.data import BaseWrapperDataset
8 | from . import data_utils
9 |
10 |
11 | class ConformerSamplePKADataset(BaseWrapperDataset):
12 | def __init__(self, dataset, seed, atoms, coordinates, charges, id="ori_smi"):
13 | self.dataset = dataset
14 | self.seed = seed
15 | self.atoms = atoms
16 | self.coordinates = coordinates
17 | self.charges = charges
18 | self.id = id
19 | self.set_epoch(None)
20 |
21 | def set_epoch(self, epoch, **unused):
22 | super().set_epoch(epoch)
23 | self.epoch = epoch
24 |
25 | @lru_cache(maxsize=16)
26 | def __cached_item__(self, index: int, epoch: int):
27 | atoms = np.array(self.dataset[index][self.atoms])
28 | charges = np.array(self.dataset[index][self.charges])
29 | assert len(atoms) > 0, 'atoms: {}, charges: {}, coordinates: {}, id: {}'.format(atoms, charges, coordinates, id)
30 | size = len(self.dataset[index][self.coordinates])
31 | with data_utils.numpy_seed(self.seed, epoch, index):
32 | sample_idx = np.random.randint(size)
33 | coordinates = self.dataset[index][self.coordinates][sample_idx]
34 | return {"atoms": atoms, "coordinates": coordinates.astype(np.float32),"charges":charges,"id": self.id}
35 |
36 | def __getitem__(self, index: int):
37 | return self.__cached_item__(index, self.epoch)
38 |
--------------------------------------------------------------------------------
/unimol/data/coord_pad_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | from unicore.data import BaseWrapperDataset
6 |
7 |
8 | def collate_tokens_coords(
9 | values,
10 | pad_idx,
11 | left_pad=False,
12 | pad_to_length=None,
13 | pad_to_multiple=1,
14 | ):
15 | """Convert a list of 1d tensors into a padded 2d tensor."""
16 | size = max(v.size(0) for v in values)
17 | size = size if pad_to_length is None else max(size, pad_to_length)
18 | if pad_to_multiple != 1 and size % pad_to_multiple != 0:
19 | size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
20 | res = values[0].new(len(values), size, 3).fill_(pad_idx)
21 |
22 | def copy_tensor(src, dst):
23 | assert dst.numel() == src.numel()
24 | dst.copy_(src)
25 |
26 | for i, v in enumerate(values):
27 | copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :])
28 | return res
29 |
30 |
31 | class RightPadDatasetCoord(BaseWrapperDataset):
32 | def __init__(self, dataset, pad_idx, left_pad=False):
33 | super().__init__(dataset)
34 | self.pad_idx = pad_idx
35 | self.left_pad = left_pad
36 |
37 | def collater(self, samples):
38 | return collate_tokens_coords(
39 | samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8
40 | )
41 |
--------------------------------------------------------------------------------
/unimol/data/cropping_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import numpy as np
6 | from functools import lru_cache
7 | import logging
8 | from unicore.data import BaseWrapperDataset
9 | from . import data_utils
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class CroppingDataset(BaseWrapperDataset):
15 | def __init__(self, dataset, seed, atoms, coordinates, charges, max_atoms=256):
16 | self.dataset = dataset
17 | self.seed = seed
18 | self.atoms = atoms
19 | self.coordinates = coordinates
20 | self.charges = charges
21 | self.max_atoms = max_atoms
22 | self.set_epoch(None)
23 |
24 | def set_epoch(self, epoch, **unused):
25 | super().set_epoch(epoch)
26 | self.epoch = epoch
27 |
28 | @lru_cache(maxsize=16)
29 | def __cached_item__(self, index: int, epoch: int):
30 | dd = self.dataset[index].copy()
31 | atoms = dd[self.atoms]
32 | coordinates = dd[self.coordinates]
33 | charges = dd[self.charges]
34 | if self.max_atoms and len(atoms) > self.max_atoms:
35 | with data_utils.numpy_seed(self.seed, epoch, index):
36 | index = np.random.choice(len(atoms), self.max_atoms, replace=False)
37 | atoms = np.array(atoms)[index]
38 | coordinates = coordinates[index]
39 | charges = charges[index]
40 | dd[self.atoms] = atoms
41 | dd[self.coordinates] = coordinates.astype(np.float32)
42 | dd[self.charges] = charges
43 | return dd
44 |
45 | def __getitem__(self, index: int):
46 | return self.__cached_item__(index, self.epoch)
47 |
--------------------------------------------------------------------------------
/unimol/data/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import numpy as np
6 | import contextlib
7 |
8 |
9 | @contextlib.contextmanager
10 | def numpy_seed(seed, *addl_seeds):
11 | """Context manager which seeds the NumPy PRNG with the specified seed and
12 | restores the state afterward"""
13 | if seed is None:
14 | yield
15 | return
16 | if len(addl_seeds) > 0:
17 | seed = int(hash((seed, *addl_seeds)) % 1e6)
18 | state = np.random.get_state()
19 | np.random.seed(seed)
20 | try:
21 | yield
22 | finally:
23 | np.random.set_state(state)
24 |
--------------------------------------------------------------------------------
/unimol/data/distance_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import numpy as np
6 | import torch
7 | from scipy.spatial import distance_matrix
8 | from functools import lru_cache
9 | from unicore.data import BaseWrapperDataset
10 |
11 |
12 | class DistanceDataset(BaseWrapperDataset):
13 | def __init__(self, dataset):
14 | super().__init__(dataset)
15 | self.dataset = dataset
16 |
17 | @lru_cache(maxsize=16)
18 | def __getitem__(self, idx):
19 | pos = self.dataset[idx].view(-1, 3).numpy()
20 | dist = distance_matrix(pos, pos).astype(np.float32)
21 | return torch.from_numpy(dist)
22 |
23 |
24 | class EdgeTypeDataset(BaseWrapperDataset):
25 | def __init__(self, dataset: torch.utils.data.Dataset, num_types: int):
26 | self.dataset = dataset
27 | self.num_types = num_types
28 |
29 | @lru_cache(maxsize=16)
30 | def __getitem__(self, index: int):
31 | node_input = self.dataset[index].clone()
32 | offset = node_input.view(-1, 1) * self.num_types + node_input.view(1, -1)
33 | return offset
34 |
--------------------------------------------------------------------------------
/unimol/data/key_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | from functools import lru_cache
6 | from unicore.data import BaseWrapperDataset
7 |
8 |
9 | class KeyDataset(BaseWrapperDataset):
10 | def __init__(self, dataset, key):
11 | self.dataset = dataset
12 | self.key = key
13 |
14 | def __len__(self):
15 | return len(self.dataset)
16 |
17 | @lru_cache(maxsize=16)
18 | def __getitem__(self, idx):
19 | return self.dataset[idx][self.key]
20 |
--------------------------------------------------------------------------------
/unimol/data/lmdb_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 |
6 | import lmdb
7 | import os
8 | import pickle
9 | from functools import lru_cache
10 | from . import data_utils
11 | import numpy as np
12 | import logging
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class FoldLMDBDataset:
18 | def __init__(self, dataset, seed, cur_fold, nfolds=5, cache_fold_info=None):
19 | super().__init__()
20 | self.dataset = dataset
21 | if cache_fold_info is None:
22 | self.keys = []
23 | self.fold_start = []
24 | self.fold_end = []
25 | self.init_random_split(dataset, seed, nfolds)
26 | else:
27 | # use cache fold info
28 | self.keys, self.fold_start, self.fold_end = cache_fold_info
29 | self.cur_fold = cur_fold
30 | self._len = self.fold_end[cur_fold] - self.fold_start[cur_fold]
31 | assert len(self.fold_end) == len(self.fold_start) == nfolds
32 |
33 | def init_random_split(self, dataset, seed, nfolds):
34 | with data_utils.numpy_seed(seed):
35 | self.keys = np.random.permutation(len(dataset))
36 | average_size = (len(dataset) + nfolds - 1) // nfolds
37 | cur_size = 0
38 | for i in range(nfolds):
39 | self.fold_start.append(cur_size)
40 | cur_size = min(cur_size + average_size, len(dataset))
41 | self.fold_end.append(cur_size)
42 |
43 | def get_fold_info(self):
44 | return self.keys, self.fold_start, self.fold_end
45 |
46 | def __len__(self):
47 | return self._len
48 |
49 | @lru_cache(maxsize=16)
50 | def __getitem__(self, idx):
51 | global_idx = idx + self.fold_start[self.cur_fold]
52 | return self.dataset[self.keys[global_idx]]
53 |
54 |
55 | class StackedLMDBDataset:
56 | def __init__(self, datasets):
57 | self._len = 0
58 | self.datasets = []
59 | self.idx_to_file = {}
60 | self.idx_offset = []
61 | for dataset in datasets:
62 | self.datasets.append(dataset)
63 | for i in range(len(dataset)):
64 | self.idx_to_file[i + self._len] = len(self.datasets) - 1
65 | self.idx_offset.append(self._len)
66 | self._len += len(dataset)
67 |
68 | def __len__(self):
69 | return self._len
70 |
71 | @lru_cache(maxsize=16)
72 | def __getitem__(self, idx):
73 | file_idx = self.idx_to_file[idx]
74 | sub_idx = idx - self.idx_offset[file_idx]
75 | return self.datasets[file_idx][sub_idx]
76 |
77 |
78 | class SplitLMDBDataset:
79 | # train:valid = 9:1
80 | def __init__(self, dataset, seed, cur_fold, cache_fold_info= None,frac_train=0.9, frac_valid=0.1):
81 | super().__init__()
82 | self.dataset = dataset
83 | np.testing.assert_almost_equal(frac_train + frac_valid, 1.0)
84 | frac = [frac_train,frac_valid]
85 | if cache_fold_info is None:
86 | self.keys = []
87 | self.fold_start = []
88 | self.fold_end = []
89 | self.init_random_split(dataset, seed, frac)
90 | else:
91 | # use cache fold info
92 | self.keys, self.fold_start, self.fold_end = cache_fold_info
93 | self.cur_fold = cur_fold
94 | self._len = self.fold_end[cur_fold] - self.fold_start[cur_fold]
95 | assert len(self.fold_end) == len(self.fold_start) == 3
96 |
97 | def init_random_split(self, dataset, seed, frac):
98 | with data_utils.numpy_seed(seed):
99 | self.keys = np.random.permutation(len(dataset))
100 | frac_train,frac_valid = frac
101 | #average_size = (len(dataset) + nfolds - 1) // nfolds
102 | fold_size = [int(frac_train * len(dataset)), len(dataset)- int(frac_train * len(dataset))]
103 | assert sum(fold_size) == len(dataset)
104 | cur_size = 0
105 | for i in range(len(fold_size)):
106 | self.fold_start.append(cur_size)
107 | cur_size = min(cur_size + fold_size[i], len(dataset))
108 | self.fold_end.append(cur_size)
109 |
110 | def get_fold_info(self):
111 | return self.keys, self.fold_start, self.fold_end
112 |
113 | def __len__(self):
114 | return self._len
115 |
116 | @lru_cache(maxsize=16)
117 | def __getitem__(self, idx):
118 | global_idx = idx + self.fold_start[self.cur_fold]
119 | return self.dataset[self.keys[global_idx]]
120 |
--------------------------------------------------------------------------------
/unimol/data/mask_points_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | from functools import lru_cache
6 |
7 | import numpy as np
8 | import torch
9 | from unicore.data import Dictionary
10 | from unicore.data import BaseWrapperDataset
11 | from . import data_utils
12 |
13 |
14 | class MaskPointsDataset(BaseWrapperDataset):
15 | def __init__(
16 | self,
17 | dataset: torch.utils.data.Dataset,
18 | coord_dataset: torch.utils.data.Dataset,
19 | charge_dataset: torch.utils.data.Dataset,
20 | vocab: Dictionary,
21 | charge_vocab: Dictionary,
22 | pad_idx: int,
23 | charge_pad_idx: int,
24 | mask_idx: int,
25 | charge_mask_idx: int,
26 | noise_type: str,
27 | noise: float = 1.0,
28 | seed: int = 1,
29 | mask_prob: float = 0.15,
30 | leave_unmasked_prob: float = 0.1,
31 | random_token_prob: float = 0.1,
32 | ):
33 | assert 0.0 < mask_prob < 1.0
34 | assert 0.0 <= random_token_prob <= 1.0
35 | assert 0.0 <= leave_unmasked_prob <= 1.0
36 | assert random_token_prob + leave_unmasked_prob <= 1.0
37 |
38 | self.dataset = dataset
39 | self.coord_dataset = coord_dataset
40 | self.charge_dataset = charge_dataset
41 | self.vocab = vocab
42 | self.charge_vocab = charge_vocab
43 | self.pad_idx = pad_idx
44 | self.charge_pad_idx = charge_pad_idx
45 | self.mask_idx = mask_idx
46 | self.charge_mask_idx = charge_mask_idx
47 | self.noise_type = noise_type
48 | self.noise = noise
49 | self.seed = seed
50 | self.mask_prob = mask_prob
51 | self.leave_unmasked_prob = leave_unmasked_prob
52 | self.random_token_prob = random_token_prob
53 |
54 | if random_token_prob > 0.0:
55 | weights = np.ones(len(self.vocab))
56 | weights[vocab.special_index()] = 0
57 | self.weights = weights / weights.sum()
58 | # for charge
59 | charge_weights = np.ones(len(self.charge_vocab))
60 | charge_weights[charge_vocab.special_index()] = 0
61 | self.charge_weights = charge_weights / charge_weights.sum()
62 |
63 | self.epoch = None
64 | if self.noise_type == "trunc_normal":
65 | self.noise_f = lambda num_mask: np.clip(
66 | np.random.randn(num_mask, 3) * self.noise,
67 | a_min=-self.noise * 2.0,
68 | a_max=self.noise * 2.0,
69 | )
70 | elif self.noise_type == "normal":
71 | self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise
72 | elif self.noise_type == "uniform":
73 | self.noise_f = lambda num_mask: np.random.uniform(
74 | low=-self.noise, high=self.noise, size=(num_mask, 3)
75 | )
76 | else:
77 | self.noise_f = lambda num_mask: 0.0
78 |
79 | def set_epoch(self, epoch, **unused):
80 | super().set_epoch(epoch)
81 | self.coord_dataset.set_epoch(epoch)
82 | self.dataset.set_epoch(epoch)
83 | self.charge_dataset.set_epoch(epoch)
84 | self.epoch = epoch
85 |
86 | def __getitem__(self, index: int):
87 | return self.__getitem_cached__(self.epoch, index)
88 |
89 | @lru_cache(maxsize=16)
90 | def __getitem_cached__(self, epoch: int, index: int):
91 | ret = {}
92 | with data_utils.numpy_seed(self.seed, epoch, index):
93 | item = self.dataset[index]
94 | coord = self.coord_dataset[index]
95 | charge = self.charge_dataset[index]
96 | sz = len(item)
97 | # don't allow empty sequence
98 | assert sz > 0
99 | # decide elements to mask
100 | num_mask = int(
101 | # add a random number for probabilistic rounding
102 | self.mask_prob * sz
103 | + np.random.rand()
104 | )
105 | mask_idc = np.random.choice(sz, num_mask, replace=False)
106 | mask = np.full(sz, False)
107 | mask[mask_idc] = True
108 | ret["targets"] = np.full(len(mask), self.pad_idx)
109 | ret["targets"][mask] = item[mask]
110 | ret["targets"] = torch.from_numpy(ret["targets"]).long()
111 | # for charge
112 | ret["charge_targets"] = np.full(len(mask), self.charge_pad_idx)
113 | ret["charge_targets"][mask] = charge[mask]
114 | ret["charge_targets"] = torch.from_numpy(ret["charge_targets"]).long()
115 |
116 | # decide unmasking and random replacement
117 | rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
118 | if rand_or_unmask_prob > 0.0:
119 | rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
120 | if self.random_token_prob == 0.0:
121 | unmask = rand_or_unmask
122 | rand_mask = None
123 | elif self.leave_unmasked_prob == 0.0:
124 | unmask = None
125 | rand_mask = rand_or_unmask
126 | else:
127 | unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
128 | decision = np.random.rand(sz) < unmask_prob
129 | unmask = rand_or_unmask & decision
130 | rand_mask = rand_or_unmask & (~decision)
131 | else:
132 | unmask = rand_mask = None
133 |
134 | if unmask is not None:
135 | mask = mask ^ unmask
136 |
137 | new_item = np.copy(item)
138 | new_item[mask] = self.mask_idx
139 |
140 | num_mask = mask.astype(np.int32).sum()
141 | new_coord = np.copy(coord)
142 | new_coord[mask, :] += self.noise_f(num_mask)
143 |
144 | # for charge mask
145 | new_charge = np.copy(charge)
146 | new_charge[mask] = self.charge_mask_idx
147 |
148 | if rand_mask is not None:
149 | num_rand = rand_mask.sum()
150 | if num_rand > 0:
151 | new_item[rand_mask] = np.random.choice(
152 | len(self.vocab),
153 | num_rand,
154 | p=self.weights,
155 | )
156 | # for charge
157 | new_charge[rand_mask] = np.random.choice(
158 | len(self.charge_vocab),
159 | num_rand,
160 | p=self.charge_weights,
161 | )
162 | ret["atoms"] = torch.from_numpy(new_item).long()
163 | ret["coordinates"] = torch.from_numpy(new_coord).float()
164 | ret["charges"] = torch.from_numpy(new_charge).long()
165 | return ret
166 |
167 |
--------------------------------------------------------------------------------
/unimol/data/normalize_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import numpy as np
6 | from functools import lru_cache
7 | from unicore.data import BaseWrapperDataset
8 |
9 |
10 | class NormalizeDataset(BaseWrapperDataset):
11 | def __init__(self, dataset, coordinates, normalize_coord=True):
12 | self.dataset = dataset
13 | self.coordinates = coordinates
14 | self.normalize_coord = normalize_coord # normalize the coordinates.
15 | self.set_epoch(None)
16 |
17 | def set_epoch(self, epoch, **unused):
18 | super().set_epoch(epoch)
19 | self.epoch = epoch
20 |
21 | @lru_cache(maxsize=16)
22 | def __cached_item__(self, index: int, epoch: int):
23 | dd = self.dataset[index].copy()
24 | coordinates = dd[self.coordinates]
25 | # normalize
26 | if self.normalize_coord:
27 | coordinates = coordinates - coordinates.mean(axis=0)
28 | dd[self.coordinates] = coordinates.astype(np.float32)
29 | return dd
30 |
31 | def __getitem__(self, index: int):
32 | return self.__cached_item__(index, self.epoch)
33 |
--------------------------------------------------------------------------------
/unimol/data/pka_input_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import numpy as np
6 | from functools import lru_cache
7 | from unicore.data import BaseWrapperDataset
8 | import collections
9 | import torch
10 | from itertools import chain
11 | from unicore.data.data_utils import collate_tokens, collate_tokens_2d
12 | from .coord_pad_dataset import collate_tokens_coords
13 |
14 |
15 | class PKAInputDataset(BaseWrapperDataset):
16 | def __init__(self, idx2key, src_tokens, src_charges, src_coord, src_distance, src_edge_type, token_pad_idx, charge_pad_idx, split='train', conf_size=10):
17 | self.idx2key = idx2key
18 | self.dataset = src_tokens
19 | self.src_tokens = src_tokens
20 | self.src_charges = src_charges
21 | self.src_coord = src_coord
22 | self.src_distance = src_distance
23 | self.src_edge_type = src_edge_type
24 | self.token_pad_idx = token_pad_idx
25 | self.charge_pad_idx = charge_pad_idx
26 | self.split = split
27 | self.conf_size = conf_size
28 | self.left_pad = False
29 | self._init_rec2mol()
30 | self.set_epoch(None)
31 |
32 | def set_epoch(self, epoch, **unused):
33 | super().set_epoch(epoch)
34 | self.epoch = epoch
35 |
36 | def _init_rec2mol(self):
37 | self.rec2mol = collections.defaultdict(list)
38 | if self.split in ['train','train.small']:
39 | total_sz = len(self.idx2key)
40 | for i in range(total_sz):
41 | smi_idx, _ = self.idx2key[i]
42 | self.rec2mol[smi_idx].append(i)
43 | else:
44 | total_sz = len(self.idx2key)
45 | for i in range(total_sz):
46 | smi_idx, _ = self.idx2key[i]
47 | self.rec2mol[smi_idx].extend([i * self.conf_size + j for j in range(self.conf_size)])
48 |
49 |
50 | def __len__(self):
51 | return len(self.rec2mol)
52 |
53 | @lru_cache(maxsize=16)
54 | def __cached_item__(self, index: int, epoch: int):
55 | mol_list = self.rec2mol[index]
56 | src_tokens_list = []
57 | src_charges_list = []
58 | src_coord_list = []
59 | src_distance_list = []
60 | src_edge_type_list = []
61 | for i in mol_list:
62 | src_tokens_list.append(self.src_tokens[i])
63 | src_charges_list.append(self.src_charges[i])
64 | src_coord_list.append(self.src_coord[i])
65 | src_distance_list.append(self.src_distance[i])
66 | src_edge_type_list.append(self.src_edge_type[i])
67 |
68 | return src_tokens_list, src_charges_list,src_coord_list,src_distance_list,src_edge_type_list
69 |
70 | def __getitem__(self, index: int):
71 | return self.__cached_item__(index, self.epoch)
72 |
73 | def collater(self, samples):
74 | batch = [len(samples[i][0]) for i in range(len(samples))]
75 |
76 | src_tokens, src_charges, src_coord, src_distance, src_edge_type = [list(chain.from_iterable(i)) for i in zip(*samples)]
77 | src_tokens = collate_tokens(src_tokens, self.token_pad_idx, left_pad=self.left_pad, pad_to_multiple=8)
78 | src_charges = collate_tokens(src_charges, self.charge_pad_idx, left_pad=self.left_pad, pad_to_multiple=8)
79 | src_coord = collate_tokens_coords(src_coord, 0, left_pad=self.left_pad, pad_to_multiple=8)
80 | src_distance = collate_tokens_2d(src_distance, 0, left_pad=self.left_pad, pad_to_multiple=8)
81 | src_edge_type = collate_tokens_2d(src_edge_type, 0, left_pad=self.left_pad, pad_to_multiple=8)
82 |
83 | return src_tokens, src_charges, src_coord, src_distance, src_edge_type, batch
84 |
85 |
86 | class PKAMLMInputDataset(BaseWrapperDataset):
87 | def __init__(self, idx2key, src_tokens, src_charges, src_coord, src_distance, src_edge_type, token_targets, charge_targets, dist_targets, coord_targets, token_pad_idx, charge_pad_idx, split='train', conf_size=10):
88 | self.idx2key = idx2key
89 | self.dataset = src_tokens
90 | self.src_tokens = src_tokens
91 | self.src_charges = src_charges
92 | self.src_coord = src_coord
93 | self.src_distance = src_distance
94 | self.src_edge_type = src_edge_type
95 | self.token_targets = token_targets
96 | self.charge_targets = charge_targets
97 | self.dist_targets = dist_targets
98 | self.coord_targets = coord_targets
99 | self.token_pad_idx = token_pad_idx
100 | self.charge_pad_idx = charge_pad_idx
101 | self.split = split
102 | self.conf_size = conf_size
103 | self.left_pad = False
104 | self._init_rec2mol()
105 | self.set_epoch(None)
106 |
107 | def set_epoch(self, epoch, **unused):
108 | super().set_epoch(epoch)
109 | self.epoch = epoch
110 |
111 | def _init_rec2mol(self):
112 | self.rec2mol = collections.defaultdict(list)
113 | if self.split in ['train','train.small']:
114 | total_sz = len(self.idx2key)
115 | for i in range(total_sz):
116 | smi_idx, _ = self.idx2key[i]
117 | self.rec2mol[smi_idx].append(i)
118 | else:
119 | total_sz = len(self.idx2key)
120 | for i in range(total_sz):
121 | smi_idx, _ = self.idx2key[i]
122 | self.rec2mol[smi_idx].extend([i * self.conf_size + j for j in range(self.conf_size)])
123 |
124 |
125 | def __len__(self):
126 | return len(self.rec2mol)
127 |
128 | @lru_cache(maxsize=16)
129 | def __cached_item__(self, index: int, epoch: int):
130 | mol_list = self.rec2mol[index]
131 | src_tokens_list = []
132 | src_charges_list = []
133 | src_coord_list = []
134 | src_distance_list = []
135 | src_edge_type_list = []
136 | token_targets_list = []
137 | charge_targets_list = []
138 | coord_targets_list = []
139 | dist_targets_list = []
140 | for i in mol_list:
141 | src_tokens_list.append(self.src_tokens[i])
142 | src_charges_list.append(self.src_charges[i])
143 | src_coord_list.append(self.src_coord[i])
144 | src_distance_list.append(self.src_distance[i])
145 | src_edge_type_list.append(self.src_edge_type[i])
146 | token_targets_list.append(self.token_targets[i])
147 | charge_targets_list.append(self.charge_targets[i])
148 | coord_targets_list.append(self.coord_targets[i])
149 | dist_targets_list.append(self.dist_targets[i])
150 |
151 | return src_tokens_list, src_charges_list,src_coord_list,src_distance_list,src_edge_type_list, token_targets_list, charge_targets_list, coord_targets_list, dist_targets_list
152 |
153 | def __getitem__(self, index: int):
154 | return self.__cached_item__(index, self.epoch)
155 |
156 | def collater(self, samples):
157 | batch = [len(samples[i][0]) for i in range(len(samples))]
158 |
159 | src_tokens, src_charges, src_coord, src_distance, src_edge_type, token_targets, charge_targets, coord_targets, dist_targets = [list(chain.from_iterable(i)) for i in zip(*samples)]
160 | src_tokens = collate_tokens(src_tokens, self.token_pad_idx, left_pad=self.left_pad, pad_to_multiple=8)
161 | src_charges = collate_tokens(src_charges, self.charge_pad_idx, left_pad=self.left_pad, pad_to_multiple=8)
162 | src_coord = collate_tokens_coords(src_coord, 0, left_pad=self.left_pad, pad_to_multiple=8)
163 | src_distance = collate_tokens_2d(src_distance, 0, left_pad=self.left_pad, pad_to_multiple=8)
164 | src_edge_type = collate_tokens_2d(src_edge_type, 0, left_pad=self.left_pad, pad_to_multiple=8)
165 | token_targets = collate_tokens(token_targets, self.token_pad_idx, left_pad=self.left_pad, pad_to_multiple=8)
166 | charge_targets = collate_tokens(charge_targets, self.charge_pad_idx, left_pad=self.left_pad, pad_to_multiple=8)
167 | coord_targets = collate_tokens_coords(coord_targets, 0, left_pad=self.left_pad, pad_to_multiple=8)
168 | dist_targets = collate_tokens_2d(dist_targets, 0, left_pad=self.left_pad, pad_to_multiple=8)
169 |
170 | return src_tokens, src_charges, src_coord, src_distance, src_edge_type, batch, charge_targets, coord_targets, dist_targets, token_targets
--------------------------------------------------------------------------------
/unimol/data/remove_hydrogen_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import numpy as np
6 | from functools import lru_cache
7 | from unicore.data import BaseWrapperDataset
8 |
9 |
10 | class RemoveHydrogenDataset(BaseWrapperDataset):
11 | def __init__(
12 | self,
13 | dataset,
14 | atoms,
15 | coordinates,
16 | charges,
17 | remove_hydrogen=False,
18 | remove_polar_hydrogen=False,
19 | ):
20 | self.dataset = dataset
21 | self.atoms = atoms
22 | self.coordinates = coordinates
23 | self.charges = charges
24 | self.remove_hydrogen = remove_hydrogen
25 | self.remove_polar_hydrogen = remove_polar_hydrogen
26 | self.set_epoch(None)
27 |
28 | def set_epoch(self, epoch, **unused):
29 | super().set_epoch(epoch)
30 | self.epoch = epoch
31 |
32 | @lru_cache(maxsize=16)
33 | def __cached_item__(self, index: int, epoch: int):
34 | dd = self.dataset[index].copy()
35 | atoms = dd[self.atoms]
36 | coordinates = dd[self.coordinates]
37 | charges = dd[self.charges]
38 |
39 | if self.remove_hydrogen:
40 | mask_hydrogen = atoms != "H"
41 | atoms = atoms[mask_hydrogen]
42 | coordinates = coordinates[mask_hydrogen]
43 | charges = charges[mask_hydrogen]
44 | if not self.remove_hydrogen and self.remove_polar_hydrogen:
45 | end_idx = 0
46 | for i, atom in enumerate(atoms[::-1]):
47 | if atom != "H":
48 | break
49 | else:
50 | end_idx = i + 1
51 | if end_idx != 0:
52 | atoms = atoms[:-end_idx]
53 | coordinates = coordinates[:-end_idx]
54 | charges = charges[:-end_idx]
55 | dd[self.atoms] = atoms
56 | dd[self.coordinates] = coordinates.astype(np.float32)
57 | dd[self.charges] = charges
58 | return dd
59 |
60 | def __getitem__(self, index: int):
61 | return self.__cached_item__(index, self.epoch)
62 |
--------------------------------------------------------------------------------
/unimol/data/tta_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import numpy as np
6 | from functools import lru_cache
7 | from unicore.data import BaseWrapperDataset
8 |
9 |
10 | class TTADataset(BaseWrapperDataset):
11 | def __init__(self, dataset, seed, atoms, coordinates, charges, id="ori_smi", conf_size=10):
12 | self.dataset = dataset
13 | self.seed = seed
14 | self.atoms = atoms
15 | self.coordinates = coordinates
16 | self.charges = charges
17 | self.id = id
18 | self.conf_size = conf_size
19 | self.set_epoch(None)
20 |
21 | def set_epoch(self, epoch, **unused):
22 | super().set_epoch(epoch)
23 | self.epoch = epoch
24 |
25 | def __len__(self):
26 | return len(self.dataset) * self.conf_size
27 |
28 | @lru_cache(maxsize=16)
29 | def __cached_item__(self, index: int, epoch: int):
30 | mol_idx = index // self.conf_size
31 | coord_idx = index % self.conf_size
32 | atoms = np.array(self.dataset[mol_idx][self.atoms])
33 | charges = np.array(self.dataset[mol_idx][self.charges])
34 | coordinates = np.array(self.dataset[mol_idx][self.coordinates][coord_idx])
35 | id = self.dataset[mol_idx][self.id]
36 | target = self.dataset[mol_idx]["target"]
37 | smi = self.dataset[mol_idx][self.id]
38 | return {
39 | "atoms": atoms,
40 | "coordinates": coordinates.astype(np.float32),
41 | "charges": charges.astype(str),
42 | "target": target,
43 | "smi":smi,
44 | "target": target,
45 | "id": id,
46 | }
47 |
48 | def __getitem__(self, index: int):
49 | return self.__cached_item__(index, self.epoch)
50 |
51 |
52 | class TTAPKADataset(BaseWrapperDataset):
53 | def __init__(self, dataset, seed, metadata, atoms, coordinates, charges, id="ori_smi"):
54 | self.dataset = dataset
55 | self.seed = seed
56 | self.metadata = metadata
57 | self.atoms = atoms
58 | self.coordinates = coordinates
59 | self.charges = charges
60 | self.id = id
61 | self._init_idx()
62 | self.set_epoch(None)
63 |
64 | def set_epoch(self, epoch, **unused):
65 | super().set_epoch(epoch)
66 | self.epoch = epoch
67 |
68 | def _init_idx(self):
69 | self.idx2key = {}
70 | total_sz = 0
71 | for i in range(len(self.dataset)):
72 | size = len(self.dataset[i][self.metadata])
73 | for j in range(size):
74 | self.idx2key[total_sz] = (i, j)
75 | total_sz += 1
76 | self.total_sz = total_sz
77 |
78 | def get_idx2key(self):
79 | return self.idx2key
80 |
81 | def __len__(self):
82 | return self.total_sz
83 |
84 | @lru_cache(maxsize=16)
85 | def __cached_item__(self, index: int, epoch: int):
86 | smi_idx, mol_idx = self.idx2key[index]
87 | atoms = np.array(self.dataset[smi_idx][self.metadata][mol_idx][self.atoms])
88 | coordinates = np.array(self.dataset[smi_idx][self.metadata][mol_idx][self.coordinates])
89 | charges = np.array(self.dataset[smi_idx][self.metadata][mol_idx][self.charges])
90 | smi = self.dataset[smi_idx]["ori_smi"]
91 | id = self.dataset[smi_idx][self.id]
92 | target = self.dataset[smi_idx]["target"]
93 | return {
94 | "atoms": atoms,
95 | "coordinates": coordinates.astype(np.float32),
96 | "charges": charges.astype(str),
97 | "smi": smi,
98 | "target": target,
99 | "id": id,
100 | }
101 |
102 | def __getitem__(self, index: int):
103 | return self.__cached_item__(index, self.epoch)
104 |
--------------------------------------------------------------------------------
/unimol/examples/dict.txt:
--------------------------------------------------------------------------------
1 | [PAD]
2 | [CLS]
3 | [SEP]
4 | [UNK]
5 | C
6 | N
7 | O
8 | S
9 | H
10 | Cl
11 | F
12 | Br
13 | I
14 | Si
15 | P
16 | B
17 | Na
18 | K
19 | Al
20 | Ca
21 | Sn
22 | As
23 | Hg
24 | Fe
25 | Zn
26 | Cr
27 | Se
28 | Gd
29 | Au
30 | Li
--------------------------------------------------------------------------------
/unimol/examples/dict_charge.txt:
--------------------------------------------------------------------------------
1 | [PAD]
2 | [CLS]
3 | [SEP]
4 | [UNK]
5 | 1
6 | 0
7 | -1
--------------------------------------------------------------------------------
/unimol/infer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3 -u
2 | # Copyright (c) DP Techonology, Inc. and its affiliates.
3 | #
4 | # This source code is licensed under the MIT license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import sys
10 | import pickle
11 | import torch
12 | from unicore import checkpoint_utils, distributed_utils, options, utils
13 | from unicore.logging import progress_bar
14 | from unicore import tasks
15 |
16 | logging.basicConfig(
17 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
18 | datefmt="%Y-%m-%d %H:%M:%S",
19 | level=os.environ.get("LOGLEVEL", "INFO").upper(),
20 | stream=sys.stdout,
21 | )
22 | logger = logging.getLogger("unimol.inference")
23 |
24 |
25 | def main(args):
26 |
27 | assert (
28 | args.batch_size is not None
29 | ), "Must specify batch size either with --batch-size"
30 |
31 | use_fp16 = args.fp16
32 | use_cuda = torch.cuda.is_available() and not args.cpu
33 |
34 | if use_cuda:
35 | torch.cuda.set_device(args.device_id)
36 |
37 | if args.distributed_world_size > 1:
38 | data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
39 | data_parallel_rank = distributed_utils.get_data_parallel_rank()
40 | else:
41 | data_parallel_world_size = 1
42 | data_parallel_rank = 0
43 |
44 | # Load model
45 | logger.info("loading model(s) from {}".format(args.path))
46 | state = checkpoint_utils.load_checkpoint_to_cpu(args.path)
47 | task = tasks.setup_task(args)
48 | model = task.build_model(args)
49 | model.load_state_dict(state["model"], strict=False)
50 |
51 | # Move models to GPU
52 | if use_fp16:
53 | model.half()
54 | if use_cuda:
55 | model.cuda()
56 |
57 | # Print args
58 | logger.info(args)
59 |
60 | # Build loss
61 | loss = task.build_loss(args)
62 | loss.eval()
63 |
64 | for subset in args.valid_subset.split(","):
65 | try:
66 | task.load_dataset(subset, combine=False, epoch=1)
67 | dataset = task.dataset(subset)
68 | except KeyError:
69 | raise Exception("Cannot find dataset: " + subset)
70 |
71 | if not os.path.exists(args.results_path):
72 | os.makedirs(args.results_path)
73 | fname = (os.path.abspath(args.path)).split("/")[-2]
74 | save_path = os.path.join(args.results_path, fname + "_" + subset + ".out.pkl")
75 | # Initialize data iterator
76 | itr = task.get_batch_iterator(
77 | dataset=dataset,
78 | batch_size=args.batch_size,
79 | ignore_invalid_inputs=True,
80 | required_batch_size_multiple=args.required_batch_size_multiple,
81 | seed=args.seed,
82 | num_shards=data_parallel_world_size,
83 | shard_id=data_parallel_rank,
84 | num_workers=args.num_workers,
85 | data_buffer_size=args.data_buffer_size,
86 | ).next_epoch_itr(shuffle=False)
87 | progress = progress_bar.progress_bar(
88 | itr,
89 | log_format=args.log_format,
90 | log_interval=args.log_interval,
91 | prefix=f"valid on '{subset}' subset",
92 | default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
93 | )
94 | log_outputs = []
95 | for i, sample in enumerate(progress):
96 | sample = utils.move_to_cuda(sample) if use_cuda else sample
97 | if len(sample) == 0:
98 | continue
99 | _, _, log_output = task.valid_step(sample, model, loss, test=True)
100 | progress.log({}, step=i)
101 | log_outputs.append(log_output)
102 | pickle.dump(log_outputs, open(save_path, "wb"))
103 | logger.info("Done inference! ")
104 | return None
105 |
106 |
107 | def cli_main():
108 | parser = options.get_validation_parser()
109 | options.add_model_args(parser)
110 | args = options.parse_args_and_arch(parser)
111 |
112 | distributed_utils.call_main(args, main)
113 |
114 |
115 | if __name__ == "__main__":
116 | cli_main()
117 |
--------------------------------------------------------------------------------
/unimol/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import importlib
3 |
4 | # automatically import any Python files in the criterions/ directory
5 | for file in sorted(Path(__file__).parent.glob("*.py")):
6 | if not file.name.startswith("_"):
7 | importlib.import_module("unimol.losses." + file.name[:-3])
8 |
--------------------------------------------------------------------------------
/unimol/losses/mlm_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import math
6 | import torch
7 | import torch.nn.functional as F
8 | import pandas as pd
9 | import numpy as np
10 | from unicore import metrics
11 | from unicore.losses import UnicoreLoss, register_loss
12 |
13 |
14 | @register_loss("pretrain_mlm")
15 | class PretrainMLMLoss(UnicoreLoss):
16 | def __init__(self, task):
17 | super().__init__(task)
18 | self.padding_idx = task.dictionary.pad()
19 | self.charge_padding_idx = task.charge_dictionary.pad()
20 | self.seed = task.seed
21 | self.dist_mean = 6.174412864984603
22 | self.dist_std = 216.17030997643033
23 |
24 | def forward(self, model, sample, reduce=True):
25 | """Compute the loss for the given sample.
26 |
27 | Returns a tuple with three elements:
28 | 1) the loss
29 | 2) the sample size, which is used as the denominator for the gradient
30 | 3) logging outputs to display while training
31 | """
32 | masked_tokens_a = sample["net_input_a"][-1].ne(self.padding_idx)
33 | all_output_a = model(
34 | sample["net_input_a"],
35 | classification_head_name=self.args.classification_head_name,
36 | encoder_masked_tokens=masked_tokens_a
37 | )
38 | masked_tokens_b = sample["net_input_b"][-1].ne(self.padding_idx)
39 | all_output_b = model(
40 | sample["net_input_b"],
41 | classification_head_name=self.args.classification_head_name,
42 | encoder_masked_tokens=masked_tokens_b
43 | )
44 | net_output_a, batch_a = all_output_a[:2]
45 | net_output_b, batch_b = all_output_b[:2]
46 |
47 | loss, predict = self.compute_loss(model, net_output_a, net_output_b, batch_a, batch_b, sample, reduce=reduce)
48 | sample_size = sample["target"]["finetune_target"].size(0)
49 | if not self.training:
50 | if self.task.mean and self.task.std:
51 | targets_mean = torch.tensor(self.task.mean, device=predict.device)
52 | targets_std = torch.tensor(self.task.std, device=predict.device)
53 | predict = predict * targets_std + targets_mean
54 | logging_output = {
55 | "pka_loss": loss.data,
56 | "predict": predict.view(-1, self.args.num_classes).data,
57 | "target": sample["target"]["finetune_target"]
58 | .view(-1, self.args.num_classes)
59 | .data,
60 | "smi_name": sample["id"],
61 | "sample_size": sample_size,
62 | "num_task": self.args.num_classes,
63 | "conf_size": self.args.conf_size,
64 | "bsz": sample["target"]["finetune_target"].size(0),
65 | }
66 | else:
67 | logging_output = {
68 | "pka_loss": loss.data,
69 | "sample_size": sample_size,
70 | "bsz": sample["target"]["finetune_target"].size(0),
71 | }
72 |
73 | loss, logging_output = self.compute_mlm_loss(loss, all_output_a, masked_tokens_a, logging_output, reduce= reduce)
74 | loss, logging_output = self.compute_mlm_loss(loss, all_output_b, masked_tokens_b, logging_output, reduce= reduce)
75 | logging_output['loss'] = loss.data
76 |
77 | return loss, sample_size, logging_output
78 |
79 | def compute_loss(self, model, net_output_a, net_output_b, batch_a, batch_b, sample, reduce=True):
80 | free_energy_a = net_output_a.view(-1, self.args.num_classes).float()
81 | free_energy_b = net_output_b.view(-1, self.args.num_classes).float()
82 | if not self.training:
83 | def compute_agg_free_energy(free_energy, batch):
84 | split_tensor_list = torch.split(free_energy, self.args.conf_size, dim=0)
85 | mean_tensor_list = [torch.mean(x, dim=0, keepdim=True) for x in split_tensor_list]
86 | agg_free_energy = torch.cat(mean_tensor_list, dim=0)
87 | agg_batch = [x//self.args.conf_size for x in batch]
88 | return agg_free_energy, agg_batch
89 | free_energy_a, batch_a = compute_agg_free_energy(free_energy_a, batch_a)
90 | free_energy_b, batch_b = compute_agg_free_energy(free_energy_b, batch_b)
91 |
92 | free_energy_a_padded = torch.nn.utils.rnn.pad_sequence(
93 | torch.split(free_energy_a, batch_a),
94 | padding_value=float("inf")
95 | )
96 | free_energy_b_padded = torch.nn.utils.rnn.pad_sequence(
97 | torch.split(free_energy_b, batch_b),
98 | padding_value=float("inf")
99 | )
100 | predicts = (
101 | torch.logsumexp(-free_energy_a_padded, dim=0)-
102 | torch.logsumexp(-free_energy_b_padded, dim=0)
103 | ) / torch.log(torch.tensor([10.0])).item()
104 |
105 | targets = (
106 | sample["target"]["finetune_target"].view(-1, self.args.num_classes).float()
107 | )
108 | if self.task.mean and self.task.std:
109 | targets_mean = torch.tensor(self.task.mean, device=targets.device)
110 | targets_std = torch.tensor(self.task.std, device=targets.device)
111 | targets = (targets - targets_mean) / targets_std
112 | loss = F.mse_loss(
113 | predicts,
114 | targets,
115 | reduction="sum" if reduce else "none",
116 | )
117 | return loss, predicts
118 |
119 | def compute_mlm_loss(self, loss, all_output, masked_tokens, logging_output, reduce= True):
120 | (_, _,
121 | logits_encoder, charge_logits, encoder_distance, encoder_coord, x_norm, delta_encoder_pair_rep_norm,
122 | token_targets, charge_targets, coord_targets, dist_targets) = all_output
123 | sample_size = masked_tokens.long().sum()
124 |
125 | if self.args.masked_token_loss > 0:
126 | target = token_targets
127 | if masked_tokens is not None:
128 | target = target[masked_tokens]
129 | masked_token_loss = F.nll_loss(
130 | F.log_softmax(logits_encoder, dim=-1, dtype=torch.float32),
131 | target,
132 | ignore_index=self.padding_idx,
133 | reduction="sum" if reduce else "none",
134 | )
135 | masked_pred = logits_encoder.argmax(dim=-1)
136 | masked_hit = (masked_pred == target).long().sum()
137 | masked_cnt = sample_size
138 | if 'masked_token_loss' in logging_output:
139 | logging_output['seq_len'] += token_targets.size(1) * token_targets.size(0)
140 | logging_output['masked_token_loss'] += masked_token_loss.data
141 | logging_output['masked_token_hit'] += masked_hit.data
142 | logging_output['masked_token_cnt'] += masked_cnt
143 | else:
144 | logging_output['seq_len'] = token_targets.size(1) * token_targets.size(0)
145 | logging_output['masked_token_loss'] = masked_token_loss.data
146 | logging_output['masked_token_hit'] = masked_hit.data
147 | logging_output['masked_token_cnt'] = masked_cnt
148 | loss += masked_token_loss * self.args.masked_token_loss
149 |
150 | if self.args.masked_charge_loss > 0:
151 | target = charge_targets
152 | if masked_tokens is not None:
153 | target = target[masked_tokens]
154 | masked_charge_loss = F.nll_loss(
155 | F.log_softmax(charge_logits, dim=-1, dtype=torch.float32),
156 | target,
157 | ignore_index=self.charge_padding_idx,
158 | reduction="sum" if reduce else "none",
159 | )
160 | masked_pred = charge_logits.argmax(dim=-1)
161 | masked_hit = (masked_pred == target).long().sum()
162 | masked_cnt = sample_size
163 | if 'masked_charge_loss' in logging_output:
164 | logging_output['masked_charge_loss'] += masked_charge_loss.data
165 | logging_output['masked_charge_hit'] += masked_hit.data
166 | logging_output['masked_charge_cnt'] += masked_cnt
167 | else:
168 | logging_output['masked_charge_loss'] = masked_charge_loss.data
169 | logging_output['masked_charge_hit'] = masked_hit.data
170 | logging_output['masked_charge_cnt'] = masked_cnt
171 | loss += masked_charge_loss * self.args.masked_charge_loss
172 |
173 | if self.args.masked_coord_loss > 0:
174 | # real = mask + delta
175 | masked_coord_loss = F.smooth_l1_loss(
176 | encoder_coord[masked_tokens].view(-1, 3).float(),
177 | coord_targets[masked_tokens].view(-1, 3),
178 | reduction="sum" if reduce else "none",
179 | beta=1.0,
180 | )
181 | loss = loss + masked_coord_loss * self.args.masked_coord_loss
182 | # restore the scale of loss for displaying
183 | if 'masked_coord_loss' in logging_output:
184 | logging_output["masked_coord_loss"] += masked_coord_loss.data
185 | else:
186 | logging_output["masked_coord_loss"] = masked_coord_loss.data
187 |
188 | if self.args.masked_dist_loss > 0:
189 | dist_masked_tokens = masked_tokens
190 | masked_dist_loss = self.cal_dist_loss(
191 | encoder_distance, dist_masked_tokens, dist_targets, reduce=reduce, normalize=True,
192 | )
193 | loss = loss + masked_dist_loss * self.args.masked_dist_loss
194 | if 'masked_dist_loss' in logging_output:
195 | logging_output["masked_dist_loss"] += masked_dist_loss.data
196 | else:
197 | logging_output["masked_dist_loss"] = masked_dist_loss.data
198 |
199 | if self.args.x_norm_loss > 0 and x_norm is not None:
200 | loss = loss + self.args.x_norm_loss * x_norm
201 | if 'x_norm_loss' in logging_output:
202 | logging_output["x_norm_loss"] += x_norm.data
203 | else:
204 | logging_output["x_norm_loss"] = x_norm.data
205 |
206 | if (
207 | self.args.delta_pair_repr_norm_loss > 0
208 | and delta_encoder_pair_rep_norm is not None
209 | ):
210 | loss = (
211 | loss + self.args.delta_pair_repr_norm_loss * delta_encoder_pair_rep_norm
212 | )
213 | if 'delta_pair_repr_norm_loss' in logging_output:
214 | logging_output[
215 | "delta_pair_repr_norm_loss"
216 | ] += delta_encoder_pair_rep_norm.data
217 | else:
218 | logging_output[
219 | "delta_pair_repr_norm_loss"
220 | ] = delta_encoder_pair_rep_norm.data
221 |
222 | return loss, logging_output
223 |
224 | @staticmethod
225 | def reduce_metrics(logging_outputs, split="valid") -> None:
226 | """Aggregate logging outputs from data parallel training."""
227 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
228 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
229 | if "valid" in split or "test" in split:
230 | sample_size *= logging_outputs[0].get("conf_size",0)
231 | bsz = sum(log.get("bsz", 0) for log in logging_outputs)
232 | seq_len = sum(log.get("seq_len", 0) for log in logging_outputs)
233 | # we divide by log(2) to convert the loss from base e to base 2
234 | metrics.log_scalar(
235 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
236 | )
237 | metrics.log_scalar("seq_len", seq_len / bsz, 1, round=3)
238 | pka_loss = sum(log.get("pka_loss", 0) for log in logging_outputs)
239 | metrics.log_scalar(
240 | "pka_loss", pka_loss / sample_size, sample_size, round=3
241 | )
242 | masked_token_loss = sum(log.get("masked_token_loss", 0) for log in logging_outputs)
243 | if masked_token_loss >0:
244 | metrics.log_scalar(
245 | "masked_token_loss", masked_token_loss / sample_size, sample_size, round=3
246 | )
247 | masked_acc = sum(
248 | log.get("masked_token_hit", 0) for log in logging_outputs
249 | ) / sum(log.get("masked_token_cnt", 0) for log in logging_outputs)
250 | metrics.log_scalar("masked_token_acc", masked_acc, sample_size, round=3)
251 |
252 | masked_charge_loss = sum(log.get("masked_charge_loss", 0) for log in logging_outputs)
253 | if masked_charge_loss >0:
254 | metrics.log_scalar(
255 | "masked_charge_loss", masked_charge_loss / sample_size, sample_size, round=3
256 | )
257 | masked_acc = sum(
258 | log.get("masked_charge_hit", 0) for log in logging_outputs
259 | ) / sum(log.get("masked_charge_cnt", 0) for log in logging_outputs)
260 | metrics.log_scalar("masked_charge_acc", masked_acc, sample_size, round=3)
261 |
262 | masked_coord_loss = sum(
263 | log.get("masked_coord_loss", 0) for log in logging_outputs
264 | )
265 | if masked_coord_loss > 0:
266 | metrics.log_scalar(
267 | "masked_coord_loss",
268 | masked_coord_loss / sample_size,
269 | sample_size,
270 | round=3,
271 | )
272 |
273 | masked_dist_loss = sum(
274 | log.get("masked_dist_loss", 0) for log in logging_outputs
275 | )
276 | if masked_dist_loss > 0:
277 | metrics.log_scalar(
278 | "masked_dist_loss", masked_dist_loss / sample_size, sample_size, round=3
279 | )
280 |
281 | x_norm_loss = sum(log.get("x_norm_loss", 0) for log in logging_outputs)
282 | if x_norm_loss > 0:
283 | metrics.log_scalar(
284 | "x_norm_loss", x_norm_loss / sample_size, sample_size, round=3
285 | )
286 |
287 | delta_pair_repr_norm_loss = sum(
288 | log.get("delta_pair_repr_norm_loss", 0) for log in logging_outputs
289 | )
290 | if delta_pair_repr_norm_loss > 0:
291 | metrics.log_scalar(
292 | "delta_pair_repr_norm_loss",
293 | delta_pair_repr_norm_loss / sample_size,
294 | sample_size,
295 | round=3,
296 | )
297 |
298 | if "valid" in split or "test" in split:
299 | sample_size //= logging_outputs[0].get("conf_size",0)
300 | predicts = torch.cat([log.get("predict") for log in logging_outputs], dim=0)
301 | if predicts.size(-1) == 1:
302 | # single label regression task, add aggregate acc and loss score
303 | targets = torch.cat(
304 | [log.get("target", 0) for log in logging_outputs], dim=0
305 | )
306 | smi_list = [
307 | item for log in logging_outputs for item in log.get("smi_name")
308 | ]
309 | df = pd.DataFrame(
310 | {
311 | "predict": predicts.view(-1).cpu(),
312 | "target": targets.view(-1).cpu(),
313 | "smi": smi_list,
314 | }
315 | )
316 | mae = np.abs(df["predict"] - df["target"]).mean()
317 | mse = ((df["predict"] - df["target"]) ** 2).mean()
318 |
319 | metrics.log_scalar(f"{split}_mae", mae, sample_size, round=3)
320 | metrics.log_scalar(f"{split}_mse", mse, sample_size, round=3)
321 | metrics.log_scalar(
322 | f"{split}_rmse", np.sqrt(mse), sample_size, round=4
323 | )
324 |
325 | @staticmethod
326 | def logging_outputs_can_be_summed(is_train) -> bool:
327 | """
328 | Whether the logging outputs returned by `forward` can be summed
329 | across workers prior to calling `reduce_metrics`. Setting this
330 | to True will improves distributed training speed.
331 | """
332 | return is_train
333 |
334 | def cal_dist_loss(self, dist, dist_masked_tokens, dist_targets, reduce= True, normalize=False):
335 | masked_distance = dist[dist_masked_tokens, :]
336 | masked_distance_target = dist_targets[dist_masked_tokens]
337 | non_pad_pos = masked_distance_target > 0
338 | if normalize:
339 | masked_distance_target = (
340 | masked_distance_target.float() - self.dist_mean
341 | ) / self.dist_std
342 | masked_dist_loss = F.smooth_l1_loss(
343 | masked_distance[non_pad_pos].view(-1).float(),
344 | masked_distance_target[non_pad_pos].view(-1),
345 | reduction="sum" if reduce else "none",
346 | beta=1.0,
347 | )
348 | return masked_dist_loss
349 |
--------------------------------------------------------------------------------
/unimol/losses/reg_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import math
6 | import torch
7 | import torch.nn.functional as F
8 | import pandas as pd
9 | import numpy as np
10 | from unicore import metrics
11 | from unicore.losses import UnicoreLoss, register_loss
12 |
13 |
14 | @register_loss("finetune_mse")
15 | class FinetuneMSELoss(UnicoreLoss):
16 | def __init__(self, task):
17 | super().__init__(task)
18 |
19 | def forward(self, model, sample, reduce=True):
20 | """Compute the loss for the given sample.
21 |
22 | Returns a tuple with three elements:
23 | 1) the loss
24 | 2) the sample size, which is used as the denominator for the gradient
25 | 3) logging outputs to display while training
26 | """
27 | net_output_a, batch_a = model(
28 | sample["net_input_a"],
29 | classification_head_name=self.args.classification_head_name,
30 | features_only = True,
31 | )
32 | net_output_b, batch_b = model(
33 | sample["net_input_b"],
34 | classification_head_name=self.args.classification_head_name,
35 | features_only = True,
36 | )
37 |
38 | loss, predict = self.compute_loss(model, net_output_a, net_output_b, batch_a, batch_b, sample, reduce=reduce)
39 | sample_size = sample["target"]["finetune_target"].size(0)
40 | if not self.training:
41 | if self.task.mean and self.task.std:
42 | targets_mean = torch.tensor(self.task.mean, device=predict.device)
43 | targets_std = torch.tensor(self.task.std, device=predict.device)
44 | predict = predict * targets_std + targets_mean
45 | logging_output = {
46 | "loss": loss.data,
47 | "predict": predict.view(-1, self.args.num_classes).data,
48 | "target": sample["target"]["finetune_target"]
49 | .view(-1, self.args.num_classes)
50 | .data,
51 | "smi_name": sample["id"],
52 | "sample_size": sample_size,
53 | "num_task": self.args.num_classes,
54 | "conf_size": self.args.conf_size,
55 | "bsz": sample["target"]["finetune_target"].size(0),
56 | }
57 | else:
58 | logging_output = {
59 | "loss": loss.data,
60 | "sample_size": sample_size,
61 | "bsz": sample["target"]["finetune_target"].size(0),
62 | }
63 | return loss, sample_size, logging_output
64 |
65 | def compute_loss(self, model, net_output_a, net_output_b, batch_a, batch_b, sample, reduce=True):
66 | free_energy_a = net_output_a.view(-1, self.args.num_classes).float()
67 | free_energy_b = net_output_b.view(-1, self.args.num_classes).float()
68 | if not self.training:
69 | def compute_agg_free_energy(free_energy, batch):
70 | split_tensor_list = torch.split(free_energy, self.args.conf_size, dim=0)
71 | mean_tensor_list = [torch.mean(x, dim=0, keepdim=True) for x in split_tensor_list]
72 | agg_free_energy = torch.cat(mean_tensor_list, dim=0)
73 | agg_batch = [x//self.args.conf_size for x in batch]
74 | return agg_free_energy, agg_batch
75 | free_energy_a, batch_a = compute_agg_free_energy(free_energy_a, batch_a)
76 | free_energy_b, batch_b = compute_agg_free_energy(free_energy_b, batch_b)
77 |
78 | free_energy_a_padded = torch.nn.utils.rnn.pad_sequence(
79 | torch.split(free_energy_a, batch_a),
80 | padding_value=float("inf")
81 | )
82 | free_energy_b_padded = torch.nn.utils.rnn.pad_sequence(
83 | torch.split(free_energy_b, batch_b),
84 | padding_value=float("inf")
85 | )
86 | predicts = (
87 | torch.logsumexp(-free_energy_a_padded, dim=0)-
88 | torch.logsumexp(-free_energy_b_padded, dim=0)
89 | ) / torch.log(torch.tensor([10.0])).item()
90 |
91 | targets = (
92 | sample["target"]["finetune_target"].view(-1, self.args.num_classes).float()
93 | )
94 | if self.task.mean and self.task.std:
95 | targets_mean = torch.tensor(self.task.mean, device=targets.device)
96 | targets_std = torch.tensor(self.task.std, device=targets.device)
97 | targets = (targets - targets_mean) / targets_std
98 | loss = F.mse_loss(
99 | predicts,
100 | targets,
101 | reduction="sum" if reduce else "none",
102 | )
103 | return loss, predicts
104 |
105 | @staticmethod
106 | def reduce_metrics(logging_outputs, split="valid") -> None:
107 | """Aggregate logging outputs from data parallel training."""
108 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
109 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
110 | # we divide by log(2) to convert the loss from base e to base 2
111 | metrics.log_scalar(
112 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
113 | )
114 | if "valid" in split or "test" in split:
115 | predicts = torch.cat([log.get("predict") for log in logging_outputs], dim=0)
116 | if predicts.size(-1) == 1:
117 | # single label regression task, add aggregate acc and loss score
118 | targets = torch.cat(
119 | [log.get("target", 0) for log in logging_outputs], dim=0
120 | )
121 | smi_list = [
122 | item for log in logging_outputs for item in log.get("smi_name")
123 | ]
124 | df = pd.DataFrame(
125 | {
126 | "predict": predicts.view(-1).cpu(),
127 | "target": targets.view(-1).cpu(),
128 | "smi": smi_list,
129 | }
130 | )
131 | mae = np.abs(df["predict"] - df["target"]).mean()
132 | mse = ((df["predict"] - df["target"]) ** 2).mean()
133 | metrics.log_scalar(f"{split}_mae", mae, sample_size, round=3)
134 | metrics.log_scalar(f"{split}_mse", mse, sample_size, round=3)
135 | metrics.log_scalar(
136 | f"{split}_rmse", np.sqrt(mse), sample_size, round=4
137 | )
138 |
139 | @staticmethod
140 | def logging_outputs_can_be_summed(is_train) -> bool:
141 | """
142 | Whether the logging outputs returned by `forward` can be summed
143 | across workers prior to calling `reduce_metrics`. Setting this
144 | to True will improves distributed training speed.
145 | """
146 | return is_train
147 |
148 |
149 | @register_loss("infer_free_energy")
150 | class InferFreeEnergyLoss(UnicoreLoss):
151 | def __init__(self, task):
152 | super().__init__(task)
153 |
154 | def forward(self, model, sample, reduce=True):
155 | """Compute the loss for the given sample.
156 |
157 | Returns a tuple with three elements:
158 | 1) the loss
159 | 2) the sample size, which is used as the denominator for the gradient
160 | 3) logging outputs to display while training
161 | """
162 | net_output = model(
163 | **sample["net_input"],
164 | classification_head_name=self.args.classification_head_name,
165 | features_only=True,
166 | )
167 | reg_output = net_output[0]
168 | loss = torch.tensor([0.01], device=sample["target"]["finetune_target"].device)
169 | sample_size = sample["target"]["finetune_target"].size(0)
170 | if not self.training:
171 | logging_output = {
172 | "loss": loss.data,
173 | "predict": reg_output.view(-1, self.args.num_classes).data,
174 | "target": sample["target"]["finetune_target"]
175 | .view(-1, self.args.num_classes)
176 | .data,
177 | "smi_name": sample["smi_name"],
178 | "sample_size": sample_size,
179 | "num_task": self.args.num_classes,
180 | "conf_size": self.args.conf_size,
181 | "bsz": sample["target"]["finetune_target"].size(0),
182 | }
183 | return loss, sample_size, logging_output
184 |
185 | @staticmethod
186 | def reduce_metrics(logging_outputs, split="valid") -> None:
187 | """Aggregate logging outputs from data parallel training."""
188 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
189 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
190 | # we divide by log(2) to convert the loss from base e to base 2
191 | metrics.log_scalar(
192 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
193 | )
194 |
195 | @staticmethod
196 | def logging_outputs_can_be_summed(is_train) -> bool:
197 | """
198 | Whether the logging outputs returned by `forward` can be summed
199 | across workers prior to calling `reduce_metrics`. Setting this
200 | to True will improves distributed training speed.
201 | """
202 | return is_train
203 |
--------------------------------------------------------------------------------
/unimol/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .unimol_pka import UniMolPKAModel
2 | from .transformer_encoder_with_pair import TransformerEncoderWithPair
3 | from .unimol import UniMolModel
--------------------------------------------------------------------------------
/unimol/models/transformer_encoder_with_pair.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | from typing import Optional
6 |
7 | import math
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from unicore.modules import TransformerEncoderLayer, LayerNorm
12 |
13 |
14 | class TransformerEncoderWithPair(nn.Module):
15 | def __init__(
16 | self,
17 | encoder_layers: int = 6,
18 | embed_dim: int = 768,
19 | ffn_embed_dim: int = 3072,
20 | attention_heads: int = 8,
21 | emb_dropout: float = 0.1,
22 | dropout: float = 0.1,
23 | attention_dropout: float = 0.1,
24 | activation_dropout: float = 0.0,
25 | max_seq_len: int = 256,
26 | activation_fn: str = "gelu",
27 | post_ln: bool = False,
28 | no_final_head_layer_norm: bool = False,
29 | ) -> None:
30 |
31 | super().__init__()
32 | self.emb_dropout = emb_dropout
33 | self.max_seq_len = max_seq_len
34 | self.embed_dim = embed_dim
35 | self.attention_heads = attention_heads
36 | self.emb_layer_norm = LayerNorm(self.embed_dim)
37 | if not post_ln:
38 | self.final_layer_norm = LayerNorm(self.embed_dim)
39 | else:
40 | self.final_layer_norm = None
41 |
42 | if not no_final_head_layer_norm:
43 | self.final_head_layer_norm = LayerNorm(attention_heads)
44 | else:
45 | self.final_head_layer_norm = None
46 |
47 | self.layers = nn.ModuleList(
48 | [
49 | TransformerEncoderLayer(
50 | embed_dim=self.embed_dim,
51 | ffn_embed_dim=ffn_embed_dim,
52 | attention_heads=attention_heads,
53 | dropout=dropout,
54 | attention_dropout=attention_dropout,
55 | activation_dropout=activation_dropout,
56 | activation_fn=activation_fn,
57 | post_ln=post_ln,
58 | )
59 | for _ in range(encoder_layers)
60 | ]
61 | )
62 |
63 | def forward(
64 | self,
65 | emb: torch.Tensor,
66 | attn_mask: Optional[torch.Tensor] = None,
67 | padding_mask: Optional[torch.Tensor] = None,
68 | ) -> torch.Tensor:
69 |
70 | bsz = emb.size(0)
71 | seq_len = emb.size(1)
72 | x = self.emb_layer_norm(emb)
73 | x = F.dropout(x, p=self.emb_dropout, training=self.training)
74 |
75 | # account for padding while computing the representation
76 | if padding_mask is not None:
77 | x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
78 | input_attn_mask = attn_mask
79 | input_padding_mask = padding_mask
80 |
81 | def fill_attn_mask(attn_mask, padding_mask, fill_val=float("-inf")):
82 | if attn_mask is not None and padding_mask is not None:
83 | # merge key_padding_mask and attn_mask
84 | attn_mask = attn_mask.view(x.size(0), -1, seq_len, seq_len)
85 | attn_mask.masked_fill_(
86 | padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
87 | fill_val,
88 | )
89 | attn_mask = attn_mask.view(-1, seq_len, seq_len)
90 | padding_mask = None
91 | return attn_mask, padding_mask
92 |
93 | assert attn_mask is not None
94 | attn_mask, padding_mask = fill_attn_mask(attn_mask, padding_mask)
95 |
96 | for i in range(len(self.layers)):
97 | x, attn_mask, _ = self.layers[i](
98 | x, padding_mask=padding_mask, attn_bias=attn_mask, return_attn=True
99 | )
100 |
101 | def norm_loss(x, eps=1e-10, tolerance=1.0):
102 | x = x.float()
103 | max_norm = x.shape[-1] ** 0.5
104 | norm = torch.sqrt(torch.sum(x**2, dim=-1) + eps)
105 | error = torch.nn.functional.relu((norm - max_norm).abs() - tolerance)
106 | return error
107 |
108 | def masked_mean(mask, value, dim=-1, eps=1e-10):
109 | return (
110 | torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
111 | ).mean()
112 |
113 | x_norm = norm_loss(x)
114 | if input_padding_mask is not None:
115 | token_mask = 1.0 - input_padding_mask.float()
116 | else:
117 | token_mask = torch.ones_like(x_norm, device=x_norm.device)
118 | x_norm = masked_mean(token_mask, x_norm)
119 |
120 | if self.final_layer_norm is not None:
121 | x = self.final_layer_norm(x)
122 |
123 | delta_pair_repr = attn_mask - input_attn_mask
124 | delta_pair_repr, _ = fill_attn_mask(delta_pair_repr, input_padding_mask, 0)
125 | attn_mask = (
126 | attn_mask.view(bsz, -1, seq_len, seq_len).permute(0, 2, 3, 1).contiguous()
127 | )
128 | delta_pair_repr = (
129 | delta_pair_repr.view(bsz, -1, seq_len, seq_len)
130 | .permute(0, 2, 3, 1)
131 | .contiguous()
132 | )
133 |
134 | pair_mask = token_mask[..., None] * token_mask[..., None, :]
135 | delta_pair_repr_norm = norm_loss(delta_pair_repr)
136 | delta_pair_repr_norm = masked_mean(
137 | pair_mask, delta_pair_repr_norm, dim=(-1, -2)
138 | )
139 |
140 | if self.final_head_layer_norm is not None:
141 | delta_pair_repr = self.final_head_layer_norm(delta_pair_repr)
142 |
143 | return x, attn_mask, delta_pair_repr, x_norm, delta_pair_repr_norm
144 |
--------------------------------------------------------------------------------
/unimol/models/unimol.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import logging
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from unicore import utils
10 | from unicore.models import BaseUnicoreModel, register_model, register_model_architecture
11 | from unicore.modules import LayerNorm, init_bert_params
12 | from .transformer_encoder_with_pair import TransformerEncoderWithPair
13 | from typing import Dict, Any, List
14 |
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | @register_model("unimol")
20 | class UniMolModel(BaseUnicoreModel):
21 | @staticmethod
22 | def add_args(parser):
23 | """Add model-specific arguments to the parser."""
24 | parser.add_argument(
25 | "--encoder-layers", type=int, metavar="L", help="num encoder layers"
26 | )
27 | parser.add_argument(
28 | "--encoder-embed-dim",
29 | type=int,
30 | metavar="H",
31 | help="encoder embedding dimension",
32 | )
33 | parser.add_argument(
34 | "--encoder-ffn-embed-dim",
35 | type=int,
36 | metavar="F",
37 | help="encoder embedding dimension for FFN",
38 | )
39 | parser.add_argument(
40 | "--encoder-attention-heads",
41 | type=int,
42 | metavar="A",
43 | help="num encoder attention heads",
44 | )
45 | parser.add_argument(
46 | "--activation-fn",
47 | choices=utils.get_available_activation_fns(),
48 | help="activation function to use",
49 | )
50 | parser.add_argument(
51 | "--pooler-activation-fn",
52 | choices=utils.get_available_activation_fns(),
53 | help="activation function to use for pooler layer",
54 | )
55 | parser.add_argument(
56 | "--emb-dropout",
57 | type=float,
58 | metavar="D",
59 | help="dropout probability for embeddings",
60 | )
61 | parser.add_argument(
62 | "--dropout", type=float, metavar="D", help="dropout probability"
63 | )
64 | parser.add_argument(
65 | "--attention-dropout",
66 | type=float,
67 | metavar="D",
68 | help="dropout probability for attention weights",
69 | )
70 | parser.add_argument(
71 | "--activation-dropout",
72 | type=float,
73 | metavar="D",
74 | help="dropout probability after activation in FFN",
75 | )
76 | parser.add_argument(
77 | "--pooler-dropout",
78 | type=float,
79 | metavar="D",
80 | help="dropout probability in the masked_lm pooler layers",
81 | )
82 | parser.add_argument(
83 | "--max-seq-len", type=int, help="number of positional embeddings to learn"
84 | )
85 | parser.add_argument(
86 | "--post-ln", type=bool, help="use post layernorm or pre layernorm"
87 | )
88 | parser.add_argument(
89 | "--x-norm-loss",
90 | type=float,
91 | metavar="D",
92 | help="x norm loss ratio",
93 | )
94 | parser.add_argument(
95 | "--delta-pair-repr-norm-loss",
96 | type=float,
97 | metavar="D",
98 | help="delta encoder pair repr norm loss ratio",
99 | )
100 | parser.add_argument(
101 | "--mode",
102 | type=str,
103 | default="train",
104 | choices=["train", "infer"],
105 | )
106 |
107 | def __init__(self, args, dictionary, charge_dictionary):
108 | super().__init__()
109 | base_architecture(args)
110 | self.args = args
111 | self.padding_idx = dictionary.pad()
112 | self.charge_padding_idx = charge_dictionary.pad()
113 | self.embed_tokens = nn.Embedding(
114 | len(dictionary), args.encoder_embed_dim, self.padding_idx
115 | )
116 | self.embed_charges = nn.Embedding(
117 | len(charge_dictionary), args.encoder_embed_dim, self.charge_padding_idx
118 | )
119 | self._num_updates = None
120 | self.encoder = TransformerEncoderWithPair(
121 | encoder_layers=args.encoder_layers,
122 | embed_dim=args.encoder_embed_dim,
123 | ffn_embed_dim=args.encoder_ffn_embed_dim,
124 | attention_heads=args.encoder_attention_heads,
125 | emb_dropout=args.emb_dropout,
126 | dropout=args.dropout,
127 | attention_dropout=args.attention_dropout,
128 | activation_dropout=args.activation_dropout,
129 | max_seq_len=args.max_seq_len,
130 | activation_fn=args.activation_fn,
131 | no_final_head_layer_norm=args.delta_pair_repr_norm_loss < 0,
132 | )
133 |
134 | K = 128
135 | n_edge_type = len(dictionary) * len(dictionary)
136 | self.gbf_proj = NonLinearHead(
137 | K, args.encoder_attention_heads, args.activation_fn
138 | )
139 | self.gbf = GaussianLayer(K, n_edge_type)
140 |
141 | self.classification_heads = nn.ModuleDict()
142 | self.apply(init_bert_params)
143 |
144 | @classmethod
145 | def build_model(cls, args, task):
146 | """Build a new model instance."""
147 | return cls(args, task.dictionary, task.charge_dictionary)
148 |
149 | def forward(
150 | self,
151 | src_tokens,
152 | src_charges,
153 | src_distance,
154 | src_coord,
155 | src_edge_type,
156 | encoder_masked_tokens=None,
157 | classification_head_name=None,
158 | **kwargs
159 | ):
160 |
161 | padding_mask = src_tokens.eq(self.padding_idx)
162 | if not padding_mask.any():
163 | padding_mask = None
164 | x = self.embed_tokens(src_tokens)
165 |
166 | charge_padding_mask = src_charges.eq(self.charge_padding_idx)
167 | if not charge_padding_mask.any():
168 | padding_mask = None
169 | charges_emb = self.embed_charges(src_charges)
170 | # involve charge info
171 | x += charges_emb
172 |
173 | def get_dist_features(dist, et):
174 | n_node = dist.size(-1)
175 | gbf_feature = self.gbf(dist, et)
176 | gbf_result = self.gbf_proj(gbf_feature)
177 | graph_attn_bias = gbf_result
178 | graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
179 | graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
180 | return graph_attn_bias
181 |
182 | graph_attn_bias = get_dist_features(src_distance, src_edge_type)
183 | (
184 | encoder_rep,
185 | encoder_pair_rep,
186 | delta_encoder_pair_rep,
187 | x_norm,
188 | delta_encoder_pair_rep_norm,
189 | ) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias)
190 | encoder_pair_rep[encoder_pair_rep == float("-inf")] = 0
191 |
192 | if classification_head_name is not None:
193 | logits = self.classification_heads[classification_head_name](encoder_rep)
194 | if self.args.mode == 'infer':
195 | return encoder_rep, encoder_pair_rep
196 | else:
197 | return (
198 | logits,
199 | x_norm,
200 | delta_encoder_pair_rep_norm,
201 | )
202 |
203 | def register_classification_head(
204 | self, name, num_classes=None, inner_dim=None, **kwargs
205 | ):
206 | """Register a classification head."""
207 | if name in self.classification_heads:
208 | prev_num_classes = self.classification_heads[name].out_proj.out_features
209 | prev_inner_dim = self.classification_heads[name].dense.out_features
210 | if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
211 | logger.warning(
212 | 're-registering head "{}" with num_classes {} (prev: {}) '
213 | "and inner_dim {} (prev: {})".format(
214 | name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
215 | )
216 | )
217 | self.classification_heads[name] = ClassificationHead(
218 | input_dim=self.args.encoder_embed_dim,
219 | inner_dim=inner_dim or self.args.encoder_embed_dim,
220 | num_classes=num_classes,
221 | activation_fn=self.args.pooler_activation_fn,
222 | pooler_dropout=self.args.pooler_dropout,
223 | )
224 |
225 | def set_num_updates(self, num_updates):
226 | """State from trainer to pass along to model at every update."""
227 | self._num_updates = num_updates
228 |
229 | def get_num_updates(self):
230 | return self._num_updates
231 |
232 |
233 | class ClassificationHead(nn.Module):
234 | """Head for sentence-level classification tasks."""
235 |
236 | def __init__(
237 | self,
238 | input_dim,
239 | inner_dim,
240 | num_classes,
241 | activation_fn,
242 | pooler_dropout,
243 | ):
244 | super().__init__()
245 | self.dense = nn.Linear(input_dim, inner_dim)
246 | self.activation_fn = utils.get_activation_fn(activation_fn)
247 | self.dropout = nn.Dropout(p=pooler_dropout)
248 | self.out_proj = nn.Linear(inner_dim, num_classes)
249 |
250 | def forward(self, features, **kwargs):
251 | x = features[:, 0, :] # take token (equiv. to [CLS])
252 | x = self.dropout(x)
253 | x = self.dense(x)
254 | x = self.activation_fn(x)
255 | x = self.dropout(x)
256 | x = self.out_proj(x)
257 | return x
258 |
259 |
260 | class NonLinearHead(nn.Module):
261 | """Head for simple classification tasks."""
262 |
263 | def __init__(
264 | self,
265 | input_dim,
266 | out_dim,
267 | activation_fn,
268 | hidden=None,
269 | ):
270 | super().__init__()
271 | hidden = input_dim if not hidden else hidden
272 | self.linear1 = nn.Linear(input_dim, hidden)
273 | self.linear2 = nn.Linear(hidden, out_dim)
274 | self.activation_fn = utils.get_activation_fn(activation_fn)
275 |
276 | def forward(self, x):
277 | x = self.linear1(x)
278 | x = self.activation_fn(x)
279 | x = self.linear2(x)
280 | return x
281 |
282 |
283 | @torch.jit.script
284 | def gaussian(x, mean, std):
285 | pi = 3.14159
286 | a = (2 * pi) ** 0.5
287 | return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
288 |
289 |
290 | class GaussianLayer(nn.Module):
291 | def __init__(self, K=128, edge_types=1024):
292 | super().__init__()
293 | self.K = K
294 | self.means = nn.Embedding(1, K)
295 | self.stds = nn.Embedding(1, K)
296 | self.mul = nn.Embedding(edge_types, 1)
297 | self.bias = nn.Embedding(edge_types, 1)
298 | nn.init.uniform_(self.means.weight, 0, 3)
299 | nn.init.uniform_(self.stds.weight, 0, 3)
300 | nn.init.constant_(self.bias.weight, 0)
301 | nn.init.constant_(self.mul.weight, 1)
302 |
303 | def forward(self, x, edge_type):
304 | mul = self.mul(edge_type).type_as(x)
305 | bias = self.bias(edge_type).type_as(x)
306 | x = mul * x.unsqueeze(-1) + bias
307 | x = x.expand(-1, -1, -1, self.K)
308 | mean = self.means.weight.float().view(-1)
309 | std = self.stds.weight.float().view(-1).abs() + 1e-5
310 | return gaussian(x.float(), mean, std).type_as(self.means.weight)
311 |
312 |
313 | @register_model_architecture("unimol", "unimol")
314 | def base_architecture(args):
315 | args.encoder_layers = getattr(args, "encoder_layers", 15)
316 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
317 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
318 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64)
319 | args.dropout = getattr(args, "dropout", 0.1)
320 | args.emb_dropout = getattr(args, "emb_dropout", 0.1)
321 | args.attention_dropout = getattr(args, "attention_dropout", 0.1)
322 | args.activation_dropout = getattr(args, "activation_dropout", 0.0)
323 | args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
324 | args.max_seq_len = getattr(args, "max_seq_len", 512)
325 | args.activation_fn = getattr(args, "activation_fn", "gelu")
326 | args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
327 | args.post_ln = getattr(args, "post_ln", False)
328 | args.x_norm_loss = getattr(args, "x_norm_loss", -1.0)
329 | args.delta_pair_repr_norm_loss = getattr(args, "delta_pair_repr_norm_loss", -1.0)
330 |
331 |
332 | @register_model_architecture("unimol", "unimol_base")
333 | def unimol_base_architecture(args):
334 | base_architecture(args)
335 |
--------------------------------------------------------------------------------
/unimol/models/unimol_pka.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import logging
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from unicore import utils
10 | from unicore.models import BaseUnicoreModel, register_model, register_model_architecture
11 | from unicore.modules import LayerNorm, init_bert_params
12 | from .unimol import UniMolModel, ClassificationHead, NonLinearHead
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | @register_model("unimol_pka")
18 | class UniMolPKAModel(BaseUnicoreModel):
19 | @staticmethod
20 | def add_args(parser):
21 | """Add model-specific arguments to the parser."""
22 | parser.add_argument(
23 | "--masked-token-loss",
24 | type=float,
25 | metavar="D",
26 | help="mask loss ratio",
27 | )
28 | parser.add_argument(
29 | "--masked-charge-loss",
30 | type=float,
31 | metavar="D",
32 | help="mask charge loss ratio",
33 | )
34 | parser.add_argument(
35 | "--masked-dist-loss",
36 | type=float,
37 | metavar="D",
38 | help="masked distance loss ratio",
39 | )
40 | parser.add_argument(
41 | "--masked-coord-loss",
42 | type=float,
43 | metavar="D",
44 | help="masked coord loss ratio",
45 | )
46 | parser.add_argument(
47 | "--x-norm-loss",
48 | type=float,
49 | metavar="D",
50 | help="x norm loss ratio",
51 | )
52 | parser.add_argument(
53 | "--delta-pair-repr-norm-loss",
54 | type=float,
55 | metavar="D",
56 | help="delta encoder pair repr norm loss ratio",
57 | )
58 | parser.add_argument(
59 | "--pooler-dropout",
60 | type=float,
61 | metavar="D",
62 | help="dropout probability in the masked_lm pooler layers",
63 | )
64 |
65 | def __init__(self, args, dictionary, charge_dictionary):
66 | super().__init__()
67 | unimol_pka_architecture(args)
68 | self.args = args
69 | self.unimol = UniMolModel(self.args, dictionary, charge_dictionary)
70 | if args.masked_token_loss > 0:
71 | self.lm_head = MaskLMHead(
72 | embed_dim=args.encoder_embed_dim,
73 | output_dim=len(dictionary),
74 | activation_fn=args.activation_fn,
75 | weight=None,
76 | )
77 | if args.masked_charge_loss > 0:
78 | self.charge_lm_head = MaskLMHead(
79 | embed_dim=args.encoder_embed_dim,
80 | output_dim=len(charge_dictionary),
81 | activation_fn=args.activation_fn,
82 | weight=None,
83 | )
84 | if args.masked_coord_loss > 0:
85 | self.pair2coord_proj = NonLinearHead(
86 | args.encoder_attention_heads, 1, args.activation_fn
87 | )
88 | if args.masked_dist_loss > 0:
89 | self.dist_head = DistanceHead(
90 | args.encoder_attention_heads, args.activation_fn
91 | )
92 |
93 | @classmethod
94 | def build_model(cls, args, task):
95 | """Build a new model instance."""
96 | return cls(args, task.dictionary, task.charge_dictionary)
97 |
98 | def forward(
99 | self,
100 | input_metadata,
101 | classification_head_name=None,
102 | encoder_masked_tokens=None,
103 | features_only=False,
104 | **kwargs
105 | ):
106 | if not features_only:
107 | src_tokens, src_charges, src_coord, src_distance, src_edge_type, batch, charge_targets, coord_targets, dist_targets, token_targets = input_metadata
108 | else:
109 | src_tokens, src_charges, src_coord, src_distance, src_edge_type, batch = input_metadata
110 | charge_targets, coord_targets, dist_targets, token_targets = None, None, None, None
111 | padding_mask = src_tokens.eq(self.unimol.padding_idx)
112 | if not padding_mask.any():
113 | padding_mask = None
114 | x = self.unimol.embed_tokens(src_tokens)
115 |
116 | charge_padding_mask = src_charges.eq(self.unimol.charge_padding_idx)
117 | if not charge_padding_mask.any():
118 | padding_mask = None
119 | charges_emb = self.unimol.embed_charges(src_charges)
120 | # involve charge info
121 | x += charges_emb
122 |
123 | def get_dist_features(dist, et):
124 | n_node = dist.size(-1)
125 | gbf_feature = self.unimol.gbf(dist, et)
126 | gbf_result = self.unimol.gbf_proj(gbf_feature)
127 | graph_attn_bias = gbf_result
128 | graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
129 | graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
130 | return graph_attn_bias
131 |
132 | graph_attn_bias = get_dist_features(src_distance, src_edge_type)
133 | (
134 | encoder_rep,
135 | encoder_pair_rep,
136 | delta_encoder_pair_rep,
137 | x_norm,
138 | delta_encoder_pair_rep_norm,
139 | ) = self.unimol.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias)
140 | encoder_pair_rep[encoder_pair_rep == float("-inf")] = 0
141 |
142 | encoder_distance = None
143 | encoder_coord = None
144 |
145 | if not features_only:
146 | if self.args.masked_token_loss > 0:
147 | logits = self.lm_head(encoder_rep, encoder_masked_tokens)
148 | if self.args.masked_charge_loss > 0:
149 | charge_logits = self.charge_lm_head(encoder_rep, encoder_masked_tokens)
150 | if self.args.masked_coord_loss > 0:
151 | coords_emb = src_coord
152 | if padding_mask is not None:
153 | atom_num = (torch.sum(1 - padding_mask.type_as(x), dim=1) - 1).view(
154 | -1, 1, 1, 1
155 | )
156 | else:
157 | atom_num = src_coord.shape[1] - 1
158 | delta_pos = coords_emb.unsqueeze(1) - coords_emb.unsqueeze(2)
159 | attn_probs = self.pair2coord_proj(delta_encoder_pair_rep)
160 | coord_update = delta_pos / atom_num * attn_probs
161 | coord_update = torch.sum(coord_update, dim=2)
162 | encoder_coord = coords_emb + coord_update
163 | if self.args.masked_dist_loss > 0:
164 | encoder_distance = self.dist_head(encoder_pair_rep)
165 |
166 | if classification_head_name is not None:
167 | cls_logits = self.unimol.classification_heads[classification_head_name](encoder_rep)
168 | if not features_only:
169 | return (
170 | cls_logits, batch,
171 | logits, charge_logits, encoder_distance, encoder_coord, x_norm, delta_encoder_pair_rep_norm,
172 | token_targets, charge_targets, coord_targets, dist_targets
173 | )
174 | else:
175 | return cls_logits, batch
176 |
177 | def register_classification_head(
178 | self, name, num_classes=None, inner_dim=None, **kwargs
179 | ):
180 | """Register a classification head."""
181 | if name in self.unimol.classification_heads:
182 | prev_num_classes = self.unimol.classification_heads[name].out_proj.out_features
183 | prev_inner_dim = self.unimol.classification_heads[name].dense.out_features
184 | if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
185 | logger.warning(
186 | 're-registering head "{}" with num_classes {} (prev: {}) '
187 | "and inner_dim {} (prev: {})".format(
188 | name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
189 | )
190 | )
191 | self.unimol.classification_heads[name] = ClassificationHead(
192 | input_dim=self.args.encoder_embed_dim,
193 | inner_dim=inner_dim or self.args.encoder_embed_dim,
194 | num_classes=num_classes,
195 | activation_fn=self.args.pooler_activation_fn,
196 | pooler_dropout=self.args.pooler_dropout,
197 | )
198 |
199 | def set_num_updates(self, num_updates):
200 | """State from trainer to pass along to model at every update."""
201 | self._num_updates = num_updates
202 |
203 | def get_num_updates(self):
204 | return self._num_updates
205 |
206 |
207 | class MaskLMHead(nn.Module):
208 | """Head for masked language modeling."""
209 |
210 | def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
211 | super().__init__()
212 | self.dense = nn.Linear(embed_dim, embed_dim)
213 | self.activation_fn = utils.get_activation_fn(activation_fn)
214 | self.layer_norm = LayerNorm(embed_dim)
215 |
216 | if weight is None:
217 | weight = nn.Linear(embed_dim, output_dim, bias=False).weight
218 | self.weight = weight
219 | self.bias = nn.Parameter(torch.zeros(output_dim))
220 |
221 | def forward(self, features, masked_tokens=None, **kwargs):
222 | # Only project the masked tokens while training,
223 | # saves both memory and computation
224 | if masked_tokens is not None:
225 | features = features[masked_tokens, :]
226 |
227 | x = self.dense(features)
228 | x = self.activation_fn(x)
229 | x = self.layer_norm(x)
230 | # project back to size of vocabulary with bias
231 | x = F.linear(x, self.weight) + self.bias
232 | return x
233 |
234 |
235 | class DistanceHead(nn.Module):
236 | def __init__(
237 | self,
238 | heads,
239 | activation_fn,
240 | ):
241 | super().__init__()
242 | self.dense = nn.Linear(heads, heads)
243 | self.layer_norm = nn.LayerNorm(heads)
244 | self.out_proj = nn.Linear(heads, 1)
245 | self.activation_fn = utils.get_activation_fn(activation_fn)
246 |
247 | def forward(self, x):
248 | bsz, seq_len, seq_len, _ = x.size()
249 | # x[x == float('-inf')] = 0
250 | x = self.dense(x)
251 | x = self.activation_fn(x)
252 | x = self.layer_norm(x)
253 | x = self.out_proj(x).view(bsz, seq_len, seq_len)
254 | x = (x + x.transpose(-1, -2)) * 0.5
255 | return x
256 |
257 |
258 | @register_model_architecture("unimol_pka", "unimol_pka")
259 | def unimol_pka_architecture(args):
260 | def base_architecture(args):
261 | args.encoder_layers = getattr(args, "encoder_layers", 15)
262 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
263 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
264 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64)
265 | args.dropout = getattr(args, "dropout", 0.1)
266 | args.emb_dropout = getattr(args, "emb_dropout", 0.1)
267 | args.attention_dropout = getattr(args, "attention_dropout", 0.1)
268 | args.activation_dropout = getattr(args, "activation_dropout", 0.0)
269 | args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
270 | args.max_seq_len = getattr(args, "max_seq_len", 512)
271 | args.activation_fn = getattr(args, "activation_fn", "gelu")
272 | args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
273 | args.post_ln = getattr(args, "post_ln", False)
274 | args.masked_token_loss = getattr(args, "masked_token_loss", -1.0)
275 | args.masked_charge_loss = getattr(args, "masked_charge_loss", -1.0)
276 | args.masked_coord_loss = getattr(args, "masked_coord_loss", -1.0)
277 | args.masked_dist_loss = getattr(args, "masked_dist_loss", -1.0)
278 | args.x_norm_loss = getattr(args, "x_norm_loss", -1.0)
279 | args.delta_pair_repr_norm_loss = getattr(args, "delta_pair_repr_norm_loss", -1.0)
280 |
281 | base_architecture(args)
282 |
--------------------------------------------------------------------------------
/unimol/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import importlib
3 |
4 | # automatically import any Python files in the criterions/ directory
5 | for file in sorted(Path(__file__).parent.glob("*.py")):
6 | if not file.name.startswith("_"):
7 | importlib.import_module("unimol.tasks." + file.name[:-3])
8 |
--------------------------------------------------------------------------------
/unimol/tasks/unimol_free_energy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import logging
6 | import os
7 |
8 | import numpy as np
9 | from unicore.data import (
10 | Dictionary,
11 | NestedDictionaryDataset,
12 | LMDBDataset,
13 | AppendTokenDataset,
14 | PrependTokenDataset,
15 | RightPadDataset,
16 | SortDataset,
17 | TokenizeDataset,
18 | RightPadDataset2D,
19 | RawLabelDataset,
20 | RawArrayDataset,
21 | FromNumpyDataset,
22 | )
23 | from unimol.data import (
24 | KeyDataset,
25 | DistanceDataset,
26 | EdgeTypeDataset,
27 | RemoveHydrogenDataset,
28 | NormalizeDataset,
29 | CroppingDataset,
30 | RightPadDatasetCoord,
31 | data_utils,
32 | )
33 |
34 | from unimol.data.tta_dataset import TTADataset
35 | from unicore.tasks import UnicoreTask, register_task
36 |
37 |
38 | logger = logging.getLogger(__name__)
39 |
40 |
41 | @register_task("mol_free_energy")
42 | class UniMolFreeEnergyTask(UnicoreTask):
43 | """Task for training transformer auto-encoder models."""
44 |
45 | @staticmethod
46 | def add_args(parser):
47 | """Add task-specific arguments to the parser."""
48 | parser.add_argument("data", help="downstream data path")
49 | parser.add_argument("--task-name", type=str, help="downstream task name")
50 | parser.add_argument(
51 | "--classification-head-name",
52 | default="classification",
53 | help="finetune downstream task name",
54 | )
55 | parser.add_argument(
56 | "--num-classes",
57 | default=1,
58 | type=int,
59 | help="finetune downstream task classes numbers",
60 | )
61 | parser.add_argument("--no-shuffle", action="store_true", help="shuffle data")
62 | parser.add_argument(
63 | "--conf-size",
64 | default=10,
65 | type=int,
66 | help="number of conformers generated with each molecule",
67 | )
68 | parser.add_argument(
69 | "--remove-hydrogen",
70 | action="store_true",
71 | help="remove hydrogen atoms",
72 | )
73 | parser.add_argument(
74 | "--remove-polar-hydrogen",
75 | action="store_true",
76 | help="remove polar hydrogen atoms",
77 | )
78 | parser.add_argument(
79 | "--max-atoms",
80 | type=int,
81 | default=256,
82 | help="selected maximum number of atoms in a molecule",
83 | )
84 | parser.add_argument(
85 | "--dict-name",
86 | default="dict.txt",
87 | help="dictionary file",
88 | )
89 | parser.add_argument(
90 | "--charge-dict-name",
91 | default="dict_charge.txt",
92 | help="dictionary file",
93 | )
94 | parser.add_argument(
95 | "--only-polar",
96 | default=1,
97 | type=int,
98 | help="1: only reserve polar hydrogen; 0: no hydrogen; -1: all hydrogen ",
99 | )
100 |
101 | def __init__(self, args, dictionary, charge_dictionary):
102 | super().__init__(args)
103 | self.dictionary = dictionary
104 | self.charge_dictionary = charge_dictionary
105 | self.seed = args.seed
106 | # add mask token
107 | self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True)
108 | self.charge_mask_idx = charge_dictionary.add_symbol("[MASK]", is_special=True)
109 | if self.args.only_polar > 0:
110 | self.args.remove_polar_hydrogen = True
111 | elif self.args.only_polar < 0:
112 | self.args.remove_polar_hydrogen = False
113 | else:
114 | self.args.remove_hydrogen = True
115 |
116 | @classmethod
117 | def setup_task(cls, args, **kwargs):
118 | dictionary = Dictionary.load(os.path.join(args.data, args.dict_name))
119 | logger.info("dictionary: {} types".format(len(dictionary)))
120 | charge_dictionary = Dictionary.load(os.path.join(args.data, args.charge_dict_name))
121 | logger.info("charge dictionary: {} types".format(len(charge_dictionary)))
122 | return cls(args, dictionary, charge_dictionary)
123 |
124 | def load_dataset(self, split, **kwargs):
125 | """Load a given dataset split.
126 | Args:
127 | split (str): name of the data scoure (e.g., train)
128 | """
129 | split_path = os.path.join(self.args.data, self.args.task_name, split + ".lmdb")
130 | dataset = LMDBDataset(split_path)
131 | dataset = TTADataset(
132 | dataset, self.args.seed, "atoms", "coordinates", "charges", "smi", self.args.conf_size
133 | )
134 | tgt_dataset = KeyDataset(dataset, "target")
135 | smi_dataset = KeyDataset(dataset, "smi")
136 |
137 | dataset = RemoveHydrogenDataset(
138 | dataset,
139 | "atoms",
140 | "coordinates",
141 | "charges",
142 | self.args.remove_hydrogen,
143 | self.args.remove_polar_hydrogen,
144 | )
145 | dataset = CroppingDataset(
146 | dataset, self.seed, "atoms", "coordinates", "charges", self.args.max_atoms
147 | )
148 | dataset = NormalizeDataset(dataset, "coordinates", normalize_coord=True)
149 | src_dataset = KeyDataset(dataset, "atoms")
150 | src_dataset = TokenizeDataset(
151 | src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
152 | )
153 | src_charge_dataset = KeyDataset(dataset, "charges")
154 | src_charge_dataset = TokenizeDataset(
155 | src_charge_dataset, self.charge_dictionary, max_seq_len=self.args.max_seq_len
156 | )
157 | coord_dataset = KeyDataset(dataset, "coordinates")
158 |
159 | def PrependAndAppend(dataset, pre_token, app_token):
160 | dataset = PrependTokenDataset(dataset, pre_token)
161 | return AppendTokenDataset(dataset, app_token)
162 |
163 | src_dataset = PrependAndAppend(
164 | src_dataset, self.dictionary.bos(), self.dictionary.eos()
165 | )
166 | src_charge_dataset = PrependAndAppend(
167 | src_charge_dataset, self.charge_dictionary.bos(), self.charge_dictionary.eos()
168 | )
169 | edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
170 | coord_dataset = FromNumpyDataset(coord_dataset)
171 | coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
172 | distance_dataset = DistanceDataset(coord_dataset)
173 |
174 | nest_dataset = NestedDictionaryDataset(
175 | {
176 | "net_input": {
177 | "src_tokens": RightPadDataset(
178 | src_dataset,
179 | pad_idx=self.dictionary.pad(),
180 | ),
181 | "src_charges": RightPadDataset(
182 | src_charge_dataset,
183 | pad_idx=self.charge_dictionary.pad(),
184 | ),
185 | "src_coord": RightPadDatasetCoord(
186 | coord_dataset,
187 | pad_idx=0,
188 | ),
189 | "src_distance": RightPadDataset2D(
190 | distance_dataset,
191 | pad_idx=0,
192 | ),
193 | "src_edge_type": RightPadDataset2D(
194 | edge_type,
195 | pad_idx=0,
196 | ),
197 | },
198 | "target": {
199 | "finetune_target": RawLabelDataset(tgt_dataset),
200 | },
201 | "smi_name": RawArrayDataset(smi_dataset),
202 | },
203 | )
204 | if not self.args.no_shuffle and split == "train":
205 | with data_utils.numpy_seed(self.args.seed):
206 | shuffle = np.random.permutation(len(src_dataset))
207 |
208 | self.datasets[split] = SortDataset(
209 | nest_dataset,
210 | sort_order=[shuffle],
211 | )
212 | else:
213 | self.datasets[split] = nest_dataset
214 |
215 | def build_model(self, args):
216 | from unicore import models
217 |
218 | model = models.build_model(args, self)
219 | model.register_classification_head(
220 | self.args.classification_head_name,
221 | num_classes=self.args.num_classes,
222 | )
223 | return model
224 |
--------------------------------------------------------------------------------
/unimol/tasks/unimol_mlm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import logging
6 | import os
7 |
8 | import numpy as np
9 | from unicore.data import (
10 | Dictionary,
11 | NestedDictionaryDataset,
12 | LMDBDataset,
13 | AppendTokenDataset,
14 | PrependTokenDataset,
15 | SortDataset,
16 | TokenizeDataset,
17 | RawLabelDataset,
18 | FromNumpyDataset,
19 | )
20 | from unimol.data import (
21 | KeyDataset,
22 | ConformerSamplePKADataset,
23 | PKAInputDataset,
24 | PKAMLMInputDataset,
25 | DistanceDataset,
26 | EdgeTypeDataset,
27 | RemoveHydrogenDataset,
28 | NormalizeDataset,
29 | CroppingDataset,
30 | FoldLMDBDataset,
31 | StackedLMDBDataset,
32 | SplitLMDBDataset,
33 | data_utils,
34 | MaskPointsDataset,
35 | )
36 |
37 | from unimol.data.tta_dataset import TTADataset, TTAPKADataset
38 | from unicore.tasks import UnicoreTask, register_task
39 |
40 |
41 | logger = logging.getLogger(__name__)
42 |
43 |
44 | @register_task("mol_pka_mlm")
45 | class UniMolPKAMLMTask(UnicoreTask):
46 | """Task for training transformer auto-encoder models."""
47 |
48 | @staticmethod
49 | def add_args(parser):
50 | """Add task-specific arguments to the parser."""
51 | parser.add_argument("data", help="downstream data path")
52 | parser.add_argument(
53 | "--mask-prob",
54 | default=0.15,
55 | type=float,
56 | help="probability of replacing a token with mask",
57 | )
58 | parser.add_argument(
59 | "--leave-unmasked-prob",
60 | default=0.05,
61 | type=float,
62 | help="probability that a masked token is unmasked",
63 | )
64 | parser.add_argument(
65 | "--random-token-prob",
66 | default=0.05,
67 | type=float,
68 | help="probability of replacing a token with a random token",
69 | )
70 | parser.add_argument(
71 | "--noise-type",
72 | default="uniform",
73 | choices=["trunc_normal", "uniform", "normal", "none"],
74 | help="noise type in coordinate noise",
75 | )
76 | parser.add_argument(
77 | "--noise",
78 | default=1.0,
79 | type=float,
80 | help="coordinate noise for masked atoms",
81 | )
82 | parser.add_argument("--task-name", type=str, help="downstream task name")
83 | parser.add_argument(
84 | "--classification-head-name",
85 | default="classification",
86 | help="finetune downstream task name",
87 | )
88 | parser.add_argument(
89 | "--num-classes",
90 | default=1,
91 | type=int,
92 | help="finetune downstream task classes numbers",
93 | )
94 | parser.add_argument("--no-shuffle", action="store_true", help="shuffle data")
95 | parser.add_argument(
96 | "--conf-size",
97 | default=10,
98 | type=int,
99 | help="number of conformers generated with each molecule",
100 | )
101 | parser.add_argument(
102 | "--remove-hydrogen",
103 | action="store_true",
104 | help="remove hydrogen atoms",
105 | )
106 | parser.add_argument(
107 | "--remove-polar-hydrogen",
108 | action="store_true",
109 | help="remove polar hydrogen atoms",
110 | )
111 | parser.add_argument(
112 | "--max-atoms",
113 | type=int,
114 | default=256,
115 | help="selected maximum number of atoms in a molecule",
116 | )
117 | parser.add_argument(
118 | "--dict-name",
119 | default="dict.txt",
120 | help="dictionary file",
121 | )
122 | parser.add_argument(
123 | "--charge-dict-name",
124 | default="dict_charge.txt",
125 | help="dictionary file",
126 | )
127 | parser.add_argument(
128 | "--only-polar",
129 | default=1,
130 | type=int,
131 | help="1: only reserve polar hydrogen; 0: no hydrogen; -1: all hydrogen ",
132 | )
133 | parser.add_argument(
134 | '--split-mode',
135 | type=str,
136 | default='predefine',
137 | choices=['predefine', 'cross_valid', 'random', 'infer'],
138 | )
139 | parser.add_argument(
140 | "--nfolds",
141 | default=5,
142 | type=int,
143 | help="cross validation split folds"
144 | )
145 | parser.add_argument(
146 | "--fold",
147 | default=0,
148 | type=int,
149 | help='local fold used as validation set, and other folds will be used as train set'
150 | )
151 | parser.add_argument(
152 | "--cv-seed",
153 | default=42,
154 | type=int,
155 | help="random seed used to do cross validation splits"
156 | )
157 |
158 | def __init__(self, args, dictionary, charge_dictionary):
159 | super().__init__(args)
160 | self.dictionary = dictionary
161 | self.charge_dictionary = charge_dictionary
162 | self.seed = args.seed
163 | # add mask token
164 | self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True)
165 | self.charge_mask_idx = charge_dictionary.add_symbol("[MASK]", is_special=True)
166 | if self.args.only_polar > 0:
167 | self.args.remove_polar_hydrogen = True
168 | elif self.args.only_polar < 0:
169 | self.args.remove_polar_hydrogen = False
170 | else:
171 | self.args.remove_hydrogen = True
172 | if self.args.split_mode !='predefine':
173 | self.__init_data()
174 |
175 | def __init_data(self):
176 | data_path = os.path.join(self.args.data, self.args.task_name + '.lmdb')
177 | raw_dataset = LMDBDataset(data_path)
178 | if self.args.split_mode == 'cross_valid':
179 | train_folds = []
180 | for _fold in range(self.args.nfolds):
181 | if _fold == 0:
182 | cache_fold_info = FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds).get_fold_info()
183 | if _fold == self.args.fold:
184 | self.valid_dataset = FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds, cache_fold_info=cache_fold_info)
185 | if _fold != self.args.fold:
186 | train_folds.append(FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds, cache_fold_info=cache_fold_info))
187 | self.train_dataset = StackedLMDBDataset(train_folds)
188 | elif self.args.split_mode == 'random':
189 | cache_fold_info = SplitLMDBDataset(raw_dataset, self.args.seed, 0).get_fold_info()
190 | self.train_dataset = SplitLMDBDataset(raw_dataset, self.args.seed, 0, cache_fold_info=cache_fold_info)
191 | self.valid_dataset = SplitLMDBDataset(raw_dataset, self.args.seed, 1, cache_fold_info=cache_fold_info)
192 |
193 | @classmethod
194 | def setup_task(cls, args, **kwargs):
195 | dictionary = Dictionary.load(os.path.join(args.data, args.dict_name))
196 | logger.info("dictionary: {} types".format(len(dictionary)))
197 | charge_dictionary = Dictionary.load(os.path.join(args.data, args.charge_dict_name))
198 | logger.info("charge dictionary: {} types".format(len(charge_dictionary)))
199 | return cls(args, dictionary, charge_dictionary)
200 |
201 | def load_dataset(self, split, **kwargs):
202 | """Load a given dataset split.
203 | Args:
204 | split (str): name of the data scoure (e.g., train)
205 | """
206 | self.split = split
207 | if self.args.split_mode != 'predefine':
208 | if split == 'train':
209 | dataset = self.train_dataset
210 | elif split == 'valid':
211 | dataset =self.valid_dataset
212 | else:
213 | split_path = os.path.join(self.args.data, split + ".lmdb")
214 | dataset = LMDBDataset(split_path)
215 | tgt_dataset = KeyDataset(dataset, "target")
216 | if split in ['train', 'train.small']:
217 | tgt_list = [tgt_dataset[i] for i in range(len(tgt_dataset))]
218 | self.mean = sum(tgt_list) / len(tgt_list)
219 | self.std = 1
220 | elif split in ['novartis_acid', 'novartis_base', 'sampl6', 'sampl7', 'sampl8']:
221 | self.mean = 6.504894871171601 # precompute from dwar_8228 full set
222 | self.std = 1
223 | id_dataset = KeyDataset(dataset, "ori_smi")
224 |
225 | def GetPKAInput(dataset, metadata_key):
226 | mol_dataset = TTAPKADataset(dataset, self.args.seed, metadata_key, "atoms", "coordinates", "charges")
227 | idx2key = mol_dataset.get_idx2key()
228 | if split in ["train","train.small"]:
229 | sample_dataset = ConformerSamplePKADataset(
230 | mol_dataset, self.args.seed, "atoms", "coordinates", "charges", self.args.conf_size
231 | )
232 | else:
233 | sample_dataset = TTADataset(
234 | mol_dataset, self.args.seed, "atoms", "coordinates","charges","id", self.args.conf_size
235 | )
236 |
237 | sample_dataset = RemoveHydrogenDataset(
238 | sample_dataset,
239 | "atoms",
240 | "coordinates",
241 | "charges",
242 | self.args.remove_hydrogen,
243 | self.args.remove_polar_hydrogen,
244 | )
245 | sample_dataset = CroppingDataset(
246 | sample_dataset, self.seed, "atoms", "coordinates","charges", self.args.max_atoms
247 | )
248 | sample_dataset = NormalizeDataset(sample_dataset, "coordinates", normalize_coord=True)
249 | src_dataset = KeyDataset(sample_dataset, "atoms")
250 | src_dataset = TokenizeDataset(
251 | src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
252 | )
253 | src_charge_dataset = KeyDataset(sample_dataset, "charges")
254 | src_charge_dataset = TokenizeDataset(
255 | src_charge_dataset, self.charge_dictionary, max_seq_len=self.args.max_seq_len
256 | )
257 | coord_dataset = KeyDataset(sample_dataset, "coordinates")
258 | expand_dataset = MaskPointsDataset(
259 | src_dataset,
260 | coord_dataset,
261 | src_charge_dataset,
262 | self.dictionary,
263 | self.charge_dictionary,
264 | pad_idx=self.dictionary.pad(),
265 | charge_pad_idx=self.charge_dictionary.pad(),
266 | mask_idx=self.mask_idx,
267 | charge_mask_idx=self.charge_mask_idx,
268 | noise_type=self.args.noise_type,
269 | noise=self.args.noise,
270 | seed=self.seed,
271 | mask_prob=self.args.mask_prob,
272 | leave_unmasked_prob=self.args.leave_unmasked_prob,
273 | random_token_prob=self.args.random_token_prob,
274 | )
275 |
276 | def PrependAndAppend(dataset, pre_token, app_token):
277 | dataset = PrependTokenDataset(dataset, pre_token)
278 | return AppendTokenDataset(dataset, app_token)
279 |
280 | encoder_token_dataset = KeyDataset(expand_dataset, "atoms")
281 | encoder_target_dataset = KeyDataset(expand_dataset, "targets")
282 | encoder_coord_dataset = KeyDataset(expand_dataset, "coordinates")
283 | encoder_charge_dataset = KeyDataset(expand_dataset, "charges")
284 | encoder_charge_target_dataset = KeyDataset(expand_dataset, "charge_targets")
285 |
286 | src_dataset = PrependAndAppend(
287 | encoder_token_dataset, self.dictionary.bos(), self.dictionary.eos()
288 | )
289 | src_charge_dataset = PrependAndAppend(
290 | encoder_charge_dataset, self.charge_dictionary.bos(), self.charge_dictionary.eos()
291 | )
292 | token_tgt_dataset = PrependAndAppend(
293 | encoder_target_dataset, self.dictionary.pad(), self.dictionary.pad()
294 | )
295 | charge_tgt_dataset = PrependAndAppend(
296 | encoder_charge_target_dataset, self.charge_dictionary.pad(), self.charge_dictionary.pad()
297 | )
298 | encoder_coord_dataset = PrependAndAppend(encoder_coord_dataset, 0.0, 0.0)
299 | encoder_distance_dataset = DistanceDataset(encoder_coord_dataset)
300 |
301 | edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
302 | coord_dataset = FromNumpyDataset(coord_dataset)
303 | coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
304 | distance_dataset = DistanceDataset(coord_dataset)
305 |
306 | return PKAMLMInputDataset(idx2key, src_dataset, src_charge_dataset, encoder_coord_dataset, encoder_distance_dataset, edge_type, token_tgt_dataset, charge_tgt_dataset, distance_dataset, coord_dataset, self.dictionary.pad(), self.charge_dictionary.pad(), split, self.args.conf_size)
307 |
308 | input_a_dataset = GetPKAInput(dataset, "metadata_a")
309 | input_b_dataset = GetPKAInput(dataset, "metadata_b")
310 |
311 | nest_dataset = NestedDictionaryDataset(
312 | {
313 | "net_input_a": input_a_dataset,
314 | "net_input_b": input_b_dataset,
315 | "target": {
316 | "finetune_target": RawLabelDataset(tgt_dataset),
317 | },
318 | "id": id_dataset,
319 | },
320 | )
321 |
322 | if not self.args.no_shuffle and split in ["train","train.small"]:
323 | with data_utils.numpy_seed(self.args.seed):
324 | shuffle = np.random.permutation(len(id_dataset))
325 |
326 | self.datasets[split] = SortDataset(
327 | nest_dataset,
328 | sort_order=[shuffle],
329 | )
330 | else:
331 | self.datasets[split] = nest_dataset
332 |
333 | def build_model(self, args):
334 | from unicore import models
335 |
336 | model = models.build_model(args, self)
337 | model.register_classification_head(
338 | self.args.classification_head_name,
339 | num_classes=self.args.num_classes,
340 | )
341 | return model
342 |
--------------------------------------------------------------------------------
/unimol/tasks/unimol_pka.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) DP Technology.
2 | # This source code is licensed under the MIT license found in the
3 | # LICENSE file in the root directory of this source tree.
4 |
5 | import logging
6 | import os
7 |
8 | import numpy as np
9 | from unicore.data import (
10 | Dictionary,
11 | NestedDictionaryDataset,
12 | LMDBDataset,
13 | AppendTokenDataset,
14 | PrependTokenDataset,
15 | SortDataset,
16 | TokenizeDataset,
17 | RawLabelDataset,
18 | FromNumpyDataset,
19 | )
20 | from unimol.data import (
21 | KeyDataset,
22 | ConformerSamplePKADataset,
23 | PKAInputDataset,
24 | DistanceDataset,
25 | EdgeTypeDataset,
26 | RemoveHydrogenDataset,
27 | NormalizeDataset,
28 | CroppingDataset,
29 | FoldLMDBDataset,
30 | StackedLMDBDataset,
31 | SplitLMDBDataset,
32 | data_utils,
33 | )
34 |
35 | from unimol.data.tta_dataset import TTADataset, TTAPKADataset
36 | from unicore.tasks import UnicoreTask, register_task
37 |
38 |
39 | logger = logging.getLogger(__name__)
40 |
41 |
42 | @register_task("mol_pka")
43 | class UniMolPKATask(UnicoreTask):
44 | """Task for training transformer auto-encoder models."""
45 |
46 | @staticmethod
47 | def add_args(parser):
48 | """Add task-specific arguments to the parser."""
49 | parser.add_argument("data", help="downstream data path")
50 | parser.add_argument("--task-name", type=str, help="downstream task name")
51 | parser.add_argument(
52 | "--classification-head-name",
53 | default="classification",
54 | help="finetune downstream task name",
55 | )
56 | parser.add_argument(
57 | "--num-classes",
58 | default=1,
59 | type=int,
60 | help="finetune downstream task classes numbers",
61 | )
62 | parser.add_argument("--no-shuffle", action="store_true", help="shuffle data")
63 | parser.add_argument(
64 | "--conf-size",
65 | default=10,
66 | type=int,
67 | help="number of conformers generated with each molecule",
68 | )
69 | parser.add_argument(
70 | "--remove-hydrogen",
71 | action="store_true",
72 | help="remove hydrogen atoms",
73 | )
74 | parser.add_argument(
75 | "--remove-polar-hydrogen",
76 | action="store_true",
77 | help="remove polar hydrogen atoms",
78 | )
79 | parser.add_argument(
80 | "--max-atoms",
81 | type=int,
82 | default=256,
83 | help="selected maximum number of atoms in a molecule",
84 | )
85 | parser.add_argument(
86 | "--dict-name",
87 | default="dict.txt",
88 | help="dictionary file",
89 | )
90 | parser.add_argument(
91 | "--charge-dict-name",
92 | default="dict_charge.txt",
93 | help="dictionary file",
94 | )
95 | parser.add_argument(
96 | "--only-polar",
97 | default=1,
98 | type=int,
99 | help="1: only reserve polar hydrogen; 0: no hydrogen; -1: all hydrogen ",
100 | )
101 | parser.add_argument(
102 | '--split-mode',
103 | type=str,
104 | default='predefine',
105 | choices=['predefine', 'cross_valid', 'random', 'infer'],
106 | )
107 | parser.add_argument(
108 | "--nfolds",
109 | default=5,
110 | type=int,
111 | help="cross validation split folds"
112 | )
113 | parser.add_argument(
114 | "--fold",
115 | default=0,
116 | type=int,
117 | help='local fold used as validation set, and other folds will be used as train set'
118 | )
119 | parser.add_argument(
120 | "--cv-seed",
121 | default=42,
122 | type=int,
123 | help="random seed used to do cross validation splits"
124 | )
125 |
126 | def __init__(self, args, dictionary, charge_dictionary):
127 | super().__init__(args)
128 | self.dictionary = dictionary
129 | self.charge_dictionary = charge_dictionary
130 | self.seed = args.seed
131 | # add mask token
132 | self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True)
133 | self.charge_mask_idx = charge_dictionary.add_symbol("[MASK]", is_special=True)
134 | if self.args.only_polar > 0:
135 | self.args.remove_polar_hydrogen = True
136 | elif self.args.only_polar < 0:
137 | self.args.remove_polar_hydrogen = False
138 | else:
139 | self.args.remove_hydrogen = True
140 | if self.args.split_mode !='predefine':
141 | self.__init_data()
142 |
143 | def __init_data(self):
144 | data_path = os.path.join(self.args.data, self.args.task_name + '.lmdb')
145 | raw_dataset = LMDBDataset(data_path)
146 | if self.args.split_mode == 'cross_valid':
147 | train_folds = []
148 | for _fold in range(self.args.nfolds):
149 | if _fold == 0:
150 | cache_fold_info = FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds).get_fold_info()
151 | if _fold == self.args.fold:
152 | self.valid_dataset = FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds, cache_fold_info=cache_fold_info)
153 | if _fold != self.args.fold:
154 | train_folds.append(FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds, cache_fold_info=cache_fold_info))
155 | self.train_dataset = StackedLMDBDataset(train_folds)
156 | elif self.args.split_mode == 'random':
157 | cache_fold_info = SplitLMDBDataset(raw_dataset, self.args.seed, 0).get_fold_info()
158 | self.train_dataset = SplitLMDBDataset(raw_dataset, self.args.seed, 0, cache_fold_info=cache_fold_info)
159 | self.valid_dataset = SplitLMDBDataset(raw_dataset, self.args.seed, 1, cache_fold_info=cache_fold_info)
160 |
161 | @classmethod
162 | def setup_task(cls, args, **kwargs):
163 | dictionary = Dictionary.load(os.path.join(args.data, args.dict_name))
164 | logger.info("dictionary: {} types".format(len(dictionary)))
165 | charge_dictionary = Dictionary.load(os.path.join(args.data, args.charge_dict_name))
166 | logger.info("charge dictionary: {} types".format(len(charge_dictionary)))
167 | return cls(args, dictionary, charge_dictionary)
168 |
169 | def load_dataset(self, split, **kwargs):
170 | """Load a given dataset split.
171 | Args:
172 | split (str): name of the data scoure (e.g., train)
173 | """
174 | self.split = split
175 | if self.args.split_mode != 'predefine':
176 | if split == 'train':
177 | dataset = self.train_dataset
178 | elif split == 'valid':
179 | dataset =self.valid_dataset
180 | else:
181 | split_path = os.path.join(self.args.data, split + ".lmdb")
182 | dataset = LMDBDataset(split_path)
183 | tgt_dataset = KeyDataset(dataset, "target")
184 | if split in ['train', 'train.small']:
185 | tgt_list = [tgt_dataset[i] for i in range(len(tgt_dataset))]
186 | self.mean = sum(tgt_list) / len(tgt_list)
187 | self.std = 1
188 | elif split in ['novartis_acid', 'novartis_base', 'sampl6', 'sampl7', 'sampl8']:
189 | self.mean = 6.504894871171601 # precompute from dwar_8228 full set
190 | self.std = 1
191 | id_dataset = KeyDataset(dataset, "ori_smi")
192 |
193 | def GetPKAInput(dataset, metadata_key):
194 | mol_dataset = TTAPKADataset(dataset, self.args.seed, metadata_key, "atoms", "coordinates", "charges")
195 | idx2key = mol_dataset.get_idx2key()
196 | if split in ["train","train.small"]:
197 | sample_dataset = ConformerSamplePKADataset(
198 | mol_dataset, self.args.seed, "atoms", "coordinates", "charges"
199 | )
200 | else:
201 | sample_dataset = TTADataset(
202 | mol_dataset, self.args.seed, "atoms", "coordinates","charges","id", self.args.conf_size
203 | )
204 |
205 | sample_dataset = RemoveHydrogenDataset(
206 | sample_dataset,
207 | "atoms",
208 | "coordinates",
209 | "charges",
210 | self.args.remove_hydrogen,
211 | self.args.remove_polar_hydrogen,
212 | )
213 | sample_dataset = CroppingDataset(
214 | sample_dataset, self.seed, "atoms", "coordinates","charges", self.args.max_atoms
215 | )
216 | sample_dataset = NormalizeDataset(sample_dataset, "coordinates", normalize_coord=True)
217 | src_dataset = KeyDataset(sample_dataset, "atoms")
218 | src_dataset = TokenizeDataset(
219 | src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
220 | )
221 | src_charge_dataset = KeyDataset(sample_dataset, "charges")
222 | src_charge_dataset = TokenizeDataset(
223 | src_charge_dataset, self.charge_dictionary, max_seq_len=self.args.max_seq_len
224 | )
225 | coord_dataset = KeyDataset(sample_dataset, "coordinates")
226 |
227 | def PrependAndAppend(dataset, pre_token, app_token):
228 | dataset = PrependTokenDataset(dataset, pre_token)
229 | return AppendTokenDataset(dataset, app_token)
230 |
231 | src_dataset = PrependAndAppend(
232 | src_dataset, self.dictionary.bos(), self.dictionary.eos()
233 | )
234 | src_charge_dataset = PrependAndAppend(
235 | src_charge_dataset, self.charge_dictionary.bos(), self.charge_dictionary.eos()
236 | )
237 | edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
238 | coord_dataset = FromNumpyDataset(coord_dataset)
239 | coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
240 | distance_dataset = DistanceDataset(coord_dataset)
241 |
242 | return PKAInputDataset(idx2key, src_dataset, src_charge_dataset, coord_dataset, distance_dataset, edge_type, self.dictionary.pad(), self.charge_dictionary.pad(), split, self.args.conf_size)
243 |
244 | input_a_dataset = GetPKAInput(dataset, "metadata_a")
245 | input_b_dataset = GetPKAInput(dataset, "metadata_b")
246 |
247 | nest_dataset = NestedDictionaryDataset(
248 | {
249 | "net_input_a": input_a_dataset,
250 | "net_input_b": input_b_dataset,
251 | "target": {
252 | "finetune_target": RawLabelDataset(tgt_dataset),
253 | },
254 | "id": id_dataset,
255 | },
256 | )
257 |
258 | if not self.args.no_shuffle and split in ["train","train.small"]:
259 | with data_utils.numpy_seed(self.args.seed):
260 | shuffle = np.random.permutation(len(id_dataset))
261 |
262 | self.datasets[split] = SortDataset(
263 | nest_dataset,
264 | sort_order=[shuffle],
265 | )
266 | else:
267 | self.datasets[split] = nest_dataset
268 |
269 | def build_model(self, args):
270 | from unicore import models
271 |
272 | model = models.build_model(args, self)
273 | model.register_classification_head(
274 | self.args.classification_head_name,
275 | num_classes=self.args.num_classes,
276 | )
277 | return model
278 |
--------------------------------------------------------------------------------