├── .github
├── CODEOWNERS
├── CODE_OF_CONDUCT.md
├── PULL_REQUEST_TEMPLATE.md
├── SECURITY.md
└── workflows
│ ├── code-check.yml
│ ├── doc.yml
│ ├── release.yml
│ └── test.yml
├── .gitignore
├── DATA_LICENSE
├── LICENSE
├── README.md
├── docs
├── api
│ ├── safe.md
│ ├── safe.models.md
│ └── safe.viz.md
├── assets
│ ├── js
│ │ └── google-analytics.js
│ ├── safe-construction.svg
│ └── safe-tasks.svg
├── cli.md
├── data_license.md
├── index.md
├── license.md
└── tutorials
│ ├── design-with-safe.ipynb
│ ├── extracting-representation-molfeat.ipynb
│ ├── getting-started.ipynb
│ ├── how-it-works.ipynb
│ └── load-from-wandb.ipynb
├── env.yml
├── expts
├── config
│ └── accelerate.yaml
├── scripts
│ ├── slurm-data-build.sh
│ ├── slurm-notebook.sh
│ ├── slurm-tokenizer-train-custom.sh
│ ├── slurm-tokenizer-train-small.sh
│ ├── slurm-tokenizer-train.sh
│ ├── train-small.sh
│ └── train.sh
└── tokenizer
│ ├── _tokenizer-custom-mini-test.json
│ └── tokenizer-custom.json
├── mkdocs.yml
├── pyproject.toml
├── safe
├── __init__.py
├── _exception.py
├── _pattern.py
├── converter.py
├── io.py
├── sample.py
├── tokenizer.py
├── trainer
│ ├── __init__.py
│ ├── cli.py
│ ├── collator.py
│ ├── configs
│ │ ├── __init__.py
│ │ └── default_config.json
│ ├── data_utils.py
│ ├── model.py
│ └── trainer_utils.py
├── utils.py
└── viz.py
└── tests
├── test_hgf_load.py
├── test_import.py
├── test_notebooks.py
└── test_safe.py
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @maclandrol
2 |
--------------------------------------------------------------------------------
/.github/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Changelogs
2 |
3 | - _enumerate the changes of that PR._
4 |
5 | ---
6 |
7 | _Checklist:_
8 |
9 | - [ ] _Add tests to cover the fixed bug(s) or the new introduced feature(s) (if appropriate)._
10 | - [ ] _Update the API documentation if a new function is added, or an existing one is deleted. Eventually consider making a new tutorial for new features._
11 | - [ ] _Write concise and explanatory changelogs below._
12 | - [ ] _If possible, assign one of the following labels to the PR: `feature`, `fix` or `test` (or ask a maintainer to do it for you)._
13 |
14 | ---
15 |
16 | _discussion related to that PR_
--------------------------------------------------------------------------------
/.github/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 |
3 | Please report any security-related issues directly to emmanuel.noutahi@hotmail.ca
4 |
--------------------------------------------------------------------------------
/.github/workflows/code-check.yml:
--------------------------------------------------------------------------------
1 | name: code-check
2 |
3 | on:
4 | push:
5 | branches:
6 | - "main"
7 | pull_request:
8 | branches:
9 | - "*"
10 |
11 | jobs:
12 | python-format-black:
13 | name: Python lint [black]
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: Checkout the code
17 | uses: actions/checkout@v3
18 |
19 | - name: Set up Python
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: "3.10"
23 |
24 | - name: Install black
25 | run: |
26 | pip install black>=24
27 |
28 | - name: Lint
29 | run: black --check .
30 |
31 | python-lint-ruff:
32 | name: Python lint [ruff]
33 | runs-on: ubuntu-latest
34 | steps:
35 | - name: Checkout the code
36 | uses: actions/checkout@v3
37 |
38 | - name: Set up Python
39 | uses: actions/setup-python@v4
40 | with:
41 | python-version: "3.10"
42 |
43 | - name: Install ruff
44 | run: |
45 | pip install ruff
46 |
47 | - name: Lint
48 | run: ruff check .
49 |
--------------------------------------------------------------------------------
/.github/workflows/doc.yml:
--------------------------------------------------------------------------------
1 | name: doc
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 |
7 | # Prevent doc action on `main` to conflict with each others.
8 | concurrency:
9 | group: doc-${{ github.ref }}
10 | cancel-in-progress: true
11 |
12 | jobs:
13 | doc:
14 | runs-on: "ubuntu-latest"
15 | timeout-minutes: 30
16 |
17 | defaults:
18 | run:
19 | shell: bash -l {0}
20 |
21 | steps:
22 | - name: Checkout the code
23 | uses: actions/checkout@v3
24 |
25 | - name: Setup mamba
26 | uses: mamba-org/setup-micromamba@v1
27 | with:
28 | environment-file: env.yml
29 | environment-name: my_env
30 | cache-environment: true
31 | cache-downloads: true
32 |
33 | - name: Install library
34 | run: python -m pip install --no-deps .
35 |
36 | - name: Configure git
37 | run: |
38 | git config --global user.name "${GITHUB_ACTOR}"
39 | git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com"
40 |
41 | - name: Deploy the doc
42 | run: |
43 | echo "Get the gh-pages branch"
44 | git fetch origin gh-pages
45 |
46 | echo "Build and deploy the doc on main"
47 | mike deploy --push main
48 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: release
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | release-version:
7 | description: "A valid Semver version string"
8 | required: true
9 |
10 | permissions:
11 | contents: write
12 | pull-requests: write
13 |
14 | jobs:
15 | release:
16 | # Do not release if not triggered from the default branch
17 | if: github.ref == format('refs/heads/{0}', github.event.repository.default_branch)
18 |
19 | runs-on: ubuntu-latest
20 | timeout-minutes: 30
21 |
22 | defaults:
23 | run:
24 | shell: bash -l {0}
25 |
26 | steps:
27 | - name: Checkout the code
28 | uses: actions/checkout@v3
29 |
30 | - name: Setup mamba
31 | uses: mamba-org/setup-micromamba@v1
32 | with:
33 | environment-file: env.yml
34 | environment-name: my_env
35 | cache-environment: true
36 | cache-downloads: true
37 | create-args: >-
38 | pip
39 | semver
40 | python-build
41 | setuptools_scm
42 |
43 | - name: Check the version is valid semver
44 | run: |
45 | RELEASE_VERSION="${{ inputs.release-version }}"
46 |
47 | {
48 | pysemver check $RELEASE_VERSION
49 | } || {
50 | echo "The version '$RELEASE_VERSION' is not a valid Semver version string."
51 | echo "Please use a valid semver version string. More details at https://semver.org/"
52 | echo "The release process is aborted."
53 | exit 1
54 | }
55 |
56 | - name: Check the version is higher than the latest one
57 | run: |
58 | # Retrieve the git tags first
59 | git fetch --prune --unshallow --tags &> /dev/null
60 |
61 | RELEASE_VERSION="${{ inputs.release-version }}"
62 | LATEST_VERSION=$(git describe --abbrev=0 --tags)
63 |
64 | IS_HIGHER_VERSION=$(pysemver compare $RELEASE_VERSION $LATEST_VERSION)
65 |
66 | if [ "$IS_HIGHER_VERSION" != "1" ]; then
67 | echo "The version '$RELEASE_VERSION' is not higher than the latest version '$LATEST_VERSION'."
68 | echo "The release process is aborted."
69 | exit 1
70 | fi
71 |
72 | - name: Build Changelog
73 | id: github_release
74 | uses: mikepenz/release-changelog-builder-action@v4
75 | env:
76 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
77 | with:
78 | toTag: "main"
79 |
80 | - name: Configure git
81 | run: |
82 | git config --global user.name "${GITHUB_ACTOR}"
83 | git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com"
84 |
85 | - name: Create and push git tag
86 | env:
87 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
88 | run: |
89 | # Tag the release
90 | git tag -a "${{ inputs.release-version }}" -m "Release version ${{ inputs.release-version }}"
91 |
92 | # Checkout the git tag
93 | git checkout "${{ inputs.release-version }}"
94 |
95 | # Push the modified changelogs
96 | git push origin main
97 |
98 | # Push the tags
99 | git push origin "${{ inputs.release-version }}"
100 |
101 | - name: Install library
102 | run: python -m pip install --no-deps .
103 |
104 | - name: Build the wheel and sdist
105 | run: python -m build --no-isolation
106 |
107 | - name: Publish package to PyPI
108 | uses: pypa/gh-action-pypi-publish@release/v1
109 | with:
110 | password: ${{ secrets.PYPI_API_TOKEN }}
111 | packages-dir: dist/
112 |
113 | - name: Create GitHub Release
114 | uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844
115 | with:
116 | tag_name: ${{ inputs.release-version }}
117 | body: ${{steps.github_release.outputs.changelog}}
118 |
119 | - name: Deploy the doc
120 | run: |
121 | echo "Get the gh-pages branch"
122 | git fetch origin gh-pages
123 |
124 | echo "Build and deploy the doc on ${{ inputs.release-version }}"
125 | mike deploy --push stable
126 | mike deploy --push ${{ inputs.release-version }}
127 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | tags: ["*"]
7 | pull_request:
8 | branches:
9 | - "*"
10 | - "!gh-pages"
11 | schedule:
12 | - cron: "0 4 * * MON"
13 |
14 | jobs:
15 | test:
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | python-version: ["3.11"]
20 |
21 | runs-on: "ubuntu-latest"
22 | timeout-minutes: 30
23 |
24 | defaults:
25 | run:
26 | shell: bash -l {0}
27 |
28 | name: python=${{ matrix.python-version }}
29 |
30 | steps:
31 | - name: Checkout the code
32 | uses: actions/checkout@v3
33 |
34 | - name: Setup mamba
35 | uses: mamba-org/setup-micromamba@v1
36 | with:
37 | environment-file: env.yml
38 | environment-name: my_env
39 | cache-environment: true
40 | cache-downloads: true
41 | create-args: >-
42 | python=${{ matrix.python-version }}
43 |
44 | - name: Install library
45 | run: python -m pip install --no-deps -e .
46 |
47 | - name: Run tests
48 | run: pytest
49 |
50 | - name: Test CLI
51 | run: safe-train --help
52 |
53 | - name: Test building the doc
54 | run: mkdocs build
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.env
2 | cov.xml
3 |
4 | .vscode/
5 |
6 | .ipynb_checkpoints/
7 |
8 | *.py[cod]
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Packages
14 | *.egg
15 | *.egg-info
16 | dist
17 | build
18 | eggs
19 | parts
20 | bin
21 | var
22 | sdist
23 | develop-eggs
24 | .installed.cfg
25 | lib
26 | lib64
27 |
28 | # Installer logs
29 | pip-log.txt
30 |
31 | # Unit test / coverage reports
32 | .coverage*
33 | .tox
34 | nosetests.xml
35 | htmlcov
36 |
37 | # Translations
38 | *.mo
39 |
40 | # Mr Developer
41 | .mr.developer.cfg
42 | .project
43 | .pydevproject
44 |
45 | # Complexity
46 | output/*.html
47 | output/*/index.html
48 |
49 | # Sphinx
50 | docs/_build
51 |
52 | MANIFEST
53 |
54 | *.tif
55 |
56 | # Rever
57 | rever/
58 |
59 | # Dev notebooks
60 | notebooks/
61 |
62 | # MkDocs
63 | site/
64 |
65 | .vscode
66 |
67 | .idea/
68 |
69 | data/
70 | output/
71 | wandb/
72 | oracle/
73 | expts/models/
74 | expts/dev-data/
75 | expts/notebooks/
76 |
--------------------------------------------------------------------------------
/DATA_LICENSE:
--------------------------------------------------------------------------------
1 | # Creative Commons Attribution 4.0 International License (CC BY 4.0)
2 |
3 | This work is licensed under the Creative Commons Attribution 4.0 International License.
4 |
5 | To view a copy of this license, visit http://creativecommons.org/licenses/by/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Emmanuel Noutahi
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
:safety_vest: SAFE
2 | Sequential Attachment-based Fragment Embedding (SAFE) is a novel molecular line notation that represents molecules as an unordered sequence of fragment blocks to improve molecule design using generative models.
3 |
4 |
5 |
6 |

7 |
8 |
9 |
10 |
11 |
12 | Paper
13 | |
14 |
15 | Docs
16 | |
17 |
18 | 🤗 Model
19 | |
20 |
21 | 🤗 Training Dataset
22 |
23 |
24 |
25 | ---
26 |
27 |
28 |
29 | [](https://pypi.org/project/safe-mol/)
30 | [](https://anaconda.org/conda-forge/safe-mol)
31 | [](https://pypi.org/project/safe-mol/)
32 | [](https://anaconda.org/conda-forge/safe-mol)
33 | [](https://github.com/datamol-io/safe/blob/main/LICENSE)
34 | [](https://github.com/datamol-io/safe/blob/main/DATA_LICENSE)
35 | [](https://github.com/datamol-io/safe/stargazers)
36 | [](https://github.com/datamol-io/safe/network/members)
37 | [](https://arxiv.org/pdf/2310.10773.pdf)
38 |
39 | [](https://github.com/datamol-io/safe/actions/workflows/test.yml)
40 | [](https://github.com/datamol-io/safe/actions/workflows/release.yml)
41 | [](https://github.com/datamol-io/safe/actions/workflows/code-check.yml)
42 | [](https://github.com/datamol-io/safe/actions/workflows/doc.yml)
43 |
44 | ## Overview of SAFE
45 |
46 | SAFE _is the_ deep learning molecular representation. It's an encoding leveraging a peculiarity in the decoding schemes of SMILES, to allow representation of molecules as a contiguous sequence of connected fragments. SAFE strings are valid SMILES strings, and thus are able to preserve the same amount of information. The intuitive representation of molecules as an ordered sequence of connected fragments greatly simplifies the following tasks often encountered in molecular design:
47 |
48 | - _de novo_ design
49 | - superstructure generation
50 | - scaffold decoration
51 | - motif extension
52 | - linker generation
53 | - scaffold morphing.
54 |
55 | The construction of a SAFE strings requires defining a molecular fragmentation algorithm. By default, we use [BRICS], but any other fragmentation algorithm can be used. The image below illustrates the process of building a SAFE string. The resulting string is a valid SMILES that can be read by [datamol](https://github.com/datamol-io/datamol) or [RDKit](https://github.com/rdkit/rdkit).
56 |
57 |
58 |
59 |

60 |
61 |
62 | ## News 🚀
63 |
64 | #### 💥 2024/01/15 💥
65 | 1. [@IanAWatson](https://github.com/IanAWatson) has a C++ implementation of SAFE in [LillyMol](https://github.com/IanAWatson/LillyMol/tree/bazel_version_float) that is quite fast and use a custom fragmentation algorithm. Follow the installation instruction on the repo and checkout the docs of the CLI here: [docs/Molecule_Tools/SAFE.md](https://github.com/IanAWatson/LillyMol/blob/bazel_version_float/docs/Molecule_Tools/SAFE.md)
66 |
67 |
68 | ### Installation
69 |
70 | You can install `safe` using pip:
71 |
72 | ```bash
73 | pip install safe-mol
74 | ```
75 |
76 | You can use conda/mamba:
77 |
78 | ```bash
79 | mamba install -c conda-forge safe-mol
80 | ```
81 |
82 | #### 2024/11/22
83 | NOTE: Installation might cause issues like no detection of GPUs (which can be checked by `torch.cuda.is_available()`) and sengmentation error due to mismatch between installed and driver cuda versions. In that case, follow these steps:
84 |
85 | Create a new environment using conda:
86 |
87 | ```bash
88 | conda create -n env_safe python=3.12
89 | conda activate env_safe
90 | ```
91 |
92 | Check nvidia driver version on machine by running `nvcc --version` or `nvidia-smi` commands
93 |
94 | Install pytorch with compatible cuda versions (from `https://pytorch.org/get-started/locally/`) and safe-mol:
95 |
96 | ```bash
97 | conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
98 | conda install -c conda-forge safe-mol
99 | ```
100 |
101 | ### Datasets and Models
102 |
103 | | Type | Name | Infos | Size | Comment |
104 | | ---------------------- | ------------------------------------------------------------------------------ | ---------- | ----- | -------------------- |
105 | | Model | [datamol-io/safe-gpt](https://huggingface.co/datamol-io/safe-gpt) | 87M params | 350M | Default model |
106 | | Training Dataset | [datamol-io/safe-gpt](https://huggingface.co/datasets/datamol-io/safe-gpt) | 1.1B rows | 250GB | Training dataset |
107 | | Drug Benchmark Dataset | [datamol-io/safe-drugs](https://huggingface.co/datasets/datamol-io/safe-drugs) | 26 rows | 20 kB | Benchmarking dataset |
108 |
109 | ## Usage
110 |
111 | Please refer to the [documentation](https://safe-docs.datamol.io/), which contains tutorials for getting started with `safe` and detailed descriptions of the functions provided, as well as an example of how to get started with SAFE-GPT.
112 |
113 | ### API
114 |
115 | We summarize some key functions provided by the `safe` package below.
116 |
117 | | Function | Description |
118 | | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
119 | | `safe.encode` | Translates a SMILES string into its corresponding SAFE string. |
120 | | `safe.decode` | Translates a SAFE string into its corresponding SMILES string. The SAFE decoder just augment RDKit's `Chem.MolFromSmiles` with an optional correction argument to take care of missing hydrogen bonds. |
121 | | `safe.split` | Tokenizes a SAFE string to build a generative model. |
122 |
123 | ### Examples
124 |
125 | #### Translation between SAFE and SMILES representations
126 |
127 | ```python
128 | import safe
129 |
130 | ibuprofen = "CC(Cc1ccc(cc1)C(C(=O)O)C)C"
131 |
132 | # SMILES -> SAFE -> SMILES translation
133 | try:
134 | ibuprofen_sf = safe.encode(ibuprofen) # c12ccc3cc1.C3(C)C(=O)O.CC(C)C2
135 | ibuprofen_smi = safe.decode(ibuprofen_sf, canonical=True) # CC(C)Cc1ccc(C(C)C(=O)O)cc1
136 | except safe.EncoderError:
137 | pass
138 | except safe.DecoderError:
139 | pass
140 |
141 | ibuprofen_tokens = list(safe.split(ibuprofen_sf))
142 | ```
143 |
144 | ### Training/Finetuning a (new) model
145 |
146 | A command line interface is available to train a new model, please run `safe-train --help`. You can also provide an existing checkpoint to continue training or finetune on you own dataset.
147 |
148 | For example:
149 |
150 | ```bash
151 | safe-train --config \
152 | --model-path \
153 | --tokenizer \
154 | --dataset \
155 | --num_labels 9 \
156 | --torch_compile True \
157 | --optim "adamw_torch" \
158 | --learning_rate 1e-5 \
159 | --prop_loss_coeff 1e-3 \
160 | --gradient_accumulation_steps 1 \
161 | --output_dir "" \
162 | --max_steps 5
163 | ```
164 |
165 | ## References
166 |
167 | If you use this repository, please cite the following related [paper](https://arxiv.org/abs/2310.10773#):
168 |
169 | ```bib
170 | @misc{noutahi2023gotta,
171 | title={Gotta be SAFE: A New Framework for Molecular Design},
172 | author={Emmanuel Noutahi and Cristian Gabellini and Michael Craig and Jonathan S. C Lim and Prudencio Tossou},
173 | year={2023},
174 | eprint={2310.10773},
175 | archivePrefix={arXiv},
176 | primaryClass={cs.LG}
177 | }
178 | ```
179 |
180 | ## License
181 |
182 | The training dataset is licensed under CC BY 4.0. See [DATA_LICENSE](DATA_LICENSE) for details. This code base is licensed under the Apache-2.0 license. See [LICENSE](LICENSE) for details.
183 |
184 | Note that the model weights of **SAFE-GPT** are exclusively licensed for research purposes (CC BY-NC 4.0).
185 |
186 | ## Development lifecycle
187 |
188 | ### Setup dev environment
189 |
190 | ```bash
191 | mamba create -n safe -f env.yml
192 | mamba activate safe
193 |
194 | pip install --no-deps -e .
195 | ```
196 |
197 | ### Tests
198 |
199 | You can run tests locally with:
200 |
201 | ```bash
202 | pytest
203 | ```
204 |
--------------------------------------------------------------------------------
/docs/api/safe.md:
--------------------------------------------------------------------------------
1 | ## SAFE Encoder-Decoder
2 |
3 | ::: safe.converter
4 | options:
5 | members:
6 | - encode
7 | - decode
8 | - SAFEConverter
9 | show_root_heading: false
10 |
11 |
12 | ---
13 |
14 | ## SAFE Design
15 |
16 | ::: safe.sample
17 | options:
18 | members:
19 | - SAFEDesign
20 | show_root_heading: false
21 |
22 |
23 | ---
24 |
25 | ## SAFE Tokenizer
26 |
27 | ::: safe.tokenizer
28 | options:
29 | members:
30 | - SAFESplitter
31 | - SAFETokenizer
32 | show_root_heading: false
33 |
34 | ---
35 |
36 | ## Utils
37 |
38 | ::: safe.utils
--------------------------------------------------------------------------------
/docs/api/safe.models.md:
--------------------------------------------------------------------------------
1 | ## Config File
2 |
3 | The input config file for training a `SAFE` model is very similar to the GPT2 config file, with the addition of an optional `num_labels` attribute for training with descriptors regularization.
4 |
5 | ```json
6 | {
7 | "activation_function": "gelu_new",
8 | "attn_pdrop": 0.1,
9 | "bos_token_id": 10000,
10 | "embd_pdrop": 0.1,
11 | "eos_token_id": 1,
12 | "initializer_range": 0.02,
13 | "layer_norm_epsilon": 1e-05,
14 | "model_type": "gpt2",
15 | "n_embd": 768,
16 | "n_head": 12,
17 | "n_inner": null,
18 | "n_layer": 12,
19 | "n_positions": 1024,
20 | "reorder_and_upcast_attn": false,
21 | "resid_pdrop": 0.1,
22 | "scale_attn_by_inverse_layer_idx": false,
23 | "scale_attn_weights": true,
24 | "summary_activation": "tanh",
25 | "summary_first_dropout": 0.1,
26 | "summary_proj_to_labels": true,
27 | "summary_type": "cls_index",
28 | "summary_hidden_size": 128,
29 | "summary_use_proj": true,
30 | "transformers_version": "4.31.0",
31 | "use_cache": true,
32 | "vocab_size": 10000,
33 | "num_labels": 9
34 | }
35 | ```
36 |
37 |
38 | ## SAFE Model
39 | ::: safe.trainer.model
40 |
41 | ---
42 |
43 | ## Trainer
44 | ::: safe.trainer.trainer_utils
45 |
46 | ---
47 |
48 | ## Data Collator
49 | ::: safe.trainer.collator
50 |
51 | ---
52 |
53 | ## Data Utils
54 | ::: safe.trainer.data_utils
55 |
56 |
57 |
--------------------------------------------------------------------------------
/docs/api/safe.viz.md:
--------------------------------------------------------------------------------
1 | ::: safe.viz
2 | options:
3 | members:
4 | - to_image
5 | show_root_heading: false
--------------------------------------------------------------------------------
/docs/assets/js/google-analytics.js:
--------------------------------------------------------------------------------
1 | var gtag_id = "G-3XLJELJ2TF";
2 |
3 | var script = document.createElement("script");
4 | script.src = "https://www.googletagmanager.com/gtag/js?id=" + gtag_id;
5 | document.head.appendChild(script);
6 |
7 | window.dataLayer = window.dataLayer || [];
8 | function gtag(){dataLayer.push(arguments);}
9 | gtag('js', new Date());
10 | gtag('config', gtag_id);
11 |
--------------------------------------------------------------------------------
/docs/data_license.md:
--------------------------------------------------------------------------------
1 | ```
2 | {!DATA_LICENSE!}
3 | ```
4 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | 🦺 SAFE
2 | Sequential Attachment-based Fragment Embedding (SAFE) is a novel molecular line notation that represents molecules as an unordered sequence of fragment blocks to improve molecule design using generative models.
3 |
4 |
5 |
6 |

7 |
8 |
9 |
10 |
11 |
12 | Paper
13 | |
14 |
15 | Docs
16 | |
17 |
18 | 🤗 Model
19 | |
20 |
21 | 🤗 Training Dataset
22 |
23 |
24 |
25 | ---
26 |
27 |
28 |
29 | [](https://pypi.org/project/safe-mol/)
30 | [](https://anaconda.org/conda-forge/safe-mol)
31 | [](https://pypi.org/project/safe-mol/)
32 | [](https://anaconda.org/conda-forge/safe-mol)
33 | [](https://pypi.org/project/safe-mol/)
34 | [](https://github.com/datamol-io/safe/blob/main/LICENSE)
35 | [](https://github.com/datamol-io/safe/blob/main/DATA_LICENSE)
36 | [](https://github.com/datamol-io/safe/stargazers)
37 | [](https://github.com/datamol-io/safe/network/members)
38 | [](https://arxiv.org/pdf/2310.10773.pdf)
39 |
40 | [](https://github.com/datamol-io/safe/actions/workflows/test.yml)
41 | [](https://github.com/datamol-io/safe/actions/workflows/release.yml)
42 | [](https://github.com/datamol-io/safe/actions/workflows/code-check.yml)
43 | [](https://github.com/datamol-io/safe/actions/workflows/doc.yml)
44 |
45 | ## Overview of SAFE
46 |
47 | SAFE _is the_ deep learning molecular representation. It's an encoding leveraging a peculiarity in the decoding schemes of SMILES, to allow representation of molecules as a contiguous sequence of connected fragments. SAFE strings are valid SMILES strings, and thus are able to preserve the same amount of information. The intuitive representation of molecules as an ordered sequence of connected fragments greatly simplifies the following tasks often encountered in molecular design:
48 |
49 | - _de novo_ design
50 | - superstructure generation
51 | - scaffold decoration
52 | - motif extension
53 | - linker generation
54 | - scaffold morphing.
55 |
56 | The construction of a SAFE strings requires defining a molecular fragmentation algorithm. By default, we use [BRICS], but any other fragmentation algorithm can be used. The image below illustrates the process of building a SAFE string. The resulting string is a valid SMILES that can be read by [datamol](https://github.com/datamol-io/datamol) or [RDKit](https://github.com/rdkit/rdkit).
57 |
58 |
59 |
60 |

61 |
62 |
63 |
64 | ## News 🚀
65 |
66 | #### 💥 2024/01/15 💥
67 | 1. [@IanAWatson](https://github.com/IanAWatson) has a C++ implementation of SAFE in [LillyMol](https://github.com/IanAWatson/LillyMol/tree/bazel_version_float) that is quite fast and use a custom fragmentation algorithm. Follow the installation instruction on the repo and checkout the docs of the CLI here: [docs/Molecule_Tools/SAFE.md](https://github.com/IanAWatson/LillyMol/blob/bazel_version_float/docs/Molecule_Tools/SAFE.md)
68 |
69 | ### Installation
70 |
71 | You can install `safe` using pip:
72 |
73 | ```bash
74 | pip install safe-mol
75 | ```
76 |
77 | You can use conda/mamba:
78 |
79 | ```bash
80 | mamba install -c conda-forge safe-mol
81 | ```
82 |
83 | ### Datasets and Models
84 |
85 | | Type | Name | Infos | Size | Comment |
86 | | ---------------------- | ------------------------------------------------------------------------------ | ---------- | ----- | -------------------- |
87 | | Model | [datamol-io/safe-gpt](https://huggingface.co/datamol-io/safe-gpt) | 87M params | 350M | Default model |
88 | | Training Dataset | [datamol-io/safe-gpt](https://huggingface.co/datasets/datamol-io/safe-gpt) | 1.1B rows | 250GB | Training dataset |
89 | | Drug Benchmark Dataset | [datamol-io/safe-drugs](https://huggingface.co/datasets/datamol-io/safe-drugs) | 26 rows | 20 kB | Benchmarking dataset |
90 |
91 | ## Usage
92 |
93 |
94 | The tutorials in the [documentation](https://safe-docs.datamol.io/) can help you get started with `safe` and `SAFE-GPT`.
95 |
96 | ### API
97 |
98 | We summarize some key functions provided by the `safe` package below.
99 |
100 | | Function | Description |
101 | | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
102 | | `safe.encode` | Translates a SMILES string into its corresponding SAFE string. |
103 | | `safe.decode` | Translates a SAFE string into its corresponding SMILES string. The SAFE decoder just augment RDKit's `Chem.MolFromSmiles` with an optional correction argument to take care of missing hydrogens bonds. |
104 | | `safe.split` | Tokenizes a SAFE string to build a generative model. |
105 |
106 | ### Examples
107 |
108 | #### Translation between SAFE and SMILES representations
109 |
110 | ```python
111 | import safe
112 |
113 | ibuprofen = "CC(Cc1ccc(cc1)C(C(=O)O)C)C"
114 |
115 | # SMILES -> SAFE -> SMILES translation
116 | try:
117 | ibuprofen_sf = safe.encode(ibuprofen) # c12ccc3cc1.C3(C)C(=O)O.CC(C)C2
118 | ibuprofen_smi = safe.decode(ibuprofen_sf, canonical=True) # CC(C)Cc1ccc(C(C)C(=O)O)cc1
119 | except safe.EncoderError:
120 | pass
121 | except safe.DecoderError:
122 | pass
123 |
124 | ibuprofen_tokens = list(safe.split(ibuprofen_sf))
125 | ```
126 |
127 | ### Training/Finetuning a (new) model
128 |
129 | A command line interface is available to train a new model, please run `safe-train --help`. You can also provide an existing checkpoint to continue training or finetune on you own dataset.
130 |
131 | For example:
132 |
133 | ```bash
134 | safe-train --config \
135 | --model-path \
136 | --tokenizer \
137 | --dataset \
138 | --num_labels 9 \
139 | --torch_compile True \
140 | --optim "adamw_torch" \
141 | --learning_rate 1e-5 \
142 | --prop_loss_coeff 1e-3 \
143 | --gradient_accumulation_steps 1 \
144 | --output_dir "" \
145 | --max_steps 5
146 | ```
147 |
148 |
149 | ## References
150 |
151 | If you use this repository, please cite the following related [paper](https://arxiv.org/abs/2310.10773#):
152 |
153 | ```bib
154 | @misc{noutahi2023gotta,
155 | title={Gotta be SAFE: A New Framework for Molecular Design},
156 | author={Emmanuel Noutahi and Cristian Gabellini and Michael Craig and Jonathan S. C Lim and Prudencio Tossou},
157 | year={2023},
158 | eprint={2310.10773},
159 | archivePrefix={arXiv},
160 | primaryClass={cs.LG}
161 | }
162 | ```
163 |
164 | ## License
165 |
166 | Note that all data and model weights of **SAFE** are exclusively licensed for research purposes. The accompanying dataset is licensed under CC BY 4.0, which permits solely non-commercial usage. See [DATA_LICENSE](data_license.md) for details.
167 |
168 | This code base is licensed under the Apache-2.0 license. See [LICENSE](license.md) for details.
169 |
170 | ## Development lifecycle
171 |
172 | ### Setup dev environment
173 |
174 | ```bash
175 | mamba create -n safe -f env.yml
176 | mamba activate safe
177 |
178 | pip install --no-deps -e .
179 | ```
180 |
181 | ### Tests
182 |
183 | You can run tests locally with:
184 |
185 | ```bash
186 | pytest
187 | ```
188 |
--------------------------------------------------------------------------------
/docs/license.md:
--------------------------------------------------------------------------------
1 | ```
2 | {!LICENSE!}
3 | ```
4 |
--------------------------------------------------------------------------------
/docs/tutorials/load-from-wandb.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "from safe.sample import SAFEDesign"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 3,
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "name": "stderr",
29 | "output_type": "stream",
30 | "text": [
31 | "/Users/emmanuel.noutahi/miniconda3/envs/safe/lib/python3.12/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
32 | " warnings.warn(\n"
33 | ]
34 | }
35 | ],
36 | "source": [
37 | "model = SAFEDesign.load_default()"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "## Upload models to wandb\n",
45 | "\n",
46 | "SAFE models can be uploaded to wandb with the `upload_to_wandb` function. You can define a general \"SAFE_WANDB_PROJECT\" env variable to save all of your models to that project. \n",
47 | "\n",
48 | "Make sure that you are login into your wandb account:\n",
49 | "\n",
50 | "```bash\n",
51 | "wandb login --relogin $WANDB_API_KEY\n",
52 | "```"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 4,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "from safe.io import upload_to_wandb"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 5,
67 | "metadata": {},
68 | "outputs": [
69 | {
70 | "name": "stdout",
71 | "output_type": "stream",
72 | "text": [
73 | "env: WANDB_SILENT=False\n",
74 | "env: SAFE_WANDB_PROJECT=safe-models\n",
75 | "[2024-09-10 13:42:46,004] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to mps (auto detect)\n"
76 | ]
77 | },
78 | {
79 | "name": "stderr",
80 | "output_type": "stream",
81 | "text": [
82 | "W0910 13:42:46.257000 8343047168 torch/distributed/elastic/multiprocessing/redirects.py:28] NOTE: Redirects are currently not supported in Windows or MacOs.\n",
83 | "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
84 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmaclandrol\u001b[0m (\u001b[33mvalencelabs\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
85 | ]
86 | },
87 | {
88 | "data": {
89 | "text/html": [
90 | "wandb version 0.17.9 is available! To upgrade, please run:\n",
91 | " $ pip install wandb --upgrade"
92 | ],
93 | "text/plain": [
94 | ""
95 | ]
96 | },
97 | "metadata": {},
98 | "output_type": "display_data"
99 | },
100 | {
101 | "data": {
102 | "text/html": [
103 | "Tracking run with wandb version 0.16.6"
104 | ],
105 | "text/plain": [
106 | ""
107 | ]
108 | },
109 | "metadata": {},
110 | "output_type": "display_data"
111 | },
112 | {
113 | "data": {
114 | "text/html": [
115 | "Run data is saved locally in /Users/emmanuel.noutahi/Code/safe/nb/wandb/run-20240910_134247-72wmn5st
"
116 | ],
117 | "text/plain": [
118 | ""
119 | ]
120 | },
121 | "metadata": {},
122 | "output_type": "display_data"
123 | },
124 | {
125 | "data": {
126 | "text/html": [
127 | "Syncing run absurd-disco-1 to Weights & Biases (docs)
"
128 | ],
129 | "text/plain": [
130 | ""
131 | ]
132 | },
133 | "metadata": {},
134 | "output_type": "display_data"
135 | },
136 | {
137 | "data": {
138 | "text/html": [
139 | " View project at https://wandb.ai/valencelabs/safe-models"
140 | ],
141 | "text/plain": [
142 | ""
143 | ]
144 | },
145 | "metadata": {},
146 | "output_type": "display_data"
147 | },
148 | {
149 | "data": {
150 | "text/html": [
151 | " View run at https://wandb.ai/valencelabs/safe-models/runs/72wmn5st"
152 | ],
153 | "text/plain": [
154 | ""
155 | ]
156 | },
157 | "metadata": {},
158 | "output_type": "display_data"
159 | },
160 | {
161 | "name": "stderr",
162 | "output_type": "stream",
163 | "text": [
164 | "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (/var/folders/rl/wwcfdj4x0pg293bfqszl970r0000gq/T/tmpy8r3mrzg)... Done. 0.6s\n"
165 | ]
166 | },
167 | {
168 | "data": {
169 | "application/vnd.jupyter.widget-view+json": {
170 | "model_id": "a9466f1c60dc43d18aceca8b74f753f4",
171 | "version_major": 2,
172 | "version_minor": 0
173 | },
174 | "text/plain": [
175 | "VBox(children=(Label(value='333.221 MB of 333.221 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
176 | ]
177 | },
178 | "metadata": {},
179 | "output_type": "display_data"
180 | },
181 | {
182 | "data": {
183 | "text/html": [
184 | " View run absurd-disco-1 at: https://wandb.ai/valencelabs/safe-models/runs/72wmn5st
View project at: https://wandb.ai/valencelabs/safe-models
Synced 6 W&B file(s), 0 media file(s), 5 artifact file(s) and 0 other file(s)"
185 | ],
186 | "text/plain": [
187 | ""
188 | ]
189 | },
190 | "metadata": {},
191 | "output_type": "display_data"
192 | },
193 | {
194 | "data": {
195 | "text/html": [
196 | "Find logs at: ./wandb/run-20240910_134247-72wmn5st/logs
"
197 | ],
198 | "text/plain": [
199 | ""
200 | ]
201 | },
202 | "metadata": {},
203 | "output_type": "display_data"
204 | }
205 | ],
206 | "source": [
207 | "%env WANDB_SILENT=False\n",
208 | "%env SAFE_WANDB_PROJECT=safe-models\n",
209 | "\n",
210 | "upload_to_wandb(model.model, model.tokenizer, artifact_name=\"default-safe-zinc\", slicer=\"BRICS/Partition\", aliases=[\"paper\"])"
211 | ]
212 | },
213 | {
214 | "cell_type": "markdown",
215 | "metadata": {},
216 | "source": [
217 | "## Loading models from wandb"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": 3,
223 | "metadata": {},
224 | "outputs": [
225 | {
226 | "name": "stdout",
227 | "output_type": "stream",
228 | "text": [
229 | "env: SAFE_MODEL_ROOT=/Users/emmanuel.noutahi/.cache/wandb/safe/\n"
230 | ]
231 | },
232 | {
233 | "name": "stderr",
234 | "output_type": "stream",
235 | "text": [
236 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact default-safe-zinc:latest, 333.22MB. 5 files... \n",
237 | "\u001b[34m\u001b[1mwandb\u001b[0m: 5 of 5 files downloaded. \n",
238 | "Done. 0:0:0.7\n"
239 | ]
240 | }
241 | ],
242 | "source": [
243 | "%env SAFE_MODEL_ROOT=/Users/emmanuel.noutahi/.cache/wandb/safe/\n",
244 | "designer = SAFEDesign.load_from_wandb(\"safe-models/default-safe-zinc\")"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": 4,
250 | "metadata": {},
251 | "outputs": [
252 | {
253 | "data": {
254 | "application/vnd.jupyter.widget-view+json": {
255 | "model_id": "96de1483b3894961a8ec2690df4f8ace",
256 | "version_major": 2,
257 | "version_minor": 0
258 | },
259 | "text/plain": [
260 | " 0%| | 0/1 [00:00, ?it/s]"
261 | ]
262 | },
263 | "metadata": {},
264 | "output_type": "display_data"
265 | },
266 | {
267 | "name": "stderr",
268 | "output_type": "stream",
269 | "text": [
270 | "/Users/emmanuel.noutahi/miniconda3/envs/safe/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
271 | " warnings.warn(\n",
272 | "/Users/emmanuel.noutahi/miniconda3/envs/safe/lib/python3.12/site-packages/transformers/generation/configuration_utils.py:615: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`.\n",
273 | " warnings.warn(\n"
274 | ]
275 | },
276 | {
277 | "data": {
278 | "text/plain": [
279 | "['C[C@]1(C(=O)N2CCC[C@@H](NC(=O)C#CC3CC3)CC2)CCNC1=O',\n",
280 | " 'CN(C(=O)CN1CC[NH+](C[C@@H](O)Cn2cc([N+](=O)[O-])cn2)CC1)c1ccccc1',\n",
281 | " 'CC[C@@H](C)C[C@@H]([NH3+])C(=O)N(CC)C[C@@H]1CCOC1',\n",
282 | " 'Cc1nnc2n1C[C@H](CNC(=O)Nc1cc(Cl)ccc1Cl)CC2',\n",
283 | " 'CCc1cccc(CC)c1NC(=O)[C@H](C)OC(=O)CCc1nc2ccccc2o1',\n",
284 | " 'Cc1cc(OC[C@H](O)C[NH2+]C[C@@H]2C[C@H](O)CN2Cc2ccccc2)ccc1F',\n",
285 | " 'Cc1c(Cl)cccc1N=C(O)CN=C(O)COC(=O)c1csc(-c2ccccc2)n1',\n",
286 | " 'CCc1nc(CCNC(=O)N[C@@H]2CCc3nnnn3CC2)cs1',\n",
287 | " 'C[C@@]1(C(=O)N[C@H]2CCCCCN(C(=O)c3cc(C4CC4)no3)C2)C=CCC1',\n",
288 | " 'Cc1cc(-c2cc(-c3cnn(C)c3)c3c(N)ncnc3n2)ccc1F']"
289 | ]
290 | },
291 | "execution_count": 4,
292 | "metadata": {},
293 | "output_type": "execute_result"
294 | }
295 | ],
296 | "source": [
297 | "designer.de_novo_generation(10)"
298 | ]
299 | }
300 | ],
301 | "metadata": {
302 | "kernelspec": {
303 | "display_name": "safe",
304 | "language": "python",
305 | "name": "python3"
306 | },
307 | "language_info": {
308 | "codemirror_mode": {
309 | "name": "ipython",
310 | "version": 3
311 | },
312 | "file_extension": ".py",
313 | "mimetype": "text/x-python",
314 | "name": "python",
315 | "nbconvert_exporter": "python",
316 | "pygments_lexer": "ipython3",
317 | "version": "3.12.5"
318 | }
319 | },
320 | "nbformat": 4,
321 | "nbformat_minor": 2
322 | }
323 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | channels:
2 | - conda-forge
3 |
4 | dependencies:
5 | - python >=3.9
6 | - pip
7 | - tqdm
8 | - loguru
9 | - typer
10 | - universal_pathlib
11 |
12 | # Scientific
13 | - datamol
14 | - pandas <=2.1.1
15 | - numpy
16 | - pytorch >=2.0
17 | - transformers
18 | - datasets
19 | - tokenizers
20 | - accelerate >=0.33 # for accelerator_config update
21 | - evaluate
22 | - wandb
23 | - huggingface_hub
24 |
25 | # Optional
26 | - deepspeed
27 |
28 | # dev
29 | - black >=24
30 | - ruff
31 | - pytest >=6.0
32 | - jupyterlab
33 | - nbconvert
34 | - ipywidgets
35 |
36 | - pip:
37 | - mkdocs <1.6.0
38 | - mkdocs-material >=7.1.1
39 | - mkdocs-material-extensions
40 | - mkdocstrings
41 | - mkdocstrings-python
42 | - mkdocs-jupyter
43 | - markdown-include
44 | - mdx_truly_sane_lists
45 | - mike >=1.0.0
46 |
--------------------------------------------------------------------------------
/expts/config/accelerate.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | gradient_accumulation_steps: 2
4 | gradient_clipping: 1.0
5 | offload_optimizer_device: cpu
6 | offload_param_device: cpu
7 | zero3_init_flag: true
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'no'
11 | dynamo_config:
12 | dynamo_backend: INDUCTOR
13 | machine_rank: 0
14 | main_training_function: main
15 | mixed_precision: 'no'
16 | num_machines: 1
17 | num_processes: 2
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_env: []
21 | tpu_use_cluster: false
22 | tpu_use_sudo: false
23 | use_cpu: false
24 |
--------------------------------------------------------------------------------
/expts/scripts/slurm-data-build.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## Name of your SLURM job
4 | #SBATCH --job-name=jupyter
5 | #SBATCH --cpus-per-task=64
6 | #SBATCH --mem=200G
7 | #SBATCH --time=1-12:00
8 |
9 | set -ex
10 | # The below env variables can eventually help setting up your workload.
11 | # In a SLURM job, you CANNOT use `conda activate` and instead MUST use:
12 | source activate safe
13 | python scripts/build_dataset.py
--------------------------------------------------------------------------------
/expts/scripts/slurm-notebook.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## Name of your SLURM job
4 | #SBATCH --job-name=jupyter
5 | #SBATCH --cpus-per-task=32
6 | #SBATCH --mem=100G
7 | #SBATCH --time=1-12:00
8 |
9 | set -ex
10 | # The below env variables can eventually help setting up your workload.
11 | # In a SLURM job, you CANNOT use `conda activate` and instead MUST use:
12 | source activate safe
13 | jupyter lab /home/emmanuel/safe/
--------------------------------------------------------------------------------
/expts/scripts/slurm-tokenizer-train-custom.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## Name of your SLURM job
4 | #SBATCH --job-name=split-train-safe-tokenizer
5 | #SBATCH --output=/home/emmanuel/safe/expts/output/job_split_%x_%a.out
6 | #SBATCH --error=/home/emmanuel/safe/expts/output/job_split_%x_%a.out
7 | #SBATCH --open-mode=append
8 | #SBATCH --cpus-per-task=32
9 | #SBATCH --mem=80G
10 | #SBATCH --time=48:00:00
11 |
12 | set -ex
13 | # The below env variables can eventually help setting up your workload.
14 | # In a SLURM job, you CANNOT use `conda activate` and instead MUST use:
15 | source activate safe
16 |
17 | TOKENIZER_TYPE="bpe"
18 | DEFAULT_DATASET="/storage/shared_data/cristian/preprocessed_zinc_unichem/train_filtered/"
19 | DATASET="${1:-$DEFAULT_DATASET}"
20 | #DATASET="/home/emmanuel/safe/expts/notebook/tmp_data/proc_data"
21 | OUTPUT="/home/emmanuel/safe/expts/tokenizer/tokenizer-custom.json"
22 | VOCAB_SIZE="10000"
23 | TEXT_COLUMN="input"
24 | BATCH_SIZE="1000"
25 | N_EXAMPLES="500000000"
26 |
27 | python scripts/tokenizer_trainer.py --tokenizer_type $TOKENIZER_TYPE \
28 | --dataset $DATASET --text_column $TEXT_COLUMN \
29 | --vocab_size $VOCAB_SIZE --batch_size $BATCH_SIZE \
30 | --outfile $OUTPUT --splitter 'safe' --n_examples $N_EXAMPLES --tokenizer_name "safe-custom"
--------------------------------------------------------------------------------
/expts/scripts/slurm-tokenizer-train-small.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## Name of your SLURM job
4 | #SBATCH --job-name=small-safe-tokenizer
5 | #SBATCH --output=/home/emmanuel/safe/expts/output/job_test_%x_%a.out
6 | #SBATCH --error=/home/emmanuel/safe/expts/output/job_test_%x_%a.out
7 | #SBATCH --open-mode=append
8 | #SBATCH --cpus-per-task=16
9 | #SBATCH --mem=32G
10 | #SBATCH --time=2:00:00
11 |
12 | set -ex
13 | # The below env variables can eventually help setting up your workload.
14 | # In a SLURM job, you CANNOT use `conda activate` and instead MUST use:
15 | source activate safe
16 |
17 | TOKENIZER_TYPE="bpe"
18 | DEFAULT_DATASET="/storage/shared_data/cristian/preprocessed_zinc_unichem/train_filtered/"
19 | DATASET="${1:-$DEFAULT_DATASET}"
20 | #DATASET="/home/emmanuel/safe/expts/notebook/tmp_data/proc_data"
21 | OUTPUT="/home/emmanuel/safe/expts/tokenizer/tokenizer-custom-test.json"
22 | VOCAB_SIZE="10000"
23 | TEXT_COLUMN="input"
24 | BATCH_SIZE="1000"
25 | N_EXAMPLES="1000000"
26 |
27 | python scripts/tokenizer_trainer.py --tokenizer_type $TOKENIZER_TYPE \
28 | --dataset $DATASET --text_column $TEXT_COLUMN \
29 | --vocab_size $VOCAB_SIZE --batch_size $BATCH_SIZE \
30 | --outfile $OUTPUT --n_examples $N_EXAMPLES
--------------------------------------------------------------------------------
/expts/scripts/slurm-tokenizer-train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## Name of your SLURM job
4 | #SBATCH --job-name=bpe-train-safe-tokenizer
5 | #SBATCH --output=/home/emmanuel/safe/expts/output/job_bpe_%x_%a.out
6 | #SBATCH --error=/home/emmanuel/safe/expts/output/job_bpe_%x_%a.out
7 | #SBATCH --open-mode=append
8 | #SBATCH --cpus-per-task=32
9 | #SBATCH --mem=80G
10 | #SBATCH --time=48:00:00
11 |
12 | set -ex
13 | # The below env variables can eventually help setting up your workload.
14 | # In a SLURM job, you CANNOT use `conda activate` and instead MUST use:
15 | source activate safe
16 |
17 | TOKENIZER_TYPE="bpe"
18 | DEFAULT_DATASET="/storage/shared_data/cristian/preprocessed_zinc_unichem/train_filtered/"
19 | DATASET="${1:-$DEFAULT_DATASET}"
20 | #DATASET="/home/emmanuel/safe/expts/notebook/tmp_data/proc_data"
21 | OUTPUT="/home/emmanuel/safe/expts/tokenizer/tokenizer.json"
22 | VOCAB_SIZE="10000"
23 | TEXT_COLUMN="input"
24 | BATCH_SIZE="1500"
25 | N_EXAMPLES="500000000"
26 |
27 | python scripts/tokenizer_trainer.py --tokenizer_type $TOKENIZER_TYPE \
28 | --dataset $DATASET --text_column $TEXT_COLUMN \
29 | --vocab_size $VOCAB_SIZE --batch_size $BATCH_SIZE \
30 | --outfile $OUTPUT --n_examples $N_EXAMPLES
--------------------------------------------------------------------------------
/expts/scripts/train-small.sh:
--------------------------------------------------------------------------------
1 | accelerate launch --config_file config/accelerate.yaml \
2 | scripts/model_trainer.py --tokenizer "tokenizer/tokenizer-custom.json" \
3 | --dataset data/ --text_column "input" \
4 | --is_tokenized False --streaming True \
5 | --num_labels 1 --include_descriptors False \
6 | --gradient_accumulation_steps 2 --wandb_watch 'gradients' \
7 | --per_device_train_batch_size 32 --num_train_epochs 5 --save_steps 2000 --save_total_limit 10 \
8 | --eval_accumulation_steps 100 --logging_steps 200 --logging_first_step True \
9 | --save_safetensors True --do_train True --output_dir output/test/ \
10 | --learning_rate 5e-4 --warmup_steps 1000 --gradient_checkpointing True --max_steps 15_000
11 |
--------------------------------------------------------------------------------
/expts/scripts/train.sh:
--------------------------------------------------------------------------------
1 | accelerate launch --config_file config/accelerate.yaml \
2 | scripts/model_trainer.py --tokenizer "tokenizer/tokenizer-custom.json" \
3 | --dataset ~/data/ --text_column "input" \
4 | --is_tokenized False --streaming True \
5 | --num_labels 1 --include_descriptors False \
6 | --gradient_accumulation_steps 2 --wandb_watch 'gradients' \
7 | --per_device_train_batch_size 64 --num_train_epochs 2 --save_steps 5000 --save_total_limit 10 \
8 | --eval_accumulation_steps 100 --logging_steps 500 --logging_first_step True \
9 | --save_safetensors True --do_train True --output_dir output/safe/ \
10 | --learning_rate 5e-5 --warmup_steps 2500 --gradient_checkpointing True --max_steps 30_000_000
--------------------------------------------------------------------------------
/expts/tokenizer/_tokenizer-custom-mini-test.json:
--------------------------------------------------------------------------------
1 | {"version": "1.0", "truncation": null, "padding": null, "added_tokens": [{"id": 0, "content": "[PAD]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}, {"id": 1, "content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}, {"id": 2, "content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}, {"id": 3, "content": "[MASK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}, {"id": 4, "content": "[UNK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}], "normalizer": null, "pre_tokenizer": {"type": "Whitespace"}, "post_processor": {"type": "TemplateProcessing", "single": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}], "pair": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 1}}, {"SpecialToken": {"id": "[SEP]", "type_id": 1}}], "special_tokens": {"[CLS]": {"id": "[CLS]", "ids": [1], "tokens": ["[CLS]"]}, "[SEP]": {"id": "[SEP]", "ids": [2], "tokens": ["[SEP]"]}}}, "decoder": {"type": "BPEDecoder", "suffix": ""}, "model": {"type": "BPE", "dropout": null, "unk_token": "[UNK]", "continuing_subword_prefix": null, "end_of_word_suffix": null, "fuse_unk": false, "byte_fallback": false, "vocab": {"[UNK]": 0, "[CLS]": 1, "[SEP]": 2, "[PAD]": 3, "[MASK]": 4, "#": 5, "%": 6, "(": 7, ")": 8, "+": 9, "-": 10, ".": 11, "/": 12, "0": 13, "1": 14, "2": 15, "3": 16, "4": 17, "5": 18, "6": 19, "7": 20, "8": 21, "9": 22, "=": 23, "@": 24, "A": 25, "B": 26, "C": 27, "D": 28, "E": 29, "F": 30, "G": 31, "H": 32, "I": 33, "K": 34, "L": 35, "M": 36, "N": 37, "O": 38, "P": 39, "R": 40, "S": 41, "T": 42, "U": 43, "V": 44, "W": 45, "X": 46, "Y": 47, "Z": 48, "[": 49, "\\": 50, "]": 51, "a": 52, "b": 53, "c": 54, "d": 55, "e": 56, "f": 57, "g": 58, "h": 59, "i": 60, "l": 61, "m": 62, "n": 63, "o": 64, "p": 65, "r": 66, "s": 67, "t": 68, "u": 69, "y": 70, "H]": 71, "[C": 72, "%1": 73, "[C@": 74, "[C@@": 75, "[C@H]": 76, "[C@@H]": 77, "Cl": 78, "%10": 79, "%11": 80, "[n": 81, "+]": 82, "[N": 83, "[nH]": 84, "-]": 85, "%12": 86, "Br": 87, "[O": 88, "[O-]": 89, "[N+]": 90, "%13": 91, "%2": 92, "%14": 93, "[C@]": 94, "[NH": 95, "%15": 96, "[C@@]": 97, "%16": 98, "[S": 99, "[NH+]": 100, "[Si": 101, "%17": 102, "[N-]": 103, "[Si]": 104, "%18": 105, "%3": 106, "%19": 107, "H+]": 108, "2+]": 109, "[nH+]": 110, "[NH2+]": 111, "%20": 112, "%21": 113, "[2": 114, "[2H]": 115, "%22": 116, "%23": 117, "[n+]": 118, "%4": 119, "%24": 120, "[CH": 121, "2]": 122, "%25": 123, "[C-]": 124, "%26": 125, "%27": 126, "[B": 127, "%28": 128, "[P": 129, "[c": 130, "[Na": 131, "%5": 132, "%29": 133, "[Cl": 134, "+2]": 135, "[Na+]": 136, "%30": 137, "[C]": 138, "[Cl-]": 139, "%31": 140, "%32": 141, "3+]": 142, "[NH3+]": 143, "@]": 144, "[O+]": 145, "[CH-]": 146, "3]": 147, "%33": 148, "[c-]": 149, "[CH2]": 150, "%34": 151, "[CH]": 152, "e]": 153, "[I": 154, "%6": 155, "%35": 156, "[B+]": 157, "%36": 158, "[SiH": 159, "%37": 160, "[O]": 161, "[n-]": 162, "%38": 163, "2-]": 164, "[P+]": 165, "[B-]": 166, "[Z": 167, "[CH2-]": 168, "@@]": 169, "%39": 170, "[B]": 171, "[Si-]": 172, "[A": 173, "[H": 174, "[SiH]": 175, "+3]": 176, "%40": 177, "[S-]": 178, "%7": 179, "%41": 180, "[1": 181, "[c]": 182, "[F": 183, "[K": 184, "[Se]": 185, "[S@]": 186, "%42": 187, "[S+]": 188, "H-]": 189, "[SiH2]": 190, "[T": 191, "[Br": 192, "r]": 193, "[S@@]": 194, "[K+]": 195, "%43": 196, "[Y": 197, "[HH]": 198, "[Y]": 199, "[Br-]": 200, "%44": 201, "[N]": 202, "[cH-]": 203, "[R": 204, "[H]": 205, "[OH+]": 206, "%45": 207, "3-]": 208, "[PH]": 209, "n]": 210, "%8": 211, "[NH-]": 212, "[CH3-]": 213, "%46": 214, "[Pt": 215, "[I-]": 216, "[L": 217, "[CH+]": 218, "%47": 219, "[Zr": 220, "[Li": 221, "[C+]": 222, "[Ti": 223, "[13": 224, "%48": 225, "[SiH3]": 226, "o+]": 227, "[o+]": 228, "%49": 229, "[Na]": 230, "[M": 231, "[CH3]": 232, "[I+]": 233, "[U": 234, "%50": 235, "[Ir]": 236, "%9": 237, "[G": 238, "[Li+]": 239, "%51": 240, "[Cu": 241, "4]": 242, "n+2]": 243, "%52": 244, "[Zr+2]": 245, "+4]": 246, "[Sn]": 247, "[As": 248, "[13C": 249, "%53": 250, "%54": 251, "[Co": 252, "%55": 253, "[s": 254, "%56": 255, "[Fe": 256, "[U+2]": 257, "[Al": 258, "%57": 259, "[Pt+2]": 260, "[SH]": 261, "c]": 262, "[P@]": 263, "[Pd": 264, "%58": 265, "[Ni": 266, "[W": 267, "%59": 268, "[As]": 269, "[Cl+3]": 270, "[Ge]": 271, "[Fe+2]": 272, "[P@@]": 273, "[3": 274, "[Zn+2]": 275, "H2]": 276, "[3H]": 277, "[Ru": 278, "[Cl+]": 279, "[Pt]": 280, "[F-]": 281, "%60": 282, "[Zr]": 283, "[se]": 284, "%61": 285, "[W]": 286, "[Cu+2]": 287, "a+2]": 288, "%62": 289, "@+]": 290, "[Mg": 291, "[Co]": 292, "[Si+]": 293, "%63": 294, "[Ac]": 295, "H2+]": 296, "[CH2+]": 297, "%64": 298, "[Pd+2]": 299, "[F+]": 300, "[Ti]": 301, "%65": 302, "[P-]": 303, "[Ti+2]": 304, "-2]": 305, "[Hf": 306, "[OH2+]": 307, "[Fe]": 308, "[Mg+2]": 309, "[Sn": 310, "%66": 311, "[Ca+2]": 312, "[Te]": 313, "[Ni+2]": 314, "r+3]": 315, "b]": 316, "%67": 317, "[Re]": 318, "[Ni]": 319, "[V": 320, "[Al]": 321, "[Rh": 322, "[13C]": 323, "[NH]": 324, "%68": 325, "%69": 326, "[Zr+4]": 327, "[N@+]": 328, "%70": 329, "[Co+2]": 330, "[Ru+2]": 331, "@H]": 332, "[O-2]": 333, "[Zn]": 334, "[S]": 335, "[Mo": 336, "[U]": 337, "%71": 338, "[OH-]": 339, "%72": 340, "%73": 341, "[Li]": 342, "[Ti+4]": 343, "[Ir+3]": 344, "%74": 345, "[Cu]": 346, "[Pd]": 347, "H3]": 348, "[Sn+]": 349, "b+]": 350, "[Cu+]": 351, "[s+]": 352, "[18": 353, "[V]": 354, "a]": 355, "[Cr]": 356, "[Au": 357, "%75": 358, "%76": 359, "%77": 360, "cH]": 361, "%78": 362, "[K]": 363, "[Rb+]": 364, "[Sn+2]": 365, "[Ti+3]": 366, "[Br+]": 367, "[Se": 368, "[Al+3]": 369, "%79": 370, "[14": 371, "%80": 372, "[13CH2]": 373, "[Tl": 374, "[Al+]": 375, "g+]": 376, "[Mo]": 377, "F]": 378, "[Pb]": 379, "[15": 380, "[Hf]": 381, "[18F]": 382, "%81": 383, "[Mn+2]": 384, "[Rh]": 385, "[13cH]": 386, "@@+]": 387, "[N@@+]": 388, "%82": 389, "[Fe+3]": 390, "[H+]": 391, "f]": 392, "[H-]": 393, "%83": 394, "[Mn]": 395, "m]": 396, "[Sb": 397, "[13c]": 398, "[Ge": 399, "[14C": 400, "%84": 401, "%85": 402, "[13CH3]": 403, "@@H]": 404, "[PH+]": 405, "[Rf]": 406, "[Tl]": 407, "[15N": 408, "[Ba+2]": 409, "g]": 410, "[PH2]": 411, "%86": 412, "%87": 413, "%88": 414, "%89": 415, "%90": 416, "[Gd": 417, "[Au+]": 418, "5]": 419, "I]": 420, "[Os": 421, "%91": 422, "%92": 423, "[E": 424, "r+2]": 425, "[NH2-]": 426, "%93": 427, "%94": 428, "%95": 429, "[AsH]": 430, "+5]": 431, "[IH+]": 432, "[p": 433, "[Bi": 434, "[Ag+]": 435, "[12": 436, "[15N]": 437, "[13C@@H]": 438, "[Ru]": 439, "a+3]": 440, "[Sb]": 441, "[In]": 442, "%96": 443, "[13CH]": 444, "[W+2]": 445, "[Cr+2]": 446, "[As-]": 447, "[Hf+4]": 448, "[SeH]": 449, "[pH]": 450, "s+]": 451, "[Cs+]": 452, "[Pt+4]": 453, "Dy": 454, "[Dy": 455, "[BH": 456, "[P]": 457, "[Hg]": 458, "[11": 459, "[AsH2]": 460, "[Sb+5]": 461, "[SH+]": 462, "[Hg+]": 463, "%97": 464, "%98": 465, "%99": 466, "[13C@H]": 467, "[Co+3]": 468, "[Gd]": 469, "n+3]": 470, "[Nd": 471, "[Pr]": 472, "[11C": 473, "[Cr+3]": 474, "[Ca]": 475, "[BH-]": 476, "[1C": 477, "[Ga]": 478, "[Eu": 479, "[Bi]": 480, "[Cd": 481, "[ClH+]": 482, "[Ru+]": 483, "[Mg]": 484, "[V+2]": 485, "5I]": 486, "[cH+]": 487, "[I]": 488, "[Y+3]": 489, "[Zr+3]": 490, "[Sn+4]": 491, "[Os]": 492, "[125I]": 493, "[Dy]": 494, "[Nd+3]": 495, "[Cd+2]": 496, "[Cm]": 497, "[S-2]": 498, "[Fm]": 499, "[Ti+]": 500, "[Ga+3]": 501, "[Rh+2]": 502, "[Rh+3]": 503, "[Se-]": 504, "[t": 505, "n+]": 506, "o]": 507, "[Si@@]": 508, "[Hf+2]": 509, "[18O": 510, "[Sb+3]": 511, "[te]": 512, "99": 513, "[99": 514, "[c+]": 515, "[Ag]": 516, "[Al+2]": 517, "[Ru+3]": 518, "[Mg+]": 519, "[Sn+3]": 520, "[Au+3]": 521, "[Tl+]": 522, "[14C]": 523, "[Gd+3]": 524, "[BH3-]": 525, "[11CH3]": 526, "[1CH2]": 527, "[99T": 528, "[Si@H]": 529, "[B@": 530, "[PH": 531, "[SiH-]": 532, "[Mn+3]": 533, "[Ru+4]": 534, "b+3]": 535, "[SH2+]": 536, "[Si@]": 537, "[SiH4]": 538, "[Tc]": 539, "[Pt+]": 540, "[Pd+]": 541, "[Au]": 542, "[Ge-]": 543, "[99Tc]": 544, "e+]": 545, "[N@]": 546, "[PH-]": 547, "[Rb]": 548, "[La+3]": 549, "[15n]": 550, "[Eu]": 551, "+6": 552, "[SH2]": 553, "[PH2+]": 554, "[In+3]": 555, "[In+]": 556, "[Mo+2]": 557, "[Mo+3]": 558, "[Se+]": 559, "[14CH2]": 560, "[Os+]": 561, "[18O]": 562, "+6]": 563, "1I]": 564, "[o": 565, "b+2]": 566, "c+3]": 567, "s]": 568, "[Ce": 569, "[N@@]": 570, "[Sc]": 571, "[Po]": 572, "[As+]": 573, "[Hf+]": 574, "[GeH]": 575, "[GeH2]": 576, "[14C@@H]": 577, "[Os+2]": 578, "[Es]": 579, "[Eu+3]": 580, "[oH+]": 581, "4I]": 582, "[b": 583, "i]": 584, "[Sr+2]": 585, "[SiH2-]": 586, "[Zn+]": 587, "[1cH]": 588, "[Ta]": 589, "[131I]": 590, "[U+4]": 591, "[Rh+]": 592, "[Mo+4]": 593, "[Ge@]": 594, "[14C@H]": 595, "[124I]": 596, "[Dy+3]": 597, "[18OH]": 598, "[B@@": 599, "[PH3+]": 600, "[b-]": 601, "3I]": 602, "e+2]": 603, "g+2]": 604, "r+]": 605, "t]": 606, "[Co+]": 607, "[Cf]": 608, "[Nb]": 609, "[Sm]": 610, "[Be+2]": 611, "[Pb+]": 612, "[Pb+2]": 613, "[At]": 614, "[Ho]": 615, "[Hg+2]": 616, "[Tc": 617, "[Y+]": 618, "[La]": 619, "[Zr+]": 620, "[si]": 621, "[Fe+]": 622, "[Al-]": 623, "[Ni+3]": 624, "[14c]": 625, "[14cH]": 626, "[Sb+2]": 627, "[Ge+]": 628, "[123I]": 629, "[BH2-]": 630, "[11CH2]": 631, "[B@-]": 632, "[PH2-]": 633, "-5]": 634, "3-2]": 635, "Xe]": 636, "[Xe]": 637, "a+]": 638, "m+3]": 639, "u]": 640, "u+3]": 641, "[Cr": 642, "[SH-]": 643, "[Sc+3]": 644, "[Si@@H]": 645, "[BH]": 646, "[Ba]": 647, "[Pa]": 648, "[Cl]": 649, "[IH]": 650, "[I+2]": 651, "[Ir+]": 652, "[Ar]": 653, "[Ta": 654, "[Te+]": 655, "[Yb+3]": 656, "[Ra]": 657, "[Lu+3]": 658, "[Ga+2]": 659, "[se+]": 660, "[Ni+]": 661, "[Hf+3]": 662, "[V+3]": 663, "[Ge@@]": 664, "[15NH]": 665, "[15N+]": 666, "[Er]": 667, "[p+]": 668, "[Bi+]": 669, "[Bi+3]": 670, "[BH3-2]": 671, "[Nd]": 672, "[1CH3]": 673, "[Ce+3]": 674, "[B@@-]": 675, "+7": 676, "2P": 677, "6N": 678, "O]": 679, "S]": 680, "[7": 681, "eH]": 682, "o+3]": 683, "[N+2]": 684, "[No]": 685, "[Nb+3]": 686, "[Nb+2]": 687, "Br]": 688, "[OH]": 689, "[Sr]": 690, "[S@+]": 691, "[Sb+]": 692, "[Sm+3]": 693, "[Si+4]": 694, "[CH3+]": 695, "[B-5]": 696, "[Po+]": 697, "[Pr+3]": 698, "[PH3]": 699, "[Pb+3]": 700, "[Cl+2]": 701, "[In+2]": 702, "[Ho+3]": 703, "[17": 704, "[1H]": 705, "[16N": 706, "[Tm]": 707, "[Tb+3]": 708, "[Tc+3]": 709, "[TeH]": 710, "[Mo+]": 711, "[U+3]": 712, "[Cu+4]": 713, "[sH+]": 714, "[Fe-": 715, "[Pd+3]": 716, "[W+4]": 717, "[32P": 718, "[Ru+6]": 719, "[V+]": 720, "[Mo+6]": 721, "[SeH2]": 722, "[Se-2]": 723, "[GeH3]": 724, "[14CH]": 725, "[14C@]": 726, "[14CH3]": 727, "[15NH2]": 728, "[Er+3]": 729, "[Bi+2]": 730, "[11C]": 731, "[Eu+2]": 732, "[Ce+4]": 733, "[B@@H-]": 734, "[Tc+4]": 735, "[Ta+5]": 736, "+7]": 737, "[32P]": 738, "13": 739, "1I": 740, "3S": 741, "4-]": 742, "5S]": 743, "5Br]": 744, "67": 745, "6S]": 746, "6Br]": 747, "7L": 748, "9P": 749, "Bi]": 750, "Cu]": 751, "OH]": 752, "[67": 753, "c+]": 754, "e+5]": 755, "f+2]": 756, "h+4]": 757, "n-]": 758, "n+5]": 759, "p]": 760, "[Ce]": 761, "[Ce+]": 762, "[Cr+]": 763, "[Cf+2]": 764, "[N+3]": 765, "[Np]": 766, "[O+2]": 767, "[SH3]": 768, "[213": 769, "[Ba+]": 770, "[P@+]": 771, "[Pu]": 772, "[ClH2+]": 773, "[Ir+2]": 774, "[In-]": 775, "[Am]": 776, "[1c]": 777, "[1OH]": 778, "[F]": 779, "[Ta+3]": 780, "[Tm+3]": 781, "[Th+4]": 782, "[Br+2]": 783, "[Re+]": 784, "[Ra+]": 785, "[Re+5]": 786, "[Pt+6]": 787, "[Mn+5]": 788, "[U-]": 789, "[Ga+]": 790, "[Cu-]": 791, "[Cu+6]": 792, "[Cu-5]": 793, "[AsH2+]": 794, "[13C@]": 795, "[Ni-2]": 796, "[W+]": 797, "[W-]": 798, "[35S]": 799, "[36S]": 800, "[Ru-]": 801, "[V+5]": 802, "[149P": 803, "[153S": 804, "[Gd+2]": 805, "[Os+5]": 806, "[Os+6]": 807, "[121I]": 808, "[BH4-]": 809, "[11O]": 810, "[111I": 811, "[1C]": 812, "[1CH]": 813, "[Cd]": 814, "[99Tc+3]": 815, "[99Tc+]": 816, "[B@H-]": 817, "[PH5]": 818, "[Ce+2]": 819, "[Tc+2]": 820, "[Tc+7]": 821, "[Cr-]": 822, "[Cr+4]": 823, "[Cr+7]": 824, "[Ta+4]": 825, "[75Br]": 826, "[76Br]": 827, "[17O]": 828, "[177L": 829, "[16N]": 830, "[16N+]": 831, "[Fe-3]": 832, "[Fe-4]": 833, "[67Cu]": 834, "[213Bi]": 835, "[149Pm]": 836, "[153Sm]": 837, "[111In]": 838, "[177Lu]": 839}, "merges": ["H ]", "[ C", "% 1", "[C @", "[C@ @", "[C@ H]", "[C@@ H]", "C l", "%1 0", "%1 1", "[ n", "+ ]", "[ N", "[n H]", "- ]", "%1 2", "B r", "[ O", "[O -]", "[N +]", "%1 3", "% 2", "%1 4", "[C@ ]", "[N H", "%1 5", "[C@@ ]", "%1 6", "[ S", "[NH +]", "[S i", "%1 7", "[N -]", "[Si ]", "%1 8", "% 3", "%1 9", "H +]", "2 +]", "[n H+]", "[NH 2+]", "%2 0", "%2 1", "[ 2", "[2 H]", "%2 2", "%2 3", "[n +]", "% 4", "%2 4", "[C H", "2 ]", "%2 5", "[C -]", "%2 6", "%2 7", "[ B", "%2 8", "[ P", "[ c", "[N a", "% 5", "%2 9", "[C l", "+ 2]", "[Na +]", "%3 0", "[C ]", "[Cl -]", "%3 1", "%3 2", "3 +]", "[NH 3+]", "@ ]", "[O +]", "[CH -]", "3 ]", "%3 3", "[c -]", "[CH 2]", "%3 4", "[C H]", "e ]", "[ I", "% 6", "%3 5", "[B +]", "%3 6", "[Si H", "%3 7", "[O ]", "[n -]", "%3 8", "2 -]", "[P +]", "[B -]", "[ Z", "[CH 2-]", "@ @]", "%3 9", "[B ]", "[Si -]", "[ A", "[ H", "[Si H]", "+ 3]", "%4 0", "[S -]", "% 7", "%4 1", "[ 1", "[c ]", "[ F", "[ K", "[S e]", "[S @]", "%4 2", "[S +]", "H -]", "[SiH 2]", "[ T", "[ Br", "r ]", "[S @@]", "[K +]", "%4 3", "[ Y", "[H H]", "[Y ]", "[Br -]", "%4 4", "[N ]", "[c H-]", "[ R", "[ H]", "[O H+]", "%4 5", "3 -]", "[P H]", "n ]", "% 8", "[NH -]", "[CH 3-]", "%4 6", "[P t", "[I -]", "[ L", "[C H+]", "%4 7", "[Z r", "[L i", "[C +]", "[T i", "[1 3", "%4 8", "[SiH 3]", "o +]", "[ o+]", "%4 9", "[Na ]", "[ M", "[CH 3]", "[I +]", "[ U", "%5 0", "[I r]", "% 9", "[ G", "[Li +]", "%5 1", "[C u", "4 ]", "n +2]", "%5 2", "[Zr +2]", "+ 4]", "[S n]", "[A s", "[13 C", "%5 3", "%5 4", "[C o", "%5 5", "[ s", "%5 6", "[F e", "[U +2]", "[A l", "%5 7", "[Pt +2]", "[S H]", "c ]", "[P @]", "[P d", "%5 8", "[N i", "[ W", "%5 9", "[As ]", "[Cl +3]", "[G e]", "[Fe +2]", "[P @@]", "[ 3", "[Z n+2]", "H 2]", "[3 H]", "[R u", "[Cl +]", "[Pt ]", "[F -]", "%6 0", "[Z r]", "[s e]", "%6 1", "[W ]", "[Cu +2]", "a +2]", "%6 2", "@ +]", "[M g", "[Co ]", "[Si +]", "%6 3", "[A c]", "H 2+]", "[CH 2+]", "%6 4", "[Pd +2]", "[F +]", "[Ti ]", "%6 5", "[P -]", "[Ti +2]", "- 2]", "[H f", "[O H2+]", "[F e]", "[Mg +2]", "[S n", "%6 6", "[C a+2]", "[T e]", "[Ni +2]", "r +3]", "b ]", "%6 7", "[R e]", "[Ni ]", "[ V", "[Al ]", "[R h", "[13C ]", "[N H]", "%6 8", "%6 9", "[Zr +4]", "[N @+]", "%7 0", "[Co +2]", "[Ru +2]", "@ H]", "[O -2]", "[Z n]", "[S ]", "[M o", "[U ]", "%7 1", "[O H-]", "%7 2", "%7 3", "[Li ]", "[Ti +4]", "[I r+3]", "%7 4", "[Cu ]", "[Pd ]", "H 3]", "[Sn +]", "b +]", "[Cu +]", "[s +]", "[1 8", "[V ]", "a ]", "[C r]", "[A u", "%7 5", "%7 6", "%7 7", "c H]", "%7 8", "[K ]", "[R b+]", "[S n+2]", "[Ti +3]", "[Br +]", "[S e", "[Al +3]", "%7 9", "[1 4", "%8 0", "[13C H2]", "[T l", "[Al +]", "g +]", "[Mo ]", "F ]", "[P b]", "[1 5", "[Hf ]", "[18 F]", "%8 1", "[M n+2]", "[Rh ]", "[13 cH]", "@ @+]", "[N @@+]", "%8 2", "[Fe +3]", "[ H+]", "f ]", "[H -]", "%8 3", "[M n]", "m ]", "[S b", "[13 c]", "[G e", "[14 C", "%8 4", "%8 5", "[13C H3]", "@ @H]", "[P H+]", "[R f]", "[Tl ]", "[15 N", "[B a+2]", "g ]", "[P H2]", "%8 6", "%8 7", "%8 8", "%8 9", "%9 0", "[G d", "[Au +]", "5 ]", "I ]", "[O s", "%9 1", "%9 2", "[ E", "r +2]", "[NH 2-]", "%9 3", "%9 4", "%9 5", "[As H]", "+ 5]", "[I H+]", "[ p", "[B i", "[A g+]", "[1 2", "[15N ]", "[13C @@H]", "[Ru ]", "a +3]", "[S b]", "[I n]", "%9 6", "[13C H]", "[W +2]", "[C r+2]", "[As -]", "[Hf +4]", "[Se H]", "[p H]", "s +]", "[C s+]", "[Pt +4]", "D y", "[ Dy", "[B H", "[P ]", "[H g]", "[1 1", "[As H2]", "[Sb +5]", "[S H+]", "[H g+]", "%9 7", "%9 8", "%9 9", "[13C @H]", "[Co +3]", "[Gd ]", "n +3]", "[N d", "[P r]", "[11 C", "[C r+3]", "[C a]", "[B H-]", "[1 C", "[G a]", "[E u", "[Bi ]", "[C d", "[Cl H+]", "[Ru +]", "[Mg ]", "[V +2]", "5 I]", "[c H+]", "[I ]", "[Y +3]", "[Zr +3]", "[Sn +4]", "[Os ]", "[12 5I]", "[Dy ]", "[Nd +3]", "[Cd +2]", "[C m]", "[S -2]", "[F m]", "[Ti +]", "[G a+3]", "[Rh +2]", "[Rh +3]", "[Se -]", "[ t", "n +]", "o ]", "[Si @@]", "[Hf +2]", "[18 O", "[Sb +3]", "[t e]", "9 9", "[ 99", "[c +]", "[A g]", "[Al +2]", "[Ru +3]", "[Mg +]", "[Sn +3]", "[Au +3]", "[Tl +]", "[14C ]", "[Gd +3]", "[BH 3-]", "[11C H3]", "[1C H2]", "[99 T", "[Si @H]", "[B @", "[P H", "[SiH -]", "[M n+3]", "[Ru +4]", "b +3]", "[S H2+]", "[Si @]", "[SiH 4]", "[T c]", "[Pt +]", "[Pd +]", "[Au ]", "[Ge -]", "[99T c]", "e +]", "[N @]", "[P H-]", "[R b]", "[L a+3]", "[15 n]", "[Eu ]", "+ 6", "[S H2]", "[P H2+]", "[I n+3]", "[I n+]", "[Mo +2]", "[Mo +3]", "[Se +]", "[14C H2]", "[Os +]", "[18O ]", "+6 ]", "1 I]", "[ o", "b +2]", "c +3]", "s ]", "[C e", "[N @@]", "[S c]", "[P o]", "[As +]", "[Hf +]", "[Ge H]", "[Ge H2]", "[14C @@H]", "[Os +2]", "[E s]", "[Eu +3]", "[o H+]", "4 I]", "[ b", "i ]", "[S r+2]", "[SiH 2-]", "[Z n+]", "[1 cH]", "[T a]", "[13 1I]", "[U +4]", "[Rh +]", "[Mo +4]", "[Ge @]", "[14C @H]", "[12 4I]", "[Dy +3]", "[18O H]", "[B@ @", "[PH 3+]", "[b -]", "3 I]", "e +2]", "g +2]", "r +]", "t ]", "[C o+]", "[C f]", "[N b]", "[S m]", "[B e+2]", "[P b+]", "[P b+2]", "[A t]", "[H o]", "[H g+2]", "[T c", "[Y +]", "[L a]", "[Zr +]", "[s i]", "[Fe +]", "[Al -]", "[Ni +3]", "[14 c]", "[14 cH]", "[Sb +2]", "[Ge +]", "[12 3I]", "[BH 2-]", "[11C H2]", "[B@ -]", "[PH 2-]", "- 5]", "3 -2]", "X e]", "[ Xe]", "a +]", "m +3]", "u ]", "u +3]", "[C r", "[S H-]", "[S c+3]", "[Si @@H]", "[B H]", "[B a]", "[P a]", "[Cl ]", "[I H]", "[I +2]", "[I r+]", "[A r]", "[T a", "[T e+]", "[Y b+3]", "[R a]", "[L u+3]", "[G a+2]", "[s e+]", "[Ni +]", "[Hf +3]", "[V +3]", "[Ge @@]", "[15N H]", "[15N +]", "[E r]", "[p +]", "[Bi +]", "[Bi +3]", "[BH 3-2]", "[Nd ]", "[1C H3]", "[Ce +3]", "[B@@ -]", "+ 7", "2 P", "6 N", "O ]", "S ]", "[ 7", "e H]", "o +3]", "[N +2]", "[N o]", "[N b+3]", "[N b+2]", "Br ]", "[O H]", "[S r]", "[S @+]", "[S b+]", "[S m+3]", "[Si +4]", "[CH 3+]", "[B -5]", "[P o+]", "[P r+3]", "[P H3]", "[P b+3]", "[Cl +2]", "[I n+2]", "[H o+3]", "[1 7", "[1 H]", "[1 6N", "[T m]", "[T b+3]", "[T c+3]", "[T eH]", "[M o+]", "[U +3]", "[Cu +4]", "[s H+]", "[Fe -", "[Pd +3]", "[W +4]", "[3 2P", "[Ru +6]", "[V +]", "[Mo +6]", "[Se H2]", "[Se -2]", "[Ge H3]", "[14C H]", "[14C @]", "[14C H3]", "[15N H2]", "[E r+3]", "[Bi +2]", "[11C ]", "[Eu +2]", "[Ce +4]", "[B@@ H-]", "[Tc +4]", "[Ta +5]", "+7 ]", "[32P ]", "1 3", "1 I", "3 S", "4 -]", "5 S]", "5 Br]", "6 7", "6 S]", "6 Br]", "7 L", "9 P", "B i]", "C u]", "O H]", "[ 67", "c +]", "e +5]", "f +2]", "h +4]", "n -]", "n +5]", "p ]", "[C e]", "[C e+]", "[C r+]", "[C f+2]", "[N +3]", "[N p]", "[O +2]", "[S H3]", "[2 13", "[B a+]", "[P @+]", "[P u]", "[Cl H2+]", "[I r+2]", "[I n-]", "[A m]", "[1 c]", "[1 OH]", "[F ]", "[T a+3]", "[T m+3]", "[T h+4]", "[Br +2]", "[R e+]", "[R a+]", "[R e+5]", "[Pt +6]", "[M n+5]", "[U -]", "[G a+]", "[Cu -]", "[Cu +6]", "[Cu -5]", "[As H2+]", "[13C @]", "[Ni -2]", "[W +]", "[W -]", "[3 5S]", "[3 6S]", "[Ru -]", "[V +5]", "[14 9P", "[15 3S", "[Gd +2]", "[Os +5]", "[Os +6]", "[12 1I]", "[BH 4-]", "[11 O]", "[11 1I", "[1C ]", "[1C H]", "[Cd ]", "[99T c+3]", "[99T c+]", "[B@ H-]", "[PH 5]", "[Ce +2]", "[Tc +2]", "[Tc +7]", "[Cr -]", "[Cr +4]", "[Cr +7]", "[Ta +4]", "[7 5Br]", "[7 6Br]", "[17 O]", "[17 7L", "[16N ]", "[16N +]", "[Fe- 3]", "[Fe- 4]", "[67 Cu]", "[213 Bi]", "[149P m]", "[153S m]", "[111I n]", "[177L u]"]}, "custom_pre_tokenizer": true, "tokenizer_type": "bpe", "tokenizer_attrs": {"pad_token": "[PAD]", "cls_token": "[CLS]", "sep_token": "[SEP]", "mask_token": "[MASK]", "unk_token": "[UNK]", "eos_token": "[SEP]", "bos_token": "[CLS]"}}
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: "SAFE"
2 | site_description: "Gotta be SAFE: a new framework for molecular design"
3 | site_url: "https://github.com/datamol-io/safe"
4 | repo_url: "https://github.com/datamol-io/safe"
5 | repo_name: "datamol-io/safe"
6 | copyright: Copyright 2023 Valence Labs
7 |
8 | remote_branch: "gh-pages"
9 | docs_dir: "docs"
10 | use_directory_urls: false
11 | strict: true
12 |
13 | nav:
14 | - Overview: index.md
15 | - Tutorials:
16 | - Getting Started: tutorials/getting-started.ipynb
17 | - Molecular design: tutorials/design-with-safe.ipynb
18 | - How it works: tutorials/how-it-works.ipynb
19 | - Extracting representation (molfeat): tutorials/extracting-representation-molfeat.ipynb
20 | - API:
21 | - SAFE: api/safe.md
22 | - Visualization: api/safe.viz.md
23 | - Model training: api/safe.models.md
24 | - CLI: cli.md
25 | - License: license.md
26 | - Data License: data_license.md
27 |
28 | theme:
29 | name: material
30 | features:
31 | - navigation.expand
32 |
33 | extra_javascript:
34 | - assets/js/google-analytics.js
35 |
36 | markdown_extensions:
37 | - admonition
38 | - markdown_include.include
39 | - pymdownx.emoji
40 | - pymdownx.highlight
41 | - pymdownx.magiclink
42 | - pymdownx.superfences
43 | - pymdownx.tabbed
44 | - pymdownx.tasklist
45 | - pymdownx.details
46 | # For `tab_length=2` in the markdown extension
47 | # See https://github.com/mkdocs/mkdocs/issues/545
48 | - mdx_truly_sane_lists
49 | - toc:
50 | permalink: true
51 |
52 | watch:
53 | - safe/
54 |
55 | plugins:
56 | - search
57 | - mkdocstrings:
58 | handlers:
59 | python:
60 | import:
61 | - https://docs.python.org/3/objects.inv
62 | setup_commands:
63 | - import sys
64 | - import safe
65 | - sys.path.append("docs")
66 | - sys.path.append("safe")
67 | selection:
68 | new_path_syntax: true
69 | rendering:
70 | show_root_heading: false
71 | heading_level: 2
72 | show_if_no_docstring: true
73 | options:
74 | docstring_options:
75 | ignore_init_summary: false
76 | docstring_section_style: list
77 | merge_init_into_class: true
78 | show_root_heading: false
79 | show_root_full_path: false
80 | show_signature_annotations: true
81 | show_symbol_type_heading: true
82 | show_symbol_type_toc: true
83 | signature_crossrefs: true
84 |
85 | - mkdocs-jupyter:
86 | execute: False
87 | remove_tag_config:
88 | remove_cell_tags: [remove_cell]
89 | remove_all_outputs_tags: [remove_output]
90 | remove_input_tags: [remove_input]
91 |
92 | - mike:
93 | version_selector: true
94 |
95 | extra:
96 | version:
97 | provider: mike
98 |
99 | social:
100 | - icon: fontawesome/brands/github
101 | link: https://github.com/datamol-io
102 | - icon: fontawesome/brands/twitter
103 | link: https://twitter.com/datamol_io
104 | - icon: fontawesome/brands/python
105 | link: https://pypi.org/project/safe-mol/
106 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "setuptools-scm"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "safe-mol"
7 | description = "Implementation of the 'Gotta be SAFE: a new framework for molecular design' paper"
8 | dynamic = ["version"]
9 | authors = [{ name = "Emmanuel Noutahi", email = "emmanuel.noutahi@gmail.com" }]
10 | readme = "README.md"
11 | license = { text = "Apache-2.0" }
12 | requires-python = ">=3.9"
13 | classifiers = [
14 | "Development Status :: 5 - Production/Stable",
15 | "Intended Audience :: Developers",
16 | "Intended Audience :: Healthcare Industry",
17 | "Intended Audience :: Science/Research",
18 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
19 | "Topic :: Scientific/Engineering :: Bio-Informatics",
20 | "Topic :: Scientific/Engineering :: Information Analysis",
21 | "Topic :: Scientific/Engineering :: Medical Science Apps.",
22 | "Natural Language :: English",
23 | "Operating System :: OS Independent",
24 | "Programming Language :: Python",
25 | "Programming Language :: Python :: 3",
26 | "Programming Language :: Python :: 3.9",
27 | "Programming Language :: Python :: 3.9",
28 | "Programming Language :: Python :: 3.10",
29 | "Programming Language :: Python :: 3.11",
30 | ]
31 |
32 | keywords = ["safe", "smiles", "de novo", "design", "molecules"]
33 | dependencies = [
34 | "tqdm",
35 | "loguru",
36 | "typer",
37 | "universal_pathlib",
38 | "datamol",
39 | "numpy",
40 | "torch>=2.0",
41 | "transformers",
42 | "datasets",
43 | "tokenizers",
44 | "accelerate",
45 | "evaluate",
46 | "wandb",
47 | "huggingface-hub",
48 | "rdkit"
49 | ]
50 |
51 | [project.urls]
52 | "Source Code" = "https://github.com/datamol-io/safe"
53 | "Bug Tracker" = "https://github.com/datamol-io/safe/issues"
54 | Documentation = "https://safe-docs.datamol.io/"
55 |
56 |
57 | [project.scripts]
58 | safe-train = "safe.trainer.cli:main"
59 |
60 | [tool.setuptools]
61 | include-package-data = true
62 | zip-safe = false
63 | license-files = ["LICENSE"]
64 |
65 | [tool.setuptools_scm]
66 | fallback_version = "dev"
67 |
68 | [tool.setuptools.packages.find]
69 | where = ["."]
70 | include = ["safe", "safe.*"]
71 | exclude = []
72 | namespaces = true
73 |
74 | [tool.black]
75 | line-length = 100
76 | target-version = ['py310', 'py311']
77 | include = '\.pyi?$'
78 |
79 | [tool.pytest.ini_options]
80 | minversion = "6.0"
81 | addopts = "--verbose --color yes"
82 | testpaths = ["tests"]
83 |
84 |
85 | [tool.ruff]
86 | line-length = 120
87 | # Enable Pyflakes `E` and `F` codes by default.
88 | lint.select = [
89 | "E",
90 | "W", # see: https://pypi.org/project/pycodestyle
91 | "F", # see: https://pypi.org/project/pyflakes
92 | ]
93 | lint.extend-select = [
94 | "C4", # see: https://pypi.org/project/flake8-comprehensions
95 | "SIM", # see: https://pypi.org/project/flake8-simplify
96 | "RET", # see: https://pypi.org/project/flake8-return
97 | "PT", # see: https://pypi.org/project/flake8-pytest-style
98 | ]
99 | lint.ignore = [
100 | "E731", # Do not assign a lambda expression, use a def
101 | "S108",
102 | "F401",
103 | "S105",
104 | "E501",
105 | "E722",
106 | ]
107 | # Exclude a variety of commonly ignored directories.
108 | exclude = [".git", "docs", "_notebooks"]
109 | lint.ignore-init-module-imports = true
110 |
--------------------------------------------------------------------------------
/safe/__init__.py:
--------------------------------------------------------------------------------
1 | from . import trainer, utils
2 | from ._exception import SAFEDecodeError, SAFEEncodeError, SAFEFragmentationError
3 | from .converter import SAFEConverter, decode, encode
4 | from .sample import SAFEDesign
5 | from .tokenizer import SAFETokenizer, split
6 | from .viz import to_image
7 | from .io import upload_to_wandb
8 |
--------------------------------------------------------------------------------
/safe/_exception.py:
--------------------------------------------------------------------------------
1 | class SAFEDecodeError(Exception):
2 | """Raised when a string cannot be decoded with the given encoding."""
3 |
4 | pass
5 |
6 |
7 | class SAFEEncodeError(Exception):
8 | """Raised when a molecule cannot be encoded using SAFE."""
9 |
10 | pass
11 |
12 |
13 | class SAFEFragmentationError(Exception):
14 | """Raised when a the slicing algorithm return empty bonds."""
15 |
16 | pass
17 |
--------------------------------------------------------------------------------
/safe/_pattern.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from typing import List
3 | from typing import Optional
4 | import re
5 | import datamol as dm
6 | import torch
7 | import numpy as np
8 | import safe as sf
9 | import torch.nn.functional as F
10 | import transformers
11 |
12 | from loguru import logger
13 | from tqdm.auto import tqdm
14 | import contextlib
15 |
16 |
17 | class PatternConstraint:
18 | """
19 | Sampling decorator for pretrained SAFE models.
20 | This implementation is inspired by the Sanofi decorator with:
21 | 1. new generalization to different tokenizers
22 | 2. support for a subset of smarts notations
23 | 3. speed improvements by dropping unnecessary steps
24 |
25 | !!! note
26 | For smarts based constraints, it's important to understand that the constraints
27 | are strong sampling suggestions and not necessarily the final result, meaning that they can
28 | fail.
29 |
30 | !!! warning
31 | Ring constraints should be interpreted as "Extended Ring System" constraints.
32 | Thus [r6] means an atom in an environment of 6 atoms and more that contains a ring system,
33 | instead of an atom in a ring of size 6.
34 | """
35 |
36 | ATTACHMENT_POINT_TOKEN = "\\*"
37 | ATTACHMENT_POINTS = [
38 | # parse any * not preceeded by "[" or ":" and not followed by "]" or ":" as attachment
39 | r"(? 0:
127 | random_mol = dm.randomize_atoms(random_mol)
128 | try:
129 | out.add(dm.to_smarts(random_mol))
130 | except Exception as e:
131 | logger.error(e)
132 | n -= 1
133 | return out
134 |
135 | def _find_ring_tokens(self):
136 | """Find all possible ring tokens in the vocab."""
137 | ring_token_ids = []
138 | for tk, tk_ids in self.tokenizer.tokenizer.get_vocab().items():
139 | try:
140 | _ = int(tk.lstrip("%"))
141 | ring_token_ids.append(tk_ids)
142 | except ValueError:
143 | pass
144 | return ring_token_ids
145 |
146 | def _prepare_scaffold(self):
147 | """Prepare scaffold for decoration."""
148 | return self.input_scaffold
149 |
150 | def is_branch_closer(self, pos_or_token: Union[int, str]):
151 | """Check whether a token is a branch closer."""
152 | if isinstance(pos_or_token, int):
153 | return self.tokens[pos_or_token] == self.branch_closer
154 | return pos_or_token == self.branch_closer
155 |
156 | def is_branch_opener(self, pos_or_token: Union[int, str]):
157 | """Check whether a token is a branch opener."""
158 | if isinstance(pos_or_token, int):
159 | return self.tokens[pos_or_token] == self.branch_opener
160 | return pos_or_token == self.branch_opener
161 |
162 | def __len__(self):
163 | """Get length of the tokenized scaffold."""
164 | if not self._is_initialized:
165 | raise ValueError("Decorator is not initialized yet")
166 | return len(self.tokens)
167 |
168 | def _initialize(self):
169 | """
170 | Initialize the current scaffold decorator object with the scaffold object.
171 | The initialization will also set and validate the vocab object to use for the scaffold decoration,
172 | """
173 | self._is_initialized = False
174 | pretrained_tokenizer = self.tokenizer.get_pretrained()
175 | encoding_tokens = [
176 | token
177 | for token, _ in pretrained_tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(
178 | self.scaffold
179 | )
180 | ] + [pretrained_tokenizer.eos_token]
181 | encoding_token_ids = [
182 | pretrained_tokenizer.convert_tokens_to_ids(x) for x in encoding_tokens
183 | ]
184 | # encodings.tokens contains BOS and EOS
185 |
186 | self.branch_opener_id = self.tokenizer.tokenizer.token_to_id(self.branch_opener)
187 | self.branch_closer_id = self.tokenizer.tokenizer.token_to_id(self.branch_closer)
188 | linker_size = {}
189 | # convert the full vocab into mol constraints
190 | vocab_as_constraints = [
191 | self._parse_token_as_mol(self.tokenizer.tokenizer.id_to_token(i))
192 | for i in range(len(self.tokenizer))
193 | ]
194 | token_masks = {}
195 | actions = {}
196 | tokens = []
197 | ids = []
198 | all_tokens = self.tokenizer.tokenizer.get_vocab().keys()
199 | unk_token_id = pretrained_tokenizer.unk_token_id
200 | unknown_token_map = {}
201 | # we include the stop token
202 | for pos in range(len(encoding_token_ids)):
203 | token_id = encoding_token_ids[pos]
204 | token = encoding_tokens[pos]
205 |
206 | # if it is not an unknown token, then it can just be rollout
207 | # note that we are not using all special tokens, just unknown
208 | # and as such, you would need to have a well defined vocab
209 | cur_mask = torch.ones(len(vocab_as_constraints), dtype=torch.bool)
210 | constraints = None
211 | ring_token = None
212 | if token_id == unk_token_id:
213 | # this here can be one of the case we want:
214 | # we need to check if the token is a ring and whether it has other constraints.
215 | ring_match = re.match(self.HAS_RING_TOKEN, token)
216 | if ring_match and token.count("r") != 1:
217 | raise ValueError("Multiple ring constraints in a single token is not supported")
218 | if ring_match:
219 | ring_size = ring_match.group(2).strip("r:")
220 | ring_size = int(ring_size) if ring_size else self.min_ring_size
221 | linker_size[pos - 1] = ring_size
222 | # since we have filled the ring constraints already, we need to remove it from the token format
223 | ring_token = self._remove_ring_constraint(token)
224 | if self._is_attachment(token) or (
225 | ring_token is not None and self._is_attachment(ring_token)
226 | ):
227 | # the mask would be handled by the attachment algorithm in decorate
228 | actions[token] = "attach"
229 | unknown_token_map[pos] = sf.utils.standardize_attach(token)
230 | elif self._is_constraint(token, all_tokens):
231 | actions[token] = "constraint"
232 | constraints = self._parse_token_as_mol(token)
233 | if constraints is not None:
234 | cur_mask = self._mask_tokens(constraints, vocab_as_constraints)
235 | else:
236 | # this means that we need to sample the exact token
237 | # and disallow all the other
238 | cur_mask = torch.zeros(len(vocab_as_constraints), dtype=torch.bool)
239 | cur_mask[token_id] = 1
240 | token_masks[token] = cur_mask
241 | tokens.append(token)
242 | ids.append(token_id)
243 | self.linker_size = linker_size
244 | self.token_masks = token_masks
245 | self.actions = actions
246 | self.tokens = tokens
247 | self.ids = ids
248 | self._is_initialized = True
249 | self.unknown_token_map = unknown_token_map
250 | self.ring_token_ids = self._find_ring_tokens()
251 |
252 | def _is_attachment(self, token: str):
253 | """Check whether a token is should be an attachment or not"""
254 | # What I define as attachment is a token that is not a constraint and not a ring
255 | # basically the classic "[*]" written however you like it
256 | return any(re.match(attach_regex, token) for attach_regex in self.ATTACHMENT_POINTS)
257 |
258 | def _is_constraint(self, token: str, vocab_list: Optional[List[str]] = None):
259 | """Check whether a token is a constraint
260 | Args:
261 | token: token to check whether
262 | vocab: optional vocab to check against
263 | """
264 | if vocab_list is None:
265 | vocab_list = []
266 | tk_constraints = re.match(self.IS_CONSTRAINT, token) or token not in vocab_list
267 | return tk_constraints is not None
268 |
269 | def _remove_ring_constraint(self, token):
270 | """Remove ring constraints from a token"""
271 | token = re.sub(r"((&|,|;)?r\d*)*(&|,|;)?", r"\3", token)
272 | token = re.sub(r"(\[[&;,]?)(.*)([&,;]?\])", r"[\2]", token)
273 | if token == "[]":
274 | return "[*]"
275 | return token
276 |
277 | def _parse_token_as_mol(self, token):
278 | """
279 | Parse a token as a valid molecular pattern
280 | """
281 | tk_mol = None
282 | with dm.without_rdkit_log():
283 | try:
284 | tk_mol = dm.from_smarts(token)
285 | except:
286 | tk_mol = dm.to_mol(token, sanitize=True)
287 | # didn't work, try with second strategy
288 | if tk_mol is None:
289 | tk_mol = dm.to_mol(token, sanitize=False)
290 | try:
291 | tk_mol = dm.from_smarts(dm.smiles_as_smarts(tk_mol))
292 | except:
293 | tk_mol = None
294 | return tk_mol
295 |
296 | def _mask_tokens(self, constraint: List[dm.Mol], vocab_mols: List[dm.Mol]):
297 | """Mask the prediction to enforce some constraints
298 |
299 | Args:
300 | constraint: constraint found in the scaffold
301 | vocab_mols: list of mol queries (convertible ones) from the vocab
302 |
303 | Returns:
304 | mask: mask for valid tokens that match constraints. 1 means keep, 0 means mask
305 | """
306 | mask = torch.zeros(len(vocab_mols), dtype=torch.bool)
307 | with dm.without_rdkit_log():
308 | for ind, tk_mol in enumerate(vocab_mols):
309 | if tk_mol is not None:
310 | with contextlib.suppress(Exception):
311 | mask[ind] = int(tk_mol.HasSubstructMatch(constraint))
312 | return mask
313 |
314 |
315 | class PatternSampler:
316 | """
317 | Implements a pattern-constrained sequence sampler for Autoregressive transformer models using a PatternConstraint.
318 |
319 | Args:
320 | model: Pretrained model used for generation.
321 | pattern_decorator: The PatternConstraint object that provides the scaffold and sampling constraints.
322 | min_linker_size: Minimum size of the linker.
323 | max_steps_to_eos: Maximum steps to end-of-sequence token.
324 | max_length: Maximum length of the generated sequences.
325 | """
326 |
327 | def __init__(
328 | self,
329 | model,
330 | pattern_decorator,
331 | min_linker_size: int = 3,
332 | max_steps_to_eos: int = 50,
333 | max_length: int = 128,
334 | ):
335 | self.model = model
336 | self.pattern_decorator = pattern_decorator
337 | self.tokenizer = self.pattern_decorator.tokenizer.get_pretrained()
338 | self.min_linker_size = min_linker_size
339 | self.max_steps_to_eos = max_steps_to_eos
340 | self.model.eval()
341 | self.max_length = max_length
342 |
343 | def nll_loss(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
344 | """
345 | Custom Negative Log Likelihood (NLL) loss that returns loss per example.
346 |
347 | Args:
348 | inputs (torch.Tensor): Log probabilities of each class, shape (batch_size, num_classes).
349 | targets (torch.Tensor): Target class indices, shape (batch_size).
350 |
351 | Returns:
352 | torch.Tensor: Loss for each example, shape (batch_size).
353 | """
354 | target_expanded = (
355 | torch.zeros(inputs.size()).cuda()
356 | if torch.cuda.is_available()
357 | else torch.zeros(inputs.size())
358 | )
359 | target_expanded.scatter_(1, targets.contiguous().view(-1, 1).data, 1.0)
360 | loss = target_expanded * inputs
361 | return torch.sum(loss, dim=1)
362 |
363 | def sample_scaffolds(
364 | self, n_samples: int = 100, n_trials: int = 1, random_seed: Optional[int] = None
365 | ):
366 | """
367 | Sample a batch of sequences based on the scaffold provided by the PatternConstraint.
368 |
369 | Args:
370 | n_samples: Number of sequences to sample.
371 | n_trials: Number of sampling trials to perform.
372 | random_seed: Seed for random number generation.
373 |
374 | Returns:
375 | List[str]: List of sampled sequences as strings.
376 | """
377 | if random_seed is not None:
378 | torch.manual_seed(random_seed)
379 |
380 | sampled_mols = []
381 | for _ in tqdm(range(n_trials), leave=False):
382 | generated_sequences, *_ = self._generate(n_samples)
383 | sampled_mols.extend(
384 | [
385 | self._as_scaffold(self.tokenizer.decode(seq, skip_special_tokens=False))
386 | for seq in generated_sequences
387 | ]
388 | )
389 | return sampled_mols
390 |
391 | def _as_scaffold(self, scaff: str) -> str:
392 | """
393 | Converts the generated sequence to a valid scaffold by replacing unknown tokens.
394 |
395 | Args:
396 | scaff: The generated sequence string.
397 |
398 | Returns:
399 | str: The scaffold string with unknown tokens replaced.
400 | """
401 | out = scaff.replace(self.tokenizer.eos_token, "")
402 | splitted_out = [
403 | token for token, _ in self.tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(out)
404 | ]
405 | for pos, query in self.pattern_decorator.unknown_token_map.items():
406 | splitted_out[pos] = query
407 | return "".join(splitted_out)
408 |
409 | def _generate(self, batch_size: int, max_length: Optional[int] = None):
410 | """
411 | Generate sequences with custom constraints using the model and PatternConstraint.
412 |
413 | Args:
414 | batch_size: Number of sequences to generate.
415 | max_length: Maximum length of the sequence.
416 |
417 | Returns:
418 | Tuple: Generated sequences, log probabilities, and entropies.
419 | """
420 | sequences = []
421 | if max_length is None:
422 | max_length = self.max_length
423 |
424 | start_token = torch.full((batch_size, 1), self.tokenizer.bos_token_id, dtype=torch.long)
425 | finished = torch.zeros(batch_size, dtype=bool)
426 | log_probs = torch.zeros(batch_size)
427 | entropy = torch.zeros(batch_size)
428 |
429 | if torch.cuda.is_available():
430 | log_probs = log_probs.cuda()
431 | entropy = entropy.cuda()
432 | start_token = start_token.cuda()
433 | finished = finished.cuda()
434 |
435 | max_dec_steps = max_length - len(self.pattern_decorator)
436 | if max_dec_steps < 0:
437 | raise ValueError("Step size negative due to scaffold being longer than max_length")
438 |
439 | input_ids = start_token
440 | trackers = torch.zeros(batch_size, dtype=torch.int) # Tracks the position in the scaffold
441 |
442 | for step in range(max_length):
443 | current_tokens = [self.pattern_decorator.tokens[index] for index in trackers]
444 | action_i = [
445 | self.pattern_decorator.actions.setdefault(
446 | self.pattern_decorator.tokens[index], "roll"
447 | )
448 | for index in trackers
449 | ]
450 |
451 | # Pass through model
452 | outputs = self.model(input_ids)
453 | logits = outputs.logits[:, -1, :]
454 | log_prob = torch.log_softmax(logits, dim=-1)
455 | probs = log_prob.exp()
456 |
457 | decoder_input = torch.multinomial(probs, num_samples=1).squeeze(1).view(-1)
458 |
459 | for i in range(batch_size):
460 | if action_i[i] == "constraint":
461 | mask = self.pattern_decorator.token_masks[current_tokens[i]].to(
462 | input_ids.device
463 | )
464 | prob_i = self.pattern_decorator._logprobs_to_probs(log_prob[i, :], mask)
465 | if prob_i.sum() == 0:
466 | random_choice = torch.nonzero(mask.squeeze()).squeeze().cpu().numpy()
467 | decoder_input[i] = np.random.choice(random_choice)
468 | else:
469 | decoder_input[i] = torch.multinomial(prob_i, num_samples=1).view(-1)
470 | trackers[i] += int(decoder_input[i] != self.tokenizer.eos_token_id)
471 | else:
472 | decoder_input[i] = self.pattern_decorator.ids[trackers[i].item()]
473 | trackers[i] += int(decoder_input[i] != self.tokenizer.eos_token_id)
474 |
475 | sequences.append(decoder_input.unsqueeze(-1))
476 | input_ids = torch.cat((input_ids, decoder_input.unsqueeze(-1)), dim=1)
477 | log_probs += self.nll_loss(log_prob, decoder_input)
478 | entropy += -torch.sum(log_prob * probs, dim=-1)
479 |
480 | eos_sampled = (decoder_input == self.tokenizer.eos_token_id).bool()
481 | finished = torch.ge(finished + eos_sampled, 1)
482 | if torch.prod(finished) == 1:
483 | break
484 | sequences = torch.cat(sequences, dim=1)
485 | return sequences, log_probs, entropy
486 |
--------------------------------------------------------------------------------
/safe/converter.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import re
3 | from collections import Counter
4 | from contextlib import suppress
5 | from typing import Callable, List, Optional, Union
6 |
7 | import datamol as dm
8 | import numpy as np
9 | from rdkit import Chem
10 | from rdkit.Chem import BRICS
11 | from loguru import logger
12 |
13 | from ._exception import SAFEDecodeError, SAFEEncodeError, SAFEFragmentationError
14 | from .utils import standardize_attach
15 |
16 |
17 | class SAFEConverter:
18 | """Molecule line notation conversion from SMILES to SAFE
19 |
20 | A SAFE representation is a string based representation of a molecule decomposition into fragment components,
21 | separated by a dot ('.'). Note that each component (fragment) might not be a valid molecule by themselves,
22 | unless explicitely correct to add missing hydrogens.
23 |
24 | !!! note "Slicing algorithms"
25 |
26 | By default SAFE strings are generated using `BRICS`, however, the following alternative are supported:
27 |
28 | * [Hussain-Rea (`hr`)](https://pubs.acs.org/doi/10.1021/ci900450m)
29 | * [RECAP (`recap`)](https://pubmed.ncbi.nlm.nih.gov/9611787/)
30 | * [RDKit's MMPA (`mmpa`)](https://www.rdkit.org/docs/source/rdkit.Chem.rdMMPA.html)
31 | * Any possible attachment points (`attach`)
32 |
33 | Furthermore, you can also provide your own slicing algorithm, which should return a pair of atoms
34 | corresponding to the bonds to break.
35 |
36 | """
37 |
38 | SUPPORTED_SLICERS = ["hr", "rotatable", "recap", "mmpa", "attach", "brics"]
39 | __SLICE_SMARTS = {
40 | "hr": ["[*]!@-[*]"], # any non ring single bond
41 | "recap": [
42 | "[$([C;!$(C([#7])[#7])](=!@[O]))]!@[$([#7;+0;!D1])]",
43 | "[$(C=!@O)]!@[$([O;+0])]",
44 | "[$([N;!D1;+0;!$(N-C=[#7,#8,#15,#16])](-!@[*]))]-!@[$([*])]",
45 | "[$(C(=!@O)([#7;+0;D2,D3])!@[#7;+0;D2,D3])]!@[$([#7;+0;D2,D3])]",
46 | "[$([O;+0](-!@[#6!$(C=O)])-!@[#6!$(C=O)])]-!@[$([#6!$(C=O)])]",
47 | "C=!@C",
48 | "[N;+1;D4]!@[#6]",
49 | "[$([n;+0])]-!@C",
50 | "[$([O]=[C]-@[N;+0])]-!@[$([C])]",
51 | "c-!@c",
52 | "[$([#7;+0;D2,D3])]-!@[$([S](=[O])=[O])]",
53 | ],
54 | "mmpa": ["[#6+0;!$(*=,#[!#6])]!@!=!#[*]"], # classical mmpa slicing smarts
55 | "attach": ["[*]!@[*]"], # any potential attachment point, including hydrogens when explicit
56 | "rotatable": ["[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"],
57 | }
58 |
59 | def __init__(
60 | self,
61 | slicer: Optional[Union[str, List[str], Callable]] = "brics",
62 | require_hs: Optional[bool] = None,
63 | use_original_opener_for_attach: bool = True,
64 | ignore_stereo: bool = False,
65 | ):
66 | """Constructor for the SAFE converter
67 |
68 | Args:
69 | slicer: slicer algorithm to use for encoding.
70 | Can either be one of the supported slicing algorithm (SUPPORTED_SLICERS)
71 | or a custom callable that returns the bond ids that can be sliced.
72 | require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added.
73 | `attach` slicer requires adding hydrogens.
74 | use_original_opener_for_attach: whether to use the original branch opener digit when adding back
75 | mapping number to attachment points, or use simple enumeration.
76 | ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined.
77 |
78 | """
79 | self.slicer = slicer
80 | if isinstance(slicer, str) and slicer.lower() in self.SUPPORTED_SLICERS:
81 | self.slicer = self.__SLICE_SMARTS.get(slicer.lower(), slicer)
82 | if self.slicer != "brics" and isinstance(self.slicer, str):
83 | self.slicer = [self.slicer]
84 | if isinstance(self.slicer, (list, tuple)):
85 | self.slicer = [dm.from_smarts(x) for x in self.slicer]
86 | if any(x is None for x in self.slicer):
87 | raise ValueError(f"Slicer: {slicer} cannot be valid")
88 | self.require_hs = require_hs or (slicer == "attach")
89 | self.use_original_opener_for_attach = use_original_opener_for_attach
90 | self.ignore_stereo = ignore_stereo
91 |
92 | @staticmethod
93 | def randomize(mol: dm.Mol, rng: Optional[int] = None):
94 | """Randomize the position of the atoms in a mol.
95 |
96 | Args:
97 | mol: molecules to randomize
98 | rng: optional seed to use
99 | """
100 | if isinstance(rng, int):
101 | rng = np.random.default_rng(rng)
102 | if mol.GetNumAtoms() == 0:
103 | return mol
104 | atom_indices = list(range(mol.GetNumAtoms()))
105 | atom_indices = rng.permutation(atom_indices).tolist()
106 | return Chem.RenumberAtoms(mol, atom_indices)
107 |
108 | @classmethod
109 | def _find_branch_number(cls, inp: str):
110 | """Find the branch number and ring closure in the SMILES representation using regexp
111 |
112 | Args:
113 | inp: input smiles
114 | """
115 | inp = re.sub(r"\[.*?\]", "", inp) # noqa
116 | matching_groups = re.findall(r"((?<=%)\d{2})|((? 0:
305 | mol = Chem.FragmentOnBonds(
306 | mol,
307 | bonds,
308 | dummyLabels=[(i + bond_map_id, i + bond_map_id) for i in range(len(bonds))],
309 | )
310 | # here we need to be clever and disable rooted atom as the atom with mapping
311 |
312 | frags = list(Chem.GetMolFrags(mol, asMols=True))
313 | if randomize:
314 | frags = rng.permutation(frags).tolist()
315 | elif canonical:
316 | frags = sorted(
317 | frags,
318 | key=lambda x: x.GetNumAtoms(),
319 | reverse=True,
320 | )
321 |
322 | frags_str = []
323 | for frag in frags:
324 | non_map_atom_idxs = [
325 | atom.GetIdx() for atom in frag.GetAtoms() if atom.GetAtomicNum() != 0
326 | ]
327 | frags_str.append(
328 | Chem.MolToSmiles(
329 | frag,
330 | isomericSmiles=True,
331 | canonical=True, # needs to always be true
332 | rootedAtAtom=non_map_atom_idxs[0],
333 | )
334 | )
335 |
336 | scaffold_str = ".".join(frags_str)
337 | # EN: fix for https://github.com/datamol-io/safe/issues/37
338 | # we were using the wrong branch number count which did not take into account
339 | # possible change in digit utilization after bond slicing
340 | scf_branch_num = self._find_branch_number(scaffold_str) + branch_numbers
341 |
342 | # don't capture atom mapping in the scaffold
343 | attach_pos = set(re.findall(r"(\[\d+\*\]|!\[[^:]*:\d+\])", scaffold_str))
344 | if canonical:
345 | attach_pos = sorted(attach_pos)
346 | starting_num = 1 if len(scf_branch_num) == 0 else max(scf_branch_num) + 1
347 | for attach in attach_pos:
348 | val = str(starting_num) if starting_num < 10 else f"%{starting_num}"
349 | # we cannot have anything of the form "\([@=-#-$/\]*\d+\)"
350 | attach_regexp = re.compile(r"(" + re.escape(attach) + r")")
351 | scaffold_str = attach_regexp.sub(val, scaffold_str)
352 | starting_num += 1
353 |
354 | # now we need to remove all the parenthesis around digit only number
355 | wrong_attach = re.compile(r"\(([\%\d]*)\)")
356 | scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str)
357 | # furthermore, we autoapply rdkit-compatible digit standardization.
358 | if rdkit_safe:
359 | pattern = r"\(([=-@#\/\\]{0,2})(%?\d{1,2})\)"
360 | replacement = r"\g<1>\g<2>"
361 | scaffold_str = re.sub(pattern, replacement, scaffold_str)
362 | if not self.ignore_stereo and has_stereo_bonds and not dm.same_mol(scaffold_str, inp):
363 | logger.warning(
364 | "Ignoring stereo is disabled, but molecule has stereochemistry interferring with SAFE representation"
365 | )
366 | return scaffold_str
367 |
368 |
369 | def encode(
370 | inp: Union[str, dm.Mol],
371 | canonical: bool = True,
372 | randomize: Optional[bool] = False,
373 | seed: Optional[int] = None,
374 | slicer: Optional[Union[List[str], str, Callable]] = None,
375 | require_hs: Optional[bool] = None,
376 | constraints: Optional[List[dm.Mol]] = None,
377 | ignore_stereo: Optional[bool] = False,
378 | ):
379 | """
380 | Convert input smiles to SAFE representation
381 |
382 | Args:
383 | inp: input smiles
384 | canonical: whether to return canonical SAFE string. Defaults to True
385 | randomize: whether to randomize the safe string encoding. Will be ignored if canonical is provided
386 | seed: optional seed to use when allowing randomization of the SAFE encoding.
387 | slicer: slicer algorithm to use for encoding. Defaults to "brics".
388 | require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added.
389 | constraints: List of molecules or pattern to preserve during the SAFE construction.
390 | ignore_stereo: RDKIT does not support some particular SAFE subset when stereochemistry is defined.
391 | """
392 | if slicer is None:
393 | slicer = "brics"
394 | with dm.without_rdkit_log():
395 | safe_obj = SAFEConverter(slicer=slicer, require_hs=require_hs, ignore_stereo=ignore_stereo)
396 | try:
397 | encoded = safe_obj.encoder(
398 | inp,
399 | canonical=canonical,
400 | randomize=randomize,
401 | constraints=constraints,
402 | seed=seed,
403 | )
404 | except SAFEFragmentationError as e:
405 | raise e
406 | except Exception as e:
407 | raise SAFEEncodeError(f"Failed to encode {inp} with {slicer}") from e
408 | return encoded
409 |
410 |
411 | def decode(
412 | safe_str: str,
413 | as_mol: bool = False,
414 | canonical: bool = False,
415 | fix: bool = True,
416 | remove_added_hs: bool = True,
417 | remove_dummies: bool = True,
418 | ignore_errors: bool = False,
419 | ):
420 | """Convert input SAFE representation to smiles
421 | Args:
422 | safe_str: input SAFE representation to decode as a valid molecule or smiles
423 | as_mol: whether to return a molecule object or a smiles string
424 | canonical: whether to return a canonical smiles or a randomized smiles
425 | fix: whether to fix the SAFE representation to take into account non-connected attachment points
426 | remove_added_hs: whether to remove the hydrogen atoms that have been added to fix the string.
427 | remove_dummies: whether to remove dummy atoms from the SAFE representation
428 | ignore_errors: whether to ignore error and return None on decoding failure or raise an error
429 |
430 | """
431 | with dm.without_rdkit_log():
432 | safe_obj = SAFEConverter()
433 | try:
434 | decoded = safe_obj.decoder(
435 | safe_str,
436 | as_mol=as_mol,
437 | canonical=canonical,
438 | fix=fix,
439 | remove_dummies=remove_dummies,
440 | remove_added_hs=remove_added_hs,
441 | )
442 |
443 | except Exception as e:
444 | if ignore_errors:
445 | return None
446 | raise SAFEDecodeError(f"Failed to decode {safe_str}") from e
447 | return decoded
448 |
--------------------------------------------------------------------------------
/safe/io.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 | import tempfile
4 | import os
5 | import contextlib
6 | import torch
7 | import wandb
8 | import fsspec
9 |
10 | from transformers import PreTrainedModel, is_torch_available
11 | from transformers.processing_utils import PushToHubMixin
12 |
13 |
14 | def upload_to_wandb(
15 | model: PreTrainedModel,
16 | tokenizer,
17 | artifact_name: str,
18 | wandb_project_name: Optional[str] = "safe-models",
19 | artifact_type: str = "model",
20 | slicer: Optional[str] = None,
21 | aliases: Optional[List[str]] = None,
22 | **init_args,
23 | ):
24 | """
25 | Uploads a model and tokenizer to a specified Weights and Biases (wandb) project.
26 |
27 | Args:
28 | model (PreTrainedModel): The model to be uploaded (instance of PreTrainedModel).
29 | tokenizer: The tokenizer associated with the model.
30 | artifact_name (str): The name of the wandb artifact to create.
31 | wandb_project_name (Optional[str]): The name of the wandb project. Defaults to 'safe-model'.
32 | artifact_type (str): The type of artifact (e.g., 'model'). Defaults to 'model'.
33 | slicer (Optional[str]): Optional metadata field that can store a slicing method.
34 | aliases (Optional[List[str]]): List of aliases to assign to this artifact version.
35 | **init_args: Additional arguments to pass into `wandb.init()`.
36 | """
37 |
38 | with tempfile.TemporaryDirectory() as tmpdirname:
39 | # Paths to save model and tokenizer
40 | model_path = tokenizer_path = tmpdirname
41 | architecture_file = os.path.join(tmpdirname, "architecture.txt")
42 | with fsspec.open(architecture_file, "w+") as f:
43 | f.write(str(model))
44 |
45 | model.save_pretrained(model_path)
46 | with contextlib.suppress(Exception):
47 | tokenizer.save_pretrained(tokenizer_path)
48 | tokenizer.save(os.path.join(tokenizer_path, "tokenizer.json"))
49 |
50 | info_dict = {"slicer": slicer}
51 | model_config = None
52 | if hasattr(model, "config") and model.config is not None:
53 | model_config = (
54 | model.config.to_dict() if not isinstance(model.config, dict) else model.config
55 | )
56 | info_dict.update(model_config)
57 |
58 | if hasattr(model, "peft_config") and model.peft_config is not None:
59 | info_dict.update({"peft_config": model.peft_config})
60 |
61 | with contextlib.suppress(Exception):
62 | info_dict["model/num_parameters"] = model.num_parameters()
63 |
64 | init_args.setdefault("config", info_dict)
65 | run = wandb.init(project=os.getenv("SAFE_WANDB_PROJECT", wandb_project_name), **init_args)
66 |
67 | artifact = wandb.Artifact(
68 | name=artifact_name,
69 | type=artifact_type,
70 | metadata={
71 | "model_config": model_config,
72 | "num_parameters": info_dict.get("model/num_parameters"),
73 | "initial_model": True,
74 | },
75 | )
76 |
77 | # Add model and tokenizer directories to the artifact
78 | artifact.add_dir(tmpdirname)
79 | run.log_artifact(artifact, aliases=aliases)
80 |
81 | # Finish the wandb run
82 | run.finish()
83 |
--------------------------------------------------------------------------------
/safe/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .. import utils
2 | from . import model
3 |
--------------------------------------------------------------------------------
/safe/trainer/cli.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import sys
4 | import uuid
5 | from dataclasses import dataclass, field
6 | from typing import Literal, Optional
7 |
8 | import datasets
9 | import evaluate
10 | import torch
11 | import transformers
12 | from loguru import logger
13 | from transformers import AutoConfig, AutoTokenizer, TrainingArguments, set_seed
14 | from transformers.trainer_utils import get_last_checkpoint
15 | from transformers.utils.logging import log_levels as LOG_LEVELS
16 |
17 | import safe
18 | from safe.tokenizer import SAFETokenizer
19 | from safe.trainer.collator import SAFECollator
20 | from safe.trainer.data_utils import get_dataset
21 | from safe.trainer.model import SAFEDoubleHeadsModel
22 | from safe.trainer.trainer_utils import SAFETrainer
23 |
24 | CURRENT_DIR = os.path.join(safe.__path__[0], "trainer")
25 |
26 |
27 | @dataclass
28 | class ModelArguments:
29 | model_path: str = field(
30 | default=None,
31 | metadata={
32 | "help": "Optional model path or model name to use as a starting point for the safe model"
33 | },
34 | )
35 | config: Optional[str] = field(
36 | default=None, metadata={"help": "Path to the default config file to use for the safe model"}
37 | )
38 |
39 | tokenizer: str = (
40 | field(
41 | default=None,
42 | metadata={"help": "Path to the trained tokenizer to use to build a safe model"},
43 | ),
44 | )
45 | num_labels: Optional[int] = field(
46 | default=None, metadata={"help": "Optional number of labels for the descriptors"}
47 | )
48 | include_descriptors: Optional[bool] = field(
49 | default=False,
50 | metadata={"help": "Whether to train with descriptors if they are available or Not"},
51 | )
52 | prop_loss_coeff: Optional[float] = field(
53 | default=1e-2, metadata={"help": "coefficient for the propery loss"}
54 | )
55 | model_hub_name: Optional[str] = field(
56 | default="maclandrol/safe-gpt2",
57 | metadata={"help": "Name of the model when uploading to huggingface"},
58 | )
59 |
60 | wandb_project: Optional[str] = field(
61 | default="safe-gpt2",
62 | metadata={"help": "Name of the wandb project to use to log the SAFE model parameter"},
63 | )
64 | wandb_watch: Optional[Literal["gradients", "all"]] = field(
65 | default=None, metadata={"help": "Whether to watch the wandb models or not"}
66 | )
67 | cache_dir: Optional[str] = field(
68 | default=None,
69 | metadata={"help": "Where do you want to store the pretrained models downloaded from s3"},
70 | )
71 | torch_dtype: Optional[str] = field(
72 | default=None,
73 | metadata={
74 | "help": (
75 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
76 | "dtype will be automatically derived from the model's weights."
77 | ),
78 | "choices": ["auto", "bfloat16", "float16", "float32"],
79 | },
80 | )
81 |
82 | low_cpu_mem_usage: bool = field(
83 | default=False,
84 | metadata={
85 | "help": (
86 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
87 | "set True will benefit LLM loading time and RAM consumption. Only valid when loading a pretrained model"
88 | )
89 | },
90 | )
91 | model_max_length: int = field(
92 | default=1024,
93 | metadata={
94 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated) up to that value."
95 | },
96 | )
97 |
98 |
99 | @dataclass
100 | class DataArguments:
101 | dataset: str = field(
102 | default=None,
103 | metadata={"help": "Path to the preprocessed dataset to use for the safe model building"},
104 | )
105 | is_tokenized: Optional[bool] = field(
106 | default=False,
107 | metadata={"help": "whether the dataset submitted as input is already tokenized or not"},
108 | )
109 |
110 | streaming: Optional[bool] = field(
111 | default=False, metadata={"help": "Whether to use a streaming dataset or not"}
112 | )
113 |
114 | text_column: Optional[str] = field(
115 | default="inputs", metadata={"help": "Column containing text data to process."}
116 | )
117 |
118 | max_train_samples: Optional[int] = field(
119 | default=None, metadata={"help": "Maximum number of training sample to use."}
120 | )
121 |
122 | max_eval_samples: Optional[int] = field(
123 | default=None, metadata={"help": "Maximum number of evaluation sample to use."}
124 | )
125 |
126 | train_split: Optional[str] = field(
127 | default="train",
128 | metadata={
129 | "help": "Training splits to use. You can use train+validation for example to include both train and validation split in the training"
130 | },
131 | )
132 |
133 | property_column: Optional[str] = field(
134 | default=None,
135 | metadata={
136 | "help": "Column containing the descriptors information. Default to None to use `mc_labels`"
137 | },
138 | )
139 |
140 |
141 | def train(model_args, data_args, training_args):
142 | """Train a new model from scratch"""
143 | if training_args.should_log:
144 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
145 | transformers.utils.logging.set_verbosity_info()
146 |
147 | log_level = training_args.get_process_log_level()
148 | if log_level is None or log_level == "passive":
149 | log_level = "info"
150 |
151 | _LOG_LEVEL = {v: k for k, v in LOG_LEVELS.items()}
152 | logger.remove()
153 | logger.add(sys.stderr, level=_LOG_LEVEL.get(log_level, log_level).upper())
154 | transformers.utils.logging.set_verbosity(log_level)
155 | transformers.utils.logging.enable_default_handler()
156 | transformers.utils.logging.enable_explicit_format()
157 |
158 | # Log on each process the small summary:
159 | logger.info(
160 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu} "
161 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
162 | )
163 |
164 | logger.info(f"Training/evaluation parameters: {training_args}")
165 |
166 | # Detecting last checkpoint.
167 | last_checkpoint = None
168 | if (
169 | os.path.isdir(training_args.output_dir)
170 | and training_args.do_train
171 | and not training_args.overwrite_output_dir
172 | ):
173 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
174 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
175 | logger.info(
176 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
177 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
178 | )
179 |
180 | # Check if parameter passed or if set within environ
181 | # Only overwrite environ if wandb param passed
182 | wandb_run_name = f"safe-model-{uuid.uuid4().hex[:8]}"
183 | if model_args.wandb_project:
184 | os.environ["WANDB_PROJECT"] = model_args.wandb_project
185 | training_args.report_to = ["wandb"]
186 | if model_args.wandb_watch:
187 | os.environ["WANDB_WATCH"] = model_args.wandb_watch
188 | if model_args.wandb_watch == "all":
189 | os.environ["WANDB_LOG_MODEL"] = "end"
190 |
191 | training_args.run_name = wandb_run_name
192 | training_args.remove_unused_columns = False
193 | # load tokenizer and model
194 |
195 | set_seed(training_args.seed)
196 | # load the tokenizer
197 | if model_args.tokenizer.endswith(".json"):
198 | tokenizer = SAFETokenizer.load(model_args.tokenizer)
199 | else:
200 | try:
201 | tokenizer = SAFETokenizer.load(model_args.tokenizer)
202 | except:
203 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer)
204 |
205 | # load dataset
206 | with training_args.main_process_first():
207 | # if the dataset is streaming we tokenize on the fly, it would be faster
208 | dataset = get_dataset(
209 | data_args.dataset,
210 | tokenizer=(None if data_args.is_tokenized or data_args.streaming else tokenizer),
211 | streaming=data_args.streaming,
212 | tokenize_column=data_args.text_column,
213 | max_length=model_args.model_max_length,
214 | property_column=data_args.property_column,
215 | )
216 |
217 | if data_args.max_train_samples is not None:
218 | dataset["train"] = dataset["train"].take(data_args.max_train_samples)
219 | if data_args.max_eval_samples is not None:
220 | for k in dataset:
221 | if k != "train":
222 | dataset[k] = dataset[k].take(data_args.max_eval_samples)
223 |
224 | eval_dataset_key_name = "validation" if "validation" in dataset else "test"
225 |
226 | train_dataset = dataset["train"]
227 | if data_args.train_split is not None:
228 | train_dataset = datasets.concatenate_datasets(
229 | [dataset[x] for x in data_args.train_split.split("+")]
230 | )
231 | if eval_dataset_key_name in data_args.train_split.split("+"):
232 | eval_dataset_key_name = None
233 |
234 | data_collator = SAFECollator(
235 | tokenizer=tokenizer,
236 | input_key=data_args.text_column,
237 | max_length=model_args.model_max_length,
238 | include_descriptors=model_args.include_descriptors,
239 | property_key="mc_labels",
240 | )
241 | pretrained_tokenizer = data_collator.get_tokenizer()
242 | config = model_args.config
243 |
244 | if config is None:
245 | config = os.path.join(CURRENT_DIR, "configs/default_config.json")
246 | config = AutoConfig.from_pretrained(config, cache_dir=model_args.cache_dir)
247 |
248 | if model_args.num_labels is not None:
249 | config.num_labels = int(model_args.num_labels)
250 |
251 | config.vocab_size = len(tokenizer)
252 | if model_args.model_max_length is not None:
253 | config.max_position_embeddings = model_args.model_max_length
254 | try:
255 | config.bos_token_id = tokenizer.bos_token_id
256 | config.eos_token_id = tokenizer.eos_token_id
257 | config.pad_token_id = tokenizer.pad_token_id
258 | except:
259 | config.bos_token_id = pretrained_tokenizer.bos_token_id
260 | config.eos_token_id = pretrained_tokenizer.eos_token_id
261 | config.pad_token_id = pretrained_tokenizer.pad_token_id
262 |
263 | if model_args.model_path is not None:
264 | torch_dtype = (
265 | model_args.torch_dtype
266 | if model_args.torch_dtype in ["auto", None]
267 | else getattr(torch, model_args.torch_dtype)
268 | )
269 | model = SAFEDoubleHeadsModel.from_pretrained(
270 | model_args.model_path,
271 | config=config,
272 | cache_dir=model_args.cache_dir,
273 | low_cpu_mem_usage=model_args.low_cpu_mem_usage,
274 | torch_dtype=torch_dtype,
275 | )
276 |
277 | else:
278 | model = SAFEDoubleHeadsModel(config)
279 |
280 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
281 | # on a small vocab and want a smaller embedding size, remove this test.
282 | embedding_size = model.get_input_embeddings().weight.shape[0]
283 | if len(tokenizer) > embedding_size:
284 | model.resize_token_embeddings(len(tokenizer))
285 |
286 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
287 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
288 |
289 | def preprocess_logits_for_metrics(logits, labels):
290 | prop_logits = None
291 | if isinstance(logits, tuple):
292 | # Depending on the model and config, logits may contain extra tensors,
293 | # like past_key_values, but logits always come first
294 | if len(logits) > 1:
295 | # we could have the loss twice
296 | base_ind = 0
297 | if logits[base_ind].ndim < 2:
298 | base_ind = 1
299 | label_logits = logits[base_ind].argmax(dim=-1)
300 | if len(logits) > base_ind + 1:
301 | prop_logits = logits[base_ind + 1]
302 | else:
303 | label_logits = logits.argmax(dim=-1)
304 |
305 | if prop_logits is not None:
306 | return label_logits, prop_logits
307 | return label_logits
308 |
309 | accuracy_metric = evaluate.load("accuracy")
310 | mse_metric = evaluate.load("mse")
311 |
312 | def compute_metrics(eval_preds):
313 | preds, labels = eval_preds
314 | mc_preds = None
315 | mc_labels = None
316 | if isinstance(preds, tuple):
317 | preds, mc_preds = preds
318 | if isinstance(labels, tuple):
319 | labels, mc_labels = labels
320 | # preds have the same shape as the labels, after the argmax(-1) has been calculated
321 | # by preprocess_logits_for_metrics but we need to shift the labels
322 | labels = labels[:, 1:].reshape(-1)
323 | preds = preds[:, :-1].reshape(-1)
324 | results = accuracy_metric.compute(predictions=preds, references=labels)
325 | if mc_preds is not None and mc_labels is not None:
326 | results_mse = mse_metric.compute(
327 | predictions=mc_preds.reshape(-1), references=mc_labels.reshape(-1)
328 | )
329 | results.update(results_mse)
330 | return results
331 |
332 | if model_args.include_descriptors:
333 | training_args.label_names = ["labels", "mc_labels"]
334 |
335 | if training_args.label_names is None:
336 | training_args.label_names = ["labels"]
337 | # update dispatch_batches in accelerator
338 | training_args.accelerator_config.dispatch_batches = data_args.streaming is not True
339 |
340 | trainer = SAFETrainer(
341 | model=model,
342 | tokenizer=None, # we don't deal with the tokenizer at all, https://github.com/huggingface/tokenizers/issues/581 -_-
343 | train_dataset=train_dataset.shuffle(seed=(training_args.seed or 42)),
344 | eval_dataset=dataset.get(eval_dataset_key_name, None),
345 | args=training_args,
346 | prop_loss_coeff=model_args.prop_loss_coeff,
347 | compute_metrics=compute_metrics if training_args.do_eval else None,
348 | data_collator=data_collator,
349 | preprocess_logits_for_metrics=(
350 | preprocess_logits_for_metrics if training_args.do_eval else None
351 | ),
352 | )
353 |
354 | if training_args.do_train:
355 | checkpoint = None
356 | if training_args.resume_from_checkpoint is not None:
357 | checkpoint = training_args.resume_from_checkpoint
358 | elif last_checkpoint is not None:
359 | checkpoint = last_checkpoint
360 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
361 | try:
362 | # we were unable to save the model because of the tokenizer
363 | trainer.save_model() # Saves the tokenizer too for easy upload
364 | except:
365 | model.save_pretrained(os.path.join(training_args.output_dir, "safe-model"))
366 |
367 | if training_args.push_to_hub and model_args.model_hub_name:
368 | model.push_to_hub(model_args.model_hub_name, private=True, safe_serialization=True)
369 |
370 | metrics = train_result.metrics
371 |
372 | trainer.log_metrics("train", metrics)
373 | trainer.save_metrics("train", metrics)
374 | trainer.save_state()
375 |
376 | # For convenience, we also re-save the tokenizer to the same directory,
377 | # so that you can share your model easily on huggingface.co/models =)
378 | if trainer.is_world_process_zero():
379 | tokenizer.save(os.path.join(training_args.output_dir, "tokenizer.json"))
380 |
381 | # Evaluation
382 | if training_args.do_eval:
383 | logger.info("*** Evaluate ***")
384 | results = trainer.evaluate()
385 | try:
386 | perplexity = math.exp(results["eval_loss"])
387 | except Exception as e:
388 | logger.error(e)
389 | perplexity = float("inf")
390 | results.update({"perplexity": perplexity})
391 | if trainer.is_world_process_zero():
392 | trainer.log_metrics("eval", results)
393 | trainer.save_metrics("eval", results)
394 |
395 | kwargs = {"finetuned_from": model_args.model_path, "tasks": "text-generation"}
396 | kwargs["dataset_tags"] = data_args.dataset
397 | kwargs["dataset"] = data_args.dataset
398 | kwargs["tags"] = ["safe", "datamol-io", "molecule-design", "smiles"]
399 |
400 | if training_args.push_to_hub:
401 | kwargs["private"] = True
402 | trainer.push_to_hub(**kwargs)
403 | else:
404 | trainer.create_model_card(**kwargs)
405 |
406 |
407 | def main():
408 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
409 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
410 | train(model_args, data_args, training_args)
411 |
412 |
413 | if __name__ == "__main__":
414 | main()
415 |
--------------------------------------------------------------------------------
/safe/trainer/collator.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | from collections.abc import Mapping
4 | from typing import Any, Dict, List, Optional, Union
5 |
6 | import torch
7 | from tokenizers import Tokenizer
8 | from transformers.data.data_collator import _torch_collate_batch
9 |
10 | from safe.tokenizer import SAFETokenizer
11 |
12 |
13 | class SAFECollator:
14 | """Collate function for language modelling tasks
15 |
16 |
17 | !!! note
18 | The collate function is based on the default DataCollatorForLanguageModeling in huggingface
19 | see: https://github.com/huggingface/transformers/blob/v4.19.2/src/transformers/data/data_collator.py
20 | """
21 |
22 | def __init__(
23 | self,
24 | tokenizer: Tokenizer,
25 | pad_to_multiple_of: Optional[int] = None,
26 | input_key: str = "inputs",
27 | label_key: str = "labels",
28 | property_key: str = "descriptors",
29 | include_descriptors: bool = False,
30 | max_length: Optional[int] = None,
31 | ):
32 | """
33 | Default collator for huggingface transformers in izanagi.
34 |
35 | Args:
36 | tokenizer: Huggingface tokenizer
37 | input_key: key to use for input ids
38 | label_key: key to use for labels
39 | property_key: key to use for properties
40 | include_descriptors: whether to include training on descriptors or not
41 | pad_to_multiple_of: pad to multiple of this value
42 | """
43 |
44 | self.tokenizer = tokenizer
45 | self.pad_to_multiple_of = pad_to_multiple_of
46 | self.input_key = input_key
47 | self.label_key = label_key
48 | self.property_key = property_key
49 | self.include_descriptors = include_descriptors
50 | self.max_length = max_length
51 |
52 | @functools.lru_cache()
53 | def get_tokenizer(self):
54 | """Get underlying tokenizer"""
55 | if isinstance(self.tokenizer, SAFETokenizer):
56 | return self.tokenizer.get_pretrained()
57 | return self.tokenizer
58 |
59 | def __call__(self, samples: List[Union[List[int], Any, Dict[str, Any]]]):
60 | """
61 | Call collate function
62 |
63 | Args:
64 | samples: list of examples
65 | """
66 | # Handle dict or lists with proper padding and conversion to tensor.
67 | tokenizer = self.get_tokenizer()
68 |
69 | # examples = samples
70 | examples = copy.deepcopy(samples)
71 | inputs = [example.pop(self.input_key, None) for example in examples]
72 | mc_labels = (
73 | torch.tensor([example.pop(self.property_key, None) for example in examples]).float()
74 | if self.property_key in examples[0]
75 | else None
76 | )
77 |
78 | if "input_ids" not in examples[0] and inputs is not None:
79 | batch = tokenizer(
80 | inputs,
81 | return_tensors="pt",
82 | padding=True,
83 | truncation=True,
84 | max_length=self.max_length,
85 | pad_to_multiple_of=self.pad_to_multiple_of,
86 | )
87 | else:
88 | batch = tokenizer.pad(
89 | examples,
90 | return_tensors="pt",
91 | padding=True,
92 | pad_to_multiple_of=self.pad_to_multiple_of,
93 | max_length=self.max_length,
94 | )
95 |
96 | # If special token mask has been preprocessed, pop it from the dict.
97 | batch.pop("special_tokens_mask", None)
98 | labels = batch.get(self.label_key, batch["input_ids"].clone())
99 | if tokenizer.pad_token_id is not None:
100 | labels[labels == tokenizer.pad_token_id] = -100
101 | batch[self.label_key] = labels
102 |
103 | if mc_labels is not None and self.include_descriptors:
104 | batch.update(
105 | {
106 | "mc_labels": mc_labels,
107 | # "input_text": inputs,
108 | }
109 | )
110 | return batch
111 |
--------------------------------------------------------------------------------
/safe/trainer/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamol-io/safe/84a4697e6d89792fe7c870fe65b3f43e28191725/safe/trainer/configs/__init__.py
--------------------------------------------------------------------------------
/safe/trainer/configs/default_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "activation_function": "gelu_new",
3 | "attn_pdrop": 0.1,
4 | "bos_token_id": 10000,
5 | "embd_pdrop": 0.1,
6 | "eos_token_id": 1,
7 | "initializer_range": 0.02,
8 | "layer_norm_epsilon": 1e-05,
9 | "model_type": "gpt2",
10 | "n_embd": 768,
11 | "n_head": 12,
12 | "n_inner": null,
13 | "n_layer": 12,
14 | "n_positions": 1024,
15 | "reorder_and_upcast_attn": false,
16 | "resid_pdrop": 0.1,
17 | "scale_attn_by_inverse_layer_idx": false,
18 | "scale_attn_weights": true,
19 | "summary_activation": "relu",
20 | "summary_first_dropout": 0.1,
21 | "summary_proj_to_labels": true,
22 | "summary_hidden_size": 64,
23 | "summary_type": "cls_index",
24 | "summary_use_proj": true,
25 | "transformers_version": "4.31.0",
26 | "use_cache": true,
27 | "vocab_size": 10000,
28 | "num_labels": 18
29 | }
--------------------------------------------------------------------------------
/safe/trainer/data_utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from collections.abc import Mapping
3 | from functools import partial
4 | from typing import Any, Callable, Dict, Optional, Union
5 |
6 | import datasets
7 | import upath
8 | from tqdm.auto import tqdm
9 |
10 | from safe.tokenizer import SAFETokenizer
11 |
12 |
13 | def take(n, iterable):
14 | "Return first n items of the iterable as a list"
15 | return list(itertools.islice(iterable, n))
16 |
17 |
18 | def get_dataset_column_names(dataset: Union[datasets.Dataset, datasets.IterableDataset, Mapping]):
19 | """Get the column names in a dataset
20 |
21 | Args:
22 | dataset: dataset to get the column names from
23 |
24 | """
25 | if isinstance(dataset, (datasets.IterableDatasetDict, Mapping)):
26 | column_names = {split: dataset[split].column_names for split in dataset}
27 | else:
28 | column_names = dataset.column_names
29 | if isinstance(column_names, dict):
30 | column_names = list(column_names.values())[0]
31 | return column_names
32 |
33 |
34 | def tokenize_fn(
35 | row: Dict[str, Any],
36 | tokenizer: Callable,
37 | tokenize_column: str = "inputs",
38 | max_length: Optional[int] = None,
39 | padding: bool = False,
40 | ):
41 | """Perform the tokenization of a row
42 | Args:
43 | row: row to tokenize
44 | tokenizer: tokenizer to use
45 | tokenize_column: column to tokenize
46 | max_length: maximum size of the tokenized sequence
47 | padding: whether to pad the sequence
48 | """
49 | # there's probably a way to do this with the tokenizer settings
50 | # but again, gotta move fast
51 |
52 | fast_tokenizer = (
53 | tokenizer.get_pretrained() if isinstance(tokenizer, SAFETokenizer) else tokenizer
54 | )
55 |
56 | return fast_tokenizer(
57 | row[tokenize_column],
58 | truncation=(max_length is not None),
59 | max_length=max_length,
60 | padding=padding,
61 | return_tensors=None,
62 | )
63 |
64 |
65 | def batch_iterator(datasets, batch_size=100, n_examples=None, column="inputs"):
66 | if isinstance(datasets, Mapping):
67 | datasets = list(datasets.values())
68 |
69 | if not isinstance(datasets, (list, tuple)):
70 | datasets = [datasets]
71 |
72 | for dataset in datasets:
73 | iter_dataset = iter(dataset)
74 | if n_examples is not None and n_examples > 0:
75 | for _ in tqdm(range(0, n_examples, batch_size)):
76 | out = [next(iter_dataset)[column] for _ in range(batch_size)]
77 | yield out
78 | else:
79 | for out in tqdm(iter(partial(take, batch_size, iter_dataset), [])):
80 | yield [x[column] for x in out]
81 |
82 |
83 | def get_dataset(
84 | data_path,
85 | name: Optional[str] = None,
86 | tokenizer: Optional[Callable] = None,
87 | cache_dir: Optional[str] = None,
88 | streaming: bool = True,
89 | use_auth_token: bool = False,
90 | tokenize_column: Optional[str] = "inputs",
91 | property_column: Optional[str] = "descriptors",
92 | max_length: Optional[int] = None,
93 | num_shards=1024,
94 | ):
95 | """Get the datasets from the config file"""
96 | raw_datasets = {}
97 | if data_path is not None:
98 | data_path = upath.UPath(str(data_path))
99 |
100 | if data_path.exists():
101 | # then we need to load from disk
102 | data_path = str(data_path)
103 | # for some reason, the datasets package is not able to load the dataset
104 | # because the split where not originally proposed
105 | raw_datasets = datasets.load_from_disk(data_path)
106 |
107 | if streaming:
108 | if isinstance(raw_datasets, datasets.DatasetDict):
109 | previous_num_examples = {k: len(dt) for k, dt in raw_datasets.items()}
110 | raw_datasets = datasets.IterableDatasetDict(
111 | {
112 | k: dt.to_iterable_dataset(num_shards=num_shards)
113 | for k, dt in raw_datasets.items()
114 | }
115 | )
116 | for k, dt in raw_datasets.items():
117 | if previous_num_examples[k] is not None:
118 | setattr(dt, "num_examples", previous_num_examples[k])
119 | else:
120 | num_examples = len(raw_datasets)
121 | raw_datasets = raw_datasets.to_iterable_dataset(num_shards=num_shards)
122 | setattr(raw_datasets, "num_examples", num_examples)
123 |
124 | else:
125 | data_path = str(data_path)
126 | raw_datasets = datasets.load_dataset(
127 | data_path,
128 | name=name,
129 | cache_dir=cache_dir,
130 | use_auth_token=True if use_auth_token else None,
131 | streaming=streaming,
132 | )
133 | # that means we need to return a tokenized version of the dataset
134 |
135 | if property_column not in ["mc_labels", None]:
136 | raw_datasets = raw_datasets.rename_column(property_column, "mc_labels")
137 |
138 | columns_to_remove = None
139 | if tokenize_column is not None:
140 | columns_to_remove = [
141 | x
142 | for x in (get_dataset_column_names(raw_datasets) or [])
143 | if x not in [tokenize_column, "mc_labels"] and "label" not in x
144 | ] or None
145 |
146 | if tokenizer is None:
147 | if columns_to_remove is not None:
148 | raw_datasets = raw_datasets.remove_columns(columns_to_remove)
149 | return raw_datasets
150 |
151 | return raw_datasets.map(
152 | partial(
153 | tokenize_fn,
154 | tokenizer=tokenizer,
155 | tokenize_column=tokenize_column,
156 | max_length=max_length,
157 | ),
158 | batched=True,
159 | remove_columns=columns_to_remove,
160 | )
161 |
--------------------------------------------------------------------------------
/safe/trainer/model.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional, Tuple, Union
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import CrossEntropyLoss, MSELoss
6 | from transformers import GPT2DoubleHeadsModel, PretrainedConfig
7 | from transformers.activations import get_activation
8 | from transformers.models.gpt2.modeling_gpt2 import (
9 | _CONFIG_FOR_DOC,
10 | GPT2_INPUTS_DOCSTRING,
11 | GPT2DoubleHeadsModelOutput,
12 | add_start_docstrings_to_model_forward,
13 | replace_return_docstrings,
14 | )
15 |
16 |
17 | class PropertyHead(torch.nn.Module):
18 | r"""
19 | Compute a single vector summary of a sequence hidden states.
20 |
21 | Args:
22 | config ([`PretrainedConfig`]):
23 | The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
24 | config class of your model for the default values it uses):
25 |
26 | - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
27 |
28 | - `"last"` -- Take the last token hidden state (like XLNet)
29 | - `"first"` -- Take the first token hidden state (like Bert)
30 | - `"mean"` -- Take the mean of all tokens hidden states
31 | - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
32 |
33 | - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
34 | another string, or `None` to add no activation.
35 | """
36 |
37 | def __init__(self, config: PretrainedConfig):
38 | super().__init__()
39 |
40 | self.summary_type = getattr(config, "summary_type", "cls_index")
41 | self.summary = torch.nn.Identity()
42 | last_hidden_size = config.hidden_size
43 |
44 | if getattr(config, "summary_hidden_size", None) and config.summary_hidden_size > 0:
45 | self.summary = nn.Linear(config.hidden_size, config.summary_hidden_size)
46 | last_hidden_size = config.summary_hidden_size
47 |
48 | activation_string = getattr(config, "summary_activation", None)
49 | self.activation: Callable = (
50 | get_activation(activation_string) if activation_string else nn.Identity()
51 | )
52 |
53 | self.out = torch.nn.Identity()
54 | if getattr(config, "num_labels", None) and config.num_labels > 0:
55 | num_labels = config.num_labels
56 | self.out = nn.Linear(last_hidden_size, num_labels)
57 |
58 | def forward(
59 | self,
60 | hidden_states: torch.FloatTensor,
61 | cls_index: Optional[torch.LongTensor] = None,
62 | ) -> torch.FloatTensor:
63 | """
64 | Compute a single vector summary of a sequence hidden states.
65 |
66 | Args:
67 | hidden_states: `torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`)
68 | The hidden states of the last layer.
69 | cls_index: `torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]`
70 | where ... are optional leading dimensions of `hidden_states`, *optional*
71 | Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
72 |
73 | Returns:
74 | `torch.FloatTensor`: The summary of the sequence hidden states.
75 | """
76 | if self.summary_type == "last":
77 | output = hidden_states[:, -1]
78 | elif self.summary_type == "first":
79 | output = hidden_states[:, 0]
80 | elif self.summary_type == "mean":
81 | output = hidden_states.mean(dim=1)
82 | elif self.summary_type == "cls_index":
83 | # if cls_index is None:
84 | # cls_index = torch.full_like(
85 | # hidden_states[..., :1, :],
86 | # hidden_states.shape[-2] - 1,
87 | # dtype=torch.long,
88 | # )
89 | # else:
90 | # cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
91 | # cls_index = cls_index.expand(
92 | # (-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)
93 | # )
94 |
95 | # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
96 | # output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
97 | batch_size = hidden_states.shape[0]
98 | output = hidden_states.squeeze()[torch.arange(batch_size), cls_index]
99 | else:
100 | raise NotImplementedError
101 |
102 | output = self.summary(output)
103 | output = self.activation(output)
104 | return self.out(output)
105 |
106 |
107 | class SAFEDoubleHeadsModel(GPT2DoubleHeadsModel):
108 | """The safe model is a dual head GPT2 model with a language modeling head and an optional multi-task regression head"""
109 |
110 | def __init__(self, config):
111 | self.num_labels = getattr(config, "num_labels", None)
112 | super().__init__(config)
113 | self.config.num_labels = self.num_labels
114 | del self.multiple_choice_head
115 | self.multiple_choice_head = PropertyHead(config)
116 |
117 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
118 | @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
119 | def forward(
120 | self,
121 | input_ids: Optional[torch.LongTensor] = None,
122 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
123 | attention_mask: Optional[torch.FloatTensor] = None,
124 | token_type_ids: Optional[torch.LongTensor] = None,
125 | position_ids: Optional[torch.LongTensor] = None,
126 | head_mask: Optional[torch.FloatTensor] = None,
127 | inputs_embeds: Optional[torch.FloatTensor] = None,
128 | mc_token_ids: Optional[torch.LongTensor] = None,
129 | labels: Optional[torch.LongTensor] = None,
130 | mc_labels: Optional[torch.LongTensor] = None,
131 | use_cache: Optional[bool] = None,
132 | output_attentions: Optional[bool] = None,
133 | output_hidden_states: Optional[bool] = None,
134 | return_dict: Optional[bool] = None,
135 | inputs: Optional[Any] = None, # do not remove because of trainer
136 | encoder_hidden_states: Optional[torch.Tensor] = None,
137 | **kwargs,
138 | ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
139 | r"""
140 |
141 | Args:
142 | mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
143 | Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
144 | 1]`.
145 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
146 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
147 | `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
148 | `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
149 | mc_labels (`torch.LongTensor` of shape `(batch_size, n_tasks)`, *optional*):
150 | Labels for computing the supervized loss for regularization.
151 | inputs: List of inputs, put here because the trainer removes information not in signature
152 | Returns:
153 | output (GPT2DoubleHeadsModelOutput): output of the model
154 | """
155 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
156 | transformer_outputs = self.transformer(
157 | input_ids,
158 | past_key_values=past_key_values,
159 | attention_mask=attention_mask,
160 | token_type_ids=token_type_ids,
161 | position_ids=position_ids,
162 | head_mask=head_mask,
163 | inputs_embeds=inputs_embeds,
164 | use_cache=use_cache,
165 | output_attentions=output_attentions,
166 | output_hidden_states=output_hidden_states,
167 | return_dict=return_dict,
168 | encoder_hidden_states=encoder_hidden_states,
169 | )
170 |
171 | hidden_states = transformer_outputs[0]
172 | lm_logits = self.lm_head(hidden_states)
173 |
174 | if mc_token_ids is None and self.config.pad_token_id is not None and input_ids is not None:
175 | mc_token_ids = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(
176 | lm_logits.device
177 | )
178 |
179 | # Set device for model parallelism
180 | if self.model_parallel:
181 | torch.cuda.set_device(self.transformer.first_device)
182 | hidden_states = hidden_states.to(self.lm_head.weight.device)
183 |
184 | mc_loss = None
185 | mc_logits = None
186 | if mc_labels is not None and getattr(self.config, "num_labels", 0) > 0:
187 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
188 | mc_labels = mc_labels.to(mc_logits.device)
189 | loss_fct = MSELoss()
190 | mc_loss = loss_fct(
191 | mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1, mc_logits.size(-1))
192 | )
193 |
194 | lm_loss = None
195 | if labels is not None:
196 | labels = labels.to(lm_logits.device)
197 | shift_logits = lm_logits[..., :-1, :].contiguous()
198 | shift_labels = labels[..., 1:].contiguous()
199 | loss_fct = CrossEntropyLoss()
200 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
201 |
202 | if not return_dict:
203 | output = (lm_logits, mc_logits) + transformer_outputs[1:]
204 | return (
205 | lm_loss,
206 | mc_loss,
207 | ) + output
208 |
209 | return GPT2DoubleHeadsModelOutput(
210 | loss=lm_loss,
211 | mc_loss=mc_loss,
212 | logits=lm_logits,
213 | mc_logits=mc_logits,
214 | past_key_values=transformer_outputs.past_key_values,
215 | hidden_states=transformer_outputs.hidden_states,
216 | attentions=transformer_outputs.attentions,
217 | )
218 |
--------------------------------------------------------------------------------
/safe/trainer/trainer_utils.py:
--------------------------------------------------------------------------------
1 | from transformers import Trainer
2 | from transformers.modeling_utils import unwrap_model
3 | from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
4 | from transformers.trainer import _is_peft_model
5 |
6 |
7 | class SAFETrainer(Trainer):
8 | """
9 | Custom trainer for training SAFE model.
10 |
11 | This custom trainer changes the loss function to support the property head
12 |
13 | """
14 |
15 | def __init__(self, *args, prop_loss_coeff: float = 1e-3, **kwargs):
16 | super().__init__(*args, **kwargs)
17 | self.prop_loss_coeff = prop_loss_coeff
18 |
19 | def compute_loss(self, model, inputs, return_outputs=False):
20 | """
21 | How the loss is computed by Trainer. By default, all models return the loss in the first element.
22 | """
23 | labels = (
24 | inputs.pop("labels") if self.label_smoother is not None and "labels" in inputs else None
25 | )
26 | outputs = model(**inputs)
27 | # Save past state if it exists
28 | # TODO: this needs to be fixed and made cleaner later.
29 | if self.args.past_index >= 0:
30 | self._past = outputs[self.args.past_index]
31 |
32 | if labels is not None:
33 | unwrapped_model = self.accelerator.unwrap_model(model)
34 | if _is_peft_model(unwrapped_model):
35 | model_name = unwrapped_model.base_model.model._get_name()
36 | else:
37 | model_name = unwrapped_model._get_name()
38 | if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
39 | loss = self.label_smoother(outputs, labels, shift_labels=True)
40 | else:
41 | loss = self.label_smoother(outputs, labels)
42 | else:
43 | if isinstance(outputs, dict) and "loss" not in outputs:
44 | raise ValueError(
45 | "The model did not return a loss from the inputs, only the following keys: "
46 | f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
47 | )
48 | # We don't use .loss here since the model may return tuples instead of ModelOutput.
49 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
50 |
51 | mc_loss = outputs.get("mc_loss", None) if isinstance(outputs, dict) else outputs[1]
52 | if mc_loss is not None:
53 | loss = loss + self.prop_loss_coeff * mc_loss
54 | return (loss, outputs) if return_outputs else loss
55 |
--------------------------------------------------------------------------------
/safe/utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import random
3 | import re
4 | from collections import deque
5 | from contextlib import contextmanager, suppress
6 | from functools import partial
7 | from itertools import combinations, compress
8 | from typing import Any, List, Optional, Tuple, Union
9 |
10 | import datamol as dm
11 | import networkx as nx
12 | import numpy as np
13 | from loguru import logger
14 | from networkx.utils import py_random_state
15 | from rdkit import Chem
16 | from rdkit.Chem import Atom, EditableMol
17 | from rdkit.Chem.rdChemReactions import ReactionFromSmarts
18 | from rdkit.Chem.rdmolops import AdjustQueryParameters, AdjustQueryProperties, ReplaceCore
19 |
20 | import safe as sf
21 |
22 | __implicit_carbon_query = dm.from_smarts("[#6;h]")
23 | __mmpa_query = dm.from_smarts("[*;!$(*=,#[!#6])]!@!=!#[*]")
24 |
25 | _SMILES_ATTACHMENT_POINT_TOKEN = "\\*"
26 | # The following allows discovering and extracting valid dummy atoms in smiles/smarts.
27 | _SMILES_ATTACHMENT_POINTS = [
28 | # parse any * not preceeded by "[" or ":"" and not followed by "]" or ":" as attachment
29 | r"(?>([*:1][*:2].[*:3][*:4])"
51 | )
52 |
53 | def __init__(
54 | self,
55 | shortest_linker: bool = False,
56 | min_linker_size: int = 0,
57 | require_ring_system: bool = True,
58 | verbose: bool = False,
59 | ):
60 | """
61 | Constructor of bond slicer.
62 |
63 | Args:
64 | shortest_linker: whether to consider longuest or shortest linker.
65 | Does not have any effect when expected_head group is provided during splitting
66 | min_linker_size: minimum linker size
67 | require_ring_system: whether all fragment needs to have a ring system
68 | verbose: whether to allow verbosity in logging
69 | """
70 |
71 | self.bond_splitters = [dm.from_smarts(x) for x in self.BOND_SPLITTERS]
72 | self.shortest_linker = shortest_linker
73 | self.min_linker_size = min_linker_size
74 | self.require_ring_system = require_ring_system
75 | self.verbose = verbose
76 |
77 | def get_ring_system(self, mol: dm.Mol):
78 | """Get the list of ring system from a molecule
79 |
80 | Args:
81 | mol: input molecule for which we are computing the ring system
82 | """
83 | mol.UpdatePropertyCache()
84 | ri = mol.GetRingInfo()
85 | systems = []
86 | for ring in ri.AtomRings():
87 | ring_atoms = set(ring)
88 | cur_system = [] # keep a track of ring system
89 | for system in systems:
90 | if len(ring_atoms.intersection(system)) > 0:
91 | ring_atoms = ring_atoms.union(system) # merge ring system that overlap
92 | else:
93 | cur_system.append(system)
94 | cur_system.append(ring_atoms)
95 | systems = cur_system
96 | return systems
97 |
98 | def _bond_selection_from_max_cuts(self, bond_list: List[int], dist_mat: np.ndarray):
99 | """Select bonds based on maximum number of cuts allowed"""
100 | # for now we are just implementing to 2 max cuts algorithms
101 | if self.MAX_CUTS != 2:
102 | raise ValueError(f"Only MAX_CUTS=2 is supported, got {self.MAX_CUTS}")
103 |
104 | bond_pdist = np.full((len(bond_list), len(bond_list)), -1)
105 | for i in range(len(bond_list)):
106 | for j in range(i, len(bond_list)):
107 | # we get the minimum topological distance between bond to cut
108 | bond_pdist[i, j] = bond_pdist[j, i] = min(
109 | [dist_mat[a1, a2] for a1, a2 in itertools.product(bond_list[i], bond_list[j])]
110 | )
111 |
112 | masked_bond_pdist = np.ma.masked_less_equal(bond_pdist, self.min_linker_size)
113 |
114 | if self.shortest_linker:
115 | return np.unravel_index(np.ma.argmin(masked_bond_pdist), bond_pdist.shape)
116 | return np.unravel_index(np.ma.argmax(masked_bond_pdist), bond_pdist.shape)
117 |
118 | def _get_bonds_to_cut(self, mol: dm.Mol):
119 | """Get possible bond to cuts
120 |
121 | Args:
122 | mol: input molecule
123 | """
124 | # use this if you want to enumerate yourself the possible cuts
125 |
126 | ring_systems = self.get_ring_system(mol)
127 | candidate_bonds = []
128 | ring_query = Chem.rdqueries.IsInRingQueryAtom()
129 |
130 | for query in self.bond_splitters:
131 | bonds = mol.GetSubstructMatches(query, uniquify=True)
132 | cur_unique_bonds = [set(cbond) for cbond in candidate_bonds]
133 | # do not accept bonds part of the same ring system or already known
134 | for b in bonds:
135 | bond_id = mol.GetBondBetweenAtoms(*b).GetIdx()
136 | bond_cut = Chem.GetMolFrags(
137 | Chem.FragmentOnBonds(mol, [bond_id], addDummies=False), asMols=True
138 | )
139 | can_add = not self.require_ring_system or all(
140 | len(frag.GetAtomsMatchingQuery(ring_query)) > 0 for frag in bond_cut
141 | )
142 | if can_add and not (
143 | set(b) in cur_unique_bonds or any(x.issuperset(set(b)) for x in ring_systems)
144 | ):
145 | candidate_bonds.append(b)
146 | return candidate_bonds
147 |
148 | def _fragment_mol(self, mol: dm.Mol, bonds: List[dm.Bond]):
149 | """Fragment molecules on bonds and return head, linker, tail combination
150 |
151 | Args:
152 | mol: input molecule
153 | bonds: list of bonds to cut
154 | """
155 | tmp = Chem.rdmolops.FragmentOnBonds(mol, [b.GetIdx() for b in bonds])
156 | _frags = list(Chem.GetMolFrags(tmp, asMols=True))
157 | # linker is the one with 2 dummy atoms
158 | linker_pos = 0
159 | for pos, _frag in enumerate(_frags):
160 | if sum([at.GetSymbol() == "*" for at in _frag.GetAtoms()]) == 2:
161 | linker_pos = pos
162 | break
163 | linker = _frags.pop(linker_pos)
164 | head, tail = _frags
165 | return (head, linker, tail)
166 |
167 | def _compute_linker_score(self, linker: dm.Mol):
168 | """Compute the score of a linker to help select between linkers"""
169 |
170 | # we need to take into account
171 | # case where we require the linker to have a ring system
172 | # case where we want the linker to be longuest or shortest
173 |
174 | # find shortest path
175 | attach1, attach2, *_ = [at.GetIdx() for at in linker.GetAtoms() if at.GetSymbol() == "*"]
176 | score = len(Chem.rdmolops.GetShortestPath(linker, attach1, attach2))
177 | ring_query = Chem.rdqueries.IsInRingQueryAtom()
178 | linker_ring_count = len(linker.GetAtomsMatchingQuery(ring_query))
179 | if self.require_ring_system:
180 | score *= int(linker_ring_count > 0)
181 | if score == 0:
182 | return float("inf")
183 | if not self.shortest_linker:
184 | score = 1 / score
185 | return score
186 |
187 | def __call__(self, mol: Union[dm.Mol, str], expected_head: Union[dm.Mol, str] = None):
188 | """Perform slicing of the input molecule
189 |
190 | Args:
191 | mol: input molecule
192 | expected_head: substructure that should be part of the head.
193 | The small fragment containing this substructure would be kept as head
194 | """
195 |
196 | mol = dm.to_mol(mol)
197 | # remove salt and solution
198 | mol = dm.keep_largest_fragment(mol)
199 | Chem.rdDepictor.Compute2DCoords(mol)
200 | dist_mat = Chem.rdmolops.GetDistanceMatrix(mol)
201 |
202 | if expected_head is not None:
203 | if isinstance(expected_head, str):
204 | expected_head = dm.to_mol(expected_head)
205 | if not mol.HasSubstructMatch(expected_head):
206 | if self.verbose:
207 | logger.info(
208 | "Expected head was provided, but does not match molecules. It will be ignored"
209 | )
210 | expected_head = None
211 |
212 | candidate_bonds = self._get_bonds_to_cut(mol)
213 |
214 | # we have all the candidate bonds we can cut
215 | # now we need to pick the most plausible bonds
216 | selected_bonds = [mol.GetBondBetweenAtoms(a1, a2) for (a1, a2) in candidate_bonds]
217 |
218 | # CASE 1: no bond to cut ==> only head
219 | if len(selected_bonds) == 0:
220 | return (mol, None, None)
221 |
222 | # CASE 2: only one bond ==> linker is empty
223 | if len(selected_bonds) == 1:
224 | # there is not linker
225 | tmp = Chem.rdmolops.FragmentOnBonds(mol, [b.GetIdx() for b in selected_bonds])
226 | head, tail = Chem.GetMolFrags(tmp, asMols=True)
227 | return (head, None, tail)
228 |
229 | # CASE 3a: we select the most plausible bond to cut on ourselves
230 | if expected_head is None:
231 | choice = self._bond_selection_from_max_cuts(candidate_bonds, dist_mat)
232 | selected_bonds = [selected_bonds[c] for c in choice]
233 | return self._fragment_mol(mol, selected_bonds)
234 |
235 | # CASE 3b: slightly more complex case where we want the head to be the smallest graph containing the
236 | # provided substructure
237 | bond_combination = list(itertools.combinations(selected_bonds, self.MAX_CUTS))
238 | bond_score = float("inf")
239 | linker_score = float("inf")
240 | head, linker, tail = (None, None, None)
241 | for split_bonds in bond_combination:
242 | cur_head, cur_linker, cur_tail = self._fragment_mol(mol, split_bonds)
243 | # head can also be tail
244 | head_match = cur_head.GetSubstructMatch(expected_head)
245 | tail_match = cur_tail.GetSubstructMatch(expected_head)
246 | if not head_match and not tail_match:
247 | continue
248 | if not head_match and tail_match:
249 | cur_head, cur_tail = cur_tail, cur_head
250 | cur_bond_score = cur_head.GetNumHeavyAtoms()
251 | # compute linker score
252 | cur_linker_score = self._compute_linker_score(cur_linker)
253 | if (cur_bond_score < bond_score) or (
254 | cur_bond_score < self._BOND_BUFFER + bond_score and cur_linker_score < linker_score
255 | ):
256 | head, linker, tail = cur_head, cur_linker, cur_tail
257 | bond_score = cur_bond_score
258 | linker_score = cur_linker_score
259 |
260 | return (head, linker, tail)
261 |
262 | @classmethod
263 | def link_fragments(
264 | cls, linker: Union[dm.Mol, str], head: Union[dm.Mol, str], tail: Union[dm.Mol, str]
265 | ):
266 | """Link fragments together using the provided linker
267 |
268 | Args:
269 | linker: linker to use
270 | head: head fragment
271 | tail: tail fragment
272 | """
273 | if isinstance(linker, dm.Mol):
274 | linker = dm.to_smiles(linker)
275 | linker = standardize_attach(linker)
276 | reactants = [dm.to_mol(head), dm.to_mol(tail), dm.to_mol(linker)]
277 | return dm.reactions.apply_reaction(
278 | cls._MERGING_RXN, reactants, as_smiles=True, sanitize=True, product_index=0
279 | )
280 |
281 |
282 | @contextmanager
283 | def attr_as(obj: Any, field: str, value: Any):
284 | """Temporary replace the value of an object
285 |
286 | Args:
287 | obj: object to temporary patch
288 | field: name of the key to change
289 | value: value of key to be temporary changed
290 | """
291 | old_value = getattr(obj, field, None)
292 | setattr(obj, field, value)
293 | yield
294 | with suppress(TypeError):
295 | setattr(obj, field, old_value)
296 |
297 |
298 | def _selective_add_hs(mol: dm.Mol, fraction_hs: Optional[bool] = None):
299 | """Custom addition of hydrogens to a molecule
300 | This version of hydrogen bond adding only at max 1 hydrogen per atom
301 |
302 | Args:
303 | mol: molecule to split
304 | fraction_hs: proportion of random atom to which we will add explicit hydrogens
305 | """
306 |
307 | carbon_with_implicit_atoms = mol.GetSubstructMatches(__implicit_carbon_query, uniquify=True)
308 | carbon_with_implicit_atoms = [x[0] for x in carbon_with_implicit_atoms]
309 | carbon_with_implicit_atoms = list(set(carbon_with_implicit_atoms))
310 | # we get a proportion of the carbon we can extend
311 | if fraction_hs is not None and fraction_hs > 0:
312 | fraction_hs = np.ceil(fraction_hs * len(carbon_with_implicit_atoms))
313 | fraction_hs = int(np.clip(fraction_hs, 1, len(carbon_with_implicit_atoms)))
314 | carbon_with_implicit_atoms = random.sample(carbon_with_implicit_atoms, k=fraction_hs)
315 | carbon_with_implicit_atoms = [int(x) for x in carbon_with_implicit_atoms]
316 | emol = EditableMol(mol)
317 | for atom_id in carbon_with_implicit_atoms:
318 | h_atom = emol.AddAtom(Atom("H"))
319 | emol.AddBond(atom_id, h_atom, dm.SINGLE_BOND)
320 | return emol.GetMol()
321 |
322 |
323 | @py_random_state("seed")
324 | def mol_partition(
325 | mol: dm.Mol, query: Optional[dm.Mol] = None, seed: Optional[int] = None, **kwargs: Any
326 | ):
327 | """Partition a molecule into fragments using a bond query
328 |
329 | Args:
330 | mol: molecule to split
331 | query: bond query to use for splitting
332 | seed: random seed
333 | kwargs: additional arguments to pass to the partitioning algorithm
334 |
335 | """
336 | resolution = kwargs.get("resolution", 1.0)
337 | threshold = kwargs.get("threshold", 1e-7)
338 | weight = kwargs.get("weight", "weight")
339 |
340 | if query is None:
341 | query = __mmpa_query
342 |
343 | G = dm.graph.to_graph(mol)
344 | bond_partition = [
345 | tuple(sorted(match)) for match in mol.GetSubstructMatches(query, uniquify=True)
346 | ]
347 |
348 | def get_relevant_edges(e1, e2):
349 | return tuple(sorted([e1, e2])) not in bond_partition
350 |
351 | subgraphs = nx.subgraph_view(G, filter_edge=get_relevant_edges)
352 |
353 | partition = [{u} for u in G.nodes()]
354 | inner_partition = sorted(nx.connected_components(subgraphs), key=lambda x: min(x))
355 | mod = nx.algorithms.community.modularity(
356 | G, inner_partition, resolution=resolution, weight=weight
357 | )
358 | is_directed = G.is_directed()
359 | graph = G.__class__()
360 | graph.add_nodes_from(G)
361 | graph.add_weighted_edges_from(G.edges(data=weight, default=1))
362 | graph = nx.algorithms.community.louvain._gen_graph(graph, inner_partition)
363 | m = graph.size(weight="weight")
364 | partition, inner_partition, improvement = nx.algorithms.community.louvain._one_level(
365 | graph, m, inner_partition, resolution, is_directed, seed
366 | )
367 | improvement = True
368 | while improvement:
369 | # gh-5901 protect the sets in the yielded list from further manipulation here
370 | yield [s.copy() for s in partition]
371 | new_mod = nx.algorithms.community.modularity(
372 | graph, inner_partition, resolution=resolution, weight="weight"
373 | )
374 | if new_mod - mod <= threshold:
375 | return
376 | mod = new_mod
377 | graph = nx.algorithms.community.louvain._gen_graph(graph, inner_partition)
378 | partition, inner_partition, improvement = nx.algorithms.community.louvain._one_level(
379 | graph, m, partition, resolution, is_directed, seed
380 | )
381 |
382 |
383 | def find_partition_edges(G: nx.Graph, partition: List[List]) -> List[Tuple]:
384 | """
385 | Find the edges connecting the subgraphs in a given partition of a graph.
386 |
387 | Args:
388 | G (networkx.Graph): The original graph.
389 | partition (list of list of nodes): The partition of the graph where each element is a list of nodes representing a subgraph.
390 |
391 | Returns:
392 | list: A list of edges connecting the subgraphs in the partition.
393 | """
394 | partition_edges = []
395 | for subgraph1, subgraph2 in combinations(partition, 2):
396 | edges = nx.edge_boundary(G, subgraph1, subgraph2)
397 | partition_edges.extend(edges)
398 | return partition_edges
399 |
400 |
401 | def fragment_aware_spliting(mol: dm.Mol, fraction_hs: Optional[bool] = None, **kwargs: Any):
402 | """Custom splitting algorithm for dataset building.
403 |
404 | This slicing strategy will cut any bond including bonding with hydrogens
405 | However, only one cut per atom is allowed
406 |
407 | Args:
408 | mol: molecule to split
409 | fraction_hs: proportion of random atom to which we will add explicit hydrogens
410 | kwargs: additional arguments to pass to the partitioning algorithm
411 | """
412 | random.seed(kwargs.get("seed", 1))
413 | mol = dm.to_mol(mol, remove_hs=False)
414 | mol = _selective_add_hs(mol, fraction_hs=fraction_hs)
415 | graph = dm.graph.to_graph(mol)
416 | d = mol_partition(mol, **kwargs)
417 | q = deque(d)
418 | partition = q.pop()
419 | return find_partition_edges(graph, partition)
420 |
421 |
422 | def convert_to_safe(
423 | mol: dm.Mol,
424 | canonical: bool = False,
425 | randomize: bool = False,
426 | seed: Optional[int] = 1,
427 | slicer: str = "brics",
428 | split_fragment: bool = True,
429 | fraction_hs: bool = None,
430 | resolution: Optional[float] = 0.5,
431 | ):
432 | """Convert a molecule to a safe representation
433 |
434 | Args:
435 | mol: molecule to convert
436 | canonical: whether to use canonical encoding
437 | randomize: whether to randomize the encoding
438 | seed: random seed
439 | slicer: the slicer to use for fragmentation
440 | split_fragment: whether to split fragments
441 | fraction_hs: proportion of random atom to which we will add explicit hydrogens
442 | resolution: resolution for the partitioning algorithm
443 | seed: random seed
444 | """
445 | x = None
446 | try:
447 | x = sf.encode(mol, canonical=canonical, randomize=randomize, slicer=slicer, seed=seed)
448 | except sf.SAFEFragmentationError:
449 | if split_fragment:
450 | if "." in mol:
451 | return None
452 | try:
453 | x = sf.encode(
454 | mol,
455 | canonical=False,
456 | randomize=randomize,
457 | seed=seed,
458 | slicer=partial(
459 | fragment_aware_spliting,
460 | fraction_hs=fraction_hs,
461 | resolution=resolution,
462 | seed=seed,
463 | ),
464 | )
465 | except (sf.SAFEEncodeError, sf.SAFEFragmentationError):
466 | # logger.exception(e)
467 | return x
468 | # we need to resplit using attachment point but here we are only adding
469 | except sf.SAFEEncodeError:
470 | return x
471 | return x
472 |
473 |
474 | def compute_side_chains(mol: dm.Mol, core: dm.Mol, label_by_index: bool = False):
475 | """Compute the side chain of a molecule given a core
476 |
477 | !!! note "Finding the side chains"
478 | The algorithm to find the side chains from core assumes that the core we get as input has attachment points.
479 | Those attachment points are never considered as part of the query, rather they are used to define the attachment points
480 | on the side chains. Removing the attachment points from the core is exactly the same as keeping them.
481 |
482 | ```python
483 | mol = "CC1=C(C(=NO1)C2=CC=CC=C2Cl)C(=O)NC3C4N(C3=O)C(C(S4)(C)C)C(=O)O"
484 | core0 = "CC1(C)CN2C(CC2=O)S1"
485 | core1 = "CC1(C)SC2C(-*)C(=O)N2C1-*"
486 | core2 = "CC1N2C(SC1(C)C)C(N)C2=O"
487 | side_chain = compute_side_chain(core=core0, mol=mol)
488 | dm.to_image([side_chain, core0, mol])
489 | ```
490 | Therefore on the above, core0 and core1 are equivalent for the molecule `mol`, but core2 is not.
491 |
492 | Args:
493 | mol: molecule to split
494 | core: core to use for deriving the side chains
495 | """
496 |
497 | if isinstance(mol, str):
498 | mol = dm.to_mol(mol)
499 | if isinstance(core, str):
500 | core = dm.to_mol(core)
501 | core_query_param = AdjustQueryParameters()
502 | core_query_param.makeDummiesQueries = True
503 | core_query_param.adjustDegree = False
504 | core_query_param.aromatizeIfPossible = True
505 | core_query_param.makeBondsGeneric = False
506 | core_query = AdjustQueryProperties(core, core_query_param)
507 | return ReplaceCore(
508 | mol, core_query, labelByIndex=label_by_index, replaceDummies=False, requireDummyMatch=False
509 | )
510 |
511 |
512 | def list_individual_attach_points(mol: dm.Mol, depth: Optional[int] = None):
513 | """List all individual attachement points.
514 |
515 | We do not allow multiple attachment points per substitution position.
516 |
517 | Args:
518 | mol: molecule for which we need to open the attachment points
519 |
520 | """
521 | ATTACHING_RXN = ReactionFromSmarts("[*;h;!$([*][#0]):1]>>[*:1][*]")
522 | mols = [mol]
523 | curated_prods = set()
524 | num_attachs = len(mol.GetSubstructMatches(dm.from_smarts("[*;h:1]"), uniquify=True))
525 | depth = depth or 1
526 | depth = min(max(depth, 1), num_attachs)
527 | while depth > 0:
528 | prods = set()
529 | for mol in mols:
530 | mol = dm.to_mol(mol)
531 | for p in ATTACHING_RXN.RunReactants((mol,)):
532 | try:
533 | m = dm.sanitize_mol(p[0])
534 | sm = dm.to_smiles(m, canonical=True)
535 | sm = dm.reactions.add_brackets_to_attachment_points(sm)
536 | prods.add(dm.reactions.convert_attach_to_isotope(sm, as_smiles=True))
537 | except Exception as e:
538 | logger.error(e)
539 | curated_prods.update(prods)
540 | mols = prods
541 | depth -= 1
542 | return list(curated_prods)
543 |
544 |
545 | def filter_by_substructure_constraints(
546 | sequences: List[Union[str, dm.Mol]], substruct: Union[str, dm.Mol], n_jobs: int = -1
547 | ):
548 | """Check whether the input substructures are present in each of the molecule in the sequences
549 |
550 | Args:
551 | sequences: list of molecules to validate
552 | substruct: substructure to use as query
553 | n_jobs: number of jobs to use for parallelization
554 |
555 | """
556 |
557 | if isinstance(substruct, str):
558 | substruct = standardize_attach(substruct)
559 | substruct = dm.from_smarts(substruct)
560 |
561 | def _check_match(mol):
562 | with suppress(Exception):
563 | mol = dm.to_mol(mol)
564 | return mol.HasSubstructMatch(substruct)
565 | return False
566 |
567 | matches = dm.parallelized(_check_match, sequences, n_jobs=n_jobs)
568 | return list(compress(sequences, matches))
569 |
570 |
571 | def standardize_attach(inputs: str, standard_attach: str = "[*]"):
572 | """Standardize the attachment points of a molecule
573 |
574 | Args:
575 | inputs: input molecule
576 | standard_attach: standard attachment point to use
577 | """
578 |
579 | for attach_regex in _SMILES_ATTACHMENT_POINTS:
580 | inputs = re.sub(attach_regex, standard_attach, inputs)
581 | return inputs
582 |
--------------------------------------------------------------------------------
/safe/viz.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from typing import Any, Optional, Tuple, Union
3 |
4 | import datamol as dm
5 | import matplotlib.pyplot as plt
6 |
7 | import safe as sf
8 |
9 |
10 | def to_image(
11 | safe_str: str,
12 | fragments: Optional[Union[str, dm.Mol]] = None,
13 | legend: Union[str, None] = None,
14 | mol_size: Union[Tuple[int, int], int] = (300, 300),
15 | use_svg: Optional[bool] = True,
16 | highlight_mode: Optional[str] = "lasso",
17 | highlight_bond_width_multiplier: int = 12,
18 | **kwargs: Any,
19 | ):
20 | """Display a safe string by highlighting the fragments that make it.
21 |
22 | Args:
23 | safe_str: the safe string to display
24 | fragments: list of fragment to highlight on the molecules. If None, will use safe decomposition of the molecule.
25 | legend: A string to use as the legend under the molecule.
26 | mol_size: The size of the image to be returned
27 | use_svg: Whether to return an svg or png image
28 | highlight_mode: the highlight mode to use. One of ["lasso", "fill", "color"]. If None, no highlight will be shown
29 | highlight_bond_width_multiplier: the multiplier to use for the bond width when using the 'fill' mode
30 | **kwargs: Additional arguments to pass to the drawing function. See RDKit
31 | documentation related to `MolDrawOptions` for more details at
32 | https://www.rdkit.org/docs/source/rdkit.Chem.Draw.rdMolDraw2D.html.
33 |
34 | """
35 |
36 | kwargs["legends"] = legend
37 | kwargs["mol_size"] = mol_size
38 | kwargs["use_svg"] = use_svg
39 | if highlight_bond_width_multiplier is not None:
40 | kwargs["highlightBondWidthMultiplier"] = highlight_bond_width_multiplier
41 |
42 | if highlight_mode == "color":
43 | kwargs["continuousHighlight"] = False
44 | kwargs["circleAtoms"] = kwargs.get("circleAtoms", False) or False
45 |
46 | if isinstance(fragments, (str, dm.Mol)):
47 | fragments = [fragments]
48 |
49 | if fragments is None and highlight_mode is not None:
50 | fragments = [
51 | sf.decode(x, as_mol=False, remove_dummies=False, ignore_errors=False)
52 | for x in safe_str.split(".")
53 | ]
54 | elif fragments and len(fragments) > 0:
55 | parsed_fragments = []
56 | for fg in fragments:
57 | if isinstance(fg, str) and dm.to_mol(fg) is None:
58 | fg = sf.decode(fg, as_mol=False, remove_dummies=False, ignore_errors=False)
59 | parsed_fragments.append(fg)
60 | fragments = parsed_fragments
61 | else:
62 | fragments = []
63 | mol = dm.to_mol(safe_str, remove_hs=False)
64 | cm = plt.get_cmap("gist_rainbow")
65 | current_colors = [cm(1.0 * i / len(fragments)) for i in range(len(fragments))]
66 |
67 | if highlight_mode == "lasso":
68 | return dm.viz.lasso_highlight_image(mol, fragments, **kwargs)
69 |
70 | atom_indices = []
71 | bond_indices = []
72 | atom_colors = {}
73 | bond_colors = {}
74 |
75 | for i, frag in enumerate(fragments):
76 | frag = dm.from_smarts(frag)
77 | atom_matches, bond_matches = dm.substructure_matching_bonds(mol, frag)
78 | atom_matches = list(itertools.chain(*atom_matches))
79 | bond_matches = list(itertools.chain(*bond_matches))
80 | atom_indices.extend(atom_matches)
81 | bond_indices.extend(bond_matches)
82 | atom_colors.update({x: current_colors[i] for x in atom_matches})
83 | bond_colors.update({x: current_colors[i] for x in bond_matches})
84 |
85 | return dm.viz.to_image(
86 | mol,
87 | highlight_atom=[atom_indices],
88 | highlight_bond=[bond_indices],
89 | highlightAtomColors=[atom_colors],
90 | highlightBondColors=[bond_colors],
91 | **kwargs,
92 | )
93 |
--------------------------------------------------------------------------------
/tests/test_hgf_load.py:
--------------------------------------------------------------------------------
1 | from safe.sample import SAFEDesign
2 | from safe.tokenizer import SAFETokenizer
3 | from safe.trainer.model import SAFEDoubleHeadsModel
4 |
5 |
6 | def test_load_default_safe_model():
7 | model = SAFEDoubleHeadsModel.from_pretrained("datamol-io/safe-gpt")
8 | assert model is not None
9 | assert isinstance(model, SAFEDoubleHeadsModel)
10 |
11 |
12 | def test_load_default_safe_tokenizer():
13 | tokenizer = SAFETokenizer.from_pretrained("datamol-io/safe-gpt")
14 | assert isinstance(tokenizer, SAFETokenizer)
15 |
16 |
17 | def test_check_molecule_sampling():
18 | designer = SAFEDesign.load_default(verbose=True)
19 | generated = designer.de_novo_generation(sanitize=True, n_samples_per_trial=10)
20 | assert len(generated) > 0
21 |
--------------------------------------------------------------------------------
/tests/test_import.py:
--------------------------------------------------------------------------------
1 | def test_import():
2 | import safe
3 |
--------------------------------------------------------------------------------
/tests/test_notebooks.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | import nbformat
4 | import pytest
5 | from nbconvert.preprocessors.execute import ExecutePreprocessor
6 |
7 | ROOT_DIR = pathlib.Path(__file__).parent.resolve()
8 |
9 | TUTORIALS_DIR = ROOT_DIR.parent / "docs" / "tutorials"
10 | DISABLE_NOTEBOOKS = []
11 | NOTEBOOK_PATHS = sorted(TUTORIALS_DIR.glob("*.ipynb"))
12 | NOTEBOOK_PATHS = list(filter(lambda x: x.name not in DISABLE_NOTEBOOKS, NOTEBOOK_PATHS))
13 |
14 | # Discard some notebooks
15 | NOTEBOOKS_TO_DISCARD = ["extracting-representation-molfeat.ipynb", "load-from-wandb.ipynb"]
16 | NOTEBOOK_PATHS = list(filter(lambda x: x.name not in NOTEBOOKS_TO_DISCARD, NOTEBOOK_PATHS))
17 |
18 |
19 | @pytest.mark.parametrize("nb_path", NOTEBOOK_PATHS, ids=[str(n.name) for n in NOTEBOOK_PATHS])
20 | def test_notebook(nb_path):
21 | # Setup and configure the processor to execute the notebook
22 | ep = ExecutePreprocessor(timeout=600, kernel_name="python3")
23 |
24 | # Open the notebook
25 | with open(nb_path) as f:
26 | nb = nbformat.read(f, as_version=nbformat.NO_CONVERT)
27 |
28 | # Execute the notebook
29 | ep.preprocess(nb, {"metadata": {"path": TUTORIALS_DIR}})
30 |
--------------------------------------------------------------------------------
/tests/test_safe.py:
--------------------------------------------------------------------------------
1 | import datamol as dm
2 | import numpy as np
3 | import pytest
4 |
5 | import safe
6 |
7 |
8 | def test_safe_encoding():
9 | celecoxib = "Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1"
10 | expected_encodings = "c13ccc(S(N)(=O)=O)cc1.Cc1ccc4cc1.c14cc5nn13.C5(F)(F)F"
11 | safe_celecoxib = safe.encode(celecoxib, canonical=True)
12 | dec_celecoxib = safe.decode(safe_celecoxib)
13 | assert safe_celecoxib.count(".") == 3 # 3 fragments
14 | # we compare length since digits can be random
15 | assert len(safe_celecoxib) == len(expected_encodings)
16 | assert dm.same_mol(celecoxib, safe_celecoxib)
17 | assert dm.same_mol(celecoxib, dec_celecoxib)
18 |
19 |
20 | def test_safe_fragment_randomization():
21 | celecoxib = "Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1"
22 | safe_celecoxib = safe.encode(celecoxib)
23 | fragments = safe_celecoxib.split(".")
24 | randomized_fragment_safe_str = np.random.permutation(fragments).tolist()
25 | randomized_fragment_safe_str = ".".join(randomized_fragment_safe_str)
26 | assert dm.same_mol(celecoxib, randomized_fragment_safe_str)
27 |
28 |
29 | def test_randomized_encoder():
30 | celecoxib = "Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1"
31 | output = set()
32 | for i in range(5):
33 | out = safe.encode(celecoxib, canonical=False, randomize=True, seed=i)
34 | output.add(out)
35 | assert len(output) > 1
36 |
37 |
38 | def test_custom_encoder():
39 | smart_slicer = ["[r]-;!@[r]"]
40 | celecoxib = "Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1"
41 | safe_str = safe.encode(celecoxib, canonical=True, slicer=smart_slicer)
42 | assert dm.same_mol(celecoxib, safe_str)
43 |
44 |
45 | def test_safe_decoder():
46 | celecoxib = "Cc1ccc(-c2cc(C(F)(F)F)nn2-c2ccc(S(N)(=O)=O)cc2)cc1"
47 | safe_str = safe.encode(celecoxib)
48 | fragments = safe_str.split(".")
49 | decoded_fragments = [safe.decode(x, fix=True) for x in fragments]
50 | assert [dm.to_mol(x) for x in fragments] == [None] * len(fragments)
51 | assert all(x is not None for x in decoded_fragments)
52 |
53 |
54 | def test_rdkit_smiles_parser_issues():
55 | # see https://github.com/datamol-io/safe/issues/22
56 | input_sm = r"C(=C/c1ccccc1)\CCc1ccccc1"
57 | slicer = "brics"
58 | safe_obj = safe.SAFEConverter(slicer=slicer, require_hs=False)
59 | with dm.without_rdkit_log():
60 | failing_encoded = safe_obj.encoder(
61 | input_sm,
62 | canonical=True,
63 | randomize=False,
64 | rdkit_safe=False,
65 | )
66 | working_encoded = safe_obj.encoder(
67 | input_sm,
68 | canonical=True,
69 | randomize=False,
70 | rdkit_safe=True,
71 | )
72 | working_decoded = safe.decode(working_encoded)
73 | working_no_stero = dm.remove_stereochemistry(dm.to_mol(input_sm))
74 | input_mol = dm.remove_stereochemistry(dm.to_mol(working_decoded))
75 | assert safe.decode(failing_encoded) is None
76 | assert working_decoded is not None
77 | assert dm.same_mol(working_no_stero, input_mol)
78 |
79 |
80 | @pytest.mark.parametrize(
81 | "input_sm",
82 | [
83 | "O=C(CN1CC[NH2+]CC1)N1CCCCC1",
84 | "[NH3+]Cc1ccccc1",
85 | "c1cc2c(cc1[C@@H]1CCC[NH2+]1)OCCO2",
86 | "[13C]1CCCCC1C[238U]C[NH3+]",
87 | "COC[CH2:1][CH2:2]O[CH:2]C[OH:3]",
88 | ],
89 | )
90 | def test_bracket_smiles_issues(input_sm):
91 | slicer = "brics"
92 | safe_obj = safe.SAFEConverter(slicer=slicer, require_hs=False)
93 | fragments = []
94 | with dm.without_rdkit_log():
95 | safe_str = safe_obj.encoder(
96 | input_sm,
97 | canonical=True,
98 | )
99 | for fragment in safe_str.split("."):
100 | f = safe_obj.decoder(
101 | fragment,
102 | as_mol=False,
103 | canonical=True,
104 | fix=True,
105 | remove_dummies=True,
106 | remove_added_hs=True,
107 | )
108 | fragments.append(f)
109 | input_mol = dm.to_mol(input_sm)
110 | assert safe.decode(safe_str) is not None
111 | assert dm.same_mol(dm.to_mol(safe_str), input_mol)
112 | assert None not in fragments
113 |
114 |
115 | def test_fused_ring_issue():
116 | FUSED_RING_LIST = [
117 | "[H][C@@]12CC[C@@]3(CCC(=O)O3)[C@@]1(C)CC[C@@]1([H])[C@@]2([H])[C@@]([H])(CC2=CC(=O)CC[C@]12C)SC(C)=O",
118 | "[H][C@@]12C[C@H](C)[C@](OC(=O)CC)(C(=O)COC(=O)CC)[C@@]1(C)C[C@H](O)[C@@]1(Cl)[C@@]2([H])CCC2=CC(=O)C=C[C@]12C",
119 | "[H][C@@]12CC[C@@](O)(C#C)[C@@]1(CC)CC[C@]1([H])[C@@]3([H])CCC(=O)C=C3CC[C@@]21[H]",
120 | ]
121 | for fused_ring in FUSED_RING_LIST:
122 | output_string = safe.decode(safe.encode(fused_ring))
123 | assert dm.same_mol(fused_ring, output_string)
124 |
125 |
126 | def test_stereochemistry_issue():
127 | STEREO_MOL_LIST = [
128 | "CC(=C\\c1ccccc1)/N=C/C(=O)O",
129 | "CC(=C/c1ccccc1)/N=C/C(=O)O",
130 | "CC(=C\\c1ccccc1)/N=C\\C(=O)O",
131 | "CC(=C/c1ccccc1)/N=C\\C(=O)O",
132 | "CC(=Cc1ccccc1)N=CC(=O)O",
133 | "Cc1ccc(-n2c(C)cc(/C=N/Nc3ccc([N+](=O)[O-])cn3)c2C)c(C)c1",
134 | "Cc1ccc(-n2c(C)cc(/C=N\\Nc3ccc([N+](=O)[O-])cn3)c2C)c(C)c1",
135 | ]
136 | for mol in STEREO_MOL_LIST:
137 | output_string = safe.encode(mol, ignore_stereo=False, slicer="rotatable")
138 | assert dm.same_mol(mol, output_string)
139 |
140 | # now let's test failure case where we fail because we split on a double bond
141 | output = safe.encode(STEREO_MOL_LIST[0], ignore_stereo=False, slicer="brics")
142 | assert dm.same_mol(STEREO_MOL_LIST[0], output) is False
143 | same_stereo = [dm.remove_stereochemistry(dm.to_mol(x)) for x in [output, STEREO_MOL_LIST[0]]]
144 | assert dm.same_mol(same_stereo[0], same_stereo[1])
145 |
146 | # check if we ignore the stereo
147 | output = safe.encode(STEREO_MOL_LIST[0], ignore_stereo=True, slicer="brics")
148 | assert dm.same_mol(dm.remove_stereochemistry(dm.to_mol(STEREO_MOL_LIST[0])), output)
149 |
--------------------------------------------------------------------------------