├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── dwar-iBond.tsv ├── novartis_acid.tsv ├── novartis_base.tsv ├── sampl6.tsv ├── sampl7.tsv └── sampl8.tsv ├── enumerator ├── example_out.tsv ├── main.py ├── simple_smarts_pattern.tsv └── smarts_pattern.tsv ├── finetune_pka.sh ├── image ├── inference.png ├── overview.png ├── performance.png └── protensemble.png ├── infer_free_energy.sh ├── infer_pka.sh ├── pretrain_pka.sh ├── scripts ├── infer_mean_ensemble.py └── preprocess_pka.py └── unimol ├── __init__.py ├── data ├── __init__.py ├── conformer_sample_dataset.py ├── coord_pad_dataset.py ├── cropping_dataset.py ├── data_utils.py ├── distance_dataset.py ├── key_dataset.py ├── lmdb_dataset.py ├── mask_points_dataset.py ├── normalize_dataset.py ├── pka_input_dataset.py ├── remove_hydrogen_dataset.py └── tta_dataset.py ├── examples ├── dict.txt └── dict_charge.txt ├── infer.py ├── losses ├── __init__.py ├── mlm_loss.py └── reg_loss.py ├── models ├── __init__.py ├── transformer_encoder_with_pair.py ├── unimol.py └── unimol_pka.py └── tasks ├── __init__.py ├── unimol_free_energy.py ├── unimol_mlm.py └── unimol_pka.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Uni-p*K*a 2 | The official implementation of the model Uni-p*K*a in the paper Bridging Machine Learning and Thermodynamics for Accurate p*K*a Prediction. 3 | 4 | Interactive demo with available model weights at https://bohrium.dp.tech/notebooks/38543442597 5 | 6 | Published paper at [[JACS Au](https://pubs.acs.org/doi/10.1021/jacsau.4c00271)] | Relevant preprint at [[ChemRxiv](https://chemrxiv.org/engage/chemrxiv/article-details/64e8da3879853bbd786ca4eb)] | Small molecule protonation state ranking demo at [[Bohrium App](https://bohrium.dp.tech/apps/uni-pka)] | Full datasets at [[AISSquare](https://www.aissquare.com/datasets/detail?pageType=datasets&name=Uni-pKa-Dataset)] 7 | 8 | This machine-learning-based p*K*a prediction model achieves the state-of-the-art accuracy on several drug-like small molecule macro-p*K*a datasets. 9 | ![Uni-p*K*a's performance](image/performance.png) 10 | 11 | Two core components of Uni-p*K*a framework are 12 | 13 | - A microstate enumerator to systematically build the protonation 14 | ensemble from a single structure. 15 | 16 | - A molecular machine learning model to predict the free energy of each single structure. 17 | 18 | The model reaches the expected accuracy in the inference stage after the comprehensive data preparation by the enumerator, pretraining on the ChemBL dataset and finetuning on our Dwar-iBond dataset. 19 | 20 | ![Alt text](image/overview.png) 21 | 22 | ## Microstate Enumerator 23 | 24 | ### Introduction 25 | 26 | It uses iterated template-matching algorithm to enumerate all the microstates in adjacent macrostates of a molecule's protonation ensemble from at least one microstate stored as SMILES. 27 | 28 | The protonation template `smarts_pattern.tsv` modifies and augments the one in the paper [MolGpka: A Web Server for Small Molecule pKa Prediction Using a Graph-Convolutional Neural Network](https://pubs.acs.org/doi/10.1021/acs.jcim.1c00075) and its open source implementation (MIT license) in the Github repository [MolGpKa](https://github.com/Xundrug/MolGpKa/blob/master/src/utils/smarts_pattern.tsv). 29 | 30 | ### Usage 31 | 32 | The recommended environment is 33 | ```yaml 34 | python = 3.8.13 35 | rdkit = 2021.09.5 36 | numpy = 1.20.3 37 | pandas = 1.5.2 38 | ``` 39 | 40 | #### Reconstruct a plain p*K*a dataset to the Uni-p*K*a standard macro-p*K*a format with fully enumerated microstates 41 | 42 | ```shell 43 | cd enumerator 44 | python main.py reconstruct -i -o -m 45 | ``` 46 | 47 | The `` dataset is assumed be a csv-like file with a column storing SMILES. There are two cases allowed for each entry in the dataset. 48 | 49 | 1. It contains only one SMILES. The Enumerator helps to build the protonated/deprotonated macrostate and complete the original macrostate. 50 | - When `` is "A", it will be considered as an acid (thrown into A pool). 51 | - When `` is "B", it will be considered as a base (thrown into B pool). 52 | 2. It contains a string like "A1,...,Am>>B1,...Bn", where A1,...,Am are comma-separated SMILES of microstates in the acid macrostate (all thrown into A pool), and B1,...,Bn are comma-separated SMILES of microstates in the base macrostate(all thrown into B pool). The Enumerator helps to complete the both. 53 | 54 | ![A/B mode of the microstate enumerator](image/protensemble.png) 55 | 56 | The `` "A" (default) or "B" determines which pool (A/B) is the reference structures and the starting point of the enumeration. 57 | 58 | The `` dataset is then constructed after the enumeration. 59 | 60 | #### Build protonation ensembles from single molecules 61 | 62 | Example: 63 | ```shell 64 | cd enumerator 65 | python main.py ensemble -i ../dataset/sampl6.tsv -o example_out.tsv -u 2 -l -2 -t simple_smarts_pattern.tsv 66 | ``` 67 | 68 | The input dataset is SAMPL6 dataset as example. Reconstructed p*K*a dataset, or just any molecular dataset with an "SMILES" column with single molecular SMILES is supported as the input. In the output file, like `example_out.tsv`, columns include the original SMILES, and macrostates of total charge between the upper bound set by `-u` (default +2) and the lower bound set by `-l` (default -2). A simpler template is prepared as `simple_smarts_pattern.tsv` here for cleaner protonation ensembles which discard some rare structure motifs in the aqueous solution. 69 | 70 | ## Machine Learning Model 71 | 72 | ### Introduction 73 | 74 | It is a [Uni-Mol](https://github.com/dptech-corp/Uni-Mol)-based neural network. By embedding the neural network into thermodynamic relationship between the free energy and p*K*a throughout the training and inference stages, the framework preserves physical consistency and adapts to multiple tasks. 75 | 76 | ![Alt text](image/inference.png) 77 | 78 | ### Usage 79 | 80 | #### Dependencies 81 | 82 | The dependencies of Uni-p*K*a are the same as those of Uni-Mol. 83 | 84 | - [Uni-Core](https://github.com/dptech-corp/Uni-Core), check its [Installation Documentation](https://github.com/dptech-corp/Uni-Core#installation). 85 | - rdkit==2022.9.3, install via `pip install rdkit-pypi==2022.9.3` 86 | 87 | The recommended environment is the docker image. 88 | 89 | ``` 90 | docker pull dptechnology/unimol:latest-pytorch1.11.0-cuda11.3 91 | ``` 92 | 93 | See details in [Uni-Mol](https://github.com/dptech-corp/Uni-Mol/tree/main/unimol#dependencies) repository. 94 | 95 | 96 | ### Ready-to-run training workflow 97 | 98 | #### Data 99 | 100 | The raw data can be downloaded from [[AISSquare](https://www.aissquare.com/datasets/detail?pageType=datasets&name=Uni-pKa-Dataset)]. 101 | 102 | 103 | #### Pretrain with ChemBL 104 | 105 | First, preprocess the ChemBL training and validation sets, and then pretrain the model: 106 | 107 | ```bash 108 | # Preprocess training set 109 | python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/chembl_train.tsv --processed-lmdb-dir chembl --task-name train 110 | 111 | # Preprocess validation set 112 | python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/chembl_valid.tsv --processed-lmdb-dir chembl --task-name valid 113 | 114 | # Copy the necessary dict file 115 | cp -r unimol/examples/* chembl 116 | 117 | # Pretrain the model 118 | bash pretrain_pka.sh 119 | ``` 120 | 121 | Note: The `head_name` in the subsequent scripts must match the `task_name` in `pretrain_pka.sh`. 122 | 123 | 124 | #### Finetune with dwar-iBond 125 | 126 | Next, preprocess the dwar-iBond dataset and finetune the model: 127 | 128 | ```bash 129 | # Preprocess 130 | python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/dwar-iBond.tsv --processed-lmdb-dir dwar --task-name dwar-iBond 131 | 132 | # Copy the necessary dict file 133 | cp -r unimol/examples/* dwar 134 | 135 | # Finetune the model 136 | bash finetune_pka.sh 137 | ``` 138 | 139 | #### Infer p*K*a 140 | 141 | Infer with the finetuned model, taking novartis_acid as an example: 142 | 143 | ```bash 144 | # Preprocess 145 | python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/novartis_acid.tsv --processed-lmdb-dir novartis_acid --task-name novartis_acid 146 | 147 | # Copy the necessary examples from unimol 148 | cp -r unimol/examples/* novartis_acid 149 | 150 | # Run inference 151 | bash infer_pka.sh 152 | ``` 153 | To test with other external test datasets, it may be necessary to modify `data_path`, `infer_task`, and `results_path` in `infer_pka.sh`. 154 | 155 | #### Obtain the result files and calculate the metrics 156 | After inference, extract the results to CSV files and calculate the performance metrics (e.g., MAE, RMSE) on the results: 157 | 158 | ```bash 159 | python ./scripts/infer_mean_ensemble.py --task pka --nfolds 5 --results-path novartis_acid_results 160 | ``` 161 | 162 | The metrics are calculated using the average of the 5-fold model predictions. -------------------------------------------------------------------------------- /dataset/sampl6.tsv: -------------------------------------------------------------------------------- 1 | SMILES TARGET ref. 2 | 0 c1cc2c(cc1O)c3c(o2)C(=O)NCCC3>>O=C1NCCCc2c1oc1ccc([O-])cc21,O=C1[N-]CCCc2c1oc1ccc(O)cc21 9.53 SAMPL6 SM01 pKa 1 3 | 1 FC(F)(F)c1cccc([NH2+]c2ncnc3ccccc23)c1,FC(F)(F)c1cccc(Nc2nc[nH+]c3ccccc23)c1,FC(F)(F)c1cccc(Nc2[nH+]cnc3ccccc23)c1>>c1ccc2c(c1)c(ncn2)Nc3cccc(c3)C(F)(F)F 5.03 SAMPL6 SM02 pKa 1 4 | 2 c1ccc(cc1)Cc2nnc(s2)NC(=O)c3cccs3>>O=C(Nc1nnc([CH-]c2ccccc2)s1)c1cccs1,O=C([N-]c1nnc(Cc2ccccc2)s1)c1cccs1 7.02 SAMPL6 SM03 pKa 1 5 | 3 Clc1ccc(C[NH2+]c2ncnc3ccccc23)cc1,Clc1ccc(CNc2nc[nH+]c3ccccc23)cc1,Clc1ccc(CNc2[nH+]cnc3ccccc23)cc1>>c1ccc2c(c1)c(ncn2)NCc3ccc(cc3)Cl 6.02 SAMPL6 SM04 pKa 1 6 | 4 [OH+]=C(Nc1ccccc1N1CCCCC1)c1ccc(Cl)o1,O=C(Nc1ccccc1[NH+]1CCCCC1)c1ccc(Cl)o1>>c1ccc(c(c1)NC(=O)c2ccc(o2)Cl)N3CCCCC3 4.59 SAMPL6 SM05 pKa 1 7 | 5 O=C(Nc1cccc2ccc[nH+]c12)c1cncc(Br)c1,O=C(Nc1cccc2cccnc12)c1c[nH+]cc(Br)c1,[OH+]=C(Nc1cccc2cccnc12)c1cncc(Br)c1>>c1cc2cccnc2c(c1)NC(=O)c3cc(cnc3)Br 3.03 SAMPL6 SM06 pKa 1 8 | 6 c1cc2cccnc2c(c1)NC(=O)c3cc(cnc3)Br>>O=C([N-]c1cccc2cccnc12)c1cncc(Br)c1 11.74 SAMPL6 SM06 pKa 2 9 | 7 c1ccc(CNc2nc[nH+]c3ccccc23)cc1,c1ccc(C[NH2+]c2ncnc3ccccc23)cc1,c1ccc(CNc2[nH+]cnc3ccccc23)cc1>>c1ccc(cc1)CNc2c3ccccc3ncn2 6.08 SAMPL6 SM07 pKa 1 10 | 8 Cc1ccc2c(c1)c(c(c(=O)[nH]2)CC(=O)O)c3ccccc3>>Cc1ccc2[nH]c(=O)c([CH-]C(=O)O)c(-c3ccccc3)c2c1,Cc1ccc2[nH]c(=O)c(CC(=O)[O-])c(-c3ccccc3)c2c1,Cc1ccc2[n-]c(=O)c(CC(=O)O)c(-c3ccccc3)c2c1 4.22 SAMPL6 SM08 pKa 1 11 | 9 COc1cccc(Nc2nc[nH+]c3ccccc23)c1,COc1cccc([NH2+]c2ncnc3ccccc23)c1,COc1cccc(Nc2[nH+]cnc3ccccc23)c1>>COc1cccc(c1)Nc2c3ccccc3ncn2 5.37 SAMPL6 SM09 pKa 1 12 | 10 c1ccc(cc1)C(=O)NCC(=O)Nc2nc3ccccc3s2>>O=C(CNC(=O)c1ccccc1)[N-]c1nc2ccccc2s1,O=C(C[N-]C(=O)c1ccccc1)Nc1nc2ccccc2s1 9.02 SAMPL6 SM10 pKa 1 13 | 11 [NH3+]c1ncnc2c1cnn2-c1ccccc1,Nc1ncnc2c1c[nH+]n2-c1ccccc1,Nc1[nH+]cnc2c1cnn2-c1ccccc1,Nc1nc[nH+]c2c1cnn2-c1ccccc1>>c1ccc(cc1)n2c3c(cn2)c(ncn3)N 3.89 SAMPL6 SM11 pKa 1 14 | 12 Clc1cccc(Nc2nc[nH+]c3ccccc23)c1,Clc1cccc([NH2+]c2ncnc3ccccc23)c1,Clc1cccc(Nc2[nH+]cnc3ccccc23)c1>>c1ccc2c(c1)c(ncn2)Nc3cccc(c3)Cl 5.28 SAMPL6 SM12 pKa 1 15 | 13 COc1cc2nc[nH+]c(Nc3cccc(C)c3)c2cc1OC,COc1cc2ncnc([NH2+]c3cccc(C)c3)c2cc1OC,COc1cc2[nH+]cnc(Nc3cccc(C)c3)c2cc1OC>>Cc1cccc(c1)Nc2c3cc(c(cc3ncn2)OC)OC 5.77 SAMPL6 SM13 pKa 1 16 | 14 [NH3+]c1ccc2c(c1)[nH+]cn2-c1ccccc1>>[NH3+]c1ccc2c(c1)ncn2-c1ccccc1,Nc1ccc2c(c1)[nH+]cn2-c1ccccc1 2.58 SAMPL6 SM14 pKa 1 17 | 15 [NH3+]c1ccc2c(c1)ncn2-c1ccccc1,Nc1ccc2c(c1)[nH+]cn2-c1ccccc1>>c1ccc(cc1)n2cnc3c2ccc(c3)N 5.3 SAMPL6 SM14 pKa 2 18 | 16 Oc1ccc(-n2c[nH+]c3ccccc32)cc1>>c1ccc2c(c1)ncn2c3ccc(cc3)O,[O-]c1ccc(-n2c[nH+]c3ccccc32)cc1 4.7 SAMPL6 SM15 pKa 1 19 | 17 c1ccc2c(c1)ncn2c3ccc(cc3)O,[O-]c1ccc(-n2c[nH+]c3ccccc32)cc1>>[O-]c1ccc(-n2cnc3ccccc32)cc1 8.94 SAMPL6 SM15 pKa 2 20 | 18 O=C(Nc1cc[nH+]cc1)c1c(Cl)cccc1Cl,[OH+]=C(Nc1ccncc1)c1c(Cl)cccc1Cl>>c1cc(c(c(c1)Cl)C(=O)Nc2ccncc2)Cl 5.37 SAMPL6 SM16 pKa 1 21 | 19 c1cc(c(c(c1)Cl)C(=O)Nc2ccncc2)Cl>>O=C([N-]c1ccncc1)c1c(Cl)cccc1Cl 10.65 SAMPL6 SM16 pKa 2 22 | 20 c1ccc(CSc2nnc(-c3cc[nH+]cc3)o2)cc1,c1ccc(CSc2[nH+]nc(-c3ccncc3)o2)cc1,c1ccc(CSc2n[nH+]c(-c3ccncc3)o2)cc1>>c1ccc(cc1)CSc2nnc(o2)c3ccncc3 3.16 SAMPL6 SM17 pKa 1 23 | 21 O=c1[nH]c(CCC(=[OH+])Nc2ncc(Cc3ccc(F)c(F)c3)s2)nc2ccccc12,O=C(CCc1nc2ccccc2c(=[OH+])[nH]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1[nH]c(=O)c2ccccc2[nH+]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1[nH+]cc(Cc2ccc(F)c(F)c2)s1>>c1ccc2c(c1)c(=O)[nH]c(n2)CCC(=O)Nc3ncc(s3)Cc4ccc(c(c4)F)F 2.15 SAMPL6 SM18 pKa 1 24 | 22 c1ccc2c(c1)c(=O)[nH]c(n2)CCC(=O)Nc3ncc(s3)Cc4ccc(c(c4)F)F>>O=C(CCc1nc2ccccc2c(=O)[n-]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1nc2ccccc2c(=O)[nH]1)[N-]c1ncc(Cc2ccc(F)c(F)c2)s1 9.58 SAMPL6 SM18 pKa 2 25 | 23 O=C(CCc1nc2ccccc2c(=O)[n-]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1nc2ccccc2c(=O)[nH]1)[N-]c1ncc(Cc2ccc(F)c(F)c2)s1>>O=C(CCc1nc2ccccc2c(=O)[n-]1)[N-]c1ncc(Cc2ccc(F)c(F)c2)s1 11.02 SAMPL6 SM18 pKa 3 26 | 24 CCOc1ccc2c(c1)sc(n2)NC(=O)Cc3ccc(c(c3)Cl)Cl>>CCOc1ccc2nc(NC(=O)[CH-]c3ccc(Cl)c(Cl)c3)sc2c1,CCOc1ccc2nc([N-]C(=O)Cc3ccc(Cl)c(Cl)c3)sc2c1 9.56 SAMPL6 SM19 pKa 1 27 | 25 c1cc(cc(c1)OCc2ccc(cc2Cl)Cl)/C=C/3\C(=O)NC(=O)S3>>O=C1[N-]C(=O)/C(=C\c2cccc(OCc3ccc(Cl)cc3Cl)c2)S1 5.7 SAMPL6 SM20 pKa 1 28 | 26 Fc1cnc(Nc2cccc(Br)c2)nc1[NH2+]c1cccc(Br)c1,Fc1cnc([NH2+]c2cccc(Br)c2)nc1Nc1cccc(Br)c1,Fc1cnc(Nc2cccc(Br)c2)[nH+]c1Nc1cccc(Br)c1,Fc1c[nH+]c(Nc2cccc(Br)c2)nc1Nc1cccc(Br)c1>>c1cc(cc(c1)Br)Nc2c(cnc(n2)Nc3cccc(c3)Br)F 4.1 SAMPL6 SM21 pKa 1 29 | 27 Oc1c(I)cc(I)c2ccc[nH+]c12>>c1cc2c(cc(c(c2nc1)O)I)I,[O-]c1c(I)cc(I)c2ccc[nH+]c12 2.4 SAMPL6 SM22 pKa 1 30 | 28 c1cc2c(cc(c(c2nc1)O)I)I,[O-]c1c(I)cc(I)c2ccc[nH+]c12>>[O-]c1c(I)cc(I)c2cccnc12 7.43 SAMPL6 SM22 pKa 2 31 | 29 CCOC(=O)c1ccc(Nc2cc(C)nc(Nc3ccc(C(=O)OCC)cc3)[nH+]2)cc1,CCOC(=O)c1ccc(Nc2cc(C)nc([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2nc(C)cc(Nc3ccc(C(=[OH+])OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2nc(C)cc([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)[nH+]c(Nc3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)nc(Nc3ccc(C(=[OH+])OCC)cc3)n2)cc1>>CCOC(=O)c1ccc(cc1)Nc2cc(nc(n2)Nc3ccc(cc3)C(=O)OCC)C 5.45 SAMPL6 SM23 pKa 1 32 | 30 COc1ccc(-c2oc3[nH+]cnc(NCCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3ncnc(NCC[OH2+])c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3ncnc([NH2+]CCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3nc[nH+]c(NCCO)c3c2-c2ccc(OC)cc2)cc1>>COc1ccc(cc1)c2c3c(ncnc3oc2c4ccc(cc4)OC)NCCO 2.6 SAMPL6 SM24 pKa 1 33 | -------------------------------------------------------------------------------- /dataset/sampl7.tsv: -------------------------------------------------------------------------------- 1 | SMILES TARGET ref. 2 | 0 O=C(NS(C1=CC=CC=C1)(=O)=O)CCC2=CC=CC=C2>>O=C(CCc1ccccc1)[N-]S(=O)(=O)c1ccccc1 4.49 SAMPL7 SM25 pKa 3 | 1 O=S(CCC1=CC=CC=C1)(NC(C)=O)=O>>CC(=O)[N-]S(=O)(=O)CCc1ccccc1 4.91 SAMPL7 SM26 pKa 4 | 2 O=S(CCC1=CC=CC=C1)(NC2(C)COC2)=O>>CC1([N-]S(=O)(=O)CCc2ccccc2)COC1 10.45 SAMPL7 SM27 pKa 5 | 3 CS(NC1(COC1)CCC2=CC=CC=C2)(=O)=O>>CS(=O)(=O)[N-]C1(CCc2ccccc2)COC1 10.05 SAMPL7 SM29 pKa 6 | 4 O=S(NC1(COC1)CCC2=CC=CC=C2)(C3=CC=CC=C3)=O>>O=S(=O)([N-]C1(CCc2ccccc2)COC1)c1ccccc1 10.29 SAMPL7 SM30 pKa 7 | 5 O=S(NC1(COC1)CCC2=CC=CC=C2)(N(C)C)=O>>CN(C)S(=O)(=O)[N-]C1(CCc2ccccc2)COC1 11.02 SAMPL7 SM31 pKa 8 | 6 CS(NC1(CSC1)CCC2=CC=CC=C2)(=O)=O>>CS(=O)(=O)[N-]C1(CCc2ccccc2)CSC1 10.45 SAMPL7 SM32 pKa 9 | 7 O=S(NC1(CSC1)CCC2=CC=CC=C2)(N(C)C)=O>>CN(C)S(=O)(=O)[N-]C1(CCc2ccccc2)CSC1 11.93 SAMPL7 SM34 pKa 10 | 8 CS(N[C@@]1(C[S+]([O-])C1)CCC2=CC=CC=C2)(=O)=O,CS(=O)(=O)[N-]C1(CCc2ccccc2)C[S+](O)C1>>CS(=O)(=O)[N-]C1(CCc2ccccc2)C[S+]([O-])C1 9.87 SAMPL7 SM35 pKa 11 | 9 O=S(N[C@@]1(C[S+]([O-])C1)CCC2=CC=CC=C2)(C3=CC=CC=C3)=O,O=S(=O)([N-]C1(CCc2ccccc2)C[S+](O)C1)c1ccccc1>>O=S(=O)([N-]C1(CCc2ccccc2)C[S+]([O-])C1)c1ccccc1 9.8 SAMPL7 SM36 pKa 12 | 10 O=S(N[C@@]1(C[S+]([O-])C1)CCC2=CC=CC=C2)(N(C)C)=O,CN(C)S(=O)(=O)[N-]C1(CCc2ccccc2)C[S+](O)C1>>CN(C)S(=O)(=O)[N-]C1(CCc2ccccc2)C[S+]([O-])C1 10.33 SAMPL7 SM37 pKa 13 | 11 CS(NC1(CS(C1)(=O)=O)CCC2=CC=CC=C2)(=O)=O>>CS(=O)(=O)[N-]C1(CCc2ccccc2)CS(=O)(=O)C1 9.44 SAMPL7 SM38 pKa 14 | 12 O=S(NC1(CS(C1)(=O)=O)CCC2=CC=CC=C2)(C3=CC=CC=C3)=O>>O=S1(=O)CC(CCc2ccccc2)([N-]S(=O)(=O)c2ccccc2)C1 10.22 SAMPL7 SM39 pKa 15 | 13 O=S(NC1(CS(C1)(=O)=O)CCC2=CC=CC=C2)(N(C)C)=O>>CN(C)S(=O)(=O)[N-]C1(CCc2ccccc2)CS(=O)(=O)C1 9.58 SAMPL7 SM40 pKa 16 | 14 O=S(NC1=NOC(C2=CC=CC=C2)=C1)(C)=O>>CS(=O)(=O)[N-]c1cc(-c2ccccc2)on1 5.22 SAMPL7 SM41 pKa 17 | 15 O=S(NC1=NOC(C2=CC=CC=C2)=C1)(C3=CC=CC=C3)=O>>O=S(=O)([N-]c1cc(-c2ccccc2)on1)c1ccccc1 6.62 SAMPL7 SM42 pKa 18 | 16 O=S(NC1=NOC(C2=CC=CC=C2)=C1)(N(C)C)=O>>CN(C)S(=O)(=O)[N-]c1cc(-c2ccccc2)on1 5.62 SAMPL7 SM43 pKa 19 | 17 O=S(NC(N=N1)=CN1C2=CC=CC=C2)(C)=O>>CS(=O)(=O)[N-]c1cn(-c2ccccc2)nn1 6.34 SAMPL7 SM44 pKa 20 | 18 O=S(NC(N=N1)=CN1C2=CC=CC=C2)(C3=CC=CC=C3)=O>>O=S(=O)([N-]c1cn(-c2ccccc2)nn1)c1ccccc1 5.93 SAMPL7 SM45 pKa 21 | 19 O=S(NC(N=N1)=CN1C2=CC=CC=C2)(N(C)C)=O>>CN(C)S(=O)(=O)[N-]c1cn(-c2ccccc2)nn1 6.42 SAMPL7 SM46 pKa 22 | -------------------------------------------------------------------------------- /dataset/sampl8.tsv: -------------------------------------------------------------------------------- 1 | SMILES TARGET ref. 2 | 0 O=C(O)c1ccccc1[NH2+]c1cccc(C(F)(F)F)c1,OC(=[OH+])c1ccccc1Nc1cccc(C(F)(F)F)c1>>OC(=O)c1ccccc1Nc1cccc(c1)C(F)(F)F,O=C([O-])c1ccccc1[NH2+]c1cccc(C(F)(F)F)c1 2.54 SAMPL8 SAMPL8-1 pKa 3 | 1 OC(=O)c1ccccc1Nc1cccc(c1)C(F)(F)F,O=C([O-])c1ccccc1[NH2+]c1cccc(C(F)(F)F)c1>>O=C([O-])c1ccccc1Nc1cccc(C(F)(F)F)c1,O=C(O)c1ccccc1[N-]c1cccc(C(F)(F)F)c1 5.01 SAMPL8 SAMPL8-1 pKa 4 | 2 CS(=O)(=O)c1ccc(CCC(O)=O)cc1>>CS(=O)(=O)c1ccc(CCC(=O)[O-])cc1 4.41 SAMPL8 SAMPL8-2 pKa 5 | 3 NS(=O)(=O)c1cc(C(O)=O)c(NCc2ccco2)cc1Cl,NS(=O)(=O)c1cc(C(=O)[O-])c([NH2+]Cc2ccco2)cc1Cl,[NH3+]S(=O)(=O)c1cc(C(=O)[O-])c(NCc2ccco2)cc1Cl>>[NH-]S(=O)(=O)c1cc(C(=O)O)c(NCc2ccco2)cc1Cl,NS(=O)(=O)c1cc(C(=O)O)c([N-]Cc2ccco2)cc1Cl,NS(=O)(=O)c1cc(C(=O)[O-])c(NCc2ccco2)cc1Cl 4.0 SAMPL8 SAMPL8-3 pKa 6 | 4 Cc1sc(Nc2ccc(C#[NH+])c(Cl)c2)nc1C(=O)[O-],Cc1sc([NH2+]c2ccc(C#N)c(Cl)c2)nc1C(=O)[O-],Cc1sc(Nc2ccc(C#N)c(Cl)c2)[nH+]c1C(=O)[O-],Cc1sc(Nc2ccc(C#N)c(Cl)c2)nc1C(O)=O>>Cc1sc([N-]c2ccc(C#N)c(Cl)c2)nc1C(=O)O,Cc1sc(Nc2ccc(C#N)c(Cl)c2)nc1C(=O)[O-] 5.77 SAMPL8 SAMPL8-4 pKa 7 | 5 O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,O=C([O-])Cc1ccccc1[NH2+]c1c(Cl)cccc1Cl>>[O-]C(=O)Cc1ccccc1Nc1c(Cl)cccc1Cl,O=C(O)[CH-]c1ccccc1Nc1c(Cl)cccc1Cl 3.92 SAMPL8 SAMPL8-5 pKa 8 | 6 Cc1ccc(Cl)cc1NCc1ccc(s1)C(O)=O,Cc1ccc(Cl)cc1[NH2+]Cc1ccc(C(=O)[O-])s1>>Cc1ccc(Cl)cc1NCc1ccc(C(=O)[O-])s1 4.17 SAMPL8 SAMPL8-6 pKa 9 | 7 [NH3+]c1nc2ccc(Br)cc2n1CC1(O)CCOCC1,Nc1nc2ccc(Br)cc2n1CC1([OH2+])CCOCC1,Nc1[nH+]c2ccc(Br)cc2n1CC1(O)CCOCC1>>Nc1nc2ccc(Br)cc2n1CC1(O)CCOCC1 6.63 SAMPL8 SAMPL8-7 pKa 10 | 8 COC(=O)c1cn2cccc(C(F)(F)F)c2[nH+]1,COC(=[OH+])c1cn2cccc(C(F)(F)F)c2n1>>COC(=O)c1cn2cccc(c2n1)C(F)(F)F 2.78 SAMPL8 SAMPL8-8 pKa 11 | 9 Nc1[nH+]c2ccc(Br)cc2n1CC1(O)CCCCC1,[NH3+]c1nc2ccc(Br)cc2n1CC1(O)CCCCC1,Nc1nc2ccc(Br)cc2n1CC1([OH2+])CCCCC1>>Nc1nc2ccc(Br)cc2n1CC1(O)CCCCC1 6.08 SAMPL8 SAMPL8-9 pKa 12 | 10 CS(=O)(=[OH+])c1cccc(-c2ccc(CNCc3c(F)cccc3Cl)cc2)c1,CS(=O)(=O)c1cccc(-c2ccc(C[NH2+]Cc3c(F)cccc3Cl)cc2)c1>>CS(=O)(=O)c1cccc(c1)-c1ccc(CNCc2c(F)cccc2Cl)cc1 7.71 SAMPL8 SAMPL8-10 pKa 13 | 11 Cc1ccc(Cc2cnc([NH3+])nc2N2CCOCC2)cc1,Cc1ccc(Cc2c[nH+]c(N)nc2N2CCOCC2)cc1,Cc1ccc(Cc2cnc(N)nc2[NH+]2CCOCC2)cc1,Cc1ccc(Cc2cnc(N)[nH+]c2N2CCOCC2)cc1>>Cc1ccc(Cc2cnc(N)nc2N2CCOCC2)cc1 6.98 SAMPL8 SAMPL8-12 pKa 14 | 12 Cc1cc(Cc2cnc(N)nc2[NH3+])cc(C(C)(C)C)c1O,Cc1cc(Cc2c[nH+]c(N)[nH+]c2N)cc(C(C)(C)C)c1[O-],Cc1cc(Cc2cnc(N)[nH+]c2N)cc(C(C)(C)C)c1O,Cc1cc(Cc2cnc([NH3+])nc2[NH3+])cc(C(C)(C)C)c1[O-],Cc1cc(Cc2cnc(N)[nH+]c2[NH3+])cc(C(C)(C)C)c1[O-],Cc1cc(Cc2c[nH+]c(N)nc2N)cc(C(C)(C)C)c1O,Cc1cc(Cc2cnc([NH3+])[nH+]c2N)cc(C(C)(C)C)c1[O-],Cc1cc(Cc2c[nH+]c([NH3+])nc2N)cc(C(C)(C)C)c1[O-],Cc1cc(Cc2c[nH+]c(N)nc2[NH3+])cc(C(C)(C)C)c1[O-],Cc1cc(Cc2cnc([NH3+])nc2N)cc(C(C)(C)C)c1O>>Cc1cc(Cc2cnc(N)[nH+]c2N)cc(C(C)(C)C)c1[O-],Cc1cc(Cc2cnc([NH3+])nc2N)cc(C(C)(C)C)c1[O-],Cc1cc(Cc2cnc(N)nc2N)cc(c1O)C(C)(C)C,Cc1cc(Cc2cnc(N)nc2[NH3+])cc(C(C)(C)C)c1[O-],Cc1cc(Cc2c[nH+]c(N)nc2N)cc(C(C)(C)C)c1[O-] 7.27 SAMPL8 SAMPL8-14 pKa 15 | 13 Cc1ccccc1[NH2+]c1nc(Cl)nc2ccccc12,Cc1ccccc1Nc1nc(Cl)[nH+]c2ccccc12,Cc1ccccc1Nc1[nH+]c(Cl)nc2ccccc12>>Cc1ccccc1Nc1nc(Cl)nc2ccccc12 2.54 SAMPL8 SAMPL8-15 pKa 16 | 14 CC(C)(C)OC(=[OH+])NCc1nc2ccccc2[nH]1,CC(C)(C)OC(=O)NCc1[nH]c2ccccc2[nH+]1>>CC(C)(C)OC(=O)NCc1nc2ccccc2[nH]1 5.1 SAMPL8 SAMPL8-16 pKa 17 | 15 Nc1nc2ccc(Br)cc2n1CC1(C[OH2+])CCOCC1,Nc1[nH+]c2ccc(Br)cc2n1CC1(CO)CCOCC1,[NH3+]c1nc2ccc(Br)cc2n1CC1(CO)CCOCC1>>Nc1nc2ccc(Br)cc2n1CC1(CO)CCOCC1 6.58 SAMPL8 SAMPL8-17 pKa 18 | 16 COc1ccc(Nc2nc(Cl)[nH+]c3ccccc23)cc1OC,COc1ccc([NH2+]c2nc(Cl)nc3ccccc23)cc1OC,COc1ccc(Nc2[nH+]c(Cl)nc3ccccc23)cc1OC>>COc1ccc(Nc2nc(Cl)nc3ccccc23)cc1OC 2.72 SAMPL8 SAMPL8-18 pKa 19 | 17 COc1ncc(Cc2c[nH]c(SCc3nc4ccccc4n3C)[nH+]c2=O)cn1,COc1ncc(Cc2c[nH]c(SCc3[nH+]c4ccccc4n3C)nc2=O)cn1,COc1ncc(Cc2c[nH]c(SCc3nc4ccccc4n3C)nc2=O)c[nH+]1,COc1ncc(Cc2c[nH]c(SCc3nc4ccccc4n3C)nc2=[OH+])cn1>>COc1ncc(Cc2c[nH]c(SCc3nc4ccccc4n3C)nc2=O)cn1 4.93 SAMPL8 SAMPL8-19 pKa 20 | 17 COc1ncc(Cc2c[nH]c(SCc3nc4ccccc4n3C)nc2=O)cn1>>COc1ncc(Cc2c[n-]c(SCc3nc4ccccc4n3C)nc2=O)cn1 6.99 SAMPL8 SAMPL8-19 pKa 21 | 19 [NH3+]c1n[nH]c2nc(-c3ccccc3)c(Cl)cc12,Nc1[nH+][nH]c2nc(-c3ccccc3)c(Cl)cc12,Nc1n[nH]c2[nH+]c(-c3ccccc3)c(Cl)cc12>>Nc1n[nH]c2nc(c(Cl)cc12)-c1ccccc1 2.44 SAMPL8 SAMPL8-20 pKa 22 | 20 Nc1n[nH]c2nc(c(Cl)cc12)-c1ccccc1>>Nc1n[n-]c2nc(-c3ccccc3)c(Cl)cc12,[NH-]c1n[nH]c2nc(-c3ccccc3)c(Cl)cc12 11.44 SAMPL8 SAMPL8-20 pKa 23 | 21 COc1cc(Cc2c(OC)nc([NH3+])[nH+]c2N)cc(OC)c1[O-],COc1cc(Cc2c(OC)nc(N)[nH+]c2N)cc(OC)c1O,COc1cc(Cc2c(N)nc([NH3+])[nH+]c2OC)cc(OC)c1[O-],COc1cc(Cc2c(N)nc(N)[nH+]c2OC)cc(OC)c1O,COc1cc(Cc2c(N)[nH+]c(N)[nH+]c2OC)cc(OC)c1[O-],COc1cc(Cc2c([NH3+])nc(N)nc2OC)cc(OC)c1O,COc1cc(Cc2c(N)nc([NH3+])nc2OC)cc(OC)c1O,COc1cc(Cc2c([NH3+])nc([NH3+])nc2OC)cc(OC)c1[O-],COc1cc(Cc2c([NH3+])nc(N)[nH+]c2OC)cc(OC)c1[O-],COc1cc(Cc2c(OC)nc(N)[nH+]c2[NH3+])cc(OC)c1[O-]>>COc1cc(Cc2c(N)nc([NH3+])nc2OC)cc(OC)c1[O-],COc1cc(Cc2c(N)nc(N)nc2OC)cc(OC)c1O,COc1cc(Cc2c(N)nc(N)[nH+]c2OC)cc(OC)c1[O-],COc1cc(Cc2c(OC)nc(N)[nH+]c2N)cc(OC)c1[O-],COc1cc(Cc2c([NH3+])nc(N)nc2OC)cc(OC)c1[O-] 5.38 SAMPL8 SAMPL8-21 pKa 24 | 22 COc1cc2[nH+]c(Cl)nc(N)c2cc1OC,COc1cc2nc(Cl)nc([NH3+])c2cc1OC,COc1cc2nc(Cl)[nH+]c(N)c2cc1OC>>COc1cc2nc(Cl)nc(N)c2cc1OC 3.36 SAMPL8 SAMPL8-22 pKa 25 | 23 Cc1[nH+]c2cc(O)ccc2s1>>Cc1[nH+]c2cc([O-])ccc2s1,Cc1nc2cc(O)ccc2s1 2.65 SAMPL8 SAMPL8-23 pKa 26 | 24 Cc1[nH+]c2cc([O-])ccc2s1,Cc1nc2cc(O)ccc2s1>>Cc1nc2cc([O-])ccc2s1 9.02 SAMPL8 SAMPL8-23 pKa 27 | -------------------------------------------------------------------------------- /enumerator/example_out.tsv: -------------------------------------------------------------------------------- 1 | -2 -1 0 1 2 SMILES 2 | 0 O=C1NCCCc2c1oc1ccc([O-])cc21 O=C1NCCCc2c1oc1ccc(O)cc21 c1cc2c(cc1O)c3c(o2)C(=O)NCCC3 3 | 1 FC(F)(F)c1cccc(Nc2ncnc3ccccc23)c1 FC(F)(F)c1cccc([NH2+]c2ncnc3ccccc23)c1,FC(F)(F)c1cccc(Nc2[nH+]cnc3ccccc23)c1,FC(F)(F)c1cccc(Nc2nc[nH+]c3ccccc23)c1 FC(F)(F)c1cccc([NH2+]c2nc[nH+]c3ccccc23)c1,FC(F)(F)c1cccc(Nc2[nH+]c[nH+]c3ccccc23)c1,FC(F)(F)c1cccc([NH2+]c2[nH+]cnc3ccccc23)c1 FC(F)(F)c1cccc([NH2+]c2ncnc3ccccc23)c1 4 | 2 O=C(Nc1nnc(Cc2ccccc2)s1)c1cccs1 O=C(Nc1n[nH+]c(Cc2ccccc2)s1)c1cccs1,O=C(Nc1[nH+]nc(Cc2ccccc2)s1)c1cccs1 c1ccc(cc1)Cc2nnc(s2)NC(=O)c3cccs3 5 | 3 Clc1ccc(CNc2ncnc3ccccc23)cc1 Clc1ccc(CNc2[nH+]cnc3ccccc23)cc1,Clc1ccc(CNc2nc[nH+]c3ccccc23)cc1,Clc1ccc(C[NH2+]c2ncnc3ccccc23)cc1 Clc1ccc(CNc2[nH+]c[nH+]c3ccccc23)cc1,Clc1ccc(C[NH2+]c2[nH+]cnc3ccccc23)cc1,Clc1ccc(C[NH2+]c2nc[nH+]c3ccccc23)cc1 Clc1ccc(C[NH2+]c2ncnc3ccccc23)cc1 6 | 4 [OH+]=C(Nc1ccccc1N1CCCCC1)c1ccc(Cl)o1 [OH+]=C(Nc1ccccc1[NH+]1CCCCC1)c1ccc(Cl)o1 [OH+]=C(Nc1ccccc1N1CCCCC1)c1ccc(Cl)o1 7 | 5 O=C(Nc1cccc2cccnc12)c1cncc(Br)c1 O=C(Nc1cccc2ccc[nH+]c12)c1cncc(Br)c1,O=C(Nc1cccc2cccnc12)c1c[nH+]cc(Br)c1 O=C(Nc1cccc2ccc[nH+]c12)c1c[nH+]cc(Br)c1 O=C(Nc1cccc2ccc[nH+]c12)c1cncc(Br)c1 8 | 6 O=C(Nc1cccc2cccnc12)c1cncc(Br)c1 O=C(Nc1cccc2cccnc12)c1c[nH+]cc(Br)c1,O=C(Nc1cccc2ccc[nH+]c12)c1cncc(Br)c1 c1cc2cccnc2c(c1)NC(=O)c3cc(cnc3)Br 9 | 7 c1ccc(CNc2ncnc3ccccc23)cc1 c1ccc(CNc2nc[nH+]c3ccccc23)cc1,c1ccc(CNc2[nH+]cnc3ccccc23)cc1,c1ccc(C[NH2+]c2ncnc3ccccc23)cc1 c1ccc(CNc2[nH+]c[nH+]c3ccccc23)cc1,c1ccc(C[NH2+]c2[nH+]cnc3ccccc23)cc1,c1ccc(C[NH2+]c2nc[nH+]c3ccccc23)cc1 c1ccc(CNc2nc[nH+]c3ccccc23)cc1 10 | 8 Cc1ccc2[nH]c(=O)c(CC(=O)[O-])c(-c3ccccc3)c2c1,Cc1ccc2[n-]c(=O)c(CC(=O)O)c(-c3ccccc3)c2c1 Cc1ccc2[nH]c(=O)c(CC(=O)O)c(-c3ccccc3)c2c1 Cc1ccc2c(c1)c(c(c(=O)[nH]2)CC(=O)O)c3ccccc3 11 | 9 COc1cccc(Nc2ncnc3ccccc23)c1 COc1cccc(Nc2[nH+]cnc3ccccc23)c1,COc1cccc(Nc2nc[nH+]c3ccccc23)c1,COc1cccc([NH2+]c2ncnc3ccccc23)c1 COc1cccc(Nc2[nH+]c[nH+]c3ccccc23)c1,COc1cccc([NH2+]c2[nH+]cnc3ccccc23)c1,COc1cccc([NH2+]c2nc[nH+]c3ccccc23)c1 COc1cccc(Nc2nc[nH+]c3ccccc23)c1 12 | 10 O=C(CNC(=O)c1ccccc1)Nc1nc2ccccc2s1 O=C(CNC(=O)c1ccccc1)Nc1[nH+]c2ccccc2s1 c1ccc(cc1)C(=O)NCC(=O)Nc2nc3ccccc3s2 13 | 11 Nc1ncnc2c1cnn2-c1ccccc1 [NH3+]c1ncnc2c1cnn2-c1ccccc1,Nc1ncnc2c1c[nH+]n2-c1ccccc1,Nc1nc[nH+]c2c1cnn2-c1ccccc1,Nc1[nH+]cnc2c1cnn2-c1ccccc1 Nc1[nH+]c[nH+]c2c1cnn2-c1ccccc1,Nc1nc[nH+]c2c1c[nH+]n2-c1ccccc1,Nc1[nH+]cnc2c1c[nH+]n2-c1ccccc1,[NH3+]c1nc[nH+]c2c1cnn2-c1ccccc1,[NH3+]c1[nH+]cnc2c1cnn2-c1ccccc1,[NH3+]c1ncnc2c1c[nH+]n2-c1ccccc1 [NH3+]c1ncnc2c1cnn2-c1ccccc1 14 | 12 Clc1cccc(Nc2ncnc3ccccc23)c1 Clc1cccc([NH2+]c2ncnc3ccccc23)c1,Clc1cccc(Nc2nc[nH+]c3ccccc23)c1,Clc1cccc(Nc2[nH+]cnc3ccccc23)c1 Clc1cccc([NH2+]c2[nH+]cnc3ccccc23)c1,Clc1cccc([NH2+]c2nc[nH+]c3ccccc23)c1,Clc1cccc(Nc2[nH+]c[nH+]c3ccccc23)c1 Clc1cccc(Nc2nc[nH+]c3ccccc23)c1 15 | 13 COc1cc2ncnc(Nc3cccc(C)c3)c2cc1OC COc1cc2nc[nH+]c(Nc3cccc(C)c3)c2cc1OC,COc1cc2[nH+]cnc(Nc3cccc(C)c3)c2cc1OC,COc1cc2ncnc([NH2+]c3cccc(C)c3)c2cc1OC COc1cc2nc[nH+]c([NH2+]c3cccc(C)c3)c2cc1OC,COc1cc2[nH+]c[nH+]c(Nc3cccc(C)c3)c2cc1OC,COc1cc2[nH+]cnc([NH2+]c3cccc(C)c3)c2cc1OC COc1cc2nc[nH+]c(Nc3cccc(C)c3)c2cc1OC 16 | 14 Nc1ccc2c(c1)ncn2-c1ccccc1 [NH3+]c1ccc2c(c1)ncn2-c1ccccc1,Nc1ccc2c(c1)[nH+]cn2-c1ccccc1 [NH3+]c1ccc2c(c1)[nH+]cn2-c1ccccc1 [NH3+]c1ccc2c(c1)[nH+]cn2-c1ccccc1 17 | 15 Nc1ccc2c(c1)ncn2-c1ccccc1 [NH3+]c1ccc2c(c1)ncn2-c1ccccc1,Nc1ccc2c(c1)[nH+]cn2-c1ccccc1 [NH3+]c1ccc2c(c1)[nH+]cn2-c1ccccc1 [NH3+]c1ccc2c(c1)ncn2-c1ccccc1 18 | 16 [O-]c1ccc(-n2cnc3ccccc32)cc1 Oc1ccc(-n2cnc3ccccc32)cc1,[O-]c1ccc(-n2c[nH+]c3ccccc32)cc1 Oc1ccc(-n2c[nH+]c3ccccc32)cc1 Oc1ccc(-n2c[nH+]c3ccccc32)cc1 19 | 17 [O-]c1ccc(-n2cnc3ccccc32)cc1 Oc1ccc(-n2cnc3ccccc32)cc1,[O-]c1ccc(-n2c[nH+]c3ccccc32)cc1 Oc1ccc(-n2c[nH+]c3ccccc32)cc1 c1ccc2c(c1)ncn2c3ccc(cc3)O 20 | 18 O=C(Nc1ccncc1)c1c(Cl)cccc1Cl O=C(Nc1cc[nH+]cc1)c1c(Cl)cccc1Cl O=C(Nc1cc[nH+]cc1)c1c(Cl)cccc1Cl 21 | 19 O=C(Nc1ccncc1)c1c(Cl)cccc1Cl O=C(Nc1cc[nH+]cc1)c1c(Cl)cccc1Cl c1cc(c(c(c1)Cl)C(=O)Nc2ccncc2)Cl 22 | 20 c1ccc(CSc2nnc(-c3ccncc3)o2)cc1 c1ccc(CSc2n[nH+]c(-c3ccncc3)o2)cc1,c1ccc(CSc2[nH+]nc(-c3ccncc3)o2)cc1,c1ccc(CSc2nnc(-c3cc[nH+]cc3)o2)cc1 c1ccc(CSc2[nH+][nH+]c(-c3ccncc3)o2)cc1,c1ccc(CSc2n[nH+]c(-c3cc[nH+]cc3)o2)cc1,c1ccc(CSc2[nH+]nc(-c3cc[nH+]cc3)o2)cc1 c1ccc(CSc2nnc(-c3cc[nH+]cc3)o2)cc1 23 | 21 O=c1[nH]c(CCC(=[OH+])Nc2ncc(Cc3ccc(F)c(F)c3)s2)nc2ccccc12 O=c1[nH]c(CCC(=[OH+])Nc2ncc(Cc3ccc(F)c(F)c3)s2)[nH+]c2ccccc12,O=c1[nH]c(CCC(=[OH+])Nc2[nH+]cc(Cc3ccc(F)c(F)c3)s2)nc2ccccc12 O=c1[nH]c(CCC(=[OH+])Nc2ncc(Cc3ccc(F)c(F)c3)s2)nc2ccccc12 24 | 22 O=C(CCc1nc2ccccc2c(=O)[n-]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1 O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1 O=C(CCc1[nH]c(=O)c2ccccc2[nH+]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1[nH+]cc(Cc2ccc(F)c(F)c2)s1 c1ccc2c(c1)c(=O)[nH]c(n2)CCC(=O)Nc3ncc(s3)Cc4ccc(c(c4)F)F 25 | 23 O=C(CCc1nc2ccccc2c(=O)[n-]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1 O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1 O=C(CCc1[nH]c(=O)c2ccccc2[nH+]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1,O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1[nH+]cc(Cc2ccc(F)c(F)c2)s1 O=C(CCc1nc2ccccc2c(=O)[n-]1)Nc1ncc(Cc2ccc(F)c(F)c2)s1 26 | 24 CCOc1ccc2nc(NC(=O)Cc3ccc(Cl)c(Cl)c3)sc2c1 CCOc1ccc2[nH+]c(NC(=O)Cc3ccc(Cl)c(Cl)c3)sc2c1 CCOc1ccc2c(c1)sc(n2)NC(=O)Cc3ccc(c(c3)Cl)Cl 27 | 25 O=C1NC(=O)/C(=C\c2cccc(OCc3ccc(Cl)cc3Cl)c2)S1 c1cc(cc(c1)OCc2ccc(cc2Cl)Cl)/C=C/3\C(=O)NC(=O)S3 28 | 26 Fc1cnc(Nc2cccc(Br)c2)nc1Nc1cccc(Br)c1 Fc1cnc(Nc2cccc(Br)c2)nc1[NH2+]c1cccc(Br)c1,Fc1cnc([NH2+]c2cccc(Br)c2)nc1Nc1cccc(Br)c1,Fc1cnc(Nc2cccc(Br)c2)[nH+]c1Nc1cccc(Br)c1,Fc1c[nH+]c(Nc2cccc(Br)c2)nc1Nc1cccc(Br)c1 Fc1cnc([NH2+]c2cccc(Br)c2)nc1[NH2+]c1cccc(Br)c1,Fc1c[nH+]c([NH2+]c2cccc(Br)c2)nc1Nc1cccc(Br)c1,Fc1c[nH+]c(Nc2cccc(Br)c2)[nH+]c1Nc1cccc(Br)c1,Fc1c[nH+]c(Nc2cccc(Br)c2)nc1[NH2+]c1cccc(Br)c1,Fc1cnc([NH2+]c2cccc(Br)c2)[nH+]c1Nc1cccc(Br)c1,Fc1cnc(Nc2cccc(Br)c2)[nH+]c1[NH2+]c1cccc(Br)c1 Fc1cnc(Nc2cccc(Br)c2)nc1[NH2+]c1cccc(Br)c1 29 | 27 [O-]c1c(I)cc(I)c2cccnc12 Oc1c(I)cc(I)c2cccnc12,[O-]c1c(I)cc(I)c2ccc[nH+]c12 Oc1c(I)cc(I)c2ccc[nH+]c12 Oc1c(I)cc(I)c2ccc[nH+]c12 30 | 28 [O-]c1c(I)cc(I)c2cccnc12 Oc1c(I)cc(I)c2cccnc12,[O-]c1c(I)cc(I)c2ccc[nH+]c12 Oc1c(I)cc(I)c2ccc[nH+]c12 c1cc2c(cc(c(c2nc1)O)I)I 31 | 29 CCOC(=O)c1ccc(Nc2cc(C)nc(Nc3ccc(C(=O)OCC)cc3)n2)cc1 CCOC(=O)c1ccc(Nc2cc(C)[nH+]c(Nc3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)nc([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2nc(C)cc([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)nc(Nc3ccc(C(=O)OCC)cc3)[nH+]2)cc1 CCOC(=O)c1ccc([NH2+]c2cc(C)nc([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)[nH+]c(Nc3ccc(C(=O)OCC)cc3)[nH+]2)cc1,CCOC(=O)c1ccc(Nc2nc([NH2+]c3ccc(C(=O)OCC)cc3)cc(C)[nH+]2)cc1,CCOC(=O)c1ccc(Nc2cc(C)[nH+]c([NH2+]c3ccc(C(=O)OCC)cc3)n2)cc1,CCOC(=O)c1ccc(Nc2cc(C)nc([NH2+]c3ccc(C(=O)OCC)cc3)[nH+]2)cc1,CCOC(=O)c1ccc(Nc2nc(C)cc([NH2+]c3ccc(C(=O)OCC)cc3)[nH+]2)cc1 CCOC(=O)c1ccc(Nc2cc(C)nc(Nc3ccc(C(=O)OCC)cc3)[nH+]2)cc1 32 | 30 COc1ccc(-c2oc3ncnc(NCCO)c3c2-c2ccc(OC)cc2)cc1 COc1ccc(-c2oc3[nH+]cnc(NCCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3ncnc([NH2+]CCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3nc[nH+]c(NCCO)c3c2-c2ccc(OC)cc2)cc1 COc1ccc(-c2oc3nc[nH+]c([NH2+]CCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3[nH+]c[nH+]c(NCCO)c3c2-c2ccc(OC)cc2)cc1,COc1ccc(-c2oc3[nH+]cnc([NH2+]CCO)c3c2-c2ccc(OC)cc2)cc1 COc1ccc(-c2oc3[nH+]cnc(NCCO)c3c2-c2ccc(OC)cc2)cc1 33 | -------------------------------------------------------------------------------- /enumerator/main.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Dict, Callable 2 | from collections import OrderedDict 3 | from argparse import ArgumentParser 4 | 5 | import pandas as pd 6 | import numpy as np 7 | from tqdm import tqdm, trange 8 | from rdkit.Chem import MolFromSmarts, AddHs, MolFromSmiles, SanitizeMol, Mol, CanonSmiles, MolToSmiles, RemoveHs, RWMol, GetFormalCharge 9 | from rdkit.RDLogger import DisableLog 10 | 11 | 12 | # Silence! 13 | DisableLog('rdApp.*') 14 | 15 | 16 | # Unreasonable chemical structures 17 | FILTER_PATTERNS = list(map(MolFromSmarts, [ 18 | "[#6X5]", 19 | "[#7X5]", 20 | "[#8X4]", 21 | "[*r]=[*r]=[*r]", 22 | "[#1]-[*+1]~[*-1]", 23 | "[#1]-[*+1]=,:[*]-,:[*-1]", 24 | "[#1]-[*+1]-,:[*]=,:[*-1]", 25 | "[*+2]", 26 | "[*-2]", 27 | "[#1]-[#8+1].[#8-1,#7-1,#6-1]", 28 | "[#1]-[#7+1,#8+1].[#7-1,#6-1]", 29 | "[#1]-[#8+1].[#8-1,#6-1]", 30 | "[#1]-[#7+1].[#8-1]-[C](-[C,#1])(-[C,#1])", 31 | # "[#6;!$([#6]-,:[*]=,:[*]);!$([#6]-,:[#7,#8,#16])]=[C](-[O,N,S]-[#1])", 32 | # "[#6]-,=[C](-[O,N,S])(-[O,N,S]-[#1])", 33 | "[OX1]=[C]-[OH2+1]", 34 | "[NX1,NX2H1,NX3H2]=[C]-[O]-[H]", 35 | "[#6-1]=[*]-[*]", 36 | "[cX2-1]", 37 | "[N+1](=O)-[O]-[H]" 38 | ])) 39 | 40 | 41 | def _read_dataset(dataset_file: str)-> pd.DataFrame: 42 | try: 43 | dataset = pd.read_csv(dataset_file, sep="\t", index_col=False) 44 | except pd.errors.ParserError: 45 | dataset = pd.read_csv(dataset_file, sep=",", index_col=False) 46 | return dataset 47 | 48 | 49 | def read_dataset(dataset_file: str, column: str="SMILES", mode=None) -> Tuple[List[List[str]], List[List[str]], pd.DataFrame]: 50 | ''' 51 | Read an acid/base dataset. 52 | 53 | Params: 54 | ---- 55 | `dataset_file`: The path of a csv-like dataset with columns separated by `\t`. 56 | 57 | `column`: The name of the column storing SMILES. 58 | 59 | `mode`: 60 | - `None` if every entry is acid/base pair recorded as [acid SMILES]>>[basic SMILES]. 61 | - `A` if every entry stores acid structures as [acid SMILES]. 62 | - `B` if every entry stores base structures as [base SMILES]. 63 | 64 | Return: 65 | ---- 66 | acid SMILES collections, base SMILES collections, the dataset as `pandas.Dataframe` 67 | ''' 68 | dataset = _read_dataset(dataset_file) 69 | smis_A, smis_B = [], [] 70 | for smi in dataset[column]: 71 | if ">>" in smi: 72 | ab_smi = smi.split(">>") 73 | smis_A.append(ab_smi[0].split(",")) 74 | smis_B.append(ab_smi[1].split(",")) 75 | else: 76 | if mode == "A": 77 | smis_A.append([smi]), smis_B.append([]) 78 | elif mode == "B": 79 | smis_A.append([]), smis_B.append([smi]) 80 | else: 81 | raise ValueError 82 | return smis_A, smis_B, dataset 83 | 84 | 85 | def read_template(template_file: str) -> Tuple[pd.DataFrame, pd.DataFrame]: 86 | ''' 87 | Read a protonation template. 88 | 89 | Params: 90 | ---- 91 | `template_file`: path of `.csv`-like template, with columns of substructure names, SMARTS patterns, protonation indices and acid/base flags 92 | 93 | Return: 94 | ---- 95 | `template_a2b`, `template_b2a`: acid to base and base to acid templates 96 | ''' 97 | template = pd.read_csv(template_file, sep="\t") 98 | template_a2b = template[template.Acid_or_base == "A"] 99 | template_b2a = template[template.Acid_or_base == "B"] 100 | return template_a2b, template_b2a 101 | 102 | 103 | def match_template(template: pd.DataFrame, mol: Mol, verbose: bool=False) -> list: 104 | ''' 105 | Find protonation site using templates 106 | 107 | Params: 108 | ---- 109 | `template`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, protonation indices and acid/base flags 110 | 111 | `mol`: Molecule 112 | 113 | `verbose`: Boolean flag for printing matching results 114 | 115 | Return: 116 | ---- 117 | A set of matched indices to be (de)protonated 118 | ''' 119 | mol = AddHs(mol) 120 | matches = [] 121 | for idx, name, smarts, index, acid_base in template.itertuples(): 122 | pattern = MolFromSmarts(smarts) 123 | match = mol.GetSubstructMatches(pattern) 124 | if len(match) == 0: 125 | continue 126 | else: 127 | index = int(index) 128 | for m in match: 129 | matches.append(m[index]) 130 | if verbose: 131 | print(f"find index {m[index]} in pattern {name} smarts {smarts}") 132 | return list(set(matches)) 133 | 134 | 135 | def prot(mol: Mol, idx: int, mode: str) -> Mol: 136 | ''' 137 | Protonate / Deprotonate a molecule at a specified site 138 | 139 | Params: 140 | ---- 141 | `mol`: Molecule 142 | 143 | `idx`: Index of reaction 144 | 145 | `mode`: `a2b` means deprotonization, with a hydrogen atom or a heavy atom at `idx`; `b2a` means protonization, with a heavy atom at `idx` 146 | 147 | Return: 148 | ---- 149 | `mol_prot`: (De)protonated molecule 150 | ''' 151 | mw = RWMol(mol) 152 | if mode == "a2b": 153 | atom_H = mw.GetAtomWithIdx(idx) 154 | if atom_H.GetAtomicNum() == 1: 155 | atom_A = atom_H.GetNeighbors()[0] 156 | charge_A = atom_A.GetFormalCharge() 157 | atom_A.SetFormalCharge(charge_A - 1) 158 | mw.RemoveAtom(idx) 159 | mol_prot = mw.GetMol() 160 | else: 161 | charge_H = atom_H.GetFormalCharge() 162 | numH_H = atom_H.GetTotalNumHs() 163 | atom_H.SetFormalCharge(charge_H - 1) 164 | atom_H.SetNumExplicitHs(numH_H - 1) 165 | atom_H.UpdatePropertyCache() 166 | mol_prot = AddHs(mw) 167 | elif mode == "b2a": 168 | atom_B = mw.GetAtomWithIdx(idx) 169 | charge_B = atom_B.GetFormalCharge() 170 | atom_B.SetFormalCharge(charge_B + 1) 171 | numH_B = atom_B.GetNumExplicitHs() 172 | atom_B.SetNumExplicitHs(numH_B + 1) 173 | mol_prot = AddHs(mw) 174 | SanitizeMol(mol_prot) 175 | mol_prot = MolFromSmiles(MolToSmiles(mol_prot)) 176 | mol_prot = AddHs(mol_prot) 177 | return mol_prot 178 | 179 | 180 | def prot_template(template: pd.DataFrame, smi: str, mode: str) -> Tuple[List[int], List[str]]: 181 | """ 182 | Protonate / Deprotonate a SMILES at every found site in the template 183 | 184 | Params: 185 | ---- 186 | `template`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, protonation indices and acid/base flags 187 | 188 | `smi`: The SMILES to be processed 189 | 190 | `mode`: `a2b` means deprotonization, with a hydrogen atom or a heavy atom at `idx`; `b2a` means protonization, with a heavy atom at `idx` 191 | """ 192 | mol = MolFromSmiles(smi) 193 | sites = match_template(template, mol) 194 | smis = [] 195 | for site in sites: 196 | smis.append(CanonSmiles(MolToSmiles(RemoveHs(prot(mol, site, mode))))) 197 | return sites, list(set(smis)) 198 | 199 | 200 | def sanitize_checker(smi: str, filter_patterns: List[Mol], verbose: bool=False) -> bool: 201 | """ 202 | Check if a SMILES can be sanitized and does not contain unreasonable chemical structures. 203 | 204 | Params: 205 | ---- 206 | `smi`: The SMILES to be check. 207 | 208 | `filter_patterns`: Unreasonable chemical structures. 209 | 210 | `verbose`: If True, matched unreasonable chemical structures will be printed. 211 | 212 | Return: 213 | ---- 214 | If the SMILES should be filtered. 215 | """ 216 | mol = AddHs(MolFromSmiles(smi)) 217 | for pattern in filter_patterns: 218 | match = mol.GetSubstructMatches(pattern) 219 | if match: 220 | if verbose: 221 | print(f"pattern {pattern}") 222 | return False 223 | try: 224 | SanitizeMol(mol) 225 | except: 226 | print("cannot sanitize") 227 | return False 228 | return True 229 | 230 | 231 | def sanitize_filter(smis: List[str], filter_patterns: List[Mol]=FILTER_PATTERNS) -> List[str]: 232 | """ 233 | A filter for SMILES can be sanitized and does not contain unreasonable chemical structures. 234 | 235 | Params: 236 | ---- 237 | `smis`: The list of SMILES. 238 | 239 | `filter_patterns`: Unreasonable chemical structures. 240 | 241 | Return: 242 | ---- 243 | The list of SMILES filtered. 244 | """ 245 | def _checker(smi): 246 | return sanitize_checker(smi, filter_patterns) 247 | return list(filter(_checker, smis)) 248 | 249 | 250 | def cnt_stereo_atom(smi: str) -> int: 251 | """ 252 | Count the stereo atoms in a SMILES 253 | """ 254 | mol = MolFromSmiles(smi) 255 | return sum([str(atom.GetChiralTag()) != "CHI_UNSPECIFIED" for atom in mol.GetAtoms()]) 256 | 257 | 258 | def stereo_filter(smis: List[str]) -> List[str]: 259 | """ 260 | A filter against SMILES losing stereochemical information in structure processing. 261 | """ 262 | filtered_smi_dict = dict() 263 | for smi in smis: 264 | nonstereo_smi = CanonSmiles(smi, useChiral=0) 265 | stereo_cnt = cnt_stereo_atom(smi) 266 | if nonstereo_smi not in filtered_smi_dict: 267 | filtered_smi_dict[nonstereo_smi] = (smi, stereo_cnt) 268 | else: 269 | if stereo_cnt > filtered_smi_dict[nonstereo_smi][1]: 270 | filtered_smi_dict[nonstereo_smi] = (smi, stereo_cnt) 271 | return [value[0] for value in filtered_smi_dict.values()] 272 | 273 | 274 | def make_filter(filter_param: OrderedDict) -> Callable: 275 | """ 276 | Make a sequential SMILES filter 277 | 278 | Params: 279 | ---- 280 | `filter_param`: An `collections.OrderedDict` whose keys are single filter functions and the corresponding values are their parameter dictionary. 281 | 282 | Return: 283 | ---- 284 | The sequential filter function 285 | """ 286 | def seq_filter(smis): 287 | for single_filter, param in filter_param.items(): 288 | smis = single_filter(smis, **param) 289 | return smis 290 | return seq_filter 291 | 292 | 293 | def enumerate_template(smi: str, template_a2b: pd.DataFrame, template_b2a: pd.DataFrame, mode: str="A", maxiter: int=2, verbose: int=0, filter_patterns: List[Mol]=FILTER_PATTERNS) -> Tuple[List[str], List[str]]: 294 | """ 295 | Enumerate all the (de)protonation results of one SMILES. 296 | 297 | Params: 298 | ---- 299 | `smi`: The smiles to be processed. 300 | 301 | `template_a2b`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, deprotonation indices and acid flags. 302 | 303 | `template_b2a`: `pandas.Dataframe` with columns of substructure names, SMARTS patterns, protonation indices and base flags. 304 | 305 | `mode`: 306 | - "A": `smi` is an acid to be deprotonated. 307 | - "B": `smi` is a base to be protonated. 308 | 309 | `maxiter`: Max iteration number of template matching and microstate pool growth. 310 | 311 | `verbose`: 312 | - 0: Silent mode. 313 | - 1: Print the length of microstate pools in each iteration. 314 | - 2: Print the content of microstate pools in each iteration. 315 | 316 | `filter_patterns`: Unreasonable chemical structures. 317 | 318 | Return: 319 | ---- 320 | A microstate pool and B microstate pool after enumeration. 321 | """ 322 | if isinstance(smi, str): 323 | smis = [smi] 324 | else: 325 | smis = list(smi) 326 | 327 | enum_func = lambda x: [x] # TODO: Tautomerism enumeration 328 | 329 | if mode == "A": 330 | smis_A_pool, smis_B_pool = smis, [] 331 | elif mode == "B": 332 | smis_A_pool, smis_B_pool = [], smis 333 | pool_length_A = -1 334 | pool_length_B = -1 335 | filters = make_filter({ 336 | sanitize_filter: {"filter_patterns": filter_patterns}, 337 | stereo_filter: {} 338 | }) 339 | i = 0 340 | while (len(smis_A_pool) != pool_length_A or len(smis_B_pool) != pool_length_B) and i < maxiter: 341 | pool_length_A, pool_length_B = len(smis_A_pool), len(smis_B_pool) 342 | if verbose > 0: 343 | print(f"iter {i}: {pool_length_A} acid, {pool_length_B} base") 344 | if verbose > 1: 345 | print(f"iter {i}, acid: {smis_A_pool}, base: {smis_B_pool}") 346 | if (mode == "A" and (i + 1) % 2) or (mode == "B" and i % 2): 347 | smis_A_tmp_pool = [] 348 | for smi in smis_A_pool: 349 | smis_B_pool += filters(prot_template(template_a2b, smi, "a2b")[1]) 350 | smis_A_tmp_pool += filters([CanonSmiles(MolToSmiles(mol)) for mol in enum_func(MolFromSmiles(smi))]) 351 | smis_A_pool += smis_A_tmp_pool 352 | elif (mode == "B" and (i + 1) % 2) or (mode == "A" and i % 2): 353 | smis_B_tmp_pool = [] 354 | for smi in smis_B_pool: 355 | smis_A_pool += filters(prot_template(template_b2a, smi, "b2a")[1]) 356 | smis_B_tmp_pool += filters([CanonSmiles(MolToSmiles(mol)) for mol in enum_func(MolFromSmiles(smi))]) 357 | smis_B_pool += smis_B_tmp_pool 358 | smis_A_pool = filters(smis_A_pool) 359 | smis_B_pool = filters(smis_B_pool) 360 | smis_A_pool = list(set(smis_A_pool)) 361 | smis_B_pool = list(set(smis_B_pool)) 362 | i += 1 363 | if verbose > 0: 364 | print(f"iter {i}: {pool_length_A} acid, {pool_length_B} base") 365 | if verbose > 1: 366 | print(f"iter {i}, acid: {smis_A_pool}, base: {smis_B_pool}") 367 | smis_A_pool = list(map(CanonSmiles, smis_A_pool)) 368 | smis_B_pool = list(map(CanonSmiles, smis_B_pool)) 369 | return smis_A_pool, smis_B_pool 370 | 371 | 372 | def check_dataset(dataset_file: str) -> None: 373 | """ 374 | Check if every entry in the dataset is valid under Uni-pKa standard format. 375 | """ 376 | print(f"Checking reconstructed dataset {dataset_file}") 377 | dataset = _read_dataset(dataset_file) 378 | for i in trange(len(dataset)): 379 | try: 380 | a_smi, b_smi = dataset.iloc[i]["SMILES"].split(">>") 381 | except: 382 | print(f"missing '>>' in index {i}") 383 | continue 384 | if not a_smi: 385 | print(f"missing acid smiles in index {i}") 386 | continue 387 | if not b_smi: 388 | print(f"missing base smiles in index {i}") 389 | continue 390 | for smi in a_smi.split(",") + b_smi.split(","): 391 | if not smi: 392 | print(f"empty smiles in index {i}") 393 | else: 394 | try: 395 | mol = AddHs(MolFromSmiles(smi)) 396 | assert mol is not None 397 | except: 398 | print(f"invalid smiles {smi} in index {i}") 399 | 400 | 401 | def enum_dataset(input_file: str, output_file: str, template: str, mode: str, column:str, maxiter: int) -> pd.DataFrame: 402 | """ 403 | Enumerate the full macrostate and reconstruct the pairwise acid/base dataset from a molecule-wise or pairwise acid/base dataset. 404 | 405 | Params: 406 | ---- 407 | `input_file`: The path of input dataset. 408 | 409 | `output_file`: The path of output dataset. 410 | 411 | `mode`: 412 | - "A": Enumeration is started from the acid. 413 | - "B": Enumeration is started from the base. 414 | 415 | `column`: The name of the column storing SMILES. 416 | 417 | `maxiter`: Max iteration number of template matching and microstate pool growth. 418 | 419 | Return: 420 | ---- 421 | The reconstructed dataset. 422 | """ 423 | print(f"Reconstructing {input_file} with the template {template} from {mode} microstates") 424 | smis_A, smis_B, dataset = read_dataset(input_file, column=column, mode=mode) 425 | 426 | template_a2b, template_b2a = read_template(template) 427 | 428 | if mode == "A": 429 | smis_I = smis_A 430 | elif mode == "B": 431 | smis_I = smis_B 432 | 433 | SMILES_col = [] 434 | for i, smis in tqdm(enumerate(smis_I), total=len(smis_I)): 435 | try: 436 | smis_a, smis_b = enumerate_template(smis, template_a2b, template_b2a, maxiter=maxiter, mode=mode) 437 | except: 438 | print(f"failed to enumerate {smis}: enum error") 439 | raise ValueError 440 | if not smis_a: 441 | if not smis_A[i]: 442 | print(f"failed to enumerate {smis}: no A states") 443 | raise ValueError 444 | else: 445 | smis_a = smis_A[i] 446 | if not smis_b: 447 | if not smis_B[i]: 448 | print(f"failed to enumerate {smis}: no B states") 449 | raise ValueError 450 | else: 451 | smis_b = smis_B[i] 452 | SMILES_col.append(",".join(smis_a) + ">>" + ",".join(smis_b)) 453 | 454 | dataset["SMILES"] = SMILES_col 455 | dataset.to_csv(output_file, sep="\t") 456 | check_dataset(output_file) 457 | return dataset 458 | 459 | 460 | def enum_ensemble(input_file: str, output_file: str, template: str, upper: str, lower: str, column: str, maxiter: int) -> pd.DataFrame: 461 | """ 462 | Enumerate the full macrostate and reconstruct the pairwise acid/base dataset from a molecule-wise or pairwise acid/base dataset. 463 | 464 | Params: 465 | ---- 466 | `input_file`: The path of input molecules. 467 | 468 | `output_file`: The path of output ensembles. 469 | 470 | `upper`: The maximum total charge of protonation ensembles. 471 | 472 | `lower`: The minimum total charge of protonation ensembles. 473 | 474 | `column`: The name of the column storing SMILES. 475 | 476 | `maxiter`: Max iteration number of template matching and microstate pool growth. 477 | 478 | Return: 479 | ---- 480 | The enumerated protonation ensembles. 481 | """ 482 | print(f"Enumerating the protonation ensemble for {input_file} with the template {template} with maximum charge {upper} and minimum charge {lower}") 483 | smis, _, _ = read_dataset(input_file, column, "A") 484 | smis = [smi[0] for smi in smis] 485 | template_a2b, template_b2a = read_template(template) 486 | ensembles = {i: [] for i in range(lower, upper + 1)} 487 | 488 | for smi in tqdm(smis): 489 | 490 | ensemble = dict() 491 | q0 = GetFormalCharge(MolFromSmiles(smi)) 492 | ensemble[q0] = [smi] 493 | 494 | smis_0 = [smi] 495 | 496 | if q0 > lower: 497 | smis_0, smis_b1 = enumerate_template(smis_0, template_a2b, template_b2a, maxiter=maxiter, mode="A") 498 | if smis_b1: 499 | ensemble[q0 - 1] = smis_b1 500 | for q in range(q0 - 2, lower, -1): 501 | if q + 1 in ensemble: 502 | _, smis_b = enumerate_template(ensemble[q + 1], template_a2b, template_b2a, maxiter=maxiter, mode="A") 503 | if smis_b: 504 | ensemble[q] = smis_b 505 | 506 | if q0 < upper: 507 | smis_a1, smis_0 = enumerate_template(smis_0, template_a2b, template_b2a, maxiter=maxiter, mode="B") 508 | if smis_a1: 509 | ensemble[q0 + 1] = smis_a1 510 | for q in range(q0 + 2, upper): 511 | if q - 1 in ensemble: 512 | smis_a, _ = enumerate_template(ensemble[q - 1], template_a2b, template_b2a, maxiter=maxiter, mode="B") 513 | if smis_a: 514 | ensemble[q] = smis_a 515 | 516 | ensemble[q0] = smis_0 517 | 518 | for q in ensembles: 519 | if q in ensemble: 520 | ensembles[q].append(",".join(ensemble[q])) 521 | else: 522 | ensembles[q].append(None) 523 | 524 | ensembles[column] = smis 525 | ensembles = pd.DataFrame(ensembles) 526 | ensembles.to_csv(output_file, sep="\t") 527 | return ensembles 528 | 529 | 530 | if __name__ == "__main__": 531 | parser = ArgumentParser() 532 | subparser = parser.add_subparsers(dest="command") 533 | 534 | parser_check = subparser.add_parser("check") 535 | parser_check.add_argument("-i", "--input", type=str) 536 | 537 | parser_enum = subparser.add_parser("reconstruct") 538 | parser_enum.add_argument("-i", "--input", type=str) 539 | parser_enum.add_argument("-o", "--output", type=str) 540 | parser_enum.add_argument("-t", "--template", type=str, default="smarts_pattern.tsv") 541 | parser_enum.add_argument("-n", "--maxiter", type=int, default=10) 542 | parser_enum.add_argument("-c", "--column", type=str, default="SMILES") 543 | parser_enum.add_argument("-m", "--mode", type=str, default="A") 544 | 545 | parser_enum = subparser.add_parser("ensemble") 546 | parser_enum.add_argument("-i", "--input", type=str) 547 | parser_enum.add_argument("-o", "--output", type=str) 548 | parser_enum.add_argument("-t", "--template", type=str, default="simple_smarts_pattern.tsv") 549 | parser_enum.add_argument("-u", "--upper", type=int, default=2) 550 | parser_enum.add_argument("-l", "--lower", type=int, default=-2) 551 | parser_enum.add_argument("-n", "--maxiter", type=int, default=10) 552 | parser_enum.add_argument("-c", "--column", type=str, default="SMILES") 553 | 554 | args = parser.parse_args() 555 | if args.command == "check": 556 | check_dataset(args.input) 557 | elif args.command == "reconstruct": 558 | enum_dataset( 559 | input_file=args.input, 560 | output_file=args.output, 561 | template=args.template, 562 | mode=args.mode, 563 | column=args.column, 564 | maxiter=args.maxiter 565 | ) 566 | elif args.command == "ensemble": 567 | enum_ensemble( 568 | input_file=args.input, 569 | output_file=args.output, 570 | template=args.template, 571 | upper=args.upper, 572 | lower=args.lower, 573 | column=args.column, 574 | maxiter=args.maxiter 575 | ) 576 | -------------------------------------------------------------------------------- /enumerator/simple_smarts_pattern.tsv: -------------------------------------------------------------------------------- 1 | Substructure SMARTS Index Acid_or_base 2 | Sulfate monoether [SX4:0](=[O:1])(=[O:2])(-[O:3])-[OX2:4]-[H:5] 4 A 3 | Sulfate monoether [SX4:0](=[O:1])(=[O:2])(-[O:3])-[O-1:4] 4 B 4 | Sulfonic acid [SX4:0](=[O:1])(=[O:2])(-[#6,#7:3])-[OX2:4]-[H:5] 4 A 5 | Sulfonic acid [SX4:0](=[O:1])(=[O:2])(-[#6,#7:3])-[O-1:4] 4 B 6 | Sulfinic acid [SX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A 7 | Sulfinic acid [SX3:0](=[O:1])(-[0#6,#7:2])-[O-1:3] 3 B 8 | Seleninic acid [SeX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A 9 | Seleninic acid [SeX3:0](=[O:1])(-[#6,#7:2])-[O-1:3] 3 B 10 | Selenenic acid [SeX2:0]-[OX2:1]-[H:2] 1 A 11 | Selenenic acid [SeX2:0]-[O-1:1] 1 B 12 | Arsonic acid [AsX4:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A 13 | Arsonic acid [AsX4:0](=[O:1])(-[#6,#7:2])-[O-1:3] 3 B 14 | Thiosulfuric acid [S:0]~[SX4:1](~[O:2])(~[O:3])-[O:4]-[H:5] 4 A 15 | Thiosulfuric acid [S:0]~[SX4:1](~[O:2])(~[O:3])-[O-1:4] 4 B 16 | Phosph(o/i)nic acid [PX4:0](=[O:1])(-[OX2:2]-[H:5])(-[#1,#6,#7,#8:3])(-[#1,#6,#7,#8:4]) 2 A 17 | Phosph(o/i)nic acid [PX4:0](=[O:1])(-[O-1:2])(-[#1,#6,#7,#8:3])(-[#1,#6,#7,#8:4]) 2 B 18 | Phosphate (mono/di)ether [PX4:0](=[O:1])(-[O:2])(-[O:3])-[OX2:4]-[H:5] 4 A 19 | Phosphate (mono/di)ether [PX4:0](=[O:1])(-[O:2])(-[O:3])-[O-1:4] 4 B 20 | Carboxyl acid [$([#6]=[#8,#7]),$(C#N):0]-[OX2:1]-[H:2] 1 A 21 | Carboxyl acid [$([#6]=[#8,#7]),$(C#N):0]-[O-1:1] 1 B 22 | Carboxyl acid enol [C:0]=[C:1](-[OX2:2]-[H:3])-[OX2:4]-[H:5] 4 A 23 | Carboxyl acid enol [C:0]=[C:1](-[OX2:2]-[H:3])-[O-1:4] 4 B 24 | Carbo(di)thioic acid [CX3:0](=[O,S:1])-[SX2,OX2:2]-[H:3] 2 A 25 | Carbo(di)thioic acid [CX3:0](=[O,S:1])-[S-1,O-1:2] 2 B 26 | Carboxyl acid vinylogue [O:0]=[C:1]-[C:2]=[C:3]-[OX2:4]-[H:5] 4 A 27 | Carboxyl acid vinylogue [O:0]=[C:1]-[C:2]=[C:3]-[O-1:4] 4 B 28 | Thiol/Thiophenol [#6,#7:0]-[SX2:1]-[H:2] 1 A 29 | Thiol/Thiophenol [#6,#7:0]-[S-1:1] 1 B 30 | Phenol [c,n:0]-[OX2:1]-[H:2] 1 A 31 | Phenol [c,n:0]-[O-1:1] 1 B 32 | Hydroperoxide/Hydroxyl amine [O,N:0]-[OX2:1]-[H:2] 1 A 33 | Hydroperoxide/Hydroxyl amine [O,N:0]-[O-1:1] 1 B 34 | Azole [#7:0]1(-[H:5])-,:[#7,#6:1]=,:[#7,#6:2]-,:[#7,#6:3]=,:[#7,#6:4]-,:1 0 A 35 | Azole [#7-1:0]1-,:[#7,#6:1]=,:[#7,#6:2]-,:[#7,#6:3]=,:[#7,#6:4]-,:1 0 B 36 | Aza-aromatics [n:0]-[H:1] 0 A 37 | Aza-aromatics [n-1,n+0X2:0] 0 B 38 | Oxime [$([#7]:,=[#6,#7]),$([#7]:,=[#6,#7]:,-[#6,#7]:,=[#6,#7]):0]-[OX2,NX3:1]-[H:2] 1 A 39 | Oxime [$([#7]:,=[#6,#7]),$([#7]:,=[#6,#7]:,-[#6,#7]:,=[#6,#7]):0]-[O-1,NX2-1:1] 1 B 40 | Amine [NX4+1:0](-[H:4])(-[CX4,c,#7,#8,#1,S,$(C=C),Cl:1])(-[CX4,c,#7,#8,#1,S,$(C=C):2])(-[CX4,c,#7,#8,#1,S,$(C=C):3]) 0 A 41 | Amine [NX3:0](-[CX4,c,#7,#8,#1,S,$(C=C),Cl:1])(-[CX4,c,#7,#8,#1,S,$(C=C):2])(-[CX4,c,#7,#8,#1,S,$(C=C):3]) 0 B 42 | Imine [#6,#7,P,S:0]=[NX3+1:1](-[H:2]) 1 A 43 | Imine [#6,#7,P,S:0]=[NX2:1] 1 B 44 | Amide [$([#7]=[#7,#8]),$(c:c:c:c:[#7+1]):0]-[NX3:1]-[H:2] 1 A 45 | Amide [$([#7]=[#7,#8]),$(c:c:c:c:[#7+1]):0]-[NX2-1:1] 1 B 46 | Amide imine [$([#6]-,:[O,S,#7]),N+1:0]=,:[NX2:1]-[H:2] 1 A 47 | Amide imine [$([#6]-,:[O,S,#7]),N+1:0]=,:[NX1-1:1] 1 B 48 | Sulfamide [SX4:0](=[O:1])(=[O:2])-[NX3:3]-[H:4] 3 A 49 | Sulfamide [SX4:0](=[O:1])(=[O:2])-[NX2-1:3] 3 B 50 | Phosphamide [PX4:0](=[O:1])-[NX3:2]-[H:3] 2 A 51 | Phosphamide [PX4:0](=[O:1])-[NX2-1:2] 2 B 52 | Enol [$([#6]=,:[#7,#8]),$(C#N),#7+1,$([S]=[O]),OH1:0]-[#6:1]:,=[#6:2]-[OX2:3]-[H:4] 3 A 53 | Enol [$([#6]=,:[#7,#8]),$(C#N),#7+1,$([S]=[O]),OH1:0]-[#6:1]:,=[#6:2]-[O-1:3] 3 B 54 | Hydrocyanic acid [N:0]#[C:1]-[H:2] 1 A 55 | Hydrocyanic acid [N:0]#[C-1:1] 1 B 56 | Selenol [SeX2:0]-[H:1] 0 A 57 | Selenol [SeX1-1:0] 0 B -------------------------------------------------------------------------------- /enumerator/smarts_pattern.tsv: -------------------------------------------------------------------------------- 1 | Substructure SMARTS Index Acid_or_base 2 | Sulfate monoether [SX4:0](=[O:1])(=[O:2])(-[O:3])-[OX2:4]-[H:5] 4 A 3 | Sulfate monoether [SX4:0](=[O:1])(=[O:2])(-[O:3])-[O-1:4] 4 B 4 | Sulfonic acid [SX4:0](=[O:1])(=[O:2])(-[#6,#7:3])-[OX2:4]-[H:5] 4 A 5 | Sulfonic acid [SX4:0](=[O:1])(=[O:2])(-[#6,#7:3])-[O-1:4] 4 B 6 | Sulfinic acid [SX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A 7 | Sulfinic acid [SX3:0](=[O:1])(-[0#6,#7:2])-[O-1:3] 3 B 8 | Seleninic acid [SeX3:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A 9 | Seleninic acid [SeX3:0](=[O:1])(-[#6,#7:2])-[O-1:3] 3 B 10 | Selenenic acid [SeX2:0]-[OX2:1]-[H:2] 1 A 11 | Selenenic acid [SeX2:0]-[O-1:1] 1 B 12 | Arsonic acid [AsX4:0](=[O:1])(-[#6,#7:2])-[OX2:3]-[H:4] 3 A 13 | Arsonic acid [AsX4:0](=[O:1])(-[#6,#7:2])-[O-1:3] 3 B 14 | Thiosulfuric acid [S:0]~[SX4:1](~[O:2])(~[O:3])-[O:4]-[H:5] 4 A 15 | Thiosulfuric acid [S:0]~[SX4:1](~[O:2])(~[O:3])-[O-1:4] 4 B 16 | Phosph(o/i)nic acid [PX4:0](=[O:1])(-[OX2:2]-[H:5])(-[#1,#6,#7,#8:3])(-[#1,#6,#7,#8:4]) 2 A 17 | Phosph(o/i)nic acid [PX4:0](=[O:1])(-[O-1:2])(-[#1,#6,#7,#8:3])(-[#1,#6,#7,#8:4]) 2 B 18 | Phosphate (mono/di)ether [PX4:0](=[O:1])(-[O:2])(-[O:3])-[OX2:4]-[H:5] 4 A 19 | Phosphate (mono/di)ether [PX4:0](=[O:1])(-[O:2])(-[O:3])-[O-1:4] 4 B 20 | Carboxyl acid [$([#6]=[#8,#7]),$(C#N):0]-[OX2:1]-[H:2] 1 A 21 | Carboxyl acid [$([#6]=[#8,#7]),$(C#N):0]-[O-1:1] 1 B 22 | Carboxyl acid enol [C:0]=[C:1](-[OX2:2]-[H:3])-[OX2:4]-[H:5] 4 A 23 | Carboxyl acid enol [C:0]=[C:1](-[OX2:2]-[H:3])-[O-1:4] 4 B 24 | Carbo(di)thioic acid [CX3:0](=[O,S:1])-[SX2,OX2:2]-[H:3] 2 A 25 | Carbo(di)thioic acid [CX3:0](=[O,S:1])-[S-1,O-1:2] 2 B 26 | Carboxyl acid vinylogue [O:0]=[C:1]-[C:2]=[C:3]-[OX2:4]-[H:5] 4 A 27 | Carboxyl acid vinylogue [O:0]=[C:1]-[C:2]=[C:3]-[O-1:4] 4 B 28 | Thiol/Thiophenol [#6,#7:0]-[SX2:1]-[H:2] 1 A 29 | Thiol/Thiophenol [#6,#7:0]-[S-1:1] 1 B 30 | Phenol [c,n:0]-[OX2:1]-[H:2] 1 A 31 | Phenol [c,n:0]-[O-1:1] 1 B 32 | Alcohol [$([CX4]-[$([#6]=,:[#7,#8]),$([#6]=,:[#6]-,:[#6]=,:[#7,#8]),$(C#N),O,#7+1,$(C-Cl),$(C#C),$(C-O),S:0]),$([CH2]-[#1,CH3]):0]-[OX2:1]-[H:2] 1 A 33 | Alcohol [$([CX4]-[$([#6]=,:[#7,#8]),$([#6]=,:[#6]-,:[#6]=,:[#7,#8]),$(C#N),O,#7+1,$(C-Cl),$(C#C),$(C-O),S:0]),$([CH2]-[#1,CH3]):0]-[O-1:1] 1 B 34 | Hydroxypyridine [n:0]:[c:1]-[OH2+1:2]-[H:3] 2 A 35 | Hydroxypyridine [n:0]:[c:1]-[OX2H1:2] 2 B 36 | Methylpyridine [n:0](-[C:1]=[O:2]):[c:3]:[c:4]:[c:5]-[CX4:6]-[H:7] 6 A 37 | Methylpyridine [n:0](-[C:1]=[O:2]):[c:3]:[c:4]:[c:5]-[CX3-1:6] 6 B 38 | Hydroperoxide/Hydroxyl amine [O,N:0]-[OX2:1]-[H:2] 1 A 39 | Hydroperoxide/Hydroxyl amine [O,N:0]-[O-1:1] 1 B 40 | Azole [#7:0]1(-[H:5])-,:[#7,#6:1]=,:[#7,#6:2]-,:[#7,#6:3]=,:[#7,#6:4]-,:1 0 A 41 | Azole [#7-1:0]1-,:[#7,#6:1]=,:[#7,#6:2]-,:[#7,#6:3]=,:[#7,#6:4]-,:1 0 B 42 | Aza-aromatics [n:0]-[H:1] 0 A 43 | Aza-aromatics [n-1,n+0X2:0] 0 B 44 | N-substitute aza-aromatics [n+1:0]-[CX4:1]-[H:2] 1 A 45 | N-substitute aza-aromatics [n+1:0]-[CX3-1:1] 1 B 46 | Oxime [$([#7]:,=[#6,#7]),$([#7]:,=[#6,#7]:,-[#6,#7]:,=[#6,#7]):0]-[OX2,NX3:1]-[H:2] 1 A 47 | Oxime [$([#7]:,=[#6,#7]),$([#7]:,=[#6,#7]:,-[#6,#7]:,=[#6,#7]):0]-[O-1,NX2-1:1] 1 B 48 | Amine [NX4+1:0](-[H:4])(-[CX4,c,#7,#8,#1,S,$(C=C),Cl:1])(-[CX4,c,#7,#8,#1,S,$(C=C):2])(-[CX4,c,#7,#8,#1,S,$(C=C):3]) 0 A 49 | Amine [NX3:0](-[CX4,c,#7,#8,#1,S,$(C=C),Cl:1])(-[CX4,c,#7,#8,#1,S,$(C=C):2])(-[CX4,c,#7,#8,#1,S,$(C=C):3]) 0 B 50 | Imine [#6,#7,P,S:0]=[NX3+1:1](-[H:2]) 1 A 51 | Imine [#6,#7,P,S:0]=[NX2:1] 1 B 52 | Amide [$([#6]=,:[O,S,#7:0]),$([#7]=[#7,#8]),$([#6]:,=[#6]:,-[#7]=[#7,#8]),$(c:c:c:c:[#7+1]):0]-[NX3:1]-[H:2] 1 A 53 | Amide [$([#6]=,:[O,S,#7:0]),$([#7]=[#7,#8]),$([#6]:,=[#6]:,-[#7]=[#7,#8]),$(c:c:c:c:[#7+1]):0]-[NX2-1:1] 1 B 54 | Amide imine [$([#6]-,:[O,S,#7]),N+1:0]=,:[NX2:1]-[H:2] 1 A 55 | Amide imine [$([#6]-,:[O,S,#7]),N+1:0]=,:[NX1-1:1] 1 B 56 | Sulfamide [SX4:0](=[O:1])(=[O:2])-[NX3:3]-[H:4] 3 A 57 | Sulfamide [SX4:0](=[O:1])(=[O:2])-[NX2-1:3] 3 B 58 | Phosphamide [PX4:0](=[O:1])-[NX3:2]-[H:3] 2 A 59 | Phosphamide [PX4:0](=[O:1])-[NX2-1:2] 2 B 60 | Amide vinylogue [NX3:0](-[H:5])-,:[#6:1]=,:[#6:2]-,:[$([#6]=,:[#7,#8]),$(C#N):3] 0 A 61 | Amide vinylogue [NX2-1:0]-,:[#6:1]=,:[#6:2]-,:[$([#6]=,:[#7,#8]),$(C#N):3] 0 B 62 | Di Carbonyl βH [$([#6,#7]=,:[#7,#8]),$(C#N),$([#6]=,:[#6]-,:[$([#6]=,:[#7,#8]),$(C#N)]),$(P=O),$(S=O),S+1,Cl,F,O,c:0]-,:[#6X4:1](-[H:3])-,:[$([#6]=,:[#7,#8]),$(C#N),$(P=O),$(S=O):2] 1 A 63 | Di Carbonyl βH [$([#6,#7]=,:[#7,#8]),$(C#N),$([#6]=,:[#6]-,:[$([#6]=,:[#7,#8]),$(C#N)]),$(P=O),$(S=O),S+1,Cl,F,O,c:0]-,:[#6X3-1:1]-,:[$([#6]=,:[#7,#8]),$(C#N),$(P=O),$(S=O):2] 1 B 64 | Carbonyl βH [$([#6](=O)(-,:[#7+1,#6,#1])(-,:[#6,#1])),$([N](=O)(-O)),$(P=O),$(C=C-C(=O)-[#6,#1]):0]-,:[#6X4:1]-[H:2] 1 A 65 | Carbonyl βH [$([#6](=O)(-,:[#7+1,#6,#1])(-,:[#6,#1])),$([N](=O)(-O)),$(P=O),$(C=C-C(=O)-[#6,#1]):0]-,:[#6X3-1:1] 1 B 66 | Carbonyl allene [O:0]=[C:1]-[C:2]=[C:3]=[CX3:4]-[H:5] 4 A 67 | Carbonyl allene [O:0]=[C:1]-[C:2]=[C:3]=[CX2-1:4] 4 B 68 | Enol [$([#6]=,:[#7,#8]),$(C#N),#7+1,$([S]=[O]),c,$(C=C),OH1:0]-[#6:1]:,=[#6:2]-[OX2:3]-[H:4] 3 A 69 | Enol [$([#6]=,:[#7,#8]),$(C#N),#7+1,$([S]=[O]),c,$(C=C),OH1:0]-[#6:1]:,=[#6:2]-[O-1:3] 3 B 70 | Enol [#6:0]=[#6:1](-[$(C=O),$(C(=C)-[OH1]):2])-[OX2:3]-[H:4] 3 A 71 | Enol [#6:0]=[#6:1](-[$(C=O),$(C(=C)-[OH1]):2])-[O-1:3] 3 B 72 | Acyl group [#6:0](-[O,N:1])=[OX2+1:2]-[H:3] 2 A 73 | Acyl group [#6:0](-[O,N:1])=[OX1:2] 2 B 74 | Sulfoxide [S+1:0](-[OX2:1]-[H:4])(-[#6:2])(-[#6:3]) 1 A 75 | Sulfoxide [S+1:0](-[O-1:1])(-[#6:2])(-[#6:3]) 1 B 76 | Sulfoxide [S:0](=[OX2+1:1]-[H:4])(-[#6:2])(-[#6:3]) 1 A 77 | Sulfoxide [S:0](=[OX1:1])(-[#6:2])(-[#6:3]) 1 B 78 | Sulfoxide [S:0](=[OX2+1:1]-[H:3])(=[#6:2]) 1 A 79 | Sulfoxide [S:0](=[OX1:1])(=[#6:2]) 1 B 80 | Hydrocyanic acid [N:0]#[C:1]-[H:2] 1 A 81 | Hydrocyanic acid [N:0]#[C-1:1] 1 B 82 | Phosphoryl group [PX4:0]=[OX2+1:1]-[H:3] 1 A 83 | Phosphoryl group [PX4:0]=[OX1:1] 1 B 84 | Selenonyl group [Se:0]=[OX2+1:1]-[H:3] 1 A 85 | Selenonyl group [Se:0]=[OX1:1] 1 B 86 | Arsenyl group [AsX4:0]=[OX2+1:1]-[H:2] 1 A 87 | Arsenyl group [AsX4:0]=[OX1:1] 1 B 88 | Carboxyl group [#6X3:0](:,-[O,#7,S:1])=[OX2+1,SX2+1:2]-[H:3] 2 A 89 | Carboxyl group [#6X3:0](:,-[O,#7,S:1])=[OX1,SX1:2] 2 B 90 | Carboxyl group vinylogue [#6X3:0](:,-[#6:1]:,=[#6:2]:,-[O,#7,S:3])=[OX2+1,SX2+1:4]-[H:5] 4 A 91 | Carboxyl group vinylogue [#6X3:0](:,-[#6:1]:,=[#6:2]:,-[O,#7,S:3])=[OX1,SX1:4] 4 B 92 | Carbonyl group [#6X3:0](:,-[#1,#6:1])(:,-[#1,#6:2])=[OX2+1:3]-[H:4] 3 A 93 | Carbonyl group [#6X3:0](:,-[#1,#6:1])(:,-[#1,#6:2])=[OX1:3] 3 B 94 | Cyano group [C:0]#[N:1]-[H:2] 1 A 95 | Cyano group [C:0]#[NX1:1] 1 B 96 | Hydroxyl group [CX4:0](-[#6,#1:1])(-[#6,#1:2])(-[#6,#1:3])-[OH2+1:4] 4 A 97 | Hydroxyl group [CX4:0](-[#6,#1:1])(-[#6,#1:2])(-[#6,#1:3])-[OX2:4]-[H:5] 4 B 98 | Selenol [SeX2:0]-[H:1] 0 A 99 | Selenol [SeX1-1:0] 0 B 100 | Borate [BX3:0]-[OX2:1]-[H:2] 1 A 101 | Borate [BX3:0]-[O-1:1] 1 B 102 | Bromomethane [Br:0]-[CH3:1]-[H:2] 1 A 103 | Bromomethane [Br:0]-[CH2-1:1] 1 B 104 | Cyclopentadiene [#6X4:0](-[#1:5])1-,:[#6:1]=,:[#6:2]-,:[#6:3]=,:[#6:4]-,:1 0 A 105 | Cyclopentadiene [#6X3-1:0]1-,:[#6:1]=,:[#6:2]-,:[#6:3]=,:[#6:4]-,:1 0 B 106 | Tin alkyl [N+:0]-[CX4:1](-[H:3])-[SnX4:2] 1 A 107 | Tin alkyl [N+:0]-[CX3-1:1]-[SnX4:2] 1 B -------------------------------------------------------------------------------- /finetune_pka.sh: -------------------------------------------------------------------------------- 1 | data_path='dwar' 2 | MASTER_PORT=10090 3 | task_name="dwar-iBond" 4 | head_name='chembl' 5 | weight_path='pretrain_save/checkpoint_best.pt' 6 | n_gpu=1 7 | save_dir='finetune_save' 8 | 9 | # train params 10 | seed=0 11 | nfolds=5 12 | cv_seed=42 13 | task_num=1 14 | loss_func="finetune_mse" 15 | dict_name='dict.txt' 16 | charge_dict_name='dict_charge.txt' 17 | only_polar=-1 18 | conf_size=11 19 | local_batch_size=16 20 | lr=3e-4 21 | bs=32 22 | epoch=20 23 | dropout=0.1 24 | warmup=0.06 25 | 26 | for ((fold=0;fold<$nfolds;fold++)) 27 | do 28 | export NCCL_ASYNC_ERROR_HANDLING=1 29 | export OMP_NUM_THREADS=1 30 | echo "params setting lr: $lr, bs: $bs, epoch: $epoch, dropout: $dropout, warmup: $warmup, cv_seed: $cv_seed, fold: $fold" 31 | update_freq=`expr $bs / $local_batch_size` 32 | fold_save_dir="fold_${fold}" 33 | model_dir="${save_dir}/${fold_save_dir}" 34 | python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT \ 35 | $(which unicore-train) $data_path --task-name $task_name --user-dir ./unimol --train-subset train --valid-subset valid \ 36 | --conf-size $conf_size --nfolds $nfolds --fold $fold --cv-seed $cv_seed\ 37 | --num-workers 8 --ddp-backend=c10d \ 38 | --dict-name $dict_name --charge-dict-name $charge_dict_name \ 39 | --task mol_pka --loss $loss_func --arch unimol_pka \ 40 | --classification-head-name $head_name --num-classes $task_num \ 41 | --optimizer adam --adam-betas '(0.9, 0.99)' --adam-eps 1e-6 --clip-norm 1.0 \ 42 | --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $local_batch_size --pooler-dropout $dropout\ 43 | --update-freq $update_freq --seed $seed \ 44 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ 45 | --log-interval 100 --log-format simple \ 46 | --finetune-from-model $weight_path \ 47 | --validate-interval 1 --keep-last-epochs 1 \ 48 | --all-gather-list-size 102400 \ 49 | --save-dir $model_dir \ 50 | --best-checkpoint-metric valid_rmse --patience 2000 \ 51 | --only-polar $only_polar --split-mode cross_valid 52 | done -------------------------------------------------------------------------------- /image/inference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dptech-corp/Uni-pKa/7cfbf6532d7e14f427abaed5859bd1001b4b6377/image/inference.png -------------------------------------------------------------------------------- /image/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dptech-corp/Uni-pKa/7cfbf6532d7e14f427abaed5859bd1001b4b6377/image/overview.png -------------------------------------------------------------------------------- /image/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dptech-corp/Uni-pKa/7cfbf6532d7e14f427abaed5859bd1001b4b6377/image/performance.png -------------------------------------------------------------------------------- /image/protensemble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dptech-corp/Uni-pKa/7cfbf6532d7e14f427abaed5859bd1001b4b6377/image/protensemble.png -------------------------------------------------------------------------------- /infer_free_energy.sh: -------------------------------------------------------------------------------- 1 | data_path='./unimol/examples' 2 | infer_task='fe_example' 3 | results_path='fe_results' 4 | head_name='chembl_small' 5 | conf_size=11 6 | dict_name='dict.txt' 7 | charge_dict_name='dict_charge.txt' 8 | task_num=1 9 | batch_size=16 10 | model_path='dwar_finetune' 11 | loss_func="infer_free_energy" 12 | nfolds=5 13 | only_polar=-1 14 | 15 | for ((fold=0;fold<$nfolds;fold++)) 16 | do 17 | python ./unimol/infer.py --user-dir ./unimol ${data_path} --task-name $infer_task --valid-subset $infer_task \ 18 | --results-path $results_path/fold_${fold} \ 19 | --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \ 20 | --task mol_free_energy --loss $loss_func --arch unimol \ 21 | --classification-head-name $head_name --num-classes $task_num \ 22 | --dict-name $dict_name --charge-dict-name $charge_dict_name --conf-size $conf_size \ 23 | --only-polar $only_polar \ 24 | --path $model_path/fold_$fold/checkpoint_best.pt \ 25 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ 26 | --log-interval 50 --log-format simple --required-batch-size-multiple 1 27 | done -------------------------------------------------------------------------------- /infer_pka.sh: -------------------------------------------------------------------------------- 1 | data_path='novartis_acid' 2 | infer_task='novartis_acid' 3 | results_path='novartis_acid_results' 4 | model_path='finetune_save' 5 | head_name='chembl' 6 | conf_size=11 7 | dict_name='dict.txt' 8 | charge_dict_name='dict_charge.txt' 9 | task_num=1 10 | batch_size=16 11 | loss_func="finetune_mse" 12 | nfolds=5 13 | only_polar=-1 14 | 15 | for ((fold=0;fold<$nfolds;fold++)) 16 | do 17 | python ./unimol/infer.py --user-dir ./unimol ${data_path} --task-name $infer_task --valid-subset $infer_task \ 18 | --results-path $results_path/fold_${fold} \ 19 | --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \ 20 | --task mol_pka --loss $loss_func --arch unimol_pka \ 21 | --classification-head-name $head_name --num-classes $task_num \ 22 | --dict-name $dict_name --charge-dict-name $charge_dict_name --conf-size $conf_size \ 23 | --only-polar $only_polar \ 24 | --path $model_path/fold_$fold/checkpoint_best.pt \ 25 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ 26 | --log-interval 50 --log-format simple --required-batch-size-multiple 1 27 | done -------------------------------------------------------------------------------- /pretrain_pka.sh: -------------------------------------------------------------------------------- 1 | data_path="chembl" 2 | task_name="chembl" 3 | n_gpu=8 4 | save_dir="pretrain_save" 5 | tmp_save_dir="tmp_save" 6 | task_num=1 7 | loss_func="pretrain_mlm" 8 | dict_name="dict.txt" 9 | charge_dict_name="dict_charge.txt" 10 | only_polar=-1 11 | conf_size=11 12 | split_mode="predefine" 13 | MASTER_PORT=10090 14 | 15 | # train params 16 | local_batch_size=16 17 | batch_size=16 18 | lr=1e-4 19 | epoch=100 20 | dropout=0.1 21 | warmup=0.06 22 | seed=0 23 | mask_prob=0.05 24 | update_freq=`expr $batch_size / $local_batch_size` 25 | global_batch_size=`expr $batch_size \* $n_gpu \* $update_freq` 26 | echo "params setting lr: $lr, bs: $global_batch_size, epoch: $epoch, dropout: $dropout, warmup: $warmup, seed: $seed" 27 | 28 | # loss 29 | masked_token_loss=1 30 | masked_charge_loss=2 31 | masked_coord_loss=2 32 | masked_dist_loss=1 33 | x_norm_loss=0.01 34 | delta_pair_repr_norm_loss=0.01 35 | 36 | 37 | export NCCL_ASYNC_ERROR_HANDLING=1 38 | export OMP_NUM_THREADS=1 39 | python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT \ 40 | $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \ 41 | --conf-size $conf_size \ 42 | --num-workers 8 --ddp-backend=c10d \ 43 | --dict-name $dict_name --charge-dict-name $charge_dict_name \ 44 | --task mol_pka_mlm --loss $loss_func --arch unimol_pka \ 45 | --classification-head-name $task_name --num-classes $task_num \ 46 | --optimizer adam --adam-betas "(0.9, 0.99)" --adam-eps 1e-6 --clip-norm 1.0 \ 47 | --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $local_batch_size --pooler-dropout $dropout\ 48 | --update-freq $update_freq --seed $seed \ 49 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ 50 | --keep-last-epochs 1 \ 51 | --log-interval 100 --log-format simple \ 52 | --validate-interval 1 \ 53 | --save-dir $save_dir --tmp-save-dir $tmp_save_dir --tensorboard-logdir $save_dir/tsb \ 54 | --best-checkpoint-metric valid_rmse --patience 2000 \ 55 | --only-polar $only_polar --mask-prob $mask_prob \ 56 | --masked-token-loss $masked_token_loss --masked-coord-loss $masked_coord_loss --masked-dist-loss $masked_dist_loss \ 57 | --masked-charge-loss $masked_charge_loss --x-norm-loss $x_norm_loss --delta-pair-repr-norm-loss $delta_pair_repr_norm_loss 58 | 59 | -------------------------------------------------------------------------------- /scripts/infer_mean_ensemble.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import pandas as pd 6 | import os 7 | import argparse 8 | import numpy as np 9 | import glob 10 | 11 | 12 | def cal_metrics(df): 13 | mae = np.abs(df["predict"] - df["target"]).mean() 14 | mse = ((df["predict"] - df["target"]) ** 2).mean() 15 | rmse = np.sqrt(mse) 16 | return mae, rmse 17 | 18 | 19 | def get_csv_results(results_path, nfolds, task): 20 | 21 | all_smi_list, all_predict_list, all_target_list = [], [], [] 22 | 23 | for fold_idx in range(nfolds): 24 | print(f"Processing fold {fold_idx}...") 25 | fold_path = os.path.join(results_path, f'fold_{fold_idx}') 26 | pkl_files = glob.glob(f"{fold_path}/*.pkl") 27 | fold_data = pd.read_pickle(pkl_files[0]) 28 | 29 | smi_list, predict_list, target_list = [], [], [] 30 | for batch in fold_data: 31 | sz = batch["bsz"] 32 | for i in range(sz): 33 | smi_list.append(batch["smi_name"][i]) 34 | predict_list.append(batch["predict"][i].cpu().item()) 35 | target_list.append(batch["target"][i].cpu().item()) 36 | fold_df = pd.DataFrame({"smiles": smi_list, "predict": predict_list, "target": target_list}) 37 | fold_df.to_csv(f'{fold_path}/fold_{fold_idx}.csv',index=False, sep='\t') 38 | 39 | # for final combined results 40 | all_smi_list.extend(smi_list) 41 | all_predict_list.extend(predict_list) 42 | all_target_list.extend(target_list) 43 | 44 | print(f"Combining results from {nfolds} folds into a single file...") 45 | combined_df = pd.DataFrame({"smiles": all_smi_list, "predict": all_predict_list, "target": all_target_list}) 46 | combined_df.to_csv(f'{results_path}/all_results.csv', index=False, sep='\t') 47 | 48 | print(f"Calculating mean results for each SMILES...") 49 | mean_results = combined_df.groupby('smiles', as_index=False).agg({ 50 | 'predict': 'mean', 51 | 'target': 'mean' 52 | }) 53 | mean_results.to_csv(f'{results_path}/mean_results.csv', index=False, sep='\t') 54 | if task == 'pka': 55 | print(f"MAE and RMSE for this task...") 56 | mae, rmse = cal_metrics(mean_results) 57 | print(f'MAE: {round(mae, 4)}, RMSE: {round(rmse, 4)}') 58 | print(f"Done!") 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser(description='Model infer result mean ensemble') 63 | parser.add_argument( 64 | '--results-path', 65 | type=str, 66 | default='results', 67 | help='path to save infer results' 68 | ) 69 | parser.add_argument( 70 | "--nfolds", 71 | default=5, 72 | type=int, 73 | help="cross validation split folds" 74 | ) 75 | parser.add_argument( 76 | "--task", 77 | default='pka', 78 | type=str, 79 | choices=['pka', 'free_energy'] 80 | ) 81 | args = parser.parse_args() 82 | get_csv_results(args.results_path, args.nfolds, args.task) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() -------------------------------------------------------------------------------- /scripts/preprocess_pka.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import os 6 | import pickle 7 | import lmdb 8 | import pandas as pd 9 | import numpy as np 10 | from rdkit import Chem 11 | from tqdm import tqdm 12 | from rdkit.Chem import AllChem 13 | from rdkit.Chem.Scaffolds import MurckoScaffold 14 | from rdkit import RDLogger 15 | RDLogger.DisableLog('rdApp.*') 16 | import warnings 17 | warnings.filterwarnings(action='ignore') 18 | from multiprocessing import Pool 19 | import argparse 20 | 21 | 22 | def smi2scaffold(smi): 23 | try: 24 | return MurckoScaffold.MurckoScaffoldSmiles( 25 | smiles=smi, includeChirality=True) 26 | except: 27 | print("failed to generate scaffold with smiles: {}".format(smi)) 28 | return smi 29 | 30 | 31 | def smi2_2Dcoords(smi): 32 | mol = Chem.MolFromSmiles(smi) 33 | mol = AllChem.AddHs(mol) 34 | AllChem.Compute2DCoords(mol) 35 | coordinates = mol.GetConformer().GetPositions().astype(np.float32) 36 | len(mol.GetAtoms()) == len(coordinates), "2D coordinates shape is not align with {}".format(smi) 37 | return coordinates 38 | 39 | 40 | def smi2_3Dcoords(smi,cnt, gen_mode='mmff'): 41 | assert gen_mode in ['mmff', 'no_mmff'] 42 | mol = Chem.MolFromSmiles(smi) 43 | mol = AllChem.AddHs(mol) 44 | coordinate_list=[] 45 | for seed in range(cnt): 46 | try: 47 | res = AllChem.EmbedMolecule(mol, randomSeed=seed) 48 | if res == 0: 49 | try: 50 | if gen_mode == 'mmff': 51 | AllChem.MMFFOptimizeMolecule(mol) # some conformer can not use MMFF optimize 52 | coordinates = mol.GetConformer().GetPositions() 53 | except: 54 | print("Failed to generate 3D, replace with 2D") 55 | coordinates = smi2_2Dcoords(smi) 56 | 57 | elif res == -1: 58 | mol_tmp = Chem.MolFromSmiles(smi) 59 | AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed) 60 | mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True) 61 | try: 62 | if gen_mode == 'mmff': 63 | AllChem.MMFFOptimizeMolecule(mol_tmp) # some conformer can not use MMFF optimize 64 | coordinates = mol_tmp.GetConformer().GetPositions() 65 | except: 66 | print("Failed to generate 3D, replace with 2D") 67 | coordinates = smi2_2Dcoords(smi) 68 | except: 69 | print("Failed to generate 3D, replace with 2D") 70 | coordinates = smi2_2Dcoords(smi) 71 | 72 | assert len(mol.GetAtoms()) == len(coordinates), "3D coordinates shape is not align with {}".format(smi) 73 | coordinate_list.append(coordinates.astype(np.float32)) 74 | return coordinate_list 75 | 76 | 77 | def smi2metadata(smi, cnt, gen_mode): # input: single smi; output: molecule metadata (atom, charge, mol. smi, coords...) 78 | scaffold = smi2scaffold(smi) 79 | mol = Chem.MolFromSmiles(smi) 80 | 81 | if len(mol.GetAtoms()) > 400: 82 | coordinate_list = [smi2_2Dcoords(smi)] * (cnt+1) 83 | print("atom num >400,use 2D coords",smi) 84 | else: 85 | # gen cnt num 3D conf 86 | coordinate_list = smi2_3Dcoords(smi,cnt, gen_mode) 87 | # gen 1 2D conf 88 | coordinate_list.append(smi2_2Dcoords(smi).astype(np.float32)) 89 | mol = AllChem.AddHs(mol) 90 | atoms, charges = [], [] 91 | for atom in mol.GetAtoms(): 92 | atoms.append(atom.GetSymbol()) 93 | charges.append(atom.GetFormalCharge()) 94 | 95 | return {'atoms': atoms,'charges': charges, 'coordinates': coordinate_list, 'mol':mol,'smi': smi, 'scaffold': scaffold} 96 | 97 | 98 | def inner_smi2coords_pka(content): 99 | smi_all, target = content 100 | cnt = 10 # conformer num,all==11, 10 3d + 1 2d 101 | gen_mode = 'mmff' # 'mmff', 'no_mmff' 102 | 103 | # get single smi from original SMILES 104 | smi_list_a, smi_list_b = smi_all.split('>>') 105 | smi_list_a = smi_list_a.split(',') 106 | smi_list_b = smi_list_b.split(',') 107 | 108 | # get whole datapoint metadata 109 | metadata_a, metadata_b = [], [] 110 | for i in range(len(smi_list_a)): 111 | metadata_a.append(smi2metadata(smi_list_a[i], cnt, gen_mode)) 112 | for i in range(len(smi_list_b)): 113 | metadata_b.append(smi2metadata(smi_list_b[i], cnt, gen_mode)) 114 | 115 | return pickle.dumps({'ori_smi': smi_all, 'metadata_a': metadata_a, 'metadata_b': metadata_b, 'target': target}, protocol=-1) 116 | 117 | 118 | def smi2coords(content): 119 | try: 120 | return inner_smi2coords_pka(content) 121 | except: 122 | print("failed smiles: {}".format(content[0])) 123 | return None 124 | 125 | 126 | def load_rawdata_pka(input_csv): 127 | 128 | # read tsv file 129 | df = pd.read_csv(input_csv, sep='\t') 130 | smi_col = 'SMILES' 131 | target_col = 'TARGET' 132 | if target_col not in df.columns: 133 | # If not exist, add "-1.0" as a placeholder. 134 | df["TARGET"] = -1.0 135 | col_list = [smi_col, target_col] 136 | df = df[col_list] 137 | print(f'raw_data size: {df.shape[0]}') 138 | return df 139 | 140 | def write_lmdb(task_name, input_csv, output_dir='.', nthreads=16): 141 | 142 | df = load_rawdata_pka(input_csv) 143 | content_list = zip(*[df[c].values.tolist() for c in df]) 144 | lmdb_name = '{}.lmdb'.format(task_name) 145 | os.makedirs(output_dir, exist_ok=True) 146 | output_name = os.path.join(output_dir, lmdb_name) 147 | try: 148 | os.remove(output_name) 149 | except: 150 | pass 151 | env_new = lmdb.open( 152 | output_name, 153 | subdir=False, 154 | readonly=False, 155 | lock=False, 156 | readahead=False, 157 | meminit=False, 158 | max_readers=1, 159 | map_size=int(100e9), 160 | ) 161 | txn_write = env_new.begin(write=True) 162 | with Pool(nthreads) as pool: 163 | i = 0 164 | for inner_output in tqdm(pool.imap(smi2coords, content_list), total=len(df)): 165 | if inner_output is not None: 166 | txn_write.put(f'{i}'.encode("ascii"), inner_output) 167 | i += 1 168 | print('{} process {} lines'.format(lmdb_name, i)) 169 | txn_write.commit() 170 | env_new.close() 171 | 172 | def main(): 173 | parser = argparse.ArgumentParser( 174 | description="use rdkit to generate conformers" 175 | ) 176 | parser.add_argument( 177 | "--raw-csv-file", 178 | type=str, 179 | default="Datasets/tsv/chembl_train.tsv", 180 | help="the original data csv file path", 181 | ) 182 | parser.add_argument( 183 | "--processed-lmdb-dir", 184 | type=str, 185 | default="chembl", 186 | help="dir of the processed lmdb data", 187 | ) 188 | parser.add_argument("--nthreads", type=int, default=22, help="num of threads") 189 | parser.add_argument( 190 | "--task-name", 191 | type=str, 192 | default="train", 193 | help="name of the lmdb file; train and valid for chembl", 194 | choices=['train', 'valid', 'dwar-iBond', 'novartis_acid', 'novartis_base', 'sampl6', 'sampl7', 'sampl8'] 195 | ) 196 | args = parser.parse_args() 197 | write_lmdb(task_name = args.task_name, input_csv=args.raw_csv_file, output_dir=args.processed_lmdb_dir, nthreads = args.nthreads) 198 | 199 | 200 | if __name__ == '__main__': 201 | main() 202 | -------------------------------------------------------------------------------- /unimol/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import unimol.tasks 3 | import unimol.data 4 | import unimol.models 5 | import unimol.losses 6 | -------------------------------------------------------------------------------- /unimol/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .key_dataset import KeyDataset 2 | from .normalize_dataset import ( 3 | NormalizeDataset, 4 | ) 5 | from .remove_hydrogen_dataset import ( 6 | RemoveHydrogenDataset, 7 | ) 8 | from .tta_dataset import ( 9 | TTADataset, 10 | TTAPKADataset, 11 | ) 12 | from .cropping_dataset import CroppingDataset 13 | 14 | from .distance_dataset import ( 15 | DistanceDataset, 16 | EdgeTypeDataset, 17 | ) 18 | from .conformer_sample_dataset import ( 19 | ConformerSamplePKADataset, 20 | ) 21 | from .coord_pad_dataset import RightPadDatasetCoord 22 | from .lmdb_dataset import ( 23 | FoldLMDBDataset, 24 | StackedLMDBDataset, 25 | SplitLMDBDataset, 26 | ) 27 | from .pka_input_dataset import ( 28 | PKAInputDataset, 29 | PKAMLMInputDataset, 30 | ) 31 | from .mask_points_dataset import MaskPointsDataset 32 | 33 | __all__ = [] -------------------------------------------------------------------------------- /unimol/data/conformer_sample_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | from . import data_utils 9 | 10 | 11 | class ConformerSamplePKADataset(BaseWrapperDataset): 12 | def __init__(self, dataset, seed, atoms, coordinates, charges, id="ori_smi"): 13 | self.dataset = dataset 14 | self.seed = seed 15 | self.atoms = atoms 16 | self.coordinates = coordinates 17 | self.charges = charges 18 | self.id = id 19 | self.set_epoch(None) 20 | 21 | def set_epoch(self, epoch, **unused): 22 | super().set_epoch(epoch) 23 | self.epoch = epoch 24 | 25 | @lru_cache(maxsize=16) 26 | def __cached_item__(self, index: int, epoch: int): 27 | atoms = np.array(self.dataset[index][self.atoms]) 28 | charges = np.array(self.dataset[index][self.charges]) 29 | assert len(atoms) > 0, 'atoms: {}, charges: {}, coordinates: {}, id: {}'.format(atoms, charges, coordinates, id) 30 | size = len(self.dataset[index][self.coordinates]) 31 | with data_utils.numpy_seed(self.seed, epoch, index): 32 | sample_idx = np.random.randint(size) 33 | coordinates = self.dataset[index][self.coordinates][sample_idx] 34 | return {"atoms": atoms, "coordinates": coordinates.astype(np.float32),"charges":charges,"id": self.id} 35 | 36 | def __getitem__(self, index: int): 37 | return self.__cached_item__(index, self.epoch) 38 | -------------------------------------------------------------------------------- /unimol/data/coord_pad_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from unicore.data import BaseWrapperDataset 6 | 7 | 8 | def collate_tokens_coords( 9 | values, 10 | pad_idx, 11 | left_pad=False, 12 | pad_to_length=None, 13 | pad_to_multiple=1, 14 | ): 15 | """Convert a list of 1d tensors into a padded 2d tensor.""" 16 | size = max(v.size(0) for v in values) 17 | size = size if pad_to_length is None else max(size, pad_to_length) 18 | if pad_to_multiple != 1 and size % pad_to_multiple != 0: 19 | size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) 20 | res = values[0].new(len(values), size, 3).fill_(pad_idx) 21 | 22 | def copy_tensor(src, dst): 23 | assert dst.numel() == src.numel() 24 | dst.copy_(src) 25 | 26 | for i, v in enumerate(values): 27 | copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :]) 28 | return res 29 | 30 | 31 | class RightPadDatasetCoord(BaseWrapperDataset): 32 | def __init__(self, dataset, pad_idx, left_pad=False): 33 | super().__init__(dataset) 34 | self.pad_idx = pad_idx 35 | self.left_pad = left_pad 36 | 37 | def collater(self, samples): 38 | return collate_tokens_coords( 39 | samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8 40 | ) 41 | -------------------------------------------------------------------------------- /unimol/data/cropping_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | import logging 8 | from unicore.data import BaseWrapperDataset 9 | from . import data_utils 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class CroppingDataset(BaseWrapperDataset): 15 | def __init__(self, dataset, seed, atoms, coordinates, charges, max_atoms=256): 16 | self.dataset = dataset 17 | self.seed = seed 18 | self.atoms = atoms 19 | self.coordinates = coordinates 20 | self.charges = charges 21 | self.max_atoms = max_atoms 22 | self.set_epoch(None) 23 | 24 | def set_epoch(self, epoch, **unused): 25 | super().set_epoch(epoch) 26 | self.epoch = epoch 27 | 28 | @lru_cache(maxsize=16) 29 | def __cached_item__(self, index: int, epoch: int): 30 | dd = self.dataset[index].copy() 31 | atoms = dd[self.atoms] 32 | coordinates = dd[self.coordinates] 33 | charges = dd[self.charges] 34 | if self.max_atoms and len(atoms) > self.max_atoms: 35 | with data_utils.numpy_seed(self.seed, epoch, index): 36 | index = np.random.choice(len(atoms), self.max_atoms, replace=False) 37 | atoms = np.array(atoms)[index] 38 | coordinates = coordinates[index] 39 | charges = charges[index] 40 | dd[self.atoms] = atoms 41 | dd[self.coordinates] = coordinates.astype(np.float32) 42 | dd[self.charges] = charges 43 | return dd 44 | 45 | def __getitem__(self, index: int): 46 | return self.__cached_item__(index, self.epoch) 47 | -------------------------------------------------------------------------------- /unimol/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | import contextlib 7 | 8 | 9 | @contextlib.contextmanager 10 | def numpy_seed(seed, *addl_seeds): 11 | """Context manager which seeds the NumPy PRNG with the specified seed and 12 | restores the state afterward""" 13 | if seed is None: 14 | yield 15 | return 16 | if len(addl_seeds) > 0: 17 | seed = int(hash((seed, *addl_seeds)) % 1e6) 18 | state = np.random.get_state() 19 | np.random.seed(seed) 20 | try: 21 | yield 22 | finally: 23 | np.random.set_state(state) 24 | -------------------------------------------------------------------------------- /unimol/data/distance_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | import torch 7 | from scipy.spatial import distance_matrix 8 | from functools import lru_cache 9 | from unicore.data import BaseWrapperDataset 10 | 11 | 12 | class DistanceDataset(BaseWrapperDataset): 13 | def __init__(self, dataset): 14 | super().__init__(dataset) 15 | self.dataset = dataset 16 | 17 | @lru_cache(maxsize=16) 18 | def __getitem__(self, idx): 19 | pos = self.dataset[idx].view(-1, 3).numpy() 20 | dist = distance_matrix(pos, pos).astype(np.float32) 21 | return torch.from_numpy(dist) 22 | 23 | 24 | class EdgeTypeDataset(BaseWrapperDataset): 25 | def __init__(self, dataset: torch.utils.data.Dataset, num_types: int): 26 | self.dataset = dataset 27 | self.num_types = num_types 28 | 29 | @lru_cache(maxsize=16) 30 | def __getitem__(self, index: int): 31 | node_input = self.dataset[index].clone() 32 | offset = node_input.view(-1, 1) * self.num_types + node_input.view(1, -1) 33 | return offset 34 | -------------------------------------------------------------------------------- /unimol/data/key_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from functools import lru_cache 6 | from unicore.data import BaseWrapperDataset 7 | 8 | 9 | class KeyDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, key): 11 | self.dataset = dataset 12 | self.key = key 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | @lru_cache(maxsize=16) 18 | def __getitem__(self, idx): 19 | return self.dataset[idx][self.key] 20 | -------------------------------------------------------------------------------- /unimol/data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | 6 | import lmdb 7 | import os 8 | import pickle 9 | from functools import lru_cache 10 | from . import data_utils 11 | import numpy as np 12 | import logging 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class FoldLMDBDataset: 18 | def __init__(self, dataset, seed, cur_fold, nfolds=5, cache_fold_info=None): 19 | super().__init__() 20 | self.dataset = dataset 21 | if cache_fold_info is None: 22 | self.keys = [] 23 | self.fold_start = [] 24 | self.fold_end = [] 25 | self.init_random_split(dataset, seed, nfolds) 26 | else: 27 | # use cache fold info 28 | self.keys, self.fold_start, self.fold_end = cache_fold_info 29 | self.cur_fold = cur_fold 30 | self._len = self.fold_end[cur_fold] - self.fold_start[cur_fold] 31 | assert len(self.fold_end) == len(self.fold_start) == nfolds 32 | 33 | def init_random_split(self, dataset, seed, nfolds): 34 | with data_utils.numpy_seed(seed): 35 | self.keys = np.random.permutation(len(dataset)) 36 | average_size = (len(dataset) + nfolds - 1) // nfolds 37 | cur_size = 0 38 | for i in range(nfolds): 39 | self.fold_start.append(cur_size) 40 | cur_size = min(cur_size + average_size, len(dataset)) 41 | self.fold_end.append(cur_size) 42 | 43 | def get_fold_info(self): 44 | return self.keys, self.fold_start, self.fold_end 45 | 46 | def __len__(self): 47 | return self._len 48 | 49 | @lru_cache(maxsize=16) 50 | def __getitem__(self, idx): 51 | global_idx = idx + self.fold_start[self.cur_fold] 52 | return self.dataset[self.keys[global_idx]] 53 | 54 | 55 | class StackedLMDBDataset: 56 | def __init__(self, datasets): 57 | self._len = 0 58 | self.datasets = [] 59 | self.idx_to_file = {} 60 | self.idx_offset = [] 61 | for dataset in datasets: 62 | self.datasets.append(dataset) 63 | for i in range(len(dataset)): 64 | self.idx_to_file[i + self._len] = len(self.datasets) - 1 65 | self.idx_offset.append(self._len) 66 | self._len += len(dataset) 67 | 68 | def __len__(self): 69 | return self._len 70 | 71 | @lru_cache(maxsize=16) 72 | def __getitem__(self, idx): 73 | file_idx = self.idx_to_file[idx] 74 | sub_idx = idx - self.idx_offset[file_idx] 75 | return self.datasets[file_idx][sub_idx] 76 | 77 | 78 | class SplitLMDBDataset: 79 | # train:valid = 9:1 80 | def __init__(self, dataset, seed, cur_fold, cache_fold_info= None,frac_train=0.9, frac_valid=0.1): 81 | super().__init__() 82 | self.dataset = dataset 83 | np.testing.assert_almost_equal(frac_train + frac_valid, 1.0) 84 | frac = [frac_train,frac_valid] 85 | if cache_fold_info is None: 86 | self.keys = [] 87 | self.fold_start = [] 88 | self.fold_end = [] 89 | self.init_random_split(dataset, seed, frac) 90 | else: 91 | # use cache fold info 92 | self.keys, self.fold_start, self.fold_end = cache_fold_info 93 | self.cur_fold = cur_fold 94 | self._len = self.fold_end[cur_fold] - self.fold_start[cur_fold] 95 | assert len(self.fold_end) == len(self.fold_start) == 3 96 | 97 | def init_random_split(self, dataset, seed, frac): 98 | with data_utils.numpy_seed(seed): 99 | self.keys = np.random.permutation(len(dataset)) 100 | frac_train,frac_valid = frac 101 | #average_size = (len(dataset) + nfolds - 1) // nfolds 102 | fold_size = [int(frac_train * len(dataset)), len(dataset)- int(frac_train * len(dataset))] 103 | assert sum(fold_size) == len(dataset) 104 | cur_size = 0 105 | for i in range(len(fold_size)): 106 | self.fold_start.append(cur_size) 107 | cur_size = min(cur_size + fold_size[i], len(dataset)) 108 | self.fold_end.append(cur_size) 109 | 110 | def get_fold_info(self): 111 | return self.keys, self.fold_start, self.fold_end 112 | 113 | def __len__(self): 114 | return self._len 115 | 116 | @lru_cache(maxsize=16) 117 | def __getitem__(self, idx): 118 | global_idx = idx + self.fold_start[self.cur_fold] 119 | return self.dataset[self.keys[global_idx]] 120 | -------------------------------------------------------------------------------- /unimol/data/mask_points_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from functools import lru_cache 6 | 7 | import numpy as np 8 | import torch 9 | from unicore.data import Dictionary 10 | from unicore.data import BaseWrapperDataset 11 | from . import data_utils 12 | 13 | 14 | class MaskPointsDataset(BaseWrapperDataset): 15 | def __init__( 16 | self, 17 | dataset: torch.utils.data.Dataset, 18 | coord_dataset: torch.utils.data.Dataset, 19 | charge_dataset: torch.utils.data.Dataset, 20 | vocab: Dictionary, 21 | charge_vocab: Dictionary, 22 | pad_idx: int, 23 | charge_pad_idx: int, 24 | mask_idx: int, 25 | charge_mask_idx: int, 26 | noise_type: str, 27 | noise: float = 1.0, 28 | seed: int = 1, 29 | mask_prob: float = 0.15, 30 | leave_unmasked_prob: float = 0.1, 31 | random_token_prob: float = 0.1, 32 | ): 33 | assert 0.0 < mask_prob < 1.0 34 | assert 0.0 <= random_token_prob <= 1.0 35 | assert 0.0 <= leave_unmasked_prob <= 1.0 36 | assert random_token_prob + leave_unmasked_prob <= 1.0 37 | 38 | self.dataset = dataset 39 | self.coord_dataset = coord_dataset 40 | self.charge_dataset = charge_dataset 41 | self.vocab = vocab 42 | self.charge_vocab = charge_vocab 43 | self.pad_idx = pad_idx 44 | self.charge_pad_idx = charge_pad_idx 45 | self.mask_idx = mask_idx 46 | self.charge_mask_idx = charge_mask_idx 47 | self.noise_type = noise_type 48 | self.noise = noise 49 | self.seed = seed 50 | self.mask_prob = mask_prob 51 | self.leave_unmasked_prob = leave_unmasked_prob 52 | self.random_token_prob = random_token_prob 53 | 54 | if random_token_prob > 0.0: 55 | weights = np.ones(len(self.vocab)) 56 | weights[vocab.special_index()] = 0 57 | self.weights = weights / weights.sum() 58 | # for charge 59 | charge_weights = np.ones(len(self.charge_vocab)) 60 | charge_weights[charge_vocab.special_index()] = 0 61 | self.charge_weights = charge_weights / charge_weights.sum() 62 | 63 | self.epoch = None 64 | if self.noise_type == "trunc_normal": 65 | self.noise_f = lambda num_mask: np.clip( 66 | np.random.randn(num_mask, 3) * self.noise, 67 | a_min=-self.noise * 2.0, 68 | a_max=self.noise * 2.0, 69 | ) 70 | elif self.noise_type == "normal": 71 | self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise 72 | elif self.noise_type == "uniform": 73 | self.noise_f = lambda num_mask: np.random.uniform( 74 | low=-self.noise, high=self.noise, size=(num_mask, 3) 75 | ) 76 | else: 77 | self.noise_f = lambda num_mask: 0.0 78 | 79 | def set_epoch(self, epoch, **unused): 80 | super().set_epoch(epoch) 81 | self.coord_dataset.set_epoch(epoch) 82 | self.dataset.set_epoch(epoch) 83 | self.charge_dataset.set_epoch(epoch) 84 | self.epoch = epoch 85 | 86 | def __getitem__(self, index: int): 87 | return self.__getitem_cached__(self.epoch, index) 88 | 89 | @lru_cache(maxsize=16) 90 | def __getitem_cached__(self, epoch: int, index: int): 91 | ret = {} 92 | with data_utils.numpy_seed(self.seed, epoch, index): 93 | item = self.dataset[index] 94 | coord = self.coord_dataset[index] 95 | charge = self.charge_dataset[index] 96 | sz = len(item) 97 | # don't allow empty sequence 98 | assert sz > 0 99 | # decide elements to mask 100 | num_mask = int( 101 | # add a random number for probabilistic rounding 102 | self.mask_prob * sz 103 | + np.random.rand() 104 | ) 105 | mask_idc = np.random.choice(sz, num_mask, replace=False) 106 | mask = np.full(sz, False) 107 | mask[mask_idc] = True 108 | ret["targets"] = np.full(len(mask), self.pad_idx) 109 | ret["targets"][mask] = item[mask] 110 | ret["targets"] = torch.from_numpy(ret["targets"]).long() 111 | # for charge 112 | ret["charge_targets"] = np.full(len(mask), self.charge_pad_idx) 113 | ret["charge_targets"][mask] = charge[mask] 114 | ret["charge_targets"] = torch.from_numpy(ret["charge_targets"]).long() 115 | 116 | # decide unmasking and random replacement 117 | rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob 118 | if rand_or_unmask_prob > 0.0: 119 | rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) 120 | if self.random_token_prob == 0.0: 121 | unmask = rand_or_unmask 122 | rand_mask = None 123 | elif self.leave_unmasked_prob == 0.0: 124 | unmask = None 125 | rand_mask = rand_or_unmask 126 | else: 127 | unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob 128 | decision = np.random.rand(sz) < unmask_prob 129 | unmask = rand_or_unmask & decision 130 | rand_mask = rand_or_unmask & (~decision) 131 | else: 132 | unmask = rand_mask = None 133 | 134 | if unmask is not None: 135 | mask = mask ^ unmask 136 | 137 | new_item = np.copy(item) 138 | new_item[mask] = self.mask_idx 139 | 140 | num_mask = mask.astype(np.int32).sum() 141 | new_coord = np.copy(coord) 142 | new_coord[mask, :] += self.noise_f(num_mask) 143 | 144 | # for charge mask 145 | new_charge = np.copy(charge) 146 | new_charge[mask] = self.charge_mask_idx 147 | 148 | if rand_mask is not None: 149 | num_rand = rand_mask.sum() 150 | if num_rand > 0: 151 | new_item[rand_mask] = np.random.choice( 152 | len(self.vocab), 153 | num_rand, 154 | p=self.weights, 155 | ) 156 | # for charge 157 | new_charge[rand_mask] = np.random.choice( 158 | len(self.charge_vocab), 159 | num_rand, 160 | p=self.charge_weights, 161 | ) 162 | ret["atoms"] = torch.from_numpy(new_item).long() 163 | ret["coordinates"] = torch.from_numpy(new_coord).float() 164 | ret["charges"] = torch.from_numpy(new_charge).long() 165 | return ret 166 | 167 | -------------------------------------------------------------------------------- /unimol/data/normalize_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class NormalizeDataset(BaseWrapperDataset): 11 | def __init__(self, dataset, coordinates, normalize_coord=True): 12 | self.dataset = dataset 13 | self.coordinates = coordinates 14 | self.normalize_coord = normalize_coord # normalize the coordinates. 15 | self.set_epoch(None) 16 | 17 | def set_epoch(self, epoch, **unused): 18 | super().set_epoch(epoch) 19 | self.epoch = epoch 20 | 21 | @lru_cache(maxsize=16) 22 | def __cached_item__(self, index: int, epoch: int): 23 | dd = self.dataset[index].copy() 24 | coordinates = dd[self.coordinates] 25 | # normalize 26 | if self.normalize_coord: 27 | coordinates = coordinates - coordinates.mean(axis=0) 28 | dd[self.coordinates] = coordinates.astype(np.float32) 29 | return dd 30 | 31 | def __getitem__(self, index: int): 32 | return self.__cached_item__(index, self.epoch) 33 | -------------------------------------------------------------------------------- /unimol/data/pka_input_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | import collections 9 | import torch 10 | from itertools import chain 11 | from unicore.data.data_utils import collate_tokens, collate_tokens_2d 12 | from .coord_pad_dataset import collate_tokens_coords 13 | 14 | 15 | class PKAInputDataset(BaseWrapperDataset): 16 | def __init__(self, idx2key, src_tokens, src_charges, src_coord, src_distance, src_edge_type, token_pad_idx, charge_pad_idx, split='train', conf_size=10): 17 | self.idx2key = idx2key 18 | self.dataset = src_tokens 19 | self.src_tokens = src_tokens 20 | self.src_charges = src_charges 21 | self.src_coord = src_coord 22 | self.src_distance = src_distance 23 | self.src_edge_type = src_edge_type 24 | self.token_pad_idx = token_pad_idx 25 | self.charge_pad_idx = charge_pad_idx 26 | self.split = split 27 | self.conf_size = conf_size 28 | self.left_pad = False 29 | self._init_rec2mol() 30 | self.set_epoch(None) 31 | 32 | def set_epoch(self, epoch, **unused): 33 | super().set_epoch(epoch) 34 | self.epoch = epoch 35 | 36 | def _init_rec2mol(self): 37 | self.rec2mol = collections.defaultdict(list) 38 | if self.split in ['train','train.small']: 39 | total_sz = len(self.idx2key) 40 | for i in range(total_sz): 41 | smi_idx, _ = self.idx2key[i] 42 | self.rec2mol[smi_idx].append(i) 43 | else: 44 | total_sz = len(self.idx2key) 45 | for i in range(total_sz): 46 | smi_idx, _ = self.idx2key[i] 47 | self.rec2mol[smi_idx].extend([i * self.conf_size + j for j in range(self.conf_size)]) 48 | 49 | 50 | def __len__(self): 51 | return len(self.rec2mol) 52 | 53 | @lru_cache(maxsize=16) 54 | def __cached_item__(self, index: int, epoch: int): 55 | mol_list = self.rec2mol[index] 56 | src_tokens_list = [] 57 | src_charges_list = [] 58 | src_coord_list = [] 59 | src_distance_list = [] 60 | src_edge_type_list = [] 61 | for i in mol_list: 62 | src_tokens_list.append(self.src_tokens[i]) 63 | src_charges_list.append(self.src_charges[i]) 64 | src_coord_list.append(self.src_coord[i]) 65 | src_distance_list.append(self.src_distance[i]) 66 | src_edge_type_list.append(self.src_edge_type[i]) 67 | 68 | return src_tokens_list, src_charges_list,src_coord_list,src_distance_list,src_edge_type_list 69 | 70 | def __getitem__(self, index: int): 71 | return self.__cached_item__(index, self.epoch) 72 | 73 | def collater(self, samples): 74 | batch = [len(samples[i][0]) for i in range(len(samples))] 75 | 76 | src_tokens, src_charges, src_coord, src_distance, src_edge_type = [list(chain.from_iterable(i)) for i in zip(*samples)] 77 | src_tokens = collate_tokens(src_tokens, self.token_pad_idx, left_pad=self.left_pad, pad_to_multiple=8) 78 | src_charges = collate_tokens(src_charges, self.charge_pad_idx, left_pad=self.left_pad, pad_to_multiple=8) 79 | src_coord = collate_tokens_coords(src_coord, 0, left_pad=self.left_pad, pad_to_multiple=8) 80 | src_distance = collate_tokens_2d(src_distance, 0, left_pad=self.left_pad, pad_to_multiple=8) 81 | src_edge_type = collate_tokens_2d(src_edge_type, 0, left_pad=self.left_pad, pad_to_multiple=8) 82 | 83 | return src_tokens, src_charges, src_coord, src_distance, src_edge_type, batch 84 | 85 | 86 | class PKAMLMInputDataset(BaseWrapperDataset): 87 | def __init__(self, idx2key, src_tokens, src_charges, src_coord, src_distance, src_edge_type, token_targets, charge_targets, dist_targets, coord_targets, token_pad_idx, charge_pad_idx, split='train', conf_size=10): 88 | self.idx2key = idx2key 89 | self.dataset = src_tokens 90 | self.src_tokens = src_tokens 91 | self.src_charges = src_charges 92 | self.src_coord = src_coord 93 | self.src_distance = src_distance 94 | self.src_edge_type = src_edge_type 95 | self.token_targets = token_targets 96 | self.charge_targets = charge_targets 97 | self.dist_targets = dist_targets 98 | self.coord_targets = coord_targets 99 | self.token_pad_idx = token_pad_idx 100 | self.charge_pad_idx = charge_pad_idx 101 | self.split = split 102 | self.conf_size = conf_size 103 | self.left_pad = False 104 | self._init_rec2mol() 105 | self.set_epoch(None) 106 | 107 | def set_epoch(self, epoch, **unused): 108 | super().set_epoch(epoch) 109 | self.epoch = epoch 110 | 111 | def _init_rec2mol(self): 112 | self.rec2mol = collections.defaultdict(list) 113 | if self.split in ['train','train.small']: 114 | total_sz = len(self.idx2key) 115 | for i in range(total_sz): 116 | smi_idx, _ = self.idx2key[i] 117 | self.rec2mol[smi_idx].append(i) 118 | else: 119 | total_sz = len(self.idx2key) 120 | for i in range(total_sz): 121 | smi_idx, _ = self.idx2key[i] 122 | self.rec2mol[smi_idx].extend([i * self.conf_size + j for j in range(self.conf_size)]) 123 | 124 | 125 | def __len__(self): 126 | return len(self.rec2mol) 127 | 128 | @lru_cache(maxsize=16) 129 | def __cached_item__(self, index: int, epoch: int): 130 | mol_list = self.rec2mol[index] 131 | src_tokens_list = [] 132 | src_charges_list = [] 133 | src_coord_list = [] 134 | src_distance_list = [] 135 | src_edge_type_list = [] 136 | token_targets_list = [] 137 | charge_targets_list = [] 138 | coord_targets_list = [] 139 | dist_targets_list = [] 140 | for i in mol_list: 141 | src_tokens_list.append(self.src_tokens[i]) 142 | src_charges_list.append(self.src_charges[i]) 143 | src_coord_list.append(self.src_coord[i]) 144 | src_distance_list.append(self.src_distance[i]) 145 | src_edge_type_list.append(self.src_edge_type[i]) 146 | token_targets_list.append(self.token_targets[i]) 147 | charge_targets_list.append(self.charge_targets[i]) 148 | coord_targets_list.append(self.coord_targets[i]) 149 | dist_targets_list.append(self.dist_targets[i]) 150 | 151 | return src_tokens_list, src_charges_list,src_coord_list,src_distance_list,src_edge_type_list, token_targets_list, charge_targets_list, coord_targets_list, dist_targets_list 152 | 153 | def __getitem__(self, index: int): 154 | return self.__cached_item__(index, self.epoch) 155 | 156 | def collater(self, samples): 157 | batch = [len(samples[i][0]) for i in range(len(samples))] 158 | 159 | src_tokens, src_charges, src_coord, src_distance, src_edge_type, token_targets, charge_targets, coord_targets, dist_targets = [list(chain.from_iterable(i)) for i in zip(*samples)] 160 | src_tokens = collate_tokens(src_tokens, self.token_pad_idx, left_pad=self.left_pad, pad_to_multiple=8) 161 | src_charges = collate_tokens(src_charges, self.charge_pad_idx, left_pad=self.left_pad, pad_to_multiple=8) 162 | src_coord = collate_tokens_coords(src_coord, 0, left_pad=self.left_pad, pad_to_multiple=8) 163 | src_distance = collate_tokens_2d(src_distance, 0, left_pad=self.left_pad, pad_to_multiple=8) 164 | src_edge_type = collate_tokens_2d(src_edge_type, 0, left_pad=self.left_pad, pad_to_multiple=8) 165 | token_targets = collate_tokens(token_targets, self.token_pad_idx, left_pad=self.left_pad, pad_to_multiple=8) 166 | charge_targets = collate_tokens(charge_targets, self.charge_pad_idx, left_pad=self.left_pad, pad_to_multiple=8) 167 | coord_targets = collate_tokens_coords(coord_targets, 0, left_pad=self.left_pad, pad_to_multiple=8) 168 | dist_targets = collate_tokens_2d(dist_targets, 0, left_pad=self.left_pad, pad_to_multiple=8) 169 | 170 | return src_tokens, src_charges, src_coord, src_distance, src_edge_type, batch, charge_targets, coord_targets, dist_targets, token_targets -------------------------------------------------------------------------------- /unimol/data/remove_hydrogen_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class RemoveHydrogenDataset(BaseWrapperDataset): 11 | def __init__( 12 | self, 13 | dataset, 14 | atoms, 15 | coordinates, 16 | charges, 17 | remove_hydrogen=False, 18 | remove_polar_hydrogen=False, 19 | ): 20 | self.dataset = dataset 21 | self.atoms = atoms 22 | self.coordinates = coordinates 23 | self.charges = charges 24 | self.remove_hydrogen = remove_hydrogen 25 | self.remove_polar_hydrogen = remove_polar_hydrogen 26 | self.set_epoch(None) 27 | 28 | def set_epoch(self, epoch, **unused): 29 | super().set_epoch(epoch) 30 | self.epoch = epoch 31 | 32 | @lru_cache(maxsize=16) 33 | def __cached_item__(self, index: int, epoch: int): 34 | dd = self.dataset[index].copy() 35 | atoms = dd[self.atoms] 36 | coordinates = dd[self.coordinates] 37 | charges = dd[self.charges] 38 | 39 | if self.remove_hydrogen: 40 | mask_hydrogen = atoms != "H" 41 | atoms = atoms[mask_hydrogen] 42 | coordinates = coordinates[mask_hydrogen] 43 | charges = charges[mask_hydrogen] 44 | if not self.remove_hydrogen and self.remove_polar_hydrogen: 45 | end_idx = 0 46 | for i, atom in enumerate(atoms[::-1]): 47 | if atom != "H": 48 | break 49 | else: 50 | end_idx = i + 1 51 | if end_idx != 0: 52 | atoms = atoms[:-end_idx] 53 | coordinates = coordinates[:-end_idx] 54 | charges = charges[:-end_idx] 55 | dd[self.atoms] = atoms 56 | dd[self.coordinates] = coordinates.astype(np.float32) 57 | dd[self.charges] = charges 58 | return dd 59 | 60 | def __getitem__(self, index: int): 61 | return self.__cached_item__(index, self.epoch) 62 | -------------------------------------------------------------------------------- /unimol/data/tta_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class TTADataset(BaseWrapperDataset): 11 | def __init__(self, dataset, seed, atoms, coordinates, charges, id="ori_smi", conf_size=10): 12 | self.dataset = dataset 13 | self.seed = seed 14 | self.atoms = atoms 15 | self.coordinates = coordinates 16 | self.charges = charges 17 | self.id = id 18 | self.conf_size = conf_size 19 | self.set_epoch(None) 20 | 21 | def set_epoch(self, epoch, **unused): 22 | super().set_epoch(epoch) 23 | self.epoch = epoch 24 | 25 | def __len__(self): 26 | return len(self.dataset) * self.conf_size 27 | 28 | @lru_cache(maxsize=16) 29 | def __cached_item__(self, index: int, epoch: int): 30 | mol_idx = index // self.conf_size 31 | coord_idx = index % self.conf_size 32 | atoms = np.array(self.dataset[mol_idx][self.atoms]) 33 | charges = np.array(self.dataset[mol_idx][self.charges]) 34 | coordinates = np.array(self.dataset[mol_idx][self.coordinates][coord_idx]) 35 | id = self.dataset[mol_idx][self.id] 36 | target = self.dataset[mol_idx]["target"] 37 | smi = self.dataset[mol_idx][self.id] 38 | return { 39 | "atoms": atoms, 40 | "coordinates": coordinates.astype(np.float32), 41 | "charges": charges.astype(str), 42 | "target": target, 43 | "smi":smi, 44 | "target": target, 45 | "id": id, 46 | } 47 | 48 | def __getitem__(self, index: int): 49 | return self.__cached_item__(index, self.epoch) 50 | 51 | 52 | class TTAPKADataset(BaseWrapperDataset): 53 | def __init__(self, dataset, seed, metadata, atoms, coordinates, charges, id="ori_smi"): 54 | self.dataset = dataset 55 | self.seed = seed 56 | self.metadata = metadata 57 | self.atoms = atoms 58 | self.coordinates = coordinates 59 | self.charges = charges 60 | self.id = id 61 | self._init_idx() 62 | self.set_epoch(None) 63 | 64 | def set_epoch(self, epoch, **unused): 65 | super().set_epoch(epoch) 66 | self.epoch = epoch 67 | 68 | def _init_idx(self): 69 | self.idx2key = {} 70 | total_sz = 0 71 | for i in range(len(self.dataset)): 72 | size = len(self.dataset[i][self.metadata]) 73 | for j in range(size): 74 | self.idx2key[total_sz] = (i, j) 75 | total_sz += 1 76 | self.total_sz = total_sz 77 | 78 | def get_idx2key(self): 79 | return self.idx2key 80 | 81 | def __len__(self): 82 | return self.total_sz 83 | 84 | @lru_cache(maxsize=16) 85 | def __cached_item__(self, index: int, epoch: int): 86 | smi_idx, mol_idx = self.idx2key[index] 87 | atoms = np.array(self.dataset[smi_idx][self.metadata][mol_idx][self.atoms]) 88 | coordinates = np.array(self.dataset[smi_idx][self.metadata][mol_idx][self.coordinates]) 89 | charges = np.array(self.dataset[smi_idx][self.metadata][mol_idx][self.charges]) 90 | smi = self.dataset[smi_idx]["ori_smi"] 91 | id = self.dataset[smi_idx][self.id] 92 | target = self.dataset[smi_idx]["target"] 93 | return { 94 | "atoms": atoms, 95 | "coordinates": coordinates.astype(np.float32), 96 | "charges": charges.astype(str), 97 | "smi": smi, 98 | "target": target, 99 | "id": id, 100 | } 101 | 102 | def __getitem__(self, index: int): 103 | return self.__cached_item__(index, self.epoch) 104 | -------------------------------------------------------------------------------- /unimol/examples/dict.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [CLS] 3 | [SEP] 4 | [UNK] 5 | C 6 | N 7 | O 8 | S 9 | H 10 | Cl 11 | F 12 | Br 13 | I 14 | Si 15 | P 16 | B 17 | Na 18 | K 19 | Al 20 | Ca 21 | Sn 22 | As 23 | Hg 24 | Fe 25 | Zn 26 | Cr 27 | Se 28 | Gd 29 | Au 30 | Li -------------------------------------------------------------------------------- /unimol/examples/dict_charge.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [CLS] 3 | [SEP] 4 | [UNK] 5 | 1 6 | 0 7 | -1 -------------------------------------------------------------------------------- /unimol/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) DP Techonology, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | import pickle 11 | import torch 12 | from unicore import checkpoint_utils, distributed_utils, options, utils 13 | from unicore.logging import progress_bar 14 | from unicore import tasks 15 | 16 | logging.basicConfig( 17 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 18 | datefmt="%Y-%m-%d %H:%M:%S", 19 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 20 | stream=sys.stdout, 21 | ) 22 | logger = logging.getLogger("unimol.inference") 23 | 24 | 25 | def main(args): 26 | 27 | assert ( 28 | args.batch_size is not None 29 | ), "Must specify batch size either with --batch-size" 30 | 31 | use_fp16 = args.fp16 32 | use_cuda = torch.cuda.is_available() and not args.cpu 33 | 34 | if use_cuda: 35 | torch.cuda.set_device(args.device_id) 36 | 37 | if args.distributed_world_size > 1: 38 | data_parallel_world_size = distributed_utils.get_data_parallel_world_size() 39 | data_parallel_rank = distributed_utils.get_data_parallel_rank() 40 | else: 41 | data_parallel_world_size = 1 42 | data_parallel_rank = 0 43 | 44 | # Load model 45 | logger.info("loading model(s) from {}".format(args.path)) 46 | state = checkpoint_utils.load_checkpoint_to_cpu(args.path) 47 | task = tasks.setup_task(args) 48 | model = task.build_model(args) 49 | model.load_state_dict(state["model"], strict=False) 50 | 51 | # Move models to GPU 52 | if use_fp16: 53 | model.half() 54 | if use_cuda: 55 | model.cuda() 56 | 57 | # Print args 58 | logger.info(args) 59 | 60 | # Build loss 61 | loss = task.build_loss(args) 62 | loss.eval() 63 | 64 | for subset in args.valid_subset.split(","): 65 | try: 66 | task.load_dataset(subset, combine=False, epoch=1) 67 | dataset = task.dataset(subset) 68 | except KeyError: 69 | raise Exception("Cannot find dataset: " + subset) 70 | 71 | if not os.path.exists(args.results_path): 72 | os.makedirs(args.results_path) 73 | fname = (os.path.abspath(args.path)).split("/")[-2] 74 | save_path = os.path.join(args.results_path, fname + "_" + subset + ".out.pkl") 75 | # Initialize data iterator 76 | itr = task.get_batch_iterator( 77 | dataset=dataset, 78 | batch_size=args.batch_size, 79 | ignore_invalid_inputs=True, 80 | required_batch_size_multiple=args.required_batch_size_multiple, 81 | seed=args.seed, 82 | num_shards=data_parallel_world_size, 83 | shard_id=data_parallel_rank, 84 | num_workers=args.num_workers, 85 | data_buffer_size=args.data_buffer_size, 86 | ).next_epoch_itr(shuffle=False) 87 | progress = progress_bar.progress_bar( 88 | itr, 89 | log_format=args.log_format, 90 | log_interval=args.log_interval, 91 | prefix=f"valid on '{subset}' subset", 92 | default_log_format=("tqdm" if not args.no_progress_bar else "simple"), 93 | ) 94 | log_outputs = [] 95 | for i, sample in enumerate(progress): 96 | sample = utils.move_to_cuda(sample) if use_cuda else sample 97 | if len(sample) == 0: 98 | continue 99 | _, _, log_output = task.valid_step(sample, model, loss, test=True) 100 | progress.log({}, step=i) 101 | log_outputs.append(log_output) 102 | pickle.dump(log_outputs, open(save_path, "wb")) 103 | logger.info("Done inference! ") 104 | return None 105 | 106 | 107 | def cli_main(): 108 | parser = options.get_validation_parser() 109 | options.add_model_args(parser) 110 | args = options.parse_args_and_arch(parser) 111 | 112 | distributed_utils.call_main(args, main) 113 | 114 | 115 | if __name__ == "__main__": 116 | cli_main() 117 | -------------------------------------------------------------------------------- /unimol/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol.losses." + file.name[:-3]) 8 | -------------------------------------------------------------------------------- /unimol/losses/mlm_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import math 6 | import torch 7 | import torch.nn.functional as F 8 | import pandas as pd 9 | import numpy as np 10 | from unicore import metrics 11 | from unicore.losses import UnicoreLoss, register_loss 12 | 13 | 14 | @register_loss("pretrain_mlm") 15 | class PretrainMLMLoss(UnicoreLoss): 16 | def __init__(self, task): 17 | super().__init__(task) 18 | self.padding_idx = task.dictionary.pad() 19 | self.charge_padding_idx = task.charge_dictionary.pad() 20 | self.seed = task.seed 21 | self.dist_mean = 6.174412864984603 22 | self.dist_std = 216.17030997643033 23 | 24 | def forward(self, model, sample, reduce=True): 25 | """Compute the loss for the given sample. 26 | 27 | Returns a tuple with three elements: 28 | 1) the loss 29 | 2) the sample size, which is used as the denominator for the gradient 30 | 3) logging outputs to display while training 31 | """ 32 | masked_tokens_a = sample["net_input_a"][-1].ne(self.padding_idx) 33 | all_output_a = model( 34 | sample["net_input_a"], 35 | classification_head_name=self.args.classification_head_name, 36 | encoder_masked_tokens=masked_tokens_a 37 | ) 38 | masked_tokens_b = sample["net_input_b"][-1].ne(self.padding_idx) 39 | all_output_b = model( 40 | sample["net_input_b"], 41 | classification_head_name=self.args.classification_head_name, 42 | encoder_masked_tokens=masked_tokens_b 43 | ) 44 | net_output_a, batch_a = all_output_a[:2] 45 | net_output_b, batch_b = all_output_b[:2] 46 | 47 | loss, predict = self.compute_loss(model, net_output_a, net_output_b, batch_a, batch_b, sample, reduce=reduce) 48 | sample_size = sample["target"]["finetune_target"].size(0) 49 | if not self.training: 50 | if self.task.mean and self.task.std: 51 | targets_mean = torch.tensor(self.task.mean, device=predict.device) 52 | targets_std = torch.tensor(self.task.std, device=predict.device) 53 | predict = predict * targets_std + targets_mean 54 | logging_output = { 55 | "pka_loss": loss.data, 56 | "predict": predict.view(-1, self.args.num_classes).data, 57 | "target": sample["target"]["finetune_target"] 58 | .view(-1, self.args.num_classes) 59 | .data, 60 | "smi_name": sample["id"], 61 | "sample_size": sample_size, 62 | "num_task": self.args.num_classes, 63 | "conf_size": self.args.conf_size, 64 | "bsz": sample["target"]["finetune_target"].size(0), 65 | } 66 | else: 67 | logging_output = { 68 | "pka_loss": loss.data, 69 | "sample_size": sample_size, 70 | "bsz": sample["target"]["finetune_target"].size(0), 71 | } 72 | 73 | loss, logging_output = self.compute_mlm_loss(loss, all_output_a, masked_tokens_a, logging_output, reduce= reduce) 74 | loss, logging_output = self.compute_mlm_loss(loss, all_output_b, masked_tokens_b, logging_output, reduce= reduce) 75 | logging_output['loss'] = loss.data 76 | 77 | return loss, sample_size, logging_output 78 | 79 | def compute_loss(self, model, net_output_a, net_output_b, batch_a, batch_b, sample, reduce=True): 80 | free_energy_a = net_output_a.view(-1, self.args.num_classes).float() 81 | free_energy_b = net_output_b.view(-1, self.args.num_classes).float() 82 | if not self.training: 83 | def compute_agg_free_energy(free_energy, batch): 84 | split_tensor_list = torch.split(free_energy, self.args.conf_size, dim=0) 85 | mean_tensor_list = [torch.mean(x, dim=0, keepdim=True) for x in split_tensor_list] 86 | agg_free_energy = torch.cat(mean_tensor_list, dim=0) 87 | agg_batch = [x//self.args.conf_size for x in batch] 88 | return agg_free_energy, agg_batch 89 | free_energy_a, batch_a = compute_agg_free_energy(free_energy_a, batch_a) 90 | free_energy_b, batch_b = compute_agg_free_energy(free_energy_b, batch_b) 91 | 92 | free_energy_a_padded = torch.nn.utils.rnn.pad_sequence( 93 | torch.split(free_energy_a, batch_a), 94 | padding_value=float("inf") 95 | ) 96 | free_energy_b_padded = torch.nn.utils.rnn.pad_sequence( 97 | torch.split(free_energy_b, batch_b), 98 | padding_value=float("inf") 99 | ) 100 | predicts = ( 101 | torch.logsumexp(-free_energy_a_padded, dim=0)- 102 | torch.logsumexp(-free_energy_b_padded, dim=0) 103 | ) / torch.log(torch.tensor([10.0])).item() 104 | 105 | targets = ( 106 | sample["target"]["finetune_target"].view(-1, self.args.num_classes).float() 107 | ) 108 | if self.task.mean and self.task.std: 109 | targets_mean = torch.tensor(self.task.mean, device=targets.device) 110 | targets_std = torch.tensor(self.task.std, device=targets.device) 111 | targets = (targets - targets_mean) / targets_std 112 | loss = F.mse_loss( 113 | predicts, 114 | targets, 115 | reduction="sum" if reduce else "none", 116 | ) 117 | return loss, predicts 118 | 119 | def compute_mlm_loss(self, loss, all_output, masked_tokens, logging_output, reduce= True): 120 | (_, _, 121 | logits_encoder, charge_logits, encoder_distance, encoder_coord, x_norm, delta_encoder_pair_rep_norm, 122 | token_targets, charge_targets, coord_targets, dist_targets) = all_output 123 | sample_size = masked_tokens.long().sum() 124 | 125 | if self.args.masked_token_loss > 0: 126 | target = token_targets 127 | if masked_tokens is not None: 128 | target = target[masked_tokens] 129 | masked_token_loss = F.nll_loss( 130 | F.log_softmax(logits_encoder, dim=-1, dtype=torch.float32), 131 | target, 132 | ignore_index=self.padding_idx, 133 | reduction="sum" if reduce else "none", 134 | ) 135 | masked_pred = logits_encoder.argmax(dim=-1) 136 | masked_hit = (masked_pred == target).long().sum() 137 | masked_cnt = sample_size 138 | if 'masked_token_loss' in logging_output: 139 | logging_output['seq_len'] += token_targets.size(1) * token_targets.size(0) 140 | logging_output['masked_token_loss'] += masked_token_loss.data 141 | logging_output['masked_token_hit'] += masked_hit.data 142 | logging_output['masked_token_cnt'] += masked_cnt 143 | else: 144 | logging_output['seq_len'] = token_targets.size(1) * token_targets.size(0) 145 | logging_output['masked_token_loss'] = masked_token_loss.data 146 | logging_output['masked_token_hit'] = masked_hit.data 147 | logging_output['masked_token_cnt'] = masked_cnt 148 | loss += masked_token_loss * self.args.masked_token_loss 149 | 150 | if self.args.masked_charge_loss > 0: 151 | target = charge_targets 152 | if masked_tokens is not None: 153 | target = target[masked_tokens] 154 | masked_charge_loss = F.nll_loss( 155 | F.log_softmax(charge_logits, dim=-1, dtype=torch.float32), 156 | target, 157 | ignore_index=self.charge_padding_idx, 158 | reduction="sum" if reduce else "none", 159 | ) 160 | masked_pred = charge_logits.argmax(dim=-1) 161 | masked_hit = (masked_pred == target).long().sum() 162 | masked_cnt = sample_size 163 | if 'masked_charge_loss' in logging_output: 164 | logging_output['masked_charge_loss'] += masked_charge_loss.data 165 | logging_output['masked_charge_hit'] += masked_hit.data 166 | logging_output['masked_charge_cnt'] += masked_cnt 167 | else: 168 | logging_output['masked_charge_loss'] = masked_charge_loss.data 169 | logging_output['masked_charge_hit'] = masked_hit.data 170 | logging_output['masked_charge_cnt'] = masked_cnt 171 | loss += masked_charge_loss * self.args.masked_charge_loss 172 | 173 | if self.args.masked_coord_loss > 0: 174 | # real = mask + delta 175 | masked_coord_loss = F.smooth_l1_loss( 176 | encoder_coord[masked_tokens].view(-1, 3).float(), 177 | coord_targets[masked_tokens].view(-1, 3), 178 | reduction="sum" if reduce else "none", 179 | beta=1.0, 180 | ) 181 | loss = loss + masked_coord_loss * self.args.masked_coord_loss 182 | # restore the scale of loss for displaying 183 | if 'masked_coord_loss' in logging_output: 184 | logging_output["masked_coord_loss"] += masked_coord_loss.data 185 | else: 186 | logging_output["masked_coord_loss"] = masked_coord_loss.data 187 | 188 | if self.args.masked_dist_loss > 0: 189 | dist_masked_tokens = masked_tokens 190 | masked_dist_loss = self.cal_dist_loss( 191 | encoder_distance, dist_masked_tokens, dist_targets, reduce=reduce, normalize=True, 192 | ) 193 | loss = loss + masked_dist_loss * self.args.masked_dist_loss 194 | if 'masked_dist_loss' in logging_output: 195 | logging_output["masked_dist_loss"] += masked_dist_loss.data 196 | else: 197 | logging_output["masked_dist_loss"] = masked_dist_loss.data 198 | 199 | if self.args.x_norm_loss > 0 and x_norm is not None: 200 | loss = loss + self.args.x_norm_loss * x_norm 201 | if 'x_norm_loss' in logging_output: 202 | logging_output["x_norm_loss"] += x_norm.data 203 | else: 204 | logging_output["x_norm_loss"] = x_norm.data 205 | 206 | if ( 207 | self.args.delta_pair_repr_norm_loss > 0 208 | and delta_encoder_pair_rep_norm is not None 209 | ): 210 | loss = ( 211 | loss + self.args.delta_pair_repr_norm_loss * delta_encoder_pair_rep_norm 212 | ) 213 | if 'delta_pair_repr_norm_loss' in logging_output: 214 | logging_output[ 215 | "delta_pair_repr_norm_loss" 216 | ] += delta_encoder_pair_rep_norm.data 217 | else: 218 | logging_output[ 219 | "delta_pair_repr_norm_loss" 220 | ] = delta_encoder_pair_rep_norm.data 221 | 222 | return loss, logging_output 223 | 224 | @staticmethod 225 | def reduce_metrics(logging_outputs, split="valid") -> None: 226 | """Aggregate logging outputs from data parallel training.""" 227 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 228 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 229 | if "valid" in split or "test" in split: 230 | sample_size *= logging_outputs[0].get("conf_size",0) 231 | bsz = sum(log.get("bsz", 0) for log in logging_outputs) 232 | seq_len = sum(log.get("seq_len", 0) for log in logging_outputs) 233 | # we divide by log(2) to convert the loss from base e to base 2 234 | metrics.log_scalar( 235 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 236 | ) 237 | metrics.log_scalar("seq_len", seq_len / bsz, 1, round=3) 238 | pka_loss = sum(log.get("pka_loss", 0) for log in logging_outputs) 239 | metrics.log_scalar( 240 | "pka_loss", pka_loss / sample_size, sample_size, round=3 241 | ) 242 | masked_token_loss = sum(log.get("masked_token_loss", 0) for log in logging_outputs) 243 | if masked_token_loss >0: 244 | metrics.log_scalar( 245 | "masked_token_loss", masked_token_loss / sample_size, sample_size, round=3 246 | ) 247 | masked_acc = sum( 248 | log.get("masked_token_hit", 0) for log in logging_outputs 249 | ) / sum(log.get("masked_token_cnt", 0) for log in logging_outputs) 250 | metrics.log_scalar("masked_token_acc", masked_acc, sample_size, round=3) 251 | 252 | masked_charge_loss = sum(log.get("masked_charge_loss", 0) for log in logging_outputs) 253 | if masked_charge_loss >0: 254 | metrics.log_scalar( 255 | "masked_charge_loss", masked_charge_loss / sample_size, sample_size, round=3 256 | ) 257 | masked_acc = sum( 258 | log.get("masked_charge_hit", 0) for log in logging_outputs 259 | ) / sum(log.get("masked_charge_cnt", 0) for log in logging_outputs) 260 | metrics.log_scalar("masked_charge_acc", masked_acc, sample_size, round=3) 261 | 262 | masked_coord_loss = sum( 263 | log.get("masked_coord_loss", 0) for log in logging_outputs 264 | ) 265 | if masked_coord_loss > 0: 266 | metrics.log_scalar( 267 | "masked_coord_loss", 268 | masked_coord_loss / sample_size, 269 | sample_size, 270 | round=3, 271 | ) 272 | 273 | masked_dist_loss = sum( 274 | log.get("masked_dist_loss", 0) for log in logging_outputs 275 | ) 276 | if masked_dist_loss > 0: 277 | metrics.log_scalar( 278 | "masked_dist_loss", masked_dist_loss / sample_size, sample_size, round=3 279 | ) 280 | 281 | x_norm_loss = sum(log.get("x_norm_loss", 0) for log in logging_outputs) 282 | if x_norm_loss > 0: 283 | metrics.log_scalar( 284 | "x_norm_loss", x_norm_loss / sample_size, sample_size, round=3 285 | ) 286 | 287 | delta_pair_repr_norm_loss = sum( 288 | log.get("delta_pair_repr_norm_loss", 0) for log in logging_outputs 289 | ) 290 | if delta_pair_repr_norm_loss > 0: 291 | metrics.log_scalar( 292 | "delta_pair_repr_norm_loss", 293 | delta_pair_repr_norm_loss / sample_size, 294 | sample_size, 295 | round=3, 296 | ) 297 | 298 | if "valid" in split or "test" in split: 299 | sample_size //= logging_outputs[0].get("conf_size",0) 300 | predicts = torch.cat([log.get("predict") for log in logging_outputs], dim=0) 301 | if predicts.size(-1) == 1: 302 | # single label regression task, add aggregate acc and loss score 303 | targets = torch.cat( 304 | [log.get("target", 0) for log in logging_outputs], dim=0 305 | ) 306 | smi_list = [ 307 | item for log in logging_outputs for item in log.get("smi_name") 308 | ] 309 | df = pd.DataFrame( 310 | { 311 | "predict": predicts.view(-1).cpu(), 312 | "target": targets.view(-1).cpu(), 313 | "smi": smi_list, 314 | } 315 | ) 316 | mae = np.abs(df["predict"] - df["target"]).mean() 317 | mse = ((df["predict"] - df["target"]) ** 2).mean() 318 | 319 | metrics.log_scalar(f"{split}_mae", mae, sample_size, round=3) 320 | metrics.log_scalar(f"{split}_mse", mse, sample_size, round=3) 321 | metrics.log_scalar( 322 | f"{split}_rmse", np.sqrt(mse), sample_size, round=4 323 | ) 324 | 325 | @staticmethod 326 | def logging_outputs_can_be_summed(is_train) -> bool: 327 | """ 328 | Whether the logging outputs returned by `forward` can be summed 329 | across workers prior to calling `reduce_metrics`. Setting this 330 | to True will improves distributed training speed. 331 | """ 332 | return is_train 333 | 334 | def cal_dist_loss(self, dist, dist_masked_tokens, dist_targets, reduce= True, normalize=False): 335 | masked_distance = dist[dist_masked_tokens, :] 336 | masked_distance_target = dist_targets[dist_masked_tokens] 337 | non_pad_pos = masked_distance_target > 0 338 | if normalize: 339 | masked_distance_target = ( 340 | masked_distance_target.float() - self.dist_mean 341 | ) / self.dist_std 342 | masked_dist_loss = F.smooth_l1_loss( 343 | masked_distance[non_pad_pos].view(-1).float(), 344 | masked_distance_target[non_pad_pos].view(-1), 345 | reduction="sum" if reduce else "none", 346 | beta=1.0, 347 | ) 348 | return masked_dist_loss 349 | -------------------------------------------------------------------------------- /unimol/losses/reg_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import math 6 | import torch 7 | import torch.nn.functional as F 8 | import pandas as pd 9 | import numpy as np 10 | from unicore import metrics 11 | from unicore.losses import UnicoreLoss, register_loss 12 | 13 | 14 | @register_loss("finetune_mse") 15 | class FinetuneMSELoss(UnicoreLoss): 16 | def __init__(self, task): 17 | super().__init__(task) 18 | 19 | def forward(self, model, sample, reduce=True): 20 | """Compute the loss for the given sample. 21 | 22 | Returns a tuple with three elements: 23 | 1) the loss 24 | 2) the sample size, which is used as the denominator for the gradient 25 | 3) logging outputs to display while training 26 | """ 27 | net_output_a, batch_a = model( 28 | sample["net_input_a"], 29 | classification_head_name=self.args.classification_head_name, 30 | features_only = True, 31 | ) 32 | net_output_b, batch_b = model( 33 | sample["net_input_b"], 34 | classification_head_name=self.args.classification_head_name, 35 | features_only = True, 36 | ) 37 | 38 | loss, predict = self.compute_loss(model, net_output_a, net_output_b, batch_a, batch_b, sample, reduce=reduce) 39 | sample_size = sample["target"]["finetune_target"].size(0) 40 | if not self.training: 41 | if self.task.mean and self.task.std: 42 | targets_mean = torch.tensor(self.task.mean, device=predict.device) 43 | targets_std = torch.tensor(self.task.std, device=predict.device) 44 | predict = predict * targets_std + targets_mean 45 | logging_output = { 46 | "loss": loss.data, 47 | "predict": predict.view(-1, self.args.num_classes).data, 48 | "target": sample["target"]["finetune_target"] 49 | .view(-1, self.args.num_classes) 50 | .data, 51 | "smi_name": sample["id"], 52 | "sample_size": sample_size, 53 | "num_task": self.args.num_classes, 54 | "conf_size": self.args.conf_size, 55 | "bsz": sample["target"]["finetune_target"].size(0), 56 | } 57 | else: 58 | logging_output = { 59 | "loss": loss.data, 60 | "sample_size": sample_size, 61 | "bsz": sample["target"]["finetune_target"].size(0), 62 | } 63 | return loss, sample_size, logging_output 64 | 65 | def compute_loss(self, model, net_output_a, net_output_b, batch_a, batch_b, sample, reduce=True): 66 | free_energy_a = net_output_a.view(-1, self.args.num_classes).float() 67 | free_energy_b = net_output_b.view(-1, self.args.num_classes).float() 68 | if not self.training: 69 | def compute_agg_free_energy(free_energy, batch): 70 | split_tensor_list = torch.split(free_energy, self.args.conf_size, dim=0) 71 | mean_tensor_list = [torch.mean(x, dim=0, keepdim=True) for x in split_tensor_list] 72 | agg_free_energy = torch.cat(mean_tensor_list, dim=0) 73 | agg_batch = [x//self.args.conf_size for x in batch] 74 | return agg_free_energy, agg_batch 75 | free_energy_a, batch_a = compute_agg_free_energy(free_energy_a, batch_a) 76 | free_energy_b, batch_b = compute_agg_free_energy(free_energy_b, batch_b) 77 | 78 | free_energy_a_padded = torch.nn.utils.rnn.pad_sequence( 79 | torch.split(free_energy_a, batch_a), 80 | padding_value=float("inf") 81 | ) 82 | free_energy_b_padded = torch.nn.utils.rnn.pad_sequence( 83 | torch.split(free_energy_b, batch_b), 84 | padding_value=float("inf") 85 | ) 86 | predicts = ( 87 | torch.logsumexp(-free_energy_a_padded, dim=0)- 88 | torch.logsumexp(-free_energy_b_padded, dim=0) 89 | ) / torch.log(torch.tensor([10.0])).item() 90 | 91 | targets = ( 92 | sample["target"]["finetune_target"].view(-1, self.args.num_classes).float() 93 | ) 94 | if self.task.mean and self.task.std: 95 | targets_mean = torch.tensor(self.task.mean, device=targets.device) 96 | targets_std = torch.tensor(self.task.std, device=targets.device) 97 | targets = (targets - targets_mean) / targets_std 98 | loss = F.mse_loss( 99 | predicts, 100 | targets, 101 | reduction="sum" if reduce else "none", 102 | ) 103 | return loss, predicts 104 | 105 | @staticmethod 106 | def reduce_metrics(logging_outputs, split="valid") -> None: 107 | """Aggregate logging outputs from data parallel training.""" 108 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 109 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 110 | # we divide by log(2) to convert the loss from base e to base 2 111 | metrics.log_scalar( 112 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 113 | ) 114 | if "valid" in split or "test" in split: 115 | predicts = torch.cat([log.get("predict") for log in logging_outputs], dim=0) 116 | if predicts.size(-1) == 1: 117 | # single label regression task, add aggregate acc and loss score 118 | targets = torch.cat( 119 | [log.get("target", 0) for log in logging_outputs], dim=0 120 | ) 121 | smi_list = [ 122 | item for log in logging_outputs for item in log.get("smi_name") 123 | ] 124 | df = pd.DataFrame( 125 | { 126 | "predict": predicts.view(-1).cpu(), 127 | "target": targets.view(-1).cpu(), 128 | "smi": smi_list, 129 | } 130 | ) 131 | mae = np.abs(df["predict"] - df["target"]).mean() 132 | mse = ((df["predict"] - df["target"]) ** 2).mean() 133 | metrics.log_scalar(f"{split}_mae", mae, sample_size, round=3) 134 | metrics.log_scalar(f"{split}_mse", mse, sample_size, round=3) 135 | metrics.log_scalar( 136 | f"{split}_rmse", np.sqrt(mse), sample_size, round=4 137 | ) 138 | 139 | @staticmethod 140 | def logging_outputs_can_be_summed(is_train) -> bool: 141 | """ 142 | Whether the logging outputs returned by `forward` can be summed 143 | across workers prior to calling `reduce_metrics`. Setting this 144 | to True will improves distributed training speed. 145 | """ 146 | return is_train 147 | 148 | 149 | @register_loss("infer_free_energy") 150 | class InferFreeEnergyLoss(UnicoreLoss): 151 | def __init__(self, task): 152 | super().__init__(task) 153 | 154 | def forward(self, model, sample, reduce=True): 155 | """Compute the loss for the given sample. 156 | 157 | Returns a tuple with three elements: 158 | 1) the loss 159 | 2) the sample size, which is used as the denominator for the gradient 160 | 3) logging outputs to display while training 161 | """ 162 | net_output = model( 163 | **sample["net_input"], 164 | classification_head_name=self.args.classification_head_name, 165 | features_only=True, 166 | ) 167 | reg_output = net_output[0] 168 | loss = torch.tensor([0.01], device=sample["target"]["finetune_target"].device) 169 | sample_size = sample["target"]["finetune_target"].size(0) 170 | if not self.training: 171 | logging_output = { 172 | "loss": loss.data, 173 | "predict": reg_output.view(-1, self.args.num_classes).data, 174 | "target": sample["target"]["finetune_target"] 175 | .view(-1, self.args.num_classes) 176 | .data, 177 | "smi_name": sample["smi_name"], 178 | "sample_size": sample_size, 179 | "num_task": self.args.num_classes, 180 | "conf_size": self.args.conf_size, 181 | "bsz": sample["target"]["finetune_target"].size(0), 182 | } 183 | return loss, sample_size, logging_output 184 | 185 | @staticmethod 186 | def reduce_metrics(logging_outputs, split="valid") -> None: 187 | """Aggregate logging outputs from data parallel training.""" 188 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 189 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 190 | # we divide by log(2) to convert the loss from base e to base 2 191 | metrics.log_scalar( 192 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 193 | ) 194 | 195 | @staticmethod 196 | def logging_outputs_can_be_summed(is_train) -> bool: 197 | """ 198 | Whether the logging outputs returned by `forward` can be summed 199 | across workers prior to calling `reduce_metrics`. Setting this 200 | to True will improves distributed training speed. 201 | """ 202 | return is_train 203 | -------------------------------------------------------------------------------- /unimol/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unimol_pka import UniMolPKAModel 2 | from .transformer_encoder_with_pair import TransformerEncoderWithPair 3 | from .unimol import UniMolModel -------------------------------------------------------------------------------- /unimol/models/transformer_encoder_with_pair.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from typing import Optional 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from unicore.modules import TransformerEncoderLayer, LayerNorm 12 | 13 | 14 | class TransformerEncoderWithPair(nn.Module): 15 | def __init__( 16 | self, 17 | encoder_layers: int = 6, 18 | embed_dim: int = 768, 19 | ffn_embed_dim: int = 3072, 20 | attention_heads: int = 8, 21 | emb_dropout: float = 0.1, 22 | dropout: float = 0.1, 23 | attention_dropout: float = 0.1, 24 | activation_dropout: float = 0.0, 25 | max_seq_len: int = 256, 26 | activation_fn: str = "gelu", 27 | post_ln: bool = False, 28 | no_final_head_layer_norm: bool = False, 29 | ) -> None: 30 | 31 | super().__init__() 32 | self.emb_dropout = emb_dropout 33 | self.max_seq_len = max_seq_len 34 | self.embed_dim = embed_dim 35 | self.attention_heads = attention_heads 36 | self.emb_layer_norm = LayerNorm(self.embed_dim) 37 | if not post_ln: 38 | self.final_layer_norm = LayerNorm(self.embed_dim) 39 | else: 40 | self.final_layer_norm = None 41 | 42 | if not no_final_head_layer_norm: 43 | self.final_head_layer_norm = LayerNorm(attention_heads) 44 | else: 45 | self.final_head_layer_norm = None 46 | 47 | self.layers = nn.ModuleList( 48 | [ 49 | TransformerEncoderLayer( 50 | embed_dim=self.embed_dim, 51 | ffn_embed_dim=ffn_embed_dim, 52 | attention_heads=attention_heads, 53 | dropout=dropout, 54 | attention_dropout=attention_dropout, 55 | activation_dropout=activation_dropout, 56 | activation_fn=activation_fn, 57 | post_ln=post_ln, 58 | ) 59 | for _ in range(encoder_layers) 60 | ] 61 | ) 62 | 63 | def forward( 64 | self, 65 | emb: torch.Tensor, 66 | attn_mask: Optional[torch.Tensor] = None, 67 | padding_mask: Optional[torch.Tensor] = None, 68 | ) -> torch.Tensor: 69 | 70 | bsz = emb.size(0) 71 | seq_len = emb.size(1) 72 | x = self.emb_layer_norm(emb) 73 | x = F.dropout(x, p=self.emb_dropout, training=self.training) 74 | 75 | # account for padding while computing the representation 76 | if padding_mask is not None: 77 | x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) 78 | input_attn_mask = attn_mask 79 | input_padding_mask = padding_mask 80 | 81 | def fill_attn_mask(attn_mask, padding_mask, fill_val=float("-inf")): 82 | if attn_mask is not None and padding_mask is not None: 83 | # merge key_padding_mask and attn_mask 84 | attn_mask = attn_mask.view(x.size(0), -1, seq_len, seq_len) 85 | attn_mask.masked_fill_( 86 | padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 87 | fill_val, 88 | ) 89 | attn_mask = attn_mask.view(-1, seq_len, seq_len) 90 | padding_mask = None 91 | return attn_mask, padding_mask 92 | 93 | assert attn_mask is not None 94 | attn_mask, padding_mask = fill_attn_mask(attn_mask, padding_mask) 95 | 96 | for i in range(len(self.layers)): 97 | x, attn_mask, _ = self.layers[i]( 98 | x, padding_mask=padding_mask, attn_bias=attn_mask, return_attn=True 99 | ) 100 | 101 | def norm_loss(x, eps=1e-10, tolerance=1.0): 102 | x = x.float() 103 | max_norm = x.shape[-1] ** 0.5 104 | norm = torch.sqrt(torch.sum(x**2, dim=-1) + eps) 105 | error = torch.nn.functional.relu((norm - max_norm).abs() - tolerance) 106 | return error 107 | 108 | def masked_mean(mask, value, dim=-1, eps=1e-10): 109 | return ( 110 | torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) 111 | ).mean() 112 | 113 | x_norm = norm_loss(x) 114 | if input_padding_mask is not None: 115 | token_mask = 1.0 - input_padding_mask.float() 116 | else: 117 | token_mask = torch.ones_like(x_norm, device=x_norm.device) 118 | x_norm = masked_mean(token_mask, x_norm) 119 | 120 | if self.final_layer_norm is not None: 121 | x = self.final_layer_norm(x) 122 | 123 | delta_pair_repr = attn_mask - input_attn_mask 124 | delta_pair_repr, _ = fill_attn_mask(delta_pair_repr, input_padding_mask, 0) 125 | attn_mask = ( 126 | attn_mask.view(bsz, -1, seq_len, seq_len).permute(0, 2, 3, 1).contiguous() 127 | ) 128 | delta_pair_repr = ( 129 | delta_pair_repr.view(bsz, -1, seq_len, seq_len) 130 | .permute(0, 2, 3, 1) 131 | .contiguous() 132 | ) 133 | 134 | pair_mask = token_mask[..., None] * token_mask[..., None, :] 135 | delta_pair_repr_norm = norm_loss(delta_pair_repr) 136 | delta_pair_repr_norm = masked_mean( 137 | pair_mask, delta_pair_repr_norm, dim=(-1, -2) 138 | ) 139 | 140 | if self.final_head_layer_norm is not None: 141 | delta_pair_repr = self.final_head_layer_norm(delta_pair_repr) 142 | 143 | return x, attn_mask, delta_pair_repr, x_norm, delta_pair_repr_norm 144 | -------------------------------------------------------------------------------- /unimol/models/unimol.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import logging 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from unicore import utils 10 | from unicore.models import BaseUnicoreModel, register_model, register_model_architecture 11 | from unicore.modules import LayerNorm, init_bert_params 12 | from .transformer_encoder_with_pair import TransformerEncoderWithPair 13 | from typing import Dict, Any, List 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @register_model("unimol") 20 | class UniMolModel(BaseUnicoreModel): 21 | @staticmethod 22 | def add_args(parser): 23 | """Add model-specific arguments to the parser.""" 24 | parser.add_argument( 25 | "--encoder-layers", type=int, metavar="L", help="num encoder layers" 26 | ) 27 | parser.add_argument( 28 | "--encoder-embed-dim", 29 | type=int, 30 | metavar="H", 31 | help="encoder embedding dimension", 32 | ) 33 | parser.add_argument( 34 | "--encoder-ffn-embed-dim", 35 | type=int, 36 | metavar="F", 37 | help="encoder embedding dimension for FFN", 38 | ) 39 | parser.add_argument( 40 | "--encoder-attention-heads", 41 | type=int, 42 | metavar="A", 43 | help="num encoder attention heads", 44 | ) 45 | parser.add_argument( 46 | "--activation-fn", 47 | choices=utils.get_available_activation_fns(), 48 | help="activation function to use", 49 | ) 50 | parser.add_argument( 51 | "--pooler-activation-fn", 52 | choices=utils.get_available_activation_fns(), 53 | help="activation function to use for pooler layer", 54 | ) 55 | parser.add_argument( 56 | "--emb-dropout", 57 | type=float, 58 | metavar="D", 59 | help="dropout probability for embeddings", 60 | ) 61 | parser.add_argument( 62 | "--dropout", type=float, metavar="D", help="dropout probability" 63 | ) 64 | parser.add_argument( 65 | "--attention-dropout", 66 | type=float, 67 | metavar="D", 68 | help="dropout probability for attention weights", 69 | ) 70 | parser.add_argument( 71 | "--activation-dropout", 72 | type=float, 73 | metavar="D", 74 | help="dropout probability after activation in FFN", 75 | ) 76 | parser.add_argument( 77 | "--pooler-dropout", 78 | type=float, 79 | metavar="D", 80 | help="dropout probability in the masked_lm pooler layers", 81 | ) 82 | parser.add_argument( 83 | "--max-seq-len", type=int, help="number of positional embeddings to learn" 84 | ) 85 | parser.add_argument( 86 | "--post-ln", type=bool, help="use post layernorm or pre layernorm" 87 | ) 88 | parser.add_argument( 89 | "--x-norm-loss", 90 | type=float, 91 | metavar="D", 92 | help="x norm loss ratio", 93 | ) 94 | parser.add_argument( 95 | "--delta-pair-repr-norm-loss", 96 | type=float, 97 | metavar="D", 98 | help="delta encoder pair repr norm loss ratio", 99 | ) 100 | parser.add_argument( 101 | "--mode", 102 | type=str, 103 | default="train", 104 | choices=["train", "infer"], 105 | ) 106 | 107 | def __init__(self, args, dictionary, charge_dictionary): 108 | super().__init__() 109 | base_architecture(args) 110 | self.args = args 111 | self.padding_idx = dictionary.pad() 112 | self.charge_padding_idx = charge_dictionary.pad() 113 | self.embed_tokens = nn.Embedding( 114 | len(dictionary), args.encoder_embed_dim, self.padding_idx 115 | ) 116 | self.embed_charges = nn.Embedding( 117 | len(charge_dictionary), args.encoder_embed_dim, self.charge_padding_idx 118 | ) 119 | self._num_updates = None 120 | self.encoder = TransformerEncoderWithPair( 121 | encoder_layers=args.encoder_layers, 122 | embed_dim=args.encoder_embed_dim, 123 | ffn_embed_dim=args.encoder_ffn_embed_dim, 124 | attention_heads=args.encoder_attention_heads, 125 | emb_dropout=args.emb_dropout, 126 | dropout=args.dropout, 127 | attention_dropout=args.attention_dropout, 128 | activation_dropout=args.activation_dropout, 129 | max_seq_len=args.max_seq_len, 130 | activation_fn=args.activation_fn, 131 | no_final_head_layer_norm=args.delta_pair_repr_norm_loss < 0, 132 | ) 133 | 134 | K = 128 135 | n_edge_type = len(dictionary) * len(dictionary) 136 | self.gbf_proj = NonLinearHead( 137 | K, args.encoder_attention_heads, args.activation_fn 138 | ) 139 | self.gbf = GaussianLayer(K, n_edge_type) 140 | 141 | self.classification_heads = nn.ModuleDict() 142 | self.apply(init_bert_params) 143 | 144 | @classmethod 145 | def build_model(cls, args, task): 146 | """Build a new model instance.""" 147 | return cls(args, task.dictionary, task.charge_dictionary) 148 | 149 | def forward( 150 | self, 151 | src_tokens, 152 | src_charges, 153 | src_distance, 154 | src_coord, 155 | src_edge_type, 156 | encoder_masked_tokens=None, 157 | classification_head_name=None, 158 | **kwargs 159 | ): 160 | 161 | padding_mask = src_tokens.eq(self.padding_idx) 162 | if not padding_mask.any(): 163 | padding_mask = None 164 | x = self.embed_tokens(src_tokens) 165 | 166 | charge_padding_mask = src_charges.eq(self.charge_padding_idx) 167 | if not charge_padding_mask.any(): 168 | padding_mask = None 169 | charges_emb = self.embed_charges(src_charges) 170 | # involve charge info 171 | x += charges_emb 172 | 173 | def get_dist_features(dist, et): 174 | n_node = dist.size(-1) 175 | gbf_feature = self.gbf(dist, et) 176 | gbf_result = self.gbf_proj(gbf_feature) 177 | graph_attn_bias = gbf_result 178 | graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous() 179 | graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node) 180 | return graph_attn_bias 181 | 182 | graph_attn_bias = get_dist_features(src_distance, src_edge_type) 183 | ( 184 | encoder_rep, 185 | encoder_pair_rep, 186 | delta_encoder_pair_rep, 187 | x_norm, 188 | delta_encoder_pair_rep_norm, 189 | ) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias) 190 | encoder_pair_rep[encoder_pair_rep == float("-inf")] = 0 191 | 192 | if classification_head_name is not None: 193 | logits = self.classification_heads[classification_head_name](encoder_rep) 194 | if self.args.mode == 'infer': 195 | return encoder_rep, encoder_pair_rep 196 | else: 197 | return ( 198 | logits, 199 | x_norm, 200 | delta_encoder_pair_rep_norm, 201 | ) 202 | 203 | def register_classification_head( 204 | self, name, num_classes=None, inner_dim=None, **kwargs 205 | ): 206 | """Register a classification head.""" 207 | if name in self.classification_heads: 208 | prev_num_classes = self.classification_heads[name].out_proj.out_features 209 | prev_inner_dim = self.classification_heads[name].dense.out_features 210 | if num_classes != prev_num_classes or inner_dim != prev_inner_dim: 211 | logger.warning( 212 | 're-registering head "{}" with num_classes {} (prev: {}) ' 213 | "and inner_dim {} (prev: {})".format( 214 | name, num_classes, prev_num_classes, inner_dim, prev_inner_dim 215 | ) 216 | ) 217 | self.classification_heads[name] = ClassificationHead( 218 | input_dim=self.args.encoder_embed_dim, 219 | inner_dim=inner_dim or self.args.encoder_embed_dim, 220 | num_classes=num_classes, 221 | activation_fn=self.args.pooler_activation_fn, 222 | pooler_dropout=self.args.pooler_dropout, 223 | ) 224 | 225 | def set_num_updates(self, num_updates): 226 | """State from trainer to pass along to model at every update.""" 227 | self._num_updates = num_updates 228 | 229 | def get_num_updates(self): 230 | return self._num_updates 231 | 232 | 233 | class ClassificationHead(nn.Module): 234 | """Head for sentence-level classification tasks.""" 235 | 236 | def __init__( 237 | self, 238 | input_dim, 239 | inner_dim, 240 | num_classes, 241 | activation_fn, 242 | pooler_dropout, 243 | ): 244 | super().__init__() 245 | self.dense = nn.Linear(input_dim, inner_dim) 246 | self.activation_fn = utils.get_activation_fn(activation_fn) 247 | self.dropout = nn.Dropout(p=pooler_dropout) 248 | self.out_proj = nn.Linear(inner_dim, num_classes) 249 | 250 | def forward(self, features, **kwargs): 251 | x = features[:, 0, :] # take token (equiv. to [CLS]) 252 | x = self.dropout(x) 253 | x = self.dense(x) 254 | x = self.activation_fn(x) 255 | x = self.dropout(x) 256 | x = self.out_proj(x) 257 | return x 258 | 259 | 260 | class NonLinearHead(nn.Module): 261 | """Head for simple classification tasks.""" 262 | 263 | def __init__( 264 | self, 265 | input_dim, 266 | out_dim, 267 | activation_fn, 268 | hidden=None, 269 | ): 270 | super().__init__() 271 | hidden = input_dim if not hidden else hidden 272 | self.linear1 = nn.Linear(input_dim, hidden) 273 | self.linear2 = nn.Linear(hidden, out_dim) 274 | self.activation_fn = utils.get_activation_fn(activation_fn) 275 | 276 | def forward(self, x): 277 | x = self.linear1(x) 278 | x = self.activation_fn(x) 279 | x = self.linear2(x) 280 | return x 281 | 282 | 283 | @torch.jit.script 284 | def gaussian(x, mean, std): 285 | pi = 3.14159 286 | a = (2 * pi) ** 0.5 287 | return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) 288 | 289 | 290 | class GaussianLayer(nn.Module): 291 | def __init__(self, K=128, edge_types=1024): 292 | super().__init__() 293 | self.K = K 294 | self.means = nn.Embedding(1, K) 295 | self.stds = nn.Embedding(1, K) 296 | self.mul = nn.Embedding(edge_types, 1) 297 | self.bias = nn.Embedding(edge_types, 1) 298 | nn.init.uniform_(self.means.weight, 0, 3) 299 | nn.init.uniform_(self.stds.weight, 0, 3) 300 | nn.init.constant_(self.bias.weight, 0) 301 | nn.init.constant_(self.mul.weight, 1) 302 | 303 | def forward(self, x, edge_type): 304 | mul = self.mul(edge_type).type_as(x) 305 | bias = self.bias(edge_type).type_as(x) 306 | x = mul * x.unsqueeze(-1) + bias 307 | x = x.expand(-1, -1, -1, self.K) 308 | mean = self.means.weight.float().view(-1) 309 | std = self.stds.weight.float().view(-1).abs() + 1e-5 310 | return gaussian(x.float(), mean, std).type_as(self.means.weight) 311 | 312 | 313 | @register_model_architecture("unimol", "unimol") 314 | def base_architecture(args): 315 | args.encoder_layers = getattr(args, "encoder_layers", 15) 316 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) 317 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) 318 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64) 319 | args.dropout = getattr(args, "dropout", 0.1) 320 | args.emb_dropout = getattr(args, "emb_dropout", 0.1) 321 | args.attention_dropout = getattr(args, "attention_dropout", 0.1) 322 | args.activation_dropout = getattr(args, "activation_dropout", 0.0) 323 | args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) 324 | args.max_seq_len = getattr(args, "max_seq_len", 512) 325 | args.activation_fn = getattr(args, "activation_fn", "gelu") 326 | args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") 327 | args.post_ln = getattr(args, "post_ln", False) 328 | args.x_norm_loss = getattr(args, "x_norm_loss", -1.0) 329 | args.delta_pair_repr_norm_loss = getattr(args, "delta_pair_repr_norm_loss", -1.0) 330 | 331 | 332 | @register_model_architecture("unimol", "unimol_base") 333 | def unimol_base_architecture(args): 334 | base_architecture(args) 335 | -------------------------------------------------------------------------------- /unimol/models/unimol_pka.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import logging 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from unicore import utils 10 | from unicore.models import BaseUnicoreModel, register_model, register_model_architecture 11 | from unicore.modules import LayerNorm, init_bert_params 12 | from .unimol import UniMolModel, ClassificationHead, NonLinearHead 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | @register_model("unimol_pka") 18 | class UniMolPKAModel(BaseUnicoreModel): 19 | @staticmethod 20 | def add_args(parser): 21 | """Add model-specific arguments to the parser.""" 22 | parser.add_argument( 23 | "--masked-token-loss", 24 | type=float, 25 | metavar="D", 26 | help="mask loss ratio", 27 | ) 28 | parser.add_argument( 29 | "--masked-charge-loss", 30 | type=float, 31 | metavar="D", 32 | help="mask charge loss ratio", 33 | ) 34 | parser.add_argument( 35 | "--masked-dist-loss", 36 | type=float, 37 | metavar="D", 38 | help="masked distance loss ratio", 39 | ) 40 | parser.add_argument( 41 | "--masked-coord-loss", 42 | type=float, 43 | metavar="D", 44 | help="masked coord loss ratio", 45 | ) 46 | parser.add_argument( 47 | "--x-norm-loss", 48 | type=float, 49 | metavar="D", 50 | help="x norm loss ratio", 51 | ) 52 | parser.add_argument( 53 | "--delta-pair-repr-norm-loss", 54 | type=float, 55 | metavar="D", 56 | help="delta encoder pair repr norm loss ratio", 57 | ) 58 | parser.add_argument( 59 | "--pooler-dropout", 60 | type=float, 61 | metavar="D", 62 | help="dropout probability in the masked_lm pooler layers", 63 | ) 64 | 65 | def __init__(self, args, dictionary, charge_dictionary): 66 | super().__init__() 67 | unimol_pka_architecture(args) 68 | self.args = args 69 | self.unimol = UniMolModel(self.args, dictionary, charge_dictionary) 70 | if args.masked_token_loss > 0: 71 | self.lm_head = MaskLMHead( 72 | embed_dim=args.encoder_embed_dim, 73 | output_dim=len(dictionary), 74 | activation_fn=args.activation_fn, 75 | weight=None, 76 | ) 77 | if args.masked_charge_loss > 0: 78 | self.charge_lm_head = MaskLMHead( 79 | embed_dim=args.encoder_embed_dim, 80 | output_dim=len(charge_dictionary), 81 | activation_fn=args.activation_fn, 82 | weight=None, 83 | ) 84 | if args.masked_coord_loss > 0: 85 | self.pair2coord_proj = NonLinearHead( 86 | args.encoder_attention_heads, 1, args.activation_fn 87 | ) 88 | if args.masked_dist_loss > 0: 89 | self.dist_head = DistanceHead( 90 | args.encoder_attention_heads, args.activation_fn 91 | ) 92 | 93 | @classmethod 94 | def build_model(cls, args, task): 95 | """Build a new model instance.""" 96 | return cls(args, task.dictionary, task.charge_dictionary) 97 | 98 | def forward( 99 | self, 100 | input_metadata, 101 | classification_head_name=None, 102 | encoder_masked_tokens=None, 103 | features_only=False, 104 | **kwargs 105 | ): 106 | if not features_only: 107 | src_tokens, src_charges, src_coord, src_distance, src_edge_type, batch, charge_targets, coord_targets, dist_targets, token_targets = input_metadata 108 | else: 109 | src_tokens, src_charges, src_coord, src_distance, src_edge_type, batch = input_metadata 110 | charge_targets, coord_targets, dist_targets, token_targets = None, None, None, None 111 | padding_mask = src_tokens.eq(self.unimol.padding_idx) 112 | if not padding_mask.any(): 113 | padding_mask = None 114 | x = self.unimol.embed_tokens(src_tokens) 115 | 116 | charge_padding_mask = src_charges.eq(self.unimol.charge_padding_idx) 117 | if not charge_padding_mask.any(): 118 | padding_mask = None 119 | charges_emb = self.unimol.embed_charges(src_charges) 120 | # involve charge info 121 | x += charges_emb 122 | 123 | def get_dist_features(dist, et): 124 | n_node = dist.size(-1) 125 | gbf_feature = self.unimol.gbf(dist, et) 126 | gbf_result = self.unimol.gbf_proj(gbf_feature) 127 | graph_attn_bias = gbf_result 128 | graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous() 129 | graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node) 130 | return graph_attn_bias 131 | 132 | graph_attn_bias = get_dist_features(src_distance, src_edge_type) 133 | ( 134 | encoder_rep, 135 | encoder_pair_rep, 136 | delta_encoder_pair_rep, 137 | x_norm, 138 | delta_encoder_pair_rep_norm, 139 | ) = self.unimol.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias) 140 | encoder_pair_rep[encoder_pair_rep == float("-inf")] = 0 141 | 142 | encoder_distance = None 143 | encoder_coord = None 144 | 145 | if not features_only: 146 | if self.args.masked_token_loss > 0: 147 | logits = self.lm_head(encoder_rep, encoder_masked_tokens) 148 | if self.args.masked_charge_loss > 0: 149 | charge_logits = self.charge_lm_head(encoder_rep, encoder_masked_tokens) 150 | if self.args.masked_coord_loss > 0: 151 | coords_emb = src_coord 152 | if padding_mask is not None: 153 | atom_num = (torch.sum(1 - padding_mask.type_as(x), dim=1) - 1).view( 154 | -1, 1, 1, 1 155 | ) 156 | else: 157 | atom_num = src_coord.shape[1] - 1 158 | delta_pos = coords_emb.unsqueeze(1) - coords_emb.unsqueeze(2) 159 | attn_probs = self.pair2coord_proj(delta_encoder_pair_rep) 160 | coord_update = delta_pos / atom_num * attn_probs 161 | coord_update = torch.sum(coord_update, dim=2) 162 | encoder_coord = coords_emb + coord_update 163 | if self.args.masked_dist_loss > 0: 164 | encoder_distance = self.dist_head(encoder_pair_rep) 165 | 166 | if classification_head_name is not None: 167 | cls_logits = self.unimol.classification_heads[classification_head_name](encoder_rep) 168 | if not features_only: 169 | return ( 170 | cls_logits, batch, 171 | logits, charge_logits, encoder_distance, encoder_coord, x_norm, delta_encoder_pair_rep_norm, 172 | token_targets, charge_targets, coord_targets, dist_targets 173 | ) 174 | else: 175 | return cls_logits, batch 176 | 177 | def register_classification_head( 178 | self, name, num_classes=None, inner_dim=None, **kwargs 179 | ): 180 | """Register a classification head.""" 181 | if name in self.unimol.classification_heads: 182 | prev_num_classes = self.unimol.classification_heads[name].out_proj.out_features 183 | prev_inner_dim = self.unimol.classification_heads[name].dense.out_features 184 | if num_classes != prev_num_classes or inner_dim != prev_inner_dim: 185 | logger.warning( 186 | 're-registering head "{}" with num_classes {} (prev: {}) ' 187 | "and inner_dim {} (prev: {})".format( 188 | name, num_classes, prev_num_classes, inner_dim, prev_inner_dim 189 | ) 190 | ) 191 | self.unimol.classification_heads[name] = ClassificationHead( 192 | input_dim=self.args.encoder_embed_dim, 193 | inner_dim=inner_dim or self.args.encoder_embed_dim, 194 | num_classes=num_classes, 195 | activation_fn=self.args.pooler_activation_fn, 196 | pooler_dropout=self.args.pooler_dropout, 197 | ) 198 | 199 | def set_num_updates(self, num_updates): 200 | """State from trainer to pass along to model at every update.""" 201 | self._num_updates = num_updates 202 | 203 | def get_num_updates(self): 204 | return self._num_updates 205 | 206 | 207 | class MaskLMHead(nn.Module): 208 | """Head for masked language modeling.""" 209 | 210 | def __init__(self, embed_dim, output_dim, activation_fn, weight=None): 211 | super().__init__() 212 | self.dense = nn.Linear(embed_dim, embed_dim) 213 | self.activation_fn = utils.get_activation_fn(activation_fn) 214 | self.layer_norm = LayerNorm(embed_dim) 215 | 216 | if weight is None: 217 | weight = nn.Linear(embed_dim, output_dim, bias=False).weight 218 | self.weight = weight 219 | self.bias = nn.Parameter(torch.zeros(output_dim)) 220 | 221 | def forward(self, features, masked_tokens=None, **kwargs): 222 | # Only project the masked tokens while training, 223 | # saves both memory and computation 224 | if masked_tokens is not None: 225 | features = features[masked_tokens, :] 226 | 227 | x = self.dense(features) 228 | x = self.activation_fn(x) 229 | x = self.layer_norm(x) 230 | # project back to size of vocabulary with bias 231 | x = F.linear(x, self.weight) + self.bias 232 | return x 233 | 234 | 235 | class DistanceHead(nn.Module): 236 | def __init__( 237 | self, 238 | heads, 239 | activation_fn, 240 | ): 241 | super().__init__() 242 | self.dense = nn.Linear(heads, heads) 243 | self.layer_norm = nn.LayerNorm(heads) 244 | self.out_proj = nn.Linear(heads, 1) 245 | self.activation_fn = utils.get_activation_fn(activation_fn) 246 | 247 | def forward(self, x): 248 | bsz, seq_len, seq_len, _ = x.size() 249 | # x[x == float('-inf')] = 0 250 | x = self.dense(x) 251 | x = self.activation_fn(x) 252 | x = self.layer_norm(x) 253 | x = self.out_proj(x).view(bsz, seq_len, seq_len) 254 | x = (x + x.transpose(-1, -2)) * 0.5 255 | return x 256 | 257 | 258 | @register_model_architecture("unimol_pka", "unimol_pka") 259 | def unimol_pka_architecture(args): 260 | def base_architecture(args): 261 | args.encoder_layers = getattr(args, "encoder_layers", 15) 262 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) 263 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) 264 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64) 265 | args.dropout = getattr(args, "dropout", 0.1) 266 | args.emb_dropout = getattr(args, "emb_dropout", 0.1) 267 | args.attention_dropout = getattr(args, "attention_dropout", 0.1) 268 | args.activation_dropout = getattr(args, "activation_dropout", 0.0) 269 | args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) 270 | args.max_seq_len = getattr(args, "max_seq_len", 512) 271 | args.activation_fn = getattr(args, "activation_fn", "gelu") 272 | args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") 273 | args.post_ln = getattr(args, "post_ln", False) 274 | args.masked_token_loss = getattr(args, "masked_token_loss", -1.0) 275 | args.masked_charge_loss = getattr(args, "masked_charge_loss", -1.0) 276 | args.masked_coord_loss = getattr(args, "masked_coord_loss", -1.0) 277 | args.masked_dist_loss = getattr(args, "masked_dist_loss", -1.0) 278 | args.x_norm_loss = getattr(args, "x_norm_loss", -1.0) 279 | args.delta_pair_repr_norm_loss = getattr(args, "delta_pair_repr_norm_loss", -1.0) 280 | 281 | base_architecture(args) 282 | -------------------------------------------------------------------------------- /unimol/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol.tasks." + file.name[:-3]) 8 | -------------------------------------------------------------------------------- /unimol/tasks/unimol_free_energy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import logging 6 | import os 7 | 8 | import numpy as np 9 | from unicore.data import ( 10 | Dictionary, 11 | NestedDictionaryDataset, 12 | LMDBDataset, 13 | AppendTokenDataset, 14 | PrependTokenDataset, 15 | RightPadDataset, 16 | SortDataset, 17 | TokenizeDataset, 18 | RightPadDataset2D, 19 | RawLabelDataset, 20 | RawArrayDataset, 21 | FromNumpyDataset, 22 | ) 23 | from unimol.data import ( 24 | KeyDataset, 25 | DistanceDataset, 26 | EdgeTypeDataset, 27 | RemoveHydrogenDataset, 28 | NormalizeDataset, 29 | CroppingDataset, 30 | RightPadDatasetCoord, 31 | data_utils, 32 | ) 33 | 34 | from unimol.data.tta_dataset import TTADataset 35 | from unicore.tasks import UnicoreTask, register_task 36 | 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | @register_task("mol_free_energy") 42 | class UniMolFreeEnergyTask(UnicoreTask): 43 | """Task for training transformer auto-encoder models.""" 44 | 45 | @staticmethod 46 | def add_args(parser): 47 | """Add task-specific arguments to the parser.""" 48 | parser.add_argument("data", help="downstream data path") 49 | parser.add_argument("--task-name", type=str, help="downstream task name") 50 | parser.add_argument( 51 | "--classification-head-name", 52 | default="classification", 53 | help="finetune downstream task name", 54 | ) 55 | parser.add_argument( 56 | "--num-classes", 57 | default=1, 58 | type=int, 59 | help="finetune downstream task classes numbers", 60 | ) 61 | parser.add_argument("--no-shuffle", action="store_true", help="shuffle data") 62 | parser.add_argument( 63 | "--conf-size", 64 | default=10, 65 | type=int, 66 | help="number of conformers generated with each molecule", 67 | ) 68 | parser.add_argument( 69 | "--remove-hydrogen", 70 | action="store_true", 71 | help="remove hydrogen atoms", 72 | ) 73 | parser.add_argument( 74 | "--remove-polar-hydrogen", 75 | action="store_true", 76 | help="remove polar hydrogen atoms", 77 | ) 78 | parser.add_argument( 79 | "--max-atoms", 80 | type=int, 81 | default=256, 82 | help="selected maximum number of atoms in a molecule", 83 | ) 84 | parser.add_argument( 85 | "--dict-name", 86 | default="dict.txt", 87 | help="dictionary file", 88 | ) 89 | parser.add_argument( 90 | "--charge-dict-name", 91 | default="dict_charge.txt", 92 | help="dictionary file", 93 | ) 94 | parser.add_argument( 95 | "--only-polar", 96 | default=1, 97 | type=int, 98 | help="1: only reserve polar hydrogen; 0: no hydrogen; -1: all hydrogen ", 99 | ) 100 | 101 | def __init__(self, args, dictionary, charge_dictionary): 102 | super().__init__(args) 103 | self.dictionary = dictionary 104 | self.charge_dictionary = charge_dictionary 105 | self.seed = args.seed 106 | # add mask token 107 | self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) 108 | self.charge_mask_idx = charge_dictionary.add_symbol("[MASK]", is_special=True) 109 | if self.args.only_polar > 0: 110 | self.args.remove_polar_hydrogen = True 111 | elif self.args.only_polar < 0: 112 | self.args.remove_polar_hydrogen = False 113 | else: 114 | self.args.remove_hydrogen = True 115 | 116 | @classmethod 117 | def setup_task(cls, args, **kwargs): 118 | dictionary = Dictionary.load(os.path.join(args.data, args.dict_name)) 119 | logger.info("dictionary: {} types".format(len(dictionary))) 120 | charge_dictionary = Dictionary.load(os.path.join(args.data, args.charge_dict_name)) 121 | logger.info("charge dictionary: {} types".format(len(charge_dictionary))) 122 | return cls(args, dictionary, charge_dictionary) 123 | 124 | def load_dataset(self, split, **kwargs): 125 | """Load a given dataset split. 126 | Args: 127 | split (str): name of the data scoure (e.g., train) 128 | """ 129 | split_path = os.path.join(self.args.data, self.args.task_name, split + ".lmdb") 130 | dataset = LMDBDataset(split_path) 131 | dataset = TTADataset( 132 | dataset, self.args.seed, "atoms", "coordinates", "charges", "smi", self.args.conf_size 133 | ) 134 | tgt_dataset = KeyDataset(dataset, "target") 135 | smi_dataset = KeyDataset(dataset, "smi") 136 | 137 | dataset = RemoveHydrogenDataset( 138 | dataset, 139 | "atoms", 140 | "coordinates", 141 | "charges", 142 | self.args.remove_hydrogen, 143 | self.args.remove_polar_hydrogen, 144 | ) 145 | dataset = CroppingDataset( 146 | dataset, self.seed, "atoms", "coordinates", "charges", self.args.max_atoms 147 | ) 148 | dataset = NormalizeDataset(dataset, "coordinates", normalize_coord=True) 149 | src_dataset = KeyDataset(dataset, "atoms") 150 | src_dataset = TokenizeDataset( 151 | src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len 152 | ) 153 | src_charge_dataset = KeyDataset(dataset, "charges") 154 | src_charge_dataset = TokenizeDataset( 155 | src_charge_dataset, self.charge_dictionary, max_seq_len=self.args.max_seq_len 156 | ) 157 | coord_dataset = KeyDataset(dataset, "coordinates") 158 | 159 | def PrependAndAppend(dataset, pre_token, app_token): 160 | dataset = PrependTokenDataset(dataset, pre_token) 161 | return AppendTokenDataset(dataset, app_token) 162 | 163 | src_dataset = PrependAndAppend( 164 | src_dataset, self.dictionary.bos(), self.dictionary.eos() 165 | ) 166 | src_charge_dataset = PrependAndAppend( 167 | src_charge_dataset, self.charge_dictionary.bos(), self.charge_dictionary.eos() 168 | ) 169 | edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary)) 170 | coord_dataset = FromNumpyDataset(coord_dataset) 171 | coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0) 172 | distance_dataset = DistanceDataset(coord_dataset) 173 | 174 | nest_dataset = NestedDictionaryDataset( 175 | { 176 | "net_input": { 177 | "src_tokens": RightPadDataset( 178 | src_dataset, 179 | pad_idx=self.dictionary.pad(), 180 | ), 181 | "src_charges": RightPadDataset( 182 | src_charge_dataset, 183 | pad_idx=self.charge_dictionary.pad(), 184 | ), 185 | "src_coord": RightPadDatasetCoord( 186 | coord_dataset, 187 | pad_idx=0, 188 | ), 189 | "src_distance": RightPadDataset2D( 190 | distance_dataset, 191 | pad_idx=0, 192 | ), 193 | "src_edge_type": RightPadDataset2D( 194 | edge_type, 195 | pad_idx=0, 196 | ), 197 | }, 198 | "target": { 199 | "finetune_target": RawLabelDataset(tgt_dataset), 200 | }, 201 | "smi_name": RawArrayDataset(smi_dataset), 202 | }, 203 | ) 204 | if not self.args.no_shuffle and split == "train": 205 | with data_utils.numpy_seed(self.args.seed): 206 | shuffle = np.random.permutation(len(src_dataset)) 207 | 208 | self.datasets[split] = SortDataset( 209 | nest_dataset, 210 | sort_order=[shuffle], 211 | ) 212 | else: 213 | self.datasets[split] = nest_dataset 214 | 215 | def build_model(self, args): 216 | from unicore import models 217 | 218 | model = models.build_model(args, self) 219 | model.register_classification_head( 220 | self.args.classification_head_name, 221 | num_classes=self.args.num_classes, 222 | ) 223 | return model 224 | -------------------------------------------------------------------------------- /unimol/tasks/unimol_mlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import logging 6 | import os 7 | 8 | import numpy as np 9 | from unicore.data import ( 10 | Dictionary, 11 | NestedDictionaryDataset, 12 | LMDBDataset, 13 | AppendTokenDataset, 14 | PrependTokenDataset, 15 | SortDataset, 16 | TokenizeDataset, 17 | RawLabelDataset, 18 | FromNumpyDataset, 19 | ) 20 | from unimol.data import ( 21 | KeyDataset, 22 | ConformerSamplePKADataset, 23 | PKAInputDataset, 24 | PKAMLMInputDataset, 25 | DistanceDataset, 26 | EdgeTypeDataset, 27 | RemoveHydrogenDataset, 28 | NormalizeDataset, 29 | CroppingDataset, 30 | FoldLMDBDataset, 31 | StackedLMDBDataset, 32 | SplitLMDBDataset, 33 | data_utils, 34 | MaskPointsDataset, 35 | ) 36 | 37 | from unimol.data.tta_dataset import TTADataset, TTAPKADataset 38 | from unicore.tasks import UnicoreTask, register_task 39 | 40 | 41 | logger = logging.getLogger(__name__) 42 | 43 | 44 | @register_task("mol_pka_mlm") 45 | class UniMolPKAMLMTask(UnicoreTask): 46 | """Task for training transformer auto-encoder models.""" 47 | 48 | @staticmethod 49 | def add_args(parser): 50 | """Add task-specific arguments to the parser.""" 51 | parser.add_argument("data", help="downstream data path") 52 | parser.add_argument( 53 | "--mask-prob", 54 | default=0.15, 55 | type=float, 56 | help="probability of replacing a token with mask", 57 | ) 58 | parser.add_argument( 59 | "--leave-unmasked-prob", 60 | default=0.05, 61 | type=float, 62 | help="probability that a masked token is unmasked", 63 | ) 64 | parser.add_argument( 65 | "--random-token-prob", 66 | default=0.05, 67 | type=float, 68 | help="probability of replacing a token with a random token", 69 | ) 70 | parser.add_argument( 71 | "--noise-type", 72 | default="uniform", 73 | choices=["trunc_normal", "uniform", "normal", "none"], 74 | help="noise type in coordinate noise", 75 | ) 76 | parser.add_argument( 77 | "--noise", 78 | default=1.0, 79 | type=float, 80 | help="coordinate noise for masked atoms", 81 | ) 82 | parser.add_argument("--task-name", type=str, help="downstream task name") 83 | parser.add_argument( 84 | "--classification-head-name", 85 | default="classification", 86 | help="finetune downstream task name", 87 | ) 88 | parser.add_argument( 89 | "--num-classes", 90 | default=1, 91 | type=int, 92 | help="finetune downstream task classes numbers", 93 | ) 94 | parser.add_argument("--no-shuffle", action="store_true", help="shuffle data") 95 | parser.add_argument( 96 | "--conf-size", 97 | default=10, 98 | type=int, 99 | help="number of conformers generated with each molecule", 100 | ) 101 | parser.add_argument( 102 | "--remove-hydrogen", 103 | action="store_true", 104 | help="remove hydrogen atoms", 105 | ) 106 | parser.add_argument( 107 | "--remove-polar-hydrogen", 108 | action="store_true", 109 | help="remove polar hydrogen atoms", 110 | ) 111 | parser.add_argument( 112 | "--max-atoms", 113 | type=int, 114 | default=256, 115 | help="selected maximum number of atoms in a molecule", 116 | ) 117 | parser.add_argument( 118 | "--dict-name", 119 | default="dict.txt", 120 | help="dictionary file", 121 | ) 122 | parser.add_argument( 123 | "--charge-dict-name", 124 | default="dict_charge.txt", 125 | help="dictionary file", 126 | ) 127 | parser.add_argument( 128 | "--only-polar", 129 | default=1, 130 | type=int, 131 | help="1: only reserve polar hydrogen; 0: no hydrogen; -1: all hydrogen ", 132 | ) 133 | parser.add_argument( 134 | '--split-mode', 135 | type=str, 136 | default='predefine', 137 | choices=['predefine', 'cross_valid', 'random', 'infer'], 138 | ) 139 | parser.add_argument( 140 | "--nfolds", 141 | default=5, 142 | type=int, 143 | help="cross validation split folds" 144 | ) 145 | parser.add_argument( 146 | "--fold", 147 | default=0, 148 | type=int, 149 | help='local fold used as validation set, and other folds will be used as train set' 150 | ) 151 | parser.add_argument( 152 | "--cv-seed", 153 | default=42, 154 | type=int, 155 | help="random seed used to do cross validation splits" 156 | ) 157 | 158 | def __init__(self, args, dictionary, charge_dictionary): 159 | super().__init__(args) 160 | self.dictionary = dictionary 161 | self.charge_dictionary = charge_dictionary 162 | self.seed = args.seed 163 | # add mask token 164 | self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) 165 | self.charge_mask_idx = charge_dictionary.add_symbol("[MASK]", is_special=True) 166 | if self.args.only_polar > 0: 167 | self.args.remove_polar_hydrogen = True 168 | elif self.args.only_polar < 0: 169 | self.args.remove_polar_hydrogen = False 170 | else: 171 | self.args.remove_hydrogen = True 172 | if self.args.split_mode !='predefine': 173 | self.__init_data() 174 | 175 | def __init_data(self): 176 | data_path = os.path.join(self.args.data, self.args.task_name + '.lmdb') 177 | raw_dataset = LMDBDataset(data_path) 178 | if self.args.split_mode == 'cross_valid': 179 | train_folds = [] 180 | for _fold in range(self.args.nfolds): 181 | if _fold == 0: 182 | cache_fold_info = FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds).get_fold_info() 183 | if _fold == self.args.fold: 184 | self.valid_dataset = FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds, cache_fold_info=cache_fold_info) 185 | if _fold != self.args.fold: 186 | train_folds.append(FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds, cache_fold_info=cache_fold_info)) 187 | self.train_dataset = StackedLMDBDataset(train_folds) 188 | elif self.args.split_mode == 'random': 189 | cache_fold_info = SplitLMDBDataset(raw_dataset, self.args.seed, 0).get_fold_info() 190 | self.train_dataset = SplitLMDBDataset(raw_dataset, self.args.seed, 0, cache_fold_info=cache_fold_info) 191 | self.valid_dataset = SplitLMDBDataset(raw_dataset, self.args.seed, 1, cache_fold_info=cache_fold_info) 192 | 193 | @classmethod 194 | def setup_task(cls, args, **kwargs): 195 | dictionary = Dictionary.load(os.path.join(args.data, args.dict_name)) 196 | logger.info("dictionary: {} types".format(len(dictionary))) 197 | charge_dictionary = Dictionary.load(os.path.join(args.data, args.charge_dict_name)) 198 | logger.info("charge dictionary: {} types".format(len(charge_dictionary))) 199 | return cls(args, dictionary, charge_dictionary) 200 | 201 | def load_dataset(self, split, **kwargs): 202 | """Load a given dataset split. 203 | Args: 204 | split (str): name of the data scoure (e.g., train) 205 | """ 206 | self.split = split 207 | if self.args.split_mode != 'predefine': 208 | if split == 'train': 209 | dataset = self.train_dataset 210 | elif split == 'valid': 211 | dataset =self.valid_dataset 212 | else: 213 | split_path = os.path.join(self.args.data, split + ".lmdb") 214 | dataset = LMDBDataset(split_path) 215 | tgt_dataset = KeyDataset(dataset, "target") 216 | if split in ['train', 'train.small']: 217 | tgt_list = [tgt_dataset[i] for i in range(len(tgt_dataset))] 218 | self.mean = sum(tgt_list) / len(tgt_list) 219 | self.std = 1 220 | elif split in ['novartis_acid', 'novartis_base', 'sampl6', 'sampl7', 'sampl8']: 221 | self.mean = 6.504894871171601 # precompute from dwar_8228 full set 222 | self.std = 1 223 | id_dataset = KeyDataset(dataset, "ori_smi") 224 | 225 | def GetPKAInput(dataset, metadata_key): 226 | mol_dataset = TTAPKADataset(dataset, self.args.seed, metadata_key, "atoms", "coordinates", "charges") 227 | idx2key = mol_dataset.get_idx2key() 228 | if split in ["train","train.small"]: 229 | sample_dataset = ConformerSamplePKADataset( 230 | mol_dataset, self.args.seed, "atoms", "coordinates", "charges", self.args.conf_size 231 | ) 232 | else: 233 | sample_dataset = TTADataset( 234 | mol_dataset, self.args.seed, "atoms", "coordinates","charges","id", self.args.conf_size 235 | ) 236 | 237 | sample_dataset = RemoveHydrogenDataset( 238 | sample_dataset, 239 | "atoms", 240 | "coordinates", 241 | "charges", 242 | self.args.remove_hydrogen, 243 | self.args.remove_polar_hydrogen, 244 | ) 245 | sample_dataset = CroppingDataset( 246 | sample_dataset, self.seed, "atoms", "coordinates","charges", self.args.max_atoms 247 | ) 248 | sample_dataset = NormalizeDataset(sample_dataset, "coordinates", normalize_coord=True) 249 | src_dataset = KeyDataset(sample_dataset, "atoms") 250 | src_dataset = TokenizeDataset( 251 | src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len 252 | ) 253 | src_charge_dataset = KeyDataset(sample_dataset, "charges") 254 | src_charge_dataset = TokenizeDataset( 255 | src_charge_dataset, self.charge_dictionary, max_seq_len=self.args.max_seq_len 256 | ) 257 | coord_dataset = KeyDataset(sample_dataset, "coordinates") 258 | expand_dataset = MaskPointsDataset( 259 | src_dataset, 260 | coord_dataset, 261 | src_charge_dataset, 262 | self.dictionary, 263 | self.charge_dictionary, 264 | pad_idx=self.dictionary.pad(), 265 | charge_pad_idx=self.charge_dictionary.pad(), 266 | mask_idx=self.mask_idx, 267 | charge_mask_idx=self.charge_mask_idx, 268 | noise_type=self.args.noise_type, 269 | noise=self.args.noise, 270 | seed=self.seed, 271 | mask_prob=self.args.mask_prob, 272 | leave_unmasked_prob=self.args.leave_unmasked_prob, 273 | random_token_prob=self.args.random_token_prob, 274 | ) 275 | 276 | def PrependAndAppend(dataset, pre_token, app_token): 277 | dataset = PrependTokenDataset(dataset, pre_token) 278 | return AppendTokenDataset(dataset, app_token) 279 | 280 | encoder_token_dataset = KeyDataset(expand_dataset, "atoms") 281 | encoder_target_dataset = KeyDataset(expand_dataset, "targets") 282 | encoder_coord_dataset = KeyDataset(expand_dataset, "coordinates") 283 | encoder_charge_dataset = KeyDataset(expand_dataset, "charges") 284 | encoder_charge_target_dataset = KeyDataset(expand_dataset, "charge_targets") 285 | 286 | src_dataset = PrependAndAppend( 287 | encoder_token_dataset, self.dictionary.bos(), self.dictionary.eos() 288 | ) 289 | src_charge_dataset = PrependAndAppend( 290 | encoder_charge_dataset, self.charge_dictionary.bos(), self.charge_dictionary.eos() 291 | ) 292 | token_tgt_dataset = PrependAndAppend( 293 | encoder_target_dataset, self.dictionary.pad(), self.dictionary.pad() 294 | ) 295 | charge_tgt_dataset = PrependAndAppend( 296 | encoder_charge_target_dataset, self.charge_dictionary.pad(), self.charge_dictionary.pad() 297 | ) 298 | encoder_coord_dataset = PrependAndAppend(encoder_coord_dataset, 0.0, 0.0) 299 | encoder_distance_dataset = DistanceDataset(encoder_coord_dataset) 300 | 301 | edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary)) 302 | coord_dataset = FromNumpyDataset(coord_dataset) 303 | coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0) 304 | distance_dataset = DistanceDataset(coord_dataset) 305 | 306 | return PKAMLMInputDataset(idx2key, src_dataset, src_charge_dataset, encoder_coord_dataset, encoder_distance_dataset, edge_type, token_tgt_dataset, charge_tgt_dataset, distance_dataset, coord_dataset, self.dictionary.pad(), self.charge_dictionary.pad(), split, self.args.conf_size) 307 | 308 | input_a_dataset = GetPKAInput(dataset, "metadata_a") 309 | input_b_dataset = GetPKAInput(dataset, "metadata_b") 310 | 311 | nest_dataset = NestedDictionaryDataset( 312 | { 313 | "net_input_a": input_a_dataset, 314 | "net_input_b": input_b_dataset, 315 | "target": { 316 | "finetune_target": RawLabelDataset(tgt_dataset), 317 | }, 318 | "id": id_dataset, 319 | }, 320 | ) 321 | 322 | if not self.args.no_shuffle and split in ["train","train.small"]: 323 | with data_utils.numpy_seed(self.args.seed): 324 | shuffle = np.random.permutation(len(id_dataset)) 325 | 326 | self.datasets[split] = SortDataset( 327 | nest_dataset, 328 | sort_order=[shuffle], 329 | ) 330 | else: 331 | self.datasets[split] = nest_dataset 332 | 333 | def build_model(self, args): 334 | from unicore import models 335 | 336 | model = models.build_model(args, self) 337 | model.register_classification_head( 338 | self.args.classification_head_name, 339 | num_classes=self.args.num_classes, 340 | ) 341 | return model 342 | -------------------------------------------------------------------------------- /unimol/tasks/unimol_pka.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import logging 6 | import os 7 | 8 | import numpy as np 9 | from unicore.data import ( 10 | Dictionary, 11 | NestedDictionaryDataset, 12 | LMDBDataset, 13 | AppendTokenDataset, 14 | PrependTokenDataset, 15 | SortDataset, 16 | TokenizeDataset, 17 | RawLabelDataset, 18 | FromNumpyDataset, 19 | ) 20 | from unimol.data import ( 21 | KeyDataset, 22 | ConformerSamplePKADataset, 23 | PKAInputDataset, 24 | DistanceDataset, 25 | EdgeTypeDataset, 26 | RemoveHydrogenDataset, 27 | NormalizeDataset, 28 | CroppingDataset, 29 | FoldLMDBDataset, 30 | StackedLMDBDataset, 31 | SplitLMDBDataset, 32 | data_utils, 33 | ) 34 | 35 | from unimol.data.tta_dataset import TTADataset, TTAPKADataset 36 | from unicore.tasks import UnicoreTask, register_task 37 | 38 | 39 | logger = logging.getLogger(__name__) 40 | 41 | 42 | @register_task("mol_pka") 43 | class UniMolPKATask(UnicoreTask): 44 | """Task for training transformer auto-encoder models.""" 45 | 46 | @staticmethod 47 | def add_args(parser): 48 | """Add task-specific arguments to the parser.""" 49 | parser.add_argument("data", help="downstream data path") 50 | parser.add_argument("--task-name", type=str, help="downstream task name") 51 | parser.add_argument( 52 | "--classification-head-name", 53 | default="classification", 54 | help="finetune downstream task name", 55 | ) 56 | parser.add_argument( 57 | "--num-classes", 58 | default=1, 59 | type=int, 60 | help="finetune downstream task classes numbers", 61 | ) 62 | parser.add_argument("--no-shuffle", action="store_true", help="shuffle data") 63 | parser.add_argument( 64 | "--conf-size", 65 | default=10, 66 | type=int, 67 | help="number of conformers generated with each molecule", 68 | ) 69 | parser.add_argument( 70 | "--remove-hydrogen", 71 | action="store_true", 72 | help="remove hydrogen atoms", 73 | ) 74 | parser.add_argument( 75 | "--remove-polar-hydrogen", 76 | action="store_true", 77 | help="remove polar hydrogen atoms", 78 | ) 79 | parser.add_argument( 80 | "--max-atoms", 81 | type=int, 82 | default=256, 83 | help="selected maximum number of atoms in a molecule", 84 | ) 85 | parser.add_argument( 86 | "--dict-name", 87 | default="dict.txt", 88 | help="dictionary file", 89 | ) 90 | parser.add_argument( 91 | "--charge-dict-name", 92 | default="dict_charge.txt", 93 | help="dictionary file", 94 | ) 95 | parser.add_argument( 96 | "--only-polar", 97 | default=1, 98 | type=int, 99 | help="1: only reserve polar hydrogen; 0: no hydrogen; -1: all hydrogen ", 100 | ) 101 | parser.add_argument( 102 | '--split-mode', 103 | type=str, 104 | default='predefine', 105 | choices=['predefine', 'cross_valid', 'random', 'infer'], 106 | ) 107 | parser.add_argument( 108 | "--nfolds", 109 | default=5, 110 | type=int, 111 | help="cross validation split folds" 112 | ) 113 | parser.add_argument( 114 | "--fold", 115 | default=0, 116 | type=int, 117 | help='local fold used as validation set, and other folds will be used as train set' 118 | ) 119 | parser.add_argument( 120 | "--cv-seed", 121 | default=42, 122 | type=int, 123 | help="random seed used to do cross validation splits" 124 | ) 125 | 126 | def __init__(self, args, dictionary, charge_dictionary): 127 | super().__init__(args) 128 | self.dictionary = dictionary 129 | self.charge_dictionary = charge_dictionary 130 | self.seed = args.seed 131 | # add mask token 132 | self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) 133 | self.charge_mask_idx = charge_dictionary.add_symbol("[MASK]", is_special=True) 134 | if self.args.only_polar > 0: 135 | self.args.remove_polar_hydrogen = True 136 | elif self.args.only_polar < 0: 137 | self.args.remove_polar_hydrogen = False 138 | else: 139 | self.args.remove_hydrogen = True 140 | if self.args.split_mode !='predefine': 141 | self.__init_data() 142 | 143 | def __init_data(self): 144 | data_path = os.path.join(self.args.data, self.args.task_name + '.lmdb') 145 | raw_dataset = LMDBDataset(data_path) 146 | if self.args.split_mode == 'cross_valid': 147 | train_folds = [] 148 | for _fold in range(self.args.nfolds): 149 | if _fold == 0: 150 | cache_fold_info = FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds).get_fold_info() 151 | if _fold == self.args.fold: 152 | self.valid_dataset = FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds, cache_fold_info=cache_fold_info) 153 | if _fold != self.args.fold: 154 | train_folds.append(FoldLMDBDataset(raw_dataset, self.args.cv_seed, _fold, nfolds=self.args.nfolds, cache_fold_info=cache_fold_info)) 155 | self.train_dataset = StackedLMDBDataset(train_folds) 156 | elif self.args.split_mode == 'random': 157 | cache_fold_info = SplitLMDBDataset(raw_dataset, self.args.seed, 0).get_fold_info() 158 | self.train_dataset = SplitLMDBDataset(raw_dataset, self.args.seed, 0, cache_fold_info=cache_fold_info) 159 | self.valid_dataset = SplitLMDBDataset(raw_dataset, self.args.seed, 1, cache_fold_info=cache_fold_info) 160 | 161 | @classmethod 162 | def setup_task(cls, args, **kwargs): 163 | dictionary = Dictionary.load(os.path.join(args.data, args.dict_name)) 164 | logger.info("dictionary: {} types".format(len(dictionary))) 165 | charge_dictionary = Dictionary.load(os.path.join(args.data, args.charge_dict_name)) 166 | logger.info("charge dictionary: {} types".format(len(charge_dictionary))) 167 | return cls(args, dictionary, charge_dictionary) 168 | 169 | def load_dataset(self, split, **kwargs): 170 | """Load a given dataset split. 171 | Args: 172 | split (str): name of the data scoure (e.g., train) 173 | """ 174 | self.split = split 175 | if self.args.split_mode != 'predefine': 176 | if split == 'train': 177 | dataset = self.train_dataset 178 | elif split == 'valid': 179 | dataset =self.valid_dataset 180 | else: 181 | split_path = os.path.join(self.args.data, split + ".lmdb") 182 | dataset = LMDBDataset(split_path) 183 | tgt_dataset = KeyDataset(dataset, "target") 184 | if split in ['train', 'train.small']: 185 | tgt_list = [tgt_dataset[i] for i in range(len(tgt_dataset))] 186 | self.mean = sum(tgt_list) / len(tgt_list) 187 | self.std = 1 188 | elif split in ['novartis_acid', 'novartis_base', 'sampl6', 'sampl7', 'sampl8']: 189 | self.mean = 6.504894871171601 # precompute from dwar_8228 full set 190 | self.std = 1 191 | id_dataset = KeyDataset(dataset, "ori_smi") 192 | 193 | def GetPKAInput(dataset, metadata_key): 194 | mol_dataset = TTAPKADataset(dataset, self.args.seed, metadata_key, "atoms", "coordinates", "charges") 195 | idx2key = mol_dataset.get_idx2key() 196 | if split in ["train","train.small"]: 197 | sample_dataset = ConformerSamplePKADataset( 198 | mol_dataset, self.args.seed, "atoms", "coordinates", "charges" 199 | ) 200 | else: 201 | sample_dataset = TTADataset( 202 | mol_dataset, self.args.seed, "atoms", "coordinates","charges","id", self.args.conf_size 203 | ) 204 | 205 | sample_dataset = RemoveHydrogenDataset( 206 | sample_dataset, 207 | "atoms", 208 | "coordinates", 209 | "charges", 210 | self.args.remove_hydrogen, 211 | self.args.remove_polar_hydrogen, 212 | ) 213 | sample_dataset = CroppingDataset( 214 | sample_dataset, self.seed, "atoms", "coordinates","charges", self.args.max_atoms 215 | ) 216 | sample_dataset = NormalizeDataset(sample_dataset, "coordinates", normalize_coord=True) 217 | src_dataset = KeyDataset(sample_dataset, "atoms") 218 | src_dataset = TokenizeDataset( 219 | src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len 220 | ) 221 | src_charge_dataset = KeyDataset(sample_dataset, "charges") 222 | src_charge_dataset = TokenizeDataset( 223 | src_charge_dataset, self.charge_dictionary, max_seq_len=self.args.max_seq_len 224 | ) 225 | coord_dataset = KeyDataset(sample_dataset, "coordinates") 226 | 227 | def PrependAndAppend(dataset, pre_token, app_token): 228 | dataset = PrependTokenDataset(dataset, pre_token) 229 | return AppendTokenDataset(dataset, app_token) 230 | 231 | src_dataset = PrependAndAppend( 232 | src_dataset, self.dictionary.bos(), self.dictionary.eos() 233 | ) 234 | src_charge_dataset = PrependAndAppend( 235 | src_charge_dataset, self.charge_dictionary.bos(), self.charge_dictionary.eos() 236 | ) 237 | edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary)) 238 | coord_dataset = FromNumpyDataset(coord_dataset) 239 | coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0) 240 | distance_dataset = DistanceDataset(coord_dataset) 241 | 242 | return PKAInputDataset(idx2key, src_dataset, src_charge_dataset, coord_dataset, distance_dataset, edge_type, self.dictionary.pad(), self.charge_dictionary.pad(), split, self.args.conf_size) 243 | 244 | input_a_dataset = GetPKAInput(dataset, "metadata_a") 245 | input_b_dataset = GetPKAInput(dataset, "metadata_b") 246 | 247 | nest_dataset = NestedDictionaryDataset( 248 | { 249 | "net_input_a": input_a_dataset, 250 | "net_input_b": input_b_dataset, 251 | "target": { 252 | "finetune_target": RawLabelDataset(tgt_dataset), 253 | }, 254 | "id": id_dataset, 255 | }, 256 | ) 257 | 258 | if not self.args.no_shuffle and split in ["train","train.small"]: 259 | with data_utils.numpy_seed(self.args.seed): 260 | shuffle = np.random.permutation(len(id_dataset)) 261 | 262 | self.datasets[split] = SortDataset( 263 | nest_dataset, 264 | sort_order=[shuffle], 265 | ) 266 | else: 267 | self.datasets[split] = nest_dataset 268 | 269 | def build_model(self, args): 270 | from unicore import models 271 | 272 | model = models.build_model(args, self) 273 | model.register_classification_head( 274 | self.args.classification_head_name, 275 | num_classes=self.args.num_classes, 276 | ) 277 | return model 278 | --------------------------------------------------------------------------------