├── .gitignore ├── LICENSE ├── README.md ├── bioreason ├── __init__.py ├── dataset │ ├── __init__.py │ ├── kegg.py │ ├── utils.py │ └── variant_effect.py ├── dna_modules │ ├── __init__.py │ ├── dna_module.py │ └── nucleotide_module.py ├── models │ ├── __init__.py │ ├── dl │ │ ├── __init__.py │ │ ├── chat_template_dl.py │ │ ├── configuration_dl.py │ │ └── processing_dl.py │ ├── dna_llm.py │ ├── dna_only.py │ └── evo2_tokenizer.py ├── trainer │ ├── __init__.py │ ├── demo_grpo.py │ ├── grpo_config.py │ └── grpo_trainer.py └── utils │ ├── __init__.py │ └── dna_utils.py ├── figures ├── Figure1.png ├── Figure2.png └── Figure3.png ├── grpo_trainer_lora_model ├── adapter_config.json └── ds_config_stage2.json ├── pyproject.toml ├── reason.py ├── requirements.txt ├── sh_reason.sh ├── sh_train_dna_only.sh ├── sh_train_dna_qwen.sh ├── train_dna_only.py └── train_dna_qwen.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | .idea/ 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | wandb/ 7 | .DS_Store 8 | .vscode/ 9 | .venv/ 10 | .env 11 | .pytest_cache/ 12 | 13 | # C extensions 14 | *.so 15 | 16 | outputs/ 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # UV 107 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | #uv.lock 111 | 112 | # poetry 113 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 114 | # This is especially recommended for binary packages to ensure reproducibility, and is more 115 | # commonly ignored for libraries. 116 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 117 | #poetry.lock 118 | 119 | # pdm 120 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 121 | #pdm.lock 122 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 123 | # in version control. 124 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 125 | .pdm.toml 126 | .pdm-python 127 | .pdm-build/ 128 | 129 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 130 | __pypackages__/ 131 | 132 | # Celery stuff 133 | celerybeat-schedule 134 | celerybeat.pid 135 | 136 | # SageMath parsed files 137 | *.sage.py 138 | 139 | # Environments 140 | .env 141 | .venv 142 | env/ 143 | venv/ 144 | ENV/ 145 | env.bak/ 146 | venv.bak/ 147 | 148 | # Spyder project settings 149 | .spyderproject 150 | .spyproject 151 | 152 | # Rope project settings 153 | .ropeproject 154 | 155 | # mkdocs documentation 156 | /site 157 | 158 | # mypy 159 | .mypy_cache/ 160 | .dmypy.json 161 | dmypy.json 162 | 163 | # Pyre type checker 164 | .pyre/ 165 | 166 | # pytype static type analyzer 167 | .pytype/ 168 | 169 | # Cython debug symbols 170 | cython_debug/ 171 | 172 | # PyCharm 173 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 174 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 175 | # and can be added to the global gitignore or merged into this file. For a more nuclear 176 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 177 | #.idea/ 178 | 179 | # PyPI configuration file 180 | .pypirc 181 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 🧬 BioReason
Incentivizing Multimodal Biological Reasoning
within a DNA-LLM Model 3 |

4 | 5 |

6 | arXiv 7 | GitHub 8 | Website 9 | HuggingFace Dataset 10 |

11 | 12 |
13 | 14 | ## Updates [Jun 10, 2025] 15 | - We are integrating vLLM to improve the speed and efficiency of the GRPO pipeline. We expect this to be pushed by end of week. 16 | - Checkpoints along with the custom DNA-LLM model class will be released on HuggingFace by end of week. 17 | - More training results with GRPO will be shared soon. 18 | 19 |
20 | 21 | ## Abstract 22 | 23 | Unlocking deep, interpretable biological reasoning from complex genomic data is a major AI challenge hindering scientific discovery. Current DNA foundation models, despite strong sequence representation, struggle with multi-step reasoning and lack inherent transparent, biologically intuitive explanations. We introduce BioReason, a pioneering architecture that, for the first time, deeply integrates a DNA foundation model with a large language model (LLM). This novel connection enables the LLM to directly process and reason with genomic information as a fundamental input, fostering a new form of multimodal biological understanding. BioReason's sophisticated multi-step reasoning is developed through supervised fine-tuning and targeted reinforcement learning, guiding the system to generate logical, biologically coherent deductions. On biological reasoning benchmarks including KEGG-based disease pathway prediction—where accuracy improves from 88% to 97%—and variant effect prediction, BioReason demonstrates an average 15% performance gain over strong single-modality baselines. 24 | 25 |
26 | 27 | ## Key Contributions 28 | 29 | • **Novel multimodal architecture**: The first successful integration of a DNA foundation model with an LLM, establishing a new methodology for AI-driven biological studies. 30 | 31 | • **Advanced reasoning methodology**: A systematic training approach combining supervised fine-tuning and reinforcement learning that incentivizes multi-step biological reasoning. 32 | 33 | • **New biological reasoning benchmarks**: Development and curation of novel benchmarks for evaluating biological reasoning capabilities, including an annotated reasoning dataset for gene pathway and disease prediction from KEGG. 34 | 35 | • **Empirical performance improvements**: Demonstration that BioReason outperforms both DNA foundation models and LLMs used independently or in simple combination, with average performance gains of 15%+ over baseline. 36 | 37 | • **Interpretable reasoning traces**: A mechanism for generating step-by-step biological reasoning traces that provide interpretable predictions, enhancing scientific insight and hypothesis generation. 38 | 39 |
40 | 41 | ## Datasets 42 | 43 | The datasets used to train and evaluate BioReason can be found on our [HuggingFace collection](https://huggingface.co/collections/wanglab/bioreason-683cd17172a037a31d208f70) with detailed download and usage instructions. 44 | 45 |
46 | 47 | ## Checkpoints 48 | 49 | We will release the checkpoints soon! 50 | 51 |
52 | 53 | ## Installation 54 | 55 | ### Prerequisites 56 | - Python 3.11+ 57 | - CUDA/GPU for best performance 58 | 59 | ### Installation Steps 60 | ```bash 61 | # Clone the repository 62 | git clone https://github.com/bowang-lab/BioReason.git 63 | cd BioReason 64 | 65 | # Install package 66 | pip install -e . 67 | ``` 68 | 69 |
70 | 71 | ## Results 72 | 73 | ### KEGG-Derived Biological Reasoning Task 74 | Performance comparison on 290 test datapoints for multi-step mechanistic reasoning: 75 | 76 | | Model | Accuracy | F1-Score | Precision | Recall | 77 | |-------|----------|----------|-----------|---------| 78 | | [DNA] NT - 500M | 86.55 | 69.76 | 73.23 | 66.61 | 79 | | [DNA] Evo2 - 1B | 88.28 | 72.43 | 75.23 | 69.83 | 80 | | [LLM] Qwen3 - 1B | 85.17 | 65.71 | 71.39 | 64.19 | 81 | | [LLM] Qwen3 - 4B | 93.48 | 85.44 | 88.31 | 86.72 | 82 | | [DNA-LLM] NT + Qwen3 - 1B | 88.42 | 72.13 | 75.42 | 71.91 | 83 | | [DNA-LLM] NT + Qwen3 - 1B (+RL) | 89.66 | 74.11 | 78.82 | 72.96 | 84 | | [DNA-LLM] NT + Qwen3 - 4B | 96.90 | **89.03** | **90.99** | **89.38** | 85 | | [DNA-LLM] Evo2 + Qwen3 - 1B | 90.42 | 75.62 | 77.42 | 73.91 | 86 | | [DNA-LLM] Evo2 + Qwen3 - 4B | **97.24** | 86.30 | 86.75 | 87.25 | 87 | 88 | ### Variant Effect Prediction Benchmarks 89 | Performance on pathogenic/benign classification: 90 | 91 | | Model | Variant Effect - Coding | | Variant Effect - Non-SNV | | 92 | |-------|------------|----------|------------|----------| 93 | | | Accuracy | F1-Score | Accuracy | F1-Score | 94 | | [DNA] NT - 500M | 60.91 | 45.20 | 67.93 | 65.97 | 95 | | [DNA] Evo2 - 1B | 70.07 | 49.19 | 76.17 | 66.51 | 96 | | [LLM] Qwen3 - 1B | 46.55 | 34.82 | 70.67 | 76.21 | 97 | | [LLM] Qwen3 - 4B | 48.99 | 39.58 | 61.86 | 67.60 | 98 | | [DNA-LLM] NT + Qwen3 - 1B | 55.58 | 54.50 | 72.82 | 76.93 | 99 | | [DNA-LLM] NT + Qwen3 - 4B | 60.94 | 55.66 | 65.59 | 73.00 | 100 | | [DNA-LLM] Evo2 + Qwen3 - 1B | 72.83 | 68.90 | **88.20** | **89.91** | 101 | | [DNA-LLM] Evo2 + Qwen3 - 4B | **80.21** | **80.00** | 83.85 | 85.02 | 102 | 103 |
104 | 105 | ## Citation 106 | 107 | If you find this work useful, please cite our paper: 108 | 109 | ```bibtex 110 | @misc{fallahpour2025bioreasonincentivizingmultimodalbiological, 111 | title={BioReason: Incentivizing Multimodal Biological Reasoning within a DNA-LLM Model}, 112 | author={Adibvafa Fallahpour and Andrew Magnuson and Purav Gupta and Shihao Ma and Jack Naimer and Arnav Shah and Haonan Duan and Omar Ibrahim and Hani Goodarzi and Chris J. Maddison and Bo Wang}, 113 | year={2025}, 114 | eprint={2505.23579}, 115 | archivePrefix={arXiv}, 116 | primaryClass={cs.LG}, 117 | url={https://arxiv.org/abs/2505.23579}, 118 | } 119 | ``` 120 | 121 |
122 | 123 | ## Authors 124 | 125 | - **Adibvafa Fallahpour**¹²³⁵ * (adibvafa.fallahpour@mail.utoronto.ca) 126 | - **Andrew Magnuson**¹² * 127 | - **Purav Gupta**¹² * 128 | - **Shihao Ma**¹²³ 129 | - **Jack Naimer**¹²³ 130 | - **Arnav Shah**¹²³ 131 | - **Haonan Duan**¹² 132 | - **Omar Ibrahim**³ 133 | - **Hani Goodarzi**†⁴⁶ 134 | - **Chris J. Maddison**†¹²⁷ 135 | - **Bo Wang**†¹²³ 136 | 137 | ¹ University of Toronto ² Vector Institute ³ University Health Network (UHN)
138 | ⁴ Arc Institute ⁵ Cohere ⁶ University of California, San Francisco ⁷ Google DeepMind 139 | 140 |
141 | * Equal contribution
142 | † Equal advising 143 | 144 | --- 145 | 146 |

147 | Made with ❤️ at University of Toronto, Vector Institute, and University Health Network 148 |

149 | -------------------------------------------------------------------------------- /bioreason/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/BioReason/e74aa1cf06445aada1e48281840f83403b832b64/bioreason/__init__.py -------------------------------------------------------------------------------- /bioreason/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .kegg import KEGGDataset, split_kegg_dataset 2 | from .utils import torch_to_hf_dataset, truncate_dna 3 | from .variant_effect import get_format_variant_effect_function 4 | 5 | __all__ = [ 6 | "KEGGDataset", 7 | "split_kegg_dataset", 8 | "torch_to_hf_dataset", 9 | "truncate_dna", 10 | "get_format_variant_effect_function", 11 | ] 12 | -------------------------------------------------------------------------------- /bioreason/dataset/kegg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import sys 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | from typing import Any, Dict, List, Tuple 8 | 9 | from bioreason.dataset.utils import torch_to_hf_dataset 10 | from bioreason.models.dl.processing_dl import DLProcessor 11 | from trl.data_utils import maybe_apply_chat_template 12 | 13 | 14 | class KEGGDataset(Dataset): 15 | """Dataset for KEGG data.""" 16 | 17 | def __init__(self, data_dir: str): 18 | """ 19 | Initialize the dataset by loading all JSON files from the given directory. 20 | 21 | Args: 22 | data_dir: Path to the directory containing JSON files 23 | """ 24 | self.data_dir = data_dir 25 | self.data = [] 26 | 27 | # Load all JSON files 28 | json_files = sorted([f for f in os.listdir(data_dir) if f.endswith(".json")]) 29 | 30 | # Process each file 31 | for filename in json_files: 32 | file_path = os.path.join(data_dir, filename) 33 | kegg_id = filename.split("_")[1] 34 | 35 | with open(file_path, "r", encoding="utf-8") as f: 36 | item = json.load(f) 37 | item["kegg_id"] = kegg_id 38 | processed_item = self._process_item(item) 39 | self.data.append(processed_item) 40 | 41 | def _process_item(self, item: Dict[str, Any]) -> Dict[str, Any]: 42 | """ 43 | Process a single data item to format fields as required. 44 | 45 | Args: 46 | item: Original data item from JSON 47 | 48 | Returns: 49 | Processed data item 50 | """ 51 | # Extract question as is 52 | question = item.get("question", "") 53 | 54 | # Convert answer to lowercase and strip whitespace 55 | answer = item.get("answer", "").lower().strip() 56 | 57 | # Combine reasoning steps into a single paragraph with newlines 58 | reasoning_steps = item.get("reasoning", {}).get("reasoning_steps", []) 59 | reasoning = "\n".join(reasoning_steps) 60 | 61 | # Convert sequences to uppercase and strip whitespace 62 | reference_sequence = item.get("reference_sequence", "").upper().strip() 63 | variant_sequence = item.get("variant_sequence", "").upper().strip() 64 | 65 | return { 66 | "question": question, 67 | "answer": answer, 68 | "reasoning": reasoning, 69 | "reference_sequence": reference_sequence, 70 | "variant_sequence": variant_sequence, 71 | } 72 | 73 | def __len__(self) -> int: 74 | """Return the number of items in the dataset.""" 75 | return len(self.data) 76 | 77 | def __getitem__(self, idx: int) -> Dict[str, Any]: 78 | """Return a specific item from the dataset.""" 79 | return self.data[idx] 80 | 81 | 82 | def split_kegg_dataset( 83 | dataset: KEGGDataset, 84 | train_ratio: float = 0.8, 85 | val_ratio: float = 0.1, 86 | test_ratio: float = 0.1, 87 | seed: int = 42, 88 | ) -> Tuple[KEGGDataset, KEGGDataset, KEGGDataset]: 89 | """ 90 | Split a KEGG dataset into train, validation, and test sets. 91 | 92 | Args: 93 | dataset: The dataset to split 94 | train_ratio: Proportion of data for training 95 | val_ratio: Proportion of data for validation 96 | test_ratio: Proportion of data for testing 97 | batch_size: Batch size for the dataloaders 98 | seed: Random seed for reproducibility 99 | 100 | Returns: 101 | Tuple of (train_dataset, val_dataset, test_dataset) 102 | """ 103 | # Calculate the size of each split 104 | dataset_size = len(dataset) 105 | train_size = int(train_ratio * dataset_size) 106 | val_size = int(val_ratio * dataset_size) 107 | test_size = dataset_size - train_size - val_size 108 | assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1" 109 | 110 | # Set the random seed 111 | torch.manual_seed(seed) 112 | random.seed(seed) 113 | 114 | # Split the dataset 115 | train_dataset, val_dataset, test_dataset = torch.utils.data.random_split( 116 | dataset, [train_size, val_size, test_size] 117 | ) 118 | 119 | return train_dataset, val_dataset, test_dataset 120 | 121 | 122 | def create_kegg_dataloader( 123 | data_dir: str, 124 | batch_size: int = 2, 125 | shuffle: bool = True, 126 | num_workers: int = 2, 127 | pin_memory: bool = True, 128 | ) -> DataLoader: 129 | """ 130 | Create a DataLoader for the KEGG dataset. 131 | 132 | Args: 133 | data_dir: Path to the directory containing JSON files 134 | batch_size: Batch size for the dataloader 135 | shuffle: Whether to shuffle the data 136 | num_workers: Number of worker processes for loading data 137 | pin_memory: Whether to pin memory for faster data transfer 138 | 139 | Returns: 140 | DataLoader for the KEGG dataset 141 | """ 142 | dataset = KEGGDataset(data_dir) 143 | return DataLoader( 144 | dataset, 145 | batch_size=batch_size, 146 | shuffle=shuffle, 147 | num_workers=num_workers, 148 | pin_memory=pin_memory, 149 | ) 150 | 151 | 152 | def get_format_kegg_function(model_name: str) -> Any: 153 | """ 154 | Get the appropriate format function for a given model name. 155 | """ 156 | if model_name.lower() == "llm": 157 | return format_kegg_for_llm 158 | elif model_name.lower() == "dna-llm": 159 | return format_kegg_for_dna_llm 160 | else: 161 | raise ValueError(f"Unsupported model name: {model_name}") 162 | 163 | 164 | def format_kegg_for_dna_llm(example: Dict[str, Any]) -> Dict[str, Any]: 165 | """ 166 | Format a KEGG example into the required chat format for DNA-LLM. 167 | """ 168 | return { 169 | "prompt": [ 170 | { 171 | "role": "user", 172 | "content": [ 173 | *({"type": "dna", "text": None} for _ in range(2)), 174 | {"type": "text", "text": example["question"].strip()}, 175 | ], 176 | }, 177 | { 178 | "role": "assistant", 179 | "reasoning_content": example["reasoning"].strip(), 180 | "content": [ 181 | {"type": "text", "text": f"Answer: {example['answer'].strip()}"}, 182 | ], 183 | }, 184 | ], 185 | "dna_sequences": [ 186 | example["reference_sequence"], 187 | example["variant_sequence"], 188 | ], 189 | "answer": example["answer"], 190 | } 191 | 192 | 193 | def format_kegg_for_llm(example: Dict[str, Any]) -> Dict[str, Any]: 194 | """ 195 | Format a KEGG example into the required chat format for LLM. 196 | """ 197 | question = f"Reference sequence: {example['reference_sequence']}\nVariant sequence: {example['variant_sequence']}\nQuestion: {example['question']}" 198 | return { 199 | "prompt": [ 200 | { 201 | "role": "user", 202 | "content": [ 203 | *({"type": "dna", "text": None} for _ in range(2)), 204 | {"type": "text", "text": question.strip()}, 205 | ], 206 | }, 207 | { 208 | "role": "assistant", 209 | "reasoning_content": example["reasoning"].strip(), 210 | "content": [ 211 | {"type": "text", "text": f"Answer: {example['answer'].strip()}"}, 212 | ], 213 | }, 214 | ], 215 | "dna_sequences": [ 216 | "", 217 | "", 218 | ], 219 | "answer": example["answer"], 220 | } 221 | 222 | 223 | def qwen_dna_collate_fn( 224 | examples: List[Dict], 225 | processor: DLProcessor, 226 | max_length_text: int, 227 | max_length_dna: int, 228 | return_answer_in_batch: bool = False, 229 | ) -> Dict: 230 | """ 231 | Custom collate function for Qwen DNA models. 232 | 233 | Creates a batch with proper labels for supervised fine-tuning where only 234 | the assistant responses contribute to the loss calculation. 235 | """ 236 | prompts_text = [ 237 | maybe_apply_chat_template(example, processor)["prompt"] for example in examples 238 | ] 239 | batch_dna_sequences = [example["dna_sequences"] for example in examples] 240 | 241 | batch = processor( 242 | text=prompts_text, 243 | batch_dna_sequences=batch_dna_sequences, 244 | return_tensors="pt", 245 | padding=True, 246 | padding_side="left", 247 | add_special_tokens=False, 248 | max_length_text=max_length_text, 249 | max_length_dna=max_length_dna, 250 | ) 251 | 252 | # Create labels tensor filled with -100 (ignored in loss calculation) 253 | labels = torch.full_like(batch["input_ids"], -100) 254 | 255 | # Get token IDs for special markers 256 | assistant_start_marker = "<|im_start|>assistant\n" 257 | im_end_marker = "<|im_end|>" 258 | 259 | assistant_start_token_ids = processor.tokenizer.encode( 260 | assistant_start_marker, add_special_tokens=False 261 | ) 262 | im_end_token_ids = processor.tokenizer.encode( 263 | im_end_marker, add_special_tokens=False 264 | ) 265 | 266 | # Convert token arrays to tensors for faster comparison 267 | assistant_marker_tensor = torch.tensor( 268 | assistant_start_token_ids, device=batch["input_ids"].device 269 | ) 270 | im_end_marker_tensor = torch.tensor( 271 | im_end_token_ids, device=batch["input_ids"].device 272 | ) 273 | 274 | # Get dimensions for easier reference 275 | assistant_marker_len = len(assistant_start_token_ids) 276 | im_end_marker_len = len(im_end_token_ids) 277 | 278 | # For each sequence in the batch 279 | for i in range(batch["input_ids"].shape[0]): 280 | input_ids = batch["input_ids"][i] 281 | seq_len = input_ids.size(0) 282 | 283 | # Track assistant sections 284 | assistant_sections = [] 285 | 286 | # Find all assistant start markers 287 | start_positions = [] 288 | for pos in range(seq_len - assistant_marker_len + 1): 289 | if torch.all( 290 | input_ids[pos : pos + assistant_marker_len] == assistant_marker_tensor 291 | ): 292 | start_positions.append( 293 | pos + assistant_marker_len 294 | ) # Store position after marker 295 | 296 | # Find all end markers 297 | end_positions = [] 298 | for pos in range(seq_len - im_end_marker_len + 1): 299 | if torch.all( 300 | input_ids[pos : pos + im_end_marker_len] == im_end_marker_tensor 301 | ): 302 | end_positions.append(pos) # Store position at start of end marker 303 | 304 | # Match start and end markers to create sections 305 | for start_pos in start_positions: 306 | # Find the next end marker after this start position 307 | valid_ends = [pos for pos in end_positions if pos > start_pos] 308 | if valid_ends: 309 | end_pos = min(valid_ends) # Take the first end marker after start 310 | # Only include content between markers (not the markers themselves) 311 | if start_pos < end_pos: 312 | assistant_sections.append((start_pos, end_pos)) 313 | else: 314 | # If no end marker, assume the section runs to the end of the sequence 315 | assistant_sections.append((start_pos, seq_len)) 316 | 317 | # Set labels for all identified assistant sections 318 | for start_pos, end_pos in assistant_sections: 319 | if start_pos < end_pos and start_pos < seq_len: 320 | end_pos = min(end_pos, seq_len) # Safety check 321 | labels[i, start_pos:end_pos] = input_ids[start_pos:end_pos] 322 | 323 | # Also mask padding tokens 324 | labels[batch["input_ids"] == processor.tokenizer.pad_token_id] = -100 325 | 326 | # Add labels to batch 327 | batch["labels"] = labels 328 | 329 | # Add answer to batch 330 | if return_answer_in_batch: 331 | batch["answer"] = [example["answer"].strip() for example in examples] 332 | 333 | return batch 334 | 335 | 336 | def dna_collate_fn( 337 | batch: List[Dict[str, Any]], 338 | dna_tokenizer: Any, 339 | label2id: Dict[str, int], 340 | max_length: int = 2048, 341 | ) -> Dict[str, Any]: 342 | """ 343 | Custom collate function for DNA models. 344 | """ 345 | ref_sequences = [item["reference_sequence"] for item in batch] 346 | alt_sequences = [item["variant_sequence"] for item in batch] 347 | 348 | # Tokenize DNA sequences separately 349 | tokenized_ref = dna_tokenizer( 350 | ref_sequences, 351 | padding=True, 352 | truncation=True, 353 | max_length=max_length, 354 | return_tensors="pt", 355 | ) 356 | 357 | tokenized_alt = dna_tokenizer( 358 | alt_sequences, 359 | padding=True, 360 | truncation=True, 361 | max_length=max_length, 362 | return_tensors="pt", 363 | ) 364 | 365 | # Get labels 366 | labels = [] 367 | for item in batch: 368 | label = label2id[item["answer"]] 369 | labels.append(label) 370 | 371 | # Create labels tensor 372 | labels_tensor = torch.tensor(labels, dtype=torch.long) 373 | 374 | tokenized_batch = { 375 | "ref_ids": tokenized_ref.input_ids, 376 | "ref_attention_mask": tokenized_ref.attention_mask, 377 | "alt_ids": tokenized_alt.input_ids, 378 | "alt_attention_mask": tokenized_alt.attention_mask, 379 | "labels": labels_tensor, 380 | } 381 | 382 | return tokenized_batch 383 | -------------------------------------------------------------------------------- /bioreason/dataset/utils.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset as HFDataset 2 | from torch.utils.data import Dataset as TorchDataset 3 | from typing import Dict, Any, Union, List 4 | 5 | 6 | def truncate_dna( 7 | example: Dict[str, Any], truncate_dna_per_side: int = 1024 8 | ) -> Dict[str, Any]: 9 | """ 10 | Truncate DNA sequences by removing a specified number of base pairs from both ends. 11 | If the sequence is too short, it will return the middle portion. 12 | """ 13 | for key in ["reference_sequence", "variant_sequence"]: 14 | sequence = example[key] 15 | seq_len = len(sequence) 16 | 17 | if seq_len > 2 * truncate_dna_per_side + 8: 18 | example[key] = sequence[truncate_dna_per_side:-truncate_dna_per_side] 19 | 20 | return example 21 | 22 | 23 | def torch_to_hf_dataset(torch_dataset: TorchDataset) -> HFDataset: 24 | """ 25 | Convert a PyTorch Dataset to a Hugging Face Dataset. 26 | 27 | This function takes a PyTorch Dataset and converts it to a Hugging Face Dataset 28 | by extracting all items and organizing them into a dictionary structure that 29 | can be used to create a Hugging Face Dataset. 30 | 31 | Args: 32 | torch_dataset: A PyTorch Dataset object to be converted 33 | 34 | Returns: 35 | A Hugging Face Dataset containing the same data as the input PyTorch Dataset 36 | """ 37 | # Get first item to determine structure 38 | if len(torch_dataset) == 0: 39 | return HFDataset.from_dict({}) 40 | 41 | first_item = torch_dataset[0] 42 | 43 | # Initialize dictionary based on first item's keys 44 | data_dict = ( 45 | {k: [] for k in first_item.keys()} 46 | if isinstance(first_item, dict) 47 | else {"data": []} 48 | ) 49 | 50 | # Populate dictionary 51 | for i in range(len(torch_dataset)): 52 | item = torch_dataset[i] 53 | if isinstance(item, dict): 54 | for k in data_dict: 55 | data_dict[k].append(item[k]) 56 | else: 57 | data_dict["data"].append(item) 58 | 59 | return HFDataset.from_dict(data_dict) 60 | -------------------------------------------------------------------------------- /bioreason/dataset/variant_effect.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import sys 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | from typing import Any, Dict, List, Tuple 8 | 9 | from bioreason.dataset.utils import torch_to_hf_dataset 10 | from bioreason.models.dl.processing_dl import DLProcessor 11 | from trl.data_utils import maybe_apply_chat_template 12 | 13 | 14 | def get_format_variant_effect_function(model_name: str) -> Any: 15 | """ 16 | Get the appropriate format function for a given model name. 17 | """ 18 | if model_name.lower() == "llm": 19 | return format_variant_effect_for_llm 20 | elif model_name.lower() == "dna-llm": 21 | return format_variant_effect_for_dna_llm 22 | else: 23 | raise ValueError(f"Unsupported model name: {model_name}") 24 | 25 | 26 | def clean_variant_effect_example(example: Dict[str, Any]) -> Dict[str, Any]: 27 | """ 28 | Clean a variant effect example. 29 | """ 30 | example['answer'] = example['answer'].split(";")[0].strip().lower() 31 | return example 32 | 33 | 34 | def clean_variant_effect_non_snv_example(example: Dict[str, Any]) -> Dict[str, Any]: 35 | """ 36 | Clean a variant effect non-SNV example. 37 | """ 38 | example['answer'] = example['answer'].replace("[", "").replace("]", "").replace("'", "").replace("_", " ").strip() 39 | return example 40 | 41 | 42 | def format_variant_effect_for_dna_llm(example: Dict[str, Any]) -> Dict[str, Any]: 43 | """ 44 | Format a VEP example into the required chat format for DNA-LLM. 45 | """ 46 | return { 47 | "prompt": [ 48 | { 49 | "role": "user", 50 | "content": [ 51 | *({"type": "dna", "text": None} for _ in range(2)), 52 | {"type": "text", "text": example["question"].strip()}, 53 | ], 54 | }, 55 | { 56 | "role": "assistant", 57 | "reasoning_content": f"Answer: {example['answer'].strip()}", 58 | "content": [ 59 | {"type": "text", "text": f"Answer: {example['answer'].strip()}"}, 60 | ], 61 | }, 62 | ], 63 | "dna_sequences": [ 64 | example["reference_sequence"], 65 | example["variant_sequence"], 66 | ], 67 | "answer": example["answer"].strip(), 68 | } 69 | 70 | 71 | def format_variant_effect_for_llm(example: Dict[str, Any]) -> Dict[str, Any]: 72 | """ 73 | Format a VEP example into the required chat format for LLM. 74 | """ 75 | question = f"Reference sequence: {example['reference_sequence']}\nVariant sequence: {example['variant_sequence']}\nQuestion: {example['question']}" 76 | return { 77 | "prompt": [ 78 | { 79 | "role": "user", 80 | "content": [ 81 | *({"type": "dna", "text": None} for _ in range(2)), 82 | {"type": "text", "text": question.strip()}, 83 | ], 84 | }, 85 | { 86 | "role": "assistant", 87 | "reasoning_content": f"Answer: {example['answer'].strip()}", 88 | "content": [ 89 | {"type": "text", "text": f"Answer: {example['answer'].strip()}"}, 90 | ], 91 | }, 92 | ], 93 | "dna_sequences": [ 94 | "", 95 | "", 96 | ], 97 | "answer": example["answer"].strip(), 98 | } -------------------------------------------------------------------------------- /bioreason/dna_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .dna_module import DNABaseModule 2 | from .nucleotide_module import NucleotideDNAModule 3 | 4 | __all__ = ["DNABaseModule", "NucleotideDNAModule"] -------------------------------------------------------------------------------- /bioreason/dna_modules/dna_module.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Any, Union 3 | import torch 4 | 5 | class DNABaseModule(ABC): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | @abstractmethod 10 | def get_dnallm_key(self): 11 | pass 12 | 13 | @abstractmethod 14 | def get_model_class(self, model_id: str, model_init_kwargs: dict): 15 | pass 16 | 17 | def post_model_init(self, model, processing_class): 18 | pass 19 | 20 | def is_embeds_input(self): 21 | return False 22 | 23 | @abstractmethod 24 | def get_processing_class(self): 25 | pass 26 | 27 | @abstractmethod 28 | def get_dnallm_modules_keywords(self): 29 | pass 30 | 31 | @abstractmethod 32 | def get_custom_multimodal_keywords(self): 33 | pass 34 | 35 | @abstractmethod 36 | def get_non_generate_params(self): 37 | pass 38 | 39 | @abstractmethod 40 | def get_custom_processing_keywords(self): 41 | pass 42 | 43 | @abstractmethod 44 | def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]): 45 | pass 46 | 47 | @abstractmethod 48 | def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors, padding, padding_side, add_special_tokens): 49 | pass -------------------------------------------------------------------------------- /bioreason/dna_modules/nucleotide_module.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | Qwen2_5_VLForConditionalGeneration, 3 | Qwen2VLForConditionalGeneration, 4 | AutoProcessor, 5 | ) 6 | from typing import Dict, Any, Union, List, Optional, Callable, Type 7 | from trl.data_utils import maybe_apply_chat_template 8 | from trl import SFTTrainer 9 | import torch 10 | 11 | from bioreason.dna_modules.dna_module import DNABaseModule 12 | from bioreason.models.dna_llm import DNALLMModel 13 | from bioreason.models.dl.processing_dl import DLProcessor 14 | 15 | 16 | class NucleotideDNAModule(DNABaseModule): 17 | """ 18 | DNA module implementation for NucleotideTransformer-based models. 19 | 20 | This module provides the interface between DNA-LLM models and the training 21 | infrastructure, handling model loading, processing setup, and reward functions. 22 | """ 23 | 24 | def __init__(self): 25 | """Initialize the NucleotideDNAModule.""" 26 | super().__init__() 27 | 28 | def get_dnallm_key(self) -> str: 29 | """ 30 | Get the key identifier for this DNA-LLM implementation. 31 | 32 | Returns: 33 | String identifier for this module type 34 | """ 35 | return "qwen" 36 | 37 | def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type: 38 | """ 39 | Return the appropriate model class based on model ID. 40 | 41 | Args: 42 | model_id: Identifier for the model 43 | model_init_kwargs: Initialization arguments for the model 44 | 45 | Returns: 46 | The model class to instantiate 47 | 48 | Raises: 49 | ValueError: If the model is not supported 50 | """ 51 | if "DNALLM" in model_id: 52 | model_cls = DNALLMModel 53 | else: 54 | raise ValueError(f"Unsupported model: {model_id}") 55 | return model_cls 56 | 57 | def post_model_init(self, model: Any, processing_class: Any) -> None: 58 | """ 59 | Perform any post-initialization setup on the model. 60 | 61 | Args: 62 | model: The initialized model 63 | processing_class: The processor for the model 64 | """ 65 | # No post-init needed for this implementation 66 | pass 67 | 68 | def get_processing_class(self) -> Type: 69 | """ 70 | Get the processing class to use with this DNA-LLM model. 71 | 72 | Returns: 73 | The processing class 74 | """ 75 | return DLProcessor 76 | 77 | def get_dnallm_modules_keywords(self) -> List[str]: 78 | """ 79 | Get keywords to identify DNA-specific modules in the model. 80 | 81 | Used to exclude DNA modules from LoRA adaptation during training. 82 | 83 | Returns: 84 | List of keywords that identify DNA modules 85 | """ 86 | return ["dna"] 87 | 88 | def get_custom_multimodal_keywords(self) -> List[str]: 89 | """ 90 | Get keywords for multimodal inputs that should be passed to the model. 91 | 92 | Returns: 93 | List of input keywords for multimodal processing 94 | """ 95 | return ["dna_tokenized", "batch_idx_map"] 96 | 97 | def get_non_generate_params(self) -> List[str]: 98 | """ 99 | Get parameter names that should be excluded from generation. 100 | 101 | Returns: 102 | List of parameter names to exclude from generation calls 103 | """ 104 | return [] 105 | 106 | def get_custom_processing_keywords(self) -> List[tuple]: 107 | """ 108 | Get custom processing keywords for the processor. 109 | 110 | Returns: 111 | List of (component, parameter) tuples for custom processing 112 | """ 113 | return [("dna_tokenizer", "max_length")] 114 | 115 | def prepare_prompt( 116 | self, processing_class: Any, inputs: List[Dict[str, Union[torch.Tensor, Any]]] 117 | ) -> List[str]: 118 | """ 119 | Prepare prompts from input examples. 120 | 121 | Args: 122 | processing_class: The processor to use 123 | inputs: List of input examples 124 | 125 | Returns: 126 | List of prepared prompts 127 | """ 128 | prompts_text = [ 129 | maybe_apply_chat_template(example, processing_class)["prompt"] 130 | for example in inputs 131 | ] 132 | return prompts_text 133 | 134 | def prepare_model_inputs( 135 | self, 136 | processing_class: Any, 137 | model: Any, 138 | prompts_text: List[str], 139 | batch_dna_sequences: List[List[str]], 140 | return_tensors: str = "pt", 141 | padding: bool = True, 142 | padding_side: str = "left", 143 | add_special_tokens: bool = False, 144 | ) -> Dict[str, Any]: 145 | """ 146 | Prepare inputs for the model. 147 | 148 | Args: 149 | processing_class: The processor to use 150 | model: The model to prepare inputs for 151 | prompts_text: List of text prompts 152 | batch_dna_sequences: List of lists of DNA sequences 153 | return_tensors: Return format for tensors 154 | padding: Whether to pad inputs 155 | padding_side: Side to pad on 156 | add_special_tokens: Whether to add special tokens 157 | 158 | Returns: 159 | Processed inputs for the model 160 | """ 161 | # Handle DataParallel wrapped models by accessing the module attribute if needed 162 | max_length_text = model.max_length_text if not hasattr(model, 'module') else model.module.max_length_text 163 | max_length_dna = model.max_length_dna if not hasattr(model, 'module') else model.module.max_length_dna 164 | 165 | prompt_inputs = processing_class( 166 | text=prompts_text, 167 | batch_dna_sequences=batch_dna_sequences, 168 | return_tensors=return_tensors, 169 | padding=padding, 170 | padding_side=padding_side, 171 | add_special_tokens=add_special_tokens, 172 | max_length_text=max_length_text, 173 | max_length_dna=max_length_dna, 174 | ) 175 | 176 | return prompt_inputs 177 | 178 | def is_embeds_input(self) -> bool: 179 | """ 180 | Whether the model uses embeddings as input (instead of token IDs). 181 | 182 | Returns: 183 | Boolean indicating if the model takes embedding inputs 184 | """ 185 | return True 186 | 187 | @staticmethod 188 | def get_question_template() -> str: 189 | """ 190 | Get the template for formatting questions. 191 | 192 | Returns: 193 | String template for questions 194 | """ 195 | return "{Question}" 196 | 197 | @staticmethod 198 | def format_reward_rec(completions: List[Dict[str, Any]], **kwargs) -> List[float]: 199 | """ 200 | Check if the Qwen model output matches a specific format. 201 | 202 | Args: 203 | completions: List of model completions 204 | **kwargs: Additional arguments 205 | 206 | Returns: 207 | List of reward scores (1.0 for match, 0.0 for no match) 208 | """ 209 | import re 210 | import os 211 | from datetime import datetime 212 | 213 | # Pattern to match the expected output format 214 | pattern = r".*?\s*.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?" 215 | completion_contents = [completion[0]["content"] for completion in completions] 216 | matches = [ 217 | re.search(pattern, content, re.DOTALL) is not None 218 | for content in completion_contents 219 | ] 220 | 221 | # Log format results if in debug mode 222 | current_time = datetime.now().strftime("%d-%H-%M-%S-%f") 223 | if os.getenv("DEBUG_MODE") == "true": 224 | log_path = os.getenv("LOG_PATH") 225 | with open( 226 | log_path.replace(".txt", "_format.txt"), "a", encoding="utf-8" 227 | ) as f: 228 | f.write(f"------------- {current_time} Format reward -------------\n") 229 | for content, match in zip(completion_contents, matches): 230 | f.write(f"Content: {content}\n") 231 | f.write(f"Has format: {bool(match)}\n") 232 | 233 | return [1.0 if match else 0.0 for match in matches] 234 | 235 | @staticmethod 236 | def select_reward_func(func: str, task_type: str) -> Callable: 237 | """ 238 | Select the appropriate reward function based on function name and task type. 239 | 240 | Args: 241 | func: The type of reward function ('accuracy', 'format', etc.) 242 | task_type: The type of task ('rec', etc.) 243 | 244 | Returns: 245 | The reward function to use 246 | 247 | Raises: 248 | ValueError: If the function or task type is not supported 249 | """ 250 | if func == "accuracy": 251 | match task_type: 252 | case "rec": 253 | return NucleotideDNAModule.iou_reward 254 | case _: 255 | raise ValueError(f"Unsupported reward function: {func}") 256 | elif func == "format": 257 | match task_type: 258 | case "rec": 259 | return NucleotideDNAModule.format_reward_rec 260 | case _: 261 | raise ValueError(f"Unsupported reward function: {func}") 262 | else: 263 | raise ValueError(f"Unsupported reward function: {func}") -------------------------------------------------------------------------------- /bioreason/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dna_only import DNAClassifierModel 2 | from .dna_llm import DNALLMModel 3 | from .evo2_tokenizer import Evo2Tokenizer 4 | 5 | __all__ = [ 6 | "DNAClassifierModel", 7 | "DNALLMModel", 8 | "Evo2Tokenizer", 9 | ] 10 | -------------------------------------------------------------------------------- /bioreason/models/dl/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /bioreason/models/dl/chat_template_dl.py: -------------------------------------------------------------------------------- 1 | CHAT_TEMPLATE = "{%- set dna_count = namespace(value=0) %}{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content is string and message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' }} {%- if message.content is string %}{{- message.content + '<|im_end|>' + '\\n' }}{%- else %}{%- for content in message.content %}{%- if content.type == 'dna' or 'dna' in content %}{%- set dna_count.value = dna_count.value + 1 %}{%- if add_dna_id %}DNA Sequence {{- dna_count.value }}: {%- endif %}<|dna_start|><|dna_pad|><|dna_end|>{%- elif 'text' in content %}{{- content.text }}{%- endif %}{%- endfor %}{{- '<|im_end|>' + '\\n' }}{%- endif %}{%- elif message.role == \"assistant\" %}\n {%- set content = message.content[0].text %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content[0].text.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content[0].text.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}" -------------------------------------------------------------------------------- /bioreason/models/dl/configuration_dl.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | class DLDNAConfig(PretrainedConfig): 4 | model_type = "dl" 5 | base_config_key = "dna_config" 6 | 7 | def __init__( 8 | self, 9 | depth=32, 10 | hidden_size=3584, 11 | hidden_act="silu", 12 | intermediate_size=3420, 13 | num_heads=16, 14 | in_channels=3, 15 | patch_size=14, 16 | spatial_merge_size=2, 17 | temporal_patch_size=2, 18 | tokens_per_second=4, 19 | window_size=112, 20 | out_hidden_size=3584, 21 | fullatt_block_indexes=[7, 15, 23, 31], 22 | **kwargs, 23 | ): 24 | super().__init__(**kwargs) 25 | 26 | self.depth = depth 27 | self.hidden_size = hidden_size 28 | self.hidden_act = hidden_act 29 | self.intermediate_size = intermediate_size 30 | self.num_heads = num_heads 31 | self.in_channels = in_channels 32 | self.patch_size = patch_size 33 | self.spatial_merge_size = spatial_merge_size 34 | self.temporal_patch_size = temporal_patch_size 35 | self.tokens_per_second = tokens_per_second 36 | self.window_size = window_size 37 | self.fullatt_block_indexes = fullatt_block_indexes 38 | self.out_hidden_size = out_hidden_size 39 | 40 | class DLConfig(PretrainedConfig): 41 | r""" 42 | This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a 43 | Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration 44 | with the defaults will yield a similar configuration to that of 45 | Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). 46 | 47 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 48 | documentation from [`PretrainedConfig`] for more information. 49 | 50 | 51 | Args: 52 | vocab_size (`int`, *optional*, defaults to 152064): 53 | Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the 54 | `inputs_ids` passed when calling [`Qwen2_5_VLModel`] 55 | hidden_size (`int`, *optional*, defaults to 8192): 56 | Dimension of the hidden representations. 57 | intermediate_size (`int`, *optional*, defaults to 29568): 58 | Dimension of the MLP representations. 59 | num_hidden_layers (`int`, *optional*, defaults to 80): 60 | Number of hidden layers in the Transformer encoder. 61 | num_attention_heads (`int`, *optional*, defaults to 64): 62 | Number of attention heads for each attention layer in the Transformer encoder. 63 | num_key_value_heads (`int`, *optional*, defaults to 8): 64 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 65 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 66 | `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When 67 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 68 | by meanpooling all the original heads within that group. For more details checkout [this 69 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. 70 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 71 | The non-linear activation function (function or string) in the decoder. 72 | max_position_embeddings (`int`, *optional*, defaults to 32768): 73 | The maximum sequence length that this model might ever be used with. 74 | initializer_range (`float`, *optional*, defaults to 0.02): 75 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 76 | rms_norm_eps (`float`, *optional*, defaults to 1e-05): 77 | The epsilon used by the rms normalization layers. 78 | use_cache (`bool`, *optional*, defaults to `True`): 79 | Whether or not the model should return the last key/values attentions (not used by all models). Only 80 | relevant if `config.is_decoder=True`. 81 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 82 | Whether the model's input and output word embeddings should be tied. 83 | rope_theta (`float`, *optional*, defaults to 1000000.0): 84 | The base period of the RoPE embeddings. 85 | use_sliding_window (`bool`, *optional*, defaults to `False`): 86 | Whether to use sliding window attention. 87 | sliding_window (`int`, *optional*, defaults to 4096): 88 | Sliding window attention (SWA) window size. If not specified, will default to `4096`. 89 | max_window_layers (`int`, *optional*, defaults to 80): 90 | The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. 91 | attention_dropout (`float`, *optional*, defaults to 0.0): 92 | The dropout ratio for the attention probabilities. 93 | vision_config (`Dict`, *optional*): 94 | The config for the visual encoder initialization. 95 | rope_scaling (`Dict`, *optional*): 96 | Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type 97 | and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value 98 | accordingly. 99 | Expected contents: 100 | `rope_type` (`str`): 101 | The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 102 | 'llama3'], with 'default' being the original RoPE implementation. 103 | `factor` (`float`, *optional*): 104 | Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In 105 | most scaling types, a `factor` of x will enable the model to handle sequences of length x * 106 | original maximum pre-trained length. 107 | `original_max_position_embeddings` (`int`, *optional*): 108 | Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during 109 | pretraining. 110 | `attention_factor` (`float`, *optional*): 111 | Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention 112 | computation. If unspecified, it defaults to value recommended by the implementation, using the 113 | `factor` field to infer the suggested value. 114 | `beta_fast` (`float`, *optional*): 115 | Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear 116 | ramp function. If unspecified, it defaults to 32. 117 | `beta_slow` (`float`, *optional*): 118 | Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear 119 | ramp function. If unspecified, it defaults to 1. 120 | `short_factor` (`List[float]`, *optional*): 121 | Only used with 'longrope'. The scaling factor to be applied to short contexts (< 122 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden 123 | size divided by the number of attention heads divided by 2 124 | `long_factor` (`List[float]`, *optional*): 125 | Only used with 'longrope'. The scaling factor to be applied to long contexts (< 126 | `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden 127 | size divided by the number of attention heads divided by 2 128 | `low_freq_factor` (`float`, *optional*): 129 | Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE 130 | `high_freq_factor` (`float`, *optional*): 131 | Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE 132 | 133 | ```python 134 | >>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig 135 | 136 | >>> # Initializing a Qwen2_5_VL style configuration 137 | >>> configuration = Qwen2_5_VLConfig() 138 | 139 | >>> # Initializing a model from the Qwen2-VL-7B style configuration 140 | >>> model = Qwen2_5_VLForConditionalGeneration(configuration) 141 | 142 | >>> # Accessing the model configuration 143 | >>> configuration = model.config 144 | ```""" 145 | 146 | model_type = "dl" 147 | sub_configs = {"dna_config": DLDNAConfig} 148 | keys_to_ignore_at_inference = ["past_key_values"] 149 | # Default tensor parallel plan for base model `Qwen2_5_VL` 150 | base_model_tp_plan = { 151 | "layers.*.self_attn.q_proj": "colwise", 152 | "layers.*.self_attn.k_proj": "colwise", 153 | "layers.*.self_attn.v_proj": "colwise", 154 | "layers.*.self_attn.o_proj": "rowwise", 155 | "layers.*.mlp.gate_proj": "colwise", 156 | "layers.*.mlp.up_proj": "colwise", 157 | "layers.*.mlp.down_proj": "rowwise", 158 | } 159 | base_model_pp_plan = { 160 | "embed_tokens": (["input_ids"], ["inputs_embeds"]), 161 | "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), 162 | "norm": (["hidden_states"], ["hidden_states"]), 163 | } 164 | 165 | def __init__( 166 | self, 167 | vocab_size=152064, 168 | hidden_size=8192, 169 | intermediate_size=29568, 170 | num_hidden_layers=80, 171 | num_attention_heads=64, 172 | num_key_value_heads=8, 173 | hidden_act="silu", 174 | max_position_embeddings=32768, 175 | initializer_range=0.02, 176 | rms_norm_eps=1e-05, 177 | use_cache=True, 178 | tie_word_embeddings=False, 179 | rope_theta=1000000.0, 180 | use_sliding_window=False, 181 | sliding_window=4096, 182 | max_window_layers=80, 183 | attention_dropout=0.0, 184 | vision_config=None, 185 | rope_scaling=None, 186 | image_token_id=None, 187 | **kwargs, 188 | ): 189 | if isinstance(vision_config, dict): 190 | self.vision_config = self.sub_configs["vision_config"](**vision_config) 191 | elif vision_config is None: 192 | self.vision_config = self.sub_configs["vision_config"]() 193 | 194 | self.vocab_size = vocab_size 195 | self.max_position_embeddings = max_position_embeddings 196 | self.hidden_size = hidden_size 197 | self.intermediate_size = intermediate_size 198 | self.num_hidden_layers = num_hidden_layers 199 | self.num_attention_heads = num_attention_heads 200 | self.use_sliding_window = use_sliding_window 201 | self.sliding_window = sliding_window 202 | self.max_window_layers = max_window_layers 203 | 204 | # for backward compatibility 205 | if num_key_value_heads is None: 206 | num_key_value_heads = num_attention_heads 207 | 208 | self.num_key_value_heads = num_key_value_heads 209 | self.hidden_act = hidden_act 210 | self.initializer_range = initializer_range 211 | self.rms_norm_eps = rms_norm_eps 212 | self.use_cache = use_cache 213 | self.rope_theta = rope_theta 214 | self.attention_dropout = attention_dropout 215 | self.rope_scaling = rope_scaling 216 | 217 | self.dna_token_id = image_token_id 218 | 219 | # Validate the correctness of rotary position embeddings parameters 220 | # BC: if there is a 'type' field, move it to 'rope_type'. 221 | # and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations 222 | # one can set it to "linear"/"dynamic" etc. to have scaled RoPE 223 | # TODO: @raushan update config in the hub 224 | if self.rope_scaling is not None and "type" in self.rope_scaling: 225 | if self.rope_scaling["type"] == "mrope": 226 | self.rope_scaling["type"] = "default" 227 | self.rope_scaling["rope_type"] = self.rope_scaling["type"] 228 | rope_config_validation(self, ignore_keys={"mrope_section"}) 229 | 230 | super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) 231 | 232 | __all__ = ["DLConfig"] -------------------------------------------------------------------------------- /bioreason/models/dl/processing_dl.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union, Dict, Any, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from transformers import AutoTokenizer 8 | from transformers.processing_utils import ( 9 | CommonKwargs, 10 | ProcessingKwargs, 11 | ProcessorMixin, 12 | Unpack, 13 | ) 14 | from transformers.feature_extraction_utils import BatchFeature 15 | from transformers.tokenization_utils_base import PreTokenizedInput, TextInput 16 | from transformers.utils import logging 17 | 18 | from bioreason.utils.dna_utils import DNAInput 19 | 20 | class DLDNAKwargs(CommonKwargs): 21 | """Keyword arguments specific to DNA processing""" 22 | max_length_text: Optional[int] 23 | max_length_dna: Optional[int] 24 | 25 | 26 | class DLProcessorKwargs(ProcessingKwargs, total=False): 27 | """Processing keyword arguments for the DL processor""" 28 | dna_kwargs: DLDNAKwargs 29 | _defaults = { 30 | "text_kwargs": { 31 | "padding": False, 32 | }, 33 | } 34 | 35 | class DLProcessor(ProcessorMixin): 36 | r""" 37 | Constructs a DL processor which wraps a NucleotideTransformer DNA processor and a Qwen2_5 tokenizer into a single processor. 38 | This processor handles both text and DNA sequence processing to prepare inputs for the DNALLMModel. 39 | 40 | Args: 41 | tokenizer (PreTrainedTokenizerBase, *optional*): 42 | The text tokenizer used for processing text inputs. 43 | dna_tokenizer (PreTrainedTokenizerBase, *optional*): 44 | The DNA tokenizer used for processing DNA sequences. 45 | chat_template (`str`, *optional*): 46 | A Jinja template for chat formatting. If None, will use the tokenizer's template. 47 | """ 48 | 49 | attributes = ["tokenizer", "dna_tokenizer"] 50 | valid_kwargs = ["model", "chat_template"] 51 | tokenizer_class = ( 52 | "Qwen2Tokenizer", "Qwen2TokenizerFast", 53 | "GPT2TokenizerFast", 54 | ) 55 | dna_tokenizer_class = ("EsmTokenizer", "Evo2Tokenizer") 56 | 57 | def __init__( 58 | self, tokenizer=None, dna_tokenizer=None, chat_template=None, **kwargs 59 | ): 60 | """ 61 | Initialize the processor with text and DNA tokenizers. 62 | 63 | Args: 64 | tokenizer: Text tokenizer (usually from a language model) 65 | dna_tokenizer: DNA tokenizer (usually from a DNA model) 66 | chat_template: Template for formatting chat conversations 67 | **kwargs: Additional arguments 68 | """ 69 | self.tokenizer = tokenizer 70 | self.dna_tokenizer = dna_tokenizer 71 | 72 | self.dna_token = ( 73 | "<|dna_pad|>" 74 | if not hasattr(self.tokenizer, "dna_token") 75 | else self.tokenizer.dna_token 76 | ) 77 | 78 | # Get chat template from tokenizer if not provided 79 | if chat_template is None and hasattr(self.tokenizer, "chat_template"): 80 | chat_template = self.tokenizer.chat_template 81 | super().__init__(tokenizer, dna_tokenizer, chat_template=chat_template) 82 | 83 | # The GRPO trainer might expect this to be set 84 | if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None: 85 | self.tokenizer.pad_token = self.tokenizer.eos_token 86 | 87 | def tokenize_dna_sequences( 88 | self, 89 | batch_dna_sequences: List[List[str]], 90 | max_length: int = 2048, 91 | return_tensors: str = "pt", 92 | device: str = "cuda", 93 | ) -> Dict[str, Any]: 94 | """ 95 | Tokenize a batch of DNA sequences. 96 | 97 | Args: 98 | batch_dna_sequences: List of lists of DNA sequences per batch item 99 | max_length: Maximum allowed length for DNA sequences 100 | return_tensors: Return format for tensors ("pt" for PyTorch) 101 | device: Device to place tensors on 102 | 103 | Returns: 104 | Dict containing: 105 | - dna_tokenized: The tokenized DNA sequences 106 | - batch_idx_map: Mapping of which sequences belong to which batch item 107 | """ 108 | # Create a mapping to track which sequences belong to which batch item 109 | batch_idx_map = [] 110 | all_sequences = [] 111 | 112 | # Flatten all sequences with batch tracking 113 | for batch_idx, dna_sequences in enumerate(batch_dna_sequences): 114 | for seq in dna_sequences: 115 | all_sequences.append(seq) 116 | batch_idx_map.append(batch_idx) 117 | 118 | # If no sequences in the entire batch, return empty dict 119 | if not all_sequences: 120 | return {"dna_tokenized": None, "batch_idx_map": []} 121 | 122 | # Tokenize all sequences at once 123 | dna_tokenized = self.dna_tokenizer( 124 | all_sequences, 125 | padding=True, 126 | truncation=True, 127 | max_length=max_length, 128 | return_tensors=return_tensors, 129 | return_attention_mask=True, 130 | ) 131 | 132 | return {"dna_tokenized": dna_tokenized, "batch_idx_map": batch_idx_map} 133 | 134 | def __call__( 135 | self, 136 | batch_dna_sequences: Optional[List[List[str]]] = None, 137 | text: Optional[ 138 | Union[ 139 | TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] 140 | ] 141 | ] = None, 142 | max_length_text: int = 512, 143 | max_length_dna: int = 2048, 144 | return_tensors: str = "pt", 145 | device: str = "cuda", 146 | **kwargs: Unpack[DLProcessorKwargs], 147 | ) -> BatchFeature: 148 | """ 149 | Process text and DNA sequences for model input. 150 | 151 | Args: 152 | batch_dna_sequences: List of lists of DNA sequences per batch item 153 | text: Input text or list of texts 154 | max_length_text: Maximum length for text sequences 155 | max_length_dna: Maximum length for DNA sequences 156 | return_tensors: Return format for tensors 157 | device: Device to place tensors on 158 | **kwargs: Additional processor keyword arguments 159 | 160 | Returns: 161 | BatchFeature with tokenized inputs for the model 162 | """ 163 | output_kwargs = self._merge_kwargs( 164 | DLProcessorKwargs, 165 | tokenizer_init_kwargs=self.tokenizer.init_kwargs, 166 | **kwargs, 167 | ) 168 | 169 | # Ensure text is a list 170 | if not isinstance(text, list): 171 | text = [text] 172 | 173 | # flattened_dna_sequences = [dna_sequence for dna_sequences in batch_dna_sequences for dna_sequence in dna_sequences] 174 | dna_inputs = {} 175 | if batch_dna_sequences is not None: 176 | # Tokenize DNA sequences 177 | dna_processing_result = self.tokenize_dna_sequences( 178 | batch_dna_sequences, 179 | max_length=max_length_dna, 180 | return_tensors=return_tensors, 181 | device=device, 182 | ) 183 | 184 | # Replace DNA tokens in text if needed 185 | index = 0 186 | for i in range(len(text)): 187 | while self.dna_token in text[i]: 188 | num_dna_tokens = (dna_processing_result['dna_tokenized']['input_ids'][index] != 1).sum().item() 189 | text[i] = text[i].replace( 190 | self.dna_token, "<|placeholder|>" * num_dna_tokens, 1 191 | ) 192 | index += 1 193 | text[i] = text[i].replace("<|placeholder|>", self.dna_token) 194 | 195 | 196 | 197 | # Add batch info to the output 198 | dna_inputs = { 199 | # "batch_dna_sequences": batch_dna_sequences, 200 | "dna_tokenized": dna_processing_result["dna_tokenized"], 201 | "batch_idx_map": dna_processing_result["batch_idx_map"], 202 | } 203 | 204 | # Tokenize text 205 | text_kwargs = output_kwargs.get("text_kwargs", {}) 206 | 207 | if 'padding' in text_kwargs: 208 | del text_kwargs['padding'] 209 | 210 | # print("__call__ (processor):", text) 211 | text_inputs = self.tokenizer( 212 | text, 213 | max_length=max_length_text + 2 * max_length_dna, 214 | return_tensors=return_tensors, 215 | padding=True, 216 | truncation=True, 217 | **text_kwargs, 218 | ) 219 | 220 | # The BatchFeature should have all required fields for the model's forward pass 221 | return BatchFeature(data={**text_inputs, **dna_inputs}) 222 | 223 | def batch_decode(self, *args, **kwargs) -> List[str]: 224 | """ 225 | This method forwards all its arguments to the tokenizer's batch_decode. 226 | 227 | Returns: 228 | List of decoded strings 229 | """ 230 | return self.tokenizer.batch_decode(*args, **kwargs) 231 | 232 | def decode(self, *args, **kwargs) -> str: 233 | """ 234 | This method forwards all its arguments to the tokenizer's decode. 235 | 236 | Returns: 237 | Decoded string 238 | """ 239 | return self.tokenizer.decode(*args, **kwargs) 240 | 241 | def post_process_dna_to_text( 242 | self, 243 | generated_outputs: torch.Tensor, 244 | skip_special_tokens: bool = True, 245 | **kwargs, 246 | ) -> List[str]: 247 | """ 248 | Post-process the model output to decode the text. 249 | 250 | Args: 251 | generated_outputs: The token IDs generated by the model 252 | skip_special_tokens: Whether to skip special tokens in the output 253 | **kwargs: Additional arguments for the decoder 254 | 255 | Returns: 256 | List of decoded strings 257 | """ 258 | return self.tokenizer.batch_decode( 259 | generated_outputs, 260 | skip_special_tokens=skip_special_tokens, 261 | **kwargs, 262 | ) 263 | 264 | @property 265 | def model_input_names(self) -> List[str]: 266 | """ 267 | Get the input names expected by the model. 268 | 269 | Returns: 270 | List of input names 271 | """ 272 | tokenizer_input_names = self.tokenizer.model_input_names 273 | dna_input_names = ["dna_tokenized", "batch_idx_map"] 274 | 275 | return list(dict.fromkeys(tokenizer_input_names + dna_input_names)) 276 | -------------------------------------------------------------------------------- /bioreason/models/dna_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | import torch 4 | import torch.nn as nn 5 | from transformers import ( 6 | AutoTokenizer, 7 | AutoModelForCausalLM, 8 | AutoModelForMaskedLM, 9 | ) 10 | 11 | from typing import Optional, List, Dict, Any, Union, Tuple 12 | 13 | from bioreason.utils.dna_utils import DNAInput 14 | from bioreason.models.dl.processing_dl import DLProcessor 15 | from bioreason.models.dl.chat_template_dl import CHAT_TEMPLATE 16 | from bioreason.models.evo2_tokenizer import Evo2Tokenizer 17 | 18 | class DNALLMModel(nn.Module): 19 | """ 20 | A combined model that processes both DNA sequences and text inputs. 21 | 22 | The model uses a DNA encoder (like NucleotideTransformer) to extract features from DNA sequences 23 | and a text model (LLM) to process text inputs and generate responses. The DNA features are 24 | projected to the text model's embedding space and prepended to the text embeddings. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | text_model_name: str, 30 | dna_model_name: str, 31 | cache_dir: Optional[str] = None, 32 | max_length_dna: int = 2048, 33 | max_length_text: int = 512, 34 | text_model_finetune: bool = True, 35 | dna_model_finetune: bool = True, 36 | dna_is_evo2: bool = False, 37 | dna_embedding_layer: str = None 38 | ): 39 | """ 40 | Initialize the DNALLMModel. 41 | 42 | Args: 43 | text_model_name: Name of the text model to be used. 44 | dna_model_name: Name of the DNA model to be used. 45 | cache_dir: Directory to cache the models. 46 | max_length_dna: Maximum length of DNA sequences. Defaults to 2048. 47 | max_length_text: Maximum length of text sequences. Defaults to 512. 48 | text_model_finetune: Whether to finetune the text model. Defaults to True. 49 | dna_model_finetune: Whether to finetune the DNA model. Defaults to True. 50 | dna_is_evo2: Whether the DNA model is Evo2. Defaults to False. 51 | dna_embedding_layer: Name of the layer to use for the Evo2 model. Defaults to None. 52 | """ 53 | super().__init__() 54 | 55 | self.text_model_finetune = text_model_finetune 56 | self.dna_model_finetune = dna_model_finetune 57 | self.max_length_dna = max_length_dna 58 | self.max_length_text = max_length_text 59 | self.dna_is_evo2 = dna_is_evo2 60 | self.dna_embedding_layer = dna_embedding_layer 61 | 62 | 63 | # Load the text model and tokenizer 64 | self.text_model = AutoModelForCausalLM.from_pretrained( 65 | text_model_name, cache_dir=cache_dir, trust_remote_code=True 66 | ) 67 | self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name, trust_remote_code=True) 68 | self.text_config = self.text_model.config 69 | self.text_tokenizer.chat_template = CHAT_TEMPLATE 70 | self.text_tokenizer.pad_token = self.text_tokenizer.eos_token 71 | 72 | new_tokens = ["<|dna_start|>", "<|dna_pad|>", "<|dna_end|>"] 73 | self.text_tokenizer.add_special_tokens({"additional_special_tokens": new_tokens}) 74 | self.dna_token_id = self.text_tokenizer.convert_tokens_to_ids("<|dna_pad|>") 75 | 76 | 77 | # Load the DNA model and tokenizer 78 | if not self.dna_is_evo2: 79 | self.dna_model = AutoModelForMaskedLM.from_pretrained( 80 | dna_model_name, cache_dir=cache_dir, trust_remote_code=True 81 | ) 82 | self.dna_tokenizer = AutoTokenizer.from_pretrained(dna_model_name, trust_remote_code=True) 83 | self.dna_config = self.dna_model.config 84 | 85 | else: 86 | from evo2 import Evo2 87 | self.dna_model = Evo2(dna_model_name) 88 | self.dna_tokenizer = Evo2Tokenizer(self.dna_model.tokenizer) 89 | self.dna_config = self.dna_model.model.config 90 | self.dna_embedding_layer = self.dna_embedding_layer 91 | 92 | # Get model dimensions 93 | self.text_hidden_size = self.text_config.hidden_size 94 | self.dna_hidden_size = self.dna_config.hidden_size 95 | 96 | # Create projection layer to map DNA embeddings to text model's embedding space 97 | self.dna_projection = nn.Linear(self.dna_hidden_size, self.text_hidden_size) 98 | 99 | # Create processor for handling inputs 100 | self.processor = DLProcessor(tokenizer=self.text_tokenizer, dna_tokenizer=self.dna_tokenizer) 101 | 102 | 103 | def process_dna_embeddings( 104 | self, 105 | dna_tokenized: Dict[str, torch.Tensor], 106 | batch_idx_map: List[int], 107 | batch_size: int, 108 | ) -> List[torch.Tensor]: 109 | """ 110 | Process DNA sequences to obtain embeddings. 111 | 112 | Args: 113 | dna_tokenized: Tokenized DNA sequences 114 | batch_idx_map: Mapping of each sequence to its batch item 115 | batch_size: Number of items in the batch 116 | 117 | Returns: 118 | List of tensor embeddings for each batch item 119 | """ 120 | # Process all sequences to get DNA representations 121 | with torch.no_grad(): 122 | # Handle different model types based on dna_is_evo2 attribute 123 | if self.dna_is_evo2 and self.dna_embedding_layer is not None: # Evo2 model 124 | # Get embeddings from the specific layer in Evo2 125 | hidden_states_list = [] 126 | 127 | for seq_idx in range(len(dna_tokenized["input_ids"])): 128 | # Extract single sequence 129 | input_ids = dna_tokenized["input_ids"][seq_idx:seq_idx+1] 130 | 131 | # Call Evo2 with return_embeddings=True 132 | _, embeddings = self.dna_model( 133 | input_ids, 134 | return_embeddings=True, 135 | layer_names=[self.dna_embedding_layer] 136 | ) 137 | 138 | # Get embeddings for the specified layer 139 | seq_embeddings = embeddings[self.dna_embedding_layer].squeeze(0) 140 | hidden_states_list.append(seq_embeddings) 141 | 142 | # Stack to get same format as non-Evo2 output 143 | if hidden_states_list: 144 | hidden_states = torch.stack(hidden_states_list) 145 | else: 146 | return [torch.zeros((0, self.text_hidden_size)) for _ in range(batch_size)] 147 | 148 | else: # Standard HuggingFace model 149 | # Use existing code path for HF models 150 | outputs = self.dna_model( 151 | input_ids=dna_tokenized["input_ids"], 152 | attention_mask=dna_tokenized["attention_mask"], 153 | output_hidden_states=True, 154 | ) 155 | # Get the last hidden state 156 | hidden_states = outputs.hidden_states[-1] # shape: [n_seqs, seq_len, hidden_dim] 157 | 158 | # Project all embeddings at once 159 | hidden_states = hidden_states.to(device=self.dna_projection.weight.device, dtype=self.dna_projection.weight.dtype) 160 | projected_states = self.dna_projection(hidden_states) 161 | 162 | # Group embeddings by batch item 163 | result = [[] for _ in range(batch_size)] 164 | 165 | # For each sequence, get its embeddings and add to appropriate batch result 166 | for seq_idx, batch_idx in enumerate(batch_idx_map): 167 | # Get only the valid (non-padding) tokens 168 | valid_length = dna_tokenized["attention_mask"][seq_idx].sum().item() 169 | seq_embedding = projected_states[seq_idx, :valid_length] 170 | result[batch_idx].append(seq_embedding) 171 | 172 | # Concatenate embeddings for each batch item 173 | for i in range(batch_size): 174 | if result[i]: 175 | result[i] = torch.cat(result[i], dim=0) 176 | else: 177 | result[i] = torch.zeros((0, self.text_hidden_size)) 178 | 179 | return result 180 | 181 | def forward( 182 | self, 183 | input_ids: Optional[torch.Tensor] = None, 184 | attention_mask: Optional[torch.Tensor] = None, 185 | dna_tokenized: Optional[Dict[str, torch.Tensor]] = None, 186 | batch_idx_map: Optional[List[int]] = None, 187 | labels: Optional[torch.Tensor] = None, 188 | **kwargs, 189 | ) -> torch.Tensor: 190 | """ 191 | Generate text based on DNA and text inputs. 192 | 193 | Args: 194 | input_ids: Input IDs (used if provided directly) 195 | attention_mask: Attention mask (used if provided directly) 196 | dna_tokenized: Tokenized DNA sequences (used if provided directly) 197 | batch_idx_map: Batch mapping for DNA sequences (used if provided directly) 198 | labels: Labels for supervised fine-tuning (used if provided directly) 199 | **kwargs: Additional arguments for generation 200 | 201 | Returns: 202 | Outputs from the text model 203 | """ 204 | # Ensure required inputs are available 205 | if input_ids is None or attention_mask is None: 206 | raise ValueError("Either 'inputs' or 'input_ids'/'attention_mask' must be provided") 207 | 208 | batch_size = input_ids.shape[0] 209 | 210 | # Get text embeddings from the model's embedding layer 211 | text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids) 212 | 213 | if dna_tokenized is not None and batch_idx_map: 214 | batch_dna_embeds = self.process_dna_embeddings(dna_tokenized, batch_idx_map, batch_size) 215 | 216 | mask = input_ids == self.dna_token_id 217 | 218 | n_dna_tokens = mask.sum().item() 219 | dna_embeds_flat = torch.cat(batch_dna_embeds, dim=0) 220 | n_dna_features = dna_embeds_flat.shape[0] 221 | 222 | if n_dna_features != n_dna_tokens: 223 | raise ValueError( 224 | f"DNA features and DNA tokens do not match: features {n_dna_features}, tokens: {n_dna_tokens}" 225 | ) 226 | 227 | # Ensure DNA embeddings have the same dtype as the text embeddings 228 | dna_embeds_flat = dna_embeds_flat.to(dtype=text_inputs_embeds.dtype) 229 | text_inputs_embeds[mask] = dna_embeds_flat 230 | 231 | # Handle labels if provided (for training) 232 | if labels is not None: 233 | # TODO: Implement this 234 | pass 235 | 236 | # Forward pass through the text model (loss is computed if labels is provided) 237 | outputs = self.text_model( 238 | inputs_embeds=text_inputs_embeds, 239 | attention_mask=attention_mask, 240 | labels=labels, 241 | **kwargs, 242 | ) 243 | 244 | return outputs 245 | 246 | def generate( 247 | self, 248 | input_ids: Optional[torch.Tensor] = None, 249 | attention_mask: Optional[torch.Tensor] = None, 250 | dna_tokenized: Optional[Dict[str, torch.Tensor]] = None, 251 | batch_idx_map: Optional[List[int]] = None, 252 | **generation_kwargs, 253 | ) -> Union[torch.Tensor, List[str]]: 254 | """ 255 | Generate text based on DNA and text inputs. 256 | 257 | Args: 258 | inputs: The preprocessed inputs from the processor (preferred method) 259 | batch_dna_sequences: List of lists of DNA sequences per batch item (legacy method) 260 | input_texts: List of input texts (legacy method) 261 | input_ids: Input IDs (used if provided directly) 262 | attention_mask: Attention mask (used if provided directly) 263 | dna_tokenized: Tokenized DNA sequences (used if provided directly) 264 | batch_idx_map: Batch mapping for DNA sequences (used if provided directly) 265 | **generation_kwargs: Additional arguments for generation 266 | 267 | Returns: 268 | Generated token IDs which can be decoded using the processor 269 | """ 270 | # Ensure required inputs are available 271 | if input_ids is None or attention_mask is None: 272 | raise ValueError("Either 'inputs' or 'input_ids'/'attention_mask' must be provided") 273 | 274 | batch_size = input_ids.shape[0] 275 | 276 | # Get text embeddings from the model's embedding layer 277 | text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids) 278 | 279 | if dna_tokenized is not None and batch_idx_map: 280 | batch_dna_embeds = self.process_dna_embeddings(dna_tokenized, batch_idx_map, batch_size) 281 | 282 | mask = input_ids == self.dna_token_id 283 | 284 | n_dna_tokens = mask.sum().item() 285 | dna_embeds_flat = torch.cat(batch_dna_embeds, dim=0) 286 | n_dna_features = dna_embeds_flat.shape[0] 287 | 288 | if n_dna_features != n_dna_tokens: 289 | raise ValueError( 290 | f"DNA features and DNA tokens do not match: features {n_dna_features}, tokens: {n_dna_tokens}" 291 | ) 292 | 293 | # Ensure DNA embeddings have the same dtype as the text embeddings 294 | dna_embeds_flat = dna_embeds_flat.to(dtype=text_inputs_embeds.dtype) 295 | text_inputs_embeds[mask] = dna_embeds_flat 296 | 297 | # Generation parameters may need adjustment based on model type 298 | with torch.no_grad(): 299 | outputs = self.text_model.generate( 300 | inputs_embeds=text_inputs_embeds, 301 | attention_mask=attention_mask, 302 | use_cache=True, 303 | **generation_kwargs, 304 | ) 305 | 306 | return outputs -------------------------------------------------------------------------------- /bioreason/models/dna_only.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Dict 5 | from transformers import AutoModelForMaskedLM, AutoTokenizer 6 | 7 | 8 | class SelfAttentionPooling(nn.Module): 9 | def __init__(self, hidden_size, num_heads=8): 10 | super().__init__() 11 | # Use PyTorch's built-in multi-head attention 12 | self.attention = nn.MultiheadAttention( 13 | embed_dim=hidden_size, 14 | num_heads=num_heads, 15 | batch_first=True 16 | ) 17 | # Learnable query vector 18 | self.query = nn.Parameter(torch.randn(1, 1, hidden_size)) 19 | 20 | def forward(self, embeddings, attention_mask=None): 21 | # Expand query to batch size 22 | batch_size = embeddings.size(0) 23 | query = self.query.expand(batch_size, -1, -1) 24 | 25 | # Create key padding mask from attention mask if provided 26 | key_padding_mask = None 27 | if attention_mask is not None: 28 | key_padding_mask = attention_mask == 0 # Convert to boolean mask where True means ignore 29 | 30 | # Apply attention: query attends to embeddings 31 | context, _ = self.attention( 32 | query=query, # [batch_size, 1, hidden_size] 33 | key=embeddings, # [batch_size, seq_len, hidden_size] 34 | value=embeddings, # [batch_size, seq_len, hidden_size] 35 | key_padding_mask=key_padding_mask 36 | ) 37 | 38 | # Squeeze out the singleton dimension 39 | return context.squeeze(1) # [batch_size, hidden_size] 40 | 41 | 42 | class DNAClassifierModel(nn.Module): 43 | """ 44 | A simple classifier that uses a DNA model with a classification head. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | dna_model_name: str, 50 | cache_dir: str = None, 51 | max_length_dna: int = 4096, 52 | num_classes: int = 2, # Binary classification by default 53 | dna_is_evo2: bool = False, 54 | dna_embedding_layer: str = None, 55 | train_just_classifier: bool = True 56 | ): 57 | """ 58 | Initialize the DNAClassifierModel. 59 | 60 | Args: 61 | dna_model_name (str): Name of the DNA model to use 62 | cache_dir (str): Directory to cache models 63 | max_length_dna (int): Maximum sequence length 64 | num_classes (int): Number of output classes 65 | dna_is_evo2: Whether the DNA model is Evo2. Defaults to False 66 | dna_embedding_layer: Name of the layer to use for the Evo2 model. Defaults to None 67 | train_just_classifier: Whether to train just the classifier. Defaults to True 68 | """ 69 | super().__init__() 70 | 71 | self.dna_model_name = dna_model_name 72 | self.cache_dir = cache_dir 73 | self.max_length_dna = max_length_dna 74 | self.num_classes = num_classes 75 | self.dna_is_evo2 = dna_is_evo2 76 | self.dna_embedding_layer = dna_embedding_layer 77 | self.train_just_classifier = train_just_classifier 78 | 79 | # Load the DNA model and tokenizer 80 | if not self.dna_is_evo2: 81 | self.dna_model = AutoModelForMaskedLM.from_pretrained( 82 | dna_model_name, cache_dir=cache_dir, trust_remote_code=True 83 | ) 84 | self.dna_tokenizer = AutoTokenizer.from_pretrained(dna_model_name, trust_remote_code=True) 85 | self.dna_config = self.dna_model.config 86 | 87 | else: 88 | from evo2 import Evo2 89 | from bioreason.models.evo2_tokenizer import Evo2Tokenizer 90 | self.dna_model = Evo2(dna_model_name) 91 | self.dna_tokenizer = Evo2Tokenizer(self.dna_model.tokenizer) 92 | self.dna_config = self.dna_model.model.config 93 | self.dna_embedding_layer = self.dna_embedding_layer 94 | 95 | # Get hidden size from model config 96 | self.hidden_size = self.dna_config.hidden_size 97 | 98 | # Add the self-attention pooling module 99 | self.pooler = SelfAttentionPooling(self.hidden_size) 100 | 101 | # Create classification head that takes concatenated embeddings from both sequences 102 | self.classifier = nn.Sequential( 103 | nn.Linear(self.hidden_size * 2, self.hidden_size), 104 | nn.ReLU(), 105 | nn.Dropout(0.1), 106 | nn.Linear(self.hidden_size, num_classes), 107 | ) 108 | 109 | self.max_length_dna = max_length_dna 110 | 111 | def get_dna_embedding(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): 112 | """ 113 | Get DNA embedding for a single DNA sequence using self-attention pooling. 114 | 115 | Args: 116 | input_ids: DNA tokenized sequence 117 | attention_mask: DNA tokenized sequence attention mask 118 | 119 | Returns: 120 | torch.Tensor: Tensor containing the self-attention pooled DNA embedding 121 | """ 122 | # Add batch dimension if not present 123 | if input_ids.dim() == 1: 124 | input_ids = input_ids.unsqueeze(0) # [1, seq_len] 125 | 126 | # Handle attention mask - create if not provided or add batch dimension 127 | if attention_mask is None: 128 | attention_mask = torch.ones_like(input_ids) 129 | elif attention_mask.dim() == 1: 130 | attention_mask = attention_mask.unsqueeze(0) # [1, seq_len] 131 | 132 | # Get embeddings from DNA model 133 | with torch.set_grad_enabled(not self.train_just_classifier): # Enable gradients for fine-tuning 134 | 135 | if self.dna_is_evo2 and self.dna_embedding_layer is not None: # Evo2 model 136 | # Get embeddings from the specific layer in Evo2 137 | _, embeddings = self.dna_model( 138 | input_ids, 139 | return_embeddings=True, 140 | layer_names=[self.dna_embedding_layer] 141 | ) 142 | 143 | # Get embeddings for the specified layer 144 | hidden_states = embeddings[self.dna_embedding_layer] 145 | 146 | else: 147 | # Get embeddings from the last hidden state 148 | outputs = self.dna_model( 149 | input_ids, 150 | attention_mask=attention_mask, 151 | output_hidden_states=True, 152 | ) 153 | 154 | # Get the last hidden state 155 | hidden_states = outputs.hidden_states[-1] 156 | 157 | # Apply self-attention pooling to get a weighted representation 158 | sequence_embedding = self.pooler(hidden_states, attention_mask) 159 | return sequence_embedding.squeeze(0) 160 | 161 | def forward( 162 | self, ref_ids=None, alt_ids=None, ref_attention_mask=None, alt_attention_mask=None 163 | ): 164 | """ 165 | Forward pass of the model. 166 | 167 | Args: 168 | ref_ids: Reference sequence token IDsself.dna_model 169 | alt_ids: Alternate sequence token IDsself.dna_model 170 | ref_attention_mask: Reference sequence attention maskself.dna_model 171 | alt_attention_mask: Alternate sequence attention maskself.dna_model 172 | 173 | Returns: 174 | torch.Tensor: Classification logits 175 | """ 176 | batch_size = ref_ids.shape[0] if ref_ids is not None else alt_ids.shape[0] 177 | 178 | if batch_size is None: 179 | raise ValueError("Either token IDs must be provided") 180 | 181 | ref_embeddings = [] 182 | alt_embeddings = [] 183 | 184 | # Process each example in the batch 185 | for i in range(batch_size): 186 | 187 | # Get sequence embeddings 188 | ref_embed = self.get_dna_embedding(ref_ids[i], ref_attention_mask[i]) 189 | alt_embed = self.get_dna_embedding(alt_ids[i], alt_attention_mask[i]) 190 | ref_embeddings.append(ref_embed) 191 | alt_embeddings.append(alt_embed) 192 | 193 | # Stack embeddings 194 | ref_embeddings = torch.stack(ref_embeddings) 195 | alt_embeddings = torch.stack(alt_embeddings) 196 | 197 | # Concatenate ref and alt embeddings 198 | combined_embeddings = torch.cat([ref_embeddings, alt_embeddings], dim=1) 199 | 200 | # Pass through classifier 201 | logits = self.classifier(combined_embeddings) 202 | 203 | return logits -------------------------------------------------------------------------------- /bioreason/models/evo2_tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers.tokenization_utils import PreTrainedTokenizer 2 | from transformers.utils import logging 3 | from transformers import AutoTokenizer 4 | from transformers.tokenization_utils_base import BatchEncoding 5 | import torch 6 | import numpy as np 7 | from typing import List, Dict, Optional, Union, Tuple 8 | 9 | # Register the tokenizer with AutoTokenizer 10 | from transformers.models.auto import AutoTokenizer 11 | from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING 12 | from transformers.models.auto.configuration_auto import CONFIG_MAPPING 13 | 14 | logger = logging.get_logger(__name__) 15 | 16 | class Evo2Tokenizer(PreTrainedTokenizer): 17 | """ 18 | Tokenizer for Evo2 models - wraps the CharLevelTokenizer to be compatible with HuggingFace. 19 | """ 20 | vocab_files_names = {} # No vocab files needed 21 | model_input_names = ["input_ids", "attention_mask"] 22 | 23 | def __init__( 24 | self, 25 | evo2_tokenizer, 26 | bos_token="", 27 | eos_token="", 28 | pad_token="", 29 | unk_token="", 30 | **kwargs 31 | ): 32 | """ 33 | Initialize the Evo2Tokenizer. 34 | 35 | Args: 36 | evo2_tokenizer: The Evo2 CharLevelTokenizer to wrap 37 | bos_token: Beginning of sequence token 38 | eos_token: End of sequence token 39 | pad_token: Padding token 40 | unk_token: Unknown token 41 | """ 42 | self.evo2_tokenizer = evo2_tokenizer 43 | 44 | # Map special tokens to Evo2 tokenizer's special token IDs 45 | self._pad_token = pad_token 46 | self._eos_token = eos_token 47 | self._bos_token = bos_token 48 | self._unk_token = unk_token 49 | 50 | # Initialize with special tokens 51 | super().__init__( 52 | bos_token=bos_token, 53 | eos_token=eos_token, 54 | pad_token=pad_token, 55 | unk_token=unk_token, 56 | **kwargs 57 | ) 58 | 59 | # Set token IDs from Evo2 tokenizer 60 | self.pad_token_id = self.evo2_tokenizer.pad_id 61 | self.eos_token_id = self.evo2_tokenizer.eos_id 62 | 63 | @property 64 | def vocab_size(self) -> int: 65 | """Return the vocab size of the tokenizer.""" 66 | return self.evo2_tokenizer.vocab_size 67 | 68 | def get_vocab(self) -> Dict: 69 | """Return vocab as a dictionary.""" 70 | # Evo2 CharLevelTokenizer doesn't have a traditional vocab dict 71 | # Create a simple mapping of ASCII codes to tokens 72 | return {chr(i): i for i in range(self.vocab_size)} 73 | 74 | def _tokenize(self, text: str) -> List[int]: 75 | """Tokenize a string using the Evo2 tokenizer.""" 76 | return [chr(int(token)) for token in self.evo2_tokenizer.tokenize(text)] 77 | 78 | def _convert_token_to_id(self, token: str) -> int: 79 | """Convert a token to an id using the Evo2 tokenizer.""" 80 | # Since tokens are just characters, convert to their ASCII value 81 | return ord(token) 82 | 83 | def _convert_id_to_token(self, index: int) -> str: 84 | """Convert an id to a token using the Evo2 tokenizer.""" 85 | # Convert ASCII value back to character 86 | return chr(index) 87 | 88 | def convert_tokens_to_string(self, tokens: List[str]) -> str: 89 | """Convert a sequence of tokens to a single string.""" 90 | return "".join(tokens) 91 | 92 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 93 | """No vocabulary to save for Evo2Tokenizer, so just return an empty tuple.""" 94 | return () 95 | 96 | def __call__( 97 | self, 98 | text: Union[str, List[str]], 99 | text_pair: Optional[Union[str, List[str]]] = None, 100 | padding: Union[bool, str] = False, 101 | truncation: Union[bool, str] = False, 102 | max_length: Optional[int] = None, 103 | return_tensors: Optional[str] = None, 104 | return_token_type_ids: Optional[bool] = None, 105 | return_attention_mask: Optional[bool] = True, 106 | **kwargs 107 | ) -> Dict[str, torch.Tensor]: 108 | """ 109 | Main tokenization method that handles batching and converts to tensors. 110 | """ 111 | # Handle single string vs list of strings 112 | if isinstance(text, str): 113 | text = [text] 114 | 115 | # Tokenize all sequences - note: tokenizer only accepts strings, not lists 116 | input_ids_list = [] 117 | for seq in text: 118 | # Tokenize and convert numpy.uint8 to Python integers 119 | tokens = [int(token) for token in self.evo2_tokenizer.tokenize(seq)] 120 | 121 | # Truncate if needed 122 | if truncation and max_length and len(tokens) > max_length: 123 | tokens = tokens[:max_length] 124 | 125 | input_ids_list.append(tokens) 126 | 127 | # Apply padding if needed 128 | if padding: 129 | if False:#max_length: 130 | max_len = max_length 131 | else: 132 | max_len = max(len(ids) for ids in input_ids_list) 133 | 134 | # Create padded sequences and attention masks 135 | padded_input_ids = [] 136 | attention_mask = [] 137 | 138 | for ids in input_ids_list: 139 | # Apply left padding (pad on the left) 140 | padding_length = max_len - len(ids) 141 | padded_ids = [self.pad_token_id] * padding_length + ids 142 | mask = [0] * padding_length + [1] * len(ids) 143 | 144 | padded_input_ids.append(padded_ids) 145 | attention_mask.append(mask) 146 | 147 | input_ids_list = padded_input_ids 148 | else: 149 | # Create attention mask without padding 150 | attention_mask = [[1] * len(ids) for ids in input_ids_list] 151 | 152 | # Create result dictionary 153 | result = {"input_ids": input_ids_list} 154 | if return_attention_mask: 155 | result["attention_mask"] = attention_mask 156 | 157 | # Convert to tensors if requested 158 | if return_tensors == "pt": 159 | result = {k: torch.tensor(v) for k, v in result.items()} 160 | 161 | # Return a BatchEncoding object rather than a plain dictionary 162 | return BatchEncoding( 163 | data=result, 164 | tensor_type=return_tensors, 165 | prepend_batch_axis=False, # Already handled in our tensor creation 166 | encoding=None # No encoding info from Evo2's tokenizer 167 | ) 168 | 169 | def batch_decode( 170 | self, 171 | sequences: Union[List[int], List[List[int]], torch.Tensor], 172 | skip_special_tokens: bool = False, 173 | **kwargs 174 | ) -> List[str]: 175 | """ 176 | Decode a batch of token ids to strings. 177 | """ 178 | if isinstance(sequences, torch.Tensor): 179 | sequences = sequences.tolist() 180 | 181 | return self.evo2_tokenizer.detokenize_batch(sequences) 182 | 183 | def decode( 184 | self, 185 | token_ids: Union[int, List[int], torch.Tensor], 186 | skip_special_tokens: bool = False, 187 | **kwargs 188 | ) -> str: 189 | """ 190 | Decode a single sequence of token ids to a string. 191 | """ 192 | if isinstance(token_ids, torch.Tensor): 193 | token_ids = token_ids.tolist() 194 | 195 | # Single sequence 196 | if not isinstance(token_ids, list) or not token_ids or not isinstance(token_ids[0], (list, torch.Tensor)): 197 | return self.evo2_tokenizer.detokenize(token_ids) 198 | 199 | # Batch with one item 200 | return self.batch_decode(token_ids, skip_special_tokens, **kwargs)[0] 201 | 202 | 203 | # Register the tokenizer - you'll need to do this when your script loads 204 | # You might want to put this in your __init__.py file 205 | def register_evo2_tokenizer(): 206 | """Register the Evo2Tokenizer with HuggingFace's AutoTokenizer.""" 207 | 208 | # This will register the tokenizer so AutoTokenizer.from_pretrained knows about it 209 | AutoTokenizer.register("evo2", Evo2Tokenizer) 210 | 211 | # If you have a config class, you would also register that 212 | # from transformers.models.auto import AutoConfig 213 | # AutoConfig.register("evo2", Evo2Config) 214 | 215 | print("Evo2Tokenizer registered with AutoTokenizer") 216 | 217 | 218 | if __name__ == "__main__": 219 | register_evo2_tokenizer() -------------------------------------------------------------------------------- /bioreason/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .grpo_config import DNALLMGRPOConfig 2 | from .grpo_trainer import DNALLMGRPOTrainer 3 | 4 | __all__ = [ 5 | "DNALLMGRPOConfig", 6 | "DNALLMGRPOTrainer", 7 | ] -------------------------------------------------------------------------------- /bioreason/trainer/grpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Optional, Union 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class DNALLMGRPOConfig(TrainingArguments): 23 | r""" 24 | Configuration class for the [`GRPOTrainer`]. 25 | 26 | Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the 27 | [`~transformers.TrainingArguments`] documentation. 28 | 29 | Using [`~transformers.HfArgumentParser`] we can turn this class into 30 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 31 | command line. 32 | 33 | Parameters: 34 | > Parameters that control the model and reference model 35 | 36 | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): 37 | Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` 38 | argument of the [`GRPOTrainer`] is provided as a string. 39 | 40 | > Parameters that control the data preprocessing 41 | 42 | remove_unused_columns (`bool`, *optional*, defaults to `False`): 43 | Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that 44 | requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. 45 | max_prompt_length (`int` or `None`, *optional*, defaults to `512`): 46 | Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. 47 | num_generations (`int` or `None`, *optional*, defaults to `8`): 48 | Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size) 49 | must be divisible by this value. 50 | max_completion_length (`int` or `None`, *optional*, defaults to `256`): 51 | Maximum length of the generated completion. 52 | ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): 53 | This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, 54 | improving generation speed. However, disabling this option allows training models that exceed the VRAM 55 | capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible 56 | with vLLM generation. 57 | 58 | > Parameters that control generation 59 | 60 | temperature (`float`, defaults to `0.9`): 61 | Temperature for sampling. The higher the temperature, the more random the completions. 62 | top_p (`float`, *optional*, defaults to `1.0`): 63 | Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to 64 | `1.0` to consider all tokens. 65 | top_k (`int` or `None`, *optional*, defaults to `50`): 66 | Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is 67 | disabled. 68 | min_p (`float` or `None`, *optional*, defaults to `None`): 69 | Minimum token probability, which will be scaled by the probability of the most likely token. It must be a 70 | value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. 71 | repetition_penalty (`float`, *optional*, defaults to `1.0`): 72 | Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. 73 | Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat 74 | tokens. 75 | cache_implementation (`str` or `None`, *optional*, defaults to `None`): 76 | Implementation of the cache method for faster generation when use_vllm is set to False. 77 | 78 | > Parameters that control generation acceleration powered by vLLM 79 | 80 | use_vllm (`bool`, *optional*, defaults to `False`): 81 | Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for 82 | training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`). 83 | vllm_device (`str`, *optional*, defaults to `"auto"`): 84 | Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will 85 | automatically select the next available GPU after the last one used for training. This assumes that 86 | training has not already occupied all available GPUs. If only one device is available, the device will be 87 | shared between both training and vLLM. 88 | vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): 89 | Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the 90 | device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus 91 | improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors 92 | during initialization. 93 | vllm_dtype (`str`, *optional*, defaults to `"auto"`): 94 | Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined 95 | based on the model configuration. Find the supported values in the vLLM documentation. 96 | vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`): 97 | If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced 98 | `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model 99 | context size, which might be much larger than the KV cache, leading to inefficiencies. 100 | vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`): 101 | Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware 102 | support this feature. 103 | vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): 104 | Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. 105 | 106 | > Parameters that control the training 107 | 108 | learning_rate (`float`, *optional*, defaults to `1e-6`): 109 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of 110 | [`~transformers.TrainingArguments`]. 111 | beta (`float`, *optional*, defaults to `0.04`): 112 | KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training 113 | speed, but may be numerically unstable for long training runs. 114 | num_iterations (`int`, *optional*, defaults to `1`): 115 | Number of iterations per batch (denoted as μ in the algorithm). 116 | epsilon (`float`, *optional*, defaults to `0.2`): 117 | Epsilon value for clipping. 118 | epsilon_high (`float` or `None`, *optional*, defaults to `None`): 119 | Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound 120 | specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. 121 | reward_weights (`list[float]` or `None`, *optional*, defaults to `None`): 122 | Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are 123 | weighted equally with weight `1.0`. 124 | sync_ref_model (`bool`, *optional*, defaults to `False`): 125 | Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using 126 | the `ref_model_mixup_alpha` parameter. This synchronization originites from the 127 | [TR-DPO](https://huggingface.co/papers/2404.09656) paper. 128 | ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): 129 | α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix 130 | between the current policy and the previous reference policy during updates. The reference policy is 131 | updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you 132 | must set `sync_ref_model=True`. 133 | ref_model_sync_steps (`int`, *optional*, defaults to `512`): 134 | τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how 135 | frequently the current policy is synchronized with the reference policy. To use this parameter, you must 136 | set `sync_ref_model=True`. 137 | 138 | > Parameters that control the logging 139 | 140 | log_completions (`bool`, *optional*, defaults to `False`): 141 | Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is 142 | installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`. 143 | """ 144 | 145 | # Parameters that control the model and reference model 146 | model_init_kwargs: Optional[dict] = field( 147 | default=None, 148 | metadata={ 149 | "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` " 150 | "argument of the `GRPOTrainer` is provided as a string." 151 | }, 152 | ) 153 | 154 | # Parameters that control the data preprocessing 155 | # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on 156 | # additional columns to compute the reward 157 | remove_unused_columns: Optional[bool] = field( 158 | default=False, 159 | metadata={ 160 | "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function " 161 | "that requires any column other than 'prompts' and 'completions', you should keep this to `False`." 162 | }, 163 | ) 164 | max_prompt_length: Optional[int] = field( 165 | default=512, 166 | metadata={ 167 | "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." 168 | }, 169 | ) 170 | num_generations: Optional[int] = field( 171 | default=8, 172 | metadata={ 173 | "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) " 174 | "must be divisible by this value." 175 | }, 176 | ) 177 | max_completion_length: Optional[int] = field( 178 | default=800, 179 | metadata={"help": "Maximum length of the generated completion."}, 180 | ) 181 | ds3_gather_for_generation: bool = field( 182 | default=True, 183 | metadata={ 184 | "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " 185 | "generation, improving generation speed. However, disabling this option allows training models that " 186 | "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option " 187 | "is not compatible with vLLM generation." 188 | }, 189 | ) 190 | 191 | # Parameters that control generation 192 | temperature: float = field( 193 | default=0.6, 194 | metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, 195 | ) 196 | top_p: float = field( 197 | default=0.95, 198 | metadata={ 199 | "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. " 200 | "Set to 1.0 to consider all tokens." 201 | }, 202 | ) 203 | top_k: Optional[int] = field( 204 | default=20, 205 | metadata={ 206 | "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, " 207 | "top-k-filtering is disabled." 208 | }, 209 | ) 210 | min_p: Optional[float] = field( 211 | default=None, 212 | metadata={ 213 | "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It " 214 | "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range." 215 | }, 216 | ) 217 | repetition_penalty: float = field( 218 | default=1.0, 219 | metadata={ 220 | "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated " 221 | "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model " 222 | "to repeat tokens." 223 | }, 224 | ) 225 | cache_implementation: Optional[str] = field( 226 | default=None, 227 | metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."}, 228 | ) 229 | 230 | # Parameters that control generation acceleration powered by vLLM 231 | use_vllm: Optional[bool] = field( 232 | default=False, 233 | metadata={ 234 | "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept " 235 | "unused for training, as vLLM will require one for generation. vLLM must be installed " 236 | "(`pip install vllm`)." 237 | }, 238 | ) 239 | vllm_device: Optional[str] = field( 240 | default="auto", 241 | metadata={ 242 | "help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system " 243 | "will automatically select the next available GPU after the last one used for training. This assumes " 244 | "that training has not already occupied all available GPUs." 245 | }, 246 | ) 247 | vllm_gpu_memory_utilization: float = field( 248 | default=0.9, 249 | metadata={ 250 | "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " 251 | "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " 252 | "size and thus improve the model's throughput. However, if the value is too high, it may cause " 253 | "out-of-memory (OOM) errors during initialization." 254 | }, 255 | ) 256 | vllm_dtype: Optional[str] = field( 257 | default="auto", 258 | metadata={ 259 | "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " 260 | "determined based on the model configuration. Find the supported values in the vLLM documentation." 261 | }, 262 | ) 263 | vllm_max_model_len: Optional[int] = field( 264 | default=None, 265 | metadata={ 266 | "help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced " 267 | "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model " 268 | "context size, which might be much larger than the KV cache, leading to inefficiencies." 269 | }, 270 | ) 271 | vllm_enable_prefix_caching: Optional[bool] = field( 272 | default=True, 273 | metadata={ 274 | "help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and " 275 | "the hardware support this feature." 276 | }, 277 | ) 278 | vllm_guided_decoding_regex: Optional[str] = field( 279 | default=None, 280 | metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, 281 | ) 282 | 283 | # Parameters that control the training 284 | learning_rate: float = field( 285 | default=1e-6, 286 | metadata={ 287 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " 288 | "`transformers.TrainingArguments`." 289 | }, 290 | ) 291 | beta: float = field( 292 | default=0.04, 293 | metadata={ 294 | "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving " 295 | "training speed, but may be numerically unstable for long training runs." 296 | }, 297 | ) 298 | num_iterations: int = field( 299 | default=1, 300 | metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."}, 301 | ) 302 | epsilon: float = field( 303 | default=0.2, 304 | metadata={"help": "Epsilon value for clipping."}, 305 | ) 306 | epsilon_high: Optional[float] = field( 307 | default=None, 308 | metadata={ 309 | "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " 310 | "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`." 311 | }, 312 | ) 313 | reward_weights: Optional[list[float]] = field( 314 | default=None, 315 | metadata={ 316 | "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all " 317 | "rewards are weighted equally with weight `1.0`." 318 | }, 319 | ) 320 | sync_ref_model: bool = field( 321 | default=False, 322 | metadata={ 323 | "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " 324 | "steps, using the `ref_model_mixup_alpha` parameter." 325 | }, 326 | ) 327 | ref_model_mixup_alpha: float = field( 328 | default=0.6, 329 | metadata={ 330 | "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " 331 | "previous reference policy during updates. The reference policy is updated according to the equation: " 332 | "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." 333 | }, 334 | ) 335 | ref_model_sync_steps: int = field( 336 | default=512, 337 | metadata={ 338 | "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " 339 | "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." 340 | }, 341 | ) 342 | 343 | # Parameters that control the logging 344 | log_completions: bool = field( 345 | default=True, 346 | metadata={ 347 | "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is " 348 | "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." 349 | }, 350 | ) 351 | 352 | report_to: Union[None, str, list[str]] = field( 353 | default="wandb", metadata={"help": "The list of integrations to report the results and logs to."} 354 | ) 355 | 356 | logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"}) 357 | logging_steps: float = field( 358 | default=2, 359 | metadata={ 360 | "help": ( 361 | "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " 362 | "If smaller than 1, will be interpreted as ratio of total training steps." 363 | ) 364 | }, 365 | ) -------------------------------------------------------------------------------- /bioreason/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/BioReason/e74aa1cf06445aada1e48281840f83403b832b64/bioreason/utils/__init__.py -------------------------------------------------------------------------------- /bioreason/utils/dna_utils.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Callable, Optional, Union 2 | 3 | import numpy as np 4 | 5 | from transformers.utils import is_torch_available 6 | 7 | if is_torch_available(): 8 | import torch 9 | 10 | DNAInput = Union[ 11 | str, list[int], np.ndarray, "torch.Tensor", list[str], list[list[int]], list[np.ndarray], list["torch.Tensor"] 12 | ] # noqa -------------------------------------------------------------------------------- /figures/Figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/BioReason/e74aa1cf06445aada1e48281840f83403b832b64/figures/Figure1.png -------------------------------------------------------------------------------- /figures/Figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/BioReason/e74aa1cf06445aada1e48281840f83403b832b64/figures/Figure2.png -------------------------------------------------------------------------------- /figures/Figure3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bowang-lab/BioReason/e74aa1cf06445aada1e48281840f83403b832b64/figures/Figure3.png -------------------------------------------------------------------------------- /grpo_trainer_lora_model/adapter_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "alpha_pattern": {}, 3 | "auto_mapping": null, 4 | "base_model_name_or_path": "unsloth/qwen2.5-1.5b-instruct-unsloth-bnb-4bit", 5 | "bias": "none", 6 | "eva_config": null, 7 | "exclude_modules": null, 8 | "fan_in_fan_out": false, 9 | "inference_mode": false, 10 | "init_lora_weights": true, 11 | "layer_replication": null, 12 | "layers_pattern": null, 13 | "layers_to_transform": null, 14 | "loftq_config": {}, 15 | "lora_alpha": 64, 16 | "lora_bias": false, 17 | "lora_dropout": 0, 18 | "megatron_config": null, 19 | "megatron_core": "megatron.core", 20 | "modules_to_save": null, 21 | "peft_type": "LORA", 22 | "r": 64, 23 | "rank_pattern": {}, 24 | "revision": null, 25 | "target_modules": [ 26 | "o_proj", 27 | "gate_proj", 28 | "v_proj", 29 | "up_proj", 30 | "q_proj", 31 | "down_proj", 32 | "k_proj" 33 | ], 34 | "task_type": "CAUSAL_LM", 35 | "use_dora": false, 36 | "use_rslora": false 37 | } -------------------------------------------------------------------------------- /grpo_trainer_lora_model/ds_config_stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": true 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupLR", 16 | "params": { 17 | "warmup_min_lr": "auto", 18 | "warmup_max_lr": "auto", 19 | "warmup_num_steps": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "contiguous_gradients": true, 29 | "overlap_comm": true, 30 | "allgather_partitions": true, 31 | "allgather_bucket_size": 5e8, 32 | "reduce_scatter": true, 33 | "reduce_bucket_size": 5e8 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 2000, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } 42 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "bioreason" 7 | version = "0.1.0" 8 | description = "Bio-related Reasoning with Language Models" 9 | readme = "README.md" 10 | requires-python = ">=3.11" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "Programming Language :: Python :: 3.11", 14 | "License :: OSI Approved :: MIT License", 15 | "Operating System :: OS Independent", 16 | ] 17 | dependencies = [ 18 | "torch", 19 | "torchvision", 20 | "transformers", 21 | "accelerate", 22 | "qwen-vl-utils", 23 | "jupyter", 24 | "datasets", 25 | "peft", 26 | "pytorch_lightning", 27 | "wandb", 28 | "trl[vllm]", 29 | "bitsandbytes", 30 | "deepspeed", 31 | ] 32 | 33 | [project.optional-dependencies] 34 | dev = [ 35 | "pytest", 36 | "black", 37 | "isort", 38 | "mypy", 39 | ] 40 | 41 | [tool.setuptools] 42 | packages = ["bioreason"] 43 | 44 | [tool.black] 45 | line-length = 88 46 | target-version = ["py311"] 47 | 48 | [tool.isort] 49 | profile = "black" 50 | line_length = 88 51 | 52 | [tool.mypy] 53 | python_version = "3.11" 54 | warn_return_any = true 55 | warn_unused_configs = true 56 | disallow_untyped_defs = true 57 | disallow_incomplete_defs = true -------------------------------------------------------------------------------- /reason.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import pathlib 5 | from argparse import ArgumentParser 6 | from typing import List, Dict, Optional 7 | from dataclasses import dataclass, field 8 | 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | from torch.optim import AdamW 13 | from torch.utils.data import DataLoader, Dataset 14 | from transformers import get_cosine_schedule_with_warmup, AutoTokenizer 15 | 16 | from transformers import ( 17 | AutoTokenizer, 18 | AutoModelForCausalLM, 19 | AutoModelForMaskedLM, 20 | AutoProcessor, 21 | ) 22 | 23 | from datasets import load_dataset, DatasetDict 24 | 25 | from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training 26 | from transformers import BitsAndBytesConfig 27 | 28 | import pytorch_lightning as pl 29 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 30 | from pytorch_lightning.loggers import WandbLogger 31 | 32 | from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config 33 | #from unsloth import FastLanguageModel, is_bfloat16_supported 34 | 35 | from bioreason.models.dna_llm import DNALLMModel 36 | from bioreason.dna_modules import NucleotideDNAModule 37 | from bioreason.models.dl.processing_dl import DLProcessor 38 | from bioreason.trainer import DNALLMGRPOTrainer, DNALLMGRPOConfig 39 | from bioreason.models.evo2_tokenizer import Evo2Tokenizer, register_evo2_tokenizer 40 | register_evo2_tokenizer() 41 | 42 | # Custom TrainerCallback to override the saving mechanism 43 | from transformers import TrainerCallback, TrainerState, TrainerControl 44 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 45 | 46 | class SaveWithPyTorchCallback(TrainerCallback): 47 | """Custom callback to save models with PyTorch's native save mechanism instead of safetensors""" 48 | def on_save(self, args, state, control, **kwargs): 49 | # Get the checkpoint folder 50 | checkpoint_folder = os.path.join( 51 | args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" 52 | ) 53 | os.makedirs(checkpoint_folder, exist_ok=True) 54 | 55 | # Save with PyTorch instead of safetensors 56 | checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin") 57 | model = kwargs.get("model") 58 | 59 | # Get model unwrapped from accelerator etc. 60 | unwrapped_model = model.module if hasattr(model, "module") else model 61 | 62 | # Save using PyTorch directly 63 | torch.save(unwrapped_model.state_dict(), checkpoint_path) 64 | 65 | # DNALLMModel doesn't have a direct config attribute, so we need to save 66 | # the configs of its sub-models 67 | if hasattr(unwrapped_model, "text_model"): 68 | if hasattr(unwrapped_model.text_model, "config"): 69 | unwrapped_model.text_model.config.save_pretrained(checkpoint_folder) 70 | # Handle PEFT models which might have base_model 71 | elif hasattr(unwrapped_model.text_model, "base_model") and hasattr(unwrapped_model.text_model.base_model, "config"): 72 | unwrapped_model.text_model.base_model.config.save_pretrained(checkpoint_folder) 73 | 74 | # Print info about what's being saved 75 | print(f"Saved model checkpoint to {checkpoint_folder}") 76 | lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k] 77 | print(f"Checkpoint contains {len(lora_params)} LoRA parameters") 78 | 79 | # Signal that we've saved 80 | control.should_save = False 81 | return control 82 | 83 | def _get_target_modules(model: DNALLMModel): 84 | # Apply LoRA to all linear layers in the text model 85 | target_modules = [] 86 | 87 | # Get all unique linear layer names 88 | seen_names = set() 89 | for name, module in model.text.named_modules(): 90 | if isinstance(module, torch.nn.Linear): 91 | names = name.split(".") 92 | target_name = names[-1] # Use the last part of the name 93 | 94 | # Skip output head but include all other linear layers 95 | if target_name != "lm_head" and target_name not in seen_names: 96 | target_modules.append(target_name) 97 | seen_names.add(target_name) 98 | 99 | # Add attention-specific layers 100 | attention_patterns = [ 101 | "q_proj", 102 | "k_proj", 103 | "v_proj", 104 | "out_proj", 105 | "query", 106 | "key", 107 | "value", 108 | ] 109 | for pattern in attention_patterns: 110 | if pattern not in seen_names: 111 | target_modules.append(pattern) 112 | 113 | # Return all unique layer names to apply LoRA to all layers 114 | return list(target_modules) 115 | 116 | 117 | def extract_xml_answer(text: str) -> str: 118 | # answer = text.split("")[-1] 119 | # answer = answer.split("")[0] 120 | answer = text.split("")[-1] 121 | return answer.strip() 122 | 123 | def extract_hash_answer(text: str) -> str | None: 124 | if "####" not in text: 125 | return None 126 | return text.split("####")[1].strip() 127 | 128 | def get_kegg_questions() -> Dataset: 129 | data = load_dataset('wanglab/kegg', 'default') # type: ignore 130 | example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"] 131 | num_dna_sequences = 2 # TODO: Change to 2! 132 | 133 | data = data.map(lambda x: { # type: ignore 134 | 'prompt': [ 135 | 136 | { 137 | 'role': 'user', 138 | 'content': [ 139 | *({'type': 'dna', 'text': None} for _ in range(num_dna_sequences)), 140 | {'type': 'text', 'text': x['question']}, 141 | ], 142 | }, 143 | ], 144 | 'dna_sequences': [x['reference_sequence'], x['variant_sequence']], 145 | 'answer': x['answer'], 146 | }) # type: ignore 147 | 148 | return data 149 | 150 | # uncomment middle messages for 1-shot prompting 151 | def get_gsm8k_questions(question_prompt: str) -> Dataset: 152 | data = load_dataset('openai/gsm8k', 'main') # type: ignore 153 | 154 | example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"] 155 | data = data.map(lambda x: { # type: ignore 156 | 'prompt': [ 157 | 158 | { 159 | 'role': 'user', 160 | 'content': [ 161 | *({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))), 162 | {'type': 'text', 'text': 'Give me a short introduction to large language model.'} 163 | ] 164 | }, 165 | ], 166 | 'dna_sequences': [dna for dna in example_dna_sequences], 167 | 'answer': extract_hash_answer(x['answer']), 168 | }) # type: ignore 169 | 170 | return data # type: ignore 171 | 172 | def get_gsm8k_questions_old(question_prompt: str) -> Dataset: 173 | data = load_dataset('openai/gsm8k', 'main') # type: ignore 174 | 175 | example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"] 176 | data = data.map(lambda x: { # type: ignore 177 | 'prompt': [ 178 | { 179 | 'role': 'user', 180 | 'content': [ 181 | *({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))), 182 | {'type': 'text', 'text': question_prompt.format(Question=x['question'])} 183 | ] 184 | }, 185 | ], 186 | 'dna_sequences': [dna for dna in example_dna_sequences], 187 | 'answer': extract_hash_answer(x['answer']), 188 | }) # type: ignore 189 | 190 | return data # type: ignore 191 | 192 | # Reward functions 193 | def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: 194 | responses = [completion[0]['content'] for completion in completions] 195 | q = prompts[0][-1]['content'] 196 | extracted_responses = [extract_xml_answer(r) for r in responses] 197 | # extracted_responses = [r.lower().replace("answer:", "").strip() for r in extracted_responses] 198 | print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") 199 | return [2.0 if a.lower() in r.lower() else 0.0 for r, a in zip(extracted_responses, answer[0])] 200 | 201 | def less_than_4_reward_func(completions, **kwargs) -> list[float]: 202 | responses = [completion[0]['content'] for completion in completions] 203 | extracted_responses = [extract_xml_answer(r) for r in responses] 204 | return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses] 205 | 206 | def strict_format_reward_func(completions, **kwargs) -> list[float]: 207 | """Reward function that checks if the completion has a specific format.""" 208 | pattern = r"^\n.*?\n\n.*?\n$" 209 | responses = [completion[0]["content"] for completion in completions] 210 | matches = [re.match(pattern, r) for r in responses] 211 | return [0.5 if match else 0.0 for match in matches] 212 | 213 | def soft_format_reward_func(completions, **kwargs) -> list[float]: 214 | """Reward function that checks if the completion has a specific format.""" 215 | pattern = r".*?\s*.*?" 216 | responses = [completion[0]["content"] for completion in completions] 217 | matches = [re.match(pattern, r) for r in responses] 218 | return [0.5 if match else 0.0 for match in matches] 219 | 220 | def count_xml(text) -> float: 221 | count = 0.0 222 | if text.count("\n") == 1: 223 | count += 0.125 224 | if text.count("\n\n") == 1: 225 | count += 0.125 226 | return count 227 | 228 | def xmlcount_reward_func(completions, **kwargs) -> list[float]: 229 | contents = [completion[0]["content"] for completion in completions] 230 | return [count_xml(c) for c in contents] 231 | 232 | # Format into conversation 233 | def make_conversation(example): 234 | return { 235 | "prompt": [ 236 | {"role": "system", "content": SYSTEM_PROMPT}, 237 | {"role": "user", "content": example["problem"]}, 238 | ], 239 | } 240 | 241 | def make_conversation_image(example): 242 | return { 243 | "prompt": [ 244 | { 245 | "role": "user", 246 | "content": [ 247 | {"type": "image"}, 248 | ], 249 | }, 250 | ], 251 | } 252 | 253 | @dataclass 254 | class GRPOModelConfig(ModelConfig): 255 | 256 | # "HuggingFaceTB/SmolLM-135M-Instruct" 257 | # "Qwen/Qwen2.5-0.5B-Instruct" 258 | model_name_or_path: str = field(default="Qwen/Qwen3-0.6B", metadata={"help": "Model checkpoint for weights initialization."}) 259 | dna_model_name_or_path: str = field(default="InstaDeepAI/nucleotide-transformer-v2-100m-multi-species", metadata={"help": "Model checkpoint for weights initialization."}) 260 | cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."}) 261 | max_length_text: int = field(default=800, metadata={"help": "Maximum length of text sequences."}) 262 | max_length_dna: int = field(default=800, metadata={"help": "Maximum length of DNA sequences, in groups of 6 nucleotides."}) 263 | sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."}) 264 | lora_r: int = field(default=32, metadata={"help": "LoRA R value."}) 265 | lora_alpha: int = field(default=64, metadata={"help": "LoRA alpha."}) 266 | lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout."}) 267 | lora_modules_to_save: Optional[list[str]] = field( 268 | default="embed_tokens", 269 | metadata={"help": "Model layers to unfreeze & train."}, 270 | ) 271 | freeze_dna_modules: bool = False 272 | 273 | @dataclass 274 | class GRPOScriptArguments(ScriptArguments): 275 | """ 276 | Script arguments for the GRPO training script. 277 | """ 278 | dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."}) 279 | data_file_paths: str = field( 280 | default=None, 281 | metadata={"help": "Paths to data files, separated by ':'"}, 282 | ) 283 | arrow_cache_dir: str = field( 284 | default=None, 285 | metadata={"help": "Path to arrow cache directory"}, 286 | ) 287 | val_split_ratio: float = field( 288 | default=0.0, 289 | metadata={"help": "Ratio of validation split, default 0.0"}, 290 | ) 291 | reward_funcs: list[str] = field( 292 | #default_factory=lambda: ["accuracy", "format"], 293 | default_factory=lambda: ["xmlcount", "soft_format", "strict_format", "less_than_4", "correctness"], 294 | #metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}, 295 | metadata={"help": "List of reward functions. Possible values: 'accuracy', 'xmlcount', 'soft_format', 'strict_format', 'less_than_4', 'correctness'"}, 296 | ) 297 | # max_pixels: Optional[int] = field( 298 | # default=12845056, 299 | # metadata={"help": "Maximum number of pixels for the image (for QwenVL)"}, 300 | # ) 301 | # min_pixels: Optional[int] = field( 302 | # default=3136, 303 | # metadata={"help": "Minimum number of pixels for the image (for QwenVL)"}, 304 | # ) 305 | # task_type: Optional[str] = field( 306 | # default=None, 307 | # metadata={"help": "Choose task type: 'default', 'gui', ..."}, 308 | # ) 309 | 310 | 311 | 312 | reward_funcs_registry = { 313 | # "accuracy": accuracy_reward, 314 | # "format": format_reward, 315 | "xmlcount": xmlcount_reward_func, 316 | "soft_format": soft_format_reward_func, 317 | "strict_format": strict_format_reward_func, 318 | "less_than_4": less_than_4_reward_func, 319 | "correctness": correctness_reward_func, 320 | } 321 | 322 | def get_vlm_module(model_name_or_path): 323 | if any(mini_name in model_name_or_path.lower() for mini_name in ["qwen", "smol"]): 324 | return NucleotideDNAModule 325 | else: 326 | raise ValueError(f"Unsupported model: {model_name_or_path}") 327 | 328 | def _get_target_modules(model): 329 | # Apply LoRA to all linear layers in the text model 330 | target_modules = [] 331 | 332 | # Get all unique linear layer names 333 | seen_names = set() 334 | for name, module in model.text_model.named_modules(): 335 | if isinstance(module, torch.nn.Linear): 336 | names = name.split(".") 337 | target_name = names[-1] # Use the last part of the name 338 | 339 | # Skip output head but include all other linear layers 340 | if target_name != "lm_head" and target_name not in seen_names: 341 | target_modules.append(target_name) 342 | seen_names.add(target_name) 343 | 344 | # Add attention-specific layers 345 | attention_patterns = [ 346 | "q_proj", 347 | "k_proj", 348 | "v_proj", 349 | "out_proj", 350 | "query", 351 | "key", 352 | "value", 353 | ] 354 | for pattern in attention_patterns: 355 | if pattern not in seen_names: 356 | target_modules.append(pattern) 357 | 358 | # Return all unique layer names to apply LoRA to all layers 359 | return list(target_modules) 360 | 361 | 362 | def _prep_for_training(model, training_args, dna_model_finetune: bool = False) -> LoraConfig: 363 | """ 364 | Load and configure the DNALLMModel. 365 | """ 366 | 367 | # Freeze DNA encoder parameters 368 | if dna_model_finetune: 369 | pass 370 | else: 371 | for param in model.dna_model.parameters(): 372 | param.requires_grad = False 373 | 374 | target_modules = _get_target_modules(model) 375 | 376 | lora_config = LoraConfig( 377 | r=training_args.lora_r, 378 | lora_alpha=training_args.lora_alpha, 379 | lora_dropout=training_args.lora_dropout, 380 | target_modules=target_modules, 381 | init_lora_weights="gaussian", 382 | bias="none", 383 | task_type="CAUSAL_LM", 384 | ) 385 | 386 | # Prepare text model for training 387 | model.text_model = prepare_model_for_kbit_training(model.text_model) 388 | model.text_model = get_peft_model(model.text_model, lora_config) 389 | 390 | # Make projection layer trainable 391 | for param in model.dna_projection.parameters(): 392 | param.requires_grad = True 393 | 394 | return lora_config 395 | 396 | def main(script_args, training_args, model_args): 397 | 398 | print(training_args.output_dir) 399 | #pl.seed_everything(args.seed) 400 | # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 401 | torch.cuda.empty_cache() 402 | torch.set_float32_matmul_precision("medium") 403 | 404 | # Initialize model 405 | # Load tokenizer for target text 406 | # tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) 407 | # tokenizer.pad_token = tokenizer.eos_token 408 | 409 | # Load model 410 | model = DNALLMModel( 411 | text_model_name=model_args.model_name_or_path, 412 | dna_model_name=model_args.dna_model_name_or_path, 413 | cache_dir=model_args.cache_dir, 414 | max_length_text=model_args.max_length_text, 415 | max_length_dna=model_args.max_length_dna, 416 | text_model_finetune=True, 417 | dna_model_finetune=not model_args.freeze_dna_modules, 418 | debug=False, 419 | ) 420 | 421 | # load checkpoint 422 | if model_args.sft_checkpoint is not None: 423 | print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}") 424 | 425 | # Determine if it's a directory (PEFT format) or file (PyTorch state dict) 426 | is_directory = os.path.isdir(model_args.sft_checkpoint) 427 | 428 | if is_directory: 429 | # It's a PEFT checkpoint directory - load properly with PEFT 430 | from peft import PeftModel 431 | 432 | # First initialize the text model with PEFT 433 | print("Loading as PEFT checkpoint directory") 434 | model.text_model = PeftModel.from_pretrained( 435 | model.text_model, 436 | model_args.sft_checkpoint, 437 | is_trainable=True 438 | ) 439 | 440 | # Verify loaded adapters 441 | print("Loaded LoRA adapters:", model.text_model.active_adapter) 442 | 443 | # Optional: Merge weights into base model 444 | print("Merging SFT LoRA weights into base model...") 445 | model.text_model = model.text_model.merge_and_unload() 446 | print("Successfully merged SFT knowledge into base model") 447 | 448 | else: 449 | # It's a PyTorch state dict file 450 | print("Loading as PyTorch state dict file") 451 | checkpoint = torch.load(model_args.sft_checkpoint) 452 | 453 | # replace model.text_model with text_model for all in state dict 454 | def new_key(k): 455 | if k.startswith("=model."): return k[6:] 456 | elif k.startswith("_forward_module."): return k[len("_forward_module."):] 457 | else: return k 458 | 459 | if "state_dict" in checkpoint: 460 | magic = {new_key(k): v for k, v in checkpoint["state_dict"].items()} 461 | elif "module" in checkpoint: 462 | magic = {new_key(k): v for k, v in checkpoint["module"].items()} 463 | elif isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()): 464 | # Direct state dict - the checkpoint itself is the state dict 465 | print("Detected direct state dict format") 466 | magic = {new_key(k): v for k, v in checkpoint.items()} 467 | else: 468 | raise ValueError(f"Unsupported checkpoint format: {model_args.sft_checkpoint}") 469 | 470 | # Handle prefix mapping for different model architectures 471 | lora_prefix = False 472 | for key in magic.keys(): 473 | if "lora" in key: 474 | lora_prefix = True 475 | break 476 | 477 | if lora_prefix: 478 | print("Detected LoRA weights in state dict") 479 | # First prepare model for LoRA training 480 | _prep_for_training(model, model_args, dna_model_finetune=model_args.freeze_dna_modules) 481 | 482 | # Print some diagnostic info about the keys 483 | model_keys = set(model.state_dict().keys()) 484 | checkpoint_keys = set(magic.keys()) 485 | print(f"Model has {len(model_keys)} keys") 486 | print(f"Checkpoint has {len(checkpoint_keys)} keys") 487 | 488 | # Try to map LoRA keys more intelligently 489 | new_magic = {} 490 | for k, v in magic.items(): 491 | # Try different prefix mappings based on common patterns 492 | if "base_model.model" in k and k not in model_keys: 493 | new_k = k.replace("text_model.base_model.model", "text_model") 494 | if new_k in model_keys: 495 | new_magic[new_k] = v 496 | continue 497 | 498 | # Try removing common prefixes 499 | if k.startswith("text_model.") and k not in model_keys: 500 | new_k = "text_model.base_model.model." + k[len("text_model."):] 501 | if new_k in model_keys: 502 | new_magic[new_k] = v 503 | continue 504 | 505 | # Keep original key if no mapping found 506 | new_magic[k] = v 507 | 508 | # Include missing target modules in diagnostic info 509 | magic = new_magic 510 | print(f"After key mapping: {len(magic)} keys") 511 | 512 | # Then load weights, allowing missing/extra keys 513 | result = model.load_state_dict(magic, strict=False) 514 | 515 | if len(result.unexpected_keys) > 0: 516 | print(f"Sample unexpected keys: {result.unexpected_keys[:5]}") 517 | if len(result.missing_keys) > 0: 518 | print(f"Sample missing keys: {result.missing_keys[:5]}") 519 | 520 | print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys") 521 | else: 522 | print("Standard weights detected - remapping keys") 523 | # Map keys to model structure 524 | magic = {k.replace("text_model", "text_model.base_model.model"): v for k, v in magic.items()} 525 | magic = {k.replace("dna_model", "dna_model"): v for k, v in magic.items()} 526 | 527 | # Fix the shared memory tensors issue by making a copy of weights 528 | for key in list(magic.keys()): 529 | if 'lm_head.weight' in key: 530 | magic[key] = magic[key].clone() 531 | 532 | # Load weights before setting up LoRA 533 | result = model.load_state_dict(magic, strict=False) 534 | print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys") 535 | 536 | # Now prepare for LoRA training 537 | _prep_for_training(model, model_args, dna_model_finetune=model_args.freeze_dna_modules) 538 | else: 539 | # No checkpoint, just prepare for training 540 | _prep_for_training(model, model_args, dna_model_finetune=model_args.freeze_dna_modules) 541 | 542 | # Get reward functions 543 | reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] 544 | # reward_funcs = [ 545 | # xmlcount_reward_func, 546 | # soft_format_reward_func, 547 | # strict_format_reward_func, 548 | # int_reward_func, 549 | # correctness_reward_func, 550 | # ] 551 | print("reward_funcs:", reward_funcs) 552 | 553 | vlm_module_cls = get_vlm_module(model_args.model_name_or_path) 554 | print("using vlm module:", vlm_module_cls.__name__) 555 | question_prompt = vlm_module_cls.get_question_template() 556 | 557 | 558 | dataset = get_kegg_questions() 559 | 560 | #dataset = get_gsm8k_questions(question_prompt) 561 | 562 | print(dataset) 563 | 564 | #print('ITEM ONE OF THE DATASET', dataset['train'][0]) 565 | 566 | # Custom callback to handle saving with PyTorch's native mechanism 567 | custom_save_callback = SaveWithPyTorchCallback() 568 | 569 | # Initialize the GRPO trainer with custom callback 570 | trainer = DNALLMGRPOTrainer( 571 | model=model, 572 | reward_funcs=reward_funcs, 573 | args=training_args, 574 | dna_module=vlm_module_cls(), 575 | train_dataset=dataset['train'], 576 | eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None, 577 | peft_config=get_peft_config(model_args), 578 | attn_implementation=model_args.attn_implementation, 579 | torch_dtype=model_args.torch_dtype, 580 | callbacks=[custom_save_callback], # Add our custom callback 581 | ) 582 | 583 | # Set the trainer to save in PyTorch format instead of safetensors 584 | training_args.save_safetensors = False 585 | 586 | # Train and push the model to the Hub 587 | # if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 588 | # trainer.train(resume_from_checkpoint=True) 589 | # else: 590 | # trainer.train() 591 | 592 | # Train and push the model to the Hub 593 | trainer.train() 594 | 595 | 596 | if __name__ == "__main__": 597 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 598 | print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}") 599 | parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, GRPOModelConfig)) 600 | script_args, training_args, model_args = parser.parse_args_and_config() 601 | 602 | # Ensure we use PyTorch's save mechanism instead of safetensors 603 | training_args.save_safetensors = False 604 | 605 | main(script_args, training_args, model_args) 606 | 607 | # parser.add_argument("--wandb_project", type=str, default="dna-text-finetune") 608 | # parser.add_argument("--wandb_entity", type=str, default="adibvafa") 609 | 610 | # args = parser.parse_args() 611 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | transformers 4 | accelerate 5 | qwen-vl-utils 6 | jupyter 7 | datasets 8 | peft 9 | pytorch_lightning 10 | wandb 11 | trl[vllm] 12 | bitsandbytes 13 | deepspeed -------------------------------------------------------------------------------- /sh_reason.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=Qwen3_1.7B_SFT_RL # Name of the job 3 | #SBATCH --gres=gpu:4 # Number of GPUs 4 | #SBATCH -p a100 # Partition 5 | #SBATCH -c 12 # Number of cores 6 | #SBATCH --time=12:00:00 # Time limit 7 | #SBATCH --mem=128gb # Memory limit 8 | #SBATCH --output=Qwen3_1.7B_SFT_RL_a100-%j.out # Output file 9 | #SBATCH --error=Qwen3_1.7B_SFT_RL_a100-%j.err # Error file 10 | 11 | ## Environment Setup 12 | echo "CUDA_HOME: $CUDA_HOME" 13 | echo "PATH: $PATH" 14 | echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" 15 | echo "which python: $(which python)" 16 | 17 | ## Configuration Variables 18 | # Change these to match your setup 19 | SFT_CHECKPOINT=SFT_CHECKPOINT # Change to the checkpoint of the SFT model 20 | CACHE_DIR=CACHE_DIR # Change to the directory where the model weights are cached 21 | OUTPUT_DIR=OUTPUT_DIR # Change to the directory where the model will be saved 22 | CONDA_ENV=CONDA_ENV # Change to the conda environment 23 | 24 | ## Setup Environment 25 | conda activate $CONDA_ENV # Change to the conda environment 26 | cd .../BioReason/ # Change to the directory containing the script 27 | nvidia-smi # Check GPU status 28 | 29 | ## Dependencies 30 | # You might need to install this on a gpu session 31 | # pip install trl[vllm] 32 | 33 | ## ============================================================================= 34 | ## Reinforcement Learning Training with DeepSpeed 35 | ## ============================================================================= 36 | 37 | # Run with DeepSpeed ZeRO Stage 2 38 | srun deepspeed --num_gpus=4 --num_nodes=1 \ 39 | reason.py \ 40 | --deepspeed grpo_trainer_lora_model/ds_config_stage2.json \ 41 | --num_generations 4 \ 42 | --per_device_train_batch_size 2 \ 43 | --bf16 true \ 44 | --ddp_find_unused_parameters false \ 45 | --sft_checkpoint $SFT_CHECKPOINT \ 46 | --model_name_or_path Qwen/Qwen3-1.7B \ 47 | --dna_model_name_or_path InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \ 48 | --cache_dir $CACHE_DIR \ 49 | --output_dir $OUTPUT_DIR \ 50 | --save_strategy "steps" \ 51 | --save_steps 100 \ 52 | --save_total_limit 2 \ 53 | --use_vllm true \ 54 | --temperature 0.6 \ 55 | --top_p 0.95 \ 56 | --top_k 20 \ 57 | --num_train_epochs 1 58 | -------------------------------------------------------------------------------- /sh_train_dna_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=train_dna # Name of the job 3 | #SBATCH --time=8:00:00 # Time limit 4 | #SBATCH --partition=gpu_batch # Partition 5 | #SBATCH --gpus=1 # Number of GPUs 6 | #SBATCH --ntasks=1 # Number of tasks 7 | #SBATCH --cpus-per-task=6 # Number of cores 8 | #SBATCH --mem=128gb # Memory limit 9 | #SBATCH --output=train_dna_%j_%x.out # Output file 10 | #SBATCH --error=train_dna_%j_%x.err # Error file 11 | 12 | ## Environment Setup 13 | echo "CUDA_HOME: $CUDA_HOME" 14 | echo "PATH: $PATH" 15 | echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" 16 | echo "which python: $(which python)" 17 | 18 | ## Configuration Variables 19 | # Change these to match your setup 20 | CONDA_ENV=CONDA_ENV # Change to your conda environment name 21 | CACHE_DIR=CACHE_DIR # Change to your HuggingFace cache directory 22 | WANDB_PROJECT=WANDB_PROJECT # Change to your W&B project name 23 | 24 | ## Setup Environment 25 | conda activate $CONDA_ENV # Change to your conda environment 26 | cd .../BioReason/ # Change to the directory containing the script 27 | nvidia-smi # Check GPU status 28 | 29 | 30 | ## ============================================================================= 31 | ## KEGG Dataset Training (DNA-only models) 32 | ## ============================================================================= 33 | 34 | # NT-500M on KEGG 35 | stdbuf -oL -eL srun python train_dna_only.py \ 36 | --cache_dir $CACHE_DIR \ 37 | --wandb_project $WANDB_PROJECT \ 38 | --dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \ 39 | --strategy ddp \ 40 | --max_epochs 5 \ 41 | --num_gpus 1 \ 42 | --batch_size 1 \ 43 | --max_length_dna 2048 \ 44 | --truncate_dna_per_side 1024 \ 45 | --train_just_classifier True \ 46 | --learning_rate 3e-4 \ 47 | --dataset_type kegg \ 48 | --merge_val_test_set True 49 | 50 | # EVO2-1B on KEGG 51 | stdbuf -oL -eL srun python train_dna_only.py \ 52 | --cache_dir $CACHE_DIR \ 53 | --wandb_project $WANDB_PROJECT \ 54 | --dna_model_name evo2_1b_base \ 55 | --strategy ddp \ 56 | --max_epochs 5 \ 57 | --num_gpus 1 \ 58 | --batch_size 1 \ 59 | --max_length_dna 2048 \ 60 | --truncate_dna_per_side 1024 \ 61 | --train_just_classifier True \ 62 | --dna_is_evo2 True \ 63 | --dna_embedding_layer blocks.20.mlp.l3 \ 64 | --learning_rate 3e-4 \ 65 | --dataset_type kegg \ 66 | --merge_val_test_set True 67 | 68 | ## ============================================================================= 69 | ## Variant Effect Prediction (VEP) Training 70 | ## ============================================================================= 71 | 72 | # NT-500M on VEP 73 | stdbuf -oL -eL srun python train_dna_only.py \ 74 | --cache_dir $CACHE_DIR \ 75 | --wandb_project $WANDB_PROJECT \ 76 | --dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \ 77 | --strategy ddp \ 78 | --max_epochs 3 \ 79 | --num_gpus 1 \ 80 | --batch_size 2 \ 81 | --max_length_dna 2048 \ 82 | --truncate_dna_per_side 1024 \ 83 | --train_just_classifier True \ 84 | --learning_rate 3e-4 \ 85 | --dataset_type variant_effect_coding 86 | 87 | # EVO2-1B on VEP 88 | stdbuf -oL -eL srun python train_dna_only.py \ 89 | --cache_dir $CACHE_DIR \ 90 | --wandb_project $WANDB_PROJECT \ 91 | --dna_model_name evo2_1b_base \ 92 | --strategy ddp \ 93 | --max_epochs 3 \ 94 | --num_gpus 1 \ 95 | --batch_size 2 \ 96 | --max_length_dna 2048 \ 97 | --truncate_dna_per_side 1024 \ 98 | --train_just_classifier True \ 99 | --dna_is_evo2 True \ 100 | --dna_embedding_layer blocks.20.mlp.l3 \ 101 | --learning_rate 3e-4 \ 102 | --dataset_type variant_effect_coding 103 | 104 | ## ============================================================================= 105 | ## Variant Effect Prediction Non-SNV Training 106 | ## ============================================================================= 107 | 108 | # NT-500M on VEP Non-SNV 109 | stdbuf -oL -eL srun python train_dna_only.py \ 110 | --cache_dir $CACHE_DIR \ 111 | --wandb_project $WANDB_PROJECT \ 112 | --dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \ 113 | --strategy ddp \ 114 | --max_epochs 3 \ 115 | --num_gpus 1 \ 116 | --batch_size 2 \ 117 | --max_length_dna 2048 \ 118 | --truncate_dna_per_side 1024 \ 119 | --train_just_classifier True \ 120 | --learning_rate 3e-4 \ 121 | --dataset_type variant_effect_non_snv 122 | 123 | # EVO2-1B on VEP Non-SNV 124 | stdbuf -oL -eL srun python train_dna_only.py \ 125 | --cache_dir $CACHE_DIR \ 126 | --wandb_project $WANDB_PROJECT \ 127 | --dna_model_name evo2_1b_base \ 128 | --strategy ddp \ 129 | --max_epochs 3 \ 130 | --num_gpus 1 \ 131 | --batch_size 2 \ 132 | --max_length_dna 2048 \ 133 | --truncate_dna_per_side 1024 \ 134 | --train_just_classifier True \ 135 | --dna_is_evo2 True \ 136 | --dna_embedding_layer blocks.20.mlp.l3 \ 137 | --learning_rate 3e-4 \ 138 | --dataset_type variant_effect_non_snv -------------------------------------------------------------------------------- /sh_train_dna_qwen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=train_dna_qwen # Name of the job 3 | #SBATCH --time=12:00:00 # Time limit 4 | #SBATCH --partition=gpu_batch # Partition 5 | #SBATCH --gpus=1 # Number of GPUs 6 | #SBATCH --ntasks=1 # Number of tasks 7 | #SBATCH --cpus-per-task=8 # Number of cores 8 | #SBATCH --mem=128gb # Memory limit 9 | #SBATCH --output=train_dna_qwen_%j_%x.out # Output file 10 | #SBATCH --error=train_dna_qwen_%j_%x.err # Error file 11 | 12 | ## Environment Setup 13 | echo "CUDA_HOME: $CUDA_HOME" 14 | echo "PATH: $PATH" 15 | echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH" 16 | echo "which python: $(which python)" 17 | 18 | ## Configuration Variables 19 | # Change these to match your setup 20 | CONDA_ENV=CONDA_ENV # Change to your conda environment name 21 | CACHE_DIR=CACHE_DIR # Change to your HuggingFace cache directory 22 | OUTPUT_DIR=OUTPUT_DIR # Change to your output/log directory 23 | WANDB_PROJECT=WANDB_PROJECT # Change to your W&B project name 24 | 25 | ## Setup Environment 26 | conda activate $CONDA_ENV # Change to your conda environment 27 | cd .../BioReason/ # Change to the directory containing the script 28 | nvidia-smi # Check GPU status 29 | 30 | 31 | ## ============================================================================= 32 | ## KEGG Dataset Training 33 | ## ============================================================================= 34 | 35 | # NT-500M + Qwen3-1.7B on KEGG 36 | stdbuf -oL -eL srun python train_dna_qwen.py \ 37 | --cache_dir $CACHE_DIR \ 38 | --wandb_project $WANDB_PROJECT \ 39 | --text_model_name Qwen/Qwen3-1.7B \ 40 | --dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \ 41 | --strategy deepspeed_stage_2 \ 42 | --max_epochs 5 \ 43 | --num_gpus 1 \ 44 | --batch_size 1 \ 45 | --model_type dna-llm \ 46 | --dataset_type kegg \ 47 | --merge_val_test_set True \ 48 | --return_answer_in_batch True 49 | 50 | # EVO2-1B + Qwen3-1.7B on KEGG 51 | stdbuf -oL -eL srun python train_dna_qwen.py \ 52 | --cache_dir $CACHE_DIR \ 53 | --wandb_project $WANDB_PROJECT \ 54 | --text_model_name Qwen/Qwen3-1.7B \ 55 | --dna_model_name evo2_1b_base \ 56 | --strategy deepspeed_stage_2 \ 57 | --max_epochs 5 \ 58 | --num_gpus 1 \ 59 | --batch_size 1 \ 60 | --model_type dna-llm \ 61 | --dataset_type kegg \ 62 | --max_length_dna 2048 \ 63 | --truncate_dna_per_side 1024 \ 64 | --dna_is_evo2 True \ 65 | --dna_embedding_layer blocks.20.mlp.l3 \ 66 | --merge_val_test_set True \ 67 | --return_answer_in_batch True 68 | 69 | # Qwen3-4B on KEGG (LLM-only) 70 | stdbuf -oL -eL srun python train_dna_qwen.py \ 71 | --cache_dir $CACHE_DIR \ 72 | --wandb_project $WANDB_PROJECT \ 73 | --text_model_name Qwen/Qwen3-4B \ 74 | --dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \ 75 | --strategy deepspeed_stage_2 \ 76 | --max_epochs 5 \ 77 | --num_gpus 1 \ 78 | --batch_size 1 \ 79 | --model_type llm \ 80 | --dataset_type kegg \ 81 | --max_length_dna 4 \ 82 | --max_length_text 8192 \ 83 | --truncate_dna_per_side 1024 \ 84 | --merge_val_test_set True \ 85 | --return_answer_in_batch True 86 | 87 | ## ============================================================================= 88 | ## Variant Effect Prediction (VEP) Training 89 | ## ============================================================================= 90 | 91 | # NT-500M + Qwen3-4B on VEP 92 | stdbuf -oL -eL srun python train_dna_qwen.py \ 93 | --cache_dir $CACHE_DIR \ 94 | --wandb_project $WANDB_PROJECT \ 95 | --text_model_name Qwen/Qwen3-4B \ 96 | --dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \ 97 | --strategy deepspeed_stage_2 \ 98 | --max_epochs 3 \ 99 | --num_gpus 1 \ 100 | --batch_size 2 \ 101 | --model_type dna-llm \ 102 | --dataset_type variant_effect_coding \ 103 | --return_answer_in_batch True 104 | 105 | # EVO2-1B + Qwen3-1.7B on VEP 106 | stdbuf -oL -eL srun python train_dna_qwen.py \ 107 | --cache_dir $CACHE_DIR \ 108 | --wandb_project $WANDB_PROJECT \ 109 | --text_model_name Qwen/Qwen3-1.7B \ 110 | --dna_model_name evo2_1b_base \ 111 | --strategy deepspeed_stage_2 \ 112 | --max_epochs 3 \ 113 | --num_gpus 1 \ 114 | --batch_size 2 \ 115 | --model_type dna-llm \ 116 | --dataset_type variant_effect_coding \ 117 | --max_length_dna 2048 \ 118 | --truncate_dna_per_side 1024 \ 119 | --dna_is_evo2 True \ 120 | --dna_embedding_layer blocks.20.mlp.l3 \ 121 | --return_answer_in_batch True 122 | 123 | # Qwen3-4B on VEP (LLM-only) - Testing max length text 124 | stdbuf -oL -eL srun python train_dna_qwen.py \ 125 | --cache_dir $CACHE_DIR \ 126 | --wandb_project $WANDB_PROJECT \ 127 | --text_model_name Qwen/Qwen3-4B \ 128 | --dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \ 129 | --strategy deepspeed_stage_2 \ 130 | --max_epochs 3 \ 131 | --num_gpus 1 \ 132 | --batch_size 2 \ 133 | --model_type llm \ 134 | --dataset_type variant_effect_coding \ 135 | --max_length_dna 4 \ 136 | --max_length_text 4096 \ 137 | --truncate_dna_per_side 1024 \ 138 | --return_answer_in_batch True 139 | 140 | ## ============================================================================= 141 | ## Variant Effect Prediction Non-SNV Training 142 | ## ============================================================================= 143 | 144 | # NT-500M + Qwen3-4B on VEP Non-SNV 145 | stdbuf -oL -eL srun python train_dna_qwen.py \ 146 | --cache_dir $CACHE_DIR \ 147 | --wandb_project $WANDB_PROJECT \ 148 | --text_model_name Qwen/Qwen3-4B \ 149 | --dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \ 150 | --strategy deepspeed_stage_2 \ 151 | --max_epochs 1 \ 152 | --num_gpus 1 \ 153 | --batch_size 2 \ 154 | --model_type dna-llm \ 155 | --dataset_type variant_effect_non_snv \ 156 | --return_answer_in_batch True 157 | 158 | # EVO2-1B + Qwen3-4B on VEP Non-SNV 159 | stdbuf -oL -eL srun python train_dna_qwen.py \ 160 | --cache_dir $CACHE_DIR \ 161 | --wandb_project $WANDB_PROJECT \ 162 | --text_model_name Qwen/Qwen3-4B \ 163 | --dna_model_name evo2_1b_base \ 164 | --strategy deepspeed_stage_2 \ 165 | --max_epochs 3 \ 166 | --num_gpus 1 \ 167 | --batch_size 2 \ 168 | --model_type dna-llm \ 169 | --dataset_type variant_effect_non_snv \ 170 | --max_length_dna 2048 \ 171 | --truncate_dna_per_side 1024 \ 172 | --dna_is_evo2 True \ 173 | --dna_embedding_layer blocks.20.mlp.l3 \ 174 | --return_answer_in_batch True 175 | 176 | # Qwen3-4B on VEP Non-SNV (LLM-only) - Testing max length text 177 | stdbuf -oL -eL srun python train_dna_qwen.py \ 178 | --cache_dir $CACHE_DIR \ 179 | --wandb_project $WANDB_PROJECT \ 180 | --text_model_name Qwen/Qwen3-4B \ 181 | --dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \ 182 | --strategy deepspeed_stage_2 \ 183 | --max_epochs 1 \ 184 | --num_gpus 1 \ 185 | --batch_size 2 \ 186 | --model_type llm \ 187 | --dataset_type variant_effect_non_snv \ 188 | --max_length_dna 4 \ 189 | --max_length_text 4096 \ 190 | --truncate_dna_per_side 1024 \ 191 | --return_answer_in_batch True -------------------------------------------------------------------------------- /train_dna_only.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import torch 5 | import wandb 6 | from torch.optim import AdamW 7 | from torch.utils.data import DataLoader 8 | from transformers import get_cosine_schedule_with_warmup, AutoTokenizer 9 | from datasets import load_dataset, concatenate_datasets 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 12 | from pytorch_lightning.loggers import WandbLogger 13 | from pytorch_lightning.strategies import DeepSpeedStrategy 14 | from bioreason.models.dna_only import DNAClassifierModel 15 | from bioreason.dataset.base import VariantEffectDataset 16 | from bioreason.dataset.utils import truncate_dna 17 | from bioreason.dataset.kegg import dna_collate_fn 18 | from bioreason.dataset.variant_effect import clean_variant_effect_example 19 | from bioreason.models.evo2_tokenizer import Evo2Tokenizer, register_evo2_tokenizer 20 | register_evo2_tokenizer() 21 | 22 | 23 | class DNAClassifierModelTrainer(pl.LightningModule): 24 | """ 25 | PyTorch Lightning module for training the DNA classifier. 26 | """ 27 | 28 | def __init__(self, args): 29 | """ 30 | Initialize the DNAClassifierModelTrainer. 31 | 32 | Args: 33 | args: Command line arguments 34 | """ 35 | super().__init__() 36 | self.save_hyperparameters(args) 37 | 38 | # Load dataset and labels 39 | self.dataset, self.labels = self.load_dataset() 40 | self.label2id = {label: i for i, label in enumerate(self.labels)} 41 | 42 | # Load model 43 | self.dna_model = DNAClassifierModel( 44 | dna_model_name=self.hparams.dna_model_name, 45 | cache_dir=self.hparams.cache_dir, 46 | max_length_dna=self.hparams.max_length_dna, 47 | num_classes=len(self.labels), 48 | dna_is_evo2=self.hparams.dna_is_evo2, 49 | dna_embedding_layer=self.hparams.dna_embedding_layer, 50 | train_just_classifier=self.hparams.train_just_classifier, 51 | ) 52 | self.dna_tokenizer = self.dna_model.dna_tokenizer 53 | 54 | # Set the training mode for the classifier and pooler 55 | self.dna_model.pooler.train() 56 | self.dna_model.classifier.train() 57 | 58 | # Freeze the DNA model parameters 59 | if self.hparams.dna_is_evo2: 60 | self.dna_model_params = self.dna_model.dna_model.model.parameters() 61 | else: 62 | self.dna_model_params = self.dna_model.dna_model.parameters() 63 | 64 | if self.hparams.train_just_classifier: 65 | for param in self.dna_model_params: 66 | param.requires_grad = False 67 | 68 | def _step(self, prefix, batch_idx, batch): 69 | """ 70 | Performs a single training/validation step. 71 | 72 | Args: 73 | batch: Dictionary containing the batch data 74 | prefix: String indicating the step type ('train' or 'val') 75 | 76 | Returns: 77 | torch.Tensor: The computed loss for this batch 78 | """ 79 | ref_ids = batch["ref_ids"].to(self.device) 80 | alt_ids = batch["alt_ids"].to(self.device) 81 | ref_attention_mask = batch["ref_attention_mask"].to(self.device) 82 | alt_attention_mask = batch["alt_attention_mask"].to(self.device) 83 | labels = batch["labels"].to(self.device) 84 | 85 | # Forward pass 86 | logits = self.dna_model(ref_ids=ref_ids, alt_ids=alt_ids, ref_attention_mask=ref_attention_mask, alt_attention_mask=alt_attention_mask) 87 | 88 | # Calculate loss 89 | loss_fn = torch.nn.CrossEntropyLoss() 90 | loss = loss_fn(logits, labels) 91 | 92 | # Calculate accuracy 93 | preds = torch.argmax(logits, dim=1) 94 | acc = (preds == labels).float().mean() 95 | 96 | # Calculate F1 score, precision, and recall for binary classification 97 | # Assuming label 1 is positive and label 0 is negative as mentioned 98 | true_positives = ((preds == 1) & (labels == 1)).float().sum() 99 | false_positives = ((preds == 1) & (labels == 0)).float().sum() 100 | false_negatives = ((preds == 0) & (labels == 1)).float().sum() 101 | 102 | # Calculate precision, recall, and F1 score 103 | precision = true_positives / (true_positives + false_positives + 1e-8) # add small epsilon to avoid division by zero 104 | recall = true_positives / (true_positives + false_negatives + 1e-8) 105 | f1 = 2 * precision * recall / (precision + recall + 1e-8) 106 | 107 | # Logging metrics 108 | self.log( 109 | f"{prefix}_loss", 110 | loss, 111 | on_step=True, 112 | on_epoch=False, 113 | prog_bar=True, 114 | logger=True, 115 | ) 116 | self.log( 117 | f"{prefix}_acc", 118 | acc, 119 | on_step=True, 120 | on_epoch=False, 121 | prog_bar=True, 122 | logger=True, 123 | ) 124 | self.log( 125 | f"{prefix}_loss_epoch", 126 | loss, 127 | on_step=False, 128 | on_epoch=True, 129 | prog_bar=True, 130 | logger=True, 131 | sync_dist=True, 132 | ) 133 | self.log( 134 | f"{prefix}_acc_epoch", 135 | acc, 136 | on_step=False, 137 | on_epoch=True, 138 | prog_bar=True, 139 | logger=True, 140 | sync_dist=True, 141 | ) 142 | self.log( 143 | f"{prefix}_precision", 144 | precision, 145 | on_step=True, 146 | on_epoch=False, 147 | prog_bar=True, 148 | logger=True, 149 | ) 150 | self.log( 151 | f"{prefix}_precision_epoch", 152 | precision, 153 | on_step=False, 154 | on_epoch=True, 155 | prog_bar=True, 156 | logger=True, 157 | sync_dist=True, 158 | ) 159 | self.log( 160 | f"{prefix}_recall", 161 | recall, 162 | on_step=True, 163 | on_epoch=False, 164 | prog_bar=True, 165 | logger=True, 166 | ) 167 | self.log( 168 | f"{prefix}_recall_epoch", 169 | recall, 170 | on_step=False, 171 | on_epoch=True, 172 | prog_bar=True, 173 | logger=True, 174 | sync_dist=True, 175 | ) 176 | self.log( 177 | f"{prefix}_f1", 178 | f1, 179 | on_step=True, 180 | on_epoch=False, 181 | prog_bar=True, 182 | logger=True, 183 | ) 184 | self.log( 185 | f"{prefix}_f1_epoch", 186 | f1, 187 | on_step=False, 188 | on_epoch=True, 189 | prog_bar=True, 190 | logger=True, 191 | sync_dist=True, 192 | ) 193 | 194 | if (prefix == "test") or (prefix == "train" and (self.global_step % 1000 == 0)) or (prefix == "val" and (batch_idx % 100 == 0)): 195 | wandb_logger = self.logger.experiment 196 | 197 | pred_label = self.labels[preds[0]] 198 | true_label = self.labels[labels[0]] 199 | timestamp = time.time() 200 | step_id = f"gen_{self.global_step}-{timestamp}" 201 | 202 | wandb_logger.log( 203 | { 204 | step_id: wandb.Table( 205 | columns=["timestamp", "prefix", "pred_label", "true_label"], 206 | data=[[timestamp, prefix, pred_label, true_label]], 207 | ) 208 | } 209 | ) 210 | 211 | print(f"Example {prefix} {batch_idx} {self.global_step}: Prediction: {pred_label}, Target: {true_label}") 212 | 213 | return loss 214 | 215 | def training_step(self, batch, batch_idx): 216 | """Perform a training step.""" 217 | return self._step(prefix="train", batch_idx=batch_idx, batch=batch) 218 | 219 | def validation_step(self, batch, batch_idx): 220 | """Perform a validation step.""" 221 | return self._step(prefix="val", batch_idx=batch_idx, batch=batch) 222 | 223 | def test_step(self, batch, batch_idx): 224 | """Perform a test step.""" 225 | return self._step(prefix="test", batch_idx=batch_idx, batch=batch) 226 | 227 | def configure_optimizers(self): 228 | """Configure optimizers and learning rate schedulers.""" 229 | # Only include parameters that require gradients 230 | classifier_params = [ 231 | { 232 | "params": self.dna_model.classifier.parameters(), 233 | "lr": self.hparams.learning_rate, 234 | }, 235 | { 236 | "params": self.dna_model.pooler.parameters(), 237 | "lr": self.hparams.learning_rate, 238 | } 239 | ] 240 | dna_model_params = [ 241 | { 242 | "params": self.dna_model_params, 243 | "lr": self.hparams.learning_rate * 0.1, 244 | }, 245 | ] 246 | 247 | if self.hparams.train_just_classifier: 248 | # Only train classifier parameters 249 | optimizer = AdamW( 250 | classifier_params, 251 | weight_decay=self.hparams.weight_decay, 252 | ) 253 | else: 254 | # Train both DNA model and classifier with different learning rates 255 | optimizer = AdamW( 256 | classifier_params + dna_model_params, 257 | weight_decay=self.hparams.weight_decay, 258 | ) 259 | 260 | # Get total steps from trainer's estimated stepping batches 261 | total_steps = self.trainer.estimated_stepping_batches 262 | warmup_steps = int(0.1 * total_steps) 263 | 264 | # Create scheduler 265 | scheduler = get_cosine_schedule_with_warmup( 266 | optimizer, 267 | num_warmup_steps=warmup_steps, 268 | num_training_steps=total_steps, 269 | ) 270 | 271 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 272 | 273 | def load_dataset(self): 274 | """Load the dataset based on the dataset type.""" 275 | if self.hparams.dataset_type == "variant_effect": 276 | dataset = load_dataset("wanglab/variant_effect_llm_tuning") 277 | dataset = VariantEffectDataset(dataset) 278 | labels = sorted(list(set(item["label"] for item in dataset))) 279 | 280 | 281 | elif self.hparams.dataset_type == "kegg": 282 | dataset = load_dataset(self.hparams.kegg_data_dir_huggingface) 283 | 284 | if self.hparams.truncate_dna_per_side: 285 | dataset = dataset.map( 286 | truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side} 287 | ) 288 | 289 | labels = [] 290 | for split, data in dataset.items(): 291 | labels.extend(data["answer"]) 292 | labels = list(set(labels)) 293 | 294 | elif self.hparams.dataset_type == "variant_effect_coding": 295 | dataset = load_dataset("wanglab/bioR_tasks", "variant_effect_coding") 296 | dataset = dataset.map(clean_variant_effect_example) 297 | 298 | if self.hparams.truncate_dna_per_side: 299 | dataset = dataset.map( 300 | truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side} 301 | ) 302 | 303 | labels = [] 304 | for split, data in dataset.items(): 305 | labels.extend(data["answer"]) 306 | labels = sorted(list(set(labels))) 307 | 308 | elif self.hparams.dataset_type == "variant_effect_non_snv": 309 | dataset = load_dataset("wanglab/bioR_tasks", "task5_variant_effect_non_snv") 310 | dataset = dataset.rename_column("mutated_sequence", "variant_sequence") 311 | dataset = dataset.map(clean_variant_effect_example) 312 | 313 | if self.hparams.truncate_dna_per_side: 314 | dataset = dataset.map( 315 | truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side} 316 | ) 317 | 318 | labels = [] 319 | for split, data in dataset.items(): 320 | labels.extend(data["answer"]) 321 | labels = sorted(list(set(labels))) 322 | 323 | else: 324 | raise ValueError(f"Invalid dataset type: {self.hparams.dataset_type}") 325 | 326 | print(f"Dataset:\n{dataset}\nLabels:\n{labels}\nNumber of labels:{len(labels)}") 327 | return dataset, labels 328 | 329 | def train_dataloader(self): 330 | """Create and return the training DataLoader.""" 331 | if self.hparams.dataset_type == "variant_effect": 332 | train_dataset = VariantEffectDataset(self.dataset["train"]) 333 | collate_fn = lambda b: VariantEffectDataset.collate_fn_dna_classifier(b, self.dna_tokenizer) 334 | 335 | elif self.hparams.dataset_type == "kegg": 336 | train_dataset = self.dataset["train"] 337 | collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna) 338 | 339 | elif self.hparams.dataset_type == "variant_effect_coding": 340 | train_dataset = self.dataset["train"] 341 | collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna) 342 | 343 | elif self.hparams.dataset_type == "variant_effect_non_snv": 344 | train_dataset = self.dataset["train"] 345 | collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna) 346 | 347 | else: 348 | raise ValueError(f"Invalid dataset type: {self.hparams.dataset_type}") 349 | 350 | return DataLoader( 351 | train_dataset, 352 | batch_size=self.hparams.batch_size, 353 | shuffle=True, 354 | collate_fn=collate_fn, 355 | num_workers=self.hparams.num_workers, 356 | persistent_workers=True, 357 | ) 358 | 359 | def val_dataloader(self): 360 | """Create and return the training DataLoader.""" 361 | if self.hparams.dataset_type == "variant_effect": 362 | val_dataset = VariantEffectDataset(self.dataset["test"]) 363 | collate_fn = lambda b: VariantEffectDataset.collate_fn_dna_classifier(b, self.dna_tokenizer) 364 | 365 | elif self.hparams.dataset_type == "kegg": 366 | 367 | if self.hparams.merge_val_test_set: 368 | val_dataset = concatenate_datasets([self.dataset['test'], self.dataset['val']]) 369 | else: 370 | val_dataset = self.dataset["val"] 371 | 372 | collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna) 373 | 374 | elif self.hparams.dataset_type == "variant_effect_coding": 375 | val_dataset = self.dataset["test"] 376 | collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna) 377 | 378 | elif self.hparams.dataset_type == "variant_effect_non_snv": 379 | val_dataset = self.dataset["test"] 380 | collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna) 381 | 382 | else: 383 | raise ValueError(f"Invalid dataset type: {self.hparams.dataset_type}") 384 | 385 | return DataLoader( 386 | val_dataset, 387 | batch_size=self.hparams.batch_size, 388 | shuffle=False, 389 | collate_fn=collate_fn, 390 | num_workers=self.hparams.num_workers, 391 | persistent_workers=True, 392 | ) 393 | 394 | def test_dataloader(self): 395 | """Create and return the test DataLoader.""" 396 | return self.val_dataloader() 397 | 398 | 399 | def main(args): 400 | """Main function to run the training process.""" 401 | # Set random seed and environment variables 402 | pl.seed_everything(args.seed) 403 | # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 404 | torch.cuda.empty_cache() 405 | torch.set_float32_matmul_precision("medium") 406 | 407 | # Initialize model 408 | model = DNAClassifierModelTrainer(args) 409 | 410 | # Setup directories 411 | run_name = f"{args.wandb_project}-{args.dataset_type}-{args.dna_model_name.split('/')[-1]}" 412 | args.checkpoint_dir = f"{args.checkpoint_dir}/{run_name}-{time.strftime('%Y%m%d-%H%M%S')}" 413 | args.output_dir = f"{args.output_dir}/{run_name}-{time.strftime('%Y%m%d-%H%M%S')}" 414 | os.makedirs(args.output_dir, exist_ok=True) 415 | os.makedirs(args.checkpoint_dir, exist_ok=True) 416 | 417 | # Setup callbacks 418 | callbacks = [ 419 | ModelCheckpoint( 420 | dirpath=args.checkpoint_dir, 421 | filename=f"{run_name}-" + "{epoch:02d}-{val_loss_epoch:.4f}", 422 | save_top_k=2, 423 | monitor="val_acc_epoch", 424 | mode="max", 425 | save_last=True, 426 | ), 427 | LearningRateMonitor(logging_interval="step"), 428 | ] 429 | 430 | # Setup logger 431 | is_resuming = args.ckpt_path is not None 432 | logger = WandbLogger( 433 | project=args.wandb_project, 434 | entity=args.wandb_entity, 435 | save_dir=args.log_dir, 436 | name=run_name, 437 | resume="allow" if is_resuming else None, # Allow resuming existing run 438 | ) 439 | 440 | # Initialize trainer 441 | trainer = pl.Trainer( 442 | max_epochs=args.max_epochs, 443 | accelerator="gpu", 444 | devices=args.num_gpus, 445 | strategy=( 446 | "ddp" 447 | if args.strategy == "ddp" 448 | else DeepSpeedStrategy(stage=2, offload_optimizer=False, allgather_bucket_size=5e8, reduce_bucket_size=5e8) 449 | ), 450 | precision="bf16-mixed", 451 | callbacks=callbacks, 452 | logger=logger, 453 | deterministic=False, 454 | enable_checkpointing=True, 455 | enable_progress_bar=True, 456 | enable_model_summary=True, 457 | log_every_n_steps=5, 458 | accumulate_grad_batches=args.gradient_accumulation_steps, 459 | gradient_clip_val=1.0, 460 | val_check_interval=1 / 3, 461 | ) 462 | 463 | # Train model 464 | trainer.fit(model, ckpt_path=args.ckpt_path) 465 | trainer.test(model, ckpt_path=args.ckpt_path if args.ckpt_path else "best") 466 | 467 | # Save final model 468 | final_model_path = os.path.join(args.output_dir, "final_model") 469 | torch.save(model.dna_model.state_dict(), final_model_path) 470 | print(f"Final model saved to {final_model_path}") 471 | 472 | 473 | if __name__ == "__main__": 474 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 475 | parser = argparse.ArgumentParser(description="Train DNA Classifier") 476 | 477 | # Model parameters 478 | parser.add_argument( 479 | "--dna_model_name", 480 | type=str, 481 | default="InstaDeepAI/nucleotide-transformer-v2-500m-multi-species", 482 | ) 483 | parser.add_argument("--cache_dir", type=str, default="/model-weights") 484 | parser.add_argument("--max_length_dna", type=int, default=1024) 485 | parser.add_argument("--dna_is_evo2", type=bool, default=False) 486 | parser.add_argument("--dna_embedding_layer", type=str, default=None) 487 | 488 | # Training parameters 489 | parser.add_argument("--strategy", type=str, default="ddp") 490 | parser.add_argument("--batch_size", type=int, default=8) 491 | parser.add_argument("--learning_rate", type=float, default=5e-5) 492 | parser.add_argument("--weight_decay", type=float, default=0.01) 493 | parser.add_argument("--max_epochs", type=int, default=5) 494 | parser.add_argument("--max_steps", type=int, default=-1) 495 | parser.add_argument("--gradient_accumulation_steps", type=int, default=8) 496 | parser.add_argument("--num_workers", type=int, default=4) 497 | parser.add_argument("--num_gpus", type=int, default=1) 498 | parser.add_argument("--train_just_classifier", type=bool, default=True) 499 | parser.add_argument("--dataset_type", type=str, choices=["variant_effect", "kegg", "variant_effect_coding", "variant_effect_non_snv"], default="kegg") 500 | parser.add_argument("--kegg_data_dir_huggingface", type=str, default="wanglab/kegg") 501 | parser.add_argument("--truncate_dna_per_side", type=int, default=0) 502 | 503 | # Output parameters 504 | parser.add_argument("--output_dir", type=str, default="dna_classifier_output") 505 | parser.add_argument( 506 | "--checkpoint_dir", type=str, default="checkpoints" 507 | ) 508 | parser.add_argument("--ckpt_path", type=str, default=None) 509 | parser.add_argument("--log_dir", type=str, default="logs") 510 | parser.add_argument("--wandb_project", type=str, default="dna-only-nt-500m") 511 | parser.add_argument("--wandb_entity", type=str, default="adibvafa") 512 | parser.add_argument("--merge_val_test_set", type=bool, default=True) 513 | 514 | # Other parameters 515 | parser.add_argument("--seed", type=int, default=23) 516 | 517 | args = parser.parse_args() 518 | main(args) 519 | --------------------------------------------------------------------------------