187 |
283 |
284 |
285 |
286 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings("ignore")
4 | import os
5 | import pickle
6 | import argparse
7 | import time
8 | from typing import Optional
9 |
10 | import cProfile
11 | import pstats
12 | import io
13 |
14 | import torch
15 | import torch._inductor.config
16 | import torch._functorch.config
17 | import torch.nn.functional as F
18 | import tiktoken
19 | import numpy as np
20 | import onnxruntime
21 | from html2term import printc
22 |
23 | torch._inductor.config.coordinate_descent_tuning = True
24 | torch._inductor.config.triton.unique_kernel_names = True
25 | torch._inductor.config.fx_graph_cache = True
26 | torch._functorch.config.enable_autograd_cache = True
27 |
28 | from model import GPTConfig
29 |
30 | from export_utils import create_onnx_model_for_inference
31 |
32 | def _apply_sampling(
33 | logits: torch.Tensor, temp: float, top_p: Optional[float], top_k: Optional[int]
34 | ) -> int:
35 | """Apply temperature, top-p, and top-k sampling to logits, expecting a GPU tensor."""
36 | if temp == 0.0:
37 | return torch.argmax(logits, dim=-1).item()
38 |
39 | logits.div_(temp)
40 |
41 | if top_p is not None and 0.0 < top_p < 1.0:
42 | probs = F.softmax(logits, dim=-1)
43 | sorted_probs, sorted_indices = torch.sort(probs, descending=True)
44 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
45 |
46 | sorted_indices_to_remove = cumulative_probs > top_p
47 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
48 | sorted_indices_to_remove[..., 0] = 0
49 |
50 | indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
51 | -1, sorted_indices, sorted_indices_to_remove
52 | )
53 | logits[indices_to_remove] = -float("Inf")
54 |
55 | elif top_k is not None and top_k > 0:
56 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
57 | logits[logits < v[..., -1, None]] = -float("Inf")
58 |
59 | probs = F.softmax(logits, dim=-1)
60 | return torch.multinomial(probs, num_samples=1).item()
61 |
62 |
63 | def run_chat_loop_io_binding(onnx_model_path, config, tokenizer, device, args):
64 | """Runs a highly optimized interactive chat loop using I/O Binding to keep data on the GPU."""
65 | if not args.profile:
66 | printc(" --- Starting Chat with Fused FP16 ONNX Model (I/O Binding Enabled) ---")
67 | sampling_params = f"temp={args.temperature}"
68 | if args.temperature == 0: sampling_params = "greedy"
69 | elif args.top_p is not None: sampling_params += f", top_p={args.top_p}"
70 | elif args.top_k is not None: sampling_params += f", top_k={args.top_k}"
71 | printc(f"<#cccccc>Params: {sampling_params}, max_new_tokens={args.max_new_tokens}. Type 'exit' or 'quit' to end.#cccccc>")
72 |
73 | options = onnxruntime.SessionOptions()
74 | options.log_severity_level = 3
75 |
76 | base_name = os.path.splitext(os.path.basename(onnx_model_path))[0]
77 | optimized_model_path = os.path.join(os.path.dirname(onnx_model_path), f"{base_name}.ort")
78 | options.optimized_model_filepath = optimized_model_path
79 | options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
80 |
81 | providers = []
82 | if device == "cuda":
83 | provider = "CUDAExecutionProvider"
84 | if provider not in onnxruntime.get_available_providers():
85 | raise RuntimeError(f"{provider} not available, please check your ONNX Runtime and CUDA setup.")
86 | device_id = torch.cuda.current_device()
87 | provider_options = {'device_id': device_id, 'arena_extend_strategy': 'kSameAsRequested'}
88 | providers = [
89 | ('CUDAExecutionProvider', provider_options),
90 | 'CPUExecutionProvider',
91 | ]
92 | if not args.profile: printc(f"Using ONNX Runtime providers: {providers[0][0]} (with CUDA Graph), {providers[1]}")
93 | device_name = 'cuda'
94 | torch_device = torch.device(f"cuda:{device_id}")
95 | else:
96 | provider = "CPUExecutionProvider"
97 | if not args.profile: printc(f"Using ONNX Runtime provider: {provider}")
98 | providers.append(provider)
99 | device_name = 'cpu'
100 | torch_device = torch.device("cpu")
101 |
102 | session = onnxruntime.InferenceSession(onnx_model_path, sess_options=options, providers=providers)
103 | stop_ids = [tokenizer.encode_single_token(t) for t in ["<|endoftext|>", "<|user|>"]]
104 |
105 | conversation_history_ids = []
106 |
107 | while True:
108 | try:
109 | if args.profile:
110 | prompt = "What is the capital of France and what is its history?"
111 | printc(f" You: {prompt}")
112 | else:
113 | printc(" You: ", end="")
114 | prompt = input()
115 |
116 | if prompt.lower() in ["exit", "quit"]: break
117 |
118 | if not conversation_history_ids:
119 | tokens_to_process = tokenizer.encode(f"<|startoftext|><|user|>{prompt}<|assistant|>", allowed_special="all")
120 | else:
121 | tokens_to_process = tokenizer.encode(f"<|user|>{prompt}<|assistant|>", allowed_special="all")
122 |
123 | conversation_history_ids.extend(tokens_to_process)
124 |
125 | if len(conversation_history_ids) > config.block_size:
126 | printc(" [CONTEXT RESET - Model has forgotten the conversation]")
127 | conversation_history_ids = tokenizer.encode(f"<|startoftext|><|user|>{prompt}<|assistant|>", allowed_special="all")
128 | tokens_to_process = conversation_history_ids
129 |
130 | binding = session.io_binding()
131 |
132 | input_ids_np = np.array([tokens_to_process], dtype=np.int64)
133 | input_ids_ort = onnxruntime.OrtValue.ortvalue_from_numpy(input_ids_np, device_name, 0)
134 | binding.bind_ortvalue_input('input_ids', input_ids_ort)
135 |
136 | dtype = np.float16
137 | empty_past = np.zeros((1, config.n_kv_heads, 0, config.n_embd // config.n_heads), dtype=dtype)
138 | empty_past_ort = onnxruntime.OrtValue.ortvalue_from_numpy(empty_past, device_name, 0)
139 | for i in range(config.n_layers):
140 | binding.bind_ortvalue_input(f'past_key_{i}', empty_past_ort)
141 | binding.bind_ortvalue_input(f'past_value_{i}', empty_past_ort)
142 |
143 | binding.bind_output('logits', device_name)
144 | for i in range(config.n_layers):
145 | binding.bind_output(f'present_key_{i}', device_name)
146 | binding.bind_output(f'present_value_{i}', device_name)
147 |
148 | session.run_with_iobinding(binding)
149 | ort_outs = binding.get_outputs()
150 | logits_ort, past_key_values = ort_outs[0], ort_outs[1:]
151 |
152 | logits_torch = torch.tensor(logits_ort.numpy(), device="cuda")
153 | next_token_id = _apply_sampling(logits_torch[0, -1, :], args.temperature, args.top_p, args.top_k)
154 |
155 | printc("Bot: ", end="", flush=True)
156 | generated_response_ids = []
157 | start_time = time.perf_counter()
158 |
159 | max_tokens = min(args.max_new_tokens, config.block_size - len(conversation_history_ids))
160 | if max_tokens <= 0:
161 | printc(" [CONTEXT FULL - Cannot generate more tokens]")
162 | if args.profile: break
163 | continue
164 |
165 | single_token_input_ort = onnxruntime.OrtValue.ortvalue_from_numpy(
166 | np.array([[next_token_id]], dtype=np.int64), device_name, 0
167 | )
168 |
169 | for _ in range(max_tokens):
170 | if next_token_id in stop_ids: break
171 |
172 | generated_response_ids.append(next_token_id)
173 | print(tokenizer.decode([next_token_id]), end="", flush=True)
174 |
175 | binding.bind_ortvalue_input('input_ids', single_token_input_ort)
176 | for j in range(config.n_layers):
177 | binding.bind_ortvalue_input(f'past_key_{j}', past_key_values[j*2])
178 | binding.bind_ortvalue_input(f'past_value_{j}', past_key_values[j*2+1])
179 |
180 | binding.bind_output('logits', device_name)
181 | for j in range(config.n_layers):
182 | binding.bind_output(f'present_key_{j}', device_name)
183 | binding.bind_output(f'present_value_{j}', device_name)
184 |
185 | session.run_with_iobinding(binding)
186 | ort_outs = binding.get_outputs()
187 |
188 | logits_ort, past_key_values = ort_outs[0], ort_outs[1:]
189 |
190 | logits_torch = torch.tensor(logits_ort.numpy(), device="cuda")
191 | next_token_id = _apply_sampling(logits_torch[0, 0, :], args.temperature, args.top_p, args.top_k)
192 |
193 | single_token_input_ort.update_inplace(np.array([[next_token_id]], dtype=np.int64))
194 |
195 |
196 | end_time = time.perf_counter()
197 | conversation_history_ids.extend(generated_response_ids)
198 | printc(" ")
199 |
200 | num_generated = len(generated_response_ids)
201 | time_taken = end_time - start_time
202 | tokens_per_sec = num_generated / time_taken if time_taken > 0 else 0
203 | printc(f"<#cccccc>Generated {num_generated} tokens in {time_taken:.2f}s ({tokens_per_sec:.2f} tokens/s)#cccccc>")
204 |
205 | if args.profile:
206 | break
207 |
208 | except KeyboardInterrupt:
209 | printc(" Exiting chat mode.")
210 | break
211 | except Exception as e:
212 | printc(f" An error occurred: {e}")
213 | import traceback
214 | traceback.print_exc()
215 | break
216 |
217 | def main():
218 | parser = argparse.ArgumentParser(
219 | description="Optimize a GPT model to Fused FP16 ONNX and run a fast chat session."
220 | )
221 | parser.add_argument("--checkpoint_path", type=str, default="checkpoints_ft/best_model.pt", help="Path to the PyTorch model checkpoint.")
222 | parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"], help="Device for model loading and ONNX export.")
223 | parser.add_argument("--max_new_tokens", type=int, default=480, help="Maximum number of new tokens to generate.")
224 | parser.add_argument("--temperature", type=float, default=0.5, help="Sampling temperature. 0 for greedy.")
225 | parser.add_argument("--top_k", type=int, default=None, help="Top-k sampling.")
226 | parser.add_argument("--top_p", type=float, default=0.95, help="Top-p (nucleus) sampling.")
227 | parser.add_argument("--profile", action="store_true", help="Enable cProfile to analyze performance of one generation cycle.")
228 | args = parser.parse_args()
229 |
230 | if args.device == "cuda" and not torch.cuda.is_available():
231 | printc("CUDA is selected but not available. Please check your environment.")
232 | exit(1)
233 |
234 | if args.temperature == 0:
235 | printc("Temperature is 0, using greedy decoding.")
236 |
237 | ONNX_MODEL_DIR = "onnx_models"
238 | os.makedirs(ONNX_MODEL_DIR, exist_ok=True)
239 |
240 | base_name = os.path.splitext(os.path.basename(args.checkpoint_path))[0]
241 | onnx_fp16_fused_path = os.path.join(ONNX_MODEL_DIR, f"{base_name}_fp16_kv_fused.onnx")
242 |
243 | with open("tokenizer/Hastings.pkl", "rb") as f:
244 | hastings = pickle.load(f)
245 | enc = tiktoken.core.Encoding(hastings.pop("name"), **hastings)
246 |
247 | create_onnx_model_for_inference(args.checkpoint_path, onnx_fp16_fused_path, args.device)
248 |
249 | checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
250 | config = GPTConfig(**checkpoint['model_args'])
251 | del checkpoint
252 |
253 | if args.profile:
254 | printc(" --- PROFILING MODE ENABLED --- ")
255 | pr = cProfile.Profile()
256 | pr.enable()
257 | run_chat_loop_io_binding(onnx_fp16_fused_path, config, enc, args.device, args)
258 | pr.disable()
259 | s = io.StringIO()
260 | ps = pstats.Stats(pr, stream=s).sort_stats('cumtime')
261 | ps.print_stats(30)
262 | printc(" --- Profiler Results (Top 30 by Cumulative Time) ---")
263 | print(s.getvalue())
264 | else:
265 | run_chat_loop_io_binding(onnx_fp16_fused_path, config, enc, args.device, args)
266 |
267 |
268 | if __name__ == "__main__":
269 | main()
270 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from dataclasses import dataclass, asdict
5 | from typing import Optional, Tuple, Union, List
6 |
7 | @dataclass
8 | class GPTConfig:
9 | """
10 | Configuration for the GPT model, inspired by GPT-OSS/Llama but adapted for this project.
11 | """
12 | n_embd: int = 768
13 | n_layers: int = 12
14 | n_heads: int = 12
15 | vocab_size: int = 32000
16 | block_size: int = 512
17 | dropout: float = 0.1
18 | layer_norm_eps: float = 1e-5
19 | n_kv_heads: Optional[int] = 4
20 | rope_theta: float = 10000.0
21 |
22 | def to_dict(self):
23 | return asdict(self)
24 |
25 | class RMSNorm(nn.Module):
26 | """
27 | Root Mean Square Layer Normalization, as used in GPT-OSS and Llama.
28 | """
29 | def __init__(self, dim: int, eps: float = 1e-5):
30 | super().__init__()
31 | self.eps = eps
32 | self.weight = nn.Parameter(torch.ones(dim))
33 |
34 | def _norm(self, x):
35 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
36 |
37 | def forward(self, x):
38 | output = self._norm(x.float()).type_as(x)
39 | return output * self.weight
40 |
41 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
42 | """
43 | Efficiently repeat the key and value tensors for Grouped-Query Attention.
44 | [B, n_kv_heads, T, head_dim] -> [B, n_q_heads, T, head_dim]
45 | """
46 | B, n_kv_heads, T, head_dim = x.shape
47 | if n_rep == 1:
48 | return x
49 | return (
50 | x[:, :, None, :, :]
51 | .expand(B, n_kv_heads, n_rep, T, head_dim)
52 | .reshape(B, n_kv_heads * n_rep, T, head_dim)
53 | )
54 |
55 | class RotaryPositionalEmbedding(nn.Module):
56 | """
57 | Original RoPE implementation, kept for its efficiency in training.
58 | """
59 | def __init__(self, dim: int, max_seq_len: int, base: int = 10000):
60 | super().__init__()
61 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
62 | self.register_buffer("inv_freq", inv_freq)
63 |
64 | t = torch.arange(max_seq_len, device=self.inv_freq.device)
65 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
66 | emb = torch.cat((freqs, freqs), dim=-1)
67 | self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
68 | self.register_buffer("sin_cached", emb.sin()[None, None, :, :])
69 |
70 | def forward(self, x, seq_len: int):
71 | return (
72 | self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
73 | self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
74 | )
75 |
76 | def apply_rotary_pos_emb(q, k, cos, sin):
77 | def rotate_half(x):
78 | return torch.cat([-x[..., 1::2], x[..., ::2]], dim=-1)
79 |
80 | q_embed = (q * cos) + (rotate_half(q) * sin)
81 | k_embed = (k * cos) + (rotate_half(k) * sin)
82 | return q_embed, k_embed
83 |
84 | class Attention(nn.Module):
85 | """
86 | Attention module with pre-normalization, based on Llama/GPT-OSS style.
87 | """
88 | def __init__(self, config: GPTConfig):
89 | super().__init__()
90 | self.n_q_heads = config.n_heads
91 | self.n_kv_heads = config.n_kv_heads if config.n_kv_heads is not None else config.n_heads
92 | self.n_rep = self.n_q_heads // self.n_kv_heads
93 | self.head_dim = config.n_embd // self.n_q_heads
94 |
95 | self.qkv_proj = nn.Linear(config.n_embd, (self.n_q_heads + 2 * self.n_kv_heads) * self.head_dim, bias=False)
96 |
97 | q_heads_concat_dim = self.n_q_heads * self.head_dim
98 | self.out_proj = nn.Linear(q_heads_concat_dim, config.n_embd, bias=False)
99 |
100 | self.dropout = nn.Dropout(config.dropout)
101 | self.norm = RMSNorm(config.n_embd, eps=config.layer_norm_eps)
102 | self.out_proj.GPT_SCALE_INIT = 1
103 |
104 | def forward(self, x: torch.Tensor, rotary_emb: Tuple[torch.Tensor, torch.Tensor], past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
105 | B, T, C = x.shape
106 |
107 | h = self.norm(x)
108 |
109 | qkv = self.qkv_proj(h)
110 | q_len = self.n_q_heads * self.head_dim
111 | k_len = self.n_kv_heads * self.head_dim
112 |
113 | q = qkv[..., :q_len].view(B, T, self.n_q_heads, self.head_dim).transpose(1, 2)
114 | k = qkv[..., q_len : q_len + k_len].view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
115 | v = qkv[..., q_len + k_len :].view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
116 |
117 | cos, sin = rotary_emb
118 | q, k = apply_rotary_pos_emb(q, k, cos, sin)
119 |
120 | if past_kv is not None:
121 | past_k, past_v = past_kv
122 | k = torch.cat((past_k, k), dim=2)
123 | v = torch.cat((past_v, v), dim=2)
124 |
125 | present_kv = (k.to(x.dtype), v.to(x.dtype))
126 |
127 | k = repeat_kv(k, self.n_rep)
128 | v = repeat_kv(v, self.n_rep)
129 |
130 | is_causal_for_sdpa = False
131 |
132 | y = F.scaled_dot_product_attention(
133 | q, k, v,
134 | attn_mask=attn_mask,
135 | is_causal=is_causal_for_sdpa,
136 | dropout_p=self.dropout.p if self.training else 0.0
137 | )
138 |
139 | y = y.transpose(1, 2).contiguous().view(B, T, -1)
140 | y = self.out_proj(y)
141 |
142 | return x + y, present_kv
143 |
144 | class FeedForward(nn.Module):
145 | """
146 | FeedForward block with pre-normalization and SwiGLU, based on Llama/GPT-OSS style.
147 | """
148 | def __init__(self, config: GPTConfig):
149 | super().__init__()
150 | hidden_dim = 4 * config.n_embd
151 | hidden_dim = int(2 * hidden_dim / 3)
152 | multiple_of = 256
153 | hidden_dim = multiple_of * round(hidden_dim / multiple_of)
154 |
155 | self.norm = RMSNorm(config.n_embd, eps=config.layer_norm_eps)
156 | self.gate_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
157 | self.up_proj = nn.Linear(config.n_embd, hidden_dim, bias=False)
158 | self.down_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
159 |
160 | self.down_proj.GPT_SCALE_INIT = 1
161 |
162 | def forward(self, x):
163 | h = self.norm(x)
164 | gate = F.silu(self.gate_proj(h))
165 | up = self.up_proj(h)
166 | fused = gate * up
167 | return x + self.down_proj(fused)
168 |
169 | class Block(nn.Module):
170 | """
171 | Transformer Block in the Llama/GPT-OSS pre-normalization style.
172 | """
173 | def __init__(self, config: GPTConfig):
174 | super().__init__()
175 | self.attention = Attention(config)
176 | self.feed_forward = FeedForward(config)
177 |
178 | def forward(self, x: torch.Tensor, rotary_emb: Tuple[torch.Tensor, torch.Tensor], past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
179 | h, present_kv = self.attention(x, rotary_emb, past_kv, attn_mask=attn_mask)
180 | out = self.feed_forward(h)
181 | return out, present_kv
182 |
183 | class GPT(nn.Module):
184 | """
185 | The main GPT model, composed of the new Llama/GPT-OSS-style blocks.
186 | """
187 | def __init__(self, config: GPTConfig):
188 | super().__init__()
189 | self.config = config
190 |
191 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
192 | self.rotary_emb = RotaryPositionalEmbedding(config.n_embd // config.n_heads, config.block_size, base=config.rope_theta)
193 | self.layers = nn.ModuleList([Block(config) for _ in range(config.n_layers)])
194 | self.norm = RMSNorm(config.n_embd, eps=config.layer_norm_eps)
195 |
196 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
197 | self.lm_head.weight = self.tok_embeddings.weight
198 |
199 | self.apply(self._init_weights)
200 |
201 | def _init_weights(self, module):
202 | if isinstance(module, nn.Linear):
203 | std = 0.02
204 | if hasattr(module, 'GPT_SCALE_INIT'):
205 | std *= (2 * self.config.n_layers) ** -0.5
206 | torch.nn.init.normal_(module.weight, mean=0.0, std=std)
207 | if module.bias is not None:
208 | torch.nn.init.zeros_(module.bias)
209 | elif isinstance(module, nn.Embedding):
210 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
211 |
212 | def get_input_embeddings(self):
213 | """
214 | Returns the model's input embeddings.
215 | Required by the Hugging Face PreTrainedModel interface.
216 | """
217 | return self.tok_embeddings
218 |
219 | def set_input_embeddings(self, new_embeddings):
220 | """
221 | Sets the model's input embeddings.
222 | Required by the Hugging Face PreTrainedModel interface.
223 | """
224 | self.tok_embeddings = new_embeddings
225 |
226 | def forward(self, input_ids: torch.Tensor, past_kv_cache: Optional[list] = None, use_cache: bool = False, attn_mask: Optional[torch.Tensor] = None) -> tuple:
227 | B, T = input_ids.size()
228 | seq_len_offset = past_kv_cache[0][0].shape[2] if past_kv_cache is not None else 0
229 | total_sequence_length = seq_len_offset + T
230 |
231 | q_indices = torch.arange(T, device=input_ids.device) + seq_len_offset
232 | k_indices = torch.arange(total_sequence_length, device=input_ids.device)
233 | causal_mask = q_indices.unsqueeze(1) >= k_indices.unsqueeze(0)
234 |
235 | if attn_mask is not None:
236 | padding_mask = attn_mask[:, :total_sequence_length]
237 | combined_mask = causal_mask.unsqueeze(0) & padding_mask.unsqueeze(1)
238 | else:
239 | combined_mask = causal_mask.unsqueeze(0)
240 |
241 | final_sdpa_mask = combined_mask.unsqueeze(1)
242 |
243 | h = self.tok_embeddings(input_ids)
244 |
245 | cos, sin = self.rotary_emb(h, seq_len=total_sequence_length)
246 | cos = cos[:, :, seq_len_offset:, :]
247 | sin = sin[:, :, seq_len_offset:, :]
248 | rotary_emb = (cos, sin)
249 |
250 | present_kv_cache = []
251 | for i, layer in enumerate(self.layers):
252 | past_kv = past_kv_cache[i] if past_kv_cache is not None else None
253 | h, present_kv = layer(h, rotary_emb, past_kv, attn_mask=final_sdpa_mask)
254 | present_kv_cache.append(present_kv)
255 |
256 | h = self.norm(h)
257 | logits = self.lm_head(h)
258 |
259 | return logits, present_kv_cache
260 |
261 | @torch.inference_mode()
262 | def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, stop_on_token: Optional[Union[int, List[int]]] = None, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
263 | past_kv_cache = None
264 | current_attn_mask = attn_mask
265 |
266 | for _ in range(max_new_tokens):
267 | B, T = idx.shape
268 |
269 | if T >= self.config.block_size:
270 | break
271 |
272 | current_input = idx[:, -1:] if past_kv_cache is not None else idx
273 |
274 | logits, past_kv_cache = self(current_input, past_kv_cache=past_kv_cache, use_cache=True, attn_mask=current_attn_mask)
275 |
276 | logits = logits[:, -1, :] / temperature
277 |
278 | if top_k is not None:
279 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
280 | logits[logits < v[:, [-1]]] = -float('inf')
281 |
282 | if top_p is not None:
283 | sorted_probs, sorted_indices = torch.sort(F.softmax(logits, dim=-1), descending=True)
284 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
285 | sorted_indices_to_remove = cumulative_probs > top_p
286 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
287 | sorted_indices_to_remove[..., 0] = 0
288 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
289 | logits[indices_to_remove] = -float('inf')
290 |
291 | probs = F.softmax(logits, dim=-1)
292 | idx_next = torch.multinomial(probs, num_samples=1)
293 | idx = torch.cat((idx, idx_next), dim=1)
294 |
295 | if current_attn_mask is not None:
296 | new_mask_col = torch.ones((B, 1), dtype=current_attn_mask.dtype, device=current_attn_mask.device)
297 | current_attn_mask = torch.cat((current_attn_mask, new_mask_col), dim=1)
298 |
299 | if stop_on_token is not None:
300 | stop_tokens = stop_on_token if isinstance(stop_on_token, (list, tuple, set)) else [stop_on_token]
301 | if idx_next.item() in stop_tokens:
302 | break
303 | return idx
304 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Lille 130M
2 |
3 | 
4 |
5 | ## Table of Contents
6 | 1. [Model Summary](#-model-summary)
7 | 2. [Evaluation](#-evaluation)
8 | 3. [How to Use](#-how-to-use)
9 | 4. [Training and Finetuning](#-training-and-finetuning)
10 | 5. [Training Details](#-training-details)
11 | 6. [Limitations](#-limitations)
12 | 7. [The Truly Open-Source Stack](#-the-truly-open-source-repos)
13 | 8. [License](#-license)
14 | 9. [Citation](#-citation)
15 |
16 | ## ✨ Model Summary
17 |
18 | **Lille** is a 130-million-parameter language model built from the ground up as a core component of a completely open-source deep learning stack. The name Lille reflects both its compact size and strong capabilities - capturing the idea that less can be more. It draws on the Norwegian word lille (‘small’ or ‘little’) as well as the French city Lille, giving it both meaning and place. It was trained using a custom tokenizer, a curated dataset, and a memory-efficient optimizer, all of which are publicly available.
19 |
20 | The model comes in two versions:
21 | * **`Lille-130M-Base`**: The foundational model pretrained on 4.27 billion of tokens from the [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) dataset. A post-processing step to only include the highest quality of content was added. It has strong general knowledge and text completion abilities.
22 | * **`Lille-130M-Instruct`**: The instruction-tuned version, fine-tuned on the **[Kyoto-Corpus](https://huggingface.co/datasets/Nikity/Kyoto-Corpus)**. It excels at following user commands, engaging in chat, and performing a variety of instruction-based tasks.
23 |
24 | The model architecture is a modern Transformer decoder featuring Grouped-Query Attention (GQA), RoPE, and RMSNorm, making it efficient and performant for its size.
25 |
26 | *Note on parameter count: While the model name is `130M` for simplicity, the actual parameter count is 127.17 million.*
27 |
28 | ## 📊 Evaluation
29 |
30 | All evaluations were conducted using **[simple-eval](https://github.com/Nikityyy/simple-eval)**, our open-source evaluation framework. Benchmarks are run in a zero-shot setting unless specified otherwise.
31 |
32 | #### `Lille-130M-Instruct`
33 |
34 | 
35 |
36 | > Evaluations for other LLMs are sourced from the Open LLM Leaderboard or their respective model cards when benchmark data is unavailable. For Lille 140M Instruct, evaluations are performed using simple-eval. ARC-C and ARC-E for Smollm2 are also evaluated using simple-eval.
37 |
38 | ## 🚀 How to Use
39 |
40 | There are several ways to use the Lille models, from easy-to-use graphical interfaces to advanced programmatic control.
41 |
42 | ### 1. LM Studio (Easiest for Chat)
43 |
44 | LM Studio provides a simple graphical interface to run LLMs on your local machine. It's the easiest way to start chatting with Lille.
45 |
46 | 1. **Download & Install:** Get [LM Studio](https://lmstudio.ai/) for your operating system (Windows, Mac, or Linux).
47 | 2. **Search for the Model:** Open LM Studio and click the **magnifying glass** icon on the left.
48 | 3. **Find Lille:** In the search bar, type `Lille` or `Nikity`. You will find the models I have uploaded.
49 | 4. **Download a GGUF:** On the right-hand side, you'll see a list of GGUF files. Download a recommended version like `lille-130m-instruct-f16.gguf`.
50 | 5. **Chat:** Click the **speech bubble** icon on the left. At the top, select the model you just downloaded. Now you can start a conversation!
51 |
52 | ### 2. SimpleAI SDK (Recommended for Programmatic Use)
53 |
54 | The easiest way to use Lille programmatically is with the `simpleai-sdk`, which handles all the boilerplate for you and provides a simple, high-level API for both Hugging Face and ONNX backends.
55 |
56 | ```bash
57 | pip install simpleai-sdk
58 | ```
59 |
60 | ```python
61 | from simple_ai import lille
62 |
63 | # This will download and cache the model on first run.
64 | # Specify the model version: "130m-instruct" (default) or "130m-base"
65 | # Specify the backend: "huggingface" (default) or "onnx"
66 | model = lille("huggingface", "130m-instruct")
67 |
68 | # --- For Chat (with instruct model) ---
69 | print("--- Chat Example ---")
70 | response1 = model.chat("What is the capital of France?", max_new_tokens=50)
71 | print(f"Bot: {response1}")
72 |
73 | response2 = model.chat("And what is its population?", max_new_tokens=50, top_p=0.90)
74 | print(f"Bot: {response2}")
75 |
76 | # This resets the chat history
77 | model.reset_chat()
78 |
79 | # --- For Text Completion (with base or instruct model) ---
80 | prompt = "Artificial Intelligence is"
81 | response = model.generate(prompt, max_new_tokens=50, temperature=0.9)
82 | print(f"\n--- Completion Example ---\n{prompt}{response}")
83 | ```
84 |
85 | ### 3. Standard Hugging Face Transformers (this also needs `simpleai-sdk` currently)
86 |
87 | You can also use the model directly with the `transformers` library for more advanced use cases.
88 |
89 | ```bash
90 | pip install transformers torch simpleai-sdk
91 | ```
92 |
93 | ```python
94 | import torch
95 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
96 | from simple_ai.model_hf import LilleConfig, LilleForCausalLM
97 |
98 | # 1. Register the custom model architecture with Hugging Face
99 | AutoConfig.register("lille-130m", LilleConfig)
100 | AutoModelForCausalLM.register(LilleConfig, LilleForCausalLM)
101 |
102 | # 2. Define constants and setup device
103 | MODEL = "Nikity/lille-130m-instruct"
104 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
105 |
106 | # 3. Load tokenizer and model
107 | tokenizer = AutoTokenizer.from_pretrained(MODEL)
108 | model = AutoModelForCausalLM.from_pretrained(
109 | MODEL,
110 | torch_dtype="auto",
111 | device_map=DEVICE,
112 | )
113 |
114 | # 4. Prepare chat prompt and tokenize it
115 | chat = [{"role": "user", "content": "What is the capital of France?"}]
116 | inputs = tokenizer.apply_chat_template(
117 | chat,
118 | add_generation_prompt=True,
119 | return_tensors="pt"
120 | ).to(DEVICE)
121 |
122 | # 5. Generate a response
123 | with torch.inference_mode():
124 | outputs = model.generate(
125 | input_ids=inputs,
126 | max_new_tokens=512,
127 | eos_token_id=tokenizer.eos_token_id,
128 | pad_token_id=tokenizer.pad_token_id,
129 | do_sample=True,
130 | temperature=0.5,
131 | top_p=0.95,
132 | )
133 |
134 | # 6. Decode and print the response
135 | response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
136 | print(response)
137 | ```
138 |
139 | ## 🚀 Training and Finetuning
140 |
141 | You can replicate the pretraining of `Lille-130M-Base` or fine-tune it on your own dataset using the provided scripts.
142 |
143 | #### 1. Setup
144 |
145 | First, clone the repository and install the required dependencies:
146 |
147 | ```bash
148 | git clone https://github.com/Nikityyy/lille
149 | cd lille
150 | pip install -r requirements.txt
151 | ```
152 |
153 | **Note on the Optimizer:** The default `Sophia-Triton` optimizer requires the [Triton](https://triton-lang.org/main/getting-started/installation.html) library. Triton is officially supported on Linux with NVIDIA GPUs. While experimental installation on Windows is possible, it can be a complex and difficult process. For a much simpler setup on **Windows and macOS**, or if you prefer not to install Triton, it is highly recommended to use a pure PyTorch implementation of Sophia instead:
154 |
155 | 1. Replace the contents of the `sophia_triton.py` file with the code from [this link](https://github.com/Nikityyy/Sophia-Triton/blob/main/sophia.py).
156 | 2. The `train.py` script should work without any import changes, as the class name `SophiaG` is the same.
157 |
158 | #### 2. Data Preparation
159 |
160 | The training script expects data in a specific `.npz` format containing tokenized documents and their offsets.
161 |
162 | **For Pretraining (like FineWeb-Edu):**
163 |
164 | Use the `prepare_dataset_fineweb.py` script. It will stream the dataset from Hugging Face, apply filters, tokenize the text, and save it in the required format.
165 |
166 | ```bash
167 | python prepare_dataset_fineweb.py
168 | ```
169 | This will create `data/fineweb_edu_sample_10BT/train.npz` and `val.npz`.
170 |
171 | **For Finetuning (Instruction Datasets):**
172 |
173 | Use the `prepare_dataset.py` script. Your input data should be a single `.txt` file where each example is separated by the `<|endoftext|>` token.
174 |
175 | 1. Place your data file, for example, at `data/my_dataset/train.txt`.
176 | 2. Modify the `input_file_path` and `output_dir` variables in `prepare_dataset.py`.
177 | 3. Run the script:
178 |
179 | ```bash
180 | python prepare_dataset.py
181 | ```
182 | This will create `train.npz` and `val.npz` in your specified output directory.
183 |
184 | #### 3. Running the Training Script
185 |
186 | All training logic is handled by `train.py`. You can configure hyperparameters directly at the top of this file.
187 |
188 | **To Pretrain from Scratch:**
189 |
190 | 1. Ensure you have prepared a pretraining dataset.
191 | 2. In `train.py`, set `finetune = False`.
192 | 3. Configure pretraining parameters like `data_dir`, `batch_size`, etc.
193 | 4. Run the script:
194 |
195 | ```bash
196 | python train.py
197 | ```
198 |
199 | **To Fine-tune a Pretrained Model:**
200 |
201 | 1. Ensure you have prepared a fine-tuning dataset.
202 | 2. In `train.py`, set `finetune = True`.
203 | 3. Set `resume_checkpoint` to the path of the pretrained model checkpoint (e.g., `checkpoints/best_model.pt`).
204 | 4. Configure fine-tuning parameters like `finetune_data_dir` and `finetune_learning_rate`.
205 | 5. Run the script:
206 |
207 | ```bash
208 | python train.py
209 | ```
210 |
211 | Checkpoints will be saved in the directory specified by `out_dir` (for pretraining) or `finetune_out_dir` (for fine-tuning). The best model based on validation loss will be saved as `best_model.pt`.
212 |
213 | ## 🛠️ Training Details
214 |
215 | ### Pretraining (`Lille-130M-Base`)
216 | * **Dataset:** Pretrained on **4.27 billion tokens** from the `sample-10BT` configuration of the [HuggingFaceFW/fineweb-edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) dataset.
217 | * **Tokenizer:** The custom **[Hastings](https://github.com/Nikityyy/Hastings)** tokenizer with a 32,768 vocabulary size.
218 | * **Optimizer:** The memory-efficient **[Sophia-Triton](https://github.com/Nikityyy/Sophia-Triton)** optimizer.
219 | * **Hardware:** Trained on a single NVIDIA RTX 4070-TI.
220 | * **Precision:** bfloat16.
221 |
222 | ### Instruction Tuning (`Lille-130M-Instruct`)
223 | * **Dataset:** Supervised Fine-Tuning (SFT) was performed on the **[Kyoto-Corpus](https://github.com/Nikityyy/Kyoto-Corpus)**, a high-quality, curated collection of conversational and instructional data.
224 |
225 | ### Model Architecture
226 | * **Type:** Transformer Decoder
227 | * **Layers:** 24
228 | * **Embedding Size:** 640
229 | * **Attention Heads:** 10
230 | * **KV Heads (GQA):** 2
231 | * **Context Length:** 512 tokens
232 |
233 | ## Limitations
234 |
235 | Lille models primarily understand and generate content in English. While powerful for their size, they can produce text that may not always be factually accurate, logically consistent, or free from biases present in the training data. These models should be used as assistive tools rather than definitive sources of information. Users should always verify important information and critically evaluate any generated content.
236 |
237 | ## 🛠️ The truly open-source repos
238 |
239 | Lille is a key component of my initiative to build and release a complete, truly open-source stack for language modeling. All components are designed to work together seamlessly.
240 |
241 | * **Tokenizer:** **[Hastings](https://github.com/Nikityyy/Hastings)** - A modern, efficient tokenizer with a 32k vocabulary.
242 | * **Dataset:** **[Kyoto-Corpus](https://github.com/Nikityyy/Kyoto-Corpus)** - A high-quality, small-scale dataset for instruction tuning.
243 | * **Model:** **[lille](https://github.com/Nikityyy/lille)** (this repository) - A powerful 130-million-parameter model trained from scratch.
244 | * **Optimizer:** **[Sophia-Triton](https://github.com/Nikityyy/Sophia-Triton)** - A memory-efficient, Triton-based implementation of the SophiaG optimizer.
245 | * **Evaluations:** **[simple-eval](https://github.com/Nikityyy/simple-eval)** - A straightforward framework for evaluating model performance using an LLM as a Judge.
246 |
247 | ## 🙏 Credits
248 |
249 | Lille’s training scripts and architecture were inspired by and build upon the work of:
250 |
251 | * **nanoGPT** – A minimal and efficient PyTorch implementation of GPT training: [https://github.com/karpathy/nanoGPT](https://github.com/karpathy/nanoGPT)
252 | * **gpt-oss** – The open models from OpenAI: [https://github.com/openai/gpt-oss](https://github.com/openai/gpt-oss)
253 |
254 | ## 📜 License
255 |
256 | This project is licensed under the Apache-2.0 License.
257 |
258 | ## Citation
259 |
260 | If you use Lille or any part of this open-source stack in your work, please consider citing it:
261 |
262 | ```bibtex
263 | @misc{lille-130m,
264 | author = {Nikita Berger},
265 | title = {Lille: A Truly Open-Source 130M Language Model},
266 | year = {2025},
267 | publisher = {GitHub},
268 | journal = {GitHub repository},
269 | howpublished = {\url{https://github.com/Nikityyy/lille}}
270 | }
271 | ```
272 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import pickle
5 | import threading
6 | import collections
7 | import queue
8 | from contextlib import nullcontext
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | import torch.distributed as dist
14 | from torch.nn.parallel import DistributedDataParallel as DDP
15 | from tqdm import tqdm
16 | import tiktoken
17 | import wandb
18 | from html2term import printc
19 |
20 | from sophia_triton import SophiaG
21 | from model import GPT, GPTConfig
22 |
23 | # --- General Settings ---
24 | out_dir = 'checkpoints'
25 | eval_interval = 500
26 | log_interval = 1
27 | eval_iters = 100
28 | resume_checkpoint = None
29 | # resume_checkpoint = "checkpoints/best_model.pt"
30 |
31 | # --- Finetuning Settings ---
32 | finetune = False
33 | finetune_out_dir = 'checkpoints_ft'
34 | finetune_data_dir = 'data/smol-sft'
35 | finetune_learning_rate = 1e-5
36 | finetune_num_epochs = 3
37 |
38 | # --- W&B Logging ---
39 | wandb_log = True
40 | wandb_project = 'modern-gpt-pretrain'
41 | wandb_run_name = f'run-modern-gpt-{time.strftime("%Y-%m-%d-%H-%M-%S")}'
42 |
43 | # --- Data Settings ---
44 | data_dir = 'data/fineweb_edu_sample_10BT'
45 | pretrain_data_dir = data_dir
46 | batch_size = 16
47 | block_size = 512
48 | num_epochs = 1
49 | gradient_accumulation_steps = 2
50 |
51 | # --- Model Architecture ---
52 | n_layers = 24
53 | n_embd = 640
54 | n_heads = 10
55 | n_kv_heads = 2
56 | dropout = 0.1
57 | layer_norm_eps = 1e-5
58 |
59 | # --- Optimizer & LR Schedule ---
60 | learning_rate = 1e-4
61 | weight_decay = 0.2
62 | beta1 = 0.9
63 | beta2 = 0.95
64 | grad_clip = 1.0
65 | decay_lr = True
66 | warmup_iters = 2000
67 | min_lr = learning_rate / 10
68 | hess_interval = 10
69 |
70 | # --- Hardware & Performance ---
71 | device = 'cuda'
72 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
73 | compile = True
74 |
75 | class NpzDataset:
76 | """A simple lazy-loading dataset for the tokens/offsets .npz format."""
77 | def __init__(self, file_path):
78 | self.data = np.load(file_path, mmap_mode='r')
79 | self.tokens = self.data['tokens']
80 | self.offsets = self.data['offsets']
81 | self.num_docs = len(self.offsets) - 1
82 |
83 | def __len__(self):
84 | return self.num_docs
85 |
86 | def __getitem__(self, idx):
87 | start = self.offsets[idx]
88 | end = self.offsets[idx + 1]
89 | return self.tokens[start:end].tolist()
90 |
91 | class DataPrefetcher:
92 | """ An asynchronous data prefetcher that prepares batches on the CPU. """
93 | def __init__(self, data, block_size, batch_size, max_prefetch=2):
94 | self.data = data
95 | self.block_size = block_size
96 | self.batch_size = batch_size
97 | self.queue = queue.Queue(maxsize=max_prefetch)
98 | self.is_running = True
99 | self.thread = threading.Thread(target=self.run, daemon=True)
100 | self.thread.start()
101 |
102 | def _preload(self):
103 | return get_batch('train')
104 |
105 | def run(self):
106 | while self.is_running:
107 | try:
108 | self.queue.put(self._preload(), timeout=1)
109 | except queue.Full:
110 | continue
111 |
112 | def next(self):
113 | return self.queue.get()
114 |
115 | def close(self):
116 | self.is_running = False
117 | while not self.queue.empty():
118 | try:
119 | self.queue.get_nowait()
120 | except queue.Empty:
121 | break
122 | self.thread.join()
123 |
124 | def get_batch(split, pretrain=False):
125 | """
126 | Get a batch of data. Handles padding for sequences shorter than block_size.
127 | For supervised fine-tuning, it masks out the loss for prompt tokens.
128 | """
129 | if pretrain:
130 | data = train_data_pretrain if split == 'train' else val_data_pretrain
131 | else:
132 | data = train_data if split == 'train' else val_data
133 |
134 | ix = torch.randint(len(data), (batch_size,))
135 | batch_raw = [data[i] for i in ix]
136 |
137 | x_padded = torch.full((batch_size, block_size), pad_token_id, dtype=torch.long)
138 | y_padded = torch.full((batch_size, block_size), -100, dtype=torch.long)
139 |
140 | is_finetune_split = finetune and not pretrain
141 |
142 | for i, tokens in enumerate(batch_raw):
143 | seq_len = min(len(tokens), block_size)
144 | x_padded[i, :seq_len] = torch.tensor(tokens[:seq_len], dtype=torch.long)
145 |
146 | targets = torch.tensor(tokens[1:seq_len], dtype=torch.long)
147 |
148 | if is_finetune_split and assistant_token_id is not None:
149 | x_seq = x_padded[i, :seq_len]
150 | assistant_indices = (x_seq == assistant_token_id).nonzero(as_tuple=True)[0]
151 |
152 | if len(assistant_indices) > 0:
153 | last_assistant_idx = assistant_indices[-1]
154 | targets[:last_assistant_idx] = -100
155 |
156 | y_padded[i, :seq_len-1] = targets
157 |
158 | return x_padded, y_padded
159 |
160 | @torch.no_grad()
161 | def estimate_loss(pretrain=False):
162 | """
163 | Estimate loss on train and validation splits.
164 | """
165 | out = {}
166 | model.eval()
167 | for split in ['train', 'val']:
168 | losses = torch.zeros(eval_iters, device=device)
169 | for k in range(eval_iters):
170 | X_cpu, Y_cpu = get_batch(split, pretrain=pretrain)
171 | X = X_cpu.to(device, non_blocking=True)
172 | Y = Y_cpu.to(device, non_blocking=True)
173 | attn_mask = (X != pad_token_id)
174 | with ctx:
175 | logits, _ = model(X, attn_mask=attn_mask)
176 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1))
177 | losses[k] = loss
178 | if ddp:
179 | dist.all_reduce(losses, op=dist.ReduceOp.SUM)
180 | losses /= ddp_world_size
181 | out[split] = losses.mean()
182 | model.train()
183 | return out
184 |
185 | def configure_optimizers(model, weight_decay, learning_rate, betas):
186 | """
187 | Configure optimizer with weight decay for 2D parameters.
188 | """
189 | param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
190 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
191 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
192 | optim_groups = [
193 | {'params': decay_params, 'weight_decay': weight_decay},
194 | {'params': nodecay_params, 'weight_decay': 0.0}
195 | ]
196 | num_decay_params = sum(p.numel() for p in decay_params)
197 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
198 | if master_process:
199 | printc(f" <#cccccc>Num decayed parameter tensors:#cccccc> {len(decay_params)}, with {num_decay_params:,} parameters")
200 | printc(f" <#cccccc>Num non-decayed parameter tensors:#cccccc> {len(nodecay_params)}, with {num_nodecay_params:,} parameters ")
201 | # If you get the following error: "got an unexpected keyword argument 'bs'", then remove bs=tokens_per_optimizer_step
202 | optimizer = SophiaG(optim_groups, lr=learning_rate, betas=betas, rho=0.05, weight_decay=weight_decay, bs=tokens_per_optimizer_step)
203 | return optimizer
204 |
205 | def get_cosine_schedule_with_warmup_scheduler(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio_val=0.1):
206 | """
207 | Create a learning rate scheduler with a cosine decay and linear warmup.
208 | """
209 | def lr_lambda_func(current_step):
210 | if current_step < num_warmup_steps:
211 | return float(current_step) / float(max(1, num_warmup_steps))
212 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
213 | cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
214 | return min_lr_ratio_val + (1.0 - min_lr_ratio_val) * cosine_decay
215 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda_func)
216 |
217 | def save_checkpoint_async(checkpoint, path, force=False, old_path_to_delete=None):
218 | """
219 | Save a checkpoint asynchronously in a separate thread.
220 | """
221 | temp_path = path + ".tmp"
222 | try:
223 | torch.save(checkpoint, temp_path)
224 | os.replace(temp_path, path)
225 | if old_path_to_delete and os.path.exists(old_path_to_delete):
226 | os.remove(old_path_to_delete)
227 | except Exception as e:
228 | printc(f"Error saving checkpoint to {path}: {e}")
229 | if os.path.exists(temp_path):
230 | os.remove(temp_path)
231 |
232 | if finetune:
233 | out_dir = finetune_out_dir
234 | learning_rate = finetune_learning_rate
235 | num_epochs = finetune_num_epochs
236 | data_dir = finetune_data_dir
237 | warmup_iters = 1000
238 | weight_decay = 0.01
239 | dropout = 0.0
240 | wandb_project = 'modern-gpt-finetune'
241 | if not resume_checkpoint:
242 | raise ValueError("For finetuning, a `resume_checkpoint` must be provided.")
243 | printc("" + "="*50 + "")
244 | printc("|| FINETUNING MODE ENABLED")
245 | printc(f"|| Output directory: {out_dir}")
246 | printc(f"|| Data directory: {data_dir}")
247 | printc(f"|| Learning rate: {learning_rate}")
248 | printc(f"|| Epochs: {num_epochs}")
249 | printc("" + "="*50 + " ")
250 | else:
251 | printc("" + "="*50 + "")
252 | printc("|| PRETRAINING MODE ENABLED")
253 | printc(f"|| Output directory: {out_dir}")
254 | printc(f"|| Data directory: {data_dir}")
255 | printc(f"|| Learning rate: {learning_rate}")
256 | printc(f"|| Epochs: {num_epochs}")
257 | printc("" + "="*50 + " ")
258 |
259 | ddp = int(os.environ.get('RANK', -1)) != -1
260 | if ddp:
261 | dist.init_process_group(backend='nccl')
262 | ddp_rank = int(os.environ['RANK'])
263 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
264 | ddp_world_size = int(os.environ['WORLD_SIZE'])
265 | device = f'cuda:{ddp_local_rank}'
266 | torch.cuda.set_device(device)
267 | master_process = ddp_rank == 0
268 | seed_offset = ddp_rank
269 | else:
270 | master_process = True
271 | seed_offset = 0
272 | ddp_world_size = 1
273 |
274 | if master_process:
275 | os.makedirs(out_dir, exist_ok=True)
276 |
277 | torch.manual_seed(1337 + seed_offset)
278 | torch.backends.cuda.matmul.allow_tf32 = True
279 | torch.backends.cudnn.allow_tf32 = True
280 | device_type = 'cuda' if 'cuda' in device else 'cpu'
281 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
282 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type, dtype=ptdtype)
283 |
284 | if wandb_log and master_process:
285 | config_dict = {k: v for k, v in locals().items() if isinstance(v, (int, float, str, bool))}
286 | wandb.init(project=wandb_project, name=wandb_run_name, config=config_dict)
287 |
288 | with open('tokenizer/Hastings.pkl', 'rb') as f:
289 | hastings = pickle.load(f)
290 | enc = tiktoken.core.Encoding(hastings.pop('name'), **hastings)
291 | vocab_size = enc.n_vocab
292 | assistant_token_id = enc.encode_single_token("<|assistant|>")
293 |
294 | try:
295 | pad_token_id = enc.encode_single_token("<|pad|>")
296 | except KeyError:
297 | printc("Warning: '<|pad|>' token not found. Using '<|endoftext|>' as a pad token.")
298 | pad_token_id = enc.encode_single_token("<|endoftext|>")
299 |
300 | printc(" Loading dataset using NpzDataset...")
301 | train_data = NpzDataset(os.path.join(data_dir, 'train.npz'))
302 | val_data = NpzDataset(os.path.join(data_dir, 'val.npz'))
303 | if data_dir != pretrain_data_dir:
304 | train_data_pretrain = NpzDataset(os.path.join(pretrain_data_dir, 'train.npz'))
305 | val_data_pretrain = NpzDataset(os.path.join(pretrain_data_dir, 'val.npz'))
306 | else:
307 | train_data_pretrain, val_data_pretrain = train_data, val_data
308 |
309 | train_tokens = len(train_data.tokens)
310 | tokens_per_optimizer_step = batch_size * block_size * ddp_world_size * gradient_accumulation_steps
311 | max_optimizer_steps = (train_tokens // tokens_per_optimizer_step) * num_epochs
312 | iters_per_epoch_optimizer_steps = train_tokens // tokens_per_optimizer_step
313 | lr_decay_iters = max_optimizer_steps
314 |
315 | model_args = dict(
316 | n_layers=n_layers, n_embd=n_embd, vocab_size=vocab_size, block_size=block_size,
317 | dropout=dropout, n_heads=n_heads, n_kv_heads=n_kv_heads, layer_norm_eps=layer_norm_eps
318 | )
319 | config = GPTConfig(**model_args)
320 | model = GPT(config)
321 | model.to(device)
322 | num_params = sum(p.numel() for p in model.parameters())
323 | if master_process:
324 | printc(f"Model has {num_params / 1e6:.2f}M parameters.")
325 |
326 | scaler = torch.amp.GradScaler(enabled=(dtype == 'float16'))
327 | optimizer = configure_optimizers(model, weight_decay, learning_rate, (beta1, beta2))
328 |
329 | min_lr_ratio_for_scheduler = min_lr / learning_rate
330 | scheduler = get_cosine_schedule_with_warmup_scheduler(
331 | optimizer, num_warmup_steps=warmup_iters,
332 | num_training_steps=max_optimizer_steps, min_lr_ratio_val=min_lr_ratio_for_scheduler
333 | )
334 |
335 | iter_num = 0
336 | best_val_loss = 1e9
337 | if resume_checkpoint and os.path.exists(resume_checkpoint):
338 | if master_process: printc(f"Loading checkpoint:{resume_checkpoint}")
339 | checkpoint = torch.load(resume_checkpoint, map_location=device)
340 | ckpt_model_args = checkpoint['model_args']
341 | for k, v in model_args.items():
342 | if k not in ckpt_model_args or ckpt_model_args[k] != v:
343 | if master_process: printc(f" Warning: Mismatch in model config: '{k}'")
344 | state_dict = checkpoint['model_state_dict']
345 | unwanted_prefix = '_orig_mod.'
346 | for k,v in list(state_dict.items()):
347 | if k.startswith(unwanted_prefix):
348 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
349 | model.load_state_dict(state_dict, strict=False)
350 |
351 | if not finetune or (finetune and 'optimizer_state_dict' in checkpoint and resume_checkpoint.startswith(finetune_out_dir)):
352 | if master_process: printc("Resuming training with optimizer state.")
353 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
354 | for group in optimizer.param_groups:
355 | group.setdefault('bs', tokens_per_optimizer_step)
356 | group.setdefault('eps', 1e-15)
357 | iter_num = checkpoint['iter_num']
358 | best_val_loss = checkpoint['best_val_loss']
359 | scheduler.last_epoch = iter_num
360 |
361 | if compile:
362 | if master_process: printc("Compiling the model...")
363 | model = torch.compile(model, backend="inductor", mode="max-autotune")
364 | if master_process:
365 | printc(" Warming up the compiled model...")
366 | with ctx:
367 | with torch.no_grad():
368 | x_warm_cpu, _ = get_batch('train')
369 | x_warm = x_warm_cpu.to(device, non_blocking=True)
370 | attn_mask_warm = (x_warm != pad_token_id)
371 | _, _ = model(x_warm, attn_mask=attn_mask_warm)
372 | printc(" Warm-up complete. ")
373 |
374 | if ddp:
375 | model = DDP(model, device_ids=[ddp_local_rank], find_unused_parameters=False)
376 |
377 | if master_process: printc("Setting up asynchronous data prefetcher for training...")
378 | train_prefetcher = DataPrefetcher(train_data, block_size, batch_size)
379 |
380 | t0 = time.time()
381 | checkpoint_threads = collections.deque()
382 | last_interval_checkpoint_path = None
383 | pbar = tqdm(range(iter_num, max_optimizer_steps), disable=not master_process)
384 |
385 | for optimizer_step in pbar:
386 | if optimizer_step > iter_num and optimizer_step % eval_interval == 0:
387 | losses = estimate_loss()
388 | if finetune:
389 | losses_pt = estimate_loss(pretrain=True)
390 | current_epoch = optimizer_step / iters_per_epoch_optimizer_steps
391 | if master_process:
392 | printc(f" Epoch {current_epoch:.2f} | Step {optimizer_step}")
393 | if finetune:
394 | printc(f" Finetune Loss -> Train: {losses['train']:.4f}, Val: {losses['val']:.4f}")
395 | printc(f" <#cccccc>Pretrain Loss#cccccc> -> Train: {losses_pt['train']:.4f}, Val: {losses_pt['val']:.4f}")
396 | else:
397 | printc(f" Pretrain Loss -> Train: {losses['train']:.4f}, Val: {losses['val']:.4f}")
398 | if wandb_log:
399 | log_data = {'eval/train_loss': losses['train'], 'eval/val_loss': losses['val'], 'trainer/epoch': current_epoch}
400 | if finetune:
401 | log_data.update({'eval/pretrain_train_loss': losses_pt['train'], 'eval/pretrain_val_loss': losses_pt['val']})
402 | wandb.log(log_data, step=optimizer_step)
403 |
404 | while checkpoint_threads and not checkpoint_threads[0].is_alive():
405 | checkpoint_threads.popleft().join()
406 | raw_model = model.module if ddp else model
407 | checkpoint = {
408 | 'model_state_dict': raw_model.state_dict(),
409 | 'optimizer_state_dict': optimizer.state_dict(),
410 | 'model_args': raw_model.config.to_dict(),
411 | 'iter_num': optimizer_step,
412 | 'best_val_loss': best_val_loss
413 | }
414 | checkpoint_path = os.path.join(out_dir, f'ckpt_iter_{optimizer_step}.pt')
415 | thread = threading.Thread(target=save_checkpoint_async, args=(checkpoint.copy(), checkpoint_path, False, last_interval_checkpoint_path))
416 | thread.start()
417 | checkpoint_threads.append(thread)
418 | last_interval_checkpoint_path = checkpoint_path
419 |
420 | if losses['val'] < best_val_loss:
421 | best_val_loss = losses['val']
422 | checkpoint['best_val_loss'] = best_val_loss
423 | best_model_path = os.path.join(out_dir, 'best_model.pt')
424 | thread = threading.Thread(target=save_checkpoint_async, args=(checkpoint.copy(), best_model_path, True))
425 | thread.start()
426 | checkpoint_threads.append(thread)
427 | printc(f" Started saving new best model to{best_model_path}")
428 |
429 | optimizer.zero_grad(set_to_none=True)
430 |
431 | if compile and device_type == 'cuda':
432 | torch.compiler.cudagraph_mark_step_begin()
433 |
434 | for micro_step in range(gradient_accumulation_steps):
435 | X_cpu, Y_cpu = train_prefetcher.next()
436 | X = X_cpu.pin_memory().to(device, non_blocking=True)
437 | Y = Y_cpu.pin_memory().to(device, non_blocking=True)
438 | attn_mask = (X != pad_token_id)
439 | with ctx:
440 | logits, _ = model(X, attn_mask=attn_mask)
441 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1)) / gradient_accumulation_steps
442 | scaler.scale(loss).backward()
443 |
444 | scaler.unscale_(optimizer)
445 | if grad_clip > 0.0:
446 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
447 |
448 | if optimizer_step == 0:
449 | optimizer.update_hessian()
450 |
451 | scaler.step(optimizer)
452 | scaler.update()
453 |
454 | if optimizer_step % hess_interval == 0:
455 | with ctx:
456 | logits, _ = model(X, attn_mask=attn_mask)
457 | probs = F.softmax(logits, dim=-1)
458 | y_sample = torch.multinomial(probs.view(-1, logits.size(-1)), num_samples=1).view_as(Y)
459 | loss_sampled = F.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1))
460 | scaler.scale(loss_sampled).backward()
461 | scaler.unscale_(optimizer)
462 | if grad_clip > 0.0:
463 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
464 | optimizer.schedule_hessian_update()
465 | optimizer.zero_grad(set_to_none=True)
466 |
467 | if decay_lr:
468 | scheduler.step()
469 |
470 | t1 = time.time()
471 | dt = t1 - t0
472 | t0 = t1
473 | if optimizer_step % log_interval == 0 and master_process:
474 | lossf = loss.item() * gradient_accumulation_steps
475 | current_lr = optimizer.param_groups[0]['lr']
476 | pbar.set_description(f"step {optimizer_step + 1}: loss {lossf:.4f}, time {dt*1000:.2f}ms, lr {current_lr:e}")
477 | if wandb_log:
478 | wandb.log({'train/loss': lossf, 'trainer/lr': current_lr, 'trainer/dt_ms': dt * 1000}, step=optimizer_step)
479 |
480 | pbar.close()
481 | if master_process: printc(" Training loop finished. Closing data prefetcher...")
482 | train_prefetcher.close()
483 |
484 | if master_process:
485 | printc("Saving final model and waiting for all saves to complete...")
486 | raw_model = model.module if ddp else model
487 | final_checkpoint = {
488 | 'model_state_dict': raw_model.state_dict(),
489 | 'optimizer_state_dict': optimizer.state_dict(),
490 | 'model_args': raw_model.config.to_dict(),
491 | 'iter_num': max_optimizer_steps,
492 | 'best_val_loss': best_val_loss
493 | }
494 | final_checkpoint_path = os.path.join(out_dir, 'ckpt_final.pt')
495 | thread = threading.Thread(target=save_checkpoint_async, args=(final_checkpoint, final_checkpoint_path, True))
496 | thread.start()
497 | checkpoint_threads.append(thread)
498 |
499 | while checkpoint_threads:
500 | printc(f" <#cccccc>Waiting for {len(checkpoint_threads)} remaining checkpoint(s) to save...#cccccc>")
501 | checkpoint_threads.popleft().join()
502 |
503 | if wandb_log:
504 | wandb.finish()
505 |
506 | if ddp:
507 | dist.destroy_process_group()
508 |
509 | printc(" ✅ Training complete and all checkpoints saved.")
510 |
--------------------------------------------------------------------------------