├── .dockerignore ├── .gitattributes ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE-CODE ├── README.md ├── examples ├── generations │ ├── generate.sh │ └── prompts.csv └── protein_gym_zero_shot │ ├── csv_to_fasta.py │ ├── download_and_prepare.sh │ ├── run.sh │ └── score.py ├── mypy.ini ├── pixi.lock ├── pyproject.toml ├── setup.sh └── src └── progen3 ├── __init__.py ├── batch_preparer.py ├── common ├── __init__.py ├── dist.py └── model_loading.py ├── config.py ├── generator.py ├── model ├── __init__.py ├── attention.py ├── mb_wrapper.py └── moe.py ├── modeling.py ├── scorer.py ├── tokenizer.json ├── tokenizer.py └── tools ├── .gitignore ├── __init__.py ├── generate.py ├── score.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .git 3 | **/.pixi 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # GitHub syntax highlighting 2 | pixi.lock linguist-language=YAML linguist-generated=true 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | local/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | .idea 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # Other 134 | .DS_Store 135 | checkpoints/ 136 | *.egg-info/ 137 | **/data/ 138 | **/outputs/ 139 | 140 | # VSCode 141 | .vscode/ 142 | 143 | # pixi environments 144 | .pixi 145 | *.egg-info 146 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v5.0.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-executables-have-shebangs 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-added-large-files 19 | 20 | # python code formatting 21 | - repo: https://github.com/psf/black 22 | rev: 23.1.0 23 | hooks: 24 | - id: black 25 | args: [--line-length, "120"] 26 | 27 | # python import sorting 28 | - repo: https://github.com/PyCQA/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | args: 33 | [ 34 | "--filter-files", 35 | "--skip", 36 | "__init__.py", 37 | "--profile", 38 | "black", 39 | "-m", 40 | "3", 41 | ] 42 | 43 | # python upgrading syntax to newer version 44 | - repo: https://github.com/asottile/pyupgrade 45 | rev: v3.3.1 46 | hooks: 47 | - id: pyupgrade 48 | args: [--py38-plus] 49 | 50 | # python check (PEP8), programming errors and code complexity 51 | - repo: https://github.com/PyCQA/flake8 52 | rev: 7.1.1 53 | hooks: 54 | - id: flake8 55 | args: 56 | [ 57 | "--max-line-length", 58 | "120", 59 | "--extend-ignore", 60 | "E203,E402,E501,F401,F841,W503,E741,E231,F403", 61 | "--exclude", 62 | "logs/*,data/*", 63 | ] 64 | 65 | # python security linter 66 | - repo: https://github.com/PyCQA/bandit 67 | rev: "1.7.5" 68 | hooks: 69 | - id: bandit 70 | args: ["-s", "B101", "--exclude", "tests"] 71 | 72 | # yaml formatting 73 | - repo: https://github.com/pre-commit/mirrors-prettier 74 | rev: v3.0.0-alpha.6 75 | hooks: 76 | - id: prettier 77 | types: [yaml] 78 | exclude: "environment.yaml" 79 | 80 | # shell scripts linter 81 | - repo: https://github.com/shellcheck-py/shellcheck-py 82 | rev: v0.9.0.2 83 | hooks: 84 | - id: shellcheck 85 | 86 | # md formatting 87 | - repo: https://github.com/executablebooks/mdformat 88 | rev: 0.7.17 89 | hooks: 90 | - id: mdformat 91 | args: ["--number"] 92 | additional_dependencies: 93 | - mdformat-gfm 94 | - mdformat-tables 95 | - mdformat_frontmatter 96 | # - mdformat-toc 97 | # - mdformat-black 98 | 99 | # jupyter notebook cell output clearing 100 | - repo: https://github.com/kynan/nbstripout 101 | rev: 0.6.1 102 | hooks: 103 | - id: nbstripout 104 | 105 | # jupyter notebook linting 106 | - repo: https://github.com/nbQA-dev/nbQA 107 | rev: 1.6.3 108 | hooks: 109 | - id: nbqa-black 110 | args: ["--line-length=120"] 111 | additional_dependencies: ["setuptools", "black"] 112 | - id: nbqa-isort 113 | args: ["--profile=black"] 114 | additional_dependencies: ["setuptools", "isort"] 115 | - id: nbqa-flake8 116 | args: 117 | [ 118 | "--extend-ignore=E203,E402,E501,F401,F841,W503", 119 | "--exclude=logs/*,data/*", 120 | ] 121 | additional_dependencies: ["setuptools", "flake8"] 122 | 123 | # python type checking 124 | - repo: https://github.com/pre-commit/mirrors-mypy 125 | rev: v1.14.1 126 | hooks: 127 | - id: mypy 128 | exclude: 'test_.*\.py$' 129 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=linux/x86_64 mosaicml/pytorch:2.5.1_cu124-python3.11-ubuntu22.04 2 | 3 | SHELL ["/bin/bash", "-c"] 4 | WORKDIR /root 5 | 6 | # Install base utilities 7 | RUN apt-get update && \ 8 | apt-get install -y \ 9 | build-essential \ 10 | wget \ 11 | curl \ 12 | git \ 13 | vim \ 14 | libxml2 \ 15 | apt-transport-https \ 16 | ca-certificates \ 17 | gnupg \ 18 | unzip && \ 19 | apt-get -y clean && \ 20 | apt-get -y autoremove && \ 21 | rm -rf /var/lib/apt/lists/* 22 | 23 | # Install pixi 24 | RUN curl -fsSL https://pixi.sh/install.sh | bash 25 | ENV PATH="/root/.pixi/bin:${PATH}" 26 | ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0" 27 | RUN pip install "megablocks[gg]==0.7.0" 28 | RUN MAX_JOBS=4 pip install flash-attn --no-build-isolation 29 | -------------------------------------------------------------------------------- /LICENSE-CODE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2025 Profluent Bio Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProGen3 2 | 3 | This repository contains code to access the ProGen3 family of models as well as cli interface to run these models for scoring and generation. The model weights are released under a **non-commercial** usage license. Please see the [LICENSE](#license) section below. 4 | 5 | ## Models Available 6 | 7 | More info on ProGen3 Models can be found [here](https://profluent.bio/showcase/progen3). In the subsequent sections, you can provide a model name from following table where needed. 8 | 9 | | Model Name | Model Parameters | HuggingFace Link | 10 | | -------------------------- | ---------------- | ------------------------------------------------------------------------------- | 11 | | Profluent-Bio/progen3-112m | 112M | [Profluent-Bio/progen3-112m](https://huggingface.co/Profluent-Bio/progen3-112m) | 12 | | Profluent-Bio/progen3-219m | 219M | [Profluent-Bio/progen3-219m](https://huggingface.co/Profluent-Bio/progen3-219m) | 13 | | Profluent-Bio/progen3-339m | 339M | [Profluent-Bio/progen3-339m](https://huggingface.co/Profluent-Bio/progen3-339m) | 14 | | Profluent-Bio/progen3-762m | 762M | [Profluent-Bio/progen3-762m](https://huggingface.co/Profluent-Bio/progen3-762m) | 15 | | Profluent-Bio/progen3-1b | 1B | [Profluent-Bio/progen3-1b](https://huggingface.co/Profluent-Bio/progen3-1b) | 16 | | Profluent-Bio/progen3-3b | 3B | [Profluent-Bio/progen3-3b](https://huggingface.co/Profluent-Bio/progen3-3b) | 17 | 18 | ## Installation 19 | 20 | ### Local Usage 21 | 22 | ProGen3 family of models require atleast one GPU device (we have tested it only on a A100/H100 with atleast 40GB of VRAM; GPUs from prior generations may not work due to lack of support for bf16 precision and flash attention kernels) be available. 23 | 24 | 1. Clone this repo and run `bash setup.sh` 25 | 26 | ### Docker Usage 27 | 28 | We also provide a docker image `ghcr.io/profluent-ai/progen3:v0.1.0` that comes with certain libraries pre-installed to reduce installation time. This image still needs to be run on a machine with GPU device. 29 | 30 | Within the container, follow the steps in the section [Local Usage](#local-usage). 31 | 32 | ## Interactive Usage 33 | 34 | ```python 35 | import torch 36 | 37 | from progen3.modeling import ProGen3ForCausalLM 38 | from progen3.batch_preparer import ProGen3BatchPreparer 39 | from progen3.scorer import ProGen3Scorer 40 | 41 | model = ProGen3ForCausalLM.from_pretrained("Profluent-Bio/progen3-3b", torch_dtype=torch.bfloat16) 42 | model = model.eval().to("cuda:0") 43 | batch_preparer = ProGen3BatchPreparer() 44 | 45 | # Direct Usage 46 | sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN" 47 | 48 | inputs = batch_preparer.get_batch_kwargs([sequence], device="cuda:0", reverse=False) 49 | outputs = model(**inputs, return_dict=True) 50 | print(outputs.logits) 51 | 52 | # Usage with scorer (returns averaged log likelihood of forward and reverse direction) 53 | # Would suggest using Scoring CLI below if scoring very large number of sequences 54 | scorer = ProGen3Scorer(model=model) 55 | scores = scorer.score_batch(sequences=[sequence]) 56 | print(scores["log_likelihood"][0]) 57 | ``` 58 | 59 | ## Scoring CLI 60 | 61 | Prepare the fasta file (described in the next section) containing sequences to be scored. Then, run the following command (modifying the arguments as necessary): 62 | 63 | ```bash 64 | torchrun --nproc-per-node=gpu -m progen3.tools.score \ 65 | --fasta-path sequences.fasta \ 66 | --output-path scores.csv \ 67 | --model-name Profluent-Bio/progen3-3b \ 68 | --fsdp 69 | ``` 70 | 71 | See an example in directory [examples/protein_gym_zero_shot](examples/protein_gym_zero_shot) on how we use it to evaluate spearman score on protein gym assays. 72 | 73 | ```bash 74 | bash examples/protein_gym_zero_shot/download_and_prepare.sh 75 | 76 | # Run for individual assay 77 | bash examples/protein_gym_zero_shot/run.sh A0A140D2T1_ZIKV_Sourisseau_2019 Profluent-Bio/progen3-3b 78 | 79 | # Run for all assays and print aggregated spearman 80 | bash examples/protein_gym_zero_shot/run.sh all Profluent-Bio/progen3-3b 81 | ``` 82 | 83 | ### Input File Formats 84 | 85 | - `fasta-path` : A path to fasta file 86 | 87 | Each fasta file should be of form: 88 | 89 | ``` 90 | >seq_id_1 91 | AASNNMETYR 92 | >seq_id_2 93 | MKKLPSDDEFGHKKL 94 | ``` 95 | 96 | where each sequence has a unique id and maximum length of 8190 residues. Once the scoring is complete, the output csv file will have following columns: 97 | 98 | - `sequence_id` : The sequence id corresponding to each sequence in the input fasta file 99 | - `log_likelihood` : The log likelihood of the sequence. 100 | - `perplexity` : The perplexity of the sequence. 101 | 102 | ## Generation CLI 103 | 104 | Prepare the prompt file (described in the next section). Then run the following command (modifying the arguments as necessary): 105 | 106 | ```bash 107 | mkdir -p generation_outputs/ 108 | 109 | torchrun --nproc-per-node=gpu -m progen3.tools.generate \ 110 | --prompt-file examples/generations/prompts.csv \ 111 | --output-dir generation_outputs/ \ 112 | --model-name Profluent-Bio/progen3-3b \ 113 | --n-per-prompt 5000 \ 114 | --fsdp \ 115 | --temperature 0.85 \ 116 | --top-p 0.95 117 | ``` 118 | 119 | You will get two files per prompt in the generation_outputs directory after completion. First is `{id}.gen.fasta` which contains the raw tokens generated by the model for each generated sequence (not containing the prompt). The second is `{id}.seq.fasta` where the generations are combined with prompt properly to output a valid N-to-C terminal protein sequence. Note, the `seq.fasta` files only contain a subset of `gen.fasta` file generations since some of the generations in `gen.fasta` file may be invalid. 120 | 121 | See an example in directory [examples/generations](examples/generations) of a prompt file for generations for Deaminase given forward and reverse prefixes and for infilling a section of PETase (again both in forward and reverse direction). 122 | 123 | ### Prompt File Format 124 | 125 | The prompt file should be a .csv file. Required columns: `id`, `sequence`, `min_new_tokens`, `max_new_tokens`. (min/max new tokens are added to the sequence length). id is the unique identifier for each prompt `sequence`. 126 | 127 | `sequence` format: `<1/2>` 128 | 129 | - `<1/2>` : The prompt to generate from. Assuming you want to generate for original sequence `AASNNMETYR`, the prompt should be: 130 | - For CLM in N-to-C (forward) direction: residues from N direction (e.g. `1AAS`) 131 | - For CLM in C-to-N (reverse) direction: residues from C direction (e.g. `2RYTE`) 132 | - For unconditional generation, you can leave the prompt part empty and only provide the direction indicator. e.g. `1` for N-to-C (forward) direction and `2` for C-to-N (reverse) direction. 133 | 134 | ### Info on GLM prompt preparation 135 | 136 | Assume you have a original protein sequence of length `N`. Let's say you want to fill in the span from original residue from `` to `` (start inclusive, end exclusive, 0-indexed) but shorten/lengthen it to (on average) `M` residues. For the N-to-C (forward) direction, the prompt should be: 137 | 138 | `1[GLM]--` 139 | 140 | For the C-to-N (reverse) direction, the prompt should be: 141 | 142 | `2[GLM]--` 143 | 144 | where `` and `` is the start and end of the span from the opposite direction. 145 | 146 | Note: you would want to set the `max_new_tokens` to the maximum length of the infill you want (this would be a hard constraint compared to `` in the prompt format which is a soft length constraint). 147 | 148 | Example: 149 | Consider you want to fill in the `NNMET` part of `AASNNMETYR` with, on average, 4 residues instead of 5. So the start pos is 3 and the end pos is 8. 150 | 151 | If you want to fill in the `NNMET` part in N-to-C direction, the prompt should be: 152 | 153 | `1AASNNMETYR[GLM]3-8-4` 154 | 155 | If you want to fill in the `NNMET` part in C-to-N direction, the prompt should be: 156 | 157 | `2RYTEMNNSAA[GLM]2-7-4` 158 | 159 | ## LICENSE 160 | 161 | The code in this repository in released under [Apache 2.0 License](LICENSE-CODE). 162 | 163 | The model weights are released under [CC BY-NC-SA 4.0 License](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.txt). 164 | -------------------------------------------------------------------------------- /examples/generations/generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | model_name=$1 6 | 7 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 8 | OUTPUT_DIR=$SCRIPT_DIR/outputs/$model_name 9 | 10 | mkdir -p "$OUTPUT_DIR" 11 | 12 | torchrun --nproc-per-node=gpu -m progen3.tools.generate \ 13 | --prompt-file "$SCRIPT_DIR/prompts.csv" \ 14 | --output-dir "$OUTPUT_DIR" \ 15 | --model-name "$model_name" \ 16 | --n-per-prompt 5000 \ 17 | --fsdp \ 18 | --temperature 0.85 \ 19 | --top-p 0.95 20 | -------------------------------------------------------------------------------- /examples/generations/prompts.csv: -------------------------------------------------------------------------------- 1 | id,sequence,min_new_tokens,max_new_tokens 2 | TADA_ECOLI-fwd,1MSEVEFSHEYWMRHALTLAKRAWDEREVPVGAVLVHNNRV,100,200 3 | TADA_ECOLI-rev,2DTSSQAKKQAKIEQRRMRFFDSLLAACEDALIGETIEVRHNMGPHHL,100,200 4 | PETH_PISS1-fwd-infill,1MNFPRASRLMQAAVLGGLMAVSAAATAQTNPYARGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCS[GLM]172-230-58,48,88 5 | PETH_PISS1-rev-infill,2SCNATRFDSVRTSNPNECAFTSYRTDNDMFRKMWAVGKKGILAQNSNGSNACSHSGGNIELFQKANRSMSDYIPLASSNVPAISDNECAFILTPVTVSSFNTSSDWPAQPAAAKLSPNNAASILSGGGGMSWGMVGMRATDVKGYIPSSSTGNLSAVQRLAAMQQSSRSSPQDLTSNTDITIVVFGHSALRPGWWKISSQRATYGPVIAIAGVTGGANTPYYVTGAGYGSPRSVTFSRVTFPGASAELSAATPNPGRAYPNTQATAAASVAMLGGLVAAQMLRSARPFNM[GLM]60-118-58,48,88 6 | -------------------------------------------------------------------------------- /examples/protein_gym_zero_shot/csv_to_fasta.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from pathlib import Path 4 | 5 | import click 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | 10 | @click.command() 11 | @click.option("--csv-path", type=str, required=True) 12 | @click.option("--fasta-path", type=str, required=True) 13 | def main(csv_path: str, fasta_path: str) -> None: 14 | if os.path.isdir(csv_path): 15 | csv_files = glob.glob(os.path.join(csv_path, "*.csv")) 16 | elif os.path.isfile(csv_path): 17 | csv_files = [csv_path] 18 | else: 19 | raise ValueError(f"Invalid CSV path: {csv_path}") 20 | 21 | num_sequences = 0 22 | with open(fasta_path, "w") as f: 23 | for csv_file in tqdm(csv_files, desc="Converting files to FASTA", ncols=80): 24 | df = pd.read_csv(csv_file) 25 | assay_name = Path(csv_file).stem 26 | for idx, sequence in enumerate(df["mutated_sequence"].values): 27 | f.write(f">{assay_name}+{idx}\n{sequence}\n") 28 | num_sequences += 1 29 | 30 | print(f"Wrote {num_sequences} sequences to {fasta_path}") 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /examples/protein_gym_zero_shot/download_and_prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 6 | DATA_DIR=$SCRIPT_DIR/data 7 | 8 | if [ ! -f "$DATA_DIR/DMS_ProteinGym_substitutions.zip" ]; then 9 | wget https://marks.hms.harvard.edu/proteingym/DMS_ProteinGym_substitutions.zip -P "$DATA_DIR" 10 | fi 11 | if [ ! -f "$DATA_DIR/DMS_substitutions.csv" ]; then 12 | wget https://marks.hms.harvard.edu/proteingym/DMS_substitutions.csv -P "$DATA_DIR" 13 | fi 14 | if [ ! -d "$DATA_DIR/DMS_ProteinGym_substitutions" ]; then 15 | unzip "$DATA_DIR/DMS_ProteinGym_substitutions.zip" -d "$DATA_DIR" 16 | fi 17 | -------------------------------------------------------------------------------- /examples/protein_gym_zero_shot/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -euo pipefail 4 | 5 | assay_name=$1 6 | model_name=$2 7 | 8 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 9 | DATA_DIR=$SCRIPT_DIR/data 10 | FASTA_DIR=$DATA_DIR/DMS_ProteinGym_substitutions_fasta 11 | OUTPUTS_DIR=$SCRIPT_DIR/outputs/$model_name 12 | 13 | mkdir -p "$FASTA_DIR" 14 | mkdir -p "$OUTPUTS_DIR" 15 | 16 | fasta_file_name=$FASTA_DIR/${assay_name}.fasta 17 | 18 | if [ "$assay_name" = "all" ]; then 19 | CSV_PATH="$DATA_DIR/DMS_ProteinGym_substitutions" 20 | else 21 | CSV_PATH="$DATA_DIR/DMS_ProteinGym_substitutions/$assay_name.csv" 22 | fi 23 | 24 | python3 "$SCRIPT_DIR"/csv_to_fasta.py \ 25 | --csv-path "$CSV_PATH" \ 26 | --fasta-path "$fasta_file_name" 27 | 28 | if [ ! -f "$OUTPUTS_DIR/$assay_name.csv" ]; then 29 | torchrun --nproc-per-node=gpu -m progen3.tools.score \ 30 | --model-name "$model_name" \ 31 | --fasta-path "$FASTA_DIR/$assay_name.fasta" \ 32 | --output-path "$OUTPUTS_DIR/$assay_name.csv" \ 33 | --fsdp 34 | else 35 | echo "Skipping scoring because $OUTPUTS_DIR/$assay_name.csv already exists" 36 | fi 37 | 38 | python3 "$SCRIPT_DIR"/score.py \ 39 | --assays-dir "$DATA_DIR/DMS_ProteinGym_substitutions" \ 40 | --outputs-dir "$OUTPUTS_DIR" \ 41 | --index-file "$DATA_DIR/DMS_substitutions.csv" \ 42 | $([ "$assay_name" = "all" ] && echo "--split-all-scores") 43 | -------------------------------------------------------------------------------- /examples/protein_gym_zero_shot/score.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import click 4 | import pandas as pd 5 | from scipy.stats import spearmanr 6 | from tqdm import tqdm 7 | 8 | 9 | @click.command() 10 | @click.option("--assays-dir", type=click.Path(exists=True), required=True) 11 | @click.option("--outputs-dir", type=click.Path(exists=True), required=True) 12 | @click.option("--index-file", type=click.Path(exists=True), required=True) 13 | @click.option("--split-all-scores", is_flag=True) 14 | def main(assays_dir: str, outputs_dir: str, index_file: str, split_all_scores: bool) -> None: 15 | if split_all_scores: 16 | all_scores_file = os.path.join(outputs_dir, "all.csv") 17 | all_scores_df = pd.read_csv(all_scores_file) 18 | all_scores_df["DMS_id"] = all_scores_df["sequence_id"].apply(lambda x: x.split("+")[0]) 19 | for assay_name in all_scores_df["DMS_id"].unique(): 20 | assay_df = all_scores_df[all_scores_df.DMS_id == assay_name] 21 | assay_df.to_csv(os.path.join(outputs_dir, f"{assay_name}.csv"), index=False) 22 | 23 | index_df = pd.read_csv(index_file) 24 | spearmans = {} 25 | 26 | for DMS_id in tqdm(index_df["DMS_id"].unique(), desc="Scoring", ncols=80): 27 | output_file = os.path.join(outputs_dir, f"{DMS_id}.csv") 28 | if not os.path.exists(output_file): 29 | continue 30 | 31 | assay_file = os.path.join(assays_dir, f"{DMS_id}.csv") 32 | assay_df = pd.read_csv(assay_file) 33 | output_df = pd.read_csv(output_file) 34 | 35 | # Remove the DMS_id from the sequence id and keep only sequence index within the assay 36 | output_df["sequence_id"] = output_df["sequence_id"].apply(lambda x: x.split("+")[1]).astype(int) 37 | output_df = output_df.sort_values(by="sequence_id") 38 | 39 | assert len(assay_df) == len(output_df), "Assay and output files must have the same number of rows" 40 | 41 | y_true = assay_df["DMS_score"].values 42 | y_pred = output_df["log_likelihood"].values 43 | 44 | spearman_corr, _ = spearmanr(y_true, y_pred) 45 | spearmans[DMS_id] = spearman_corr 46 | 47 | index_df["spearman_corr"] = index_df["DMS_id"].apply(lambda x: spearmans.get(x, None)) 48 | print(index_df[index_df["spearman_corr"].notna()][["DMS_id", "spearman_corr"]]) 49 | 50 | spearman_df = index_df[index_df["spearman_corr"].notna()] 51 | aggregated_spearman = ( 52 | spearman_df.groupby(["coarse_selection_type", "UniProt_ID"])["spearman_corr"] 53 | .mean() 54 | .groupby("coarse_selection_type") 55 | .mean() 56 | .mean() 57 | .item() 58 | ) 59 | print(f"Aggregated Spearman: {aggregated_spearman}") 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | disallow_untyped_defs = True 3 | install_types = True 4 | non_interactive = True 5 | ignore_missing_imports = True 6 | exclude = test_.*\.py$ 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "progen3" 3 | version = "0.1.0" 4 | description = "Progen3 Models" 5 | requires-python = ">=3.10, <3.13" 6 | dependencies = [ 7 | "transformers>=4.42,<4.49", 8 | "pip", 9 | "tokenizers", 10 | "accelerate", 11 | "torch>=2.5.0,<2.5.2", 12 | "click", 13 | "tqdm", 14 | "pandas", 15 | "scipy", 16 | "biopython" 17 | ] 18 | 19 | [pypi-options] 20 | index-url = "https://pypi.org/simple" 21 | extra-index-urls = ["https://download.pytorch.org/whl/cu124"] # pin to cuda 12.4 22 | 23 | [build-system] 24 | requires = ["hatchling"] 25 | build-backend = "hatchling.build" 26 | 27 | [tool.pixi.project] 28 | channels = ["conda-forge"] 29 | platforms = ["linux-64", "osx-arm64"] 30 | 31 | [tool.pixi.pypi-dependencies] 32 | progen3 = { path = ".", editable = true } 33 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | pip install -e . 4 | pip install "megablocks[gg]==0.7.0" --no-build-isolation 5 | MAX_JOBS=4 pip install flash-attn --no-build-isolation 6 | -------------------------------------------------------------------------------- /src/progen3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-AI/progen3/5b393afff2500a62858471f77fbcb59b20c0aa91/src/progen3/__init__.py -------------------------------------------------------------------------------- /src/progen3/batch_preparer.py: -------------------------------------------------------------------------------- 1 | import re 2 | from copy import deepcopy 3 | from dataclasses import dataclass 4 | from typing import Any, Optional 5 | 6 | import numpy as np 7 | import torch 8 | from torch.nn.utils.rnn import pad_sequence 9 | 10 | from progen3.common import dist 11 | 12 | from .tokenizer import END_OF_SPAN_TOKEN, get_tokenizer 13 | 14 | CLM_PATTERN = re.compile(r"^[A-Z]+$") 15 | GLM_PATTERN = re.compile(r"^[A-Z]+\[GLM\](?:\d+\-\d+\-\d+;)*\d+\-\d+\-\d+;?$") 16 | 17 | 18 | @dataclass 19 | class DataPrepConfig: 20 | fuzzy_span_len_factor: float = 0.2 21 | max_glm_spans: int = 50 22 | 23 | 24 | class ProGen3BatchPreparer: 25 | """ 26 | Takes a batch of sequences and prepares them to be fed into the model's forward pass. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | data_prep_config: Optional[DataPrepConfig] = None, 32 | rng: Optional[np.random.Generator] = None, 33 | ): 34 | tokenizer = get_tokenizer() 35 | super().__init__() 36 | self.data_prep_config = data_prep_config or DataPrepConfig() 37 | self.rng = rng or np.random.default_rng(0) 38 | self.tokenizer = tokenizer 39 | self.pad_token_id = tokenizer.token_to_id("") 40 | 41 | def get_batch_kwargs( # type: ignore[override] 42 | self, 43 | sequences: list[str], 44 | device: torch.device = torch.device("cpu"), 45 | reverse: bool = False, 46 | ) -> dict[str, torch.Tensor]: 47 | """ 48 | NOTE: This function assumes all sequences are in 1->2 direction only. 49 | Passing reverse sequences will result in incorrect encoding. 50 | """ 51 | sequence_encodings = [self.prepare_singleseq(sequence, reverse) for sequence in sequences] 52 | padded_encodings = self.pad_encodings(sequence_encodings) 53 | padded_encodings = {k: v.to(device=device, non_blocking=True) for k, v in padded_encodings.items()} 54 | 55 | return padded_encodings 56 | 57 | def pad_encodings(self, sequence_encodings: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: 58 | padding_value = { 59 | "input_ids": self.pad_token_id, 60 | "labels": self.pad_token_id, 61 | "position_ids": 0, 62 | "sequence_ids": 0, 63 | } 64 | padded_batch = {} 65 | for key, padding_value in padding_value.items(): 66 | padded_batch[key] = pad_sequence( 67 | [enc[key] for enc in sequence_encodings], 68 | batch_first=True, 69 | padding_value=padding_value, 70 | ).to(dtype=torch.long) 71 | 72 | return padded_batch 73 | 74 | def get_generation_kwargs(self, sequence: str, reverse_sequences: bool) -> dict[str, torch.Tensor]: 75 | """ 76 | NOTE: This function assumes both sequence and context are in 1->2 direction only. 77 | Passing reverse sequences will result in incorrect encoding. 78 | """ 79 | single_seq_encoding = self.prepare_singleseq(sequence, reverse_sequences) 80 | prefix_length = single_seq_encoding["metadata"]["prefix_length"] 81 | 82 | input_ids = single_seq_encoding["input_ids"][:prefix_length] 83 | sequence_ids = single_seq_encoding["sequence_ids"][:prefix_length] 84 | position_ids = single_seq_encoding["position_ids"][:prefix_length] 85 | 86 | return { 87 | "input_ids": input_ids.unsqueeze(0).to(dist.get_device()), 88 | "sequence_ids": sequence_ids.unsqueeze(0).to(dist.get_device()), 89 | "position_ids": position_ids.unsqueeze(0).to(dist.get_device()), 90 | } 91 | 92 | def prepare_singleseq(self, sequence: str, reverse_sequence: bool) -> dict[str, Any]: 93 | """ 94 | NOTE: This function assumes sequence is in 1->2 direction only. 95 | Passing reverse sequence as first argument will result in incorrect encoding. 96 | """ 97 | example = (self.prepare_clm if not is_glm_instance(sequence) else self.prepare_glm)(sequence, reverse_sequence) 98 | return example 99 | 100 | def prepare_clm(self, sequence: str, reverse_sequence: bool) -> dict[str, Any]: 101 | sequence = "1" + sequence + "2" 102 | if reverse_sequence: 103 | sequence = sequence[::-1] 104 | 105 | tokens = self.tokenizer.encode(f"{sequence}").ids 106 | return { 107 | "input_ids": torch.tensor(tokens), 108 | "labels": torch.tensor(tokens), 109 | "position_ids": torch.arange(len(tokens)), 110 | "sequence_ids": torch.zeros(len(tokens)), 111 | # Metadata for generation 112 | # remove <1/2> from the end for generation 113 | "metadata": {"prefix_length": len(tokens) - 2}, 114 | } 115 | 116 | def prepare_glm(self, sequence: str, reverse_sequence: bool) -> dict[str, Any]: 117 | sequence, masking_info = get_spans_to_mask(sequence) 118 | spans_to_mask = sorted(masking_info.keys()) 119 | remaining_spans = get_remaining_spans_from_infill_spans(spans_to_mask, len(sequence)) 120 | tokens = list(sequence) 121 | if reverse_sequence: 122 | tokens = tokens[::-1] 123 | spans_to_mask = [(len(tokens) - e, len(tokens) - s) for s, e in spans_to_mask] 124 | remaining_spans = [(len(tokens) - e, len(tokens) - s) for s, e in remaining_spans] 125 | masking_info = {(len(tokens) - e, len(tokens) - s): L for (s, e), L in masking_info.items()} 126 | 127 | infill_span_ids = self.rng.choice(self.data_prep_config.max_glm_spans, len(spans_to_mask)) 128 | infill_span_ids = [f"" for i in infill_span_ids] # type: ignore 129 | 130 | all_spans = sorted( 131 | [(*x, True, infill_span_ids[i]) for i, x in enumerate(spans_to_mask)] 132 | + [(*x, False, "") for x in remaining_spans] 133 | ) 134 | 135 | prefix_tokens, suffix_tokens = [], [] 136 | prefix_pos_ids, suffix_pos_ids = [], [] 137 | 138 | pos_id_start = 0 + 2 # 0 for , 1 for 139 | for s, e, is_infill_span, span_id in all_spans: 140 | if is_infill_span: 141 | # If span is infill span, add it to suffix and replace it with span_id in prefix 142 | span_suffix_tokens = [span_id] + tokens[s:e] + [END_OF_SPAN_TOKEN] 143 | suffix_tokens.extend(span_suffix_tokens) 144 | suffix_pos_ids.extend(list(range(pos_id_start, pos_id_start + len(span_suffix_tokens)))) 145 | 146 | prefix_tokens.append(span_id) 147 | prefix_pos_ids.append(pos_id_start) 148 | pos_id_start += 1 149 | 150 | # The infill length L may be different from the span length (for example for miniaturization) 151 | L = masking_info[(s, e)] 152 | fuzzy_diff = np.floor(L * self.data_prep_config.fuzzy_span_len_factor) 153 | pos_id_start += L + int(fuzzy_diff) 154 | else: 155 | # If span is remaining span, add it to prefix 156 | prefix = tokens[s:e] 157 | prefix_tokens.extend(deepcopy(prefix)) 158 | prefix_pos_ids.extend(list(range(pos_id_start, pos_id_start + len(prefix)))) 159 | pos_id_start += len(prefix) 160 | 161 | term_1, term_2 = ("1", "2") if not reverse_sequence else ("2", "1") 162 | 163 | prefix_tokens = ["", term_1] + prefix_tokens + [term_2, ""] 164 | prefix_labels = [""] * len(prefix_tokens) 165 | prefix_pos_ids = [0, 1] + prefix_pos_ids + [pos_id_start, pos_id_start + 1] 166 | 167 | input_tokens = prefix_tokens + suffix_tokens 168 | labels = prefix_labels + [x if not x.startswith("" for x in suffix_tokens] 169 | pos_ids = prefix_pos_ids + suffix_pos_ids 170 | 171 | input_tokens_ids = self.tokenizer.encode("".join(input_tokens)).ids 172 | labels_ids = self.tokenizer.encode("".join(labels)).ids 173 | 174 | return { 175 | "input_ids": torch.tensor(input_tokens_ids), 176 | "labels": torch.tensor(labels_ids), 177 | "position_ids": torch.tensor(pos_ids), 178 | "sequence_ids": torch.zeros(len(input_tokens_ids)), 179 | # Metadata for generation 180 | # +1 for first span_id in suffix which we want to keep 181 | "metadata": {"prefix_length": len(prefix_pos_ids) + 1}, 182 | } 183 | 184 | 185 | def assert_valid_instance(sequence: str) -> None: 186 | if len(sequence) == 0: 187 | return 188 | 189 | if not CLM_PATTERN.match(sequence) and not GLM_PATTERN.match(sequence): 190 | raise ValueError(f"Sequence is not a valid CLM or GLM instance: {sequence}") 191 | 192 | 193 | def is_glm_instance(sequence: str) -> bool: 194 | return GLM_PATTERN.match(sequence) is not None 195 | 196 | 197 | def get_spans_to_mask(sequence: str) -> tuple[str, dict[tuple[int, int], int]]: 198 | spans = {} 199 | sequence, spans_str = sequence.split("[GLM]") 200 | spans_str = spans_str.strip(";") 201 | for span in spans_str.split(";"): 202 | s, e, length = span.split("-") 203 | spans[(int(s), int(e))] = int(length) 204 | return sequence, spans 205 | 206 | 207 | def prepare_glm_string_from_spans(spans: dict[tuple[int, int], int]) -> str: 208 | return "[GLM]" + ";".join(f"{s}-{e}-{v}" for (s, e), v in spans.items()) 209 | 210 | 211 | def get_remaining_spans_from_infill_spans( 212 | infill_spans: list[tuple[int, int]], num_tokens: int 213 | ) -> list[tuple[int, int]]: 214 | remaining_spans = [] 215 | start = 0 216 | for s, e in infill_spans: 217 | assert s >= 0 and e <= num_tokens, f"Span {s}-{e} is invalid for sequence of length {num_tokens}." 218 | if start < s: 219 | remaining_spans.append((start, s)) 220 | start = e 221 | if start < num_tokens: 222 | remaining_spans.append((start, num_tokens)) 223 | 224 | return remaining_spans 225 | -------------------------------------------------------------------------------- /src/progen3/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-AI/progen3/5b393afff2500a62858471f77fbcb59b20c0aa91/src/progen3/common/__init__.py -------------------------------------------------------------------------------- /src/progen3/common/dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | from torch.distributed.fsdp import FullyShardedDataParallel 8 | from transformers.generation.utils import GenerateOutput 9 | from transformers.modeling_utils import PreTrainedModel 10 | 11 | 12 | def get_world_size(group: Any = None) -> int: 13 | if os.environ.get("RANK", -1) == -1 or not dist.is_initialized(): 14 | return 1 15 | return dist.get_world_size(group=group) 16 | 17 | 18 | def get_rank(group: Any = None) -> int: 19 | if os.environ.get("RANK", -1) == -1 or not dist.is_initialized(): 20 | return 0 21 | return dist.get_rank(group=group) 22 | 23 | 24 | def get_device() -> int: 25 | if torch.cuda.is_available(): 26 | return torch.cuda.current_device() 27 | return torch.device("cpu") 28 | 29 | 30 | def get_local_rank() -> int: 31 | return int(os.environ.get("LOCAL_RANK", 0)) if dist.is_initialized() else 0 32 | 33 | 34 | def setup_dist() -> None: 35 | rank = int(os.environ.get("RANK", -1)) 36 | if dist.is_available() and torch.cuda.is_available() and rank != -1: 37 | torch.distributed.init_process_group(backend="nccl") 38 | torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) 39 | 40 | 41 | def destroy_process_group() -> None: 42 | if dist.is_initialized(): 43 | dist.destroy_process_group() 44 | 45 | 46 | def barrier() -> None: 47 | if dist.is_initialized(): 48 | dist.barrier() 49 | 50 | 51 | def is_initialized() -> bool: 52 | return dist.is_initialized() 53 | 54 | 55 | @torch.no_grad() 56 | def generate( 57 | model: Union[FullyShardedDataParallel, PreTrainedModel], *args: Any, **kwargs: Any 58 | ) -> Union[GenerateOutput, torch.LongTensor]: 59 | if any(isinstance(m, FullyShardedDataParallel) for m in [model, *model.named_children()]): 60 | kwargs["synced_gpus"] = True 61 | with FullyShardedDataParallel.summon_full_params(model, writeback=False, recurse=False): 62 | return model.generate(*args, **kwargs) 63 | return model.generate(*args, **kwargs) 64 | -------------------------------------------------------------------------------- /src/progen3/common/model_loading.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from contextlib import contextmanager 3 | from pathlib import Path 4 | from typing import Any, Callable, Generator, Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.distributed._tensor import DTensor 9 | from torch.distributed.fsdp import ( 10 | BackwardPrefetch, 11 | FullyShardedDataParallel, 12 | MixedPrecision, 13 | ShardingStrategy, 14 | ) 15 | from torch.distributed.fsdp.wrap import CustomPolicy 16 | from transformers.modeling_utils import PreTrainedModel 17 | 18 | from progen3.common.dist import get_rank 19 | 20 | 21 | def get_model( 22 | model_name_or_path: Union[str, Path], 23 | model_class: PreTrainedModel, 24 | dtype: Optional[torch.dtype] = None, 25 | fsdp: bool = False, 26 | ) -> Union[PreTrainedModel, FullyShardedDataParallel]: 27 | init_ctx = contextlib.nullcontext if get_rank() == 0 or not fsdp else init_empty_weights 28 | device = torch.cuda.current_device() if not fsdp else torch.device("cpu") 29 | with init_ctx(): 30 | model = model_class.from_pretrained(model_name_or_path, device_map=device, torch_dtype=dtype) 31 | if fsdp: 32 | model = _fsdp_wrap(model, mixed_precision=dtype) 33 | return model 34 | 35 | 36 | def _fsdp_wrap(model: PreTrainedModel, mixed_precision: Optional[torch.dtype] = None) -> FullyShardedDataParallel: 37 | if mixed_precision is not None: 38 | mp = MixedPrecision( 39 | param_dtype=mixed_precision, 40 | reduce_dtype=mixed_precision, 41 | buffer_dtype=mixed_precision, 42 | ) 43 | else: 44 | mp = None 45 | return FullyShardedDataParallel( 46 | model, 47 | auto_wrap_policy=_auto_wrap_policy(model), 48 | sharding_strategy=ShardingStrategy.FULL_SHARD, 49 | backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 50 | mixed_precision=mp, 51 | device_id=torch.cuda.current_device(), 52 | limit_all_gathers=True, 53 | sync_module_states=True, 54 | ) 55 | 56 | 57 | def _auto_wrap_policy(obj: PreTrainedModel) -> CustomPolicy: 58 | def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]: 59 | ret = False 60 | if hasattr(module, "_fsdp_wrap"): 61 | ret = bool(module._fsdp_wrap) 62 | elif hasattr(obj, "fsdp_wrap_fn") and callable(obj.fsdp_wrap_fn): 63 | ret = obj.fsdp_wrap_fn(module) 64 | # TODO: may need to modify a dict ret in case some values are strings when they shouldn't be 65 | return ret 66 | 67 | return CustomPolicy(lambda_fn) 68 | 69 | 70 | # Modified from https://github.com/huggingface/accelerate/blob/main/src/accelerate/big_modeling.py 71 | @contextmanager 72 | def init_empty_weights(include_buffers: bool = False) -> Generator[None, None, None]: 73 | """Meta initialization context manager. 74 | 75 | A context manager under which models are initialized with all parameters 76 | on the meta device, therefore creating an empty model. Useful when just 77 | initializing the model would blow the available RAM. 78 | 79 | Args: 80 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 81 | not to also put all buffers on the meta device while initializing. 82 | 83 | Example: 84 | ```python 85 | import torch.nn as nn 86 | 87 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 88 | with init_empty_weights(): 89 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 90 | ``` 91 | 92 | 93 | 94 | Any model created under this context manager has no weights. As such you can't do something like 95 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 96 | 97 | 98 | """ 99 | with init_on_device( 100 | torch.device("meta"), 101 | include_buffers=include_buffers, 102 | ) as f: 103 | yield f 104 | 105 | 106 | # Modified from https://github.com/huggingface/accelerate/blob/main/src/accelerate/big_modeling.py 107 | @contextmanager 108 | def init_on_device(device: torch.device, include_buffers: bool = False) -> Generator[None, None, None]: 109 | """Device initialization context manager. 110 | 111 | A context manager under which models are initialized with all parameters 112 | on the specified device. 113 | 114 | Args: 115 | device (`torch.device`): Device to initialize all parameters on. 116 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 117 | not to also put all buffers on the meta device while initializing. 118 | 119 | Example: 120 | ```python 121 | import torch.nn as nn 122 | 123 | with init_on_device(device=torch.device("cuda")): 124 | tst = nn.Liner(100, 100) # on `cuda` device 125 | ``` 126 | """ 127 | old_register_parameter = nn.Module.register_parameter 128 | if include_buffers: 129 | old_register_buffer = nn.Module.register_buffer 130 | 131 | def register_empty_parameter( 132 | self: torch.nn.Module, 133 | name: str, 134 | param: Optional[torch.nn.Parameter], 135 | ) -> None: 136 | old_register_parameter(self, name, param) 137 | if param is not None: 138 | parameter = self._parameters[name] 139 | assert parameter is not None 140 | if isinstance(parameter, DTensor): 141 | self._parameters[name] = parameter.to(device) # type: ignore 142 | else: 143 | param_cls = type(parameter) 144 | kwargs = parameter.__dict__ 145 | self._parameters[name] = param_cls( 146 | parameter.to(device), 147 | **kwargs, 148 | ) 149 | 150 | def register_empty_buffer( 151 | self: torch.nn.Module, 152 | name: str, 153 | tensor: Optional[torch.Tensor], 154 | persistent: bool = True, 155 | ) -> None: 156 | old_register_buffer(self, name, tensor, persistent=persistent) 157 | if tensor is not None: 158 | named_buffer = self._buffers[name] 159 | assert named_buffer is not None 160 | self._buffers[name] = named_buffer.to(device) 161 | 162 | # Patch tensor creation 163 | if include_buffers: 164 | tensor_constructors_to_patch = { 165 | torch_function_name: getattr(torch, torch_function_name) 166 | for torch_function_name in ["empty", "zeros", "ones", "full"] 167 | } 168 | else: 169 | tensor_constructors_to_patch = {} 170 | 171 | def patch_tensor_constructor(fn: Callable) -> Callable: 172 | def wrapper(*args: Any, **kwargs: Any) -> torch.Tensor: 173 | kwargs["device"] = device 174 | return fn(*args, **kwargs) 175 | 176 | return wrapper 177 | 178 | try: 179 | nn.Module.register_parameter = register_empty_parameter # type: ignore 180 | if include_buffers: 181 | nn.Module.register_buffer = register_empty_buffer # type: ignore 182 | for torch_function_name in tensor_constructors_to_patch.keys(): 183 | setattr( 184 | torch, 185 | torch_function_name, 186 | patch_tensor_constructor(getattr(torch, torch_function_name)), 187 | ) 188 | yield 189 | finally: 190 | nn.Module.register_parameter = old_register_parameter # type: ignore 191 | if include_buffers: 192 | nn.Module.register_buffer = old_register_buffer # type: ignore 193 | for ( 194 | torch_function_name, 195 | old_torch_function, 196 | ) in tensor_constructors_to_patch.items(): 197 | setattr(torch, torch_function_name, old_torch_function) 198 | -------------------------------------------------------------------------------- /src/progen3/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Adapted from Mixtral model configuration.""" 15 | 16 | from transformers.configuration_utils import PretrainedConfig 17 | from transformers.utils import logging 18 | 19 | from .tokenizer import get_tokenizer 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | 24 | class ProGen3Config(PretrainedConfig): 25 | model_type = "progen3" 26 | keys_to_ignore_at_inference = ["past_key_values"] 27 | 28 | def __init__( # type: ignore 29 | self, 30 | # Model architecture/initialization 31 | vocab_size=None, 32 | hidden_size=4096, 33 | intermediate_size=16384, 34 | gated_mlp=False, 35 | num_hidden_layers=40, 36 | num_attention_heads=32, 37 | num_key_value_heads=8, 38 | hidden_act="silu", 39 | rms_norm_eps=1e-5, 40 | initializer_range=0.02, 41 | torch_dtype="bfloat16", 42 | use_cache=True, 43 | gradient_checkpointing=False, 44 | no_ffn_gradient_checkpointing=False, 45 | # Tokenization 46 | pad_token_id=None, 47 | bos_token_id=None, 48 | eos_token_id=None, 49 | tie_word_embeddings=False, 50 | # Attention implementation & rotary positional embeddings 51 | fused_attention_norm=False, 52 | msa_style_attention=True, 53 | max_num_sequences=512, 54 | max_position_embeddings=1024 * 64, 55 | rope_theta=100000.0, 56 | attention_dropout=0.0, 57 | clip_qkv=None, 58 | # Mixture of experts implementation 59 | moe_implementation="megablocks", 60 | moe_expert_selection="switch", 61 | moe_grouped_gemm=True, 62 | moe_memory_optimized=None, 63 | num_experts=8, 64 | num_experts_per_tok=2, 65 | moe_world_size=1, 66 | output_router_weights=False, 67 | # Additional activation quantization fn 68 | quantize_inputs_num_bits=None, 69 | quantize_rematerialize_num_bits=None, 70 | quantize_scatter_num_bits=None, 71 | # Loss function details 72 | mlm_loss_coef=1.0, 73 | # From DBRX, https://github.com/databricks/dbrx/blob/main/model/config.json 74 | router_aux_loss_coef=0.05, 75 | **kwargs, 76 | ) -> None: 77 | tokenizer = get_tokenizer() 78 | super().__init__( 79 | pad_token_id=tokenizer.token_to_id(""), 80 | bos_token_id=tokenizer.token_to_id(""), 81 | eos_token_id=tokenizer.token_to_id(""), 82 | tie_word_embeddings=tie_word_embeddings, 83 | torch_dtype=torch_dtype, 84 | **kwargs, 85 | ) 86 | 87 | self.max_position_embeddings = max_position_embeddings 88 | self.hidden_size = hidden_size 89 | if intermediate_size is None: 90 | intermediate_size = 3 * hidden_size if gated_mlp else 4 * hidden_size 91 | self.intermediate_size = intermediate_size 92 | self.gated_mlp = gated_mlp 93 | self.num_hidden_layers = num_hidden_layers 94 | self.num_attention_heads = num_attention_heads 95 | 96 | # for backward compatibility 97 | if num_key_value_heads is None: 98 | num_key_value_heads = num_attention_heads 99 | 100 | self.fused_attention_norm = fused_attention_norm 101 | self.num_key_value_heads = num_key_value_heads 102 | self.hidden_act = hidden_act 103 | self.initializer_range = initializer_range 104 | self.rms_norm_eps = rms_norm_eps 105 | self.use_cache = use_cache 106 | self.rope_theta = rope_theta 107 | self.attention_dropout = attention_dropout 108 | self.msa_style_attention = msa_style_attention 109 | self.max_num_sequences = max_num_sequences 110 | assert clip_qkv is None or clip_qkv > 0 111 | self.clip_qkv = clip_qkv 112 | 113 | num_experts_per_tok = min(num_experts_per_tok, num_experts) 114 | assert num_experts > 0 and num_experts_per_tok > 0 115 | assert ( 116 | num_experts % moe_world_size == 0 117 | ), f"Expected {moe_world_size=} to perfectly divide {num_experts=}" # noqa: E225 118 | if num_experts == 1: 119 | moe_implementation = "eager" 120 | moe_expert_selection = "switch" 121 | if num_experts == 1 or moe_expert_selection == "sinkhorn": 122 | router_aux_loss_coef = 0.0 123 | if moe_implementation != "megablocks": 124 | moe_world_size = 1 125 | output_router_weights = output_router_weights or router_aux_loss_coef > 0 126 | if moe_memory_optimized is None: 127 | moe_memory_optimized = moe_grouped_gemm 128 | 129 | self.quantize_inputs_num_bits = quantize_inputs_num_bits 130 | self.quantize_rematerialize_num_bits = quantize_rematerialize_num_bits 131 | self.quantize_scatter_num_bits = quantize_scatter_num_bits 132 | assert quantize_inputs_num_bits is None or quantize_inputs_num_bits == 8, "Only 8-bit quantization is supported" 133 | assert ( 134 | self.quantize_inputs_num_bits == self.quantize_rematerialize_num_bits == self.quantize_scatter_num_bits 135 | ), "Different quantization bitwidths for inputs, rematerialize, and scatter are not supported" 136 | 137 | self.num_experts = num_experts 138 | self.num_experts_per_tok = num_experts_per_tok 139 | self.output_router_weights = output_router_weights 140 | self.mlm_loss_coef = mlm_loss_coef 141 | self.router_aux_loss_coef = router_aux_loss_coef 142 | 143 | self.moe_implementation = moe_implementation 144 | self.moe_expert_selection = moe_expert_selection 145 | self.moe_grouped_gemm = moe_grouped_gemm 146 | self.moe_memory_optimized = moe_memory_optimized 147 | self.moe_world_size = max(1, moe_world_size) 148 | 149 | self.vocab_size = tokenizer.get_vocab_size() 150 | self.gradient_checkpointing = gradient_checkpointing 151 | self.no_ffn_gradient_checkpointing = no_ffn_gradient_checkpointing 152 | 153 | if vocab_size is not None: 154 | if vocab_size < self.vocab_size: 155 | logger.warning(f"Ignoring vocab_size {vocab_size}. Using larger {self.vocab_size} from tokenizer.") 156 | elif vocab_size > self.vocab_size: 157 | logger.warning(f"Using vocab_size {vocab_size} instead of smaller {self.vocab_size} from tokenizer.") 158 | self.vocab_size = vocab_size 159 | if pad_token_id is not None and pad_token_id != self.pad_token_id: 160 | logger.warning(f"Ignoring pad_token_id. Using {self.pad_token_id} from tokenizer") 161 | if bos_token_id is not None and bos_token_id != self.bos_token_id: 162 | logger.warning(f"Ignoring bos_token_id. Using {self.bos_token_id} from tokenizer") 163 | if eos_token_id is not None and eos_token_id != self.eos_token_id: 164 | logger.warning(f"Ignoring eos_token_id. Using {self.eos_token_id} from tokenizer") 165 | -------------------------------------------------------------------------------- /src/progen3/generator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import re 4 | from copy import deepcopy 5 | from typing import Iterator, NamedTuple, Optional 6 | 7 | import pandas as pd 8 | import torch 9 | import torch.distributed 10 | from tqdm import tqdm 11 | from transformers.cache_utils import DynamicCache 12 | from transformers.generation import GenerateDecoderOnlyOutput, GenerationConfig 13 | 14 | from progen3.batch_preparer import ( 15 | ProGen3BatchPreparer, 16 | assert_valid_instance, 17 | get_spans_to_mask, 18 | is_glm_instance, 19 | prepare_glm_string_from_spans, 20 | ) 21 | from progen3.common import dist 22 | from progen3.common.dist import generate 23 | from progen3.modeling import ProGen3ForCausalLM 24 | from progen3.tools.utils import batched, write_fasta_sequences 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class GenerationResult(NamedTuple): 30 | generation: str # aka raw generation from the model 31 | sequence: Optional[str] # cleaned, validated and compiled sequence in the forward direction 32 | 33 | 34 | class ProGen3Generator: 35 | OUTPUT_GENERATIONS_PATTERN = "{output_dir}/{prompt_id}.gen.fasta" 36 | OUTPUT_SEQUENCES_PATTERN = "{output_dir}/{prompt_id}.seq.fasta" 37 | 38 | def __init__( 39 | self, 40 | model: ProGen3ForCausalLM, 41 | max_batch_tokens: int = 65536, 42 | temperature: float = 0.2, 43 | top_p: float = 0.95, 44 | ): 45 | self.model = model 46 | self.model.eval() 47 | self.batch_preparer = ProGen3BatchPreparer() 48 | 49 | # eos_token_id is set in run_generation() 50 | self.default_gen_config = GenerationConfig( 51 | do_sample=True, 52 | use_cache=True, 53 | output_logits=True, 54 | return_dict_in_generate=True, 55 | temperature=temperature, 56 | top_p=top_p, 57 | pad_token_id=self.batch_preparer.tokenizer.padding["pad_id"], 58 | ) 59 | 60 | self.max_batch_tokens = max_batch_tokens 61 | 62 | def generate( 63 | self, 64 | prompt: str, 65 | num_sequences: int, 66 | min_new_tokens: int, 67 | max_new_tokens: int, 68 | gen_config: Optional[GenerationConfig] = None, 69 | ) -> Iterator[GenerationResult]: 70 | prompt, direction = parse_directed_prompt(prompt) 71 | # After above operation, prompt is in 1->2 format only (including GLM spans) 72 | assert_valid_instance(prompt) 73 | 74 | # Prepare input 75 | reverse_sequence = direction == "rev" 76 | input_encoding = self.batch_preparer.get_generation_kwargs(prompt, reverse_sequence) 77 | num_input_tokens = len(input_encoding["input_ids"][0]) 78 | logger.info(f"Generating for {prompt} ({direction}) with {num_input_tokens} input tokens") 79 | 80 | # Fill out generation config 81 | end_token = "" if not is_glm_instance(prompt) else "" 82 | end_token_id = self.batch_preparer.tokenizer.token_to_id(end_token) 83 | assert end_token_id is not None, f"End token {end_token} not found in tokenizer" 84 | 85 | gen_config = deepcopy(gen_config or self.default_gen_config) 86 | gen_config.eos_token_id = end_token_id 87 | gen_config.max_new_tokens = max_new_tokens + 2 88 | gen_config.min_new_tokens = min_new_tokens + 2 89 | 90 | batch_size = max(1, math.floor(self.max_batch_tokens / (num_input_tokens + gen_config.max_new_tokens))) 91 | 92 | # Prepare key-value cache with the prompt 93 | cached_length = num_input_tokens - 1 94 | key_value_cache: DynamicCache | None = None 95 | if cached_length > 0: 96 | with torch.no_grad(): 97 | cached_encoding = {k: v[:, :cached_length] for k, v in input_encoding.items()} 98 | key_value_cache = self.model(**cached_encoding, use_cache=True, return_dict=True).past_key_values 99 | assert key_value_cache is not None, "Key-value cache must be non-None" 100 | assert key_value_cache.get_seq_length(0) == cached_length, f"Cache must have {cached_length} tokens." 101 | torch.cuda.empty_cache() 102 | 103 | # Generate sequences 104 | pbar = tqdm( 105 | total=num_sequences, 106 | ncols=80, 107 | disable=dist.get_rank() != 0, 108 | desc=f"Sequences generated (Rank {dist.get_rank()}): ", 109 | ) 110 | for i in range(0, num_sequences, batch_size): 111 | num_gen = min(num_sequences - i, batch_size) 112 | 113 | if key_value_cache: 114 | key_value_cache.batch_repeat_interleave(num_gen) 115 | gen_config.num_return_sequences = num_gen 116 | outputs: GenerateDecoderOnlyOutput = generate( 117 | self.model, **input_encoding, generation_config=gen_config, past_key_values=key_value_cache 118 | ) 119 | if key_value_cache: 120 | key_value_cache.crop(cached_length) 121 | key_value_cache.batch_select_indices([0]) 122 | assert key_value_cache.get_seq_length(0) == cached_length, f"Cache must have {cached_length} tokens." 123 | 124 | # Get generated sequences 125 | input_token_ids = input_encoding["input_ids"][0].cpu().numpy().tolist() 126 | input_token_ids_in_completions = outputs.sequences[:, :num_input_tokens].tolist() 127 | assert all( 128 | input_token_ids == x for x in input_token_ids_in_completions 129 | ), "Input tokens must be the same in prompt and completion." 130 | completions = outputs.sequences[:, num_input_tokens:].tolist() 131 | decoded_completions = self.batch_preparer.tokenizer.decode_batch(completions, skip_special_tokens=False) 132 | pbar.update(num_gen) 133 | for decoded_completion in decoded_completions: 134 | compiled_completion = compile_generation(prompt, decoded_completion, direction) 135 | yield GenerationResult(sequence=compiled_completion, generation=decoded_completion) 136 | 137 | def run(self, prompt_file: str, output_dir: str, n_per_prompt: int) -> None: 138 | prompt_df = pd.read_csv(prompt_file) 139 | for _, row in prompt_df.iterrows(): 140 | prompt_id, prompt = row["id"], row["sequence"] 141 | min_new_tokens = row["min_new_tokens"] 142 | max_new_tokens = row["max_new_tokens"] 143 | generations_per_rank = math.ceil(n_per_prompt / dist.get_world_size()) 144 | 145 | results = [] 146 | generation_iterator = self.generate(prompt, generations_per_rank, min_new_tokens, max_new_tokens) 147 | 148 | for batch_generations in batched(generation_iterator, 500): 149 | if torch.distributed.is_initialized(): 150 | all_batch_generations: list = [None for _ in range(dist.get_world_size())] 151 | torch.distributed.all_gather_object(all_batch_generations, batch_generations) 152 | batch_generations = [result for rank_results in all_batch_generations for result in rank_results] 153 | results.extend(batch_generations) 154 | 155 | results = results[:n_per_prompt] 156 | if dist.get_rank() == 0: 157 | self._save_generations(prompt_id, results, output_dir) 158 | 159 | def _save_generations(self, prompt_id: str, results: list[GenerationResult], output_dir: str) -> None: 160 | generation_seqs = {} 161 | sequence_seqs = {} 162 | for i, result in enumerate(results): 163 | new_seq_id = str(i) 164 | generation_seqs[new_seq_id] = result.generation 165 | if result.sequence: 166 | sequence_seqs[new_seq_id] = result.sequence 167 | 168 | generations_file = self.OUTPUT_GENERATIONS_PATTERN.format(output_dir=output_dir, prompt_id=prompt_id) 169 | sequences_file = self.OUTPUT_SEQUENCES_PATTERN.format(output_dir=output_dir, prompt_id=prompt_id) 170 | write_fasta_sequences(generations_file, generation_seqs) 171 | write_fasta_sequences(sequences_file, sequence_seqs) 172 | 173 | 174 | def completion_validity_pattern(mode: str, direction: str) -> re.Pattern: 175 | pattern = r"[ACDEFGHIKLMNPQRSTVWY]+" # all non-X standard amino acids 176 | if mode == "CLM": 177 | terminal = "1" if direction == "rev" else "2" 178 | pattern += "(" + re.escape(terminal + "") + ")" 179 | elif mode == "GLM": 180 | pattern += "(" + re.escape("") + ")" 181 | 182 | return re.compile(pattern) 183 | 184 | 185 | def compile_generation(prompt: str, completion: str, direction: str) -> str | None: 186 | mode = "GLM" if is_glm_instance(prompt) else "CLM" 187 | pattern = completion_validity_pattern(mode, direction) 188 | match = pattern.match(completion) 189 | if match is None: 190 | return None 191 | 192 | # Remove special tokens from the completion 193 | stripped_completion = re.sub(r"[^A-Z]", "", completion) 194 | 195 | # original prompt is always 1->2, whereas completion is either: 196 | # direction fwd: 1->2 197 | # direction rev: 2->1 198 | match (mode, direction): 199 | case ("CLM", "fwd"): 200 | compiled_completion = prompt + stripped_completion 201 | case ("CLM", "rev"): 202 | compiled_completion = stripped_completion[::-1] + prompt 203 | case ("GLM", "fwd"): 204 | prompt, spans = get_spans_to_mask(prompt) 205 | assert len(spans) == 1, f"GLM can only have one span, but got {spans} spans." 206 | s, e = next(iter(spans.keys())) 207 | compiled_completion = prompt[:s] + stripped_completion + prompt[e:] 208 | case ("GLM", "rev"): 209 | prompt, spans = get_spans_to_mask(prompt) 210 | assert len(spans) == 1, f"GLM can only have one span, but got {spans} spans." 211 | s, e = next(iter(spans.keys())) 212 | compiled_completion = prompt[:s] + stripped_completion[::-1] + prompt[e:] 213 | case _: 214 | raise ValueError(f"Invalid mode or direction: {mode} {direction}") 215 | 216 | return compiled_completion 217 | 218 | 219 | def parse_directed_prompt(prompt: str) -> tuple[str, str]: 220 | assert prompt.startswith("1") or prompt.startswith("2"), "Prompt must start with 1 or 2" 221 | is_fwd = prompt.startswith("1") 222 | prompt = prompt[1:] 223 | 224 | is_glm = is_glm_instance(prompt) 225 | 226 | match (is_glm, is_fwd): 227 | case (False, True): 228 | return prompt, "fwd" 229 | case (False, False): 230 | return prompt[::-1], "rev" 231 | case (True, True): 232 | return prompt, "fwd" 233 | case (True, False): 234 | sequence, spans = get_spans_to_mask(prompt) 235 | sequence = sequence[::-1] 236 | spans = {(len(sequence) - e, len(sequence) - s): v for (s, e), v in spans.items()} 237 | post_string = prepare_glm_string_from_spans(spans) 238 | return sequence + post_string, "rev" 239 | case _: 240 | raise ValueError(f"Invalid mode or direction: {is_glm} {is_fwd}") 241 | -------------------------------------------------------------------------------- /src/progen3/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-AI/progen3/5b393afff2500a62858471f77fbcb59b20c0aa91/src/progen3/model/__init__.py -------------------------------------------------------------------------------- /src/progen3/model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.attention import SDPBackend, sdpa_kernel 5 | from torch.nn.attention.bias import causal_lower_right 6 | from transformers.cache_utils import Cache 7 | from transformers.utils import logging 8 | 9 | from ..config import ProGen3Config 10 | 11 | logger = logging.get_logger(__name__) 12 | 13 | 14 | # Copied from transformers.models.llama.modeling_llama.repeat_kv 15 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 16 | """This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). 17 | 18 | The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, 19 | num_attention_heads, seqlen, head_dim) 20 | """ 21 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 22 | if n_rep == 1: 23 | return hidden_states 24 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 25 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 26 | 27 | 28 | # Copied from transformers.models.llama.modeling_llama.rotate_half 29 | def rotate_half(x: torch.Tensor) -> torch.Tensor: 30 | """Rotates half the hidden dims of the input.""" 31 | x1 = x[..., : x.shape[-1] // 2] 32 | x2 = x[..., x.shape[-1] // 2 :] 33 | return torch.cat((-x2, x1), dim=-1) 34 | 35 | 36 | # Adapted from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding and 37 | # transformers.models.llama.modeling_llama.apply_rotary_pos_emb 38 | class RotaryPositionalEmbedding(nn.Module): 39 | def __init__( 40 | self, dim: int, max_position_embeddings: int = 2048, base: float = 10000, device: torch.device | None = None 41 | ): 42 | super().__init__() 43 | 44 | self.dim = dim 45 | self.base = base 46 | self.max_position_embeddings = max_position_embeddings 47 | inv_freq = base ** -(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) 48 | self.register_buffer("inv_freq", inv_freq, persistent=False) 49 | 50 | # Build here to make `torch.jit.trace` work. 51 | self._set_sin_cos_cache( 52 | seq_len=max_position_embeddings, 53 | device=self.inv_freq.device, 54 | ) 55 | 56 | def _set_sin_cos_cache(self, seq_len: int, device: torch.device) -> None: 57 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 58 | self.max_seq_len_cached = seq_len 59 | t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) 60 | angles = torch.outer(t, self.inv_freq.to(device)) 61 | angles = torch.cat((angles, angles), dim=1) 62 | self.register_buffer("cos_cached", angles.cos(), persistent=False) 63 | self.register_buffer("sin_cached", angles.sin(), persistent=False) 64 | 65 | def forward( 66 | self, 67 | q: torch.Tensor, 68 | k: torch.Tensor, 69 | position_ids: torch.LongTensor, 70 | ) -> tuple[torch.Tensor, torch.Tensor]: 71 | # q, k: [bsz, n, num_attention_heads, head_size] 72 | # position_ids: [bsz, n] 73 | device, dtype = q.device, q.dtype 74 | 75 | # max position id can be different from number of tokens in the sequence 76 | # For example, for GLM/Infilling 77 | seq_len = position_ids.max().item() + 1 78 | if seq_len > self.max_seq_len_cached: 79 | self._set_sin_cos_cache(seq_len=seq_len, device=device) 80 | 81 | # angles_cached[position_ids] gets us something of shape (batch_size, seq_len, head_dim), 82 | # so unsqueeze dimension -2 to broadcast to (batch_size, seq_len, n_heads, head_dim). 83 | idxs = position_ids.to(device) 84 | cos = self.cos_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] 85 | sin = self.sin_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] 86 | 87 | q_embed = (q * cos) + (rotate_half(q) * sin) 88 | k_embed = (k * cos) + (rotate_half(k) * sin) 89 | return q_embed, k_embed 90 | 91 | 92 | # Adapted from transformers.models.mistral.modeling_mistral.MistralAttention 93 | class Attention(nn.Module): 94 | """Multi-headed attention from 'Attention Is All You Need' paper.""" 95 | 96 | def __init__(self, config: ProGen3Config, layer_idx: int): 97 | super().__init__() 98 | self.config = config 99 | self.layer_idx = layer_idx 100 | 101 | self.hidden_size = config.hidden_size 102 | self.num_heads = config.num_attention_heads 103 | self.head_dim = self.hidden_size // self.num_heads 104 | self.num_kv_heads = config.num_key_value_heads 105 | self.num_key_value_groups = self.num_heads // self.num_kv_heads 106 | self.max_position_embeddings = config.max_position_embeddings 107 | self.max_num_seqs = config.max_num_sequences 108 | self.rope_theta = config.rope_theta 109 | self.attention_dropout = config.attention_dropout 110 | self.clip_qkv = config.clip_qkv 111 | 112 | if (self.head_dim * self.num_heads) != self.hidden_size: 113 | raise ValueError( 114 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 115 | f" and `num_heads`: {self.num_heads})." 116 | ) 117 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 118 | self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) 119 | self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) 120 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 121 | 122 | self.rotary_emb = RotaryPositionalEmbedding( 123 | self.head_dim, 124 | max_position_embeddings=self.max_position_embeddings, 125 | base=self.rope_theta, 126 | ) 127 | 128 | def prepare_qkv( 129 | self, 130 | hidden_states: torch.Tensor, 131 | position_ids: torch.LongTensor, 132 | past_key_value: Cache | None = None, 133 | use_cache: bool | None = None, 134 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 135 | bsz, q_len, _ = hidden_states.size() 136 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) 137 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim) 138 | val_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim) 139 | if self.clip_qkv is not None: 140 | query_states = query_states.clamp(-self.clip_qkv, self.clip_qkv) 141 | key_states = key_states.clamp(-self.clip_qkv, self.clip_qkv) 142 | val_states = val_states.clamp(-self.clip_qkv, self.clip_qkv) 143 | 144 | query_states, key_states = self.rotary_emb( 145 | query_states, 146 | key_states, 147 | position_ids, 148 | ) 149 | 150 | if use_cache and past_key_value is not None: 151 | key_states, val_states = key_states.transpose(1, 2), val_states.transpose(1, 2) 152 | key_states, val_states = past_key_value.update(key_states, val_states, self.layer_idx) 153 | key_states, val_states = key_states.transpose(1, 2), val_states.transpose(1, 2) 154 | 155 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 156 | # therefore the input hidden states gets silently casted in float32. Hence, we need 157 | # cast them back in float16 just to be sure everything works as expected. 158 | input_dtype = query_states.dtype 159 | if torch.is_autocast_enabled(): 160 | target_dtype = torch.get_autocast_gpu_dtype() 161 | # Handle the case where the model is quantized 162 | elif hasattr(self.config, "_pre_quantization_dtype"): 163 | target_dtype = self.config._pre_quantization_dtype 164 | else: 165 | target_dtype = self.q_proj.weight.dtype 166 | if input_dtype != target_dtype: 167 | logger.warning_once( 168 | f"The input hidden states seems to be silently casted in {input_dtype}. " 169 | f"This might be because you have upcasted embedding or layer norm layers " 170 | f"in {input_dtype}. We will cast back the input in {target_dtype}." 171 | ) 172 | query_states = query_states.to(target_dtype) 173 | key_states = key_states.to(target_dtype) 174 | val_states = val_states.to(target_dtype) 175 | 176 | return query_states, key_states, val_states 177 | 178 | def forward( 179 | self, 180 | hidden_states: torch.Tensor, 181 | position_ids: torch.LongTensor, 182 | past_key_value: Cache | None = None, 183 | output_attentions: bool | None = None, 184 | use_cache: bool | None = None, 185 | ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: 186 | query_states, key_states, val_states = self.prepare_qkv( 187 | hidden_states=hidden_states, 188 | position_ids=position_ids, 189 | past_key_value=past_key_value, 190 | use_cache=use_cache, 191 | ) 192 | 193 | attn_output, attn_weights = self._attn( 194 | query_states=query_states, 195 | key_states=key_states, 196 | val_states=val_states, 197 | output_attentions=output_attentions, 198 | ) 199 | 200 | attn_output = self.o_proj(attn_output) 201 | return attn_output, attn_weights, past_key_value 202 | 203 | def _attn( 204 | self, 205 | query_states: torch.Tensor, 206 | key_states: torch.Tensor, 207 | val_states: torch.Tensor, 208 | output_attentions: bool | None = None, 209 | ) -> tuple[torch.Tensor, torch.Tensor | None]: 210 | assert not output_attentions, "output_attentions not supported" 211 | return self._sdpa_attn( 212 | query_states=query_states, 213 | key_states=key_states, 214 | val_states=val_states, 215 | ) 216 | 217 | def _sdpa_attn( 218 | self, query_states: torch.Tensor, key_states: torch.Tensor, val_states: torch.Tensor 219 | ) -> tuple[torch.Tensor, None]: 220 | query_states = query_states.transpose(1, 2) 221 | key_states = key_states.transpose(1, 2) 222 | val_states = val_states.transpose(1, 2) 223 | 224 | # repeat k/v heads if n_kv_heads < n_heads 225 | # enable_gqa in F.sdpa is not supported for all backends yet 226 | key_states = repeat_kv(key_states, self.num_key_value_groups) 227 | val_states = repeat_kv(val_states, self.num_key_value_groups) 228 | 229 | bsz, q_len = query_states.shape[0], query_states.shape[2] 230 | k_len = key_states.shape[2] 231 | 232 | causal_mask = None 233 | if k_len > q_len: 234 | causal_mask = causal_lower_right(q_len, k_len) 235 | elif k_len < q_len: 236 | raise ValueError("k_len must be greater than or equal to q_len") 237 | 238 | with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): 239 | attn_output = F.scaled_dot_product_attention( 240 | query_states, 241 | key_states, 242 | val_states, 243 | is_causal=causal_mask is None, 244 | attn_mask=causal_mask, 245 | ) 246 | 247 | attn_output = attn_output.transpose(1, 2).contiguous() 248 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 249 | return attn_output, None 250 | -------------------------------------------------------------------------------- /src/progen3/model/mb_wrapper.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | # Copyright 2022 MosaicML LLM Foundry authors 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | """Adapted from LLM foundry.""" 6 | import functools 7 | import logging 8 | 9 | import megablocks 10 | import megablocks.layers.arguments 11 | import megablocks.layers.common 12 | import megablocks.layers.dmoe 13 | import torch 14 | import torch.distributed as dist 15 | import torch.nn as nn 16 | from torch.distributed.tensor import DeviceMesh, DTensor, Placement, Shard 17 | from torch.distributed.tensor.device_mesh import init_device_mesh 18 | from transformers.activations import ACT2FN 19 | 20 | from ..config import ProGen3Config 21 | 22 | log = logging.getLogger(__name__) 23 | 24 | __all__ = [ 25 | "mb_build_dmoe", 26 | "mb_setup_args", 27 | ] 28 | 29 | functional_ACT2FN = {**ACT2FN} 30 | functional_ACT2FN["gelu"] = torch.nn.functional.gelu 31 | functional_ACT2FN["silu"] = torch.nn.functional.silu 32 | 33 | 34 | def dtensorify_param( 35 | param: nn.Parameter, 36 | mesh: DeviceMesh, 37 | placements: list[Placement], 38 | ): 39 | """Construct a DTensor from an already sharded local parameter.""" 40 | param_dtensor = DTensor.from_local( 41 | param.data, 42 | device_mesh=mesh, 43 | placements=placements, 44 | run_check=False, 45 | ) 46 | return nn.Parameter(param_dtensor) 47 | 48 | 49 | def get_mb_device_mesh(config: ProGen3Config) -> DeviceMesh: 50 | """Helper function to get the device mesh for MegaBlocks MoE. 51 | 52 | Args: 53 | moe_world_size (int): The MoE world size. 54 | world_size (int): The world size. 55 | 56 | Raises: 57 | ValueError: If the device mesh configuration is not valid. 58 | 59 | Returns: 60 | The device mesh for MegaBlocks MoE. 61 | """ 62 | world_size = dist.get_world_size() if dist.is_initialized() else 1 63 | assert world_size >= config.moe_world_size and world_size % config.moe_world_size == 0 64 | return init_device_mesh( 65 | "cuda", 66 | (world_size // config.moe_world_size, config.moe_world_size), 67 | mesh_dim_names=("weight_parallel", "expert_parallel"), 68 | ) 69 | 70 | 71 | def mb_setup_args( 72 | config: ProGen3Config, 73 | device: str | None = None, 74 | dtype: torch.dtype = torch.float32, 75 | **kwargs, 76 | ) -> tuple[megablocks.layers.arguments.Arguments, DeviceMesh]: 77 | """Setup the MegaBlocks args. 78 | 79 | Args: 80 | config (MixtralConfig): The model config object. 81 | device (Optional[str]): The device to run the FFN on. 82 | 83 | Returns: 84 | tuple[megablocks.layers.arguments.Arguments, DeviceMesh]: 85 | The MegaBlocks args and the device mesh for FSDP/expert parallelism. 86 | """ 87 | # Configure device mesh for expert parallelism if desired 88 | device_mesh = None 89 | if config.moe_world_size > 1: 90 | device_mesh = get_mb_device_mesh(config) 91 | kwargs.update( 92 | moe_expert_model_parallelism=True, 93 | expert_parallel_group=device_mesh["expert_parallel"].get_group(0), 94 | ) 95 | 96 | args = megablocks.layers.arguments.Arguments( 97 | hidden_size=config.hidden_size, 98 | ffn_hidden_size=config.intermediate_size, 99 | num_layers=config.num_hidden_layers, 100 | bias=False, 101 | return_bias=False, 102 | activation_fn=functional_ACT2FN[config.hidden_act], 103 | moe_num_experts=config.num_experts, 104 | moe_top_k=config.num_experts_per_tok, 105 | moe_loss_weight=config.router_aux_loss_coef, 106 | bf16=dtype is torch.bfloat16, 107 | fp16=dtype is torch.float16, 108 | device=device, 109 | mlp_type="glu" if config.gated_mlp else "mlp", 110 | mlp_impl="grouped" if config.moe_grouped_gemm else "sparse", 111 | memory_optimized_mlp=config.moe_memory_optimized, 112 | moe_normalize_expert_weights=1, 113 | init_method=functools.partial(torch.nn.init.normal_, mean=0.0, std=config.initializer_range), 114 | **kwargs, 115 | ) 116 | 117 | return args, device_mesh 118 | 119 | 120 | def attach_ffn_mb_args( 121 | ffn: megablocks.layers.dmoe.dMoE, 122 | args: megablocks.layers.arguments.Arguments, 123 | ): 124 | """Attach arguments used in parameter initialization to the FFN. 125 | 126 | Args: 127 | ffn (nn.Module): The FFN module. 128 | args (megablocks.layers.arguments.Arguments): The arguments for MegaBlocks. 129 | """ 130 | ffn.experts.mlp.hidden_size = args.ffn_hidden_size 131 | ffn.experts.mlp.expert_parallel_group = args.expert_parallel_group 132 | 133 | 134 | def set_ffn_device_mesh( 135 | ffn: nn.Module, 136 | moe_world_size: int, 137 | device_mesh: DeviceMesh, 138 | ): 139 | """Sets the device mesh in FSDP kwargs. 140 | 141 | Args: 142 | ffn (nn.Module): The FFN module. 143 | moe_world_size (int): The MoE world size. 144 | device_mesh (DeviceMesh): The full device mesh. 145 | 146 | Raises: 147 | RuntimeError: If the device mesh is 3D. 148 | ValueError: If the device mesh is not 2D or 3D. 149 | """ 150 | if moe_world_size > 1 and device_mesh is not None: 151 | expert_mesh = device_mesh["expert_parallel"] 152 | expert_placements: list[Placement] = [Shard(dim=0)] 153 | # Register in two loops as you cannot overwrite parameters while iterating over named_parameters() 154 | dtensorified_params = [ 155 | ( 156 | name, 157 | dtensorify_param( 158 | param=parameter, 159 | mesh=expert_mesh, 160 | placements=expert_placements, 161 | ), 162 | ) 163 | for name, parameter in ffn.experts.mlp.named_parameters() 164 | ] 165 | for name, dtensorified_param in dtensorified_params: 166 | ffn.experts.mlp.register_parameter(name, dtensorified_param) 167 | 168 | ffn.experts._fsdp_kwargs_dict = {"device_mesh": device_mesh["weight_parallel"]} 169 | 170 | 171 | def mb_build_dmoe( 172 | config: ProGen3Config, 173 | args: megablocks.layers.arguments.Arguments, 174 | device_mesh: DeviceMesh | None = None, 175 | **kwargs, 176 | ) -> megablocks.layers.dmoe.dMoE: 177 | ffn = megablocks.layers.dmoe.dMoE(args) 178 | attach_ffn_mb_args( 179 | ffn=ffn, 180 | args=args, 181 | ) 182 | set_ffn_device_mesh( 183 | ffn=ffn, 184 | moe_world_size=config.moe_world_size, 185 | device_mesh=device_mesh, 186 | ) 187 | return ffn 188 | -------------------------------------------------------------------------------- /src/progen3/model/moe.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Type 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from transformers.activations import ACT2FN 8 | 9 | from ..config import ProGen3Config 10 | from .mb_wrapper import mb_build_dmoe 11 | 12 | 13 | def promote_scalar(x: torch.Tensor) -> torch.Tensor: 14 | return x.view(1) if len(x.size()) == 0 else x 15 | 16 | 17 | class LogitConverter(ABC): 18 | @classmethod 19 | @abstractmethod 20 | def logits_to_probs(cls, logits: torch.Tensor, dtype: torch.dtype | None = None) -> torch.Tensor: 21 | raise NotImplementedError 22 | 23 | 24 | class SoftmaxMixIn(LogitConverter): 25 | """Converts logits to probabilities using a softmax.""" 26 | 27 | @classmethod 28 | def logits_to_probs(cls, logits: torch.Tensor, dtype: torch.dtype | None = None) -> torch.Tensor: 29 | dtype = logits.dtype if dtype is None else dtype 30 | return F.softmax(logits, dim=-1, dtype=dtype) 31 | 32 | 33 | class MLP(nn.Module): 34 | def __init__(self, config: ProGen3Config): 35 | super().__init__() 36 | self.ffn_dim = config.intermediate_size 37 | self.hidden_dim = config.hidden_size 38 | self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) 39 | self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) 40 | self.act_fn = ACT2FN[config.hidden_act] 41 | 42 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 43 | return self.w2(self.act_fn(self.w1(hidden_states))) 44 | 45 | 46 | class GLUMLP(nn.Module): 47 | def __init__(self, config: ProGen3Config): 48 | super().__init__() 49 | self.ffn_dim = config.intermediate_size 50 | self.hidden_dim = config.hidden_size 51 | self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) 52 | self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) 53 | self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) 54 | self.act_fn = ACT2FN[config.hidden_act] 55 | 56 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 57 | hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) 58 | hidden_states = self.w2(hidden_states) 59 | return hidden_states 60 | 61 | 62 | class SparseMoeBlock(nn.Module): 63 | """Strictly equivalent to standard MoE with full capacity (no dropped tokens).""" 64 | 65 | def __init__(self, config: ProGen3Config, **kwargs: dict): 66 | super().__init__() 67 | self.hidden_dim = config.hidden_size 68 | self.ffn_dim = config.intermediate_size 69 | self.n_experts = config.num_experts 70 | self.top_k = config.num_experts_per_tok 71 | self.expert_selector: Type[LogitConverter] = MOE_EXPERT_SELECTION[config.moe_expert_selection] 72 | mlp_cls = GLUMLP if config.gated_mlp else MLP 73 | self.experts = nn.ModuleList([mlp_cls(config) for _ in range(self.n_experts)]) 74 | if self.n_experts > 1: 75 | self.gate = nn.Linear(self.hidden_dim, self.n_experts, bias=False) 76 | 77 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 78 | # router_logits: (batch_size * sequence_length, n_experts) 79 | # router_weights: (batch_size * sequence_length, n_experts) 80 | bsz, seqlen, dim = hidden_states.shape 81 | if self.n_experts == 1: 82 | final_hidden_states = self.experts[0](hidden_states) 83 | router_logits = torch.zeros( 84 | (bsz * seqlen, self.n_experts), 85 | device=hidden_states.device, 86 | dtype=hidden_states.dtype, 87 | ) 88 | return final_hidden_states, F.softmax(router_logits, dim=-1) 89 | 90 | router_logits = self.gate(hidden_states) 91 | routing_weights = self.expert_selector.logits_to_probs(router_logits, dtype=torch.float32) 92 | router_logits = router_logits.view(-1, router_logits.shape[-1]) 93 | routing_weights = routing_weights.view(-1, routing_weights.shape[-1]) 94 | routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) 95 | routing_weights /= routing_weights.sum(dim=-1, keepdim=True) 96 | routing_weights = routing_weights.to(hidden_states.dtype) 97 | 98 | # Dense mixture of experts 99 | if self.n_experts == self.top_k: 100 | routing_weights = routing_weights.unsqueeze(1) 101 | final_hidden_states = torch.stack([e(hidden_states) for e in self.experts], dim=-1) 102 | final_hidden_states = torch.einsum("lde,lde->ld", final_hidden_states, routing_weights) 103 | final_hidden_states = final_hidden_states.reshape(bsz, seqlen, dim) 104 | return final_hidden_states, F.softmax(router_logits, dim=-1) 105 | 106 | # One hot encode the selected experts to create an expert mask 107 | # this will be used to easily index which expert is going to be sollicitated 108 | hidden_states = hidden_states.view(-1, dim) 109 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.n_experts).permute(2, 1, 0) 110 | final_hidden_states = torch.zeros_like(hidden_states) 111 | 112 | # Loop over all available experts in the model and perform the computation on each expert 113 | for expert_idx in range(self.n_experts): 114 | expert = self.experts[expert_idx] 115 | idx, top_x = torch.where(expert_mask[expert_idx]) 116 | 117 | if top_x.shape[0] == 0: 118 | continue 119 | 120 | # in torch it is faster to index using lists than torch tensors 121 | top_x_list = top_x.tolist() 122 | idx_list = idx.tolist() 123 | 124 | # Index the correct hidden states and compute the expert hidden state for 125 | # the current expert. We need to make sure to multiply the output hidden 126 | # states by `routing_weights` on the corresponding tokens (top-1 and top-2) 127 | current_state = expert(hidden_states[top_x_list]) 128 | current_state = current_state * routing_weights[top_x_list, idx_list, None] 129 | 130 | # However `index_add_` only support torch tensors for indexing so we'll use 131 | # the `top_x` tensor here. 132 | final_hidden_states.index_add_(0, top_x, current_state.to(hidden_states.dtype)) 133 | final_hidden_states = final_hidden_states.reshape(bsz, seqlen, dim) 134 | return final_hidden_states, F.softmax(router_logits, dim=-1) 135 | 136 | 137 | MOE_CLASSES = { 138 | "eager": SparseMoeBlock, 139 | "megablocks": mb_build_dmoe, 140 | } 141 | MOE_EXPERT_SELECTION = {"switch": SoftmaxMixIn} 142 | -------------------------------------------------------------------------------- /src/progen3/modeling.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | import contextlib 21 | import functools 22 | import os 23 | import re 24 | from dataclasses import dataclass 25 | from typing import Any, Callable, Mapping 26 | 27 | import megablocks.layers.moe 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | import torch.utils.checkpoint 32 | from megablocks.layers.dmoe import dMoE 33 | from transformers.cache_utils import Cache, DynamicCache 34 | from transformers.modeling_outputs import ModelOutput 35 | from transformers.modeling_utils import GenerationMixin, PreTrainedModel 36 | from transformers.utils import logging 37 | 38 | from progen3.common.model_loading import init_empty_weights 39 | 40 | try: 41 | from flash_attn.ops.triton.layer_norm import rms_norm_fn 42 | except ImportError: 43 | raise ImportError( 44 | "triton_rms_norm requires Flash Attention to be installed. " + "Please pip install flash-attn.", 45 | ) 46 | 47 | from .config import ProGen3Config 48 | from .model.attention import Attention 49 | from .model.mb_wrapper import mb_setup_args 50 | from .model.moe import MOE_CLASSES 51 | 52 | logger = logging.get_logger(__name__) 53 | 54 | 55 | def _update_state_dict( 56 | state_dict: Mapping[str, Any], 57 | config: ProGen3Config, 58 | ): 59 | # Make state dict interoperable between megablocks implementations 60 | key_sub = {} 61 | if config.moe_implementation == "eager": 62 | # TODO: add megablocks to eager substitutions here 63 | key_sub = {} 64 | 65 | def update_key(key): 66 | for k, v in key_sub.items(): 67 | key = re.sub(k, v, key) 68 | return key 69 | 70 | return {update_key(k): v for k, v in state_dict.items()} 71 | 72 | 73 | @dataclass 74 | class MoeModelOutputWithPast(ModelOutput): 75 | """Base class for model's outputs, with potential hidden states and attentions. 76 | 77 | Args: 78 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 79 | Sequence of hidden-states at the output of the last layer of the model. 80 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 81 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 82 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if 83 | `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, 84 | encoder_sequence_length, embed_size_per_head)`. 85 | 86 | Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if 87 | `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` 88 | input) to speed up sequential decoding. 89 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 90 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 91 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 92 | 93 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 94 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 95 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 96 | sequence_length)`. 97 | 98 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 99 | heads. 100 | 101 | router_weights (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_weights=True` and `config.add_router_weights=True` is passed or when `config.output_router_weights=True`): 102 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. 103 | 104 | Raw router weights (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary 105 | loss for Mixture of Experts models. 106 | """ 107 | 108 | last_hidden_state: torch.FloatTensor | None = None 109 | past_key_values: tuple[tuple[torch.FloatTensor]] | None = None 110 | hidden_states: tuple[torch.FloatTensor, ...] | None = None 111 | attentions: tuple[torch.FloatTensor, ...] | None = None 112 | router_weights: tuple[torch.FloatTensor] | None = None 113 | 114 | 115 | @dataclass 116 | class MoeCausalOutputWithPast(ModelOutput): 117 | """Base class for joint causal/masked language model with mixture of experts outputs. 118 | 119 | Args: 120 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 121 | Total loss. 122 | 123 | ar_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 124 | Autoregressive language modeling loss. 125 | 126 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 127 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 128 | 129 | aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): 130 | aux_loss for the sparse modules. 131 | 132 | router_weights (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_weights=True` or `config.output_router_weights=True`): 133 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. 134 | 135 | Raw router logits (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary 136 | loss for Mixture of Experts models. 137 | 138 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 139 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 140 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) 141 | 142 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see 143 | `past_key_values` input) to speed up sequential decoding. 144 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 145 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 146 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 147 | 148 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 149 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 150 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 151 | sequence_length)`. 152 | 153 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 154 | heads. 155 | """ 156 | 157 | loss: torch.FloatTensor | None = None 158 | ar_loss: torch.FloatTensor | None = None 159 | aux_loss: torch.FloatTensor | None = None 160 | logits: torch.FloatTensor = None 161 | past_key_values: tuple[tuple[torch.FloatTensor]] | None = None 162 | hidden_states: tuple[torch.FloatTensor, ...] | None = None 163 | attentions: tuple[torch.FloatTensor, ...] | None = None 164 | router_weights: tuple[torch.FloatTensor] | None = None 165 | 166 | 167 | class RMSNorm(nn.Module): 168 | def __init__(self, hidden_size, eps=1e-6): 169 | super().__init__() 170 | self.weight = nn.Parameter(torch.ones(hidden_size)) 171 | self.variance_epsilon = eps 172 | 173 | if not isinstance(hidden_size, int): 174 | raise ValueError("TritonRMSNorm only supports 1D tensors") 175 | 176 | self.rms_norm_fn = rms_norm_fn 177 | 178 | def forward(self, hidden_states: torch.Tensor): 179 | input_dtype = hidden_states.dtype 180 | return self.rms_norm_fn( 181 | hidden_states, 182 | self.weight, 183 | None, # no bias 184 | residual=None, 185 | eps=self.variance_epsilon, 186 | dropout_p=0.0, # no dropout by default 187 | prenorm=False, 188 | residual_in_fp32=False, 189 | ).to(input_dtype) 190 | 191 | 192 | class NormAttentionNorm(nn.Module): 193 | def __init__(self, config: ProGen3Config, layer_idx: int): 194 | super().__init__() 195 | self.self_attn = Attention(config, layer_idx) 196 | self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 197 | self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 198 | 199 | def forward( 200 | self, 201 | hidden_states: torch.Tensor, 202 | position_ids: torch.LongTensor, 203 | past_key_value: Cache | None = None, 204 | output_attentions: bool | None = None, 205 | use_cache: bool | None = None, 206 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, Cache | None]: 207 | residual = hidden_states 208 | hidden_states = self.input_layernorm(hidden_states) 209 | 210 | # Self Attention 211 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 212 | hidden_states=hidden_states, 213 | position_ids=position_ids, 214 | past_key_value=past_key_value, 215 | output_attentions=output_attentions, 216 | use_cache=use_cache, 217 | ) 218 | hidden_states = residual + hidden_states 219 | 220 | # Fully Connected 221 | residual = hidden_states 222 | hidden_states = self.post_attention_layernorm(hidden_states) 223 | return hidden_states, residual, self_attn_weights, present_key_value 224 | 225 | 226 | class DecoderLayer(nn.Module): 227 | def __init__(self, config: ProGen3Config, layer_idx: int, **moe_kwargs): 228 | super().__init__() 229 | self.initializer_range = config.initializer_range 230 | self.hidden_size = config.hidden_size 231 | self.fused_attention_norm = config.fused_attention_norm 232 | if self.fused_attention_norm: 233 | self.norm_attn_norm = NormAttentionNorm(config, layer_idx) 234 | else: 235 | self.self_attn = Attention(config, layer_idx) 236 | self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 237 | self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 238 | self.block_sparse_moe = MOE_CLASSES[config.moe_implementation](config, **moe_kwargs) 239 | self.moe_implementation = config.moe_implementation 240 | 241 | def forward( 242 | self, 243 | hidden_states: torch.Tensor, 244 | position_ids: torch.LongTensor, 245 | past_key_value: Cache | None = None, 246 | output_attentions: bool | None = None, 247 | output_router_weights: bool | None = None, 248 | use_cache: bool | None = None, 249 | ) -> tuple[torch.Tensor, ...]: 250 | if self.fused_attention_norm: 251 | hidden_states, residual, self_attn_weights, present_key_value = self.norm_attn_norm( 252 | hidden_states=hidden_states, 253 | position_ids=position_ids, 254 | past_key_value=past_key_value, 255 | output_attentions=output_attentions, 256 | use_cache=use_cache, 257 | ) 258 | else: 259 | residual = hidden_states 260 | hidden_states = self.input_layernorm(hidden_states) 261 | 262 | # Self Attention 263 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 264 | hidden_states=hidden_states, 265 | position_ids=position_ids, 266 | past_key_value=past_key_value, 267 | output_attentions=output_attentions, 268 | use_cache=use_cache, 269 | ) 270 | hidden_states = residual + hidden_states 271 | 272 | residual = hidden_states 273 | hidden_states = self.post_attention_layernorm(hidden_states) 274 | 275 | # Fully Connected 276 | if self.moe_implementation == "megablocks": 277 | hidden_states = self.block_sparse_moe(hidden_states) 278 | else: 279 | hidden_states, router_weights = self.block_sparse_moe(hidden_states) 280 | hidden_states = residual + hidden_states 281 | 282 | outputs = (hidden_states,) 283 | if output_attentions: 284 | outputs += (self_attn_weights,) 285 | if use_cache: 286 | outputs += (present_key_value,) 287 | if output_router_weights: 288 | outputs += (router_weights,) 289 | return outputs 290 | 291 | 292 | class ProGen3PreTrainedModel(PreTrainedModel): 293 | config_class = ProGen3Config 294 | base_model_prefix = "model" 295 | supports_gradient_checkpointing = True 296 | _no_split_modules = ["DecoderLayer"] 297 | _transformer_layer_cls = [DecoderLayer] 298 | _skip_keys_device_placement = "past_key_values" 299 | _supports_flash_attn_2 = False 300 | _supports_sdpa = True 301 | _supports_cache_class = True 302 | _vocab_keys = [] 303 | 304 | def _init_weights(self, module): 305 | std = self.config.initializer_range 306 | if isinstance(module, nn.Linear): 307 | module.weight.data.normal_(mean=0.0, std=std) 308 | if module.bias is not None: 309 | module.bias.data.zero_() 310 | elif isinstance(module, nn.Embedding): 311 | module.weight.data.normal_(mean=0.0, std=std) 312 | if module.padding_idx is not None: 313 | module.weight.data[module.padding_idx].zero_() 314 | elif isinstance(module, RMSNorm): 315 | module.weight.data.fill_(1.0) 316 | 317 | def post_init(self): 318 | super().post_init() 319 | self._set_update_state_dict() 320 | 321 | def _set_update_state_dict(self, update_fn: Callable | None = None): 322 | if update_fn is None: 323 | update_fn = functools.partial( 324 | _update_state_dict, 325 | config=self.config, 326 | ) 327 | self._update_state_dict = update_fn 328 | for child in self._modules.values(): 329 | child._update_state_dict = update_fn 330 | if isinstance(child, ProGen3PreTrainedModel): 331 | child._set_update_state_dict(update_fn) 332 | 333 | def _load_from_state_dict(self, state_dict, *args, **kwargs): 334 | state_dict = self._update_state_dict(state_dict) 335 | return super()._load_from_state_dict(state_dict, *args, **kwargs) 336 | 337 | def param_init_fn(self, module): 338 | std = self.config.initializer_range 339 | if isinstance(module, dMoE): 340 | module.experts.mlp.w1.data.normal_(mean=0.0, std=std) 341 | module.experts.mlp.w2.data.normal_(mean=0.0, std=std) 342 | if hasattr(module.experts.mlp, "v1"): 343 | module.experts.mlp.v1.data.normal_(mean=0.0, std=std) 344 | else: 345 | self._init_weights(module) 346 | 347 | def _backward_compatibility_gradient_checkpointing(self): 348 | if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): 349 | self.gradient_checkpointing_enable(dict(use_reentrant=False)) 350 | 351 | def fsdp_wrap_fn(self, module): 352 | if hasattr(module, "_fsdp_kwargs_dict"): 353 | return module._fsdp_kwargs_dict 354 | return isinstance(module, tuple(self._transformer_layer_cls)) 355 | 356 | def activation_checkpointing_fn(self, module): 357 | attn_cls = NormAttentionNorm if self.config.fused_attention_norm else (Attention, RMSNorm) 358 | ckpt_cls = attn_cls if self.config.no_ffn_gradient_checkpointing else tuple(self._transformer_layer_cls) 359 | return isinstance(module, ckpt_cls) 360 | 361 | @classmethod 362 | def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike | None, *args, **kwargs): 363 | return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) 364 | 365 | 366 | class ProGen3Model(ProGen3PreTrainedModel): 367 | _vocab_keys = ["embed_tokens.weight"] 368 | 369 | def __init__(self, config: ProGen3Config, meta_init: bool = False): 370 | super().__init__(config) 371 | self.padding_idx = config.pad_token_id 372 | self.vocab_size = config.vocab_size 373 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 374 | self.embed_seq_id = nn.Embedding(config.max_num_sequences, config.hidden_size) 375 | p = next(self.embed_tokens.parameters()) 376 | if config.moe_implementation == "megablocks": 377 | mb_args, device_mesh = mb_setup_args(config, dtype=p.dtype, device=p.device) 378 | kwargs = dict(args=mb_args, device_mesh=device_mesh) 379 | self.mb_args = mb_args 380 | self.expert_parallel_device_mesh = device_mesh 381 | else: 382 | kwargs = dict() 383 | ctx = init_empty_weights if meta_init else contextlib.nullcontext 384 | with ctx(): 385 | self.layers = nn.ModuleList([DecoderLayer(config, i, **kwargs) for i in range(config.num_hidden_layers)]) 386 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 387 | self.gradient_checkpointing = config.gradient_checkpointing 388 | self.post_init() 389 | 390 | def get_input_embeddings(self): 391 | return self.embed_tokens 392 | 393 | def set_input_embeddings(self, value): 394 | self.embed_tokens = value 395 | 396 | def forward( 397 | self, 398 | input_ids: torch.LongTensor, 399 | position_ids: torch.LongTensor, 400 | sequence_ids: torch.LongTensor, 401 | past_key_values: Cache | None = None, 402 | use_cache: bool | None = None, 403 | output_attentions: bool | None = None, 404 | output_hidden_states: bool | None = None, 405 | output_router_weights: bool | None = None, 406 | return_dict: bool | None = None, 407 | ) -> tuple[torch.Tensor, ...] | MoeModelOutputWithPast: 408 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 409 | output_router_weights = ( 410 | output_router_weights if output_router_weights is not None else self.config.output_router_weights 411 | ) 412 | output_hidden_states = ( 413 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 414 | ) 415 | use_cache = use_cache if use_cache is not None else self.config.use_cache 416 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 417 | if output_router_weights and self.config.moe_implementation == "megablocks": 418 | raise ValueError(f"{output_router_weights=} not compatible with megablocks MoE implementation") 419 | if self.config.moe_implementation == "megablocks": 420 | megablocks.layers.moe.clear_load_balancing_loss() 421 | 422 | # retrieve input_ids and inputs_embeds 423 | batch_size, seq_length = input_ids.shape 424 | 425 | if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): 426 | if use_cache: 427 | logger.warning_once( 428 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 429 | ) 430 | use_cache = False 431 | 432 | if use_cache and past_key_values is None: 433 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 434 | elif not use_cache: 435 | # To avoid weirdness with gradient checkpointing: https://github.com/huggingface/transformers/issues/28499 436 | past_key_values = None 437 | 438 | position_ids = position_ids.view(-1, seq_length).long() 439 | sequence_ids = sequence_ids.view(-1, seq_length).long() 440 | inputs_embeds = self.embed_tokens(input_ids) 441 | inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids) 442 | 443 | # In case we need to do any manual typecasting 444 | if torch.is_autocast_enabled(): 445 | target_dtype = torch.get_autocast_gpu_dtype() 446 | elif hasattr(self.config, "_pre_quantization_dtype"): 447 | target_dtype = self.config._pre_quantization_dtype 448 | elif self.config.fused_attention_norm: 449 | target_dtype = self.layers[0].norm_attn_norm.self_attn.q_proj.weight.dtype 450 | else: 451 | target_dtype = self.layers[0].self_attn.q_proj.weight.dtype 452 | hidden_states = inputs_embeds.to(target_dtype) 453 | 454 | # decoder layers 455 | all_hidden_states = () if output_hidden_states else None 456 | all_self_attns = () if output_attentions else None 457 | all_router_weights = () if output_router_weights else None 458 | next_decoder_cache = None 459 | 460 | for decoder_layer in self.layers: 461 | if output_hidden_states: 462 | all_hidden_states += (hidden_states,) 463 | 464 | if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): 465 | layer_outputs = self._gradient_checkpointing_func( 466 | decoder_layer.__call__, 467 | hidden_states, 468 | position_ids, 469 | past_key_values, 470 | output_attentions, 471 | output_router_weights, 472 | use_cache, 473 | ) 474 | else: 475 | layer_outputs = decoder_layer( 476 | hidden_states, 477 | position_ids=position_ids, 478 | past_key_value=past_key_values, 479 | output_attentions=output_attentions, 480 | output_router_weights=output_router_weights, 481 | use_cache=use_cache, 482 | ) 483 | 484 | hidden_states = layer_outputs[0] 485 | 486 | if use_cache: 487 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 488 | 489 | if output_attentions: 490 | all_self_attns += (layer_outputs[1],) 491 | 492 | if output_router_weights: 493 | all_router_weights += (layer_outputs[-1],) 494 | 495 | hidden_states = self.norm(hidden_states) 496 | 497 | # add hidden states from the last decoder layer 498 | if output_hidden_states: 499 | all_hidden_states += (hidden_states,) 500 | 501 | next_cache = next_decoder_cache if use_cache else None 502 | 503 | if not return_dict: 504 | return tuple( 505 | v 506 | for v in [ 507 | hidden_states, 508 | next_cache, 509 | all_hidden_states, 510 | all_self_attns, 511 | all_router_weights, 512 | ] 513 | if v is not None 514 | ) 515 | return MoeModelOutputWithPast( 516 | last_hidden_state=hidden_states, 517 | past_key_values=next_cache, 518 | hidden_states=all_hidden_states, 519 | attentions=all_self_attns, 520 | router_weights=all_router_weights, 521 | ) 522 | 523 | 524 | class ProGen3ForCausalLM(ProGen3PreTrainedModel, GenerationMixin): 525 | _vocab_keys = ["model.embed_tokens.weight", "lm_head.weight"] 526 | 527 | def __init__(self, config: ProGen3Config, meta_init: bool = False): 528 | super().__init__(config) 529 | self.model = ProGen3Model(config, meta_init=meta_init) 530 | self.vocab_size = config.vocab_size 531 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 532 | self.router_aux_loss_coef = config.router_aux_loss_coef 533 | self.num_experts = config.num_experts 534 | self.num_experts_per_tok = config.num_experts_per_tok 535 | self.gradient_checkpointing = config.gradient_checkpointing 536 | self.post_init() 537 | 538 | @property 539 | def device_mesh(self): 540 | return self.model.device_mesh 541 | 542 | def get_input_embeddings(self): 543 | return self.model.embed_tokens 544 | 545 | def set_input_embeddings(self, value): 546 | self.model.embed_tokens = value 547 | 548 | def get_output_embeddings(self): 549 | return self.lm_head 550 | 551 | def set_output_embeddings(self, new_embeddings): 552 | self.lm_head = new_embeddings 553 | 554 | def set_decoder(self, decoder): 555 | self.model = decoder 556 | 557 | def get_decoder(self): 558 | return self.model 559 | 560 | def forward( 561 | self, 562 | input_ids: torch.LongTensor, 563 | position_ids: torch.LongTensor, 564 | sequence_ids: torch.LongTensor, 565 | past_key_values: Cache | None = None, 566 | labels: torch.LongTensor | None = None, 567 | use_cache: bool | None = None, 568 | output_attentions: bool | None = None, 569 | output_hidden_states: bool | None = None, 570 | output_router_weights: bool | None = None, 571 | return_dict: bool | None = None, 572 | ) -> tuple[torch.Tensor, ...] | MoeCausalOutputWithPast: 573 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 574 | output_router_weights = ( 575 | output_router_weights if output_router_weights is not None else self.config.output_router_weights 576 | ) 577 | 578 | output_hidden_states = ( 579 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 580 | ) 581 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 582 | 583 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 584 | outputs = self.model( 585 | input_ids=input_ids, 586 | position_ids=position_ids, 587 | sequence_ids=sequence_ids, 588 | past_key_values=past_key_values, 589 | use_cache=use_cache, 590 | output_attentions=output_attentions, 591 | output_hidden_states=output_hidden_states, 592 | output_router_weights=output_router_weights, 593 | return_dict=return_dict, 594 | ) 595 | 596 | hidden_states = outputs[0] 597 | loss = None 598 | 599 | # Compute autoregressive languag modeling loss 600 | logits = self.lm_head(hidden_states).float() 601 | if labels is not None: 602 | # Shift inputs & labels so that tokens < n predict n, and flatten them 603 | shift_logits = logits[..., :-1, :].contiguous().view(-1, self.config.vocab_size) 604 | shift_labels = labels[..., 1:].contiguous().view(-1).to(shift_logits.device) 605 | ar_loss = F.cross_entropy(shift_logits, shift_labels, reduction="none") 606 | mask = shift_labels != self.model.padding_idx 607 | n_ar = mask.sum() 608 | ar_loss = (ar_loss * mask.to(ar_loss)).sum() / (1 if n_ar == 0 else n_ar) 609 | loss = ar_loss 610 | else: 611 | n_ar, ar_loss = 0, 0 612 | 613 | aux_loss = None 614 | if self.config.moe_implementation == "megablocks" and self.training: 615 | aux_loss = megablocks.layers.moe.batched_load_balancing_loss(self.model.mb_args) 616 | if loss is not None: 617 | loss += aux_loss 618 | aux_loss /= self.router_aux_loss_coef 619 | 620 | if not return_dict: 621 | output = (logits,) + outputs[1:] 622 | if output_router_weights: 623 | output = (aux_loss,) + output 624 | return (loss,) + output if loss is not None else output 625 | 626 | return MoeCausalOutputWithPast( 627 | loss=loss, 628 | ar_loss=None if labels is None else ar_loss, 629 | aux_loss=aux_loss, 630 | logits=logits, 631 | past_key_values=outputs.past_key_values, 632 | hidden_states=outputs.hidden_states, 633 | attentions=outputs.attentions, 634 | router_weights=outputs.router_weights, 635 | ) 636 | 637 | def prepare_inputs_for_generation( 638 | self, 639 | input_ids, 640 | position_ids, 641 | sequence_ids, 642 | past_key_values=None, 643 | cache_position=None, 644 | **kwargs, 645 | ): 646 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens 647 | # Exception 1: when passing input_embeds, input_ids may be missing entries 648 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here 649 | # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case 650 | if past_key_values is not None: 651 | if cache_position[-1] >= input_ids.shape[1]: # Exception 3 652 | input_ids = input_ids[:, -cache_position.shape[0] :] 653 | position_ids = position_ids[:, cache_position] 654 | sequence_ids = sequence_ids[:, cache_position] 655 | elif input_ids.shape[1] != len(cache_position): # Default case (the "else", a no op, is Exception 2) 656 | input_ids = input_ids[:, cache_position] 657 | position_ids = position_ids[:, cache_position] 658 | sequence_ids = sequence_ids[:, cache_position] 659 | 660 | model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases 661 | 662 | model_inputs.update( 663 | position_ids=position_ids, 664 | sequence_ids=sequence_ids, 665 | past_key_values=past_key_values, 666 | use_cache=kwargs.get("use_cache", None), 667 | output_router_weights=kwargs.get("output_router_weights", None), 668 | ) 669 | return model_inputs 670 | 671 | @staticmethod 672 | def _reorder_cache(past_key_values, beam_idx): 673 | if isinstance(past_key_values, Cache): 674 | return past_key_values.reorder_cache(beam_idx) 675 | 676 | reordered_past = () 677 | for layer_past in past_key_values: 678 | reordered_past += ( 679 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 680 | ) 681 | return DynamicCache.from_legacy_cache(reordered_past) 682 | 683 | def _update_model_kwargs_for_generation( 684 | self, 685 | outputs: ModelOutput, 686 | model_kwargs: dict[str, Any], 687 | num_new_tokens: int = 1, 688 | **kwargs, 689 | ) -> dict[str, Any]: 690 | # Change made in transformers>4.42.0 to return two values, 691 | # cache_name and past_key_values, instead of a single past_key_values 692 | cache_name, cache = self._extract_past_from_model_output(outputs) 693 | assert cache_name == "past_key_values", "Only past_key_values is supported" 694 | model_kwargs["past_key_values"] = cache 695 | 696 | # update position_ids with one plus last value 697 | pos_ids = model_kwargs["position_ids"] 698 | new_delta = torch.arange(num_new_tokens, device=pos_ids.device, dtype=pos_ids.dtype).unsqueeze(0) + 1 699 | model_kwargs["position_ids"] = torch.cat([pos_ids, pos_ids[:, -1:] + new_delta], dim=-1) 700 | 701 | # update sequence_ids with last value 702 | seq_ids = model_kwargs["sequence_ids"] 703 | model_kwargs["sequence_ids"] = torch.cat([seq_ids, seq_ids[:, -1:].repeat(1, num_new_tokens)], dim=-1) 704 | 705 | if model_kwargs.get("use_cache", True): 706 | model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens 707 | else: 708 | past_positions = model_kwargs.pop("cache_position") 709 | new_positions = torch.arange( 710 | past_positions[-1] + 1, 711 | past_positions[-1] + num_new_tokens + 1, 712 | dtype=past_positions.dtype, 713 | ).to(past_positions.device) 714 | model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) 715 | 716 | return model_kwargs 717 | -------------------------------------------------------------------------------- /src/progen3/scorer.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections import defaultdict 3 | from typing import Any 4 | 5 | import pandas as pd 6 | import torch 7 | import torch.distributed 8 | import torch.nn as nn 9 | from Bio import SeqIO 10 | from tqdm import tqdm 11 | 12 | from progen3.batch_preparer import ProGen3BatchPreparer 13 | from progen3.common import dist 14 | from progen3.modeling import MoeCausalOutputWithPast, ProGen3ForCausalLM 15 | 16 | IndexedSequence = tuple[int, str] 17 | 18 | 19 | class ProGen3Scorer: 20 | def __init__( 21 | self, 22 | model: ProGen3ForCausalLM, 23 | max_batch_tokens: int = 65536, 24 | reduction: str = "mean", 25 | ): 26 | super().__init__() 27 | self.batch_preparer = ProGen3BatchPreparer() 28 | if reduction not in ["mean", "sum"]: 29 | raise ValueError(f"Reduction must be one of {['mean', 'sum']}") 30 | self.reduction = reduction 31 | self.model = model 32 | self.max_batch_tokens = max_batch_tokens 33 | self.model.eval() 34 | 35 | def group_by_length(self, indexed_sequences: list[IndexedSequence]) -> list[list[IndexedSequence]]: 36 | batches: list[list[IndexedSequence]] = [[]] 37 | for idx, seq in sorted(indexed_sequences, key=lambda idx_seq: (len(idx_seq[1]), idx_seq[0])): 38 | if len(batches[-1]) > 0 and len(seq) * (len(batches[-1]) + 1) > self.max_batch_tokens: 39 | batches.append([]) 40 | batches[-1].append((idx, seq)) 41 | 42 | return batches 43 | 44 | def batch_sequences(self, sequences: list[str]) -> list[list[int]]: # type: ignore[override] 45 | """ 46 | Batches the sequences and returns indices for the current rank 47 | We want to keep sequences of similar length together. 48 | Ensures that no batch exceeds max_batch_tokens 49 | """ 50 | indexed_sequences: list[IndexedSequence] = list(enumerate(sequences)) 51 | indexed_batches = self.group_by_length(indexed_sequences) 52 | batches = [[item[0] for item in batch] for batch in indexed_batches] # type: ignore[no-redef,misc] 53 | 54 | assert sorted(sum(batches, [])) == list( 55 | range(len(sequences)) 56 | ), "Batches must contain all indices with no repetition" 57 | 58 | world_size = dist.get_world_size() 59 | extra_batches_needed = (world_size - len(batches)) % world_size 60 | # create extra batches to make the total number of batches divisible by the world size 61 | batches += [batches[0][:1]] * extra_batches_needed 62 | assert len(batches) % world_size == 0 63 | 64 | return batches[dist.get_rank() :: world_size] # type: ignore[return-value] 65 | 66 | @torch.no_grad 67 | def score_batch(self, sequences: list[str]) -> dict[str, list[float]]: 68 | kwargs_n_to_c = self.batch_preparer.get_batch_kwargs(sequences, device=dist.get_device(), reverse=False) 69 | output_batch = self._log_likelihoods(kwargs_n_to_c) 70 | 71 | kwargs_c_to_n = self.batch_preparer.get_batch_kwargs(sequences, device=dist.get_device(), reverse=True) 72 | output_rev_batch = self._log_likelihoods(kwargs_c_to_n) 73 | scores: dict[str, list[float]] = {"log_likelihood": [], "perplexity": []} 74 | for i in range(len(sequences)): 75 | ll_batch, ll_rev_batch = output_batch[i], output_rev_batch[i] 76 | ll = (ll_batch + ll_rev_batch) / 2 77 | scores["log_likelihood"].append(ll.item()) 78 | scores["perplexity"].append(torch.exp(-ll).item()) 79 | return scores 80 | 81 | def _log_likelihoods(self, model_forward_kwargs: dict[str, Any]) -> torch.Tensor: 82 | output: MoeCausalOutputWithPast = self.model( 83 | input_ids=model_forward_kwargs["input_ids"], 84 | labels=model_forward_kwargs["labels"], 85 | sequence_ids=model_forward_kwargs["sequence_ids"], 86 | position_ids=model_forward_kwargs["position_ids"], 87 | return_dict=True, 88 | ) 89 | labels = model_forward_kwargs["labels"] 90 | target_mask = labels != self.model.config.pad_token_id 91 | 92 | targets = labels[..., 1:].contiguous() 93 | target_mask = target_mask[..., 1:].contiguous() 94 | logits = output.logits[..., :-1, :].contiguous().to(torch.float32) 95 | flat_logits = logits.view(-1, logits.shape[-1]) 96 | nll = nn.functional.cross_entropy(flat_logits, targets.view(-1), reduction="none").view(targets.shape) 97 | nll = (nll * target_mask.to(nll)).sum(dim=1) 98 | if self.reduction == "mean": 99 | nll = nll / target_mask.sum(dim=1) 100 | return -nll.detach() 101 | 102 | def evaluate(self, sequences: list[str]) -> dict[str, torch.Tensor]: 103 | """ 104 | Returns a dictionary mapping a scoring metric to a tensor of scores. 105 | 106 | In a distributed setting, each rank will compute scores for its own sequences, 107 | and then we will all_gather the scores from each rank to get the full tensor of scores. 108 | Each rank is responsible for its own indices in the full tensor. 109 | """ 110 | 111 | sequence_batch_indices: list[list[int]] = self.batch_sequences(sequences) 112 | scores: dict[str, list[float]] = defaultdict(list) 113 | pbar = tqdm(desc="Scored sequences: ", ncols=80, disable=dist.get_rank() != 0) 114 | for indices in sequence_batch_indices: 115 | sequence_batch = [sequences[i] for i in indices] 116 | batch_scores: dict[str, list[float]] = self.score_batch(sequence_batch) 117 | for metric, float_list in batch_scores.items(): 118 | scores[metric] += float_list 119 | 120 | batch_size = torch.tensor(len(sequence_batch), device=dist.get_device()) 121 | if torch.distributed.is_initialized(): 122 | torch.distributed.all_reduce(batch_size, op=torch.distributed.ReduceOp.SUM) 123 | pbar.update(batch_size.item()) 124 | 125 | sequence_indices = list(itertools.chain.from_iterable(sequence_batch_indices)) 126 | 127 | if torch.distributed.is_initialized(): 128 | # gather scores and sequence batch indices from all ranks 129 | scores, sequence_indices = self._dist_get_rank_scores_and_indices(scores, sequence_indices) 130 | 131 | ordered_scores = self._order_scores_by_indices(scores, sequence_indices) 132 | return ordered_scores 133 | 134 | def _dist_get_rank_scores_and_indices( 135 | self, 136 | scores: dict[str, list[float]], 137 | sequence_indices: list[int], 138 | ) -> tuple[dict[str, list[float]], list[int]]: 139 | """ 140 | Concatenates scores and sequence batch indices from all ranks 141 | Puts scores and indices in the same form as the local scores and indices 142 | """ 143 | all_rank: list = [None for _ in range(dist.get_world_size())] 144 | torch.distributed.all_gather_object(all_rank, (scores, sequence_indices)) 145 | 146 | all_scores: dict[str, list[float]] = defaultdict(list) 147 | all_indices: list[int] = [] 148 | for rank_scores, rank_indices in all_rank: # type: ignore 149 | for metric, float_list in rank_scores.items(): # type: ignore 150 | all_scores[metric] += float_list 151 | all_indices.extend(rank_indices) 152 | 153 | return all_scores, all_indices 154 | 155 | def _order_scores_by_indices( 156 | self, scores: dict[str, list[float]], sequence_indices: list[int] 157 | ) -> dict[str, torch.Tensor]: 158 | """ 159 | Scores are not ordered. We need to order them based on the sequence_indices. 160 | """ 161 | output: dict[str, torch.Tensor] = {} 162 | for metric, float_list in scores.items(): 163 | if metric not in output: 164 | output[metric] = torch.zeros(max(sequence_indices) + 1) 165 | for i, idx in enumerate(sequence_indices): 166 | output[metric][idx] = float_list[i] 167 | return output 168 | 169 | def run(self, fasta_path: str, output_path: str) -> None: 170 | sequences = {} 171 | for record in SeqIO.parse(fasta_path, "fasta"): 172 | sequences[record.id] = str(record.seq) 173 | 174 | seq_ids = sorted(list(sequences.keys())) 175 | seqs = [sequences[seq_id] for seq_id in seq_ids] 176 | 177 | scores = self.evaluate(seqs) 178 | 179 | if dist.get_rank() == 0: 180 | df = pd.DataFrame(scores, index=seq_ids) 181 | df.index.name = "sequence_id" 182 | df.to_csv(output_path) 183 | -------------------------------------------------------------------------------- /src/progen3/tokenizer.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.0", 3 | "truncation": null, 4 | "padding": { 5 | "strategy": "BatchLongest", 6 | "direction": "Right", 7 | "pad_to_multiple_of": null, 8 | "pad_id": 0, 9 | "pad_type_id": 0, 10 | "pad_token": "" 11 | }, 12 | "added_tokens": [ 13 | { 14 | "id": 0, 15 | "content": "", 16 | "single_word": false, 17 | "lstrip": false, 18 | "rstrip": false, 19 | "normalized": false, 20 | "special": true 21 | }, 22 | { 23 | "id": 1, 24 | "content": "", 25 | "single_word": false, 26 | "lstrip": false, 27 | "rstrip": false, 28 | "normalized": false, 29 | "special": true 30 | }, 31 | { 32 | "id": 2, 33 | "content": "", 34 | "single_word": false, 35 | "lstrip": false, 36 | "rstrip": false, 37 | "normalized": false, 38 | "special": true 39 | }, 40 | { 41 | "id": 3, 42 | "content": "", 43 | "single_word": false, 44 | "lstrip": false, 45 | "rstrip": false, 46 | "normalized": false, 47 | "special": true 48 | }, 49 | { 50 | "id": 4, 51 | "content": "", 52 | "single_word": false, 53 | "lstrip": false, 54 | "rstrip": false, 55 | "normalized": false, 56 | "special": true 57 | }, 58 | { 59 | "id": 5, 60 | "content": "", 61 | "single_word": false, 62 | "lstrip": false, 63 | "rstrip": false, 64 | "normalized": false, 65 | "special": true 66 | }, 67 | { 68 | "id": 34, 69 | "content": "", 70 | "single_word": false, 71 | "lstrip": false, 72 | "rstrip": false, 73 | "normalized": false, 74 | "special": true 75 | }, 76 | { 77 | "id": 35, 78 | "content": "", 79 | "single_word": false, 80 | "lstrip": false, 81 | "rstrip": false, 82 | "normalized": false, 83 | "special": true 84 | }, 85 | { 86 | "id": 36, 87 | "content": "", 88 | "single_word": false, 89 | "lstrip": false, 90 | "rstrip": false, 91 | "normalized": false, 92 | "special": true 93 | }, 94 | { 95 | "id": 37, 96 | "content": "", 97 | "single_word": false, 98 | "lstrip": false, 99 | "rstrip": false, 100 | "normalized": false, 101 | "special": true 102 | }, 103 | { 104 | "id": 38, 105 | "content": "", 106 | "single_word": false, 107 | "lstrip": false, 108 | "rstrip": false, 109 | "normalized": false, 110 | "special": true 111 | }, 112 | { 113 | "id": 39, 114 | "content": "", 115 | "single_word": false, 116 | "lstrip": false, 117 | "rstrip": false, 118 | "normalized": false, 119 | "special": true 120 | }, 121 | { 122 | "id": 40, 123 | "content": "", 124 | "single_word": false, 125 | "lstrip": false, 126 | "rstrip": false, 127 | "normalized": false, 128 | "special": true 129 | }, 130 | { 131 | "id": 41, 132 | "content": "", 133 | "single_word": false, 134 | "lstrip": false, 135 | "rstrip": false, 136 | "normalized": false, 137 | "special": true 138 | }, 139 | { 140 | "id": 42, 141 | "content": "", 142 | "single_word": false, 143 | "lstrip": false, 144 | "rstrip": false, 145 | "normalized": false, 146 | "special": true 147 | }, 148 | { 149 | "id": 43, 150 | "content": "", 151 | "single_word": false, 152 | "lstrip": false, 153 | "rstrip": false, 154 | "normalized": false, 155 | "special": true 156 | }, 157 | { 158 | "id": 44, 159 | "content": "", 160 | "single_word": false, 161 | "lstrip": false, 162 | "rstrip": false, 163 | "normalized": false, 164 | "special": true 165 | }, 166 | { 167 | "id": 45, 168 | "content": "", 169 | "single_word": false, 170 | "lstrip": false, 171 | "rstrip": false, 172 | "normalized": false, 173 | "special": true 174 | }, 175 | { 176 | "id": 46, 177 | "content": "", 178 | "single_word": false, 179 | "lstrip": false, 180 | "rstrip": false, 181 | "normalized": false, 182 | "special": true 183 | }, 184 | { 185 | "id": 47, 186 | "content": "", 187 | "single_word": false, 188 | "lstrip": false, 189 | "rstrip": false, 190 | "normalized": false, 191 | "special": true 192 | }, 193 | { 194 | "id": 48, 195 | "content": "", 196 | "single_word": false, 197 | "lstrip": false, 198 | "rstrip": false, 199 | "normalized": false, 200 | "special": true 201 | }, 202 | { 203 | "id": 49, 204 | "content": "", 205 | "single_word": false, 206 | "lstrip": false, 207 | "rstrip": false, 208 | "normalized": false, 209 | "special": true 210 | }, 211 | { 212 | "id": 50, 213 | "content": "", 214 | "single_word": false, 215 | "lstrip": false, 216 | "rstrip": false, 217 | "normalized": false, 218 | "special": true 219 | }, 220 | { 221 | "id": 51, 222 | "content": "", 223 | "single_word": false, 224 | "lstrip": false, 225 | "rstrip": false, 226 | "normalized": false, 227 | "special": true 228 | }, 229 | { 230 | "id": 52, 231 | "content": "", 232 | "single_word": false, 233 | "lstrip": false, 234 | "rstrip": false, 235 | "normalized": false, 236 | "special": true 237 | }, 238 | { 239 | "id": 53, 240 | "content": "", 241 | "single_word": false, 242 | "lstrip": false, 243 | "rstrip": false, 244 | "normalized": false, 245 | "special": true 246 | }, 247 | { 248 | "id": 54, 249 | "content": "", 250 | "single_word": false, 251 | "lstrip": false, 252 | "rstrip": false, 253 | "normalized": false, 254 | "special": true 255 | }, 256 | { 257 | "id": 55, 258 | "content": "", 259 | "single_word": false, 260 | "lstrip": false, 261 | "rstrip": false, 262 | "normalized": false, 263 | "special": true 264 | }, 265 | { 266 | "id": 56, 267 | "content": "", 268 | "single_word": false, 269 | "lstrip": false, 270 | "rstrip": false, 271 | "normalized": false, 272 | "special": true 273 | }, 274 | { 275 | "id": 57, 276 | "content": "", 277 | "single_word": false, 278 | "lstrip": false, 279 | "rstrip": false, 280 | "normalized": false, 281 | "special": true 282 | }, 283 | { 284 | "id": 58, 285 | "content": "", 286 | "single_word": false, 287 | "lstrip": false, 288 | "rstrip": false, 289 | "normalized": false, 290 | "special": true 291 | }, 292 | { 293 | "id": 59, 294 | "content": "", 295 | "single_word": false, 296 | "lstrip": false, 297 | "rstrip": false, 298 | "normalized": false, 299 | "special": true 300 | }, 301 | { 302 | "id": 60, 303 | "content": "", 304 | "single_word": false, 305 | "lstrip": false, 306 | "rstrip": false, 307 | "normalized": false, 308 | "special": true 309 | }, 310 | { 311 | "id": 61, 312 | "content": "", 313 | "single_word": false, 314 | "lstrip": false, 315 | "rstrip": false, 316 | "normalized": false, 317 | "special": true 318 | }, 319 | { 320 | "id": 62, 321 | "content": "", 322 | "single_word": false, 323 | "lstrip": false, 324 | "rstrip": false, 325 | "normalized": false, 326 | "special": true 327 | }, 328 | { 329 | "id": 63, 330 | "content": "", 331 | "single_word": false, 332 | "lstrip": false, 333 | "rstrip": false, 334 | "normalized": false, 335 | "special": true 336 | }, 337 | { 338 | "id": 64, 339 | "content": "", 340 | "single_word": false, 341 | "lstrip": false, 342 | "rstrip": false, 343 | "normalized": false, 344 | "special": true 345 | }, 346 | { 347 | "id": 65, 348 | "content": "", 349 | "single_word": false, 350 | "lstrip": false, 351 | "rstrip": false, 352 | "normalized": false, 353 | "special": true 354 | }, 355 | { 356 | "id": 66, 357 | "content": "", 358 | "single_word": false, 359 | "lstrip": false, 360 | "rstrip": false, 361 | "normalized": false, 362 | "special": true 363 | }, 364 | { 365 | "id": 67, 366 | "content": "", 367 | "single_word": false, 368 | "lstrip": false, 369 | "rstrip": false, 370 | "normalized": false, 371 | "special": true 372 | }, 373 | { 374 | "id": 68, 375 | "content": "", 376 | "single_word": false, 377 | "lstrip": false, 378 | "rstrip": false, 379 | "normalized": false, 380 | "special": true 381 | }, 382 | { 383 | "id": 69, 384 | "content": "", 385 | "single_word": false, 386 | "lstrip": false, 387 | "rstrip": false, 388 | "normalized": false, 389 | "special": true 390 | }, 391 | { 392 | "id": 70, 393 | "content": "", 394 | "single_word": false, 395 | "lstrip": false, 396 | "rstrip": false, 397 | "normalized": false, 398 | "special": true 399 | }, 400 | { 401 | "id": 71, 402 | "content": "", 403 | "single_word": false, 404 | "lstrip": false, 405 | "rstrip": false, 406 | "normalized": false, 407 | "special": true 408 | }, 409 | { 410 | "id": 72, 411 | "content": "", 412 | "single_word": false, 413 | "lstrip": false, 414 | "rstrip": false, 415 | "normalized": false, 416 | "special": true 417 | }, 418 | { 419 | "id": 73, 420 | "content": "", 421 | "single_word": false, 422 | "lstrip": false, 423 | "rstrip": false, 424 | "normalized": false, 425 | "special": true 426 | }, 427 | { 428 | "id": 74, 429 | "content": "", 430 | "single_word": false, 431 | "lstrip": false, 432 | "rstrip": false, 433 | "normalized": false, 434 | "special": true 435 | }, 436 | { 437 | "id": 75, 438 | "content": "", 439 | "single_word": false, 440 | "lstrip": false, 441 | "rstrip": false, 442 | "normalized": false, 443 | "special": true 444 | }, 445 | { 446 | "id": 76, 447 | "content": "", 448 | "single_word": false, 449 | "lstrip": false, 450 | "rstrip": false, 451 | "normalized": false, 452 | "special": true 453 | }, 454 | { 455 | "id": 77, 456 | "content": "", 457 | "single_word": false, 458 | "lstrip": false, 459 | "rstrip": false, 460 | "normalized": false, 461 | "special": true 462 | }, 463 | { 464 | "id": 78, 465 | "content": "", 466 | "single_word": false, 467 | "lstrip": false, 468 | "rstrip": false, 469 | "normalized": false, 470 | "special": true 471 | }, 472 | { 473 | "id": 79, 474 | "content": "", 475 | "single_word": false, 476 | "lstrip": false, 477 | "rstrip": false, 478 | "normalized": false, 479 | "special": true 480 | }, 481 | { 482 | "id": 80, 483 | "content": "", 484 | "single_word": false, 485 | "lstrip": false, 486 | "rstrip": false, 487 | "normalized": false, 488 | "special": true 489 | }, 490 | { 491 | "id": 81, 492 | "content": "", 493 | "single_word": false, 494 | "lstrip": false, 495 | "rstrip": false, 496 | "normalized": false, 497 | "special": true 498 | }, 499 | { 500 | "id": 82, 501 | "content": "", 502 | "single_word": false, 503 | "lstrip": false, 504 | "rstrip": false, 505 | "normalized": false, 506 | "special": true 507 | }, 508 | { 509 | "id": 83, 510 | "content": "", 511 | "single_word": false, 512 | "lstrip": false, 513 | "rstrip": false, 514 | "normalized": false, 515 | "special": true 516 | }, 517 | { 518 | "id": 84, 519 | "content": "", 520 | "single_word": false, 521 | "lstrip": false, 522 | "rstrip": false, 523 | "normalized": false, 524 | "special": true 525 | }, 526 | { 527 | "id": 85, 528 | "content": "", 529 | "single_word": false, 530 | "lstrip": false, 531 | "rstrip": false, 532 | "normalized": false, 533 | "special": true 534 | }, 535 | { 536 | "id": 86, 537 | "content": "", 538 | "single_word": false, 539 | "lstrip": false, 540 | "rstrip": false, 541 | "normalized": false, 542 | "special": true 543 | }, 544 | { 545 | "id": 87, 546 | "content": "", 547 | "single_word": false, 548 | "lstrip": false, 549 | "rstrip": false, 550 | "normalized": false, 551 | "special": true 552 | }, 553 | { 554 | "id": 88, 555 | "content": "", 556 | "single_word": false, 557 | "lstrip": false, 558 | "rstrip": false, 559 | "normalized": false, 560 | "special": true 561 | }, 562 | { 563 | "id": 89, 564 | "content": "", 565 | "single_word": false, 566 | "lstrip": false, 567 | "rstrip": false, 568 | "normalized": false, 569 | "special": true 570 | }, 571 | { 572 | "id": 90, 573 | "content": "", 574 | "single_word": false, 575 | "lstrip": false, 576 | "rstrip": false, 577 | "normalized": false, 578 | "special": true 579 | }, 580 | { 581 | "id": 91, 582 | "content": "", 583 | "single_word": false, 584 | "lstrip": false, 585 | "rstrip": false, 586 | "normalized": false, 587 | "special": true 588 | }, 589 | { 590 | "id": 92, 591 | "content": "", 592 | "single_word": false, 593 | "lstrip": false, 594 | "rstrip": false, 595 | "normalized": false, 596 | "special": true 597 | }, 598 | { 599 | "id": 93, 600 | "content": "", 601 | "single_word": false, 602 | "lstrip": false, 603 | "rstrip": false, 604 | "normalized": false, 605 | "special": true 606 | }, 607 | { 608 | "id": 94, 609 | "content": "", 610 | "single_word": false, 611 | "lstrip": false, 612 | "rstrip": false, 613 | "normalized": false, 614 | "special": true 615 | }, 616 | { 617 | "id": 95, 618 | "content": "", 619 | "single_word": false, 620 | "lstrip": false, 621 | "rstrip": false, 622 | "normalized": false, 623 | "special": true 624 | }, 625 | { 626 | "id": 96, 627 | "content": "", 628 | "single_word": false, 629 | "lstrip": false, 630 | "rstrip": false, 631 | "normalized": false, 632 | "special": true 633 | }, 634 | { 635 | "id": 97, 636 | "content": "", 637 | "single_word": false, 638 | "lstrip": false, 639 | "rstrip": false, 640 | "normalized": false, 641 | "special": true 642 | }, 643 | { 644 | "id": 98, 645 | "content": "", 646 | "single_word": false, 647 | "lstrip": false, 648 | "rstrip": false, 649 | "normalized": false, 650 | "special": true 651 | }, 652 | { 653 | "id": 99, 654 | "content": "", 655 | "single_word": false, 656 | "lstrip": false, 657 | "rstrip": false, 658 | "normalized": false, 659 | "special": true 660 | }, 661 | { 662 | "id": 100, 663 | "content": "", 664 | "single_word": false, 665 | "lstrip": false, 666 | "rstrip": false, 667 | "normalized": false, 668 | "special": true 669 | }, 670 | { 671 | "id": 101, 672 | "content": "", 673 | "single_word": false, 674 | "lstrip": false, 675 | "rstrip": false, 676 | "normalized": false, 677 | "special": true 678 | }, 679 | { 680 | "id": 102, 681 | "content": "", 682 | "single_word": false, 683 | "lstrip": false, 684 | "rstrip": false, 685 | "normalized": false, 686 | "special": true 687 | }, 688 | { 689 | "id": 103, 690 | "content": "", 691 | "single_word": false, 692 | "lstrip": false, 693 | "rstrip": false, 694 | "normalized": false, 695 | "special": true 696 | }, 697 | { 698 | "id": 104, 699 | "content": "", 700 | "single_word": false, 701 | "lstrip": false, 702 | "rstrip": false, 703 | "normalized": false, 704 | "special": true 705 | }, 706 | { 707 | "id": 105, 708 | "content": "", 709 | "single_word": false, 710 | "lstrip": false, 711 | "rstrip": false, 712 | "normalized": false, 713 | "special": true 714 | }, 715 | { 716 | "id": 106, 717 | "content": "", 718 | "single_word": false, 719 | "lstrip": false, 720 | "rstrip": false, 721 | "normalized": false, 722 | "special": true 723 | }, 724 | { 725 | "id": 107, 726 | "content": "", 727 | "single_word": false, 728 | "lstrip": false, 729 | "rstrip": false, 730 | "normalized": false, 731 | "special": true 732 | }, 733 | { 734 | "id": 108, 735 | "content": "", 736 | "single_word": false, 737 | "lstrip": false, 738 | "rstrip": false, 739 | "normalized": false, 740 | "special": true 741 | }, 742 | { 743 | "id": 109, 744 | "content": "", 745 | "single_word": false, 746 | "lstrip": false, 747 | "rstrip": false, 748 | "normalized": false, 749 | "special": true 750 | }, 751 | { 752 | "id": 110, 753 | "content": "", 754 | "single_word": false, 755 | "lstrip": false, 756 | "rstrip": false, 757 | "normalized": false, 758 | "special": true 759 | }, 760 | { 761 | "id": 111, 762 | "content": "", 763 | "single_word": false, 764 | "lstrip": false, 765 | "rstrip": false, 766 | "normalized": false, 767 | "special": true 768 | }, 769 | { 770 | "id": 112, 771 | "content": "", 772 | "single_word": false, 773 | "lstrip": false, 774 | "rstrip": false, 775 | "normalized": false, 776 | "special": true 777 | }, 778 | { 779 | "id": 113, 780 | "content": "", 781 | "single_word": false, 782 | "lstrip": false, 783 | "rstrip": false, 784 | "normalized": false, 785 | "special": true 786 | }, 787 | { 788 | "id": 114, 789 | "content": "", 790 | "single_word": false, 791 | "lstrip": false, 792 | "rstrip": false, 793 | "normalized": false, 794 | "special": true 795 | }, 796 | { 797 | "id": 115, 798 | "content": "", 799 | "single_word": false, 800 | "lstrip": false, 801 | "rstrip": false, 802 | "normalized": false, 803 | "special": true 804 | }, 805 | { 806 | "id": 116, 807 | "content": "", 808 | "single_word": false, 809 | "lstrip": false, 810 | "rstrip": false, 811 | "normalized": false, 812 | "special": true 813 | }, 814 | { 815 | "id": 117, 816 | "content": "", 817 | "single_word": false, 818 | "lstrip": false, 819 | "rstrip": false, 820 | "normalized": false, 821 | "special": true 822 | }, 823 | { 824 | "id": 118, 825 | "content": "", 826 | "single_word": false, 827 | "lstrip": false, 828 | "rstrip": false, 829 | "normalized": false, 830 | "special": true 831 | }, 832 | { 833 | "id": 119, 834 | "content": "", 835 | "single_word": false, 836 | "lstrip": false, 837 | "rstrip": false, 838 | "normalized": false, 839 | "special": true 840 | }, 841 | { 842 | "id": 120, 843 | "content": "", 844 | "single_word": false, 845 | "lstrip": false, 846 | "rstrip": false, 847 | "normalized": false, 848 | "special": true 849 | }, 850 | { 851 | "id": 121, 852 | "content": "", 853 | "single_word": false, 854 | "lstrip": false, 855 | "rstrip": false, 856 | "normalized": false, 857 | "special": true 858 | }, 859 | { 860 | "id": 122, 861 | "content": "", 862 | "single_word": false, 863 | "lstrip": false, 864 | "rstrip": false, 865 | "normalized": false, 866 | "special": true 867 | }, 868 | { 869 | "id": 123, 870 | "content": "", 871 | "single_word": false, 872 | "lstrip": false, 873 | "rstrip": false, 874 | "normalized": false, 875 | "special": true 876 | }, 877 | { 878 | "id": 124, 879 | "content": "", 880 | "single_word": false, 881 | "lstrip": false, 882 | "rstrip": false, 883 | "normalized": false, 884 | "special": true 885 | }, 886 | { 887 | "id": 125, 888 | "content": "", 889 | "single_word": false, 890 | "lstrip": false, 891 | "rstrip": false, 892 | "normalized": false, 893 | "special": true 894 | }, 895 | { 896 | "id": 126, 897 | "content": "", 898 | "single_word": false, 899 | "lstrip": false, 900 | "rstrip": false, 901 | "normalized": false, 902 | "special": true 903 | }, 904 | { 905 | "id": 127, 906 | "content": "", 907 | "single_word": false, 908 | "lstrip": false, 909 | "rstrip": false, 910 | "normalized": false, 911 | "special": true 912 | }, 913 | { 914 | "id": 128, 915 | "content": "", 916 | "single_word": false, 917 | "lstrip": false, 918 | "rstrip": false, 919 | "normalized": false, 920 | "special": true 921 | }, 922 | { 923 | "id": 129, 924 | "content": "", 925 | "single_word": false, 926 | "lstrip": false, 927 | "rstrip": false, 928 | "normalized": false, 929 | "special": true 930 | }, 931 | { 932 | "id": 130, 933 | "content": "", 934 | "single_word": false, 935 | "lstrip": false, 936 | "rstrip": false, 937 | "normalized": false, 938 | "special": true 939 | }, 940 | { 941 | "id": 131, 942 | "content": "", 943 | "single_word": false, 944 | "lstrip": false, 945 | "rstrip": false, 946 | "normalized": false, 947 | "special": true 948 | }, 949 | { 950 | "id": 132, 951 | "content": "", 952 | "single_word": false, 953 | "lstrip": false, 954 | "rstrip": false, 955 | "normalized": false, 956 | "special": true 957 | }, 958 | { 959 | "id": 133, 960 | "content": "", 961 | "single_word": false, 962 | "lstrip": false, 963 | "rstrip": false, 964 | "normalized": false, 965 | "special": true 966 | } 967 | ], 968 | "normalizer": null, 969 | "pre_tokenizer": { 970 | "type": "ByteLevel", 971 | "add_prefix_space": false, 972 | "trim_offsets": true, 973 | "use_regex": true 974 | }, 975 | "post_processor": { 976 | "type": "ByteLevel", 977 | "add_prefix_space": true, 978 | "trim_offsets": true, 979 | "use_regex": true 980 | }, 981 | "decoder": { 982 | "type": "ByteLevel", 983 | "add_prefix_space": true, 984 | "trim_offsets": true, 985 | "use_regex": true 986 | }, 987 | "model": { 988 | "type": "BPE", 989 | "dropout": null, 990 | "continuing_subword_prefix": null, 991 | "end_of_word_suffix": null, 992 | "fuse_unk": false, 993 | "byte_fallback": false, 994 | "vocab": { 995 | "": 0, 996 | "": 1, 997 | "": 2, 998 | "": 3, 999 | "": 4, 1000 | "": 5, 1001 | "1": 6, 1002 | "2": 7, 1003 | "A": 8, 1004 | "B": 9, 1005 | "C": 10, 1006 | "D": 11, 1007 | "E": 12, 1008 | "F": 13, 1009 | "G": 14, 1010 | "H": 15, 1011 | "I": 16, 1012 | "J": 17, 1013 | "K": 18, 1014 | "L": 19, 1015 | "M": 20, 1016 | "N": 21, 1017 | "O": 22, 1018 | "P": 23, 1019 | "Q": 24, 1020 | "R": 25, 1021 | "S": 26, 1022 | "T": 27, 1023 | "U": 28, 1024 | "V": 29, 1025 | "W": 30, 1026 | "X": 31, 1027 | "Y": 32, 1028 | "Z": 33 1029 | }, 1030 | "merges": [] 1031 | } 1032 | } 1033 | -------------------------------------------------------------------------------- /src/progen3/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tokenizers import Tokenizer 4 | 5 | END_OF_SPAN_TOKEN = "" # nosec 6 | PAD_TOKEN_ID = 0 7 | 8 | 9 | def get_tokenizer() -> Tokenizer: 10 | fname = os.path.join(os.path.dirname(__file__), "tokenizer.json") 11 | tokenizer: Tokenizer = Tokenizer.from_file(fname) 12 | assert ( 13 | tokenizer.padding["pad_id"] == PAD_TOKEN_ID 14 | ), f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}" 15 | 16 | return tokenizer 17 | -------------------------------------------------------------------------------- /src/progen3/tools/.gitignore: -------------------------------------------------------------------------------- 1 | *.fasta 2 | *.csv 3 | data/ 4 | -------------------------------------------------------------------------------- /src/progen3/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Profluent-AI/progen3/5b393afff2500a62858471f77fbcb59b20c0aa91/src/progen3/tools/__init__.py -------------------------------------------------------------------------------- /src/progen3/tools/generate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import click 4 | from torch.distributed.elastic.multiprocessing.errors import record 5 | 6 | from progen3.common import dist 7 | from progen3.generator import ProGen3Generator 8 | from progen3.tools.utils import get_progen3_model 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @record 14 | @click.command() 15 | @click.option("--prompt-file", type=str, required=True, help="Must be a .csv file with a 'sequence' column.") 16 | @click.option("--model-name", type=str, default="progen3-base") 17 | @click.option("--n-per-prompt", type=int, default=1, help="Number of sequences to generate per prompt.") 18 | @click.option("--output-dir", type=str, default=".", help="Must be a directory.") 19 | @click.option("--max-batch-tokens", type=int, default=65536, help="Number of sequences to score in a batch.") 20 | @click.option("--fsdp", "fsdp", is_flag=True, help="Use FSDP.") 21 | @click.option("--temperature", type=float, default=0.2, help="Temperature for generation.") 22 | @click.option("--top-p", type=float, default=0.95, help="Top-p for generation.") 23 | def generate( 24 | prompt_file: str, 25 | model_name: str, 26 | n_per_prompt: int, 27 | output_dir: str, 28 | max_batch_tokens: int, 29 | fsdp: bool, 30 | temperature: float, 31 | top_p: float, 32 | ) -> None: 33 | if not dist.is_initialized() and fsdp: 34 | raise ValueError("Distributed training is not initialized but fsdp is set to True.") 35 | model = get_progen3_model(model_name, use_fsdp=fsdp) 36 | generator = ProGen3Generator( 37 | model=model, 38 | max_batch_tokens=max_batch_tokens, 39 | temperature=temperature, 40 | top_p=top_p, 41 | ) 42 | generator.run(prompt_file, output_dir, n_per_prompt) 43 | 44 | 45 | if __name__ == "__main__": 46 | logging.basicConfig( 47 | format="[%(asctime)s] [%(levelname)s] [%(name)s:%(lineno)s:%(funcName)s] %(message)s", 48 | datefmt="%m/%d/%Y %H:%M:%S", 49 | level=logging.INFO, 50 | ) 51 | dist.setup_dist() 52 | try: 53 | generate() 54 | finally: 55 | dist.destroy_process_group() 56 | -------------------------------------------------------------------------------- /src/progen3/tools/score.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import click 4 | 5 | from progen3.common import dist 6 | from progen3.scorer import ProGen3Scorer 7 | from progen3.tools.utils import get_progen3_model, seed_all 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @click.command() 13 | @click.option("--model-name", type=str, required=True) 14 | @click.option("--fasta-path", type=str, required=True) 15 | @click.option("--output-path", type=str, required=True) 16 | @click.option( 17 | "--max-batch-tokens", 18 | type=int, 19 | default=65536, 20 | help="Maximum number of tokens to score in a batch. Dependent on GPU memory.", 21 | ) 22 | @click.option("--fsdp", "fsdp", is_flag=True, help="Use fsdp.") 23 | @click.option("--seed", "seed", type=int, help="Seed for random number generators.", default=42) 24 | def score( 25 | fasta_path: str, 26 | model_name: str, 27 | output_path: str, 28 | max_batch_tokens: int, 29 | fsdp: bool, 30 | seed: int, 31 | ) -> None: 32 | logger.info(f"Using fsdp: {fsdp}") 33 | if not dist.is_initialized() and fsdp: 34 | raise ValueError("Distributed training is not initialized but fsdp is set to True.") 35 | seed_all(seed) 36 | model = get_progen3_model(model_name, use_fsdp=fsdp) 37 | scorer = ProGen3Scorer(model, max_batch_tokens=max_batch_tokens) 38 | scorer.run(fasta_path, output_path) 39 | 40 | 41 | if __name__ == "__main__": 42 | logging.basicConfig( 43 | format="[%(asctime)s] [%(levelname)s] [%(name)s:%(lineno)s:%(funcName)s] %(message)s", 44 | datefmt="%m/%d/%Y %H:%M:%S", 45 | level=logging.INFO, 46 | ) 47 | dist.setup_dist() 48 | try: 49 | score() 50 | finally: 51 | dist.destroy_process_group() 52 | -------------------------------------------------------------------------------- /src/progen3/tools/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from itertools import islice 4 | from pathlib import Path 5 | from typing import Iterator, List, TypeVar 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from progen3.common.model_loading import get_model 11 | from progen3.modeling import ProGen3ForCausalLM 12 | 13 | AVAILABLE_MODELS = [ 14 | "Profluent-Bio/progen3-3b", 15 | "Profluent-Bio/progen3-1b", 16 | "Profluent-Bio/progen3-762m", 17 | "Profluent-Bio/progen3-339m", 18 | "Profluent-Bio/progen3-219m", 19 | "Profluent-Bio/progen3-112m", 20 | ] 21 | FILE_DIR = Path(__file__).parent 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def get_progen3_model(model_name: str, use_fsdp: bool) -> ProGen3ForCausalLM: 27 | """ 28 | Initialize model only on rank 0 in cpu. Rest are initialized with empty weights. 29 | Returns a FSDP wrapped model. 30 | """ 31 | if model_name not in AVAILABLE_MODELS: 32 | logger.warning(f"Model {model_name} not in AVAILABLE_MODELS; assuming its a local path.") 33 | 34 | model = get_model( 35 | model_name_or_path=model_name, model_class=ProGen3ForCausalLM, fsdp=use_fsdp, dtype=torch.bfloat16 36 | ) 37 | return model 38 | 39 | 40 | def seed_all(seed: int) -> None: 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | 46 | 47 | def write_fasta_sequences(file_path: str, sequences: dict[str, str]) -> None: 48 | with open(file_path, "w") as f: 49 | for seq_id, seq in sequences.items(): 50 | f.write(f">{seq_id}\n{seq}\n") 51 | 52 | 53 | T = TypeVar("T") 54 | 55 | 56 | def batched(iterator: Iterator[T], n: int) -> Iterator[List[T]]: 57 | "Batch data into lists of length n. The last batch may be shorter." 58 | # batched('ABCDEFG', 3) --> ABC DEF G 59 | while True: 60 | batch = list(islice(iterator, n)) 61 | if not batch: 62 | return 63 | yield batch 64 | --------------------------------------------------------------------------------