├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── documentation.yml │ ├── feature-request.yml │ └── generic-issue.yml └── workflows │ └── docker.yml ├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── conf.py │ ├── data.rst │ ├── examples.md │ ├── features.md │ ├── index.rst │ ├── installation.md │ ├── models.rst │ ├── quickstart.md │ ├── requirements.md │ ├── school.md │ ├── task.rst │ ├── train.rst │ ├── utils.rst │ └── weight.rst ├── unimol ├── README.md ├── docker │ └── Dockerfile ├── example_data │ ├── molecule │ │ ├── dict.txt │ │ ├── train.lmdb │ │ └── valid.lmdb │ └── pocket │ │ ├── dict_coarse.txt │ │ ├── train.lmdb │ │ └── valid.lmdb ├── figure │ └── overview.png ├── notebooks │ ├── mol_property_demo.csv │ ├── unimol_binding_pose_demo.ipynb │ ├── unimol_mol_property_demo.ipynb │ ├── unimol_mol_repr_demo.ipynb │ ├── unimol_pocket_repr_demo.ipynb │ └── unimol_posebuster_demo.ipynb ├── requirements.txt ├── setup.py └── unimol │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── add_2d_conformer_dataset.py │ ├── atom_type_dataset.py │ ├── conformer_sample_dataset.py │ ├── coord_pad_dataset.py │ ├── cropping_dataset.py │ ├── data_utils.py │ ├── distance_dataset.py │ ├── from_str_dataset.py │ ├── key_dataset.py │ ├── lmdb_dataset.py │ ├── mask_points_dataset.py │ ├── normalize_dataset.py │ ├── prepend_and_append_2d_dataset.py │ ├── remove_hydrogen_dataset.py │ └── tta_dataset.py │ ├── infer.py │ ├── losses │ ├── __init__.py │ ├── conf_gen.py │ ├── cross_entropy.py │ ├── docking_pose.py │ ├── reg_loss.py │ └── unimol.py │ ├── models │ ├── __init__.py │ ├── conf_gen.py │ ├── docking_pose.py │ ├── transformer_encoder_with_pair.py │ └── unimol.py │ ├── tasks │ ├── __init__.py │ ├── docking_pose.py │ ├── unimol.py │ ├── unimol_conf_gen.py │ ├── unimol_finetune.py │ ├── unimol_pocket.py │ └── unimol_pocket_finetune.py │ └── utils │ ├── __init__.py │ ├── conf_gen_cal_metrics.py │ ├── conformer_model.py │ ├── coordinate_model.py │ ├── docking.py │ └── docking_utils.py ├── unimol2 ├── README.md ├── docker │ └── Dockerfile ├── figure │ ├── predicted_loss.jpg │ └── unimol2_arch.jpg ├── requirements.txt └── unimol2 │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── add_2d_conformer_dataset.py │ ├── conformer_sample_dataset.py │ ├── cropping_dataset.py │ ├── data_utils.py │ ├── graph_features.py │ ├── index_atom_dataset.py │ ├── key_dataset.py │ ├── lmdb_dataset.py │ ├── molecule_dataset.py │ ├── noised_points_dataset.py │ ├── normalize_dataset.py │ ├── remove_hydrogen_dataset.py │ ├── tta_dataset.py │ └── unimol2_dataset.py │ ├── infer.py │ ├── losses │ ├── __init__.py │ ├── cross_entropy.py │ ├── reg_loss.py │ └── unimol2.py │ ├── models │ ├── __init__.py │ ├── layers.py │ ├── transformer_encoder_with_pair.py │ └── unimol2.py │ └── tasks │ ├── __init__.py │ ├── unimol2.py │ └── unimol_finetune.py ├── unimol_docking_v2 ├── README.md ├── docker │ └── Dockerfile ├── example_data │ ├── dict_mol.txt │ ├── dict_pkt.txt │ ├── docking_grid.json │ ├── ligand.sdf │ └── protein.pdb ├── figure │ └── bohrium_app.gif ├── interface │ ├── demo.py │ ├── demo.sh │ ├── demo_batch_one2one.sh │ ├── input_batch_one2one.csv │ ├── posebuster_demo.ipynb │ └── predictor │ │ ├── __init__.py │ │ ├── processor.py │ │ └── unimol_predictor.py ├── train.sh └── unimol │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── conformer_sample_dataset.py │ ├── coord_pad_dataset.py │ ├── cropping_dataset.py │ ├── data_utils.py │ ├── distance_dataset.py │ ├── key_dataset.py │ ├── lmdb_dataset.py │ ├── normalize_dataset.py │ ├── prepend_and_append_2d_dataset.py │ ├── realign_ligand_dataset.py │ ├── remove_hydrogen_dataset.py │ └── tta_dataset.py │ ├── infer.py │ ├── losses │ ├── __init__.py │ └── docking_pose_v2.py │ ├── models │ ├── __init__.py │ ├── docking_pose_v2.py │ ├── transformer_encoder_with_pair.py │ └── unimol.py │ ├── scripts │ └── 6tsr.py │ └── tasks │ ├── __init__.py │ └── docking_pose_v2.py ├── unimol_plus ├── README.md ├── docker │ └── Dockerfile ├── figure │ └── overview.png ├── inference.py ├── inference.sh ├── scripts │ ├── download.sh │ ├── get_3d_lmdb.py │ ├── get_label3d_lmdb.py │ ├── make_oc20_test_submission.py │ ├── make_pcq_test_dev_submission.py │ └── oc20_preprocess.py ├── setup.py ├── train_oc20.sh ├── train_pcq.sh └── unimol_plus │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── conformer_sample_dataset.py │ ├── data_utils.py │ ├── key_dataset.py │ ├── lmdb_dataset.py │ ├── oc20_dataset.py │ └── pcq_dataset.py │ ├── losses │ ├── __init__.py │ └── unimol_plus.py │ ├── models │ ├── __init__.py │ ├── layers.py │ ├── unimol_plus_encoder.py │ ├── unimol_plus_oc20.py │ └── unimol_plus_pcq.py │ └── tasks │ ├── __init__.py │ ├── oc20.py │ └── pcq.py └── unimol_tools ├── MANIFEST.in ├── README.md ├── requirements.txt ├── setup.py └── unimol_tools ├── __init__.py ├── config ├── __init__.py ├── default.yaml └── model_config.py ├── data ├── __init__.py ├── conformer.py ├── datahub.py ├── datareader.py ├── datascaler.py ├── dictionary.py └── split.py ├── models ├── __init__.py ├── loss.py ├── nnmodel.py ├── transformers.py ├── transformersv2.py ├── unimol.py └── unimolv2.py ├── predict.py ├── predictor.py ├── tasks ├── __init__.py └── trainer.py ├── train.py ├── utils ├── __init__.py ├── base_logger.py ├── config_handler.py ├── metrics.py └── util.py └── weights ├── __init__.py └── weighthub.py /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Create a bug report to help us eliminate issues and improve Uni-Mol. 3 | title: "[BUG] _Replace With Suitable Title_" 4 | labels: bug 5 | assignees: [] 6 | body: 7 | - type: textarea 8 | attributes: 9 | label: Describe the bug 10 | description: | 11 | A clear and concise description of what the bug is. 12 | validations: 13 | required: true 14 | 15 | - type: input 16 | id: version 17 | attributes: 18 | label: Uni-Mol Version 19 | description: "The version of Uni-Mol you are using, e.g. Uni-Mol, Uni-Mol2, Uni-Mol Tools, etc." 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | attributes: 25 | label: Expected behavior 26 | description: | 27 | A clear and concise description of what you expected to happen. 28 | validations: 29 | required: true 30 | 31 | - type: textarea 32 | attributes: 33 | label: To Reproduce 34 | description: | 35 | It is recommended to attach your data, cases here for the developers to reproduce the bug. 36 | 37 | - type: textarea 38 | attributes: 39 | label: Environment 40 | description: | 41 | - OS: [e.g. Ubuntu 20.04] 42 | - Dependencies: [e.g. PyTorch, SciPy, NumPy, ...] 43 | 44 | - type: textarea 45 | attributes: 46 | label: Additional Context 47 | description: | 48 | Add any other context about the problem here. 49 | 50 | - type: markdown 51 | attributes: 52 | value: > 53 | Thanks for contributing 🎉! 54 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | description: For users or developers to report issues related to software documentation, such as missing or incomplete documentation, or documentation that is difficult to understand. 3 | labels: documentation 4 | assignees: [] 5 | body: 6 | - type: textarea 7 | attributes: 8 | label: Details 9 | description: | 10 | Please provide details about the documentation issue you are experiencing. Include any relevant information such as the specific section of the documentation, the information that is missing or unclear, and any suggestions for improvement. 11 | validations: 12 | required: true 13 | - type: markdown 14 | attributes: 15 | value: > 16 | Thank you for reporting this documentation issue. We will review and update the documentation as necessary. 📚 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest an idea for this project. 3 | title: "[Feature Request] _Replace with Title_" 4 | labels: enhancement 5 | assignees: [] 6 | body: 7 | - type: textarea 8 | id: summary 9 | attributes: 10 | label: Summary 11 | description: "Please provide a brief and concise description of the suggested feature or change" 12 | placeholder: 13 | value: 14 | validations: 15 | required: true 16 | - type: textarea 17 | id: details 18 | attributes: 19 | label: Detailed Description 20 | description: "Please explain how you would like to see Uni-Mol enhanced, what feature(s) you are looking for, what specific problems this will solve. If possible, provide references to relevant background information like publications or web pages, and whether you are planning to implement the enhancement yourself or would like to participate in the implementation. If applicable add a reference to an existing bug report or issue that this will address." 21 | placeholder: 22 | value: 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: further 27 | attributes: 28 | label: Further Information, Files, and Links 29 | description: "Put any additional information here, attach relevant text or image files and URLs to external sites, e.g. relevant publications" 30 | placeholder: 31 | value: 32 | validations: 33 | required: false 34 | - type: markdown 35 | attributes: 36 | value: > 37 | Thanks for contributing 🎉! 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/generic-issue.yml: -------------------------------------------------------------------------------- 1 | name: Generic issue 2 | description: For issues that do not fit any of the other categories. 3 | title: _Replace With a Descriptive Title_ 4 | labels: wontfix 5 | assignees: [] 6 | body: 7 | - type: textarea 8 | id: summary 9 | attributes: 10 | label: Summary 11 | description: "Please provide a clear and concise description of what the question is." 12 | placeholder: 13 | value: 14 | validations: 15 | required: true 16 | - type: input 17 | id: version 18 | attributes: 19 | label: Uni-Mol Version 20 | description: "The version of Uni-Mol you are using, e.g. Uni-Mol, Uni-Mol2, Uni-Mol Tools, etc." 21 | validations: 22 | required: true 23 | - type: textarea 24 | id: details 25 | attributes: 26 | label: Details 27 | description: "Please explain the issue in detail here." 28 | placeholder: 29 | value: 30 | validations: 31 | required: true 32 | - type: markdown 33 | attributes: 34 | value: > 35 | Thanks for contributing 🎉! 36 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Build and Publish Docker 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | docker: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - 13 | name: Checkout 14 | uses: actions/checkout@v3 15 | - 16 | name: Set up QEMU 17 | uses: docker/setup-qemu-action@v2 18 | - 19 | name: Set up Docker Buildx 20 | uses: docker/setup-buildx-action@v2 21 | - 22 | name: Login to DockerHub 23 | uses: docker/login-action@v2 24 | with: 25 | username: ${{ secrets.DOCKERHUB_USERNAME }} 26 | password: ${{ secrets.DOCKERHUB_TOKEN }} 27 | - 28 | name: Free Disk Space (Ubuntu) 29 | uses: jlumbroso/free-disk-space@main 30 | with: 31 | # this might remove tools that are actually needed, 32 | # if set to "true" but frees about 6 GB 33 | tool-cache: false 34 | 35 | # all of these default to true, but feel free to set to 36 | # "false" if necessary for your workflow 37 | android: true 38 | dotnet: true 39 | haskell: true 40 | large-packages: true 41 | docker-images: false 42 | swap-storage: false 43 | - 44 | name: Set up swap space 45 | uses: pierotofy/set-swap-space@v1.0 46 | with: 47 | swap-size-gb: 10 48 | - 49 | name: Build and push with rdma 50 | uses: docker/build-push-action@v3 51 | with: 52 | context: ./unimol/docker/ 53 | push: true 54 | tags: dptechnology/unimol:latest-pytorch1.11.0-cuda11.3 55 | 56 | publish_package: 57 | name: Publish package 58 | needs: [docker] 59 | 60 | runs-on: ubuntu-latest 61 | steps: 62 | - uses: actions/checkout@v3 63 | 64 | - uses: actions/setup-python@v4 65 | with: 66 | python-version: '3.10' 67 | 68 | - name: Build core package 69 | env: 70 | FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE" 71 | run: | 72 | pip install setuptools wheel twine 73 | cd unimol_tools 74 | python setup.py sdist --dist-dir=dist 75 | 76 | - name: Deploy 77 | env: 78 | TWINE_USERNAME: "__token__" 79 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 80 | run: | 81 | cd unimol_tools 82 | python -m twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.12" 13 | # You can also specify other tool versions: 14 | # nodejs: "20" 15 | # rust: "1.70" 16 | # golang: "1.20" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | 23 | # Optional but recommended, declare the Python requirements required 24 | # to build your documentation 25 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 26 | python: 27 | install: 28 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 DP Technology 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Packages required to build docs, independent of application dependencies 2 | 3 | -e ./unimol_tools 4 | 5 | sphinx 6 | 7 | sphinx_rtd_theme==2.0.0rc2 8 | sphinx-tabs 9 | sphinx-intl 10 | sphinx-design 11 | sphinx-multiproject 12 | 13 | # RTD deps :) 14 | readthedocs-sphinx-search 15 | sphinx-hoverxref 16 | sphinx-notfound-page 17 | 18 | # Docs 19 | sphinxemoji 20 | sphinxcontrib-httpdomain 21 | sphinx-prompt 22 | sphinx-autobuild 23 | sphinxext-opengraph 24 | sphinx-copybutton 25 | 26 | # Markdown 27 | myst_parser -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = 'Uni-Mol' 10 | copyright = '2023, cuiyaning' 11 | author = 'cuiyaning' 12 | release = '0.1.1' 13 | 14 | # -- General configuration --------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 16 | 17 | extensions = [ 18 | 'sphinx.ext.autodoc', 19 | 'sphinx.ext.viewcode', 20 | 'myst_parser', 21 | ] 22 | 23 | templates_path = ['_templates'] 24 | exclude_patterns = [] 25 | 26 | highlight_language = 'python' 27 | 28 | 29 | # List of modules to be mocked up. This is useful when some external 30 | # dependencies are not met at build time and break the building process. 31 | autodoc_mock_imports = [ 32 | 'rdkit', 33 | 'unicore', 34 | 'torch', 35 | 'sklearn', 36 | ] 37 | 38 | 39 | # -- Options for HTML output ------------------------------------------------- 40 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 41 | 42 | html_theme = 'sphinx_rtd_theme' 43 | html_static_path = ['_static'] 44 | 45 | # -- Autodoc configuration --------------------------------------------------- 46 | 47 | autoclass_content = 'class' 48 | 49 | # 显式地设置成员的顺序,确保构造函数的参数首先显示 50 | autodoc_member_order = 'bysource' 51 | 52 | # 设置构造函数的默认选项,包括显示参数 53 | 54 | autodoc_default_options = { 55 | 'members': True, 56 | 'special-members': '__init__', 57 | #'undoc-members': False, 58 | 'private-members': True, 59 | #'show-inheritance': False, 60 | } 61 | -------------------------------------------------------------------------------- /docs/source/data.rst: -------------------------------------------------------------------------------- 1 | Data 2 | ==== 3 | 4 | `unimol_tools.data `_ contains functions and classes for loading, containing, and scaler data, feature. 5 | 6 | DataHub 7 | ------- 8 | 9 | Classes and functions from `unimol_tools.data.datahub.py `_. 10 | 11 | .. automodule:: unimol_tools.data.datahub 12 | :members: 13 | 14 | Datareader 15 | ---------- 16 | 17 | Classes and functions from `unimol_tools.data.datareader.py `_. 18 | 19 | .. automodule:: unimol_tools.data.datareader 20 | :members: 21 | 22 | Datascaler 23 | ----------- 24 | 25 | Classes and functions from `unimol_tools.data.datascaler.py `_. 26 | 27 | .. automodule:: unimol_tools.data.datascaler 28 | :members: 29 | 30 | Conformer 31 | --------- 32 | 33 | Classes and functions from `unimol_tools.data.conformer.py `_. 34 | 35 | .. automodule:: unimol_tools.data.conformer 36 | :members: 37 | 38 | Split 39 | ------- 40 | 41 | `unimol_tools.data.split.py `_ manages the split methods in the dataset. 42 | 43 | .. automodule:: unimol_tools.data.split 44 | :members: -------------------------------------------------------------------------------- /docs/source/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Welcome to the examples section! On our platform Bohrium, we offer a variety of notebook cases for studying Uni-Mol. These notebooks provide practical examples and applications of Uni-Mol in different scientific fields. You can explore these notebooks to gain hands-on experience and deepen your understanding of Uni-Mol. 4 | 5 | ## Uni-Mol Notebooks on Bohrium 6 | Explore our collection of Uni-Mol notebooks on Bohrium: [Uni-Mol Notebooks](https://bohrium.dp.tech/search?searchKey=UniMol&%3BactiveTab=notebook&activeTab=notebook) 7 | 8 | ### Uni-Mol for QSAR (Quantitative Structure-Activity Relationship) 9 | Uni-Mol can be used to predict the biological activity of compounds based on their chemical structure. These notebooks demonstrate how to apply Uni-Mol for QSAR tasks: 10 | - [QSAR Example 1](https://bohrium.dp.tech/notebooks/7141701322) 11 | - [QSAR Example 2](https://bohrium.dp.tech/notebooks/9919429887) 12 | 13 | ### Uni-Mol for OLED Properties Predictions 14 | Organic Light Emitting Diodes (OLEDs) are used in various display technologies. Uni-Mol can predict the properties of OLED molecules, aiding in the design of more efficient materials. Check out these notebooks for detailed examples: 15 | - [OLED Properties Prediction Example 1](https://bohrium.dp.tech/notebooks/2412844127) 16 | - [OLED Properties Prediction Example 2](https://bohrium.dp.tech/notebooks/7637046852) 17 | 18 | ### Uni-Mol Predicts Liquid Flow Battery Solubility 19 | Liquid flow batteries are a promising technology for energy storage. Uni-Mol can predict the solubility of compounds used in these batteries, helping to optimize their performance. Explore this notebook to see how Uni-Mol is applied in this context: 20 | - [Liquid Flow Battery Solubility Prediction](https://bohrium.dp.tech/notebooks/7941779831) 21 | 22 | These examples provide a glimpse into the powerful capabilities of Uni-Mol in various scientific applications. We encourage you to explore these notebooks and experiment with Uni-Mol to discover its full potential. -------------------------------------------------------------------------------- /docs/source/features.md: -------------------------------------------------------------------------------- 1 | # New Features 2 | 3 | ## 2025-03-28 4 | Unimol_tools now support Distributed Data Parallel (DDP)! 5 | 6 | ## 2024-11-22 7 | Unimol V2 has been added to Unimol_tools! 8 | 9 | ## 2024-06-25 10 | 11 | Unimol_tools has been publish to pypi! Huggingface has been used to manage the pretrain models. 12 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. UniMol documentation master file, created by 2 | sphinx-quickstart on Wed Nov 29 03:53:18 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Uni-Mol' documentation! 7 | ========================================== 8 | 9 | Uni-Mol is the first universal large-scale three-dimensional Molecular Representation Learning (MRL) framework developed by the DP Technology. It expands the application scope and representation capabilities of MRL. 10 | 11 | This framework consists of two models, one trained on billions of molecular three-dimensional conformations and the other on millions of protein pocket data. 12 | 13 | It has shown excellent performance in various molecular property prediction tasks, especially in 3D-related tasks, where it demonstrates significant performance. In addition to drug design, Uni-Mol can also predict the properties of materials, such as the gas adsorption performance of MOF materials and the optical properties of OLED molecules. 14 | 15 | .. Important:: 16 | 17 | The project Uni-Mol is licensed under `MIT LICENSE `_. 18 | If you use Uni-Mol in your research, please kindly cite the following works: 19 | 20 | - Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. "Uni-Mol: A Universal 3D Molecular Representation Learning Framework." The Eleventh International Conference on Learning Representations, 2023. `https://openreview.net/forum?id=6K2RM6wVqKu `_. 21 | - Shuqi Lu, Zhifeng Gao, Di He, Linfeng Zhang, Guolin Ke. "Data-driven quantum chemical property prediction leveraging 3D conformations with Uni-Mol+." Nature Communications, 2024. `https://www.nature.com/articles/s41467-024-51321-w `_. 22 | 23 | 24 | Uni-Mol tools is a easy-use wrappers for property prediction,representation and downstreams with Uni-Mol. It includes the following tools: 25 | 26 | * molecular property prediction with Uni-Mol. 27 | * molecular representation with Uni-Mol. 28 | * other downstreams with Uni-Mol. 29 | 30 | .. toctree:: 31 | :maxdepth: 2 32 | :caption: Getting Started: 33 | 34 | requirements 35 | installation 36 | 37 | .. toctree:: 38 | :maxdepth: 2 39 | :caption: Tutorials: 40 | 41 | quickstart 42 | school 43 | examples 44 | 45 | .. toctree:: 46 | :maxdepth: 2 47 | :caption: Uni-Mol tools: 48 | 49 | train 50 | data 51 | models 52 | task 53 | utils 54 | weight 55 | features 56 | 57 | 58 | Indices and tables 59 | ================== 60 | 61 | * :ref:`genindex` 62 | * :ref:`modindex` 63 | * :ref:`search` 64 | -------------------------------------------------------------------------------- /docs/source/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Install 4 | - pytorch is required, please install pytorch according to your environment. if you are using cuda, please install pytorch with cuda. More details can be found at https://pytorch.org/get-started/locally/ 5 | - currently, rdkit needs with numpy<2.0.0, please install rdkit with numpy<2.0.0. 6 | 7 | ### Option 1: Installing from PyPi (Recommended) 8 | 9 | ```bash 10 | pip install unimol_tools 11 | ``` 12 | 13 | We recommend installing ```huggingface_hub``` so that the required unimol models can be automatically downloaded at runtime! It can be install by 14 | 15 | ```bash 16 | pip install huggingface_hub 17 | ``` 18 | 19 | `huggingface_hub` allows you to easily download and manage models from the Hugging Face Hub, which is key for using Uni-Mol models. 20 | 21 | ### Option 2: Installing from source 22 | 23 | ```python 24 | ## Dependencies installation 25 | pip install -r requirements.txt 26 | 27 | ## Clone repository 28 | git clone https://github.com/deepmodeling/Uni-Mol.git 29 | cd Uni-Mol/unimol_tools 30 | 31 | ## Install 32 | python setup.py install 33 | ``` 34 | 35 | ### Models in Huggingface 36 | 37 | The Uni-Mol pretrained models can be found at [dptech/Uni-Mol-Models](https://huggingface.co/dptech/Uni-Mol-Models/tree/main). 38 | 39 | If the download is slow, you can use other mirrors, such as: 40 | 41 | ```bash 42 | export HF_ENDPOINT=https://hf-mirror.com 43 | ``` 44 | 45 | Setting the `HF_ENDPOINT` environment variable specifies the mirror address for the Hugging Face Hub to use when downloading models. 46 | 47 | ## Bohrium notebook 48 | 49 | Uni-Mol images can be avaliable on the online notebook platform [Bohirum notebook](https://nb.bohrium.dp.tech/). 50 | -------------------------------------------------------------------------------- /docs/source/models.rst: -------------------------------------------------------------------------------- 1 | .. _models: 2 | 3 | Models 4 | ====== 5 | 6 | `unimol_tools.models `_ contains the models of Uni-Mol. 7 | 8 | 9 | Uni-Mol 10 | ------- 11 | 12 | `unimol_tools.models.unimol.py `_ contains the :class:`~unimol_tools.models.UniMolModel`, which is the backbone of Uni-Mol model. 13 | 14 | .. automodule:: unimol_tools.models.unimol 15 | :members: 16 | 17 | Model 18 | ----- 19 | 20 | `unimol_tools.models.nnmodel.py `_ contains the :class:`~unimol_tools.models.NNModel`, which is responsible for initializing the model. 21 | 22 | .. automodule:: unimol_tools.models.nnmodel 23 | :members: 24 | 25 | Loss 26 | ----- 27 | 28 | `unimol_tools.models.loss.py `_ contains different loss functions. 29 | 30 | .. automodule:: unimol_tools.models.loss 31 | :members: 32 | 33 | Transformers 34 | ------------ 35 | 36 | `unimol_tools.models.transformers.py `_ contains a custom Transformer Encoder module that extends `PyTorch's nn.Module `_. 37 | 38 | .. automodule:: unimol_tools.models.transformers 39 | :members: -------------------------------------------------------------------------------- /docs/source/requirements.md: -------------------------------------------------------------------------------- 1 | # Requirements 2 | 3 | For small datasets (~1000 molecules), it is possible to train models within a few minutes on a standard laptop with CPUs only. However, for larger datasets and larger Uni-Mol models, we recommend using a GPU for significantly faster training. 4 | 5 | Notice: [Uni-Core](https://github.com/dptech-corp/Uni-Core) is needed, please install it first. Current Uni-Core requires torch>=2.0.0 by default, if you want to install other version, please check its [Installation Documentation](https://github.com/dptech-corp/Uni-Core#installation). 6 | 7 | 8 | Uni-Mol is uses Python 3.8+ and all models are built with [PyTorch](https://pytorch.org/). See [Installation](#installation) for details on how to install Uni-Mol and its dependencies. 9 | -------------------------------------------------------------------------------- /docs/source/school.md: -------------------------------------------------------------------------------- 1 | # Uni-Mol School 2 | 3 | Welcome to Uni-Mol School! This course is designed to provide comprehensive training on Uni-Mol, a powerful tool for molecular modeling and simulations. 4 | 5 | ## Course Introduction 6 | The properties of drugs are determined by their three-dimensional structures, which are crucial for their efficacy and absorption. Drug design requires consideration of molecular diversity. Current Molecular Representation Learning (MRL) models mainly utilize one-dimensional or two-dimensional data, with limited capability to integrate 3D information. 7 | 8 | Uni-Mol, developed by the DP Technology team, is the first general large-scale 3D MRL framework in the field of drug design, expanding the application scope and representation capabilities of MRL. This framework consists of two models trained on billions of molecular 3D conformations and millions of protein pocket data, respectively. It has shown excellent performance in various molecular property prediction tasks, especially in 3D-related tasks. Besides drug design, Uni-Mol can also predict the properties of materials, such as gas adsorption performance of MOF materials and optical properties of OLED molecules. 9 | 10 | ## Course Content 11 | | Topic | Course Content | Instructor | 12 | |-------|----------------|------------| 13 | | Introduction to Uni-Mol | Uni-Mol molecular 3D representation learning framework and pre-trained models | Chen Letian | 14 | | Uni-Mol for Materials Science | Case study of Uni-Mol in predicting the properties of battery materials | Chen Letian | 15 | | | 3D Representation Learning Framework and Pre-trained Models for Nanoporous Materials | Chen Letian | 16 | | | Efficient Screening of Ir(III) Complex Emitters: A Study Combining Machine Learning and Computational Analysis | Chen Letian | 17 | | | Application of 3D Molecular Pre-trained Model Uni-Mol in Flow Batteries | Xie Qiming | 18 | | | Materials Science Uni-Mol Notebook Case Study | | 19 | | Uni-Mol for Biomedical Science | Application of Uni-Mol in Molecular Docking | Zhou Gengmo | 20 | | | Application of Uni-Mol in Molecular Generation | Song Ke | 21 | | | Biomedical Science Uni-Mol Notebook Case Study | | 22 | 23 | ## How to Enroll 24 | Enroll now and start your journey with Uni-Mol! [Click here to enroll](https://bohrium.dp.tech/courses/6134196349?tab=courses) 25 | 26 | Don't miss this opportunity to advance your knowledge and skills in molecular modeling with Uni-Mol! -------------------------------------------------------------------------------- /docs/source/task.rst: -------------------------------------------------------------------------------- 1 | .. _tasks: 2 | 3 | Task 4 | ====== 5 | 6 | `unimol_tools.tasks `_ oversees the tasks related to the model, such as training and prediction. 7 | 8 | 9 | Trainer 10 | ------- 11 | 12 | `unimol_tools.tasks.trainer.py `_ contains the :class:`~unimol_tools.unimol_tools.models.tasks.Trainer`, managing the training, validation, and testing phases. 13 | 14 | .. automodule:: unimol_tools.tasks.trainer 15 | :members: 16 | -------------------------------------------------------------------------------- /docs/source/train.rst: -------------------------------------------------------------------------------- 1 | .. _train: 2 | 3 | Interface 4 | ======================= 5 | 6 | 7 | Train 8 | ----- 9 | 10 | `unimol_tools.train.py `_ trains a Uni-Mol model. 11 | 12 | .. automodule:: unimol_tools.train 13 | :members: 14 | 15 | 16 | Predict 17 | ------------ 18 | 19 | `unimol_tools.predictor.py `_ predict through a Uni-Mol model. 20 | 21 | .. automodule:: unimol_tools.predict 22 | :members: 23 | 24 | 25 | Uni-Mol representation 26 | ------------------------ 27 | 28 | `unimol_tools.predictor.py `_ get the Uni-Mol representation. 29 | 30 | .. automodule:: unimol_tools.predictor 31 | :members: 32 | -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | .. _utils: 2 | 3 | Utils 4 | ======= 5 | 6 | `unimol_tools.utils `_ contains the utils related to the model, such as metrics and logger. 7 | 8 | 9 | Metrics 10 | ------- 11 | 12 | `unimol_tools.utils.metrics `_ contains the metrics included in the model. 13 | 14 | .. automodule:: unimol_tools.utils.metrics 15 | :members: 16 | 17 | Logger 18 | ------- 19 | 20 | `unimol_tools.utils.base_logger.py `_ control the logger. 21 | 22 | .. automodule:: unimol_tools.utils.base_logger 23 | :members: 24 | 25 | Config 26 | ------- 27 | 28 | `unimol_tools.utils.config_handler.py `_ manages the config input file. 29 | 30 | .. automodule:: unimol_tools.utils.config_handler 31 | :members: 32 | 33 | Padding 34 | ------- 35 | 36 | `unimol_tools.utils.util.py `_ contain some padding methods. 37 | 38 | .. automodule:: unimol_tools.utils.util 39 | :members: -------------------------------------------------------------------------------- /docs/source/weight.rst: -------------------------------------------------------------------------------- 1 | .. _weights: 2 | 3 | Weights 4 | ======= 5 | 6 | We recommend installing ``huggingface_hub`` so that the required Uni-Mol models can be automatically downloaded at runtime! It can be installed by: 7 | 8 | .. code-block:: bash 9 | 10 | pip install huggingface_hub 11 | 12 | ``huggingface_hub`` allows you to easily download and manage models from the Hugging Face Hub, which is key for using Uni-Mol models. 13 | 14 | Models in Huggingface 15 | --------------------- 16 | 17 | The Uni-Mol pretrained models can be found at `dptech/Uni-Mol-Models `_. 18 | 19 | If the download is slow, you can use other mirrors, such as: 20 | 21 | .. code-block:: bash 22 | 23 | export HF_ENDPOINT=https://hf-mirror.com 24 | 25 | Setting the ``HF_ENDPOINT`` environment variable specifies the mirror address for the Hugging Face Hub to use when downloading models. 26 | 27 | `unimol_tools.weights.weight_hub.py `_ control the logger. 28 | 29 | .. automodule:: unimol_tools.weights.weighthub 30 | :members: -------------------------------------------------------------------------------- /unimol/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM dptechnology/unicore:0.0.1-pytorch1.11.0-cuda11.3 2 | 3 | RUN pip install setuptools wheel twine 4 | 5 | RUN pip install rdkit-pypi==2021.9.5.1 6 | 7 | RUN ldconfig && \ 8 | apt-get clean && \ 9 | apt-get autoremove && \ 10 | rm -rf /var/lib/apt/lists/* /tmp/* && \ 11 | pip cache purge 12 | -------------------------------------------------------------------------------- /unimol/example_data/molecule/dict.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [CLS] 3 | [SEP] 4 | [UNK] 5 | C 6 | N 7 | O 8 | S 9 | H 10 | Cl 11 | F 12 | Br 13 | I 14 | Si 15 | P 16 | B 17 | Na 18 | K 19 | Al 20 | Ca 21 | Sn 22 | As 23 | Hg 24 | Fe 25 | Zn 26 | Cr 27 | Se 28 | Gd 29 | Au 30 | Li -------------------------------------------------------------------------------- /unimol/example_data/molecule/train.lmdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/Uni-Mol/90f52c41299a1a582da0f9765e9f87aa21faa16a/unimol/example_data/molecule/train.lmdb -------------------------------------------------------------------------------- /unimol/example_data/molecule/valid.lmdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/Uni-Mol/90f52c41299a1a582da0f9765e9f87aa21faa16a/unimol/example_data/molecule/valid.lmdb -------------------------------------------------------------------------------- /unimol/example_data/pocket/dict_coarse.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [CLS] 3 | [SEP] 4 | [UNK] 5 | C 6 | N 7 | O 8 | S 9 | H 10 | -------------------------------------------------------------------------------- /unimol/example_data/pocket/train.lmdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/Uni-Mol/90f52c41299a1a582da0f9765e9f87aa21faa16a/unimol/example_data/pocket/train.lmdb -------------------------------------------------------------------------------- /unimol/example_data/pocket/valid.lmdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/Uni-Mol/90f52c41299a1a582da0f9765e9f87aa21faa16a/unimol/example_data/pocket/valid.lmdb -------------------------------------------------------------------------------- /unimol/figure/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/Uni-Mol/90f52c41299a1a582da0f9765e9f87aa21faa16a/unimol/figure/overview.png -------------------------------------------------------------------------------- /unimol/notebooks/mol_property_demo.csv: -------------------------------------------------------------------------------- 1 | mol,Class 2 | O=C1N(C)C(=N[C@@]1(c1cc(ccc1)-c1cccnc1)c1ccncc1)N,1 3 | s1cc(cc1)[C@@]1(N=C(N)N(C)C1=O)c1cc(ccc1)-c1cccnc1,1 4 | O=C(NC1CCCCC1)CCc1cc2cc(ccc2nc1N)-c1ccccc1C,1 5 | S(=O)(=O)(C(CCC)CCC)C[C@@H](NC(OCc1ccccc1)=O)C(=O)N[C@H]([C@H](O)C[NH2+]Cc1cc(OC)ccc1)Cc1ccccc1,0 6 | S1(=O)(=O)C[C@@H](Cc2cc(F)c3NCC4(CCC(F)(F)CC4)c3c2)[C@H](O)[C@@H]([NH2+]Cc2cc(ccc2)C(C)(C)C)C1,0 7 | O=C(N1CC[C@H](C[C@H]1c1ccccc1)c1ccccc1)[C@@H]1C[NH2+]C[C@]12CCCc1c2cccc1,0 8 | S1(=O)(=O)C[C@@H](Cc2cc(C[C@@H]3N(CCC)C(OC3)=O)c(O)cc2)[C@H](O)[C@@H]([NH2+]Cc2cc(ccc2)C(C)C)C1,0 9 | O(C)c1ccc(cc1C)[C@@]1(N=C(N)N(C)C1=O)C12CC3CC(C1)CC(C2)C3,0 10 | Clc1cc2CC([NH+]=C(N[C@@H](Cc3ccccc3)C=3NC(=O)c4c(N=3)ccnc4)c2cc1)(C)C,0 11 | O=C1N(CCC1)c1cc(cc(NCC)c1)C(=O)N[C@H]([C@H](O)C[NH2+]C1CCCCC1)Cc1ccccc1,0 12 | Fc1ccc(NC(=O)c2ncc(OCC)cc2)cc1[C@]1(N=C(OCC1(F)F)N)C,0 13 | O(C)c1cc(ccc1)C[NH2+]C[C@@H](O)[C@@H](NC(=O)c1cc(ccc1)C(=O)N(CCC)CCC)Cc1ccccc1,0 14 | FC(F)(F)c1cc(ccc1)C[NH2+]C[C@@H](O)[C@@H](NC(=O)C=1C=C(N2CCCC2=O)C(=O)N(C=1)C1CCCC1)Cc1ccccc1,1 15 | Fc1ccc(NC(=O)c2ncc(cc2)C#N)cc1[C@]1(N=C(OC[C@@H]1F)N)CF,1 16 | FC1(F)CN2C(=NC1)[C@]([NH+]=C2N)(c1cc(ccc1)C#CCOC)c1ccc(OC(F)F)cc1,1 17 | [NH+]=1[C@](N=C(C)C=1N)(C1CC1)c1cc(ccc1)-c1cc(cnc1)C#CC,1 18 | Fc1ccc(cc1-c1cncnc1)[C@]1([NH+]=C(N)c2c1cccc2F)c1cc(ncc1)C(F)(F)F,0 19 | O=C1N(C)C(=N[C@@]1(c1cc(ccc1)-c1cncnc1)c1cn(nc1)CC)N,1 20 | O(C)c1cc(ccc1)-c1cc(ccc1)CC[C@]1(N=C(N)N(C)C(=O)C1)C,0 21 | O1c2c(cc(cc2)-c2cc(ccc2)C#N)[C@]2(N=C(N)N(C)C2=O)CC1(C)C,1 -------------------------------------------------------------------------------- /unimol/requirements.txt: -------------------------------------------------------------------------------- 1 | git+git://github.com/dptech-corp/Uni-Core.git@stable#egg=Uni-Core 2 | -------------------------------------------------------------------------------- /unimol/setup.py: -------------------------------------------------------------------------------- 1 | """Install script for setuptools.""" 2 | 3 | from setuptools import find_packages 4 | from setuptools import setup 5 | 6 | setup( 7 | name="unimol", 8 | version="1.0.0", 9 | description="", 10 | author="DP Technology", 11 | author_email="unimol@dp.tech", 12 | license="The MIT License", 13 | url="https://github.com/deepmodeling/Uni-Mol", 14 | packages=find_packages( 15 | exclude=["scripts", "tests", "example_data", "docker", "figure"] 16 | ), 17 | install_requires=[ 18 | "numpy", 19 | "pandas", 20 | "scikit-learn-extra", 21 | ], 22 | classifiers=[ 23 | "Development Status :: 5 - Production/Stable", 24 | "Intended Audience :: Science/Research", 25 | "License :: OSI Approved :: Apache Software License", 26 | "Operating System :: POSIX :: Linux", 27 | "Programming Language :: Python :: 3.7", 28 | "Programming Language :: Python :: 3.8", 29 | "Programming Language :: Python :: 3.9", 30 | "Programming Language :: Python :: 3.10", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /unimol/unimol/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import unimol.tasks 3 | import unimol.data 4 | import unimol.models 5 | import unimol.losses 6 | import unimol.utils 7 | -------------------------------------------------------------------------------- /unimol/unimol/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .key_dataset import KeyDataset 2 | from .normalize_dataset import ( 3 | NormalizeDataset, 4 | NormalizeDockingPoseDataset, 5 | ) 6 | from .remove_hydrogen_dataset import ( 7 | RemoveHydrogenDataset, 8 | RemoveHydrogenResiduePocketDataset, 9 | RemoveHydrogenPocketDataset, 10 | ) 11 | from .tta_dataset import ( 12 | TTADataset, 13 | TTADockingPoseDataset, 14 | ) 15 | from .cropping_dataset import ( 16 | CroppingDataset, 17 | CroppingPocketDataset, 18 | CroppingResiduePocketDataset, 19 | CroppingPocketDockingPoseDataset, 20 | ) 21 | from .atom_type_dataset import AtomTypeDataset 22 | from .add_2d_conformer_dataset import Add2DConformerDataset 23 | from .distance_dataset import ( 24 | DistanceDataset, 25 | EdgeTypeDataset, 26 | CrossDistanceDataset, 27 | ) 28 | from .conformer_sample_dataset import ( 29 | ConformerSampleDataset, 30 | ConformerSamplePocketDataset, 31 | ConformerSamplePocketFinetuneDataset, 32 | ConformerSampleConfGDataset, 33 | ConformerSampleConfGV2Dataset, 34 | ConformerSampleDockingPoseDataset, 35 | ) 36 | from .mask_points_dataset import MaskPointsDataset, MaskPointsPocketDataset 37 | from .coord_pad_dataset import RightPadDatasetCoord, RightPadDatasetCross2D 38 | from .from_str_dataset import FromStrLabelDataset 39 | from .lmdb_dataset import LMDBDataset 40 | from .prepend_and_append_2d_dataset import PrependAndAppend2DDataset 41 | 42 | __all__ = [] -------------------------------------------------------------------------------- /unimol/unimol/data/add_2d_conformer_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | from rdkit import Chem 9 | from rdkit.Chem import AllChem 10 | 11 | 12 | class Add2DConformerDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, smi, atoms, coordinates): 14 | self.dataset = dataset 15 | self.smi = smi 16 | self.atoms = atoms 17 | self.coordinates = coordinates 18 | self.set_epoch(None) 19 | 20 | def set_epoch(self, epoch, **unused): 21 | super().set_epoch(epoch) 22 | self.epoch = epoch 23 | 24 | @lru_cache(maxsize=16) 25 | def __cached_item__(self, index: int, epoch: int): 26 | atoms = np.array(self.dataset[index][self.atoms]) 27 | assert len(atoms) > 0 28 | smi = self.dataset[index][self.smi] 29 | coordinates_2d = smi2_2Dcoords(smi) 30 | coordinates = self.dataset[index][self.coordinates] 31 | coordinates.append(coordinates_2d) 32 | return {"smi": smi, "atoms": atoms, "coordinates": coordinates} 33 | 34 | def __getitem__(self, index: int): 35 | return self.__cached_item__(index, self.epoch) 36 | 37 | 38 | def smi2_2Dcoords(smi): 39 | mol = Chem.MolFromSmiles(smi) 40 | mol = AllChem.AddHs(mol) 41 | AllChem.Compute2DCoords(mol) 42 | coordinates = mol.GetConformer().GetPositions().astype(np.float32) 43 | len(mol.GetAtoms()) == len( 44 | coordinates 45 | ), "2D coordinates shape is not align with {}".format(smi) 46 | return coordinates 47 | -------------------------------------------------------------------------------- /unimol/unimol/data/atom_type_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from functools import lru_cache 6 | from unicore.data import BaseWrapperDataset 7 | 8 | 9 | class AtomTypeDataset(BaseWrapperDataset): 10 | def __init__( 11 | self, 12 | raw_dataset, 13 | dataset, 14 | smi="smi", 15 | atoms="atoms", 16 | ): 17 | self.raw_dataset = raw_dataset 18 | self.dataset = dataset 19 | self.smi = smi 20 | self.atoms = atoms 21 | 22 | @lru_cache(maxsize=16) 23 | def __getitem__(self, index: int): 24 | # for low rdkit version 25 | if len(self.dataset[index]["atoms"]) != len(self.dataset[index]["coordinates"]): 26 | min_len = min( 27 | len(self.dataset[index]["atoms"]), 28 | len(self.dataset[index]["coordinates"]), 29 | ) 30 | self.dataset[index]["atoms"] = self.dataset[index]["atoms"][:min_len] 31 | self.dataset[index]["coordinates"] = self.dataset[index]["coordinates"][ 32 | :min_len 33 | ] 34 | return self.dataset[index] 35 | -------------------------------------------------------------------------------- /unimol/unimol/data/coord_pad_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from unicore.data import BaseWrapperDataset 6 | 7 | 8 | def collate_tokens_coords( 9 | values, 10 | pad_idx, 11 | left_pad=False, 12 | pad_to_length=None, 13 | pad_to_multiple=1, 14 | ): 15 | """Convert a list of 1d tensors into a padded 2d tensor.""" 16 | size = max(v.size(0) for v in values) 17 | size = size if pad_to_length is None else max(size, pad_to_length) 18 | if pad_to_multiple != 1 and size % pad_to_multiple != 0: 19 | size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) 20 | res = values[0].new(len(values), size, 3).fill_(pad_idx) 21 | 22 | def copy_tensor(src, dst): 23 | assert dst.numel() == src.numel() 24 | dst.copy_(src) 25 | 26 | for i, v in enumerate(values): 27 | copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :]) 28 | return res 29 | 30 | 31 | class RightPadDatasetCoord(BaseWrapperDataset): 32 | def __init__(self, dataset, pad_idx, left_pad=False): 33 | super().__init__(dataset) 34 | self.pad_idx = pad_idx 35 | self.left_pad = left_pad 36 | 37 | def collater(self, samples): 38 | return collate_tokens_coords( 39 | samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8 40 | ) 41 | 42 | 43 | def collate_cross_2d( 44 | values, 45 | pad_idx, 46 | left_pad=False, 47 | pad_to_length=None, 48 | pad_to_multiple=1, 49 | ): 50 | """Convert a list of 2d tensors into a padded 2d tensor.""" 51 | size_h = max(v.size(0) for v in values) 52 | size_w = max(v.size(1) for v in values) 53 | if pad_to_multiple != 1 and size_h % pad_to_multiple != 0: 54 | size_h = int(((size_h - 0.1) // pad_to_multiple + 1) * pad_to_multiple) 55 | if pad_to_multiple != 1 and size_w % pad_to_multiple != 0: 56 | size_w = int(((size_w - 0.1) // pad_to_multiple + 1) * pad_to_multiple) 57 | res = values[0].new(len(values), size_h, size_w).fill_(pad_idx) 58 | 59 | def copy_tensor(src, dst): 60 | assert dst.numel() == src.numel() 61 | dst.copy_(src) 62 | 63 | for i, v in enumerate(values): 64 | copy_tensor( 65 | v, 66 | res[i][size_h - v.size(0) :, size_w - v.size(1) :] 67 | if left_pad 68 | else res[i][: v.size(0), : v.size(1)], 69 | ) 70 | return res 71 | 72 | 73 | class RightPadDatasetCross2D(BaseWrapperDataset): 74 | def __init__(self, dataset, pad_idx, left_pad=False): 75 | super().__init__(dataset) 76 | self.pad_idx = pad_idx 77 | self.left_pad = left_pad 78 | 79 | def collater(self, samples): 80 | return collate_cross_2d( 81 | samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8 82 | ) 83 | -------------------------------------------------------------------------------- /unimol/unimol/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | import contextlib 7 | 8 | 9 | @contextlib.contextmanager 10 | def numpy_seed(seed, *addl_seeds): 11 | """Context manager which seeds the NumPy PRNG with the specified seed and 12 | restores the state afterward""" 13 | if seed is None: 14 | yield 15 | return 16 | if len(addl_seeds) > 0: 17 | seed = int(hash((seed, *addl_seeds)) % 1e6) 18 | state = np.random.get_state() 19 | np.random.seed(seed) 20 | try: 21 | yield 22 | finally: 23 | np.random.set_state(state) 24 | -------------------------------------------------------------------------------- /unimol/unimol/data/distance_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | import torch 7 | from scipy.spatial import distance_matrix 8 | from functools import lru_cache 9 | from unicore.data import BaseWrapperDataset 10 | 11 | 12 | class DistanceDataset(BaseWrapperDataset): 13 | def __init__(self, dataset): 14 | super().__init__(dataset) 15 | self.dataset = dataset 16 | 17 | @lru_cache(maxsize=16) 18 | def __getitem__(self, idx): 19 | pos = self.dataset[idx].view(-1, 3).numpy() 20 | dist = distance_matrix(pos, pos).astype(np.float32) 21 | return torch.from_numpy(dist) 22 | 23 | 24 | class EdgeTypeDataset(BaseWrapperDataset): 25 | def __init__(self, dataset: torch.utils.data.Dataset, num_types: int): 26 | self.dataset = dataset 27 | self.num_types = num_types 28 | 29 | @lru_cache(maxsize=16) 30 | def __getitem__(self, index: int): 31 | node_input = self.dataset[index].clone() 32 | offset = node_input.view(-1, 1) * self.num_types + node_input.view(1, -1) 33 | return offset 34 | 35 | 36 | class CrossDistanceDataset(BaseWrapperDataset): 37 | def __init__(self, mol_dataset, pocket_dataset): 38 | super().__init__(mol_dataset) 39 | self.mol_dataset = mol_dataset 40 | self.pocket_dataset = pocket_dataset 41 | 42 | @lru_cache(maxsize=16) 43 | def __getitem__(self, idx): 44 | mol_pos = self.mol_dataset[idx].view(-1, 3).numpy() 45 | pocket_pos = self.pocket_dataset[idx].view(-1, 3).numpy() 46 | dist = distance_matrix(mol_pos, pocket_pos).astype(np.float32) 47 | return torch.from_numpy(dist) 48 | -------------------------------------------------------------------------------- /unimol/unimol/data/from_str_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import lru_cache 3 | from unicore.data import UnicoreDataset 4 | 5 | 6 | class FromStrLabelDataset(UnicoreDataset): 7 | def __init__(self, labels): 8 | super().__init__() 9 | self.labels = labels 10 | 11 | @lru_cache(maxsize=16) 12 | def __getitem__(self, index): 13 | return self.labels[index] 14 | 15 | def __len__(self): 16 | return len(self.labels) 17 | 18 | def collater(self, samples): 19 | return torch.tensor(list(map(float, samples))) 20 | -------------------------------------------------------------------------------- /unimol/unimol/data/key_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from functools import lru_cache 6 | from unicore.data import BaseWrapperDataset 7 | 8 | 9 | class KeyDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, key): 11 | self.dataset = dataset 12 | self.key = key 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | @lru_cache(maxsize=16) 18 | def __getitem__(self, idx): 19 | return self.dataset[idx][self.key] 20 | -------------------------------------------------------------------------------- /unimol/unimol/data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | 6 | import lmdb 7 | import os 8 | import pickle 9 | from functools import lru_cache 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class LMDBDataset: 16 | def __init__(self, db_path): 17 | self.db_path = db_path 18 | assert os.path.isfile(self.db_path), "{} not found".format(self.db_path) 19 | env = self.connect_db(self.db_path) 20 | with env.begin() as txn: 21 | self._keys = list(txn.cursor().iternext(values=False)) 22 | 23 | def connect_db(self, lmdb_path, save_to_self=False): 24 | env = lmdb.open( 25 | lmdb_path, 26 | subdir=False, 27 | readonly=True, 28 | lock=False, 29 | readahead=False, 30 | meminit=False, 31 | max_readers=256, 32 | ) 33 | if not save_to_self: 34 | return env 35 | else: 36 | self.env = env 37 | 38 | def __len__(self): 39 | return len(self._keys) 40 | 41 | @lru_cache(maxsize=16) 42 | def __getitem__(self, idx): 43 | if not hasattr(self, "env"): 44 | self.connect_db(self.db_path, save_to_self=True) 45 | datapoint_pickled = self.env.begin().get(f"{idx}".encode("ascii")) 46 | data = pickle.loads(datapoint_pickled) 47 | return data 48 | -------------------------------------------------------------------------------- /unimol/unimol/data/normalize_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class NormalizeDataset(BaseWrapperDataset): 11 | def __init__(self, dataset, coordinates, normalize_coord=True): 12 | self.dataset = dataset 13 | self.coordinates = coordinates 14 | self.normalize_coord = normalize_coord # normalize the coordinates. 15 | self.set_epoch(None) 16 | 17 | def set_epoch(self, epoch, **unused): 18 | super().set_epoch(epoch) 19 | self.epoch = epoch 20 | 21 | @lru_cache(maxsize=16) 22 | def __cached_item__(self, index: int, epoch: int): 23 | dd = self.dataset[index].copy() 24 | coordinates = dd[self.coordinates] 25 | # normalize 26 | if self.normalize_coord: 27 | coordinates = coordinates - coordinates.mean(axis=0) 28 | dd[self.coordinates] = coordinates.astype(np.float32) 29 | return dd 30 | 31 | def __getitem__(self, index: int): 32 | return self.__cached_item__(index, self.epoch) 33 | 34 | 35 | class NormalizeDockingPoseDataset(BaseWrapperDataset): 36 | def __init__( 37 | self, 38 | dataset, 39 | coordinates, 40 | pocket_coordinates, 41 | center_coordinates="center_coordinates", 42 | ): 43 | self.dataset = dataset 44 | self.coordinates = coordinates 45 | self.pocket_coordinates = pocket_coordinates 46 | self.center_coordinates = center_coordinates 47 | self.set_epoch(None) 48 | 49 | def set_epoch(self, epoch, **unused): 50 | super().set_epoch(epoch) 51 | self.epoch = epoch 52 | 53 | @lru_cache(maxsize=16) 54 | def __cached_item__(self, index: int, epoch: int): 55 | dd = self.dataset[index].copy() 56 | coordinates = dd[self.coordinates] 57 | pocket_coordinates = dd[self.pocket_coordinates] 58 | # normalize coordinates and pocket coordinates ,align with pocket center coordinates 59 | center_coordinates = pocket_coordinates.mean(axis=0) 60 | coordinates = coordinates - center_coordinates 61 | pocket_coordinates = pocket_coordinates - center_coordinates 62 | dd[self.coordinates] = coordinates.astype(np.float32) 63 | dd[self.pocket_coordinates] = pocket_coordinates.astype(np.float32) 64 | dd[self.center_coordinates] = center_coordinates.astype(np.float32) 65 | return dd 66 | 67 | def __getitem__(self, index: int): 68 | return self.__cached_item__(index, self.epoch) 69 | -------------------------------------------------------------------------------- /unimol/unimol/data/prepend_and_append_2d_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import torch 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class PrependAndAppend2DDataset(BaseWrapperDataset): 11 | def __init__(self, dataset, token=None): 12 | super().__init__(dataset) 13 | self.token = token 14 | 15 | @lru_cache(maxsize=16) 16 | def __getitem__(self, idx): 17 | item = self.dataset[idx] 18 | if self.token is not None: 19 | h, w = item.size(-2), item.size(-1) 20 | new_item = torch.full((h + 2, w + 2), self.token).type_as(item) 21 | new_item[1:-1, 1:-1] = item 22 | return new_item 23 | return item 24 | -------------------------------------------------------------------------------- /unimol/unimol/data/tta_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class TTADataset(BaseWrapperDataset): 11 | def __init__(self, dataset, seed, atoms, coordinates, conf_size=10): 12 | self.dataset = dataset 13 | self.seed = seed 14 | self.atoms = atoms 15 | self.coordinates = coordinates 16 | self.conf_size = conf_size 17 | self.set_epoch(None) 18 | 19 | def set_epoch(self, epoch, **unused): 20 | super().set_epoch(epoch) 21 | self.epoch = epoch 22 | 23 | def __len__(self): 24 | return len(self.dataset) * self.conf_size 25 | 26 | @lru_cache(maxsize=16) 27 | def __cached_item__(self, index: int, epoch: int): 28 | smi_idx = index // self.conf_size 29 | coord_idx = index % self.conf_size 30 | atoms = np.array(self.dataset[smi_idx][self.atoms]) 31 | coordinates = np.array(self.dataset[smi_idx][self.coordinates][coord_idx]) 32 | smi = self.dataset[smi_idx]["smi"] 33 | target = self.dataset[smi_idx].get("target", None) 34 | return { 35 | "atoms": atoms, 36 | "coordinates": coordinates.astype(np.float32), 37 | "smi": smi, 38 | "target": target, 39 | } 40 | 41 | def __getitem__(self, index: int): 42 | return self.__cached_item__(index, self.epoch) 43 | 44 | 45 | class TTADockingPoseDataset(BaseWrapperDataset): 46 | def __init__( 47 | self, 48 | dataset, 49 | atoms, 50 | coordinates, 51 | pocket_atoms, 52 | pocket_coordinates, 53 | holo_coordinates, 54 | holo_pocket_coordinates, 55 | is_train=True, 56 | conf_size=10, 57 | ): 58 | self.dataset = dataset 59 | self.atoms = atoms 60 | self.coordinates = coordinates 61 | self.pocket_atoms = pocket_atoms 62 | self.pocket_coordinates = pocket_coordinates 63 | self.holo_coordinates = holo_coordinates 64 | self.holo_pocket_coordinates = holo_pocket_coordinates 65 | self.is_train = is_train 66 | self.conf_size = conf_size 67 | self.set_epoch(None) 68 | 69 | def set_epoch(self, epoch, **unused): 70 | super().set_epoch(epoch) 71 | self.epoch = epoch 72 | 73 | def __len__(self): 74 | return len(self.dataset) * self.conf_size 75 | 76 | @lru_cache(maxsize=16) 77 | def __cached_item__(self, index: int, epoch: int): 78 | smi_idx = index // self.conf_size 79 | coord_idx = index % self.conf_size 80 | atoms = np.array(self.dataset[smi_idx][self.atoms]) 81 | coordinates = np.array(self.dataset[smi_idx][self.coordinates][coord_idx]) 82 | pocket_atoms = np.array( 83 | [item[0] for item in self.dataset[smi_idx][self.pocket_atoms]] 84 | ) 85 | pocket_coordinates = np.array(self.dataset[smi_idx][self.pocket_coordinates][0]) 86 | if self.is_train: 87 | holo_coordinates = np.array(self.dataset[smi_idx][self.holo_coordinates][0]) 88 | holo_pocket_coordinates = np.array( 89 | self.dataset[smi_idx][self.holo_pocket_coordinates][0] 90 | ) 91 | else: 92 | holo_coordinates = coordinates 93 | holo_pocket_coordinates = pocket_coordinates 94 | 95 | smi = self.dataset[smi_idx]["smi"] 96 | pocket = self.dataset[smi_idx]["pocket"] 97 | 98 | return { 99 | "atoms": atoms, 100 | "coordinates": coordinates.astype(np.float32), 101 | "pocket_atoms": pocket_atoms, 102 | "pocket_coordinates": pocket_coordinates.astype(np.float32), 103 | "holo_coordinates": holo_coordinates.astype(np.float32), 104 | "holo_pocket_coordinates": holo_pocket_coordinates.astype(np.float32), 105 | "smi": smi, 106 | "pocket": pocket, 107 | } 108 | 109 | def __getitem__(self, index: int): 110 | return self.__cached_item__(index, self.epoch) 111 | -------------------------------------------------------------------------------- /unimol/unimol/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) DP Techonology, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | import pickle 11 | import torch 12 | from unicore import checkpoint_utils, distributed_utils, options, utils 13 | from unicore.logging import progress_bar 14 | from unicore import tasks 15 | 16 | logging.basicConfig( 17 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 18 | datefmt="%Y-%m-%d %H:%M:%S", 19 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 20 | stream=sys.stdout, 21 | ) 22 | logger = logging.getLogger("unimol.inference") 23 | 24 | 25 | def main(args): 26 | 27 | assert ( 28 | args.batch_size is not None 29 | ), "Must specify batch size either with --batch-size" 30 | 31 | use_fp16 = args.fp16 32 | use_cuda = torch.cuda.is_available() and not args.cpu 33 | 34 | if use_cuda: 35 | torch.cuda.set_device(args.device_id) 36 | 37 | if args.distributed_world_size > 1: 38 | data_parallel_world_size = distributed_utils.get_data_parallel_world_size() 39 | data_parallel_rank = distributed_utils.get_data_parallel_rank() 40 | else: 41 | data_parallel_world_size = 1 42 | data_parallel_rank = 0 43 | 44 | # Load model 45 | logger.info("loading model(s) from {}".format(args.path)) 46 | state = checkpoint_utils.load_checkpoint_to_cpu(args.path) 47 | task = tasks.setup_task(args) 48 | model = task.build_model(args) 49 | model.load_state_dict(state["model"], strict=False) 50 | 51 | # Move models to GPU 52 | if use_cuda: 53 | model.cuda() 54 | # fp16 only supported on CUDA for fused kernels 55 | if use_fp16: 56 | model.half() 57 | 58 | # Print args 59 | logger.info(args) 60 | 61 | # Build loss 62 | loss = task.build_loss(args) 63 | loss.eval() 64 | 65 | for subset in args.valid_subset.split(","): 66 | try: 67 | task.load_dataset(subset, combine=False, epoch=1) 68 | dataset = task.dataset(subset) 69 | except KeyError: 70 | raise Exception("Cannot find dataset: " + subset) 71 | 72 | if not os.path.exists(args.results_path): 73 | os.makedirs(args.results_path) 74 | try: 75 | fname = (args.path).split("/")[-2] 76 | except: 77 | fname = 'infer' 78 | save_path = os.path.join(args.results_path, fname + "_" + subset + ".out.pkl") 79 | # Initialize data iterator 80 | itr = task.get_batch_iterator( 81 | dataset=dataset, 82 | batch_size=args.batch_size, 83 | ignore_invalid_inputs=True, 84 | required_batch_size_multiple=args.required_batch_size_multiple, 85 | seed=args.seed, 86 | num_shards=data_parallel_world_size, 87 | shard_id=data_parallel_rank, 88 | num_workers=args.num_workers, 89 | data_buffer_size=args.data_buffer_size, 90 | ).next_epoch_itr(shuffle=False) 91 | progress = progress_bar.progress_bar( 92 | itr, 93 | log_format=args.log_format, 94 | log_interval=args.log_interval, 95 | prefix=f"valid on '{subset}' subset", 96 | default_log_format=("tqdm" if not args.no_progress_bar else "simple"), 97 | ) 98 | log_outputs = [] 99 | for i, sample in enumerate(progress): 100 | sample = utils.move_to_cuda(sample) if use_cuda else sample 101 | if len(sample) == 0: 102 | continue 103 | _, _, log_output = task.valid_step(sample, model, loss, test=True) 104 | progress.log({}, step=i) 105 | log_outputs.append(log_output) 106 | pickle.dump(log_outputs, open(save_path, "wb")) 107 | logger.info("Done inference! ") 108 | return None 109 | 110 | 111 | def cli_main(): 112 | parser = options.get_validation_parser() 113 | options.add_model_args(parser) 114 | args = options.parse_args_and_arch(parser) 115 | 116 | distributed_utils.call_main(args, main) 117 | 118 | 119 | if __name__ == "__main__": 120 | cli_main() 121 | -------------------------------------------------------------------------------- /unimol/unimol/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol.losses." + file.name[:-3]) 8 | -------------------------------------------------------------------------------- /unimol/unimol/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unimol import UniMolModel 2 | from .transformer_encoder_with_pair import TransformerEncoderWithPair 3 | from .conf_gen import UnimolConfGModel 4 | from .docking_pose import DockingPoseModel -------------------------------------------------------------------------------- /unimol/unimol/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol.tasks." + file.name[:-3]) 8 | -------------------------------------------------------------------------------- /unimol/unimol/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/Uni-Mol/90f52c41299a1a582da0f9765e9f87aa21faa16a/unimol/unimol/utils/__init__.py -------------------------------------------------------------------------------- /unimol/unimol/utils/docking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Techonology, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import numpy as np 8 | import pandas as pd 9 | from multiprocessing import Pool 10 | from tqdm import tqdm 11 | import glob 12 | import argparse 13 | from docking_utils import ( 14 | docking_data_pre, 15 | ensemble_iterations, 16 | print_results, 17 | rmsd_func, 18 | ) 19 | import warnings 20 | 21 | warnings.filterwarnings(action="ignore") 22 | 23 | 24 | def result_log(dir_path): 25 | ### result logging ### 26 | output_dir = os.path.join(dir_path, "cache") 27 | rmsd_results = [] 28 | for path in glob.glob(os.path.join(output_dir, "*.docking.pkl")): 29 | ( 30 | mol, 31 | bst_predict_coords, 32 | holo_coords, 33 | bst_loss, 34 | smi, 35 | pocket, 36 | pocket_coords, 37 | ) = pd.read_pickle(path) 38 | rmsd = rmsd_func(holo_coords, bst_predict_coords, mol=mol) 39 | rmsd_results.append(rmsd) 40 | rmsd_results = np.array(rmsd_results) 41 | print_results(rmsd_results) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser(description="docking") 46 | parser.add_argument( 47 | "--reference-file", 48 | type=str, 49 | default="./protein_ligand_binding_pose_prediction/test.lmdb", 50 | help="Location of the reference set", 51 | ) 52 | parser.add_argument("--nthreads", type=int, default=40, help="num of threads") 53 | parser.add_argument( 54 | "--predict-file", 55 | type=str, 56 | default="./infer_pose/save_pose_test.out.pkl", 57 | help="Location of the prediction file", 58 | ) 59 | parser.add_argument( 60 | "--output-path", 61 | type=str, 62 | default="./protein_ligand_binding_pose_prediction", 63 | help="Location of the docking output path", 64 | ) 65 | parser.add_argument( 66 | "--optimization-model", 67 | type=str, 68 | default="conformer", 69 | help="Optimize coordinates ('coordinate') or ligand internal torsions ('conformer')", 70 | choices=["coordinate", "conformer"], 71 | ) 72 | args = parser.parse_args() 73 | 74 | raw_data_path, predict_path, dir_path, nthreads, model_choice = ( 75 | args.reference_file, 76 | args.predict_file, 77 | args.output_path, 78 | args.nthreads, 79 | args.optimization_model, 80 | ) 81 | tta_times = 10 82 | ( 83 | mol_list, 84 | smi_list, 85 | pocket_list, 86 | pocket_coords_list, 87 | distance_predict_list, 88 | holo_distance_predict_list, 89 | holo_coords_list, 90 | holo_center_coords_list, 91 | ) = docking_data_pre(raw_data_path, predict_path) 92 | iterations = ensemble_iterations( 93 | mol_list, 94 | smi_list, 95 | pocket_list, 96 | pocket_coords_list, 97 | distance_predict_list, 98 | holo_distance_predict_list, 99 | holo_coords_list, 100 | holo_center_coords_list, 101 | tta_times=tta_times, 102 | ) 103 | sz = len(mol_list) // tta_times 104 | new_pocket_list = pocket_list[::tta_times] 105 | output_dir = os.path.join(dir_path, "cache") 106 | os.makedirs(output_dir, exist_ok=True) 107 | 108 | def dump(content): 109 | pocket = content[3] 110 | output_name = os.path.join(output_dir, "{}.pkl".format(pocket)) 111 | try: 112 | os.remove(output_name) 113 | except: 114 | pass 115 | pd.to_pickle(content, output_name) 116 | return True 117 | 118 | # skip step if repeat 119 | with Pool(nthreads) as pool: 120 | for inner_output in tqdm(pool.imap_unordered(dump, iterations), total=sz): 121 | if not inner_output: 122 | print("fail to dump") 123 | 124 | def single_docking(pocket_name): 125 | input_name = os.path.join(output_dir, "{}.pkl".format(pocket_name)) 126 | output_name = os.path.join(output_dir, "{}.docking.pkl".format(pocket_name)) 127 | output_ligand_name = os.path.join( 128 | output_dir, "{}.ligand.sdf".format(pocket_name) 129 | ) 130 | try: 131 | os.remove(output_name) 132 | except: 133 | pass 134 | try: 135 | os.remove(output_ligand_name) 136 | except: 137 | pass 138 | 139 | cmd = "python ./unimol/utils/{}_model.py --input {} --output {} --output-ligand {}".format( 140 | model_choice, input_name, output_name, output_ligand_name 141 | ) 142 | os.system(cmd) 143 | return True 144 | 145 | 146 | with Pool(nthreads) as pool: 147 | for inner_output in tqdm( 148 | pool.imap_unordered(single_docking, new_pocket_list), total=len(new_pocket_list) 149 | ): 150 | if not inner_output: 151 | print("fail to docking") 152 | 153 | result_log(args.output_path) 154 | -------------------------------------------------------------------------------- /unimol2/README.md: -------------------------------------------------------------------------------- 1 | 2 | Uni-Mol2: Exploring Molecular Pretraining Model at Scale 3 | ================================================================== 4 | 5 |

6 | overview 7 |

8 | 9 | We present Uni-Mol2 , an innovative 10 | molecular pretraining model that leverages a two-track transformer to effectively 11 | integrate features at the atomic level, graph level, and geometry structure level. 12 | Along with this, we systematically investigate the scaling law within molecular 13 | pretraining models, characterizing the power-law correlations between validation 14 | loss and model size, dataset size, and computational resources. Consequently, 15 | we successfully scale Uni-Mol2 to 1.1 billion parameters through pretraining on 16 | 800 million conformations, making it the largest molecular pretraining model to 17 | date. 18 | 19 | 20 | Dependencies 21 | ------------ 22 | - [Uni-Core](https://github.com/dptech-corp/Uni-Core) with pytorch > 2.0.0, check its [Installation Documentation](https://github.com/dptech-corp/Uni-Core#installation). 23 | - rdkit==2022.09.5, install via `pip install rdkit==2022.09.5` 24 | 25 | 26 | Model Zoo 27 | ------------ 28 | 29 | 30 | | Model | Layers | Embedding dim | Attention heads | Pair embedding dim | Pair hidden dim | FFN embedding dim | Learning rate | Batch size | 31 | |-----------|--------|---------------|-----------------|--------------------|-----------------|-------------------|---------------|-----------| 32 | | [**UniMol2-84M**](https://huggingface.co/dptech/Uni-Mol2/blob/main/modelzoo/84M/checkpoint.pt) | 12 | 768 | 48 | 512 | 64 | 768 | 1e-4 | 1024 | 33 | | [**UniMol2-164M**](https://huggingface.co/dptech/Uni-Mol2/blob/main/modelzoo/164M/checkpoint.pt) | 24 | 768 | 48 | 512 | 64 | 768 | 1e-4 | 1024 | 34 | | [**UniMol2-310M**](https://huggingface.co/dptech/Uni-Mol2/blob/main/modelzoo/310M/checkpoint.pt) | 32 | 1024 | 64 | 512 | 64 | 1024 | 1e-4 | 1024 | 35 | | [**UniMol2-570M**](https://huggingface.co/dptech/Uni-Mol2/blob/main/modelzoo/570M/checkpoint.pt) | 32 | 1536 | 96 | 512 | 64 | 1536 | 1e-4 | 1024 | 36 | | [**UniMol2-1.1B**](https://huggingface.co/dptech/Uni-Mol2/blob/main/modelzoo/1.1B/checkpoint.pt) | 64 | 1536 | 96 | 512 | 64 | 1536 | 1e-4 | 1024 | 37 | 38 | 39 | Downstream Finetune 40 | ------------ 41 | 42 | ``` 43 | task_name="qm9dft_v2" # molecular property prediction task name 44 | task_num=3 45 | weight_name="checkpoint.pt" 46 | loss_func="finetune_smooth_mae" 47 | arch_name=84M 48 | arch=unimol2_$arch_name 49 | 50 | 51 | data_path='Your Data Path" 52 | weight_path="Your Checkpoint Path" 53 | weight_path=$weight_path/$weight_name 54 | 55 | drop_feat_prob=1.0 56 | use_2d_pos=0.0 57 | ema_decay=0.999 58 | 59 | lr=1e-4 60 | batch_size=32 61 | epoch=40 62 | dropout=0 63 | warmup=0.06 64 | local_batch_size=16 65 | seed=0 66 | conf_size=11 67 | 68 | n_gpu=1 69 | reg_task="--reg" 70 | metric="valid_agg_mae" 71 | save_dir="./save_dir" 72 | 73 | update_freq=`expr $batch_size / $local_batch_size` 74 | global_batch_size=`expr $local_batch_size \* $n_gpu \* $update_freq` 75 | 76 | torchrun --standalone --nnodes=1 --nproc_per_node=$n_gpu \ 77 | $(which unicore-train) $data_path \ 78 | --task-name $task_name --user-dir ./unimol2 --train-subset train --valid-subset valid,test \ 79 | --conf-size $conf_size \ 80 | --num-workers 8 --ddp-backend=c10d \ 81 | --task mol_finetune --loss $loss_func --arch $arch \ 82 | --classification-head-name $task_name --num-classes $task_num \ 83 | --optimizer adam --adam-betas "(0.9, 0.99)" --adam-eps 1e-6 --clip-norm 1.0 \ 84 | --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch \ 85 | --batch-size $local_batch_size --pooler-dropout $dropout\ 86 | --update-freq $update_freq --seed $seed \ 87 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --no-save \ 88 | --log-interval 100 --log-format simple \ 89 | --validate-interval 1 \ 90 | --finetune-from-model $weight_path \ 91 | --best-checkpoint-metric $metric --patience 20 \ 92 | --save-dir $save_dir \ 93 | --drop-feat-prob ${drop_feat_prob} \ 94 | --use-2d-pos-prob ${use_2d_pos} \ 95 | $more_args \ 96 | $reg_task \ 97 | --find-unused-parameters 98 | ``` 99 | 100 | 101 | Citation 102 | ------------ 103 | 104 | Please kindly cite this paper if you use the data/code/model. 105 | ``` 106 | @article{ji2024uni, 107 | title={Uni-Mol2: Exploring Molecular Pretraining Model at Scale}, 108 | author={Xiaohong, Ji and Zhen, Wang and Zhifeng, Gao and Hang, Zheng and Linfeng, Zhang and Guolin, Ke and Weinan, E}, 109 | journal={arXiv preprint arXiv:2406.14969}, 110 | year={2024} 111 | } 112 | 113 | ``` 114 | 115 | License 116 | ------- 117 | 118 | This project is licensed under the terms of the MIT license. See [LICENSE](https://github.com/deepmodeling/Uni-Mol/blob/main/LICENSE) for additional details. 119 | -------------------------------------------------------------------------------- /unimol2/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM dptechnology/unicore:latest-pytorch2.0.1-cuda11.7-rdma 2 | RUN pip install rdkit==2022.09.5 3 | -------------------------------------------------------------------------------- /unimol2/figure/predicted_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/Uni-Mol/90f52c41299a1a582da0f9765e9f87aa21faa16a/unimol2/figure/predicted_loss.jpg -------------------------------------------------------------------------------- /unimol2/figure/unimol2_arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/Uni-Mol/90f52c41299a1a582da0f9765e9f87aa21faa16a/unimol2/figure/unimol2_arch.jpg -------------------------------------------------------------------------------- /unimol2/requirements.txt: -------------------------------------------------------------------------------- 1 | git+git://github.com/dptech-corp/Uni-Core.git@stable#egg=Uni-Core 2 | rdkit-pypi==2022.9.5 3 | -------------------------------------------------------------------------------- /unimol2/unimol2/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import unimol2.tasks 3 | import unimol2.data 4 | import unimol2.models 5 | import unimol2.losses -------------------------------------------------------------------------------- /unimol2/unimol2/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .key_dataset import KeyDataset 2 | from .normalize_dataset import ( 3 | NormalizeDataset, 4 | ) 5 | from .remove_hydrogen_dataset import ( 6 | RemoveHydrogenDataset, 7 | ) 8 | from .tta_dataset import ( 9 | TTADataset, 10 | ) 11 | from .cropping_dataset import CroppingDataset 12 | from .add_2d_conformer_dataset import Add2DConformerDataset 13 | from .conformer_sample_dataset import ConformerSampleDataset 14 | 15 | from .graph_features import PairTypeDataset 16 | 17 | from .molecule_dataset import MoleculeFeatureDataset 18 | from .noised_points_dataset import NoisePointsDataset, PadBiasDataset2D, AttnBiasDataset 19 | 20 | from .lmdb_dataset import LMDBDataset 21 | from .unimol2_dataset import Unimol2FeatureDataset, Unimol2FinetuneFeatureDataset 22 | from .index_atom_dataset import IndexAtomDataset 23 | 24 | __all__ = [] -------------------------------------------------------------------------------- /unimol2/unimol2/data/add_2d_conformer_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | from rdkit import Chem 9 | from rdkit.Chem import AllChem 10 | 11 | 12 | class Add2DConformerDataset(BaseWrapperDataset): 13 | def __init__(self, dataset, smi, atoms, coordinates): 14 | self.dataset = dataset 15 | self.smi = smi 16 | self.atoms = atoms 17 | self.coordinates = coordinates 18 | self.set_epoch(None) 19 | 20 | def set_epoch(self, epoch, **unused): 21 | super().set_epoch(epoch) 22 | self.epoch = epoch 23 | 24 | @lru_cache(maxsize=16) 25 | def __cached_item__(self, index: int, epoch: int): 26 | atoms = np.array(self.dataset[index][self.atoms]) 27 | assert len(atoms) > 0 28 | smi = self.dataset[index][self.smi] 29 | coordinates_2d = smi2_2Dcoords(smi) 30 | coordinates = self.dataset[index][self.coordinates] 31 | return { 32 | "smi": smi, 33 | "atoms": atoms, 34 | "coordinates": coordinates, 35 | "coordinates_2d": coordinates_2d 36 | } 37 | 38 | def __getitem__(self, index: int): 39 | return self.__cached_item__(index, self.epoch) 40 | 41 | 42 | def smi2_2Dcoords(smi): 43 | mol = Chem.MolFromSmiles(smi) 44 | mol = AllChem.AddHs(mol) 45 | AllChem.Compute2DCoords(mol) 46 | coordinates = mol.GetConformer().GetPositions().astype(np.float32) 47 | len(mol.GetAtoms()) == len( 48 | coordinates 49 | ), "2D coordinates shape is not align with {}".format(smi) 50 | return coordinates 51 | -------------------------------------------------------------------------------- /unimol2/unimol2/data/conformer_sample_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | from . import data_utils 9 | 10 | 11 | class ConformerSampleDataset(BaseWrapperDataset): 12 | def __init__(self, dataset, seed, atoms, coordinates, coordinates_2d): 13 | self.dataset = dataset 14 | self.seed = seed 15 | self.atoms = atoms 16 | self.coordinates = coordinates 17 | self.coordinates_2d = coordinates_2d 18 | self.set_epoch(None) 19 | 20 | def set_epoch(self, epoch, **unused): 21 | super().set_epoch(epoch) 22 | self.epoch = epoch 23 | 24 | @lru_cache(maxsize=16) 25 | def __cached_item__(self, index: int, epoch: int): 26 | atoms = np.array(self.dataset[index][self.atoms]) 27 | assert len(atoms) > 0 28 | coordinates_list = self.dataset[index][self.coordinates] 29 | if not isinstance(coordinates_list, list): 30 | coordinates_list = [coordinates_list] 31 | 32 | size = len(coordinates_list) 33 | with data_utils.numpy_seed(self.seed, epoch, index): 34 | sample_idx = np.random.randint(size) 35 | coordinates = coordinates_list[sample_idx] 36 | 37 | return { 38 | "atoms": atoms, 39 | "coordinates": coordinates.astype(np.float32), 40 | "coordinates_2d": self.dataset[index][self.coordinates_2d] 41 | } 42 | 43 | def __getitem__(self, index: int): 44 | return self.__cached_item__(index, self.epoch) -------------------------------------------------------------------------------- /unimol2/unimol2/data/cropping_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | import logging 8 | from unicore.data import BaseWrapperDataset 9 | from . import data_utils 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class CroppingDataset(BaseWrapperDataset): 15 | def __init__(self, dataset, seed, atoms, coordinates, coordinates_2d, max_atoms=256): 16 | self.dataset = dataset 17 | self.seed = seed 18 | self.atoms = atoms 19 | self.coordinates = coordinates 20 | self.coordinates_2d = coordinates_2d 21 | self.max_atoms = max_atoms 22 | self.set_epoch(None) 23 | 24 | def set_epoch(self, epoch, **unused): 25 | super().set_epoch(epoch) 26 | self.epoch = epoch 27 | 28 | @lru_cache(maxsize=16) 29 | def __cached_item__(self, index: int, epoch: int): 30 | dd = self.dataset[index].copy() 31 | atoms = dd[self.atoms] 32 | coordinates = dd[self.coordinates] 33 | coordinates_2d = dd[self.coordinates_2d] 34 | if self.max_atoms and len(atoms) > self.max_atoms: 35 | with data_utils.numpy_seed(self.seed, epoch, index): 36 | index = np.random.choice(len(atoms), self.max_atoms, replace=False) 37 | atoms = np.array(atoms)[index] 38 | coordinates = coordinates[index] 39 | coordinates_2d = coordinates_2d[index] 40 | dd[self.atoms] = atoms 41 | dd[self.coordinates] = coordinates.astype(np.float32) 42 | dd[self.coordinates_2d] = coordinates_2d.astype(np.float32) 43 | return dd 44 | 45 | def __getitem__(self, index: int): 46 | return self.__cached_item__(index, self.epoch) -------------------------------------------------------------------------------- /unimol2/unimol2/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | import contextlib 7 | 8 | 9 | @contextlib.contextmanager 10 | def numpy_seed(seed, *addl_seeds): 11 | """Context manager which seeds the NumPy PRNG with the specified seed and 12 | restores the state afterward""" 13 | if seed is None: 14 | yield 15 | return 16 | if len(addl_seeds) > 0: 17 | seed = int(hash((seed, *addl_seeds)) % 1e6) 18 | state = np.random.get_state() 19 | np.random.seed(seed) 20 | try: 21 | yield 22 | finally: 23 | np.random.set_state(state) 24 | -------------------------------------------------------------------------------- /unimol2/unimol2/data/index_atom_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from functools import lru_cache 9 | from unicore.data import BaseWrapperDataset 10 | 11 | from rdkit import Chem 12 | from rdkit.Chem import AllChem 13 | 14 | 15 | class IndexAtomDataset(BaseWrapperDataset): 16 | def __init__( 17 | self, 18 | smi_dataset: torch.utils.data.Dataset, 19 | token_dataset: torch.utils.data.Dataset, 20 | ): 21 | super().__init__(smi_dataset) 22 | self.smi_dataset = smi_dataset 23 | self.token_dataset = token_dataset 24 | self.set_epoch(None) 25 | 26 | def set_epoch(self, epoch, **unused): 27 | super().set_epoch(epoch) 28 | self.epoch = epoch 29 | 30 | @lru_cache(maxsize=16) 31 | def __getitem__(self, index: int): 32 | atoms = self.token_dataset[index] 33 | 34 | atom_index = [ 35 | AllChem.GetPeriodicTable().GetAtomicNumber(item) for item in atoms 36 | ] 37 | 38 | return np.array(atom_index) -------------------------------------------------------------------------------- /unimol2/unimol2/data/key_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from functools import lru_cache 6 | from unicore.data import BaseWrapperDataset 7 | 8 | 9 | class KeyDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, key): 11 | self.dataset = dataset 12 | self.key = key 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | @lru_cache(maxsize=16) 18 | def __getitem__(self, idx): 19 | return self.dataset[idx][self.key] 20 | -------------------------------------------------------------------------------- /unimol2/unimol2/data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | 6 | import lmdb 7 | import os 8 | import pickle 9 | from functools import lru_cache 10 | import logging 11 | import shutil 12 | import time 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | # async ckp copy 18 | def lmdb_data_copy_fun(src_path_dir, target_path_dir, epoch, split): 19 | db_path_src = os.path.join(src_path_dir, "{}_part_{}.lmdb".format(split, epoch)) 20 | db_path_tgt = os.path.join(target_path_dir, "{}_part_{}.lmdb".format(split, epoch)) 21 | 22 | if os.path.exists(db_path_tgt): 23 | return 24 | 25 | if not os.path.exists(db_path_src): 26 | logger.warning(f"please not that {db_path_src} not exists.") 27 | return 28 | 29 | shutil.copyfile(db_path_src, db_path_tgt) 30 | 31 | last_db_path_tgt = os.path.join(target_path_dir, "{}_part_{}.lmdb".format(split, epoch-2)) 32 | if os.path.exists(last_db_path_tgt): 33 | os.remove(last_db_path_tgt) 34 | 35 | logger.info(f"finished async copy file from {db_path_src} to {db_path_tgt}") 36 | return 37 | 38 | class LMDBDataset(): 39 | def __init__(self, db_dir, split, epoch, max_epoch, lmdb_copy_thread=None, tmp_data_dir="/temp/"): 40 | self.db_dir = db_dir 41 | if not os.path.exists(tmp_data_dir): 42 | os.makedirs(tmp_data_dir, exist_ok=True) 43 | 44 | self.tmp_data_dir = tmp_data_dir 45 | self.lmdb_copy_thread = lmdb_copy_thread 46 | self.split = split 47 | self.max_epoch = max_epoch 48 | self.content = self.load_data(self.tmp_data_dir, epoch, split) 49 | 50 | def load_data(self, data_dir, epoch, split): 51 | self.db_path_tgt = os.path.join(data_dir, "{}_part_{}.lmdb".format(split, epoch-1)) 52 | 53 | self.db_path_tgt_lock = os.path.join(data_dir, "{}_part_{}.lmdb.lock".format(split, epoch-1)) 54 | if not os.path.exists(self.db_path_tgt): 55 | self.db_path_src = os.path.join(self.db_dir, "{}_part_{}.lmdb".format(split, epoch-1)) 56 | if not os.path.exists(self.db_path_src): 57 | raise FileNotFoundError(f"{0} not found, please make sure the max-epoch were setting right.".format(self.db_path_src)) 58 | os.system(f"touch {self.db_path_tgt_lock}") 59 | shutil.copyfile(self.db_path_src, self.db_path_tgt) 60 | logger.info(f"{self.db_path_tgt} not exist, copy file from {self.db_path_src}") 61 | os.system(f"rm -rf {self.db_path_tgt_lock}") 62 | else: 63 | while os.path.exists(self.db_path_tgt_lock): 64 | time.sleep(1) 65 | 66 | env = lmdb.open( 67 | self.db_path_tgt, 68 | subdir=False, 69 | readonly=True, 70 | lock=False, 71 | readahead=False, 72 | meminit=False, 73 | max_readers=256, 74 | ) 75 | content = [] 76 | with env.begin() as txn: 77 | self._keys = list(range(txn.stat()['entries'])) 78 | for idx in self._keys: 79 | datapoint_pickled = txn.get( idx.to_bytes(4, byteorder="big") ) 80 | content.append(datapoint_pickled) 81 | 82 | if self.lmdb_copy_thread is not None: 83 | self.lmdb_copy_thread.apply_async(lmdb_data_copy_fun, 84 | (self.db_dir, self.tmp_data_dir, epoch, self.split)) 85 | return content 86 | 87 | def __len__(self): 88 | return len(self._keys) 89 | 90 | def set_epoch(self, epoch): 91 | if epoch is not None and epoch < self.max_epoch: 92 | self.content = self.load_data(self.tmp_data_dir, epoch, self.split) 93 | 94 | @lru_cache(maxsize=16) 95 | def __getitem__(self, idx): 96 | datapoint_pickled = self.content[idx] 97 | data = pickle.loads(datapoint_pickled) 98 | return data 99 | -------------------------------------------------------------------------------- /unimol2/unimol2/data/normalize_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class NormalizeDataset(BaseWrapperDataset): 11 | def __init__(self, dataset, coordinates, coordinates_2d, normalize_coord=True): 12 | self.dataset = dataset 13 | self.coordinates = coordinates 14 | self.coordinates_2d = coordinates_2d 15 | self.normalize_coord = normalize_coord # normalize the coordinates. 16 | self.set_epoch(None) 17 | 18 | def set_epoch(self, epoch, **unused): 19 | super().set_epoch(epoch) 20 | self.epoch = epoch 21 | 22 | @lru_cache(maxsize=16) 23 | def __cached_item__(self, index: int, epoch: int): 24 | dd = self.dataset[index].copy() 25 | coordinates = dd[self.coordinates] 26 | coordinates_2d = dd[self.coordinates_2d] 27 | # normalize 28 | if self.normalize_coord: 29 | coordinates = coordinates - coordinates.mean(axis=0) 30 | dd[self.coordinates] = coordinates.astype(np.float32) 31 | dd[self.coordinates_2d] = coordinates_2d - coordinates_2d.mean(axis=0) 32 | return dd 33 | 34 | def __getitem__(self, index: int): 35 | return self.__cached_item__(index, self.epoch) 36 | 37 | 38 | class NormalizeDockingPoseDataset(BaseWrapperDataset): 39 | def __init__( 40 | self, 41 | dataset, 42 | coordinates, 43 | pocket_coordinates, 44 | center_coordinates="center_coordinates", 45 | ): 46 | self.dataset = dataset 47 | self.coordinates = coordinates 48 | self.pocket_coordinates = pocket_coordinates 49 | self.center_coordinates = center_coordinates 50 | self.set_epoch(None) 51 | 52 | def set_epoch(self, epoch, **unused): 53 | super().set_epoch(epoch) 54 | self.epoch = epoch 55 | 56 | @lru_cache(maxsize=16) 57 | def __cached_item__(self, index: int, epoch: int): 58 | dd = self.dataset[index].copy() 59 | coordinates = dd[self.coordinates] 60 | pocket_coordinates = dd[self.pocket_coordinates] 61 | # normalize coordinates and pocket coordinates ,align with pocket center coordinates 62 | center_coordinates = pocket_coordinates.mean(axis=0) 63 | coordinates = coordinates - center_coordinates 64 | pocket_coordinates = pocket_coordinates - center_coordinates 65 | dd[self.coordinates] = coordinates.astype(np.float32) 66 | dd[self.pocket_coordinates] = pocket_coordinates.astype(np.float32) 67 | dd[self.center_coordinates] = center_coordinates.astype(np.float32) 68 | return dd 69 | 70 | def __getitem__(self, index: int): 71 | return self.__cached_item__(index, self.epoch) 72 | -------------------------------------------------------------------------------- /unimol2/unimol2/data/remove_hydrogen_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class RemoveHydrogenDataset(BaseWrapperDataset): 11 | def __init__( 12 | self, 13 | dataset, 14 | atoms, 15 | coordinates, 16 | coordinates_2d, 17 | remove_hydrogen=False, 18 | ): 19 | self.dataset = dataset 20 | self.atoms = atoms 21 | self.coordinates = coordinates 22 | self.coordinates_2d = coordinates_2d 23 | self.remove_hydrogen = remove_hydrogen 24 | self.set_epoch(None) 25 | 26 | def set_epoch(self, epoch, **unused): 27 | super().set_epoch(epoch) 28 | self.epoch = epoch 29 | 30 | @lru_cache(maxsize=16) 31 | def __cached_item__(self, index: int, epoch: int): 32 | dd = self.dataset[index].copy() 33 | atoms = dd[self.atoms] 34 | coordinates = dd[self.coordinates] 35 | coordinates_2d = dd[self.coordinates_2d] 36 | 37 | if self.remove_hydrogen: 38 | mask_hydrogen = atoms != "H" 39 | atoms = atoms[mask_hydrogen] 40 | coordinates = coordinates[mask_hydrogen] 41 | coordinates_2d = coordinates_2d[mask_hydrogen] 42 | dd[self.atoms] = atoms 43 | dd[self.coordinates] = coordinates.astype(np.float32) 44 | dd[self.coordinates_2d] = coordinates_2d.astype(np.float32) 45 | return dd 46 | 47 | def __getitem__(self, index: int): 48 | return self.__cached_item__(index, self.epoch) -------------------------------------------------------------------------------- /unimol2/unimol2/data/tta_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class TTADataset(BaseWrapperDataset): 11 | def __init__(self, dataset, seed, atoms, coordinates, conf_size=10): 12 | self.dataset = dataset 13 | self.seed = seed 14 | self.atoms = atoms 15 | self.coordinates = coordinates 16 | self.conf_size = conf_size 17 | self.set_epoch(None) 18 | 19 | def set_epoch(self, epoch, **unused): 20 | super().set_epoch(epoch) 21 | self.epoch = epoch 22 | 23 | def __len__(self): 24 | return len(self.dataset) * self.conf_size 25 | 26 | @lru_cache(maxsize=16) 27 | def __cached_item__(self, index: int, epoch: int): 28 | smi_idx = index // self.conf_size 29 | coord_idx = index % self.conf_size 30 | atoms = np.array(self.dataset[smi_idx][self.atoms]) 31 | coordinates = np.array(self.dataset[smi_idx][self.coordinates][coord_idx]) 32 | smi = self.dataset[smi_idx]["smi"] 33 | target = self.dataset[smi_idx]["target"] 34 | return { 35 | "atoms": atoms, 36 | "coordinates": coordinates.astype(np.float32), 37 | "smi": smi, 38 | "target": target, 39 | } 40 | 41 | def __getitem__(self, index: int): 42 | return self.__cached_item__(index, self.epoch) -------------------------------------------------------------------------------- /unimol2/unimol2/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) DP Techonology, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | import pickle 11 | import torch 12 | from unicore import checkpoint_utils, distributed_utils, options, utils 13 | from unicore.logging import progress_bar 14 | from unicore import tasks 15 | 16 | logging.basicConfig( 17 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 18 | datefmt="%Y-%m-%d %H:%M:%S", 19 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 20 | stream=sys.stdout, 21 | ) 22 | logger = logging.getLogger("unimol.inference") 23 | 24 | 25 | def main(args): 26 | 27 | assert ( 28 | args.batch_size is not None 29 | ), "Must specify batch size either with --batch-size" 30 | 31 | use_fp16 = args.fp16 32 | use_cuda = torch.cuda.is_available() and not args.cpu 33 | 34 | if use_cuda: 35 | torch.cuda.set_device(args.device_id) 36 | 37 | if args.distributed_world_size > 1: 38 | data_parallel_world_size = distributed_utils.get_data_parallel_world_size() 39 | data_parallel_rank = distributed_utils.get_data_parallel_rank() 40 | else: 41 | data_parallel_world_size = 1 42 | data_parallel_rank = 0 43 | 44 | # Load model 45 | logger.info("loading model(s) from {}".format(args.path)) 46 | state = checkpoint_utils.load_checkpoint_to_cpu(args.path) 47 | task = tasks.setup_task(args) 48 | model = task.build_model(args) 49 | model.load_state_dict(state["model"], strict=False) 50 | 51 | # Move models to GPU 52 | if use_fp16: 53 | model.half() 54 | if use_cuda: 55 | model.cuda() 56 | 57 | # Print args 58 | logger.info(args) 59 | 60 | # Build loss 61 | loss = task.build_loss(args) 62 | loss.eval() 63 | 64 | for subset in args.valid_subset.split(","): 65 | try: 66 | task.load_dataset(subset, combine=False, epoch=1) 67 | dataset = task.dataset(subset) 68 | except KeyError: 69 | raise Exception("Cannot find dataset: " + subset) 70 | 71 | if not os.path.exists(args.results_path): 72 | os.makedirs(args.results_path) 73 | fname = (args.path).split("/")[-2] 74 | save_path = os.path.join(args.results_path, fname + "_" + subset + ".out.pkl") 75 | # Initialize data iterator 76 | itr = task.get_batch_iterator( 77 | dataset=dataset, 78 | batch_size=args.batch_size, 79 | ignore_invalid_inputs=True, 80 | required_batch_size_multiple=args.required_batch_size_multiple, 81 | seed=args.seed, 82 | num_shards=data_parallel_world_size, 83 | shard_id=data_parallel_rank, 84 | num_workers=args.num_workers, 85 | data_buffer_size=args.data_buffer_size, 86 | ).next_epoch_itr(shuffle=False) 87 | progress = progress_bar.progress_bar( 88 | itr, 89 | log_format=args.log_format, 90 | log_interval=args.log_interval, 91 | prefix=f"valid on '{subset}' subset", 92 | default_log_format=("tqdm" if not args.no_progress_bar else "simple"), 93 | ) 94 | log_outputs = [] 95 | for i, sample in enumerate(progress): 96 | sample = utils.move_to_cuda(sample) if use_cuda else sample 97 | if len(sample) == 0: 98 | continue 99 | _, _, log_output = task.valid_step(sample, model, loss, test=True) 100 | progress.log({}, step=i) 101 | log_outputs.append(log_output) 102 | pickle.dump(log_outputs, open(save_path, "wb")) 103 | logger.info("Done inference! ") 104 | return None 105 | 106 | 107 | def cli_main(): 108 | parser = options.get_validation_parser() 109 | options.add_model_args(parser) 110 | args = options.parse_args_and_arch(parser) 111 | 112 | distributed_utils.call_main(args, main) 113 | 114 | 115 | if __name__ == "__main__": 116 | cli_main() 117 | -------------------------------------------------------------------------------- /unimol2/unimol2/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol2.losses." + file.name[:-3]) 8 | -------------------------------------------------------------------------------- /unimol2/unimol2/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unimol2 import UniMol2Model 2 | from .transformer_encoder_with_pair import TransformerEncoderWithPair -------------------------------------------------------------------------------- /unimol2/unimol2/models/transformer_encoder_with_pair.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from typing import Optional, Tuple 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from unicore.modules import LayerNorm 12 | from .layers import TransformerEncoderLayer, Dropout 13 | 14 | 15 | class TransformerEncoderWithPair(nn.Module): 16 | def __init__( 17 | self, 18 | num_encoder_layers: int = 6, 19 | embedding_dim: int = 768, 20 | 21 | pair_dim: int = 64, 22 | pair_hidden_dim: int = 32, 23 | 24 | ffn_embedding_dim: int = 3072, 25 | num_attention_heads: int = 8, 26 | dropout: float = 0.1, 27 | attention_dropout: float = 0.1, 28 | activation_dropout: float = 0.0, 29 | activation_fn: str = "gelu", 30 | droppath_prob: float = 0.0, 31 | pair_dropout: float = 0.25, 32 | ) -> None: 33 | super().__init__() 34 | self.embedding_dim = embedding_dim 35 | self.num_head = num_attention_heads 36 | self.layer_norm = LayerNorm(embedding_dim) 37 | self.pair_layer_norm = LayerNorm(pair_dim) 38 | self.layers = nn.ModuleList([]) 39 | 40 | if droppath_prob > 0: 41 | droppath_probs = [ 42 | x.item() for x in torch.linspace(0, droppath_prob, num_encoder_layers) 43 | ] 44 | else: 45 | droppath_probs = None 46 | 47 | self.layers.extend( 48 | [ 49 | TransformerEncoderLayer( 50 | embedding_dim=embedding_dim, 51 | pair_dim=pair_dim, 52 | pair_hidden_dim=pair_hidden_dim, 53 | ffn_embedding_dim=ffn_embedding_dim, 54 | num_attention_heads=num_attention_heads, 55 | dropout=dropout, 56 | attention_dropout=attention_dropout, 57 | activation_dropout=activation_dropout, 58 | activation_fn=activation_fn, 59 | droppath_prob=droppath_probs[i] 60 | if droppath_probs is not None 61 | else 0, 62 | pair_dropout=pair_dropout, 63 | ) 64 | for i in range(num_encoder_layers) 65 | ] 66 | ) 67 | 68 | def forward( 69 | self, 70 | x, 71 | pair, 72 | atom_mask, 73 | pair_mask, 74 | attn_mask=None, 75 | ) -> Tuple[torch.Tensor, torch.Tensor]: 76 | 77 | x = self.layer_norm(x) 78 | pair = self.pair_layer_norm(pair) 79 | op_mask = atom_mask.unsqueeze(-1) 80 | op_mask = op_mask * (op_mask.size(-2) ** -0.5) 81 | eps = 1e-3 82 | op_norm = 1.0 / (eps + torch.einsum("...bc,...dc->...bdc", op_mask, op_mask)) 83 | for layer in self.layers: 84 | x, pair = layer( 85 | x, 86 | pair, 87 | pair_mask=pair_mask, 88 | self_attn_mask=attn_mask, 89 | op_mask=op_mask, 90 | op_norm=op_norm, 91 | ) 92 | return x, pair -------------------------------------------------------------------------------- /unimol2/unimol2/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol2.tasks." + file.name[:-3]) 8 | -------------------------------------------------------------------------------- /unimol_docking_v2/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM dptechnology/unicore:latest-pytorch1.12.1-cuda11.6-rdma 2 | 3 | RUN pip install rdkit-pypi==2022.9.3 4 | RUN pip install biopandas 5 | 6 | RUN ldconfig && \ 7 | apt-get clean && \ 8 | apt-get autoremove && \ 9 | rm -rf /var/lib/apt/lists/* /tmp/* && \ 10 | conda clean -ya -------------------------------------------------------------------------------- /unimol_docking_v2/example_data/dict_mol.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [CLS] 3 | [SEP] 4 | [UNK] 5 | C 6 | N 7 | O 8 | S 9 | H 10 | Cl 11 | F 12 | Br 13 | I 14 | Si 15 | P 16 | B 17 | Na 18 | K 19 | Al 20 | Ca 21 | Sn 22 | As 23 | Hg 24 | Fe 25 | Zn 26 | Cr 27 | Se 28 | Gd 29 | Au 30 | Li -------------------------------------------------------------------------------- /unimol_docking_v2/example_data/dict_pkt.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [CLS] 3 | [SEP] 4 | [UNK] 5 | C 6 | N 7 | O 8 | S 9 | H -------------------------------------------------------------------------------- /unimol_docking_v2/example_data/docking_grid.json: -------------------------------------------------------------------------------- 1 | { 2 | "center_x": 8.728845596313477, 3 | "center_y": 25.618770599365234, 4 | "center_z": 4.682269096374512, 5 | "size_x": 19.121000289916992, 6 | "size_y": 16.562999725341797, 7 | "size_z": 18.64900016784668 8 | } -------------------------------------------------------------------------------- /unimol_docking_v2/figure/bohrium_app.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/Uni-Mol/90f52c41299a1a582da0f9765e9f87aa21faa16a/unimol_docking_v2/figure/bohrium_app.gif -------------------------------------------------------------------------------- /unimol_docking_v2/interface/demo.sh: -------------------------------------------------------------------------------- 1 | python demo.py --mode single --conf-size 10 --cluster \ 2 | --input-protein ../example_data/protein.pdb \ 3 | --input-ligand ../example_data/ligand.sdf \ 4 | --input-docking-grid ../example_data/docking_grid.json \ 5 | --output-ligand-name ligand_predict \ 6 | --output-ligand-dir predict_sdf \ 7 | --steric-clash-fix \ 8 | --model-dir checkpoint_best.pt -------------------------------------------------------------------------------- /unimol_docking_v2/interface/demo_batch_one2one.sh: -------------------------------------------------------------------------------- 1 | python demo.py --mode batch_one2one --batch-size 8 --conf-size 10 --cluster \ 2 | --input-batch-file input_batch_one2one.csv \ 3 | --output-ligand-dir predict_sdf \ 4 | --steric-clash-fix \ 5 | --model-dir checkpoint_best.pt 6 | -------------------------------------------------------------------------------- /unimol_docking_v2/interface/input_batch_one2one.csv: -------------------------------------------------------------------------------- 1 | input_protein,input_ligand,input_docking_grid,output_ligand_name 2 | protein1.pdb,ligand_prepared1.sdf,docking_grid1.json,ligand_predict1 3 | protein2.pdb,ligand_prepared2.sdf,docking_grid2.json,ligand_predict2 4 | protein3.pdb,ligand_prepared3.sdf,docking_grid3.json,ligand_predict3 -------------------------------------------------------------------------------- /unimol_docking_v2/interface/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from .unimol_predictor import UnimolPredictor 2 | from .processor import Processor -------------------------------------------------------------------------------- /unimol_docking_v2/train.sh: -------------------------------------------------------------------------------- 1 | data_path="./protein_ligand_binding_pose_prediction_v2" # replace to your data path 2 | save_dir="./save_pose" # replace to your save path 3 | n_gpu=8 4 | MASTER_PORT=10086 5 | finetune_mol_model="./weights/mol_checkpoint.pt" 6 | finetune_pocket_model="./weights/pocket_checkpoint.pt" 7 | lr=3e-4 8 | batch_size=8 9 | epoch=100 10 | dropout=0.2 11 | warmup=0.06 12 | update_freq=1 13 | dist_threshold=8.0 14 | recycling=4 15 | 16 | export NCCL_ASYNC_ERROR_HANDLING=1 17 | export OMP_NUM_THREADS=1 18 | python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) --user-dir ./unimol $data_path --train-subset train --valid-subset valid \ 19 | --num-workers 8 --ddp-backend=c10d \ 20 | --task docking_pose_v2 --loss docking_pose_v2 --arch docking_pose_v2 \ 21 | --optimizer adam --adam-betas '(0.9, 0.99)' --adam-eps 1e-6 --clip-norm 1.0 \ 22 | --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size \ 23 | --mol-pooler-dropout $dropout --pocket-pooler-dropout $dropout \ 24 | --update-freq $update_freq --seed 42 \ 25 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ 26 | --tensorboard-logdir $save_dir/tsb \ 27 | --log-interval 100 --log-format simple \ 28 | --validate-interval 1 --keep-last-epochs 10 \ 29 | --best-checkpoint-metric valid_loss --patience 2000 --all-gather-list-size 1024000 \ 30 | --finetune-mol-model $finetune_mol_model \ 31 | --finetune-pocket-model $finetune_pocket_model \ 32 | --dist-threshold $dist_threshold --recycling $recycling \ 33 | --save-dir $save_dir \ 34 | --find-unused-parameters \ 35 | --required-batch-size-multiple 1 36 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import unimol.tasks 3 | import unimol.data 4 | import unimol.models 5 | import unimol.losses 6 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .key_dataset import KeyDataset 2 | from .normalize_dataset import ( 3 | NormalizeDataset, 4 | NormalizeDockingPoseDataset, 5 | ) 6 | from .remove_hydrogen_dataset import ( 7 | RemoveHydrogenPocketDataset, 8 | ) 9 | from .tta_dataset import ( 10 | TTADockingPoseDataset, 11 | ) 12 | from .cropping_dataset import ( 13 | CroppingPocketDataset, 14 | ) 15 | from .distance_dataset import ( 16 | DistanceDataset, 17 | EdgeTypeDataset, 18 | CrossDistanceDataset, 19 | ) 20 | from .conformer_sample_dataset import ( 21 | ConformerSampleDockingPoseDataset, 22 | ) 23 | from .coord_pad_dataset import RightPadDatasetCoord, RightPadDatasetCross2D 24 | from .lmdb_dataset import LMDBDataset 25 | from .prepend_and_append_2d_dataset import PrependAndAppend2DDataset 26 | from .realign_ligand_dataset import ReAlignLigandDataset 27 | 28 | __all__ = [] -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/conformer_sample_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | from . import data_utils 9 | 10 | 11 | class ConformerSampleDockingPoseDataset(BaseWrapperDataset): 12 | def __init__( 13 | self, 14 | dataset, 15 | seed, 16 | atoms, 17 | coordinates, 18 | pocket_atoms, 19 | pocket_coordinates, 20 | holo_coordinates, 21 | holo_pocket_coordinates, 22 | is_train=True, 23 | ): 24 | self.dataset = dataset 25 | self.seed = seed 26 | self.atoms = atoms 27 | self.coordinates = coordinates 28 | self.pocket_atoms = pocket_atoms 29 | self.pocket_coordinates = pocket_coordinates 30 | self.holo_coordinates = holo_coordinates 31 | self.holo_pocket_coordinates = holo_pocket_coordinates 32 | self.is_train = is_train 33 | self.set_epoch(None) 34 | 35 | def set_epoch(self, epoch, **unused): 36 | super().set_epoch(epoch) 37 | self.epoch = epoch 38 | 39 | @lru_cache(maxsize=16) 40 | def __cached_item__(self, index: int, epoch: int): 41 | atoms = np.array(self.dataset[index][self.atoms]) 42 | size = len(self.dataset[index][self.coordinates]) 43 | with data_utils.numpy_seed(self.seed, epoch, index): 44 | sample_idx = np.random.randint(size) 45 | coordinates = self.dataset[index][self.coordinates][sample_idx] 46 | pocket_atoms = np.array( 47 | [item[0] for item in self.dataset[index][self.pocket_atoms]] 48 | ) 49 | pocket_coordinates = self.dataset[index][self.pocket_coordinates][0] 50 | if self.is_train: 51 | holo_coordinates = self.dataset[index][self.holo_coordinates][0] 52 | holo_pocket_coordinates = self.dataset[index][self.holo_pocket_coordinates][ 53 | 0 54 | ] 55 | else: 56 | holo_coordinates = coordinates 57 | holo_pocket_coordinates = pocket_coordinates 58 | 59 | smi = self.dataset[index]["smi"] 60 | pocket = self.dataset[index]["pocket"] 61 | 62 | return { 63 | "atoms": atoms, 64 | "coordinates": coordinates.astype(np.float32), 65 | "pocket_atoms": pocket_atoms, 66 | "pocket_coordinates": pocket_coordinates.astype(np.float32), 67 | "holo_coordinates": holo_coordinates.astype(np.float32), 68 | "holo_pocket_coordinates": holo_pocket_coordinates.astype(np.float32), 69 | "smi": smi, 70 | "pocket": pocket, 71 | } 72 | 73 | def __getitem__(self, index: int): 74 | return self.__cached_item__(index, self.epoch) 75 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/coord_pad_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from unicore.data import BaseWrapperDataset 6 | 7 | 8 | def collate_tokens_coords( 9 | values, 10 | pad_idx, 11 | left_pad=False, 12 | pad_to_length=None, 13 | pad_to_multiple=1, 14 | ): 15 | """Convert a list of 1d tensors into a padded 2d tensor.""" 16 | size = max(v.size(0) for v in values) 17 | size = size if pad_to_length is None else max(size, pad_to_length) 18 | if pad_to_multiple != 1 and size % pad_to_multiple != 0: 19 | size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) 20 | res = values[0].new(len(values), size, 3).fill_(pad_idx) 21 | 22 | def copy_tensor(src, dst): 23 | assert dst.numel() == src.numel() 24 | dst.copy_(src) 25 | 26 | for i, v in enumerate(values): 27 | copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :]) 28 | return res 29 | 30 | 31 | class RightPadDatasetCoord(BaseWrapperDataset): 32 | def __init__(self, dataset, pad_idx, left_pad=False): 33 | super().__init__(dataset) 34 | self.pad_idx = pad_idx 35 | self.left_pad = left_pad 36 | 37 | def collater(self, samples): 38 | return collate_tokens_coords( 39 | samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8 40 | ) 41 | 42 | 43 | def collate_cross_2d( 44 | values, 45 | pad_idx, 46 | left_pad=False, 47 | pad_to_length=None, 48 | pad_to_multiple=1, 49 | ): 50 | """Convert a list of 2d tensors into a padded 2d tensor.""" 51 | size_h = max(v.size(0) for v in values) 52 | size_w = max(v.size(1) for v in values) 53 | if pad_to_multiple != 1 and size_h % pad_to_multiple != 0: 54 | size_h = int(((size_h - 0.1) // pad_to_multiple + 1) * pad_to_multiple) 55 | if pad_to_multiple != 1 and size_w % pad_to_multiple != 0: 56 | size_w = int(((size_w - 0.1) // pad_to_multiple + 1) * pad_to_multiple) 57 | res = values[0].new(len(values), size_h, size_w).fill_(pad_idx) 58 | 59 | def copy_tensor(src, dst): 60 | assert dst.numel() == src.numel() 61 | dst.copy_(src) 62 | 63 | for i, v in enumerate(values): 64 | copy_tensor( 65 | v, 66 | res[i][size_h - v.size(0) :, size_w - v.size(1) :] 67 | if left_pad 68 | else res[i][: v.size(0), : v.size(1)], 69 | ) 70 | return res 71 | 72 | 73 | class RightPadDatasetCross2D(BaseWrapperDataset): 74 | def __init__(self, dataset, pad_idx, left_pad=False): 75 | super().__init__(dataset) 76 | self.pad_idx = pad_idx 77 | self.left_pad = left_pad 78 | 79 | def collater(self, samples): 80 | return collate_cross_2d( 81 | samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8 82 | ) 83 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/cropping_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | import logging 8 | from unicore.data import BaseWrapperDataset 9 | from . import data_utils 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class CroppingPocketDataset(BaseWrapperDataset): 15 | def __init__(self, dataset, seed, atoms, coordinates, holo_coordinates, max_atoms=256): 16 | self.dataset = dataset 17 | self.seed = seed 18 | self.atoms = atoms 19 | self.coordinates = coordinates 20 | self.holo_coordinates = holo_coordinates 21 | self.max_atoms = ( 22 | max_atoms # max number of atoms in a molecule, None indicates no limit. 23 | ) 24 | self.set_epoch(None) 25 | 26 | def set_epoch(self, epoch, **unused): 27 | super().set_epoch(epoch) 28 | self.epoch = epoch 29 | 30 | @lru_cache(maxsize=16) 31 | def __cached_item__(self, index: int, epoch: int): 32 | dd = self.dataset[index].copy() 33 | atoms = dd[self.atoms] 34 | coordinates = dd[self.coordinates] 35 | # residue = dd["residue"] 36 | holo_coordinates = dd[self.holo_coordinates] 37 | 38 | # crop atoms according to their distance to the center of pockets 39 | if self.max_atoms and len(atoms) > self.max_atoms: 40 | with data_utils.numpy_seed(self.seed, epoch, index): 41 | distance = np.linalg.norm( 42 | coordinates - coordinates.mean(axis=0), axis=1 43 | ) 44 | 45 | def softmax(x): 46 | x -= np.max(x) 47 | x = np.exp(x) / np.sum(np.exp(x)) 48 | return x 49 | 50 | distance += 1 # prevent inf 51 | weight = softmax(np.reciprocal(distance)) 52 | index = np.random.choice( 53 | len(atoms), self.max_atoms, replace=False, p=weight 54 | ) 55 | atoms = atoms[index] 56 | coordinates = coordinates[index] 57 | # residue = residue[index] 58 | holo_coordinates = holo_coordinates[index] 59 | 60 | dd[self.atoms] = atoms 61 | dd[self.coordinates] = coordinates.astype(np.float32) 62 | # dd["residue"] = residue 63 | dd[self.holo_coordinates] = holo_coordinates.astype(np.float32) 64 | return dd 65 | 66 | def __getitem__(self, index: int): 67 | return self.__cached_item__(index, self.epoch) 68 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | import contextlib 7 | 8 | 9 | @contextlib.contextmanager 10 | def numpy_seed(seed, *addl_seeds): 11 | """Context manager which seeds the NumPy PRNG with the specified seed and 12 | restores the state afterward""" 13 | if seed is None: 14 | yield 15 | return 16 | if len(addl_seeds) > 0: 17 | seed = int(hash((seed, *addl_seeds)) % 1e6) 18 | state = np.random.get_state() 19 | np.random.seed(seed) 20 | try: 21 | yield 22 | finally: 23 | np.random.set_state(state) 24 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/distance_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | import torch 7 | from scipy.spatial import distance_matrix 8 | from functools import lru_cache 9 | from unicore.data import BaseWrapperDataset 10 | 11 | 12 | class DistanceDataset(BaseWrapperDataset): 13 | def __init__(self, dataset): 14 | super().__init__(dataset) 15 | self.dataset = dataset 16 | 17 | @lru_cache(maxsize=16) 18 | def __getitem__(self, idx): 19 | pos = self.dataset[idx].view(-1, 3).numpy() 20 | dist = distance_matrix(pos, pos).astype(np.float32) 21 | return torch.from_numpy(dist) 22 | 23 | 24 | class EdgeTypeDataset(BaseWrapperDataset): 25 | def __init__(self, dataset: torch.utils.data.Dataset, num_types: int): 26 | self.dataset = dataset 27 | self.num_types = num_types 28 | 29 | @lru_cache(maxsize=16) 30 | def __getitem__(self, index: int): 31 | node_input = self.dataset[index].clone() 32 | offset = node_input.view(-1, 1) * self.num_types + node_input.view(1, -1) 33 | return offset 34 | 35 | 36 | class CrossDistanceDataset(BaseWrapperDataset): 37 | def __init__(self, mol_dataset, pocket_dataset): 38 | super().__init__(mol_dataset) 39 | self.mol_dataset = mol_dataset 40 | self.pocket_dataset = pocket_dataset 41 | 42 | @lru_cache(maxsize=16) 43 | def __getitem__(self, idx): 44 | mol_pos = self.mol_dataset[idx].view(-1, 3).numpy() 45 | pocket_pos = self.pocket_dataset[idx].view(-1, 3).numpy() 46 | dist = distance_matrix(mol_pos, pocket_pos).astype(np.float32) 47 | return torch.from_numpy(dist) 48 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/key_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from functools import lru_cache 6 | from unicore.data import BaseWrapperDataset 7 | 8 | 9 | class KeyDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, key): 11 | self.dataset = dataset 12 | self.key = key 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | @lru_cache(maxsize=16) 18 | def __getitem__(self, idx): 19 | return self.dataset[idx][self.key] 20 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | 6 | import lmdb 7 | import os 8 | import pickle 9 | from functools import lru_cache 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class LMDBDataset: 16 | def __init__(self, db_path): 17 | self.db_path = db_path 18 | assert os.path.isfile(self.db_path), "{} not found".format(self.db_path) 19 | env = self.connect_db(self.db_path) 20 | with env.begin() as txn: 21 | self._keys = list(txn.cursor().iternext(values=False)) 22 | 23 | def connect_db(self, lmdb_path, save_to_self=False): 24 | env = lmdb.open( 25 | lmdb_path, 26 | subdir=False, 27 | readonly=True, 28 | lock=False, 29 | readahead=False, 30 | meminit=False, 31 | max_readers=256, 32 | ) 33 | if not save_to_self: 34 | return env 35 | else: 36 | self.env = env 37 | 38 | def __len__(self): 39 | return len(self._keys) 40 | 41 | @lru_cache(maxsize=16) 42 | def __getitem__(self, idx): 43 | if not hasattr(self, "env"): 44 | self.connect_db(self.db_path, save_to_self=True) 45 | datapoint_pickled = self.env.begin().get(f"{idx}".encode("ascii")) 46 | data = pickle.loads(datapoint_pickled) 47 | return data 48 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/normalize_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class NormalizeDataset(BaseWrapperDataset): 11 | def __init__(self, dataset, coordinates, normalize_coord=True): 12 | self.dataset = dataset 13 | self.coordinates = coordinates 14 | self.normalize_coord = normalize_coord # normalize the coordinates. 15 | self.set_epoch(None) 16 | 17 | def set_epoch(self, epoch, **unused): 18 | super().set_epoch(epoch) 19 | self.epoch = epoch 20 | 21 | @lru_cache(maxsize=16) 22 | def __cached_item__(self, index: int, epoch: int): 23 | dd = self.dataset[index].copy() 24 | coordinates = dd[self.coordinates] 25 | # normalize 26 | if self.normalize_coord: 27 | coordinates = coordinates - coordinates.mean(axis=0) 28 | dd[self.coordinates] = coordinates.astype(np.float32) 29 | return dd 30 | 31 | def __getitem__(self, index: int): 32 | return self.__cached_item__(index, self.epoch) 33 | 34 | 35 | class NormalizeDockingPoseDataset(BaseWrapperDataset): 36 | def __init__( 37 | self, 38 | dataset, 39 | coordinates, 40 | pocket_coordinates, 41 | center_coordinates="center_coordinates", 42 | ): 43 | self.dataset = dataset 44 | self.coordinates = coordinates 45 | self.pocket_coordinates = pocket_coordinates 46 | self.center_coordinates = center_coordinates 47 | self.set_epoch(None) 48 | 49 | def set_epoch(self, epoch, **unused): 50 | super().set_epoch(epoch) 51 | self.epoch = epoch 52 | 53 | @lru_cache(maxsize=16) 54 | def __cached_item__(self, index: int, epoch: int): 55 | dd = self.dataset[index].copy() 56 | coordinates = dd[self.coordinates] 57 | pocket_coordinates = dd[self.pocket_coordinates] 58 | # normalize coordinates and pocket coordinates ,align with pocket center coordinates 59 | center_coordinates = pocket_coordinates.mean(axis=0) 60 | coordinates = coordinates - center_coordinates 61 | pocket_coordinates = pocket_coordinates - center_coordinates 62 | dd[self.coordinates] = coordinates.astype(np.float32) 63 | dd[self.pocket_coordinates] = pocket_coordinates.astype(np.float32) 64 | dd[self.center_coordinates] = center_coordinates.astype(np.float32) 65 | return dd 66 | 67 | def __getitem__(self, index: int): 68 | return self.__cached_item__(index, self.epoch) 69 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/prepend_and_append_2d_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import torch 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class PrependAndAppend2DDataset(BaseWrapperDataset): 11 | def __init__(self, dataset, token=None): 12 | super().__init__(dataset) 13 | self.token = token 14 | 15 | @lru_cache(maxsize=16) 16 | def __getitem__(self, idx): 17 | item = self.dataset[idx] 18 | if self.token is not None: 19 | h, w = item.size(-2), item.size(-1) 20 | new_item = torch.full((h + 2, w + 2), self.token).type_as(item) 21 | new_item[1:-1, 1:-1] = item 22 | return new_item 23 | return item 24 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/realign_ligand_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Techonology, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | from functools import lru_cache 8 | import logging 9 | from unicore.data import BaseWrapperDataset 10 | from scipy.spatial.transform import Rotation 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class ReAlignLigandDataset(BaseWrapperDataset): 16 | def __init__(self, dataset, coordinates, pocket_coordinates): 17 | self.dataset = dataset 18 | self.coordinates = coordinates 19 | self.pocket_coordinates = pocket_coordinates 20 | self.set_epoch(None) 21 | 22 | def set_epoch(self, epoch, **unused): 23 | super().set_epoch(epoch) 24 | self.epoch = epoch 25 | 26 | @lru_cache(maxsize=16) 27 | def __cached_item__(self, index: int, epoch: int): 28 | dd = self.dataset[index].copy() 29 | coordinates = dd[self.coordinates] 30 | pocket_coordinates = dd[self.pocket_coordinates] 31 | normal_coordinates, normal_pocket_coordinates = realigncoordinates(coordinates, pocket_coordinates) 32 | 33 | dd[self.coordinates] = normal_coordinates.astype(np.float32) 34 | dd[self.pocket_coordinates] = normal_pocket_coordinates.astype(np.float32) 35 | return dd 36 | 37 | 38 | def __getitem__(self, index: int): 39 | return self.__cached_item__(index, self.epoch) 40 | 41 | def calc_inertia_tensor(new_coord, mass=None): 42 | """ This function calculates the Elements of inertia tensor for the 43 | center-moved coordinates. 44 | """ 45 | if mass is None: 46 | mass = 1.0 47 | I_xx = (mass * np.sum(np.square(new_coord[:,1:3:1]),axis=1)).sum() 48 | I_yy = (mass * np.sum(np.square(new_coord[:,0:3:2]),axis=1)).sum() 49 | I_zz = (mass * np.sum(np.square(new_coord[:,0:2:1]),axis=1)).sum() 50 | I_xy = (-1 * mass * np.prod(new_coord[:,0:2:1],axis=1)).sum() 51 | I_yz = (-1 * mass * np.prod(new_coord[:,1:3:1],axis=1)).sum() 52 | I_xz = (-1 * mass * np.prod(new_coord[:,0:3:2],axis=1)).sum() 53 | I = np.array([[I_xx, I_xy, I_xz], 54 | [I_xy, I_yy, I_yz], 55 | [I_xz, I_yz, I_zz]]) 56 | return I 57 | 58 | def realigncoordinates(coordinates, pocket_coordinates): 59 | coordinates = coordinates - coordinates.mean(axis=0) 60 | pocket_coordinates = pocket_coordinates - pocket_coordinates.mean(axis=0) 61 | 62 | D = calc_inertia_tensor(coordinates) 63 | I, E = np.linalg.eigh(D) 64 | 65 | D_poc = calc_inertia_tensor(pocket_coordinates) 66 | I_poc, E_poc = np.linalg.eigh(D_poc) 67 | 68 | _R, _score = Rotation.align_vectors(E[:,:].T, E_poc[:,:].T) 69 | new_coordinates = np.dot(coordinates, _R.as_matrix()) 70 | 71 | return new_coordinates, pocket_coordinates 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/remove_hydrogen_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class RemoveHydrogenPocketDataset(BaseWrapperDataset): 11 | def __init__( 12 | self, 13 | dataset, 14 | atoms, 15 | coordinates, 16 | holo_coordinates, 17 | remove_hydrogen=True, 18 | remove_polar_hydrogen=False, 19 | ): 20 | self.dataset = dataset 21 | self.atoms = atoms 22 | self.coordinates = coordinates 23 | self.holo_coordinates = holo_coordinates 24 | self.remove_hydrogen = remove_hydrogen 25 | self.remove_polar_hydrogen = remove_polar_hydrogen 26 | self.set_epoch(None) 27 | 28 | def set_epoch(self, epoch, **unused): 29 | super().set_epoch(epoch) 30 | self.epoch = epoch 31 | 32 | @lru_cache(maxsize=16) 33 | def __cached_item__(self, index: int, epoch: int): 34 | dd = self.dataset[index].copy() 35 | atoms = dd[self.atoms] 36 | coordinates = dd[self.coordinates] 37 | holo_coordinates = dd[self.holo_coordinates] 38 | 39 | if self.remove_hydrogen: 40 | mask_hydrogen = atoms != "H" 41 | atoms = atoms[mask_hydrogen] 42 | coordinates = coordinates[mask_hydrogen] 43 | holo_coordinates = holo_coordinates[mask_hydrogen] 44 | if not self.remove_hydrogen and self.remove_polar_hydrogen: 45 | end_idx = 0 46 | for i, atom in enumerate(atoms[::-1]): 47 | if atom != "H": 48 | break 49 | else: 50 | end_idx = i + 1 51 | if end_idx != 0: 52 | atoms = atoms[:-end_idx] 53 | coordinates = coordinates[:-end_idx] 54 | holo_coordinates = holo_coordinates[:-end_idx] 55 | dd[self.atoms] = atoms 56 | dd[self.coordinates] = coordinates.astype(np.float32) 57 | dd[self.holo_coordinates] = holo_coordinates.astype(np.float32) 58 | return dd 59 | 60 | def __getitem__(self, index: int): 61 | return self.__cached_item__(index, self.epoch) 62 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/data/tta_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset 8 | 9 | 10 | class TTADockingPoseDataset(BaseWrapperDataset): 11 | def __init__( 12 | self, 13 | dataset, 14 | atoms, 15 | coordinates, 16 | pocket_atoms, 17 | pocket_coordinates, 18 | holo_coordinates, 19 | holo_pocket_coordinates, 20 | is_train=True, 21 | conf_size=10, 22 | ): 23 | self.dataset = dataset 24 | self.atoms = atoms 25 | self.coordinates = coordinates 26 | self.pocket_atoms = pocket_atoms 27 | self.pocket_coordinates = pocket_coordinates 28 | self.holo_coordinates = holo_coordinates 29 | self.holo_pocket_coordinates = holo_pocket_coordinates 30 | self.is_train = is_train 31 | self.conf_size = conf_size 32 | self.set_epoch(None) 33 | 34 | def set_epoch(self, epoch, **unused): 35 | super().set_epoch(epoch) 36 | self.epoch = epoch 37 | 38 | def __len__(self): 39 | return len(self.dataset) * self.conf_size 40 | 41 | @lru_cache(maxsize=16) 42 | def __cached_item__(self, index: int, epoch: int): 43 | smi_idx = index // self.conf_size 44 | coord_idx = index % self.conf_size 45 | atoms = np.array(self.dataset[smi_idx][self.atoms]) 46 | coordinates = np.array(self.dataset[smi_idx][self.coordinates][coord_idx]) 47 | pocket_atoms = np.array( 48 | [item[0] for item in self.dataset[smi_idx][self.pocket_atoms]] 49 | ) 50 | pocket_coordinates = np.array(self.dataset[smi_idx][self.pocket_coordinates][0]) 51 | if self.is_train: 52 | holo_coordinates = np.array(self.dataset[smi_idx][self.holo_coordinates][0]) 53 | holo_pocket_coordinates = np.array( 54 | self.dataset[smi_idx][self.holo_pocket_coordinates][0] 55 | ) 56 | else: 57 | holo_coordinates = coordinates 58 | holo_pocket_coordinates = pocket_coordinates 59 | 60 | smi = self.dataset[smi_idx]["smi"] 61 | pocket = self.dataset[smi_idx]["pocket"] 62 | # id = self.dataset[smi_idx]['id'] 63 | 64 | return { 65 | "atoms": atoms, 66 | "coordinates": coordinates.astype(np.float32), 67 | "pocket_atoms": pocket_atoms, 68 | "pocket_coordinates": pocket_coordinates.astype(np.float32), 69 | "holo_coordinates": holo_coordinates.astype(np.float32), 70 | "holo_pocket_coordinates": holo_pocket_coordinates.astype(np.float32), 71 | "smi": smi, 72 | "pocket": pocket, 73 | # 'id': id, 74 | } 75 | 76 | def __getitem__(self, index: int): 77 | return self.__cached_item__(index, self.epoch) 78 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) DP Techonology, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | import pickle 11 | import torch 12 | from unicore import checkpoint_utils, distributed_utils, options, utils 13 | from unicore.logging import progress_bar 14 | from unicore import tasks 15 | 16 | logging.basicConfig( 17 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 18 | datefmt="%Y-%m-%d %H:%M:%S", 19 | level=os.environ.get("LOGLEVEL", "INFO").upper(), 20 | stream=sys.stdout, 21 | ) 22 | logger = logging.getLogger("unimol.inference") 23 | 24 | 25 | def main(args): 26 | 27 | assert ( 28 | args.batch_size is not None 29 | ), "Must specify batch size either with --batch-size" 30 | 31 | use_cuda = torch.cuda.is_available() and not args.cpu 32 | use_fp16 = args.fp16 and use_cuda 33 | 34 | if use_cuda: 35 | torch.cuda.set_device(args.device_id) 36 | 37 | if args.distributed_world_size > 1: 38 | data_parallel_world_size = distributed_utils.get_data_parallel_world_size() 39 | data_parallel_rank = distributed_utils.get_data_parallel_rank() 40 | else: 41 | data_parallel_world_size = 1 42 | data_parallel_rank = 0 43 | 44 | # Load model 45 | logger.info("loading model(s) from {}".format(args.path)) 46 | state = checkpoint_utils.load_checkpoint_to_cpu(args.path) 47 | task = tasks.setup_task(args) 48 | model = task.build_model(args) 49 | model.load_state_dict(state["model"], strict=False) 50 | 51 | # Move models to GPU 52 | if use_fp16: 53 | model.half() 54 | if use_cuda: 55 | model.cuda() 56 | 57 | # Print args 58 | logger.info(args) 59 | 60 | # Build loss 61 | loss = task.build_loss(args) 62 | loss.eval() 63 | 64 | for subset in args.valid_subset.split(","): 65 | try: 66 | task.load_dataset(subset, combine=False, epoch=1) 67 | dataset = task.dataset(subset) 68 | except KeyError: 69 | raise Exception("Cannot find dataset: " + subset) 70 | 71 | if not os.path.exists(args.results_path): 72 | os.makedirs(args.results_path) 73 | save_path = os.path.join(args.results_path, subset + ".pkl") 74 | # Initialize data iterator 75 | itr = task.get_batch_iterator( 76 | dataset=dataset, 77 | batch_size=args.batch_size, 78 | ignore_invalid_inputs=True, 79 | required_batch_size_multiple=args.required_batch_size_multiple, 80 | seed=args.seed, 81 | num_shards=data_parallel_world_size, 82 | shard_id=data_parallel_rank, 83 | num_workers=args.num_workers, 84 | data_buffer_size=args.data_buffer_size, 85 | ).next_epoch_itr(shuffle=False) 86 | progress = progress_bar.progress_bar( 87 | itr, 88 | log_format=args.log_format, 89 | log_interval=args.log_interval, 90 | prefix=f"valid on '{subset}' subset", 91 | default_log_format=("tqdm" if not args.no_progress_bar else "simple"), 92 | ) 93 | log_outputs = [] 94 | for i, sample in enumerate(progress): 95 | sample = utils.move_to_cuda(sample) if use_cuda else sample 96 | if len(sample) == 0: 97 | continue 98 | _, _, log_output = task.valid_step(sample, model, loss, test=True) 99 | progress.log({}, step=i) 100 | log_outputs.append(log_output) 101 | pickle.dump(log_outputs, open(save_path, "wb")) 102 | logger.info("Done inference! ") 103 | return None 104 | 105 | 106 | def cli_main(): 107 | parser = options.get_validation_parser() 108 | options.add_model_args(parser) 109 | args = options.parse_args_and_arch(parser) 110 | 111 | distributed_utils.call_main(args, main) 112 | 113 | 114 | if __name__ == "__main__": 115 | cli_main() 116 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol.losses." + file.name[:-3]) 8 | -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unimol import UniMolModel 2 | from .transformer_encoder_with_pair import TransformerEncoderWithPair 3 | from .docking_pose_v2 import DockingPoseV2Model -------------------------------------------------------------------------------- /unimol_docking_v2/unimol/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol.tasks." + file.name[:-3]) 8 | -------------------------------------------------------------------------------- /unimol_plus/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM dptechnology/unicore:latest-pytorch1.12.1-cuda11.6-rdma 2 | 3 | RUN pip install rdkit==2022.09.3 4 | 5 | RUN ldconfig && \ 6 | apt-get clean && \ 7 | apt-get autoremove && \ 8 | rm -rf /var/lib/apt/lists/* /tmp/* && \ 9 | conda clean -ya 10 | -------------------------------------------------------------------------------- /unimol_plus/figure/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/Uni-Mol/90f52c41299a1a582da0f9765e9f87aa21faa16a/unimol_plus/figure/overview.png -------------------------------------------------------------------------------- /unimol_plus/inference.sh: -------------------------------------------------------------------------------- 1 | [ -z "${MASTER_PORT}" ] && MASTER_PORT=10087 2 | [ -z "${MASTER_IP}" ] && MASTER_IP=127.0.0.1 3 | [ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l) 4 | [ -z "${OMPI_COMM_WORLD_SIZE}" ] && OMPI_COMM_WORLD_SIZE=1 5 | [ -z "${OMPI_COMM_WORLD_RANK}" ] && OMPI_COMM_WORLD_RANK=0 6 | 7 | [ -z "${arch}" ] && arch=unimol_plus_pcq_base 8 | [ -z "${task}" ] && task=pcq 9 | [ -z "${batch_size}" ] && batch_size=128 10 | 11 | mkdir -p $results_path 12 | 13 | torchrun --nproc_per_node=$n_gpu --nnodes=$OMPI_COMM_WORLD_SIZE --node_rank=$OMPI_COMM_WORLD_RANK --master_addr=$MASTER_IP --master_port=$MASTER_PORT \ 14 | ./inference.py --user-dir ./unimol_plus/ $data_path --valid-subset $1 \ 15 | --results-path $results_path \ 16 | --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \ 17 | --task $task --loss unimol_plus --arch $arch \ 18 | --path $weight_path \ 19 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ 20 | --log-interval 50 --log-format simple --label-prob 0.0 -------------------------------------------------------------------------------- /unimol_plus/scripts/download.sh: -------------------------------------------------------------------------------- 1 | wget http://ogb-data.stanford.edu/data/lsc/pcqm4m-v2-train.sdf.tar.gz 2 | md5sum pcqm4m-v2-train.sdf.tar.gz # fd72bce606e7ddf36c2a832badeec6ab 3 | tar -xf pcqm4m-v2-train.sdf.tar.gz # extracted pcqm4m-v2-train.sdf 4 | 5 | wget 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m-v2.zip' 6 | unzip pcqm4m-v2.zip 7 | mv pcqm4m-v2-train.sdf pcqm4m-v2 8 | -------------------------------------------------------------------------------- /unimol_plus/scripts/get_label3d_lmdb.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os, sys 3 | import pickle 4 | from tqdm import tqdm 5 | from multiprocessing import Pool 6 | import lmdb 7 | from rdkit import Chem 8 | import torch 9 | 10 | split = torch.load("split_dict.pt") 11 | train_index = split["train"] 12 | 13 | 14 | os.system("rm -f label_3D.lmdb") 15 | 16 | env_new = lmdb.open( 17 | "label_3D.lmdb", 18 | subdir=False, 19 | readonly=False, 20 | lock=False, 21 | readahead=False, 22 | meminit=False, 23 | max_readers=1, 24 | map_size=int(100e9), 25 | ) 26 | txn_write = env_new.begin(write=True) 27 | 28 | i = 0 29 | with open("pcqm4m-v2-train.sdf", "r") as input: 30 | cur_content = "" 31 | for line in input: 32 | cur_content += line 33 | if line == "$$$$\n": 34 | ret = gzip.compress(pickle.dumps(cur_content)) 35 | a = txn_write.put(int(train_index[i]).to_bytes(4, byteorder="big"), ret) 36 | i += 1 37 | cur_content = "" 38 | if i % 10000 == 0: 39 | txn_write.commit() 40 | txn_write = env_new.begin(write=True) 41 | print("processed {} molecules".format(i)) 42 | 43 | txn_write.commit() 44 | env_new.close() 45 | -------------------------------------------------------------------------------- /unimol_plus/scripts/make_oc20_test_submission.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pickle 4 | import os, sys 5 | import glob 6 | 7 | input_folder = sys.argv[1] 8 | 9 | subsets = ["test_id", "test_ood_ads", "test_ood_both", "test_ood_cat"] 10 | 11 | 12 | def flatten(d, index): 13 | res = [] 14 | for x in d: 15 | res.extend(x[index]) 16 | return np.array(res) 17 | 18 | 19 | def one_ckp(folder, subset): 20 | s = f"{folder}/" + subset + "*.pkl" 21 | files = sorted(glob.glob(s)) 22 | data = [] 23 | for file in files: 24 | with open(file, "rb") as f: 25 | try: 26 | data.extend(pickle.load(f)) 27 | except Exception as e: 28 | print("Error in file: ", file) 29 | raise e 30 | 31 | id = flatten(data, 0) 32 | y_pred = flatten(data, 2) 33 | 34 | return np.array(id), np.array(y_pred) 35 | 36 | 37 | submission_file = {} 38 | 39 | for subset in subsets: 40 | id, y_pred = one_ckp(input_folder, subset) 41 | prefix = "_".join(subset.split("_")[1:]) 42 | submission_file[prefix + "_ids"] = id 43 | submission_file[prefix + "_energy"] = y_pred 44 | 45 | np.savez_compressed(sys.argv[2], **submission_file) 46 | -------------------------------------------------------------------------------- /unimol_plus/scripts/make_pcq_test_dev_submission.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pickle 4 | import os, sys 5 | import glob 6 | import pandas as pd 7 | 8 | input_folder = sys.argv[1] 9 | subset = sys.argv[2] 10 | 11 | 12 | split = torch.load("./split_dict.pt") 13 | valid_index = split[subset] 14 | 15 | 16 | def flatten(d, index): 17 | res = [] 18 | for x in d: 19 | res.extend(x[index]) 20 | return np.array(res) 21 | 22 | 23 | def one_ckp(folder, subset): 24 | s = f"{folder}/" + subset + "*.pkl" 25 | files = sorted(glob.glob(s)) 26 | data = [] 27 | for file in files: 28 | with open(file, "rb") as f: 29 | try: 30 | data.extend(pickle.load(f)) 31 | except Exception as e: 32 | print("Error in file: ", file) 33 | raise e 34 | 35 | id = flatten(data, 0) 36 | # index 1 is the predicted position 37 | gap_pred = flatten(data, 2) 38 | 39 | df = pd.DataFrame( 40 | { 41 | "id": id, 42 | "data_index": id, 43 | "pred": gap_pred, 44 | } 45 | ) 46 | df_grouped = df.groupby(["id"]) 47 | df_mean = df_grouped.agg("mean") 48 | return df_mean.sort_values(by="data_index") 49 | 50 | 51 | def save_test_submission(input_dict, dir_path: str, mode: str): 52 | """ 53 | save test submission file at dir_path 54 | """ 55 | assert "y_pred" in input_dict 56 | assert mode in ["test-dev", "test-challenge"] 57 | 58 | y_pred = input_dict["y_pred"] 59 | 60 | if mode == "test-dev": 61 | filename = os.path.join(dir_path, "y_pred_pcqm4m-v2_test-dev") 62 | assert y_pred.shape == (147037,) 63 | elif mode == "test-challenge": 64 | filename = os.path.join(dir_path, "y_pred_pcqm4m-v2_test-challenge") 65 | assert y_pred.shape == (147432,) 66 | 67 | assert isinstance(filename, str) 68 | assert isinstance(y_pred, np.ndarray) or isinstance(y_pred, torch.Tensor) 69 | 70 | if not os.path.exists(dir_path): 71 | os.makedirs(dir_path) 72 | 73 | if isinstance(y_pred, torch.Tensor): 74 | y_pred = y_pred.numpy() 75 | y_pred = y_pred.astype(np.float32) 76 | np.savez_compressed(filename, y_pred=y_pred) 77 | 78 | 79 | df_mean = one_ckp(input_folder, subset) 80 | pred_id = df_mean["data_index"].values 81 | for i in range(len(valid_index)): 82 | assert valid_index[i] == pred_id[i] 83 | pred = df_mean["pred"].values 84 | print(pred.shape) 85 | save_test_submission({"y_pred": pred}, "./", subset) 86 | -------------------------------------------------------------------------------- /unimol_plus/scripts/oc20_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lmdb 3 | import pickle 4 | from multiprocessing import Pool 5 | from tqdm import tqdm 6 | import numpy as np 7 | import argparse 8 | from multiprocessing import cpu_count 9 | 10 | nthreads = cpu_count() 11 | 12 | 13 | def inner_read(cursor): 14 | key, value = cursor 15 | data = pickle.loads(value) 16 | 17 | if "y_relaxed" in data: 18 | ret_data = { 19 | "cell": data["cell"].numpy().astype(np.float32), 20 | "pos": data["pos"].numpy().astype(np.float32), 21 | "atomic_numbers": data["atomic_numbers"].numpy().astype(np.int8), 22 | "tags": data["tags"].numpy().astype(np.int8), 23 | "pos_relaxed": data["pos_relaxed"].numpy().astype(np.float32), 24 | "y_relaxed": data["y_relaxed"], 25 | "sid": data["sid"], 26 | } 27 | else: 28 | ret_data = { 29 | "cell": data["cell"].numpy().astype(np.float32), 30 | "pos": data["pos"].numpy().astype(np.float32), 31 | "atomic_numbers": data["atomic_numbers"].numpy().astype(np.int8), 32 | "tags": data["tags"].numpy().astype(np.int8), 33 | "sid": data["sid"], 34 | } 35 | return data["sid"], pickle.dumps(ret_data) 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser(description="generate lmdb file") 40 | parser.add_argument("--input-path", type=str, help="initial oc20 data file path") 41 | parser.add_argument("--out-path", type=str, help="output path") 42 | parser.add_argument("--split", type=str, help="train/valid/test") 43 | args = parser.parse_args() 44 | 45 | train_list = ["train"] 46 | valid_list = ["val_id", "val_ood_ads", "val_ood_both", "val_ood_cat"] 47 | test_list = ["test_id", "test_ood_ads", "test_ood_both", "test_ood_cat"] 48 | path = args.input_path 49 | out_path = args.out_path 50 | 51 | if args.split == "train": 52 | name_list = train_list 53 | elif args.split == "valid": 54 | name_list = valid_list 55 | elif args.split == "test": 56 | name_list = test_list 57 | 58 | file_list = [os.path.join(path, name, "data.lmdb") for name in name_list] 59 | with Pool(nthreads) as pool: 60 | for filename, outname in zip(file_list, name_list): 61 | i = 0 62 | env = lmdb.open( 63 | filename, 64 | subdir=False, 65 | readonly=True, 66 | lock=False, 67 | readahead=True, 68 | meminit=False, 69 | max_readers=nthreads, 70 | map_size=int(1000e9), 71 | ) 72 | txn = env.begin() 73 | 74 | out_dir = os.path.join(out_path, outname) 75 | if not os.path.exists(out_dir): 76 | os.mkdir(out_dir) 77 | outputfilename = os.path.join(out_dir, "data.lmdb") 78 | try: 79 | os.remove(outputfilename) 80 | except: 81 | pass 82 | 83 | env_new = lmdb.open( 84 | outputfilename, 85 | subdir=False, 86 | readonly=False, 87 | lock=False, 88 | readahead=False, 89 | meminit=False, 90 | max_readers=1, 91 | map_size=int(1000e9), 92 | ) 93 | txn_write = env_new.begin(write=True) 94 | for inner_output in tqdm( 95 | pool.imap(inner_read, txn.cursor()), total=env.stat()["entries"] 96 | ): 97 | txn_write.put(f"{inner_output[0]}".encode("ascii"), inner_output[1]) 98 | i += 1 99 | if i % 1000 == 0: 100 | txn_write.commit() 101 | txn_write = env_new.begin(write=True) 102 | txn_write.commit() 103 | env_new.close() 104 | env.close() 105 | -------------------------------------------------------------------------------- /unimol_plus/setup.py: -------------------------------------------------------------------------------- 1 | """Install script for setuptools.""" 2 | 3 | from setuptools import find_packages 4 | from setuptools import setup 5 | 6 | setup( 7 | name="unimol_plus", 8 | version="1.0.0", 9 | description="", 10 | author="DP Technology", 11 | author_email="unimol@dp.tech", 12 | license="The MIT License", 13 | url="https://github.com/deepmodeling/Uni-Mol", 14 | packages=find_packages( 15 | exclude=["scripts", "tests", "example_data", "docker", "figure"] 16 | ), 17 | install_requires=[ 18 | "numpy", 19 | "pandas", 20 | ], 21 | classifiers=[ 22 | "Development Status :: 5 - Production/Stable", 23 | "Intended Audience :: Science/Research", 24 | "License :: OSI Approved :: Apache Software License", 25 | "Operating System :: POSIX :: Linux", 26 | "Programming Language :: Python :: 3.7", 27 | "Programming Language :: Python :: 3.8", 28 | "Programming Language :: Python :: 3.9", 29 | "Programming Language :: Python :: 3.10", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /unimol_plus/train_oc20.sh: -------------------------------------------------------------------------------- 1 | [ -z "${MASTER_PORT}" ] && MASTER_PORT=10088 2 | [ -z "${MASTER_IP}" ] && MASTER_IP=127.0.0.1 3 | [ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l) 4 | [ -z "${OMPI_COMM_WORLD_SIZE}" ] && OMPI_COMM_WORLD_SIZE=1 5 | [ -z "${OMPI_COMM_WORLD_RANK}" ] && OMPI_COMM_WORLD_RANK=0 6 | 7 | [ -z "${lr}" ] && lr=2e-4 8 | [ -z "${end_lr}" ] && end_lr=1e-9 9 | [ -z "${warmup_steps}" ] && warmup_steps=150000 10 | [ -z "${total_steps}" ] && total_steps=1500000 11 | [ -z "${update_freq}" ] && update_freq=1 12 | [ -z "${seed}" ] && seed=1 13 | [ -z "${clip_norm}" ] && clip_norm=5 14 | [ -z "${weight_decay}" ] && weight_decay=0.0 15 | [ -z "${pos_loss_weight}" ] && pos_loss_weight=12.0 16 | [ -z "${dist_loss_weight}" ] && dist_loss_weight=5.0 17 | [ -z "${min_pos_loss_weight}" ] && min_pos_loss_weight=1.0 18 | [ -z "${min_dist_loss_weight}" ] && min_dist_loss_weight=1.0 19 | [ -z "${noise}" ] && noise=0.3 20 | [ -z "${label_prob}" ] && label_prob=0.8 21 | [ -z "${mid_prob}" ] && mid_prob=0.1 22 | [ -z "${mid_lower}" ] && mid_lower=0.4 23 | [ -z "${mid_upper}" ] && mid_upper=0.6 24 | [ -z "${ema_decay}" ] && ema_decay=0.999 25 | [ -z "${arch}" ] && arch=unimol_plus_oc20_base 26 | 27 | 28 | export NCCL_ASYNC_ERROR_HANDLING=1 29 | export OMP_NUM_THREADS=1 30 | echo "n_gpu per node" $n_gpu 31 | echo "OMPI_COMM_WORLD_SIZE" $OMPI_COMM_WORLD_SIZE 32 | echo "OMPI_COMM_WORLD_RANK" $OMPI_COMM_WORLD_RANK 33 | echo "MASTER_IP" $MASTER_IP 34 | echo "MASTER_PORT" $MASTER_PORT 35 | echo "data" $1 36 | echo "save_dir" $2 37 | echo "warmup_step" $warmup_step 38 | echo "total_step" $total_step 39 | echo "update_freq" $update_freq 40 | echo "seed" $seed 41 | echo "valid_sets" $valid_sets 42 | 43 | data_path=$1 44 | save_dir=$2 45 | lr=$3 46 | batch_size=$4 47 | 48 | more_args="" 49 | 50 | 51 | 52 | more_args=$more_args" --ema-decay $ema_decay --validate-with-ema" 53 | save_dir=$save_dir"-ema"$ema_decay 54 | 55 | if [ -z "${train_with_valid_data}" ] 56 | then 57 | echo "normal training" 58 | else 59 | echo "training with additional validation data" 60 | more_args=$more_args" --train-with-valid-data" 61 | save_dir=$save_dir"-full" 62 | fi 63 | 64 | 65 | mkdir -p $save_dir 66 | 67 | export NCCL_ASYNC_ERROR_HANDLING=1 68 | export OMP_NUM_THREADS=1 69 | 70 | torchrun --nproc_per_node=$n_gpu --nnodes=$OMPI_COMM_WORLD_SIZE --node_rank=$OMPI_COMM_WORLD_RANK --master_addr=$MASTER_IP --master_port=$MASTER_PORT \ 71 | $(which unicore-train) $data_path --user-dir ./unimol_plus \ 72 | --num-workers 8 --ddp-backend=no_c10d \ 73 | --task oc20 --loss unimol_plus --arch $arch \ 74 | --train-subset train --valid-subset val_id,val_ood_ads,val_ood_cat,val_ood_both --best-checkpoint-metric loss \ 75 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --tensorboard-logdir $save_dir/tsb \ 76 | --log-interval 100 --log-format simple \ 77 | --save-interval-updates 10000 --validate-interval-updates 10000 --keep-interval-updates 10 --no-epoch-checkpoints \ 78 | --save-dir $save_dir \ 79 | --batch-size $batch_size \ 80 | --data-buffer-size 32 --fixed-validation-seed 11 --batch-size-valid $((batch_size*4)) \ 81 | --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm $clip_norm \ 82 | --lr $lr --end-learning-rate $end_lr --lr-scheduler polynomial_decay --power 1 \ 83 | --warmup-updates $warmup_steps --total-num-update $total_steps --max-update $total_steps --update-freq $update_freq \ 84 | --weight-decay $weight_decay \ 85 | --dist-loss-weight $dist_loss_weight --pos-loss-weight $pos_loss_weight \ 86 | --min-dist-loss-weight $min_dist_loss_weight --min-pos-loss-weight $min_pos_loss_weight \ 87 | --label-prob $label_prob --noise-scale $noise \ 88 | --mid-prob $mid_prob --mid-lower $mid_lower --mid-upper $mid_upper --seed $seed $more_args -------------------------------------------------------------------------------- /unimol_plus/train_pcq.sh: -------------------------------------------------------------------------------- 1 | [ -z "${MASTER_PORT}" ] && MASTER_PORT=10088 2 | [ -z "${MASTER_IP}" ] && MASTER_IP=127.0.0.1 3 | [ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l) 4 | [ -z "${OMPI_COMM_WORLD_SIZE}" ] && OMPI_COMM_WORLD_SIZE=1 5 | [ -z "${OMPI_COMM_WORLD_RANK}" ] && OMPI_COMM_WORLD_RANK=0 6 | 7 | [ -z "${lr}" ] && lr=2e-4 8 | [ -z "${end_lr}" ] && end_lr=1e-9 9 | [ -z "${warmup_steps}" ] && warmup_steps=150000 10 | [ -z "${total_steps}" ] && total_steps=1500000 11 | [ -z "${update_freq}" ] && update_freq=1 12 | [ -z "${seed}" ] && seed=1 13 | [ -z "${clip_norm}" ] && clip_norm=5 14 | [ -z "${weight_decay}" ] && weight_decay=0.0 15 | [ -z "${pos_loss_weight}" ] && pos_loss_weight=0.3 16 | [ -z "${dist_loss_weight}" ] && dist_loss_weight=1.5 17 | [ -z "${min_pos_loss_weight}" ] && min_pos_loss_weight=0.06 18 | [ -z "${min_dist_loss_weight}" ] && min_dist_loss_weight=0.3 19 | [ -z "${valid_sets}" ] && valid_sets="valid" 20 | [ -z "${noise}" ] && noise=0.2 21 | [ -z "${label_prob}" ] && label_prob=0.8 22 | [ -z "${mid_prob}" ] && mid_prob=0.1 23 | [ -z "${mid_lower}" ] && mid_lower=0.4 24 | [ -z "${mid_upper}" ] && mid_upper=0.6 25 | [ -z "${ema_decay}" ] && ema_decay=0.999 26 | [ -z "${arch}" ] && arch=unimol_plus_pcq_base 27 | 28 | 29 | export NCCL_ASYNC_ERROR_HANDLING=1 30 | export OMP_NUM_THREADS=1 31 | echo "n_gpu per node" $n_gpu 32 | echo "OMPI_COMM_WORLD_SIZE" $OMPI_COMM_WORLD_SIZE 33 | echo "OMPI_COMM_WORLD_RANK" $OMPI_COMM_WORLD_RANK 34 | echo "MASTER_IP" $MASTER_IP 35 | echo "MASTER_PORT" $MASTER_PORT 36 | echo "data" $1 37 | echo "save_dir" $2 38 | echo "warmup_step" $warmup_step 39 | echo "total_step" $total_step 40 | echo "update_freq" $update_freq 41 | echo "seed" $seed 42 | echo "valid_sets" $valid_sets 43 | 44 | data_path=$1 45 | save_dir=$2 46 | lr=$3 47 | batch_size=$4 48 | 49 | more_args="" 50 | 51 | 52 | more_args=$more_args" --ema-decay $ema_decay --validate-with-ema" 53 | save_dir=$save_dir"-ema"$ema_decay 54 | 55 | 56 | mkdir -p $save_dir 57 | 58 | export NCCL_ASYNC_ERROR_HANDLING=1 59 | export OMP_NUM_THREADS=1 60 | 61 | torchrun --nproc_per_node=$n_gpu --nnodes=$OMPI_COMM_WORLD_SIZE --node_rank=$OMPI_COMM_WORLD_RANK --master_addr=$MASTER_IP --master_port=$MASTER_PORT \ 62 | $(which unicore-train) $data_path --user-dir ./unimol_plus --train-subset train --valid-subset $valid_sets \ 63 | --num-workers 8 --ddp-backend=c10d \ 64 | --task pcq --loss unimol_plus --arch $arch \ 65 | --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --tensorboard-logdir $save_dir/tsb \ 66 | --log-interval 100 --log-format simple \ 67 | --save-interval-updates 10000 --validate-interval-updates 10000 --keep-interval-updates 50 --no-epoch-checkpoints \ 68 | --save-dir $save_dir \ 69 | --batch-size $batch_size \ 70 | --data-buffer-size 32 --fixed-validation-seed 11 --batch-size-valid 256 \ 71 | --optimizer adam --adam-betas '(0.9, 0.999)' --adam-eps 1e-8 --clip-norm $clip_norm \ 72 | --lr $lr --end-learning-rate $end_lr --lr-scheduler polynomial_decay --power 1 \ 73 | --warmup-updates $warmup_steps --total-num-update $total_steps --max-update $total_steps --update-freq $update_freq \ 74 | --weight-decay $weight_decay \ 75 | --dist-loss-weight $dist_loss_weight --pos-loss-weight $pos_loss_weight \ 76 | --min-dist-loss-weight $min_dist_loss_weight --min-pos-loss-weight $min_pos_loss_weight \ 77 | --label-prob $label_prob --noise-scale $noise \ 78 | --mid-prob $mid_prob --mid-lower $mid_lower --mid-upper $mid_upper --seed $seed $more_args -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import unimol_plus.tasks 3 | import unimol_plus.data 4 | import unimol_plus.models 5 | import unimol_plus.losses -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .key_dataset import KeyDataset 2 | from .conformer_sample_dataset import ( 3 | ConformationSampleDataset, 4 | ConformationExpandDataset, 5 | ) 6 | from .lmdb_dataset import LMDBDataset, StackedLMDBDataset 7 | from .pcq_dataset import ( 8 | PCQDataset, 9 | ) 10 | from .oc20_dataset import ( 11 | Is2reDataset, 12 | ) 13 | 14 | from .data_utils import numpy_seed 15 | 16 | __all__ = [] 17 | -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/data/conformer_sample_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | from functools import lru_cache 7 | from unicore.data import BaseWrapperDataset, data_utils 8 | from copy import deepcopy 9 | from tqdm import tqdm 10 | 11 | 12 | class ConformationSampleDataset(BaseWrapperDataset): 13 | def __init__( 14 | self, 15 | dataset, 16 | seed, 17 | coordinates, 18 | target_coordinates, 19 | ): 20 | self.dataset = dataset 21 | self.seed = seed 22 | self.coordinates = coordinates 23 | self.target_coordinates = target_coordinates 24 | self.set_epoch(None) 25 | 26 | def set_epoch(self, epoch, **unused): 27 | super().set_epoch(epoch) 28 | self.epoch = epoch 29 | 30 | @lru_cache(maxsize=16) 31 | def __cached_item__(self, index: int, epoch: int): 32 | data = deepcopy(self.dataset[index]) 33 | size = len(data[self.coordinates]) 34 | with data_utils.numpy_seed(self.seed, epoch, index): 35 | sample_idx = np.random.randint(size) 36 | coordinates = data[self.coordinates][sample_idx] 37 | if isinstance(data[self.target_coordinates], list): 38 | target_coordinates = data[self.target_coordinates][-1] 39 | else: 40 | target_coordinates = data[self.target_coordinates] 41 | del data[self.coordinates] 42 | del data[self.target_coordinates] 43 | data["coordinates"] = coordinates 44 | data["target_coordinates"] = target_coordinates 45 | return data 46 | 47 | def __getitem__(self, index: int): 48 | return self.__cached_item__(index, self.epoch) 49 | 50 | 51 | class ConformationExpandDataset(BaseWrapperDataset): 52 | def __init__( 53 | self, 54 | dataset, 55 | seed, 56 | coordinates, 57 | target_coordinates, 58 | ): 59 | self.dataset = dataset 60 | self.seed = seed 61 | self.coordinates = coordinates 62 | self.target_coordinates = target_coordinates 63 | self._init_idx() 64 | self.set_epoch(None) 65 | 66 | def set_epoch(self, epoch, **unused): 67 | super().set_epoch(epoch) 68 | self.epoch = epoch 69 | 70 | def _init_idx(self): 71 | self.idx2key = [] 72 | for i in tqdm(range(len(self.dataset))): 73 | size = len(self.dataset[i][self.coordinates]) 74 | self.idx2key.extend([(i, j) for j in range(size)]) 75 | self.cnt = len(self.idx2key) 76 | 77 | def __len__(self): 78 | return self.cnt 79 | 80 | @lru_cache(maxsize=16) 81 | def __cached_item__(self, index: int, epoch: int): 82 | key_idx, conf_idx = self.idx2key[index] 83 | data = self.dataset[key_idx] 84 | coordinates = data[self.coordinates][conf_idx] 85 | if isinstance(data[self.target_coordinates], list): 86 | target_coordinates = data[self.target_coordinates][-1] 87 | else: 88 | target_coordinates = data[self.target_coordinates] 89 | 90 | ret_data = deepcopy(data) 91 | del ret_data[self.coordinates] 92 | del ret_data[self.target_coordinates] 93 | ret_data["coordinates"] = coordinates 94 | ret_data["target_coordinates"] = target_coordinates 95 | return ret_data 96 | 97 | def __getitem__(self, index: int): 98 | return self.__cached_item__(index, self.epoch) 99 | -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | import numpy as np 6 | import torch 7 | import contextlib 8 | 9 | 10 | def str_hash(text: str): 11 | hash = 0 12 | for ch in text: 13 | hash = (hash * 281 ^ ord(ch) * 997) & 0xFFFFFFFF 14 | return hash 15 | 16 | 17 | @contextlib.contextmanager 18 | def numpy_seed(seed, *addl_seeds, key=None): 19 | """Context manager which seeds the NumPy PRNG with the specified seed and 20 | restores the state afterward""" 21 | if seed is None: 22 | yield 23 | return 24 | 25 | def check_seed(s): 26 | assert type(s) == int or type(s) == np.int32 or type(s) == np.int64 27 | 28 | check_seed(seed) 29 | if len(addl_seeds) > 0: 30 | for s in addl_seeds: 31 | check_seed(s) 32 | seed = int(hash((seed, *addl_seeds)) % 1e8) 33 | if key is not None: 34 | seed = int(hash((seed, str_hash(key))) % 1e8) 35 | state = np.random.get_state() 36 | np.random.seed(seed) 37 | try: 38 | yield 39 | finally: 40 | np.random.set_state(state) 41 | 42 | 43 | def convert_to_single_emb(x, sizes): 44 | assert x.shape[-1] == len(sizes) 45 | offset = 1 46 | for i in range(len(sizes)): 47 | assert (x[..., i] < sizes[i]).all() 48 | x[..., i] = x[..., i] + offset 49 | offset += sizes[i] 50 | return x 51 | 52 | 53 | def pad_1d(samples, pad_len, pad_value=0): 54 | batch_size = len(samples) 55 | tensor = torch.full([batch_size, pad_len], pad_value, dtype=samples[0].dtype) 56 | for i in range(batch_size): 57 | tensor[i, : samples[i].shape[0]] = samples[i] 58 | return tensor 59 | 60 | 61 | def pad_1d_feat(samples, pad_len, pad_value=0): 62 | batch_size = len(samples) 63 | assert len(samples[0].shape) == 2 64 | feat_size = samples[0].shape[-1] 65 | tensor = torch.full( 66 | [batch_size, pad_len, feat_size], pad_value, dtype=samples[0].dtype 67 | ) 68 | for i in range(batch_size): 69 | tensor[i, : samples[i].shape[0]] = samples[i] 70 | return tensor 71 | 72 | 73 | def pad_2d(samples, pad_len, pad_value=0): 74 | batch_size = len(samples) 75 | tensor = torch.full( 76 | [batch_size, pad_len, pad_len], pad_value, dtype=samples[0].dtype 77 | ) 78 | for i in range(batch_size): 79 | tensor[i, : samples[i].shape[0], : samples[i].shape[1]] = samples[i] 80 | return tensor 81 | 82 | 83 | def pad_2d_feat(samples, pad_len, pad_value=0): 84 | batch_size = len(samples) 85 | assert len(samples[0].shape) == 3 86 | feat_size = samples[0].shape[-1] 87 | tensor = torch.full( 88 | [batch_size, pad_len, pad_len, feat_size], pad_value, dtype=samples[0].dtype 89 | ) 90 | for i in range(batch_size): 91 | tensor[i, : samples[i].shape[0], : samples[i].shape[1]] = samples[i] 92 | return tensor 93 | 94 | 95 | def pad_attn_bias(samples, pad_len): 96 | batch_size = len(samples) 97 | pad_len = pad_len + 1 98 | tensor = torch.full( 99 | [batch_size, pad_len, pad_len], float("-inf"), dtype=samples[0].dtype 100 | ) 101 | for i in range(batch_size): 102 | tensor[i, : samples[i].shape[0], : samples[i].shape[1]] = samples[i] 103 | tensor[i, samples[i].shape[0] :, : samples[i].shape[1]] = 0 104 | return tensor 105 | -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/data/key_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from functools import lru_cache 6 | from unicore.data import BaseWrapperDataset 7 | 8 | 9 | class KeyDataset(BaseWrapperDataset): 10 | def __init__(self, dataset, key): 11 | self.dataset = dataset 12 | self.key = key 13 | self.epoch = None 14 | 15 | def set_epoch(self, epoch, **unused): 16 | super().set_epoch(epoch) 17 | self.epoch = epoch 18 | 19 | def __len__(self): 20 | return len(self.dataset) 21 | 22 | @lru_cache(maxsize=16) 23 | def __cached_item__(self, idx: int, epoch: int): 24 | return self.dataset[idx][self.key] 25 | 26 | def __getitem__(self, idx): 27 | return self.__cached_item__(idx, self.epoch) 28 | -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | 6 | import lmdb 7 | import os 8 | import numpy as np 9 | import gzip 10 | import pickle 11 | from functools import lru_cache 12 | import logging 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class LMDBDataset: 18 | def __init__(self, db_path, key_to_id=True, gzip=True): 19 | self.db_path = db_path 20 | assert os.path.isfile(self.db_path), "{} not found".format(self.db_path) 21 | env = self.connect_db(self.db_path) 22 | with env.begin() as txn: 23 | self._keys = list(txn.cursor().iternext(values=False)) 24 | self.key_to_id = key_to_id 25 | self.gzip = gzip 26 | 27 | def connect_db(self, lmdb_path, save_to_self=False): 28 | env = lmdb.open( 29 | lmdb_path, 30 | subdir=False, 31 | readonly=True, 32 | lock=False, 33 | readahead=False, 34 | meminit=False, 35 | max_readers=256, 36 | ) 37 | if not save_to_self: 38 | return env 39 | else: 40 | self.env = env 41 | 42 | def __len__(self): 43 | return len(self._keys) 44 | 45 | @lru_cache(maxsize=16) 46 | def __getitem__(self, idx): 47 | if not hasattr(self, "env"): 48 | self.connect_db(self.db_path, save_to_self=True) 49 | key = self._keys[idx] 50 | datapoint_pickled = self.env.begin().get(key) 51 | if self.gzip: 52 | datapoint_pickled = gzip.decompress(datapoint_pickled) 53 | data = pickle.loads(datapoint_pickled) 54 | if self.key_to_id: 55 | data["id"] = int.from_bytes(key, "big") 56 | return data 57 | 58 | 59 | class StackedLMDBDataset: 60 | def __init__(self, datasets): 61 | self._len = 0 62 | self.datasets = [] 63 | self.idx_to_file = {} 64 | self.idx_offset = [] 65 | for dataset in datasets: 66 | self.datasets.append(dataset) 67 | for i in range(len(dataset)): 68 | self.idx_to_file[i + self._len] = len(self.datasets) - 1 69 | self.idx_offset.append(self._len) 70 | self._len += len(dataset) 71 | 72 | def __len__(self): 73 | return self._len 74 | 75 | @lru_cache(maxsize=16) 76 | def __getitem__(self, idx): 77 | file_idx = self.idx_to_file[idx] 78 | sub_idx = idx - self.idx_offset[file_idx] 79 | return self.datasets[file_idx][sub_idx] 80 | -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol_plus.losses." + file.name[:-3]) 8 | -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/models/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol_plus.models." + file.name[:-3]) -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/models/unimol_plus_encoder.py: -------------------------------------------------------------------------------- 1 | import imp 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from unicore.modules import LayerNorm 8 | 9 | from .layers import ( 10 | UnimolPlusEncoderLayer, 11 | Dropout, 12 | ) 13 | 14 | 15 | class UnimolPLusEncoder(nn.Module): 16 | def __init__( 17 | self, 18 | num_encoder_layers: int = 6, 19 | embedding_dim: int = 768, 20 | pair_dim: int = 64, 21 | pair_hidden_dim: int = 32, 22 | ffn_embedding_dim: int = 3072, 23 | num_attention_heads: int = 8, 24 | dropout: float = 0.1, 25 | attention_dropout: float = 0.1, 26 | activation_dropout: float = 0.1, 27 | activation_fn: str = "gelu", 28 | droppath_prob: float = 0.0, 29 | pair_dropout: float = 0.25, 30 | ) -> None: 31 | super().__init__() 32 | self.embedding_dim = embedding_dim 33 | self.num_head = num_attention_heads 34 | self.layer_norm = LayerNorm(embedding_dim) 35 | self.pair_layer_norm = LayerNorm(pair_dim) 36 | self.layers = nn.ModuleList([]) 37 | 38 | if droppath_prob > 0: 39 | droppath_probs = [ 40 | x.item() for x in torch.linspace(0, droppath_prob, num_encoder_layers) 41 | ] 42 | else: 43 | droppath_probs = None 44 | 45 | self.layers.extend( 46 | [ 47 | UnimolPlusEncoderLayer( 48 | embedding_dim=embedding_dim, 49 | pair_dim=pair_dim, 50 | pair_hidden_dim=pair_hidden_dim, 51 | ffn_embedding_dim=ffn_embedding_dim, 52 | num_attention_heads=num_attention_heads, 53 | dropout=dropout, 54 | attention_dropout=attention_dropout, 55 | activation_dropout=activation_dropout, 56 | activation_fn=activation_fn, 57 | droppath_prob=droppath_probs[i] 58 | if droppath_probs is not None 59 | else 0, 60 | pair_dropout=pair_dropout, 61 | ) 62 | for i in range(num_encoder_layers) 63 | ] 64 | ) 65 | 66 | def forward( 67 | self, 68 | x, 69 | pair, 70 | atom_mask, 71 | pair_mask, 72 | attn_mask=None, 73 | ) -> Tuple[torch.Tensor, torch.Tensor]: 74 | x = self.layer_norm(x) 75 | pair = self.pair_layer_norm(pair) 76 | op_mask = atom_mask.unsqueeze(-1) 77 | op_mask = op_mask * (op_mask.size(-2) ** -0.5) 78 | eps = 1e-3 79 | op_norm = 1.0 / (eps + torch.einsum("...bc,...dc->...bdc", op_mask, op_mask)) 80 | for layer in self.layers: 81 | x, pair = layer( 82 | x, 83 | pair, 84 | pair_mask=pair_mask, 85 | self_attn_mask=attn_mask, 86 | op_mask=op_mask, 87 | op_norm=op_norm, 88 | ) 89 | return x, pair 90 | -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in sorted(Path(__file__).parent.glob("*.py")): 6 | if not file.name.startswith("_"): 7 | importlib.import_module("unimol_plus.tasks." + file.name[:-3]) 8 | -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/tasks/oc20.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | # 4 | 5 | import os 6 | import numpy as np 7 | from unicore.data import ( 8 | NestedDictionaryDataset, 9 | EpochShuffleDataset, 10 | ) 11 | from unimol_plus.data import ( 12 | LMDBDataset, 13 | StackedLMDBDataset, 14 | Is2reDataset, 15 | ) 16 | 17 | from unicore.tasks import UnicoreTask, register_task 18 | 19 | 20 | @register_task("oc20") 21 | class MDTask(UnicoreTask): 22 | @staticmethod 23 | def add_args(parser): 24 | parser.add_argument("data", metavar="FILE", help="file prefix for data") 25 | parser.add_argument( 26 | "--train-with-valid-data", 27 | default=False, 28 | action="store_true", 29 | ) 30 | 31 | def __init__(self, args): 32 | super().__init__(args) 33 | self.seed = args.seed 34 | 35 | @classmethod 36 | def setup_task(cls, args, **kwargs): 37 | return cls(args) 38 | 39 | def load_dataset(self, split, combine=False, **kwargs): 40 | assert split in [ 41 | "train", 42 | "val_id", 43 | "val_ood_ads", 44 | "val_ood_cat", 45 | "val_ood_both", 46 | "test_id", 47 | "test_ood_ads", 48 | "test_ood_cat", 49 | "test_ood_both", 50 | "test_sumbit", 51 | ], "invalid split: {}!".format(split) 52 | print(" > Loading {} ...".format(split)) 53 | 54 | if self.args.train_with_valid_data and split == "train": 55 | datasets = [] 56 | for cur_split in [ 57 | "train", 58 | "val_id", 59 | "val_ood_ads", 60 | "val_ood_cat", 61 | "val_ood_both", 62 | ]: 63 | db_path = os.path.join(self.args.data, cur_split, "data.lmdb") 64 | lmdb_dataset = LMDBDataset(db_path, key_to_id=False, gzip=False) 65 | datasets.append(lmdb_dataset) 66 | lmdb_dataset = StackedLMDBDataset(datasets) 67 | else: 68 | db_path = os.path.join(self.args.data, split, "data.lmdb") 69 | lmdb_dataset = LMDBDataset(db_path, key_to_id=False, gzip=False) 70 | 71 | is_train = split == "train" 72 | is2re_dataset = Is2reDataset(lmdb_dataset, self.args, is_train=is_train) 73 | nest_dataset = NestedDictionaryDataset( 74 | { 75 | "batched_data": is2re_dataset, 76 | }, 77 | ) 78 | 79 | if is_train: 80 | nest_dataset = EpochShuffleDataset( 81 | nest_dataset, len(nest_dataset), self.seed 82 | ) 83 | self.datasets[split] = nest_dataset 84 | 85 | print("| Loaded {} with {} samples".format(split, len(nest_dataset))) 86 | 87 | self.datasets[split] = nest_dataset 88 | 89 | def build_model(self, args): 90 | from unicore import models 91 | 92 | model = models.build_model(args, self) 93 | return model 94 | -------------------------------------------------------------------------------- /unimol_plus/unimol_plus/tasks/pcq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Techonology, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import os 8 | 9 | import numpy as np 10 | from unicore.data import ( 11 | NestedDictionaryDataset, 12 | EpochShuffleDataset, 13 | ) 14 | from unimol_plus.data import ( 15 | KeyDataset, 16 | LMDBDataset, 17 | ConformationSampleDataset, 18 | ConformationExpandDataset, 19 | PCQDataset, 20 | ) 21 | from unicore.tasks import UnicoreTask, register_task 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | @register_task("pcq") 27 | class PCQTask(UnicoreTask): 28 | """Task for training transformer auto-encoder models.""" 29 | 30 | @staticmethod 31 | def add_args(parser): 32 | """Add task-specific arguments to the parser.""" 33 | parser.add_argument("data", help="downstream data path") 34 | 35 | def __init__(self, args): 36 | super().__init__(args) 37 | self.seed = args.seed 38 | 39 | @classmethod 40 | def setup_task(cls, args, **kwargs): 41 | return cls(args) 42 | 43 | def load_dataset(self, split, force_valid=False, **kwargs): 44 | split_path = os.path.join(self.args.data, split + ".lmdb") 45 | dataset = LMDBDataset(split_path) 46 | is_train = (split == "train") and not force_valid 47 | if is_train: 48 | sample_dataset = ConformationSampleDataset( 49 | dataset, 50 | self.seed, 51 | "input_pos", 52 | "label_pos", 53 | ) 54 | else: 55 | sample_dataset = ConformationExpandDataset( 56 | dataset, 57 | self.seed, 58 | "input_pos", 59 | "label_pos", 60 | ) 61 | raw_coord_dataset = KeyDataset(sample_dataset, "coordinates") 62 | tgt_coord_dataset = KeyDataset(sample_dataset, "target_coordinates") 63 | graph_features = PCQDataset( 64 | sample_dataset, 65 | raw_coord_dataset, 66 | tgt_coord_dataset if split in ["train"] else None, 67 | is_train=is_train, 68 | label_prob=self.args.label_prob, 69 | mid_prob=self.args.mid_prob, 70 | mid_lower=self.args.mid_lower, 71 | mid_upper=self.args.mid_upper, 72 | noise=self.args.noise_scale, 73 | seed=self.seed + 2, 74 | ) 75 | 76 | nest_dataset = NestedDictionaryDataset( 77 | { 78 | "batched_data": graph_features, 79 | }, 80 | ) 81 | if is_train: 82 | nest_dataset = EpochShuffleDataset( 83 | nest_dataset, len(nest_dataset), self.seed 84 | ) 85 | print("| Loaded {} with {} samples".format(split, len(nest_dataset))) 86 | self.datasets[split] = nest_dataset 87 | 88 | def build_model(self, args): 89 | from unicore import models 90 | 91 | model = models.build_model(args, self) 92 | return model 93 | -------------------------------------------------------------------------------- /unimol_tools/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include unimol_tools/config/* 2 | include unimol_tools/weights/* 3 | -------------------------------------------------------------------------------- /unimol_tools/README.md: -------------------------------------------------------------------------------- 1 | > [!IMPORTANT] 2 | > To enable more efficient maintenance and facilitate future development, we have migrated unimol-tools from Uni-Mol repository. In principle, the unimol_tools folder will be archieved in current stable version(0.1.3.post1) from **June 1, 2025**. 3 | 4 | > Issues tracking and new features will be maintained in [unimol-tools](https://github.com/deepmodeling/unimol_tools), Please submit unimol-tools related issues, pull requests, or suggestions in https://github.com/deepmodeling/unimol_tools/issues. 5 | 6 | # Uni-Mol tools for various prediction and downstreams. 7 | 8 | Documentation of Uni-Mol tools is available at https://unimol.readthedocs.io/en/latest/ 9 | 10 | ## Install 11 | - pytorch is required, please install pytorch according to your environment. if you are using cuda, please install pytorch with cuda. More details can be found at https://pytorch.org/get-started/locally/ 12 | - currently, rdkit needs with numpy<2.0.0, please install rdkit with numpy<2.0.0. 13 | 14 | ### Option 1: Installing from PyPi (Recommended, for stable version) 15 | 16 | ```bash 17 | pip install unimol_tools --upgrade 18 | ``` 19 | 20 | We recommend installing ```huggingface_hub``` so that the required unimol models can be automatically downloaded at runtime! It can be install by 21 | 22 | ```bash 23 | pip install huggingface_hub 24 | ``` 25 | 26 | `huggingface_hub` allows you to easily download and manage models from the Hugging Face Hub, which is key for using Uni-Mol models. 27 | 28 | ### Option 2: Installing from source (for latest version) 29 | 30 | ```python 31 | ## Dependencies installation 32 | pip install -r requirements.txt 33 | 34 | ## Clone repository 35 | git clone https://github.com/deepmodeling/Uni-Mol.git 36 | cd Uni-Mol/unimol_tools 37 | 38 | ## Install 39 | python setup.py install 40 | ``` 41 | 42 | ### Models in Huggingface 43 | 44 | The UniMol pretrained models can be found at [dptech/Uni-Mol-Models](https://huggingface.co/dptech/Uni-Mol-Models/tree/main). 45 | 46 | If the download is slow, you can use other mirrors, such as: 47 | 48 | ```bash 49 | export HF_ENDPOINT=https://hf-mirror.com 50 | ``` 51 | 52 | Setting the `HF_ENDPOINT` environment variable specifies the mirror address for the Hugging Face Hub to use when downloading models. 53 | 54 | ### Modify the default directory for weights 55 | 56 | Setting the `UNIMOL_WEIGHT_DIR` environment variable specifies the directory for pre-trained weights if the weights have been downloaded from another source. 57 | 58 | ```bash 59 | export UNIMOL_WEIGHT_DIR=/path/to/your/weights/dir/ 60 | ``` 61 | 62 | ## News 63 | - 2025-03-28: Unimol_tools now support Distributed Data Parallel (DDP)! 64 | - 2024-11-22: Unimol V2 has been added to Unimol_tools! 65 | - 2024-07-23: User experience improvements: Add `UNIMOL_WEIGHT_DIR`. 66 | - 2024-06-25: unimol_tools has been publish to pypi! Huggingface has been used to manage the pretrain models. 67 | - 2024-06-20: unimol_tools v0.1.0 released, we remove the dependency of Uni-Core. And we will publish to pypi soon. 68 | - 2024-03-20: unimol_tools documents is available at https://unimol.readthedocs.io/en/latest/ 69 | 70 | ## molecule property prediction 71 | ```python 72 | from unimol_tools import MolTrain, MolPredict 73 | clf = MolTrain(task='classification', 74 | data_type='molecule', 75 | epochs=10, 76 | batch_size=16, 77 | metrics='auc', 78 | ) 79 | pred = clf.fit(data = data) 80 | # currently support data with smiles based csv/txt file, and 81 | # custom dict of {'atoms':[['C','C],['C','H','O']], 'coordinates':[coordinates_1,coordinates_2]} 82 | 83 | clf = MolPredict(load_model='../exp') 84 | res = clf.predict(data = data) 85 | ``` 86 | ## unimol molecule and atoms level representation 87 | ```python 88 | import numpy as np 89 | from unimol_tools import UniMolRepr 90 | # single smiles unimol representation 91 | clf = UniMolRepr(data_type='molecule', remove_hs=False) 92 | smiles = 'c1ccc(cc1)C2=NCC(=O)Nc3c2cc(cc3)[N+](=O)[O]' 93 | smiles_list = [smiles] 94 | unimol_repr = clf.get_repr(smiles_list, return_atomic_reprs=True) 95 | # CLS token repr 96 | print(np.array(unimol_repr['cls_repr']).shape) 97 | # atomic level repr, align with rdkit mol.GetAtoms() 98 | print(np.array(unimol_repr['atomic_reprs']).shape) 99 | ``` 100 | 101 | Please kindly cite our papers if you use the data/code/model. 102 | ``` 103 | @inproceedings{ 104 | zhou2023unimol, 105 | title={Uni-Mol: A Universal 3D Molecular Representation Learning Framework}, 106 | author={Gengmo Zhou and Zhifeng Gao and Qiankun Ding and Hang Zheng and Hongteng Xu and Zhewei Wei and Linfeng Zhang and Guolin Ke}, 107 | booktitle={The Eleventh International Conference on Learning Representations }, 108 | year={2023}, 109 | url={https://openreview.net/forum?id=6K2RM6wVqKu} 110 | } 111 | @article{gao2023uni, 112 | title={Uni-qsar: an auto-ml tool for molecular property prediction}, 113 | author={Gao, Zhifeng and Ji, Xiaohong and Zhao, Guojiang and Wang, Hongshuai and Zheng, Hang and Ke, Guolin and Zhang, Linfeng}, 114 | journal={arXiv preprint arXiv:2304.12239}, 115 | year={2023} 116 | } 117 | ``` 118 | 119 | License 120 | ------- 121 | 122 | This project is licensed under the terms of the MIT license. See LICENSE for additional details. 123 | -------------------------------------------------------------------------------- /unimol_tools/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.4 2 | pandas==1.4.0 3 | scikit-learn==1.5.0 4 | torch 5 | joblib 6 | rdkit 7 | pyyaml 8 | addict 9 | tqdm -------------------------------------------------------------------------------- /unimol_tools/setup.py: -------------------------------------------------------------------------------- 1 | """Install script for setuptools.""" 2 | 3 | from setuptools import find_packages 4 | from setuptools import setup 5 | 6 | setup( 7 | name="unimol_tools", 8 | version="0.1.3.post1", 9 | description=( 10 | "unimol_tools is a Python package for property prediction with Uni-Mol in molecule, materials and protein." 11 | ), 12 | long_description=open('README.md').read(), 13 | long_description_content_type='text/markdown', 14 | author="DP Technology", 15 | author_email="unimol@dp.tech", 16 | license="The MIT License", 17 | url="https://github.com/deepmodeling/Uni-Mol/unimol_tools", 18 | packages=find_packages( 19 | where='.', 20 | exclude=[ 21 | "build", 22 | "dist", 23 | ], 24 | ), 25 | install_requires=[ 26 | "numpy<2.0.0,>=1.22.4", 27 | "pandas<2.0.0", 28 | "torch", 29 | "joblib", 30 | "rdkit", 31 | "pyyaml", 32 | "addict", 33 | "scikit-learn", 34 | "numba", 35 | "tqdm", 36 | ], 37 | python_requires=">=3.6", 38 | include_package_data=True, 39 | classifiers=[ 40 | "Development Status :: 5 - Production/Stable", 41 | "Intended Audience :: Science/Research", 42 | "License :: OSI Approved :: Apache Software License", 43 | "Operating System :: POSIX :: Linux", 44 | "Programming Language :: Python :: 3.7", 45 | "Programming Language :: Python :: 3.8", 46 | "Programming Language :: Python :: 3.9", 47 | "Programming Language :: Python :: 3.10", 48 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 49 | ], 50 | ) 51 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .predict import MolPredict 2 | from .predictor import UniMolRepr 3 | from .train import MolTrain 4 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_config import MODEL_CONFIG, MODEL_CONFIG_V2 2 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/config/default.yaml: -------------------------------------------------------------------------------- 1 | ### data 2 | smiles_col: "SMILES" 3 | target_col_prefix: "TARGET" 4 | target_normalize: "auto" 5 | anomaly_clean: False 6 | smi_strict: False 7 | ### model 8 | model_name: "unimolv1" 9 | ### trainer 10 | split_method: "5fold_random" 11 | split_seed: 42 12 | seed: 42 13 | logger_level: 1 14 | patience: 10 15 | max_epochs: 100 16 | learning_rate: 1e-4 17 | warmup_ratio: 0.03 18 | batch_size: 16 19 | max_norm: 5.0 20 | use_cuda: True 21 | use_amp: True 22 | use_ddp: True 23 | use_gpu: 0, 1 -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/config/model_config.py: -------------------------------------------------------------------------------- 1 | MODEL_CONFIG = { 2 | "weight": { 3 | "protein": "poc_pre_220816.pt", 4 | "molecule_no_h": "mol_pre_no_h_220816.pt", 5 | "molecule_all_h": "mol_pre_all_h_220816.pt", 6 | "crystal": "mp_all_h_230313.pt", 7 | "oled": "oled_pre_no_h_230101.pt", 8 | }, 9 | "dict": { 10 | "protein": "poc.dict.txt", 11 | "molecule_no_h": "mol.dict.txt", 12 | "molecule_all_h": "mol.dict.txt", 13 | "crystal": "mp.dict.txt", 14 | "oled": "oled.dict.txt", 15 | }, 16 | } 17 | 18 | MODEL_CONFIG_V2 = { 19 | 'weight': { 20 | '84m': 'modelzoo/84M/checkpoint.pt', 21 | '164m': 'modelzoo/164M/checkpoint.pt', 22 | '310m': 'modelzoo/310M/checkpoint.pt', 23 | '570m': 'modelzoo/570M/checkpoint.pt', 24 | '1.1B': 'modelzoo/1.1B/checkpoint.pt', 25 | }, 26 | } 27 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datahub import DataHub 2 | from .dictionary import Dictionary 3 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/data/dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import logging 7 | 8 | import numpy as np 9 | 10 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 11 | 12 | 13 | class Dictionary: 14 | """A mapping from symbols to consecutive integers""" 15 | 16 | def __init__( 17 | self, 18 | *, # begin keyword-only arguments 19 | bos="[CLS]", 20 | pad="[PAD]", 21 | eos="[SEP]", 22 | unk="[UNK]", 23 | extra_special_symbols=None, 24 | ): 25 | self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos 26 | self.symbols = [] 27 | self.count = [] 28 | self.indices = {} 29 | self.specials = set() 30 | self.specials.add(bos) 31 | self.specials.add(unk) 32 | self.specials.add(pad) 33 | self.specials.add(eos) 34 | 35 | def __eq__(self, other): 36 | return self.indices == other.indices 37 | 38 | def __getitem__(self, idx): 39 | if idx < len(self.symbols): 40 | return self.symbols[idx] 41 | return self.unk_word 42 | 43 | def __len__(self): 44 | """Returns the number of symbols in the dictionary""" 45 | return len(self.symbols) 46 | 47 | def __contains__(self, sym): 48 | return sym in self.indices 49 | 50 | def vec_index(self, a): 51 | return np.vectorize(self.index)(a) 52 | 53 | def index(self, sym): 54 | """Returns the index of the specified symbol""" 55 | assert isinstance(sym, str) 56 | if sym in self.indices: 57 | return self.indices[sym] 58 | return self.indices[self.unk_word] 59 | 60 | def special_index(self): 61 | return [self.index(x) for x in self.specials] 62 | 63 | def add_symbol(self, word, n=1, overwrite=False, is_special=False): 64 | """Adds a word to the dictionary""" 65 | if is_special: 66 | self.specials.add(word) 67 | if word in self.indices and not overwrite: 68 | idx = self.indices[word] 69 | self.count[idx] = self.count[idx] + n 70 | return idx 71 | else: 72 | idx = len(self.symbols) 73 | self.indices[word] = idx 74 | self.symbols.append(word) 75 | self.count.append(n) 76 | return idx 77 | 78 | def bos(self): 79 | """Helper to get index of beginning-of-sentence symbol""" 80 | return self.index(self.bos_word) 81 | 82 | def pad(self): 83 | """Helper to get index of pad symbol""" 84 | return self.index(self.pad_word) 85 | 86 | def eos(self): 87 | """Helper to get index of end-of-sentence symbol""" 88 | return self.index(self.eos_word) 89 | 90 | def unk(self): 91 | """Helper to get index of unk symbol""" 92 | return self.index(self.unk_word) 93 | 94 | @classmethod 95 | def load(cls, f): 96 | """Loads the dictionary from a text file with the format: 97 | 98 | ``` 99 | 100 | 101 | ... 102 | ``` 103 | """ 104 | d = cls() 105 | d.add_from_file(f) 106 | return d 107 | 108 | def add_from_file(self, f): 109 | """ 110 | Loads a pre-existing dictionary from a text file and adds its symbols 111 | to this instance. 112 | """ 113 | if isinstance(f, str): 114 | try: 115 | with open(f, "r", encoding="utf-8") as fd: 116 | self.add_from_file(fd) 117 | except FileNotFoundError as fnfe: 118 | raise fnfe 119 | except UnicodeError: 120 | raise Exception( 121 | "Incorrect encoding detected in {}, please " 122 | "rebuild the dataset".format(f) 123 | ) 124 | return 125 | 126 | lines = f.readlines() 127 | 128 | for line_idx, line in enumerate(lines): 129 | try: 130 | splits = line.rstrip().rsplit(" ", 1) 131 | line = splits[0] 132 | field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx) 133 | if field == "#overwrite": 134 | overwrite = True 135 | line, field = line.rsplit(" ", 1) 136 | else: 137 | overwrite = False 138 | count = int(field) 139 | word = line 140 | if word in self and not overwrite: 141 | logger.info( 142 | "Duplicate word found when loading Dictionary: '{}', index is {}.".format( 143 | word, self.indices[word] 144 | ) 145 | ) 146 | else: 147 | self.add_symbol(word, n=count, overwrite=overwrite) 148 | except ValueError: 149 | raise ValueError( 150 | "Incorrect dictionary format, expected ' [flags]'" 151 | ) 152 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/data/split.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from __future__ import absolute_import, division, print_function 6 | 7 | import numpy as np 8 | from sklearn.model_selection import GroupKFold, KFold, StratifiedKFold 9 | 10 | from ..utils import logger 11 | 12 | 13 | class Splitter(object): 14 | """ 15 | The Splitter class is responsible for splitting a dataset into train and test sets 16 | based on the specified method. 17 | """ 18 | 19 | def __init__(self, method='random', kfold=5, seed=42, **params): 20 | """ 21 | Initializes the Splitter with a specified split method and random seed. 22 | 23 | :param split_method: (str) The method for splitting the dataset, in the format 'Nfold_method'. 24 | Defaults to '5fold_random'. 25 | :param seed: (int) Random seed for reproducibility in random splitting. Defaults to 42. 26 | """ 27 | self.method = method 28 | self.n_splits = kfold 29 | self.seed = seed 30 | self.splitter = self._init_split() 31 | 32 | def _init_split(self): 33 | """ 34 | Initializes the actual splitter object based on the specified method. 35 | 36 | :return: The initialized splitter object. 37 | :raises ValueError: If an unknown splitting method is specified. 38 | """ 39 | if self.n_splits == 1: 40 | return None 41 | if self.method == 'random': 42 | splitter = KFold( 43 | n_splits=self.n_splits, shuffle=True, random_state=self.seed 44 | ) 45 | elif self.method == 'scaffold' or self.method == 'group': 46 | splitter = GroupKFold(n_splits=self.n_splits) 47 | elif self.method == 'stratified': 48 | splitter = StratifiedKFold( 49 | n_splits=self.n_splits, shuffle=True, random_state=self.seed 50 | ) 51 | elif self.method == 'select': 52 | splitter = GroupKFold(n_splits=self.n_splits) 53 | else: 54 | raise ValueError( 55 | 'Unknown splitter method: {}fold - {}'.format( 56 | self.n_splits, self.method 57 | ) 58 | ) 59 | 60 | return splitter 61 | 62 | def split(self, smiles, target=None, group=None, scaffolds=None, **params): 63 | """ 64 | Splits the dataset into train and test sets based on the initialized method. 65 | 66 | :param data: The dataset to be split. 67 | :param target: (optional) Target labels for stratified splitting. Defaults to None. 68 | :param group: (optional) Group labels for group-based splitting. Defaults to None. 69 | 70 | :return: An iterator yielding train and test set indices for each fold. 71 | :raises ValueError: If the splitter method does not support the provided parameters. 72 | """ 73 | if self.n_splits == 1: 74 | logger.warning( 75 | 'Only one fold is used for training, no splitting is performed.' 76 | ) 77 | return [(np.arange(len(smiles)), ())] 78 | if smiles is None and 'atoms' in params: 79 | smiles = params['atoms'] 80 | logger.warning('Atoms are used as SMILES for splitting.') 81 | if self.method in ['random']: 82 | self.skf = self.splitter.split(smiles) 83 | elif self.method in ['scaffold']: 84 | self.skf = self.splitter.split(smiles, target, scaffolds) 85 | elif self.method in ['group']: 86 | self.skf = self.splitter.split(smiles, target, group) 87 | elif self.method in ['stratified']: 88 | self.skf = self.splitter.split(smiles, group) 89 | elif self.method in ['select']: 90 | unique_groups = np.unique(group) 91 | if len(unique_groups) == self.n_splits: 92 | split_folds = [] 93 | for unique_group in unique_groups: 94 | train_idx = np.where(group != unique_group)[0] 95 | test_idx = np.where(group == unique_group)[0] 96 | split_folds.append((train_idx, test_idx)) 97 | self.split_folds = split_folds 98 | return self.split_folds 99 | else: 100 | logger.error( 101 | 'The number of unique groups is not equal to the number of splits.' 102 | ) 103 | exit(1) 104 | else: 105 | logger.error('Unknown splitter method: {}'.format(self.method)) 106 | exit(1) 107 | self.split_folds = list(self.skf) 108 | return self.split_folds 109 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .nnmodel import NNModel, UniMolModel, UniMolV2Model 2 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_logger import logger 2 | from .config_handler import YamlHandler 3 | from .metrics import Metrics 4 | from .util import * 5 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/utils/base_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from __future__ import absolute_import, division, print_function 6 | 7 | import datetime 8 | import logging 9 | import os 10 | import sys 11 | import threading 12 | from logging.handlers import TimedRotatingFileHandler 13 | 14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | 16 | 17 | class PackagePathFilter(logging.Filter): 18 | """A custom logging filter for adding the relative path to the log record.""" 19 | 20 | def filter(self, record): 21 | """add relative path to record""" 22 | pathname = record.pathname 23 | record.relativepath = None 24 | abs_sys_paths = map(os.path.abspath, sys.path) 25 | for path in sorted(abs_sys_paths, key=len, reverse=True): # longer paths first 26 | if not path.endswith(os.sep): 27 | path += os.sep 28 | if pathname.startswith(path): 29 | record.relativepath = os.path.relpath(pathname, path) 30 | break 31 | return True 32 | 33 | 34 | class Logger(object): 35 | """A custom logger class that provides logging functionality to console and file.""" 36 | 37 | _instance = None 38 | _lock = threading.Lock() 39 | 40 | DATE_FORMAT = "%Y-%m-%d %H:%M:%S" 41 | LOG_FORMAT = "%(asctime)s | %(relativepath)s | %(lineno)s | %(levelname)s | %(name)s | %(message)s" 42 | 43 | def __new__(cls, *args, **kwargs): 44 | if not cls._instance: 45 | with cls._lock: 46 | if not cls._instance: 47 | cls._instance = super(Logger, cls).__new__(cls) 48 | return cls._instance 49 | 50 | def __init__(self, logger_name='None'): 51 | """ 52 | :param logger_name: (str) The name of the logger (default: 'None') 53 | """ 54 | self.logger = logging.getLogger(logger_name) 55 | logging.root.setLevel(logging.NOTSET) 56 | self.log_file_name = 'unimol_tools_{0}.log'.format( 57 | datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 58 | ) 59 | 60 | cwd_path = os.path.abspath(os.getcwd()) 61 | self.log_path = os.path.join(cwd_path, "logs") 62 | 63 | if not os.path.exists(self.log_path): 64 | os.makedirs(self.log_path) 65 | self.backup_count = 5 66 | 67 | self.console_output_level = 'INFO' 68 | self.file_output_level = 'INFO' 69 | 70 | self.formatter = logging.Formatter(self.LOG_FORMAT, self.DATE_FORMAT) 71 | 72 | def get_logger(self): 73 | """ 74 | Get the logger object. 75 | 76 | :return: logging.Logger - a logger object. 77 | 78 | """ 79 | if not self.logger.handlers: 80 | console_handler = logging.StreamHandler() 81 | console_handler.setFormatter(self.formatter) 82 | console_handler.setLevel(self.console_output_level) 83 | console_handler.addFilter(PackagePathFilter()) 84 | self.logger.addHandler(console_handler) 85 | 86 | file_handler = TimedRotatingFileHandler( 87 | filename=os.path.join(self.log_path, self.log_file_name), 88 | when='D', 89 | interval=1, 90 | backupCount=self.backup_count, 91 | delay=True, 92 | encoding='utf-8', 93 | ) 94 | file_handler.setFormatter(self.formatter) 95 | file_handler.setLevel(self.file_output_level) 96 | self.logger.addHandler(file_handler) 97 | return self.logger 98 | 99 | 100 | # add highlight formatter to logger 101 | class HighlightFormatter(logging.Formatter): 102 | def format(self, record): 103 | if record.levelno == logging.WARNING: 104 | record.msg = "\033[93m{}\033[0m".format(record.msg) # 黄色高亮 105 | return super().format(record) 106 | 107 | 108 | logger = Logger('Uni-Mol Tools').get_logger() 109 | logger.setLevel(logging.INFO) 110 | 111 | # highlight warning messages in console 112 | for handler in logger.handlers: 113 | if isinstance(handler, logging.StreamHandler): 114 | handler.setFormatter(HighlightFormatter(Logger.LOG_FORMAT, Logger.DATE_FORMAT)) 115 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/utils/config_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from __future__ import absolute_import, division, print_function 6 | 7 | import os 8 | 9 | import yaml 10 | from addict import Dict 11 | 12 | from .base_logger import logger 13 | 14 | 15 | class YamlHandler: 16 | '''A clss to read and write the yaml file''' 17 | 18 | def __init__(self, file_path): 19 | """ 20 | A custom logger class that provides logging functionality to console and file. 21 | 22 | :param file_path: (str) The yaml file path of the config. 23 | """ 24 | if not os.path.exists(file_path): 25 | raise FileExistsError(OSError) 26 | self.file_path = file_path 27 | 28 | def read_yaml(self, encoding='utf-8'): 29 | """read yaml file and convert to easydict 30 | 31 | :param encoding: (str) encoding method uses utf-8 by default 32 | :return: Dict (addict), the usage of Dict is the same as dict 33 | """ 34 | with open(self.file_path, encoding=encoding) as f: 35 | return Dict(yaml.load(f.read(), Loader=yaml.FullLoader)) 36 | 37 | def write_yaml(self, data, out_file_path, encoding='utf-8'): 38 | """write dict or easydict to yaml file(auto write to self.file_path) 39 | 40 | :param data: (dict or Dict(addict)) dict containing the contents of the yaml file 41 | """ 42 | with open(out_file_path, encoding=encoding, mode='w') as f: 43 | return yaml.dump( 44 | addict2dict(data) if isinstance(data, Dict) else data, 45 | stream=f, 46 | allow_unicode=True, 47 | ) 48 | 49 | 50 | def addict2dict(addict_obj): 51 | '''convert addict to dict 52 | 53 | :param addict_obj: (Dict(addict)) the addict obj that you want to convert to dict 54 | 55 | :return: (Dict) converted result 56 | ''' 57 | dict_obj = {} 58 | for key, vals in addict_obj.items(): 59 | dict_obj[key] = addict2dict(vals) if isinstance(vals, Dict) else vals 60 | return dict_obj 61 | 62 | 63 | if __name__ == '__main__': 64 | yaml_handler = YamlHandler('../config/default.yaml') 65 | config = yaml_handler.read_yaml() 66 | print(config.Modelhub) 67 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) DP Technology. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from hashlib import md5 6 | 7 | 8 | def pad_1d_tokens( 9 | values, 10 | pad_idx, 11 | left_pad=False, 12 | pad_to_length=None, 13 | pad_to_multiple=1, 14 | ): 15 | """ 16 | padding one dimension tokens inputs. 17 | 18 | :param values: A list of 1d tensors. 19 | :param pad_idx: The padding index. 20 | :param left_pad: Whether to left pad the tensors. Defaults to False. 21 | :param pad_to_length: The desired length of the padded tensors. Defaults to None. 22 | :param pad_to_multiple: The multiple to pad the tensors to. Defaults to 1. 23 | 24 | :return: A padded 1d tensor as a torch.Tensor. 25 | 26 | """ 27 | size = max(v.size(0) for v in values) 28 | size = size if pad_to_length is None else max(size, pad_to_length) 29 | if pad_to_multiple != 1 and size % pad_to_multiple != 0: 30 | size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) 31 | res = values[0].new(len(values), size).fill_(pad_idx) 32 | 33 | def copy_tensor(src, dst): 34 | assert dst.numel() == src.numel() 35 | dst.copy_(src) 36 | 37 | for i, v in enumerate(values): 38 | copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) 39 | return res 40 | 41 | 42 | def pad_2d( 43 | values, 44 | pad_idx, 45 | dim=1, 46 | left_pad=False, 47 | pad_to_length=None, 48 | pad_to_multiple=1, 49 | ): 50 | """ 51 | padding two dimension tensor inputs. 52 | 53 | :param values: A list of 2d tensors. 54 | :param pad_idx: The padding index. 55 | :param left_pad: Whether to pad on the left side. Defaults to False. 56 | :param pad_to_length: The length to pad the tensors to. If None, the maximum length in the list 57 | is used. Defaults to None. 58 | :param pad_to_multiple: The multiple to pad the tensors to. Defaults to 1. 59 | 60 | :return: A padded 2d tensor as a torch.Tensor. 61 | """ 62 | size = max(v.size(0) for v in values) 63 | size = size if pad_to_length is None else max(size, pad_to_length) 64 | if pad_to_multiple != 1 and size % pad_to_multiple != 0: 65 | size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) 66 | if dim == 1: 67 | res = values[0].new(len(values), size, size).fill_(pad_idx) 68 | else: 69 | res = values[0].new(len(values), size, size, dim).fill_(pad_idx) 70 | 71 | def copy_tensor(src, dst): 72 | assert dst.numel() == src.numel() 73 | dst.copy_(src) 74 | 75 | for i, v in enumerate(values): 76 | copy_tensor( 77 | v, 78 | ( 79 | res[i][size - len(v) :, size - len(v) :] 80 | if left_pad 81 | else res[i][: len(v), : len(v)] 82 | ), 83 | ) 84 | return res 85 | 86 | 87 | def pad_coords( 88 | values, 89 | pad_idx, 90 | dim=3, 91 | left_pad=False, 92 | pad_to_length=None, 93 | pad_to_multiple=1, 94 | ): 95 | """ 96 | padding two dimension tensor coords which the third dimension is 3. 97 | 98 | :param values: A list of 1d tensors. 99 | :param pad_idx: The value used for padding. 100 | :param left_pad: Whether to pad on the left side. Defaults to False. 101 | :param pad_to_length: The desired length of the padded tensor. Defaults to None. 102 | :param pad_to_multiple: The multiple to pad the tensor to. Defaults to 1. 103 | 104 | :return: A padded 2d coordinate tensor as a torch.Tensor. 105 | """ 106 | size = max(v.size(0) for v in values) 107 | size = size if pad_to_length is None else max(size, pad_to_length) 108 | if pad_to_multiple != 1 and size % pad_to_multiple != 0: 109 | size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) 110 | res = values[0].new(len(values), size, dim).fill_(pad_idx) 111 | 112 | def copy_tensor(src, dst): 113 | assert dst.numel() == src.numel() 114 | dst.copy_(src) 115 | 116 | for i, v in enumerate(values): 117 | copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :]) 118 | return res 119 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/weights/__init__.py: -------------------------------------------------------------------------------- 1 | from .weighthub import WEIGHT_DIR, weight_download, weight_download_v2 2 | -------------------------------------------------------------------------------- /unimol_tools/unimol_tools/weights/weighthub.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from ..utils import logger 4 | 5 | try: 6 | from huggingface_hub import snapshot_download 7 | except: 8 | huggingface_hub_installed = False 9 | 10 | def snapshot_download(*args, **kwargs): 11 | raise ImportError( 12 | 'huggingface_hub is not installed. If weights are not avaliable, please install it by running: pip install huggingface_hub. Otherwise, please download the weights manually from https://huggingface.co/dptech/Uni-Mol-Models' 13 | ) 14 | 15 | 16 | WEIGHT_DIR = os.environ.get( 17 | 'UNIMOL_WEIGHT_DIR', os.path.dirname(os.path.abspath(__file__)) 18 | ) 19 | 20 | os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" # use mirror to download weights 21 | 22 | 23 | def log_weights_dir(): 24 | """ 25 | Logs the directory where the weights are stored. 26 | """ 27 | if 'UNIMOL_WEIGHT_DIR' in os.environ: 28 | logger.warning( 29 | f'Using custom weight directory from UNIMOL_WEIGHT_DIR: {WEIGHT_DIR}' 30 | ) 31 | else: 32 | logger.info(f'Weights will be downloaded to default directory: {WEIGHT_DIR}') 33 | 34 | 35 | def weight_download(pretrain, save_path, local_dir_use_symlinks=True): 36 | """ 37 | Downloads the specified pretrained model weights. 38 | 39 | :param pretrain: (str), The name of the pretrained model to download. 40 | :param save_path: (str), The directory where the weights should be saved. 41 | :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to True. 42 | """ 43 | log_weights_dir() 44 | 45 | if os.path.exists(os.path.join(save_path, pretrain)): 46 | logger.info(f'{pretrain} exists in {save_path}') 47 | return 48 | 49 | logger.info(f'Downloading {pretrain}') 50 | snapshot_download( 51 | repo_id="dptech/Uni-Mol-Models", 52 | local_dir=save_path, 53 | allow_patterns=pretrain, 54 | local_dir_use_symlinks=local_dir_use_symlinks, 55 | # max_workers=8 56 | ) 57 | 58 | 59 | def weight_download_v2(pretrain, save_path, local_dir_use_symlinks=True): 60 | """ 61 | Downloads the specified pretrained model weights. 62 | 63 | :param pretrain: (str), The name of the pretrained model to download. 64 | :param save_path: (str), The directory where the weights should be saved. 65 | :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to True. 66 | """ 67 | log_weights_dir() 68 | 69 | if os.path.exists(os.path.join(save_path, pretrain)): 70 | logger.info(f'{pretrain} exists in {save_path}') 71 | return 72 | 73 | logger.info(f'Downloading {pretrain}') 74 | snapshot_download( 75 | repo_id="dptech/Uni-Mol2", 76 | local_dir=save_path, 77 | allow_patterns=pretrain, 78 | local_dir_use_symlinks=local_dir_use_symlinks, 79 | # max_workers=8 80 | ) 81 | 82 | 83 | # Download all the weights when this script is run 84 | def download_all_weights(local_dir_use_symlinks=False): 85 | """ 86 | Downloads all available pretrained model weights to the WEIGHT_DIR. 87 | 88 | :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to False. 89 | """ 90 | log_weights_dir() 91 | 92 | logger.info(f'Downloading all weights to {WEIGHT_DIR}') 93 | snapshot_download( 94 | repo_id="dptech/Uni-Mol-Models", 95 | local_dir=WEIGHT_DIR, 96 | allow_patterns='*', 97 | local_dir_use_symlinks=local_dir_use_symlinks, 98 | # max_workers=8 99 | ) 100 | 101 | 102 | if '__main__' == __name__: 103 | download_all_weights() 104 | --------------------------------------------------------------------------------