├── .github ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── PULL_REQUEST_TEMPLATE.md ├── SECURITY.md └── workflows │ ├── code-check.yml │ ├── doc.yml │ ├── release.yml │ └── test.yml ├── .gitignore ├── DATA_LICENSE ├── LICENSE ├── README.md ├── docs ├── api │ ├── safe.md │ ├── safe.models.md │ └── safe.viz.md ├── assets │ ├── js │ │ └── google-analytics.js │ ├── safe-construction.svg │ └── safe-tasks.svg ├── cli.md ├── data_license.md ├── index.md ├── license.md └── tutorials │ ├── design-with-safe.ipynb │ ├── extracting-representation-molfeat.ipynb │ ├── getting-started.ipynb │ ├── how-it-works.ipynb │ └── load-from-wandb.ipynb ├── env.yml ├── expts ├── config │ └── accelerate.yaml ├── scripts │ ├── slurm-data-build.sh │ ├── slurm-notebook.sh │ ├── slurm-tokenizer-train-custom.sh │ ├── slurm-tokenizer-train-small.sh │ ├── slurm-tokenizer-train.sh │ ├── train-small.sh │ └── train.sh └── tokenizer │ ├── _tokenizer-custom-mini-test.json │ └── tokenizer-custom.json ├── mkdocs.yml ├── pyproject.toml ├── safe ├── __init__.py ├── _exception.py ├── _pattern.py ├── converter.py ├── io.py ├── sample.py ├── tokenizer.py ├── trainer │ ├── __init__.py │ ├── cli.py │ ├── collator.py │ ├── configs │ │ ├── __init__.py │ │ └── default_config.json │ ├── data_utils.py │ ├── model.py │ └── trainer_utils.py ├── utils.py └── viz.py └── tests ├── test_hgf_load.py ├── test_import.py ├── test_notebooks.py └── test_safe.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @maclandrol 2 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Changelogs 2 | 3 | - _enumerate the changes of that PR._ 4 | 5 | --- 6 | 7 | _Checklist:_ 8 | 9 | - [ ] _Add tests to cover the fixed bug(s) or the new introduced feature(s) (if appropriate)._ 10 | - [ ] _Update the API documentation if a new function is added, or an existing one is deleted. Eventually consider making a new tutorial for new features._ 11 | - [ ] _Write concise and explanatory changelogs below._ 12 | - [ ] _If possible, assign one of the following labels to the PR: `feature`, `fix` or `test` (or ask a maintainer to do it for you)._ 13 | 14 | --- 15 | 16 | _discussion related to that PR_ -------------------------------------------------------------------------------- /.github/SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | Please report any security-related issues directly to emmanuel.noutahi@hotmail.ca 4 | -------------------------------------------------------------------------------- /.github/workflows/code-check.yml: -------------------------------------------------------------------------------- 1 | name: code-check 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" 7 | pull_request: 8 | branches: 9 | - "*" 10 | 11 | jobs: 12 | python-format-black: 13 | name: Python lint [black] 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout the code 17 | uses: actions/checkout@v3 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: "3.10" 23 | 24 | - name: Install black 25 | run: | 26 | pip install black>=24 27 | 28 | - name: Lint 29 | run: black --check . 30 | 31 | python-lint-ruff: 32 | name: Python lint [ruff] 33 | runs-on: ubuntu-latest 34 | steps: 35 | - name: Checkout the code 36 | uses: actions/checkout@v3 37 | 38 | - name: Set up Python 39 | uses: actions/setup-python@v4 40 | with: 41 | python-version: "3.10" 42 | 43 | - name: Install ruff 44 | run: | 45 | pip install ruff 46 | 47 | - name: Lint 48 | run: ruff check . 49 | -------------------------------------------------------------------------------- /.github/workflows/doc.yml: -------------------------------------------------------------------------------- 1 | name: doc 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | 7 | # Prevent doc action on `main` to conflict with each others. 8 | concurrency: 9 | group: doc-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | doc: 14 | runs-on: "ubuntu-latest" 15 | timeout-minutes: 30 16 | 17 | defaults: 18 | run: 19 | shell: bash -l {0} 20 | 21 | steps: 22 | - name: Checkout the code 23 | uses: actions/checkout@v3 24 | 25 | - name: Setup mamba 26 | uses: mamba-org/setup-micromamba@v1 27 | with: 28 | environment-file: env.yml 29 | environment-name: my_env 30 | cache-environment: true 31 | cache-downloads: true 32 | 33 | - name: Install library 34 | run: python -m pip install --no-deps . 35 | 36 | - name: Configure git 37 | run: | 38 | git config --global user.name "${GITHUB_ACTOR}" 39 | git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com" 40 | 41 | - name: Deploy the doc 42 | run: | 43 | echo "Get the gh-pages branch" 44 | git fetch origin gh-pages 45 | 46 | echo "Build and deploy the doc on main" 47 | mike deploy --push main 48 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | release-version: 7 | description: "A valid Semver version string" 8 | required: true 9 | 10 | permissions: 11 | contents: write 12 | pull-requests: write 13 | 14 | jobs: 15 | release: 16 | # Do not release if not triggered from the default branch 17 | if: github.ref == format('refs/heads/{0}', github.event.repository.default_branch) 18 | 19 | runs-on: ubuntu-latest 20 | timeout-minutes: 30 21 | 22 | defaults: 23 | run: 24 | shell: bash -l {0} 25 | 26 | steps: 27 | - name: Checkout the code 28 | uses: actions/checkout@v3 29 | 30 | - name: Setup mamba 31 | uses: mamba-org/setup-micromamba@v1 32 | with: 33 | environment-file: env.yml 34 | environment-name: my_env 35 | cache-environment: true 36 | cache-downloads: true 37 | create-args: >- 38 | pip 39 | semver 40 | python-build 41 | setuptools_scm 42 | 43 | - name: Check the version is valid semver 44 | run: | 45 | RELEASE_VERSION="${{ inputs.release-version }}" 46 | 47 | { 48 | pysemver check $RELEASE_VERSION 49 | } || { 50 | echo "The version '$RELEASE_VERSION' is not a valid Semver version string." 51 | echo "Please use a valid semver version string. More details at https://semver.org/" 52 | echo "The release process is aborted." 53 | exit 1 54 | } 55 | 56 | - name: Check the version is higher than the latest one 57 | run: | 58 | # Retrieve the git tags first 59 | git fetch --prune --unshallow --tags &> /dev/null 60 | 61 | RELEASE_VERSION="${{ inputs.release-version }}" 62 | LATEST_VERSION=$(git describe --abbrev=0 --tags) 63 | 64 | IS_HIGHER_VERSION=$(pysemver compare $RELEASE_VERSION $LATEST_VERSION) 65 | 66 | if [ "$IS_HIGHER_VERSION" != "1" ]; then 67 | echo "The version '$RELEASE_VERSION' is not higher than the latest version '$LATEST_VERSION'." 68 | echo "The release process is aborted." 69 | exit 1 70 | fi 71 | 72 | - name: Build Changelog 73 | id: github_release 74 | uses: mikepenz/release-changelog-builder-action@v4 75 | env: 76 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 77 | with: 78 | toTag: "main" 79 | 80 | - name: Configure git 81 | run: | 82 | git config --global user.name "${GITHUB_ACTOR}" 83 | git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com" 84 | 85 | - name: Create and push git tag 86 | env: 87 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 88 | run: | 89 | # Tag the release 90 | git tag -a "${{ inputs.release-version }}" -m "Release version ${{ inputs.release-version }}" 91 | 92 | # Checkout the git tag 93 | git checkout "${{ inputs.release-version }}" 94 | 95 | # Push the modified changelogs 96 | git push origin main 97 | 98 | # Push the tags 99 | git push origin "${{ inputs.release-version }}" 100 | 101 | - name: Install library 102 | run: python -m pip install --no-deps . 103 | 104 | - name: Build the wheel and sdist 105 | run: python -m build --no-isolation 106 | 107 | - name: Publish package to PyPI 108 | uses: pypa/gh-action-pypi-publish@release/v1 109 | with: 110 | password: ${{ secrets.PYPI_API_TOKEN }} 111 | packages-dir: dist/ 112 | 113 | - name: Create GitHub Release 114 | uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 115 | with: 116 | tag_name: ${{ inputs.release-version }} 117 | body: ${{steps.github_release.outputs.changelog}} 118 | 119 | - name: Deploy the doc 120 | run: | 121 | echo "Get the gh-pages branch" 122 | git fetch origin gh-pages 123 | 124 | echo "Build and deploy the doc on ${{ inputs.release-version }}" 125 | mike deploy --push stable 126 | mike deploy --push ${{ inputs.release-version }} 127 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | tags: ["*"] 7 | pull_request: 8 | branches: 9 | - "*" 10 | - "!gh-pages" 11 | schedule: 12 | - cron: "0 4 * * MON" 13 | 14 | jobs: 15 | test: 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.11"] 20 | 21 | runs-on: "ubuntu-latest" 22 | timeout-minutes: 30 23 | 24 | defaults: 25 | run: 26 | shell: bash -l {0} 27 | 28 | name: python=${{ matrix.python-version }} 29 | 30 | steps: 31 | - name: Checkout the code 32 | uses: actions/checkout@v3 33 | 34 | - name: Setup mamba 35 | uses: mamba-org/setup-micromamba@v1 36 | with: 37 | environment-file: env.yml 38 | environment-name: my_env 39 | cache-environment: true 40 | cache-downloads: true 41 | create-args: >- 42 | python=${{ matrix.python-version }} 43 | 44 | - name: Install library 45 | run: python -m pip install --no-deps -e . 46 | 47 | - name: Run tests 48 | run: pytest 49 | 50 | - name: Test CLI 51 | run: safe-train --help 52 | 53 | - name: Test building the doc 54 | run: mkdocs build 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.env 2 | cov.xml 3 | 4 | .vscode/ 5 | 6 | .ipynb_checkpoints/ 7 | 8 | *.py[cod] 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Packages 14 | *.egg 15 | *.egg-info 16 | dist 17 | build 18 | eggs 19 | parts 20 | bin 21 | var 22 | sdist 23 | develop-eggs 24 | .installed.cfg 25 | lib 26 | lib64 27 | 28 | # Installer logs 29 | pip-log.txt 30 | 31 | # Unit test / coverage reports 32 | .coverage* 33 | .tox 34 | nosetests.xml 35 | htmlcov 36 | 37 | # Translations 38 | *.mo 39 | 40 | # Mr Developer 41 | .mr.developer.cfg 42 | .project 43 | .pydevproject 44 | 45 | # Complexity 46 | output/*.html 47 | output/*/index.html 48 | 49 | # Sphinx 50 | docs/_build 51 | 52 | MANIFEST 53 | 54 | *.tif 55 | 56 | # Rever 57 | rever/ 58 | 59 | # Dev notebooks 60 | notebooks/ 61 | 62 | # MkDocs 63 | site/ 64 | 65 | .vscode 66 | 67 | .idea/ 68 | 69 | data/ 70 | output/ 71 | wandb/ 72 | oracle/ 73 | expts/models/ 74 | expts/dev-data/ 75 | expts/notebooks/ 76 | -------------------------------------------------------------------------------- /DATA_LICENSE: -------------------------------------------------------------------------------- 1 | # Creative Commons Attribution 4.0 International License (CC BY 4.0) 2 | 3 | This work is licensed under the Creative Commons Attribution 4.0 International License. 4 | 5 | To view a copy of this license, visit http://creativecommons.org/licenses/by/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Emmanuel Noutahi 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

:safety_vest: SAFE

2 |

Sequential Attachment-based Fragment Embedding (SAFE) is a novel molecular line notation that represents molecules as an unordered sequence of fragment blocks to improve molecule design using generative models.

3 | 4 |
5 |
6 | 7 |
8 |
9 | 10 |

11 | 12 | Paper 13 | | 14 | 15 | Docs 16 | | 17 | 18 | 🤗 Model 19 | | 20 | 21 | 🤗 Training Dataset 22 | 23 |

24 | 25 | --- 26 | 27 |
28 | 29 | [![PyPI](https://img.shields.io/pypi/v/safe-mol)](https://pypi.org/project/safe-mol/) 30 | [![Conda](https://img.shields.io/conda/v/conda-forge/safe-mol?label=conda&color=success)](https://anaconda.org/conda-forge/safe-mol) 31 | [![PyPI - Downloads](https://img.shields.io/pypi/dm/safe-mol)](https://pypi.org/project/safe-mol/) 32 | [![Conda](https://img.shields.io/conda/dn/conda-forge/safe-mol)](https://anaconda.org/conda-forge/safe-mol) 33 | [![Code license](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/datamol-io/safe/blob/main/LICENSE) 34 | [![Data License](https://img.shields.io/badge/Data%20License-CC%20BY%204.0-red.svg)](https://github.com/datamol-io/safe/blob/main/DATA_LICENSE) 35 | [![GitHub Repo stars](https://img.shields.io/github/stars/datamol-io/safe)](https://github.com/datamol-io/safe/stargazers) 36 | [![GitHub Repo stars](https://img.shields.io/github/forks/datamol-io/safe)](https://github.com/datamol-io/safe/network/members) 37 | [![arXiv](https://img.shields.io/badge/arXiv-2310.10773-b31b1b.svg)](https://arxiv.org/pdf/2310.10773.pdf) 38 | 39 | [![test](https://github.com/datamol-io/safe/actions/workflows/test.yml/badge.svg)](https://github.com/datamol-io/safe/actions/workflows/test.yml) 40 | [![release](https://github.com/datamol-io/safe/actions/workflows/release.yml/badge.svg)](https://github.com/datamol-io/safe/actions/workflows/release.yml) 41 | [![code-check](https://github.com/datamol-io/safe/actions/workflows/code-check.yml/badge.svg)](https://github.com/datamol-io/safe/actions/workflows/code-check.yml) 42 | [![doc](https://github.com/datamol-io/safe/actions/workflows/doc.yml/badge.svg)](https://github.com/datamol-io/safe/actions/workflows/doc.yml) 43 | 44 | ## Overview of SAFE 45 | 46 | SAFE _is the_ deep learning molecular representation. It's an encoding leveraging a peculiarity in the decoding schemes of SMILES, to allow representation of molecules as a contiguous sequence of connected fragments. SAFE strings are valid SMILES strings, and thus are able to preserve the same amount of information. The intuitive representation of molecules as an ordered sequence of connected fragments greatly simplifies the following tasks often encountered in molecular design: 47 | 48 | - _de novo_ design 49 | - superstructure generation 50 | - scaffold decoration 51 | - motif extension 52 | - linker generation 53 | - scaffold morphing. 54 | 55 | The construction of a SAFE strings requires defining a molecular fragmentation algorithm. By default, we use [BRICS], but any other fragmentation algorithm can be used. The image below illustrates the process of building a SAFE string. The resulting string is a valid SMILES that can be read by [datamol](https://github.com/datamol-io/datamol) or [RDKit](https://github.com/rdkit/rdkit). 56 | 57 |
58 |
59 | 60 |
61 | 62 | ## News 🚀 63 | 64 | #### 💥 2024/01/15 💥 65 | 1. [@IanAWatson](https://github.com/IanAWatson) has a C++ implementation of SAFE in [LillyMol](https://github.com/IanAWatson/LillyMol/tree/bazel_version_float) that is quite fast and use a custom fragmentation algorithm. Follow the installation instruction on the repo and checkout the docs of the CLI here: [docs/Molecule_Tools/SAFE.md](https://github.com/IanAWatson/LillyMol/blob/bazel_version_float/docs/Molecule_Tools/SAFE.md) 66 | 67 | 68 | ### Installation 69 | 70 | You can install `safe` using pip: 71 | 72 | ```bash 73 | pip install safe-mol 74 | ``` 75 | 76 | You can use conda/mamba: 77 | 78 | ```bash 79 | mamba install -c conda-forge safe-mol 80 | ``` 81 | 82 | #### 2024/11/22 83 | NOTE: Installation might cause issues like no detection of GPUs (which can be checked by `torch.cuda.is_available()`) and sengmentation error due to mismatch between installed and driver cuda versions. In that case, follow these steps: 84 | 85 | Create a new environment using conda: 86 | 87 | ```bash 88 | conda create -n env_safe python=3.12 89 | conda activate env_safe 90 | ``` 91 | 92 | Check nvidia driver version on machine by running `nvcc --version` or `nvidia-smi` commands 93 | 94 | Install pytorch with compatible cuda versions (from `https://pytorch.org/get-started/locally/`) and safe-mol: 95 | 96 | ```bash 97 | conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia 98 | conda install -c conda-forge safe-mol 99 | ``` 100 | 101 | ### Datasets and Models 102 | 103 | | Type | Name | Infos | Size | Comment | 104 | | ---------------------- | ------------------------------------------------------------------------------ | ---------- | ----- | -------------------- | 105 | | Model | [datamol-io/safe-gpt](https://huggingface.co/datamol-io/safe-gpt) | 87M params | 350M | Default model | 106 | | Training Dataset | [datamol-io/safe-gpt](https://huggingface.co/datasets/datamol-io/safe-gpt) | 1.1B rows | 250GB | Training dataset | 107 | | Drug Benchmark Dataset | [datamol-io/safe-drugs](https://huggingface.co/datasets/datamol-io/safe-drugs) | 26 rows | 20 kB | Benchmarking dataset | 108 | 109 | ## Usage 110 | 111 | Please refer to the [documentation](https://safe-docs.datamol.io/), which contains tutorials for getting started with `safe` and detailed descriptions of the functions provided, as well as an example of how to get started with SAFE-GPT. 112 | 113 | ### API 114 | 115 | We summarize some key functions provided by the `safe` package below. 116 | 117 | | Function | Description | 118 | | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | 119 | | `safe.encode` | Translates a SMILES string into its corresponding SAFE string. | 120 | | `safe.decode` | Translates a SAFE string into its corresponding SMILES string. The SAFE decoder just augment RDKit's `Chem.MolFromSmiles` with an optional correction argument to take care of missing hydrogen bonds. | 121 | | `safe.split` | Tokenizes a SAFE string to build a generative model. | 122 | 123 | ### Examples 124 | 125 | #### Translation between SAFE and SMILES representations 126 | 127 | ```python 128 | import safe 129 | 130 | ibuprofen = "CC(Cc1ccc(cc1)C(C(=O)O)C)C" 131 | 132 | # SMILES -> SAFE -> SMILES translation 133 | try: 134 | ibuprofen_sf = safe.encode(ibuprofen) # c12ccc3cc1.C3(C)C(=O)O.CC(C)C2 135 | ibuprofen_smi = safe.decode(ibuprofen_sf, canonical=True) # CC(C)Cc1ccc(C(C)C(=O)O)cc1 136 | except safe.EncoderError: 137 | pass 138 | except safe.DecoderError: 139 | pass 140 | 141 | ibuprofen_tokens = list(safe.split(ibuprofen_sf)) 142 | ``` 143 | 144 | ### Training/Finetuning a (new) model 145 | 146 | A command line interface is available to train a new model, please run `safe-train --help`. You can also provide an existing checkpoint to continue training or finetune on you own dataset. 147 | 148 | For example: 149 | 150 | ```bash 151 | safe-train --config \ 152 | --model-path \ 153 | --tokenizer \ 154 | --dataset \ 155 | --num_labels 9 \ 156 | --torch_compile True \ 157 | --optim "adamw_torch" \ 158 | --learning_rate 1e-5 \ 159 | --prop_loss_coeff 1e-3 \ 160 | --gradient_accumulation_steps 1 \ 161 | --output_dir "" \ 162 | --max_steps 5 163 | ``` 164 | 165 | ## References 166 | 167 | If you use this repository, please cite the following related [paper](https://arxiv.org/abs/2310.10773#): 168 | 169 | ```bib 170 | @misc{noutahi2023gotta, 171 | title={Gotta be SAFE: A New Framework for Molecular Design}, 172 | author={Emmanuel Noutahi and Cristian Gabellini and Michael Craig and Jonathan S. C Lim and Prudencio Tossou}, 173 | year={2023}, 174 | eprint={2310.10773}, 175 | archivePrefix={arXiv}, 176 | primaryClass={cs.LG} 177 | } 178 | ``` 179 | 180 | ## License 181 | 182 | The training dataset is licensed under CC BY 4.0. See [DATA_LICENSE](DATA_LICENSE) for details. This code base is licensed under the Apache-2.0 license. See [LICENSE](LICENSE) for details. 183 | 184 | Note that the model weights of **SAFE-GPT** are exclusively licensed for research purposes (CC BY-NC 4.0). 185 | 186 | ## Development lifecycle 187 | 188 | ### Setup dev environment 189 | 190 | ```bash 191 | mamba create -n safe -f env.yml 192 | mamba activate safe 193 | 194 | pip install --no-deps -e . 195 | ``` 196 | 197 | ### Tests 198 | 199 | You can run tests locally with: 200 | 201 | ```bash 202 | pytest 203 | ``` 204 | -------------------------------------------------------------------------------- /docs/api/safe.md: -------------------------------------------------------------------------------- 1 | ## SAFE Encoder-Decoder 2 | 3 | ::: safe.converter 4 | options: 5 | members: 6 | - encode 7 | - decode 8 | - SAFEConverter 9 | show_root_heading: false 10 | 11 | 12 | --- 13 | 14 | ## SAFE Design 15 | 16 | ::: safe.sample 17 | options: 18 | members: 19 | - SAFEDesign 20 | show_root_heading: false 21 | 22 | 23 | --- 24 | 25 | ## SAFE Tokenizer 26 | 27 | ::: safe.tokenizer 28 | options: 29 | members: 30 | - SAFESplitter 31 | - SAFETokenizer 32 | show_root_heading: false 33 | 34 | --- 35 | 36 | ## Utils 37 | 38 | ::: safe.utils -------------------------------------------------------------------------------- /docs/api/safe.models.md: -------------------------------------------------------------------------------- 1 | ## Config File 2 | 3 | The input config file for training a `SAFE` model is very similar to the GPT2 config file, with the addition of an optional `num_labels` attribute for training with descriptors regularization. 4 | 5 | ```json 6 | { 7 | "activation_function": "gelu_new", 8 | "attn_pdrop": 0.1, 9 | "bos_token_id": 10000, 10 | "embd_pdrop": 0.1, 11 | "eos_token_id": 1, 12 | "initializer_range": 0.02, 13 | "layer_norm_epsilon": 1e-05, 14 | "model_type": "gpt2", 15 | "n_embd": 768, 16 | "n_head": 12, 17 | "n_inner": null, 18 | "n_layer": 12, 19 | "n_positions": 1024, 20 | "reorder_and_upcast_attn": false, 21 | "resid_pdrop": 0.1, 22 | "scale_attn_by_inverse_layer_idx": false, 23 | "scale_attn_weights": true, 24 | "summary_activation": "tanh", 25 | "summary_first_dropout": 0.1, 26 | "summary_proj_to_labels": true, 27 | "summary_type": "cls_index", 28 | "summary_hidden_size": 128, 29 | "summary_use_proj": true, 30 | "transformers_version": "4.31.0", 31 | "use_cache": true, 32 | "vocab_size": 10000, 33 | "num_labels": 9 34 | } 35 | ``` 36 | 37 | 38 | ## SAFE Model 39 | ::: safe.trainer.model 40 | 41 | --- 42 | 43 | ## Trainer 44 | ::: safe.trainer.trainer_utils 45 | 46 | --- 47 | 48 | ## Data Collator 49 | ::: safe.trainer.collator 50 | 51 | --- 52 | 53 | ## Data Utils 54 | ::: safe.trainer.data_utils 55 | 56 | 57 | -------------------------------------------------------------------------------- /docs/api/safe.viz.md: -------------------------------------------------------------------------------- 1 | ::: safe.viz 2 | options: 3 | members: 4 | - to_image 5 | show_root_heading: false -------------------------------------------------------------------------------- /docs/assets/js/google-analytics.js: -------------------------------------------------------------------------------- 1 | var gtag_id = "G-3XLJELJ2TF"; 2 | 3 | var script = document.createElement("script"); 4 | script.src = "https://www.googletagmanager.com/gtag/js?id=" + gtag_id; 5 | document.head.appendChild(script); 6 | 7 | window.dataLayer = window.dataLayer || []; 8 | function gtag(){dataLayer.push(arguments);} 9 | gtag('js', new Date()); 10 | gtag('config', gtag_id); 11 | -------------------------------------------------------------------------------- /docs/data_license.md: -------------------------------------------------------------------------------- 1 | ``` 2 | {!DATA_LICENSE!} 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 |

🦺 SAFE

2 |

Sequential Attachment-based Fragment Embedding (SAFE) is a novel molecular line notation that represents molecules as an unordered sequence of fragment blocks to improve molecule design using generative models.

3 | 4 |
5 |
6 | 7 |
8 |
9 | 10 |

11 | 12 | Paper 13 | | 14 | 15 | Docs 16 | | 17 | 18 | 🤗 Model 19 | | 20 | 21 | 🤗 Training Dataset 22 | 23 |

24 | 25 | --- 26 | 27 |
28 | 29 | [![PyPI](https://img.shields.io/pypi/v/safe-mol)](https://pypi.org/project/safe-mol/) 30 | [![Conda](https://img.shields.io/conda/v/conda-forge/safe-mol?label=conda&color=success)](https://anaconda.org/conda-forge/safe-mol) 31 | [![PyPI - Downloads](https://img.shields.io/pypi/dm/safe-mol)](https://pypi.org/project/safe-mol/) 32 | [![Conda](https://img.shields.io/conda/dn/conda-forge/safe-mol)](https://anaconda.org/conda-forge/safe-mol) 33 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/safe-mol)](https://pypi.org/project/safe-mol/) 34 | [![Code license](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/datamol-io/safe/blob/main/LICENSE) 35 | [![Data License](https://img.shields.io/badge/Data%20License-CC%20BY%204.0-red.svg)](https://github.com/datamol-io/safe/blob/main/DATA_LICENSE) 36 | [![GitHub Repo stars](https://img.shields.io/github/stars/datamol-io/safe)](https://github.com/datamol-io/safe/stargazers) 37 | [![GitHub Repo stars](https://img.shields.io/github/forks/datamol-io/safe)](https://github.com/datamol-io/safe/network/members) 38 | [![arXiv](https://img.shields.io/badge/arXiv-2310.10773-b31b1b.svg)](https://arxiv.org/pdf/2310.10773.pdf) 39 | 40 | [![test](https://github.com/datamol-io/safe/actions/workflows/test.yml/badge.svg)](https://github.com/datamol-io/safe/actions/workflows/test.yml) 41 | [![release](https://github.com/datamol-io/safe/actions/workflows/release.yml/badge.svg)](https://github.com/datamol-io/safe/actions/workflows/release.yml) 42 | [![code-check](https://github.com/datamol-io/safe/actions/workflows/code-check.yml/badge.svg)](https://github.com/datamol-io/safe/actions/workflows/code-check.yml) 43 | [![doc](https://github.com/datamol-io/safe/actions/workflows/doc.yml/badge.svg)](https://github.com/datamol-io/safe/actions/workflows/doc.yml) 44 | 45 | ## Overview of SAFE 46 | 47 | SAFE _is the_ deep learning molecular representation. It's an encoding leveraging a peculiarity in the decoding schemes of SMILES, to allow representation of molecules as a contiguous sequence of connected fragments. SAFE strings are valid SMILES strings, and thus are able to preserve the same amount of information. The intuitive representation of molecules as an ordered sequence of connected fragments greatly simplifies the following tasks often encountered in molecular design: 48 | 49 | - _de novo_ design 50 | - superstructure generation 51 | - scaffold decoration 52 | - motif extension 53 | - linker generation 54 | - scaffold morphing. 55 | 56 | The construction of a SAFE strings requires defining a molecular fragmentation algorithm. By default, we use [BRICS], but any other fragmentation algorithm can be used. The image below illustrates the process of building a SAFE string. The resulting string is a valid SMILES that can be read by [datamol](https://github.com/datamol-io/datamol) or [RDKit](https://github.com/rdkit/rdkit). 57 | 58 |
59 |
60 | 61 |
62 | 63 | 64 | ## News 🚀 65 | 66 | #### 💥 2024/01/15 💥 67 | 1. [@IanAWatson](https://github.com/IanAWatson) has a C++ implementation of SAFE in [LillyMol](https://github.com/IanAWatson/LillyMol/tree/bazel_version_float) that is quite fast and use a custom fragmentation algorithm. Follow the installation instruction on the repo and checkout the docs of the CLI here: [docs/Molecule_Tools/SAFE.md](https://github.com/IanAWatson/LillyMol/blob/bazel_version_float/docs/Molecule_Tools/SAFE.md) 68 | 69 | ### Installation 70 | 71 | You can install `safe` using pip: 72 | 73 | ```bash 74 | pip install safe-mol 75 | ``` 76 | 77 | You can use conda/mamba: 78 | 79 | ```bash 80 | mamba install -c conda-forge safe-mol 81 | ``` 82 | 83 | ### Datasets and Models 84 | 85 | | Type | Name | Infos | Size | Comment | 86 | | ---------------------- | ------------------------------------------------------------------------------ | ---------- | ----- | -------------------- | 87 | | Model | [datamol-io/safe-gpt](https://huggingface.co/datamol-io/safe-gpt) | 87M params | 350M | Default model | 88 | | Training Dataset | [datamol-io/safe-gpt](https://huggingface.co/datasets/datamol-io/safe-gpt) | 1.1B rows | 250GB | Training dataset | 89 | | Drug Benchmark Dataset | [datamol-io/safe-drugs](https://huggingface.co/datasets/datamol-io/safe-drugs) | 26 rows | 20 kB | Benchmarking dataset | 90 | 91 | ## Usage 92 | 93 | 94 | The tutorials in the [documentation](https://safe-docs.datamol.io/) can help you get started with `safe` and `SAFE-GPT`. 95 | 96 | ### API 97 | 98 | We summarize some key functions provided by the `safe` package below. 99 | 100 | | Function | Description | 101 | | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 102 | | `safe.encode` | Translates a SMILES string into its corresponding SAFE string. | 103 | | `safe.decode` | Translates a SAFE string into its corresponding SMILES string. The SAFE decoder just augment RDKit's `Chem.MolFromSmiles` with an optional correction argument to take care of missing hydrogens bonds. | 104 | | `safe.split` | Tokenizes a SAFE string to build a generative model. | 105 | 106 | ### Examples 107 | 108 | #### Translation between SAFE and SMILES representations 109 | 110 | ```python 111 | import safe 112 | 113 | ibuprofen = "CC(Cc1ccc(cc1)C(C(=O)O)C)C" 114 | 115 | # SMILES -> SAFE -> SMILES translation 116 | try: 117 | ibuprofen_sf = safe.encode(ibuprofen) # c12ccc3cc1.C3(C)C(=O)O.CC(C)C2 118 | ibuprofen_smi = safe.decode(ibuprofen_sf, canonical=True) # CC(C)Cc1ccc(C(C)C(=O)O)cc1 119 | except safe.EncoderError: 120 | pass 121 | except safe.DecoderError: 122 | pass 123 | 124 | ibuprofen_tokens = list(safe.split(ibuprofen_sf)) 125 | ``` 126 | 127 | ### Training/Finetuning a (new) model 128 | 129 | A command line interface is available to train a new model, please run `safe-train --help`. You can also provide an existing checkpoint to continue training or finetune on you own dataset. 130 | 131 | For example: 132 | 133 | ```bash 134 | safe-train --config \ 135 | --model-path \ 136 | --tokenizer \ 137 | --dataset \ 138 | --num_labels 9 \ 139 | --torch_compile True \ 140 | --optim "adamw_torch" \ 141 | --learning_rate 1e-5 \ 142 | --prop_loss_coeff 1e-3 \ 143 | --gradient_accumulation_steps 1 \ 144 | --output_dir "" \ 145 | --max_steps 5 146 | ``` 147 | 148 | 149 | ## References 150 | 151 | If you use this repository, please cite the following related [paper](https://arxiv.org/abs/2310.10773#): 152 | 153 | ```bib 154 | @misc{noutahi2023gotta, 155 | title={Gotta be SAFE: A New Framework for Molecular Design}, 156 | author={Emmanuel Noutahi and Cristian Gabellini and Michael Craig and Jonathan S. C Lim and Prudencio Tossou}, 157 | year={2023}, 158 | eprint={2310.10773}, 159 | archivePrefix={arXiv}, 160 | primaryClass={cs.LG} 161 | } 162 | ``` 163 | 164 | ## License 165 | 166 | Note that all data and model weights of **SAFE** are exclusively licensed for research purposes. The accompanying dataset is licensed under CC BY 4.0, which permits solely non-commercial usage. See [DATA_LICENSE](data_license.md) for details. 167 | 168 | This code base is licensed under the Apache-2.0 license. See [LICENSE](license.md) for details. 169 | 170 | ## Development lifecycle 171 | 172 | ### Setup dev environment 173 | 174 | ```bash 175 | mamba create -n safe -f env.yml 176 | mamba activate safe 177 | 178 | pip install --no-deps -e . 179 | ``` 180 | 181 | ### Tests 182 | 183 | You can run tests locally with: 184 | 185 | ```bash 186 | pytest 187 | ``` 188 | -------------------------------------------------------------------------------- /docs/license.md: -------------------------------------------------------------------------------- 1 | ``` 2 | {!LICENSE!} 3 | ``` 4 | -------------------------------------------------------------------------------- /docs/tutorials/load-from-wandb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from safe.sample import SAFEDesign" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "/Users/emmanuel.noutahi/miniconda3/envs/safe/lib/python3.12/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", 32 | " warnings.warn(\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "model = SAFEDesign.load_default()" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Upload models to wandb\n", 45 | "\n", 46 | "SAFE models can be uploaded to wandb with the `upload_to_wandb` function. You can define a general \"SAFE_WANDB_PROJECT\" env variable to save all of your models to that project. \n", 47 | "\n", 48 | "Make sure that you are login into your wandb account:\n", 49 | "\n", 50 | "```bash\n", 51 | "wandb login --relogin $WANDB_API_KEY\n", 52 | "```" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "from safe.io import upload_to_wandb" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 5, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "env: WANDB_SILENT=False\n", 74 | "env: SAFE_WANDB_PROJECT=safe-models\n", 75 | "[2024-09-10 13:42:46,004] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to mps (auto detect)\n" 76 | ] 77 | }, 78 | { 79 | "name": "stderr", 80 | "output_type": "stream", 81 | "text": [ 82 | "W0910 13:42:46.257000 8343047168 torch/distributed/elastic/multiprocessing/redirects.py:28] NOTE: Redirects are currently not supported in Windows or MacOs.\n", 83 | "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", 84 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmaclandrol\u001b[0m (\u001b[33mvalencelabs\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" 85 | ] 86 | }, 87 | { 88 | "data": { 89 | "text/html": [ 90 | "wandb version 0.17.9 is available! To upgrade, please run:\n", 91 | " $ pip install wandb --upgrade" 92 | ], 93 | "text/plain": [ 94 | "" 95 | ] 96 | }, 97 | "metadata": {}, 98 | "output_type": "display_data" 99 | }, 100 | { 101 | "data": { 102 | "text/html": [ 103 | "Tracking run with wandb version 0.16.6" 104 | ], 105 | "text/plain": [ 106 | "" 107 | ] 108 | }, 109 | "metadata": {}, 110 | "output_type": "display_data" 111 | }, 112 | { 113 | "data": { 114 | "text/html": [ 115 | "Run data is saved locally in /Users/emmanuel.noutahi/Code/safe/nb/wandb/run-20240910_134247-72wmn5st" 116 | ], 117 | "text/plain": [ 118 | "" 119 | ] 120 | }, 121 | "metadata": {}, 122 | "output_type": "display_data" 123 | }, 124 | { 125 | "data": { 126 | "text/html": [ 127 | "Syncing run absurd-disco-1 to Weights & Biases (docs)
" 128 | ], 129 | "text/plain": [ 130 | "" 131 | ] 132 | }, 133 | "metadata": {}, 134 | "output_type": "display_data" 135 | }, 136 | { 137 | "data": { 138 | "text/html": [ 139 | " View project at https://wandb.ai/valencelabs/safe-models" 140 | ], 141 | "text/plain": [ 142 | "" 143 | ] 144 | }, 145 | "metadata": {}, 146 | "output_type": "display_data" 147 | }, 148 | { 149 | "data": { 150 | "text/html": [ 151 | " View run at https://wandb.ai/valencelabs/safe-models/runs/72wmn5st" 152 | ], 153 | "text/plain": [ 154 | "" 155 | ] 156 | }, 157 | "metadata": {}, 158 | "output_type": "display_data" 159 | }, 160 | { 161 | "name": "stderr", 162 | "output_type": "stream", 163 | "text": [ 164 | "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (/var/folders/rl/wwcfdj4x0pg293bfqszl970r0000gq/T/tmpy8r3mrzg)... Done. 0.6s\n" 165 | ] 166 | }, 167 | { 168 | "data": { 169 | "application/vnd.jupyter.widget-view+json": { 170 | "model_id": "a9466f1c60dc43d18aceca8b74f753f4", 171 | "version_major": 2, 172 | "version_minor": 0 173 | }, 174 | "text/plain": [ 175 | "VBox(children=(Label(value='333.221 MB of 333.221 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" 176 | ] 177 | }, 178 | "metadata": {}, 179 | "output_type": "display_data" 180 | }, 181 | { 182 | "data": { 183 | "text/html": [ 184 | " View run absurd-disco-1 at: https://wandb.ai/valencelabs/safe-models/runs/72wmn5st
View project at: https://wandb.ai/valencelabs/safe-models
Synced 6 W&B file(s), 0 media file(s), 5 artifact file(s) and 0 other file(s)" 185 | ], 186 | "text/plain": [ 187 | "" 188 | ] 189 | }, 190 | "metadata": {}, 191 | "output_type": "display_data" 192 | }, 193 | { 194 | "data": { 195 | "text/html": [ 196 | "Find logs at: ./wandb/run-20240910_134247-72wmn5st/logs" 197 | ], 198 | "text/plain": [ 199 | "" 200 | ] 201 | }, 202 | "metadata": {}, 203 | "output_type": "display_data" 204 | } 205 | ], 206 | "source": [ 207 | "%env WANDB_SILENT=False\n", 208 | "%env SAFE_WANDB_PROJECT=safe-models\n", 209 | "\n", 210 | "upload_to_wandb(model.model, model.tokenizer, artifact_name=\"default-safe-zinc\", slicer=\"BRICS/Partition\", aliases=[\"paper\"])" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "## Loading models from wandb" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 3, 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "name": "stdout", 227 | "output_type": "stream", 228 | "text": [ 229 | "env: SAFE_MODEL_ROOT=/Users/emmanuel.noutahi/.cache/wandb/safe/\n" 230 | ] 231 | }, 232 | { 233 | "name": "stderr", 234 | "output_type": "stream", 235 | "text": [ 236 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact default-safe-zinc:latest, 333.22MB. 5 files... \n", 237 | "\u001b[34m\u001b[1mwandb\u001b[0m: 5 of 5 files downloaded. \n", 238 | "Done. 0:0:0.7\n" 239 | ] 240 | } 241 | ], 242 | "source": [ 243 | "%env SAFE_MODEL_ROOT=/Users/emmanuel.noutahi/.cache/wandb/safe/\n", 244 | "designer = SAFEDesign.load_from_wandb(\"safe-models/default-safe-zinc\")" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 4, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "data": { 254 | "application/vnd.jupyter.widget-view+json": { 255 | "model_id": "96de1483b3894961a8ec2690df4f8ace", 256 | "version_major": 2, 257 | "version_minor": 0 258 | }, 259 | "text/plain": [ 260 | " 0%| | 0/1 [00:001` or unset `early_stopping`.\n", 273 | " warnings.warn(\n" 274 | ] 275 | }, 276 | { 277 | "data": { 278 | "text/plain": [ 279 | "['C[C@]1(C(=O)N2CCC[C@@H](NC(=O)C#CC3CC3)CC2)CCNC1=O',\n", 280 | " 'CN(C(=O)CN1CC[NH+](C[C@@H](O)Cn2cc([N+](=O)[O-])cn2)CC1)c1ccccc1',\n", 281 | " 'CC[C@@H](C)C[C@@H]([NH3+])C(=O)N(CC)C[C@@H]1CCOC1',\n", 282 | " 'Cc1nnc2n1C[C@H](CNC(=O)Nc1cc(Cl)ccc1Cl)CC2',\n", 283 | " 'CCc1cccc(CC)c1NC(=O)[C@H](C)OC(=O)CCc1nc2ccccc2o1',\n", 284 | " 'Cc1cc(OC[C@H](O)C[NH2+]C[C@@H]2C[C@H](O)CN2Cc2ccccc2)ccc1F',\n", 285 | " 'Cc1c(Cl)cccc1N=C(O)CN=C(O)COC(=O)c1csc(-c2ccccc2)n1',\n", 286 | " 'CCc1nc(CCNC(=O)N[C@@H]2CCc3nnnn3CC2)cs1',\n", 287 | " 'C[C@@]1(C(=O)N[C@H]2CCCCCN(C(=O)c3cc(C4CC4)no3)C2)C=CCC1',\n", 288 | " 'Cc1cc(-c2cc(-c3cnn(C)c3)c3c(N)ncnc3n2)ccc1F']" 289 | ] 290 | }, 291 | "execution_count": 4, 292 | "metadata": {}, 293 | "output_type": "execute_result" 294 | } 295 | ], 296 | "source": [ 297 | "designer.de_novo_generation(10)" 298 | ] 299 | } 300 | ], 301 | "metadata": { 302 | "kernelspec": { 303 | "display_name": "safe", 304 | "language": "python", 305 | "name": "python3" 306 | }, 307 | "language_info": { 308 | "codemirror_mode": { 309 | "name": "ipython", 310 | "version": 3 311 | }, 312 | "file_extension": ".py", 313 | "mimetype": "text/x-python", 314 | "name": "python", 315 | "nbconvert_exporter": "python", 316 | "pygments_lexer": "ipython3", 317 | "version": "3.12.5" 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 2 322 | } 323 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | 4 | dependencies: 5 | - python >=3.9 6 | - pip 7 | - tqdm 8 | - loguru 9 | - typer 10 | - universal_pathlib 11 | 12 | # Scientific 13 | - datamol 14 | - pandas <=2.1.1 15 | - numpy 16 | - pytorch >=2.0 17 | - transformers 18 | - datasets 19 | - tokenizers 20 | - accelerate >=0.33 # for accelerator_config update 21 | - evaluate 22 | - wandb 23 | - huggingface_hub 24 | 25 | # Optional 26 | - deepspeed 27 | 28 | # dev 29 | - black >=24 30 | - ruff 31 | - pytest >=6.0 32 | - jupyterlab 33 | - nbconvert 34 | - ipywidgets 35 | 36 | - pip: 37 | - mkdocs <1.6.0 38 | - mkdocs-material >=7.1.1 39 | - mkdocs-material-extensions 40 | - mkdocstrings 41 | - mkdocstrings-python 42 | - mkdocs-jupyter 43 | - markdown-include 44 | - mdx_truly_sane_lists 45 | - mike >=1.0.0 46 | -------------------------------------------------------------------------------- /expts/config/accelerate.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 2 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: cpu 6 | offload_param_device: cpu 7 | zero3_init_flag: true 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | dynamo_config: 12 | dynamo_backend: INDUCTOR 13 | machine_rank: 0 14 | main_training_function: main 15 | mixed_precision: 'no' 16 | num_machines: 1 17 | num_processes: 2 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /expts/scripts/slurm-data-build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## Name of your SLURM job 4 | #SBATCH --job-name=jupyter 5 | #SBATCH --cpus-per-task=64 6 | #SBATCH --mem=200G 7 | #SBATCH --time=1-12:00 8 | 9 | set -ex 10 | # The below env variables can eventually help setting up your workload. 11 | # In a SLURM job, you CANNOT use `conda activate` and instead MUST use: 12 | source activate safe 13 | python scripts/build_dataset.py -------------------------------------------------------------------------------- /expts/scripts/slurm-notebook.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## Name of your SLURM job 4 | #SBATCH --job-name=jupyter 5 | #SBATCH --cpus-per-task=32 6 | #SBATCH --mem=100G 7 | #SBATCH --time=1-12:00 8 | 9 | set -ex 10 | # The below env variables can eventually help setting up your workload. 11 | # In a SLURM job, you CANNOT use `conda activate` and instead MUST use: 12 | source activate safe 13 | jupyter lab /home/emmanuel/safe/ -------------------------------------------------------------------------------- /expts/scripts/slurm-tokenizer-train-custom.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## Name of your SLURM job 4 | #SBATCH --job-name=split-train-safe-tokenizer 5 | #SBATCH --output=/home/emmanuel/safe/expts/output/job_split_%x_%a.out 6 | #SBATCH --error=/home/emmanuel/safe/expts/output/job_split_%x_%a.out 7 | #SBATCH --open-mode=append 8 | #SBATCH --cpus-per-task=32 9 | #SBATCH --mem=80G 10 | #SBATCH --time=48:00:00 11 | 12 | set -ex 13 | # The below env variables can eventually help setting up your workload. 14 | # In a SLURM job, you CANNOT use `conda activate` and instead MUST use: 15 | source activate safe 16 | 17 | TOKENIZER_TYPE="bpe" 18 | DEFAULT_DATASET="/storage/shared_data/cristian/preprocessed_zinc_unichem/train_filtered/" 19 | DATASET="${1:-$DEFAULT_DATASET}" 20 | #DATASET="/home/emmanuel/safe/expts/notebook/tmp_data/proc_data" 21 | OUTPUT="/home/emmanuel/safe/expts/tokenizer/tokenizer-custom.json" 22 | VOCAB_SIZE="10000" 23 | TEXT_COLUMN="input" 24 | BATCH_SIZE="1000" 25 | N_EXAMPLES="500000000" 26 | 27 | python scripts/tokenizer_trainer.py --tokenizer_type $TOKENIZER_TYPE \ 28 | --dataset $DATASET --text_column $TEXT_COLUMN \ 29 | --vocab_size $VOCAB_SIZE --batch_size $BATCH_SIZE \ 30 | --outfile $OUTPUT --splitter 'safe' --n_examples $N_EXAMPLES --tokenizer_name "safe-custom" -------------------------------------------------------------------------------- /expts/scripts/slurm-tokenizer-train-small.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## Name of your SLURM job 4 | #SBATCH --job-name=small-safe-tokenizer 5 | #SBATCH --output=/home/emmanuel/safe/expts/output/job_test_%x_%a.out 6 | #SBATCH --error=/home/emmanuel/safe/expts/output/job_test_%x_%a.out 7 | #SBATCH --open-mode=append 8 | #SBATCH --cpus-per-task=16 9 | #SBATCH --mem=32G 10 | #SBATCH --time=2:00:00 11 | 12 | set -ex 13 | # The below env variables can eventually help setting up your workload. 14 | # In a SLURM job, you CANNOT use `conda activate` and instead MUST use: 15 | source activate safe 16 | 17 | TOKENIZER_TYPE="bpe" 18 | DEFAULT_DATASET="/storage/shared_data/cristian/preprocessed_zinc_unichem/train_filtered/" 19 | DATASET="${1:-$DEFAULT_DATASET}" 20 | #DATASET="/home/emmanuel/safe/expts/notebook/tmp_data/proc_data" 21 | OUTPUT="/home/emmanuel/safe/expts/tokenizer/tokenizer-custom-test.json" 22 | VOCAB_SIZE="10000" 23 | TEXT_COLUMN="input" 24 | BATCH_SIZE="1000" 25 | N_EXAMPLES="1000000" 26 | 27 | python scripts/tokenizer_trainer.py --tokenizer_type $TOKENIZER_TYPE \ 28 | --dataset $DATASET --text_column $TEXT_COLUMN \ 29 | --vocab_size $VOCAB_SIZE --batch_size $BATCH_SIZE \ 30 | --outfile $OUTPUT --n_examples $N_EXAMPLES -------------------------------------------------------------------------------- /expts/scripts/slurm-tokenizer-train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## Name of your SLURM job 4 | #SBATCH --job-name=bpe-train-safe-tokenizer 5 | #SBATCH --output=/home/emmanuel/safe/expts/output/job_bpe_%x_%a.out 6 | #SBATCH --error=/home/emmanuel/safe/expts/output/job_bpe_%x_%a.out 7 | #SBATCH --open-mode=append 8 | #SBATCH --cpus-per-task=32 9 | #SBATCH --mem=80G 10 | #SBATCH --time=48:00:00 11 | 12 | set -ex 13 | # The below env variables can eventually help setting up your workload. 14 | # In a SLURM job, you CANNOT use `conda activate` and instead MUST use: 15 | source activate safe 16 | 17 | TOKENIZER_TYPE="bpe" 18 | DEFAULT_DATASET="/storage/shared_data/cristian/preprocessed_zinc_unichem/train_filtered/" 19 | DATASET="${1:-$DEFAULT_DATASET}" 20 | #DATASET="/home/emmanuel/safe/expts/notebook/tmp_data/proc_data" 21 | OUTPUT="/home/emmanuel/safe/expts/tokenizer/tokenizer.json" 22 | VOCAB_SIZE="10000" 23 | TEXT_COLUMN="input" 24 | BATCH_SIZE="1500" 25 | N_EXAMPLES="500000000" 26 | 27 | python scripts/tokenizer_trainer.py --tokenizer_type $TOKENIZER_TYPE \ 28 | --dataset $DATASET --text_column $TEXT_COLUMN \ 29 | --vocab_size $VOCAB_SIZE --batch_size $BATCH_SIZE \ 30 | --outfile $OUTPUT --n_examples $N_EXAMPLES -------------------------------------------------------------------------------- /expts/scripts/train-small.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file config/accelerate.yaml \ 2 | scripts/model_trainer.py --tokenizer "tokenizer/tokenizer-custom.json" \ 3 | --dataset data/ --text_column "input" \ 4 | --is_tokenized False --streaming True \ 5 | --num_labels 1 --include_descriptors False \ 6 | --gradient_accumulation_steps 2 --wandb_watch 'gradients' \ 7 | --per_device_train_batch_size 32 --num_train_epochs 5 --save_steps 2000 --save_total_limit 10 \ 8 | --eval_accumulation_steps 100 --logging_steps 200 --logging_first_step True \ 9 | --save_safetensors True --do_train True --output_dir output/test/ \ 10 | --learning_rate 5e-4 --warmup_steps 1000 --gradient_checkpointing True --max_steps 15_000 11 | -------------------------------------------------------------------------------- /expts/scripts/train.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file config/accelerate.yaml \ 2 | scripts/model_trainer.py --tokenizer "tokenizer/tokenizer-custom.json" \ 3 | --dataset ~/data/ --text_column "input" \ 4 | --is_tokenized False --streaming True \ 5 | --num_labels 1 --include_descriptors False \ 6 | --gradient_accumulation_steps 2 --wandb_watch 'gradients' \ 7 | --per_device_train_batch_size 64 --num_train_epochs 2 --save_steps 5000 --save_total_limit 10 \ 8 | --eval_accumulation_steps 100 --logging_steps 500 --logging_first_step True \ 9 | --save_safetensors True --do_train True --output_dir output/safe/ \ 10 | --learning_rate 5e-5 --warmup_steps 2500 --gradient_checkpointing True --max_steps 30_000_000 -------------------------------------------------------------------------------- /expts/tokenizer/_tokenizer-custom-mini-test.json: -------------------------------------------------------------------------------- 1 | {"version": "1.0", "truncation": null, "padding": null, "added_tokens": [{"id": 0, "content": "[PAD]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}, {"id": 1, "content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}, {"id": 2, "content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}, {"id": 3, "content": "[MASK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}, {"id": 4, "content": "[UNK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}], "normalizer": null, "pre_tokenizer": {"type": "Whitespace"}, "post_processor": {"type": "TemplateProcessing", "single": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}], "pair": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 1}}, {"SpecialToken": {"id": "[SEP]", "type_id": 1}}], "special_tokens": {"[CLS]": {"id": "[CLS]", "ids": [1], "tokens": ["[CLS]"]}, "[SEP]": {"id": "[SEP]", "ids": [2], "tokens": ["[SEP]"]}}}, "decoder": {"type": "BPEDecoder", "suffix": ""}, "model": {"type": "BPE", "dropout": null, "unk_token": "[UNK]", "continuing_subword_prefix": null, "end_of_word_suffix": null, "fuse_unk": false, "byte_fallback": false, "vocab": {"[UNK]": 0, "[CLS]": 1, "[SEP]": 2, "[PAD]": 3, "[MASK]": 4, "#": 5, "%": 6, "(": 7, ")": 8, "+": 9, "-": 10, ".": 11, "/": 12, "0": 13, "1": 14, "2": 15, "3": 16, "4": 17, "5": 18, "6": 19, "7": 20, "8": 21, "9": 22, "=": 23, "@": 24, "A": 25, "B": 26, "C": 27, "D": 28, "E": 29, "F": 30, "G": 31, "H": 32, "I": 33, "K": 34, "L": 35, "M": 36, "N": 37, "O": 38, "P": 39, "R": 40, "S": 41, "T": 42, "U": 43, "V": 44, "W": 45, "X": 46, "Y": 47, "Z": 48, "[": 49, "\\": 50, "]": 51, "a": 52, "b": 53, "c": 54, "d": 55, "e": 56, "f": 57, "g": 58, "h": 59, "i": 60, "l": 61, "m": 62, "n": 63, "o": 64, "p": 65, "r": 66, "s": 67, "t": 68, "u": 69, "y": 70, "H]": 71, "[C": 72, "%1": 73, "[C@": 74, "[C@@": 75, "[C@H]": 76, "[C@@H]": 77, "Cl": 78, "%10": 79, "%11": 80, "[n": 81, "+]": 82, "[N": 83, "[nH]": 84, "-]": 85, "%12": 86, "Br": 87, "[O": 88, "[O-]": 89, "[N+]": 90, "%13": 91, "%2": 92, "%14": 93, "[C@]": 94, "[NH": 95, "%15": 96, "[C@@]": 97, "%16": 98, "[S": 99, "[NH+]": 100, "[Si": 101, "%17": 102, "[N-]": 103, "[Si]": 104, "%18": 105, "%3": 106, "%19": 107, "H+]": 108, "2+]": 109, "[nH+]": 110, "[NH2+]": 111, "%20": 112, "%21": 113, "[2": 114, "[2H]": 115, "%22": 116, "%23": 117, "[n+]": 118, "%4": 119, "%24": 120, "[CH": 121, "2]": 122, "%25": 123, "[C-]": 124, "%26": 125, "%27": 126, "[B": 127, "%28": 128, "[P": 129, "[c": 130, "[Na": 131, "%5": 132, "%29": 133, "[Cl": 134, "+2]": 135, "[Na+]": 136, "%30": 137, "[C]": 138, "[Cl-]": 139, "%31": 140, "%32": 141, "3+]": 142, "[NH3+]": 143, "@]": 144, "[O+]": 145, "[CH-]": 146, "3]": 147, "%33": 148, "[c-]": 149, "[CH2]": 150, "%34": 151, "[CH]": 152, "e]": 153, "[I": 154, "%6": 155, "%35": 156, "[B+]": 157, "%36": 158, "[SiH": 159, "%37": 160, "[O]": 161, "[n-]": 162, "%38": 163, "2-]": 164, "[P+]": 165, "[B-]": 166, "[Z": 167, "[CH2-]": 168, "@@]": 169, "%39": 170, "[B]": 171, "[Si-]": 172, "[A": 173, "[H": 174, "[SiH]": 175, "+3]": 176, "%40": 177, "[S-]": 178, "%7": 179, "%41": 180, "[1": 181, "[c]": 182, "[F": 183, "[K": 184, "[Se]": 185, "[S@]": 186, "%42": 187, "[S+]": 188, "H-]": 189, "[SiH2]": 190, "[T": 191, "[Br": 192, "r]": 193, "[S@@]": 194, "[K+]": 195, "%43": 196, "[Y": 197, "[HH]": 198, "[Y]": 199, "[Br-]": 200, "%44": 201, "[N]": 202, "[cH-]": 203, "[R": 204, "[H]": 205, "[OH+]": 206, "%45": 207, "3-]": 208, "[PH]": 209, "n]": 210, "%8": 211, "[NH-]": 212, "[CH3-]": 213, "%46": 214, "[Pt": 215, "[I-]": 216, "[L": 217, "[CH+]": 218, "%47": 219, "[Zr": 220, "[Li": 221, "[C+]": 222, "[Ti": 223, "[13": 224, "%48": 225, "[SiH3]": 226, "o+]": 227, "[o+]": 228, "%49": 229, "[Na]": 230, "[M": 231, "[CH3]": 232, "[I+]": 233, "[U": 234, "%50": 235, "[Ir]": 236, "%9": 237, "[G": 238, "[Li+]": 239, "%51": 240, "[Cu": 241, "4]": 242, "n+2]": 243, "%52": 244, "[Zr+2]": 245, "+4]": 246, "[Sn]": 247, "[As": 248, "[13C": 249, "%53": 250, "%54": 251, "[Co": 252, "%55": 253, "[s": 254, "%56": 255, "[Fe": 256, "[U+2]": 257, "[Al": 258, "%57": 259, "[Pt+2]": 260, "[SH]": 261, "c]": 262, "[P@]": 263, "[Pd": 264, "%58": 265, "[Ni": 266, "[W": 267, "%59": 268, "[As]": 269, "[Cl+3]": 270, "[Ge]": 271, "[Fe+2]": 272, "[P@@]": 273, "[3": 274, "[Zn+2]": 275, "H2]": 276, "[3H]": 277, "[Ru": 278, "[Cl+]": 279, "[Pt]": 280, "[F-]": 281, "%60": 282, "[Zr]": 283, "[se]": 284, "%61": 285, "[W]": 286, "[Cu+2]": 287, "a+2]": 288, "%62": 289, "@+]": 290, "[Mg": 291, "[Co]": 292, "[Si+]": 293, "%63": 294, "[Ac]": 295, "H2+]": 296, "[CH2+]": 297, "%64": 298, "[Pd+2]": 299, "[F+]": 300, "[Ti]": 301, "%65": 302, "[P-]": 303, "[Ti+2]": 304, "-2]": 305, "[Hf": 306, "[OH2+]": 307, "[Fe]": 308, "[Mg+2]": 309, "[Sn": 310, "%66": 311, "[Ca+2]": 312, "[Te]": 313, "[Ni+2]": 314, "r+3]": 315, "b]": 316, "%67": 317, "[Re]": 318, "[Ni]": 319, "[V": 320, "[Al]": 321, "[Rh": 322, "[13C]": 323, "[NH]": 324, "%68": 325, "%69": 326, "[Zr+4]": 327, "[N@+]": 328, "%70": 329, "[Co+2]": 330, "[Ru+2]": 331, "@H]": 332, "[O-2]": 333, "[Zn]": 334, "[S]": 335, "[Mo": 336, "[U]": 337, "%71": 338, "[OH-]": 339, "%72": 340, "%73": 341, "[Li]": 342, "[Ti+4]": 343, "[Ir+3]": 344, "%74": 345, "[Cu]": 346, "[Pd]": 347, "H3]": 348, "[Sn+]": 349, "b+]": 350, "[Cu+]": 351, "[s+]": 352, "[18": 353, "[V]": 354, "a]": 355, "[Cr]": 356, "[Au": 357, "%75": 358, "%76": 359, "%77": 360, "cH]": 361, "%78": 362, "[K]": 363, "[Rb+]": 364, "[Sn+2]": 365, "[Ti+3]": 366, "[Br+]": 367, "[Se": 368, "[Al+3]": 369, "%79": 370, "[14": 371, "%80": 372, "[13CH2]": 373, "[Tl": 374, "[Al+]": 375, "g+]": 376, "[Mo]": 377, "F]": 378, "[Pb]": 379, "[15": 380, "[Hf]": 381, "[18F]": 382, "%81": 383, "[Mn+2]": 384, "[Rh]": 385, "[13cH]": 386, "@@+]": 387, "[N@@+]": 388, "%82": 389, "[Fe+3]": 390, "[H+]": 391, "f]": 392, "[H-]": 393, "%83": 394, "[Mn]": 395, "m]": 396, "[Sb": 397, "[13c]": 398, "[Ge": 399, "[14C": 400, "%84": 401, "%85": 402, "[13CH3]": 403, "@@H]": 404, "[PH+]": 405, "[Rf]": 406, "[Tl]": 407, "[15N": 408, "[Ba+2]": 409, "g]": 410, "[PH2]": 411, "%86": 412, "%87": 413, "%88": 414, "%89": 415, "%90": 416, "[Gd": 417, "[Au+]": 418, "5]": 419, "I]": 420, "[Os": 421, "%91": 422, "%92": 423, "[E": 424, "r+2]": 425, "[NH2-]": 426, "%93": 427, "%94": 428, "%95": 429, "[AsH]": 430, "+5]": 431, "[IH+]": 432, "[p": 433, "[Bi": 434, "[Ag+]": 435, "[12": 436, "[15N]": 437, "[13C@@H]": 438, "[Ru]": 439, "a+3]": 440, "[Sb]": 441, "[In]": 442, "%96": 443, "[13CH]": 444, "[W+2]": 445, "[Cr+2]": 446, "[As-]": 447, "[Hf+4]": 448, "[SeH]": 449, "[pH]": 450, "s+]": 451, "[Cs+]": 452, "[Pt+4]": 453, "Dy": 454, "[Dy": 455, "[BH": 456, "[P]": 457, "[Hg]": 458, "[11": 459, "[AsH2]": 460, "[Sb+5]": 461, "[SH+]": 462, "[Hg+]": 463, "%97": 464, "%98": 465, "%99": 466, "[13C@H]": 467, "[Co+3]": 468, "[Gd]": 469, "n+3]": 470, "[Nd": 471, "[Pr]": 472, "[11C": 473, "[Cr+3]": 474, "[Ca]": 475, "[BH-]": 476, "[1C": 477, "[Ga]": 478, "[Eu": 479, "[Bi]": 480, "[Cd": 481, "[ClH+]": 482, "[Ru+]": 483, "[Mg]": 484, "[V+2]": 485, "5I]": 486, "[cH+]": 487, "[I]": 488, "[Y+3]": 489, "[Zr+3]": 490, "[Sn+4]": 491, "[Os]": 492, "[125I]": 493, "[Dy]": 494, "[Nd+3]": 495, "[Cd+2]": 496, "[Cm]": 497, "[S-2]": 498, "[Fm]": 499, "[Ti+]": 500, "[Ga+3]": 501, "[Rh+2]": 502, "[Rh+3]": 503, "[Se-]": 504, "[t": 505, "n+]": 506, "o]": 507, "[Si@@]": 508, "[Hf+2]": 509, "[18O": 510, "[Sb+3]": 511, "[te]": 512, "99": 513, "[99": 514, "[c+]": 515, "[Ag]": 516, "[Al+2]": 517, "[Ru+3]": 518, "[Mg+]": 519, "[Sn+3]": 520, "[Au+3]": 521, "[Tl+]": 522, "[14C]": 523, "[Gd+3]": 524, "[BH3-]": 525, "[11CH3]": 526, "[1CH2]": 527, "[99T": 528, "[Si@H]": 529, "[B@": 530, "[PH": 531, "[SiH-]": 532, "[Mn+3]": 533, "[Ru+4]": 534, "b+3]": 535, "[SH2+]": 536, "[Si@]": 537, "[SiH4]": 538, "[Tc]": 539, "[Pt+]": 540, "[Pd+]": 541, "[Au]": 542, "[Ge-]": 543, "[99Tc]": 544, "e+]": 545, "[N@]": 546, "[PH-]": 547, "[Rb]": 548, "[La+3]": 549, "[15n]": 550, "[Eu]": 551, "+6": 552, "[SH2]": 553, "[PH2+]": 554, "[In+3]": 555, "[In+]": 556, "[Mo+2]": 557, "[Mo+3]": 558, "[Se+]": 559, "[14CH2]": 560, "[Os+]": 561, "[18O]": 562, "+6]": 563, "1I]": 564, "[o": 565, "b+2]": 566, "c+3]": 567, "s]": 568, "[Ce": 569, "[N@@]": 570, "[Sc]": 571, "[Po]": 572, "[As+]": 573, "[Hf+]": 574, "[GeH]": 575, "[GeH2]": 576, "[14C@@H]": 577, "[Os+2]": 578, "[Es]": 579, "[Eu+3]": 580, "[oH+]": 581, "4I]": 582, "[b": 583, "i]": 584, "[Sr+2]": 585, "[SiH2-]": 586, "[Zn+]": 587, "[1cH]": 588, "[Ta]": 589, "[131I]": 590, "[U+4]": 591, "[Rh+]": 592, "[Mo+4]": 593, "[Ge@]": 594, "[14C@H]": 595, "[124I]": 596, "[Dy+3]": 597, "[18OH]": 598, "[B@@": 599, "[PH3+]": 600, "[b-]": 601, "3I]": 602, "e+2]": 603, "g+2]": 604, "r+]": 605, "t]": 606, "[Co+]": 607, "[Cf]": 608, "[Nb]": 609, "[Sm]": 610, "[Be+2]": 611, "[Pb+]": 612, "[Pb+2]": 613, "[At]": 614, "[Ho]": 615, "[Hg+2]": 616, "[Tc": 617, "[Y+]": 618, "[La]": 619, "[Zr+]": 620, "[si]": 621, "[Fe+]": 622, "[Al-]": 623, "[Ni+3]": 624, "[14c]": 625, "[14cH]": 626, "[Sb+2]": 627, "[Ge+]": 628, "[123I]": 629, "[BH2-]": 630, "[11CH2]": 631, "[B@-]": 632, "[PH2-]": 633, "-5]": 634, "3-2]": 635, "Xe]": 636, "[Xe]": 637, "a+]": 638, "m+3]": 639, "u]": 640, "u+3]": 641, "[Cr": 642, "[SH-]": 643, "[Sc+3]": 644, "[Si@@H]": 645, "[BH]": 646, "[Ba]": 647, "[Pa]": 648, "[Cl]": 649, "[IH]": 650, "[I+2]": 651, "[Ir+]": 652, "[Ar]": 653, "[Ta": 654, "[Te+]": 655, "[Yb+3]": 656, "[Ra]": 657, "[Lu+3]": 658, "[Ga+2]": 659, "[se+]": 660, "[Ni+]": 661, "[Hf+3]": 662, "[V+3]": 663, "[Ge@@]": 664, "[15NH]": 665, "[15N+]": 666, "[Er]": 667, "[p+]": 668, "[Bi+]": 669, "[Bi+3]": 670, "[BH3-2]": 671, "[Nd]": 672, "[1CH3]": 673, "[Ce+3]": 674, "[B@@-]": 675, "+7": 676, "2P": 677, "6N": 678, "O]": 679, "S]": 680, "[7": 681, "eH]": 682, "o+3]": 683, "[N+2]": 684, "[No]": 685, "[Nb+3]": 686, "[Nb+2]": 687, "Br]": 688, "[OH]": 689, "[Sr]": 690, "[S@+]": 691, "[Sb+]": 692, "[Sm+3]": 693, "[Si+4]": 694, "[CH3+]": 695, "[B-5]": 696, "[Po+]": 697, "[Pr+3]": 698, "[PH3]": 699, "[Pb+3]": 700, "[Cl+2]": 701, "[In+2]": 702, "[Ho+3]": 703, "[17": 704, "[1H]": 705, "[16N": 706, "[Tm]": 707, "[Tb+3]": 708, "[Tc+3]": 709, "[TeH]": 710, "[Mo+]": 711, "[U+3]": 712, "[Cu+4]": 713, "[sH+]": 714, "[Fe-": 715, "[Pd+3]": 716, "[W+4]": 717, "[32P": 718, "[Ru+6]": 719, "[V+]": 720, "[Mo+6]": 721, "[SeH2]": 722, "[Se-2]": 723, "[GeH3]": 724, "[14CH]": 725, "[14C@]": 726, "[14CH3]": 727, "[15NH2]": 728, "[Er+3]": 729, "[Bi+2]": 730, "[11C]": 731, "[Eu+2]": 732, "[Ce+4]": 733, "[B@@H-]": 734, "[Tc+4]": 735, "[Ta+5]": 736, "+7]": 737, "[32P]": 738, "13": 739, "1I": 740, "3S": 741, "4-]": 742, "5S]": 743, "5Br]": 744, "67": 745, "6S]": 746, "6Br]": 747, "7L": 748, "9P": 749, "Bi]": 750, "Cu]": 751, "OH]": 752, "[67": 753, "c+]": 754, "e+5]": 755, "f+2]": 756, "h+4]": 757, "n-]": 758, "n+5]": 759, "p]": 760, "[Ce]": 761, "[Ce+]": 762, "[Cr+]": 763, "[Cf+2]": 764, "[N+3]": 765, "[Np]": 766, "[O+2]": 767, "[SH3]": 768, "[213": 769, "[Ba+]": 770, "[P@+]": 771, "[Pu]": 772, "[ClH2+]": 773, "[Ir+2]": 774, "[In-]": 775, "[Am]": 776, "[1c]": 777, "[1OH]": 778, "[F]": 779, "[Ta+3]": 780, "[Tm+3]": 781, "[Th+4]": 782, "[Br+2]": 783, "[Re+]": 784, "[Ra+]": 785, "[Re+5]": 786, "[Pt+6]": 787, "[Mn+5]": 788, "[U-]": 789, "[Ga+]": 790, "[Cu-]": 791, "[Cu+6]": 792, "[Cu-5]": 793, "[AsH2+]": 794, "[13C@]": 795, "[Ni-2]": 796, "[W+]": 797, "[W-]": 798, "[35S]": 799, "[36S]": 800, "[Ru-]": 801, "[V+5]": 802, "[149P": 803, "[153S": 804, "[Gd+2]": 805, "[Os+5]": 806, "[Os+6]": 807, "[121I]": 808, "[BH4-]": 809, "[11O]": 810, "[111I": 811, "[1C]": 812, "[1CH]": 813, "[Cd]": 814, "[99Tc+3]": 815, "[99Tc+]": 816, "[B@H-]": 817, "[PH5]": 818, "[Ce+2]": 819, "[Tc+2]": 820, "[Tc+7]": 821, "[Cr-]": 822, "[Cr+4]": 823, "[Cr+7]": 824, "[Ta+4]": 825, "[75Br]": 826, "[76Br]": 827, "[17O]": 828, "[177L": 829, "[16N]": 830, "[16N+]": 831, "[Fe-3]": 832, "[Fe-4]": 833, "[67Cu]": 834, "[213Bi]": 835, "[149Pm]": 836, "[153Sm]": 837, "[111In]": 838, "[177Lu]": 839}, "merges": ["H ]", "[ C", "% 1", "[C @", "[C@ @", "[C@ H]", "[C@@ H]", "C l", "%1 0", "%1 1", "[ n", "+ ]", "[ N", "[n H]", "- ]", "%1 2", "B r", "[ O", "[O -]", "[N +]", "%1 3", "% 2", "%1 4", "[C@ ]", "[N H", "%1 5", "[C@@ ]", "%1 6", "[ S", "[NH +]", "[S i", "%1 7", "[N -]", "[Si ]", "%1 8", "% 3", "%1 9", "H +]", "2 +]", "[n H+]", "[NH 2+]", "%2 0", "%2 1", "[ 2", "[2 H]", "%2 2", "%2 3", "[n +]", "% 4", "%2 4", "[C H", "2 ]", "%2 5", "[C -]", "%2 6", "%2 7", "[ B", "%2 8", "[ P", "[ c", "[N a", "% 5", "%2 9", "[C l", "+ 2]", "[Na +]", "%3 0", "[C ]", "[Cl -]", "%3 1", "%3 2", "3 +]", "[NH 3+]", "@ ]", "[O +]", "[CH -]", "3 ]", "%3 3", "[c -]", "[CH 2]", "%3 4", "[C H]", "e ]", "[ I", "% 6", "%3 5", "[B +]", "%3 6", "[Si H", "%3 7", "[O ]", "[n -]", "%3 8", "2 -]", "[P +]", "[B -]", "[ Z", "[CH 2-]", "@ @]", "%3 9", "[B ]", "[Si -]", "[ A", "[ H", "[Si H]", "+ 3]", "%4 0", "[S -]", "% 7", "%4 1", "[ 1", "[c ]", "[ F", "[ K", "[S e]", "[S @]", "%4 2", "[S +]", "H -]", "[SiH 2]", "[ T", "[ Br", "r ]", "[S @@]", "[K +]", "%4 3", "[ Y", "[H H]", "[Y ]", "[Br -]", "%4 4", "[N ]", "[c H-]", "[ R", "[ H]", "[O H+]", "%4 5", "3 -]", "[P H]", "n ]", "% 8", "[NH -]", "[CH 3-]", "%4 6", "[P t", "[I -]", "[ L", "[C H+]", "%4 7", "[Z r", "[L i", "[C +]", "[T i", "[1 3", "%4 8", "[SiH 3]", "o +]", "[ o+]", "%4 9", "[Na ]", "[ M", "[CH 3]", "[I +]", "[ U", "%5 0", "[I r]", "% 9", "[ G", "[Li +]", "%5 1", "[C u", "4 ]", "n +2]", "%5 2", "[Zr +2]", "+ 4]", "[S n]", "[A s", "[13 C", "%5 3", "%5 4", "[C o", "%5 5", "[ s", "%5 6", "[F e", "[U +2]", "[A l", "%5 7", "[Pt +2]", "[S H]", "c ]", "[P @]", "[P d", "%5 8", "[N i", "[ W", "%5 9", "[As ]", "[Cl +3]", "[G e]", "[Fe +2]", "[P @@]", "[ 3", "[Z n+2]", "H 2]", "[3 H]", "[R u", "[Cl +]", "[Pt ]", "[F -]", "%6 0", "[Z r]", "[s e]", "%6 1", "[W ]", "[Cu +2]", "a +2]", "%6 2", "@ +]", "[M g", "[Co ]", "[Si +]", "%6 3", "[A c]", "H 2+]", "[CH 2+]", "%6 4", "[Pd +2]", "[F +]", "[Ti ]", "%6 5", "[P -]", "[Ti +2]", "- 2]", "[H f", "[O H2+]", "[F e]", "[Mg +2]", "[S n", "%6 6", "[C a+2]", "[T e]", "[Ni +2]", "r +3]", "b ]", "%6 7", "[R e]", "[Ni ]", "[ V", "[Al ]", "[R h", "[13C ]", "[N H]", "%6 8", "%6 9", "[Zr +4]", "[N @+]", "%7 0", "[Co +2]", "[Ru +2]", "@ H]", "[O -2]", "[Z n]", "[S ]", "[M o", "[U ]", "%7 1", "[O H-]", "%7 2", "%7 3", "[Li ]", "[Ti +4]", "[I r+3]", "%7 4", "[Cu ]", "[Pd ]", "H 3]", "[Sn +]", "b +]", "[Cu +]", "[s +]", "[1 8", "[V ]", "a ]", "[C r]", "[A u", "%7 5", "%7 6", "%7 7", "c H]", "%7 8", "[K ]", "[R b+]", "[S n+2]", "[Ti +3]", "[Br +]", "[S e", "[Al +3]", "%7 9", "[1 4", "%8 0", "[13C H2]", "[T l", "[Al +]", "g +]", "[Mo ]", "F ]", "[P b]", "[1 5", "[Hf ]", "[18 F]", "%8 1", "[M n+2]", "[Rh ]", "[13 cH]", "@ @+]", "[N @@+]", "%8 2", "[Fe +3]", "[ H+]", "f ]", "[H -]", "%8 3", "[M n]", "m ]", "[S b", "[13 c]", "[G e", "[14 C", "%8 4", "%8 5", "[13C H3]", "@ @H]", "[P H+]", "[R f]", "[Tl ]", "[15 N", "[B a+2]", "g ]", "[P H2]", "%8 6", "%8 7", "%8 8", "%8 9", "%9 0", "[G d", "[Au +]", "5 ]", "I ]", "[O s", "%9 1", "%9 2", "[ E", "r +2]", "[NH 2-]", "%9 3", "%9 4", "%9 5", "[As H]", "+ 5]", "[I H+]", "[ p", "[B i", "[A g+]", "[1 2", "[15N ]", "[13C @@H]", "[Ru ]", "a +3]", "[S b]", "[I n]", "%9 6", "[13C H]", "[W +2]", "[C r+2]", "[As -]", "[Hf +4]", "[Se H]", "[p H]", "s +]", "[C s+]", "[Pt +4]", "D y", "[ Dy", "[B H", "[P ]", "[H g]", "[1 1", "[As H2]", "[Sb +5]", "[S H+]", "[H g+]", "%9 7", "%9 8", "%9 9", "[13C @H]", "[Co +3]", "[Gd ]", "n +3]", "[N d", "[P r]", "[11 C", "[C r+3]", "[C a]", "[B H-]", "[1 C", "[G a]", "[E u", "[Bi ]", "[C d", "[Cl H+]", "[Ru +]", "[Mg ]", "[V +2]", "5 I]", "[c H+]", "[I ]", "[Y +3]", "[Zr +3]", "[Sn +4]", "[Os ]", "[12 5I]", "[Dy ]", "[Nd +3]", "[Cd +2]", "[C m]", "[S -2]", "[F m]", "[Ti +]", "[G a+3]", "[Rh +2]", "[Rh +3]", "[Se -]", "[ t", "n +]", "o ]", "[Si @@]", "[Hf +2]", "[18 O", "[Sb +3]", "[t e]", "9 9", "[ 99", "[c +]", "[A g]", "[Al +2]", "[Ru +3]", "[Mg +]", "[Sn +3]", "[Au +3]", "[Tl +]", "[14C ]", "[Gd +3]", "[BH 3-]", "[11C H3]", "[1C H2]", "[99 T", "[Si @H]", "[B @", "[P H", "[SiH -]", "[M n+3]", "[Ru +4]", "b +3]", "[S H2+]", "[Si @]", "[SiH 4]", "[T c]", "[Pt +]", "[Pd +]", "[Au ]", "[Ge -]", "[99T c]", "e +]", "[N @]", "[P H-]", "[R b]", "[L a+3]", "[15 n]", "[Eu ]", "+ 6", "[S H2]", "[P H2+]", "[I n+3]", "[I n+]", "[Mo +2]", "[Mo +3]", "[Se +]", "[14C H2]", "[Os +]", "[18O ]", "+6 ]", "1 I]", "[ o", "b +2]", "c +3]", "s ]", "[C e", "[N @@]", "[S c]", "[P o]", "[As +]", "[Hf +]", "[Ge H]", "[Ge H2]", "[14C @@H]", "[Os +2]", "[E s]", "[Eu +3]", "[o H+]", "4 I]", "[ b", "i ]", "[S r+2]", "[SiH 2-]", "[Z n+]", "[1 cH]", "[T a]", "[13 1I]", "[U +4]", "[Rh +]", "[Mo +4]", "[Ge @]", "[14C @H]", "[12 4I]", "[Dy +3]", "[18O H]", "[B@ @", "[PH 3+]", "[b -]", "3 I]", "e +2]", "g +2]", "r +]", "t ]", "[C o+]", "[C f]", "[N b]", "[S m]", "[B e+2]", "[P b+]", "[P b+2]", "[A t]", "[H o]", "[H g+2]", "[T c", "[Y +]", "[L a]", "[Zr +]", "[s i]", "[Fe +]", "[Al -]", "[Ni +3]", "[14 c]", "[14 cH]", "[Sb +2]", "[Ge +]", "[12 3I]", "[BH 2-]", "[11C H2]", "[B@ -]", "[PH 2-]", "- 5]", "3 -2]", "X e]", "[ Xe]", "a +]", "m +3]", "u ]", "u +3]", "[C r", "[S H-]", "[S c+3]", "[Si @@H]", "[B H]", "[B a]", "[P a]", "[Cl ]", "[I H]", "[I +2]", "[I r+]", "[A r]", "[T a", "[T e+]", "[Y b+3]", "[R a]", "[L u+3]", "[G a+2]", "[s e+]", "[Ni +]", "[Hf +3]", "[V +3]", "[Ge @@]", "[15N H]", "[15N +]", "[E r]", "[p +]", "[Bi +]", "[Bi +3]", "[BH 3-2]", "[Nd ]", "[1C H3]", "[Ce +3]", "[B@@ -]", "+ 7", "2 P", "6 N", "O ]", "S ]", "[ 7", "e H]", "o +3]", "[N +2]", "[N o]", "[N b+3]", "[N b+2]", "Br ]", "[O H]", "[S r]", "[S @+]", "[S b+]", "[S m+3]", "[Si +4]", "[CH 3+]", "[B -5]", "[P o+]", "[P r+3]", "[P H3]", "[P b+3]", "[Cl +2]", "[I n+2]", "[H o+3]", "[1 7", "[1 H]", "[1 6N", "[T m]", "[T b+3]", "[T c+3]", "[T eH]", "[M o+]", "[U +3]", "[Cu +4]", "[s H+]", "[Fe -", "[Pd +3]", "[W +4]", "[3 2P", "[Ru +6]", "[V +]", "[Mo +6]", "[Se H2]", "[Se -2]", "[Ge H3]", "[14C H]", "[14C @]", "[14C H3]", "[15N H2]", "[E r+3]", "[Bi +2]", "[11C ]", "[Eu +2]", "[Ce +4]", "[B@@ H-]", "[Tc +4]", "[Ta +5]", "+7 ]", "[32P ]", "1 3", "1 I", "3 S", "4 -]", "5 S]", "5 Br]", "6 7", "6 S]", "6 Br]", "7 L", "9 P", "B i]", "C u]", "O H]", "[ 67", "c +]", "e +5]", "f +2]", "h +4]", "n -]", "n +5]", "p ]", "[C e]", "[C e+]", "[C r+]", "[C f+2]", "[N +3]", "[N p]", "[O +2]", "[S H3]", "[2 13", "[B a+]", "[P @+]", "[P u]", "[Cl H2+]", "[I r+2]", "[I n-]", "[A m]", "[1 c]", "[1 OH]", "[F ]", "[T a+3]", "[T m+3]", "[T h+4]", "[Br +2]", "[R e+]", "[R a+]", "[R e+5]", "[Pt +6]", "[M n+5]", "[U -]", "[G a+]", "[Cu -]", "[Cu +6]", "[Cu -5]", "[As H2+]", "[13C @]", "[Ni -2]", "[W +]", "[W -]", "[3 5S]", "[3 6S]", "[Ru -]", "[V +5]", "[14 9P", "[15 3S", "[Gd +2]", "[Os +5]", "[Os +6]", "[12 1I]", "[BH 4-]", "[11 O]", "[11 1I", "[1C ]", "[1C H]", "[Cd ]", "[99T c+3]", "[99T c+]", "[B@ H-]", "[PH 5]", "[Ce +2]", "[Tc +2]", "[Tc +7]", "[Cr -]", "[Cr +4]", "[Cr +7]", "[Ta +4]", "[7 5Br]", "[7 6Br]", "[17 O]", "[17 7L", "[16N ]", "[16N +]", "[Fe- 3]", "[Fe- 4]", "[67 Cu]", "[213 Bi]", "[149P m]", "[153S m]", "[111I n]", "[177L u]"]}, "custom_pre_tokenizer": true, "tokenizer_type": "bpe", "tokenizer_attrs": {"pad_token": "[PAD]", "cls_token": "[CLS]", "sep_token": "[SEP]", "mask_token": "[MASK]", "unk_token": "[UNK]", "eos_token": "[SEP]", "bos_token": "[CLS]"}} -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: "SAFE" 2 | site_description: "Gotta be SAFE: a new framework for molecular design" 3 | site_url: "https://github.com/datamol-io/safe" 4 | repo_url: "https://github.com/datamol-io/safe" 5 | repo_name: "datamol-io/safe" 6 | copyright: Copyright 2023 Valence Labs 7 | 8 | remote_branch: "gh-pages" 9 | docs_dir: "docs" 10 | use_directory_urls: false 11 | strict: true 12 | 13 | nav: 14 | - Overview: index.md 15 | - Tutorials: 16 | - Getting Started: tutorials/getting-started.ipynb 17 | - Molecular design: tutorials/design-with-safe.ipynb 18 | - How it works: tutorials/how-it-works.ipynb 19 | - Extracting representation (molfeat): tutorials/extracting-representation-molfeat.ipynb 20 | - API: 21 | - SAFE: api/safe.md 22 | - Visualization: api/safe.viz.md 23 | - Model training: api/safe.models.md 24 | - CLI: cli.md 25 | - License: license.md 26 | - Data License: data_license.md 27 | 28 | theme: 29 | name: material 30 | features: 31 | - navigation.expand 32 | 33 | extra_javascript: 34 | - assets/js/google-analytics.js 35 | 36 | markdown_extensions: 37 | - admonition 38 | - markdown_include.include 39 | - pymdownx.emoji 40 | - pymdownx.highlight 41 | - pymdownx.magiclink 42 | - pymdownx.superfences 43 | - pymdownx.tabbed 44 | - pymdownx.tasklist 45 | - pymdownx.details 46 | # For `tab_length=2` in the markdown extension 47 | # See https://github.com/mkdocs/mkdocs/issues/545 48 | - mdx_truly_sane_lists 49 | - toc: 50 | permalink: true 51 | 52 | watch: 53 | - safe/ 54 | 55 | plugins: 56 | - search 57 | - mkdocstrings: 58 | handlers: 59 | python: 60 | import: 61 | - https://docs.python.org/3/objects.inv 62 | setup_commands: 63 | - import sys 64 | - import safe 65 | - sys.path.append("docs") 66 | - sys.path.append("safe") 67 | selection: 68 | new_path_syntax: true 69 | rendering: 70 | show_root_heading: false 71 | heading_level: 2 72 | show_if_no_docstring: true 73 | options: 74 | docstring_options: 75 | ignore_init_summary: false 76 | docstring_section_style: list 77 | merge_init_into_class: true 78 | show_root_heading: false 79 | show_root_full_path: false 80 | show_signature_annotations: true 81 | show_symbol_type_heading: true 82 | show_symbol_type_toc: true 83 | signature_crossrefs: true 84 | 85 | - mkdocs-jupyter: 86 | execute: False 87 | remove_tag_config: 88 | remove_cell_tags: [remove_cell] 89 | remove_all_outputs_tags: [remove_output] 90 | remove_input_tags: [remove_input] 91 | 92 | - mike: 93 | version_selector: true 94 | 95 | extra: 96 | version: 97 | provider: mike 98 | 99 | social: 100 | - icon: fontawesome/brands/github 101 | link: https://github.com/datamol-io 102 | - icon: fontawesome/brands/twitter 103 | link: https://twitter.com/datamol_io 104 | - icon: fontawesome/brands/python 105 | link: https://pypi.org/project/safe-mol/ 106 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "safe-mol" 7 | description = "Implementation of the 'Gotta be SAFE: a new framework for molecular design' paper" 8 | dynamic = ["version"] 9 | authors = [{ name = "Emmanuel Noutahi", email = "emmanuel.noutahi@gmail.com" }] 10 | readme = "README.md" 11 | license = { text = "Apache-2.0" } 12 | requires-python = ">=3.9" 13 | classifiers = [ 14 | "Development Status :: 5 - Production/Stable", 15 | "Intended Audience :: Developers", 16 | "Intended Audience :: Healthcare Industry", 17 | "Intended Audience :: Science/Research", 18 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 19 | "Topic :: Scientific/Engineering :: Bio-Informatics", 20 | "Topic :: Scientific/Engineering :: Information Analysis", 21 | "Topic :: Scientific/Engineering :: Medical Science Apps.", 22 | "Natural Language :: English", 23 | "Operating System :: OS Independent", 24 | "Programming Language :: Python", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.9", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | ] 31 | 32 | keywords = ["safe", "smiles", "de novo", "design", "molecules"] 33 | dependencies = [ 34 | "tqdm", 35 | "loguru", 36 | "typer", 37 | "universal_pathlib", 38 | "datamol", 39 | "numpy", 40 | "torch>=2.0", 41 | "transformers", 42 | "datasets", 43 | "tokenizers", 44 | "accelerate", 45 | "evaluate", 46 | "wandb", 47 | "huggingface-hub", 48 | "rdkit" 49 | ] 50 | 51 | [project.urls] 52 | "Source Code" = "https://github.com/datamol-io/safe" 53 | "Bug Tracker" = "https://github.com/datamol-io/safe/issues" 54 | Documentation = "https://safe-docs.datamol.io/" 55 | 56 | 57 | [project.scripts] 58 | safe-train = "safe.trainer.cli:main" 59 | 60 | [tool.setuptools] 61 | include-package-data = true 62 | zip-safe = false 63 | license-files = ["LICENSE"] 64 | 65 | [tool.setuptools_scm] 66 | fallback_version = "dev" 67 | 68 | [tool.setuptools.packages.find] 69 | where = ["."] 70 | include = ["safe", "safe.*"] 71 | exclude = [] 72 | namespaces = true 73 | 74 | [tool.black] 75 | line-length = 100 76 | target-version = ['py310', 'py311'] 77 | include = '\.pyi?$' 78 | 79 | [tool.pytest.ini_options] 80 | minversion = "6.0" 81 | addopts = "--verbose --color yes" 82 | testpaths = ["tests"] 83 | 84 | 85 | [tool.ruff] 86 | line-length = 120 87 | # Enable Pyflakes `E` and `F` codes by default. 88 | lint.select = [ 89 | "E", 90 | "W", # see: https://pypi.org/project/pycodestyle 91 | "F", # see: https://pypi.org/project/pyflakes 92 | ] 93 | lint.extend-select = [ 94 | "C4", # see: https://pypi.org/project/flake8-comprehensions 95 | "SIM", # see: https://pypi.org/project/flake8-simplify 96 | "RET", # see: https://pypi.org/project/flake8-return 97 | "PT", # see: https://pypi.org/project/flake8-pytest-style 98 | ] 99 | lint.ignore = [ 100 | "E731", # Do not assign a lambda expression, use a def 101 | "S108", 102 | "F401", 103 | "S105", 104 | "E501", 105 | "E722", 106 | ] 107 | # Exclude a variety of commonly ignored directories. 108 | exclude = [".git", "docs", "_notebooks"] 109 | lint.ignore-init-module-imports = true 110 | -------------------------------------------------------------------------------- /safe/__init__.py: -------------------------------------------------------------------------------- 1 | from . import trainer, utils 2 | from ._exception import SAFEDecodeError, SAFEEncodeError, SAFEFragmentationError 3 | from .converter import SAFEConverter, decode, encode 4 | from .sample import SAFEDesign 5 | from .tokenizer import SAFETokenizer, split 6 | from .viz import to_image 7 | from .io import upload_to_wandb 8 | -------------------------------------------------------------------------------- /safe/_exception.py: -------------------------------------------------------------------------------- 1 | class SAFEDecodeError(Exception): 2 | """Raised when a string cannot be decoded with the given encoding.""" 3 | 4 | pass 5 | 6 | 7 | class SAFEEncodeError(Exception): 8 | """Raised when a molecule cannot be encoded using SAFE.""" 9 | 10 | pass 11 | 12 | 13 | class SAFEFragmentationError(Exception): 14 | """Raised when a the slicing algorithm return empty bonds.""" 15 | 16 | pass 17 | -------------------------------------------------------------------------------- /safe/_pattern.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from typing import List 3 | from typing import Optional 4 | import re 5 | import datamol as dm 6 | import torch 7 | import numpy as np 8 | import safe as sf 9 | import torch.nn.functional as F 10 | import transformers 11 | 12 | from loguru import logger 13 | from tqdm.auto import tqdm 14 | import contextlib 15 | 16 | 17 | class PatternConstraint: 18 | """ 19 | Sampling decorator for pretrained SAFE models. 20 | This implementation is inspired by the Sanofi decorator with: 21 | 1. new generalization to different tokenizers 22 | 2. support for a subset of smarts notations 23 | 3. speed improvements by dropping unnecessary steps 24 | 25 | !!! note 26 | For smarts based constraints, it's important to understand that the constraints 27 | are strong sampling suggestions and not necessarily the final result, meaning that they can 28 | fail. 29 | 30 | !!! warning 31 | Ring constraints should be interpreted as "Extended Ring System" constraints. 32 | Thus [r6] means an atom in an environment of 6 atoms and more that contains a ring system, 33 | instead of an atom in a ring of size 6. 34 | """ 35 | 36 | ATTACHMENT_POINT_TOKEN = "\\*" 37 | ATTACHMENT_POINTS = [ 38 | # parse any * not preceeded by "[" or ":" and not followed by "]" or ":" as attachment 39 | r"(? 0: 127 | random_mol = dm.randomize_atoms(random_mol) 128 | try: 129 | out.add(dm.to_smarts(random_mol)) 130 | except Exception as e: 131 | logger.error(e) 132 | n -= 1 133 | return out 134 | 135 | def _find_ring_tokens(self): 136 | """Find all possible ring tokens in the vocab.""" 137 | ring_token_ids = [] 138 | for tk, tk_ids in self.tokenizer.tokenizer.get_vocab().items(): 139 | try: 140 | _ = int(tk.lstrip("%")) 141 | ring_token_ids.append(tk_ids) 142 | except ValueError: 143 | pass 144 | return ring_token_ids 145 | 146 | def _prepare_scaffold(self): 147 | """Prepare scaffold for decoration.""" 148 | return self.input_scaffold 149 | 150 | def is_branch_closer(self, pos_or_token: Union[int, str]): 151 | """Check whether a token is a branch closer.""" 152 | if isinstance(pos_or_token, int): 153 | return self.tokens[pos_or_token] == self.branch_closer 154 | return pos_or_token == self.branch_closer 155 | 156 | def is_branch_opener(self, pos_or_token: Union[int, str]): 157 | """Check whether a token is a branch opener.""" 158 | if isinstance(pos_or_token, int): 159 | return self.tokens[pos_or_token] == self.branch_opener 160 | return pos_or_token == self.branch_opener 161 | 162 | def __len__(self): 163 | """Get length of the tokenized scaffold.""" 164 | if not self._is_initialized: 165 | raise ValueError("Decorator is not initialized yet") 166 | return len(self.tokens) 167 | 168 | def _initialize(self): 169 | """ 170 | Initialize the current scaffold decorator object with the scaffold object. 171 | The initialization will also set and validate the vocab object to use for the scaffold decoration, 172 | """ 173 | self._is_initialized = False 174 | pretrained_tokenizer = self.tokenizer.get_pretrained() 175 | encoding_tokens = [ 176 | token 177 | for token, _ in pretrained_tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str( 178 | self.scaffold 179 | ) 180 | ] + [pretrained_tokenizer.eos_token] 181 | encoding_token_ids = [ 182 | pretrained_tokenizer.convert_tokens_to_ids(x) for x in encoding_tokens 183 | ] 184 | # encodings.tokens contains BOS and EOS 185 | 186 | self.branch_opener_id = self.tokenizer.tokenizer.token_to_id(self.branch_opener) 187 | self.branch_closer_id = self.tokenizer.tokenizer.token_to_id(self.branch_closer) 188 | linker_size = {} 189 | # convert the full vocab into mol constraints 190 | vocab_as_constraints = [ 191 | self._parse_token_as_mol(self.tokenizer.tokenizer.id_to_token(i)) 192 | for i in range(len(self.tokenizer)) 193 | ] 194 | token_masks = {} 195 | actions = {} 196 | tokens = [] 197 | ids = [] 198 | all_tokens = self.tokenizer.tokenizer.get_vocab().keys() 199 | unk_token_id = pretrained_tokenizer.unk_token_id 200 | unknown_token_map = {} 201 | # we include the stop token 202 | for pos in range(len(encoding_token_ids)): 203 | token_id = encoding_token_ids[pos] 204 | token = encoding_tokens[pos] 205 | 206 | # if it is not an unknown token, then it can just be rollout 207 | # note that we are not using all special tokens, just unknown 208 | # and as such, you would need to have a well defined vocab 209 | cur_mask = torch.ones(len(vocab_as_constraints), dtype=torch.bool) 210 | constraints = None 211 | ring_token = None 212 | if token_id == unk_token_id: 213 | # this here can be one of the case we want: 214 | # we need to check if the token is a ring and whether it has other constraints. 215 | ring_match = re.match(self.HAS_RING_TOKEN, token) 216 | if ring_match and token.count("r") != 1: 217 | raise ValueError("Multiple ring constraints in a single token is not supported") 218 | if ring_match: 219 | ring_size = ring_match.group(2).strip("r:") 220 | ring_size = int(ring_size) if ring_size else self.min_ring_size 221 | linker_size[pos - 1] = ring_size 222 | # since we have filled the ring constraints already, we need to remove it from the token format 223 | ring_token = self._remove_ring_constraint(token) 224 | if self._is_attachment(token) or ( 225 | ring_token is not None and self._is_attachment(ring_token) 226 | ): 227 | # the mask would be handled by the attachment algorithm in decorate 228 | actions[token] = "attach" 229 | unknown_token_map[pos] = sf.utils.standardize_attach(token) 230 | elif self._is_constraint(token, all_tokens): 231 | actions[token] = "constraint" 232 | constraints = self._parse_token_as_mol(token) 233 | if constraints is not None: 234 | cur_mask = self._mask_tokens(constraints, vocab_as_constraints) 235 | else: 236 | # this means that we need to sample the exact token 237 | # and disallow all the other 238 | cur_mask = torch.zeros(len(vocab_as_constraints), dtype=torch.bool) 239 | cur_mask[token_id] = 1 240 | token_masks[token] = cur_mask 241 | tokens.append(token) 242 | ids.append(token_id) 243 | self.linker_size = linker_size 244 | self.token_masks = token_masks 245 | self.actions = actions 246 | self.tokens = tokens 247 | self.ids = ids 248 | self._is_initialized = True 249 | self.unknown_token_map = unknown_token_map 250 | self.ring_token_ids = self._find_ring_tokens() 251 | 252 | def _is_attachment(self, token: str): 253 | """Check whether a token is should be an attachment or not""" 254 | # What I define as attachment is a token that is not a constraint and not a ring 255 | # basically the classic "[*]" written however you like it 256 | return any(re.match(attach_regex, token) for attach_regex in self.ATTACHMENT_POINTS) 257 | 258 | def _is_constraint(self, token: str, vocab_list: Optional[List[str]] = None): 259 | """Check whether a token is a constraint 260 | Args: 261 | token: token to check whether 262 | vocab: optional vocab to check against 263 | """ 264 | if vocab_list is None: 265 | vocab_list = [] 266 | tk_constraints = re.match(self.IS_CONSTRAINT, token) or token not in vocab_list 267 | return tk_constraints is not None 268 | 269 | def _remove_ring_constraint(self, token): 270 | """Remove ring constraints from a token""" 271 | token = re.sub(r"((&|,|;)?r\d*)*(&|,|;)?", r"\3", token) 272 | token = re.sub(r"(\[[&;,]?)(.*)([&,;]?\])", r"[\2]", token) 273 | if token == "[]": 274 | return "[*]" 275 | return token 276 | 277 | def _parse_token_as_mol(self, token): 278 | """ 279 | Parse a token as a valid molecular pattern 280 | """ 281 | tk_mol = None 282 | with dm.without_rdkit_log(): 283 | try: 284 | tk_mol = dm.from_smarts(token) 285 | except: 286 | tk_mol = dm.to_mol(token, sanitize=True) 287 | # didn't work, try with second strategy 288 | if tk_mol is None: 289 | tk_mol = dm.to_mol(token, sanitize=False) 290 | try: 291 | tk_mol = dm.from_smarts(dm.smiles_as_smarts(tk_mol)) 292 | except: 293 | tk_mol = None 294 | return tk_mol 295 | 296 | def _mask_tokens(self, constraint: List[dm.Mol], vocab_mols: List[dm.Mol]): 297 | """Mask the prediction to enforce some constraints 298 | 299 | Args: 300 | constraint: constraint found in the scaffold 301 | vocab_mols: list of mol queries (convertible ones) from the vocab 302 | 303 | Returns: 304 | mask: mask for valid tokens that match constraints. 1 means keep, 0 means mask 305 | """ 306 | mask = torch.zeros(len(vocab_mols), dtype=torch.bool) 307 | with dm.without_rdkit_log(): 308 | for ind, tk_mol in enumerate(vocab_mols): 309 | if tk_mol is not None: 310 | with contextlib.suppress(Exception): 311 | mask[ind] = int(tk_mol.HasSubstructMatch(constraint)) 312 | return mask 313 | 314 | 315 | class PatternSampler: 316 | """ 317 | Implements a pattern-constrained sequence sampler for Autoregressive transformer models using a PatternConstraint. 318 | 319 | Args: 320 | model: Pretrained model used for generation. 321 | pattern_decorator: The PatternConstraint object that provides the scaffold and sampling constraints. 322 | min_linker_size: Minimum size of the linker. 323 | max_steps_to_eos: Maximum steps to end-of-sequence token. 324 | max_length: Maximum length of the generated sequences. 325 | """ 326 | 327 | def __init__( 328 | self, 329 | model, 330 | pattern_decorator, 331 | min_linker_size: int = 3, 332 | max_steps_to_eos: int = 50, 333 | max_length: int = 128, 334 | ): 335 | self.model = model 336 | self.pattern_decorator = pattern_decorator 337 | self.tokenizer = self.pattern_decorator.tokenizer.get_pretrained() 338 | self.min_linker_size = min_linker_size 339 | self.max_steps_to_eos = max_steps_to_eos 340 | self.model.eval() 341 | self.max_length = max_length 342 | 343 | def nll_loss(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 344 | """ 345 | Custom Negative Log Likelihood (NLL) loss that returns loss per example. 346 | 347 | Args: 348 | inputs (torch.Tensor): Log probabilities of each class, shape (batch_size, num_classes). 349 | targets (torch.Tensor): Target class indices, shape (batch_size). 350 | 351 | Returns: 352 | torch.Tensor: Loss for each example, shape (batch_size). 353 | """ 354 | target_expanded = ( 355 | torch.zeros(inputs.size()).cuda() 356 | if torch.cuda.is_available() 357 | else torch.zeros(inputs.size()) 358 | ) 359 | target_expanded.scatter_(1, targets.contiguous().view(-1, 1).data, 1.0) 360 | loss = target_expanded * inputs 361 | return torch.sum(loss, dim=1) 362 | 363 | def sample_scaffolds( 364 | self, n_samples: int = 100, n_trials: int = 1, random_seed: Optional[int] = None 365 | ): 366 | """ 367 | Sample a batch of sequences based on the scaffold provided by the PatternConstraint. 368 | 369 | Args: 370 | n_samples: Number of sequences to sample. 371 | n_trials: Number of sampling trials to perform. 372 | random_seed: Seed for random number generation. 373 | 374 | Returns: 375 | List[str]: List of sampled sequences as strings. 376 | """ 377 | if random_seed is not None: 378 | torch.manual_seed(random_seed) 379 | 380 | sampled_mols = [] 381 | for _ in tqdm(range(n_trials), leave=False): 382 | generated_sequences, *_ = self._generate(n_samples) 383 | sampled_mols.extend( 384 | [ 385 | self._as_scaffold(self.tokenizer.decode(seq, skip_special_tokens=False)) 386 | for seq in generated_sequences 387 | ] 388 | ) 389 | return sampled_mols 390 | 391 | def _as_scaffold(self, scaff: str) -> str: 392 | """ 393 | Converts the generated sequence to a valid scaffold by replacing unknown tokens. 394 | 395 | Args: 396 | scaff: The generated sequence string. 397 | 398 | Returns: 399 | str: The scaffold string with unknown tokens replaced. 400 | """ 401 | out = scaff.replace(self.tokenizer.eos_token, "") 402 | splitted_out = [ 403 | token for token, _ in self.tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(out) 404 | ] 405 | for pos, query in self.pattern_decorator.unknown_token_map.items(): 406 | splitted_out[pos] = query 407 | return "".join(splitted_out) 408 | 409 | def _generate(self, batch_size: int, max_length: Optional[int] = None): 410 | """ 411 | Generate sequences with custom constraints using the model and PatternConstraint. 412 | 413 | Args: 414 | batch_size: Number of sequences to generate. 415 | max_length: Maximum length of the sequence. 416 | 417 | Returns: 418 | Tuple: Generated sequences, log probabilities, and entropies. 419 | """ 420 | sequences = [] 421 | if max_length is None: 422 | max_length = self.max_length 423 | 424 | start_token = torch.full((batch_size, 1), self.tokenizer.bos_token_id, dtype=torch.long) 425 | finished = torch.zeros(batch_size, dtype=bool) 426 | log_probs = torch.zeros(batch_size) 427 | entropy = torch.zeros(batch_size) 428 | 429 | if torch.cuda.is_available(): 430 | log_probs = log_probs.cuda() 431 | entropy = entropy.cuda() 432 | start_token = start_token.cuda() 433 | finished = finished.cuda() 434 | 435 | max_dec_steps = max_length - len(self.pattern_decorator) 436 | if max_dec_steps < 0: 437 | raise ValueError("Step size negative due to scaffold being longer than max_length") 438 | 439 | input_ids = start_token 440 | trackers = torch.zeros(batch_size, dtype=torch.int) # Tracks the position in the scaffold 441 | 442 | for step in range(max_length): 443 | current_tokens = [self.pattern_decorator.tokens[index] for index in trackers] 444 | action_i = [ 445 | self.pattern_decorator.actions.setdefault( 446 | self.pattern_decorator.tokens[index], "roll" 447 | ) 448 | for index in trackers 449 | ] 450 | 451 | # Pass through model 452 | outputs = self.model(input_ids) 453 | logits = outputs.logits[:, -1, :] 454 | log_prob = torch.log_softmax(logits, dim=-1) 455 | probs = log_prob.exp() 456 | 457 | decoder_input = torch.multinomial(probs, num_samples=1).squeeze(1).view(-1) 458 | 459 | for i in range(batch_size): 460 | if action_i[i] == "constraint": 461 | mask = self.pattern_decorator.token_masks[current_tokens[i]].to( 462 | input_ids.device 463 | ) 464 | prob_i = self.pattern_decorator._logprobs_to_probs(log_prob[i, :], mask) 465 | if prob_i.sum() == 0: 466 | random_choice = torch.nonzero(mask.squeeze()).squeeze().cpu().numpy() 467 | decoder_input[i] = np.random.choice(random_choice) 468 | else: 469 | decoder_input[i] = torch.multinomial(prob_i, num_samples=1).view(-1) 470 | trackers[i] += int(decoder_input[i] != self.tokenizer.eos_token_id) 471 | else: 472 | decoder_input[i] = self.pattern_decorator.ids[trackers[i].item()] 473 | trackers[i] += int(decoder_input[i] != self.tokenizer.eos_token_id) 474 | 475 | sequences.append(decoder_input.unsqueeze(-1)) 476 | input_ids = torch.cat((input_ids, decoder_input.unsqueeze(-1)), dim=1) 477 | log_probs += self.nll_loss(log_prob, decoder_input) 478 | entropy += -torch.sum(log_prob * probs, dim=-1) 479 | 480 | eos_sampled = (decoder_input == self.tokenizer.eos_token_id).bool() 481 | finished = torch.ge(finished + eos_sampled, 1) 482 | if torch.prod(finished) == 1: 483 | break 484 | sequences = torch.cat(sequences, dim=1) 485 | return sequences, log_probs, entropy 486 | -------------------------------------------------------------------------------- /safe/converter.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import re 3 | from collections import Counter 4 | from contextlib import suppress 5 | from typing import Callable, List, Optional, Union 6 | 7 | import datamol as dm 8 | import numpy as np 9 | from rdkit import Chem 10 | from rdkit.Chem import BRICS 11 | from loguru import logger 12 | 13 | from ._exception import SAFEDecodeError, SAFEEncodeError, SAFEFragmentationError 14 | from .utils import standardize_attach 15 | 16 | 17 | class SAFEConverter: 18 | """Molecule line notation conversion from SMILES to SAFE 19 | 20 | A SAFE representation is a string based representation of a molecule decomposition into fragment components, 21 | separated by a dot ('.'). Note that each component (fragment) might not be a valid molecule by themselves, 22 | unless explicitely correct to add missing hydrogens. 23 | 24 | !!! note "Slicing algorithms" 25 | 26 | By default SAFE strings are generated using `BRICS`, however, the following alternative are supported: 27 | 28 | * [Hussain-Rea (`hr`)](https://pubs.acs.org/doi/10.1021/ci900450m) 29 | * [RECAP (`recap`)](https://pubmed.ncbi.nlm.nih.gov/9611787/) 30 | * [RDKit's MMPA (`mmpa`)](https://www.rdkit.org/docs/source/rdkit.Chem.rdMMPA.html) 31 | * Any possible attachment points (`attach`) 32 | 33 | Furthermore, you can also provide your own slicing algorithm, which should return a pair of atoms 34 | corresponding to the bonds to break. 35 | 36 | """ 37 | 38 | SUPPORTED_SLICERS = ["hr", "rotatable", "recap", "mmpa", "attach", "brics"] 39 | __SLICE_SMARTS = { 40 | "hr": ["[*]!@-[*]"], # any non ring single bond 41 | "recap": [ 42 | "[$([C;!$(C([#7])[#7])](=!@[O]))]!@[$([#7;+0;!D1])]", 43 | "[$(C=!@O)]!@[$([O;+0])]", 44 | "[$([N;!D1;+0;!$(N-C=[#7,#8,#15,#16])](-!@[*]))]-!@[$([*])]", 45 | "[$(C(=!@O)([#7;+0;D2,D3])!@[#7;+0;D2,D3])]!@[$([#7;+0;D2,D3])]", 46 | "[$([O;+0](-!@[#6!$(C=O)])-!@[#6!$(C=O)])]-!@[$([#6!$(C=O)])]", 47 | "C=!@C", 48 | "[N;+1;D4]!@[#6]", 49 | "[$([n;+0])]-!@C", 50 | "[$([O]=[C]-@[N;+0])]-!@[$([C])]", 51 | "c-!@c", 52 | "[$([#7;+0;D2,D3])]-!@[$([S](=[O])=[O])]", 53 | ], 54 | "mmpa": ["[#6+0;!$(*=,#[!#6])]!@!=!#[*]"], # classical mmpa slicing smarts 55 | "attach": ["[*]!@[*]"], # any potential attachment point, including hydrogens when explicit 56 | "rotatable": ["[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"], 57 | } 58 | 59 | def __init__( 60 | self, 61 | slicer: Optional[Union[str, List[str], Callable]] = "brics", 62 | require_hs: Optional[bool] = None, 63 | use_original_opener_for_attach: bool = True, 64 | ignore_stereo: bool = False, 65 | ): 66 | """Constructor for the SAFE converter 67 | 68 | Args: 69 | slicer: slicer algorithm to use for encoding. 70 | Can either be one of the supported slicing algorithm (SUPPORTED_SLICERS) 71 | or a custom callable that returns the bond ids that can be sliced. 72 | require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. 73 | `attach` slicer requires adding hydrogens. 74 | use_original_opener_for_attach: whether to use the original branch opener digit when adding back 75 | mapping number to attachment points, or use simple enumeration. 76 | ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. 77 | 78 | """ 79 | self.slicer = slicer 80 | if isinstance(slicer, str) and slicer.lower() in self.SUPPORTED_SLICERS: 81 | self.slicer = self.__SLICE_SMARTS.get(slicer.lower(), slicer) 82 | if self.slicer != "brics" and isinstance(self.slicer, str): 83 | self.slicer = [self.slicer] 84 | if isinstance(self.slicer, (list, tuple)): 85 | self.slicer = [dm.from_smarts(x) for x in self.slicer] 86 | if any(x is None for x in self.slicer): 87 | raise ValueError(f"Slicer: {slicer} cannot be valid") 88 | self.require_hs = require_hs or (slicer == "attach") 89 | self.use_original_opener_for_attach = use_original_opener_for_attach 90 | self.ignore_stereo = ignore_stereo 91 | 92 | @staticmethod 93 | def randomize(mol: dm.Mol, rng: Optional[int] = None): 94 | """Randomize the position of the atoms in a mol. 95 | 96 | Args: 97 | mol: molecules to randomize 98 | rng: optional seed to use 99 | """ 100 | if isinstance(rng, int): 101 | rng = np.random.default_rng(rng) 102 | if mol.GetNumAtoms() == 0: 103 | return mol 104 | atom_indices = list(range(mol.GetNumAtoms())) 105 | atom_indices = rng.permutation(atom_indices).tolist() 106 | return Chem.RenumberAtoms(mol, atom_indices) 107 | 108 | @classmethod 109 | def _find_branch_number(cls, inp: str): 110 | """Find the branch number and ring closure in the SMILES representation using regexp 111 | 112 | Args: 113 | inp: input smiles 114 | """ 115 | inp = re.sub(r"\[.*?\]", "", inp) # noqa 116 | matching_groups = re.findall(r"((?<=%)\d{2})|((? 0: 305 | mol = Chem.FragmentOnBonds( 306 | mol, 307 | bonds, 308 | dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))], 309 | ) 310 | # here we need to be clever and disable rooted atom as the atom with mapping 311 | 312 | frags = list(Chem.GetMolFrags(mol, asMols=True)) 313 | if randomize: 314 | frags = rng.permutation(frags).tolist() 315 | elif canonical: 316 | frags = sorted( 317 | frags, 318 | key=lambda x: x.GetNumAtoms(), 319 | reverse=True, 320 | ) 321 | 322 | frags_str = [] 323 | for frag in frags: 324 | non_map_atom_idxs = [ 325 | atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0 326 | ] 327 | frags_str.append( 328 | Chem.MolToSmiles( 329 | frag, 330 | isomericSmiles=True, 331 | canonical=True, # needs to always be true 332 | rootedAtAtom=non_map_atom_idxs[0], 333 | ) 334 | ) 335 | 336 | scaffold_str = ".".join(frags_str) 337 | # EN: fix for https://github.com/datamol-io/safe/issues/37 338 | # we were using the wrong branch number count which did not take into account 339 | # possible change in digit utilization after bond slicing 340 | scf_branch_num = self._find_branch_number(scaffold_str) + branch_numbers 341 | 342 | # don't capture atom mapping in the scaffold 343 | attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str)) 344 | if canonical: 345 | attach_pos = sorted(attach_pos) 346 | starting_num = 1 if len(scf_branch_num) == 0 else max(scf_branch_num) + 1 347 | for attach in attach_pos: 348 | val = str(starting_num) if starting_num < 10 else f"%{starting_num}" 349 | # we cannot have anything of the form "\([@=-#-$/\]*\d+\)" 350 | attach_regexp = re.compile(r"(" + re.escape(attach) + r")") 351 | scaffold_str = attach_regexp.sub(val, scaffold_str) 352 | starting_num += 1 353 | 354 | # now we need to remove all the parenthesis around digit only number 355 | wrong_attach = re.compile(r"\(([\%\d]*)\)") 356 | scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str) 357 | # furthermore, we autoapply rdkit-compatible digit standardization. 358 | if rdkit_safe: 359 | pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)" 360 | replacement = r"\g<1>\g<2>" 361 | scaffold_str = re.sub(pattern, replacement, scaffold_str) 362 | if not self.ignore_stereo and has_stereo_bonds and not dm.same_mol(scaffold_str, inp): 363 | logger.warning( 364 | "Ignoring stereo is disabled, but molecule has stereochemistry interferring with SAFE representation" 365 | ) 366 | return scaffold_str 367 | 368 | 369 | def encode( 370 | inp: Union[str, dm.Mol], 371 | canonical: bool = True, 372 | randomize: Optional[bool] = False, 373 | seed: Optional[int] = None, 374 | slicer: Optional[Union[List[str], str, Callable]] = None, 375 | require_hs: Optional[bool] = None, 376 | constraints: Optional[List[dm.Mol]] = None, 377 | ignore_stereo: Optional[bool] = False, 378 | ): 379 | """ 380 | Convert input smiles to SAFE representation 381 | 382 | Args: 383 | inp: input smiles 384 | canonical: whether to return canonical SAFE string. Defaults to True 385 | randomize: whether to randomize the safe string encoding. Will be ignored if canonical is provided 386 | seed: optional seed to use when allowing randomization of the SAFE encoding. 387 | slicer: slicer algorithm to use for encoding. Defaults to "brics". 388 | require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added. 389 | constraints: List of molecules or pattern to preserve during the SAFE construction. 390 | ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined. 391 | """ 392 | if slicer is None: 393 | slicer = "brics" 394 | with dm.without_rdkit_log(): 395 | safe_obj = SAFEConverter(slicer=slicer, require_hs=require_hs, ignore_stereo=ignore_stereo) 396 | try: 397 | encoded = safe_obj.encoder( 398 | inp, 399 | canonical=canonical, 400 | randomize=randomize, 401 | constraints=constraints, 402 | seed=seed, 403 | ) 404 | except SAFEFragmentationError as e: 405 | raise e 406 | except Exception as e: 407 | raise SAFEEncodeError(f"Failed to encode {inp} with {slicer}") from e 408 | return encoded 409 | 410 | 411 | def decode( 412 | safe_str: str, 413 | as_mol: bool = False, 414 | canonical: bool = False, 415 | fix: bool = True, 416 | remove_added_hs: bool = True, 417 | remove_dummies: bool = True, 418 | ignore_errors: bool = False, 419 | ): 420 | """Convert input SAFE representation to smiles 421 | Args: 422 | safe_str: input SAFE representation to decode as a valid molecule or smiles 423 | as_mol: whether to return a molecule object or a smiles string 424 | canonical: whether to return a canonical smiles or a randomized smiles 425 | fix: whether to fix the SAFE representation to take into account non-connected attachment points 426 | remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string. 427 | remove_dummies: whether to remove dummy atoms from the SAFE representation 428 | ignore_errors: whether to ignore error and return None on decoding failure or raise an error 429 | 430 | """ 431 | with dm.without_rdkit_log(): 432 | safe_obj = SAFEConverter() 433 | try: 434 | decoded = safe_obj.decoder( 435 | safe_str, 436 | as_mol=as_mol, 437 | canonical=canonical, 438 | fix=fix, 439 | remove_dummies=remove_dummies, 440 | remove_added_hs=remove_added_hs, 441 | ) 442 | 443 | except Exception as e: 444 | if ignore_errors: 445 | return None 446 | raise SAFEDecodeError(f"Failed to decode {safe_str}") from e 447 | return decoded 448 | -------------------------------------------------------------------------------- /safe/io.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import tempfile 4 | import os 5 | import contextlib 6 | import torch 7 | import wandb 8 | import fsspec 9 | 10 | from transformers import PreTrainedModel, is_torch_available 11 | from transformers.processing_utils import PushToHubMixin 12 | 13 | 14 | def upload_to_wandb( 15 | model: PreTrainedModel, 16 | tokenizer, 17 | artifact_name: str, 18 | wandb_project_name: Optional[str] = "safe-models", 19 | artifact_type: str = "model", 20 | slicer: Optional[str] = None, 21 | aliases: Optional[List[str]] = None, 22 | **init_args, 23 | ): 24 | """ 25 | Uploads a model and tokenizer to a specified Weights and Biases (wandb) project. 26 | 27 | Args: 28 | model (PreTrainedModel): The model to be uploaded (instance of PreTrainedModel). 29 | tokenizer: The tokenizer associated with the model. 30 | artifact_name (str): The name of the wandb artifact to create. 31 | wandb_project_name (Optional[str]): The name of the wandb project. Defaults to 'safe-model'. 32 | artifact_type (str): The type of artifact (e.g., 'model'). Defaults to 'model'. 33 | slicer (Optional[str]): Optional metadata field that can store a slicing method. 34 | aliases (Optional[List[str]]): List of aliases to assign to this artifact version. 35 | **init_args: Additional arguments to pass into `wandb.init()`. 36 | """ 37 | 38 | with tempfile.TemporaryDirectory() as tmpdirname: 39 | # Paths to save model and tokenizer 40 | model_path = tokenizer_path = tmpdirname 41 | architecture_file = os.path.join(tmpdirname, "architecture.txt") 42 | with fsspec.open(architecture_file, "w+") as f: 43 | f.write(str(model)) 44 | 45 | model.save_pretrained(model_path) 46 | with contextlib.suppress(Exception): 47 | tokenizer.save_pretrained(tokenizer_path) 48 | tokenizer.save(os.path.join(tokenizer_path, "tokenizer.json")) 49 | 50 | info_dict = {"slicer": slicer} 51 | model_config = None 52 | if hasattr(model, "config") and model.config is not None: 53 | model_config = ( 54 | model.config.to_dict() if not isinstance(model.config, dict) else model.config 55 | ) 56 | info_dict.update(model_config) 57 | 58 | if hasattr(model, "peft_config") and model.peft_config is not None: 59 | info_dict.update({"peft_config": model.peft_config}) 60 | 61 | with contextlib.suppress(Exception): 62 | info_dict["model/num_parameters"] = model.num_parameters() 63 | 64 | init_args.setdefault("config", info_dict) 65 | run = wandb.init(project=os.getenv("SAFE_WANDB_PROJECT", wandb_project_name), **init_args) 66 | 67 | artifact = wandb.Artifact( 68 | name=artifact_name, 69 | type=artifact_type, 70 | metadata={ 71 | "model_config": model_config, 72 | "num_parameters": info_dict.get("model/num_parameters"), 73 | "initial_model": True, 74 | }, 75 | ) 76 | 77 | # Add model and tokenizer directories to the artifact 78 | artifact.add_dir(tmpdirname) 79 | run.log_artifact(artifact, aliases=aliases) 80 | 81 | # Finish the wandb run 82 | run.finish() 83 | -------------------------------------------------------------------------------- /safe/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .. import utils 2 | from . import model 3 | -------------------------------------------------------------------------------- /safe/trainer/cli.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | import uuid 5 | from dataclasses import dataclass, field 6 | from typing import Literal, Optional 7 | 8 | import datasets 9 | import evaluate 10 | import torch 11 | import transformers 12 | from loguru import logger 13 | from transformers import AutoConfig, AutoTokenizer, TrainingArguments, set_seed 14 | from transformers.trainer_utils import get_last_checkpoint 15 | from transformers.utils.logging import log_levels as LOG_LEVELS 16 | 17 | import safe 18 | from safe.tokenizer import SAFETokenizer 19 | from safe.trainer.collator import SAFECollator 20 | from safe.trainer.data_utils import get_dataset 21 | from safe.trainer.model import SAFEDoubleHeadsModel 22 | from safe.trainer.trainer_utils import SAFETrainer 23 | 24 | CURRENT_DIR = os.path.join(safe.__path__[0], "trainer") 25 | 26 | 27 | @dataclass 28 | class ModelArguments: 29 | model_path: str = field( 30 | default=None, 31 | metadata={ 32 | "help": "Optional model path or model name to use as a starting point for the safe model" 33 | }, 34 | ) 35 | config: Optional[str] = field( 36 | default=None, metadata={"help": "Path to the default config file to use for the safe model"} 37 | ) 38 | 39 | tokenizer: str = ( 40 | field( 41 | default=None, 42 | metadata={"help": "Path to the trained tokenizer to use to build a safe model"}, 43 | ), 44 | ) 45 | num_labels: Optional[int] = field( 46 | default=None, metadata={"help": "Optional number of labels for the descriptors"} 47 | ) 48 | include_descriptors: Optional[bool] = field( 49 | default=False, 50 | metadata={"help": "Whether to train with descriptors if they are available or Not"}, 51 | ) 52 | prop_loss_coeff: Optional[float] = field( 53 | default=1e-2, metadata={"help": "coefficient for the propery loss"} 54 | ) 55 | model_hub_name: Optional[str] = field( 56 | default="maclandrol/safe-gpt2", 57 | metadata={"help": "Name of the model when uploading to huggingface"}, 58 | ) 59 | 60 | wandb_project: Optional[str] = field( 61 | default="safe-gpt2", 62 | metadata={"help": "Name of the wandb project to use to log the SAFE model parameter"}, 63 | ) 64 | wandb_watch: Optional[Literal["gradients", "all"]] = field( 65 | default=None, metadata={"help": "Whether to watch the wandb models or not"} 66 | ) 67 | cache_dir: Optional[str] = field( 68 | default=None, 69 | metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}, 70 | ) 71 | torch_dtype: Optional[str] = field( 72 | default=None, 73 | metadata={ 74 | "help": ( 75 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 76 | "dtype will be automatically derived from the model's weights." 77 | ), 78 | "choices": ["auto", "bfloat16", "float16", "float32"], 79 | }, 80 | ) 81 | 82 | low_cpu_mem_usage: bool = field( 83 | default=False, 84 | metadata={ 85 | "help": ( 86 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." 87 | "set True will benefit LLM loading time and RAM consumption. Only valid when loading a pretrained model" 88 | ) 89 | }, 90 | ) 91 | model_max_length: int = field( 92 | default=1024, 93 | metadata={ 94 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated) up to that value." 95 | }, 96 | ) 97 | 98 | 99 | @dataclass 100 | class DataArguments: 101 | dataset: str = field( 102 | default=None, 103 | metadata={"help": "Path to the preprocessed dataset to use for the safe model building"}, 104 | ) 105 | is_tokenized: Optional[bool] = field( 106 | default=False, 107 | metadata={"help": "whether the dataset submitted as input is already tokenized or not"}, 108 | ) 109 | 110 | streaming: Optional[bool] = field( 111 | default=False, metadata={"help": "Whether to use a streaming dataset or not"} 112 | ) 113 | 114 | text_column: Optional[str] = field( 115 | default="inputs", metadata={"help": "Column containing text data to process."} 116 | ) 117 | 118 | max_train_samples: Optional[int] = field( 119 | default=None, metadata={"help": "Maximum number of training sample to use."} 120 | ) 121 | 122 | max_eval_samples: Optional[int] = field( 123 | default=None, metadata={"help": "Maximum number of evaluation sample to use."} 124 | ) 125 | 126 | train_split: Optional[str] = field( 127 | default="train", 128 | metadata={ 129 | "help": "Training splits to use. You can use train+validation for example to include both train and validation split in the training" 130 | }, 131 | ) 132 | 133 | property_column: Optional[str] = field( 134 | default=None, 135 | metadata={ 136 | "help": "Column containing the descriptors information. Default to None to use `mc_labels`" 137 | }, 138 | ) 139 | 140 | 141 | def train(model_args, data_args, training_args): 142 | """Train a new model from scratch""" 143 | if training_args.should_log: 144 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 145 | transformers.utils.logging.set_verbosity_info() 146 | 147 | log_level = training_args.get_process_log_level() 148 | if log_level is None or log_level == "passive": 149 | log_level = "info" 150 | 151 | _LOG_LEVEL = {v: k for k, v in LOG_LEVELS.items()} 152 | logger.remove() 153 | logger.add(sys.stderr, level=_LOG_LEVEL.get(log_level, log_level).upper()) 154 | transformers.utils.logging.set_verbosity(log_level) 155 | transformers.utils.logging.enable_default_handler() 156 | transformers.utils.logging.enable_explicit_format() 157 | 158 | # Log on each process the small summary: 159 | logger.info( 160 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu} " 161 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" 162 | ) 163 | 164 | logger.info(f"Training/evaluation parameters: {training_args}") 165 | 166 | # Detecting last checkpoint. 167 | last_checkpoint = None 168 | if ( 169 | os.path.isdir(training_args.output_dir) 170 | and training_args.do_train 171 | and not training_args.overwrite_output_dir 172 | ): 173 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 174 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 175 | logger.info( 176 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 177 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 178 | ) 179 | 180 | # Check if parameter passed or if set within environ 181 | # Only overwrite environ if wandb param passed 182 | wandb_run_name = f"safe-model-{uuid.uuid4().hex[:8]}" 183 | if model_args.wandb_project: 184 | os.environ["WANDB_PROJECT"] = model_args.wandb_project 185 | training_args.report_to = ["wandb"] 186 | if model_args.wandb_watch: 187 | os.environ["WANDB_WATCH"] = model_args.wandb_watch 188 | if model_args.wandb_watch == "all": 189 | os.environ["WANDB_LOG_MODEL"] = "end" 190 | 191 | training_args.run_name = wandb_run_name 192 | training_args.remove_unused_columns = False 193 | # load tokenizer and model 194 | 195 | set_seed(training_args.seed) 196 | # load the tokenizer 197 | if model_args.tokenizer.endswith(".json"): 198 | tokenizer = SAFETokenizer.load(model_args.tokenizer) 199 | else: 200 | try: 201 | tokenizer = SAFETokenizer.load(model_args.tokenizer) 202 | except: 203 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer) 204 | 205 | # load dataset 206 | with training_args.main_process_first(): 207 | # if the dataset is streaming we tokenize on the fly, it would be faster 208 | dataset = get_dataset( 209 | data_args.dataset, 210 | tokenizer=(None if data_args.is_tokenized or data_args.streaming else tokenizer), 211 | streaming=data_args.streaming, 212 | tokenize_column=data_args.text_column, 213 | max_length=model_args.model_max_length, 214 | property_column=data_args.property_column, 215 | ) 216 | 217 | if data_args.max_train_samples is not None: 218 | dataset["train"] = dataset["train"].take(data_args.max_train_samples) 219 | if data_args.max_eval_samples is not None: 220 | for k in dataset: 221 | if k != "train": 222 | dataset[k] = dataset[k].take(data_args.max_eval_samples) 223 | 224 | eval_dataset_key_name = "validation" if "validation" in dataset else "test" 225 | 226 | train_dataset = dataset["train"] 227 | if data_args.train_split is not None: 228 | train_dataset = datasets.concatenate_datasets( 229 | [dataset[x] for x in data_args.train_split.split("+")] 230 | ) 231 | if eval_dataset_key_name in data_args.train_split.split("+"): 232 | eval_dataset_key_name = None 233 | 234 | data_collator = SAFECollator( 235 | tokenizer=tokenizer, 236 | input_key=data_args.text_column, 237 | max_length=model_args.model_max_length, 238 | include_descriptors=model_args.include_descriptors, 239 | property_key="mc_labels", 240 | ) 241 | pretrained_tokenizer = data_collator.get_tokenizer() 242 | config = model_args.config 243 | 244 | if config is None: 245 | config = os.path.join(CURRENT_DIR, "configs/default_config.json") 246 | config = AutoConfig.from_pretrained(config, cache_dir=model_args.cache_dir) 247 | 248 | if model_args.num_labels is not None: 249 | config.num_labels = int(model_args.num_labels) 250 | 251 | config.vocab_size = len(tokenizer) 252 | if model_args.model_max_length is not None: 253 | config.max_position_embeddings = model_args.model_max_length 254 | try: 255 | config.bos_token_id = tokenizer.bos_token_id 256 | config.eos_token_id = tokenizer.eos_token_id 257 | config.pad_token_id = tokenizer.pad_token_id 258 | except: 259 | config.bos_token_id = pretrained_tokenizer.bos_token_id 260 | config.eos_token_id = pretrained_tokenizer.eos_token_id 261 | config.pad_token_id = pretrained_tokenizer.pad_token_id 262 | 263 | if model_args.model_path is not None: 264 | torch_dtype = ( 265 | model_args.torch_dtype 266 | if model_args.torch_dtype in ["auto", None] 267 | else getattr(torch, model_args.torch_dtype) 268 | ) 269 | model = SAFEDoubleHeadsModel.from_pretrained( 270 | model_args.model_path, 271 | config=config, 272 | cache_dir=model_args.cache_dir, 273 | low_cpu_mem_usage=model_args.low_cpu_mem_usage, 274 | torch_dtype=torch_dtype, 275 | ) 276 | 277 | else: 278 | model = SAFEDoubleHeadsModel(config) 279 | 280 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 281 | # on a small vocab and want a smaller embedding size, remove this test. 282 | embedding_size = model.get_input_embeddings().weight.shape[0] 283 | if len(tokenizer) > embedding_size: 284 | model.resize_token_embeddings(len(tokenizer)) 285 | 286 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 287 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") 288 | 289 | def preprocess_logits_for_metrics(logits, labels): 290 | prop_logits = None 291 | if isinstance(logits, tuple): 292 | # Depending on the model and config, logits may contain extra tensors, 293 | # like past_key_values, but logits always come first 294 | if len(logits) > 1: 295 | # we could have the loss twice 296 | base_ind = 0 297 | if logits[base_ind].ndim < 2: 298 | base_ind = 1 299 | label_logits = logits[base_ind].argmax(dim=-1) 300 | if len(logits) > base_ind + 1: 301 | prop_logits = logits[base_ind + 1] 302 | else: 303 | label_logits = logits.argmax(dim=-1) 304 | 305 | if prop_logits is not None: 306 | return label_logits, prop_logits 307 | return label_logits 308 | 309 | accuracy_metric = evaluate.load("accuracy") 310 | mse_metric = evaluate.load("mse") 311 | 312 | def compute_metrics(eval_preds): 313 | preds, labels = eval_preds 314 | mc_preds = None 315 | mc_labels = None 316 | if isinstance(preds, tuple): 317 | preds, mc_preds = preds 318 | if isinstance(labels, tuple): 319 | labels, mc_labels = labels 320 | # preds have the same shape as the labels, after the argmax(-1) has been calculated 321 | # by preprocess_logits_for_metrics but we need to shift the labels 322 | labels = labels[:, 1:].reshape(-1) 323 | preds = preds[:, :-1].reshape(-1) 324 | results = accuracy_metric.compute(predictions=preds, references=labels) 325 | if mc_preds is not None and mc_labels is not None: 326 | results_mse = mse_metric.compute( 327 | predictions=mc_preds.reshape(-1), references=mc_labels.reshape(-1) 328 | ) 329 | results.update(results_mse) 330 | return results 331 | 332 | if model_args.include_descriptors: 333 | training_args.label_names = ["labels", "mc_labels"] 334 | 335 | if training_args.label_names is None: 336 | training_args.label_names = ["labels"] 337 | # update dispatch_batches in accelerator 338 | training_args.accelerator_config.dispatch_batches = data_args.streaming is not True 339 | 340 | trainer = SAFETrainer( 341 | model=model, 342 | tokenizer=None, # we don't deal with the tokenizer at all, https://github.com/huggingface/tokenizers/issues/581 -_- 343 | train_dataset=train_dataset.shuffle(seed=(training_args.seed or 42)), 344 | eval_dataset=dataset.get(eval_dataset_key_name, None), 345 | args=training_args, 346 | prop_loss_coeff=model_args.prop_loss_coeff, 347 | compute_metrics=compute_metrics if training_args.do_eval else None, 348 | data_collator=data_collator, 349 | preprocess_logits_for_metrics=( 350 | preprocess_logits_for_metrics if training_args.do_eval else None 351 | ), 352 | ) 353 | 354 | if training_args.do_train: 355 | checkpoint = None 356 | if training_args.resume_from_checkpoint is not None: 357 | checkpoint = training_args.resume_from_checkpoint 358 | elif last_checkpoint is not None: 359 | checkpoint = last_checkpoint 360 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 361 | try: 362 | # we were unable to save the model because of the tokenizer 363 | trainer.save_model() # Saves the tokenizer too for easy upload 364 | except: 365 | model.save_pretrained(os.path.join(training_args.output_dir, "safe-model")) 366 | 367 | if training_args.push_to_hub and model_args.model_hub_name: 368 | model.push_to_hub(model_args.model_hub_name, private=True, safe_serialization=True) 369 | 370 | metrics = train_result.metrics 371 | 372 | trainer.log_metrics("train", metrics) 373 | trainer.save_metrics("train", metrics) 374 | trainer.save_state() 375 | 376 | # For convenience, we also re-save the tokenizer to the same directory, 377 | # so that you can share your model easily on huggingface.co/models =) 378 | if trainer.is_world_process_zero(): 379 | tokenizer.save(os.path.join(training_args.output_dir, "tokenizer.json")) 380 | 381 | # Evaluation 382 | if training_args.do_eval: 383 | logger.info("*** Evaluate ***") 384 | results = trainer.evaluate() 385 | try: 386 | perplexity = math.exp(results["eval_loss"]) 387 | except Exception as e: 388 | logger.error(e) 389 | perplexity = float("inf") 390 | results.update({"perplexity": perplexity}) 391 | if trainer.is_world_process_zero(): 392 | trainer.log_metrics("eval", results) 393 | trainer.save_metrics("eval", results) 394 | 395 | kwargs = {"finetuned_from": model_args.model_path, "tasks": "text-generation"} 396 | kwargs["dataset_tags"] = data_args.dataset 397 | kwargs["dataset"] = data_args.dataset 398 | kwargs["tags"] = ["safe", "datamol-io", "molecule-design", "smiles"] 399 | 400 | if training_args.push_to_hub: 401 | kwargs["private"] = True 402 | trainer.push_to_hub(**kwargs) 403 | else: 404 | trainer.create_model_card(**kwargs) 405 | 406 | 407 | def main(): 408 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 409 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 410 | train(model_args, data_args, training_args) 411 | 412 | 413 | if __name__ == "__main__": 414 | main() 415 | -------------------------------------------------------------------------------- /safe/trainer/collator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | from collections.abc import Mapping 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | import torch 7 | from tokenizers import Tokenizer 8 | from transformers.data.data_collator import _torch_collate_batch 9 | 10 | from safe.tokenizer import SAFETokenizer 11 | 12 | 13 | class SAFECollator: 14 | """Collate function for language modelling tasks 15 | 16 | 17 | !!! note 18 | The collate function is based on the default DataCollatorForLanguageModeling in huggingface 19 | see: https://github.com/huggingface/transformers/blob/v4.19.2/src/transformers/data/data_collator.py 20 | """ 21 | 22 | def __init__( 23 | self, 24 | tokenizer: Tokenizer, 25 | pad_to_multiple_of: Optional[int] = None, 26 | input_key: str = "inputs", 27 | label_key: str = "labels", 28 | property_key: str = "descriptors", 29 | include_descriptors: bool = False, 30 | max_length: Optional[int] = None, 31 | ): 32 | """ 33 | Default collator for huggingface transformers in izanagi. 34 | 35 | Args: 36 | tokenizer: Huggingface tokenizer 37 | input_key: key to use for input ids 38 | label_key: key to use for labels 39 | property_key: key to use for properties 40 | include_descriptors: whether to include training on descriptors or not 41 | pad_to_multiple_of: pad to multiple of this value 42 | """ 43 | 44 | self.tokenizer = tokenizer 45 | self.pad_to_multiple_of = pad_to_multiple_of 46 | self.input_key = input_key 47 | self.label_key = label_key 48 | self.property_key = property_key 49 | self.include_descriptors = include_descriptors 50 | self.max_length = max_length 51 | 52 | @functools.lru_cache() 53 | def get_tokenizer(self): 54 | """Get underlying tokenizer""" 55 | if isinstance(self.tokenizer, SAFETokenizer): 56 | return self.tokenizer.get_pretrained() 57 | return self.tokenizer 58 | 59 | def __call__(self, samples: List[Union[List[int], Any, Dict[str, Any]]]): 60 | """ 61 | Call collate function 62 | 63 | Args: 64 | samples: list of examples 65 | """ 66 | # Handle dict or lists with proper padding and conversion to tensor. 67 | tokenizer = self.get_tokenizer() 68 | 69 | # examples = samples 70 | examples = copy.deepcopy(samples) 71 | inputs = [example.pop(self.input_key, None) for example in examples] 72 | mc_labels = ( 73 | torch.tensor([example.pop(self.property_key, None) for example in examples]).float() 74 | if self.property_key in examples[0] 75 | else None 76 | ) 77 | 78 | if "input_ids" not in examples[0] and inputs is not None: 79 | batch = tokenizer( 80 | inputs, 81 | return_tensors="pt", 82 | padding=True, 83 | truncation=True, 84 | max_length=self.max_length, 85 | pad_to_multiple_of=self.pad_to_multiple_of, 86 | ) 87 | else: 88 | batch = tokenizer.pad( 89 | examples, 90 | return_tensors="pt", 91 | padding=True, 92 | pad_to_multiple_of=self.pad_to_multiple_of, 93 | max_length=self.max_length, 94 | ) 95 | 96 | # If special token mask has been preprocessed, pop it from the dict. 97 | batch.pop("special_tokens_mask", None) 98 | labels = batch.get(self.label_key, batch["input_ids"].clone()) 99 | if tokenizer.pad_token_id is not None: 100 | labels[labels == tokenizer.pad_token_id] = -100 101 | batch[self.label_key] = labels 102 | 103 | if mc_labels is not None and self.include_descriptors: 104 | batch.update( 105 | { 106 | "mc_labels": mc_labels, 107 | # "input_text": inputs, 108 | } 109 | ) 110 | return batch 111 | -------------------------------------------------------------------------------- /safe/trainer/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamol-io/safe/84a4697e6d89792fe7c870fe65b3f43e28191725/safe/trainer/configs/__init__.py -------------------------------------------------------------------------------- /safe/trainer/configs/default_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "attn_pdrop": 0.1, 4 | "bos_token_id": 10000, 5 | "embd_pdrop": 0.1, 6 | "eos_token_id": 1, 7 | "initializer_range": 0.02, 8 | "layer_norm_epsilon": 1e-05, 9 | "model_type": "gpt2", 10 | "n_embd": 768, 11 | "n_head": 12, 12 | "n_inner": null, 13 | "n_layer": 12, 14 | "n_positions": 1024, 15 | "reorder_and_upcast_attn": false, 16 | "resid_pdrop": 0.1, 17 | "scale_attn_by_inverse_layer_idx": false, 18 | "scale_attn_weights": true, 19 | "summary_activation": "relu", 20 | "summary_first_dropout": 0.1, 21 | "summary_proj_to_labels": true, 22 | "summary_hidden_size": 64, 23 | "summary_type": "cls_index", 24 | "summary_use_proj": true, 25 | "transformers_version": "4.31.0", 26 | "use_cache": true, 27 | "vocab_size": 10000, 28 | "num_labels": 18 29 | } -------------------------------------------------------------------------------- /safe/trainer/data_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections.abc import Mapping 3 | from functools import partial 4 | from typing import Any, Callable, Dict, Optional, Union 5 | 6 | import datasets 7 | import upath 8 | from tqdm.auto import tqdm 9 | 10 | from safe.tokenizer import SAFETokenizer 11 | 12 | 13 | def take(n, iterable): 14 | "Return first n items of the iterable as a list" 15 | return list(itertools.islice(iterable, n)) 16 | 17 | 18 | def get_dataset_column_names(dataset: Union[datasets.Dataset, datasets.IterableDataset, Mapping]): 19 | """Get the column names in a dataset 20 | 21 | Args: 22 | dataset: dataset to get the column names from 23 | 24 | """ 25 | if isinstance(dataset, (datasets.IterableDatasetDict, Mapping)): 26 | column_names = {split: dataset[split].column_names for split in dataset} 27 | else: 28 | column_names = dataset.column_names 29 | if isinstance(column_names, dict): 30 | column_names = list(column_names.values())[0] 31 | return column_names 32 | 33 | 34 | def tokenize_fn( 35 | row: Dict[str, Any], 36 | tokenizer: Callable, 37 | tokenize_column: str = "inputs", 38 | max_length: Optional[int] = None, 39 | padding: bool = False, 40 | ): 41 | """Perform the tokenization of a row 42 | Args: 43 | row: row to tokenize 44 | tokenizer: tokenizer to use 45 | tokenize_column: column to tokenize 46 | max_length: maximum size of the tokenized sequence 47 | padding: whether to pad the sequence 48 | """ 49 | # there's probably a way to do this with the tokenizer settings 50 | # but again, gotta move fast 51 | 52 | fast_tokenizer = ( 53 | tokenizer.get_pretrained() if isinstance(tokenizer, SAFETokenizer) else tokenizer 54 | ) 55 | 56 | return fast_tokenizer( 57 | row[tokenize_column], 58 | truncation=(max_length is not None), 59 | max_length=max_length, 60 | padding=padding, 61 | return_tensors=None, 62 | ) 63 | 64 | 65 | def batch_iterator(datasets, batch_size=100, n_examples=None, column="inputs"): 66 | if isinstance(datasets, Mapping): 67 | datasets = list(datasets.values()) 68 | 69 | if not isinstance(datasets, (list, tuple)): 70 | datasets = [datasets] 71 | 72 | for dataset in datasets: 73 | iter_dataset = iter(dataset) 74 | if n_examples is not None and n_examples > 0: 75 | for _ in tqdm(range(0, n_examples, batch_size)): 76 | out = [next(iter_dataset)[column] for _ in range(batch_size)] 77 | yield out 78 | else: 79 | for out in tqdm(iter(partial(take, batch_size, iter_dataset), [])): 80 | yield [x[column] for x in out] 81 | 82 | 83 | def get_dataset( 84 | data_path, 85 | name: Optional[str] = None, 86 | tokenizer: Optional[Callable] = None, 87 | cache_dir: Optional[str] = None, 88 | streaming: bool = True, 89 | use_auth_token: bool = False, 90 | tokenize_column: Optional[str] = "inputs", 91 | property_column: Optional[str] = "descriptors", 92 | max_length: Optional[int] = None, 93 | num_shards=1024, 94 | ): 95 | """Get the datasets from the config file""" 96 | raw_datasets = {} 97 | if data_path is not None: 98 | data_path = upath.UPath(str(data_path)) 99 | 100 | if data_path.exists(): 101 | # then we need to load from disk 102 | data_path = str(data_path) 103 | # for some reason, the datasets package is not able to load the dataset 104 | # because the split where not originally proposed 105 | raw_datasets = datasets.load_from_disk(data_path) 106 | 107 | if streaming: 108 | if isinstance(raw_datasets, datasets.DatasetDict): 109 | previous_num_examples = {k: len(dt) for k, dt in raw_datasets.items()} 110 | raw_datasets = datasets.IterableDatasetDict( 111 | { 112 | k: dt.to_iterable_dataset(num_shards=num_shards) 113 | for k, dt in raw_datasets.items() 114 | } 115 | ) 116 | for k, dt in raw_datasets.items(): 117 | if previous_num_examples[k] is not None: 118 | setattr(dt, "num_examples", previous_num_examples[k]) 119 | else: 120 | num_examples = len(raw_datasets) 121 | raw_datasets = raw_datasets.to_iterable_dataset(num_shards=num_shards) 122 | setattr(raw_datasets, "num_examples", num_examples) 123 | 124 | else: 125 | data_path = str(data_path) 126 | raw_datasets = datasets.load_dataset( 127 | data_path, 128 | name=name, 129 | cache_dir=cache_dir, 130 | use_auth_token=True if use_auth_token else None, 131 | streaming=streaming, 132 | ) 133 | # that means we need to return a tokenized version of the dataset 134 | 135 | if property_column not in ["mc_labels", None]: 136 | raw_datasets = raw_datasets.rename_column(property_column, "mc_labels") 137 | 138 | columns_to_remove = None 139 | if tokenize_column is not None: 140 | columns_to_remove = [ 141 | x 142 | for x in (get_dataset_column_names(raw_datasets) or []) 143 | if x not in [tokenize_column, "mc_labels"] and "label" not in x 144 | ] or None 145 | 146 | if tokenizer is None: 147 | if columns_to_remove is not None: 148 | raw_datasets = raw_datasets.remove_columns(columns_to_remove) 149 | return raw_datasets 150 | 151 | return raw_datasets.map( 152 | partial( 153 | tokenize_fn, 154 | tokenizer=tokenizer, 155 | tokenize_column=tokenize_column, 156 | max_length=max_length, 157 | ), 158 | batched=True, 159 | remove_columns=columns_to_remove, 160 | ) 161 | -------------------------------------------------------------------------------- /safe/trainer/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import CrossEntropyLoss, MSELoss 6 | from transformers import GPT2DoubleHeadsModel, PretrainedConfig 7 | from transformers.activations import get_activation 8 | from transformers.models.gpt2.modeling_gpt2 import ( 9 | _CONFIG_FOR_DOC, 10 | GPT2_INPUTS_DOCSTRING, 11 | GPT2DoubleHeadsModelOutput, 12 | add_start_docstrings_to_model_forward, 13 | replace_return_docstrings, 14 | ) 15 | 16 | 17 | class PropertyHead(torch.nn.Module): 18 | r""" 19 | Compute a single vector summary of a sequence hidden states. 20 | 21 | Args: 22 | config ([`PretrainedConfig`]): 23 | The config used by the model. Relevant arguments in the config class of the model are (refer to the actual 24 | config class of your model for the default values it uses): 25 | 26 | - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: 27 | 28 | - `"last"` -- Take the last token hidden state (like XLNet) 29 | - `"first"` -- Take the first token hidden state (like Bert) 30 | - `"mean"` -- Take the mean of all tokens hidden states 31 | - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) 32 | 33 | - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, 34 | another string, or `None` to add no activation. 35 | """ 36 | 37 | def __init__(self, config: PretrainedConfig): 38 | super().__init__() 39 | 40 | self.summary_type = getattr(config, "summary_type", "cls_index") 41 | self.summary = torch.nn.Identity() 42 | last_hidden_size = config.hidden_size 43 | 44 | if getattr(config, "summary_hidden_size", None) and config.summary_hidden_size > 0: 45 | self.summary = nn.Linear(config.hidden_size, config.summary_hidden_size) 46 | last_hidden_size = config.summary_hidden_size 47 | 48 | activation_string = getattr(config, "summary_activation", None) 49 | self.activation: Callable = ( 50 | get_activation(activation_string) if activation_string else nn.Identity() 51 | ) 52 | 53 | self.out = torch.nn.Identity() 54 | if getattr(config, "num_labels", None) and config.num_labels > 0: 55 | num_labels = config.num_labels 56 | self.out = nn.Linear(last_hidden_size, num_labels) 57 | 58 | def forward( 59 | self, 60 | hidden_states: torch.FloatTensor, 61 | cls_index: Optional[torch.LongTensor] = None, 62 | ) -> torch.FloatTensor: 63 | """ 64 | Compute a single vector summary of a sequence hidden states. 65 | 66 | Args: 67 | hidden_states: `torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`) 68 | The hidden states of the last layer. 69 | cls_index: `torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` 70 | where ... are optional leading dimensions of `hidden_states`, *optional* 71 | Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. 72 | 73 | Returns: 74 | `torch.FloatTensor`: The summary of the sequence hidden states. 75 | """ 76 | if self.summary_type == "last": 77 | output = hidden_states[:, -1] 78 | elif self.summary_type == "first": 79 | output = hidden_states[:, 0] 80 | elif self.summary_type == "mean": 81 | output = hidden_states.mean(dim=1) 82 | elif self.summary_type == "cls_index": 83 | # if cls_index is None: 84 | # cls_index = torch.full_like( 85 | # hidden_states[..., :1, :], 86 | # hidden_states.shape[-2] - 1, 87 | # dtype=torch.long, 88 | # ) 89 | # else: 90 | # cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) 91 | # cls_index = cls_index.expand( 92 | # (-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),) 93 | # ) 94 | 95 | # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states 96 | # output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) 97 | batch_size = hidden_states.shape[0] 98 | output = hidden_states.squeeze()[torch.arange(batch_size), cls_index] 99 | else: 100 | raise NotImplementedError 101 | 102 | output = self.summary(output) 103 | output = self.activation(output) 104 | return self.out(output) 105 | 106 | 107 | class SAFEDoubleHeadsModel(GPT2DoubleHeadsModel): 108 | """The safe model is a dual head GPT2 model with a language modeling head and an optional multi-task regression head""" 109 | 110 | def __init__(self, config): 111 | self.num_labels = getattr(config, "num_labels", None) 112 | super().__init__(config) 113 | self.config.num_labels = self.num_labels 114 | del self.multiple_choice_head 115 | self.multiple_choice_head = PropertyHead(config) 116 | 117 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 118 | @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) 119 | def forward( 120 | self, 121 | input_ids: Optional[torch.LongTensor] = None, 122 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 123 | attention_mask: Optional[torch.FloatTensor] = None, 124 | token_type_ids: Optional[torch.LongTensor] = None, 125 | position_ids: Optional[torch.LongTensor] = None, 126 | head_mask: Optional[torch.FloatTensor] = None, 127 | inputs_embeds: Optional[torch.FloatTensor] = None, 128 | mc_token_ids: Optional[torch.LongTensor] = None, 129 | labels: Optional[torch.LongTensor] = None, 130 | mc_labels: Optional[torch.LongTensor] = None, 131 | use_cache: Optional[bool] = None, 132 | output_attentions: Optional[bool] = None, 133 | output_hidden_states: Optional[bool] = None, 134 | return_dict: Optional[bool] = None, 135 | inputs: Optional[Any] = None, # do not remove because of trainer 136 | encoder_hidden_states: Optional[torch.Tensor] = None, 137 | **kwargs, 138 | ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: 139 | r""" 140 | 141 | Args: 142 | mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): 143 | Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - 144 | 1]`. 145 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 146 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 147 | `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to 148 | `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` 149 | mc_labels (`torch.LongTensor` of shape `(batch_size, n_tasks)`, *optional*): 150 | Labels for computing the supervized loss for regularization. 151 | inputs: List of inputs, put here because the trainer removes information not in signature 152 | Returns: 153 | output (GPT2DoubleHeadsModelOutput): output of the model 154 | """ 155 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 156 | transformer_outputs = self.transformer( 157 | input_ids, 158 | past_key_values=past_key_values, 159 | attention_mask=attention_mask, 160 | token_type_ids=token_type_ids, 161 | position_ids=position_ids, 162 | head_mask=head_mask, 163 | inputs_embeds=inputs_embeds, 164 | use_cache=use_cache, 165 | output_attentions=output_attentions, 166 | output_hidden_states=output_hidden_states, 167 | return_dict=return_dict, 168 | encoder_hidden_states=encoder_hidden_states, 169 | ) 170 | 171 | hidden_states = transformer_outputs[0] 172 | lm_logits = self.lm_head(hidden_states) 173 | 174 | if mc_token_ids is None and self.config.pad_token_id is not None and input_ids is not None: 175 | mc_token_ids = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to( 176 | lm_logits.device 177 | ) 178 | 179 | # Set device for model parallelism 180 | if self.model_parallel: 181 | torch.cuda.set_device(self.transformer.first_device) 182 | hidden_states = hidden_states.to(self.lm_head.weight.device) 183 | 184 | mc_loss = None 185 | mc_logits = None 186 | if mc_labels is not None and getattr(self.config, "num_labels", 0) > 0: 187 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) 188 | mc_labels = mc_labels.to(mc_logits.device) 189 | loss_fct = MSELoss() 190 | mc_loss = loss_fct( 191 | mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1, mc_logits.size(-1)) 192 | ) 193 | 194 | lm_loss = None 195 | if labels is not None: 196 | labels = labels.to(lm_logits.device) 197 | shift_logits = lm_logits[..., :-1, :].contiguous() 198 | shift_labels = labels[..., 1:].contiguous() 199 | loss_fct = CrossEntropyLoss() 200 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 201 | 202 | if not return_dict: 203 | output = (lm_logits, mc_logits) + transformer_outputs[1:] 204 | return ( 205 | lm_loss, 206 | mc_loss, 207 | ) + output 208 | 209 | return GPT2DoubleHeadsModelOutput( 210 | loss=lm_loss, 211 | mc_loss=mc_loss, 212 | logits=lm_logits, 213 | mc_logits=mc_logits, 214 | past_key_values=transformer_outputs.past_key_values, 215 | hidden_states=transformer_outputs.hidden_states, 216 | attentions=transformer_outputs.attentions, 217 | ) 218 | -------------------------------------------------------------------------------- /safe/trainer/trainer_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import Trainer 2 | from transformers.modeling_utils import unwrap_model 3 | from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES 4 | from transformers.trainer import _is_peft_model 5 | 6 | 7 | class SAFETrainer(Trainer): 8 | """ 9 | Custom trainer for training SAFE model. 10 | 11 | This custom trainer changes the loss function to support the property head 12 | 13 | """ 14 | 15 | def __init__(self, *args, prop_loss_coeff: float = 1e-3, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.prop_loss_coeff = prop_loss_coeff 18 | 19 | def compute_loss(self, model, inputs, return_outputs=False): 20 | """ 21 | How the loss is computed by Trainer. By default, all models return the loss in the first element. 22 | """ 23 | labels = ( 24 | inputs.pop("labels") if self.label_smoother is not None and "labels" in inputs else None 25 | ) 26 | outputs = model(**inputs) 27 | # Save past state if it exists 28 | # TODO: this needs to be fixed and made cleaner later. 29 | if self.args.past_index >= 0: 30 | self._past = outputs[self.args.past_index] 31 | 32 | if labels is not None: 33 | unwrapped_model = self.accelerator.unwrap_model(model) 34 | if _is_peft_model(unwrapped_model): 35 | model_name = unwrapped_model.base_model.model._get_name() 36 | else: 37 | model_name = unwrapped_model._get_name() 38 | if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): 39 | loss = self.label_smoother(outputs, labels, shift_labels=True) 40 | else: 41 | loss = self.label_smoother(outputs, labels) 42 | else: 43 | if isinstance(outputs, dict) and "loss" not in outputs: 44 | raise ValueError( 45 | "The model did not return a loss from the inputs, only the following keys: " 46 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." 47 | ) 48 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 49 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 50 | 51 | mc_loss = outputs.get("mc_loss", None) if isinstance(outputs, dict) else outputs[1] 52 | if mc_loss is not None: 53 | loss = loss + self.prop_loss_coeff * mc_loss 54 | return (loss, outputs) if return_outputs else loss 55 | -------------------------------------------------------------------------------- /safe/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import random 3 | import re 4 | from collections import deque 5 | from contextlib import contextmanager, suppress 6 | from functools import partial 7 | from itertools import combinations, compress 8 | from typing import Any, List, Optional, Tuple, Union 9 | 10 | import datamol as dm 11 | import networkx as nx 12 | import numpy as np 13 | from loguru import logger 14 | from networkx.utils import py_random_state 15 | from rdkit import Chem 16 | from rdkit.Chem import Atom, EditableMol 17 | from rdkit.Chem.rdChemReactions import ReactionFromSmarts 18 | from rdkit.Chem.rdmolops import AdjustQueryParameters, AdjustQueryProperties, ReplaceCore 19 | 20 | import safe as sf 21 | 22 | __implicit_carbon_query = dm.from_smarts("[#6;h]") 23 | __mmpa_query = dm.from_smarts("[*;!$(*=,#[!#6])]!@!=!#[*]") 24 | 25 | _SMILES_ATTACHMENT_POINT_TOKEN = "\\*" 26 | # The following allows discovering and extracting valid dummy atoms in smiles/smarts. 27 | _SMILES_ATTACHMENT_POINTS = [ 28 | # parse any * not preceeded by "[" or ":"" and not followed by "]" or ":" as attachment 29 | r"(?>([*:1][*:2].[*:3][*:4])" 51 | ) 52 | 53 | def __init__( 54 | self, 55 | shortest_linker: bool = False, 56 | min_linker_size: int = 0, 57 | require_ring_system: bool = True, 58 | verbose: bool = False, 59 | ): 60 | """ 61 | Constructor of bond slicer. 62 | 63 | Args: 64 | shortest_linker: whether to consider longuest or shortest linker. 65 | Does not have any effect when expected_head group is provided during splitting 66 | min_linker_size: minimum linker size 67 | require_ring_system: whether all fragment needs to have a ring system 68 | verbose: whether to allow verbosity in logging 69 | """ 70 | 71 | self.bond_splitters = [dm.from_smarts(x) for x in self.BOND_SPLITTERS] 72 | self.shortest_linker = shortest_linker 73 | self.min_linker_size = min_linker_size 74 | self.require_ring_system = require_ring_system 75 | self.verbose = verbose 76 | 77 | def get_ring_system(self, mol: dm.Mol): 78 | """Get the list of ring system from a molecule 79 | 80 | Args: 81 | mol: input molecule for which we are computing the ring system 82 | """ 83 | mol.UpdatePropertyCache() 84 | ri = mol.GetRingInfo() 85 | systems = [] 86 | for ring in ri.AtomRings(): 87 | ring_atoms = set(ring) 88 | cur_system = [] # keep a track of ring system 89 | for system in systems: 90 | if len(ring_atoms.intersection(system)) > 0: 91 | ring_atoms = ring_atoms.union(system) # merge ring system that overlap 92 | else: 93 | cur_system.append(system) 94 | cur_system.append(ring_atoms) 95 | systems = cur_system 96 | return systems 97 | 98 | def _bond_selection_from_max_cuts(self, bond_list: List[int], dist_mat: np.ndarray): 99 | """Select bonds based on maximum number of cuts allowed""" 100 | # for now we are just implementing to 2 max cuts algorithms 101 | if self.MAX_CUTS != 2: 102 | raise ValueError(f"Only MAX_CUTS=2 is supported, got {self.MAX_CUTS}") 103 | 104 | bond_pdist = np.full((len(bond_list), len(bond_list)), -1) 105 | for i in range(len(bond_list)): 106 | for j in range(i, len(bond_list)): 107 | # we get the minimum topological distance between bond to cut 108 | bond_pdist[i, j] = bond_pdist[j, i] = min( 109 | [dist_mat[a1, a2] for a1, a2 in itertools.product(bond_list[i], bond_list[j])] 110 | ) 111 | 112 | masked_bond_pdist = np.ma.masked_less_equal(bond_pdist, self.min_linker_size) 113 | 114 | if self.shortest_linker: 115 | return np.unravel_index(np.ma.argmin(masked_bond_pdist), bond_pdist.shape) 116 | return np.unravel_index(np.ma.argmax(masked_bond_pdist), bond_pdist.shape) 117 | 118 | def _get_bonds_to_cut(self, mol: dm.Mol): 119 | """Get possible bond to cuts 120 | 121 | Args: 122 | mol: input molecule 123 | """ 124 | # use this if you want to enumerate yourself the possible cuts 125 | 126 | ring_systems = self.get_ring_system(mol) 127 | candidate_bonds = [] 128 | ring_query = Chem.rdqueries.IsInRingQueryAtom() 129 | 130 | for query in self.bond_splitters: 131 | bonds = mol.GetSubstructMatches(query, uniquify=True) 132 | cur_unique_bonds = [set(cbond) for cbond in candidate_bonds] 133 | # do not accept bonds part of the same ring system or already known 134 | for b in bonds: 135 | bond_id = mol.GetBondBetweenAtoms(*b).GetIdx() 136 | bond_cut = Chem.GetMolFrags( 137 | Chem.FragmentOnBonds(mol, [bond_id], addDummies=False), asMols=True 138 | ) 139 | can_add = not self.require_ring_system or all( 140 | len(frag.GetAtomsMatchingQuery(ring_query)) > 0 for frag in bond_cut 141 | ) 142 | if can_add and not ( 143 | set(b) in cur_unique_bonds or any(x.issuperset(set(b)) for x in ring_systems) 144 | ): 145 | candidate_bonds.append(b) 146 | return candidate_bonds 147 | 148 | def _fragment_mol(self, mol: dm.Mol, bonds: List[dm.Bond]): 149 | """Fragment molecules on bonds and return head, linker, tail combination 150 | 151 | Args: 152 | mol: input molecule 153 | bonds: list of bonds to cut 154 | """ 155 | tmp = Chem.rdmolops.FragmentOnBonds(mol, [b.GetIdx() for b in bonds]) 156 | _frags = list(Chem.GetMolFrags(tmp, asMols=True)) 157 | # linker is the one with 2 dummy atoms 158 | linker_pos = 0 159 | for pos, _frag in enumerate(_frags): 160 | if sum([at.GetSymbol() == "*" for at in _frag.GetAtoms()]) == 2: 161 | linker_pos = pos 162 | break 163 | linker = _frags.pop(linker_pos) 164 | head, tail = _frags 165 | return (head, linker, tail) 166 | 167 | def _compute_linker_score(self, linker: dm.Mol): 168 | """Compute the score of a linker to help select between linkers""" 169 | 170 | # we need to take into account 171 | # case where we require the linker to have a ring system 172 | # case where we want the linker to be longuest or shortest 173 | 174 | # find shortest path 175 | attach1, attach2, *_ = [at.GetIdx() for at in linker.GetAtoms() if at.GetSymbol() == "*"] 176 | score = len(Chem.rdmolops.GetShortestPath(linker, attach1, attach2)) 177 | ring_query = Chem.rdqueries.IsInRingQueryAtom() 178 | linker_ring_count = len(linker.GetAtomsMatchingQuery(ring_query)) 179 | if self.require_ring_system: 180 | score *= int(linker_ring_count > 0) 181 | if score == 0: 182 | return float("inf") 183 | if not self.shortest_linker: 184 | score = 1 / score 185 | return score 186 | 187 | def __call__(self, mol: Union[dm.Mol, str], expected_head: Union[dm.Mol, str] = None): 188 | """Perform slicing of the input molecule 189 | 190 | Args: 191 | mol: input molecule 192 | expected_head: substructure that should be part of the head. 193 | The small fragment containing this substructure would be kept as head 194 | """ 195 | 196 | mol = dm.to_mol(mol) 197 | # remove salt and solution 198 | mol = dm.keep_largest_fragment(mol) 199 | Chem.rdDepictor.Compute2DCoords(mol) 200 | dist_mat = Chem.rdmolops.GetDistanceMatrix(mol) 201 | 202 | if expected_head is not None: 203 | if isinstance(expected_head, str): 204 | expected_head = dm.to_mol(expected_head) 205 | if not mol.HasSubstructMatch(expected_head): 206 | if self.verbose: 207 | logger.info( 208 | "Expected head was provided, but does not match molecules. It will be ignored" 209 | ) 210 | expected_head = None 211 | 212 | candidate_bonds = self._get_bonds_to_cut(mol) 213 | 214 | # we have all the candidate bonds we can cut 215 | # now we need to pick the most plausible bonds 216 | selected_bonds = [mol.GetBondBetweenAtoms(a1, a2) for (a1, a2) in candidate_bonds] 217 | 218 | # CASE 1: no bond to cut ==> only head 219 | if len(selected_bonds) == 0: 220 | return (mol, None, None) 221 | 222 | # CASE 2: only one bond ==> linker is empty 223 | if len(selected_bonds) == 1: 224 | # there is not linker 225 | tmp = Chem.rdmolops.FragmentOnBonds(mol, [b.GetIdx() for b in selected_bonds]) 226 | head, tail = Chem.GetMolFrags(tmp, asMols=True) 227 | return (head, None, tail) 228 | 229 | # CASE 3a: we select the most plausible bond to cut on ourselves 230 | if expected_head is None: 231 | choice = self._bond_selection_from_max_cuts(candidate_bonds, dist_mat) 232 | selected_bonds = [selected_bonds[c] for c in choice] 233 | return self._fragment_mol(mol, selected_bonds) 234 | 235 | # CASE 3b: slightly more complex case where we want the head to be the smallest graph containing the 236 | # provided substructure 237 | bond_combination = list(itertools.combinations(selected_bonds, self.MAX_CUTS)) 238 | bond_score = float("inf") 239 | linker_score = float("inf") 240 | head, linker, tail = (None, None, None) 241 | for split_bonds in bond_combination: 242 | cur_head, cur_linker, cur_tail = self._fragment_mol(mol, split_bonds) 243 | # head can also be tail 244 | head_match = cur_head.GetSubstructMatch(expected_head) 245 | tail_match = cur_tail.GetSubstructMatch(expected_head) 246 | if not head_match and not tail_match: 247 | continue 248 | if not head_match and tail_match: 249 | cur_head, cur_tail = cur_tail, cur_head 250 | cur_bond_score = cur_head.GetNumHeavyAtoms() 251 | # compute linker score 252 | cur_linker_score = self._compute_linker_score(cur_linker) 253 | if (cur_bond_score < bond_score) or ( 254 | cur_bond_score < self._BOND_BUFFER + bond_score and cur_linker_score < linker_score 255 | ): 256 | head, linker, tail = cur_head, cur_linker, cur_tail 257 | bond_score = cur_bond_score 258 | linker_score = cur_linker_score 259 | 260 | return (head, linker, tail) 261 | 262 | @classmethod 263 | def link_fragments( 264 | cls, linker: Union[dm.Mol, str], head: Union[dm.Mol, str], tail: Union[dm.Mol, str] 265 | ): 266 | """Link fragments together using the provided linker 267 | 268 | Args: 269 | linker: linker to use 270 | head: head fragment 271 | tail: tail fragment 272 | """ 273 | if isinstance(linker, dm.Mol): 274 | linker = dm.to_smiles(linker) 275 | linker = standardize_attach(linker) 276 | reactants = [dm.to_mol(head), dm.to_mol(tail), dm.to_mol(linker)] 277 | return dm.reactions.apply_reaction( 278 | cls._MERGING_RXN, reactants, as_smiles=True, sanitize=True, product_index=0 279 | ) 280 | 281 | 282 | @contextmanager 283 | def attr_as(obj: Any, field: str, value: Any): 284 | """Temporary replace the value of an object 285 | 286 | Args: 287 | obj: object to temporary patch 288 | field: name of the key to change 289 | value: value of key to be temporary changed 290 | """ 291 | old_value = getattr(obj, field, None) 292 | setattr(obj, field, value) 293 | yield 294 | with suppress(TypeError): 295 | setattr(obj, field, old_value) 296 | 297 | 298 | def _selective_add_hs(mol: dm.Mol, fraction_hs: Optional[bool] = None): 299 | """Custom addition of hydrogens to a molecule 300 | This version of hydrogen bond adding only at max 1 hydrogen per atom 301 | 302 | Args: 303 | mol: molecule to split 304 | fraction_hs: proportion of random atom to which we will add explicit hydrogens 305 | """ 306 | 307 | carbon_with_implicit_atoms = mol.GetSubstructMatches(__implicit_carbon_query, uniquify=True) 308 | carbon_with_implicit_atoms = [x[0] for x in carbon_with_implicit_atoms] 309 | carbon_with_implicit_atoms = list(set(carbon_with_implicit_atoms)) 310 | # we get a proportion of the carbon we can extend 311 | if fraction_hs is not None and fraction_hs > 0: 312 | fraction_hs = np.ceil(fraction_hs * len(carbon_with_implicit_atoms)) 313 | fraction_hs = int(np.clip(fraction_hs, 1, len(carbon_with_implicit_atoms))) 314 | carbon_with_implicit_atoms = random.sample(carbon_with_implicit_atoms, k=fraction_hs) 315 | carbon_with_implicit_atoms = [int(x) for x in carbon_with_implicit_atoms] 316 | emol = EditableMol(mol) 317 | for atom_id in carbon_with_implicit_atoms: 318 | h_atom = emol.AddAtom(Atom("H")) 319 | emol.AddBond(atom_id, h_atom, dm.SINGLE_BOND) 320 | return emol.GetMol() 321 | 322 | 323 | @py_random_state("seed") 324 | def mol_partition( 325 | mol: dm.Mol, query: Optional[dm.Mol] = None, seed: Optional[int] = None, **kwargs: Any 326 | ): 327 | """Partition a molecule into fragments using a bond query 328 | 329 | Args: 330 | mol: molecule to split 331 | query: bond query to use for splitting 332 | seed: random seed 333 | kwargs: additional arguments to pass to the partitioning algorithm 334 | 335 | """ 336 | resolution = kwargs.get("resolution", 1.0) 337 | threshold = kwargs.get("threshold", 1e-7) 338 | weight = kwargs.get("weight", "weight") 339 | 340 | if query is None: 341 | query = __mmpa_query 342 | 343 | G = dm.graph.to_graph(mol) 344 | bond_partition = [ 345 | tuple(sorted(match)) for match in mol.GetSubstructMatches(query, uniquify=True) 346 | ] 347 | 348 | def get_relevant_edges(e1, e2): 349 | return tuple(sorted([e1, e2])) not in bond_partition 350 | 351 | subgraphs = nx.subgraph_view(G, filter_edge=get_relevant_edges) 352 | 353 | partition = [{u} for u in G.nodes()] 354 | inner_partition = sorted(nx.connected_components(subgraphs), key=lambda x: min(x)) 355 | mod = nx.algorithms.community.modularity( 356 | G, inner_partition, resolution=resolution, weight=weight 357 | ) 358 | is_directed = G.is_directed() 359 | graph = G.__class__() 360 | graph.add_nodes_from(G) 361 | graph.add_weighted_edges_from(G.edges(data=weight, default=1)) 362 | graph = nx.algorithms.community.louvain._gen_graph(graph, inner_partition) 363 | m = graph.size(weight="weight") 364 | partition, inner_partition, improvement = nx.algorithms.community.louvain._one_level( 365 | graph, m, inner_partition, resolution, is_directed, seed 366 | ) 367 | improvement = True 368 | while improvement: 369 | # gh-5901 protect the sets in the yielded list from further manipulation here 370 | yield [s.copy() for s in partition] 371 | new_mod = nx.algorithms.community.modularity( 372 | graph, inner_partition, resolution=resolution, weight="weight" 373 | ) 374 | if new_mod - mod <= threshold: 375 | return 376 | mod = new_mod 377 | graph = nx.algorithms.community.louvain._gen_graph(graph, inner_partition) 378 | partition, inner_partition, improvement = nx.algorithms.community.louvain._one_level( 379 | graph, m, partition, resolution, is_directed, seed 380 | ) 381 | 382 | 383 | def find_partition_edges(G: nx.Graph, partition: List[List]) -> List[Tuple]: 384 | """ 385 | Find the edges connecting the subgraphs in a given partition of a graph. 386 | 387 | Args: 388 | G (networkx.Graph): The original graph. 389 | partition (list of list of nodes): The partition of the graph where each element is a list of nodes representing a subgraph. 390 | 391 | Returns: 392 | list: A list of edges connecting the subgraphs in the partition. 393 | """ 394 | partition_edges = [] 395 | for subgraph1, subgraph2 in combinations(partition, 2): 396 | edges = nx.edge_boundary(G, subgraph1, subgraph2) 397 | partition_edges.extend(edges) 398 | return partition_edges 399 | 400 | 401 | def fragment_aware_spliting(mol: dm.Mol, fraction_hs: Optional[bool] = None, **kwargs: Any): 402 | """Custom splitting algorithm for dataset building. 403 | 404 | This slicing strategy will cut any bond including bonding with hydrogens 405 | However, only one cut per atom is allowed 406 | 407 | Args: 408 | mol: molecule to split 409 | fraction_hs: proportion of random atom to which we will add explicit hydrogens 410 | kwargs: additional arguments to pass to the partitioning algorithm 411 | """ 412 | random.seed(kwargs.get("seed", 1)) 413 | mol = dm.to_mol(mol, remove_hs=False) 414 | mol = _selective_add_hs(mol, fraction_hs=fraction_hs) 415 | graph = dm.graph.to_graph(mol) 416 | d = mol_partition(mol, **kwargs) 417 | q = deque(d) 418 | partition = q.pop() 419 | return find_partition_edges(graph, partition) 420 | 421 | 422 | def convert_to_safe( 423 | mol: dm.Mol, 424 | canonical: bool = False, 425 | randomize: bool = False, 426 | seed: Optional[int] = 1, 427 | slicer: str = "brics", 428 | split_fragment: bool = True, 429 | fraction_hs: bool = None, 430 | resolution: Optional[float] = 0.5, 431 | ): 432 | """Convert a molecule to a safe representation 433 | 434 | Args: 435 | mol: molecule to convert 436 | canonical: whether to use canonical encoding 437 | randomize: whether to randomize the encoding 438 | seed: random seed 439 | slicer: the slicer to use for fragmentation 440 | split_fragment: whether to split fragments 441 | fraction_hs: proportion of random atom to which we will add explicit hydrogens 442 | resolution: resolution for the partitioning algorithm 443 | seed: random seed 444 | """ 445 | x = None 446 | try: 447 | x = sf.encode(mol, canonical=canonical, randomize=randomize, slicer=slicer, seed=seed) 448 | except sf.SAFEFragmentationError: 449 | if split_fragment: 450 | if "." in mol: 451 | return None 452 | try: 453 | x = sf.encode( 454 | mol, 455 | canonical=False, 456 | randomize=randomize, 457 | seed=seed, 458 | slicer=partial( 459 | fragment_aware_spliting, 460 | fraction_hs=fraction_hs, 461 | resolution=resolution, 462 | seed=seed, 463 | ), 464 | ) 465 | except (sf.SAFEEncodeError, sf.SAFEFragmentationError): 466 | # logger.exception(e) 467 | return x 468 | # we need to resplit using attachment point but here we are only adding 469 | except sf.SAFEEncodeError: 470 | return x 471 | return x 472 | 473 | 474 | def compute_side_chains(mol: dm.Mol, core: dm.Mol, label_by_index: bool = False): 475 | """Compute the side chain of a molecule given a core 476 | 477 | !!! note "Finding the side chains" 478 | The algorithm to find the side chains from core assumes that the core we get as input has attachment points. 479 | Those attachment points are never considered as part of the query, rather they are used to define the attachment points 480 | on the side chains. Removing the attachment points from the core is exactly the same as keeping them. 481 | 482 | ```python 483 | mol = "CC1=C(C(=NO1)C2=CC=CC=C2Cl)C(=O)NC3C4N(C3=O)C(C(S4)(C)C)C(=O)O" 484 | core0 = "CC1(C)CN2C(CC2=O)S1" 485 | core1 = "CC1(C)SC2C(-*)C(=O)N2C1-*" 486 | core2 = "CC1N2C(SC1(C)C)C(N)C2=O" 487 | side_chain = compute_side_chain(core=core0, mol=mol) 488 | dm.to_image([side_chain, core0, mol]) 489 | ``` 490 | Therefore on the above, core0 and core1 are equivalent for the molecule `mol`, but core2 is not. 491 | 492 | Args: 493 | mol: molecule to split 494 | core: core to use for deriving the side chains 495 | """ 496 | 497 | if isinstance(mol, str): 498 | mol = dm.to_mol(mol) 499 | if isinstance(core, str): 500 | core = dm.to_mol(core) 501 | core_query_param = AdjustQueryParameters() 502 | core_query_param.makeDummiesQueries = True 503 | core_query_param.adjustDegree = False 504 | core_query_param.aromatizeIfPossible = True 505 | core_query_param.makeBondsGeneric = False 506 | core_query = AdjustQueryProperties(core, core_query_param) 507 | return ReplaceCore( 508 | mol, core_query, labelByIndex=label_by_index, replaceDummies=False, requireDummyMatch=False 509 | ) 510 | 511 | 512 | def list_individual_attach_points(mol: dm.Mol, depth: Optional[int] = None): 513 | """List all individual attachement points. 514 | 515 | We do not allow multiple attachment points per substitution position. 516 | 517 | Args: 518 | mol: molecule for which we need to open the attachment points 519 | 520 | """ 521 | ATTACHING_RXN = ReactionFromSmarts("[*;h;!$([*][#0]):1]>>[*:1][*]") 522 | mols = [mol] 523 | curated_prods = set() 524 | num_attachs = len(mol.GetSubstructMatches(dm.from_smarts("[*;h:1]"), uniquify=True)) 525 | depth = depth or 1 526 | depth = min(max(depth, 1), num_attachs) 527 | while depth > 0: 528 | prods = set() 529 | for mol in mols: 530 | mol = dm.to_mol(mol) 531 | for p in ATTACHING_RXN.RunReactants((mol,)): 532 | try: 533 | m = dm.sanitize_mol(p[0]) 534 | sm = dm.to_smiles(m, canonical=True) 535 | sm = dm.reactions.add_brackets_to_attachment_points(sm) 536 | prods.add(dm.reactions.convert_attach_to_isotope(sm, as_smiles=True)) 537 | except Exception as e: 538 | logger.error(e) 539 | curated_prods.update(prods) 540 | mols = prods 541 | depth -= 1 542 | return list(curated_prods) 543 | 544 | 545 | def filter_by_substructure_constraints( 546 | sequences: List[Union[str, dm.Mol]], substruct: Union[str, dm.Mol], n_jobs: int = -1 547 | ): 548 | """Check whether the input substructures are present in each of the molecule in the sequences 549 | 550 | Args: 551 | sequences: list of molecules to validate 552 | substruct: substructure to use as query 553 | n_jobs: number of jobs to use for parallelization 554 | 555 | """ 556 | 557 | if isinstance(substruct, str): 558 | substruct = standardize_attach(substruct) 559 | substruct = dm.from_smarts(substruct) 560 | 561 | def _check_match(mol): 562 | with suppress(Exception): 563 | mol = dm.to_mol(mol) 564 | return mol.HasSubstructMatch(substruct) 565 | return False 566 | 567 | matches = dm.parallelized(_check_match, sequences, n_jobs=n_jobs) 568 | return list(compress(sequences, matches)) 569 | 570 | 571 | def standardize_attach(inputs: str, standard_attach: str = "[*]"): 572 | """Standardize the attachment points of a molecule 573 | 574 | Args: 575 | inputs: input molecule 576 | standard_attach: standard attachment point to use 577 | """ 578 | 579 | for attach_regex in _SMILES_ATTACHMENT_POINTS: 580 | inputs = re.sub(attach_regex, standard_attach, inputs) 581 | return inputs 582 | -------------------------------------------------------------------------------- /safe/viz.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Any, Optional, Tuple, Union 3 | 4 | import datamol as dm 5 | import matplotlib.pyplot as plt 6 | 7 | import safe as sf 8 | 9 | 10 | def to_image( 11 | safe_str: str, 12 | fragments: Optional[Union[str, dm.Mol]] = None, 13 | legend: Union[str, None] = None, 14 | mol_size: Union[Tuple[int, int], int] = (300, 300), 15 | use_svg: Optional[bool] = True, 16 | highlight_mode: Optional[str] = "lasso", 17 | highlight_bond_width_multiplier: int = 12, 18 | **kwargs: Any, 19 | ): 20 | """Display a safe string by highlighting the fragments that make it. 21 | 22 | Args: 23 | safe_str: the safe string to display 24 | fragments: list of fragment to highlight on the molecules. If None, will use safe decomposition of the molecule. 25 | legend: A string to use as the legend under the molecule. 26 | mol_size: The size of the image to be returned 27 | use_svg: Whether to return an svg or png image 28 | highlight_mode: the highlight mode to use. One of ["lasso", "fill", "color"]. If None, no highlight will be shown 29 | highlight_bond_width_multiplier: the multiplier to use for the bond width when using the 'fill' mode 30 | **kwargs: Additional arguments to pass to the drawing function. See RDKit 31 | documentation related to `MolDrawOptions` for more details at 32 | https://www.rdkit.org/docs/source/rdkit.Chem.Draw.rdMolDraw2D.html. 33 | 34 | """ 35 | 36 | kwargs["legends"] = legend 37 | kwargs["mol_size"] = mol_size 38 | kwargs["use_svg"] = use_svg 39 | if highlight_bond_width_multiplier is not None: 40 | kwargs["highlightBondWidthMultiplier"] = highlight_bond_width_multiplier 41 | 42 | if highlight_mode == "color": 43 | kwargs["continuousHighlight"] = False 44 | kwargs["circleAtoms"] = kwargs.get("circleAtoms", False) or False 45 | 46 | if isinstance(fragments, (str, dm.Mol)): 47 | fragments = [fragments] 48 | 49 | if fragments is None and highlight_mode is not None: 50 | fragments = [ 51 | sf.decode(x, as_mol=False, remove_dummies=False, ignore_errors=False) 52 | for x in safe_str.split(".") 53 | ] 54 | elif fragments and len(fragments) > 0: 55 | parsed_fragments = [] 56 | for fg in fragments: 57 | if isinstance(fg, str) and dm.to_mol(fg) is None: 58 | fg = sf.decode(fg, as_mol=False, remove_dummies=False, ignore_errors=False) 59 | parsed_fragments.append(fg) 60 | fragments = parsed_fragments 61 | else: 62 | fragments = [] 63 | mol = dm.to_mol(safe_str, remove_hs=False) 64 | cm = plt.get_cmap("gist_rainbow") 65 | current_colors = [cm(1.0 * i / len(fragments)) for i in range(len(fragments))] 66 | 67 | if highlight_mode == "lasso": 68 | return dm.viz.lasso_highlight_image(mol, fragments, **kwargs) 69 | 70 | atom_indices = [] 71 | bond_indices = [] 72 | atom_colors = {} 73 | bond_colors = {} 74 | 75 | for i, frag in enumerate(fragments): 76 | frag = dm.from_smarts(frag) 77 | atom_matches, bond_matches = dm.substructure_matching_bonds(mol, frag) 78 | atom_matches = list(itertools.chain(*atom_matches)) 79 | bond_matches = list(itertools.chain(*bond_matches)) 80 | atom_indices.extend(atom_matches) 81 | bond_indices.extend(bond_matches) 82 | atom_colors.update({x: current_colors[i] for x in atom_matches}) 83 | bond_colors.update({x: current_colors[i] for x in bond_matches}) 84 | 85 | return dm.viz.to_image( 86 | mol, 87 | highlight_atom=[atom_indices], 88 | highlight_bond=[bond_indices], 89 | highlightAtomColors=[atom_colors], 90 | highlightBondColors=[bond_colors], 91 | **kwargs, 92 | ) 93 | -------------------------------------------------------------------------------- /tests/test_hgf_load.py: -------------------------------------------------------------------------------- 1 | from safe.sample import SAFEDesign 2 | from safe.tokenizer import SAFETokenizer 3 | from safe.trainer.model import SAFEDoubleHeadsModel 4 | 5 | 6 | def test_load_default_safe_model(): 7 | model = SAFEDoubleHeadsModel.from_pretrained("datamol-io/safe-gpt") 8 | assert model is not None 9 | assert isinstance(model, SAFEDoubleHeadsModel) 10 | 11 | 12 | def test_load_default_safe_tokenizer(): 13 | tokenizer = SAFETokenizer.from_pretrained("datamol-io/safe-gpt") 14 | assert isinstance(tokenizer, SAFETokenizer) 15 | 16 | 17 | def test_check_molecule_sampling(): 18 | designer = SAFEDesign.load_default(verbose=True) 19 | generated = designer.de_novo_generation(sanitize=True, n_samples_per_trial=10) 20 | assert len(generated) > 0 21 | -------------------------------------------------------------------------------- /tests/test_import.py: -------------------------------------------------------------------------------- 1 | def test_import(): 2 | import safe 3 | -------------------------------------------------------------------------------- /tests/test_notebooks.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import nbformat 4 | import pytest 5 | from nbconvert.preprocessors.execute import ExecutePreprocessor 6 | 7 | ROOT_DIR = pathlib.Path(__file__).parent.resolve() 8 | 9 | TUTORIALS_DIR = ROOT_DIR.parent / "docs" / "tutorials" 10 | DISABLE_NOTEBOOKS = [] 11 | NOTEBOOK_PATHS = sorted(TUTORIALS_DIR.glob("*.ipynb")) 12 | NOTEBOOK_PATHS = list(filter(lambda x: x.name not in DISABLE_NOTEBOOKS, NOTEBOOK_PATHS)) 13 | 14 | # Discard some notebooks 15 | NOTEBOOKS_TO_DISCARD = ["extracting-representation-molfeat.ipynb", "load-from-wandb.ipynb"] 16 | NOTEBOOK_PATHS = list(filter(lambda x: x.name not in NOTEBOOKS_TO_DISCARD, NOTEBOOK_PATHS)) 17 | 18 | 19 | @pytest.mark.parametrize("nb_path", NOTEBOOK_PATHS, ids=[str(n.name) for n in NOTEBOOK_PATHS]) 20 | def test_notebook(nb_path): 21 | # Setup and configure the processor to execute the notebook 22 | ep = ExecutePreprocessor(timeout=600, kernel_name="python3") 23 | 24 | # Open the notebook 25 | with open(nb_path) as f: 26 | nb = nbformat.read(f, as_version=nbformat.NO_CONVERT) 27 | 28 | # Execute the notebook 29 | ep.preprocess(nb, {"metadata": {"path": TUTORIALS_DIR}}) 30 | -------------------------------------------------------------------------------- /tests/test_safe.py: -------------------------------------------------------------------------------- 1 | import datamol as dm 2 | import numpy as np 3 | import pytest 4 | 5 | import safe 6 | 7 | 8 | def test_safe_encoding(): 9 | celecoxib = "Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1" 10 | expected_encodings = "c13ccc(S(N)(=O)=O)cc1.Cc1ccc4cc1.c14cc5nn13.C5(F)(F)F" 11 | safe_celecoxib = safe.encode(celecoxib, canonical=True) 12 | dec_celecoxib = safe.decode(safe_celecoxib) 13 | assert safe_celecoxib.count(".") == 3 # 3 fragments 14 | # we compare length since digits can be random 15 | assert len(safe_celecoxib) == len(expected_encodings) 16 | assert dm.same_mol(celecoxib, safe_celecoxib) 17 | assert dm.same_mol(celecoxib, dec_celecoxib) 18 | 19 | 20 | def test_safe_fragment_randomization(): 21 | celecoxib = "Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1" 22 | safe_celecoxib = safe.encode(celecoxib) 23 | fragments = safe_celecoxib.split(".") 24 | randomized_fragment_safe_str = np.random.permutation(fragments).tolist() 25 | randomized_fragment_safe_str = ".".join(randomized_fragment_safe_str) 26 | assert dm.same_mol(celecoxib, randomized_fragment_safe_str) 27 | 28 | 29 | def test_randomized_encoder(): 30 | celecoxib = "Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1" 31 | output = set() 32 | for i in range(5): 33 | out = safe.encode(celecoxib, canonical=False, randomize=True, seed=i) 34 | output.add(out) 35 | assert len(output) > 1 36 | 37 | 38 | def test_custom_encoder(): 39 | smart_slicer = ["[r]-;!@[r]"] 40 | celecoxib = "Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1" 41 | safe_str = safe.encode(celecoxib, canonical=True, slicer=smart_slicer) 42 | assert dm.same_mol(celecoxib, safe_str) 43 | 44 | 45 | def test_safe_decoder(): 46 | celecoxib = "Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1" 47 | safe_str = safe.encode(celecoxib) 48 | fragments = safe_str.split(".") 49 | decoded_fragments = [safe.decode(x, fix=True) for x in fragments] 50 | assert [dm.to_mol(x) for x in fragments] == [None] * len(fragments) 51 | assert all(x is not None for x in decoded_fragments) 52 | 53 | 54 | def test_rdkit_smiles_parser_issues(): 55 | # see https://github.com/datamol-io/safe/issues/22 56 | input_sm = r"C(=C/c1ccccc1)\CCc1ccccc1" 57 | slicer = "brics" 58 | safe_obj = safe.SAFEConverter(slicer=slicer, require_hs=False) 59 | with dm.without_rdkit_log(): 60 | failing_encoded = safe_obj.encoder( 61 | input_sm, 62 | canonical=True, 63 | randomize=False, 64 | rdkit_safe=False, 65 | ) 66 | working_encoded = safe_obj.encoder( 67 | input_sm, 68 | canonical=True, 69 | randomize=False, 70 | rdkit_safe=True, 71 | ) 72 | working_decoded = safe.decode(working_encoded) 73 | working_no_stero = dm.remove_stereochemistry(dm.to_mol(input_sm)) 74 | input_mol = dm.remove_stereochemistry(dm.to_mol(working_decoded)) 75 | assert safe.decode(failing_encoded) is None 76 | assert working_decoded is not None 77 | assert dm.same_mol(working_no_stero, input_mol) 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "input_sm", 82 | [ 83 | "O=C(CN1CC[NH2+]CC1)N1CCCCC1", 84 | "[NH3+]Cc1ccccc1", 85 | "c1cc2c(cc1[C@@H]1CCC[NH2+]1)OCCO2", 86 | "[13C]1CCCCC1C[238U]C[NH3+]", 87 | "COC[CH2:1][CH2:2]O[CH:2]C[OH:3]", 88 | ], 89 | ) 90 | def test_bracket_smiles_issues(input_sm): 91 | slicer = "brics" 92 | safe_obj = safe.SAFEConverter(slicer=slicer, require_hs=False) 93 | fragments = [] 94 | with dm.without_rdkit_log(): 95 | safe_str = safe_obj.encoder( 96 | input_sm, 97 | canonical=True, 98 | ) 99 | for fragment in safe_str.split("."): 100 | f = safe_obj.decoder( 101 | fragment, 102 | as_mol=False, 103 | canonical=True, 104 | fix=True, 105 | remove_dummies=True, 106 | remove_added_hs=True, 107 | ) 108 | fragments.append(f) 109 | input_mol = dm.to_mol(input_sm) 110 | assert safe.decode(safe_str) is not None 111 | assert dm.same_mol(dm.to_mol(safe_str), input_mol) 112 | assert None not in fragments 113 | 114 | 115 | def test_fused_ring_issue(): 116 | FUSED_RING_LIST = [ 117 | "[H][C@@]12CC[C@@]3(CCC(=O)O3)[C@@]1(C)CC[C@@]1([H])[C@@]2([H])[C@@]([H])(CC2=CC(=O)CC[C@]12C)SC(C)=O", 118 | "[H][C@@]12C[C@H](C)[C@](OC(=O)CC)(C(=O)COC(=O)CC)[C@@]1(C)C[C@H](O)[C@@]1(Cl)[C@@]2([H])CCC2=CC(=O)C=C[C@]12C", 119 | "[H][C@@]12CC[C@@](O)(C#C)[C@@]1(CC)CC[C@]1([H])[C@@]3([H])CCC(=O)C=C3CC[C@@]21[H]", 120 | ] 121 | for fused_ring in FUSED_RING_LIST: 122 | output_string = safe.decode(safe.encode(fused_ring)) 123 | assert dm.same_mol(fused_ring, output_string) 124 | 125 | 126 | def test_stereochemistry_issue(): 127 | STEREO_MOL_LIST = [ 128 | "CC(=C\\c1ccccc1)/N=C/C(=O)O", 129 | "CC(=C/c1ccccc1)/N=C/C(=O)O", 130 | "CC(=C\\c1ccccc1)/N=C\\C(=O)O", 131 | "CC(=C/c1ccccc1)/N=C\\C(=O)O", 132 | "CC(=Cc1ccccc1)N=CC(=O)O", 133 | "Cc1ccc(-n2c(C)cc(/C=N/Nc3ccc([N+](=O)[O-])cn3)c2C)c(C)c1", 134 | "Cc1ccc(-n2c(C)cc(/C=N\\Nc3ccc([N+](=O)[O-])cn3)c2C)c(C)c1", 135 | ] 136 | for mol in STEREO_MOL_LIST: 137 | output_string = safe.encode(mol, ignore_stereo=False, slicer="rotatable") 138 | assert dm.same_mol(mol, output_string) 139 | 140 | # now let's test failure case where we fail because we split on a double bond 141 | output = safe.encode(STEREO_MOL_LIST[0], ignore_stereo=False, slicer="brics") 142 | assert dm.same_mol(STEREO_MOL_LIST[0], output) is False 143 | same_stereo = [dm.remove_stereochemistry(dm.to_mol(x)) for x in [output, STEREO_MOL_LIST[0]]] 144 | assert dm.same_mol(same_stereo[0], same_stereo[1]) 145 | 146 | # check if we ignore the stereo 147 | output = safe.encode(STEREO_MOL_LIST[0], ignore_stereo=True, slicer="brics") 148 | assert dm.same_mol(dm.remove_stereochemistry(dm.to_mol(STEREO_MOL_LIST[0])), output) 149 | --------------------------------------------------------------------------------