├── assets ├── evaluations.png └── lille-header.png ├── tokenizer └── Hastings.pkl ├── requirements.txt ├── .gitignore ├── huggingface ├── inference_hf.py ├── tokenizer_hf.py ├── export_hf.py └── model_hf.py ├── prepare_dataset.py ├── export_utils.py ├── sophia_triton.py ├── prepare_dataset_fineweb.py ├── gguf └── export_gguf.py ├── eval-table └── index.html ├── LICENSE ├── inference.py ├── model.py ├── README.md └── train.py /assets/evaluations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikityyy/lille/HEAD/assets/evaluations.png -------------------------------------------------------------------------------- /assets/lille-header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikityyy/lille/HEAD/assets/lille-header.png -------------------------------------------------------------------------------- /tokenizer/Hastings.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikityyy/lille/HEAD/tokenizer/Hastings.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | tiktoken 3 | html2term 4 | datasets 5 | numpy 6 | tqdm 7 | wandb 8 | onnx 9 | onnxsim 10 | onnxruntime_tools 11 | onnxruntime-gpu 12 | triton 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/best_model.pt 2 | checkpoints_ft/best_model.pt 3 | data/fineweb_edu_sample_10BT/train.npz 4 | data/fineweb_edu_sample_10BT/val.npz 5 | data/smol-sft/train.npz 6 | data/smol-sft/train.txt 7 | data/smol-sft/val.npz 8 | check.py 9 | -------------------------------------------------------------------------------- /huggingface/inference_hf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 4 | from model_hf import LilleConfig, LilleForCausalLM 5 | 6 | print("Registering custom 'lille-130m' architecture...") 7 | AutoConfig.register("lille-130m", LilleConfig) 8 | AutoModelForCausalLM.register(LilleConfig, LilleForCausalLM) 9 | print("Registration complete.") 10 | 11 | MODEL_DIR = "model" 12 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 13 | 14 | torch_dtype = torch.float32 15 | 16 | if torch.cuda.is_available(): 17 | if torch.cuda.is_bf16_supported(): 18 | torch_dtype = torch.bfloat16 19 | print("Hardware supports bfloat16, using it for better performance.") 20 | else: 21 | torch_dtype = torch.float16 22 | print("Hardware does not support bfloat16, falling back to float16.") 23 | 24 | print(f"Loading tokenizer from {MODEL_DIR}...") 25 | tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) 26 | 27 | print(f"Loading model from {MODEL_DIR}...") 28 | model = AutoModelForCausalLM.from_pretrained( 29 | MODEL_DIR, 30 | torch_dtype=torch_dtype, 31 | device_map=DEVICE, 32 | ) 33 | print("Model loaded successfully!") 34 | 35 | model.eval() 36 | 37 | print("Compiling the model with torch.compile... (This may take a moment)") 38 | model = torch.compile(model, mode="reduce-overhead", fullgraph=True) 39 | print("Model compiled successfully!") 40 | 41 | print("Performing a warmup run...") 42 | with torch.inference_mode(): 43 | _ = model.generate( 44 | tokenizer("<|startoftext|>", return_tensors="pt").input_ids.to(DEVICE), 45 | max_new_tokens=2, 46 | eos_token_id=tokenizer.eos_token_id, 47 | pad_token_id=tokenizer.pad_token_id 48 | ) 49 | print("Warmup complete.") 50 | 51 | 52 | chat = [ 53 | {"role": "user", "content": "What is the capital of France?"}, 54 | ] 55 | prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 56 | 57 | print(f"\n--- Prompt ---\n{prompt}") 58 | 59 | inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) 60 | 61 | start_time = time.time() 62 | with torch.inference_mode(): 63 | outputs = model.generate( 64 | **inputs, 65 | max_new_tokens=512, 66 | eos_token_id=tokenizer.eos_token_id, 67 | pad_token_id=tokenizer.pad_token_id, 68 | do_sample=True, 69 | temperature=0.5, 70 | top_p=0.95, 71 | use_cache=True 72 | ) 73 | end_time = time.time() 74 | 75 | response_ids = outputs[0][inputs['input_ids'].shape[1]:] 76 | response_text = tokenizer.decode(response_ids, skip_special_tokens=True) 77 | 78 | num_tokens = len(response_ids) 79 | elapsed_time = end_time - start_time 80 | tokens_per_second = num_tokens / elapsed_time if elapsed_time > 0 else float('inf') 81 | 82 | print(f"\n--- Response ---\n{response_text}") 83 | print(f"\n--- Statistics ---") 84 | print(f"Tokens generated: {num_tokens}") 85 | print(f"Time taken: {elapsed_time:.2f} seconds") 86 | print(f"Tokens/second: {tokens_per_second:.2f}") 87 | -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from tqdm import tqdm 4 | import tiktoken 5 | import pickle 6 | import gc 7 | import random 8 | from html2term import printc 9 | 10 | dataset = "smol-sft" 11 | input_file_path = f"data/{dataset}/train.txt" 12 | output_dir = f"data/{dataset}" 13 | val_split = 0.1 14 | EOT_TOKEN = "<|endoftext|>" 15 | TOKENIZER_PATH = 'tokenizer/Hastings.pkl' 16 | 17 | def load_tokenizer(tokenizer_path=TOKENIZER_PATH): 18 | """Loads the tokenizer and returns the encoding instance.""" 19 | with open(tokenizer_path, 'rb') as f: 20 | hastings = pickle.load(f) 21 | enc = tiktoken.core.Encoding(hastings.pop('name'), **hastings) 22 | printc(f"Tokenizer loaded. Vocab size: {enc.n_vocab}") 23 | try: 24 | enc.encode_single_token(EOT_TOKEN) 25 | except KeyError: 26 | raise ValueError(f"The EOT token '{EOT_TOKEN}' is not in the tokenizer vocabulary!") 27 | return enc 28 | 29 | def process_and_save(lines, enc, output_path, split_name): 30 | """Tokenizes lines and saves them in the efficient tokens/offsets format.""" 31 | printc(f"
Processing and tokenizing {len(lines):,} documents for the {split_name} split...") 32 | 33 | all_tokenized = [] 34 | for doc in tqdm(lines, desc=f"Tokenizing {split_name}"): 35 | if doc: 36 | tokens = enc.encode(doc, allowed_special='all') 37 | all_tokenized.append(tokens) 38 | 39 | total_tokens = sum(len(doc) for doc in all_tokenized) 40 | num_docs = len(all_tokenized) 41 | 42 | tokens_arr = np.empty(total_tokens, dtype=np.uint16) 43 | offsets_arr = np.empty(num_docs + 1, dtype=np.uint64) 44 | offsets_arr[0] = 0 45 | 46 | token_pos = 0 47 | for i, doc in enumerate(all_tokenized): 48 | doc_len = len(doc) 49 | tokens_arr[token_pos : token_pos + doc_len] = doc 50 | token_pos += doc_len 51 | offsets_arr[i + 1] = token_pos 52 | 53 | del all_tokenized 54 | gc.collect() 55 | 56 | printc(f" Saving {split_name} data to {output_path}...") 57 | np.savez_compressed(output_path, tokens=tokens_arr, offsets=offsets_arr) 58 | printc(f" Saved {num_docs:,} documents ({total_tokens:,} tokens).") 59 | 60 | def main(): 61 | os.makedirs(output_dir, exist_ok=True) 62 | enc = load_tokenizer() 63 | 64 | with open(input_file_path, 'r', encoding='utf-8') as f: 65 | content = f.read() 66 | 67 | docs = content.split(EOT_TOKEN) 68 | documents = [doc.strip() + EOT_TOKEN for doc in docs if doc.strip()] 69 | printc(f"Found and processed {len(documents):,} documents.") 70 | 71 | random.seed(42) 72 | random.shuffle(documents) 73 | printc("Shuffled documents randomly to ensure a representative validation split.") 74 | 75 | split_idx = int(len(documents) * (1 - val_split)) 76 | train_docs = documents[:split_idx] 77 | val_docs = documents[split_idx:] 78 | 79 | train_output_path = os.path.join(output_dir, "train.npz") 80 | val_output_path = os.path.join(output_dir, "val.npz") 81 | 82 | process_and_save(train_docs, enc, train_output_path, "train") 83 | process_and_save(val_docs, enc, val_output_path, "validation") 84 | 85 | printc("
✅ Processing complete.") 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /huggingface/tokenizer_hf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import json 4 | import tempfile 5 | 6 | from html2term import printc 7 | from transformers import GPT2Tokenizer 8 | from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode 9 | 10 | def create_tokenizer_from_custom_pickle(pickle_path: str) -> GPT2Tokenizer: 11 | printc(f"Building tokenizer from custom file: {pickle_path}") 12 | 13 | if not os.path.exists(pickle_path): 14 | raise FileNotFoundError(f"Custom tokenizer pickle file not found at {pickle_path}") 15 | 16 | with open(pickle_path, 'rb') as f: 17 | hastings_data = pickle.load(f) 18 | 19 | byte_decoder = bytes_to_unicode() 20 | final_token_to_rank = {} 21 | 22 | for token_bytes, rank in hastings_data['mergeable_ranks'].items(): 23 | token_str = "".join([byte_decoder[b] for b in token_bytes]) 24 | final_token_to_rank[token_str] = rank 25 | 26 | for token_str, rank in hastings_data['special_tokens'].items(): 27 | final_token_to_rank[token_str] = rank 28 | 29 | valid_token_strings = set(final_token_to_rank.keys()) 30 | 31 | sorted_vocab = sorted(final_token_to_rank.items(), key=lambda item: item[1]) 32 | final_vocab_dict = {token: i for i, (token, rank) in enumerate(sorted_vocab)} 33 | 34 | printc("Reconstructing BPE merges based on the new vocabulary...") 35 | base_tokenizer = GPT2Tokenizer.from_pretrained("gpt2", use_fast=False) 36 | 37 | re_ranked_merges = [] 38 | for pair, original_rank in base_tokenizer.bpe_ranks.items(): 39 | p1, p2 = pair 40 | merged_token_str = p1 + p2 41 | 42 | if p1 in valid_token_strings and p2 in valid_token_strings and merged_token_str in valid_token_strings: 43 | new_rank = final_token_to_rank[merged_token_str] 44 | re_ranked_merges.append((new_rank, pair)) 45 | 46 | re_ranked_merges.sort(key=lambda x: x[0]) 47 | 48 | final_merges_formatted = [f"{p1} {p2}" for rank, (p1, p2) in re_ranked_merges] 49 | 50 | printc(f"Target vocab size: {hastings_data['explicit_n_vocab']}. Reconstructed valid merges: {len(final_merges_formatted)}") 51 | 52 | with tempfile.TemporaryDirectory() as temp_dir: 53 | vocab_path = os.path.join(temp_dir, 'vocab.json') 54 | with open(vocab_path, 'w', encoding='utf-8') as f: 55 | json.dump(final_vocab_dict, f, ensure_ascii=False) 56 | 57 | merges_path = os.path.join(temp_dir, 'merges.txt') 58 | with open(merges_path, 'w', encoding='utf-8') as f: 59 | f.write("#version: 0.2\n") 60 | f.write("\n".join(final_merges_formatted)) 61 | 62 | config_path = os.path.join(temp_dir, 'tokenizer_config.json') 63 | tokenizer_config = { 64 | "model_max_length": 512, 65 | "add_prefix_space": True, 66 | } 67 | with open(config_path, 'w', encoding='utf-8') as f: 68 | json.dump(tokenizer_config, f, ensure_ascii=False) 69 | 70 | tokenizer = GPT2Tokenizer.from_pretrained(temp_dir) 71 | 72 | special_tokens_map = { 73 | "bos_token": "<|startoftext|>", 74 | "eos_token": "<|endoftext|>", 75 | "pad_token": "<|pad|>", 76 | "additional_special_tokens": ["<|assistant|>", "<|user|>"] 77 | } 78 | tokenizer.add_special_tokens(special_tokens_map) 79 | printc("Custom tokenizer built successfully!") 80 | return tokenizer 81 | -------------------------------------------------------------------------------- /huggingface/export_hf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from html2term import printc 5 | 6 | from tokenizer_hf import create_tokenizer_from_custom_pickle 7 | from model_hf import LilleConfig, LilleForCausalLM 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser( 12 | description="Export a trained GPT model to a local directory in Hugging Face format." 13 | ) 14 | parser.add_argument( 15 | "--checkpoint_path", 16 | type=str, 17 | default="../checkpoints_ft/best_model.pt", 18 | help="Path to the PyTorch model checkpoint (.pt file).", 19 | ) 20 | parser.add_argument( 21 | "--output_dir", 22 | type=str, 23 | required=True, 24 | help="Directory where the Hugging Face compatible model files will be saved.", 25 | ) 26 | parser.add_argument( 27 | "--tokenizer_pickle_path", 28 | type=str, 29 | default="../tokenizer/Hastings.pkl", 30 | help="Path to the custom tokenizer pickle file.", 31 | ) 32 | args = parser.parse_args() 33 | 34 | os.makedirs(args.output_dir, exist_ok=True) 35 | 36 | tokenizer = create_tokenizer_from_custom_pickle(args.tokenizer_pickle_path) 37 | 38 | printc(f"Loading checkpoint from: {args.checkpoint_path}") 39 | checkpoint = torch.load(args.checkpoint_path, map_location="cpu") 40 | model_args = checkpoint.get("model_args") 41 | if not model_args: 42 | printc("Error: 'model_args' not found in the checkpoint.") 43 | return 44 | 45 | printc(f"Using vocab size from checkpoint: {model_args['vocab_size']}") 46 | printc( 47 | f"Custom tokenizer length is: {len(tokenizer)} (This should be consistent with the checkpoint)" 48 | ) 49 | 50 | if "n_layers" in model_args: 51 | model_args["n_layer"] = model_args.pop("n_layers") 52 | if "n_heads" in model_args: 53 | model_args["n_head"] = model_args.pop("n_heads") 54 | 55 | config = LilleConfig(**model_args) 56 | model = LilleForCausalLM(config) 57 | 58 | printc("Loading model weights...") 59 | state_dict = checkpoint["model_state_dict"] 60 | 61 | if config.tie_word_embeddings: 62 | if "lm_head.weight" in state_dict: 63 | printc( 64 | "NOTE: Removing 'lm_head.weight' from state_dict due to tied weights." 65 | ) 66 | state_dict.pop("lm_head.weight") 67 | 68 | unwanted_prefix = "_orig_mod." 69 | for k, v in list(state_dict.items()): 70 | if k.startswith(unwanted_prefix): 71 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) 72 | 73 | model.transformer.load_state_dict(state_dict, strict=True) 74 | model.eval() 75 | printc( 76 | f"Model loaded successfully with {sum(p.numel() for p in model.parameters())/1e6:.2f}M parameters." 77 | ) 78 | 79 | chat_template = ( 80 | "{{ bos_token }}" 81 | "{% for message in messages %}" 82 | "{% if message['role'] == 'user' %}" 83 | "{{ '<|user|>' + message['content'] + '<|assistant|>' }}" 84 | "{% elif message['role'] == 'assistant' %}" 85 | "{{ message['content'] + eos_token }}" 86 | "{% endif %}" 87 | "{% endfor %}" 88 | ) 89 | 90 | tokenizer.chat_template = chat_template 91 | 92 | printc(f"
Saving model and tokenizer to: {args.output_dir}") 93 | model.save_pretrained(args.output_dir, safe_serialization=True) 94 | tokenizer.save_pretrained(args.output_dir) 95 | printc(f"\n✅ Export complete!") 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /huggingface/model_hf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import inspect 4 | 5 | import torch 6 | import torch.onnx 7 | 8 | from transformers import PreTrainedModel, PretrainedConfig 9 | from transformers.generation import GenerationMixin 10 | from transformers.modeling_outputs import CausalLMOutputWithPast 11 | 12 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from model import GPT as OriginalGPT, GPTConfig as OriginalGPTConfig 14 | 15 | class LilleConfig(PretrainedConfig): 16 | model_type = "lille-130m" 17 | def __init__( 18 | self, 19 | block_size: int = 512, 20 | vocab_size: int = 32768, 21 | n_layer: int = 12, 22 | n_head: int = 12, 23 | n_embd: int = 768, 24 | dropout: float = 0.1, 25 | layer_norm_eps: float = 1e-5, 26 | n_kv_heads: int | None = 4, 27 | rope_theta: float = 10000.0, 28 | **kwargs 29 | ): 30 | self.block_size = block_size 31 | self.vocab_size = vocab_size 32 | self.n_layer = n_layer 33 | self.n_head = n_head 34 | self.n_embd = n_embd 35 | self.dropout = dropout 36 | self.layer_norm_eps = layer_norm_eps 37 | self.n_kv_heads = n_kv_heads 38 | self.rope_theta = rope_theta 39 | self.tie_word_embeddings = True 40 | super().__init__(**kwargs) 41 | 42 | class LilleForCausalLM(PreTrainedModel, GenerationMixin): 43 | config_class = LilleConfig 44 | 45 | @property 46 | def main_input_name(self): 47 | return "input_ids" 48 | 49 | def __init__(self, config: LilleConfig): 50 | super().__init__(config) 51 | hf_config_dict = config.to_dict() 52 | if 'n_layer' in hf_config_dict: 53 | hf_config_dict['n_layers'] = hf_config_dict.pop('n_layer') 54 | if 'n_head' in hf_config_dict: 55 | hf_config_dict['n_heads'] = hf_config_dict.pop('n_head') 56 | 57 | expected_keys = inspect.signature(OriginalGPTConfig).parameters.keys() 58 | filtered_args = {key: hf_config_dict[key] for key in expected_keys if key in hf_config_dict} 59 | original_config = OriginalGPTConfig(**filtered_args) 60 | 61 | self.transformer = OriginalGPT(original_config) 62 | self.post_init() 63 | 64 | if self.config.tie_word_embeddings: 65 | self._tied_weights_keys = ["transformer.lm_head.weight"] 66 | 67 | def get_input_embeddings(self): 68 | return self.transformer.tok_embeddings 69 | 70 | def set_input_embeddings(self, new_embeddings): 71 | self.transformer.tok_embeddings = new_embeddings 72 | 73 | def get_output_embeddings(self): 74 | return self.transformer.lm_head 75 | 76 | def forward( 77 | self, 78 | input_ids: torch.LongTensor, 79 | attention_mask: torch.LongTensor = None, 80 | past_key_values: list[torch.Tensor] | None = None, 81 | use_cache: bool | None = None, 82 | **kwargs 83 | ): 84 | past_kv_cache_for_model = None 85 | if past_key_values is not None: 86 | past_kv_cache_for_model = past_key_values 87 | 88 | if attention_mask is not None: 89 | attention_mask = attention_mask.bool() 90 | 91 | if past_kv_cache_for_model is not None and len(past_kv_cache_for_model) > 0: 92 | past_seq_len = past_kv_cache_for_model[0][0].shape[2] 93 | current_seq_len = input_ids.shape[1] 94 | total_seq_len = past_seq_len + current_seq_len 95 | 96 | if attention_mask.shape[1] < total_seq_len: 97 | batch_size = attention_mask.shape[0] 98 | cached_mask = torch.ones((batch_size, past_seq_len), 99 | dtype=attention_mask.dtype, 100 | device=attention_mask.device) 101 | attention_mask = torch.cat([cached_mask, attention_mask], dim=1) 102 | 103 | outputs = self.transformer( 104 | input_ids, 105 | past_kv_cache=past_kv_cache_for_model, 106 | use_cache=use_cache, 107 | attn_mask=attention_mask 108 | ) 109 | 110 | if use_cache: 111 | logits, present_kv_cache = outputs 112 | new_past_kv = tuple(present_kv_cache) if present_kv_cache else None 113 | else: 114 | logits = outputs[0] 115 | new_past_kv = None 116 | 117 | if torch.onnx.is_in_onnx_export(): 118 | if new_past_kv is not None: 119 | flat_present = [t for pair in new_past_kv for t in pair] 120 | return (logits, *flat_present) 121 | else: 122 | return (logits,) 123 | 124 | return CausalLMOutputWithPast( 125 | logits=logits, 126 | past_key_values=new_past_kv 127 | ) 128 | 129 | def prepare_inputs_for_generation( 130 | self, 131 | input_ids, 132 | past_key_values=None, 133 | attention_mask=None, 134 | inputs_embeds=None, 135 | **kwargs 136 | ): 137 | if past_key_values: 138 | input_ids = input_ids[:, -1:] 139 | 140 | model_inputs = { 141 | "input_ids": input_ids, 142 | "past_key_values": past_key_values, 143 | "use_cache": kwargs.get("use_cache"), 144 | "attention_mask": attention_mask, 145 | } 146 | 147 | return model_inputs 148 | -------------------------------------------------------------------------------- /export_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import onnx 5 | import onnxsim 6 | from onnxruntime.transformers import optimizer 7 | from html2term import printc 8 | 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 10 | 11 | from model import GPT, GPTConfig 12 | 13 | def load_pytorch_model(checkpoint_path, device="cpu", use_fp16=False): 14 | """ 15 | Loads the PyTorch model from a checkpoint and optionally converts to FP16. 16 | """ 17 | printc(f"Loading PyTorch checkpoint from: {checkpoint_path}") 18 | checkpoint = torch.load(checkpoint_path, map_location=device) 19 | model_args = checkpoint["model_args"] 20 | config = GPTConfig(**model_args) 21 | model = GPT(config) 22 | state_dict = checkpoint["model_state_dict"] 23 | 24 | unwanted_prefixes = ["_orig_mod.", "module."] 25 | for k, v in list(state_dict.items()): 26 | for prefix in unwanted_prefixes: 27 | if k.startswith(prefix): 28 | state_dict[k[len(prefix) :]] = state_dict.pop(k) 29 | break 30 | model.load_state_dict(state_dict, strict=True) 31 | model.eval() 32 | 33 | if use_fp16: 34 | printc("Converting PyTorch model to FP16...") 35 | model.half() 36 | 37 | model.to(device) 38 | printc("PyTorch model loaded successfully.") 39 | return model, config 40 | 41 | 42 | def export_unified_onnx_model(model, onnx_path, device): 43 | """Exports a single, unified ONNX model for both prefill and decode.""" 44 | printc(f"Exporting UNIFIED model to ONNX (FP16) at: {onnx_path}") 45 | config = model.config 46 | 47 | input_names = ["input_ids"] 48 | output_names = ["logits"] 49 | 50 | dynamic_axes = { 51 | "input_ids": {0: "batch_size", 1: "sequence_length"}, 52 | "logits": {0: "batch_size", 1: "sequence_length"}, 53 | } 54 | 55 | model_dtype = torch.float16 56 | dummy_past_kv = [] 57 | for i in range(config.n_layers): 58 | past_key, past_val = f"past_key_{i}", f"past_value_{i}" 59 | present_key, present_val = f"present_key_{i}", f"present_value_{i}" 60 | 61 | input_names.extend([past_key, past_val]) 62 | output_names.extend([present_key, present_val]) 63 | 64 | dynamic_axes.update({ 65 | past_key: {0: "batch_size", 2: "past_sequence_len"}, 66 | past_val: {0: "batch_size", 2: "past_sequence_len"}, 67 | present_key: {0: "batch_size", 2: "total_sequence_len"}, 68 | present_val: {0: "batch_size", 2: "total_sequence_len"}, 69 | }) 70 | 71 | dummy_past_kv.append(( 72 | torch.randn(1, config.n_kv_heads, 12, config.n_embd // config.n_heads, device=device, dtype=model_dtype), 73 | torch.randn(1, config.n_kv_heads, 12, config.n_embd // config.n_heads, device=device, dtype=model_dtype), 74 | )) 75 | 76 | dummy_input_ids = torch.ones(1, 1, dtype=torch.long, device=device) 77 | model_args = (dummy_input_ids, dummy_past_kv, True) 78 | 79 | torch.onnx.export( 80 | model, 81 | model_args, 82 | onnx_path, 83 | input_names=input_names, 84 | output_names=output_names, 85 | do_constant_folding=True, 86 | opset_version=17, 87 | dynamic_axes=dynamic_axes, 88 | ) 89 | printc("Unified ONNX export complete.") 90 | 91 | 92 | def simplify_and_optimize_onnx(unsimplified_path, final_path, config): 93 | """Simplifies and optimizes a single ONNX model.""" 94 | printc(f"Simplifying and optimizing: {unsimplified_path}") 95 | temp_simplified_path = unsimplified_path.replace(".onnx", "_simplified.onnx") 96 | 97 | onnx_model = onnx.load(unsimplified_path) 98 | model_simplified, check = onnxsim.simplify(onnx_model) 99 | if not check: 100 | printc("ONNX simplification failed. Using unsimplified model.") 101 | onnx.save(onnx_model, temp_simplified_path) 102 | else: 103 | onnx.save(model_simplified, temp_simplified_path) 104 | 105 | opt_model = optimizer.optimize_model( 106 | input=temp_simplified_path, 107 | model_type="gpt2", 108 | num_heads=config.n_heads, 109 | hidden_size=config.n_embd, 110 | opt_level=2, 111 | use_gpu=True, 112 | only_onnxruntime=False, 113 | ) 114 | 115 | printc("Converting optimized model to FP16...") 116 | opt_model.convert_model_float32_to_float16() 117 | 118 | opt_model.save_model_to_file(final_path) 119 | printc(f"Optimization complete. Final model at: {final_path}") 120 | 121 | if os.path.exists(temp_simplified_path): 122 | os.remove(temp_simplified_path) 123 | 124 | 125 | def create_onnx_model_for_inference( 126 | checkpoint_path, onnx_model_path, device="cuda" 127 | ): 128 | """Full pipeline to create a final, optimized ONNX model.""" 129 | if os.path.exists(onnx_model_path): 130 | printc("Final ONNX model already exists, skipping generation.") 131 | _, config = load_pytorch_model(checkpoint_path, device, use_fp16=False) 132 | return config 133 | 134 | temp_unsimplified_path = onnx_model_path.replace(".onnx", "_temp_unsimplified.onnx") 135 | 136 | try: 137 | model, config = load_pytorch_model(checkpoint_path, device, use_fp16=True) 138 | 139 | export_unified_onnx_model(model, temp_unsimplified_path, device) 140 | 141 | del model 142 | if "cuda" in device: 143 | torch.cuda.empty_cache() 144 | 145 | simplify_and_optimize_onnx(temp_unsimplified_path, onnx_model_path, config) 146 | 147 | return config 148 | 149 | finally: 150 | if os.path.exists(temp_unsimplified_path): 151 | os.remove(temp_unsimplified_path) 152 | printc(f"Cleaned up temporary file: {temp_unsimplified_path}") 153 | -------------------------------------------------------------------------------- /sophia_triton.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | import triton 5 | import triton.language as tl 6 | 7 | @triton.jit 8 | def _update_hessian_kernel( 9 | hessian_ptr, 10 | grad_ptr, 11 | beta2, 12 | n_elements, 13 | BLOCK_SIZE: tl.constexpr, 14 | ): 15 | pid = tl.program_id(axis=0) 16 | block_start = pid * BLOCK_SIZE 17 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 18 | mask = offsets < n_elements 19 | 20 | hessian = tl.load(hessian_ptr + offsets, mask=mask) 21 | grad = tl.load(grad_ptr + offsets, mask=mask) 22 | 23 | hessian_fp32 = hessian.to(tl.float32) 24 | grad_fp32 = grad.to(tl.float32) 25 | 26 | new_hessian = beta2 * hessian_fp32 + (1.0 - beta2) * grad_fp32 * grad_fp32 27 | 28 | tl.store(hessian_ptr + offsets, new_hessian.to(hessian.dtype), mask=mask) 29 | 30 | 31 | @triton.jit 32 | def _step_kernel( 33 | p_ptr, 34 | grad_ptr, 35 | exp_avg_ptr, 36 | hessian_ptr, 37 | lr, 38 | beta1, 39 | rho, 40 | bs, 41 | weight_decay, 42 | eps, 43 | n_elements, 44 | p_dtype: tl.constexpr, 45 | BLOCK_SIZE: tl.constexpr, 46 | ): 47 | pid = tl.program_id(axis=0) 48 | block_start = pid * BLOCK_SIZE 49 | offsets = block_start + tl.arange(0, BLOCK_SIZE) 50 | mask = offsets < n_elements 51 | 52 | p = tl.load(p_ptr + offsets, mask=mask) 53 | grad = tl.load(grad_ptr + offsets, mask=mask) 54 | exp_avg = tl.load(exp_avg_ptr + offsets, mask=mask) 55 | hessian = tl.load(hessian_ptr + offsets, mask=mask) 56 | 57 | p_fp32 = p.to(tl.float32) 58 | grad_fp32 = grad.to(tl.float32) 59 | exp_avg_fp32 = exp_avg.to(tl.float32) 60 | hessian_fp32 = hessian.to(tl.float32) 61 | 62 | new_exp_avg = beta1 * exp_avg_fp32 + (1.0 - beta1) * grad_fp32 63 | 64 | p_decayed = p_fp32 * (1.0 - lr * weight_decay) 65 | 66 | denominator = tl.maximum(rho * bs * hessian_fp32, eps) 67 | ratio = tl.abs(new_exp_avg) / denominator 68 | clamped_ratio = tl.minimum(ratio, 1.0) 69 | 70 | sign_new_exp_avg = tl.where(new_exp_avg > 0, 1.0, tl.where(new_exp_avg < 0, -1.0, 0.0)) 71 | update = lr * sign_new_exp_avg * clamped_ratio 72 | new_p = p_decayed - update 73 | 74 | tl.store(p_ptr + offsets, new_p.to(p_dtype), mask=mask) 75 | tl.store(exp_avg_ptr + offsets, new_exp_avg.to(exp_avg.dtype), mask=mask) 76 | 77 | 78 | class SophiaG(Optimizer): 79 | def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho=0.04, 80 | weight_decay=1e-1, *, maximize: bool = False, 81 | capturable: bool = False, eps: float = 1e-15, bs: int): 82 | if not 0.0 <= lr: 83 | raise ValueError(f"Invalid learning rate: {lr}") 84 | if not 0.0 <= betas[0] < 1.0: 85 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 86 | if not 0.0 <= betas[1] < 1.0: 87 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 88 | if not 0.0 <= rho: 89 | raise ValueError(f"Invalid rho parameter: {rho}") 90 | if not 0.0 <= weight_decay: 91 | raise ValueError(f"Invalid weight_decay value: {weight_decay}") 92 | if capturable: 93 | raise ValueError("Capturable mode is not supported by this Triton implementation.") 94 | if not 0.0 <= eps: 95 | raise ValueError(f"Invalid epsilon value: {eps}") 96 | if not bs > 0: 97 | raise ValueError(f"Invalid batch size (bs): {bs}") 98 | 99 | defaults = dict(lr=lr, betas=betas, rho=rho, weight_decay=weight_decay, 100 | maximize=maximize, eps=eps, bs=bs) 101 | super(SophiaG, self).__init__(params, defaults) 102 | 103 | self.hessian_update_stream = torch.cuda.Stream() 104 | 105 | def _init_state(self, p): 106 | """Initializes optimizer state for a parameter.""" 107 | state = self.state[p] 108 | if len(state) == 0: 109 | state['step'] = 0 110 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 111 | state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) 112 | 113 | @torch.no_grad() 114 | def update_hessian(self): 115 | """ 116 | Synchronizes the Hessian update stream with the current stream. 117 | This ensures that the Hessian update from the previous step is complete 118 | before the current optimizer step uses it. Also handles state initialization. 119 | """ 120 | torch.cuda.current_stream().wait_stream(self.hessian_update_stream) 121 | for group in self.param_groups: 122 | for p in group['params']: 123 | state = self.state[p] 124 | 125 | if len(state) > 0 and state['exp_avg'].shape != p.shape: 126 | print(f"SophiaG: Detected shape mismatch for a parameter (state: {state['exp_avg'].shape}, param: {p.shape}). Re-initializing state.") 127 | state.clear() 128 | 129 | self._init_state(p) 130 | 131 | @torch.no_grad() 132 | def schedule_hessian_update(self): 133 | """ 134 | This allows the update to overlap with the backward pass of the next iteration, 135 | hiding its latency and improving GPU utilization. 136 | """ 137 | with torch.cuda.stream(self.hessian_update_stream): 138 | for group in self.param_groups: 139 | beta1, beta2 = group['betas'] 140 | for p in group['params']: 141 | if p.grad is None: 142 | continue 143 | 144 | grad = p.grad 145 | if grad.is_sparse: 146 | raise RuntimeError('SophiaG does not support sparse gradients') 147 | 148 | state = self.state[p] 149 | 150 | if len(state) == 0: 151 | raise RuntimeError(f"SophiaG: State not initialized for parameter with shape {p.shape}, but it has a gradient. Ensure `optimizer.update_hessian()` is called before `backward()`.") 152 | 153 | hessian = state['hessian'] 154 | n_elements = p.numel() 155 | 156 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 157 | _update_hessian_kernel[grid]( 158 | hessian, grad, beta2, n_elements, BLOCK_SIZE=1024 159 | ) 160 | 161 | @torch.no_grad() 162 | def step(self, closure=None): 163 | loss = None 164 | if closure is not None: 165 | with torch.enable_grad(): 166 | loss = closure() 167 | 168 | for group in self.param_groups: 169 | beta1, beta2 = group['betas'] 170 | lr = group['lr'] 171 | weight_decay = group['weight_decay'] 172 | rho = group['rho'] 173 | eps = group['eps'] 174 | bs = group['bs'] 175 | maximize = group['maximize'] 176 | 177 | for p in group['params']: 178 | if p.grad is None: 179 | continue 180 | 181 | grad = p.grad 182 | if maximize: 183 | grad = -grad 184 | 185 | if grad.is_sparse: 186 | raise RuntimeError('SophiaG does not support sparse gradients') 187 | 188 | state = self.state[p] 189 | if len(state) == 0: 190 | raise RuntimeError("Optimizer state not initialized. Call update_hessian() before step().") 191 | 192 | state['step'] += 1 193 | 194 | exp_avg = state['exp_avg'] 195 | hessian = state['hessian'] 196 | n_elements = p.numel() 197 | 198 | p_dtype = p.dtype 199 | if p_dtype == torch.float16: 200 | p_dtype_tl = tl.float16 201 | elif p_dtype == torch.bfloat16: 202 | p_dtype_tl = tl.bfloat16 203 | else: 204 | p_dtype_tl = tl.float32 205 | 206 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) 207 | _step_kernel[grid]( 208 | p, grad, exp_avg, hessian, 209 | lr, beta1, rho, float(bs), weight_decay, eps, 210 | n_elements, 211 | p_dtype=p_dtype_tl, 212 | BLOCK_SIZE=1024, 213 | ) 214 | return loss 215 | -------------------------------------------------------------------------------- /prepare_dataset_fineweb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import gc 5 | import pickle 6 | import numpy as np 7 | from tqdm import tqdm 8 | from datasets import load_dataset 9 | import tiktoken 10 | from html2term import printc 11 | 12 | dataset_name = "fineweb_edu_sample_10BT" 13 | output_dir = f"data/{dataset_name}" 14 | temp_dir = os.path.join(output_dir, "tmp") 15 | state_file = os.path.join(temp_dir, "progress.json") 16 | TOKENIZER_PATH = "tokenizer/Hastings.pkl" 17 | 18 | hf_dataset_name = "HuggingFaceFW/fineweb-edu" 19 | hf_dataset_config = "sample-10BT" 20 | hf_dataset_split = "train" 21 | num_documents = 9_672_101 22 | 23 | CHUNK_SIZE = 200_000 24 | language_filter = "en" 25 | language_score_threshold = 0.95 26 | val_split = 0.0005 27 | EOT_TOKEN = "<|endoftext|>" 28 | REMOVE_CHUNKS_AFTER_CONSOLIDATION = True 29 | 30 | def load_tokenizer(tokenizer_path=TOKENIZER_PATH): 31 | """Loads the tokenizer from a pickle file.""" 32 | with open(tokenizer_path, "rb") as f: 33 | hastings = pickle.load(f) 34 | enc = tiktoken.core.Encoding(hastings.pop("name"), **hastings) 35 | printc(f"Tokenizer loaded. Vocab size: {enc.n_vocab}") 36 | try: 37 | enc.encode_single_token(EOT_TOKEN) 38 | except KeyError: 39 | raise ValueError(f"EOT token '{EOT_TOKEN}' is not in the tokenizer vocabulary.") 40 | return enc 41 | 42 | def load_progress(): 43 | """Loads the processing progress from the state file.""" 44 | if os.path.exists(state_file): 45 | with open(state_file, "r") as f: 46 | p = json.load(f) 47 | return p.get("processed_docs", 0), p.get("chunk_index", 0) 48 | return 0, 0 49 | 50 | def save_progress(processed_docs, chunk_index): 51 | """Saves the current processing progress.""" 52 | os.makedirs(temp_dir, exist_ok=True) 53 | with open(state_file, "w") as f: 54 | json.dump({"processed_docs": processed_docs, "chunk_index": chunk_index}, f) 55 | 56 | def save_chunk(docs, chunk_index, split_name): 57 | """Saves a chunk of tokenized documents to a temporary file.""" 58 | os.makedirs(temp_dir, exist_ok=True) 59 | file_path = os.path.join(temp_dir, f"{split_name}_chunk_{chunk_index}.npy") 60 | np.save(file_path, np.array(docs, dtype=object)) 61 | 62 | def _get_chunk_files(temp_dir, split_name): 63 | """Helper to list and sort temporary chunk files.""" 64 | pattern = re.compile(rf"^{re.escape(split_name)}_chunk_(\d+)\.npy$") 65 | files = [] 66 | if not os.path.isdir(temp_dir): 67 | return [] 68 | for fn in os.listdir(temp_dir): 69 | m = pattern.match(fn) 70 | if m: 71 | idx = int(m.group(1)) 72 | files.append((idx, os.path.join(temp_dir, fn))) 73 | files.sort(key=lambda x: x[0]) 74 | return [p for _, p in files] 75 | 76 | def consolidate_chunks_to_npz(temp_dir, output_dir, split_name, remove_chunks=REMOVE_CHUNKS_AFTER_CONSOLIDATION): 77 | """Consolidates temporary chunks into a final compressed .npz file.""" 78 | chunk_files = _get_chunk_files(temp_dir, split_name) 79 | if not chunk_files: 80 | printc(f"No chunk files found for split '{split_name}'. Skipping.") 81 | return 82 | 83 | printc(f"Pass 1/2: Counting tokens in {len(chunk_files)} chunks for '{split_name}'...") 84 | total_docs = 0 85 | total_tokens = 0 86 | for p in tqdm(chunk_files, desc=f"Counting {split_name}"): 87 | arr = np.load(p, allow_pickle=True) 88 | total_docs += arr.shape[0] 89 | total_tokens += sum(len(x) for x in arr) 90 | del arr 91 | gc.collect() 92 | 93 | if total_docs == 0: 94 | printc(f"No documents found in chunks for '{split_name}'. Skipping.") 95 | return 96 | 97 | printc(f" Found {total_docs:,} documents and {total_tokens:,} tokens for '{split_name}'.") 98 | 99 | tokens_arr = np.empty(total_tokens, dtype=np.uint16) 100 | offsets_arr = np.empty(total_docs + 1, dtype=np.uint64) 101 | offsets_arr[0] = 0 102 | 103 | printc(f"Pass 2/2: Consolidating chunks into final arrays...") 104 | token_pos = 0 105 | doc_idx = 0 106 | for p in tqdm(chunk_files, desc=f"Consolidating {split_name}"): 107 | chunk = np.load(p, allow_pickle=True) 108 | for doc in chunk: 109 | doc_len = len(doc) 110 | tokens_arr[token_pos : token_pos + doc_len] = doc 111 | token_pos += doc_len 112 | doc_idx += 1 113 | offsets_arr[doc_idx] = token_pos 114 | del chunk 115 | gc.collect() 116 | 117 | output_path = os.path.join(output_dir, f"{split_name}.npz") 118 | printc(f" Saving consolidated data to {output_path}...") 119 | np.savez_compressed(output_path, tokens=tokens_arr, offsets=offsets_arr) 120 | printc(f" Successfully saved {output_path}") 121 | 122 | if remove_chunks: 123 | printc(" Cleaning up temporary chunk files...") 124 | for p in chunk_files: 125 | try: 126 | os.remove(p) 127 | except OSError as e: 128 | printc(f" Error removing {p}: {e}") 129 | printc(" Cleanup complete.") 130 | 131 | def tokenize_and_create_chunks(enc): 132 | """ 133 | Loads the dataset, filters, tokenizes, and saves data in chunks. 134 | """ 135 | processed_docs_count, chunk_index = load_progress() 136 | eot_token_id = enc.encode_single_token(EOT_TOKEN) 137 | 138 | printc("Loading dataset (streaming)...") 139 | ds = load_dataset(hf_dataset_name, name=hf_dataset_config, split=hf_dataset_split, streaming=True) 140 | 141 | if processed_docs_count > 0: 142 | printc(f"Resume: skipping {processed_docs_count:,} already processed docs.") 143 | ds = ds.skip(processed_docs_count) 144 | chunk_index += 1 145 | 146 | val_docs_count = int(num_documents * val_split) 147 | train_docs_count = num_documents - val_docs_count 148 | 149 | train_chunk, val_chunk = [], [] 150 | pbar = tqdm(initial=processed_docs_count, total=num_documents, unit="docs", desc="Processing documents") 151 | try: 152 | for doc in ds: 153 | if doc.get("language") == language_filter and doc.get("language_score", 0) >= language_score_threshold: 154 | text = doc.get("text") 155 | if text: 156 | tokens = enc.encode(text, allowed_special="all") 157 | tokens.append(eot_token_id) 158 | 159 | if pbar.n < train_docs_count: 160 | train_chunk.append(tokens) 161 | else: 162 | val_chunk.append(tokens) 163 | 164 | if len(train_chunk) >= CHUNK_SIZE or len(val_chunk) >= CHUNK_SIZE: 165 | if train_chunk: 166 | save_chunk(train_chunk, chunk_index, "train") 167 | if val_chunk: 168 | save_chunk(val_chunk, chunk_index, "val") 169 | 170 | processed_docs_count = pbar.n + 1 171 | save_progress(processed_docs_count, chunk_index) 172 | train_chunk, val_chunk = [], [] 173 | chunk_index += 1 174 | 175 | pbar.update(1) 176 | 177 | except (KeyboardInterrupt, Exception) as e: 178 | printc(f"
Process interrupted: {e}. Progress saved.") 179 | finally: 180 | pbar.close() 181 | 182 | if train_chunk or val_chunk: 183 | if train_chunk: save_chunk(train_chunk, chunk_index, "train") 184 | if val_chunk: save_chunk(val_chunk, chunk_index, "val") 185 | 186 | if os.path.exists(state_file): 187 | os.remove(state_file) 188 | 189 | printc("
Tokenization finished.") 190 | return True 191 | 192 | def main(): 193 | os.makedirs(output_dir, exist_ok=True) 194 | 195 | if os.path.exists(os.path.join(output_dir, "train.npz")) and os.path.exists(os.path.join(output_dir, "val.npz")): 196 | printc("Final .npz files already exist. Skipping processing.") 197 | return 198 | 199 | enc = load_tokenizer() 200 | tokenize_and_create_chunks(enc) 201 | 202 | printc("
Starting consolidation process...") 203 | consolidate_chunks_to_npz(temp_dir, output_dir, "train") 204 | consolidate_chunks_to_npz(temp_dir, output_dir, "val") 205 | 206 | try: 207 | if os.path.isdir(temp_dir) and not os.listdir(temp_dir): 208 | os.rmdir(temp_dir) 209 | except Exception as e: 210 | printc(f"Could not remove temp directory: {e}") 211 | 212 | printc("
✅ All done.") 213 | 214 | if __name__ == "__main__": 215 | main() 216 | -------------------------------------------------------------------------------- /gguf/export_gguf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | import torch 5 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 6 | from simple_ai.model_hf import LilleConfig, LilleForCausalLM 7 | from typing import Iterable 8 | from torch import Tensor 9 | 10 | LLAMA_CPP_DIR = "llama.cpp" 11 | if not os.path.isdir(LLAMA_CPP_DIR): 12 | print(f"'{LLAMA_CPP_DIR}' directory not found.") 13 | print("Attempting to clone the repository...") 14 | LLAMA_CPP_REPO_URL = "https://github.com/ggml-org/llama.cpp.git" 15 | try: 16 | subprocess.run( 17 | ["git", "clone", "--depth=1", LLAMA_CPP_REPO_URL, LLAMA_CPP_DIR], 18 | check=True, 19 | capture_output=True, 20 | text=True, 21 | ) 22 | print(f"Successfully cloned '{LLAMA_CPP_REPO_URL}' into '{LLAMA_CPP_DIR}'.") 23 | except FileNotFoundError: 24 | print("Error: 'git' command not found.") 25 | print( 26 | "Please install Git and ensure it is in your system's PATH to automatically clone llama.cpp." 27 | ) 28 | sys.exit(1) 29 | except subprocess.CalledProcessError as e: 30 | print(f"Error: Failed to clone the llama.cpp repository.") 31 | print(f"Git command failed with exit code {e.returncode}.") 32 | print(f"Stderr: {e.stderr.strip()}") 33 | sys.exit(1) 34 | 35 | sys.path.insert(0, LLAMA_CPP_DIR) 36 | 37 | try: 38 | from convert_hf_to_gguf import main as convert_main, ModelBase, gguf, LlamaModel # type: ignore 39 | except ImportError as e: 40 | print( 41 | f"Error: Failed to import from {os.path.join(LLAMA_CPP_DIR, 'convert_hf_to_gguf.py')}: {e}" 42 | ) 43 | print("Please ensure your llama.cpp clone is up-to-date and the file exists.") 44 | sys.exit(1) 45 | 46 | 47 | # 1. Define and Register the GGUF Conversion Class --- 48 | @ModelBase.register("LilleForCausalLM") 49 | class LilleModel(LlamaModel): 50 | model_arch = gguf.MODEL_ARCH.LLAMA 51 | 52 | def set_gguf_parameters(self): 53 | self.gguf_writer.add_block_count(self.hparams["n_layer"]) 54 | self.gguf_writer.add_context_length(self.hparams["block_size"]) 55 | self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) 56 | self.gguf_writer.add_head_count(self.hparams["n_head"]) 57 | self.gguf_writer.add_head_count_kv(self.hparams["n_kv_heads"]) 58 | self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_eps"]) 59 | self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) 60 | 61 | multiple_of = 256 62 | n_embd = self.hparams["n_embd"] 63 | ff_dim = int(2 * (4 * n_embd) / 3) 64 | ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) 65 | self.gguf_writer.add_feed_forward_length(ff_dim) 66 | 67 | def set_vocab(self): 68 | tokens: list[str] = [] 69 | toktypes: list[int] = [] 70 | 71 | tokenizer = AutoTokenizer.from_pretrained(self.dir_model) 72 | 73 | vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) 74 | assert max(tokenizer.vocab.values()) < vocab_size 75 | 76 | tokpre = "gpt-2" 77 | 78 | reverse_vocab = { 79 | id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items() 80 | } 81 | added_vocab = tokenizer.get_added_vocab() 82 | added_tokens_decoder = tokenizer.added_tokens_decoder 83 | 84 | for i in range(vocab_size): 85 | if i not in reverse_vocab: 86 | tokens.append(f"[PAD{i}]") 87 | toktypes.append(gguf.TokenType.UNUSED) 88 | else: 89 | token: str = reverse_vocab[i] 90 | if token in added_vocab: 91 | if added_tokens_decoder[i].special or self.does_token_look_special( 92 | token 93 | ): 94 | toktypes.append(gguf.TokenType.CONTROL) 95 | else: 96 | toktypes.append(gguf.TokenType.USER_DEFINED) 97 | else: 98 | toktypes.append(gguf.TokenType.NORMAL) 99 | tokens.append(token) 100 | 101 | self.gguf_writer.add_tokenizer_model("gpt2") 102 | self.gguf_writer.add_tokenizer_pre(tokpre) 103 | self.gguf_writer.add_token_list(tokens) 104 | self.gguf_writer.add_token_types(toktypes) 105 | 106 | special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) 107 | special_vocab.add_to_gguf(self.gguf_writer) 108 | 109 | def modify_tensors( 110 | self, data_torch: Tensor, name: str, bid: int | None 111 | ) -> Iterable[tuple[str, Tensor]]: 112 | if name.endswith((".cos_cached", ".sin_cached")): 113 | return 114 | 115 | n_head = self.hparams["n_head"] 116 | n_kv_head = self.hparams["n_kv_heads"] 117 | 118 | if name.endswith("attention.qkv_proj.weight"): 119 | n_embd = self.hparams["n_embd"] 120 | head_dim = n_embd // n_head 121 | 122 | q_size = n_head * head_dim 123 | k_size = n_kv_head * head_dim 124 | v_size = n_kv_head * head_dim 125 | 126 | q_proj, k_proj, v_proj = torch.split( 127 | data_torch, [q_size, k_size, v_size], dim=0 128 | ) 129 | 130 | q_proj = LlamaModel.permute(q_proj, n_head, n_head) 131 | k_proj = LlamaModel.permute(k_proj, n_head, n_kv_head) 132 | 133 | yield self.map_tensor_name( 134 | f"model.layers.{bid}.self_attn.q_proj.weight" 135 | ), q_proj 136 | yield self.map_tensor_name( 137 | f"model.layers.{bid}.self_attn.k_proj.weight" 138 | ), k_proj 139 | yield self.map_tensor_name( 140 | f"model.layers.{bid}.self_attn.v_proj.weight" 141 | ), v_proj 142 | return 143 | 144 | rename_map = { 145 | "transformer.tok_embeddings.weight": "model.embed_tokens.weight", 146 | "transformer.norm.weight": "model.norm.weight", 147 | "transformer.output.weight": "lm_head.weight", 148 | f"transformer.layers.{bid}.attention.out_proj.weight": f"model.layers.{bid}.self_attn.o_proj.weight", 149 | f"transformer.layers.{bid}.attention.norm.weight": f"model.layers.{bid}.input_layernorm.weight", 150 | f"transformer.layers.{bid}.feed_forward.gate_proj.weight": f"model.layers.{bid}.mlp.gate_proj.weight", 151 | f"transformer.layers.{bid}.feed_forward.up_proj.weight": f"model.layers.{bid}.mlp.up_proj.weight", 152 | f"transformer.layers.{bid}.feed_forward.down_proj.weight": f"model.layers.{bid}.mlp.down_proj.weight", 153 | f"transformer.layers.{bid}.feed_forward.norm.weight": f"model.layers.{bid}.post_attention_layernorm.weight", 154 | } 155 | 156 | if name in rename_map: 157 | translated_name = rename_map[name] 158 | yield self.map_tensor_name(translated_name), data_torch 159 | return 160 | 161 | raise ValueError(f"Can not map tensor {name!r}") 162 | 163 | 164 | # 2. Register the custom HF model architecture 165 | AutoConfig.register("lille-130m", LilleConfig) 166 | AutoModelForCausalLM.register(LilleConfig, LilleForCausalLM) 167 | 168 | # 3. Define constants 169 | MODEL_NAME = "Nikity/lille-130m-instruct" 170 | LOCAL_MODEL_DIR = "./lille-130m-instruct" 171 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 172 | 173 | # 4. Define quantization types 174 | QUANTIZATION_TYPES = ["q8_0", "f16", "f32"] 175 | 176 | # 5. Download and save the model and tokenizer 177 | print("Downloading and saving model...") 178 | try: 179 | if not os.path.exists(LOCAL_MODEL_DIR): 180 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) 181 | model = AutoModelForCausalLM.from_pretrained( 182 | MODEL_NAME, torch_dtype="auto", device_map=DEVICE 183 | ) 184 | model.save_pretrained(LOCAL_MODEL_DIR) 185 | tokenizer.save_pretrained(LOCAL_MODEL_DIR) 186 | print(f"Model and tokenizer saved to {LOCAL_MODEL_DIR}") 187 | else: 188 | print(f"Model already exists at {LOCAL_MODEL_DIR}. Skipping download.") 189 | except Exception as e: 190 | print(f"Failed to download/save model: {e}") 191 | raise 192 | 193 | # 6. Inspect model configuration (for debugging) 194 | print("Model configuration:") 195 | config = AutoConfig.from_pretrained(LOCAL_MODEL_DIR) 196 | print(config) 197 | 198 | # 7. Convert to GGUF for each quantization type 199 | for quant_type in QUANTIZATION_TYPES: 200 | out_dir = "gguf_models" 201 | os.makedirs(out_dir, exist_ok=True) 202 | output_gguf = os.path.join(out_dir, f"lille-130m-instruct-{quant_type}.gguf") 203 | print(f"Attempting to convert to GGUF with quantization {quant_type}...") 204 | 205 | # Build argument list 206 | command_args = [LOCAL_MODEL_DIR, "--outfile", output_gguf, "--outtype", quant_type] 207 | 208 | original_argv = sys.argv 209 | try: 210 | sys.argv = ["convert_hf_to_gguf.py"] + command_args 211 | 212 | convert_main() 213 | 214 | print(f"Conversion successful! GGUF file created at {output_gguf}") 215 | except Exception as e: 216 | print(f"GGUF conversion failed for {quant_type}: {e}") 217 | raise 218 | finally: 219 | sys.argv = original_argv 220 | -------------------------------------------------------------------------------- /eval-table/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Markdown Table to HTML/PNG Converter 7 | 8 | 9 | 10 | 11 | 13 | 160 | 161 | 162 | 163 |
164 |
165 |

Markdown Input

166 | 177 |
178 | 179 |
180 |
181 |
182 |

HTML Preview

183 |
184 | 185 |
186 |
187 | 283 | 284 | 285 | 286 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | import os 5 | import pickle 6 | import argparse 7 | import time 8 | from typing import Optional 9 | 10 | import cProfile 11 | import pstats 12 | import io 13 | 14 | import torch 15 | import torch._inductor.config 16 | import torch._functorch.config 17 | import torch.nn.functional as F 18 | import tiktoken 19 | import numpy as np 20 | import onnxruntime 21 | from html2term import printc 22 | 23 | torch._inductor.config.coordinate_descent_tuning = True 24 | torch._inductor.config.triton.unique_kernel_names = True 25 | torch._inductor.config.fx_graph_cache = True 26 | torch._functorch.config.enable_autograd_cache = True 27 | 28 | from model import GPTConfig 29 | 30 | from export_utils import create_onnx_model_for_inference 31 | 32 | def _apply_sampling( 33 | logits: torch.Tensor, temp: float, top_p: Optional[float], top_k: Optional[int] 34 | ) -> int: 35 | """Apply temperature, top-p, and top-k sampling to logits, expecting a GPU tensor.""" 36 | if temp == 0.0: 37 | return torch.argmax(logits, dim=-1).item() 38 | 39 | logits.div_(temp) 40 | 41 | if top_p is not None and 0.0 < top_p < 1.0: 42 | probs = F.softmax(logits, dim=-1) 43 | sorted_probs, sorted_indices = torch.sort(probs, descending=True) 44 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 45 | 46 | sorted_indices_to_remove = cumulative_probs > top_p 47 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 48 | sorted_indices_to_remove[..., 0] = 0 49 | 50 | indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_( 51 | -1, sorted_indices, sorted_indices_to_remove 52 | ) 53 | logits[indices_to_remove] = -float("Inf") 54 | 55 | elif top_k is not None and top_k > 0: 56 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 57 | logits[logits < v[..., -1, None]] = -float("Inf") 58 | 59 | probs = F.softmax(logits, dim=-1) 60 | return torch.multinomial(probs, num_samples=1).item() 61 | 62 | 63 | def run_chat_loop_io_binding(onnx_model_path, config, tokenizer, device, args): 64 | """Runs a highly optimized interactive chat loop using I/O Binding to keep data on the GPU.""" 65 | if not args.profile: 66 | printc("
--- Starting Chat with Fused FP16 ONNX Model (I/O Binding Enabled) ---") 67 | sampling_params = f"temp={args.temperature}" 68 | if args.temperature == 0: sampling_params = "greedy" 69 | elif args.top_p is not None: sampling_params += f", top_p={args.top_p}" 70 | elif args.top_k is not None: sampling_params += f", top_k={args.top_k}" 71 | printc(f"<#cccccc>Params: {sampling_params}, max_new_tokens={args.max_new_tokens}. Type 'exit' or 'quit' to end.") 72 | 73 | options = onnxruntime.SessionOptions() 74 | options.log_severity_level = 3 75 | 76 | base_name = os.path.splitext(os.path.basename(onnx_model_path))[0] 77 | optimized_model_path = os.path.join(os.path.dirname(onnx_model_path), f"{base_name}.ort") 78 | options.optimized_model_filepath = optimized_model_path 79 | options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 80 | 81 | providers = [] 82 | if device == "cuda": 83 | provider = "CUDAExecutionProvider" 84 | if provider not in onnxruntime.get_available_providers(): 85 | raise RuntimeError(f"{provider} not available, please check your ONNX Runtime and CUDA setup.") 86 | device_id = torch.cuda.current_device() 87 | provider_options = {'device_id': device_id, 'arena_extend_strategy': 'kSameAsRequested'} 88 | providers = [ 89 | ('CUDAExecutionProvider', provider_options), 90 | 'CPUExecutionProvider', 91 | ] 92 | if not args.profile: printc(f"Using ONNX Runtime providers: {providers[0][0]} (with CUDA Graph), {providers[1]}") 93 | device_name = 'cuda' 94 | torch_device = torch.device(f"cuda:{device_id}") 95 | else: 96 | provider = "CPUExecutionProvider" 97 | if not args.profile: printc(f"Using ONNX Runtime provider: {provider}") 98 | providers.append(provider) 99 | device_name = 'cpu' 100 | torch_device = torch.device("cpu") 101 | 102 | session = onnxruntime.InferenceSession(onnx_model_path, sess_options=options, providers=providers) 103 | stop_ids = [tokenizer.encode_single_token(t) for t in ["<|endoftext|>", "<|user|>"]] 104 | 105 | conversation_history_ids = [] 106 | 107 | while True: 108 | try: 109 | if args.profile: 110 | prompt = "What is the capital of France and what is its history?" 111 | printc(f"
You: {prompt}") 112 | else: 113 | printc("
You: ", end="") 114 | prompt = input() 115 | 116 | if prompt.lower() in ["exit", "quit"]: break 117 | 118 | if not conversation_history_ids: 119 | tokens_to_process = tokenizer.encode(f"<|startoftext|><|user|>{prompt}<|assistant|>", allowed_special="all") 120 | else: 121 | tokens_to_process = tokenizer.encode(f"<|user|>{prompt}<|assistant|>", allowed_special="all") 122 | 123 | conversation_history_ids.extend(tokens_to_process) 124 | 125 | if len(conversation_history_ids) > config.block_size: 126 | printc("
[CONTEXT RESET - Model has forgotten the conversation]") 127 | conversation_history_ids = tokenizer.encode(f"<|startoftext|><|user|>{prompt}<|assistant|>", allowed_special="all") 128 | tokens_to_process = conversation_history_ids 129 | 130 | binding = session.io_binding() 131 | 132 | input_ids_np = np.array([tokens_to_process], dtype=np.int64) 133 | input_ids_ort = onnxruntime.OrtValue.ortvalue_from_numpy(input_ids_np, device_name, 0) 134 | binding.bind_ortvalue_input('input_ids', input_ids_ort) 135 | 136 | dtype = np.float16 137 | empty_past = np.zeros((1, config.n_kv_heads, 0, config.n_embd // config.n_heads), dtype=dtype) 138 | empty_past_ort = onnxruntime.OrtValue.ortvalue_from_numpy(empty_past, device_name, 0) 139 | for i in range(config.n_layers): 140 | binding.bind_ortvalue_input(f'past_key_{i}', empty_past_ort) 141 | binding.bind_ortvalue_input(f'past_value_{i}', empty_past_ort) 142 | 143 | binding.bind_output('logits', device_name) 144 | for i in range(config.n_layers): 145 | binding.bind_output(f'present_key_{i}', device_name) 146 | binding.bind_output(f'present_value_{i}', device_name) 147 | 148 | session.run_with_iobinding(binding) 149 | ort_outs = binding.get_outputs() 150 | logits_ort, past_key_values = ort_outs[0], ort_outs[1:] 151 | 152 | logits_torch = torch.tensor(logits_ort.numpy(), device="cuda") 153 | next_token_id = _apply_sampling(logits_torch[0, -1, :], args.temperature, args.top_p, args.top_k) 154 | 155 | printc("Bot: ", end="", flush=True) 156 | generated_response_ids = [] 157 | start_time = time.perf_counter() 158 | 159 | max_tokens = min(args.max_new_tokens, config.block_size - len(conversation_history_ids)) 160 | if max_tokens <= 0: 161 | printc("
[CONTEXT FULL - Cannot generate more tokens]") 162 | if args.profile: break 163 | continue 164 | 165 | single_token_input_ort = onnxruntime.OrtValue.ortvalue_from_numpy( 166 | np.array([[next_token_id]], dtype=np.int64), device_name, 0 167 | ) 168 | 169 | for _ in range(max_tokens): 170 | if next_token_id in stop_ids: break 171 | 172 | generated_response_ids.append(next_token_id) 173 | print(tokenizer.decode([next_token_id]), end="", flush=True) 174 | 175 | binding.bind_ortvalue_input('input_ids', single_token_input_ort) 176 | for j in range(config.n_layers): 177 | binding.bind_ortvalue_input(f'past_key_{j}', past_key_values[j*2]) 178 | binding.bind_ortvalue_input(f'past_value_{j}', past_key_values[j*2+1]) 179 | 180 | binding.bind_output('logits', device_name) 181 | for j in range(config.n_layers): 182 | binding.bind_output(f'present_key_{j}', device_name) 183 | binding.bind_output(f'present_value_{j}', device_name) 184 | 185 | session.run_with_iobinding(binding) 186 | ort_outs = binding.get_outputs() 187 | 188 | logits_ort, past_key_values = ort_outs[0], ort_outs[1:] 189 | 190 | logits_torch = torch.tensor(logits_ort.numpy(), device="cuda") 191 | next_token_id = _apply_sampling(logits_torch[0, 0, :], args.temperature, args.top_p, args.top_k) 192 | 193 | single_token_input_ort.update_inplace(np.array([[next_token_id]], dtype=np.int64)) 194 | 195 | 196 | end_time = time.perf_counter() 197 | conversation_history_ids.extend(generated_response_ids) 198 | printc("
") 199 | 200 | num_generated = len(generated_response_ids) 201 | time_taken = end_time - start_time 202 | tokens_per_sec = num_generated / time_taken if time_taken > 0 else 0 203 | printc(f"<#cccccc>Generated {num_generated} tokens in {time_taken:.2f}s ({tokens_per_sec:.2f} tokens/s)") 204 | 205 | if args.profile: 206 | break 207 | 208 | except KeyboardInterrupt: 209 | printc("
Exiting chat mode.") 210 | break 211 | except Exception as e: 212 | printc(f"
An error occurred: {e}") 213 | import traceback 214 | traceback.print_exc() 215 | break 216 | 217 | def main(): 218 | parser = argparse.ArgumentParser( 219 | description="Optimize a GPT model to Fused FP16 ONNX and run a fast chat session." 220 | ) 221 | parser.add_argument("--checkpoint_path", type=str, default="checkpoints_ft/best_model.pt", help="Path to the PyTorch model checkpoint.") 222 | parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"], help="Device for model loading and ONNX export.") 223 | parser.add_argument("--max_new_tokens", type=int, default=480, help="Maximum number of new tokens to generate.") 224 | parser.add_argument("--temperature", type=float, default=0.5, help="Sampling temperature. 0 for greedy.") 225 | parser.add_argument("--top_k", type=int, default=None, help="Top-k sampling.") 226 | parser.add_argument("--top_p", type=float, default=0.95, help="Top-p (nucleus) sampling.") 227 | parser.add_argument("--profile", action="store_true", help="Enable cProfile to analyze performance of one generation cycle.") 228 | args = parser.parse_args() 229 | 230 | if args.device == "cuda" and not torch.cuda.is_available(): 231 | printc("CUDA is selected but not available. Please check your environment.") 232 | exit(1) 233 | 234 | if args.temperature == 0: 235 | printc("Temperature is 0, using greedy decoding.") 236 | 237 | ONNX_MODEL_DIR = "onnx_models" 238 | os.makedirs(ONNX_MODEL_DIR, exist_ok=True) 239 | 240 | base_name = os.path.splitext(os.path.basename(args.checkpoint_path))[0] 241 | onnx_fp16_fused_path = os.path.join(ONNX_MODEL_DIR, f"{base_name}_fp16_kv_fused.onnx") 242 | 243 | with open("tokenizer/Hastings.pkl", "rb") as f: 244 | hastings = pickle.load(f) 245 | enc = tiktoken.core.Encoding(hastings.pop("name"), **hastings) 246 | 247 | create_onnx_model_for_inference(args.checkpoint_path, onnx_fp16_fused_path, args.device) 248 | 249 | checkpoint = torch.load(args.checkpoint_path, map_location='cpu') 250 | config = GPTConfig(**checkpoint['model_args']) 251 | del checkpoint 252 | 253 | if args.profile: 254 | printc(" --- PROFILING MODE ENABLED --- ") 255 | pr = cProfile.Profile() 256 | pr.enable() 257 | run_chat_loop_io_binding(onnx_fp16_fused_path, config, enc, args.device, args) 258 | pr.disable() 259 | s = io.StringIO() 260 | ps = pstats.Stats(pr, stream=s).sort_stats('cumtime') 261 | ps.print_stats(30) 262 | printc("
--- Profiler Results (Top 30 by Cumulative Time) ---") 263 | print(s.getvalue()) 264 | else: 265 | run_chat_loop_io_binding(onnx_fp16_fused_path, config, enc, args.device, args) 266 | 267 | 268 | if __name__ == "__main__": 269 | main() 270 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from dataclasses import dataclass, asdict 5 | from typing import Optional, Tuple, Union, List 6 | 7 | @dataclass 8 | class GPTConfig: 9 | """ 10 | Configuration for the GPT model, inspired by GPT-OSS/Llama but adapted for this project. 11 | """ 12 | n_embd: int = 768 13 | n_layers: int = 12 14 | n_heads: int = 12 15 | vocab_size: int = 32000 16 | block_size: int = 512 17 | dropout: float = 0.1 18 | layer_norm_eps: float = 1e-5 19 | n_kv_heads: Optional[int] = 4 20 | rope_theta: float = 10000.0 21 | 22 | def to_dict(self): 23 | return asdict(self) 24 | 25 | class RMSNorm(nn.Module): 26 | """ 27 | Root Mean Square Layer Normalization, as used in GPT-OSS and Llama. 28 | """ 29 | def __init__(self, dim: int, eps: float = 1e-5): 30 | super().__init__() 31 | self.eps = eps 32 | self.weight = nn.Parameter(torch.ones(dim)) 33 | 34 | def _norm(self, x): 35 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 36 | 37 | def forward(self, x): 38 | output = self._norm(x.float()).type_as(x) 39 | return output * self.weight 40 | 41 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 42 | """ 43 | Efficiently repeat the key and value tensors for Grouped-Query Attention. 44 | [B, n_kv_heads, T, head_dim] -> [B, n_q_heads, T, head_dim] 45 | """ 46 | B, n_kv_heads, T, head_dim = x.shape 47 | if n_rep == 1: 48 | return x 49 | return ( 50 | x[:, :, None, :, :] 51 | .expand(B, n_kv_heads, n_rep, T, head_dim) 52 | .reshape(B, n_kv_heads * n_rep, T, head_dim) 53 | ) 54 | 55 | class RotaryPositionalEmbedding(nn.Module): 56 | """ 57 | Original RoPE implementation, kept for its efficiency in training. 58 | """ 59 | def __init__(self, dim: int, max_seq_len: int, base: int = 10000): 60 | super().__init__() 61 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 62 | self.register_buffer("inv_freq", inv_freq) 63 | 64 | t = torch.arange(max_seq_len, device=self.inv_freq.device) 65 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 66 | emb = torch.cat((freqs, freqs), dim=-1) 67 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :]) 68 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :]) 69 | 70 | def forward(self, x, seq_len: int): 71 | return ( 72 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 73 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), 74 | ) 75 | 76 | def apply_rotary_pos_emb(q, k, cos, sin): 77 | def rotate_half(x): 78 | return torch.cat([-x[..., 1::2], x[..., ::2]], dim=-1) 79 | 80 | q_embed = (q * cos) + (rotate_half(q) * sin) 81 | k_embed = (k * cos) + (rotate_half(k) * sin) 82 | return q_embed, k_embed 83 | 84 | class Attention(nn.Module): 85 | """ 86 | Attention module with pre-normalization, based on Llama/GPT-OSS style. 87 | """ 88 | def __init__(self, config: GPTConfig): 89 | super().__init__() 90 | self.n_q_heads = config.n_heads 91 | self.n_kv_heads = config.n_kv_heads if config.n_kv_heads is not None else config.n_heads 92 | self.n_rep = self.n_q_heads // self.n_kv_heads 93 | self.head_dim = config.n_embd // self.n_q_heads 94 | 95 | self.qkv_proj = nn.Linear(config.n_embd, (self.n_q_heads + 2 * self.n_kv_heads) * self.head_dim, bias=False) 96 | 97 | q_heads_concat_dim = self.n_q_heads * self.head_dim 98 | self.out_proj = nn.Linear(q_heads_concat_dim, config.n_embd, bias=False) 99 | 100 | self.dropout = nn.Dropout(config.dropout) 101 | self.norm = RMSNorm(config.n_embd, eps=config.layer_norm_eps) 102 | self.out_proj.GPT_SCALE_INIT = 1 103 | 104 | def forward(self, x: torch.Tensor, rotary_emb: Tuple[torch.Tensor, torch.Tensor], past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 105 | B, T, C = x.shape 106 | 107 | h = self.norm(x) 108 | 109 | qkv = self.qkv_proj(h) 110 | q_len = self.n_q_heads * self.head_dim 111 | k_len = self.n_kv_heads * self.head_dim 112 | 113 | q = qkv[..., :q_len].view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2) 114 | k = qkv[..., q_len : q_len + k_len].view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) 115 | v = qkv[..., q_len + k_len :].view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) 116 | 117 | cos, sin = rotary_emb 118 | q, k = apply_rotary_pos_emb(q, k, cos, sin) 119 | 120 | if past_kv is not None: 121 | past_k, past_v = past_kv 122 | k = torch.cat((past_k, k), dim=2) 123 | v = torch.cat((past_v, v), dim=2) 124 | 125 | present_kv = (k.to(x.dtype), v.to(x.dtype)) 126 | 127 | k = repeat_kv(k, self.n_rep) 128 | v = repeat_kv(v, self.n_rep) 129 | 130 | is_causal_for_sdpa = False 131 | 132 | y = F.scaled_dot_product_attention( 133 | q, k, v, 134 | attn_mask=attn_mask, 135 | is_causal=is_causal_for_sdpa, 136 | dropout_p=self.dropout.p if self.training else 0.0 137 | ) 138 | 139 | y = y.transpose(1, 2).contiguous().view(B, T, -1) 140 | y = self.out_proj(y) 141 | 142 | return x + y, present_kv 143 | 144 | class FeedForward(nn.Module): 145 | """ 146 | FeedForward block with pre-normalization and SwiGLU, based on Llama/GPT-OSS style. 147 | """ 148 | def __init__(self, config: GPTConfig): 149 | super().__init__() 150 | hidden_dim = 4 * config.n_embd 151 | hidden_dim = int(2 * hidden_dim / 3) 152 | multiple_of = 256 153 | hidden_dim = multiple_of * round(hidden_dim / multiple_of) 154 | 155 | self.norm = RMSNorm(config.n_embd, eps=config.layer_norm_eps) 156 | self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False) 157 | self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False) 158 | self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) 159 | 160 | self.down_proj.GPT_SCALE_INIT = 1 161 | 162 | def forward(self, x): 163 | h = self.norm(x) 164 | gate = F.silu(self.gate_proj(h)) 165 | up = self.up_proj(h) 166 | fused = gate * up 167 | return x + self.down_proj(fused) 168 | 169 | class Block(nn.Module): 170 | """ 171 | Transformer Block in the Llama/GPT-OSS pre-normalization style. 172 | """ 173 | def __init__(self, config: GPTConfig): 174 | super().__init__() 175 | self.attention = Attention(config) 176 | self.feed_forward = FeedForward(config) 177 | 178 | def forward(self, x: torch.Tensor, rotary_emb: Tuple[torch.Tensor, torch.Tensor], past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 179 | h, present_kv = self.attention(x, rotary_emb, past_kv, attn_mask=attn_mask) 180 | out = self.feed_forward(h) 181 | return out, present_kv 182 | 183 | class GPT(nn.Module): 184 | """ 185 | The main GPT model, composed of the new Llama/GPT-OSS-style blocks. 186 | """ 187 | def __init__(self, config: GPTConfig): 188 | super().__init__() 189 | self.config = config 190 | 191 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd) 192 | self.rotary_emb = RotaryPositionalEmbedding(config.n_embd // config.n_heads, config.block_size, base=config.rope_theta) 193 | self.layers = nn.ModuleList([Block(config) for _ in range(config.n_layers)]) 194 | self.norm = RMSNorm(config.n_embd, eps=config.layer_norm_eps) 195 | 196 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 197 | self.lm_head.weight = self.tok_embeddings.weight 198 | 199 | self.apply(self._init_weights) 200 | 201 | def _init_weights(self, module): 202 | if isinstance(module, nn.Linear): 203 | std = 0.02 204 | if hasattr(module, 'GPT_SCALE_INIT'): 205 | std *= (2 * self.config.n_layers) ** -0.5 206 | torch.nn.init.normal_(module.weight, mean=0.0, std=std) 207 | if module.bias is not None: 208 | torch.nn.init.zeros_(module.bias) 209 | elif isinstance(module, nn.Embedding): 210 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 211 | 212 | def get_input_embeddings(self): 213 | """ 214 | Returns the model's input embeddings. 215 | Required by the Hugging Face PreTrainedModel interface. 216 | """ 217 | return self.tok_embeddings 218 | 219 | def set_input_embeddings(self, new_embeddings): 220 | """ 221 | Sets the model's input embeddings. 222 | Required by the Hugging Face PreTrainedModel interface. 223 | """ 224 | self.tok_embeddings = new_embeddings 225 | 226 | def forward(self, input_ids: torch.Tensor, past_kv_cache: Optional[list] = None, use_cache: bool = False, attn_mask: Optional[torch.Tensor] = None) -> tuple: 227 | B, T = input_ids.size() 228 | seq_len_offset = past_kv_cache[0][0].shape[2] if past_kv_cache is not None else 0 229 | total_sequence_length = seq_len_offset + T 230 | 231 | q_indices = torch.arange(T, device=input_ids.device) + seq_len_offset 232 | k_indices = torch.arange(total_sequence_length, device=input_ids.device) 233 | causal_mask = q_indices.unsqueeze(1) >= k_indices.unsqueeze(0) 234 | 235 | if attn_mask is not None: 236 | padding_mask = attn_mask[:, :total_sequence_length] 237 | combined_mask = causal_mask.unsqueeze(0) & padding_mask.unsqueeze(1) 238 | else: 239 | combined_mask = causal_mask.unsqueeze(0) 240 | 241 | final_sdpa_mask = combined_mask.unsqueeze(1) 242 | 243 | h = self.tok_embeddings(input_ids) 244 | 245 | cos, sin = self.rotary_emb(h, seq_len=total_sequence_length) 246 | cos = cos[:, :, seq_len_offset:, :] 247 | sin = sin[:, :, seq_len_offset:, :] 248 | rotary_emb = (cos, sin) 249 | 250 | present_kv_cache = [] 251 | for i, layer in enumerate(self.layers): 252 | past_kv = past_kv_cache[i] if past_kv_cache is not None else None 253 | h, present_kv = layer(h, rotary_emb, past_kv, attn_mask=final_sdpa_mask) 254 | present_kv_cache.append(present_kv) 255 | 256 | h = self.norm(h) 257 | logits = self.lm_head(h) 258 | 259 | return logits, present_kv_cache 260 | 261 | @torch.inference_mode() 262 | def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, stop_on_token: Optional[Union[int, List[int]]] = None, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 263 | past_kv_cache = None 264 | current_attn_mask = attn_mask 265 | 266 | for _ in range(max_new_tokens): 267 | B, T = idx.shape 268 | 269 | if T >= self.config.block_size: 270 | break 271 | 272 | current_input = idx[:, -1:] if past_kv_cache is not None else idx 273 | 274 | logits, past_kv_cache = self(current_input, past_kv_cache=past_kv_cache, use_cache=True, attn_mask=current_attn_mask) 275 | 276 | logits = logits[:, -1, :] / temperature 277 | 278 | if top_k is not None: 279 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 280 | logits[logits < v[:, [-1]]] = -float('inf') 281 | 282 | if top_p is not None: 283 | sorted_probs, sorted_indices = torch.sort(F.softmax(logits, dim=-1), descending=True) 284 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 285 | sorted_indices_to_remove = cumulative_probs > top_p 286 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 287 | sorted_indices_to_remove[..., 0] = 0 288 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 289 | logits[indices_to_remove] = -float('inf') 290 | 291 | probs = F.softmax(logits, dim=-1) 292 | idx_next = torch.multinomial(probs, num_samples=1) 293 | idx = torch.cat((idx, idx_next), dim=1) 294 | 295 | if current_attn_mask is not None: 296 | new_mask_col = torch.ones((B, 1), dtype=current_attn_mask.dtype, device=current_attn_mask.device) 297 | current_attn_mask = torch.cat((current_attn_mask, new_mask_col), dim=1) 298 | 299 | if stop_on_token is not None: 300 | stop_tokens = stop_on_token if isinstance(stop_on_token, (list, tuple, set)) else [stop_on_token] 301 | if idx_next.item() in stop_tokens: 302 | break 303 | return idx 304 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lille 130M 2 | 3 | ![Lille-Header](assets/lille-header.png) 4 | 5 | ## Table of Contents 6 | 1. [Model Summary](#-model-summary) 7 | 2. [Evaluation](#-evaluation) 8 | 3. [How to Use](#-how-to-use) 9 | 4. [Training and Finetuning](#-training-and-finetuning) 10 | 5. [Training Details](#-training-details) 11 | 6. [Limitations](#-limitations) 12 | 7. [The Truly Open-Source Stack](#-the-truly-open-source-repos) 13 | 8. [License](#-license) 14 | 9. [Citation](#-citation) 15 | 16 | ## ✨ Model Summary 17 | 18 | **Lille** is a 130-million-parameter language model built from the ground up as a core component of a completely open-source deep learning stack. The name Lille reflects both its compact size and strong capabilities - capturing the idea that less can be more. It draws on the Norwegian word lille (‘small’ or ‘little’) as well as the French city Lille, giving it both meaning and place. It was trained using a custom tokenizer, a curated dataset, and a memory-efficient optimizer, all of which are publicly available. 19 | 20 | The model comes in two versions: 21 | * **`Lille-130M-Base`**: The foundational model pretrained on 4.27 billion of tokens from the [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) dataset. A post-processing step to only include the highest quality of content was added. It has strong general knowledge and text completion abilities. 22 | * **`Lille-130M-Instruct`**: The instruction-tuned version, fine-tuned on the **[Kyoto-Corpus](https://huggingface.co/datasets/Nikity/Kyoto-Corpus)**. It excels at following user commands, engaging in chat, and performing a variety of instruction-based tasks. 23 | 24 | The model architecture is a modern Transformer decoder featuring Grouped-Query Attention (GQA), RoPE, and RMSNorm, making it efficient and performant for its size. 25 | 26 | *Note on parameter count: While the model name is `130M` for simplicity, the actual parameter count is 127.17 million.* 27 | 28 | ## 📊 Evaluation 29 | 30 | All evaluations were conducted using **[simple-eval](https://github.com/Nikityyy/simple-eval)**, our open-source evaluation framework. Benchmarks are run in a zero-shot setting unless specified otherwise. 31 | 32 | #### `Lille-130M-Instruct` 33 | 34 | ![Evaluations](assets/evaluations.png) 35 | 36 | > Evaluations for other LLMs are sourced from the Open LLM Leaderboard or their respective model cards when benchmark data is unavailable. For Lille 140M Instruct, evaluations are performed using simple-eval. ARC-C and ARC-E for Smollm2 are also evaluated using simple-eval. 37 | 38 | ## 🚀 How to Use 39 | 40 | There are several ways to use the Lille models, from easy-to-use graphical interfaces to advanced programmatic control. 41 | 42 | ### 1. LM Studio (Easiest for Chat) 43 | 44 | LM Studio provides a simple graphical interface to run LLMs on your local machine. It's the easiest way to start chatting with Lille. 45 | 46 | 1. **Download & Install:** Get [LM Studio](https://lmstudio.ai/) for your operating system (Windows, Mac, or Linux). 47 | 2. **Search for the Model:** Open LM Studio and click the **magnifying glass** icon on the left. 48 | 3. **Find Lille:** In the search bar, type `Lille` or `Nikity`. You will find the models I have uploaded. 49 | 4. **Download a GGUF:** On the right-hand side, you'll see a list of GGUF files. Download a recommended version like `lille-130m-instruct-f16.gguf`. 50 | 5. **Chat:** Click the **speech bubble** icon on the left. At the top, select the model you just downloaded. Now you can start a conversation! 51 | 52 | ### 2. SimpleAI SDK (Recommended for Programmatic Use) 53 | 54 | The easiest way to use Lille programmatically is with the `simpleai-sdk`, which handles all the boilerplate for you and provides a simple, high-level API for both Hugging Face and ONNX backends. 55 | 56 | ```bash 57 | pip install simpleai-sdk 58 | ``` 59 | 60 | ```python 61 | from simple_ai import lille 62 | 63 | # This will download and cache the model on first run. 64 | # Specify the model version: "130m-instruct" (default) or "130m-base" 65 | # Specify the backend: "huggingface" (default) or "onnx" 66 | model = lille("huggingface", "130m-instruct") 67 | 68 | # --- For Chat (with instruct model) --- 69 | print("--- Chat Example ---") 70 | response1 = model.chat("What is the capital of France?", max_new_tokens=50) 71 | print(f"Bot: {response1}") 72 | 73 | response2 = model.chat("And what is its population?", max_new_tokens=50, top_p=0.90) 74 | print(f"Bot: {response2}") 75 | 76 | # This resets the chat history 77 | model.reset_chat() 78 | 79 | # --- For Text Completion (with base or instruct model) --- 80 | prompt = "Artificial Intelligence is" 81 | response = model.generate(prompt, max_new_tokens=50, temperature=0.9) 82 | print(f"\n--- Completion Example ---\n{prompt}{response}") 83 | ``` 84 | 85 | ### 3. Standard Hugging Face Transformers (this also needs `simpleai-sdk` currently) 86 | 87 | You can also use the model directly with the `transformers` library for more advanced use cases. 88 | 89 | ```bash 90 | pip install transformers torch simpleai-sdk 91 | ``` 92 | 93 | ```python 94 | import torch 95 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 96 | from simple_ai.model_hf import LilleConfig, LilleForCausalLM 97 | 98 | # 1. Register the custom model architecture with Hugging Face 99 | AutoConfig.register("lille-130m", LilleConfig) 100 | AutoModelForCausalLM.register(LilleConfig, LilleForCausalLM) 101 | 102 | # 2. Define constants and setup device 103 | MODEL = "Nikity/lille-130m-instruct" 104 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 105 | 106 | # 3. Load tokenizer and model 107 | tokenizer = AutoTokenizer.from_pretrained(MODEL) 108 | model = AutoModelForCausalLM.from_pretrained( 109 | MODEL, 110 | torch_dtype="auto", 111 | device_map=DEVICE, 112 | ) 113 | 114 | # 4. Prepare chat prompt and tokenize it 115 | chat = [{"role": "user", "content": "What is the capital of France?"}] 116 | inputs = tokenizer.apply_chat_template( 117 | chat, 118 | add_generation_prompt=True, 119 | return_tensors="pt" 120 | ).to(DEVICE) 121 | 122 | # 5. Generate a response 123 | with torch.inference_mode(): 124 | outputs = model.generate( 125 | input_ids=inputs, 126 | max_new_tokens=512, 127 | eos_token_id=tokenizer.eos_token_id, 128 | pad_token_id=tokenizer.pad_token_id, 129 | do_sample=True, 130 | temperature=0.5, 131 | top_p=0.95, 132 | ) 133 | 134 | # 6. Decode and print the response 135 | response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) 136 | print(response) 137 | ``` 138 | 139 | ## 🚀 Training and Finetuning 140 | 141 | You can replicate the pretraining of `Lille-130M-Base` or fine-tune it on your own dataset using the provided scripts. 142 | 143 | #### 1. Setup 144 | 145 | First, clone the repository and install the required dependencies: 146 | 147 | ```bash 148 | git clone https://github.com/Nikityyy/lille 149 | cd lille 150 | pip install -r requirements.txt 151 | ``` 152 | 153 | **Note on the Optimizer:** The default `Sophia-Triton` optimizer requires the [Triton](https://triton-lang.org/main/getting-started/installation.html) library. Triton is officially supported on Linux with NVIDIA GPUs. While experimental installation on Windows is possible, it can be a complex and difficult process. For a much simpler setup on **Windows and macOS**, or if you prefer not to install Triton, it is highly recommended to use a pure PyTorch implementation of Sophia instead: 154 | 155 | 1. Replace the contents of the `sophia_triton.py` file with the code from [this link](https://github.com/Nikityyy/Sophia-Triton/blob/main/sophia.py). 156 | 2. The `train.py` script should work without any import changes, as the class name `SophiaG` is the same. 157 | 158 | #### 2. Data Preparation 159 | 160 | The training script expects data in a specific `.npz` format containing tokenized documents and their offsets. 161 | 162 | **For Pretraining (like FineWeb-Edu):** 163 | 164 | Use the `prepare_dataset_fineweb.py` script. It will stream the dataset from Hugging Face, apply filters, tokenize the text, and save it in the required format. 165 | 166 | ```bash 167 | python prepare_dataset_fineweb.py 168 | ``` 169 | This will create `data/fineweb_edu_sample_10BT/train.npz` and `val.npz`. 170 | 171 | **For Finetuning (Instruction Datasets):** 172 | 173 | Use the `prepare_dataset.py` script. Your input data should be a single `.txt` file where each example is separated by the `<|endoftext|>` token. 174 | 175 | 1. Place your data file, for example, at `data/my_dataset/train.txt`. 176 | 2. Modify the `input_file_path` and `output_dir` variables in `prepare_dataset.py`. 177 | 3. Run the script: 178 | 179 | ```bash 180 | python prepare_dataset.py 181 | ``` 182 | This will create `train.npz` and `val.npz` in your specified output directory. 183 | 184 | #### 3. Running the Training Script 185 | 186 | All training logic is handled by `train.py`. You can configure hyperparameters directly at the top of this file. 187 | 188 | **To Pretrain from Scratch:** 189 | 190 | 1. Ensure you have prepared a pretraining dataset. 191 | 2. In `train.py`, set `finetune = False`. 192 | 3. Configure pretraining parameters like `data_dir`, `batch_size`, etc. 193 | 4. Run the script: 194 | 195 | ```bash 196 | python train.py 197 | ``` 198 | 199 | **To Fine-tune a Pretrained Model:** 200 | 201 | 1. Ensure you have prepared a fine-tuning dataset. 202 | 2. In `train.py`, set `finetune = True`. 203 | 3. Set `resume_checkpoint` to the path of the pretrained model checkpoint (e.g., `checkpoints/best_model.pt`). 204 | 4. Configure fine-tuning parameters like `finetune_data_dir` and `finetune_learning_rate`. 205 | 5. Run the script: 206 | 207 | ```bash 208 | python train.py 209 | ``` 210 | 211 | Checkpoints will be saved in the directory specified by `out_dir` (for pretraining) or `finetune_out_dir` (for fine-tuning). The best model based on validation loss will be saved as `best_model.pt`. 212 | 213 | ## 🛠️ Training Details 214 | 215 | ### Pretraining (`Lille-130M-Base`) 216 | * **Dataset:** Pretrained on **4.27 billion tokens** from the `sample-10BT` configuration of the [HuggingFaceFW/fineweb-edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) dataset. 217 | * **Tokenizer:** The custom **[Hastings](https://github.com/Nikityyy/Hastings)** tokenizer with a 32,768 vocabulary size. 218 | * **Optimizer:** The memory-efficient **[Sophia-Triton](https://github.com/Nikityyy/Sophia-Triton)** optimizer. 219 | * **Hardware:** Trained on a single NVIDIA RTX 4070-TI. 220 | * **Precision:** bfloat16. 221 | 222 | ### Instruction Tuning (`Lille-130M-Instruct`) 223 | * **Dataset:** Supervised Fine-Tuning (SFT) was performed on the **[Kyoto-Corpus](https://github.com/Nikityyy/Kyoto-Corpus)**, a high-quality, curated collection of conversational and instructional data. 224 | 225 | ### Model Architecture 226 | * **Type:** Transformer Decoder 227 | * **Layers:** 24 228 | * **Embedding Size:** 640 229 | * **Attention Heads:** 10 230 | * **KV Heads (GQA):** 2 231 | * **Context Length:** 512 tokens 232 | 233 | ## Limitations 234 | 235 | Lille models primarily understand and generate content in English. While powerful for their size, they can produce text that may not always be factually accurate, logically consistent, or free from biases present in the training data. These models should be used as assistive tools rather than definitive sources of information. Users should always verify important information and critically evaluate any generated content. 236 | 237 | ## 🛠️ The truly open-source repos 238 | 239 | Lille is a key component of my initiative to build and release a complete, truly open-source stack for language modeling. All components are designed to work together seamlessly. 240 | 241 | * **Tokenizer:** **[Hastings](https://github.com/Nikityyy/Hastings)** - A modern, efficient tokenizer with a 32k vocabulary. 242 | * **Dataset:** **[Kyoto-Corpus](https://github.com/Nikityyy/Kyoto-Corpus)** - A high-quality, small-scale dataset for instruction tuning. 243 | * **Model:** **[lille](https://github.com/Nikityyy/lille)** (this repository) - A powerful 130-million-parameter model trained from scratch. 244 | * **Optimizer:** **[Sophia-Triton](https://github.com/Nikityyy/Sophia-Triton)** - A memory-efficient, Triton-based implementation of the SophiaG optimizer. 245 | * **Evaluations:** **[simple-eval](https://github.com/Nikityyy/simple-eval)** - A straightforward framework for evaluating model performance using an LLM as a Judge. 246 | 247 | ## 🙏 Credits 248 | 249 | Lille’s training scripts and architecture were inspired by and build upon the work of: 250 | 251 | * **nanoGPT** – A minimal and efficient PyTorch implementation of GPT training: [https://github.com/karpathy/nanoGPT](https://github.com/karpathy/nanoGPT) 252 | * **gpt-oss** – The open models from OpenAI: [https://github.com/openai/gpt-oss](https://github.com/openai/gpt-oss) 253 | 254 | ## 📜 License 255 | 256 | This project is licensed under the Apache-2.0 License. 257 | 258 | ## Citation 259 | 260 | If you use Lille or any part of this open-source stack in your work, please consider citing it: 261 | 262 | ```bibtex 263 | @misc{lille-130m, 264 | author = {Nikita Berger}, 265 | title = {Lille: A Truly Open-Source 130M Language Model}, 266 | year = {2025}, 267 | publisher = {GitHub}, 268 | journal = {GitHub repository}, 269 | howpublished = {\url{https://github.com/Nikityyy/lille}} 270 | } 271 | ``` 272 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import pickle 5 | import threading 6 | import collections 7 | import queue 8 | from contextlib import nullcontext 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.distributed as dist 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | from tqdm import tqdm 16 | import tiktoken 17 | import wandb 18 | from html2term import printc 19 | 20 | from sophia_triton import SophiaG 21 | from model import GPT, GPTConfig 22 | 23 | # --- General Settings --- 24 | out_dir = 'checkpoints' 25 | eval_interval = 500 26 | log_interval = 1 27 | eval_iters = 100 28 | resume_checkpoint = None 29 | # resume_checkpoint = "checkpoints/best_model.pt" 30 | 31 | # --- Finetuning Settings --- 32 | finetune = False 33 | finetune_out_dir = 'checkpoints_ft' 34 | finetune_data_dir = 'data/smol-sft' 35 | finetune_learning_rate = 1e-5 36 | finetune_num_epochs = 3 37 | 38 | # --- W&B Logging --- 39 | wandb_log = True 40 | wandb_project = 'modern-gpt-pretrain' 41 | wandb_run_name = f'run-modern-gpt-{time.strftime("%Y-%m-%d-%H-%M-%S")}' 42 | 43 | # --- Data Settings --- 44 | data_dir = 'data/fineweb_edu_sample_10BT' 45 | pretrain_data_dir = data_dir 46 | batch_size = 16 47 | block_size = 512 48 | num_epochs = 1 49 | gradient_accumulation_steps = 2 50 | 51 | # --- Model Architecture --- 52 | n_layers = 24 53 | n_embd = 640 54 | n_heads = 10 55 | n_kv_heads = 2 56 | dropout = 0.1 57 | layer_norm_eps = 1e-5 58 | 59 | # --- Optimizer & LR Schedule --- 60 | learning_rate = 1e-4 61 | weight_decay = 0.2 62 | beta1 = 0.9 63 | beta2 = 0.95 64 | grad_clip = 1.0 65 | decay_lr = True 66 | warmup_iters = 2000 67 | min_lr = learning_rate / 10 68 | hess_interval = 10 69 | 70 | # --- Hardware & Performance --- 71 | device = 'cuda' 72 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' 73 | compile = True 74 | 75 | class NpzDataset: 76 | """A simple lazy-loading dataset for the tokens/offsets .npz format.""" 77 | def __init__(self, file_path): 78 | self.data = np.load(file_path, mmap_mode='r') 79 | self.tokens = self.data['tokens'] 80 | self.offsets = self.data['offsets'] 81 | self.num_docs = len(self.offsets) - 1 82 | 83 | def __len__(self): 84 | return self.num_docs 85 | 86 | def __getitem__(self, idx): 87 | start = self.offsets[idx] 88 | end = self.offsets[idx + 1] 89 | return self.tokens[start:end].tolist() 90 | 91 | class DataPrefetcher: 92 | """ An asynchronous data prefetcher that prepares batches on the CPU. """ 93 | def __init__(self, data, block_size, batch_size, max_prefetch=2): 94 | self.data = data 95 | self.block_size = block_size 96 | self.batch_size = batch_size 97 | self.queue = queue.Queue(maxsize=max_prefetch) 98 | self.is_running = True 99 | self.thread = threading.Thread(target=self.run, daemon=True) 100 | self.thread.start() 101 | 102 | def _preload(self): 103 | return get_batch('train') 104 | 105 | def run(self): 106 | while self.is_running: 107 | try: 108 | self.queue.put(self._preload(), timeout=1) 109 | except queue.Full: 110 | continue 111 | 112 | def next(self): 113 | return self.queue.get() 114 | 115 | def close(self): 116 | self.is_running = False 117 | while not self.queue.empty(): 118 | try: 119 | self.queue.get_nowait() 120 | except queue.Empty: 121 | break 122 | self.thread.join() 123 | 124 | def get_batch(split, pretrain=False): 125 | """ 126 | Get a batch of data. Handles padding for sequences shorter than block_size. 127 | For supervised fine-tuning, it masks out the loss for prompt tokens. 128 | """ 129 | if pretrain: 130 | data = train_data_pretrain if split == 'train' else val_data_pretrain 131 | else: 132 | data = train_data if split == 'train' else val_data 133 | 134 | ix = torch.randint(len(data), (batch_size,)) 135 | batch_raw = [data[i] for i in ix] 136 | 137 | x_padded = torch.full((batch_size, block_size), pad_token_id, dtype=torch.long) 138 | y_padded = torch.full((batch_size, block_size), -100, dtype=torch.long) 139 | 140 | is_finetune_split = finetune and not pretrain 141 | 142 | for i, tokens in enumerate(batch_raw): 143 | seq_len = min(len(tokens), block_size) 144 | x_padded[i, :seq_len] = torch.tensor(tokens[:seq_len], dtype=torch.long) 145 | 146 | targets = torch.tensor(tokens[1:seq_len], dtype=torch.long) 147 | 148 | if is_finetune_split and assistant_token_id is not None: 149 | x_seq = x_padded[i, :seq_len] 150 | assistant_indices = (x_seq == assistant_token_id).nonzero(as_tuple=True)[0] 151 | 152 | if len(assistant_indices) > 0: 153 | last_assistant_idx = assistant_indices[-1] 154 | targets[:last_assistant_idx] = -100 155 | 156 | y_padded[i, :seq_len-1] = targets 157 | 158 | return x_padded, y_padded 159 | 160 | @torch.no_grad() 161 | def estimate_loss(pretrain=False): 162 | """ 163 | Estimate loss on train and validation splits. 164 | """ 165 | out = {} 166 | model.eval() 167 | for split in ['train', 'val']: 168 | losses = torch.zeros(eval_iters, device=device) 169 | for k in range(eval_iters): 170 | X_cpu, Y_cpu = get_batch(split, pretrain=pretrain) 171 | X = X_cpu.to(device, non_blocking=True) 172 | Y = Y_cpu.to(device, non_blocking=True) 173 | attn_mask = (X != pad_token_id) 174 | with ctx: 175 | logits, _ = model(X, attn_mask=attn_mask) 176 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1)) 177 | losses[k] = loss 178 | if ddp: 179 | dist.all_reduce(losses, op=dist.ReduceOp.SUM) 180 | losses /= ddp_world_size 181 | out[split] = losses.mean() 182 | model.train() 183 | return out 184 | 185 | def configure_optimizers(model, weight_decay, learning_rate, betas): 186 | """ 187 | Configure optimizer with weight decay for 2D parameters. 188 | """ 189 | param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} 190 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 191 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 192 | optim_groups = [ 193 | {'params': decay_params, 'weight_decay': weight_decay}, 194 | {'params': nodecay_params, 'weight_decay': 0.0} 195 | ] 196 | num_decay_params = sum(p.numel() for p in decay_params) 197 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 198 | if master_process: 199 | printc(f" <#cccccc>Num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 200 | printc(f" <#cccccc>Num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters
") 201 | # If you get the following error: "got an unexpected keyword argument 'bs'", then remove bs=tokens_per_optimizer_step 202 | optimizer = SophiaG(optim_groups, lr=learning_rate, betas=betas, rho=0.05, weight_decay=weight_decay, bs=tokens_per_optimizer_step) 203 | return optimizer 204 | 205 | def get_cosine_schedule_with_warmup_scheduler(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio_val=0.1): 206 | """ 207 | Create a learning rate scheduler with a cosine decay and linear warmup. 208 | """ 209 | def lr_lambda_func(current_step): 210 | if current_step < num_warmup_steps: 211 | return float(current_step) / float(max(1, num_warmup_steps)) 212 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 213 | cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) 214 | return min_lr_ratio_val + (1.0 - min_lr_ratio_val) * cosine_decay 215 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda_func) 216 | 217 | def save_checkpoint_async(checkpoint, path, force=False, old_path_to_delete=None): 218 | """ 219 | Save a checkpoint asynchronously in a separate thread. 220 | """ 221 | temp_path = path + ".tmp" 222 | try: 223 | torch.save(checkpoint, temp_path) 224 | os.replace(temp_path, path) 225 | if old_path_to_delete and os.path.exists(old_path_to_delete): 226 | os.remove(old_path_to_delete) 227 | except Exception as e: 228 | printc(f"Error saving checkpoint to {path}: {e}") 229 | if os.path.exists(temp_path): 230 | os.remove(temp_path) 231 | 232 | if finetune: 233 | out_dir = finetune_out_dir 234 | learning_rate = finetune_learning_rate 235 | num_epochs = finetune_num_epochs 236 | data_dir = finetune_data_dir 237 | warmup_iters = 1000 238 | weight_decay = 0.01 239 | dropout = 0.0 240 | wandb_project = 'modern-gpt-finetune' 241 | if not resume_checkpoint: 242 | raise ValueError("For finetuning, a `resume_checkpoint` must be provided.") 243 | printc("" + "="*50 + "") 244 | printc("|| FINETUNING MODE ENABLED") 245 | printc(f"|| Output directory: {out_dir}") 246 | printc(f"|| Data directory: {data_dir}") 247 | printc(f"|| Learning rate: {learning_rate}") 248 | printc(f"|| Epochs: {num_epochs}") 249 | printc("" + "="*50 + "
") 250 | else: 251 | printc("" + "="*50 + "") 252 | printc("|| PRETRAINING MODE ENABLED") 253 | printc(f"|| Output directory: {out_dir}") 254 | printc(f"|| Data directory: {data_dir}") 255 | printc(f"|| Learning rate: {learning_rate}") 256 | printc(f"|| Epochs: {num_epochs}") 257 | printc("" + "="*50 + "
") 258 | 259 | ddp = int(os.environ.get('RANK', -1)) != -1 260 | if ddp: 261 | dist.init_process_group(backend='nccl') 262 | ddp_rank = int(os.environ['RANK']) 263 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 264 | ddp_world_size = int(os.environ['WORLD_SIZE']) 265 | device = f'cuda:{ddp_local_rank}' 266 | torch.cuda.set_device(device) 267 | master_process = ddp_rank == 0 268 | seed_offset = ddp_rank 269 | else: 270 | master_process = True 271 | seed_offset = 0 272 | ddp_world_size = 1 273 | 274 | if master_process: 275 | os.makedirs(out_dir, exist_ok=True) 276 | 277 | torch.manual_seed(1337 + seed_offset) 278 | torch.backends.cuda.matmul.allow_tf32 = True 279 | torch.backends.cudnn.allow_tf32 = True 280 | device_type = 'cuda' if 'cuda' in device else 'cpu' 281 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 282 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type, dtype=ptdtype) 283 | 284 | if wandb_log and master_process: 285 | config_dict = {k: v for k, v in locals().items() if isinstance(v, (int, float, str, bool))} 286 | wandb.init(project=wandb_project, name=wandb_run_name, config=config_dict) 287 | 288 | with open('tokenizer/Hastings.pkl', 'rb') as f: 289 | hastings = pickle.load(f) 290 | enc = tiktoken.core.Encoding(hastings.pop('name'), **hastings) 291 | vocab_size = enc.n_vocab 292 | assistant_token_id = enc.encode_single_token("<|assistant|>") 293 | 294 | try: 295 | pad_token_id = enc.encode_single_token("<|pad|>") 296 | except KeyError: 297 | printc("Warning: '<|pad|>' token not found. Using '<|endoftext|>' as a pad token.") 298 | pad_token_id = enc.encode_single_token("<|endoftext|>") 299 | 300 | printc("
Loading dataset using NpzDataset...") 301 | train_data = NpzDataset(os.path.join(data_dir, 'train.npz')) 302 | val_data = NpzDataset(os.path.join(data_dir, 'val.npz')) 303 | if data_dir != pretrain_data_dir: 304 | train_data_pretrain = NpzDataset(os.path.join(pretrain_data_dir, 'train.npz')) 305 | val_data_pretrain = NpzDataset(os.path.join(pretrain_data_dir, 'val.npz')) 306 | else: 307 | train_data_pretrain, val_data_pretrain = train_data, val_data 308 | 309 | train_tokens = len(train_data.tokens) 310 | tokens_per_optimizer_step = batch_size * block_size * ddp_world_size * gradient_accumulation_steps 311 | max_optimizer_steps = (train_tokens // tokens_per_optimizer_step) * num_epochs 312 | iters_per_epoch_optimizer_steps = train_tokens // tokens_per_optimizer_step 313 | lr_decay_iters = max_optimizer_steps 314 | 315 | model_args = dict( 316 | n_layers=n_layers, n_embd=n_embd, vocab_size=vocab_size, block_size=block_size, 317 | dropout=dropout, n_heads=n_heads, n_kv_heads=n_kv_heads, layer_norm_eps=layer_norm_eps 318 | ) 319 | config = GPTConfig(**model_args) 320 | model = GPT(config) 321 | model.to(device) 322 | num_params = sum(p.numel() for p in model.parameters()) 323 | if master_process: 324 | printc(f"Model has {num_params / 1e6:.2f}M parameters.") 325 | 326 | scaler = torch.amp.GradScaler(enabled=(dtype == 'float16')) 327 | optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2)) 328 | 329 | min_lr_ratio_for_scheduler = min_lr / learning_rate 330 | scheduler = get_cosine_schedule_with_warmup_scheduler( 331 | optimizer, num_warmup_steps=warmup_iters, 332 | num_training_steps=max_optimizer_steps, min_lr_ratio_val=min_lr_ratio_for_scheduler 333 | ) 334 | 335 | iter_num = 0 336 | best_val_loss = 1e9 337 | if resume_checkpoint and os.path.exists(resume_checkpoint): 338 | if master_process: printc(f"Loading checkpoint: {resume_checkpoint}") 339 | checkpoint = torch.load(resume_checkpoint, map_location=device) 340 | ckpt_model_args = checkpoint['model_args'] 341 | for k, v in model_args.items(): 342 | if k not in ckpt_model_args or ckpt_model_args[k] != v: 343 | if master_process: printc(f" Warning: Mismatch in model config: '{k}'") 344 | state_dict = checkpoint['model_state_dict'] 345 | unwanted_prefix = '_orig_mod.' 346 | for k,v in list(state_dict.items()): 347 | if k.startswith(unwanted_prefix): 348 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 349 | model.load_state_dict(state_dict, strict=False) 350 | 351 | if not finetune or (finetune and 'optimizer_state_dict' in checkpoint and resume_checkpoint.startswith(finetune_out_dir)): 352 | if master_process: printc("Resuming training with optimizer state.") 353 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 354 | for group in optimizer.param_groups: 355 | group.setdefault('bs', tokens_per_optimizer_step) 356 | group.setdefault('eps', 1e-15) 357 | iter_num = checkpoint['iter_num'] 358 | best_val_loss = checkpoint['best_val_loss'] 359 | scheduler.last_epoch = iter_num 360 | 361 | if compile: 362 | if master_process: printc("Compiling the model...") 363 | model = torch.compile(model, backend="inductor", mode="max-autotune") 364 | if master_process: 365 | printc(" Warming up the compiled model...") 366 | with ctx: 367 | with torch.no_grad(): 368 | x_warm_cpu, _ = get_batch('train') 369 | x_warm = x_warm_cpu.to(device, non_blocking=True) 370 | attn_mask_warm = (x_warm != pad_token_id) 371 | _, _ = model(x_warm, attn_mask=attn_mask_warm) 372 | printc(" Warm-up complete.
") 373 | 374 | if ddp: 375 | model = DDP(model, device_ids=[ddp_local_rank], find_unused_parameters=False) 376 | 377 | if master_process: printc("Setting up asynchronous data prefetcher for training...") 378 | train_prefetcher = DataPrefetcher(train_data, block_size, batch_size) 379 | 380 | t0 = time.time() 381 | checkpoint_threads = collections.deque() 382 | last_interval_checkpoint_path = None 383 | pbar = tqdm(range(iter_num, max_optimizer_steps), disable=not master_process) 384 | 385 | for optimizer_step in pbar: 386 | if optimizer_step > iter_num and optimizer_step % eval_interval == 0: 387 | losses = estimate_loss() 388 | if finetune: 389 | losses_pt = estimate_loss(pretrain=True) 390 | current_epoch = optimizer_step / iters_per_epoch_optimizer_steps 391 | if master_process: 392 | printc(f"
Epoch {current_epoch:.2f} | Step {optimizer_step}") 393 | if finetune: 394 | printc(f" Finetune Loss -> Train: {losses['train']:.4f}, Val: {losses['val']:.4f}") 395 | printc(f" <#cccccc>Pretrain Loss -> Train: {losses_pt['train']:.4f}, Val: {losses_pt['val']:.4f}") 396 | else: 397 | printc(f" Pretrain Loss -> Train: {losses['train']:.4f}, Val: {losses['val']:.4f}") 398 | if wandb_log: 399 | log_data = {'eval/train_loss': losses['train'], 'eval/val_loss': losses['val'], 'trainer/epoch': current_epoch} 400 | if finetune: 401 | log_data.update({'eval/pretrain_train_loss': losses_pt['train'], 'eval/pretrain_val_loss': losses_pt['val']}) 402 | wandb.log(log_data, step=optimizer_step) 403 | 404 | while checkpoint_threads and not checkpoint_threads[0].is_alive(): 405 | checkpoint_threads.popleft().join() 406 | raw_model = model.module if ddp else model 407 | checkpoint = { 408 | 'model_state_dict': raw_model.state_dict(), 409 | 'optimizer_state_dict': optimizer.state_dict(), 410 | 'model_args': raw_model.config.to_dict(), 411 | 'iter_num': optimizer_step, 412 | 'best_val_loss': best_val_loss 413 | } 414 | checkpoint_path = os.path.join(out_dir, f'ckpt_iter_{optimizer_step}.pt') 415 | thread = threading.Thread(target=save_checkpoint_async, args=(checkpoint.copy(), checkpoint_path, False, last_interval_checkpoint_path)) 416 | thread.start() 417 | checkpoint_threads.append(thread) 418 | last_interval_checkpoint_path = checkpoint_path 419 | 420 | if losses['val'] < best_val_loss: 421 | best_val_loss = losses['val'] 422 | checkpoint['best_val_loss'] = best_val_loss 423 | best_model_path = os.path.join(out_dir, 'best_model.pt') 424 | thread = threading.Thread(target=save_checkpoint_async, args=(checkpoint.copy(), best_model_path, True)) 425 | thread.start() 426 | checkpoint_threads.append(thread) 427 | printc(f" Started saving new best model to {best_model_path}") 428 | 429 | optimizer.zero_grad(set_to_none=True) 430 | 431 | if compile and device_type == 'cuda': 432 | torch.compiler.cudagraph_mark_step_begin() 433 | 434 | for micro_step in range(gradient_accumulation_steps): 435 | X_cpu, Y_cpu = train_prefetcher.next() 436 | X = X_cpu.pin_memory().to(device, non_blocking=True) 437 | Y = Y_cpu.pin_memory().to(device, non_blocking=True) 438 | attn_mask = (X != pad_token_id) 439 | with ctx: 440 | logits, _ = model(X, attn_mask=attn_mask) 441 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1)) / gradient_accumulation_steps 442 | scaler.scale(loss).backward() 443 | 444 | scaler.unscale_(optimizer) 445 | if grad_clip > 0.0: 446 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip) 447 | 448 | if optimizer_step == 0: 449 | optimizer.update_hessian() 450 | 451 | scaler.step(optimizer) 452 | scaler.update() 453 | 454 | if optimizer_step % hess_interval == 0: 455 | with ctx: 456 | logits, _ = model(X, attn_mask=attn_mask) 457 | probs = F.softmax(logits, dim=-1) 458 | y_sample = torch.multinomial(probs.view(-1, logits.size(-1)), num_samples=1).view_as(Y) 459 | loss_sampled = F.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1)) 460 | scaler.scale(loss_sampled).backward() 461 | scaler.unscale_(optimizer) 462 | if grad_clip > 0.0: 463 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip) 464 | optimizer.schedule_hessian_update() 465 | optimizer.zero_grad(set_to_none=True) 466 | 467 | if decay_lr: 468 | scheduler.step() 469 | 470 | t1 = time.time() 471 | dt = t1 - t0 472 | t0 = t1 473 | if optimizer_step % log_interval == 0 and master_process: 474 | lossf = loss.item() * gradient_accumulation_steps 475 | current_lr = optimizer.param_groups[0]['lr'] 476 | pbar.set_description(f"step {optimizer_step + 1}: loss {lossf:.4f}, time {dt*1000:.2f}ms, lr {current_lr:e}") 477 | if wandb_log: 478 | wandb.log({'train/loss': lossf, 'trainer/lr': current_lr, 'trainer/dt_ms': dt * 1000}, step=optimizer_step) 479 | 480 | pbar.close() 481 | if master_process: printc("
Training loop finished. Closing data prefetcher...") 482 | train_prefetcher.close() 483 | 484 | if master_process: 485 | printc("Saving final model and waiting for all saves to complete...") 486 | raw_model = model.module if ddp else model 487 | final_checkpoint = { 488 | 'model_state_dict': raw_model.state_dict(), 489 | 'optimizer_state_dict': optimizer.state_dict(), 490 | 'model_args': raw_model.config.to_dict(), 491 | 'iter_num': max_optimizer_steps, 492 | 'best_val_loss': best_val_loss 493 | } 494 | final_checkpoint_path = os.path.join(out_dir, 'ckpt_final.pt') 495 | thread = threading.Thread(target=save_checkpoint_async, args=(final_checkpoint, final_checkpoint_path, True)) 496 | thread.start() 497 | checkpoint_threads.append(thread) 498 | 499 | while checkpoint_threads: 500 | printc(f" <#cccccc>Waiting for {len(checkpoint_threads)} remaining checkpoint(s) to save...") 501 | checkpoint_threads.popleft().join() 502 | 503 | if wandb_log: 504 | wandb.finish() 505 | 506 | if ddp: 507 | dist.destroy_process_group() 508 | 509 | printc("
✅ Training complete and all checkpoints saved.") 510 | --------------------------------------------------------------------------------