├── .github └── ISSUE_TEMPLATE │ ├── bug_report.yml │ └── config.yml ├── .gitignore ├── LICENSE ├── README.md ├── assets └── smoe.png ├── deploy ├── .dockerignore ├── Dockerfile └── entrypoint.sh ├── poetry.lock ├── pyproject.toml ├── src └── mistral_inference │ ├── __init__.py │ ├── args.py │ ├── cache.py │ ├── generate.py │ ├── lora.py │ ├── main.py │ ├── mamba.py │ ├── model.py │ ├── moe.py │ ├── rope.py │ ├── transformer.py │ ├── transformer_layers.py │ └── vision_encoder.py ├── tests └── test_generate.py └── tutorials ├── classifier.ipynb └── getting_started.ipynb /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report related to mistral-inference 2 | description: Submit a bug report that's related to mistral-inference 3 | title: '[BUG: ' 4 | labels: ['bug', 'triage'] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to fill out this bug report! 10 | - type: textarea 11 | id: python-vv 12 | attributes: 13 | label: Python -VV 14 | description: Run `python -VV` from your virtual environment 15 | placeholder: Copy-paste the output (no need for backticks, will be formatted into code automatically) 16 | render: shell 17 | validations: 18 | required: true 19 | - type: textarea 20 | id: pip-freeze 21 | attributes: 22 | label: Pip Freeze 23 | description: Run `pip freeze` from your virtual environment 24 | placeholder: Copy-paste the output (no need for backticks, will be formatted into code automatically) 25 | render: shell 26 | validations: 27 | required: true 28 | - type: textarea 29 | id: reproduction-steps 30 | attributes: 31 | label: Reproduction Steps 32 | description: Provide a clear and concise description of the steps that lead to your issue. 33 | placeholder: | 34 | 1. First step... 35 | 2. Step 2... 36 | ... 37 | validations: 38 | required: true 39 | - type: textarea 40 | id: expected-behavior 41 | attributes: 42 | label: Expected Behavior 43 | description: Explain briefly what you expected to happen. 44 | validations: 45 | required: true 46 | - type: textarea 47 | id: additional-context 48 | attributes: 49 | label: Additional Context 50 | description: Add any context about your problem that you deem relevant. 51 | - type: textarea 52 | id: suggested-solutions 53 | attributes: 54 | label: Suggested Solutions 55 | description: Please list any solutions you recommend we consider. 56 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Documentation 4 | url: https://docs.mistral.ai 5 | about: Developer documentation for the Mistral AI platform 6 | - name: Discord 7 | url: https://discord.com/invite/mistralai) 8 | about: Chat with the Mistral community 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mistral Inference 2 | 3 | Open In Colab 4 | 5 | 6 | 7 | This repository contains minimal code to run Mistral models. 8 | 9 | Blog 7B: [https://mistral.ai/news/announcing-mistral-7b/](https://mistral.ai/news/announcing-mistral-7b/)\ 10 | Blog 8x7B: [https://mistral.ai/news/mixtral-of-experts/](https://mistral.ai/news/mixtral-of-experts/)\ 11 | Blog 8x22B: [https://mistral.ai/news/mixtral-8x22b/](https://mistral.ai/news/mixtral-8x22b/)\ 12 | Blog Codestral 22B: [https://mistral.ai/news/codestral](https://mistral.ai/news/codestral/) \ 13 | Blog Codestral Mamba 7B: [https://mistral.ai/news/codestral-mamba/](https://mistral.ai/news/codestral-mamba/) \ 14 | Blog Mathstral 7B: [https://mistral.ai/news/mathstral/](https://mistral.ai/news/mathstral/) \ 15 | Blog Nemo: [https://mistral.ai/news/mistral-nemo/](https://mistral.ai/news/mistral-nemo/) \ 16 | Blog Mistral Large 2: [https://mistral.ai/news/mistral-large-2407/](https://mistral.ai/news/mistral-large-2407/) \ 17 | Blog Pixtral 12B: [https://mistral.ai/news/pixtral-12b/](https://mistral.ai/news/pixtral-12b/) 18 | Blog Mistral Small 3.1: [https://mistral.ai/news/mistral-small-3-1/](https://mistral.ai/news/mistral-small-3-1/) 19 | 20 | Discord: [https://discord.com/invite/mistralai](https://discord.com/invite/mistralai)\ 21 | Documentation: [https://docs.mistral.ai/](https://docs.mistral.ai/)\ 22 | Guardrailing: [https://docs.mistral.ai/usage/guardrailing](https://docs.mistral.ai/usage/guardrailing) 23 | 24 | ## Installation 25 | 26 | Note: You will use a GPU to install `mistral-inference`, as it currently requires `xformers` to be installed and `xformers` itself needs a GPU for installation. 27 | 28 | ### PyPI 29 | 30 | ``` 31 | pip install mistral-inference 32 | ``` 33 | 34 | ### Local 35 | 36 | ``` 37 | cd $HOME && git clone https://github.com/mistralai/mistral-inference 38 | cd $HOME/mistral-inference && poetry install . 39 | ``` 40 | 41 | ## Model download 42 | 43 | ### Direct links 44 | 45 | | Name | Download | md5sum | 46 | |-------------|-------|-------| 47 | | 7B Instruct | https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar | `80b71fcb6416085bcb4efad86dfb4d52` | 48 | | 8x7B Instruct | https://models.mistralcdn.com/mixtral-8x7b-v0-1/Mixtral-8x7B-v0.1-Instruct.tar (**Updated model coming soon!**) | `8e2d3930145dc43d3084396f49d38a3f` | 49 | | 8x22 Instruct | https://models.mistralcdn.com/mixtral-8x22b-v0-3/mixtral-8x22B-Instruct-v0.3.tar | `471a02a6902706a2f1e44a693813855b` | 50 | | 7B Base | https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar | `0663b293810d7571dad25dae2f2a5806` | 51 | | 8x7B | **Updated model coming soon!** | - | 52 | | 8x22B | https://models.mistralcdn.com/mixtral-8x22b-v0-3/mixtral-8x22B-v0.3.tar | `a2fa75117174f87d1197e3a4eb50371a` | 53 | | Codestral 22B | https://models.mistralcdn.com/codestral-22b-v0-1/codestral-22B-v0.1.tar | `1ea95d474a1d374b1d1b20a8e0159de3` | 54 | | Mathstral 7B | https://models.mistralcdn.com/mathstral-7b-v0-1/mathstral-7B-v0.1.tar | `5f05443e94489c261462794b1016f10b` | 55 | | Codestral-Mamba 7B | https://models.mistralcdn.com/codestral-mamba-7b-v0-1/codestral-mamba-7B-v0.1.tar | `d3993e4024d1395910c55db0d11db163` | 56 | | Nemo Base | https://models.mistralcdn.com/mistral-nemo-2407/mistral-nemo-base-2407.tar | `c5d079ac4b55fc1ae35f51f0a3c0eb83` | 57 | | Nemo Instruct | https://models.mistralcdn.com/mistral-nemo-2407/mistral-nemo-instruct-2407.tar | `296fbdf911cb88e6f0be74cd04827fe7` | 58 | | Mistral Large 2 | https://models.mistralcdn.com/mistral-large-2407/mistral-large-instruct-2407.tar | `fc602155f9e39151fba81fcaab2fa7c4` | 59 | 60 | Note: 61 | - **Important**: 62 | - `mixtral-8x22B-Instruct-v0.3.tar` is exactly the same as [Mixtral-8x22B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), only stored in `.safetensors` format 63 | - `mixtral-8x22B-v0.3.tar` is the same as [Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1), but has an extended vocabulary of 32768 tokens. 64 | - `codestral-22B-v0.1.tar` has a custom non-commercial license, called [Mistral AI Non-Production (MNPL) License](https://mistral.ai/licenses/MNPL-0.1.md) 65 | - `mistral-large-instruct-2407.tar` has a custom non-commercial license, called [Mistral AI Research (MRL) License](https://mistral.ai/licenses/MRL-0.1.md) 66 | - All of the listed models above support function calling. For example, Mistral 7B Base/Instruct v3 is a minor update to Mistral 7B Base/Instruct v2, with the addition of function calling capabilities. 67 | - The "coming soon" models will include function calling as well. 68 | - You can download the previous versions of our models from our [docs](https://docs.mistral.ai/getting-started/open_weight_models/#downloading). 69 | 70 | ### From Hugging Face Hub 71 | 72 | | Name | ID | URL | 73 | |-------------|-------|-------| 74 | | Pixtral Large Instruct | mistralai/Pixtral-Large-Instruct-2411 | https://huggingface.co/mistralai/Pixtral-Large-Instruct-2411 | 75 | | Pixtral 12B Base | mistralai/Pixtral-12B-Base-2409 | https://huggingface.co/mistralai/Pixtral-12B-Base-2409 | 76 | | Pixtral 12B | mistralai/Pixtral-12B-2409 | https://huggingface.co/mistralai/Pixtral-12B-2409 | 77 | | Mistral Small 3.1 24B Base | mistralai/Mistral-Small-3.1-24B-Base-2503 | https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503 78 | | Mistral Small 3.1 24B Instruct | mistralai/Mistral-Small-3.1-24B-Instruct-2503 | https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503 | 79 | 80 | 81 | ### Usage 82 | 83 | **News!!!**: Mistral Large 2 is out. Read more about its capabilities [here](https://mistral.ai/news/mistral-large-2407/). 84 | 85 | Create a local folder to store models 86 | ```sh 87 | export MISTRAL_MODEL=$HOME/mistral_models 88 | mkdir -p $MISTRAL_MODEL 89 | ``` 90 | 91 | Download any of the above links and extract the content, *e.g.*: 92 | 93 | ```sh 94 | export 12B_DIR=$MISTRAL_MODEL/12B_Nemo 95 | wget https://models.mistralcdn.com/mistral-nemo-2407/mistral-nemo-instruct-2407.tar 96 | mkdir -p $12B_DIR 97 | tar -xf mistral-nemo-instruct-2407.tar -C $12B_DIR 98 | ``` 99 | 100 | or 101 | 102 | ```sh 103 | export M8x7B_DIR=$MISTRAL_MODEL/8x7b_instruct 104 | wget https://models.mistralcdn.com/mixtral-8x7b-v0-1/Mixtral-8x7B-v0.1-Instruct.tar 105 | mkdir -p $M8x7B_DIR 106 | tar -xf Mixtral-8x7B-v0.1-Instruct.tar -C $M8x7B_DIR 107 | ``` 108 | 109 | For Hugging Face models' weights, here is an example to download [Mistral Small 3.1 24B Instruct](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503): 110 | 111 | ```python 112 | from pathlib import Path 113 | from huggingface_hub import snapshot_download 114 | 115 | 116 | mistral_models_path = Path.home().joinpath("mistral_models") 117 | 118 | model_path = mistral_models_path / "mistral-small-3.1-instruct" 119 | model_path.mkdir(parents=True, exist_ok=True) 120 | 121 | repo_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" 122 | 123 | snapshot_download( 124 | repo_id=repo_id, 125 | allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], 126 | local_dir=model_path, 127 | ) 128 | ``` 129 | 130 | ## Usage 131 | 132 | The following sections give an overview of how to run the model from the Command-line interface (CLI) or directly within Python. 133 | 134 | ### CLI 135 | 136 | - **Demo** 137 | 138 | To test that a model works in your setup, you can run the `mistral-demo` command. 139 | *E.g.* the 12B Mistral-Nemo model can be tested on a single GPU as follows: 140 | 141 | ```sh 142 | mistral-demo $12B_DIR 143 | ``` 144 | 145 | Large models, such **8x7B** and **8x22B** have to be run in a multi-GPU setup. 146 | For these models, you can use the following command: 147 | 148 | ```sh 149 | torchrun --nproc-per-node 2 --no-python mistral-demo $M8x7B_DIR 150 | ``` 151 | 152 | *Note*: Change `--nproc-per-node` to more GPUs if available. 153 | 154 | - **Chat** 155 | 156 | To interactively chat with the models, you can make use of the `mistral-chat` command. 157 | 158 | ```sh 159 | mistral-chat $12B_DIR --instruct --max_tokens 1024 --temperature 0.35 160 | ``` 161 | 162 | For large models, you can make use of `torchrun`. 163 | 164 | ```sh 165 | torchrun --nproc-per-node 2 --no-python mistral-chat $M8x7B_DIR --instruct 166 | ``` 167 | 168 | *Note*: Change `--nproc-per-node` to more GPUs if necessary (*e.g.* for 8x22B). 169 | 170 | - **Chat with Codestral** 171 | 172 | To use [Codestral](https://mistral.ai/news/codestral/) as a coding assistant you can run the following command using `mistral-chat`. 173 | Make sure `$M22B_CODESTRAL` is set to a valid path to the downloaded codestral folder, e.g. `$HOME/mistral_models/Codestral-22B-v0.1` 174 | 175 | ```sh 176 | mistral-chat $M22B_CODESTRAL --instruct --max_tokens 256 177 | ``` 178 | 179 | If you prompt it with *"Write me a function that computes fibonacci in Rust"*, the model should generate something along the following lines: 180 | 181 | ```sh 182 | Sure, here's a simple implementation of a function that computes the Fibonacci sequence in Rust. This function takes an integer `n` as an argument and returns the `n`th Fibonacci number. 183 | 184 | fn fibonacci(n: u32) -> u32 { 185 | match n { 186 | 0 => 0, 187 | 1 => 1, 188 | _ => fibonacci(n - 1) + fibonacci(n - 2), 189 | } 190 | } 191 | 192 | fn main() { 193 | let n = 10; 194 | println!("The {}th Fibonacci number is: {}", n, fibonacci(n)); 195 | } 196 | 197 | This function uses recursion to calculate the Fibonacci number. However, it's not the most efficient solution because it performs a lot of redundant calculations. A more efficient solution would use a loop to iteratively calculate the Fibonacci numbers. 198 | ``` 199 | 200 | You can continue chatting afterwards, *e.g.* with *"Translate it to Python"*. 201 | 202 | - **Chat with Codestral-Mamba** 203 | 204 | To use [Codestral-Mamba](https://mistral.ai/news/codestral-mamba/) as a coding assistant you can run the following command using `mistral-chat`. 205 | Make sure `$7B_CODESTRAL_MAMBA` is set to a valid path to the downloaded codestral-mamba folder, e.g. `$HOME/mistral_models/mamba-codestral-7B-v0.1`. 206 | 207 | You then need to additionally install the following packages: 208 | 209 | ``` 210 | pip install packaging mamba-ssm causal-conv1d transformers 211 | ``` 212 | 213 | before you can start chatting: 214 | 215 | ```sh 216 | mistral-chat $7B_CODESTRAL_MAMBA --instruct --max_tokens 256 217 | ``` 218 | 219 | - **Chat with Mathstral** 220 | 221 | To use [Mathstral](https://mistral.ai/news/mathstral/) as an assistant you can run the following command using `mistral-chat`. 222 | Make sure `$7B_MATHSTRAL` is set to a valid path to the downloaded codestral folder, e.g. `$HOME/mistral_models/mathstral-7B-v0.1` 223 | 224 | ```sh 225 | mistral-chat $7B_MATHSTRAL --instruct --max_tokens 256 226 | ``` 227 | 228 | If you prompt it with *"Albert likes to surf every week. Each surfing session lasts for 4 hours and costs $20 per hour. How much would Albert spend in 5 weeks?"*, the model should answer with the correct calculation. 229 | 230 | You can then continue chatting afterwards, *e.g.* with *"How much would he spend in a year?"*. 231 | 232 | - **Chat with Mistral Small 3.1 24B Instruct** 233 | 234 | To use [Mistral Small 3.1 24B Instruct](https://mistral.ai/news/mistral-small-3-1/) as an assistant you can run the following command using `mistral-chat`. 235 | Make sure `$MISTRAL_SMALL_3_1_INSTRUCT` is set to a valid path to the downloaded mistral small folder, e.g. `$HOME/mistral_models/mistral-small-3.1-instruct` 236 | 237 | ```sh 238 | mistral-chat $MISTRAL_SMALL_3_1_INSTRUCT --instruct --max_tokens 256 239 | ``` 240 | 241 | If you prompt it with *"The above image presents an image of which park ? Please give the hints to identify the park."* with the following image URL *https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png*, the model should answer with the Yosemite park and give hints to identify it. 242 | 243 | You can then continue chatting afterwards, *e.g.* with *"What is the name of the lake in the image?"*. The model should respond that it is not a lake but a river. 244 | 245 | ### Python 246 | 247 | - *Instruction Following*: 248 | 249 | ```py 250 | from mistral_inference.transformer import Transformer 251 | from mistral_inference.generate import generate 252 | 253 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 254 | from mistral_common.protocol.instruct.messages import UserMessage 255 | from mistral_common.protocol.instruct.request import ChatCompletionRequest 256 | 257 | 258 | tokenizer = MistralTokenizer.from_file("./mistral-nemo-instruct-v0.1/tekken.json") # change to extracted tokenizer file 259 | model = Transformer.from_folder("./mistral-nemo-instruct-v0.1") # change to extracted model dir 260 | 261 | prompt = "How expensive would it be to ask a window cleaner to clean all windows in Paris. Make a reasonable guess in US Dollar." 262 | 263 | completion_request = ChatCompletionRequest(messages=[UserMessage(content=prompt)]) 264 | 265 | tokens = tokenizer.encode_chat_completion(completion_request).tokens 266 | 267 | out_tokens, _ = generate([tokens], model, max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) 268 | result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0]) 269 | 270 | print(result) 271 | ``` 272 | 273 | - *Multimodal Instruction Following*: 274 | 275 | 276 | ```python 277 | from pathlib import Path 278 | 279 | from huggingface_hub import snapshot_download 280 | from mistral_common.protocol.instruct.messages import ImageURLChunk, TextChunk 281 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 282 | from mistral_inference.generate import generate 283 | from mistral_inference.transformer import Transformer 284 | 285 | model_path = Path.home().joinpath("mistral_models") / "mistral-small-3.1-instruct" # change to extracted model 286 | 287 | tokenizer = MistralTokenizer.from_file(model_path / "tekken.json") 288 | model = Transformer.from_folder(model_path) 289 | 290 | url = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png" 291 | prompt = "The above image presents an image of which park ? Please give the hints to identify the park." 292 | 293 | user_content = [ImageURLChunk(image_url=url), TextChunk(text=prompt)] 294 | 295 | tokens, images = tokenizer.instruct_tokenizer.encode_user_content(user_content, False) 296 | 297 | out_tokens, _ = generate( 298 | [tokens], 299 | model, 300 | images=[images], 301 | max_tokens=256, 302 | temperature=0.15, 303 | eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id, 304 | ) 305 | result = tokenizer.decode(out_tokens[0]) 306 | 307 | print("Prompt:", prompt) 308 | print("Completion:", result) 309 | ``` 310 | 311 | - *Function Calling*: 312 | 313 | ```py 314 | from mistral_common.protocol.instruct.tool_calls import Function, Tool 315 | 316 | completion_request = ChatCompletionRequest( 317 | tools=[ 318 | Tool( 319 | function=Function( 320 | name="get_current_weather", 321 | description="Get the current weather", 322 | parameters={ 323 | "type": "object", 324 | "properties": { 325 | "location": { 326 | "type": "string", 327 | "description": "The city and state, e.g. San Francisco, CA", 328 | }, 329 | "format": { 330 | "type": "string", 331 | "enum": ["celsius", "fahrenheit"], 332 | "description": "The temperature unit to use. Infer this from the users location.", 333 | }, 334 | }, 335 | "required": ["location", "format"], 336 | }, 337 | ) 338 | ) 339 | ], 340 | messages=[ 341 | UserMessage(content="What's the weather like today in Paris?"), 342 | ], 343 | ) 344 | 345 | tokens = tokenizer.encode_chat_completion(completion_request).tokens 346 | 347 | out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) 348 | result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0]) 349 | 350 | print(result) 351 | ``` 352 | 353 | - *Fill-in-the-middle (FIM)*: 354 | 355 | Make sure to have `mistral-common >= 1.2.0` installed: 356 | ``` 357 | pip install --upgrade mistral-common 358 | ``` 359 | 360 | You can simulate a code completion in-filling as follows. 361 | 362 | ```py 363 | from mistral_inference.transformer import Transformer 364 | from mistral_inference.generate import generate 365 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 366 | from mistral_common.tokens.instruct.request import FIMRequest 367 | 368 | tokenizer = MistralTokenizer.from_model("codestral-22b") 369 | model = Transformer.from_folder("./mistral_22b_codestral") 370 | 371 | prefix = """def add(""" 372 | suffix = """ return sum""" 373 | 374 | request = FIMRequest(prompt=prefix, suffix=suffix) 375 | 376 | tokens = tokenizer.encode_fim(request).tokens 377 | 378 | out_tokens, _ = generate([tokens], model, max_tokens=256, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) 379 | result = tokenizer.decode(out_tokens[0]) 380 | 381 | middle = result.split(suffix)[0].strip() 382 | print(middle) 383 | ``` 384 | 385 | ### Test 386 | 387 | To run logits equivalence: 388 | ``` 389 | python -m pytest tests 390 | ``` 391 | 392 | ## Deployment 393 | 394 | The `deploy` folder contains code to build a [vLLM](https://M7B_DIR.com/vllm-project/vllm) image with the required dependencies to serve the Mistral AI model. In the image, the [transformers](https://github.com/huggingface/transformers/) library is used instead of the reference implementation. To build it: 395 | 396 | ```bash 397 | docker build deploy --build-arg MAX_JOBS=8 398 | ``` 399 | 400 | Instructions to run the image can be found in the [official documentation](https://docs.mistral.ai/quickstart). 401 | 402 | 403 | ## Model platforms 404 | 405 | - Use Mistral models on [Mistral AI official API](https://console.mistral.ai/) (La Plateforme) 406 | - Use Mistral models via [cloud providers](https://docs.mistral.ai/deployment/cloud/overview/) 407 | 408 | ## References 409 | 410 | [1]: [LoRA](https://arxiv.org/abs/2106.09685): Low-Rank Adaptation of Large Language Models, Hu et al. 2021 411 | -------------------------------------------------------------------------------- /assets/smoe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mistralai/mistral-inference/6eb35510403825cfb430b0004443053e8c4b70dc/assets/smoe.png -------------------------------------------------------------------------------- /deploy/.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | !entrypoint.sh 3 | -------------------------------------------------------------------------------- /deploy/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=amd64 nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04 as base 2 | 3 | WORKDIR /workspace 4 | 5 | RUN apt update && \ 6 | apt install -y python3-pip python3-packaging \ 7 | git ninja-build && \ 8 | pip3 install -U pip 9 | 10 | # Tweak this list to reduce build time 11 | # https://developer.nvidia.com/cuda-gpus 12 | ENV TORCH_CUDA_ARCH_LIST "7.0;7.2;7.5;8.0;8.6;8.9;9.0" 13 | 14 | RUN pip3 install "torch==2.1.1" 15 | 16 | # This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed. 17 | RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git" 18 | RUN pip3 install "git+https://github.com/vllm-project/vllm.git" 19 | RUN pip3 install "xformers==0.0.23" "transformers==4.36.0" "fschat[model_worker]==0.2.34" 20 | 21 | RUN git clone https://github.com/NVIDIA/apex && \ 22 | cd apex && git checkout 2386a912164b0c5cfcd8be7a2b890fbac5607c82 && \ 23 | sed -i '/check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)/d' setup.py && \ 24 | python3 setup.py install --cpp_ext --cuda_ext 25 | 26 | 27 | COPY entrypoint.sh . 28 | 29 | RUN chmod +x /workspace/entrypoint.sh 30 | 31 | ENTRYPOINT ["/workspace/entrypoint.sh"] -------------------------------------------------------------------------------- /deploy/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ ! -z "${HF_TOKEN}" ]]; then 4 | echo "The HF_TOKEN environment variable is set, logging to Hugging Face." 5 | python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" 6 | else 7 | echo "The HF_TOKEN environment variable is not set or empty, not logging to Hugging Face." 8 | fi 9 | 10 | # Run the provided command 11 | exec python3 -u -m vllm.entrypoints.openai.api_server "$@" 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "mistral_inference" 3 | version = "1.6.0" 4 | description = "" 5 | authors = ["bam4d "] 6 | readme = "README.md" 7 | packages = [{ include = "mistral_inference", from = "src" }] 8 | 9 | [tool.ruff] 10 | lint.select = ["E", "F", "W", "Q", "I"] 11 | lint.ignore = ["E203"] 12 | lint.fixable = ["ALL"] 13 | lint.unfixable = [] 14 | line-length = 120 15 | exclude = ["docs", "build", "tutorials"] 16 | 17 | [tool.mypy] 18 | disallow_untyped_defs = true 19 | show_error_codes = true 20 | no_implicit_optional = true 21 | warn_return_any = true 22 | warn_unused_ignores = true 23 | exclude = ["docs", "tools", "build"] 24 | 25 | [tool.poetry.dependencies] 26 | python = "^3.9.10" 27 | xformers = ">=0.0.24" 28 | simple-parsing = ">=0.1.5" 29 | fire = ">=0.6.0" 30 | mistral_common = ">=1.5.4" 31 | safetensors = ">=0.4.0" 32 | pillow = ">=10.3.0" 33 | 34 | [tool.poetry.group.dev.dependencies] 35 | types-protobuf = "4.24.0.20240129" 36 | mypy-protobuf = "^3.5.0" 37 | pytest = "7.4.4" 38 | ruff = "^0.2.2" 39 | mypy = "^1.8.0" 40 | 41 | [build-system] 42 | requires = ["poetry-core>=1.0.0"] 43 | build-backend = "poetry.core.masonry.api" 44 | 45 | [tool.pytest.ini_options] 46 | testpaths = ["./tests"] 47 | 48 | [tool.poetry.scripts] 49 | mistral-chat = "mistral_inference.main:mistral_chat" 50 | mistral-demo = "mistral_inference.main:mistral_demo" 51 | -------------------------------------------------------------------------------- /src/mistral_inference/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.6.0" 2 | -------------------------------------------------------------------------------- /src/mistral_inference/args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional 3 | 4 | from simple_parsing.helpers import Serializable 5 | 6 | from mistral_inference.lora import LoraArgs 7 | from mistral_inference.moe import MoeArgs 8 | 9 | PATCH_MERGE = "patch_merge" 10 | 11 | 12 | @dataclass 13 | class VisionEncoderArgs: 14 | hidden_size: int 15 | num_channels: int 16 | image_size: int 17 | patch_size: int 18 | intermediate_size: int 19 | num_hidden_layers: int 20 | num_attention_heads: int 21 | rope_theta: float = 1e4 # for rope-2D 22 | image_token_id: int = 10 23 | adapter_bias: bool = True 24 | spatial_merge_size: int = 1 25 | add_pre_mm_projector_layer_norm: bool = False 26 | mm_projector_id: str = "" 27 | 28 | 29 | @dataclass 30 | class TransformerArgs(Serializable): 31 | dim: int 32 | n_layers: int 33 | head_dim: int 34 | hidden_dim: int 35 | n_heads: int 36 | n_kv_heads: int 37 | norm_eps: float 38 | vocab_size: int 39 | 40 | max_batch_size: int = 0 41 | 42 | # For rotary embeddings. If not set, will be inferred 43 | rope_theta: Optional[float] = None 44 | # If this is set, we will use MoE layers instead of dense layers. 45 | moe: Optional[MoeArgs] = None 46 | # If this is set, we will load LoRA linear layers instead of linear layers. 47 | lora: Optional[LoraArgs] = None 48 | sliding_window: Optional[int] | Optional[List[int]] = None 49 | _sliding_window: Optional[int] | Optional[List[int]] = None 50 | model_type: str = "transformer" 51 | 52 | vision_encoder: Optional[VisionEncoderArgs] = None 53 | 54 | def __post_init__(self) -> None: 55 | assert self.model_type == "transformer", self.model_type 56 | assert self.sliding_window is None or self._sliding_window is None 57 | 58 | # hack for now so that vLLM is supported correctly 59 | self.sliding_window = self.sliding_window if self.sliding_window is not None else self._sliding_window 60 | 61 | 62 | @dataclass 63 | class MambaArgs(Serializable): 64 | dim: int 65 | n_layers: int 66 | vocab_size: int 67 | n_groups: int 68 | rms_norm: bool 69 | residual_in_fp32: bool 70 | fused_add_norm: bool 71 | pad_vocab_size_multiple: int 72 | tie_embeddings: bool 73 | model_type: str = "mamba" 74 | 75 | def __post_init__(self) -> None: 76 | assert self.model_type == "mamba", self.model_type 77 | -------------------------------------------------------------------------------- /src/mistral_inference/cache.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | from xformers.ops.fmha.attn_bias import ( # type: ignore 6 | AttentionBias, 7 | BlockDiagonalCausalMask, 8 | BlockDiagonalCausalWithOffsetPaddedKeysMask, 9 | BlockDiagonalMask, 10 | ) 11 | 12 | 13 | def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Optional[int] | Optional[List[int]]) -> List[int]: 14 | if sliding_window is None: 15 | return n_layers * [max_seq_len] 16 | elif isinstance(sliding_window, int): 17 | return n_layers * [sliding_window] 18 | else: 19 | assert isinstance(sliding_window, list), f"Expected list, got {type(sliding_window)}" 20 | assert ( 21 | n_layers % len(sliding_window) == 0 22 | ), f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}" 23 | num_repeats = n_layers // len(sliding_window) 24 | return num_repeats * [w if w is not None else max_seq_len for w in sliding_window] 25 | 26 | 27 | @dataclass 28 | class CacheInputMetadata: 29 | # # rope absolute positions 30 | # positions: torch.Tensor 31 | # # where tokens should go in the cache 32 | # cache_positions: torch.Tensor 33 | 34 | # # if prefill, use block diagonal causal mask 35 | # # else use causal with padded key mask 36 | # prefill: bool 37 | # mask: AttentionBias 38 | # seqlens: List[int] 39 | # rope absolute positions 40 | positions: torch.Tensor 41 | # which elements in the sequences need to be cached 42 | to_cache_mask: torch.Tensor 43 | # how many elements are cached per sequence 44 | cached_elements: torch.Tensor 45 | # where tokens should go in the cache 46 | cache_positions: torch.Tensor 47 | # if prefill, use block diagonal causal mask 48 | # else use causal with padded key mask 49 | prefill: bool 50 | mask: AttentionBias 51 | seqlens: List[int] 52 | 53 | 54 | def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> List[torch.Tensor]: 55 | assert len(l1) == len(l2) 56 | return [v for pair in zip(l1, l2) for v in pair] 57 | 58 | 59 | def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor: 60 | assert cache.ndim == 3 # (W, H, D) 61 | position = seqlen % cache.shape[0] 62 | if seqlen < cache.shape[0]: 63 | return cache[:seqlen] 64 | elif position == 0: 65 | return cache 66 | else: 67 | return torch.cat([cache[position:], cache[:position]], dim=0) 68 | 69 | 70 | class CacheView: 71 | def __init__( 72 | self, 73 | cache_k: torch.Tensor, 74 | cache_v: torch.Tensor, 75 | metadata: CacheInputMetadata, 76 | kv_seqlens: torch.Tensor, 77 | ): 78 | self.cache_k = cache_k 79 | self.cache_v = cache_v 80 | self.kv_seqlens = kv_seqlens 81 | self.metadata = metadata 82 | 83 | def update(self, xk: torch.Tensor, xv: torch.Tensor) -> None: 84 | """ 85 | to_cache_mask masks the last [max_seq_len] tokens in each sequence 86 | """ 87 | n_kv_heads, head_dim = self.cache_k.shape[-2:] 88 | flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim) 89 | flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim) 90 | 91 | flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask]) 92 | flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask]) 93 | 94 | def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 95 | """ 96 | This is a naive implementation and not optimized for speed. 97 | """ 98 | assert xk.ndim == xv.ndim == 3 # (B * T, H, D) 99 | assert xk.shape == xv.shape 100 | 101 | if all([s == 0 for s in self.metadata.seqlens]): 102 | # No cache to interleave 103 | return xk, xv 104 | 105 | # Make it a list of [(T, H, D)] 106 | xk: Tuple[torch.Tensor] = torch.split(xk, self.metadata.seqlens) # type: ignore 107 | xv: Tuple[torch.Tensor] = torch.split(xv, self.metadata.seqlens) # type: ignore 108 | assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}" 109 | 110 | # Order elements in cache by position by unrotating 111 | cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)] 112 | cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)] 113 | 114 | interleaved_k = interleave_list(cache_k, list(xk)) 115 | interleaved_v = interleave_list(cache_v, list(xv)) 116 | 117 | return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0) 118 | 119 | @property 120 | def max_seq_len(self) -> int: 121 | return self.cache_k.shape[1] 122 | 123 | @property 124 | def key(self) -> torch.Tensor: 125 | return self.cache_k[: len(self.kv_seqlens)] 126 | 127 | @property 128 | def value(self) -> torch.Tensor: 129 | return self.cache_v[: len(self.kv_seqlens)] 130 | 131 | @property 132 | def prefill(self) -> bool: 133 | return self.metadata.prefill 134 | 135 | @property 136 | def mask(self) -> AttentionBias: 137 | return self.metadata.mask 138 | 139 | 140 | class BufferCache: 141 | """ 142 | This is an example that implements a buffer cache, allowing for variable length sequences. 143 | Allocated cache is rectangular which is wasteful (see PagedAttention for better mechanisms) 144 | """ 145 | 146 | def __init__( 147 | self, 148 | n_layers: int, 149 | max_batch_size: int, 150 | max_seq_len: int, 151 | n_kv_heads: int, 152 | head_dim: int, 153 | sliding_window: Optional[int] | Optional[List[int]] = None, 154 | ): 155 | self.max_seq_len = max_seq_len 156 | self.n_kv_heads = n_kv_heads 157 | self.head_dim = head_dim 158 | self.n_layers = n_layers 159 | 160 | self.cache_sizes: List[int] = get_cache_sizes(n_layers, max_seq_len, sliding_window) 161 | assert len(self.cache_sizes) == n_layers, f"Expected {n_layers} cache sizes, got {len(self.cache_sizes)}" 162 | 163 | self.cache_k = {} 164 | self.cache_v = {} 165 | for i, cache_size in enumerate(self.cache_sizes): 166 | self.cache_k[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim)) 167 | self.cache_v[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim)) 168 | 169 | # holds the valid length for each batch element in the cache 170 | self.kv_seqlens: Optional[torch.Tensor] = None 171 | 172 | def get_view(self, layer_id: int, metadata: CacheInputMetadata) -> CacheView: 173 | assert self.kv_seqlens is not None 174 | return CacheView(self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens) 175 | 176 | def reset(self) -> None: 177 | self.kv_seqlens = None 178 | 179 | def init_kvseqlens(self, batch_size: int) -> None: 180 | self.kv_seqlens = torch.zeros((batch_size,), device=self.device, dtype=torch.long) 181 | 182 | @property 183 | def device(self) -> torch.device: 184 | return self.cache_k[0].device 185 | 186 | def to(self, device: torch.device, dtype: torch.dtype) -> "BufferCache": 187 | for i in range(self.n_layers): 188 | self.cache_k[i] = self.cache_k[i].to(device=device, dtype=dtype) 189 | self.cache_v[i] = self.cache_v[i].to(device=device, dtype=dtype) 190 | 191 | return self 192 | 193 | def update_seqlens(self, seqlens: List[int]) -> None: 194 | assert self.kv_seqlens is not None 195 | self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long) 196 | 197 | def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]: 198 | """ 199 | input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3 200 | --> only cache last 3 tokens in each sequence 201 | - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1] 202 | - cached_elements = [3 | 3 | 2] 203 | --> absolute positions are used for rope 204 | - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4] 205 | --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window 206 | - cache_positions = [2 0 1 | 5 3 4 | 6 7] 207 | """ 208 | metadata: List[CacheInputMetadata] = [] 209 | 210 | if self.kv_seqlens is None: 211 | self.init_kvseqlens(len(seqlens)) 212 | 213 | assert self.kv_seqlens is not None 214 | assert len(seqlens) == len( 215 | self.kv_seqlens 216 | ), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?" 217 | seqpos = self.kv_seqlens.tolist() 218 | assert len(seqlens) > 0, seqlens 219 | 220 | for cache_size in self.cache_sizes: 221 | metadata.append(self._get_input_metadata_layer(cache_size, seqlens, seqpos)) 222 | 223 | return metadata 224 | 225 | def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int], seqpos: List[int]) -> CacheInputMetadata: 226 | masks = [[x >= seqlen - cache_size for x in range(seqlen)] for seqlen in seqlens] 227 | to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool) 228 | cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long) 229 | positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to( 230 | device=self.device, dtype=torch.long 231 | ) 232 | batch_idx = torch.tensor( 233 | sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long 234 | ) 235 | cache_positions = positions % cache_size + batch_idx * cache_size 236 | first_prefill = seqpos[0] == 0 237 | subsequent_prefill = any(seqlen > 1 for seqlen in seqlens) 238 | if first_prefill: 239 | assert all([pos == 0 for pos in seqpos]), seqpos 240 | mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(cache_size) 241 | elif subsequent_prefill: 242 | assert self.kv_seqlens is not None 243 | mask = BlockDiagonalMask.from_seqlens( 244 | q_seqlen=seqlens, 245 | kv_seqlen=[ 246 | s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens) 247 | ], 248 | ).make_local_attention_from_bottomright(cache_size) 249 | else: 250 | mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( 251 | q_seqlen=seqlens, 252 | kv_padding=cache_size, 253 | kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist(), 254 | ) 255 | return CacheInputMetadata( 256 | positions=positions, 257 | to_cache_mask=to_cache_mask, 258 | cached_elements=cached_elements, 259 | cache_positions=cache_positions[to_cache_mask], 260 | prefill=first_prefill or subsequent_prefill, 261 | mask=mask, 262 | seqlens=seqlens, 263 | ) 264 | -------------------------------------------------------------------------------- /src/mistral_inference/generate.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from mistral_inference.cache import BufferCache 7 | from mistral_inference.mamba import Mamba 8 | from mistral_inference.transformer import Transformer 9 | 10 | 11 | @torch.inference_mode() 12 | def generate_mamba( 13 | encoded_prompts: List[List[int]], 14 | model: Mamba, 15 | *, 16 | max_tokens: int, 17 | temperature: float, 18 | chunk_size: Optional[int] = None, 19 | eos_id: Optional[int] = None, 20 | ) -> Tuple[List[List[int]], List[List[float]]]: 21 | input_ids = torch.tensor(encoded_prompts, device=model.device) 22 | output = model.model.generate( 23 | input_ids=input_ids, 24 | max_length=input_ids.shape[-1] + max_tokens, 25 | cg=True, 26 | return_dict_in_generate=True, 27 | output_scores=True, 28 | enable_timing=False, 29 | eos_token_id=eos_id, 30 | temperature=temperature, 31 | top_p=0.8, 32 | ) 33 | generated_tokens = output.sequences[:, input_ids.shape[-1] :].tolist() 34 | 35 | _logprobs: List[List[float]] = [[] for _ in range(len(generated_tokens))] 36 | for seq_idx, batch_score in enumerate(output.scores): 37 | for batch_idx, score in enumerate(batch_score.tolist()): 38 | _logprobs[batch_idx].append(score[generated_tokens[batch_idx][seq_idx]]) 39 | 40 | return generated_tokens, _logprobs 41 | 42 | 43 | @torch.inference_mode() 44 | def generate( 45 | encoded_prompts: List[List[int]], 46 | model: Transformer, 47 | images: List[List[np.ndarray]] = [], 48 | *, 49 | max_tokens: int, 50 | temperature: float, 51 | chunk_size: Optional[int] = None, 52 | eos_id: Optional[int] = None, 53 | ) -> Tuple[List[List[int]], List[List[float]]]: 54 | images_torch: List[List[torch.Tensor]] = [] 55 | if images: 56 | assert chunk_size is None 57 | images_torch = [ 58 | [torch.tensor(im, device=model.device, dtype=model.dtype) for im in images_for_sample] 59 | for images_for_sample in images 60 | ] 61 | 62 | model = model.eval() 63 | B, V = len(encoded_prompts), model.args.vocab_size 64 | 65 | seqlens = [len(x) for x in encoded_prompts] 66 | 67 | # Cache 68 | cache_window = max(seqlens) + max_tokens 69 | cache = BufferCache( 70 | model.n_local_layers, 71 | model.args.max_batch_size, 72 | cache_window, 73 | model.args.n_kv_heads, 74 | model.args.head_dim, 75 | model.args.sliding_window, 76 | ) 77 | cache.to(device=model.device, dtype=model.dtype) 78 | cache.reset() 79 | 80 | # Bookkeeping 81 | logprobs: List[List[float]] = [[] for _ in range(B)] 82 | last_token_prelogits = None 83 | 84 | # One chunk if size not specified 85 | max_prompt_len = max(seqlens) 86 | if chunk_size is None: 87 | chunk_size = max_prompt_len 88 | 89 | flattened_images: List[torch.Tensor] = sum(images_torch, []) 90 | 91 | # Encode prompt by chunks 92 | for s in range(0, max_prompt_len, chunk_size): 93 | prompt_chunks = [p[s : s + chunk_size] for p in encoded_prompts] 94 | assert all(len(p) > 0 for p in prompt_chunks) 95 | prelogits = model.forward( 96 | torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long), 97 | images=flattened_images, 98 | seqlens=[len(p) for p in prompt_chunks], 99 | cache=cache, 100 | ) 101 | logits = torch.log_softmax(prelogits, dim=-1) 102 | 103 | if last_token_prelogits is not None: 104 | # Pass > 1 105 | last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1) 106 | for i_seq in range(B): 107 | logprobs[i_seq].append(last_token_logits[i_seq, prompt_chunks[i_seq][0]].item()) 108 | 109 | offset = 0 110 | for i_seq, sequence in enumerate(prompt_chunks): 111 | logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)]) 112 | offset += len(sequence) 113 | 114 | last_token_prelogits = prelogits.index_select( 115 | 0, 116 | torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1, 117 | ) 118 | assert last_token_prelogits.shape == (B, V) 119 | 120 | # decode 121 | generated_tensors = [] 122 | is_finished = torch.tensor([False for _ in range(B)]) 123 | 124 | assert last_token_prelogits is not None 125 | for _ in range(max_tokens): 126 | next_token = sample(last_token_prelogits, temperature=temperature, top_p=0.8) 127 | 128 | if eos_id is not None: 129 | is_finished = is_finished | (next_token == eos_id).cpu() 130 | 131 | if is_finished.all(): 132 | break 133 | 134 | last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1) 135 | for i in range(B): 136 | logprobs[i].append(last_token_logits[i, next_token[i]].item()) 137 | 138 | generated_tensors.append(next_token[:, None]) 139 | last_token_prelogits = model.forward(next_token, seqlens=[1] * B, cache=cache) 140 | assert last_token_prelogits.shape == (B, V) 141 | 142 | generated_tokens: List[List[int]] 143 | if generated_tensors: 144 | generated_tokens = torch.cat(generated_tensors, 1).tolist() 145 | else: 146 | generated_tokens = [] 147 | 148 | return generated_tokens, logprobs 149 | 150 | 151 | def sample(logits: torch.Tensor, temperature: float, top_p: float) -> torch.Tensor: 152 | if temperature > 0: 153 | probs = torch.softmax(logits / temperature, dim=-1) 154 | next_token = sample_top_p(probs, top_p) 155 | else: 156 | next_token = torch.argmax(logits, dim=-1).unsqueeze(0) 157 | 158 | return next_token.reshape(-1) 159 | 160 | 161 | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: 162 | assert 0 <= p <= 1 163 | 164 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 165 | probs_sum = torch.cumsum(probs_sort, dim=-1) 166 | mask = probs_sum - probs_sort > p 167 | probs_sort[mask] = 0.0 168 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 169 | next_token = torch.multinomial(probs_sort, num_samples=1) 170 | return torch.gather(probs_idx, -1, next_token) 171 | -------------------------------------------------------------------------------- /src/mistral_inference/lora.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Any, Dict, NamedTuple, Union 5 | 6 | import safetensors.torch 7 | import torch 8 | import torch.nn as nn 9 | from simple_parsing.helpers import Serializable 10 | 11 | 12 | @dataclass 13 | class LoraArgs(Serializable): 14 | rank: int 15 | scaling: float 16 | 17 | def __post_init__(self) -> None: 18 | assert self.rank > 0 19 | assert self.scaling > 0.0 20 | 21 | 22 | class LoRALinear(nn.Module): 23 | """ 24 | Implementation of: 25 | - LoRA: https://arxiv.org/abs/2106.09685 26 | 27 | Notes: 28 | - Freezing is handled at network level, not layer level. 29 | - Scaling factor controls relative importance of LoRA skip 30 | connection versus original frozen weight. General guidance is 31 | to keep it to 2.0 and sweep over learning rate when changing 32 | the rank. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | in_features: int, 38 | out_features: int, 39 | rank: int, 40 | scaling: float, 41 | bias: bool = False, 42 | ): 43 | super().__init__() 44 | 45 | self.in_features = in_features 46 | self.out_features = out_features 47 | assert not bias 48 | self.bias = bias 49 | self.rank = rank 50 | self.scaling = scaling 51 | 52 | self.lora_A = nn.Linear( 53 | self.in_features, 54 | self.rank, 55 | bias=self.bias, 56 | ) 57 | self.lora_B = nn.Linear( 58 | self.rank, 59 | self.out_features, 60 | bias=self.bias, 61 | ) 62 | 63 | self.linear = nn.Linear(self.in_features, self.out_features, bias=self.bias) 64 | 65 | # make sure no LoRA weights are marked as "missing" in load_state_dict 66 | def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple) -> None: 67 | incompatible_keys.missing_keys[:] = [] # type: ignore 68 | 69 | self.register_load_state_dict_post_hook(ignore_missing_keys) 70 | 71 | def forward(self, x: torch.Tensor) -> torch.Tensor: 72 | lora = self.lora_B(self.lora_A(x)) 73 | result: torch.Tensor = self.linear(x) + lora * self.scaling 74 | return result 75 | 76 | def _load_from_state_dict(self, state_dict: Dict[str, Any], prefix: str, *args, **kwargs) -> None: # type: ignore[no-untyped-def] 77 | key_name = prefix + "weight" 78 | 79 | # full checkpoint 80 | if key_name in state_dict: 81 | w_ref = state_dict[key_name] 82 | 83 | # load frozen weights 84 | state_dict = { 85 | "linear.weight": w_ref, 86 | "lora_A.weight": torch.zeros_like(self.lora_A.weight, device=w_ref.device, dtype=w_ref.dtype), 87 | "lora_B.weight": torch.zeros_like(self.lora_B.weight, device=w_ref.device, dtype=w_ref.dtype), 88 | } 89 | self.load_state_dict(state_dict, assign=True, strict=True) 90 | 91 | 92 | class LoRALoaderMixin: 93 | def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0) -> None: 94 | """Loads LoRA checkpoint""" 95 | 96 | lora_path = Path(lora_path) 97 | assert lora_path.is_file(), f"{lora_path} does not exist or is not a file" 98 | 99 | state_dict = safetensors.torch.load_file(lora_path) 100 | 101 | self._load_lora_state_dict(state_dict, scaling=scaling) 102 | 103 | def _load_lora_state_dict(self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0) -> None: 104 | """Loads LoRA state_dict""" 105 | lora_dtypes = set([p.dtype for p in lora_state_dict.values()]) 106 | assert ( 107 | len(lora_dtypes) == 1 108 | ), f"LoRA weights have multiple different dtypes {lora_dtypes}. All weights need to have the same dtype" 109 | lora_dtype = lora_dtypes.pop() 110 | assert lora_dtype == self.dtype, f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}" # type: ignore[attr-defined] 111 | assert all("lora" in key for key in lora_state_dict.keys()) 112 | 113 | # move tensors to device 114 | lora_state_dict = {k: v.to(self.device) for k, v in lora_state_dict.items()} # type: ignore[attr-defined] 115 | 116 | state_dict = self.state_dict() # type: ignore[attr-defined] 117 | 118 | if self.args.lora is None: # type: ignore[attr-defined] 119 | logging.info("Loading and merging LoRA weights...") 120 | 121 | # replace every nn.Linear with a LoRALinear with 'meta' device except the output layer 122 | named_modules = dict(self.named_modules()) # type: ignore[attr-defined] 123 | for name, module in named_modules.items(): 124 | if isinstance(module, nn.Linear) and name != "output": 125 | layer_id = name.split(".")[1] 126 | if layer_id not in self.layers: # type: ignore[attr-defined] 127 | logging.debug( 128 | "Skipping parameter %s at pipeline rank %d", 129 | name, 130 | self.pipeline_rank, # type: ignore[attr-defined] 131 | ) 132 | elif (name + ".lora_B.weight") in lora_state_dict: 133 | weight = ( 134 | module.weight 135 | + (lora_state_dict[name + ".lora_B.weight"] @ lora_state_dict[name + ".lora_A.weight"]) 136 | * scaling 137 | ) 138 | 139 | state_dict[name + ".weight"] = weight 140 | else: 141 | logging.info("Loading LoRA weights...") 142 | for k, v in lora_state_dict.items(): 143 | state_dict.update(lora_state_dict) 144 | 145 | layer_id = k.split(".")[1] 146 | if layer_id in self.layers: # type: ignore[attr-defined] 147 | state_dict[k] = v 148 | else: 149 | logging.debug( 150 | "Skipping parameter %s at pipeline rank %d", 151 | k, 152 | self.pipeline_rank, # type: ignore[attr-defined] 153 | ) 154 | 155 | self.load_state_dict(state_dict, strict=True) # type: ignore[attr-defined] 156 | -------------------------------------------------------------------------------- /src/mistral_inference/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import warnings 5 | from pathlib import Path 6 | from typing import List, Optional, Tuple, Type, Union 7 | 8 | import fire # type: ignore 9 | import torch 10 | import torch.distributed as dist 11 | from mistral_common.protocol.instruct.messages import ( 12 | AssistantMessage, 13 | ContentChunk, 14 | ImageChunk, 15 | ImageURLChunk, 16 | TextChunk, 17 | UserMessage, 18 | ) 19 | from mistral_common.protocol.instruct.request import ChatCompletionRequest 20 | from mistral_common.tokens.tokenizers.base import Tokenizer 21 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer 22 | from mistral_common.tokens.tokenizers.sentencepiece import is_sentencepiece 23 | from mistral_common.tokens.tokenizers.tekken import ( 24 | SpecialTokenPolicy, 25 | Tekkenizer, 26 | is_tekken, 27 | ) 28 | from PIL import Image 29 | 30 | from mistral_inference.args import TransformerArgs 31 | from mistral_inference.generate import generate, generate_mamba 32 | from mistral_inference.mamba import Mamba 33 | from mistral_inference.transformer import Transformer 34 | 35 | 36 | def is_torchrun() -> bool: 37 | required_vars = ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE"] 38 | return all(var in os.environ for var in required_vars) 39 | 40 | 41 | def load_tokenizer(model_path: Path) -> MistralTokenizer: 42 | tokenizer = [f for f in os.listdir(model_path) if is_tekken(model_path / f) or is_sentencepiece(model_path / f)] 43 | assert ( 44 | len(tokenizer) > 0 45 | ), f"No tokenizer in {model_path}, place a `tokenizer.model.[v1,v2,v3]` or `tekken.json` file in {model_path}." 46 | assert ( 47 | len(tokenizer) == 1 48 | ), f"Multiple tokenizers {', '.join(tokenizer)} found in `model_path`, make sure to only have one tokenizer" 49 | 50 | mistral_tokenizer = MistralTokenizer.from_file(str(model_path / tokenizer[0])) 51 | 52 | if isinstance(mistral_tokenizer.instruct_tokenizer.tokenizer, Tekkenizer): 53 | mistral_tokenizer.instruct_tokenizer.tokenizer.special_token_policy = SpecialTokenPolicy.KEEP 54 | 55 | logging.info(f"Loaded tokenizer of type {mistral_tokenizer.instruct_tokenizer.__class__}") 56 | 57 | return mistral_tokenizer 58 | 59 | 60 | def get_model_cls(model_path: str) -> Union[Type[Mamba], Type[Transformer]]: 61 | with open(Path(model_path) / "params.json", "r") as f: 62 | args_dict = json.load(f) 63 | 64 | return {"mamba": Mamba, "transformer": Transformer}[args_dict.get("model_type", "transformer")] # type: ignore[return-value] 65 | 66 | 67 | def pad_and_convert_to_tensor(list_of_lists: List[List[int]], pad_id: int) -> List[List[int]]: 68 | # Determine the length of the longest list 69 | max_len = max(len(lst) for lst in list_of_lists) 70 | 71 | # Left pad each list to the maximum length 72 | padded_lists = [[pad_id] * (max_len - len(lst)) + lst for lst in list_of_lists] 73 | 74 | return padded_lists 75 | 76 | 77 | def _get_multimodal_input() -> Tuple[UserMessage, bool]: 78 | chunks: List[ContentChunk] = [] 79 | 80 | response = input("Text prompt: ") 81 | if response: 82 | chunks.append(TextChunk(text=response)) 83 | 84 | print("[You can input zero, one or more images now.]") 85 | while True: 86 | did_something = False 87 | response = input("Image path or url [Leave empty and press enter to finish image input]: ") 88 | if response: 89 | if Path(response).is_file(): 90 | chunks.append(ImageChunk(image=Image.open(response))) 91 | else: 92 | assert response.startswith("http"), f"{response} does not seem to be a valid url." 93 | chunks.append(ImageURLChunk(image_url=response)) 94 | did_something = True 95 | 96 | if not did_something: 97 | break 98 | 99 | return UserMessage(content=chunks), not chunks 100 | 101 | 102 | def interactive( 103 | model_path: str, 104 | max_tokens: int = 35, 105 | temperature: float = 0.7, 106 | num_pipeline_ranks: int = 1, 107 | instruct: bool = False, 108 | lora_path: Optional[str] = None, 109 | ) -> None: 110 | if is_torchrun(): 111 | torch.distributed.init_process_group() 112 | torch.cuda.set_device(torch.distributed.get_rank()) 113 | should_print = torch.distributed.get_rank() == 0 114 | 115 | num_pipeline_ranks = torch.distributed.get_world_size() 116 | else: 117 | should_print = True 118 | num_pipeline_ranks = 1 119 | 120 | mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path)) 121 | tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer 122 | 123 | model_cls = get_model_cls(model_path) 124 | model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks) 125 | is_multimodal = isinstance(model.args, TransformerArgs) and model.args.vision_encoder is not None 126 | 127 | if is_multimodal: 128 | assert instruct, "Multimodal models should only be used in instruct mode" 129 | 130 | # load LoRA 131 | if lora_path is not None: 132 | model.load_lora(Path(lora_path)) 133 | 134 | prompt: str = "" 135 | messages: List[UserMessage | AssistantMessage] = [] 136 | 137 | while True: 138 | if should_print: 139 | if not is_multimodal: 140 | user_input = input("Prompt: ") 141 | 142 | if instruct: 143 | if is_multimodal: 144 | mm_input, finished = _get_multimodal_input() 145 | if finished: 146 | break 147 | messages += [mm_input] 148 | else: 149 | messages += [UserMessage(content=user_input)] 150 | chat_completion_request = ChatCompletionRequest(messages=messages) 151 | 152 | tokenized = mistral_tokenizer.encode_chat_completion(chat_completion_request) 153 | tokens = tokenized.tokens 154 | images = tokenized.images 155 | else: 156 | prompt += user_input 157 | 158 | tokens = tokenizer.encode(prompt, bos=True, eos=False) 159 | images = [] 160 | 161 | length_tensor = torch.tensor([len(tokens)], dtype=torch.int) 162 | else: 163 | length_tensor = torch.tensor([0], dtype=torch.int) 164 | images = [] 165 | 166 | if is_torchrun(): 167 | dist.broadcast(length_tensor, src=0) 168 | 169 | if not should_print: 170 | tokens = int(length_tensor.item()) * [0] 171 | 172 | generate_fn = generate if isinstance(model, Transformer) else generate_mamba 173 | generated_tokens, _ = generate_fn( # type: ignore[operator] 174 | [tokens], 175 | model, 176 | [images], 177 | max_tokens=max_tokens, 178 | temperature=temperature, 179 | eos_id=tokenizer.eos_id, 180 | ) 181 | 182 | answer = tokenizer.decode(generated_tokens[0]) 183 | 184 | if should_print: 185 | print(answer) 186 | print("=====================") 187 | 188 | if instruct: 189 | messages += [AssistantMessage(content=answer)] 190 | else: 191 | prompt += answer 192 | 193 | 194 | def demo( 195 | model_path: str, 196 | max_tokens: int = 35, 197 | temperature: float = 0, 198 | lora_path: Optional[str] = None, 199 | ) -> None: 200 | if is_torchrun(): 201 | torch.distributed.init_process_group() 202 | torch.cuda.set_device(torch.distributed.get_rank()) 203 | should_print = torch.distributed.get_rank() == 0 204 | 205 | num_pipeline_ranks = torch.distributed.get_world_size() 206 | else: 207 | should_print = True 208 | num_pipeline_ranks = 1 209 | 210 | model_cls = get_model_cls(model_path) 211 | model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks) 212 | # load LoRA 213 | if lora_path is not None: 214 | model.load_lora(Path(lora_path)) 215 | 216 | mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path)) 217 | tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer 218 | 219 | prompts = [ 220 | "This is a test", 221 | "This is another great test", 222 | "This is a third test, mistral AI is very good at testing. ", 223 | ] 224 | 225 | encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts] 226 | 227 | if isinstance(model, Transformer): 228 | generate_fn = generate 229 | else: 230 | generate_fn = generate_mamba # type: ignore[assignment] 231 | warnings.warn( 232 | "Batched generation is not correctly supported at the moment and therefore might lead to worse results " 233 | "as compared to non-batched generation. " 234 | "See https://github.com/state-spaces/mamba/issues/66#issuecomment-1862349718 for more information." 235 | ) 236 | encoded_prompts = pad_and_convert_to_tensor(encoded_prompts, mistral_tokenizer.instruct_tokenizer.BOS) # type: ignore[attr-defined] 237 | 238 | generated_tokens, _logprobs = generate_fn( 239 | encoded_prompts, 240 | model, # type: ignore[arg-type] 241 | max_tokens=max_tokens, 242 | temperature=temperature, 243 | eos_id=tokenizer.eos_id, 244 | ) 245 | 246 | generated_words = [] 247 | for i, x in enumerate(generated_tokens): 248 | generated_words.append(tokenizer.decode(encoded_prompts[i] + x)) 249 | 250 | res = generated_words 251 | 252 | if should_print: 253 | for w, logprob in zip(res, _logprobs): 254 | print(w) 255 | logging.debug("Logprobs: %s", logprob) 256 | print("=====================") 257 | 258 | 259 | def mistral_chat() -> None: 260 | fire.Fire(interactive) 261 | 262 | 263 | def mistral_demo() -> None: 264 | fire.Fire(demo) 265 | 266 | 267 | if __name__ == "__main__": 268 | logging.basicConfig(level=logging.INFO) 269 | fire.Fire( 270 | { 271 | "interactive": interactive, 272 | "demo": demo, 273 | } 274 | ) 275 | -------------------------------------------------------------------------------- /src/mistral_inference/mamba.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List, Optional, Union 4 | 5 | import safetensors 6 | import torch 7 | import torch.nn as nn 8 | 9 | from mistral_inference.args import MambaArgs 10 | from mistral_inference.cache import BufferCache 11 | from mistral_inference.model import ModelBase 12 | 13 | _is_mamba_installed = False 14 | try: 15 | from mamba_ssm.models.config_mamba import MambaConfig 16 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 17 | 18 | _is_mamba_installed = True 19 | except ImportError: 20 | _is_mamba_installed = False 21 | 22 | 23 | class Mamba(ModelBase, nn.Module): 24 | def __init__(self, args: MambaArgs): 25 | super().__init__() 26 | self.args = args 27 | assert _is_mamba_installed, "Mamba is not installed. Please install it using `pip install mamba-ssm`." 28 | 29 | # make sure naming is consistent with `mamba_ssm` 30 | config = MambaConfig( 31 | d_model=args.dim, 32 | n_layer=args.n_layers, 33 | vocab_size=args.vocab_size, 34 | ssm_cfg={"ngroups": args.n_groups, "layer": "Mamba2"}, 35 | attn_layer_idx=[], 36 | attn_cfg={}, 37 | rms_norm=args.rms_norm, 38 | residual_in_fp32=args.residual_in_fp32, 39 | fused_add_norm=args.fused_add_norm, 40 | pad_vocab_size_multiple=args.pad_vocab_size_multiple, 41 | tie_embeddings=args.tie_embeddings, 42 | ) 43 | self.model = MambaLMHeadModel(config) 44 | 45 | @property 46 | def dtype(self) -> torch.dtype: 47 | return next(self.parameters()).dtype 48 | 49 | @property 50 | def device(self) -> torch.device: 51 | return next(self.parameters()).device 52 | 53 | def forward( 54 | self, 55 | input_ids: torch.Tensor, 56 | seqlens: List[int], # not supported for now 57 | cache: Optional[BufferCache] = None, # not supported for now 58 | ) -> torch.Tensor: 59 | lm_output = self.model(input_ids) 60 | result: torch.Tensor = lm_output.logits 61 | return result 62 | 63 | @staticmethod 64 | def from_folder( 65 | folder: Union[Path, str], 66 | max_batch_size: int = 1, 67 | num_pipeline_ranks: int = 1, 68 | device: Union[torch.device, str] = "cuda", 69 | dtype: Optional[torch.dtype] = None, 70 | ) -> "Mamba": 71 | with open(Path(folder) / "params.json", "r") as f: 72 | model_args = MambaArgs.from_dict(json.load(f)) 73 | 74 | with torch.device("meta"): 75 | model = Mamba(model_args) 76 | 77 | model_file = Path(folder) / "consolidated.safetensors" 78 | 79 | assert model_file.exists(), f"Make sure {model_file} exists." 80 | loaded = safetensors.torch.load_file(str(model_file)) 81 | 82 | model.load_state_dict(loaded, assign=True, strict=True) 83 | return model.to(device=device, dtype=dtype) 84 | -------------------------------------------------------------------------------- /src/mistral_inference/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import List, Optional, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from mistral_inference.cache import BufferCache 9 | 10 | 11 | class ModelBase(nn.Module, ABC): 12 | def __init__(self) -> None: 13 | super().__init__() 14 | 15 | @property 16 | @abstractmethod 17 | def dtype(self) -> torch.dtype: 18 | pass 19 | 20 | @property 21 | @abstractmethod 22 | def device(self) -> torch.device: 23 | pass 24 | 25 | @abstractmethod 26 | def forward( 27 | self, 28 | input_ids: torch.Tensor, 29 | seqlens: List[int], # not supported for now 30 | cache: Optional[BufferCache] = None, # not supported for now 31 | ) -> torch.Tensor: 32 | pass 33 | 34 | @staticmethod 35 | @abstractmethod 36 | def from_folder( 37 | folder: Union[Path, str], 38 | max_batch_size: int = 1, 39 | num_pipeline_ranks: int = 1, 40 | device: Union[torch.device, str] = "cuda", 41 | dtype: Optional[torch.dtype] = None, 42 | ) -> "ModelBase": 43 | pass 44 | -------------------------------------------------------------------------------- /src/mistral_inference/moe.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from simple_parsing.helpers import Serializable 7 | from torch import nn 8 | 9 | 10 | @dataclasses.dataclass 11 | class MoeArgs(Serializable): 12 | num_experts: int 13 | num_experts_per_tok: int 14 | 15 | 16 | class MoeLayer(nn.Module): 17 | def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs): 18 | super().__init__() 19 | assert len(experts) > 0 20 | self.experts = nn.ModuleList(experts) 21 | self.gate = gate 22 | self.args = moe_args 23 | 24 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 25 | gate_logits = self.gate(inputs) 26 | weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok) 27 | weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) 28 | results = torch.zeros_like(inputs) 29 | for i, expert in enumerate(self.experts): 30 | batch_idx, nth_expert = torch.where(selected_experts == i) 31 | results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx]) 32 | return results 33 | -------------------------------------------------------------------------------- /src/mistral_inference/rope.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor: 7 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 8 | t = torch.arange(end, device=freqs.device) 9 | freqs = torch.outer(t, freqs).float() 10 | return torch.polar(torch.ones_like(freqs), freqs) # complex64 11 | 12 | 13 | def apply_rotary_emb( 14 | xq: torch.Tensor, 15 | xk: torch.Tensor, 16 | freqs_cis: torch.Tensor, 17 | ) -> Tuple[torch.Tensor, torch.Tensor]: 18 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 19 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 20 | freqs_cis = freqs_cis[:, None, :] 21 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) 22 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) 23 | return xq_out.type_as(xq), xk_out.type_as(xk) 24 | 25 | 26 | def precompute_freqs_cis_2d( 27 | dim: int, 28 | height: int, 29 | width: int, 30 | theta: float, 31 | ) -> torch.Tensor: 32 | """ 33 | freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by 34 | (height, width) position tuples 35 | """ 36 | # (dim / 2) frequency bases 37 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) 38 | 39 | h = torch.arange(height, device=freqs.device) 40 | w = torch.arange(width, device=freqs.device) 41 | 42 | freqs_h = torch.outer(h, freqs[::2]).float() 43 | freqs_w = torch.outer(w, freqs[1::2]).float() 44 | freqs_2d = torch.cat( 45 | [ 46 | freqs_h[:, None, :].repeat(1, width, 1), 47 | freqs_w[None, :, :].repeat(height, 1, 1), 48 | ], 49 | dim=-1, 50 | ) 51 | return torch.polar(torch.ones_like(freqs_2d), freqs_2d) 52 | -------------------------------------------------------------------------------- /src/mistral_inference/transformer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Any, List, Mapping, Optional, Union 7 | 8 | import safetensors.torch 9 | import torch 10 | from torch import nn 11 | 12 | from mistral_inference.args import PATCH_MERGE, TransformerArgs 13 | from mistral_inference.cache import BufferCache, CacheInputMetadata 14 | from mistral_inference.lora import LoRALoaderMixin 15 | from mistral_inference.model import ModelBase 16 | from mistral_inference.rope import precompute_freqs_cis 17 | from mistral_inference.transformer_layers import RMSNorm, TransformerBlock 18 | from mistral_inference.vision_encoder import PatchMerger, VisionLanguageAdapter, VisionTransformer 19 | 20 | 21 | @dataclass 22 | class SimpleInputMetadata: 23 | # rope absolute positions 24 | positions: torch.Tensor 25 | 26 | @staticmethod 27 | def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata": 28 | return SimpleInputMetadata( 29 | positions=torch.cat([torch.arange(0, seqlen) for seqlen in seqlens]).to(device=device, dtype=torch.long) 30 | ) 31 | 32 | 33 | class Transformer(ModelBase, LoRALoaderMixin): 34 | def __init__( 35 | self, 36 | args: TransformerArgs, 37 | pipeline_rank: int = 0, 38 | num_pipeline_ranks: int = 1, 39 | softmax_fp32: bool = True, 40 | ): 41 | super().__init__() 42 | self.args = args 43 | self.vocab_size = args.vocab_size 44 | self.n_layers = args.n_layers 45 | self._precomputed_freqs_cis: Optional[torch.Tensor] = None 46 | assert self.vocab_size > 0 47 | assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks) 48 | self.pipeline_rank = pipeline_rank 49 | self.num_pipeline_ranks = num_pipeline_ranks 50 | self.softmax_fp32 = softmax_fp32 51 | 52 | # Modules specific to some ranks: 53 | self.tok_embeddings: Optional[nn.Embedding] = None 54 | self.norm: Optional[RMSNorm] = None 55 | self.output: Optional[nn.Linear] = None 56 | if pipeline_rank == 0: 57 | self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) 58 | 59 | self.vision_encoder: Optional[VisionTransformer] = None 60 | self.vision_language_adapter: Optional[VisionLanguageAdapter] = None 61 | 62 | if args.vision_encoder is not None: 63 | self.vision_encoder = VisionTransformer(args.vision_encoder) 64 | self.vision_language_adapter = VisionLanguageAdapter( 65 | args.vision_encoder.hidden_size, args.dim, args.vision_encoder.adapter_bias 66 | ) 67 | 68 | if args.vision_encoder.add_pre_mm_projector_layer_norm: 69 | self.pre_mm_projector_norm = RMSNorm(args.vision_encoder.hidden_size, eps=1e-5) 70 | 71 | if args.vision_encoder.mm_projector_id == PATCH_MERGE: 72 | self.patch_merger = PatchMerger( 73 | vision_encoder_dim=args.vision_encoder.hidden_size, 74 | spatial_merge_size=args.vision_encoder.spatial_merge_size, 75 | ) 76 | 77 | if pipeline_rank == num_pipeline_ranks - 1: 78 | self.norm = RMSNorm(args.dim, eps=args.norm_eps) 79 | self.output = nn.Linear(args.dim, args.vocab_size, bias=False) 80 | # Initialize all layers but slice off those not of this rank. 81 | layers = [ 82 | TransformerBlock( 83 | dim=args.dim, 84 | hidden_dim=args.hidden_dim, 85 | n_heads=args.n_heads, 86 | n_kv_heads=args.n_kv_heads, 87 | head_dim=args.head_dim, 88 | norm_eps=args.norm_eps, 89 | lora=args.lora, 90 | moe=args.moe, 91 | ) 92 | for _ in range(args.n_layers) 93 | ] 94 | num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks) 95 | offset = self.pipeline_rank * num_layers_per_rank 96 | end = min(self.n_layers, offset + num_layers_per_rank) 97 | self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)}) 98 | self.n_local_layers = len(self.layers) 99 | 100 | @property 101 | def dtype(self) -> torch.dtype: 102 | return next(self.parameters()).dtype 103 | 104 | @property 105 | def device(self) -> torch.device: 106 | return next(self.parameters()).device 107 | 108 | @property 109 | def freqs_cis(self) -> torch.Tensor: 110 | # We cache freqs_cis but need to take care that it is on the right device 111 | # and has the right dtype (complex64). The fact that the dtype is different 112 | # from the module's dtype means we cannot register it as a buffer 113 | if self._precomputed_freqs_cis is None: 114 | # default to 10**6 115 | theta = self.args.rope_theta or 1000000.0 116 | self._precomputed_freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta) 117 | 118 | if self._precomputed_freqs_cis.device != self.device: 119 | self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device) 120 | return self._precomputed_freqs_cis 121 | 122 | def embed_vision_language_features(self, input_ids: torch.Tensor, images: List[torch.Tensor]) -> torch.Tensor: 123 | assert self.tok_embeddings is not None 124 | assert self.vision_encoder is not None 125 | assert self.vision_language_adapter is not None 126 | assert self.args.vision_encoder is not None 127 | 128 | text_locations = input_ids != self.args.vision_encoder.image_token_id 129 | image_locations = input_ids == self.args.vision_encoder.image_token_id 130 | text_features = self.tok_embeddings(input_ids[text_locations]) 131 | 132 | image_features = self.vision_encoder(images) 133 | 134 | if self.args.vision_encoder.add_pre_mm_projector_layer_norm: 135 | image_features = self.pre_mm_projector_norm(image_features) 136 | 137 | if self.args.vision_encoder.mm_projector_id == PATCH_MERGE: 138 | patch_size = self.args.vision_encoder.patch_size 139 | img_patch_dims = [(img.shape[1] // patch_size, img.shape[2] // patch_size) for img in images] 140 | image_features = self.patch_merger(image_features, image_sizes=img_patch_dims) 141 | 142 | image_features = self.vision_language_adapter(image_features) 143 | 144 | N_txt, D_txt = text_features.shape 145 | N_img, D_img = image_features.shape 146 | 147 | seq_len = input_ids.shape[0] 148 | 149 | assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}" 150 | assert seq_len == N_txt + N_img, ( 151 | f"seq_len {seq_len} should be equal to N_txt + N_img {(N_txt, N_img, image_locations.sum().item())}" 152 | ) 153 | 154 | combined_features = torch.empty( 155 | (seq_len, D_txt), 156 | dtype=text_features.dtype, 157 | device=text_features.device, 158 | ) 159 | combined_features[text_locations, :] = text_features 160 | combined_features[image_locations, :] = image_features 161 | return combined_features 162 | 163 | def forward_partial( 164 | self, 165 | input_ids: torch.Tensor, 166 | seqlens: List[int], 167 | cache: Optional[BufferCache] = None, 168 | images: Optional[List[torch.Tensor]] = None, 169 | ) -> torch.Tensor: 170 | """Local forward pass. 171 | 172 | If doing pipeline parallelism, this will return the activations of the last layer of this stage. 173 | For the last stage, this will return the normalized final embeddings. 174 | """ 175 | assert len(seqlens) <= self.args.max_batch_size, ( 176 | f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}" 177 | ) 178 | (num_toks,) = input_ids.shape 179 | assert sum(seqlens) == num_toks, (sum(seqlens), num_toks) 180 | 181 | input_metadata: List[CacheInputMetadata] | List[SimpleInputMetadata] 182 | 183 | if cache is not None: 184 | input_metadata = cache.get_input_metadata(seqlens) 185 | else: 186 | input_metadata = [SimpleInputMetadata.from_seqlens(seqlens, self.device) for _ in range(len(self.layers))] 187 | 188 | if self.pipeline_rank == 0: 189 | assert self.tok_embeddings is not None 190 | if self.vision_encoder is not None and images: 191 | h = self.embed_vision_language_features(input_ids, images) 192 | else: 193 | h = self.tok_embeddings(input_ids) 194 | else: 195 | h = torch.empty(num_toks, self.args.dim, device=self.device, dtype=self.dtype) 196 | torch.distributed.recv(h, src=self.pipeline_rank - 1) 197 | 198 | # freqs_cis is always the same for every layer 199 | freqs_cis = self.freqs_cis[input_metadata[0].positions] 200 | 201 | for local_layer_id, layer in enumerate(self.layers.values()): 202 | if cache is not None: 203 | assert input_metadata is not None 204 | cache_metadata = input_metadata[local_layer_id] 205 | assert isinstance(cache_metadata, CacheInputMetadata) 206 | cache_view = cache.get_view(local_layer_id, cache_metadata) 207 | else: 208 | cache_view = None 209 | h = layer(h, freqs_cis, cache_view) 210 | 211 | if cache is not None: 212 | cache.update_seqlens(seqlens) 213 | if self.pipeline_rank < self.num_pipeline_ranks - 1: 214 | torch.distributed.send(h, dst=self.pipeline_rank + 1) 215 | return h 216 | else: 217 | # Last rank has a final normalization step. 218 | assert self.norm is not None 219 | return self.norm(h) # type: ignore 220 | 221 | def forward( 222 | self, 223 | input_ids: torch.Tensor, 224 | seqlens: List[int], 225 | cache: Optional[BufferCache] = None, 226 | images: Optional[List[torch.Tensor]] = None, 227 | ) -> torch.Tensor: 228 | h = self.forward_partial(input_ids, seqlens, cache=cache, images=images) 229 | if self.pipeline_rank < self.num_pipeline_ranks - 1: 230 | # ignore the intermediate activations as we'll get the final output from 231 | # the last stage 232 | outs = torch.empty(h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype) 233 | else: 234 | assert self.output is not None 235 | outs = self.output(h) 236 | if self.num_pipeline_ranks > 1: 237 | torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1) 238 | 239 | if self.softmax_fp32: 240 | return outs.float() 241 | else: 242 | return outs 243 | 244 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None: 245 | state_to_load = {} 246 | skipped = set([]) 247 | for k, v in state_dict.items(): 248 | if k.startswith("tok_embeddings"): 249 | if self.pipeline_rank == 0: 250 | state_to_load[k] = v 251 | else: 252 | logging.debug( 253 | "Skipping parameter %s at pipeline rank %d", 254 | k, 255 | self.pipeline_rank, 256 | ) 257 | skipped.add(k) 258 | elif k.startswith("norm") or k.startswith("output"): 259 | if self.pipeline_rank == self.num_pipeline_ranks - 1: 260 | state_to_load[k] = v 261 | else: 262 | logging.debug( 263 | "Skipping parameter %s at pipeline rank %d", 264 | k, 265 | self.pipeline_rank, 266 | ) 267 | skipped.add(k) 268 | elif k.startswith("layers"): 269 | layer_id = k.split(".")[1] 270 | if layer_id in self.layers: 271 | state_to_load[k] = v 272 | else: 273 | logging.debug( 274 | "Skipping parameter %s at pipeline rank %d", 275 | k, 276 | self.pipeline_rank, 277 | ) 278 | skipped.add(k) 279 | elif any( 280 | k.startswith(key) 281 | for key in ["vision_encoder", "vision_language_adapter", "patch_merger", "pre_mm_projector_norm"] 282 | ): 283 | if self.pipeline_rank == 0: 284 | state_to_load[k] = v 285 | else: 286 | logging.debug( 287 | "Skipping parameter %s at pipeline rank %d", 288 | k, 289 | self.pipeline_rank, 290 | ) 291 | skipped.add(k) 292 | else: 293 | raise ValueError(f"Unexpected key {k}") 294 | assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys())) 295 | super().load_state_dict(state_to_load, strict=strict, assign=assign) 296 | 297 | @staticmethod 298 | def from_folder( 299 | folder: Union[Path, str], 300 | max_batch_size: int = 1, 301 | num_pipeline_ranks: int = 1, 302 | device: Union[torch.device, str] = "cuda", 303 | dtype: Optional[torch.dtype] = None, 304 | softmax_fp32: bool = True, 305 | ) -> "Transformer": 306 | with open(Path(folder) / "params.json", "r") as f: 307 | model_args = TransformerArgs.from_dict(json.load(f)) 308 | model_args.max_batch_size = max_batch_size 309 | if num_pipeline_ranks > 1: 310 | pipeline_rank = torch.distributed.get_rank() 311 | else: 312 | pipeline_rank = 0 313 | with torch.device("meta"): 314 | model = Transformer( 315 | model_args, 316 | pipeline_rank=pipeline_rank, 317 | num_pipeline_ranks=num_pipeline_ranks, 318 | softmax_fp32=softmax_fp32, 319 | ) 320 | 321 | pt_model_file = Path(folder) / "consolidated.00.pth" 322 | safetensors_model_file = Path(folder) / "consolidated.safetensors" 323 | 324 | assert pt_model_file.exists() or safetensors_model_file.exists(), ( 325 | f"Make sure either {pt_model_file} or {safetensors_model_file} exists" 326 | ) 327 | assert not (pt_model_file.exists() and safetensors_model_file.exists()), ( 328 | f"Both {pt_model_file} and {safetensors_model_file} cannot exist" 329 | ) 330 | 331 | if pt_model_file.exists(): 332 | loaded = torch.load(str(pt_model_file), mmap=True) 333 | else: 334 | loaded = safetensors.torch.load_file(str(safetensors_model_file)) 335 | 336 | model.load_state_dict(loaded, assign=True, strict=True) 337 | 338 | return model.to(device=device, dtype=dtype) 339 | -------------------------------------------------------------------------------- /src/mistral_inference/transformer_layers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional, Tuple, Type, Union 3 | 4 | import torch 5 | from torch import nn 6 | from xformers.ops.fmha import memory_efficient_attention # type: ignore 7 | from xformers.ops.fmha.attn_bias import BlockDiagonalMask 8 | 9 | from mistral_inference.args import LoraArgs 10 | from mistral_inference.cache import CacheView 11 | from mistral_inference.lora import LoRALinear 12 | from mistral_inference.moe import MoeArgs, MoeLayer 13 | from mistral_inference.rope import apply_rotary_emb 14 | 15 | 16 | def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: 17 | keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) 18 | values = torch.repeat_interleave(values, repeats=repeats, dim=dim) 19 | return keys, values 20 | 21 | 22 | def maybe_lora( 23 | lora_args: Optional[LoraArgs], 24 | ) -> Union[Type[nn.Linear], partial[LoRALinear]]: 25 | if lora_args is None: 26 | return nn.Linear 27 | else: 28 | return partial(LoRALinear, rank=lora_args.rank, scaling=lora_args.scaling) 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__( 33 | self, 34 | dim: int, 35 | n_heads: int, 36 | head_dim: int, 37 | n_kv_heads: int, 38 | lora: Optional[LoraArgs] = None, 39 | ): 40 | super().__init__() 41 | 42 | self.n_heads: int = n_heads 43 | self.head_dim: int = head_dim 44 | self.n_kv_heads: int = n_kv_heads 45 | 46 | self.repeats = self.n_heads // self.n_kv_heads 47 | 48 | self.scale = self.head_dim**-0.5 49 | 50 | MaybeLora = maybe_lora(lora) 51 | self.wq = MaybeLora(dim, n_heads * head_dim, bias=False) 52 | self.wk = MaybeLora(dim, n_kv_heads * head_dim, bias=False) 53 | self.wv = MaybeLora(dim, n_kv_heads * head_dim, bias=False) 54 | self.wo = MaybeLora(n_heads * head_dim, dim, bias=False) 55 | 56 | def forward( 57 | self, 58 | x: torch.Tensor, 59 | freqs_cis: torch.Tensor, 60 | cache: Optional[CacheView] = None, 61 | mask: Optional[BlockDiagonalMask] = None, 62 | ) -> torch.Tensor: 63 | assert mask is None or cache is None 64 | seqlen_sum, _ = x.shape 65 | 66 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 67 | xq = xq.view(seqlen_sum, self.n_heads, self.head_dim) 68 | xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim) 69 | xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim) 70 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 71 | 72 | if cache is None: 73 | key, val = xk, xv 74 | elif cache.prefill: 75 | key, val = cache.interleave_kv(xk, xv) 76 | cache.update(xk, xv) 77 | else: 78 | cache.update(xk, xv) 79 | key, val = cache.key, cache.value 80 | key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim) 81 | val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim) 82 | 83 | # Repeat keys and values to match number of query heads 84 | key, val = repeat_kv(key, val, self.repeats, dim=1) 85 | 86 | # xformers requires (B=1, S, H, D) 87 | xq, key, val = xq[None, ...], key[None, ...], val[None, ...] 88 | output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask) 89 | output = output.view(seqlen_sum, self.n_heads * self.head_dim) 90 | 91 | assert isinstance(output, torch.Tensor) 92 | 93 | return self.wo(output) # type: ignore 94 | 95 | 96 | class FeedForward(nn.Module): 97 | def __init__(self, dim: int, hidden_dim: int, lora: Optional[LoraArgs] = None): 98 | super().__init__() 99 | 100 | MaybeLora = maybe_lora(lora) 101 | self.w1 = MaybeLora(dim, hidden_dim, bias=False) 102 | self.w2 = MaybeLora(hidden_dim, dim, bias=False) 103 | self.w3 = MaybeLora(dim, hidden_dim, bias=False) 104 | 105 | def forward(self, x: torch.Tensor) -> torch.Tensor: 106 | return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore 107 | 108 | 109 | class RMSNorm(torch.nn.Module): 110 | def __init__(self, dim: int, eps: float = 1e-6): 111 | super().__init__() 112 | self.eps = eps 113 | self.weight = nn.Parameter(torch.ones(dim)) 114 | 115 | def _norm(self, x: torch.Tensor) -> torch.Tensor: 116 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 117 | 118 | def forward(self, x: torch.Tensor) -> torch.Tensor: 119 | output = self._norm(x.float()).type_as(x) 120 | return output * self.weight 121 | 122 | 123 | class TransformerBlock(nn.Module): 124 | def __init__( 125 | self, 126 | dim: int, 127 | hidden_dim: int, 128 | n_heads: int, 129 | n_kv_heads: int, 130 | head_dim: int, 131 | norm_eps: float, 132 | lora: Optional[LoraArgs] = None, 133 | moe: Optional[MoeArgs] = None, 134 | ): 135 | super().__init__() 136 | self.n_heads = n_heads 137 | self.dim = dim 138 | self.attention = Attention( 139 | dim=dim, 140 | n_heads=n_heads, 141 | head_dim=head_dim, 142 | n_kv_heads=n_kv_heads, 143 | lora=lora, 144 | ) 145 | self.attention_norm = RMSNorm(dim, eps=norm_eps) 146 | self.ffn_norm = RMSNorm(dim, eps=norm_eps) 147 | 148 | self.feed_forward: nn.Module 149 | if moe is not None: 150 | self.feed_forward = MoeLayer( 151 | experts=[FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora) for _ in range(moe.num_experts)], 152 | gate=nn.Linear(dim, moe.num_experts, bias=False), 153 | moe_args=moe, 154 | ) 155 | else: 156 | self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora) 157 | 158 | def forward( 159 | self, 160 | x: torch.Tensor, 161 | freqs_cis: torch.Tensor, 162 | cache: Optional[CacheView] = None, 163 | mask: Optional[BlockDiagonalMask] = None, 164 | ) -> torch.Tensor: 165 | r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) 166 | h = x + r 167 | r = self.feed_forward.forward(self.ffn_norm(h)) 168 | out = h + r 169 | return out 170 | -------------------------------------------------------------------------------- /src/mistral_inference/vision_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from xformers.ops.fmha.attn_bias import BlockDiagonalMask 6 | 7 | from mistral_inference.args import VisionEncoderArgs 8 | from mistral_inference.rope import precompute_freqs_cis_2d 9 | from mistral_inference.transformer_layers import RMSNorm, TransformerBlock 10 | 11 | 12 | def position_meshgrid( 13 | patch_embeds_list: list[torch.Tensor], 14 | ) -> torch.Tensor: 15 | positions = torch.cat( 16 | [ 17 | torch.stack( 18 | torch.meshgrid( 19 | torch.arange(p.shape[-2]), 20 | torch.arange(p.shape[-1]), 21 | indexing="ij", 22 | ), 23 | dim=-1, 24 | ).reshape(-1, 2) 25 | for p in patch_embeds_list 26 | ] 27 | ) 28 | return positions 29 | 30 | 31 | class VisionTransformer(nn.Module): 32 | def __init__(self, args: VisionEncoderArgs): 33 | super().__init__() 34 | self.args = args 35 | self.patch_conv = nn.Conv2d( 36 | in_channels=args.num_channels, 37 | out_channels=args.hidden_size, 38 | kernel_size=args.patch_size, 39 | stride=args.patch_size, 40 | bias=False, 41 | ) 42 | self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) 43 | self.transformer = VisionTransformerBlocks(args) 44 | 45 | head_dim = self.args.hidden_size // self.args.num_attention_heads 46 | assert head_dim % 2 == 0, "ROPE requires even head_dim" 47 | self._freqs_cis: Optional[torch.Tensor] = None 48 | 49 | @property 50 | def max_patches_per_side(self) -> int: 51 | return self.args.image_size // self.args.patch_size 52 | 53 | @property 54 | def device(self) -> torch.device: 55 | return next(self.parameters()).device 56 | 57 | @property 58 | def freqs_cis(self) -> torch.Tensor: 59 | if self._freqs_cis is None: 60 | self._freqs_cis = precompute_freqs_cis_2d( 61 | dim=self.args.hidden_size // self.args.num_attention_heads, 62 | height=self.max_patches_per_side, 63 | width=self.max_patches_per_side, 64 | theta=self.args.rope_theta, 65 | ) 66 | 67 | if self._freqs_cis.device != self.device: 68 | self._freqs_cis = self._freqs_cis.to(device=self.device) 69 | 70 | return self._freqs_cis 71 | 72 | def forward( 73 | self, 74 | images: List[torch.Tensor], 75 | ) -> torch.Tensor: 76 | """ 77 | Args: 78 | images: list of N_img images of variable sizes, each of shape (C, H, W) 79 | 80 | Returns: 81 | image_features: tensor of token features for all tokens of all images of 82 | shape (N_toks, D) 83 | """ 84 | # pass images through initial convolution independently 85 | patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images] 86 | 87 | # flatten to a single sequence 88 | patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0) 89 | patch_embeds = self.ln_pre(patch_embeds) 90 | 91 | # positional embeddings 92 | positions = position_meshgrid(patch_embeds_list).to(self.device) 93 | freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] 94 | 95 | # pass through Transformer with a block diagonal mask delimiting images 96 | mask = BlockDiagonalMask.from_seqlens( 97 | [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], 98 | ) 99 | out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) 100 | 101 | # remove batch dimension of the single sequence 102 | return out # type: ignore[no-any-return] 103 | 104 | 105 | class VisionLanguageAdapter(nn.Module): 106 | def __init__(self, in_dim: int, out_dim: int, bias: bool = True): 107 | super().__init__() 108 | self.w_in = nn.Linear( 109 | in_dim, 110 | out_dim, 111 | bias=bias, 112 | ) 113 | self.gelu = nn.GELU() 114 | self.w_out = nn.Linear(out_dim, out_dim, bias=bias) 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return] 118 | 119 | 120 | class VisionTransformerBlocks(nn.Module): 121 | def __init__(self, args: VisionEncoderArgs): 122 | super().__init__() 123 | self.layers = torch.nn.ModuleList() 124 | for _ in range(args.num_hidden_layers): 125 | self.layers.append( 126 | TransformerBlock( 127 | dim=args.hidden_size, 128 | hidden_dim=args.intermediate_size, 129 | n_heads=args.num_attention_heads, 130 | n_kv_heads=args.num_attention_heads, 131 | head_dim=args.hidden_size // args.num_attention_heads, 132 | norm_eps=1e-5, 133 | ) 134 | ) 135 | 136 | def forward( 137 | self, 138 | x: torch.Tensor, 139 | mask: BlockDiagonalMask, 140 | freqs_cis: Optional[torch.Tensor], 141 | ) -> torch.Tensor: 142 | for layer in self.layers: 143 | x = layer(x, mask=mask, freqs_cis=freqs_cis) 144 | return x 145 | 146 | 147 | class PatchMerger(nn.Module): 148 | """ 149 | Learned merging of spatial_merge_size ** 2 patches 150 | """ 151 | 152 | def __init__( 153 | self, 154 | vision_encoder_dim: int, 155 | spatial_merge_size: int, 156 | ) -> None: 157 | super().__init__() 158 | 159 | mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) 160 | 161 | self.spatial_merge_size = spatial_merge_size 162 | self.mlp_input_dim = mlp_input_dim 163 | 164 | self.merging_layer = nn.Linear(mlp_input_dim, vision_encoder_dim, bias=False) 165 | 166 | def forward(self, x: torch.Tensor, image_sizes: list[tuple[int, int]]) -> torch.Tensor: 167 | # image_sizes specified in tokens 168 | assert sum([h * w for h, w in image_sizes]) == len(x), f"{sum([h * w for h, w in image_sizes])} != {len(x)}" 169 | 170 | # x is (N, vision_encoder_dim) 171 | x = self.permute(x, image_sizes) 172 | 173 | # x is (N / spatial_merge_size ** 2, 174 | # vision_encoder_dim * spatial_merge_size ** 2) 175 | x = self.merging_layer(x) 176 | 177 | # x is (N / spatial_merge_size ** 2, vision_encoder_dim) 178 | return x 179 | 180 | def permute( 181 | self, 182 | x: torch.Tensor, 183 | image_sizes: list[tuple[int, int]], 184 | ) -> torch.Tensor: 185 | """ 186 | Args: 187 | x: (N, D) where N is flattened and concatenated patch tokens 188 | for all images 189 | image_sizes: list of tuple of (height, width) in tokens for 190 | each image 191 | Returns: 192 | image_features: reorders patch tokens so each grid of 193 | (spatial_merge_size, spatial_merge_size) is contiguous. 194 | now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2) 195 | """ 196 | 197 | sub_grids = get_sub_grids( 198 | x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size 199 | ) # list of [d x sub_grid_size x sub_grid_size x n_patches] 200 | permuted_tensor = [ 201 | grid.view(-1, grid.shape[-1]).t() for grid in sub_grids 202 | ] # n_patches x d * sub_grid_size * sub_grid_size 203 | return torch.cat(permuted_tensor, dim=0) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2) 204 | 205 | 206 | def get_sub_grids( 207 | x: torch.Tensor, 208 | image_sizes: list[tuple[int, int]], 209 | spatial_merge_size: int, 210 | ) -> list[torch.Tensor]: 211 | # image_sizes specified in tokens 212 | tokens_per_image = [h * w for h, w in image_sizes] 213 | d = x.shape[-1] 214 | all_img_sub_grids: list[torch.Tensor] = [] 215 | sub_grid_size = spatial_merge_size 216 | 217 | for image_index, image_tokens in enumerate(x.split(tokens_per_image)): 218 | # Reshape image_tokens into a 2D grid 219 | h, w = image_sizes[image_index] 220 | image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[None, :, :, :] # 1 x d x h x w 221 | sub_grids = torch.nn.functional.unfold(image_grid, kernel_size=sub_grid_size, stride=sub_grid_size) 222 | sub_grids = sub_grids.view( 223 | 1, d, sub_grid_size, sub_grid_size, -1 224 | ) # 1 x d x sub_grid_size x sub_grid_size x n_patches 225 | 226 | all_img_sub_grids.append(sub_grids[0]) 227 | 228 | return all_img_sub_grids 229 | -------------------------------------------------------------------------------- /tests/test_generate.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | from mistral_inference.args import VisionEncoderArgs 6 | from mistral_inference.generate import generate_mamba 7 | from mistral_inference.main import generate 8 | from mistral_inference.mamba import Mamba, MambaArgs 9 | from mistral_inference.transformer import Transformer, TransformerArgs 10 | 11 | 12 | class DebugTokenizer: 13 | @property 14 | def bos_id(self) -> int: 15 | return 0 16 | 17 | @property 18 | def eos_id(self) -> int: 19 | return 1 20 | 21 | @property 22 | def pad_id(self) -> int: 23 | return -1 24 | 25 | def encode(self, s: str, bos: bool = True) -> List[int]: 26 | assert isinstance(s, str) 27 | t = [int(x) for x in s.split()] 28 | if bos: 29 | t = [self.bos_id, *t] 30 | return t 31 | 32 | def decode(self, t: List[int]) -> str: 33 | return " ".join([str(x) for x in t]) 34 | 35 | 36 | def test_generation_transformer() -> None: 37 | torch.manual_seed(42) 38 | 39 | sequences = ["1 2 3 4 5 6 7", "0 1 2", "12 13 14", "2 4 34"] 40 | args = TransformerArgs( 41 | dim=512, 42 | n_layers=1, 43 | head_dim=128, 44 | hidden_dim=2048, 45 | n_heads=4, 46 | n_kv_heads=2, 47 | norm_eps=1e-5, 48 | vocab_size=32_000, 49 | max_batch_size=len(sequences), 50 | ) 51 | model = Transformer(args).to("cuda", dtype=torch.float32) 52 | tokenizer = DebugTokenizer() 53 | 54 | encoded = [tokenizer.encode(s, bos=True) for s in sequences] 55 | toks, all_logprobs_old = generate(encoded, model, temperature=0.0, max_tokens=7) 56 | 57 | # concat generated and prompt 58 | encoded = [e + t for e, t in zip(encoded, toks)] 59 | 60 | generated, all_logprobs_new = generate(encoded, model, temperature=0.0, max_tokens=0) 61 | 62 | assert generated == [] 63 | 64 | # Verify that logprobs are the same 65 | assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new) 66 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new): 67 | assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}" 68 | 69 | print("All tests passed.") 70 | 71 | 72 | def test_generation_pixtral() -> None: 73 | torch.manual_seed(42) 74 | gen = np.random.default_rng(seed=42) 75 | 76 | sequences = ["1 2 2 2 2 4 5 6 7", "12 13 14", "2 2 2 2 7 8 9"] 77 | images = [[gen.normal(size=(3, 4, 4))], [], [gen.normal(size=(3, 4, 4))]] 78 | args = TransformerArgs( 79 | dim=512, 80 | n_layers=1, 81 | head_dim=128, 82 | hidden_dim=2048, 83 | n_heads=4, 84 | n_kv_heads=2, 85 | norm_eps=1e-5, 86 | vocab_size=32_000, 87 | max_batch_size=len(sequences), 88 | vision_encoder=VisionEncoderArgs( 89 | hidden_size=128, 90 | num_channels=3, 91 | image_size=4, 92 | patch_size=2, 93 | intermediate_size=256, 94 | num_hidden_layers=1, 95 | num_attention_heads=2, 96 | rope_theta=10000, 97 | image_token_id=2, 98 | ), 99 | ) 100 | model = Transformer(args).to("cuda", dtype=torch.float32) 101 | tokenizer = DebugTokenizer() 102 | 103 | encoded = [tokenizer.encode(s, bos=True) for s in sequences] 104 | toks, all_logprobs_old = generate(encoded, model, images=images, temperature=0.0, max_tokens=7) 105 | 106 | # concat generated and prompt 107 | encoded = [e + t for e, t in zip(encoded, toks)] 108 | 109 | generated, all_logprobs_new = generate(encoded, model, images=images, temperature=0.0, max_tokens=0) 110 | 111 | assert generated == [] 112 | 113 | # Verify that logprobs are the same 114 | assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new) 115 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new): 116 | assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}" 117 | 118 | print("All tests passed.") 119 | 120 | 121 | def test_generation_pixtral_patch_merger() -> None: 122 | torch.manual_seed(42) 123 | gen = np.random.default_rng(seed=42) 124 | 125 | sequences = ["1 2 2 2 2 4 5 6 7", "12 13 14", "2 2 2 2 7 8 9"] 126 | images = [[gen.normal(size=(3, 8, 8))], [], [gen.normal(size=(3, 8, 8))]] 127 | args = TransformerArgs( 128 | dim=512, 129 | n_layers=1, 130 | head_dim=128, 131 | hidden_dim=2048, 132 | n_heads=4, 133 | n_kv_heads=2, 134 | norm_eps=1e-5, 135 | vocab_size=32_000, 136 | max_batch_size=len(sequences), 137 | vision_encoder=VisionEncoderArgs( 138 | hidden_size=128, 139 | num_channels=3, 140 | image_size=8, 141 | patch_size=2, 142 | intermediate_size=256, 143 | num_hidden_layers=1, 144 | num_attention_heads=2, 145 | rope_theta=10000, 146 | image_token_id=2, 147 | adapter_bias=False, 148 | spatial_merge_size=2, 149 | add_pre_mm_projector_layer_norm=True, 150 | mm_projector_id="patch_merge", 151 | ), 152 | ) 153 | model = Transformer(args).to("cuda", dtype=torch.float32) 154 | tokenizer = DebugTokenizer() 155 | 156 | encoded = [tokenizer.encode(s, bos=True) for s in sequences] 157 | toks, all_logprobs_old = generate(encoded, model, images=images, temperature=0.0, max_tokens=7) 158 | 159 | # concat generated and prompt 160 | encoded = [e + t for e, t in zip(encoded, toks)] 161 | 162 | generated, all_logprobs_new = generate(encoded, model, images=images, temperature=0.0, max_tokens=0) 163 | 164 | assert generated == [] 165 | 166 | # Verify that logprobs are the same 167 | assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new) 168 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new): 169 | assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}" 170 | 171 | print("All tests passed.") 172 | 173 | 174 | def test_generation_mamba() -> None: 175 | torch.manual_seed(42) 176 | 177 | sequences = ["1 2 3 4 5 6 7"] 178 | args = MambaArgs( 179 | dim=512, 180 | n_layers=1, 181 | n_groups=1, 182 | rms_norm=True, 183 | residual_in_fp32=True, 184 | fused_add_norm=True, 185 | pad_vocab_size_multiple=1, 186 | tie_embeddings=False, 187 | vocab_size=32768, 188 | ) 189 | model = Mamba(args).to("cuda", dtype=torch.float32) 190 | tokenizer = DebugTokenizer() 191 | 192 | encoded = [tokenizer.encode(s, bos=True) for s in sequences] 193 | toks, all_logprobs_old = generate_mamba(encoded, model, temperature=0.0, max_tokens=7) 194 | 195 | assert len(toks[0]) == 7 196 | assert toks == [[25574, 14821, 11843, 23698, 12735, 23522, 27542]] 197 | 198 | 199 | def test_chunks_transformer() -> None: 200 | torch.manual_seed(42) 201 | 202 | sequences = [ 203 | " ".join([str(i) for i in range(7)]), 204 | " ".join([str(i) for i in range(9, 0, -1)]), 205 | ] 206 | args = TransformerArgs( 207 | dim=512, 208 | n_layers=1, 209 | head_dim=128, 210 | hidden_dim=2048, 211 | n_heads=4, 212 | n_kv_heads=2, 213 | norm_eps=1e-5, 214 | vocab_size=32_000, 215 | max_batch_size=3, 216 | ) 217 | model = Transformer(args).to("cuda", dtype=torch.float32) 218 | tokenizer = DebugTokenizer() 219 | 220 | encoded = [tokenizer.encode(s, bos=True) for s in sequences] 221 | toks, all_logprobs_old = generate(encoded, model, temperature=0.0, max_tokens=8) 222 | 223 | # concat generated and prompt 224 | encoded = [e + t for e, t in zip(encoded, toks)] 225 | 226 | generated, all_logprobs_new = generate(encoded, model, temperature=0.0, max_tokens=0, chunk_size=5) 227 | assert len(generated) == 0 228 | 229 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new): 230 | assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}" 231 | -------------------------------------------------------------------------------- /tutorials/getting_started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Getting Started with `mistral-inference`\n", 8 | "\n", 9 | "This notebook will guide you through the process of running Mistral models locally. We will cover the following: \n", 10 | "- How to chat with Mistral 7B Instruct\n", 11 | "- How to run Mistral 7B Instruct with function calling capabilities\n", 12 | "\n", 13 | "We recommend using a GPU such as the A100 to run this notebook. " 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "id": "G6tXvIsQenpI" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "!pip install mistral-inference" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Download Mistral 7B Instruct" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "colab": { 39 | "background_save": true 40 | }, 41 | "id": "4ytmRt0WQeMW" 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "!wget https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": { 52 | "id": "eRZg_8wvs5A6" 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "!DIR=$HOME/mistral_7b_instruct_v3 && mkdir -p $DIR && tar -xf mistral-7B-Instruct-v0.3.tar -C $DIR" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "id": "7CN8gShDf65M" 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "!ls mistral_7b_instruct_v3" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Chat with the model" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "import os \n", 84 | "\n", 85 | "from mistral_inference.transformer import Transformer\n", 86 | "from mistral_inference.generate import generate\n", 87 | "\n", 88 | "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n", 89 | "from mistral_common.protocol.instruct.messages import UserMessage\n", 90 | "from mistral_common.protocol.instruct.request import ChatCompletionRequest\n", 91 | "\n", 92 | "# load tokenizer\n", 93 | "mistral_tokenizer = MistralTokenizer.from_file(os.path.expanduser(\"~\")+\"/mistral_7b_instruct_v3/tokenizer.model.v3\")\n", 94 | "# chat completion request\n", 95 | "completion_request = ChatCompletionRequest(messages=[UserMessage(content=\"Explain Machine Learning to me in a nutshell.\")])\n", 96 | "# encode message\n", 97 | "tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens\n", 98 | "# load model\n", 99 | "model = Transformer.from_folder(os.path.expanduser(\"~\")+\"/mistral_7b_instruct_v3\")\n", 100 | "# generate results\n", 101 | "out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)\n", 102 | "# decode generated tokens\n", 103 | "result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])\n", 104 | "print(result)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": { 110 | "id": "ce4woS3LkgZ9" 111 | }, 112 | "source": [ 113 | "## Function calling\n", 114 | "\n", 115 | "Mistral 7B Instruct v3 also supports function calling!" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": { 121 | "id": "TKfPiEwNk1kh" 122 | }, 123 | "source": [ 124 | "Let's start by creating a function calling example" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": { 131 | "id": "0PJdwvDEk3dl" 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "from mistral_common.protocol.instruct.messages import UserMessage\n", 136 | "from mistral_common.protocol.instruct.request import ChatCompletionRequest\n", 137 | "from mistral_common.protocol.instruct.tool_calls import Function, Tool\n", 138 | "\n", 139 | "completion_request = ChatCompletionRequest(\n", 140 | " tools=[\n", 141 | " Tool(\n", 142 | " function=Function(\n", 143 | " name=\"get_current_weather\",\n", 144 | " description=\"Get the current weather\",\n", 145 | " parameters={\n", 146 | " \"type\": \"object\",\n", 147 | " \"properties\": {\n", 148 | " \"location\": {\n", 149 | " \"type\": \"string\",\n", 150 | " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", 151 | " },\n", 152 | " \"format\": {\n", 153 | " \"type\": \"string\",\n", 154 | " \"enum\": [\"celsius\", \"fahrenheit\"],\n", 155 | " \"description\": \"The temperature unit to use. Infer this from the users location.\",\n", 156 | " },\n", 157 | " },\n", 158 | " \"required\": [\"location\", \"format\"],\n", 159 | " },\n", 160 | " )\n", 161 | " )\n", 162 | " ],\n", 163 | " messages=[\n", 164 | " UserMessage(content=\"What's the weather like today in Paris?\"),\n", 165 | " ],\n", 166 | ")" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": { 172 | "id": "bG6ZeZUylpBW" 173 | }, 174 | "source": [ 175 | "Since we have already loaded the tokenizer and the model in the example above. We will skip these steps here. \n", 176 | "\n", 177 | "Now we can encode the message with our tokenizer using `MistralTokenizer`. " 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "id": "Ii8q-JNClwiq" 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n", 189 | "\n", 190 | "tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": { 196 | "id": "NrueDujkmJT4" 197 | }, 198 | "source": [ 199 | "and run `generate` to get a response. Don't forget to pass the EOS id!" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": { 206 | "id": "GWJYO43rl0V8" 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "from mistral_inference.generate import generate\n", 211 | "\n", 212 | "out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": { 218 | "id": "v7baJ1msmPMv" 219 | }, 220 | "source": [ 221 | "Finally, we can decode the generated tokens." 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": { 228 | "id": "RKhryfBWmHon" 229 | }, 230 | "outputs": [], 231 | "source": [ 232 | "result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens)[0]\n", 233 | "result" 234 | ] 235 | } 236 | ], 237 | "metadata": { 238 | "accelerator": "GPU", 239 | "colab": { 240 | "gpuType": "L4", 241 | "machine_shape": "hm", 242 | "provenance": [] 243 | }, 244 | "kernelspec": { 245 | "display_name": "Python 3 (ipykernel)", 246 | "language": "python", 247 | "name": "python3" 248 | }, 249 | "language_info": { 250 | "codemirror_mode": { 251 | "name": "ipython", 252 | "version": 3 253 | }, 254 | "file_extension": ".py", 255 | "mimetype": "text/x-python", 256 | "name": "python", 257 | "nbconvert_exporter": "python", 258 | "pygments_lexer": "ipython3", 259 | "version": "3.11.8" 260 | } 261 | }, 262 | "nbformat": 4, 263 | "nbformat_minor": 4 264 | } 265 | --------------------------------------------------------------------------------