├── .gitignore ├── src ├── output.png └── qwen3.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | model_cache -------------------------------------------------------------------------------- /src/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saurabhaloneai/qwen3-exp/HEAD/src/output.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 saurabh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # qwen3 jax implementation 2 | 3 | > [!IMPORTANT] 4 | > a clean, pure jax implementation of Qwen3-0.6B for inference with optimized memory usage and KV caching. 5 | 6 | ## output 7 | 8 | ![img](src/output.png) 9 | 10 | ## features 11 | 12 | - pure jax implementation with no PyTorch dependencies for inference 13 | - optimized KV caching for efficient text generation 14 | 15 | ## start here 16 | 17 | ```bash 18 | # Clone the repository 19 | git clone 20 | cd qwen3-exp 21 | 22 | # install dependencies 23 | pip install -U "jax[cuda12]" tokenizers torch safetensors huggingface-hub tqdm numpy 24 | 25 | # run inference 26 | python src/qwen3.py 27 | ``` 28 | 29 | ## usage 30 | 31 | The implementation automatically downloads the Qwen3-0.6B model from Hugging Face and runs inference: 32 | 33 | ```python 34 | from qwen3 import Qwen3Tokenizer, generate_kv_optimized, load_qwen3_weights_jax_optimized 35 | 36 | # initialize tokenizer and model 37 | tokenizer = Qwen3Tokenizer(repo_id="Qwen/Qwen3-0.6B") 38 | model = load_model() # Loads and converts weights to JAX 39 | 40 | # generate text 41 | prompt = "Give me a short introduction to large language models." 42 | output = generate_kv_optimized(model, prompt, max_new_tokens=50) 43 | ``` 44 | 45 | ## requirements 46 | 47 | - Python 3.8+ 48 | - JAX/JAXLib (GPU support recommended) 49 | - tokenizers 50 | - safetensors 51 | - huggingface-hub 52 | - numpy 53 | - tqdm 54 | 55 | ## license 56 | 57 | MIT License - see [LICENSE](LICENSE) for details. 58 | 59 | ## reference 60 | 61 | Based on the implementation from [LLMs-from-scratch](https://github.com/rasbt/LLMs-from-scratch/tree/main/ch05/11_qwen3) by Sebastian Raschka. -------------------------------------------------------------------------------- /src/qwen3.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from tokenizers import Tokenizer 4 | import torch 5 | from safetensors.numpy import load_file 6 | import os 7 | from pathlib import Path 8 | import gc 9 | from collections import defaultdict 10 | import numpy as np 11 | from tqdm import tqdm 12 | try: 13 | from huggingface_hub import snapshot_download 14 | except ImportError: 15 | snapshot_download = None 16 | 17 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 18 | os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5' 19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 20 | os.environ['JAX_PLATFORMS'] = 'gpu' 21 | 22 | device = jax.devices('gpu')[0] if jax.devices('gpu') else jax.devices('cpu')[0] 23 | 24 | QWEN3_CONFIG = { 25 | "vocab_size": 151936, "context_length": 40960, "emb_dim": 1024, "n_heads": 16, 26 | "n_layers": 28, "hidden_dim": 3072, "head_dim": 128, "qk_norm": True, 27 | "n_kv_groups": 8, "rope_base": 1000000.0, "dtype": torch.bfloat16, 28 | } 29 | 30 | class Qwen3Tokenizer(): 31 | def __init__(self, tokenizer_file_path="tokenizer.json", repo_id=None): 32 | if not Path(tokenizer_file_path).is_file() and repo_id and snapshot_download: 33 | snapshot_download(repo_id=repo_id, local_dir=Path(tokenizer_file_path).parent) 34 | self.tokenizer = Tokenizer.from_file(tokenizer_file_path) 35 | 36 | def encode(self, prompt): 37 | messages = [{"role": "user", "content": prompt}] 38 | formatted_prompt = self.format_qwen_chat(messages) 39 | return self.tokenizer.encode(formatted_prompt).ids 40 | 41 | def decode(self, token_ids): 42 | return self.tokenizer.decode(token_ids, skip_special_tokens=False) 43 | 44 | @staticmethod 45 | def format_qwen_chat(messages): 46 | prompt = "" 47 | for msg in messages: 48 | prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n" 49 | prompt += "<|im_start|>assistant<|think>\n\n<|/think>\n\n" 50 | return prompt 51 | 52 | def download_model_from_hf(repo_id, local_dir="./model_cache"): 53 | local_dir = Path(local_dir) 54 | local_dir.mkdir(exist_ok=True) 55 | model_path = snapshot_download(repo_id=repo_id, local_dir=local_dir / repo_id.replace("/", "_"), local_dir_use_symlinks=False) 56 | return Path(model_path) 57 | 58 | def safe_convert_numpy_to_jax(numpy_array): 59 | if numpy_array.dtype in [np.float16]: 60 | numpy_array = numpy_array.astype(np.float32) 61 | return jnp.array(numpy_array) 62 | 63 | def batch_convert_numpy_weights(numpy_weights_dict): 64 | converted = {key: safe_convert_numpy_to_jax(array) for key, array in numpy_weights_dict.items()} 65 | return jax.tree.map(lambda x: jax.device_put(x, device), converted) 66 | 67 | def cleanup_memory(): 68 | if torch.cuda.is_available(): 69 | torch.cuda.empty_cache() 70 | torch.cuda.synchronize() 71 | gc.collect() 72 | 73 | @jax.jit 74 | def feedforward_forward(params, x): 75 | gate = jax.nn.silu(jnp.einsum('bse,eh->bsh', x, params["gate_proj"])) 76 | up = jnp.einsum('bse,eh->bsh', x, params["up_proj"]) 77 | return jnp.einsum('bsh,he->bse', gate * up, params["down_proj"]) 78 | 79 | @jax.jit 80 | def rmsnorm_forward(params, x, eps=1e-6): 81 | orig_dtype = x.dtype 82 | x = x.astype(jnp.float32) 83 | variance = jnp.mean(x ** 2, axis=-1, keepdims=True) 84 | norm_x = x * jax.lax.rsqrt(variance + eps) * params["scale"] 85 | return norm_x.astype(orig_dtype) 86 | 87 | def compute_rope_params(head_dim, theta_base=10000.0, context_length=4096): 88 | inv_freq = 1.0 / (theta_base ** (jnp.arange(0, head_dim, 2) / head_dim)) 89 | positions = jnp.arange(context_length) 90 | angles = jnp.concatenate([positions[:, None] * inv_freq[None, :]] * 2, axis=1) 91 | return jnp.cos(angles), jnp.sin(angles) 92 | 93 | def apply_rope(x, cos, sin): 94 | seq_len = x.shape[2] 95 | x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] 96 | cos, sin = cos[:seq_len, :][None, None, :, :], sin[:seq_len, :][None, None, :, :] 97 | rotated = jnp.concatenate([-x2, x1], axis=-1) 98 | return ((x * cos) + (rotated * sin)).astype(x.dtype) 99 | 100 | def apply_rope_with_offset(x, cos, sin, position_offset=0): 101 | seq_len = x.shape[2] 102 | x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] 103 | 104 | positions = jnp.arange(position_offset, position_offset + seq_len) 105 | cos_slice = cos[positions, :][None, None, :, :] 106 | sin_slice = sin[positions, :][None, None, :, :] 107 | 108 | rotated = jnp.concatenate([-x2, x1], axis=-1) 109 | return ((x * cos_slice) + (rotated * sin_slice)).astype(x.dtype) 110 | 111 | def apply_qk_norm(x, norm_params): 112 | b, h, s, d = x.shape 113 | x_reshaped = x.reshape(b * h * s, d) 114 | x_normed = rmsnorm_forward(norm_params, x_reshaped) 115 | return x_normed.reshape(b, h, s, d) 116 | 117 | def grouped_query_attention_forward_kv(params, x, mask, cos, sin, num_heads, num_kv_groups, head_dim, kv_cache=None, qk_norm=False): 118 | b, seq, d_in = x.shape 119 | group_size = num_heads // num_kv_groups 120 | 121 | if kv_cache is not None and kv_cache["keys"].shape[2] > 0: 122 | position_offset = kv_cache["keys"].shape[2] 123 | else: 124 | position_offset = 0 125 | 126 | queries = jnp.einsum('bsd,dh->bsh', x, params["W_query"]).reshape(b, seq, num_heads, head_dim).transpose(0,2,1,3) 127 | keys = jnp.einsum('bsd,dh->bsh', x, params["W_key"]).reshape(b, seq, num_kv_groups, head_dim).transpose(0,2,1,3) 128 | values = jnp.einsum('bsd,dh->bsh', x, params["W_value"]).reshape(b, seq, num_kv_groups, head_dim).transpose(0,2,1,3) 129 | 130 | if qk_norm and "q_norm" in params and "k_norm" in params: 131 | queries = apply_qk_norm(queries, params["q_norm"]) 132 | keys = apply_qk_norm(keys, params["k_norm"]) 133 | 134 | queries = apply_rope_with_offset(queries, cos, sin, position_offset) 135 | keys = apply_rope_with_offset(keys, cos, sin, position_offset) 136 | 137 | if kv_cache is not None and kv_cache["keys"].shape[2] > 0: 138 | keys = jnp.concatenate([kv_cache["keys"], keys], axis=2) 139 | values = jnp.concatenate([kv_cache["values"], values], axis=2) 140 | 141 | new_cache = {"keys": keys, "values": values} 142 | 143 | keys_expanded = jnp.repeat(keys, group_size, axis=1) 144 | values_expanded = jnp.repeat(values, group_size, axis=1) 145 | 146 | attn_scores = jnp.einsum('bnqh,bnkh->bnqk', queries, keys_expanded) / jnp.sqrt(head_dim) 147 | 148 | if kv_cache is None or kv_cache["keys"].shape[2] == 0: 149 | q_len, k_len = queries.shape[2], keys.shape[2] 150 | causal_mask = jnp.triu(jnp.ones((q_len, k_len)), k=1) 151 | attn_scores = jnp.where(causal_mask[None, None, :, :], -jnp.inf, attn_scores) 152 | 153 | attn_weights = jax.nn.softmax(attn_scores, axis=-1) 154 | context = jnp.einsum('bnqk,bnkh->bnqh', attn_weights, values_expanded) 155 | context = context.transpose(0,2,1,3).reshape(b, seq, num_heads * head_dim) 156 | output = jnp.einsum('bsh,hd->bsd', context, params["out_proj"]) 157 | 158 | return output, new_cache 159 | 160 | def transformer_block_forward_kv(params, x, mask, cos, sin, cfg, kv_cache=None): 161 | shortcut = x 162 | x = rmsnorm_forward(params["norm1"], x) 163 | x, new_cache = grouped_query_attention_forward_kv(params["att"], x, mask, cos, sin, cfg["n_heads"], cfg["n_kv_groups"], cfg["head_dim"], kv_cache, cfg["qk_norm"]) 164 | x = x + shortcut 165 | shortcut = x 166 | x = rmsnorm_forward(params["norm2"], x) 167 | x = feedforward_forward(params["ff"], x) 168 | return x + shortcut, new_cache 169 | 170 | def init_qwen3_params(key, cfg): 171 | k_emb, k_blocks, k_final_norm, k_out = jax.random.split(key, 4) 172 | tok_emb = jax.random.normal(k_emb, (cfg["vocab_size"], cfg["emb_dim"])) / jnp.sqrt(cfg["vocab_size"]) 173 | block_keys = jax.random.split(k_blocks, cfg["n_layers"]) 174 | 175 | def init_block_params(k): 176 | k_att, k_ff, k_norm1, k_norm2 = jax.random.split(k, 4) 177 | kq, kk, kv, ko = jax.random.split(k_att, 4) 178 | k_gate, k_up, k_down = jax.random.split(k_ff, 3) 179 | 180 | att_params = { 181 | "W_query": jax.random.normal(kq, (cfg["emb_dim"], cfg["n_heads"] * cfg["head_dim"])) / jnp.sqrt(cfg["emb_dim"]), 182 | "W_key": jax.random.normal(kk, (cfg["emb_dim"], cfg["n_kv_groups"] * cfg["head_dim"])) / jnp.sqrt(cfg["emb_dim"]), 183 | "W_value": jax.random.normal(kv, (cfg["emb_dim"], cfg["n_kv_groups"] * cfg["head_dim"])) / jnp.sqrt(cfg["emb_dim"]), 184 | "out_proj": jax.random.normal(ko, (cfg["n_heads"] * cfg["head_dim"], cfg["emb_dim"])) / jnp.sqrt(cfg["n_heads"] * cfg["head_dim"]), 185 | } 186 | 187 | if cfg["qk_norm"]: 188 | att_params["q_norm"] = {"scale": jnp.ones((cfg["head_dim"],))} 189 | att_params["k_norm"] = {"scale": jnp.ones((cfg["head_dim"],))} 190 | 191 | return { 192 | "att": att_params, 193 | "ff": { 194 | "gate_proj": jax.random.normal(k_gate, (cfg["emb_dim"], cfg["hidden_dim"])) / jnp.sqrt(cfg["emb_dim"]), 195 | "up_proj": jax.random.normal(k_up, (cfg["emb_dim"], cfg["hidden_dim"])) / jnp.sqrt(cfg["emb_dim"]), 196 | "down_proj": jax.random.normal(k_down, (cfg["hidden_dim"], cfg["emb_dim"])) / jnp.sqrt(cfg["hidden_dim"]), 197 | }, 198 | "norm1": {"scale": jnp.ones((cfg["emb_dim"],))}, 199 | "norm2": {"scale": jnp.ones((cfg["emb_dim"],))}, 200 | } 201 | 202 | trf_blocks = [init_block_params(k) for k in block_keys] 203 | final_norm = {"scale": jnp.ones((cfg["emb_dim"],))} 204 | out_head = jax.random.normal(k_out, (cfg["emb_dim"], cfg["vocab_size"])) / jnp.sqrt(cfg["emb_dim"]) 205 | cos, sin = compute_rope_params(cfg["head_dim"], cfg["rope_base"], cfg["context_length"]) 206 | 207 | params = {"tok_emb": tok_emb, "trf_blocks": trf_blocks, "final_norm": final_norm, "out_head": out_head, "cos": cos, "sin": sin} 208 | 209 | return jax.tree.map(lambda x: jax.device_put(x, device), params) 210 | 211 | def qwen3_forward_kv(params, x, cfg, kv_cache=None): 212 | x = params["tok_emb"][x] 213 | mask = jnp.triu(jnp.ones((cfg["context_length"], cfg["context_length"]), dtype=bool), k=1) 214 | 215 | new_cache = [] 216 | for i, block_params in enumerate(params["trf_blocks"]): 217 | layer_cache = kv_cache[i] if kv_cache else None 218 | x, updated_cache = transformer_block_forward_kv(block_params, x, mask, params["cos"], params["sin"], cfg, layer_cache) 219 | new_cache.append(updated_cache) 220 | 221 | x = rmsnorm_forward(params["final_norm"], x) 222 | logits = jnp.einsum('bse,ev->bsv', x, params["out_head"]) 223 | 224 | return logits, new_cache 225 | 226 | def generate_kv_optimized(model, idx, max_new_tokens, context_size, temperature=0.7, top_k=50, eos_id=None, batch_size=1): 227 | params, cfg = model["params"], model["cfg"] 228 | 229 | # Keep input on device 230 | cur_ids = jnp.array([idx] * batch_size) if batch_size > 1 else jnp.array([idx]) 231 | key = jax.random.PRNGKey(42) 232 | 233 | # Initialize KV cache for batch processing 234 | kv_cache = [{"keys": jnp.zeros((batch_size, cfg["n_kv_groups"], 0, cfg["head_dim"])), 235 | "values": jnp.zeros((batch_size, cfg["n_kv_groups"], 0, cfg["head_dim"]))} 236 | for _ in range(cfg["n_layers"])] 237 | 238 | logits, kv_cache = qwen3_forward_kv(params, cur_ids, cfg, kv_cache) 239 | 240 | for i in tqdm(range(max_new_tokens), desc="Generating"): 241 | next_token_logits = logits[:, -1, :] 242 | 243 | if top_k is not None and top_k > 0: 244 | # Vectorized top_k for batch processing 245 | top_k_logits, top_k_indices = jax.lax.top_k(next_token_logits, top_k) 246 | mask = jnp.full_like(next_token_logits, -jnp.inf) 247 | mask = jnp.take_along_axis(mask, top_k_indices, axis=-1) 248 | mask = jnp.where(jnp.arange(mask.shape[-1])[None, :] < top_k, top_k_logits, -jnp.inf) 249 | next_token_logits = jnp.full_like(next_token_logits, -jnp.inf) 250 | next_token_logits = next_token_logits.at[jnp.arange(batch_size)[:, None], top_k_indices].set(mask) 251 | 252 | if temperature > 0.0: 253 | next_token_logits = next_token_logits / temperature 254 | key, subkey = jax.random.split(key) 255 | next_token = jax.random.categorical(subkey, next_token_logits, axis=-1) 256 | else: 257 | next_token = jnp.argmax(next_token_logits, axis=-1) 258 | 259 | # Check EOS for all sequences in batch - keep on device 260 | if eos_id is not None and jnp.any(next_token == eos_id): 261 | break 262 | 263 | cur_ids = jnp.concatenate([cur_ids, next_token[:, None]], axis=1) 264 | 265 | # Process next tokens for entire batch 266 | logits, kv_cache = qwen3_forward_kv(params, next_token[:, None], cfg, kv_cache) 267 | 268 | return cur_ids 269 | 270 | def assign_layer_weights(block_params, converted_weights, qk_norm=False): 271 | weight_map = { 272 | "self_attn.q_proj.weight": ("att", "W_query", True), 273 | "self_attn.k_proj.weight": ("att", "W_key", True), 274 | "self_attn.v_proj.weight": ("att", "W_value", True), 275 | "self_attn.o_proj.weight": ("att", "out_proj", True), 276 | "input_layernorm.weight": ("norm1", "scale", False), 277 | "post_attention_layernorm.weight": ("norm2", "scale", False), 278 | "mlp.gate_proj.weight": ("ff", "gate_proj", True), 279 | "mlp.up_proj.weight": ("ff", "up_proj", True), 280 | "mlp.down_proj.weight": ("ff", "down_proj", True), 281 | } 282 | 283 | if qk_norm: 284 | weight_map.update({ 285 | "self_attn.q_norm.weight": ("att", "q_norm", "scale", False), 286 | "self_attn.k_norm.weight": ("att", "k_norm", "scale", False), 287 | }) 288 | 289 | for key, tensor in converted_weights.items(): 290 | if key in weight_map: 291 | if len(weight_map[key]) == 3: 292 | section, param, transpose = weight_map[key] 293 | block_params[section][param] = tensor.T if transpose else tensor 294 | elif len(weight_map[key]) == 4: 295 | section, subsection, param, transpose = weight_map[key] 296 | if subsection in block_params[section]: 297 | block_params[section][subsection][param] = tensor.T if transpose else tensor 298 | 299 | def load_and_convert_file_weights(file_path, jax_params, cfg): 300 | pt_params = load_file(str(file_path)) 301 | file_weights, layer_weights = {}, defaultdict(dict) 302 | 303 | for key, tensor in pt_params.items(): 304 | if key == "model.embed_tokens.weight": 305 | file_weights["tok_emb"] = tensor 306 | elif key == "model.norm.weight": 307 | file_weights["final_norm"] = tensor 308 | elif key == "lm_head.weight": 309 | file_weights["out_head"] = tensor 310 | elif key.startswith("model.layers."): 311 | parts = key.split(".") 312 | layer_idx = int(parts[2]) 313 | layer_weights[layer_idx][".".join(parts[3:])] = tensor 314 | 315 | if file_weights: 316 | converted_global = batch_convert_numpy_weights(file_weights) 317 | if "tok_emb" in converted_global: 318 | jax_params["tok_emb"] = converted_global["tok_emb"] 319 | if "final_norm" in converted_global: 320 | jax_params["final_norm"]["scale"] = converted_global["final_norm"] 321 | if "out_head" in converted_global: 322 | jax_params["out_head"] = converted_global["out_head"].T 323 | 324 | for layer_idx, weights in layer_weights.items(): 325 | if layer_idx < len(jax_params["trf_blocks"]): 326 | converted_layer = batch_convert_numpy_weights(weights) 327 | assign_layer_weights(jax_params["trf_blocks"][layer_idx], converted_layer, cfg["qk_norm"]) 328 | 329 | del pt_params 330 | cleanup_memory() 331 | 332 | def load_qwen3_weights_jax_optimized(param_config, jax_params, safetensors_files): 333 | for i, file_path in enumerate(safetensors_files): 334 | print(f"Loading file {i+1}/{len(safetensors_files)}: {file_path.name}") 335 | load_and_convert_file_weights(file_path, jax_params, param_config) 336 | cleanup_memory() 337 | 338 | if "lm_head.weight" not in [key for file_path in safetensors_files for key in load_file(str(file_path)).keys()]: 339 | if jax_params["tok_emb"] is not None: 340 | jax_params["out_head"] = jax_params["tok_emb"].T 341 | 342 | return jax_params 343 | 344 | if __name__ == "__main__": 345 | HF_REPO_ID = "Qwen/Qwen3-0.6B" 346 | 347 | model_path = download_model_from_hf(HF_REPO_ID) 348 | safetensors_files = list(Path(model_path).glob("*.safetensors")) 349 | safetensors_files.sort() 350 | 351 | tokenizer_path = model_path / "tokenizer.json" 352 | tokenizer = Qwen3Tokenizer(str(tokenizer_path) if tokenizer_path.exists() else "tokenizer.json", repo_id=HF_REPO_ID) 353 | 354 | prompt = "Give me a short introduction to large language models." 355 | input_ids = tokenizer.encode(prompt) 356 | if len(input_ids) > QWEN3_CONFIG["context_length"]: 357 | input_ids = input_ids[:QWEN3_CONFIG["context_length"]] 358 | 359 | # Keep input on device from start 360 | input_token_ids = jnp.array(input_ids) 361 | 362 | cfg = QWEN3_CONFIG 363 | key = jax.random.PRNGKey(0) 364 | params = init_qwen3_params(key, cfg) 365 | params = load_qwen3_weights_jax_optimized(cfg, params, safetensors_files) 366 | model = {"params": params, "cfg": cfg} 367 | 368 | import time 369 | start_time = time.time() 370 | 371 | # Generate with optimized function (batch_size=1 for single sequence) 372 | output_token_ids = generate_kv_optimized( 373 | model=model, idx=input_token_ids, max_new_tokens=50, 374 | context_size=QWEN3_CONFIG["context_length"], top_k=1, 375 | temperature=0, eos_id=None, batch_size=1 376 | ) 377 | 378 | generation_time = time.time() - start_time 379 | 380 | # Only move to CPU at the very end for decoding 381 | output_text = tokenizer.decode(list(output_token_ids[0])) 382 | print("\n" + "="*50) 383 | print("GENERATED TEXT :") 384 | print("="*50) 385 | print(output_text) 386 | print(f"Time taken: {generation_time:.2f}s") 387 | print("="*50) --------------------------------------------------------------------------------