├── LICENSE.txt ├── README.md ├── llama.py └── quarantine ├── __pycache__ └── load.cpython-310.pyc └── load.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Just Large Models 2 | Hackable, with as little abstraction as possible. Done for my own purposes, feel free to rip. 3 | 4 | Every model should have its own runnable logic. Seperate. Not shared! Each file does A Thing. The adaptibility of huggingface's code is incredibly bad, due to over abstraction and incidental complexity. Therefore, DIY. Right now, I'm still improving it as I have time. 5 | 6 | This is designed to rely on HF for the files, and loading of the files. Otherwise, that is where the dependency ends. 7 | 8 | ## Rules: 9 | - The code is the tool. Edit it as you see fit! 10 | - All h*ggingface imports will be placed in quarantine. Model pass files will contain no references 11 | - I will not be addressing issues 12 | - Not a single kwargs shall be observed 13 | - One model forward pass = one function call. Simple as. 14 | 15 | ## Current models 16 | ``` 17 | python ./llama.py 18 | ``` 19 | -------------------------------------------------------------------------------- /llama.py: -------------------------------------------------------------------------------- 1 | # Why? I got a little annoyed w/ huggingface's.. code. So I ripped it out and fixed some over abstraction 2 | # The end goal is to remove all dependencies from huggingface. 3 | 4 | # Attribution: 5 | # Most of this code has been ripped out of huggingface, with great pain 6 | # Shout out to kaparthy's GPT & makemore series. Helped a lot 7 | # Shout out to llama implementation from meta 8 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 9 | 10 | import torch 11 | import math 12 | from box import Box # Forgive me 13 | from torch import nn 14 | from torch.nn.functional import silu 15 | from transformers.modeling_outputs import CausalLMOutputWithPast 16 | from quarantine.load import load_hf_llama_state_dict, load_sterilized_hf_llama_tokenizer 17 | 18 | def generate( 19 | input_ids = None, 20 | max_length = None, 21 | attention_mask = None, 22 | use_cache = None, 23 | model_as_fn = None, 24 | model_state_dict = None, 25 | pad_token_id = None, 26 | eos_token_id = None, 27 | ): 28 | eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) 29 | 30 | # We want to mark sequences that are finished 31 | unfinished_sequences = torch.ones( 32 | input_ids.shape[0], dtype=torch.long, device=input_ids.device) 33 | 34 | past_key_values = None 35 | position_ids = None 36 | 37 | while True: 38 | # todo: This used to be worse, but its still bad. continue healing 39 | if past_key_values: 40 | inference_state = {"input_ids": input_ids[:, -1:]} 41 | else: 42 | inference_state = {"input_ids": input_ids} 43 | 44 | if attention_mask is not None: 45 | # create position_ids on the fly for batch generation 46 | position_ids = attention_mask.long().cumsum(-1) - 1 47 | position_ids.masked_fill_(attention_mask == 0, 1) 48 | if past_key_values: 49 | position_ids = position_ids[:, -1].unsqueeze(-1) 50 | 51 | inference_state.update( 52 | { 53 | "position_ids": position_ids, 54 | "past_key_values": past_key_values, 55 | "use_cache": use_cache, 56 | "attention_mask": attention_mask, 57 | } 58 | ) 59 | 60 | outputs = model_as_fn( 61 | model_state_dict, 62 | input_ids = inference_state['input_ids'], 63 | position_ids = inference_state['position_ids'], 64 | past_key_values = inference_state['past_key_values'], 65 | use_cache = inference_state['use_cache'], 66 | attention_mask = inference_state['attention_mask'], 67 | ) 68 | 69 | next_token_logits = outputs.logits[:, -1, :] 70 | next_tokens = torch.argmax(next_token_logits, dim=-1) 71 | 72 | # finished sentences should have their next token be a padding token 73 | next_tokens = next_tokens * unfinished_sequences + \ 74 | pad_token_id * (1 - unfinished_sequences) 75 | 76 | # update generated ids, model inputs, and length for next step 77 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 78 | past_key_values = outputs.past_key_values 79 | 80 | attention_mask = torch.cat( 81 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 82 | ) 83 | 84 | unfinished_sequences = unfinished_sequences.mul( 85 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne( 86 | eos_token_id_tensor.unsqueeze(1)).prod(dim=0) 87 | ) 88 | 89 | # stop when each sentence is finished 90 | if unfinished_sequences.max() == 0: 91 | break 92 | 93 | # # stop if we exceed the maximum length 94 | cur_len = input_ids.shape[-1] 95 | is_done = cur_len >= max_length 96 | if is_done: 97 | break 98 | 99 | return input_ids 100 | 101 | def get_total_layers(model_state_dict): 102 | layer_keys = [key for key in model_state_dict.keys() if "model.layers." in key] 103 | layer_numbers = set([int(key.split('.')[2]) for key in layer_keys]) 104 | total_layers = max(layer_numbers) + 1 105 | return total_layers 106 | 107 | def get_layer_weights(layer_number, model_state_dict): 108 | layer_key = f"model.layers.{layer_number}" 109 | weights = { 110 | "self_attn": { 111 | "q_proj": model_state_dict[f"{layer_key}.self_attn.q_proj.weight"], 112 | "k_proj": model_state_dict[f"{layer_key}.self_attn.k_proj.weight"], 113 | "v_proj": model_state_dict[f"{layer_key}.self_attn.v_proj.weight"], 114 | "o_proj": model_state_dict[f"{layer_key}.self_attn.o_proj.weight"], 115 | }, 116 | "mlp": { 117 | "gate_proj": model_state_dict[f"{layer_key}.mlp.gate_proj.weight"], 118 | "up_proj": model_state_dict[f"{layer_key}.mlp.up_proj.weight"], 119 | "down_proj": model_state_dict[f"{layer_key}.mlp.down_proj.weight"], 120 | }, 121 | "input_layernorm": model_state_dict[f"{layer_key}.input_layernorm.weight"], 122 | "post_attention_layernorm": model_state_dict[f"{layer_key}.post_attention_layernorm.weight"] 123 | } 124 | 125 | ret = Box(weights) 126 | return ret 127 | 128 | 129 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 130 | def rotate_half(x): 131 | """Rotates half the hidden dims of the input.""" 132 | x1 = x[..., : x.shape[-1] // 2] 133 | x2 = x[..., x.shape[-1] // 2:] 134 | return torch.cat((-x2, x1), dim=-1) 135 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 136 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 137 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 138 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 139 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 140 | q_embed = (q * cos) + (rotate_half(q) * sin) 141 | k_embed = (k * cos) + (rotate_half(k) * sin) 142 | 143 | return q_embed, k_embed 144 | 145 | def rmsnorm(weights, hidden_states, rms_norm_epsilon=1e-05): 146 | input_dtype = hidden_states.dtype 147 | hidden_states = hidden_states.to(torch.float32) 148 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 149 | hidden_states = hidden_states * torch.rsqrt(variance + rms_norm_epsilon) 150 | return weights * hidden_states.to(input_dtype) 151 | 152 | def model_as_fn( 153 | model_state_dict, 154 | input_ids, 155 | position_ids, 156 | past_key_values, 157 | use_cache, 158 | attention_mask, 159 | ): 160 | B, T = input_ids.shape 161 | sequence_length_with_past = T 162 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values else 0 163 | sequence_length_with_past += past_key_values_length if past_key_values else 0 164 | 165 | embedding_layer = nn.Embedding.from_pretrained(model_state_dict["model.embed_tokens.weight"]) 166 | hidden_states = embedding_layer(input_ids) 167 | 168 | # Attention mask 169 | def prepare_attention_mask(attention_mask, input_shape, input_embeds, past_key_values_length): 170 | B, source_length = attention_mask.size() 171 | target_length = input_shape[-1] 172 | dtype = input_embeds.dtype 173 | device = input_embeds.device 174 | mask = torch.full((target_length, target_length), 175 | torch.finfo(dtype).min, device=device) 176 | mask_cond = torch.arange(mask.size(-1), device=device) 177 | mask.masked_fill_(mask_cond < ( 178 | mask_cond + 1).view(mask.size(-1), 1), 0) 179 | mask = mask.to(dtype) 180 | if past_key_values_length > 0: 181 | mask = torch.cat([torch.zeros( 182 | target_length, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 183 | combined_attention_mask = mask[None, None, :, :].expand( 184 | B, 1, target_length, target_length + past_key_values_length) 185 | expanded_mask = attention_mask[:, None, None, :].expand( 186 | B, 1, target_length, source_length).to(dtype) 187 | inverted_mask = 1.0 - expanded_mask 188 | expanded_attn_mask = inverted_mask.masked_fill( 189 | inverted_mask.to(torch.bool), torch.finfo(dtype).min) 190 | combined_attention_mask = ( 191 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + 192 | combined_attention_mask 193 | ) 194 | return combined_attention_mask 195 | 196 | attention_mask = prepare_attention_mask( 197 | attention_mask, (B, T), hidden_states, past_key_values_length) 198 | 199 | next_decoder_cache = () if use_cache else None 200 | 201 | total_layers = get_total_layers(model_state_dict) 202 | for i in range(total_layers): 203 | decoder_layer = get_layer_weights(i, model_state_dict) 204 | past_key_value = past_key_values[i] if past_key_values is not None else None 205 | residual = hidden_states 206 | hidden_states = rmsnorm(decoder_layer.input_layernorm, hidden_states) 207 | 208 | # Self Attention 209 | def self_attn( 210 | layer_weights, 211 | hidden_states, 212 | attention_mask, 213 | position_ids, 214 | past_key_value, 215 | use_cache, 216 | num_key_value_heads=32, 217 | num_attention_heads=32, 218 | hidden_size=4096, 219 | max_position_embeddings=4096, 220 | rope_theta=10000 221 | ): 222 | head_dim = hidden_size // num_attention_heads 223 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 224 | """ 225 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 226 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 227 | # so why is it differetnt? 228 | # TODO: understand and avoid ugly 229 | """ 230 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 231 | if n_rep == 1: 232 | return hidden_states 233 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 234 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 235 | 236 | B, q_len, _ = hidden_states.size() 237 | query_states = hidden_states @ layer_weights.q_proj.T 238 | key_states = hidden_states @ layer_weights.k_proj.T 239 | value_states = hidden_states @ layer_weights.v_proj.T 240 | 241 | query_states = query_states.view( 242 | B, q_len, num_attention_heads, head_dim).transpose(1, 2) 243 | key_states = key_states.view( 244 | B, q_len, num_key_value_heads, head_dim).transpose(1, 2) 245 | value_states = value_states.view( 246 | B, q_len, num_key_value_heads, head_dim).transpose(1, 2) 247 | 248 | kv_seq_len = key_states.shape[-2] 249 | if past_key_value is not None: 250 | kv_seq_len += past_key_value[0].shape[-2] 251 | 252 | def get_rotary_embedding(value_states, dim, seq_len, max_position_embeddings, device, base=10000): 253 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 254 | t = torch.arange(max_position_embeddings, device=device, dtype=inv_freq.dtype) 255 | freqs = torch.einsum("i,j->ij", t, inv_freq) 256 | emb = torch.cat((freqs, freqs), dim=-1) 257 | # TODO: Avoid recomputing somehow 258 | # How much faster is it *really*? 259 | cos = emb.cos()[None, None, :, :].to(inv_freq.dtype) 260 | sin = emb.sin()[None, None, :, :].to(inv_freq.dtype) 261 | return cos[:, :, :seq_len, ...].to(dtype=value_states.dtype), sin[:, :, :seq_len, ...].to(dtype=value_states.dtype) 262 | 263 | cos, sin = get_rotary_embedding( 264 | value_states, 265 | head_dim, 266 | max_position_embeddings=max_position_embeddings, 267 | base=rope_theta, 268 | device=input_ids.device, 269 | seq_len=kv_seq_len 270 | ) 271 | 272 | query_states, key_states = apply_rotary_pos_emb( 273 | query_states, key_states, cos, sin, position_ids) 274 | 275 | if past_key_value is not None: 276 | # reuse k, v, self_attention 277 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 278 | value_states = torch.cat( 279 | [past_key_value[1], value_states], dim=2) 280 | 281 | past_key_value = (key_states, value_states) if use_cache else None 282 | 283 | num_key_value_groups = num_attention_heads // num_key_value_heads 284 | # Should be nil op in the case of 7b 285 | key_states = repeat_kv(key_states, num_key_value_groups) 286 | value_states = repeat_kv(value_states, num_key_value_groups) 287 | attn_weights = torch.matmul( 288 | query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim) 289 | 290 | if attention_mask is not None: 291 | attn_weights = attn_weights + attention_mask 292 | 293 | # upcast attention to fp32 294 | # TODO: Why? Not sure why! I cargo'd from huggingface 295 | attn_weights = nn.functional.softmax( 296 | attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 297 | attn_output = torch.matmul(attn_weights, value_states) 298 | 299 | attn_output = attn_output.transpose(1, 2).contiguous() 300 | attn_output = attn_output.reshape(B, q_len, hidden_size) 301 | 302 | attn_output = attn_output @ layer_weights.o_proj.T 303 | 304 | return attn_output, past_key_value 305 | 306 | hidden_states, present_key_value = self_attn( 307 | decoder_layer.self_attn, 308 | hidden_states=hidden_states, 309 | attention_mask=attention_mask, 310 | position_ids=position_ids, 311 | past_key_value=past_key_value, 312 | use_cache=use_cache, 313 | ) 314 | 315 | hidden_states = residual + hidden_states 316 | residual = hidden_states 317 | hidden_states = rmsnorm(decoder_layer.post_attention_layernorm, hidden_states) 318 | 319 | def mlp_new(layer_weights, x): 320 | up_proj_output = x @ layer_weights.up_proj.T 321 | gate_proj_output = x @ layer_weights.gate_proj.T 322 | down_proj_input = silu(gate_proj_output) * up_proj_output 323 | down_proj_output = down_proj_input @ layer_weights.down_proj.T 324 | return down_proj_output 325 | 326 | hidden_states = mlp_new(decoder_layer.mlp, hidden_states) 327 | hidden_states = residual + hidden_states 328 | layer_outputs = (hidden_states,) 329 | 330 | if use_cache: 331 | layer_outputs += (present_key_value,) 332 | 333 | hidden_states = layer_outputs[0] 334 | 335 | if use_cache: 336 | next_decoder_cache += ( 337 | layer_outputs[1],) 338 | 339 | norm_weights = model_state_dict["model.norm.weight"] 340 | lm_head_weights = model_state_dict["lm_head.weight"] 341 | hidden_states = rmsnorm(norm_weights, hidden_states) 342 | 343 | logits = hidden_states @ lm_head_weights.T 344 | 345 | return CausalLMOutputWithPast( 346 | logits=logits, 347 | past_key_values=next_decoder_cache, 348 | ) 349 | 350 | model_path = "/home/kache/models/llama2/7bhf/model" 351 | tokenizer_path = "/home/kache/models/llama2/7bhf/tokenizer" 352 | tokenizer = load_sterilized_hf_llama_tokenizer(tokenizer_path) 353 | model_state_dict = load_hf_llama_state_dict(model_path) 354 | 355 | prompts = [ 356 | "I went to the store the other day", 357 | "For what its worth, I really dont think that you should" 358 | ] 359 | inputs, attention_mask = tokenizer.encode(prompts) 360 | outputs = generate( 361 | input_ids = inputs, 362 | max_length = 100, 363 | attention_mask = attention_mask, 364 | use_cache = True, 365 | model_as_fn = model_as_fn, 366 | model_state_dict = model_state_dict, 367 | eos_token_id=tokenizer.eos_token, 368 | pad_token_id=tokenizer.pad_token, 369 | ) 370 | print(tokenizer.decode(outputs[0], skip_special_tokens=True)) 371 | print(tokenizer.decode(outputs[1], skip_special_tokens=True)) 372 | -------------------------------------------------------------------------------- /quarantine/__pycache__/load.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yacineMTB/just-large-models/f1526aae256879e7f949bfece82f9075ef8ae4a6/quarantine/__pycache__/load.cpython-310.pyc -------------------------------------------------------------------------------- /quarantine/load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import LlamaTokenizer, LlamaForCausalLM 4 | from transformers import BitsAndBytesConfig 5 | 6 | 7 | # Takes a path of the model 8 | # Returns a state dict of the model 9 | def load_hf_llama_state_dict(model_path): 10 | # TODO: Handle bnb at seperate layer 11 | bnb_config = BitsAndBytesConfig( 12 | load_in_4bit=True, 13 | bnb_4bit_use_double_quant=True, 14 | bnb_4bit_quant_type="nf4", 15 | bnb_4bit_compute_dtype=torch.bfloat16 16 | ) 17 | model = LlamaForCausalLM.from_pretrained(model_path, device_map={"": 0}, torch_dtype=torch.float16) 18 | return model.state_dict() 19 | 20 | class SterilizedHFLlamaTokenizer: 21 | def __init__(self, tokenizer): 22 | self.tokenizer = tokenizer 23 | self.eos_token = tokenizer.eos_token_id 24 | self.pad_token = tokenizer.pad_token_id 25 | 26 | def encode(self, text, return_tensors="pt", padding=True, device=0): 27 | output = self.tokenizer(text, return_tensors=return_tensors, padding=padding).to(device) 28 | return output['input_ids'], output['attention_mask'] 29 | 30 | 31 | def decode(self, outputs, skip_special_tokens=True): 32 | return self.tokenizer.decode(outputs, skip_special_tokens=skip_special_tokens) 33 | 34 | # Takes a path of the tokenizer model 35 | # Returns an object, with two functions, decode and encode 36 | # Decode returns a dictionary of things. I trust that my child assigns them to adequately named variables.. 37 | def load_sterilized_hf_llama_tokenizer(model_path): 38 | tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True) 39 | tokenizer.pad_token = tokenizer.eos_token 40 | return SterilizedHFLlamaTokenizer(tokenizer) 41 | 42 | # Usage 43 | # model_path = "/home/kache/models/llama1/7bhf/model" 44 | # tokenizer_path = "/home/kache/models/llama1/7bhf/tokenizer" 45 | # sterilized_tokenizer = load_sterilized_hf_llama_tokenizer(tokenizer_path) 46 | # model = load_hf_llama_state_dict(model_path) --------------------------------------------------------------------------------