├── images ├── overview.png └── scenario.png ├── requirements.txt ├── llama ├── __init__.py ├── tokenizer.py ├── generation.py └── model.py ├── LICENSE ├── task_decomposition.py ├── code_generation.py ├── fine_tune.py └── README.md /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lemingshen/GPIoT/HEAD/images/overview.png -------------------------------------------------------------------------------- /images/scenario.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lemingshen/GPIoT/HEAD/images/scenario.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | fairscale 3 | fire 4 | sentencepiece 5 | trl 6 | peft 7 | transformers 8 | datasets 9 | wandb -------------------------------------------------------------------------------- /llama/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from .generation import Llama, Dialog 5 | from .model import ModelArgs, Transformer 6 | from .tokenizer import Tokenizer 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Leming Shen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /task_decomposition.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoModelForCausalLM, 3 | AutoTokenizer, 4 | ) 5 | 6 | 7 | def data_transform(message_list): 8 | system_message = message_list[0]["content"] 9 | user_message = message_list[1]["content"] 10 | 11 | output = "[INST] <>\n{}\n<>\n\n{}[/INST]".format( 12 | system_message, user_message.strip() 13 | ) 14 | 15 | return output 16 | 17 | 18 | base_model_dir = "meta-llama/Llama-2-13b-chat-hf" 19 | lora_model_dir = "GPIoT_Task_Decomposition/checkpoint-13400" 20 | 21 | tokenizer = AutoTokenizer.from_pretrained(base_model_dir) 22 | model = AutoModelForCausalLM.from_pretrained(lora_model_dir) 23 | model = model.to("cuda") 24 | model.eval() 25 | 26 | prompt = [ 27 | { 28 | "role": "system", 29 | "content": "You are a professional IoT application developer. According to the user problem, you need to decompose the problem into multiple steps with implementation details. The output must be in the format of:\n\n1. description and implementation details of step 1\n\n2. description and implementation details of step 2\n\n......", 30 | }, 31 | { 32 | "role": "user", 33 | "content": "Decompose the following task into multiple steps and describe the implementation details for each step. The task is to maximize throughput by digitally compensating for wireless impairments and removing residual interference from the transmit chain. ", 34 | }, 35 | ] 36 | 37 | input_text = data_transform(prompt) 38 | input_ids = tokenizer(input_text, return_tensors="pt") 39 | output = model.generate(input_ids, max_length=1024, num_return_sequences=1) 40 | output_text = tokenizer.decode(output[0], skip_special_tokens=True) 41 | print(output_text) 42 | -------------------------------------------------------------------------------- /code_generation.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoModelForCausalLM, 3 | AutoTokenizer, 4 | ) 5 | 6 | 7 | def data_transform(message_list): 8 | system_message = message_list[0]["content"] 9 | user_message = message_list[1]["content"] 10 | 11 | output = "[INST] <>\n{}\n<>\n\n{}[/INST]".format( 12 | system_message, user_message.strip() 13 | ) 14 | 15 | return output 16 | 17 | 18 | base_model_dir = "meta-llama/Llama-2-13b-chat-hf" 19 | lora_model_dir = "GPIoT_Code_Generation/checkpoint-13000" 20 | 21 | tokenizer = AutoTokenizer.from_pretrained(base_model_dir) 22 | model = AutoModelForCausalLM.from_pretrained(lora_model_dir) 23 | model = model.to("cuda") 24 | model.eval() 25 | 26 | prompt = [ 27 | { 28 | "role": "system", 29 | "content": "You are a professional and skillful Python programmer, especially in the field of communication, signal processing, and machine learning. According to the user instruction, you need to generate one single Python function with detailed comments and documentation. The documentation should be in the Markdown format.", 30 | }, 31 | { 32 | "role": "user", 33 | "content": "**Target**\nDefine a Python function to create a simple augmentation pipeline for image processing and provide detailed code comments.\n\n**Input Specifications**\n- `image_path` (str): The file path to the input image.\n\n**Output specifications**\nThe function does not explicitly return any value but visualizes the original and augmented images using matplotlib.", 34 | }, 35 | ] 36 | 37 | input_text = data_transform(prompt) 38 | input_ids = tokenizer(input_text, return_tensors="pt") 39 | output = model.generate(input_ids, max_length=1024, num_return_sequences=1) 40 | output_text = tokenizer.decode(output[0], skip_special_tokens=True) 41 | print(output_text) 42 | -------------------------------------------------------------------------------- /llama/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import os 5 | from logging import getLogger 6 | from typing import List 7 | 8 | from sentencepiece import SentencePieceProcessor 9 | 10 | 11 | logger = getLogger() 12 | 13 | 14 | class Tokenizer: 15 | """tokenizing and encoding/decoding text using SentencePiece.""" 16 | def __init__(self, model_path: str): 17 | """ 18 | Initializes the Tokenizer with a SentencePiece model. 19 | 20 | Args: 21 | model_path (str): The path to the SentencePiece model file. 22 | """ 23 | # reload tokenizer 24 | assert os.path.isfile(model_path), model_path 25 | self.sp_model = SentencePieceProcessor(model_file=model_path) 26 | logger.info(f"Reloaded SentencePiece model from {model_path}") 27 | 28 | # BOS / EOS token IDs 29 | self.n_words: int = self.sp_model.vocab_size() 30 | self.bos_id: int = self.sp_model.bos_id() 31 | self.eos_id: int = self.sp_model.eos_id() 32 | self.pad_id: int = self.sp_model.pad_id() 33 | logger.info( 34 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 35 | ) 36 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 37 | 38 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 39 | """ 40 | Encodes a string into a list of token IDs. 41 | 42 | Args: 43 | s (str): The input string to be encoded. 44 | bos (bool): Whether to prepend the beginning-of-sequence token. 45 | eos (bool): Whether to append the end-of-sequence token. 46 | 47 | Returns: 48 | List[int]: A list of token IDs. 49 | """ 50 | assert type(s) is str 51 | t = self.sp_model.encode(s) 52 | if bos: 53 | t = [self.bos_id] + t 54 | if eos: 55 | t = t + [self.eos_id] 56 | return t 57 | 58 | def decode(self, t: List[int]) -> str: 59 | """ 60 | Decodes a list of token IDs into a string. 61 | 62 | Args: 63 | t (List[int]): The list of token IDs to be decoded. 64 | 65 | Returns: 66 | str: The decoded string. 67 | """ 68 | return self.sp_model.decode(t) 69 | -------------------------------------------------------------------------------- /fine_tune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import wandb 4 | from datasets import load_from_disk 5 | from trl import SFTTrainer 6 | from peft import LoraConfig, prepare_model_for_kbit_training 7 | from transformers import ( 8 | AutoModelForCausalLM, 9 | AutoTokenizer, 10 | BitsAndBytesConfig, 11 | TrainingArguments, 12 | pipeline, 13 | logging, 14 | ) 15 | 16 | wandb.login(key="your key") 17 | 18 | base_model = "meta-llama/Llama-2-13b-chat-hf" 19 | # new_model = "GPIoT_Code_Generation" 20 | new_model = "GPIoT_Task_Decomposition" 21 | # dataset = load_from_disk("dataset/Code_Generation_dataset") 22 | dataset = load_from_disk("dataset/Task_Decomposition_dataset") 23 | 24 | compute_dtype = getattr(torch, "float16") 25 | 26 | quantization_config = BitsAndBytesConfig( 27 | load_in_8bit=True, 28 | bnb_8bit_quant_type="nf8", 29 | bnb_8bit_compute_dtype=compute_dtype, 30 | bnb_8bit_use_double_quant=False, 31 | ) 32 | 33 | model = AutoModelForCausalLM.from_pretrained( 34 | base_model, quantization_config=quantization_config 35 | ) 36 | model.config.use_cache = False 37 | model = prepare_model_for_kbit_training(model) 38 | 39 | tokenizer = AutoTokenizer.from_pretrained(base_model) 40 | tokenizer.pad_token = tokenizer.eos_token 41 | tokenizer.padding_side = "right" 42 | 43 | peft_parameters = LoraConfig( 44 | r=64, lora_alpha=16, lora_dropout=0.001, bias="lora_only", task_type="CAUSAL_LM" 45 | ) 46 | 47 | training_params = TrainingArguments( 48 | # output_dir="GPIoT_Code_Generation", 49 | output_dir="GPIoT_Task_Decomposition", 50 | num_train_epochs=3, 51 | per_device_train_batch_size=2, 52 | per_device_eval_batch_size=2, 53 | gradient_accumulation_steps=4, 54 | optim="paged_adamw_32bit", 55 | save_steps=200, 56 | logging_steps=25, 57 | learning_rate=2e-4, 58 | weight_decay=0.001, 59 | fp16=True, 60 | bf16=False, 61 | max_grad_norm=0.3, 62 | max_steps=-1, 63 | warmup_ratio=0.03, 64 | group_by_length=True, 65 | lr_scheduler_type="cosine", 66 | report_to="wandb", 67 | ) 68 | trainer = SFTTrainer( 69 | model=model, 70 | train_dataset=dataset, 71 | # eval_dataset=dataset["test"], 72 | peft_config=peft_parameters, 73 | dataset_text_field="data", 74 | max_seq_length=None, 75 | tokenizer=tokenizer, 76 | args=training_params, 77 | packing=False, 78 | ) 79 | 80 | trainer.train() 81 | trainer.model.save_pretrained(new_model) 82 | trainer.tokenizer.save_pretrained(new_model) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # # GPIoT: Tailoring Small Language Models for IoT Program Synthesis and Development 2 | - Foundation model: Llama-2-13b-chat-hf 3 | - LoRA fine tuned with INT8 quantization 4 | 5 | 6 | 7 | [![r](https://img.shields.io/badge/access-paper-blue)](https://lemingshen.github.io/assets/publication/conference/GPIoT/paper.pdf)   [![](https://img.shields.io/badge/visit_our-website-red)](https://lemingshen.github.io/projects/gpiot/)   [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 8 | 9 | Code Large Language Models (LLMs) enhance software development efficiency by automatically generating code and documentation in response to user requirements. However, code LLMs cannot synthesize specialized programs when tasked with IoT applications that require domain knowledge. While Retrieval-Augmented Generation (RAG) offers a promising solution by fetching relevant domain knowledge, it necessitates powerful cloud LLMs (e.g., GPT-4) to process user requirements and retrieved contents, which raises significant privacy concerns. This approach also suffers from unstable networks and prohibitive LLM query costs. Moreover, it is challenging to ensure the correctness and relevance of the fetched contents. To address these issues, we propose GPIoT, a code generation system for IoT applications by fine-tuning locally deployable Small Language Models (SLMs) on IoT-specialized datasets. SLMs have smaller model sizes, allowing efficient local deployment and execution to mitigate privacy concerns and network uncertainty. Furthermore, by fine-tuning the SLMs with our IoT-specialized datasets, the SLMs' ability to synthesize IoT-related programs can be substantially improved. To evaluate GPIoT's capability in synthesizing programs for IoT applications, we develop a benchmark, IoTBench. Extensive experiments and user trials demonstrate the effectiveness of GPIoT in generating IoT-specialized code, outperforming state-of-the-art code LLMs with an average task accuracy increment of 64.7% and significant improvements in user satisfaction. 10 | 11 | 12 | 13 | ## AutoIOT Overview 14 | 15 | ![System overview of FedConv](images/overview.png) 16 | 17 | ## Quick Start 18 | ### 1. Installation 19 | ```bash 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ### 2. Dataset Preparation 24 | - Our dataset for task decomposition: [link](https://huggingface.co/datasets/lemingshen/GPIoT_Task_Decomposition). 25 | - Our dataset for code generation: [link](https://huggingface.co/datasets/lemingshen/GPIoT_Code_Generation). 26 | 27 | 28 | ### 3. Model File 29 | - Our fine-tuned model for task decomposition: [link](https://huggingface.co/lemingshen/GPIoT_Task_Decomposition). 30 | - Our fine-tuned model for code generation: [link](https://huggingface.co/lemingshen/GPIoT_Code_Generation). 31 | 32 | ### 4. IoTBench 33 | - To evaluate LLM's capabilities in generating IoT-related programs, we develop an IoTBench 34 | - You can find download the benchmark here: [link](https://mypikpak.com/s/VOLPPwXhmHBnHMY7hW6oGHjMo1) 35 | 36 | ### 4. Download Foundation Model 37 | - Download the foundation model of `Llama2-13b-chat-hf` from this [link](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) 38 | - Make sure that the file structure looks like this: 39 | ``` 40 | |-- GPIoT 41 | |-- meta-llama 42 | |-- Llama-2-13b-chat-hf 43 | |-- .gitattributes 44 | |-- config.json 45 | |-- generation_config.json 46 | |-- LICENSE.txt 47 | |-- model.safetensors.index.json 48 | |-- model-00001-of-00003.safetensors 49 | |-- ...... 50 | ``` 51 | 52 | ### 5. Perform Task Decomposition 53 | - Modify the `task_decomposition.py` based on you IoT application and requirements 54 | - Directly execute the `task_decomposition.py` 55 | 56 | ### 6. Perform Code Generation 57 | - Modify the `code_generation.py` based on you IoT application and requirements 58 | - Directly execute the `code_generation.py` 59 | 60 | ### Please don't hesitate to reach out if you have any questions. 61 | 62 | ## Citation 63 | ``` 64 | @inproceedings{shen2025gpiot, 65 | title={GPIoT: Tailoring Small Language Models for IoT Program Synthesis and Development}, 66 | author={Shen, Leming and Yang, Qiang and Huang, Xinyu and Ma, Zijing and Zheng, Yuanqing}, 67 | booktitle={Proceedings of the 22nd ACM Conference on Embedded Networked Sensor Systems}, 68 | pages={1--14}, 69 | year={2025} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /llama/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import json 5 | import os 6 | import sys 7 | import time 8 | from pathlib import Path 9 | from typing import List, Literal, Optional, Tuple, TypedDict 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from fairscale.nn.model_parallel.initialize import ( 14 | get_model_parallel_rank, 15 | initialize_model_parallel, 16 | model_parallel_is_initialized, 17 | ) 18 | 19 | from llama.model import ModelArgs, Transformer 20 | from llama.tokenizer import Tokenizer 21 | 22 | Role = Literal["system", "user", "assistant"] 23 | 24 | 25 | class Message(TypedDict): 26 | role: Role 27 | content: str 28 | 29 | 30 | class CompletionPrediction(TypedDict, total=False): 31 | generation: str 32 | tokens: List[str] # not required 33 | logprobs: List[float] # not required 34 | 35 | 36 | class ChatPrediction(TypedDict, total=False): 37 | generation: Message 38 | tokens: List[str] # not required 39 | logprobs: List[float] # not required 40 | 41 | 42 | Dialog = List[Message] 43 | 44 | B_INST, E_INST = "[INST]", "[/INST]" 45 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 46 | 47 | SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"] 48 | UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt." 49 | 50 | 51 | class Llama: 52 | @staticmethod 53 | def build( 54 | ckpt_dir: str, 55 | tokenizer_path: str, 56 | max_seq_len: int, 57 | max_batch_size: int, 58 | model_parallel_size: Optional[int] = None, 59 | seed: int = 1, 60 | ) -> "Llama": 61 | """ 62 | Build a Llama instance by initializing and loading a pre-trained model. 63 | 64 | Args: 65 | ckpt_dir (str): Path to the directory containing checkpoint files. 66 | tokenizer_path (str): Path to the tokenizer file. 67 | max_seq_len (int): Maximum sequence length for input text. 68 | max_batch_size (int): Maximum batch size for inference. 69 | model_parallel_size (Optional[int], optional): Number of model parallel processes. 70 | If not provided, it's determined from the environment. Defaults to None. 71 | 72 | Returns: 73 | Llama: An instance of the Llama class with the loaded model and tokenizer. 74 | 75 | Raises: 76 | AssertionError: If there are no checkpoint files in the specified directory, 77 | or if the model parallel size does not match the number of checkpoint files. 78 | 79 | Note: 80 | This method initializes the distributed process group, sets the device to CUDA, 81 | and loads the pre-trained model and tokenizer. 82 | 83 | """ 84 | if not torch.distributed.is_initialized(): 85 | torch.distributed.init_process_group("nccl") 86 | if not model_parallel_is_initialized(): 87 | if model_parallel_size is None: 88 | model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) 89 | initialize_model_parallel(model_parallel_size) 90 | 91 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 92 | torch.cuda.set_device(local_rank) 93 | 94 | # seed must be the same in all processes 95 | torch.manual_seed(seed) 96 | 97 | if local_rank > 0: 98 | sys.stdout = open(os.devnull, "w") 99 | 100 | start_time = time.time() 101 | checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) 102 | assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" 103 | assert model_parallel_size == len( 104 | checkpoints 105 | ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" 106 | ckpt_path = checkpoints[get_model_parallel_rank()] 107 | checkpoint = torch.load(ckpt_path, map_location="cpu") 108 | with open(Path(ckpt_dir) / "params.json", "r") as f: 109 | params = json.loads(f.read()) 110 | 111 | model_args: ModelArgs = ModelArgs( 112 | max_seq_len=max_seq_len, 113 | max_batch_size=max_batch_size, 114 | **params, 115 | ) 116 | tokenizer = Tokenizer(model_path=tokenizer_path) 117 | model_args.vocab_size = tokenizer.n_words 118 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 119 | model = Transformer(model_args) 120 | model.load_state_dict(checkpoint, strict=False) 121 | print(f"Loaded in {time.time() - start_time:.2f} seconds") 122 | 123 | return Llama(model, tokenizer) 124 | 125 | def __init__(self, model: Transformer, tokenizer: Tokenizer): 126 | self.model = model 127 | self.tokenizer = tokenizer 128 | 129 | @torch.inference_mode() 130 | def generate( 131 | self, 132 | prompt_tokens: List[List[int]], 133 | max_gen_len: int, 134 | temperature: float = 0.6, 135 | top_p: float = 0.9, 136 | logprobs: bool = False, 137 | echo: bool = False, 138 | ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: 139 | """ 140 | Generate text sequences based on provided prompts using the language generation model. 141 | 142 | Args: 143 | prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. 144 | max_gen_len (int): Maximum length of the generated text sequence. 145 | temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. 146 | top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. 147 | logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. 148 | echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. 149 | 150 | Returns: 151 | Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. 152 | 153 | Note: 154 | This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. 155 | If logprobs is True, token log probabilities are computed for each generated token. 156 | 157 | """ 158 | params = self.model.params 159 | bsz = len(prompt_tokens) 160 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) 161 | 162 | min_prompt_len = min(len(t) for t in prompt_tokens) 163 | max_prompt_len = max(len(t) for t in prompt_tokens) 164 | assert max_prompt_len <= params.max_seq_len 165 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) 166 | 167 | pad_id = self.tokenizer.pad_id 168 | tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") 169 | for k, t in enumerate(prompt_tokens): 170 | tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") 171 | if logprobs: 172 | token_logprobs = torch.zeros_like(tokens, dtype=torch.float) 173 | 174 | prev_pos = 0 175 | eos_reached = torch.tensor([False] * bsz, device="cuda") 176 | input_text_mask = tokens != pad_id 177 | if min_prompt_len == total_len: 178 | logits = self.model.forward(tokens, prev_pos) 179 | token_logprobs = -F.cross_entropy( 180 | input=logits.transpose(1, 2), 181 | target=tokens, 182 | reduction="none", 183 | ignore_index=pad_id, 184 | ) 185 | 186 | for cur_pos in range(min_prompt_len, total_len): 187 | logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) 188 | if temperature > 0: 189 | probs = torch.softmax(logits[:, -1] / temperature, dim=-1) 190 | next_token = sample_top_p(probs, top_p) 191 | else: 192 | next_token = torch.argmax(logits[:, -1], dim=-1) 193 | 194 | next_token = next_token.reshape(-1) 195 | # only replace token if prompt has already been generated 196 | next_token = torch.where( 197 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token 198 | ) 199 | tokens[:, cur_pos] = next_token 200 | if logprobs: 201 | token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( 202 | input=logits.transpose(1, 2), 203 | target=tokens[:, prev_pos + 1 : cur_pos + 1], 204 | reduction="none", 205 | ignore_index=pad_id, 206 | ) 207 | eos_reached |= (~input_text_mask[:, cur_pos]) & ( 208 | next_token == self.tokenizer.eos_id 209 | ) 210 | prev_pos = cur_pos 211 | if all(eos_reached): 212 | break 213 | 214 | if logprobs: 215 | token_logprobs = token_logprobs.tolist() 216 | out_tokens, out_logprobs = [], [] 217 | for i, toks in enumerate(tokens.tolist()): 218 | # cut to max gen len 219 | start = 0 if echo else len(prompt_tokens[i]) 220 | toks = toks[start : len(prompt_tokens[i]) + max_gen_len] 221 | probs = None 222 | if logprobs: 223 | probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] 224 | # cut to eos tok if any 225 | if self.tokenizer.eos_id in toks: 226 | eos_idx = toks.index(self.tokenizer.eos_id) 227 | toks = toks[:eos_idx] 228 | probs = probs[:eos_idx] if logprobs else None 229 | out_tokens.append(toks) 230 | out_logprobs.append(probs) 231 | return (out_tokens, out_logprobs if logprobs else None) 232 | 233 | def text_completion( 234 | self, 235 | prompts: List[str], 236 | temperature: float = 0.6, 237 | top_p: float = 0.9, 238 | max_gen_len: Optional[int] = None, 239 | logprobs: bool = False, 240 | echo: bool = False, 241 | ) -> List[CompletionPrediction]: 242 | """ 243 | Perform text completion for a list of prompts using the language generation model. 244 | 245 | Args: 246 | prompts (List[str]): List of text prompts for completion. 247 | temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. 248 | top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. 249 | max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. 250 | If not provided, it's set to the model's maximum sequence length minus 1. 251 | logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. 252 | echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. 253 | 254 | Returns: 255 | List[CompletionPrediction]: List of completion predictions, each containing the generated text completion. 256 | 257 | Note: 258 | This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness. 259 | If logprobs is True, token log probabilities are computed for each generated token. 260 | 261 | """ 262 | if max_gen_len is None: 263 | max_gen_len = self.model.params.max_seq_len - 1 264 | prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] 265 | generation_tokens, generation_logprobs = self.generate( 266 | prompt_tokens=prompt_tokens, 267 | max_gen_len=max_gen_len, 268 | temperature=temperature, 269 | top_p=top_p, 270 | logprobs=logprobs, 271 | echo=echo, 272 | ) 273 | if logprobs: 274 | return [ 275 | { 276 | "generation": self.tokenizer.decode(t), 277 | "tokens": [self.tokenizer.decode(x) for x in t], 278 | "logprobs": logprobs_i, 279 | } 280 | for t, logprobs_i in zip(generation_tokens, generation_logprobs) 281 | ] 282 | return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens] 283 | 284 | def chat_completion( 285 | self, 286 | dialogs: List[Dialog], 287 | temperature: float = 0.6, 288 | top_p: float = 0.9, 289 | max_gen_len: Optional[int] = None, 290 | logprobs: bool = False, 291 | ) -> List[ChatPrediction]: 292 | """ 293 | Generate assistant responses for a list of conversational dialogs using the language generation model. 294 | 295 | Args: 296 | dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. 297 | temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. 298 | top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. 299 | max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. 300 | If not provided, it's set to the model's maximum sequence length minus 1. 301 | logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. 302 | 303 | Returns: 304 | List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. 305 | 306 | Raises: 307 | AssertionError: If the last message in a dialog is not from the user. 308 | AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order. 309 | 310 | Note: 311 | This method generates assistant responses for the provided conversational dialogs. 312 | It employs nucleus sampling to introduce controlled randomness in text generation. 313 | If logprobs is True, token log probabilities are computed for each generated token. 314 | 315 | """ 316 | if max_gen_len is None: 317 | max_gen_len = self.model.params.max_seq_len - 1 318 | prompt_tokens = [] 319 | unsafe_requests = [] 320 | for dialog in dialogs: 321 | unsafe_requests.append( 322 | any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) 323 | ) 324 | if dialog[0]["role"] == "system": 325 | dialog = [ 326 | { 327 | "role": dialog[1]["role"], 328 | "content": B_SYS 329 | + dialog[0]["content"] 330 | + E_SYS 331 | + dialog[1]["content"], 332 | } 333 | ] + dialog[2:] 334 | assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( 335 | [msg["role"] == "assistant" for msg in dialog[1::2]] 336 | ), ( 337 | "model only supports 'system', 'user' and 'assistant' roles, " 338 | "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" 339 | ) 340 | dialog_tokens: List[int] = sum( 341 | [ 342 | self.tokenizer.encode( 343 | f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", 344 | bos=True, 345 | eos=True, 346 | ) 347 | for prompt, answer in zip( 348 | dialog[::2], 349 | dialog[1::2], 350 | ) 351 | ], 352 | [], 353 | ) 354 | assert ( 355 | dialog[-1]["role"] == "user" 356 | ), f"Last message must be from user, got {dialog[-1]['role']}" 357 | dialog_tokens += self.tokenizer.encode( 358 | f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", 359 | bos=True, 360 | eos=False, 361 | ) 362 | prompt_tokens.append(dialog_tokens) 363 | 364 | generation_tokens, generation_logprobs = self.generate( 365 | prompt_tokens=prompt_tokens, 366 | max_gen_len=max_gen_len, 367 | temperature=temperature, 368 | top_p=top_p, 369 | logprobs=logprobs, 370 | ) 371 | if logprobs: 372 | return [ 373 | { 374 | "generation": { 375 | "role": "assistant", 376 | "content": self.tokenizer.decode(t) 377 | if not unsafe 378 | else UNSAFE_ERROR, 379 | }, 380 | "tokens": [self.tokenizer.decode(x) for x in t], 381 | "logprobs": logprobs_i, 382 | } 383 | for t, logprobs_i, unsafe in zip( 384 | generation_tokens, generation_logprobs, unsafe_requests 385 | ) 386 | ] 387 | return [ 388 | { 389 | "generation": { 390 | "role": "assistant", 391 | "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, 392 | } 393 | } 394 | for t, unsafe in zip(generation_tokens, unsafe_requests) 395 | ] 396 | 397 | 398 | def sample_top_p(probs, p): 399 | """ 400 | Perform top-p (nucleus) sampling on a probability distribution. 401 | 402 | Args: 403 | probs (torch.Tensor): Probability distribution tensor. 404 | p (float): Probability threshold for top-p sampling. 405 | 406 | Returns: 407 | torch.Tensor: Sampled token indices. 408 | 409 | Note: 410 | Top-p sampling selects the smallest set of tokens whose cumulative probability mass 411 | exceeds the threshold p. The distribution is renormalized based on the selected tokens. 412 | 413 | """ 414 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 415 | probs_sum = torch.cumsum(probs_sort, dim=-1) 416 | mask = probs_sum - probs_sort > p 417 | probs_sort[mask] = 0.0 418 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 419 | next_token = torch.multinomial(probs_sort, num_samples=1) 420 | next_token = torch.gather(probs_idx, -1, next_token) 421 | return next_token 422 | -------------------------------------------------------------------------------- /llama/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import math 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple 7 | 8 | import fairscale.nn.model_parallel.initialize as fs_init 9 | import torch 10 | import torch.nn.functional as F 11 | from fairscale.nn.model_parallel.layers import ( 12 | ColumnParallelLinear, 13 | ParallelEmbedding, 14 | RowParallelLinear, 15 | ) 16 | from torch import nn 17 | 18 | 19 | @dataclass 20 | class ModelArgs: 21 | dim: int = 4096 22 | n_layers: int = 32 23 | n_heads: int = 32 24 | n_kv_heads: Optional[int] = None 25 | vocab_size: int = -1 # defined later by tokenizer 26 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 27 | ffn_dim_multiplier: Optional[float] = None 28 | norm_eps: float = 1e-5 29 | 30 | max_batch_size: int = 32 31 | max_seq_len: int = 2048 32 | 33 | 34 | class RMSNorm(torch.nn.Module): 35 | def __init__(self, dim: int, eps: float = 1e-6): 36 | """ 37 | Initialize the RMSNorm normalization layer. 38 | 39 | Args: 40 | dim (int): The dimension of the input tensor. 41 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 42 | 43 | Attributes: 44 | eps (float): A small value added to the denominator for numerical stability. 45 | weight (nn.Parameter): Learnable scaling parameter. 46 | 47 | """ 48 | super().__init__() 49 | self.eps = eps 50 | self.weight = nn.Parameter(torch.ones(dim)) 51 | 52 | def _norm(self, x): 53 | """ 54 | Apply the RMSNorm normalization to the input tensor. 55 | 56 | Args: 57 | x (torch.Tensor): The input tensor. 58 | 59 | Returns: 60 | torch.Tensor: The normalized tensor. 61 | 62 | """ 63 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 64 | 65 | def forward(self, x): 66 | """ 67 | Forward pass through the RMSNorm layer. 68 | 69 | Args: 70 | x (torch.Tensor): The input tensor. 71 | 72 | Returns: 73 | torch.Tensor: The output tensor after applying RMSNorm. 74 | 75 | """ 76 | output = self._norm(x.float()).type_as(x) 77 | return output * self.weight 78 | 79 | 80 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 81 | """ 82 | Precompute the frequency tensor for complex exponentials (cis) with given dimensions. 83 | 84 | This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' 85 | and the end index 'end'. The 'theta' parameter scales the frequencies. 86 | The returned tensor contains complex values in complex64 data type. 87 | 88 | Args: 89 | dim (int): Dimension of the frequency tensor. 90 | end (int): End index for precomputing frequencies. 91 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 92 | 93 | Returns: 94 | torch.Tensor: Precomputed frequency tensor with complex exponentials. 95 | 96 | 97 | 98 | 99 | """ 100 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 101 | t = torch.arange(end, device=freqs.device) # type: ignore 102 | freqs = torch.outer(t, freqs).float() # type: ignore 103 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 104 | return freqs_cis 105 | 106 | 107 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 108 | """ 109 | Reshape frequency tensor for broadcasting it with another tensor. 110 | 111 | This function reshapes the frequency tensor to have the same shape as the target tensor 'x' 112 | for the purpose of broadcasting the frequency tensor during element-wise operations. 113 | 114 | Args: 115 | freqs_cis (torch.Tensor): Frequency tensor to be reshaped. 116 | x (torch.Tensor): Target tensor for broadcasting compatibility. 117 | 118 | Returns: 119 | torch.Tensor: Reshaped frequency tensor. 120 | 121 | Raises: 122 | AssertionError: If the frequency tensor doesn't match the expected shape. 123 | AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. 124 | """ 125 | ndim = x.ndim 126 | assert 0 <= 1 < ndim 127 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 128 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 129 | return freqs_cis.view(*shape) 130 | 131 | 132 | def apply_rotary_emb( 133 | xq: torch.Tensor, 134 | xk: torch.Tensor, 135 | freqs_cis: torch.Tensor, 136 | ) -> Tuple[torch.Tensor, torch.Tensor]: 137 | """ 138 | Apply rotary embeddings to input tensors using the given frequency tensor. 139 | 140 | This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided 141 | frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor 142 | is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are 143 | returned as real tensors. 144 | 145 | Args: 146 | xq (torch.Tensor): Query tensor to apply rotary embeddings. 147 | xk (torch.Tensor): Key tensor to apply rotary embeddings. 148 | freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. 149 | 150 | Returns: 151 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 152 | 153 | 154 | 155 | """ 156 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 157 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 158 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 159 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 160 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 161 | return xq_out.type_as(xq), xk_out.type_as(xk) 162 | 163 | 164 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 165 | """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" 166 | bs, slen, n_kv_heads, head_dim = x.shape 167 | if n_rep == 1: 168 | return x 169 | return ( 170 | x[:, :, :, None, :] 171 | .expand(bs, slen, n_kv_heads, n_rep, head_dim) 172 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim) 173 | ) 174 | 175 | 176 | class Attention(nn.Module): 177 | """Multi-head attention module.""" 178 | def __init__(self, args: ModelArgs): 179 | """ 180 | Initialize the Attention module. 181 | 182 | Args: 183 | args (ModelArgs): Model configuration parameters. 184 | 185 | Attributes: 186 | n_kv_heads (int): Number of key and value heads. 187 | n_local_heads (int): Number of local query heads. 188 | n_local_kv_heads (int): Number of local key and value heads. 189 | n_rep (int): Number of repetitions for local heads. 190 | head_dim (int): Dimension size of each attention head. 191 | wq (ColumnParallelLinear): Linear transformation for queries. 192 | wk (ColumnParallelLinear): Linear transformation for keys. 193 | wv (ColumnParallelLinear): Linear transformation for values. 194 | wo (RowParallelLinear): Linear transformation for output. 195 | cache_k (torch.Tensor): Cached keys for attention. 196 | cache_v (torch.Tensor): Cached values for attention. 197 | 198 | """ 199 | super().__init__() 200 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads 201 | model_parallel_size = fs_init.get_model_parallel_world_size() 202 | self.n_local_heads = args.n_heads // model_parallel_size 203 | self.n_local_kv_heads = self.n_kv_heads // model_parallel_size 204 | self.n_rep = self.n_local_heads // self.n_local_kv_heads 205 | self.head_dim = args.dim // args.n_heads 206 | 207 | self.wq = ColumnParallelLinear( 208 | args.dim, 209 | args.n_heads * self.head_dim, 210 | bias=False, 211 | gather_output=False, 212 | init_method=lambda x: x, 213 | ) 214 | self.wk = ColumnParallelLinear( 215 | args.dim, 216 | self.n_kv_heads * self.head_dim, 217 | bias=False, 218 | gather_output=False, 219 | init_method=lambda x: x, 220 | ) 221 | self.wv = ColumnParallelLinear( 222 | args.dim, 223 | self.n_kv_heads * self.head_dim, 224 | bias=False, 225 | gather_output=False, 226 | init_method=lambda x: x, 227 | ) 228 | self.wo = RowParallelLinear( 229 | args.n_heads * self.head_dim, 230 | args.dim, 231 | bias=False, 232 | input_is_parallel=True, 233 | init_method=lambda x: x, 234 | ) 235 | 236 | self.cache_k = torch.zeros( 237 | ( 238 | args.max_batch_size, 239 | args.max_seq_len, 240 | self.n_local_kv_heads, 241 | self.head_dim, 242 | ) 243 | ).cuda() 244 | self.cache_v = torch.zeros( 245 | ( 246 | args.max_batch_size, 247 | args.max_seq_len, 248 | self.n_local_kv_heads, 249 | self.head_dim, 250 | ) 251 | ).cuda() 252 | 253 | def forward( 254 | self, 255 | x: torch.Tensor, 256 | start_pos: int, 257 | freqs_cis: torch.Tensor, 258 | mask: Optional[torch.Tensor], 259 | ): 260 | """ 261 | Forward pass of the attention module. 262 | 263 | Args: 264 | x (torch.Tensor): Input tensor. 265 | start_pos (int): Starting position for caching. 266 | freqs_cis (torch.Tensor): Precomputed frequency tensor. 267 | mask (torch.Tensor, optional): Attention mask tensor. 268 | 269 | Returns: 270 | torch.Tensor: Output tensor after attention. 271 | 272 | """ 273 | bsz, seqlen, _ = x.shape 274 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 275 | 276 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 277 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 278 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 279 | 280 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 281 | 282 | self.cache_k = self.cache_k.to(xq) 283 | self.cache_v = self.cache_v.to(xq) 284 | 285 | self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk 286 | self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv 287 | 288 | keys = self.cache_k[:bsz, : start_pos + seqlen] 289 | values = self.cache_v[:bsz, : start_pos + seqlen] 290 | 291 | # repeat k/v heads if n_kv_heads < n_heads 292 | keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) 293 | values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) 294 | 295 | xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 296 | keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) 297 | values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) 298 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) 299 | if mask is not None: 300 | scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) 301 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 302 | output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) 303 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 304 | return self.wo(output) 305 | 306 | 307 | class FeedForward(nn.Module): 308 | def __init__( 309 | self, 310 | dim: int, 311 | hidden_dim: int, 312 | multiple_of: int, 313 | ffn_dim_multiplier: Optional[float], 314 | ): 315 | """ 316 | Initialize the FeedForward module. 317 | 318 | Args: 319 | dim (int): Input dimension. 320 | hidden_dim (int): Hidden dimension of the feedforward layer. 321 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value. 322 | ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. 323 | 324 | Attributes: 325 | w1 (ColumnParallelLinear): Linear transformation for the first layer. 326 | w2 (RowParallelLinear): Linear transformation for the second layer. 327 | w3 (ColumnParallelLinear): Linear transformation for the third layer. 328 | 329 | """ 330 | super().__init__() 331 | hidden_dim = int(2 * hidden_dim / 3) 332 | # custom dim factor multiplier 333 | if ffn_dim_multiplier is not None: 334 | hidden_dim = int(ffn_dim_multiplier * hidden_dim) 335 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 336 | 337 | self.w1 = ColumnParallelLinear( 338 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x 339 | ) 340 | self.w2 = RowParallelLinear( 341 | hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x 342 | ) 343 | self.w3 = ColumnParallelLinear( 344 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x 345 | ) 346 | 347 | def forward(self, x): 348 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 349 | 350 | 351 | class TransformerBlock(nn.Module): 352 | def __init__(self, layer_id: int, args: ModelArgs): 353 | """ 354 | Initialize a TransformerBlock. 355 | 356 | Args: 357 | layer_id (int): Identifier for the layer. 358 | args (ModelArgs): Model configuration parameters. 359 | 360 | Attributes: 361 | n_heads (int): Number of attention heads. 362 | dim (int): Dimension size of the model. 363 | head_dim (int): Dimension size of each attention head. 364 | attention (Attention): Attention module. 365 | feed_forward (FeedForward): FeedForward module. 366 | layer_id (int): Identifier for the layer. 367 | attention_norm (RMSNorm): Layer normalization for attention output. 368 | ffn_norm (RMSNorm): Layer normalization for feedforward output. 369 | 370 | """ 371 | super().__init__() 372 | self.n_heads = args.n_heads 373 | self.dim = args.dim 374 | self.head_dim = args.dim // args.n_heads 375 | self.attention = Attention(args) 376 | self.feed_forward = FeedForward( 377 | dim=args.dim, 378 | hidden_dim=4 * args.dim, 379 | multiple_of=args.multiple_of, 380 | ffn_dim_multiplier=args.ffn_dim_multiplier, 381 | ) 382 | self.layer_id = layer_id 383 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 384 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 385 | 386 | def forward( 387 | self, 388 | x: torch.Tensor, 389 | start_pos: int, 390 | freqs_cis: torch.Tensor, 391 | mask: Optional[torch.Tensor], 392 | ): 393 | """ 394 | Perform a forward pass through the TransformerBlock. 395 | 396 | Args: 397 | x (torch.Tensor): Input tensor. 398 | start_pos (int): Starting position for attention caching. 399 | freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. 400 | mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. 401 | 402 | Returns: 403 | torch.Tensor: Output tensor after applying attention and feedforward layers. 404 | 405 | """ 406 | h = x + self.attention( 407 | self.attention_norm(x), start_pos, freqs_cis, mask 408 | ) 409 | out = h + self.feed_forward(self.ffn_norm(h)) 410 | return out 411 | 412 | 413 | class Transformer(nn.Module): 414 | def __init__(self, params: ModelArgs): 415 | """ 416 | Initialize a Transformer model. 417 | 418 | Args: 419 | params (ModelArgs): Model configuration parameters. 420 | 421 | Attributes: 422 | params (ModelArgs): Model configuration parameters. 423 | vocab_size (int): Vocabulary size. 424 | n_layers (int): Number of layers in the model. 425 | tok_embeddings (ParallelEmbedding): Token embeddings. 426 | layers (torch.nn.ModuleList): List of Transformer blocks. 427 | norm (RMSNorm): Layer normalization for the model output. 428 | output (ColumnParallelLinear): Linear layer for final output. 429 | freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. 430 | 431 | """ 432 | super().__init__() 433 | self.params = params 434 | self.vocab_size = params.vocab_size 435 | self.n_layers = params.n_layers 436 | 437 | self.tok_embeddings = ParallelEmbedding( 438 | params.vocab_size, params.dim, init_method=lambda x: x 439 | ) 440 | 441 | self.layers = torch.nn.ModuleList() 442 | for layer_id in range(params.n_layers): 443 | self.layers.append(TransformerBlock(layer_id, params)) 444 | 445 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 446 | self.output = ColumnParallelLinear( 447 | params.dim, params.vocab_size, bias=False, init_method=lambda x: x 448 | ) 449 | 450 | self.freqs_cis = precompute_freqs_cis( 451 | # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. 452 | # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. 453 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 454 | ) 455 | 456 | @torch.inference_mode() 457 | def forward(self, tokens: torch.Tensor, start_pos: int): 458 | """ 459 | Perform a forward pass through the Transformer model. 460 | 461 | Args: 462 | tokens (torch.Tensor): Input token indices. 463 | start_pos (int): Starting position for attention caching. 464 | 465 | Returns: 466 | torch.Tensor: Output logits after applying the Transformer model. 467 | 468 | """ 469 | _bsz, seqlen = tokens.shape 470 | h = self.tok_embeddings(tokens) 471 | self.freqs_cis = self.freqs_cis.to(h.device) 472 | freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] 473 | 474 | mask = None 475 | if seqlen > 1: 476 | mask = torch.full( 477 | (seqlen, seqlen), float("-inf"), device=tokens.device 478 | ) 479 | 480 | mask = torch.triu(mask, diagonal=1) 481 | 482 | # When performing key-value caching, we compute the attention scores 483 | # only for the new sequence. Thus, the matrix of scores is of size 484 | # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for 485 | # j > cache_len + i, since row i corresponds to token cache_len + i. 486 | mask = torch.hstack([ 487 | torch.zeros((seqlen, start_pos), device=tokens.device), 488 | mask 489 | ]).type_as(h) 490 | 491 | for layer in self.layers: 492 | h = layer(h, start_pos, freqs_cis, mask) 493 | h = self.norm(h) 494 | output = self.output(h).float() 495 | return output 496 | --------------------------------------------------------------------------------