├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── mistral ├── __init__.py ├── cache.py ├── model.py ├── moe.py ├── rope.py └── tokenizer.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Ignore model weights 163 | mistral-7B-v0.1/ 164 | mistral-7B-v0.1.tar 165 | 166 | mixtral-8x7b-32kseqlen/ 167 | 168 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Notes 2 | 3 | Source code of the mistral model with my personal comments to make it easier for everyone to understand the code. 4 | Please check the [original repository](https://github.com/mistralai/mistral-src) for the most up-to-date version of the code. 5 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from mistral.cache import RotatingBufferCache 2 | import logging 3 | import torch 4 | import fire 5 | from typing import List 6 | from pathlib import Path 7 | 8 | from mistral.model import Transformer 9 | from mistral.tokenizer import Tokenizer 10 | 11 | 12 | def sample_top_p(probs: torch.Tensor, p: float): 13 | assert 0 <= p <= 1 14 | 15 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 16 | probs_sum = torch.cumsum(probs_sort, dim=-1) 17 | mask = probs_sum - probs_sort > p 18 | probs_sort[mask] = 0.0 19 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 20 | next_token = torch.multinomial(probs_sort, num_samples=1) 21 | return torch.gather(probs_idx, -1, next_token) 22 | 23 | 24 | def sample(logits: torch.Tensor, temperature: float, top_p: float): 25 | if temperature > 0: 26 | probs = torch.softmax(logits / temperature, dim=-1) 27 | next_token = sample_top_p(probs, top_p) 28 | else: 29 | next_token = torch.argmax(logits, dim=-1).unsqueeze(0) 30 | 31 | return next_token.reshape(-1) 32 | 33 | 34 | @torch.inference_mode() 35 | def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, max_tokens: int, temperature: float, chunk_size: int = None): 36 | model = model.eval() 37 | batch_size, vocabulary_size = len(prompts), model.args.vocab_size # Batch_Size, Vocabulary_Size 38 | 39 | # Tokenize 40 | encoded_prompts = [tokenizer.encode(prompt, bos=True) for prompt in prompts] 41 | # Indicates the number of tokens in each prompt 42 | prompts_sequence_lengths = [len(x) for x in encoded_prompts] 43 | 44 | # Cache 45 | # Indicates the size of the rotating cache 46 | cache_window = max(prompts_sequence_lengths) + max_tokens 47 | 48 | # If the cache window is larger than the sliding window, the cache window is set to the sliding window 49 | if model.args.sliding_window is not None and cache_window > model.args.sliding_window: 50 | cache_window = model.args.sliding_window 51 | 52 | # Create the cache 53 | cache = RotatingBufferCache( 54 | model.n_local_layers, 55 | model.args.max_batch_size, 56 | cache_window, 57 | model.args.n_kv_heads, 58 | model.args.head_dim, 59 | ) 60 | 61 | cache.to(device=model.device, dtype=model.dtype) 62 | cache.reset() 63 | 64 | # Bookkeeping 65 | logprobs = [[] for _ in range(batch_size)] 66 | last_token_prelogits = None 67 | 68 | # One chunk if size not specified 69 | max_prompt_len = max(prompts_sequence_lengths) 70 | if chunk_size is None: 71 | chunk_size = max_prompt_len 72 | 73 | # Encode prompt by chunks 74 | for s in range(0, max_prompt_len, chunk_size): 75 | prompt_chunks = [p[s:s+chunk_size] for p in encoded_prompts] # Extract the tokens belonging to the current chunk 76 | assert all(len(p) > 0 for p in prompt_chunks) 77 | prelogits = model.forward( 78 | torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long), # Concatenate all the tokens in the current chunk (of all the prompts) in a single tensor 79 | seqlens=[len(p) for p in prompt_chunks], 80 | cache=cache 81 | ) 82 | logits = torch.log_softmax(prelogits, dim=-1) 83 | 84 | if last_token_prelogits is not None: 85 | # Pass > 1 86 | last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1) 87 | for i_seq in range(batch_size): 88 | logprobs[i_seq].append(last_token_logits[i_seq, prompt_chunks[i_seq][0]].item()) 89 | 90 | offset = 0 91 | for i_seq, sequence in enumerate(prompt_chunks): 92 | logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)]) 93 | offset += len(sequence) 94 | 95 | last_token_prelogits = prelogits.index_select(0, torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1) 96 | assert last_token_prelogits.shape == (batch_size, vocabulary_size) 97 | 98 | # decode 99 | generated_tokens = [] 100 | assert last_token_prelogits is not None 101 | for i_token in range(max_tokens): 102 | next_token = sample(last_token_prelogits, temperature=temperature, top_p=0.8) 103 | 104 | last_token_logits = torch.log_softmax(last_token_prelogits, dim=-1) 105 | for i in range(batch_size): 106 | logprobs[i].append(last_token_logits[i, next_token[i]].item()) 107 | 108 | generated_tokens.append(next_token[:, None]) 109 | last_token_prelogits = model.forward(next_token, seqlens=[1] * len(prompts), cache=cache) 110 | assert last_token_prelogits.shape == (batch_size, vocabulary_size) 111 | 112 | generated_words = [] 113 | if generated_tokens: 114 | generated_tokens = torch.cat(generated_tokens, 1) 115 | for i, x in enumerate(encoded_prompts): 116 | generated_words.append(tokenizer.decode(x + generated_tokens[i].tolist())) 117 | 118 | return generated_words, logprobs 119 | 120 | 121 | def interactive(model_path: str, max_tokens: int = 35, temperature: float = 0.7, instruct: bool = False): 122 | tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model")) 123 | transformer = Transformer.from_folder(Path(model_path), max_batch_size=3) 124 | 125 | while True: 126 | prompt = input("Prompt: ") 127 | if instruct: 128 | prompt = f"[INST] {prompt} [/INST]" 129 | res, _logprobs = generate( 130 | [prompt], 131 | transformer, 132 | tokenizer, 133 | max_tokens=max_tokens, 134 | temperature=temperature, 135 | ) 136 | print(res[0]) 137 | print("=====================") 138 | 139 | 140 | def demo( 141 | model_path: str, max_tokens: int = 35, temperature: float = 0, num_pipeline_ranks=1 142 | ): 143 | if num_pipeline_ranks > 1: 144 | torch.distributed.init_process_group() 145 | torch.cuda.set_device(torch.distributed.get_rank()) 146 | should_print = torch.distributed.get_rank() == 0 147 | else: 148 | should_print = True 149 | tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model")) 150 | transformer = Transformer.from_folder( 151 | Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks 152 | ) 153 | 154 | res, _logprobs = generate( 155 | [ 156 | "This is a test made by me with the help of an AI assistant. I also like to play with videogames. Can you recommend me one game to play with?", 157 | "This is another great test", 158 | "This is a third test, mistral AI is very good at testing. ", 159 | ], 160 | transformer, 161 | tokenizer, 162 | max_tokens=max_tokens, 163 | temperature=temperature 164 | ) 165 | if should_print: 166 | for x,l in zip(res, _logprobs): 167 | print(x) 168 | logging.debug('Logprobs: %s',l) 169 | print("=====================") 170 | 171 | if __name__ == "__main__": 172 | logging.basicConfig(level=logging.INFO) 173 | fire.Fire({ 174 | "interactive": interactive, 175 | "demo": demo, 176 | }) 177 | -------------------------------------------------------------------------------- /mistral/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/mistral-src-commented/dd6211c2fc41c1807970e0b65909fc9b19fd18cb/mistral/__init__.py -------------------------------------------------------------------------------- /mistral/cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple 3 | from dataclasses import dataclass 4 | 5 | from xformers.ops.fmha.attn_bias import ( 6 | AttentionBias, 7 | BlockDiagonalCausalMask, 8 | BlockDiagonalCausalWithOffsetPaddedKeysMask, 9 | BlockDiagonalMask, 10 | ) 11 | 12 | 13 | @dataclass 14 | class RotatingCacheInputMetadata: 15 | # rope absolute positions 16 | positions: torch.Tensor 17 | # which elements in the sequences need to be cached 18 | to_cache_mask: torch.Tensor 19 | # how many elements are cached per sequence 20 | cached_elements: torch.Tensor 21 | # where tokens should go in the cache 22 | cache_positions: torch.Tensor 23 | 24 | # if prefill, use block diagonal causal mask 25 | # else use causal with padded key mask 26 | prefill: bool 27 | mask: AttentionBias # Mask for the attention 28 | seqlens: List[int] 29 | 30 | 31 | def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]): 32 | assert len(l1) == len(l2) 33 | return [v for pair in zip(l1, l2) for v in pair] 34 | 35 | 36 | def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor: # seqlen is the total number of tokens cached so far, including the overwritten one. This is needed to calculate the rotation point of the cache 37 | assert cache.ndim == 3 # (Sliding_Window_Size, Num_Heads, Head_Dim) 38 | position = seqlen % cache.shape[0] # This is the pivot point around which we need to rotate the cache 39 | if seqlen < cache.shape[0]: # If the total sequence length so far is smaller than the cache size, then just return the first seqlen elements, as the cache didn't have any rotations yet 40 | return cache[:seqlen] 41 | elif position == 0: 42 | return cache 43 | else: 44 | return torch.cat([cache[position:], cache[:position]], dim=0) # Select the unrotated elements from the cache around the pivot point 45 | 46 | 47 | class CacheView: 48 | def __init__(self, cache_k: torch.Tensor, cache_v: torch.Tensor, metadata: RotatingCacheInputMetadata, kv_seqlens: torch.Tensor): 49 | self.cache_k = cache_k 50 | self.cache_v = cache_v 51 | self.kv_seqlens = kv_seqlens 52 | self.metadata = metadata 53 | 54 | def update(self, xk: torch.Tensor, xv: torch.Tensor): 55 | """ 56 | to_cache_mask masks the last [sliding_window] tokens in each sequence 57 | """ 58 | n_kv_heads, head_dim = self.cache_k.shape[-2:] 59 | flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim) # (Max_Batch_Size, Sliding_Window_Size, N_Heads_KV, Head_Dim) --> (Max_Batch_Size * Sliding_Window_Size, N_Heads_KV, Head_Dim) 60 | flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim) # (Max_Batch_Size, Sliding_Window_Size, N_Heads_KV, Head_Dim) --> (Max_Batch_Size * Sliding_Window_Size, N_Heads_KV, Head_Dim) 61 | # Copies from the xk and xv tensors to the cache tensors, based on the cache positions and the items to cache (to_cache_mask) 62 | flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask]) 63 | flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask]) 64 | 65 | def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 66 | """ 67 | This is a naive implementation and not optimized for speed. 68 | """ 69 | assert xk.ndim == xv.ndim == 3 # (B * T, H, D) 70 | assert xk.shape == xv.shape 71 | 72 | if all([s == 0 for s in self.metadata.seqlens]): 73 | # No cache to interleave 74 | return xk, xv 75 | 76 | # Make it a list of [(Seq, N_Heads_KV, Head_Dim)] 77 | xk = torch.split(xk, self.metadata.seqlens) # (Seq1+Seq2+Seq3, N_Heads_KV, Head_Dim) --> [(Seq1, N_Heads_KV, Head_Dim), (Seq2, N_Heads_KV, Head_Dim), (Seq3, N_Heads_KV, Head_Dim)] 78 | xv = torch.split(xv, self.metadata.seqlens) # (Seq1+Seq2+Seq3, N_Heads_KV, Head_Dim) --> [(Seq1, N_Heads_KV, Head_Dim), (Seq2, N_Heads_KV, Head_Dim), (Seq3, N_Heads_KV, Head_Dim)] 79 | assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}" 80 | 81 | # Order elements in cache by position by unrotating 82 | cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)] # Currently cached elements, already unrotated, one for each prompt 83 | cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)] # Currently cached elements, already unrotated, one for each prompt 84 | 85 | interleaved_k = interleave_list(cache_k, xk) # Appends the incoming keys and values to the currently cached elements (one for each prompt) 86 | interleaved_v = interleave_list(cache_v, xv) # Appends the incoming keys and values to the currently cached elements (one for each prompt) 87 | 88 | return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0) 89 | 90 | @property 91 | def sliding_window(self): 92 | return self.cache_k.shape[1] 93 | 94 | @property 95 | def key(self) -> torch.Tensor: 96 | return self.cache_k[:len(self.kv_seqlens)] 97 | 98 | @property 99 | def value(self) -> torch.Tensor: 100 | return self.cache_v[:len(self.kv_seqlens)] 101 | 102 | @property 103 | def prefill(self): 104 | return self.metadata.prefill 105 | 106 | @property 107 | def mask(self): 108 | return self.metadata.mask 109 | 110 | 111 | class RotatingBufferCache: 112 | """ 113 | This is an example that implements a less naive rotating buffer cache, allowing for variable length sequences. 114 | Allocated cache is rectangular which is wasteful (see PagedAttention for better mechanisms) 115 | """ 116 | def __init__(self, n_layers: int, max_batch_size: int, sliding_window: int, n_kv_heads: int, head_dim: int): 117 | 118 | self.sliding_window = sliding_window 119 | self.n_kv_heads = n_kv_heads 120 | self.head_dim = head_dim # model_dim / n_heads 121 | 122 | self.cache_k = torch.empty(( 123 | n_layers, 124 | max_batch_size, 125 | sliding_window, 126 | n_kv_heads, 127 | head_dim 128 | )) 129 | self.cache_v = torch.empty(( 130 | n_layers, 131 | max_batch_size, 132 | sliding_window, 133 | n_kv_heads, 134 | head_dim 135 | )) 136 | 137 | # holds the valid length for each batch element in the cache 138 | self.kv_seqlens = None 139 | 140 | def get_view(self, layer_id: int, metadata: RotatingCacheInputMetadata) -> CacheView: 141 | return CacheView(self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens) 142 | 143 | def reset(self): 144 | self.kv_seqlens = None 145 | 146 | def init_kvseqlens(self, batch_size: int): 147 | self.kv_seqlens = torch.zeros((batch_size,), device=self.device, dtype=torch.long) 148 | 149 | @property 150 | def device(self): 151 | return self.cache_k.device 152 | 153 | def to(self, device: torch.device, dtype: torch.dtype): 154 | self.cache_k = self.cache_k.to(device=device, dtype=dtype) 155 | self.cache_v = self.cache_v.to(device=device, dtype=dtype) 156 | 157 | return self 158 | 159 | def update_seqlens(self, seqlens: List[int]): 160 | self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long) 161 | 162 | def get_input_metadata(self, seqlens: List[int]) -> RotatingCacheInputMetadata: 163 | """ 164 | inpput = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3 165 | --> only cache last 3 tokens in each sequence 166 | - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1] 167 | - cached_elements = [3 | 3 | 2] 168 | --> absolute positions are used for rope 169 | - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4] 170 | --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window 171 | - cache_positions = [2 0 1 | 5 3 4 | 6 7] 172 | """ 173 | if self.kv_seqlens is None: 174 | self.init_kvseqlens(len(seqlens)) 175 | assert len(seqlens) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?" 176 | seqpos = self.kv_seqlens.tolist() # Indicates the total length seen by the cache so far (including the overwritten elements) for each prompt 177 | 178 | assert len(seqlens) > 0, seqlens 179 | 180 | # [True] if the token position belongs to the last `sliding_window` positions of the sequence. It is always True unless the chunk size is bigger than the sliding window 181 | # Indicates which items in the sequence should be cached (the last `sliding_window` tokens of each sequence) 182 | masks = [ 183 | [x >= seqlen - self.sliding_window for x in range(seqlen)] 184 | for seqlen in seqlens # The sequence length of each input in the batch (so we can understand which token belongs to which prompt) 185 | ] 186 | 187 | # Indicates which items in the sequence should be cached (the last `sliding_window` tokens of each sequence) 188 | # Concatenate all the masks of each prompt in the batch 189 | to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool) 190 | 191 | # Number of elements in the mask == True 192 | cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long) 193 | 194 | # The position of each token in the prompt (all concatenated). It may not start from zero (because for example the first chunk may be 5 tokens and we are now processing the second chunk) 195 | positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(device=self.device, dtype=torch.long) 196 | 197 | # The index of the batch to which each token (in the concatenated list) belongs to. 198 | batch_idx = torch.tensor(sum([[i]*seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long) 199 | 200 | # Where each token should be placed in the cache (based on the position in the prompt and the batch index) 201 | cache_positions = positions % self.sliding_window + batch_idx * self.sliding_window 202 | 203 | # Indicates if it is the first prefill (only True on the first chunk) 204 | first_prefill = seqpos[0] == 0 205 | # Indicates if it is a subsequent prefill (True from second chunk onwards), but False when generating tokens. 206 | subsequent_prefill = any(seqlen > 1 for seqlen in seqlens) 207 | 208 | if first_prefill: 209 | # For first chunk of prompt. It creates an attention mask that is causal for each prompt and also local based on the sliding window size 210 | # https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.BlockDiagonalMask + local attention based on the sliding window 211 | assert all([pos == 0 for pos in seqpos]), (seqpos) 212 | mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.sliding_window) 213 | elif subsequent_prefill: 214 | # For subsequent chunks of prompt 215 | mask = BlockDiagonalMask.from_seqlens( 216 | q_seqlen=seqlens, # Size of the query 217 | kv_seqlen=[s + cached_s.clamp(max=self.sliding_window).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)] # The total number of keys and values will be the incoming sequence length + the cached elements 218 | ).make_local_attention_from_bottomright(self.sliding_window) 219 | else: # For token generation 220 | mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( 221 | q_seqlen=seqlens, # Size of the query 222 | kv_padding=self.sliding_window, 223 | kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.sliding_window).tolist() # The total number of keys and values will be the incoming sequence length + the cached elements 224 | ) 225 | 226 | return RotatingCacheInputMetadata( 227 | positions=positions, 228 | to_cache_mask=to_cache_mask, 229 | cached_elements=cached_elements, 230 | cache_positions=cache_positions[to_cache_mask], 231 | prefill=first_prefill or subsequent_prefill, 232 | mask=mask, 233 | seqlens=seqlens, 234 | ) 235 | -------------------------------------------------------------------------------- /mistral/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import List, Optional 7 | 8 | import torch 9 | from torch import nn 10 | from simple_parsing.helpers import Serializable 11 | 12 | from mistral.rope import precompute_freqs_cis, apply_rotary_emb 13 | from mistral.cache import CacheView, RotatingBufferCache 14 | from mistral.moe import MoeArgs, MoeLayer 15 | 16 | from xformers.ops.fmha import memory_efficient_attention 17 | 18 | 19 | @dataclass 20 | class ModelArgs(Serializable): 21 | dim: int 22 | n_layers: int 23 | head_dim: int 24 | hidden_dim: int 25 | n_heads: int 26 | n_kv_heads: int 27 | norm_eps: float 28 | vocab_size: int 29 | 30 | max_batch_size: int = 0 31 | 32 | # For rotary embeddings. If not set, will be infered from sliding window. 33 | rope_theta: Optional[float] = None 34 | # If this is set, use sliding window attention rotating cache. 35 | sliding_window: Optional[int] = None 36 | # If this is set, we will use MoE layers instead of dense layers. 37 | moe: Optional[MoeArgs] = None 38 | 39 | 40 | @dataclass 41 | class SimpleInputMetadata: 42 | # rope absolute positions 43 | positions: torch.Tensor 44 | 45 | @staticmethod 46 | def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata": 47 | return SimpleInputMetadata( 48 | positions=torch.cat([torch.arange(0, seqlen) for seqlen in seqlens]).to( 49 | device=device, dtype=torch.long 50 | ) 51 | ) 52 | 53 | 54 | def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int): 55 | # repeat_intrleave repeats the given dimension like this: x = torch.tensor([1, 2, 3]) --> torch.repeat_interleave(x, repeats=2, dim=0) --> torch.tensor([1, 1, 2, 2, 3, 3]) 56 | # This is used to repeat the keys and values to match the number of query heads (Grouped Query Attention). 57 | keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) # (Seq, N_Heads_KV, Head_Dim) --> (Seq, N_Heads, Head_Dim) 58 | values = torch.repeat_interleave(values, repeats=repeats, dim=dim) # (Seq, N_Heads_KV, Head_Dim) --> (Seq, N_Heads, Head_Dim) 59 | return keys, values 60 | 61 | 62 | class Attention(nn.Module): 63 | def __init__(self, args: ModelArgs): 64 | super().__init__() 65 | self.args = args 66 | 67 | self.n_heads: int = args.n_heads 68 | self.head_dim: int = args.head_dim 69 | self.n_kv_heads: int = args.n_kv_heads 70 | 71 | self.repeats = self.n_heads // self.n_kv_heads 72 | 73 | self.scale = self.args.head_dim**-0.5 74 | 75 | self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False) 76 | self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) 77 | self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False) 78 | self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False) 79 | 80 | def forward( 81 | self, 82 | x: torch.Tensor, 83 | freqs_cis: torch.Tensor, 84 | cache: Optional[CacheView], 85 | ) -> torch.Tensor: 86 | seqlen_sum, _ = x.shape # (Seq, Dim) 87 | 88 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) # (Seq, Dim) --> (Seq, Dim) 89 | xq = xq.view(seqlen_sum, self.n_heads, self.head_dim) # (Seq, Dim) --> (Seq, N_Heads, Head_Dim) 90 | xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim) # (Seq, Dim) --> (Seq, N_Heads_KV, Head_Dim) 91 | xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim) # (Seq, Dim) --> (Seq, N_Heads_KV, Head_Dim) 92 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # (Seq, N_Heads, Head_Dim), (Seq, N_Heads_KV, Head_Dim) 93 | 94 | if cache is None: 95 | key, val = xk, xv 96 | elif cache.prefill: 97 | key, val = cache.interleave_kv(xk, xv) # Appends the incoming keys and values to the currently cached keys and values (because we need to use them for attention) 98 | cache.update(xk, xv) # Add the incoming keys and values to the cache in the positions indicated by metadata.cache_positions 99 | else: 100 | cache.update(xk, xv) 101 | key, val = cache.key, cache.value # Retrieve the cached keys and values (including the newly added ones) 102 | key = key.view( 103 | seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim 104 | ) 105 | val = val.view( 106 | seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim 107 | ) 108 | 109 | # Repeat keys and values to match number of query heads 110 | key, val = repeat_kv(key, val, self.repeats, dim=1) 111 | 112 | # xformers requires (B=1, S, H, D) 113 | xq, key, val = xq[None, ...], key[None, ...], val[None, ...] 114 | output = memory_efficient_attention( 115 | xq, key, val, None if cache is None else cache.mask 116 | ) # Output: (B=1, Seq, N_Heads, Head_Dim) 117 | # (B=1, Seq, N_Heads, Head_Dim) --> (Seq, N_Heads * Head_Dim) --> (Seq, Dim) 118 | return self.wo(output.view(seqlen_sum, self.n_heads * self.head_dim)) 119 | 120 | 121 | class FeedForward(nn.Module): 122 | def __init__(self, args: ModelArgs): 123 | super().__init__() 124 | 125 | self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) 126 | self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) 127 | self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) 128 | 129 | def forward(self, x) -> torch.Tensor: 130 | return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) 131 | 132 | 133 | class RMSNorm(torch.nn.Module): 134 | def __init__(self, dim: int, eps: float = 1e-6): 135 | super().__init__() 136 | self.eps = eps 137 | self.weight = nn.Parameter(torch.ones(dim)) 138 | 139 | def _norm(self, x): 140 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 141 | 142 | def forward(self, x): 143 | output = self._norm(x.float()).type_as(x) 144 | return output * self.weight 145 | 146 | 147 | class TransformerBlock(nn.Module): 148 | def __init__(self, args: ModelArgs): 149 | super().__init__() 150 | self.n_heads = args.n_heads 151 | self.dim = args.dim 152 | self.attention = Attention(args) 153 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 154 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 155 | self.args = args 156 | 157 | self.feed_forward: nn.Module 158 | if args.moe is not None: 159 | self.feed_forward = MoeLayer( 160 | experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)], 161 | gate=nn.Linear(args.dim, args.moe.num_experts, bias=False), 162 | moe_args=args.moe, 163 | ) 164 | else: 165 | self.feed_forward = FeedForward(args=args) 166 | 167 | def forward( 168 | self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView] 169 | ) -> torch.Tensor: 170 | r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) 171 | h = x + r 172 | r = self.feed_forward.forward(self.ffn_norm(h)) 173 | out = h + r 174 | return out 175 | 176 | 177 | class Transformer(nn.Module): 178 | def __init__( 179 | self, 180 | args: ModelArgs, 181 | pipeline_rank: int = 0, 182 | num_pipeline_ranks: int = 1, 183 | ): 184 | super().__init__() 185 | self.args = args 186 | self.vocab_size = args.vocab_size 187 | self.n_layers = args.n_layers 188 | self._precomputed_freqs_cis: Optional[torch.Tensor] = None 189 | assert self.vocab_size > 0 190 | assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks) 191 | self.pipeline_rank = pipeline_rank 192 | self.num_pipeline_ranks = num_pipeline_ranks 193 | # Modules specific to some ranks: 194 | self.tok_embeddings: Optional[nn.Embedding] = None 195 | self.norm: Optional[RMSNorm] = None 196 | self.output: Optional[nn.Linear] = None 197 | if pipeline_rank == 0: 198 | self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) 199 | if pipeline_rank == num_pipeline_ranks - 1: 200 | self.norm = RMSNorm(args.dim, eps=args.norm_eps) 201 | self.output = nn.Linear(args.dim, args.vocab_size, bias=False) 202 | # Initialize all layers but slice off those not of this rank. 203 | layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] 204 | num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks) 205 | offset = self.pipeline_rank * num_layers_per_rank 206 | end = min(self.n_layers, offset + num_layers_per_rank) 207 | # A dictionary that defines which layers are present in the current rank 208 | self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)}) 209 | self.n_local_layers = len(self.layers) 210 | 211 | @property 212 | def dtype(self) -> torch.dtype: 213 | return next(self.parameters()).dtype 214 | 215 | @property 216 | def device(self) -> torch.device: 217 | return next(self.parameters()).device 218 | 219 | @property 220 | def freqs_cis(self) -> torch.Tensor: 221 | # We cache freqs_cis but need to take care that it is on the right device 222 | # and has the right dtype (complex64). The fact that the dtype is different 223 | # from the module's dtype means we cannot register it as a buffer 224 | if self._precomputed_freqs_cis is None: 225 | # If no sliding window, assume a larger seqlen 226 | theta = self.args.rope_theta 227 | if theta is None: 228 | theta = 1000000.0 if self.args.sliding_window is None else 10000.0 229 | # theta = 10000. 230 | self._precomputed_freqs_cis = precompute_freqs_cis( 231 | self.args.head_dim, 128_000, theta 232 | ) 233 | if self._precomputed_freqs_cis.device != self.device: 234 | self._precomputed_freqs_cis = self._precomputed_freqs_cis.to( 235 | device=self.device 236 | ) 237 | return self._precomputed_freqs_cis 238 | 239 | def forward_partial( 240 | self, 241 | input_ids: torch.Tensor, # The concatenated tokens of the batch 242 | seqlens: List[int], # The sequence length of each input in the batch (so we can understand which token belongs to which prompt) 243 | cache: Optional[RotatingBufferCache] = None, 244 | ) -> torch.Tensor: 245 | """Local forward pass. 246 | 247 | If doing pipeline parallelism, this will return the activations of the last layer of this stage. 248 | For the last stage, this will return the normalized final embeddings. 249 | """ 250 | assert ( 251 | len(seqlens) <= self.args.max_batch_size 252 | ), f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}" 253 | (num_toks,) = input_ids.shape 254 | assert sum(seqlens) == num_toks, (sum(seqlens), num_toks) 255 | if cache is not None: 256 | # Generate the attention mask based on the current stage: first pre-fill, subsequent pre-fill or token generation. 257 | input_metadata = cache.get_input_metadata(seqlens) 258 | else: 259 | # If we do not use the cache, then just return the positions of the tokens to be used for RoPE 260 | input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device) 261 | 262 | if self.pipeline_rank == 0: 263 | # Only the first GPU will take care of the embeddings 264 | assert self.tok_embeddings is not None 265 | h = self.tok_embeddings(input_ids) # Transform the tokens into embeddings 266 | else: 267 | h = torch.empty( 268 | num_toks, self.args.dim, device=self.device, dtype=self.dtype 269 | ) 270 | # Subsequent GPUs will receive the embeddings from the previous GPU 271 | torch.distributed.recv(h, src=self.pipeline_rank - 1) 272 | 273 | freqs_cis = self.freqs_cis[input_metadata.positions] 274 | 275 | # Apply each layer iteratively 276 | for local_layer_id, layer in enumerate(self.layers.values()): 277 | if cache is not None: 278 | assert input_metadata is not None 279 | # Retrieves the KV cache for the current layer 280 | cache_view = cache.get_view(local_layer_id, input_metadata) 281 | else: 282 | cache_view = None 283 | h = layer(h, freqs_cis, cache_view) 284 | 285 | if cache is not None: 286 | cache.update_seqlens(seqlens) # Updates the total sequence length so far seen by the cache among all the iterations 287 | if self.pipeline_rank < self.num_pipeline_ranks - 1: 288 | # After all the layers for the current GPU have been applied, send the output to the next GPU 289 | torch.distributed.send(h, dst=self.pipeline_rank + 1) 290 | return h 291 | else: 292 | # Last rank has a final normalization step. 293 | assert self.norm is not None 294 | return self.norm(h) 295 | 296 | def forward( 297 | self, 298 | input_ids: torch.Tensor, # The concatenated tokens of the batch 299 | seqlens: List[int], # The sequence length of each input in the batch (so we can understand which token belongs to which prompt) 300 | cache: Optional[RotatingBufferCache] = None, 301 | ) -> torch.Tensor: 302 | h = self.forward_partial(input_ids, seqlens, cache=cache) 303 | if self.pipeline_rank < self.num_pipeline_ranks - 1: 304 | # ignore the intermediate activations as we'll get the final output from 305 | # the last stage 306 | outs = torch.empty( 307 | h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype 308 | ) 309 | else: 310 | assert self.output is not None 311 | outs = self.output(h) # Apply the output linear projection of the embeddings to the vocabulary size 312 | if self.num_pipeline_ranks > 1: 313 | torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1) 314 | return outs.float() 315 | 316 | def load_state_dict(self, state_dict, *args, **kwargs): 317 | state_to_load = {} 318 | skipped = set([]) 319 | for k, v in state_dict.items(): 320 | if k.startswith("tok_embeddings"): 321 | if self.pipeline_rank == 0: 322 | state_to_load[k] = v 323 | else: 324 | logging.debug( 325 | "Skipping parameter %s at pipeline rank %d", 326 | k, 327 | self.pipeline_rank, 328 | ) 329 | skipped.add(k) 330 | elif k.startswith("norm") or k.startswith("output"): 331 | if self.pipeline_rank == self.num_pipeline_ranks - 1: 332 | state_to_load[k] = v 333 | else: 334 | logging.debug( 335 | "Skipping parameter %s at pipeline rank %d", 336 | k, 337 | self.pipeline_rank, 338 | ) 339 | skipped.add(k) 340 | elif k.startswith("layers"): 341 | layer_id = k.split(".")[1] 342 | if layer_id in self.layers: 343 | state_to_load[k] = v 344 | else: 345 | logging.debug( 346 | "Skipping parameter %s at pipeline rank %d", 347 | k, 348 | self.pipeline_rank, 349 | ) 350 | skipped.add(k) 351 | else: 352 | raise ValueError(f"Unexpected key {k}") 353 | assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys())) 354 | super().load_state_dict(state_to_load, *args, **kwargs) 355 | 356 | @staticmethod 357 | def from_folder( 358 | folder: Path, 359 | max_batch_size: int = 1, 360 | num_pipeline_ranks: int = 1, 361 | device="cuda", 362 | dtype=torch.float16, 363 | ) -> "Transformer": 364 | with open(folder / "params.json", "r") as f: 365 | model_args = ModelArgs.from_dict(json.load(f)) 366 | model_args.max_batch_size = max_batch_size 367 | if num_pipeline_ranks > 1: 368 | pipeline_rank = torch.distributed.get_rank() 369 | else: 370 | pipeline_rank = 0 371 | with torch.device("meta"): 372 | model = Transformer( 373 | model_args, 374 | pipeline_rank=pipeline_rank, 375 | num_pipeline_ranks=num_pipeline_ranks, 376 | ) 377 | loaded = torch.load(str(folder / "consolidated.00.pth"), mmap=True) 378 | model.load_state_dict(loaded, assign=True) 379 | return model.to(device=device, dtype=dtype) 380 | -------------------------------------------------------------------------------- /mistral/moe.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from simple_parsing.helpers import Serializable 7 | from torch import nn 8 | 9 | 10 | @dataclasses.dataclass 11 | class MoeArgs(Serializable): 12 | num_experts: int 13 | num_experts_per_tok: int 14 | 15 | 16 | class MoeLayer(nn.Module): 17 | def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs): 18 | super().__init__() 19 | assert len(experts) > 0 20 | self.experts = nn.ModuleList(experts) 21 | self.gate = gate 22 | self.args = moe_args 23 | 24 | def forward(self, inputs: torch.Tensor): 25 | # For each token, generate `num_experts` logits indicating which expert to use. 26 | gate_logits = self.gate(inputs) 27 | # For each token, select the top `num_experts_per_tok` experts, and use them to compute 28 | weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok) 29 | # Apply the softmax to the logits AFTER selecting the top-k, this makes comparison with different hyperparams consitent. 30 | # Because even if we change the total number of experts or the number of experts per token, the sum of the weights will still be 1 for each token. 31 | weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) 32 | results = torch.zeros_like(inputs) 33 | for current_expert_index, current_expert in enumerate(self.experts): 34 | # For each expert, select which token it will be applied to. 35 | token_index, token_expert_index = torch.where(selected_experts == current_expert_index) 36 | # Apply the expert to the selected tokens weighting it by the logits (post-softmax) computed above . 37 | results[token_index] += weights[token_index, token_expert_index, None] * current_expert( 38 | inputs[token_index] 39 | ) 40 | return results 41 | -------------------------------------------------------------------------------- /mistral/rope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | 5 | def precompute_freqs_cis(dim: int, end: int, theta: float) -> torch.Tensor: 6 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 7 | t = torch.arange(end, device=freqs.device) # type: ignore 8 | freqs = torch.outer(t, freqs).float() # type: ignore 9 | return torch.polar(torch.ones_like(freqs), freqs) # complex64 10 | 11 | 12 | def apply_rotary_emb( 13 | xq: torch.Tensor, 14 | xk: torch.Tensor, 15 | freqs_cis: torch.Tensor, 16 | ) -> Tuple[torch.Tensor, torch.Tensor]: 17 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # (Seq, N_Heads, Head_Dim) --> (Seq, N_Heads, Head_Dim // 2) 18 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # (Seq, N_Heads, Head_Dim) --> (Seq, N_Heads, Head_Dim // 2) 19 | freqs_cis = freqs_cis[:, None, :] # (Seq, 1, Head_Dim // 2) 20 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) # (Seq, N_Heads, Head_Dim // 2) * (Seq, 1, Head_Dim // 2) --> (Seq, N_Heads, Head_Dim // 2) --> (Seq, N_Heads, Head_Dim) 21 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2) 22 | return xq_out.type_as(xq), xk_out.type_as(xk) 23 | -------------------------------------------------------------------------------- /mistral/tokenizer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from sentencepiece import SentencePieceProcessor 3 | from typing import List 4 | 5 | 6 | class Tokenizer: 7 | def __init__(self, model_path: str): 8 | assert Path(model_path).exists(), model_path 9 | self._model = SentencePieceProcessor(model_file=model_path) 10 | assert self._model.vocab_size() == self._model.get_piece_size() 11 | 12 | @property 13 | def n_words(self) -> int: 14 | return self._model.vocab_size() 15 | 16 | @property 17 | def bos_id(self) -> int: 18 | return self._model.bos_id() 19 | 20 | @property 21 | def eos_id(self) -> int: 22 | return self._model.eos_id() 23 | 24 | @property 25 | def pad_id(self) -> int: 26 | return self._model.pad_id() 27 | 28 | def encode(self, s: str, bos: bool = True) -> List[int]: 29 | assert isinstance(s, str) 30 | t = self._model.encode(s) 31 | if bos: 32 | t = [self.bos_id, *t] 33 | return t 34 | 35 | def decode(self, t: List[int]) -> str: 36 | return self._model.decode(t) 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fire 2 | sentencepiece 3 | torch>=2.1.0 4 | xformers 5 | simple-parsing --------------------------------------------------------------------------------