├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ └── config.yml
├── .gitignore
├── LICENSE
├── README.md
├── assets
└── smoe.png
├── deploy
├── .dockerignore
├── Dockerfile
└── entrypoint.sh
├── poetry.lock
├── pyproject.toml
├── src
└── mistral_inference
│ ├── __init__.py
│ ├── args.py
│ ├── cache.py
│ ├── generate.py
│ ├── lora.py
│ ├── main.py
│ ├── mamba.py
│ ├── model.py
│ ├── moe.py
│ ├── rope.py
│ ├── transformer.py
│ ├── transformer_layers.py
│ └── vision_encoder.py
├── tests
└── test_generate.py
└── tutorials
├── classifier.ipynb
└── getting_started.ipynb
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug report related to mistral-inference
2 | description: Submit a bug report that's related to mistral-inference
3 | title: '[BUG: '
4 | labels: ['bug', 'triage']
5 | body:
6 | - type: markdown
7 | attributes:
8 | value: |
9 | Thanks for taking the time to fill out this bug report!
10 | - type: textarea
11 | id: python-vv
12 | attributes:
13 | label: Python -VV
14 | description: Run `python -VV` from your virtual environment
15 | placeholder: Copy-paste the output (no need for backticks, will be formatted into code automatically)
16 | render: shell
17 | validations:
18 | required: true
19 | - type: textarea
20 | id: pip-freeze
21 | attributes:
22 | label: Pip Freeze
23 | description: Run `pip freeze` from your virtual environment
24 | placeholder: Copy-paste the output (no need for backticks, will be formatted into code automatically)
25 | render: shell
26 | validations:
27 | required: true
28 | - type: textarea
29 | id: reproduction-steps
30 | attributes:
31 | label: Reproduction Steps
32 | description: Provide a clear and concise description of the steps that lead to your issue.
33 | placeholder: |
34 | 1. First step...
35 | 2. Step 2...
36 | ...
37 | validations:
38 | required: true
39 | - type: textarea
40 | id: expected-behavior
41 | attributes:
42 | label: Expected Behavior
43 | description: Explain briefly what you expected to happen.
44 | validations:
45 | required: true
46 | - type: textarea
47 | id: additional-context
48 | attributes:
49 | label: Additional Context
50 | description: Add any context about your problem that you deem relevant.
51 | - type: textarea
52 | id: suggested-solutions
53 | attributes:
54 | label: Suggested Solutions
55 | description: Please list any solutions you recommend we consider.
56 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Documentation
4 | url: https://docs.mistral.ai
5 | about: Developer documentation for the Mistral AI platform
6 | - name: Discord
7 | url: https://discord.com/invite/mistralai)
8 | about: Chat with the Mistral community
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mistral Inference
2 |
3 |
4 |
5 |
6 |
7 | This repository contains minimal code to run Mistral models.
8 |
9 | Blog 7B: [https://mistral.ai/news/announcing-mistral-7b/](https://mistral.ai/news/announcing-mistral-7b/)\
10 | Blog 8x7B: [https://mistral.ai/news/mixtral-of-experts/](https://mistral.ai/news/mixtral-of-experts/)\
11 | Blog 8x22B: [https://mistral.ai/news/mixtral-8x22b/](https://mistral.ai/news/mixtral-8x22b/)\
12 | Blog Codestral 22B: [https://mistral.ai/news/codestral](https://mistral.ai/news/codestral/) \
13 | Blog Codestral Mamba 7B: [https://mistral.ai/news/codestral-mamba/](https://mistral.ai/news/codestral-mamba/) \
14 | Blog Mathstral 7B: [https://mistral.ai/news/mathstral/](https://mistral.ai/news/mathstral/) \
15 | Blog Nemo: [https://mistral.ai/news/mistral-nemo/](https://mistral.ai/news/mistral-nemo/) \
16 | Blog Mistral Large 2: [https://mistral.ai/news/mistral-large-2407/](https://mistral.ai/news/mistral-large-2407/) \
17 | Blog Pixtral 12B: [https://mistral.ai/news/pixtral-12b/](https://mistral.ai/news/pixtral-12b/)
18 | Blog Mistral Small 3.1: [https://mistral.ai/news/mistral-small-3-1/](https://mistral.ai/news/mistral-small-3-1/)
19 |
20 | Discord: [https://discord.com/invite/mistralai](https://discord.com/invite/mistralai)\
21 | Documentation: [https://docs.mistral.ai/](https://docs.mistral.ai/)\
22 | Guardrailing: [https://docs.mistral.ai/usage/guardrailing](https://docs.mistral.ai/usage/guardrailing)
23 |
24 | ## Installation
25 |
26 | Note: You will use a GPU to install `mistral-inference`, as it currently requires `xformers` to be installed and `xformers` itself needs a GPU for installation.
27 |
28 | ### PyPI
29 |
30 | ```
31 | pip install mistral-inference
32 | ```
33 |
34 | ### Local
35 |
36 | ```
37 | cd $HOME && git clone https://github.com/mistralai/mistral-inference
38 | cd $HOME/mistral-inference && poetry install .
39 | ```
40 |
41 | ## Model download
42 |
43 | ### Direct links
44 |
45 | | Name | Download | md5sum |
46 | |-------------|-------|-------|
47 | | 7B Instruct | https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar | `80b71fcb6416085bcb4efad86dfb4d52` |
48 | | 8x7B Instruct | https://models.mistralcdn.com/mixtral-8x7b-v0-1/Mixtral-8x7B-v0.1-Instruct.tar (**Updated model coming soon!**) | `8e2d3930145dc43d3084396f49d38a3f` |
49 | | 8x22 Instruct | https://models.mistralcdn.com/mixtral-8x22b-v0-3/mixtral-8x22B-Instruct-v0.3.tar | `471a02a6902706a2f1e44a693813855b` |
50 | | 7B Base | https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar | `0663b293810d7571dad25dae2f2a5806` |
51 | | 8x7B | **Updated model coming soon!** | - |
52 | | 8x22B | https://models.mistralcdn.com/mixtral-8x22b-v0-3/mixtral-8x22B-v0.3.tar | `a2fa75117174f87d1197e3a4eb50371a` |
53 | | Codestral 22B | https://models.mistralcdn.com/codestral-22b-v0-1/codestral-22B-v0.1.tar | `1ea95d474a1d374b1d1b20a8e0159de3` |
54 | | Mathstral 7B | https://models.mistralcdn.com/mathstral-7b-v0-1/mathstral-7B-v0.1.tar | `5f05443e94489c261462794b1016f10b` |
55 | | Codestral-Mamba 7B | https://models.mistralcdn.com/codestral-mamba-7b-v0-1/codestral-mamba-7B-v0.1.tar | `d3993e4024d1395910c55db0d11db163` |
56 | | Nemo Base | https://models.mistralcdn.com/mistral-nemo-2407/mistral-nemo-base-2407.tar | `c5d079ac4b55fc1ae35f51f0a3c0eb83` |
57 | | Nemo Instruct | https://models.mistralcdn.com/mistral-nemo-2407/mistral-nemo-instruct-2407.tar | `296fbdf911cb88e6f0be74cd04827fe7` |
58 | | Mistral Large 2 | https://models.mistralcdn.com/mistral-large-2407/mistral-large-instruct-2407.tar | `fc602155f9e39151fba81fcaab2fa7c4` |
59 |
60 | Note:
61 | - **Important**:
62 | - `mixtral-8x22B-Instruct-v0.3.tar` is exactly the same as [Mixtral-8x22B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), only stored in `.safetensors` format
63 | - `mixtral-8x22B-v0.3.tar` is the same as [Mixtral-8x22B-v0.1](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1), but has an extended vocabulary of 32768 tokens.
64 | - `codestral-22B-v0.1.tar` has a custom non-commercial license, called [Mistral AI Non-Production (MNPL) License](https://mistral.ai/licenses/MNPL-0.1.md)
65 | - `mistral-large-instruct-2407.tar` has a custom non-commercial license, called [Mistral AI Research (MRL) License](https://mistral.ai/licenses/MRL-0.1.md)
66 | - All of the listed models above support function calling. For example, Mistral 7B Base/Instruct v3 is a minor update to Mistral 7B Base/Instruct v2, with the addition of function calling capabilities.
67 | - The "coming soon" models will include function calling as well.
68 | - You can download the previous versions of our models from our [docs](https://docs.mistral.ai/getting-started/open_weight_models/#downloading).
69 |
70 | ### From Hugging Face Hub
71 |
72 | | Name | ID | URL |
73 | |-------------|-------|-------|
74 | | Pixtral Large Instruct | mistralai/Pixtral-Large-Instruct-2411 | https://huggingface.co/mistralai/Pixtral-Large-Instruct-2411 |
75 | | Pixtral 12B Base | mistralai/Pixtral-12B-Base-2409 | https://huggingface.co/mistralai/Pixtral-12B-Base-2409 |
76 | | Pixtral 12B | mistralai/Pixtral-12B-2409 | https://huggingface.co/mistralai/Pixtral-12B-2409 |
77 | | Mistral Small 3.1 24B Base | mistralai/Mistral-Small-3.1-24B-Base-2503 | https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503
78 | | Mistral Small 3.1 24B Instruct | mistralai/Mistral-Small-3.1-24B-Instruct-2503 | https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503 |
79 |
80 |
81 | ### Usage
82 |
83 | **News!!!**: Mistral Large 2 is out. Read more about its capabilities [here](https://mistral.ai/news/mistral-large-2407/).
84 |
85 | Create a local folder to store models
86 | ```sh
87 | export MISTRAL_MODEL=$HOME/mistral_models
88 | mkdir -p $MISTRAL_MODEL
89 | ```
90 |
91 | Download any of the above links and extract the content, *e.g.*:
92 |
93 | ```sh
94 | export 12B_DIR=$MISTRAL_MODEL/12B_Nemo
95 | wget https://models.mistralcdn.com/mistral-nemo-2407/mistral-nemo-instruct-2407.tar
96 | mkdir -p $12B_DIR
97 | tar -xf mistral-nemo-instruct-2407.tar -C $12B_DIR
98 | ```
99 |
100 | or
101 |
102 | ```sh
103 | export M8x7B_DIR=$MISTRAL_MODEL/8x7b_instruct
104 | wget https://models.mistralcdn.com/mixtral-8x7b-v0-1/Mixtral-8x7B-v0.1-Instruct.tar
105 | mkdir -p $M8x7B_DIR
106 | tar -xf Mixtral-8x7B-v0.1-Instruct.tar -C $M8x7B_DIR
107 | ```
108 |
109 | For Hugging Face models' weights, here is an example to download [Mistral Small 3.1 24B Instruct](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503):
110 |
111 | ```python
112 | from pathlib import Path
113 | from huggingface_hub import snapshot_download
114 |
115 |
116 | mistral_models_path = Path.home().joinpath("mistral_models")
117 |
118 | model_path = mistral_models_path / "mistral-small-3.1-instruct"
119 | model_path.mkdir(parents=True, exist_ok=True)
120 |
121 | repo_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
122 |
123 | snapshot_download(
124 | repo_id=repo_id,
125 | allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
126 | local_dir=model_path,
127 | )
128 | ```
129 |
130 | ## Usage
131 |
132 | The following sections give an overview of how to run the model from the Command-line interface (CLI) or directly within Python.
133 |
134 | ### CLI
135 |
136 | - **Demo**
137 |
138 | To test that a model works in your setup, you can run the `mistral-demo` command.
139 | *E.g.* the 12B Mistral-Nemo model can be tested on a single GPU as follows:
140 |
141 | ```sh
142 | mistral-demo $12B_DIR
143 | ```
144 |
145 | Large models, such **8x7B** and **8x22B** have to be run in a multi-GPU setup.
146 | For these models, you can use the following command:
147 |
148 | ```sh
149 | torchrun --nproc-per-node 2 --no-python mistral-demo $M8x7B_DIR
150 | ```
151 |
152 | *Note*: Change `--nproc-per-node` to more GPUs if available.
153 |
154 | - **Chat**
155 |
156 | To interactively chat with the models, you can make use of the `mistral-chat` command.
157 |
158 | ```sh
159 | mistral-chat $12B_DIR --instruct --max_tokens 1024 --temperature 0.35
160 | ```
161 |
162 | For large models, you can make use of `torchrun`.
163 |
164 | ```sh
165 | torchrun --nproc-per-node 2 --no-python mistral-chat $M8x7B_DIR --instruct
166 | ```
167 |
168 | *Note*: Change `--nproc-per-node` to more GPUs if necessary (*e.g.* for 8x22B).
169 |
170 | - **Chat with Codestral**
171 |
172 | To use [Codestral](https://mistral.ai/news/codestral/) as a coding assistant you can run the following command using `mistral-chat`.
173 | Make sure `$M22B_CODESTRAL` is set to a valid path to the downloaded codestral folder, e.g. `$HOME/mistral_models/Codestral-22B-v0.1`
174 |
175 | ```sh
176 | mistral-chat $M22B_CODESTRAL --instruct --max_tokens 256
177 | ```
178 |
179 | If you prompt it with *"Write me a function that computes fibonacci in Rust"*, the model should generate something along the following lines:
180 |
181 | ```sh
182 | Sure, here's a simple implementation of a function that computes the Fibonacci sequence in Rust. This function takes an integer `n` as an argument and returns the `n`th Fibonacci number.
183 |
184 | fn fibonacci(n: u32) -> u32 {
185 | match n {
186 | 0 => 0,
187 | 1 => 1,
188 | _ => fibonacci(n - 1) + fibonacci(n - 2),
189 | }
190 | }
191 |
192 | fn main() {
193 | let n = 10;
194 | println!("The {}th Fibonacci number is: {}", n, fibonacci(n));
195 | }
196 |
197 | This function uses recursion to calculate the Fibonacci number. However, it's not the most efficient solution because it performs a lot of redundant calculations. A more efficient solution would use a loop to iteratively calculate the Fibonacci numbers.
198 | ```
199 |
200 | You can continue chatting afterwards, *e.g.* with *"Translate it to Python"*.
201 |
202 | - **Chat with Codestral-Mamba**
203 |
204 | To use [Codestral-Mamba](https://mistral.ai/news/codestral-mamba/) as a coding assistant you can run the following command using `mistral-chat`.
205 | Make sure `$7B_CODESTRAL_MAMBA` is set to a valid path to the downloaded codestral-mamba folder, e.g. `$HOME/mistral_models/mamba-codestral-7B-v0.1`.
206 |
207 | You then need to additionally install the following packages:
208 |
209 | ```
210 | pip install packaging mamba-ssm causal-conv1d transformers
211 | ```
212 |
213 | before you can start chatting:
214 |
215 | ```sh
216 | mistral-chat $7B_CODESTRAL_MAMBA --instruct --max_tokens 256
217 | ```
218 |
219 | - **Chat with Mathstral**
220 |
221 | To use [Mathstral](https://mistral.ai/news/mathstral/) as an assistant you can run the following command using `mistral-chat`.
222 | Make sure `$7B_MATHSTRAL` is set to a valid path to the downloaded codestral folder, e.g. `$HOME/mistral_models/mathstral-7B-v0.1`
223 |
224 | ```sh
225 | mistral-chat $7B_MATHSTRAL --instruct --max_tokens 256
226 | ```
227 |
228 | If you prompt it with *"Albert likes to surf every week. Each surfing session lasts for 4 hours and costs $20 per hour. How much would Albert spend in 5 weeks?"*, the model should answer with the correct calculation.
229 |
230 | You can then continue chatting afterwards, *e.g.* with *"How much would he spend in a year?"*.
231 |
232 | - **Chat with Mistral Small 3.1 24B Instruct**
233 |
234 | To use [Mistral Small 3.1 24B Instruct](https://mistral.ai/news/mistral-small-3-1/) as an assistant you can run the following command using `mistral-chat`.
235 | Make sure `$MISTRAL_SMALL_3_1_INSTRUCT` is set to a valid path to the downloaded mistral small folder, e.g. `$HOME/mistral_models/mistral-small-3.1-instruct`
236 |
237 | ```sh
238 | mistral-chat $MISTRAL_SMALL_3_1_INSTRUCT --instruct --max_tokens 256
239 | ```
240 |
241 | If you prompt it with *"The above image presents an image of which park ? Please give the hints to identify the park."* with the following image URL *https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png*, the model should answer with the Yosemite park and give hints to identify it.
242 |
243 | You can then continue chatting afterwards, *e.g.* with *"What is the name of the lake in the image?"*. The model should respond that it is not a lake but a river.
244 |
245 | ### Python
246 |
247 | - *Instruction Following*:
248 |
249 | ```py
250 | from mistral_inference.transformer import Transformer
251 | from mistral_inference.generate import generate
252 |
253 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
254 | from mistral_common.protocol.instruct.messages import UserMessage
255 | from mistral_common.protocol.instruct.request import ChatCompletionRequest
256 |
257 |
258 | tokenizer = MistralTokenizer.from_file("./mistral-nemo-instruct-v0.1/tekken.json") # change to extracted tokenizer file
259 | model = Transformer.from_folder("./mistral-nemo-instruct-v0.1") # change to extracted model dir
260 |
261 | prompt = "How expensive would it be to ask a window cleaner to clean all windows in Paris. Make a reasonable guess in US Dollar."
262 |
263 | completion_request = ChatCompletionRequest(messages=[UserMessage(content=prompt)])
264 |
265 | tokens = tokenizer.encode_chat_completion(completion_request).tokens
266 |
267 | out_tokens, _ = generate([tokens], model, max_tokens=1024, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
268 | result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
269 |
270 | print(result)
271 | ```
272 |
273 | - *Multimodal Instruction Following*:
274 |
275 |
276 | ```python
277 | from pathlib import Path
278 |
279 | from huggingface_hub import snapshot_download
280 | from mistral_common.protocol.instruct.messages import ImageURLChunk, TextChunk
281 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
282 | from mistral_inference.generate import generate
283 | from mistral_inference.transformer import Transformer
284 |
285 | model_path = Path.home().joinpath("mistral_models") / "mistral-small-3.1-instruct" # change to extracted model
286 |
287 | tokenizer = MistralTokenizer.from_file(model_path / "tekken.json")
288 | model = Transformer.from_folder(model_path)
289 |
290 | url = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
291 | prompt = "The above image presents an image of which park ? Please give the hints to identify the park."
292 |
293 | user_content = [ImageURLChunk(image_url=url), TextChunk(text=prompt)]
294 |
295 | tokens, images = tokenizer.instruct_tokenizer.encode_user_content(user_content, False)
296 |
297 | out_tokens, _ = generate(
298 | [tokens],
299 | model,
300 | images=[images],
301 | max_tokens=256,
302 | temperature=0.15,
303 | eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id,
304 | )
305 | result = tokenizer.decode(out_tokens[0])
306 |
307 | print("Prompt:", prompt)
308 | print("Completion:", result)
309 | ```
310 |
311 | - *Function Calling*:
312 |
313 | ```py
314 | from mistral_common.protocol.instruct.tool_calls import Function, Tool
315 |
316 | completion_request = ChatCompletionRequest(
317 | tools=[
318 | Tool(
319 | function=Function(
320 | name="get_current_weather",
321 | description="Get the current weather",
322 | parameters={
323 | "type": "object",
324 | "properties": {
325 | "location": {
326 | "type": "string",
327 | "description": "The city and state, e.g. San Francisco, CA",
328 | },
329 | "format": {
330 | "type": "string",
331 | "enum": ["celsius", "fahrenheit"],
332 | "description": "The temperature unit to use. Infer this from the users location.",
333 | },
334 | },
335 | "required": ["location", "format"],
336 | },
337 | )
338 | )
339 | ],
340 | messages=[
341 | UserMessage(content="What's the weather like today in Paris?"),
342 | ],
343 | )
344 |
345 | tokens = tokenizer.encode_chat_completion(completion_request).tokens
346 |
347 | out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
348 | result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
349 |
350 | print(result)
351 | ```
352 |
353 | - *Fill-in-the-middle (FIM)*:
354 |
355 | Make sure to have `mistral-common >= 1.2.0` installed:
356 | ```
357 | pip install --upgrade mistral-common
358 | ```
359 |
360 | You can simulate a code completion in-filling as follows.
361 |
362 | ```py
363 | from mistral_inference.transformer import Transformer
364 | from mistral_inference.generate import generate
365 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
366 | from mistral_common.tokens.instruct.request import FIMRequest
367 |
368 | tokenizer = MistralTokenizer.from_model("codestral-22b")
369 | model = Transformer.from_folder("./mistral_22b_codestral")
370 |
371 | prefix = """def add("""
372 | suffix = """ return sum"""
373 |
374 | request = FIMRequest(prompt=prefix, suffix=suffix)
375 |
376 | tokens = tokenizer.encode_fim(request).tokens
377 |
378 | out_tokens, _ = generate([tokens], model, max_tokens=256, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
379 | result = tokenizer.decode(out_tokens[0])
380 |
381 | middle = result.split(suffix)[0].strip()
382 | print(middle)
383 | ```
384 |
385 | ### Test
386 |
387 | To run logits equivalence:
388 | ```
389 | python -m pytest tests
390 | ```
391 |
392 | ## Deployment
393 |
394 | The `deploy` folder contains code to build a [vLLM](https://M7B_DIR.com/vllm-project/vllm) image with the required dependencies to serve the Mistral AI model. In the image, the [transformers](https://github.com/huggingface/transformers/) library is used instead of the reference implementation. To build it:
395 |
396 | ```bash
397 | docker build deploy --build-arg MAX_JOBS=8
398 | ```
399 |
400 | Instructions to run the image can be found in the [official documentation](https://docs.mistral.ai/quickstart).
401 |
402 |
403 | ## Model platforms
404 |
405 | - Use Mistral models on [Mistral AI official API](https://console.mistral.ai/) (La Plateforme)
406 | - Use Mistral models via [cloud providers](https://docs.mistral.ai/deployment/cloud/overview/)
407 |
408 | ## References
409 |
410 | [1]: [LoRA](https://arxiv.org/abs/2106.09685): Low-Rank Adaptation of Large Language Models, Hu et al. 2021
411 |
--------------------------------------------------------------------------------
/assets/smoe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mistralai/mistral-inference/6eb35510403825cfb430b0004443053e8c4b70dc/assets/smoe.png
--------------------------------------------------------------------------------
/deploy/.dockerignore:
--------------------------------------------------------------------------------
1 | *
2 | !entrypoint.sh
3 |
--------------------------------------------------------------------------------
/deploy/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM --platform=amd64 nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04 as base
2 |
3 | WORKDIR /workspace
4 |
5 | RUN apt update && \
6 | apt install -y python3-pip python3-packaging \
7 | git ninja-build && \
8 | pip3 install -U pip
9 |
10 | # Tweak this list to reduce build time
11 | # https://developer.nvidia.com/cuda-gpus
12 | ENV TORCH_CUDA_ARCH_LIST "7.0;7.2;7.5;8.0;8.6;8.9;9.0"
13 |
14 | RUN pip3 install "torch==2.1.1"
15 |
16 | # This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed.
17 | RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git"
18 | RUN pip3 install "git+https://github.com/vllm-project/vllm.git"
19 | RUN pip3 install "xformers==0.0.23" "transformers==4.36.0" "fschat[model_worker]==0.2.34"
20 |
21 | RUN git clone https://github.com/NVIDIA/apex && \
22 | cd apex && git checkout 2386a912164b0c5cfcd8be7a2b890fbac5607c82 && \
23 | sed -i '/check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)/d' setup.py && \
24 | python3 setup.py install --cpp_ext --cuda_ext
25 |
26 |
27 | COPY entrypoint.sh .
28 |
29 | RUN chmod +x /workspace/entrypoint.sh
30 |
31 | ENTRYPOINT ["/workspace/entrypoint.sh"]
--------------------------------------------------------------------------------
/deploy/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ ! -z "${HF_TOKEN}" ]]; then
4 | echo "The HF_TOKEN environment variable is set, logging to Hugging Face."
5 | python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
6 | else
7 | echo "The HF_TOKEN environment variable is not set or empty, not logging to Hugging Face."
8 | fi
9 |
10 | # Run the provided command
11 | exec python3 -u -m vllm.entrypoints.openai.api_server "$@"
12 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "mistral_inference"
3 | version = "1.6.0"
4 | description = ""
5 | authors = ["bam4d "]
6 | readme = "README.md"
7 | packages = [{ include = "mistral_inference", from = "src" }]
8 |
9 | [tool.ruff]
10 | lint.select = ["E", "F", "W", "Q", "I"]
11 | lint.ignore = ["E203"]
12 | lint.fixable = ["ALL"]
13 | lint.unfixable = []
14 | line-length = 120
15 | exclude = ["docs", "build", "tutorials"]
16 |
17 | [tool.mypy]
18 | disallow_untyped_defs = true
19 | show_error_codes = true
20 | no_implicit_optional = true
21 | warn_return_any = true
22 | warn_unused_ignores = true
23 | exclude = ["docs", "tools", "build"]
24 |
25 | [tool.poetry.dependencies]
26 | python = "^3.9.10"
27 | xformers = ">=0.0.24"
28 | simple-parsing = ">=0.1.5"
29 | fire = ">=0.6.0"
30 | mistral_common = ">=1.5.4"
31 | safetensors = ">=0.4.0"
32 | pillow = ">=10.3.0"
33 |
34 | [tool.poetry.group.dev.dependencies]
35 | types-protobuf = "4.24.0.20240129"
36 | mypy-protobuf = "^3.5.0"
37 | pytest = "7.4.4"
38 | ruff = "^0.2.2"
39 | mypy = "^1.8.0"
40 |
41 | [build-system]
42 | requires = ["poetry-core>=1.0.0"]
43 | build-backend = "poetry.core.masonry.api"
44 |
45 | [tool.pytest.ini_options]
46 | testpaths = ["./tests"]
47 |
48 | [tool.poetry.scripts]
49 | mistral-chat = "mistral_inference.main:mistral_chat"
50 | mistral-demo = "mistral_inference.main:mistral_demo"
51 |
--------------------------------------------------------------------------------
/src/mistral_inference/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.6.0"
2 |
--------------------------------------------------------------------------------
/src/mistral_inference/args.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional
3 |
4 | from simple_parsing.helpers import Serializable
5 |
6 | from mistral_inference.lora import LoraArgs
7 | from mistral_inference.moe import MoeArgs
8 |
9 | PATCH_MERGE = "patch_merge"
10 |
11 |
12 | @dataclass
13 | class VisionEncoderArgs:
14 | hidden_size: int
15 | num_channels: int
16 | image_size: int
17 | patch_size: int
18 | intermediate_size: int
19 | num_hidden_layers: int
20 | num_attention_heads: int
21 | rope_theta: float = 1e4 # for rope-2D
22 | image_token_id: int = 10
23 | adapter_bias: bool = True
24 | spatial_merge_size: int = 1
25 | add_pre_mm_projector_layer_norm: bool = False
26 | mm_projector_id: str = ""
27 |
28 |
29 | @dataclass
30 | class TransformerArgs(Serializable):
31 | dim: int
32 | n_layers: int
33 | head_dim: int
34 | hidden_dim: int
35 | n_heads: int
36 | n_kv_heads: int
37 | norm_eps: float
38 | vocab_size: int
39 |
40 | max_batch_size: int = 0
41 |
42 | # For rotary embeddings. If not set, will be inferred
43 | rope_theta: Optional[float] = None
44 | # If this is set, we will use MoE layers instead of dense layers.
45 | moe: Optional[MoeArgs] = None
46 | # If this is set, we will load LoRA linear layers instead of linear layers.
47 | lora: Optional[LoraArgs] = None
48 | sliding_window: Optional[int] | Optional[List[int]] = None
49 | _sliding_window: Optional[int] | Optional[List[int]] = None
50 | model_type: str = "transformer"
51 |
52 | vision_encoder: Optional[VisionEncoderArgs] = None
53 |
54 | def __post_init__(self) -> None:
55 | assert self.model_type == "transformer", self.model_type
56 | assert self.sliding_window is None or self._sliding_window is None
57 |
58 | # hack for now so that vLLM is supported correctly
59 | self.sliding_window = self.sliding_window if self.sliding_window is not None else self._sliding_window
60 |
61 |
62 | @dataclass
63 | class MambaArgs(Serializable):
64 | dim: int
65 | n_layers: int
66 | vocab_size: int
67 | n_groups: int
68 | rms_norm: bool
69 | residual_in_fp32: bool
70 | fused_add_norm: bool
71 | pad_vocab_size_multiple: int
72 | tie_embeddings: bool
73 | model_type: str = "mamba"
74 |
75 | def __post_init__(self) -> None:
76 | assert self.model_type == "mamba", self.model_type
77 |
--------------------------------------------------------------------------------
/src/mistral_inference/cache.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Tuple
3 |
4 | import torch
5 | from xformers.ops.fmha.attn_bias import ( # type: ignore
6 | AttentionBias,
7 | BlockDiagonalCausalMask,
8 | BlockDiagonalCausalWithOffsetPaddedKeysMask,
9 | BlockDiagonalMask,
10 | )
11 |
12 |
13 | def get_cache_sizes(n_layers: int, max_seq_len: int, sliding_window: Optional[int] | Optional[List[int]]) -> List[int]:
14 | if sliding_window is None:
15 | return n_layers * [max_seq_len]
16 | elif isinstance(sliding_window, int):
17 | return n_layers * [sliding_window]
18 | else:
19 | assert isinstance(sliding_window, list), f"Expected list, got {type(sliding_window)}"
20 | assert (
21 | n_layers % len(sliding_window) == 0
22 | ), f"Expected n_layers % len(sliding_window) == 0, got {n_layers} % {len(sliding_window)}"
23 | num_repeats = n_layers // len(sliding_window)
24 | return num_repeats * [w if w is not None else max_seq_len for w in sliding_window]
25 |
26 |
27 | @dataclass
28 | class CacheInputMetadata:
29 | # # rope absolute positions
30 | # positions: torch.Tensor
31 | # # where tokens should go in the cache
32 | # cache_positions: torch.Tensor
33 |
34 | # # if prefill, use block diagonal causal mask
35 | # # else use causal with padded key mask
36 | # prefill: bool
37 | # mask: AttentionBias
38 | # seqlens: List[int]
39 | # rope absolute positions
40 | positions: torch.Tensor
41 | # which elements in the sequences need to be cached
42 | to_cache_mask: torch.Tensor
43 | # how many elements are cached per sequence
44 | cached_elements: torch.Tensor
45 | # where tokens should go in the cache
46 | cache_positions: torch.Tensor
47 | # if prefill, use block diagonal causal mask
48 | # else use causal with padded key mask
49 | prefill: bool
50 | mask: AttentionBias
51 | seqlens: List[int]
52 |
53 |
54 | def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]) -> List[torch.Tensor]:
55 | assert len(l1) == len(l2)
56 | return [v for pair in zip(l1, l2) for v in pair]
57 |
58 |
59 | def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor:
60 | assert cache.ndim == 3 # (W, H, D)
61 | position = seqlen % cache.shape[0]
62 | if seqlen < cache.shape[0]:
63 | return cache[:seqlen]
64 | elif position == 0:
65 | return cache
66 | else:
67 | return torch.cat([cache[position:], cache[:position]], dim=0)
68 |
69 |
70 | class CacheView:
71 | def __init__(
72 | self,
73 | cache_k: torch.Tensor,
74 | cache_v: torch.Tensor,
75 | metadata: CacheInputMetadata,
76 | kv_seqlens: torch.Tensor,
77 | ):
78 | self.cache_k = cache_k
79 | self.cache_v = cache_v
80 | self.kv_seqlens = kv_seqlens
81 | self.metadata = metadata
82 |
83 | def update(self, xk: torch.Tensor, xv: torch.Tensor) -> None:
84 | """
85 | to_cache_mask masks the last [max_seq_len] tokens in each sequence
86 | """
87 | n_kv_heads, head_dim = self.cache_k.shape[-2:]
88 | flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim)
89 | flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim)
90 |
91 | flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask])
92 | flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask])
93 |
94 | def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
95 | """
96 | This is a naive implementation and not optimized for speed.
97 | """
98 | assert xk.ndim == xv.ndim == 3 # (B * T, H, D)
99 | assert xk.shape == xv.shape
100 |
101 | if all([s == 0 for s in self.metadata.seqlens]):
102 | # No cache to interleave
103 | return xk, xv
104 |
105 | # Make it a list of [(T, H, D)]
106 | xk: Tuple[torch.Tensor] = torch.split(xk, self.metadata.seqlens) # type: ignore
107 | xv: Tuple[torch.Tensor] = torch.split(xv, self.metadata.seqlens) # type: ignore
108 | assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}"
109 |
110 | # Order elements in cache by position by unrotating
111 | cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)]
112 | cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)]
113 |
114 | interleaved_k = interleave_list(cache_k, list(xk))
115 | interleaved_v = interleave_list(cache_v, list(xv))
116 |
117 | return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0)
118 |
119 | @property
120 | def max_seq_len(self) -> int:
121 | return self.cache_k.shape[1]
122 |
123 | @property
124 | def key(self) -> torch.Tensor:
125 | return self.cache_k[: len(self.kv_seqlens)]
126 |
127 | @property
128 | def value(self) -> torch.Tensor:
129 | return self.cache_v[: len(self.kv_seqlens)]
130 |
131 | @property
132 | def prefill(self) -> bool:
133 | return self.metadata.prefill
134 |
135 | @property
136 | def mask(self) -> AttentionBias:
137 | return self.metadata.mask
138 |
139 |
140 | class BufferCache:
141 | """
142 | This is an example that implements a buffer cache, allowing for variable length sequences.
143 | Allocated cache is rectangular which is wasteful (see PagedAttention for better mechanisms)
144 | """
145 |
146 | def __init__(
147 | self,
148 | n_layers: int,
149 | max_batch_size: int,
150 | max_seq_len: int,
151 | n_kv_heads: int,
152 | head_dim: int,
153 | sliding_window: Optional[int] | Optional[List[int]] = None,
154 | ):
155 | self.max_seq_len = max_seq_len
156 | self.n_kv_heads = n_kv_heads
157 | self.head_dim = head_dim
158 | self.n_layers = n_layers
159 |
160 | self.cache_sizes: List[int] = get_cache_sizes(n_layers, max_seq_len, sliding_window)
161 | assert len(self.cache_sizes) == n_layers, f"Expected {n_layers} cache sizes, got {len(self.cache_sizes)}"
162 |
163 | self.cache_k = {}
164 | self.cache_v = {}
165 | for i, cache_size in enumerate(self.cache_sizes):
166 | self.cache_k[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim))
167 | self.cache_v[i] = torch.empty((max_batch_size, cache_size, n_kv_heads, head_dim))
168 |
169 | # holds the valid length for each batch element in the cache
170 | self.kv_seqlens: Optional[torch.Tensor] = None
171 |
172 | def get_view(self, layer_id: int, metadata: CacheInputMetadata) -> CacheView:
173 | assert self.kv_seqlens is not None
174 | return CacheView(self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens)
175 |
176 | def reset(self) -> None:
177 | self.kv_seqlens = None
178 |
179 | def init_kvseqlens(self, batch_size: int) -> None:
180 | self.kv_seqlens = torch.zeros((batch_size,), device=self.device, dtype=torch.long)
181 |
182 | @property
183 | def device(self) -> torch.device:
184 | return self.cache_k[0].device
185 |
186 | def to(self, device: torch.device, dtype: torch.dtype) -> "BufferCache":
187 | for i in range(self.n_layers):
188 | self.cache_k[i] = self.cache_k[i].to(device=device, dtype=dtype)
189 | self.cache_v[i] = self.cache_v[i].to(device=device, dtype=dtype)
190 |
191 | return self
192 |
193 | def update_seqlens(self, seqlens: List[int]) -> None:
194 | assert self.kv_seqlens is not None
195 | self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long)
196 |
197 | def get_input_metadata(self, seqlens: List[int]) -> List[CacheInputMetadata]:
198 | """
199 | input = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3
200 | --> only cache last 3 tokens in each sequence
201 | - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1]
202 | - cached_elements = [3 | 3 | 2]
203 | --> absolute positions are used for rope
204 | - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4]
205 | --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window
206 | - cache_positions = [2 0 1 | 5 3 4 | 6 7]
207 | """
208 | metadata: List[CacheInputMetadata] = []
209 |
210 | if self.kv_seqlens is None:
211 | self.init_kvseqlens(len(seqlens))
212 |
213 | assert self.kv_seqlens is not None
214 | assert len(seqlens) == len(
215 | self.kv_seqlens
216 | ), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
217 | seqpos = self.kv_seqlens.tolist()
218 | assert len(seqlens) > 0, seqlens
219 |
220 | for cache_size in self.cache_sizes:
221 | metadata.append(self._get_input_metadata_layer(cache_size, seqlens, seqpos))
222 |
223 | return metadata
224 |
225 | def _get_input_metadata_layer(self, cache_size: int, seqlens: List[int], seqpos: List[int]) -> CacheInputMetadata:
226 | masks = [[x >= seqlen - cache_size for x in range(seqlen)] for seqlen in seqlens]
227 | to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool)
228 | cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long)
229 | positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(
230 | device=self.device, dtype=torch.long
231 | )
232 | batch_idx = torch.tensor(
233 | sum([[i] * seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long
234 | )
235 | cache_positions = positions % cache_size + batch_idx * cache_size
236 | first_prefill = seqpos[0] == 0
237 | subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
238 | if first_prefill:
239 | assert all([pos == 0 for pos in seqpos]), seqpos
240 | mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(cache_size)
241 | elif subsequent_prefill:
242 | assert self.kv_seqlens is not None
243 | mask = BlockDiagonalMask.from_seqlens(
244 | q_seqlen=seqlens,
245 | kv_seqlen=[
246 | s + cached_s.clamp(max=cache_size).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)
247 | ],
248 | ).make_local_attention_from_bottomright(cache_size)
249 | else:
250 | mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
251 | q_seqlen=seqlens,
252 | kv_padding=cache_size,
253 | kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=cache_size).tolist(),
254 | )
255 | return CacheInputMetadata(
256 | positions=positions,
257 | to_cache_mask=to_cache_mask,
258 | cached_elements=cached_elements,
259 | cache_positions=cache_positions[to_cache_mask],
260 | prefill=first_prefill or subsequent_prefill,
261 | mask=mask,
262 | seqlens=seqlens,
263 | )
264 |
--------------------------------------------------------------------------------
/src/mistral_inference/generate.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from mistral_inference.cache import BufferCache
7 | from mistral_inference.mamba import Mamba
8 | from mistral_inference.transformer import Transformer
9 |
10 |
11 | @torch.inference_mode()
12 | def generate_mamba(
13 | encoded_prompts: List[List[int]],
14 | model: Mamba,
15 | *,
16 | max_tokens: int,
17 | temperature: float,
18 | chunk_size: Optional[int] = None,
19 | eos_id: Optional[int] = None,
20 | ) -> Tuple[List[List[int]], List[List[float]]]:
21 | input_ids = torch.tensor(encoded_prompts, device=model.device)
22 | output = model.model.generate(
23 | input_ids=input_ids,
24 | max_length=input_ids.shape[-1] + max_tokens,
25 | cg=True,
26 | return_dict_in_generate=True,
27 | output_scores=True,
28 | enable_timing=False,
29 | eos_token_id=eos_id,
30 | temperature=temperature,
31 | top_p=0.8,
32 | )
33 | generated_tokens = output.sequences[:, input_ids.shape[-1] :].tolist()
34 |
35 | _logprobs: List[List[float]] = [[] for _ in range(len(generated_tokens))]
36 | for seq_idx, batch_score in enumerate(output.scores):
37 | for batch_idx, score in enumerate(batch_score.tolist()):
38 | _logprobs[batch_idx].append(score[generated_tokens[batch_idx][seq_idx]])
39 |
40 | return generated_tokens, _logprobs
41 |
42 |
43 | @torch.inference_mode()
44 | def generate(
45 | encoded_prompts: List[List[int]],
46 | model: Transformer,
47 | images: List[List[np.ndarray]] = [],
48 | *,
49 | max_tokens: int,
50 | temperature: float,
51 | chunk_size: Optional[int] = None,
52 | eos_id: Optional[int] = None,
53 | ) -> Tuple[List[List[int]], List[List[float]]]:
54 | images_torch: List[List[torch.Tensor]] = []
55 | if images:
56 | assert chunk_size is None
57 | images_torch = [
58 | [torch.tensor(im, device=model.device, dtype=model.dtype) for im in images_for_sample]
59 | for images_for_sample in images
60 | ]
61 |
62 | model = model.eval()
63 | B, V = len(encoded_prompts), model.args.vocab_size
64 |
65 | seqlens = [len(x) for x in encoded_prompts]
66 |
67 | # Cache
68 | cache_window = max(seqlens) + max_tokens
69 | cache = BufferCache(
70 | model.n_local_layers,
71 | model.args.max_batch_size,
72 | cache_window,
73 | model.args.n_kv_heads,
74 | model.args.head_dim,
75 | model.args.sliding_window,
76 | )
77 | cache.to(device=model.device, dtype=model.dtype)
78 | cache.reset()
79 |
80 | # Bookkeeping
81 | logprobs: List[List[float]] = [[] for _ in range(B)]
82 | last_token_prelogits = None
83 |
84 | # One chunk if size not specified
85 | max_prompt_len = max(seqlens)
86 | if chunk_size is None:
87 | chunk_size = max_prompt_len
88 |
89 | flattened_images: List[torch.Tensor] = sum(images_torch, [])
90 |
91 | # Encode prompt by chunks
92 | for s in range(0, max_prompt_len, chunk_size):
93 | prompt_chunks = [p[s : s + chunk_size] for p in encoded_prompts]
94 | assert all(len(p) > 0 for p in prompt_chunks)
95 | prelogits = model.forward(
96 | torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long),
97 | images=flattened_images,
98 | seqlens=[len(p) for p in prompt_chunks],
99 | cache=cache,
100 | )
101 | logits = torch.log_softmax(prelogits, dim=-1)
102 |
103 | if last_token_prelogits is not None:
104 | # Pass > 1
105 | last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1)
106 | for i_seq in range(B):
107 | logprobs[i_seq].append(last_token_logits[i_seq, prompt_chunks[i_seq][0]].item())
108 |
109 | offset = 0
110 | for i_seq, sequence in enumerate(prompt_chunks):
111 | logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)])
112 | offset += len(sequence)
113 |
114 | last_token_prelogits = prelogits.index_select(
115 | 0,
116 | torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1,
117 | )
118 | assert last_token_prelogits.shape == (B, V)
119 |
120 | # decode
121 | generated_tensors = []
122 | is_finished = torch.tensor([False for _ in range(B)])
123 |
124 | assert last_token_prelogits is not None
125 | for _ in range(max_tokens):
126 | next_token = sample(last_token_prelogits, temperature=temperature, top_p=0.8)
127 |
128 | if eos_id is not None:
129 | is_finished = is_finished | (next_token == eos_id).cpu()
130 |
131 | if is_finished.all():
132 | break
133 |
134 | last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1)
135 | for i in range(B):
136 | logprobs[i].append(last_token_logits[i, next_token[i]].item())
137 |
138 | generated_tensors.append(next_token[:, None])
139 | last_token_prelogits = model.forward(next_token, seqlens=[1] * B, cache=cache)
140 | assert last_token_prelogits.shape == (B, V)
141 |
142 | generated_tokens: List[List[int]]
143 | if generated_tensors:
144 | generated_tokens = torch.cat(generated_tensors, 1).tolist()
145 | else:
146 | generated_tokens = []
147 |
148 | return generated_tokens, logprobs
149 |
150 |
151 | def sample(logits: torch.Tensor, temperature: float, top_p: float) -> torch.Tensor:
152 | if temperature > 0:
153 | probs = torch.softmax(logits / temperature, dim=-1)
154 | next_token = sample_top_p(probs, top_p)
155 | else:
156 | next_token = torch.argmax(logits, dim=-1).unsqueeze(0)
157 |
158 | return next_token.reshape(-1)
159 |
160 |
161 | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
162 | assert 0 <= p <= 1
163 |
164 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
165 | probs_sum = torch.cumsum(probs_sort, dim=-1)
166 | mask = probs_sum - probs_sort > p
167 | probs_sort[mask] = 0.0
168 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
169 | next_token = torch.multinomial(probs_sort, num_samples=1)
170 | return torch.gather(probs_idx, -1, next_token)
171 |
--------------------------------------------------------------------------------
/src/mistral_inference/lora.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from dataclasses import dataclass
3 | from pathlib import Path
4 | from typing import Any, Dict, NamedTuple, Union
5 |
6 | import safetensors.torch
7 | import torch
8 | import torch.nn as nn
9 | from simple_parsing.helpers import Serializable
10 |
11 |
12 | @dataclass
13 | class LoraArgs(Serializable):
14 | rank: int
15 | scaling: float
16 |
17 | def __post_init__(self) -> None:
18 | assert self.rank > 0
19 | assert self.scaling > 0.0
20 |
21 |
22 | class LoRALinear(nn.Module):
23 | """
24 | Implementation of:
25 | - LoRA: https://arxiv.org/abs/2106.09685
26 |
27 | Notes:
28 | - Freezing is handled at network level, not layer level.
29 | - Scaling factor controls relative importance of LoRA skip
30 | connection versus original frozen weight. General guidance is
31 | to keep it to 2.0 and sweep over learning rate when changing
32 | the rank.
33 | """
34 |
35 | def __init__(
36 | self,
37 | in_features: int,
38 | out_features: int,
39 | rank: int,
40 | scaling: float,
41 | bias: bool = False,
42 | ):
43 | super().__init__()
44 |
45 | self.in_features = in_features
46 | self.out_features = out_features
47 | assert not bias
48 | self.bias = bias
49 | self.rank = rank
50 | self.scaling = scaling
51 |
52 | self.lora_A = nn.Linear(
53 | self.in_features,
54 | self.rank,
55 | bias=self.bias,
56 | )
57 | self.lora_B = nn.Linear(
58 | self.rank,
59 | self.out_features,
60 | bias=self.bias,
61 | )
62 |
63 | self.linear = nn.Linear(self.in_features, self.out_features, bias=self.bias)
64 |
65 | # make sure no LoRA weights are marked as "missing" in load_state_dict
66 | def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple) -> None:
67 | incompatible_keys.missing_keys[:] = [] # type: ignore
68 |
69 | self.register_load_state_dict_post_hook(ignore_missing_keys)
70 |
71 | def forward(self, x: torch.Tensor) -> torch.Tensor:
72 | lora = self.lora_B(self.lora_A(x))
73 | result: torch.Tensor = self.linear(x) + lora * self.scaling
74 | return result
75 |
76 | def _load_from_state_dict(self, state_dict: Dict[str, Any], prefix: str, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
77 | key_name = prefix + "weight"
78 |
79 | # full checkpoint
80 | if key_name in state_dict:
81 | w_ref = state_dict[key_name]
82 |
83 | # load frozen weights
84 | state_dict = {
85 | "linear.weight": w_ref,
86 | "lora_A.weight": torch.zeros_like(self.lora_A.weight, device=w_ref.device, dtype=w_ref.dtype),
87 | "lora_B.weight": torch.zeros_like(self.lora_B.weight, device=w_ref.device, dtype=w_ref.dtype),
88 | }
89 | self.load_state_dict(state_dict, assign=True, strict=True)
90 |
91 |
92 | class LoRALoaderMixin:
93 | def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0) -> None:
94 | """Loads LoRA checkpoint"""
95 |
96 | lora_path = Path(lora_path)
97 | assert lora_path.is_file(), f"{lora_path} does not exist or is not a file"
98 |
99 | state_dict = safetensors.torch.load_file(lora_path)
100 |
101 | self._load_lora_state_dict(state_dict, scaling=scaling)
102 |
103 | def _load_lora_state_dict(self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0) -> None:
104 | """Loads LoRA state_dict"""
105 | lora_dtypes = set([p.dtype for p in lora_state_dict.values()])
106 | assert (
107 | len(lora_dtypes) == 1
108 | ), f"LoRA weights have multiple different dtypes {lora_dtypes}. All weights need to have the same dtype"
109 | lora_dtype = lora_dtypes.pop()
110 | assert lora_dtype == self.dtype, f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}" # type: ignore[attr-defined]
111 | assert all("lora" in key for key in lora_state_dict.keys())
112 |
113 | # move tensors to device
114 | lora_state_dict = {k: v.to(self.device) for k, v in lora_state_dict.items()} # type: ignore[attr-defined]
115 |
116 | state_dict = self.state_dict() # type: ignore[attr-defined]
117 |
118 | if self.args.lora is None: # type: ignore[attr-defined]
119 | logging.info("Loading and merging LoRA weights...")
120 |
121 | # replace every nn.Linear with a LoRALinear with 'meta' device except the output layer
122 | named_modules = dict(self.named_modules()) # type: ignore[attr-defined]
123 | for name, module in named_modules.items():
124 | if isinstance(module, nn.Linear) and name != "output":
125 | layer_id = name.split(".")[1]
126 | if layer_id not in self.layers: # type: ignore[attr-defined]
127 | logging.debug(
128 | "Skipping parameter %s at pipeline rank %d",
129 | name,
130 | self.pipeline_rank, # type: ignore[attr-defined]
131 | )
132 | elif (name + ".lora_B.weight") in lora_state_dict:
133 | weight = (
134 | module.weight
135 | + (lora_state_dict[name + ".lora_B.weight"] @ lora_state_dict[name + ".lora_A.weight"])
136 | * scaling
137 | )
138 |
139 | state_dict[name + ".weight"] = weight
140 | else:
141 | logging.info("Loading LoRA weights...")
142 | for k, v in lora_state_dict.items():
143 | state_dict.update(lora_state_dict)
144 |
145 | layer_id = k.split(".")[1]
146 | if layer_id in self.layers: # type: ignore[attr-defined]
147 | state_dict[k] = v
148 | else:
149 | logging.debug(
150 | "Skipping parameter %s at pipeline rank %d",
151 | k,
152 | self.pipeline_rank, # type: ignore[attr-defined]
153 | )
154 |
155 | self.load_state_dict(state_dict, strict=True) # type: ignore[attr-defined]
156 |
--------------------------------------------------------------------------------
/src/mistral_inference/main.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import warnings
5 | from pathlib import Path
6 | from typing import List, Optional, Tuple, Type, Union
7 |
8 | import fire # type: ignore
9 | import torch
10 | import torch.distributed as dist
11 | from mistral_common.protocol.instruct.messages import (
12 | AssistantMessage,
13 | ContentChunk,
14 | ImageChunk,
15 | ImageURLChunk,
16 | TextChunk,
17 | UserMessage,
18 | )
19 | from mistral_common.protocol.instruct.request import ChatCompletionRequest
20 | from mistral_common.tokens.tokenizers.base import Tokenizer
21 | from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
22 | from mistral_common.tokens.tokenizers.sentencepiece import is_sentencepiece
23 | from mistral_common.tokens.tokenizers.tekken import (
24 | SpecialTokenPolicy,
25 | Tekkenizer,
26 | is_tekken,
27 | )
28 | from PIL import Image
29 |
30 | from mistral_inference.args import TransformerArgs
31 | from mistral_inference.generate import generate, generate_mamba
32 | from mistral_inference.mamba import Mamba
33 | from mistral_inference.transformer import Transformer
34 |
35 |
36 | def is_torchrun() -> bool:
37 | required_vars = ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE"]
38 | return all(var in os.environ for var in required_vars)
39 |
40 |
41 | def load_tokenizer(model_path: Path) -> MistralTokenizer:
42 | tokenizer = [f for f in os.listdir(model_path) if is_tekken(model_path / f) or is_sentencepiece(model_path / f)]
43 | assert (
44 | len(tokenizer) > 0
45 | ), f"No tokenizer in {model_path}, place a `tokenizer.model.[v1,v2,v3]` or `tekken.json` file in {model_path}."
46 | assert (
47 | len(tokenizer) == 1
48 | ), f"Multiple tokenizers {', '.join(tokenizer)} found in `model_path`, make sure to only have one tokenizer"
49 |
50 | mistral_tokenizer = MistralTokenizer.from_file(str(model_path / tokenizer[0]))
51 |
52 | if isinstance(mistral_tokenizer.instruct_tokenizer.tokenizer, Tekkenizer):
53 | mistral_tokenizer.instruct_tokenizer.tokenizer.special_token_policy = SpecialTokenPolicy.KEEP
54 |
55 | logging.info(f"Loaded tokenizer of type {mistral_tokenizer.instruct_tokenizer.__class__}")
56 |
57 | return mistral_tokenizer
58 |
59 |
60 | def get_model_cls(model_path: str) -> Union[Type[Mamba], Type[Transformer]]:
61 | with open(Path(model_path) / "params.json", "r") as f:
62 | args_dict = json.load(f)
63 |
64 | return {"mamba": Mamba, "transformer": Transformer}[args_dict.get("model_type", "transformer")] # type: ignore[return-value]
65 |
66 |
67 | def pad_and_convert_to_tensor(list_of_lists: List[List[int]], pad_id: int) -> List[List[int]]:
68 | # Determine the length of the longest list
69 | max_len = max(len(lst) for lst in list_of_lists)
70 |
71 | # Left pad each list to the maximum length
72 | padded_lists = [[pad_id] * (max_len - len(lst)) + lst for lst in list_of_lists]
73 |
74 | return padded_lists
75 |
76 |
77 | def _get_multimodal_input() -> Tuple[UserMessage, bool]:
78 | chunks: List[ContentChunk] = []
79 |
80 | response = input("Text prompt: ")
81 | if response:
82 | chunks.append(TextChunk(text=response))
83 |
84 | print("[You can input zero, one or more images now.]")
85 | while True:
86 | did_something = False
87 | response = input("Image path or url [Leave empty and press enter to finish image input]: ")
88 | if response:
89 | if Path(response).is_file():
90 | chunks.append(ImageChunk(image=Image.open(response)))
91 | else:
92 | assert response.startswith("http"), f"{response} does not seem to be a valid url."
93 | chunks.append(ImageURLChunk(image_url=response))
94 | did_something = True
95 |
96 | if not did_something:
97 | break
98 |
99 | return UserMessage(content=chunks), not chunks
100 |
101 |
102 | def interactive(
103 | model_path: str,
104 | max_tokens: int = 35,
105 | temperature: float = 0.7,
106 | num_pipeline_ranks: int = 1,
107 | instruct: bool = False,
108 | lora_path: Optional[str] = None,
109 | ) -> None:
110 | if is_torchrun():
111 | torch.distributed.init_process_group()
112 | torch.cuda.set_device(torch.distributed.get_rank())
113 | should_print = torch.distributed.get_rank() == 0
114 |
115 | num_pipeline_ranks = torch.distributed.get_world_size()
116 | else:
117 | should_print = True
118 | num_pipeline_ranks = 1
119 |
120 | mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path))
121 | tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
122 |
123 | model_cls = get_model_cls(model_path)
124 | model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks)
125 | is_multimodal = isinstance(model.args, TransformerArgs) and model.args.vision_encoder is not None
126 |
127 | if is_multimodal:
128 | assert instruct, "Multimodal models should only be used in instruct mode"
129 |
130 | # load LoRA
131 | if lora_path is not None:
132 | model.load_lora(Path(lora_path))
133 |
134 | prompt: str = ""
135 | messages: List[UserMessage | AssistantMessage] = []
136 |
137 | while True:
138 | if should_print:
139 | if not is_multimodal:
140 | user_input = input("Prompt: ")
141 |
142 | if instruct:
143 | if is_multimodal:
144 | mm_input, finished = _get_multimodal_input()
145 | if finished:
146 | break
147 | messages += [mm_input]
148 | else:
149 | messages += [UserMessage(content=user_input)]
150 | chat_completion_request = ChatCompletionRequest(messages=messages)
151 |
152 | tokenized = mistral_tokenizer.encode_chat_completion(chat_completion_request)
153 | tokens = tokenized.tokens
154 | images = tokenized.images
155 | else:
156 | prompt += user_input
157 |
158 | tokens = tokenizer.encode(prompt, bos=True, eos=False)
159 | images = []
160 |
161 | length_tensor = torch.tensor([len(tokens)], dtype=torch.int)
162 | else:
163 | length_tensor = torch.tensor([0], dtype=torch.int)
164 | images = []
165 |
166 | if is_torchrun():
167 | dist.broadcast(length_tensor, src=0)
168 |
169 | if not should_print:
170 | tokens = int(length_tensor.item()) * [0]
171 |
172 | generate_fn = generate if isinstance(model, Transformer) else generate_mamba
173 | generated_tokens, _ = generate_fn( # type: ignore[operator]
174 | [tokens],
175 | model,
176 | [images],
177 | max_tokens=max_tokens,
178 | temperature=temperature,
179 | eos_id=tokenizer.eos_id,
180 | )
181 |
182 | answer = tokenizer.decode(generated_tokens[0])
183 |
184 | if should_print:
185 | print(answer)
186 | print("=====================")
187 |
188 | if instruct:
189 | messages += [AssistantMessage(content=answer)]
190 | else:
191 | prompt += answer
192 |
193 |
194 | def demo(
195 | model_path: str,
196 | max_tokens: int = 35,
197 | temperature: float = 0,
198 | lora_path: Optional[str] = None,
199 | ) -> None:
200 | if is_torchrun():
201 | torch.distributed.init_process_group()
202 | torch.cuda.set_device(torch.distributed.get_rank())
203 | should_print = torch.distributed.get_rank() == 0
204 |
205 | num_pipeline_ranks = torch.distributed.get_world_size()
206 | else:
207 | should_print = True
208 | num_pipeline_ranks = 1
209 |
210 | model_cls = get_model_cls(model_path)
211 | model = model_cls.from_folder(Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks)
212 | # load LoRA
213 | if lora_path is not None:
214 | model.load_lora(Path(lora_path))
215 |
216 | mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path))
217 | tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
218 |
219 | prompts = [
220 | "This is a test",
221 | "This is another great test",
222 | "This is a third test, mistral AI is very good at testing. ",
223 | ]
224 |
225 | encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts]
226 |
227 | if isinstance(model, Transformer):
228 | generate_fn = generate
229 | else:
230 | generate_fn = generate_mamba # type: ignore[assignment]
231 | warnings.warn(
232 | "Batched generation is not correctly supported at the moment and therefore might lead to worse results "
233 | "as compared to non-batched generation. "
234 | "See https://github.com/state-spaces/mamba/issues/66#issuecomment-1862349718 for more information."
235 | )
236 | encoded_prompts = pad_and_convert_to_tensor(encoded_prompts, mistral_tokenizer.instruct_tokenizer.BOS) # type: ignore[attr-defined]
237 |
238 | generated_tokens, _logprobs = generate_fn(
239 | encoded_prompts,
240 | model, # type: ignore[arg-type]
241 | max_tokens=max_tokens,
242 | temperature=temperature,
243 | eos_id=tokenizer.eos_id,
244 | )
245 |
246 | generated_words = []
247 | for i, x in enumerate(generated_tokens):
248 | generated_words.append(tokenizer.decode(encoded_prompts[i] + x))
249 |
250 | res = generated_words
251 |
252 | if should_print:
253 | for w, logprob in zip(res, _logprobs):
254 | print(w)
255 | logging.debug("Logprobs: %s", logprob)
256 | print("=====================")
257 |
258 |
259 | def mistral_chat() -> None:
260 | fire.Fire(interactive)
261 |
262 |
263 | def mistral_demo() -> None:
264 | fire.Fire(demo)
265 |
266 |
267 | if __name__ == "__main__":
268 | logging.basicConfig(level=logging.INFO)
269 | fire.Fire(
270 | {
271 | "interactive": interactive,
272 | "demo": demo,
273 | }
274 | )
275 |
--------------------------------------------------------------------------------
/src/mistral_inference/mamba.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import List, Optional, Union
4 |
5 | import safetensors
6 | import torch
7 | import torch.nn as nn
8 |
9 | from mistral_inference.args import MambaArgs
10 | from mistral_inference.cache import BufferCache
11 | from mistral_inference.model import ModelBase
12 |
13 | _is_mamba_installed = False
14 | try:
15 | from mamba_ssm.models.config_mamba import MambaConfig
16 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
17 |
18 | _is_mamba_installed = True
19 | except ImportError:
20 | _is_mamba_installed = False
21 |
22 |
23 | class Mamba(ModelBase, nn.Module):
24 | def __init__(self, args: MambaArgs):
25 | super().__init__()
26 | self.args = args
27 | assert _is_mamba_installed, "Mamba is not installed. Please install it using `pip install mamba-ssm`."
28 |
29 | # make sure naming is consistent with `mamba_ssm`
30 | config = MambaConfig(
31 | d_model=args.dim,
32 | n_layer=args.n_layers,
33 | vocab_size=args.vocab_size,
34 | ssm_cfg={"ngroups": args.n_groups, "layer": "Mamba2"},
35 | attn_layer_idx=[],
36 | attn_cfg={},
37 | rms_norm=args.rms_norm,
38 | residual_in_fp32=args.residual_in_fp32,
39 | fused_add_norm=args.fused_add_norm,
40 | pad_vocab_size_multiple=args.pad_vocab_size_multiple,
41 | tie_embeddings=args.tie_embeddings,
42 | )
43 | self.model = MambaLMHeadModel(config)
44 |
45 | @property
46 | def dtype(self) -> torch.dtype:
47 | return next(self.parameters()).dtype
48 |
49 | @property
50 | def device(self) -> torch.device:
51 | return next(self.parameters()).device
52 |
53 | def forward(
54 | self,
55 | input_ids: torch.Tensor,
56 | seqlens: List[int], # not supported for now
57 | cache: Optional[BufferCache] = None, # not supported for now
58 | ) -> torch.Tensor:
59 | lm_output = self.model(input_ids)
60 | result: torch.Tensor = lm_output.logits
61 | return result
62 |
63 | @staticmethod
64 | def from_folder(
65 | folder: Union[Path, str],
66 | max_batch_size: int = 1,
67 | num_pipeline_ranks: int = 1,
68 | device: Union[torch.device, str] = "cuda",
69 | dtype: Optional[torch.dtype] = None,
70 | ) -> "Mamba":
71 | with open(Path(folder) / "params.json", "r") as f:
72 | model_args = MambaArgs.from_dict(json.load(f))
73 |
74 | with torch.device("meta"):
75 | model = Mamba(model_args)
76 |
77 | model_file = Path(folder) / "consolidated.safetensors"
78 |
79 | assert model_file.exists(), f"Make sure {model_file} exists."
80 | loaded = safetensors.torch.load_file(str(model_file))
81 |
82 | model.load_state_dict(loaded, assign=True, strict=True)
83 | return model.to(device=device, dtype=dtype)
84 |
--------------------------------------------------------------------------------
/src/mistral_inference/model.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 | from typing import List, Optional, Union
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from mistral_inference.cache import BufferCache
9 |
10 |
11 | class ModelBase(nn.Module, ABC):
12 | def __init__(self) -> None:
13 | super().__init__()
14 |
15 | @property
16 | @abstractmethod
17 | def dtype(self) -> torch.dtype:
18 | pass
19 |
20 | @property
21 | @abstractmethod
22 | def device(self) -> torch.device:
23 | pass
24 |
25 | @abstractmethod
26 | def forward(
27 | self,
28 | input_ids: torch.Tensor,
29 | seqlens: List[int], # not supported for now
30 | cache: Optional[BufferCache] = None, # not supported for now
31 | ) -> torch.Tensor:
32 | pass
33 |
34 | @staticmethod
35 | @abstractmethod
36 | def from_folder(
37 | folder: Union[Path, str],
38 | max_batch_size: int = 1,
39 | num_pipeline_ranks: int = 1,
40 | device: Union[torch.device, str] = "cuda",
41 | dtype: Optional[torch.dtype] = None,
42 | ) -> "ModelBase":
43 | pass
44 |
--------------------------------------------------------------------------------
/src/mistral_inference/moe.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from simple_parsing.helpers import Serializable
7 | from torch import nn
8 |
9 |
10 | @dataclasses.dataclass
11 | class MoeArgs(Serializable):
12 | num_experts: int
13 | num_experts_per_tok: int
14 |
15 |
16 | class MoeLayer(nn.Module):
17 | def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
18 | super().__init__()
19 | assert len(experts) > 0
20 | self.experts = nn.ModuleList(experts)
21 | self.gate = gate
22 | self.args = moe_args
23 |
24 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
25 | gate_logits = self.gate(inputs)
26 | weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
27 | weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
28 | results = torch.zeros_like(inputs)
29 | for i, expert in enumerate(self.experts):
30 | batch_idx, nth_expert = torch.where(selected_experts == i)
31 | results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx])
32 | return results
33 |
--------------------------------------------------------------------------------
/src/mistral_inference/rope.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 |
5 |
6 | def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor:
7 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
8 | t = torch.arange(end, device=freqs.device)
9 | freqs = torch.outer(t, freqs).float()
10 | return torch.polar(torch.ones_like(freqs), freqs) # complex64
11 |
12 |
13 | def apply_rotary_emb(
14 | xq: torch.Tensor,
15 | xk: torch.Tensor,
16 | freqs_cis: torch.Tensor,
17 | ) -> Tuple[torch.Tensor, torch.Tensor]:
18 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
19 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
20 | freqs_cis = freqs_cis[:, None, :]
21 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
22 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
23 | return xq_out.type_as(xq), xk_out.type_as(xk)
24 |
25 |
26 | def precompute_freqs_cis_2d(
27 | dim: int,
28 | height: int,
29 | width: int,
30 | theta: float,
31 | ) -> torch.Tensor:
32 | """
33 | freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by
34 | (height, width) position tuples
35 | """
36 | # (dim / 2) frequency bases
37 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
38 |
39 | h = torch.arange(height, device=freqs.device)
40 | w = torch.arange(width, device=freqs.device)
41 |
42 | freqs_h = torch.outer(h, freqs[::2]).float()
43 | freqs_w = torch.outer(w, freqs[1::2]).float()
44 | freqs_2d = torch.cat(
45 | [
46 | freqs_h[:, None, :].repeat(1, width, 1),
47 | freqs_w[None, :, :].repeat(height, 1, 1),
48 | ],
49 | dim=-1,
50 | )
51 | return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
52 |
--------------------------------------------------------------------------------
/src/mistral_inference/transformer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import math
4 | from dataclasses import dataclass
5 | from pathlib import Path
6 | from typing import Any, List, Mapping, Optional, Union
7 |
8 | import safetensors.torch
9 | import torch
10 | from torch import nn
11 |
12 | from mistral_inference.args import PATCH_MERGE, TransformerArgs
13 | from mistral_inference.cache import BufferCache, CacheInputMetadata
14 | from mistral_inference.lora import LoRALoaderMixin
15 | from mistral_inference.model import ModelBase
16 | from mistral_inference.rope import precompute_freqs_cis
17 | from mistral_inference.transformer_layers import RMSNorm, TransformerBlock
18 | from mistral_inference.vision_encoder import PatchMerger, VisionLanguageAdapter, VisionTransformer
19 |
20 |
21 | @dataclass
22 | class SimpleInputMetadata:
23 | # rope absolute positions
24 | positions: torch.Tensor
25 |
26 | @staticmethod
27 | def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata":
28 | return SimpleInputMetadata(
29 | positions=torch.cat([torch.arange(0, seqlen) for seqlen in seqlens]).to(device=device, dtype=torch.long)
30 | )
31 |
32 |
33 | class Transformer(ModelBase, LoRALoaderMixin):
34 | def __init__(
35 | self,
36 | args: TransformerArgs,
37 | pipeline_rank: int = 0,
38 | num_pipeline_ranks: int = 1,
39 | softmax_fp32: bool = True,
40 | ):
41 | super().__init__()
42 | self.args = args
43 | self.vocab_size = args.vocab_size
44 | self.n_layers = args.n_layers
45 | self._precomputed_freqs_cis: Optional[torch.Tensor] = None
46 | assert self.vocab_size > 0
47 | assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
48 | self.pipeline_rank = pipeline_rank
49 | self.num_pipeline_ranks = num_pipeline_ranks
50 | self.softmax_fp32 = softmax_fp32
51 |
52 | # Modules specific to some ranks:
53 | self.tok_embeddings: Optional[nn.Embedding] = None
54 | self.norm: Optional[RMSNorm] = None
55 | self.output: Optional[nn.Linear] = None
56 | if pipeline_rank == 0:
57 | self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
58 |
59 | self.vision_encoder: Optional[VisionTransformer] = None
60 | self.vision_language_adapter: Optional[VisionLanguageAdapter] = None
61 |
62 | if args.vision_encoder is not None:
63 | self.vision_encoder = VisionTransformer(args.vision_encoder)
64 | self.vision_language_adapter = VisionLanguageAdapter(
65 | args.vision_encoder.hidden_size, args.dim, args.vision_encoder.adapter_bias
66 | )
67 |
68 | if args.vision_encoder.add_pre_mm_projector_layer_norm:
69 | self.pre_mm_projector_norm = RMSNorm(args.vision_encoder.hidden_size, eps=1e-5)
70 |
71 | if args.vision_encoder.mm_projector_id == PATCH_MERGE:
72 | self.patch_merger = PatchMerger(
73 | vision_encoder_dim=args.vision_encoder.hidden_size,
74 | spatial_merge_size=args.vision_encoder.spatial_merge_size,
75 | )
76 |
77 | if pipeline_rank == num_pipeline_ranks - 1:
78 | self.norm = RMSNorm(args.dim, eps=args.norm_eps)
79 | self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
80 | # Initialize all layers but slice off those not of this rank.
81 | layers = [
82 | TransformerBlock(
83 | dim=args.dim,
84 | hidden_dim=args.hidden_dim,
85 | n_heads=args.n_heads,
86 | n_kv_heads=args.n_kv_heads,
87 | head_dim=args.head_dim,
88 | norm_eps=args.norm_eps,
89 | lora=args.lora,
90 | moe=args.moe,
91 | )
92 | for _ in range(args.n_layers)
93 | ]
94 | num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks)
95 | offset = self.pipeline_rank * num_layers_per_rank
96 | end = min(self.n_layers, offset + num_layers_per_rank)
97 | self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)})
98 | self.n_local_layers = len(self.layers)
99 |
100 | @property
101 | def dtype(self) -> torch.dtype:
102 | return next(self.parameters()).dtype
103 |
104 | @property
105 | def device(self) -> torch.device:
106 | return next(self.parameters()).device
107 |
108 | @property
109 | def freqs_cis(self) -> torch.Tensor:
110 | # We cache freqs_cis but need to take care that it is on the right device
111 | # and has the right dtype (complex64). The fact that the dtype is different
112 | # from the module's dtype means we cannot register it as a buffer
113 | if self._precomputed_freqs_cis is None:
114 | # default to 10**6
115 | theta = self.args.rope_theta or 1000000.0
116 | self._precomputed_freqs_cis = precompute_freqs_cis(self.args.head_dim, 128_000, theta)
117 |
118 | if self._precomputed_freqs_cis.device != self.device:
119 | self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(device=self.device)
120 | return self._precomputed_freqs_cis
121 |
122 | def embed_vision_language_features(self, input_ids: torch.Tensor, images: List[torch.Tensor]) -> torch.Tensor:
123 | assert self.tok_embeddings is not None
124 | assert self.vision_encoder is not None
125 | assert self.vision_language_adapter is not None
126 | assert self.args.vision_encoder is not None
127 |
128 | text_locations = input_ids != self.args.vision_encoder.image_token_id
129 | image_locations = input_ids == self.args.vision_encoder.image_token_id
130 | text_features = self.tok_embeddings(input_ids[text_locations])
131 |
132 | image_features = self.vision_encoder(images)
133 |
134 | if self.args.vision_encoder.add_pre_mm_projector_layer_norm:
135 | image_features = self.pre_mm_projector_norm(image_features)
136 |
137 | if self.args.vision_encoder.mm_projector_id == PATCH_MERGE:
138 | patch_size = self.args.vision_encoder.patch_size
139 | img_patch_dims = [(img.shape[1] // patch_size, img.shape[2] // patch_size) for img in images]
140 | image_features = self.patch_merger(image_features, image_sizes=img_patch_dims)
141 |
142 | image_features = self.vision_language_adapter(image_features)
143 |
144 | N_txt, D_txt = text_features.shape
145 | N_img, D_img = image_features.shape
146 |
147 | seq_len = input_ids.shape[0]
148 |
149 | assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}"
150 | assert seq_len == N_txt + N_img, (
151 | f"seq_len {seq_len} should be equal to N_txt + N_img {(N_txt, N_img, image_locations.sum().item())}"
152 | )
153 |
154 | combined_features = torch.empty(
155 | (seq_len, D_txt),
156 | dtype=text_features.dtype,
157 | device=text_features.device,
158 | )
159 | combined_features[text_locations, :] = text_features
160 | combined_features[image_locations, :] = image_features
161 | return combined_features
162 |
163 | def forward_partial(
164 | self,
165 | input_ids: torch.Tensor,
166 | seqlens: List[int],
167 | cache: Optional[BufferCache] = None,
168 | images: Optional[List[torch.Tensor]] = None,
169 | ) -> torch.Tensor:
170 | """Local forward pass.
171 |
172 | If doing pipeline parallelism, this will return the activations of the last layer of this stage.
173 | For the last stage, this will return the normalized final embeddings.
174 | """
175 | assert len(seqlens) <= self.args.max_batch_size, (
176 | f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
177 | )
178 | (num_toks,) = input_ids.shape
179 | assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)
180 |
181 | input_metadata: List[CacheInputMetadata] | List[SimpleInputMetadata]
182 |
183 | if cache is not None:
184 | input_metadata = cache.get_input_metadata(seqlens)
185 | else:
186 | input_metadata = [SimpleInputMetadata.from_seqlens(seqlens, self.device) for _ in range(len(self.layers))]
187 |
188 | if self.pipeline_rank == 0:
189 | assert self.tok_embeddings is not None
190 | if self.vision_encoder is not None and images:
191 | h = self.embed_vision_language_features(input_ids, images)
192 | else:
193 | h = self.tok_embeddings(input_ids)
194 | else:
195 | h = torch.empty(num_toks, self.args.dim, device=self.device, dtype=self.dtype)
196 | torch.distributed.recv(h, src=self.pipeline_rank - 1)
197 |
198 | # freqs_cis is always the same for every layer
199 | freqs_cis = self.freqs_cis[input_metadata[0].positions]
200 |
201 | for local_layer_id, layer in enumerate(self.layers.values()):
202 | if cache is not None:
203 | assert input_metadata is not None
204 | cache_metadata = input_metadata[local_layer_id]
205 | assert isinstance(cache_metadata, CacheInputMetadata)
206 | cache_view = cache.get_view(local_layer_id, cache_metadata)
207 | else:
208 | cache_view = None
209 | h = layer(h, freqs_cis, cache_view)
210 |
211 | if cache is not None:
212 | cache.update_seqlens(seqlens)
213 | if self.pipeline_rank < self.num_pipeline_ranks - 1:
214 | torch.distributed.send(h, dst=self.pipeline_rank + 1)
215 | return h
216 | else:
217 | # Last rank has a final normalization step.
218 | assert self.norm is not None
219 | return self.norm(h) # type: ignore
220 |
221 | def forward(
222 | self,
223 | input_ids: torch.Tensor,
224 | seqlens: List[int],
225 | cache: Optional[BufferCache] = None,
226 | images: Optional[List[torch.Tensor]] = None,
227 | ) -> torch.Tensor:
228 | h = self.forward_partial(input_ids, seqlens, cache=cache, images=images)
229 | if self.pipeline_rank < self.num_pipeline_ranks - 1:
230 | # ignore the intermediate activations as we'll get the final output from
231 | # the last stage
232 | outs = torch.empty(h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype)
233 | else:
234 | assert self.output is not None
235 | outs = self.output(h)
236 | if self.num_pipeline_ranks > 1:
237 | torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
238 |
239 | if self.softmax_fp32:
240 | return outs.float()
241 | else:
242 | return outs
243 |
244 | def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None:
245 | state_to_load = {}
246 | skipped = set([])
247 | for k, v in state_dict.items():
248 | if k.startswith("tok_embeddings"):
249 | if self.pipeline_rank == 0:
250 | state_to_load[k] = v
251 | else:
252 | logging.debug(
253 | "Skipping parameter %s at pipeline rank %d",
254 | k,
255 | self.pipeline_rank,
256 | )
257 | skipped.add(k)
258 | elif k.startswith("norm") or k.startswith("output"):
259 | if self.pipeline_rank == self.num_pipeline_ranks - 1:
260 | state_to_load[k] = v
261 | else:
262 | logging.debug(
263 | "Skipping parameter %s at pipeline rank %d",
264 | k,
265 | self.pipeline_rank,
266 | )
267 | skipped.add(k)
268 | elif k.startswith("layers"):
269 | layer_id = k.split(".")[1]
270 | if layer_id in self.layers:
271 | state_to_load[k] = v
272 | else:
273 | logging.debug(
274 | "Skipping parameter %s at pipeline rank %d",
275 | k,
276 | self.pipeline_rank,
277 | )
278 | skipped.add(k)
279 | elif any(
280 | k.startswith(key)
281 | for key in ["vision_encoder", "vision_language_adapter", "patch_merger", "pre_mm_projector_norm"]
282 | ):
283 | if self.pipeline_rank == 0:
284 | state_to_load[k] = v
285 | else:
286 | logging.debug(
287 | "Skipping parameter %s at pipeline rank %d",
288 | k,
289 | self.pipeline_rank,
290 | )
291 | skipped.add(k)
292 | else:
293 | raise ValueError(f"Unexpected key {k}")
294 | assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys()))
295 | super().load_state_dict(state_to_load, strict=strict, assign=assign)
296 |
297 | @staticmethod
298 | def from_folder(
299 | folder: Union[Path, str],
300 | max_batch_size: int = 1,
301 | num_pipeline_ranks: int = 1,
302 | device: Union[torch.device, str] = "cuda",
303 | dtype: Optional[torch.dtype] = None,
304 | softmax_fp32: bool = True,
305 | ) -> "Transformer":
306 | with open(Path(folder) / "params.json", "r") as f:
307 | model_args = TransformerArgs.from_dict(json.load(f))
308 | model_args.max_batch_size = max_batch_size
309 | if num_pipeline_ranks > 1:
310 | pipeline_rank = torch.distributed.get_rank()
311 | else:
312 | pipeline_rank = 0
313 | with torch.device("meta"):
314 | model = Transformer(
315 | model_args,
316 | pipeline_rank=pipeline_rank,
317 | num_pipeline_ranks=num_pipeline_ranks,
318 | softmax_fp32=softmax_fp32,
319 | )
320 |
321 | pt_model_file = Path(folder) / "consolidated.00.pth"
322 | safetensors_model_file = Path(folder) / "consolidated.safetensors"
323 |
324 | assert pt_model_file.exists() or safetensors_model_file.exists(), (
325 | f"Make sure either {pt_model_file} or {safetensors_model_file} exists"
326 | )
327 | assert not (pt_model_file.exists() and safetensors_model_file.exists()), (
328 | f"Both {pt_model_file} and {safetensors_model_file} cannot exist"
329 | )
330 |
331 | if pt_model_file.exists():
332 | loaded = torch.load(str(pt_model_file), mmap=True)
333 | else:
334 | loaded = safetensors.torch.load_file(str(safetensors_model_file))
335 |
336 | model.load_state_dict(loaded, assign=True, strict=True)
337 |
338 | return model.to(device=device, dtype=dtype)
339 |
--------------------------------------------------------------------------------
/src/mistral_inference/transformer_layers.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional, Tuple, Type, Union
3 |
4 | import torch
5 | from torch import nn
6 | from xformers.ops.fmha import memory_efficient_attention # type: ignore
7 | from xformers.ops.fmha.attn_bias import BlockDiagonalMask
8 |
9 | from mistral_inference.args import LoraArgs
10 | from mistral_inference.cache import CacheView
11 | from mistral_inference.lora import LoRALinear
12 | from mistral_inference.moe import MoeArgs, MoeLayer
13 | from mistral_inference.rope import apply_rotary_emb
14 |
15 |
16 | def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
17 | keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
18 | values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
19 | return keys, values
20 |
21 |
22 | def maybe_lora(
23 | lora_args: Optional[LoraArgs],
24 | ) -> Union[Type[nn.Linear], partial[LoRALinear]]:
25 | if lora_args is None:
26 | return nn.Linear
27 | else:
28 | return partial(LoRALinear, rank=lora_args.rank, scaling=lora_args.scaling)
29 |
30 |
31 | class Attention(nn.Module):
32 | def __init__(
33 | self,
34 | dim: int,
35 | n_heads: int,
36 | head_dim: int,
37 | n_kv_heads: int,
38 | lora: Optional[LoraArgs] = None,
39 | ):
40 | super().__init__()
41 |
42 | self.n_heads: int = n_heads
43 | self.head_dim: int = head_dim
44 | self.n_kv_heads: int = n_kv_heads
45 |
46 | self.repeats = self.n_heads // self.n_kv_heads
47 |
48 | self.scale = self.head_dim**-0.5
49 |
50 | MaybeLora = maybe_lora(lora)
51 | self.wq = MaybeLora(dim, n_heads * head_dim, bias=False)
52 | self.wk = MaybeLora(dim, n_kv_heads * head_dim, bias=False)
53 | self.wv = MaybeLora(dim, n_kv_heads * head_dim, bias=False)
54 | self.wo = MaybeLora(n_heads * head_dim, dim, bias=False)
55 |
56 | def forward(
57 | self,
58 | x: torch.Tensor,
59 | freqs_cis: torch.Tensor,
60 | cache: Optional[CacheView] = None,
61 | mask: Optional[BlockDiagonalMask] = None,
62 | ) -> torch.Tensor:
63 | assert mask is None or cache is None
64 | seqlen_sum, _ = x.shape
65 |
66 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
67 | xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
68 | xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
69 | xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
70 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
71 |
72 | if cache is None:
73 | key, val = xk, xv
74 | elif cache.prefill:
75 | key, val = cache.interleave_kv(xk, xv)
76 | cache.update(xk, xv)
77 | else:
78 | cache.update(xk, xv)
79 | key, val = cache.key, cache.value
80 | key = key.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
81 | val = val.view(seqlen_sum * cache.max_seq_len, self.n_kv_heads, self.head_dim)
82 |
83 | # Repeat keys and values to match number of query heads
84 | key, val = repeat_kv(key, val, self.repeats, dim=1)
85 |
86 | # xformers requires (B=1, S, H, D)
87 | xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
88 | output = memory_efficient_attention(xq, key, val, mask if cache is None else cache.mask)
89 | output = output.view(seqlen_sum, self.n_heads * self.head_dim)
90 |
91 | assert isinstance(output, torch.Tensor)
92 |
93 | return self.wo(output) # type: ignore
94 |
95 |
96 | class FeedForward(nn.Module):
97 | def __init__(self, dim: int, hidden_dim: int, lora: Optional[LoraArgs] = None):
98 | super().__init__()
99 |
100 | MaybeLora = maybe_lora(lora)
101 | self.w1 = MaybeLora(dim, hidden_dim, bias=False)
102 | self.w2 = MaybeLora(hidden_dim, dim, bias=False)
103 | self.w3 = MaybeLora(dim, hidden_dim, bias=False)
104 |
105 | def forward(self, x: torch.Tensor) -> torch.Tensor:
106 | return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore
107 |
108 |
109 | class RMSNorm(torch.nn.Module):
110 | def __init__(self, dim: int, eps: float = 1e-6):
111 | super().__init__()
112 | self.eps = eps
113 | self.weight = nn.Parameter(torch.ones(dim))
114 |
115 | def _norm(self, x: torch.Tensor) -> torch.Tensor:
116 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
117 |
118 | def forward(self, x: torch.Tensor) -> torch.Tensor:
119 | output = self._norm(x.float()).type_as(x)
120 | return output * self.weight
121 |
122 |
123 | class TransformerBlock(nn.Module):
124 | def __init__(
125 | self,
126 | dim: int,
127 | hidden_dim: int,
128 | n_heads: int,
129 | n_kv_heads: int,
130 | head_dim: int,
131 | norm_eps: float,
132 | lora: Optional[LoraArgs] = None,
133 | moe: Optional[MoeArgs] = None,
134 | ):
135 | super().__init__()
136 | self.n_heads = n_heads
137 | self.dim = dim
138 | self.attention = Attention(
139 | dim=dim,
140 | n_heads=n_heads,
141 | head_dim=head_dim,
142 | n_kv_heads=n_kv_heads,
143 | lora=lora,
144 | )
145 | self.attention_norm = RMSNorm(dim, eps=norm_eps)
146 | self.ffn_norm = RMSNorm(dim, eps=norm_eps)
147 |
148 | self.feed_forward: nn.Module
149 | if moe is not None:
150 | self.feed_forward = MoeLayer(
151 | experts=[FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora) for _ in range(moe.num_experts)],
152 | gate=nn.Linear(dim, moe.num_experts, bias=False),
153 | moe_args=moe,
154 | )
155 | else:
156 | self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim, lora=lora)
157 |
158 | def forward(
159 | self,
160 | x: torch.Tensor,
161 | freqs_cis: torch.Tensor,
162 | cache: Optional[CacheView] = None,
163 | mask: Optional[BlockDiagonalMask] = None,
164 | ) -> torch.Tensor:
165 | r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
166 | h = x + r
167 | r = self.feed_forward.forward(self.ffn_norm(h))
168 | out = h + r
169 | return out
170 |
--------------------------------------------------------------------------------
/src/mistral_inference/vision_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | from xformers.ops.fmha.attn_bias import BlockDiagonalMask
6 |
7 | from mistral_inference.args import VisionEncoderArgs
8 | from mistral_inference.rope import precompute_freqs_cis_2d
9 | from mistral_inference.transformer_layers import RMSNorm, TransformerBlock
10 |
11 |
12 | def position_meshgrid(
13 | patch_embeds_list: list[torch.Tensor],
14 | ) -> torch.Tensor:
15 | positions = torch.cat(
16 | [
17 | torch.stack(
18 | torch.meshgrid(
19 | torch.arange(p.shape[-2]),
20 | torch.arange(p.shape[-1]),
21 | indexing="ij",
22 | ),
23 | dim=-1,
24 | ).reshape(-1, 2)
25 | for p in patch_embeds_list
26 | ]
27 | )
28 | return positions
29 |
30 |
31 | class VisionTransformer(nn.Module):
32 | def __init__(self, args: VisionEncoderArgs):
33 | super().__init__()
34 | self.args = args
35 | self.patch_conv = nn.Conv2d(
36 | in_channels=args.num_channels,
37 | out_channels=args.hidden_size,
38 | kernel_size=args.patch_size,
39 | stride=args.patch_size,
40 | bias=False,
41 | )
42 | self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
43 | self.transformer = VisionTransformerBlocks(args)
44 |
45 | head_dim = self.args.hidden_size // self.args.num_attention_heads
46 | assert head_dim % 2 == 0, "ROPE requires even head_dim"
47 | self._freqs_cis: Optional[torch.Tensor] = None
48 |
49 | @property
50 | def max_patches_per_side(self) -> int:
51 | return self.args.image_size // self.args.patch_size
52 |
53 | @property
54 | def device(self) -> torch.device:
55 | return next(self.parameters()).device
56 |
57 | @property
58 | def freqs_cis(self) -> torch.Tensor:
59 | if self._freqs_cis is None:
60 | self._freqs_cis = precompute_freqs_cis_2d(
61 | dim=self.args.hidden_size // self.args.num_attention_heads,
62 | height=self.max_patches_per_side,
63 | width=self.max_patches_per_side,
64 | theta=self.args.rope_theta,
65 | )
66 |
67 | if self._freqs_cis.device != self.device:
68 | self._freqs_cis = self._freqs_cis.to(device=self.device)
69 |
70 | return self._freqs_cis
71 |
72 | def forward(
73 | self,
74 | images: List[torch.Tensor],
75 | ) -> torch.Tensor:
76 | """
77 | Args:
78 | images: list of N_img images of variable sizes, each of shape (C, H, W)
79 |
80 | Returns:
81 | image_features: tensor of token features for all tokens of all images of
82 | shape (N_toks, D)
83 | """
84 | # pass images through initial convolution independently
85 | patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images]
86 |
87 | # flatten to a single sequence
88 | patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0)
89 | patch_embeds = self.ln_pre(patch_embeds)
90 |
91 | # positional embeddings
92 | positions = position_meshgrid(patch_embeds_list).to(self.device)
93 | freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
94 |
95 | # pass through Transformer with a block diagonal mask delimiting images
96 | mask = BlockDiagonalMask.from_seqlens(
97 | [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
98 | )
99 | out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
100 |
101 | # remove batch dimension of the single sequence
102 | return out # type: ignore[no-any-return]
103 |
104 |
105 | class VisionLanguageAdapter(nn.Module):
106 | def __init__(self, in_dim: int, out_dim: int, bias: bool = True):
107 | super().__init__()
108 | self.w_in = nn.Linear(
109 | in_dim,
110 | out_dim,
111 | bias=bias,
112 | )
113 | self.gelu = nn.GELU()
114 | self.w_out = nn.Linear(out_dim, out_dim, bias=bias)
115 |
116 | def forward(self, x: torch.Tensor) -> torch.Tensor:
117 | return self.w_out(self.gelu(self.w_in(x))) # type: ignore[no-any-return]
118 |
119 |
120 | class VisionTransformerBlocks(nn.Module):
121 | def __init__(self, args: VisionEncoderArgs):
122 | super().__init__()
123 | self.layers = torch.nn.ModuleList()
124 | for _ in range(args.num_hidden_layers):
125 | self.layers.append(
126 | TransformerBlock(
127 | dim=args.hidden_size,
128 | hidden_dim=args.intermediate_size,
129 | n_heads=args.num_attention_heads,
130 | n_kv_heads=args.num_attention_heads,
131 | head_dim=args.hidden_size // args.num_attention_heads,
132 | norm_eps=1e-5,
133 | )
134 | )
135 |
136 | def forward(
137 | self,
138 | x: torch.Tensor,
139 | mask: BlockDiagonalMask,
140 | freqs_cis: Optional[torch.Tensor],
141 | ) -> torch.Tensor:
142 | for layer in self.layers:
143 | x = layer(x, mask=mask, freqs_cis=freqs_cis)
144 | return x
145 |
146 |
147 | class PatchMerger(nn.Module):
148 | """
149 | Learned merging of spatial_merge_size ** 2 patches
150 | """
151 |
152 | def __init__(
153 | self,
154 | vision_encoder_dim: int,
155 | spatial_merge_size: int,
156 | ) -> None:
157 | super().__init__()
158 |
159 | mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)
160 |
161 | self.spatial_merge_size = spatial_merge_size
162 | self.mlp_input_dim = mlp_input_dim
163 |
164 | self.merging_layer = nn.Linear(mlp_input_dim, vision_encoder_dim, bias=False)
165 |
166 | def forward(self, x: torch.Tensor, image_sizes: list[tuple[int, int]]) -> torch.Tensor:
167 | # image_sizes specified in tokens
168 | assert sum([h * w for h, w in image_sizes]) == len(x), f"{sum([h * w for h, w in image_sizes])} != {len(x)}"
169 |
170 | # x is (N, vision_encoder_dim)
171 | x = self.permute(x, image_sizes)
172 |
173 | # x is (N / spatial_merge_size ** 2,
174 | # vision_encoder_dim * spatial_merge_size ** 2)
175 | x = self.merging_layer(x)
176 |
177 | # x is (N / spatial_merge_size ** 2, vision_encoder_dim)
178 | return x
179 |
180 | def permute(
181 | self,
182 | x: torch.Tensor,
183 | image_sizes: list[tuple[int, int]],
184 | ) -> torch.Tensor:
185 | """
186 | Args:
187 | x: (N, D) where N is flattened and concatenated patch tokens
188 | for all images
189 | image_sizes: list of tuple of (height, width) in tokens for
190 | each image
191 | Returns:
192 | image_features: reorders patch tokens so each grid of
193 | (spatial_merge_size, spatial_merge_size) is contiguous.
194 | now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
195 | """
196 |
197 | sub_grids = get_sub_grids(
198 | x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size
199 | ) # list of [d x sub_grid_size x sub_grid_size x n_patches]
200 | permuted_tensor = [
201 | grid.view(-1, grid.shape[-1]).t() for grid in sub_grids
202 | ] # n_patches x d * sub_grid_size * sub_grid_size
203 | return torch.cat(permuted_tensor, dim=0) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)
204 |
205 |
206 | def get_sub_grids(
207 | x: torch.Tensor,
208 | image_sizes: list[tuple[int, int]],
209 | spatial_merge_size: int,
210 | ) -> list[torch.Tensor]:
211 | # image_sizes specified in tokens
212 | tokens_per_image = [h * w for h, w in image_sizes]
213 | d = x.shape[-1]
214 | all_img_sub_grids: list[torch.Tensor] = []
215 | sub_grid_size = spatial_merge_size
216 |
217 | for image_index, image_tokens in enumerate(x.split(tokens_per_image)):
218 | # Reshape image_tokens into a 2D grid
219 | h, w = image_sizes[image_index]
220 | image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[None, :, :, :] # 1 x d x h x w
221 | sub_grids = torch.nn.functional.unfold(image_grid, kernel_size=sub_grid_size, stride=sub_grid_size)
222 | sub_grids = sub_grids.view(
223 | 1, d, sub_grid_size, sub_grid_size, -1
224 | ) # 1 x d x sub_grid_size x sub_grid_size x n_patches
225 |
226 | all_img_sub_grids.append(sub_grids[0])
227 |
228 | return all_img_sub_grids
229 |
--------------------------------------------------------------------------------
/tests/test_generate.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import torch
5 | from mistral_inference.args import VisionEncoderArgs
6 | from mistral_inference.generate import generate_mamba
7 | from mistral_inference.main import generate
8 | from mistral_inference.mamba import Mamba, MambaArgs
9 | from mistral_inference.transformer import Transformer, TransformerArgs
10 |
11 |
12 | class DebugTokenizer:
13 | @property
14 | def bos_id(self) -> int:
15 | return 0
16 |
17 | @property
18 | def eos_id(self) -> int:
19 | return 1
20 |
21 | @property
22 | def pad_id(self) -> int:
23 | return -1
24 |
25 | def encode(self, s: str, bos: bool = True) -> List[int]:
26 | assert isinstance(s, str)
27 | t = [int(x) for x in s.split()]
28 | if bos:
29 | t = [self.bos_id, *t]
30 | return t
31 |
32 | def decode(self, t: List[int]) -> str:
33 | return " ".join([str(x) for x in t])
34 |
35 |
36 | def test_generation_transformer() -> None:
37 | torch.manual_seed(42)
38 |
39 | sequences = ["1 2 3 4 5 6 7", "0 1 2", "12 13 14", "2 4 34"]
40 | args = TransformerArgs(
41 | dim=512,
42 | n_layers=1,
43 | head_dim=128,
44 | hidden_dim=2048,
45 | n_heads=4,
46 | n_kv_heads=2,
47 | norm_eps=1e-5,
48 | vocab_size=32_000,
49 | max_batch_size=len(sequences),
50 | )
51 | model = Transformer(args).to("cuda", dtype=torch.float32)
52 | tokenizer = DebugTokenizer()
53 |
54 | encoded = [tokenizer.encode(s, bos=True) for s in sequences]
55 | toks, all_logprobs_old = generate(encoded, model, temperature=0.0, max_tokens=7)
56 |
57 | # concat generated and prompt
58 | encoded = [e + t for e, t in zip(encoded, toks)]
59 |
60 | generated, all_logprobs_new = generate(encoded, model, temperature=0.0, max_tokens=0)
61 |
62 | assert generated == []
63 |
64 | # Verify that logprobs are the same
65 | assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new)
66 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
67 | assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}"
68 |
69 | print("All tests passed.")
70 |
71 |
72 | def test_generation_pixtral() -> None:
73 | torch.manual_seed(42)
74 | gen = np.random.default_rng(seed=42)
75 |
76 | sequences = ["1 2 2 2 2 4 5 6 7", "12 13 14", "2 2 2 2 7 8 9"]
77 | images = [[gen.normal(size=(3, 4, 4))], [], [gen.normal(size=(3, 4, 4))]]
78 | args = TransformerArgs(
79 | dim=512,
80 | n_layers=1,
81 | head_dim=128,
82 | hidden_dim=2048,
83 | n_heads=4,
84 | n_kv_heads=2,
85 | norm_eps=1e-5,
86 | vocab_size=32_000,
87 | max_batch_size=len(sequences),
88 | vision_encoder=VisionEncoderArgs(
89 | hidden_size=128,
90 | num_channels=3,
91 | image_size=4,
92 | patch_size=2,
93 | intermediate_size=256,
94 | num_hidden_layers=1,
95 | num_attention_heads=2,
96 | rope_theta=10000,
97 | image_token_id=2,
98 | ),
99 | )
100 | model = Transformer(args).to("cuda", dtype=torch.float32)
101 | tokenizer = DebugTokenizer()
102 |
103 | encoded = [tokenizer.encode(s, bos=True) for s in sequences]
104 | toks, all_logprobs_old = generate(encoded, model, images=images, temperature=0.0, max_tokens=7)
105 |
106 | # concat generated and prompt
107 | encoded = [e + t for e, t in zip(encoded, toks)]
108 |
109 | generated, all_logprobs_new = generate(encoded, model, images=images, temperature=0.0, max_tokens=0)
110 |
111 | assert generated == []
112 |
113 | # Verify that logprobs are the same
114 | assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new)
115 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
116 | assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}"
117 |
118 | print("All tests passed.")
119 |
120 |
121 | def test_generation_pixtral_patch_merger() -> None:
122 | torch.manual_seed(42)
123 | gen = np.random.default_rng(seed=42)
124 |
125 | sequences = ["1 2 2 2 2 4 5 6 7", "12 13 14", "2 2 2 2 7 8 9"]
126 | images = [[gen.normal(size=(3, 8, 8))], [], [gen.normal(size=(3, 8, 8))]]
127 | args = TransformerArgs(
128 | dim=512,
129 | n_layers=1,
130 | head_dim=128,
131 | hidden_dim=2048,
132 | n_heads=4,
133 | n_kv_heads=2,
134 | norm_eps=1e-5,
135 | vocab_size=32_000,
136 | max_batch_size=len(sequences),
137 | vision_encoder=VisionEncoderArgs(
138 | hidden_size=128,
139 | num_channels=3,
140 | image_size=8,
141 | patch_size=2,
142 | intermediate_size=256,
143 | num_hidden_layers=1,
144 | num_attention_heads=2,
145 | rope_theta=10000,
146 | image_token_id=2,
147 | adapter_bias=False,
148 | spatial_merge_size=2,
149 | add_pre_mm_projector_layer_norm=True,
150 | mm_projector_id="patch_merge",
151 | ),
152 | )
153 | model = Transformer(args).to("cuda", dtype=torch.float32)
154 | tokenizer = DebugTokenizer()
155 |
156 | encoded = [tokenizer.encode(s, bos=True) for s in sequences]
157 | toks, all_logprobs_old = generate(encoded, model, images=images, temperature=0.0, max_tokens=7)
158 |
159 | # concat generated and prompt
160 | encoded = [e + t for e, t in zip(encoded, toks)]
161 |
162 | generated, all_logprobs_new = generate(encoded, model, images=images, temperature=0.0, max_tokens=0)
163 |
164 | assert generated == []
165 |
166 | # Verify that logprobs are the same
167 | assert len(sequences) == len(all_logprobs_old) == len(all_logprobs_new)
168 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
169 | assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}"
170 |
171 | print("All tests passed.")
172 |
173 |
174 | def test_generation_mamba() -> None:
175 | torch.manual_seed(42)
176 |
177 | sequences = ["1 2 3 4 5 6 7"]
178 | args = MambaArgs(
179 | dim=512,
180 | n_layers=1,
181 | n_groups=1,
182 | rms_norm=True,
183 | residual_in_fp32=True,
184 | fused_add_norm=True,
185 | pad_vocab_size_multiple=1,
186 | tie_embeddings=False,
187 | vocab_size=32768,
188 | )
189 | model = Mamba(args).to("cuda", dtype=torch.float32)
190 | tokenizer = DebugTokenizer()
191 |
192 | encoded = [tokenizer.encode(s, bos=True) for s in sequences]
193 | toks, all_logprobs_old = generate_mamba(encoded, model, temperature=0.0, max_tokens=7)
194 |
195 | assert len(toks[0]) == 7
196 | assert toks == [[25574, 14821, 11843, 23698, 12735, 23522, 27542]]
197 |
198 |
199 | def test_chunks_transformer() -> None:
200 | torch.manual_seed(42)
201 |
202 | sequences = [
203 | " ".join([str(i) for i in range(7)]),
204 | " ".join([str(i) for i in range(9, 0, -1)]),
205 | ]
206 | args = TransformerArgs(
207 | dim=512,
208 | n_layers=1,
209 | head_dim=128,
210 | hidden_dim=2048,
211 | n_heads=4,
212 | n_kv_heads=2,
213 | norm_eps=1e-5,
214 | vocab_size=32_000,
215 | max_batch_size=3,
216 | )
217 | model = Transformer(args).to("cuda", dtype=torch.float32)
218 | tokenizer = DebugTokenizer()
219 |
220 | encoded = [tokenizer.encode(s, bos=True) for s in sequences]
221 | toks, all_logprobs_old = generate(encoded, model, temperature=0.0, max_tokens=8)
222 |
223 | # concat generated and prompt
224 | encoded = [e + t for e, t in zip(encoded, toks)]
225 |
226 | generated, all_logprobs_new = generate(encoded, model, temperature=0.0, max_tokens=0, chunk_size=5)
227 | assert len(generated) == 0
228 |
229 | for lp_old, lp_new in zip(all_logprobs_old, all_logprobs_new):
230 | assert all([abs(x - y) < 5e-4 for x, y in zip(lp_old, lp_new)]), f"\n{lp_old}\n{lp_new}"
231 |
--------------------------------------------------------------------------------
/tutorials/getting_started.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Getting Started with `mistral-inference`\n",
8 | "\n",
9 | "This notebook will guide you through the process of running Mistral models locally. We will cover the following: \n",
10 | "- How to chat with Mistral 7B Instruct\n",
11 | "- How to run Mistral 7B Instruct with function calling capabilities\n",
12 | "\n",
13 | "We recommend using a GPU such as the A100 to run this notebook. "
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {
20 | "id": "G6tXvIsQenpI"
21 | },
22 | "outputs": [],
23 | "source": [
24 | "!pip install mistral-inference"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {},
30 | "source": [
31 | "## Download Mistral 7B Instruct"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {
38 | "colab": {
39 | "background_save": true
40 | },
41 | "id": "4ytmRt0WQeMW"
42 | },
43 | "outputs": [],
44 | "source": [
45 | "!wget https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {
52 | "id": "eRZg_8wvs5A6"
53 | },
54 | "outputs": [],
55 | "source": [
56 | "!DIR=$HOME/mistral_7b_instruct_v3 && mkdir -p $DIR && tar -xf mistral-7B-Instruct-v0.3.tar -C $DIR"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {
63 | "id": "7CN8gShDf65M"
64 | },
65 | "outputs": [],
66 | "source": [
67 | "!ls mistral_7b_instruct_v3"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "## Chat with the model"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "import os \n",
84 | "\n",
85 | "from mistral_inference.transformer import Transformer\n",
86 | "from mistral_inference.generate import generate\n",
87 | "\n",
88 | "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
89 | "from mistral_common.protocol.instruct.messages import UserMessage\n",
90 | "from mistral_common.protocol.instruct.request import ChatCompletionRequest\n",
91 | "\n",
92 | "# load tokenizer\n",
93 | "mistral_tokenizer = MistralTokenizer.from_file(os.path.expanduser(\"~\")+\"/mistral_7b_instruct_v3/tokenizer.model.v3\")\n",
94 | "# chat completion request\n",
95 | "completion_request = ChatCompletionRequest(messages=[UserMessage(content=\"Explain Machine Learning to me in a nutshell.\")])\n",
96 | "# encode message\n",
97 | "tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens\n",
98 | "# load model\n",
99 | "model = Transformer.from_folder(os.path.expanduser(\"~\")+\"/mistral_7b_instruct_v3\")\n",
100 | "# generate results\n",
101 | "out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)\n",
102 | "# decode generated tokens\n",
103 | "result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])\n",
104 | "print(result)"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {
110 | "id": "ce4woS3LkgZ9"
111 | },
112 | "source": [
113 | "## Function calling\n",
114 | "\n",
115 | "Mistral 7B Instruct v3 also supports function calling!"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {
121 | "id": "TKfPiEwNk1kh"
122 | },
123 | "source": [
124 | "Let's start by creating a function calling example"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {
131 | "id": "0PJdwvDEk3dl"
132 | },
133 | "outputs": [],
134 | "source": [
135 | "from mistral_common.protocol.instruct.messages import UserMessage\n",
136 | "from mistral_common.protocol.instruct.request import ChatCompletionRequest\n",
137 | "from mistral_common.protocol.instruct.tool_calls import Function, Tool\n",
138 | "\n",
139 | "completion_request = ChatCompletionRequest(\n",
140 | " tools=[\n",
141 | " Tool(\n",
142 | " function=Function(\n",
143 | " name=\"get_current_weather\",\n",
144 | " description=\"Get the current weather\",\n",
145 | " parameters={\n",
146 | " \"type\": \"object\",\n",
147 | " \"properties\": {\n",
148 | " \"location\": {\n",
149 | " \"type\": \"string\",\n",
150 | " \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
151 | " },\n",
152 | " \"format\": {\n",
153 | " \"type\": \"string\",\n",
154 | " \"enum\": [\"celsius\", \"fahrenheit\"],\n",
155 | " \"description\": \"The temperature unit to use. Infer this from the users location.\",\n",
156 | " },\n",
157 | " },\n",
158 | " \"required\": [\"location\", \"format\"],\n",
159 | " },\n",
160 | " )\n",
161 | " )\n",
162 | " ],\n",
163 | " messages=[\n",
164 | " UserMessage(content=\"What's the weather like today in Paris?\"),\n",
165 | " ],\n",
166 | ")"
167 | ]
168 | },
169 | {
170 | "cell_type": "markdown",
171 | "metadata": {
172 | "id": "bG6ZeZUylpBW"
173 | },
174 | "source": [
175 | "Since we have already loaded the tokenizer and the model in the example above. We will skip these steps here. \n",
176 | "\n",
177 | "Now we can encode the message with our tokenizer using `MistralTokenizer`. "
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {
184 | "id": "Ii8q-JNClwiq"
185 | },
186 | "outputs": [],
187 | "source": [
188 | "from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
189 | "\n",
190 | "tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens"
191 | ]
192 | },
193 | {
194 | "cell_type": "markdown",
195 | "metadata": {
196 | "id": "NrueDujkmJT4"
197 | },
198 | "source": [
199 | "and run `generate` to get a response. Don't forget to pass the EOS id!"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": null,
205 | "metadata": {
206 | "id": "GWJYO43rl0V8"
207 | },
208 | "outputs": [],
209 | "source": [
210 | "from mistral_inference.generate import generate\n",
211 | "\n",
212 | "out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {
218 | "id": "v7baJ1msmPMv"
219 | },
220 | "source": [
221 | "Finally, we can decode the generated tokens."
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": null,
227 | "metadata": {
228 | "id": "RKhryfBWmHon"
229 | },
230 | "outputs": [],
231 | "source": [
232 | "result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens)[0]\n",
233 | "result"
234 | ]
235 | }
236 | ],
237 | "metadata": {
238 | "accelerator": "GPU",
239 | "colab": {
240 | "gpuType": "L4",
241 | "machine_shape": "hm",
242 | "provenance": []
243 | },
244 | "kernelspec": {
245 | "display_name": "Python 3 (ipykernel)",
246 | "language": "python",
247 | "name": "python3"
248 | },
249 | "language_info": {
250 | "codemirror_mode": {
251 | "name": "ipython",
252 | "version": 3
253 | },
254 | "file_extension": ".py",
255 | "mimetype": "text/x-python",
256 | "name": "python",
257 | "nbconvert_exporter": "python",
258 | "pygments_lexer": "ipython3",
259 | "version": "3.11.8"
260 | }
261 | },
262 | "nbformat": 4,
263 | "nbformat_minor": 4
264 | }
265 |
--------------------------------------------------------------------------------