├── entropixing ├── __init__.py ├── attn_stats.py ├── utils.py ├── kv_cache.py ├── generate.py ├── model.py ├── sampler.py └── llama_cpp_impl.py ├── .python-version ├── .gitignore ├── requirements.txt ├── README.md ├── chat_llama_cpp.py ├── elyza_tasks.py ├── main.py ├── chat.py ├── server_llama_cpp.py └── server.py /entropixing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | __pycache__/ 3 | .aider* 4 | *.sh 5 | *.gguf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | torch 3 | accelerate 4 | fastapi 5 | uvicorn 6 | openai 7 | rich 8 | datasets 9 | torchao 10 | cupy 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Support status 2 | ## Transformers 3 | - [x] Gemma2 4 | - [x] Llama 5 | - [x] Mistral 6 | - [-] Qwen2(Somehow broken) 7 | - [-] Phi3(Somehow broken) 8 | ## Llama-cpp-python 9 | all -------------------------------------------------------------------------------- /entropixing/attn_stats.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import torch 4 | 5 | 6 | class AttnStats(NamedTuple): 7 | entropy: torch.Tensor # (bsz, n_layers, num_heads) 8 | varentropy: torch.Tensor # (bsz, n_layers, num_heads) 9 | n_layers: int 10 | n_heads: int 11 | 12 | @classmethod 13 | def init(self, bsz: int, n_layers: int, n_heads: int, device: torch.device): 14 | return self( 15 | entropy=torch.zeros((bsz, n_layers, n_heads), device=device), 16 | varentropy=torch.zeros((bsz, n_layers, n_heads), device=device), 17 | n_layers=n_layers, 18 | n_heads=n_heads, 19 | ) 20 | 21 | @property 22 | def avg_entropy(self): 23 | return self.entropy.sum(dim=-1, keepdim=False) # Average across heads 24 | 25 | @property 26 | def std_error(self): 27 | return torch.sqrt(torch.mean(self.varentropy)) / (self.n_heads * self.n_layers) 28 | 29 | def update(self, scores: torch.Tensor, layer_idx: int): 30 | # scores shape: (bsz, n_heads, seqlen, n_words) 31 | probs = torch.nn.functional.softmax(scores, dim=-1) 32 | new_entropy = -torch.sum( 33 | torch.where(probs > 0, probs * torch.log(probs), torch.tensor(0.0)), dim=-1 34 | ) 35 | new_varentropy = torch.sum( 36 | probs * (torch.log(probs) + new_entropy.unsqueeze(-1)) ** 2, dim=-1 37 | ) 38 | 39 | # Update entropy and varentropy tensors 40 | self.entropy[:, layer_idx, :] = new_entropy 41 | self.varentropy[:, layer_idx, :] = new_varentropy 42 | 43 | return self 44 | -------------------------------------------------------------------------------- /chat_llama_cpp.py: -------------------------------------------------------------------------------- 1 | from entropixing.llama_cpp_impl import generate_response 2 | from rich.console import Console 3 | from transformers import AutoTokenizer 4 | from llama_cpp import Llama, GGML_TYPE_Q4_0 5 | 6 | 7 | def main(): 8 | from argparse import ArgumentParser 9 | 10 | global device 11 | console = Console() 12 | parser = ArgumentParser() 13 | parser.add_argument("--tokenizer", type=str, default="google/gemma-2-2b-it") 14 | parser.add_argument("--model", type=str, required=True, default="./model.gguf") 15 | parser.add_argument("--max_length", type=int, default=4096) 16 | parser.add_argument("--context_length", type=int, default=16384) 17 | parser.add_argument("--ngl", type=int, default=0) 18 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 19 | args = parser.parse_args() 20 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 21 | conv = [] 22 | weights = Llama( 23 | args.model, 24 | n_gpu_layers=args.ngl, 25 | n_ctx=args.context_length, 26 | verbose=False, 27 | flash_attn=True, 28 | type_k=GGML_TYPE_Q4_0, 29 | type_v=GGML_TYPE_Q4_0, 30 | ) 31 | while True: 32 | console.print("User: ", end="", style="green") 33 | inp = input("").strip() 34 | if inp == "exit": 35 | break 36 | elif inp == "clear": 37 | conv.clear() 38 | continue 39 | conv.append({"role": "user", "content": inp}) 40 | inputs = tokenizer.apply_chat_template( 41 | conv, tokenize=False, add_generation_prompt=True 42 | )[len(tokenizer.bos_token) if tokenizer.bos_token else 0 :] 43 | it = generate_response( 44 | weights, 45 | inputs, 46 | args.max_length, 47 | stop=[tokenizer.eos_token], 48 | ) 49 | console.print("Assistant: ", end="", style="green") 50 | text = "" 51 | for token in it: 52 | console.print(token, end="") 53 | text += token 54 | conv.append({"role": "assistant", "content": text.strip()}) 55 | console.print() 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /elyza_tasks.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | from transformers import AutoTokenizer 3 | from datasets import load_dataset 4 | from tqdm import tqdm 5 | from argparse import ArgumentParser 6 | import time 7 | 8 | parser = ArgumentParser() 9 | parser.add_argument("--api_key", type=str, required=True) 10 | parser.add_argument("--model_name", type=str, default="gpt-3.5-turbo") 11 | parser.add_argument("--api_base", type=str, default=None) 12 | parser.add_argument("--model_base", type=str, default="http://localhost:8000") 13 | parser.add_argument("--llama_cpp", action="store_true") 14 | args = parser.parse_args() 15 | 16 | ai = OpenAI(base_url=args.api_base, api_key=args.api_key) 17 | ai2 = OpenAI(base_url=args.model_base, api_key="hello") 18 | 19 | 20 | def generate_response(q: str): 21 | res = ai2.chat.completions.create( 22 | model="default", 23 | messages=[{"role": "user", "content": q}], 24 | max_completion_tokens=2048, 25 | ) 26 | return res.choices[0].message.content.strip() 27 | 28 | 29 | def generate_score(p: str): 30 | try: 31 | res = ai.chat.completions.create( 32 | model=args.model_name, 33 | messages=[{"role": "user", "content": p}], 34 | extra_body={"grammar": "root ::= [1-5]"} if args.llama_cpp else {}, 35 | ) 36 | return float(res.choices[0].message.content.strip()) 37 | except Exception as e: 38 | time.sleep(2) 39 | print(f"Error: {e}") 40 | return generate_score(p) 41 | 42 | 43 | def eval_one(q: str, a: str, aspect: str): 44 | pred = generate_response(q) 45 | res = generate_score( 46 | f"""あなたは採点者です。 47 | 48 | 問題, 正解例, 採点基準, 回答 が与えられます。 49 | 50 | 採点基準と正解例を参考にして、回答を1,2,3,4,5の5段階で採点し、数字のみを出力してください。 51 | 52 | # 問題 53 | {q} 54 | 55 | # 正解例 56 | {a} 57 | 58 | # 採点基準 59 | 基本的な採点基準 60 | - 1点: 誤っている、 指示に従えていない 61 | - 2点: 誤っているが、方向性は合っている 62 | - 3点: 部分的に誤っている、 部分的に合っている 63 | - 4点: 合っている 64 | - 5点: 役に立つ 65 | 66 | 基本的な減点項目 67 | - 不自然な日本語: -1点 68 | - 部分的に事実と異なる内容を述べている: -1点 69 | - 「倫理的に答えられません」のように過度に安全性を気にしてしまっている: 2点にする 70 | 71 | 問題固有の採点基準 72 | {aspect} 73 | 74 | # 回答 75 | {pred}""" 76 | ) 77 | return res, pred 78 | 79 | 80 | if __name__ == "__main__": 81 | ds = load_dataset("elyza/ELYZA-tasks-100", split="test") 82 | score = 0.0 83 | res = [] 84 | for entry in tqdm(ds): 85 | q = entry["input"] 86 | a = entry["output"] 87 | aspect = entry["eval_aspect"] 88 | s, pred = eval_one( 89 | q, 90 | a, 91 | aspect, 92 | ) 93 | score += s 94 | print(f"Score: {s}\nQuestion: {q}\nAnswer: {a}\nPred: {pred}") 95 | res.append( 96 | {"score": int(s), "question": q, "pure_answer": a, "pred_answer": pred} 97 | ) 98 | print(f"Score: {score / len(ds)}") 99 | -------------------------------------------------------------------------------- /entropixing/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig, PretrainedConfig 3 | 4 | 5 | def apply_scaling(freqs: torch.Tensor) -> torch.Tensor: 6 | SCALE_FACTOR = 8.0 7 | LOW_FREQ_FACTOR = 1.0 8 | HIGH_FREQ_FACTOR = 4.0 9 | OLD_CONTEXT_LEN = 8192 # original llama3 length 10 | 11 | low_freq_wavelen = OLD_CONTEXT_LEN / LOW_FREQ_FACTOR 12 | high_freq_wavelen = OLD_CONTEXT_LEN / HIGH_FREQ_FACTOR 13 | 14 | def scale_freq(freq: torch.Tensor) -> torch.Tensor: 15 | wavelen = 2 * torch.pi / freq 16 | 17 | # Calculate smooth factor 18 | smooth = (OLD_CONTEXT_LEN / wavelen - LOW_FREQ_FACTOR) / ( 19 | HIGH_FREQ_FACTOR - LOW_FREQ_FACTOR 20 | ) 21 | smooth = torch.clamp(smooth, 0.0, 1.0) # Ensure smooth is between 0 and 1 22 | 23 | # Calculate scaled frequency 24 | scaled = (1 - smooth) * freq / SCALE_FACTOR + smooth * freq 25 | 26 | # Apply conditional scaling 27 | scaled = torch.where( 28 | wavelen < high_freq_wavelen, 29 | freq, # No scaling 30 | torch.where( 31 | wavelen > low_freq_wavelen, 32 | freq / SCALE_FACTOR, # Apply scaling factor 33 | scaled, # Apply smooth scaling 34 | ), 35 | ) 36 | return scaled 37 | 38 | scaled_freqs = torch.vmap(scale_freq)(freqs) 39 | 40 | return scaled_freqs 41 | 42 | 43 | def precompute_freqs_cis( 44 | dim: int, 45 | end: int, 46 | theta: float = 500000.0, 47 | use_scaled: bool = False, 48 | dtype: torch.dtype = torch.float32, 49 | device: torch.device = "cpu", 50 | ) -> torch.Tensor: 51 | freqs = 1.0 / ( 52 | theta 53 | ** (torch.arange(0, dim, 2, dtype=dtype, device=device)[: (dim // 2)] / dim) 54 | ) 55 | if use_scaled: 56 | freqs = apply_scaling(freqs) 57 | 58 | t = torch.arange(end, dtype=dtype, device=device).unsqueeze(1) # Shape: (end, 1) 59 | freqs = freqs.unsqueeze(0) # Shape: (1, dim//2) 60 | freqs = t * freqs # Broadcasting to shape: (end, dim//2) 61 | return torch.exp(1j * freqs) 62 | 63 | 64 | def build_attn_mask( 65 | seqlen: int, 66 | start_pos: int, 67 | device: torch.device = "cpu", 68 | ) -> torch.Tensor: 69 | mask = None 70 | if seqlen > 1: 71 | mask = torch.full((seqlen, seqlen), float("-inf")) 72 | mask = torch.triu(mask, diagonal=1) 73 | mask = ( 74 | torch.hstack([torch.zeros((seqlen, start_pos)), mask]) 75 | .to(torch.float32) 76 | .to(device) 77 | ) 78 | return mask 79 | 80 | 81 | def is_supported_model(mod: str): 82 | arch: PretrainedConfig = AutoConfig.from_pretrained(mod) 83 | return arch.architectures[0] in [ 84 | "Gemma2ForCausalLM", 85 | "LlamaForCausalLM", 86 | "Qwen2ForCausalLM", 87 | "MistralForCausalLM", 88 | "Phi3ForCausalLM", 89 | ] 90 | -------------------------------------------------------------------------------- /entropixing/kv_cache.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | class KVCache(nn.Module): 6 | def __init__( 7 | self, 8 | layers: int, 9 | bsz: int, 10 | max_seq_len: int, 11 | kv_heads: int, 12 | head_dim: int, 13 | device: torch.device, 14 | dtype: torch.dtype, 15 | ): 16 | super(KVCache, self).__init__() 17 | # Initialize k and v as buffers to ensure they're part of the module state 18 | self.register_buffer( 19 | "k", 20 | torch.zeros( 21 | (layers, bsz, max_seq_len, kv_heads, head_dim), 22 | dtype=dtype, 23 | device=device, 24 | ), 25 | ) 26 | self.register_buffer( 27 | "v", 28 | torch.zeros( 29 | (layers, bsz, max_seq_len, kv_heads, head_dim), 30 | dtype=dtype, 31 | device=device, 32 | ), 33 | ) 34 | 35 | def update( 36 | self, 37 | xk: torch.Tensor, 38 | xv: torch.Tensor, 39 | layer_idx: int, 40 | cur_pos: int, 41 | n_rep: int, 42 | ): 43 | """ 44 | Updates the cache with new key and value tensors. 45 | 46 | Args: 47 | xk (torch.Tensor): New key tensor to insert. Shape should align with (bsz, insert_len, kv_heads, head_dim). 48 | xv (torch.Tensor): New value tensor to insert. Shape should align with (bsz, insert_len, kv_heads, head_dim). 49 | layer_idx (int): The index of the layer to update. 50 | cur_pos (int): The current position in the sequence to start inserting. 51 | n_rep (int): The number of times to repeat the keys and values along the sequence dimension. 52 | 53 | Returns: 54 | Tuple[torch.Tensor, torch.Tensor]: 55 | - keys: Updated or repeated keys tensor. 56 | - values: Updated or repeated values tensor. 57 | """ 58 | # Ensure xk and xv have the correct device and dtype 59 | xk = xk.to(self.k.dtype) 60 | xv = xv.to(self.v.dtype) 61 | 62 | # Update the k and v tensors in the specified layer and position 63 | insert_len = xk.size( 64 | 1 65 | ) # Assuming xk shape is (bsz, insert_len, kv_heads, head_dim) 66 | self.k[layer_idx, :, cur_pos : cur_pos + insert_len, :, :] = xk 67 | self.v[layer_idx, :, cur_pos : cur_pos + insert_len, :, :] = xv 68 | 69 | if cur_pos == 0: 70 | # If inserting at the beginning, repeat the new keys and values 71 | keys = xk.repeat_interleave(n_rep, dim=2) 72 | values = xv.repeat_interleave(n_rep, dim=2) 73 | else: 74 | # Otherwise, repeat the existing keys and values from the cache 75 | keys = self.k[layer_idx].repeat_interleave(n_rep, dim=2) 76 | values = self.v[layer_idx].repeat_interleave(n_rep, dim=2) 77 | 78 | return keys, values, self 79 | 80 | def clear(self): 81 | """Resets the k and v caches to zeros.""" 82 | self.k.zero_() 83 | self.v.zero_() 84 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AutoModelForCausalLM, 4 | ) 5 | import torch 6 | from entropixing.generate import generate, stream 7 | from entropixing.utils import is_supported_model 8 | from rich.console import Console 9 | 10 | if torch.backends.mps.is_available(): 11 | device = torch.device("mps") 12 | elif torch.cuda.is_available(): 13 | device = torch.device("cuda") 14 | else: 15 | device = torch.device("cpu") 16 | 17 | print(f"Default device: {device}") 18 | 19 | torch.set_float32_matmul_precision("high") 20 | torch.backends.cudnn.benchmark = True 21 | 22 | 23 | def main(): 24 | from argparse import ArgumentParser 25 | 26 | global device 27 | parser = ArgumentParser() 28 | parser.add_argument( 29 | "--model", type=str, required=True, default="google/gemma-2-2b-jpn-it" 30 | ) 31 | parser.add_argument( 32 | "--dtype", 33 | type=str, 34 | choices=["float16", "bfloat16", "float32"], 35 | default="bfloat16", 36 | ) 37 | parser.add_argument("--max_length", type=int, default=4096) 38 | parser.add_argument("--context_length", type=int) 39 | parser.add_argument("--prompt", type=str, default="Hello, my name is ") 40 | parser.add_argument("--device", type=str, default=device.type) 41 | parser.add_argument("--top_p", type=float, default=0.95) 42 | parser.add_argument("--top_k", type=int, default=40) 43 | parser.add_argument("--min_p", type=int, default=0) 44 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 45 | parser.add_argument("--seed", type=int) 46 | parser.add_argument("--print_back", action="store_true") 47 | parser.add_argument("--go_back", action="store_true") 48 | args = parser.parse_args() 49 | device = torch.device(args.device) 50 | console = Console() 51 | print(f"Using device: {device}") 52 | if not is_supported_model(args.model): 53 | raise ValueError("Unsupported model") 54 | dtype = getattr(torch, args.dtype) 55 | weights = AutoModelForCausalLM.from_pretrained( 56 | args.model, 57 | device_map=device, 58 | torch_dtype=dtype, 59 | ).eval() 60 | 61 | tokenizer = AutoTokenizer.from_pretrained(args.model) 62 | inputs = tokenizer.encode(args.prompt, return_tensors="pt") 63 | 64 | console.print(args.prompt, style="green", end="") 65 | it = generate( 66 | weights, 67 | inputs, 68 | device, 69 | dtype, 70 | [tokenizer.eos_token_id], 71 | args.max_length, 72 | args.top_p, 73 | args.top_k, 74 | args.min_p, 75 | args.repetition_penalty, 76 | args.seed, 77 | args.go_back, 78 | args.context_length, 79 | ) 80 | for token in stream(it, tokenizer): 81 | if "text" in token: 82 | style = "" 83 | if token["entropy"] > 3: 84 | style = "bold" 85 | elif token["varentropy"] > 15: 86 | style += "blue" 87 | console.print(token["text"], style=style, end="") 88 | elif "back" in token: 89 | if args.print_back: 90 | console.print("⌫", style="red", end="") 91 | else: 92 | console.print("\b \b", end="") 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AutoModelForCausalLM, 4 | ) 5 | import torch 6 | from entropixing.generate import generate, stream 7 | from rich.console import Console 8 | 9 | from entropixing.utils import is_supported_model 10 | 11 | if torch.backends.mps.is_available(): 12 | device = torch.device("mps") 13 | elif torch.cuda.is_available(): 14 | device = torch.device("cuda") 15 | else: 16 | device = torch.device("cpu") 17 | 18 | print(f"Default device: {device}") 19 | 20 | torch.set_float32_matmul_precision("high") 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | def main(): 25 | from argparse import ArgumentParser 26 | 27 | global device 28 | console = Console() 29 | parser = ArgumentParser() 30 | parser.add_argument( 31 | "--model", type=str, required=True, default="google/gemma-2-2b-it" 32 | ) 33 | parser.add_argument( 34 | "--dtype", 35 | type=str, 36 | choices=["float16", "bfloat16", "float32"], 37 | default="bfloat16", 38 | ) 39 | parser.add_argument("--max_length", type=int, default=4096) 40 | parser.add_argument("--context_length", type=int) 41 | parser.add_argument("--device", type=str, default=device.type) 42 | parser.add_argument("--top_p", type=float, default=0.95) 43 | parser.add_argument("--top_k", type=int, default=40) 44 | parser.add_argument("--min_p", type=int, default=0) 45 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 46 | parser.add_argument("--seed", type=int) 47 | parser.add_argument("--print_back", action="store_true") 48 | parser.add_argument("--go_back", action="store_true") 49 | args = parser.parse_args() 50 | device = torch.device(args.device) 51 | print(f"Using device: {device}") 52 | if not is_supported_model(args.model): 53 | raise ValueError("Unsupported model") 54 | dtype = getattr(torch, args.dtype) 55 | weights = AutoModelForCausalLM.from_pretrained( 56 | args.model, 57 | device_map=device, 58 | torch_dtype=dtype, 59 | ).eval() 60 | tokenizer = AutoTokenizer.from_pretrained(args.model) 61 | conv = [] 62 | while True: 63 | console.print("User: ", end="", style="green") 64 | inp = input("").strip() 65 | if inp == "exit": 66 | break 67 | elif inp == "clear": 68 | conv.clear() 69 | continue 70 | conv.append({"role": "user", "content": inp}) 71 | inputs = tokenizer.apply_chat_template( 72 | conv, return_tensors="pt", add_generation_prompt=True 73 | ) 74 | it = generate( 75 | weights, 76 | inputs, 77 | device, 78 | dtype, 79 | [tokenizer.eos_token_id], 80 | args.max_length, 81 | args.top_p, 82 | args.top_k, 83 | args.min_p, 84 | args.repetition_penalty, 85 | args.seed, 86 | args.go_back, 87 | args.context_length, 88 | ) 89 | console.print("Assistant: ", end="", style="green") 90 | text = "" 91 | for token in stream(it, tokenizer): 92 | if "text" in token: 93 | style = "" 94 | if token["entropy"] > 3: 95 | style = "bold" 96 | elif token["varentropy"] > 15: 97 | style += "blue" 98 | console.print(token["text"], style=style, end="") 99 | text += token["text"] 100 | elif "back" in token: 101 | if args.print_back: 102 | console.print("⌫", style="red", end="") 103 | else: 104 | console.print("\b \b", end="") 105 | conv.append({"role": "assistant", "content": text.strip()}) 106 | console.print() 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /entropixing/generate.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .model import forward 4 | from .sampler import sample, calculate_metrics 5 | from .kv_cache import KVCache 6 | from .utils import build_attn_mask, precompute_freqs_cis 7 | import torch 8 | from transformers import AutoModelForCausalLM, AutoConfig 9 | 10 | 11 | @torch.inference_mode() 12 | def generate( 13 | weights: AutoModelForCausalLM, 14 | tokens: torch.tensor, 15 | device: torch.device, 16 | dtype: torch.dtype, 17 | stop_tokens: list[int], 18 | max_length: int, 19 | top_p: float = 0.95, 20 | top_k: int = 27, 21 | min_p: int = 0, 22 | repetition_penalty: float = 1.0, 23 | seed: Optional[int] = None, 24 | go_back: bool = True, 25 | context_len: int = None, 26 | ): 27 | if seed is not None: 28 | generator = torch.Generator(device=device).manual_seed(seed) 29 | else: 30 | generator = None 31 | config: AutoConfig = weights.config 32 | gen_tokens = None 33 | cur_pos = 0 34 | bsz, seqlen = tokens.shape 35 | attn_mask = build_attn_mask( 36 | seqlen, 37 | cur_pos, 38 | device=device, 39 | ) 40 | if not hasattr(config, "head_dim"): 41 | setattr(config, "head_dim", config.hidden_size // config.num_attention_heads) 42 | weights.config = config 43 | freqs_cis = precompute_freqs_cis( 44 | config.head_dim, 45 | config.max_position_embeddings, 46 | config.rope_theta, 47 | (config.rope_scaling is not None if hasattr(config, "rope_scaling") else False), 48 | device=device, 49 | ) 50 | kvcache = KVCache( 51 | config.num_hidden_layers, 52 | bsz, 53 | context_len or config.max_position_embeddings, 54 | config.num_key_value_heads, 55 | config.head_dim, 56 | device, 57 | dtype, 58 | ).to(device) 59 | logits, kvcache, scores, _stats = forward( 60 | weights, 61 | config, 62 | tokens, 63 | cur_pos, 64 | freqs_cis[:seqlen], 65 | kvcache, 66 | attn_mask=attn_mask, 67 | device=device, 68 | ) 69 | next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True).to(torch.int32) 70 | gen_tokens = next_token 71 | metrics = calculate_metrics(logits, scores) 72 | ent, vent = metrics.logits_entropy, metrics.logits_varentropy 73 | yield { 74 | "token": next_token.item(), 75 | "entropy": ent.item(), 76 | "varentropy": vent.item(), 77 | } 78 | cur_pos = seqlen 79 | stop = torch.tensor(stop_tokens, device=device, dtype=torch.int32) 80 | num_recent_deletes = 0 81 | should_noise = False 82 | while cur_pos < max_length: 83 | cur_pos += 1 84 | logits, kvcache, scores, _stats = forward( 85 | weights, 86 | config, 87 | next_token, 88 | cur_pos, 89 | freqs_cis[cur_pos : cur_pos + 1], 90 | kvcache, 91 | device=device, 92 | ) 93 | if should_noise: 94 | logits = logits + torch.randn_like(logits) * 0.1 95 | metrics = calculate_metrics(logits, scores) 96 | ent, vent = metrics.logits_entropy, metrics.logits_varentropy 97 | del metrics 98 | 99 | # basic weighting to prevent backspacing too much 100 | threshold = 5.0 + 2 * num_recent_deletes 101 | if ent > threshold and vent > threshold and cur_pos > seqlen + 4 and go_back: 102 | # backspace and pop the last token 103 | num_recent_deletes += 1 104 | # reset to the position before the last token, regenerate the token 105 | cur_pos -= 2 106 | next_token = gen_tokens[:, -2].unsqueeze(0) 107 | gen_tokens = gen_tokens[:, :-1] 108 | yield {"back": True} 109 | should_noise = True 110 | continue 111 | else: 112 | num_recent_deletes = max(0, num_recent_deletes - 0.5) 113 | should_noise = False 114 | 115 | temperature = 0.7 + (0.5 * num_recent_deletes) 116 | next_token = sample( 117 | gen_tokens, 118 | logits, 119 | scores, 120 | temperature=temperature, 121 | top_p=top_p, 122 | top_k=top_k, 123 | min_p=min_p, 124 | repetition_penalty=repetition_penalty, 125 | generator=generator, 126 | ) 127 | gen_tokens = torch.cat((gen_tokens, next_token), dim=1) 128 | yield { 129 | "token": next_token.tolist()[0][0], 130 | "entropy": ent.item(), 131 | "varentropy": vent.item(), 132 | } 133 | if torch.isin(next_token, stop).any(): 134 | break 135 | 136 | 137 | def is_valid_str(s: str): 138 | try: 139 | s.encode("utf-8").decode("utf-8") 140 | return True 141 | except UnicodeDecodeError: 142 | return False 143 | 144 | 145 | def stream(it, tokenizer): 146 | text_cache = [] 147 | for token in it: 148 | if "back" in token: 149 | yield {"back": True} 150 | else: 151 | text_cache.append(token["token"]) 152 | dec = tokenizer.decode(text_cache, skip_special_tokens=True) 153 | if is_valid_str(dec): 154 | yield { 155 | "text": dec, 156 | "entropy": token["entropy"], 157 | "varentropy": token["varentropy"], 158 | } 159 | text_cache = [] 160 | -------------------------------------------------------------------------------- /server_llama_cpp.py: -------------------------------------------------------------------------------- 1 | from pydantic import TypeAdapter 2 | from transformers import AutoTokenizer 3 | from llama_cpp import Llama, GGML_TYPE_Q4_0 4 | import json 5 | from entropixing.llama_cpp_impl import generate_response 6 | from asyncio import Lock 7 | from uvicorn import run 8 | from fastapi import FastAPI, Response, Request 9 | from fastapi.responses import StreamingResponse 10 | from openai.types.model import Model 11 | from openai.types.chat import ChatCompletionChunk, ChatCompletion, ChatCompletionMessage 12 | from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta 13 | from openai.types.chat.chat_completion import Choice as NostreamChoice 14 | from openai.types.chat.completion_create_params import ( 15 | CompletionCreateParams, 16 | ) 17 | from uuid import uuid4 18 | import time 19 | 20 | adapter: TypeAdapter[CompletionCreateParams] = TypeAdapter(CompletionCreateParams) 21 | 22 | 23 | def main(): 24 | from argparse import ArgumentParser 25 | 26 | global weights 27 | global tokenizer 28 | global lock 29 | lock = Lock() 30 | parser = ArgumentParser() 31 | parser.add_argument("--port", type=int, default=8000) 32 | parser.add_argument("--host", type=str, default="0.0.0.0") 33 | parser.add_argument("--tokenizer", type=str, default="google/gemma-2-2b-it") 34 | parser.add_argument("--model", type=str, required=True, default="./model.gguf") 35 | parser.add_argument("--max_length", type=int, default=4096) 36 | parser.add_argument("--context_length", type=int, default=16384) 37 | parser.add_argument("--ngl", type=int, default=0) 38 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 39 | args = parser.parse_args() 40 | weights = Llama( 41 | args.model, 42 | n_gpu_layers=args.ngl, 43 | n_ctx=args.context_length, 44 | verbose=False, 45 | flash_attn=True, 46 | type_k=GGML_TYPE_Q4_0, 47 | type_v=GGML_TYPE_Q4_0, 48 | ) 49 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 50 | app = FastAPI() 51 | 52 | @app.post("/chat/completions") 53 | async def chat_completion(body: Request) -> Response: 54 | j = adapter.validate_python(await body.json()) 55 | max_length = j.get("max_completion_tokens") or args.max_length 56 | messages = list(j["messages"]) 57 | if j.get("stream") == True: 58 | 59 | async def stream_generator(): 60 | async for chunk in gen( 61 | messages, 62 | max_length, 63 | ): 64 | yield f"data: {ChatCompletionChunk( 65 | id=str(uuid4()), 66 | choices=[Choice(delta=ChoiceDelta(content=chunk), index=0)], 67 | created=time.time() // 1000, 68 | model=j["model"], 69 | object="chat.completion.chunk", 70 | ).model_dump_json()}\n\n" 71 | yield f"data: {ChatCompletionChunk( 72 | id=str(uuid4()), 73 | choices=[ 74 | Choice(delta=ChoiceDelta(), finish_reason="stop", index=0) 75 | ], 76 | created=time.time() // 1000, 77 | model=j["model"], 78 | object="chat.completion.chunk", 79 | ).model_dump_json()}\n\n" 80 | 81 | return StreamingResponse( 82 | content=stream_generator(), media_type="text/event-stream" 83 | ) 84 | else: 85 | text = await gen_no_stream( 86 | messages, 87 | max_length, 88 | ) 89 | return Response( 90 | content=ChatCompletion( 91 | id=str(uuid4()), 92 | choices=[ 93 | NostreamChoice( 94 | finish_reason="stop", 95 | message=ChatCompletionMessage( 96 | content=text, role="assistant" 97 | ), 98 | index=0, 99 | ) 100 | ], 101 | created=time.time() // 1000, 102 | model=j["model"], 103 | object="chat.completion", 104 | ).model_dump_json(), 105 | media_type="application/json", 106 | ) 107 | 108 | @app.get("/models") 109 | async def models(): 110 | return Response( 111 | content=json.dumps( 112 | { 113 | "data": [ 114 | json.loads( 115 | Model( 116 | id="entropix-any", 117 | object="model", 118 | created=1, 119 | owned_by="someone", 120 | ).model_dump_json() 121 | ) 122 | ] 123 | } 124 | ), 125 | media_type="application/json", 126 | ) 127 | 128 | run(app, host=args.host, port=args.port) 129 | 130 | 131 | async def gen_no_stream( 132 | conv, 133 | max_length, 134 | stop=None, 135 | ): 136 | text = "" 137 | async for chunk in gen( 138 | conv, 139 | max_length, 140 | stop, 141 | ): 142 | text += chunk 143 | return text 144 | 145 | 146 | async def gen(conv, max_length, stop=None): 147 | inputs = ( 148 | tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)[ 149 | len(tokenizer.bos_token) if tokenizer.bos_token else 0 : 150 | ] 151 | if isinstance(conv, list) 152 | else conv 153 | ) 154 | stops = [tokenizer.eos_token] 155 | if stop: 156 | stops.extend(stop) 157 | async with lock: 158 | it = generate_response( 159 | weights, 160 | inputs, 161 | max_new_tokens=max_length, 162 | stop=stops, 163 | ) 164 | for token in it: 165 | print(token, end="", flush=True) 166 | yield token 167 | 168 | 169 | if __name__ == "__main__": 170 | main() 171 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | from pydantic import TypeAdapter 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, TorchAoConfig 3 | import torch 4 | import json 5 | from entropixing.generate import generate, stream 6 | from asyncio import Lock 7 | from uvicorn import run 8 | from fastapi import FastAPI, Response, Request 9 | from fastapi.responses import StreamingResponse 10 | from openai.types.model import Model 11 | from openai.types.chat import ChatCompletionChunk, ChatCompletion, ChatCompletionMessage 12 | from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta 13 | from openai.types.chat.chat_completion import Choice as NostreamChoice 14 | from openai.types.chat.completion_create_params import CompletionCreateParams 15 | from uuid import uuid4 16 | import time 17 | 18 | from entropixing.utils import is_supported_model 19 | 20 | if torch.backends.mps.is_available(): 21 | device = torch.device("mps") 22 | elif torch.cuda.is_available(): 23 | device = torch.device("cuda") 24 | else: 25 | device = torch.device("cpu") 26 | 27 | print(f"Default device: {device}") 28 | 29 | torch.set_float32_matmul_precision("high") 30 | torch.backends.cudnn.benchmark = True 31 | adapter: TypeAdapter[CompletionCreateParams] = TypeAdapter(CompletionCreateParams) 32 | 33 | 34 | def main(): 35 | from argparse import ArgumentParser 36 | 37 | global dtype 38 | global device 39 | global weights 40 | global tokenizer 41 | global lock 42 | lock = Lock() 43 | parser = ArgumentParser() 44 | parser.add_argument( 45 | "--model", type=str, required=True, default="google/gemma-2-2b-jpn-it" 46 | ) 47 | parser.add_argument( 48 | "--dtype", 49 | type=str, 50 | choices=["float16", "bfloat16", "float32"], 51 | default="bfloat16", 52 | ) 53 | parser.add_argument("--device", type=str, default=device.type) 54 | parser.add_argument("--port", type=int, default=8000) 55 | parser.add_argument("--host", type=str, default="0.0.0.0") 56 | parser.add_argument("--max_length", type=int, default=512) 57 | parser.add_argument("--context_length", type=int) 58 | parser.add_argument("--top_p", type=float, default=0.95) 59 | parser.add_argument("--top_k", type=int, default=40) 60 | parser.add_argument("--min_p", type=int, default=0) 61 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 62 | parser.add_argument("--seed", type=int) 63 | parser.add_argument("--quantize", action="store_true") 64 | args = parser.parse_args() 65 | device = torch.device(args.device) 66 | print(f"Using device: {device}") 67 | if not is_supported_model(args.model): 68 | raise ValueError("Unsupported model") 69 | dtype = getattr(torch, args.dtype) 70 | weights = AutoModelForCausalLM.from_pretrained( 71 | args.model, 72 | device_map=device, 73 | torch_dtype=dtype, 74 | quantization_config=( 75 | TorchAoConfig("int4_weight_only", ["self_attn"], group_size=64) 76 | if args.quantize 77 | else None 78 | ), 79 | ).eval() 80 | tokenizer = AutoTokenizer.from_pretrained(args.model) 81 | app = FastAPI() 82 | 83 | @app.post("/chat/completions") 84 | async def chat_completion(body: Request) -> Response: 85 | j = adapter.validate_python(await body.json()) 86 | max_length = j.get("max_completion_tokens") or args.max_length 87 | messages = list(j["messages"]) 88 | top_p = j.get("top_p", args.top_p) 89 | top_k = j.get("top_logprobs", args.top_k) 90 | min_p = args.min_p 91 | repetition_penalty = j.get("frequency_penalty", args.repetition_penalty) 92 | seed = j.get("seed", args.seed) 93 | if j.get("stream"): 94 | 95 | async def stream_generator(): 96 | async for chunk in gen( 97 | messages, 98 | max_length, 99 | top_p, 100 | top_k, 101 | min_p, 102 | repetition_penalty, 103 | seed, 104 | args.context_length, 105 | ): 106 | if "text" in chunk: 107 | yield f"data: {ChatCompletionChunk( 108 | id=str(uuid4()), 109 | choices=[Choice(delta=ChoiceDelta(content=chunk["text"]), index=0)], 110 | created=time.time() // 1000, 111 | model=j["model"], 112 | object="chat.completion.chunk", 113 | ).model_dump_json()}\n\n" 114 | else: 115 | yield f"data: {ChatCompletionChunk( 116 | id=str(uuid4()), 117 | choices=[Choice(delta=ChoiceDelta(content="⌫"), index=0)], 118 | created=time.time() // 1000, 119 | model=j["model"], 120 | object="chat.completion.chunk", 121 | ).model_dump_json()}\n\n" 122 | yield f"data: {ChatCompletionChunk( 123 | id=str(uuid4()), 124 | choices=[ 125 | Choice(delta=ChoiceDelta(), finish_reason="stop", index=0) 126 | ], 127 | created=time.time() // 1000, 128 | model=j["model"], 129 | object="chat.completion.chunk", 130 | ).model_dump_json()}\n\n" 131 | 132 | return StreamingResponse( 133 | content=stream_generator(), media_type="text/event-stream" 134 | ) 135 | else: 136 | text = await gen_no_stream( 137 | messages, 138 | max_length, 139 | top_p, 140 | top_k, 141 | min_p, 142 | repetition_penalty, 143 | seed, 144 | args.context_length, 145 | ) 146 | return Response( 147 | content=ChatCompletion( 148 | id=str(uuid4()), 149 | choices=[ 150 | NostreamChoice( 151 | finish_reason="stop", 152 | message=ChatCompletionMessage( 153 | content=text, role="assistant" 154 | ), 155 | index=0, 156 | ) 157 | ], 158 | created=time.time() // 1000, 159 | model=j["model"], 160 | object="chat.completion", 161 | ).model_dump_json(), 162 | media_type="application/json", 163 | ) 164 | 165 | @app.get("/models") 166 | async def models(): 167 | return Response( 168 | content=json.dumps( 169 | { 170 | "data": [ 171 | json.loads( 172 | Model( 173 | id="entropix-any", 174 | object="model", 175 | created=1, 176 | owned_by="someone", 177 | ).model_dump_json() 178 | ) 179 | ] 180 | } 181 | ), 182 | media_type="application/json", 183 | ) 184 | 185 | run(app, host=args.host, port=args.port) 186 | 187 | 188 | async def gen_no_stream( 189 | conv, 190 | max_length, 191 | top_p, 192 | top_k, 193 | min_p, 194 | repetition_penalty, 195 | seed, 196 | context_length, 197 | ): 198 | text = "" 199 | async for chunk in gen( 200 | conv, max_length, top_p, top_k, min_p, repetition_penalty, seed, context_length 201 | ): 202 | if "text" in chunk: 203 | text += chunk["text"] 204 | return text 205 | 206 | 207 | async def gen( 208 | conv, 209 | max_length, 210 | top_p, 211 | top_k, 212 | min_p, 213 | repetition_penalty, 214 | seed, 215 | context_length, 216 | ): 217 | inputs = tokenizer.apply_chat_template( 218 | conv, return_tensors="pt", add_generation_prompt=True 219 | ) 220 | async with lock: 221 | it = generate( 222 | weights, 223 | inputs, 224 | device, 225 | dtype, 226 | [tokenizer.eos_token_id], 227 | max_length, 228 | top_p, 229 | top_k, 230 | min_p, 231 | repetition_penalty, 232 | seed, 233 | False, 234 | context_length, 235 | ) 236 | for token in stream(it, tokenizer): 237 | yield token 238 | 239 | 240 | if __name__ == "__main__": 241 | main() 242 | -------------------------------------------------------------------------------- /entropixing/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from .kv_cache import KVCache 5 | from .attn_stats import AttnStats 6 | from transformers import ( 7 | PretrainedConfig, 8 | AutoModelForCausalLM, 9 | Gemma2ForCausalLM, 10 | LlamaForCausalLM, 11 | Qwen2ForCausalLM, 12 | MistralForCausalLM, 13 | Phi3ForCausalLM, 14 | ) 15 | 16 | DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max) 17 | 18 | from typing import Tuple, Optional 19 | 20 | 21 | def apply_rotary_emb( 22 | xq: torch.Tensor, 23 | xk: torch.Tensor, 24 | freqs_cis: torch.Tensor, 25 | dtype: torch.dtype = torch.float32, 26 | ) -> Tuple[torch.Tensor, torch.Tensor]: 27 | reshape_xq = xq.float().reshape(*xq.shape[:-1], -1, 2) 28 | reshape_xk = xk.float().reshape(*xk.shape[:-1], -1, 2) 29 | xq_ = torch.complex(reshape_xq[..., 0], reshape_xq[..., 1]) 30 | xk_ = torch.complex(reshape_xk[..., 0], reshape_xk[..., 1]) 31 | xq_out = xq_ * freqs_cis.unsqueeze(0).unsqueeze(2) 32 | xk_out = xk_ * freqs_cis.unsqueeze(0).unsqueeze(2) 33 | xq_out = torch.stack((xq_out.real, xq_out.imag), dim=-1).reshape( 34 | *xq_out.shape[:-1], -1 35 | ) 36 | xk_out = torch.stack((xk_out.real, xk_out.imag), dim=-1).reshape( 37 | *xk_out.shape[:-1], -1 38 | ) 39 | return xq_out.to(dtype), xk_out.to(dtype) 40 | 41 | 42 | def reverse_permute( 43 | tensor: torch.Tensor, n_heads: int = 32, dim1: int = 4096, dim2: int = 4096 44 | ) -> torch.Tensor: 45 | return ( 46 | tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2) 47 | .transpose(1, 2) 48 | .reshape(dim1, dim2) 49 | ) 50 | 51 | 52 | def linear( 53 | emb: torch.Tensor, 54 | nah: int, 55 | linear: torch.nn.Linear, 56 | ) -> torch.Tensor: 57 | weight = linear.weight 58 | weight = reverse_permute( 59 | weight, n_heads=nah, dim1=weight.size(-2), dim2=weight.size(-1) 60 | ) 61 | bias = ( 62 | reverse_permute( 63 | linear.bias.view(1, -1), n_heads=nah, dim1=linear.bias.size(-1), dim2=1 64 | ).squeeze() 65 | if linear.bias is not None 66 | else None 67 | ) 68 | return F.linear(emb, weight, bias) 69 | 70 | 71 | def attention( 72 | weights: AutoModelForCausalLM, 73 | x: torch.Tensor, 74 | layer_weights, 75 | model_params: PretrainedConfig, 76 | cur_pos: int, 77 | layer_idx: int, 78 | freqs_cis: torch.Tensor, 79 | kvcache: KVCache, 80 | attn_mask: Optional[torch.Tensor] = None, 81 | ) -> Tuple[torch.Tensor, KVCache, torch.Tensor]: 82 | bsz, q_len, _ = x.shape 83 | n_rep = model_params.num_attention_heads // model_params.num_key_value_heads 84 | if isinstance(weights, Phi3ForCausalLM): 85 | qkv = linear( 86 | x, 87 | model_params.num_attention_heads 88 | + model_params.num_key_value_heads * 2, # Total heads * 3 for Q, K, V 89 | layer_weights.qkv_proj, 90 | ) 91 | xq, xk, xv = torch.chunk(qkv, 3, dim=-1) 92 | xq = xq.reshape( 93 | bsz, q_len, model_params.num_attention_heads, model_params.head_dim 94 | ) 95 | xk = xk.reshape( 96 | bsz, q_len, model_params.num_key_value_heads, model_params.head_dim 97 | ) 98 | xv = torch.chunk(layer_weights.qkv_proj(x), 3, dim=-1)[2].reshape( 99 | bsz, q_len, model_params.num_key_value_heads, model_params.head_dim 100 | ) 101 | else: 102 | xq = linear(x, model_params.num_attention_heads, layer_weights.q_proj).reshape( 103 | bsz, q_len, model_params.num_attention_heads, model_params.head_dim 104 | ) 105 | xk = linear(x, model_params.num_key_value_heads, layer_weights.k_proj).reshape( 106 | bsz, q_len, model_params.num_key_value_heads, model_params.head_dim 107 | ) 108 | xv = layer_weights.v_proj(x).reshape( 109 | bsz, q_len, model_params.num_key_value_heads, model_params.head_dim 110 | ) 111 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=xq.dtype) 112 | keys, values, kvcache = kvcache.update(xk, xv, layer_idx, cur_pos, n_rep) 113 | xq = torch.permute(xq, (0, 2, 1, 3)) # (bs, n_heads, seqlen, head_dim) 114 | keys = torch.permute( 115 | keys, (0, 2, 3, 1) 116 | ) # (bs, n_heads, head_dim, cache_len + seqlen) 117 | values = torch.permute( 118 | values, (0, 2, 1, 3) 119 | ) # (bs, n_heads, cache_len + seqlen, head_dim) 120 | scores = torch.matmul(xq, keys) 121 | if isinstance(weights, Gemma2ForCausalLM): 122 | scores = scores * (model_params.query_pre_attn_scalar**-0.5) 123 | if model_params.attn_logit_softcapping is not None: 124 | scores = scores / model_params.attn_logit_softcapping 125 | scores = torch.tanh(scores) 126 | scores = scores * model_params.attn_logit_softcapping 127 | elif ( 128 | isinstance(weights, LlamaForCausalLM) 129 | or isinstance(weights, MistralForCausalLM) 130 | or isinstance(weights, Qwen2ForCausalLM) 131 | or isinstance(weights, Phi3ForCausalLM) 132 | ): 133 | scores = scores / math.sqrt(model_params.head_dim) 134 | pre_scores = scores 135 | scores = pre_scores.to(torch.float32) # Always do attention softmax at float32 136 | if cur_pos == 0: 137 | scores = scores + attn_mask 138 | mask = torch.where(scores != 0.0, scores, DEFAULT_MASK_VALUE) 139 | masked_logits = torch.where( 140 | (mask >= DEFAULT_MASK_VALUE * 0.5), scores, DEFAULT_MASK_VALUE 141 | ) 142 | scores = F.softmax(masked_logits, dim=-1).to(values.dtype) 143 | if ( 144 | hasattr(model_params, "attention_dropout") 145 | and model_params.attention_dropout is not None 146 | ): 147 | scores = F.dropout(scores, p=model_params.attention_dropout) 148 | output = torch.matmul(scores, values) 149 | output = output.transpose(1, 2) 150 | output = output.reshape(xq.shape[0], xq.shape[2], -1) 151 | output = layer_weights.o_proj(output) 152 | return output, kvcache, pre_scores 153 | 154 | 155 | def forward( 156 | weights: AutoModelForCausalLM, 157 | model_params: PretrainedConfig, 158 | tokens: torch.Tensor, 159 | cur_pos: int, 160 | freqs_cis: torch.Tensor, 161 | kvcache: KVCache, 162 | attn_mask: Optional[torch.Tensor] = None, 163 | device: torch.device = "cpu", 164 | ) -> Tuple[torch.Tensor, KVCache, torch.Tensor, AttnStats]: 165 | h = weights.model.embed_tokens.weight[tokens] 166 | if isinstance(weights, Gemma2ForCausalLM): 167 | normalizer = torch.tensor(model_params.hidden_size**0.5, dtype=h.dtype) 168 | h = h * normalizer 169 | attn_stats = AttnStats.init( 170 | bsz=tokens.shape[0], 171 | n_layers=model_params.num_hidden_layers, 172 | n_heads=model_params.num_attention_heads, 173 | device=device, 174 | ) 175 | for i in range(model_params.num_hidden_layers): 176 | layer = weights.model.layers[i] 177 | if isinstance(weights, Gemma2ForCausalLM): 178 | if not bool(i % 2) and attn_mask is not None: 179 | min_dtype = torch.finfo(h.dtype).min 180 | sliding_window_mask = torch.tril( 181 | torch.ones_like(attn_mask, dtype=torch.bool), 182 | diagonal=-model_params.sliding_window, 183 | ) 184 | attn_mask = torch.where(sliding_window_mask, min_dtype, attn_mask) 185 | if attn_mask.shape[-1] <= 1: # when decoding 186 | attn_mask = attn_mask[:, :, :, -model_params.sliding_window :] 187 | norm_x = layer.input_layernorm(h) 188 | h_attn, kvcache, scores = attention( 189 | weights, 190 | norm_x, 191 | layer.self_attn, 192 | model_params, 193 | cur_pos, 194 | i, 195 | freqs_cis, 196 | kvcache, 197 | attn_mask=attn_mask, 198 | ) 199 | if isinstance(weights, Phi3ForCausalLM): 200 | h_attn = layer.resid_attn_dropout(h_attn) 201 | if ( 202 | isinstance(weights, LlamaForCausalLM) 203 | or isinstance(weights, MistralForCausalLM) 204 | or isinstance(weights, Qwen2ForCausalLM) 205 | or isinstance(weights, Phi3ForCausalLM) 206 | ): 207 | h = h + h_attn 208 | h_mlp = layer.post_attention_layernorm(h) 209 | elif isinstance(weights, Gemma2ForCausalLM): 210 | h = h + layer.post_attention_layernorm(h_attn) 211 | h_mlp = layer.pre_feedforward_layernorm(h) 212 | attn_stats = attn_stats.update(scores[:, :, -1, :], i) 213 | h_mlp = layer.mlp(h_mlp) 214 | if isinstance(weights, Gemma2ForCausalLM): 215 | h_mlp = layer.post_feedforward_layernorm(h_mlp) 216 | elif isinstance(weights, Phi3ForCausalLM): 217 | h_mlp = layer.resid_mlp_dropout(h_mlp) 218 | h = h + h_mlp 219 | logits = weights.lm_head(weights.model.norm(h)) 220 | if isinstance(weights, Gemma2ForCausalLM): 221 | if model_params.final_logit_softcapping is not None: 222 | logits = logits / model_params.final_logit_softcapping 223 | logits = torch.tanh(logits) 224 | logits = logits * model_params.final_logit_softcapping 225 | return logits, kvcache, scores, attn_stats 226 | -------------------------------------------------------------------------------- /entropixing/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Tuple, Optional 4 | from dataclasses import dataclass 5 | 6 | LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E 7 | 8 | 9 | def calculate_varentropy_logsoftmax( 10 | logits: torch.Tensor, axis: int = -1 11 | ) -> Tuple[torch.Tensor, torch.Tensor]: 12 | """Calculate the entropy and varentropy of the probability distribution using logsoftmax.""" 13 | log_probs = F.log_softmax(logits, dim=axis) 14 | probs = torch.exp(log_probs) 15 | entropy = -torch.sum(probs * log_probs, dim=axis) / LN_2 # Convert to base-2 16 | varentropy = torch.sum( 17 | probs * (log_probs / LN_2 + entropy.unsqueeze(-1)) ** 2, dim=axis 18 | ) 19 | return entropy, varentropy 20 | 21 | 22 | def multinomial_sample_one( 23 | probs_sort: torch.Tensor, generator: torch.Generator 24 | ) -> torch.Tensor: 25 | """Samples one token from a multinomial distribution with sorted probabilities.""" 26 | # Use torch.rand instead of Exponential distribution 27 | q = torch.rand(probs_sort.shape, generator=generator, device=probs_sort.device) 28 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(torch.int32) 29 | 30 | 31 | def _sample( 32 | logits: torch.Tensor, 33 | temperature=0.666, 34 | top_p=0.90, 35 | top_k=27, 36 | min_p: float = 0.0, 37 | repetition_penalty: float = 1.0, 38 | prev_tokens: torch.Tensor = None, 39 | generator: torch.Generator = None, 40 | ) -> torch.Tensor: 41 | device = logits.device 42 | bsz = logits.shape[0] 43 | logit = logits[:, -1] 44 | if repetition_penalty != 1.0 and prev_tokens is not None: 45 | score = torch.gather(logit, 1, prev_tokens) 46 | # if score < 0 then repetition penalty has to be multiplied 47 | # if score > 0 then repetition penalty has to be divided 48 | score = torch.where( 49 | score < 0, score * repetition_penalty, score / repetition_penalty 50 | ) 51 | logit.scatter_(1, prev_tokens, score) 52 | probs = F.softmax(logit / temperature, dim=-1) 53 | 54 | # Apply min_p sampling 55 | if min_p > 0.0: 56 | p_max = torch.max(probs, dim=-1, keepdim=True).values 57 | indices_to_remove = probs < (min_p * p_max) 58 | logit = torch.where( 59 | indices_to_remove, torch.full_like(logit, float("-inf")), logit 60 | ) 61 | 62 | # Apply top-k sampling 63 | top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1])) 64 | probs_sort = torch.flip(top_k_probs, dims=[-1]) 65 | probs_idx = torch.flip(top_k_indices, dims=[-1]) 66 | probs_sum = torch.cumsum(probs_sort, dim=-1) 67 | # Apply top-p sampling 68 | mask = torch.where( 69 | probs_sum - probs_sort > top_p, 70 | torch.tensor(1.0, device=device), 71 | torch.tensor(0.0, device=device), 72 | ) 73 | probs_sort = probs_sort * (1 - mask) 74 | probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdim=True) 75 | next_token = multinomial_sample_one(probs_sort, generator) 76 | # Convert next_token to int64 before using it in gather 77 | next_token_g = torch.gather( 78 | probs_idx, -1, next_token.reshape(bsz, 1).to(torch.int64) 79 | ) 80 | return next_token_g.to(torch.int32) 81 | 82 | 83 | @dataclass 84 | class CaluclateMetricsOutput: 85 | logits_entropy: torch.Tensor 86 | logits_varentropy: torch.Tensor 87 | attn_entropy: torch.Tensor 88 | attn_varentropy: torch.Tensor 89 | agreement: torch.Tensor 90 | interaction_strength: torch.Tensor 91 | 92 | 93 | def calculate_metrics(logits: torch.Tensor, attention_scores: torch.Tensor): 94 | entropy, varentropy = calculate_varentropy_logsoftmax(logits) 95 | # NB chua: filter to non-zero values because future values are always zero (causal mask) 96 | # another implementation would be to pass in or calculate the number of indices at play 97 | # this is _probably_ fine though. 98 | attention_scores = torch.where( 99 | attention_scores != 0.0, 100 | attention_scores, 101 | torch.full_like(attention_scores, float("-inf")), 102 | ) 103 | attention_probs = F.softmax(attention_scores, dim=-1) 104 | attn_entropy = -torch.sum( 105 | attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1 106 | ) 107 | attn_varentropy = torch.var(attn_entropy, dim=1) 108 | 109 | # Add a small epsilon to avoid NaN when all values are the same 110 | attn_varentropy = torch.where( 111 | torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy 112 | ) 113 | mean_attention = torch.mean(attention_probs, dim=1) 114 | agreement = torch.mean( 115 | torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2) 116 | ) 117 | non_inf_attn_scores = torch.where( 118 | attention_scores != float("-inf"), 119 | attention_scores, 120 | torch.full_like(attention_scores, torch.nan), 121 | ) 122 | interaction_strength = torch.nanmean(torch.abs(non_inf_attn_scores), dim=(1, 2, 3)) 123 | 124 | return CaluclateMetricsOutput( 125 | logits_entropy=torch.mean(entropy), 126 | logits_varentropy=torch.mean(varentropy), 127 | attn_entropy=torch.mean(attn_entropy), 128 | attn_varentropy=torch.mean(attn_varentropy), 129 | agreement=torch.mean(agreement), 130 | interaction_strength=interaction_strength, 131 | ) 132 | 133 | 134 | def adaptive_sample( 135 | logits: torch.Tensor, 136 | metrics: CaluclateMetricsOutput, 137 | _gen_tokens: torch.Tensor, 138 | n_samples: int, 139 | base_temp: float = 0.666, 140 | base_top_p: float = 0.90, 141 | base_top_k: int = 40, 142 | base_min_p: float = 0.01, 143 | generator: torch.Generator = None, 144 | ) -> torch.Tensor: 145 | logits_uncertainty = metrics.logits_entropy + metrics.logits_varentropy 146 | attn_uncertainty = metrics.attn_entropy + metrics.attn_varentropy 147 | 148 | temperature = base_temp * ( 149 | 1 + 0.3 * logits_uncertainty + 0.2 * attn_uncertainty - 2 * metrics.agreement 150 | ) 151 | top_p = torch.clamp(base_top_p * (1 + 0.1 * metrics.attn_varentropy), 0.1, 1.0) 152 | top_k = int( 153 | torch.clamp( 154 | torch.round( 155 | torch.tensor(base_top_k) 156 | * ( 157 | 1 158 | + 0.3 * metrics.interaction_strength.item() 159 | - 2 * metrics.agreement.item() 160 | ) 161 | ), 162 | min=1, 163 | max=100, 164 | ).item() 165 | ) 166 | min_p = torch.clamp(base_min_p * (2 - 00.5 * logits_uncertainty), 0.01, 0.5) 167 | samples = [] 168 | for _ in range(n_samples): 169 | sample = _sample( 170 | logits, 171 | temperature=temperature, 172 | top_p=top_p, 173 | top_k=top_k, 174 | min_p=min_p, 175 | generator=generator, 176 | ) 177 | samples.append(sample) 178 | 179 | def score_sample(sample): 180 | # Flatten the sample tensor and convert to long (int64) 181 | sample_flat = sample.flatten().to(torch.long) 182 | 183 | # Create one-hot encoding 184 | one_hot = F.one_hot(sample_flat, logits.shape[-1]) 185 | 186 | # Reshape log_softmax output to match one_hot 187 | log_probs = F.log_softmax(logits, dim=-1).view(-1, logits.shape[-1]) 188 | 189 | # Calculate log probability 190 | log_prob = torch.sum(log_probs * one_hot) 191 | 192 | confidence_score = ( 193 | (1 - metrics.logits_entropy) * 0.1 194 | + (1 - metrics.attn_entropy) * 0.2 195 | + (1 - metrics.logits_varentropy) * 0.3 196 | + (1 - metrics.attn_varentropy) * 0.4 197 | + metrics.agreement * 0.5 198 | + metrics.interaction_strength * 0.6 199 | ) 200 | return log_prob + confidence_score 201 | 202 | sample_scores = torch.stack([score_sample(sample) for sample in samples]) 203 | best_sample_idx = torch.argmax(sample_scores) 204 | return samples[best_sample_idx] 205 | 206 | 207 | def sample( 208 | gen_tokens: torch.Tensor, 209 | logits: torch.Tensor, 210 | attention_scores: torch.Tensor, 211 | temperature=0.666, 212 | top_p=0.90, 213 | top_k=27, 214 | min_p: float = 0.0, 215 | repetition_penalty: float = 1.0, 216 | generator: Optional[torch.Generator] = None, 217 | ) -> torch.Tensor: 218 | device = logits.device 219 | if generator is None: 220 | generator = torch.Generator(device=device).manual_seed(42) 221 | metrics = calculate_metrics(logits, attention_scores) 222 | ent, vent = metrics.logits_entropy, metrics.logits_varentropy 223 | attn_ent, attn_vent = metrics.attn_entropy, metrics.attn_varentropy 224 | agreement = metrics.agreement 225 | interaction_strength = metrics.interaction_strength 226 | 227 | # Low Entropy, Low Varentropy: "flowing with unspoken intent" 228 | if ent < 0.1 and vent < 0.1: 229 | return torch.argmax(logits[:, -1], dim=-1, keepdim=True).to(torch.int32) 230 | 231 | # High Entropy, Low Varentropy: "treading carefully, asking clarifying questions" 232 | elif ent > 3.0 and vent < 0.1: 233 | # Insert a clarifying question token if not already present 234 | if not torch.isin(gen_tokens[:, -1], torch.tensor([2564], device=device)).any(): 235 | return torch.tensor( 236 | [[2564]], dtype=torch.int32, device=device 237 | ) # Assuming 2564 is our "ask clarifying question" token 238 | else: 239 | # If we've just asked a question, sample with slightly higher temperature 240 | temp_adj = ( 241 | 1.3 + 0.2 * attn_ent 242 | ) # Increase temperature based on attention entropy 243 | return _sample( 244 | logits, 245 | temperature=min(1.5, temperature * temp_adj), 246 | top_p=top_p, 247 | top_k=top_k, 248 | min_p=min_p, 249 | repetition_penalty=repetition_penalty, 250 | generator=generator, 251 | ) 252 | 253 | # Low Entropy, High Varentropy: "exploring forks in the path" 254 | elif ent < 5.0 and vent > 5.0: 255 | temp_adj = ( 256 | 1.2 + 0.03 * interaction_strength 257 | ) # Increase temperature based on interaction strength 258 | top_k_adj = max( 259 | 5, int(top_k * (1 + 0.5 * (1 - agreement))) 260 | ) # Increase top_k when agreement is low 261 | return _sample( 262 | logits, 263 | temperature=min(1.5, temperature * temp_adj), 264 | top_p=top_p, 265 | top_k=top_k_adj, 266 | min_p=min_p, 267 | generator=generator, 268 | ) 269 | 270 | # High Entropy, High Varentropy: "resampling in the mist" 271 | elif ent > 5.0 and vent > 5.0: 272 | # Use high temperature and adjusted top_p based on attention metrics 273 | temp_adj = ( 274 | 2.0 + 0.5 * attn_vent 275 | ) # Increase temperature based on attention varentropy 276 | top_p_adj = max( 277 | 0.5, top_p - 0.2 * attn_ent 278 | ) # Decrease top_p when attention entropy is high 279 | return _sample( 280 | logits, 281 | temperature=max(2.0, temperature * temp_adj), 282 | top_p=top_p_adj, 283 | top_k=top_k, 284 | min_p=min_p, 285 | generator=generator, 286 | ) 287 | 288 | # Middle ground: use adaptive sampling 289 | else: 290 | return adaptive_sample( 291 | logits, 292 | metrics, 293 | gen_tokens, 294 | n_samples=5, 295 | base_temp=temperature, 296 | base_top_p=top_p, 297 | base_top_k=top_k, 298 | generator=generator, 299 | ) 300 | -------------------------------------------------------------------------------- /entropixing/llama_cpp_impl.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/EdwardDali/EntropixLab/blob/main/main.py 2 | from llama_cpp import LogitsProcessorList, LogitsProcessor, Llama 3 | from typing import List, Dict 4 | from enum import Enum 5 | 6 | try: 7 | import cupy as np 8 | import cupyx.scipy.special as sp 9 | 10 | use_cupy = True 11 | except ImportError: 12 | import numpy as np 13 | import scipy.special as sp 14 | 15 | use_cupy = False 16 | from collections import Counter, deque 17 | 18 | LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E 19 | 20 | 21 | class SamplerState(Enum): 22 | ARGMAX = 0 23 | SAMPLE = 1 24 | INSERT_COT = 2 25 | RESAMPLE = 3 26 | ADAPTIVE = 4 # New adaptive sampling strategy 27 | 28 | 29 | class SamplerConfig: 30 | def __init__(self): 31 | self.entropy_threshold = 1.0 32 | self.varentropy_threshold = 1.5 33 | self.cot_token = "[COT]" 34 | self.resample_count = 5 35 | self.strategy_params: Dict[SamplerState, Dict[str, float]] = { 36 | SamplerState.ARGMAX: { 37 | "temperature": 0.1, 38 | "top_p": 1.0, 39 | "top_k": 1, 40 | "min_p": 0.0, 41 | }, 42 | SamplerState.SAMPLE: { 43 | "temperature": 0.7, 44 | "top_p": 0.9, 45 | "top_k": 50, 46 | "min_p": 0.02, 47 | }, 48 | SamplerState.INSERT_COT: { 49 | "temperature": 0.8, 50 | "top_p": 0.95, 51 | "top_k": 100, 52 | "min_p": 0.01, 53 | }, 54 | SamplerState.RESAMPLE: { 55 | "temperature": 1.0, 56 | "top_p": 0.98, 57 | "top_k": 200, 58 | "min_p": 0.005, 59 | }, 60 | SamplerState.ADAPTIVE: { 61 | "temperature": 0.666, 62 | "top_p": 0.90, 63 | "top_k": 27, 64 | "min_p": 0.03, 65 | }, 66 | } 67 | self.repetition_penalty = 1.2 68 | self.max_ngram_size = 5 69 | self.max_ngram_repeat = 3 70 | self.strategy_change_batch_size = 5 71 | self.window_size = 50 # Size of the sliding window for weighted average 72 | self.decay_factor = 0.95 # Exponential decay factor for weighting 73 | 74 | # Adaptive sampling parameters 75 | self.n_adaptive_samples = 5 76 | self.ada_temp_logits = 0.3 77 | self.ada_temp_attn = 0.2 78 | self.ada_temp_agree = 0.2 79 | self.ada_top_p = 0.1 80 | self.ada_top_k_int = 0.3 81 | self.ada_top_k_agree = 0.2 82 | self.ada_min_p = 0.5 83 | self.ada_score_logits_ent = 0.1 84 | self.ada_score_attn_ent = 0.2 85 | self.ada_score_logits_vent = 0.3 86 | self.ada_score_attn_vent = 0.4 87 | self.ada_score_agree = 0.5 88 | self.ada_score_int = 0.6 89 | 90 | 91 | class VarentropyLogitsProcessor(LogitsProcessor): 92 | def __init__(self, config: SamplerConfig): 93 | self.config = config 94 | self.strategy_counter = Counter() 95 | self.recent_tokens = deque(maxlen=100) 96 | self.current_batch = [] 97 | self.current_strategy = SamplerState.SAMPLE 98 | self.tokens_since_last_change = 0 99 | self.entropy_window = deque(maxlen=self.config.window_size) 100 | self.varentropy_window = deque(maxlen=self.config.window_size) 101 | 102 | def __call__(self, input_ids: List[int], logits: np.ndarray) -> List[float]: 103 | if use_cupy: 104 | logits = np.asarray(logits) 105 | # Calculate entropy and varentropy for the current token 106 | entropy, varentropy = self.calculate_varentropy_logsoftmax(logits) 107 | self.entropy_window.append(entropy) 108 | self.varentropy_window.append(varentropy) 109 | 110 | # Check if it's time to recalculate the strategy 111 | if self.tokens_since_last_change % self.config.strategy_change_batch_size == 0: 112 | avg_entropy = self.weighted_average(self.entropy_window) 113 | avg_varentropy = self.weighted_average(self.varentropy_window) 114 | 115 | self.current_strategy = self.determine_strategy(avg_entropy, avg_varentropy) 116 | self.tokens_since_last_change = 0 117 | 118 | # Use the current strategy to sample 119 | if self.current_strategy == SamplerState.ADAPTIVE: 120 | sampled_token = self._adaptive_sample(logits) 121 | else: 122 | params = self.config.strategy_params[self.current_strategy] 123 | sampled_token = self._sample(logits, **params) 124 | 125 | # Update counters and lists 126 | self.strategy_counter[self.current_strategy.name] += 1 127 | self.tokens_since_last_change += 1 128 | self.current_batch.append(sampled_token) 129 | self.recent_tokens.append(sampled_token) 130 | 131 | # Check for n-gram repetition in the current batch 132 | if self.check_ngram_repetition(self.current_batch): 133 | # Increase temperature and top_k to encourage diversity 134 | temp_config = SamplerConfig() 135 | temp_config.strategy_params[SamplerState.SAMPLE]["temperature"] = 1.2 136 | temp_config.strategy_params[SamplerState.SAMPLE]["top_k"] = 100 137 | sampled_token = self._sample( 138 | logits, **temp_config.strategy_params[SamplerState.SAMPLE] 139 | ) 140 | 141 | # Reset batch if it reaches the configured batch size 142 | if len(self.current_batch) == self.config.strategy_change_batch_size: 143 | self.current_batch = [] 144 | 145 | # Set all logits to negative infinity except the sampled token 146 | new_scores = [-float("inf")] * len(logits) 147 | new_scores[sampled_token] = 0 148 | 149 | return new_scores 150 | 151 | def weighted_average(self, values): 152 | if not values: 153 | return 0 154 | weights = [self.config.decay_factor**i for i in range(len(values) - 1, -1, -1)] 155 | return sum(w * v for w, v in zip(weights, values)) / sum(weights) 156 | 157 | def determine_strategy(self, entropy: float, varentropy: float) -> SamplerState: 158 | if entropy < self.config.entropy_threshold: 159 | if varentropy < self.config.varentropy_threshold: 160 | return SamplerState.ARGMAX 161 | else: 162 | return SamplerState.SAMPLE 163 | else: 164 | if varentropy < self.config.varentropy_threshold: 165 | return SamplerState.INSERT_COT 166 | elif ( 167 | varentropy > self.config.varentropy_threshold * 1.5 168 | ): # Adjust this threshold as needed 169 | return SamplerState.RESAMPLE 170 | else: 171 | return SamplerState.ADAPTIVE 172 | 173 | def calculate_varentropy_logsoftmax( 174 | self, logits: np.ndarray, axis: int = -1 175 | ) -> tuple[float, float]: 176 | log_probs = sp.log_softmax(logits, axis=axis) 177 | probs = np.exp(log_probs) 178 | entropy = -np.sum(probs * log_probs, axis=axis) / np.log(2) 179 | entropy_expanded = np.expand_dims(entropy, axis=axis) 180 | varentropy = np.sum( 181 | probs * (log_probs / np.log(2) + entropy_expanded) ** 2, axis=axis 182 | ) 183 | return float(entropy), float(varentropy) 184 | 185 | def _sample( 186 | self, 187 | logits: np.ndarray, 188 | temperature: float, 189 | top_p: float, 190 | top_k: int, 191 | min_p: float, 192 | ) -> int: 193 | # Apply temperature and convert to probabilities 194 | logits = logits / temperature 195 | # Subtract max for numerical stability 196 | logits = logits - np.max(logits) 197 | probs = np.exp(logits) / np.sum(np.exp(logits)) 198 | 199 | # Apply min_p sampling 200 | if min_p > 0.0: 201 | p_max = np.max(probs) 202 | probs[probs < (min_p * p_max)] = 0 203 | # Renormalize 204 | probs_sum = np.sum(probs) 205 | if probs_sum > 0: 206 | probs = probs / probs_sum 207 | 208 | # Apply top-k sampling 209 | if top_k > 0: 210 | top_k = min(top_k, len(probs)) 211 | indices = np.argpartition(probs, -top_k)[-top_k:] 212 | top_k_probs = probs[indices] 213 | sorted_idx = np.argsort(-top_k_probs) # Sort in descending order 214 | top_k_probs = top_k_probs[sorted_idx] 215 | indices = indices[sorted_idx] 216 | else: 217 | top_k_probs = probs 218 | indices = np.arange(len(probs)) 219 | 220 | # Apply top-p (nucleus) sampling 221 | if 0.0 < top_p < 1.0: 222 | cumulative_probs = np.cumsum(top_k_probs) 223 | cutoff_idx = np.searchsorted( 224 | cumulative_probs, np.array(top_p), side="right" 225 | ) 226 | if cutoff_idx == 0: 227 | cutoff_idx = 1 228 | top_k_probs = top_k_probs[:cutoff_idx] 229 | indices = indices[:cutoff_idx] 230 | 231 | # Renormalize 232 | top_k_probs = top_k_probs / np.sum(top_k_probs) 233 | 234 | # If all probabilities are zero, return the highest probability token 235 | if np.sum(top_k_probs) <= 0: 236 | return np.argmax(probs).tolist()[0] 237 | 238 | # Sample from the filtered distribution 239 | try: 240 | sample_idx = np.random.choice(len(top_k_probs), p=top_k_probs, size=1) 241 | return indices[sample_idx].tolist()[0] 242 | except ValueError: 243 | # If sampling fails, fall back to argmax 244 | return np.argmax(probs).tolist() 245 | 246 | def _adaptive_sample(self, logits: np.ndarray) -> int: 247 | # Calculate metrics (simplified version as we don't have access to attention scores) 248 | entropy, varentropy = self.calculate_varentropy_logsoftmax(logits) 249 | 250 | # Adaptive sampling parameters (using fixed values from config) 251 | temperature = self.config.strategy_params[SamplerState.ADAPTIVE]["temperature"] 252 | top_p = self.config.strategy_params[SamplerState.ADAPTIVE]["top_p"] 253 | top_k = self.config.strategy_params[SamplerState.ADAPTIVE]["top_k"] 254 | min_p = self.config.strategy_params[SamplerState.ADAPTIVE]["min_p"] 255 | 256 | # Sample multiple times 257 | samples = [] 258 | for _ in range(self.config.n_adaptive_samples): 259 | sample = self._sample(logits, temperature, top_p, top_k, min_p) 260 | samples.append(sample) 261 | 262 | # Score samples (simplified version) 263 | def score_sample(sample): 264 | log_prob = np.log(sp.softmax(logits, axis=-1)[sample]) 265 | confidence_score = (1 - entropy) * self.config.ada_score_logits_ent + ( 266 | 1 - varentropy 267 | ) * self.config.ada_score_logits_vent 268 | return log_prob + confidence_score 269 | 270 | sample_scores = [score_sample(sample) for sample in samples] 271 | best_sample_idx = np.argmax(np.array(sample_scores)).tolist() 272 | return samples[best_sample_idx] 273 | 274 | def check_ngram_repetition(self, tokens: List[int]) -> bool: 275 | for n in range(2, self.config.max_ngram_size + 1): 276 | ngrams = [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)] 277 | for ngram in set(ngrams): 278 | if ngrams.count(ngram) > self.config.max_ngram_repeat: 279 | return True 280 | return False 281 | 282 | 283 | def generate_response( 284 | model: Llama, 285 | prompt: str, 286 | max_new_tokens=4096, 287 | batch_size=10, 288 | stop: List[str] = [], 289 | **kwargs 290 | ): 291 | cfg = SamplerConfig() 292 | cfg.strategy_change_batch_size = batch_size 293 | logits_processor = VarentropyLogitsProcessor(cfg) 294 | logits_processors = LogitsProcessorList([logits_processor]) 295 | default_params = cfg.strategy_params[SamplerState.SAMPLE] 296 | generation_params = { 297 | "prompt": prompt, 298 | "max_tokens": max_new_tokens, 299 | "logits_processor": logits_processors, 300 | "echo": False, 301 | "temperature": default_params["temperature"], 302 | "top_p": default_params["top_p"], 303 | "top_k": default_params["top_k"], 304 | "stream": True, 305 | } 306 | for k, v in kwargs.items(): 307 | generation_params[k] = v 308 | generated_text = "" 309 | for output in model(**generation_params): 310 | token: str = output["choices"][0]["text"] 311 | generated_text += token 312 | yield token 313 | if any([x in generated_text for x in stop]): 314 | break 315 | --------------------------------------------------------------------------------