├── inference.py ├── README.md └── benchmark.py /inference.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | 3 | model_name = "HuggingFaceTB/SmolLM3-3B" 4 | 5 | model = AutoModelForCausalLM.from_pretrained( 6 | model_name, 7 | attn_implementation="kernels-community/flash-attn3:flash_attention", 8 | ) 9 | tokenizer = AutoTokenizer.from_pretrained(model_name) 10 | inputs = tokenizer("Hello, how are you?", return_tensors="pt").to(model.device) 11 | outputs = model.generate(**inputs) 12 | 13 | print(tokenizer.decode(outputs[0], skip_special_tokens=True)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Experiments with Kernels 🪛 2 | 3 | There's a new kid in the block - Kernels to build and use faster and efficient kernels for your ML models. Kernels allows you to download and use pre-compiled compute kernels without installing and compiling them from scratch. 4 | 5 | To put this in words, you don't need to spend 2 hours compiling flash attention in your Python Environment. You can just pull it from the hub and the kernels library will take care of pulling the right version fit to your runtime for you. 6 | 7 | Here's how this would look like in practice to run SmolLM3 with Flash Attention. 8 | 9 | Installation: 10 | 11 | ```bash 12 | uv pip install -U transformers kernels 13 | ``` 14 | 15 | Followed by: 16 | 17 | ```python 18 | from transformers import AutoModelForCausalLM, AutoTokenizer 19 | 20 | model_name = "HuggingFaceTB/SmolLM3-3B" 21 | 22 | model = AutoModelForCausalLM.from_pretrained( 23 | model_name, 24 | attn_implementation="kernels-community/flash-attn", 25 | ) 26 | tokenizer = AutoTokenizer.from_pretrained(model_name) 27 | inputs = tokenizer("Hello, how are you?", return_tensors="pt").to(model.device) 28 | outputs = model.generate(**inputs) 29 | 30 | print(tokenizer.decode(outputs[0], skip_special_tokens=True)) 31 | ``` 32 | 33 | Traditionally just installing flash attention would take up to an hour if not more leading to compute down time and not to mention how finnicky the whole process is. 34 | 35 | Now, you can just point to the Flash Attention 3 kernel repo on the hub and that's it! 36 | 37 | Best Part: You can swap to use Flash Attention 3 in just a line change. 38 | 39 | ```diff 40 | from transformers import AutoModelForCausalLM, AutoTokenizer 41 | 42 | model_name = "HuggingFaceTB/SmolLM3-3B" 43 | 44 | model = AutoModelForCausalLM.from_pretrained( 45 | model_name, 46 | - attn_implementation="kernels-community/flash-attn", 47 | + attn_implementation="kernels-community/flash-attn3:flash_attention", 48 | ) 49 | tokenizer = AutoTokenizer.from_pretrained(model_name) 50 | inputs = tokenizer("Hello, how are you?", return_tensors="pt").to(model.device) 51 | outputs = model.generate(**inputs) 52 | 53 | print(tokenizer.decode(outputs[0], skip_special_tokens=True)) 54 | ``` 55 | 56 | Let's run a quick benchmark: 57 | 58 | 59 | === Throughput & Memory Comparison (Single H100) === 60 | Model: HuggingFaceTB/SmolLM3-3B 61 | Token budgets: [512, 2048] 62 | Batch sizes: [1, 16, 32] 63 | torch.utils.benchmark min_run_time=2.0s | mem_repeats=3 64 | 65 | | Batch | Tokens | eager (Latency s / Tok/s / Alloc GiB / Reserved GiB) | sdpa (Latency s / Tok/s / Alloc GiB / Reserved GiB) | flash-attn3 (Latency s / Tok/s / Alloc GiB / Reserved GiB) | flash-attn (Latency s / Tok/s / Alloc GiB / Reserved GiB) | 66 | |-------|--------|------------------------------------------------------|-----------------------------------------------------|------------------------------------------------------------|-----------------------------------------------------------| 67 | | 1 | 512 | 17.895 / 28.6 / 5.80 / 5.83 | 14.180 / 36.1 / 5.80 / 5.81 | 18.530 / 27.6 / 5.80 / 5.81 | 18.650 / 27.5 / 5.80 / 5.81 | 68 | | 16 | 512 | 18.036 / 454.2 / 6.46 / 6.86 | 17.595 / 465.6 / 6.46 / 6.86 | 23.669 / 346.1 / 6.41 / 6.62 | 23.612 / 346.9 / 6.41 / 6.62 | 69 | | 32 | 512 | 18.572 / 882.2 / 7.15 / 8.46 | 17.478 / 937.4 / 7.15 / 8.46 | 23.887 / 685.9 / 7.08 / 8.43 | 24.025 / 682.0 / 7.08 / 8.47 | 70 | | 1 | 2048 | 71.426 / 28.7 / 5.92 / 6.04 | 56.245 / 36.4 / 5.91 / 6.02 | 73.922 / 27.7 / 5.91 / 6.02 | 74.111 / 27.6 / 5.91 / 6.02 | 71 | | 16 | 2048 | 71.743 / 456.7 / 8.34 / 14.83 | 70.460 / 465.1 / 8.33 / 14.83 | 95.068 / 344.7 / 8.17 / 23.79 | 94.579 / 346.5 / 8.17 / 23.79 | 72 | | 32 | 2048 | 77.889 / 841.4 / 10.91 / 39.77 | 75.058 / 873.1 / 10.90 / 39.78 | 96.181 / 681.4 / 10.61 / 78.71 | 96.758 / 677.3 / 10.61 / 78.70 | 73 | 74 | **Legend:** 75 | - eager = `attn_implementation=eager` 76 | - sdpa = `attn_implementation=sdpa` 77 | - flash-attn3 = `attn_implementation=kernels-community/flash-attn3:flash_attention` 78 | - flash-attn = `attn_implementation=kernels-community/flash-attn` 79 | - Each cell: `Median Latency (s) / Tokens/s / Peak Alloc (GiB) / Peak Reserved (GiB)` 80 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import time 4 | import statistics 5 | from dataclasses import dataclass, field 6 | 7 | import torch 8 | from torch.utils import benchmark 9 | from transformers import AutoTokenizer, AutoModelForCausalLM, Mxfp4Config 10 | 11 | 12 | # ---------------------------- 13 | # Config 14 | # ---------------------------- 15 | MODEL_ID = "HuggingFaceTB/SmolLM3-3B" 16 | DEVICE = "cuda" # Single H100 17 | SEED = 1234 18 | NUM_THREADS = torch.get_num_threads() 19 | TOKEN_BUDGETS = [512, 2048] 20 | BATCH_SIZES = [1, 16, 32] # batch sizes to test (processes multiple inputs simultaneously) 21 | WARMUP_GENERATIONS = 1 # warmup calls (not timed) 22 | MEM_REPEATS = 3 # times to measure peak memory per setting 23 | TIMER_MIN_RUNTIME_S = 2.0 # torch.benchmark blocked_autorange budget 24 | 25 | # Attention implementations to benchmark 26 | ATTN_IMPLEMENTATIONS = [ 27 | "eager", 28 | "sdpa", 29 | "kernels-community/flash-attn3:flash_attention", 30 | "kernels-community/flash-attn" 31 | ] 32 | 33 | 34 | # ---------------------------- 35 | # Utilities 36 | # ---------------------------- 37 | def set_determinism(seed=SEED): 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | 41 | 42 | def to_gib(x_bytes: int) -> float: 43 | return x_bytes / (1024 ** 3) 44 | 45 | 46 | def clear_cuda(): 47 | torch.cuda.empty_cache() 48 | torch.cuda.ipc_collect() 49 | gc.collect() 50 | 51 | 52 | @dataclass 53 | class Point: 54 | latency_s: float 55 | tokens: int 56 | toks_per_s: float 57 | peak_alloc_gib: float 58 | peak_reserved_gib: float 59 | batch_size: int = 1 60 | 61 | 62 | @dataclass 63 | class ScenarioResult: 64 | label: str 65 | batch_size: int = 1 66 | # map: max_new_tokens -> Point 67 | by_tokens: dict = field(default_factory=dict) 68 | 69 | 70 | # ---------------------------- 71 | # Model + inputs 72 | # ---------------------------- 73 | def load_model(attn_implementation: str): 74 | model = AutoModelForCausalLM.from_pretrained( 75 | MODEL_ID, 76 | torch_dtype="auto", 77 | device_map=DEVICE, 78 | attn_implementation=attn_implementation, 79 | ).eval() 80 | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) 81 | # Ensure tokenizer has a pad token for proper batching 82 | if tokenizer.pad_token is None: 83 | tokenizer.pad_token = tokenizer.eos_token 84 | return tokenizer, model, model.device 85 | 86 | 87 | def build_inputs(tokenizer, device, batch_size=1): 88 | # Create different prompts for each item in the batch to simulate realistic usage 89 | base_prompts = [ 90 | "What is Tensor Parallelism?", 91 | "Explain machine learning fundamentals.", 92 | "How do neural networks work?", 93 | "What are the benefits of distributed computing?", 94 | "Describe the attention mechanism in transformers.", 95 | "What is gradient descent?", 96 | "How does backpropagation work?", 97 | "Explain the concept of overfitting.", 98 | ] 99 | 100 | # Cycle through prompts to create a batch 101 | batch_messages = [] 102 | for i in range(batch_size): 103 | prompt = base_prompts[i % len(base_prompts)] 104 | messages = [{"role": "system", "content": prompt}] 105 | batch_messages.append(messages) 106 | 107 | # Apply chat template to each conversation in the batch 108 | batch_texts = [] 109 | for messages in batch_messages: 110 | text = tokenizer.apply_chat_template( 111 | messages, 112 | add_generation_prompt=True, 113 | tokenize=False, 114 | ) 115 | batch_texts.append(text) 116 | 117 | # Tokenize all texts together with padding to handle variable lengths 118 | if batch_size == 1: 119 | inputs = tokenizer( 120 | batch_texts[0], 121 | return_tensors="pt", 122 | ) 123 | return inputs.to(device) 124 | else: 125 | # Use padding to handle variable-length sequences 126 | inputs = tokenizer( 127 | batch_texts, 128 | return_tensors="pt", 129 | padding=True, # Pad to the longest sequence in the batch 130 | truncation=True, # Ensure we don't exceed model's max length 131 | ) 132 | return {k: v.to(device) for k, v in inputs.items()} 133 | 134 | 135 | # ---------------------------- 136 | # Generation helpers 137 | # ---------------------------- 138 | @torch.inference_mode() 139 | def generate_once(model, model_inputs, max_new_tokens: int): 140 | # eos_token_id=-1 prevents early stop; disable_compile=True matches baseline 141 | return model.generate( 142 | **model_inputs, 143 | do_sample=False, 144 | temperature=None, 145 | max_new_tokens=max_new_tokens, 146 | eos_token_id=-1, 147 | disable_compile=True, 148 | return_dict_in_generate=True, # so we can count actual generated length 149 | ) 150 | 151 | 152 | def run_memory_probe(model, model_inputs, device, max_new_tokens: int): 153 | torch.cuda.reset_peak_memory_stats(device) 154 | clear_cuda() 155 | 156 | # One measured pass 157 | torch.cuda.synchronize(device) 158 | _out = generate_once(model, model_inputs, max_new_tokens) 159 | torch.cuda.synchronize(device) 160 | 161 | peak_alloc = torch.cuda.max_memory_allocated(device) 162 | peak_reserved = torch.cuda.max_memory_reserved(device) 163 | return to_gib(peak_alloc), to_gib(peak_reserved) 164 | 165 | 166 | def measure_latency_with_torch_benchmark(model, model_inputs, max_new_tokens: int): 167 | # Closure for timing; torch.utils.benchmark will handle CUDA syncs. 168 | def gen_fn(): 169 | generate_once(model, model_inputs, max_new_tokens) 170 | 171 | t = benchmark.Timer( 172 | stmt="gen_fn()", 173 | globals={"gen_fn": gen_fn}, 174 | num_threads=NUM_THREADS, 175 | ) 176 | m = t.blocked_autorange(min_run_time=TIMER_MIN_RUNTIME_S) 177 | return float(m.median) # seconds 178 | 179 | 180 | # ---------------------------- 181 | # Benchmark driver 182 | # ---------------------------- 183 | def benchmark_scenario(attn_implementation: str, batch_size: int = 1) -> ScenarioResult: 184 | set_determinism() 185 | tokenizer, model, device = load_model(attn_implementation) 186 | inputs = build_inputs(tokenizer, device, batch_size) 187 | 188 | # Warmups (not measured) - use smaller token budget for larger batches to avoid OOM 189 | # Scale down the warmup tokens based on batch size to prevent memory issues 190 | if batch_size >= 32: 191 | warmup_tokens = min(256, min(TOKEN_BUDGETS) // 2) # Very conservative for large batches 192 | elif batch_size >= 16: 193 | warmup_tokens = min(TOKEN_BUDGETS) // 2 # Half the smallest budget for medium-large batches 194 | elif batch_size > 4: 195 | warmup_tokens = min(TOKEN_BUDGETS) # Smallest budget for moderately large batches 196 | else: 197 | warmup_tokens = TOKEN_BUDGETS[0] # First budget for small batches 198 | for _ in range(WARMUP_GENERATIONS): 199 | _ = generate_once(model, inputs, max_new_tokens=warmup_tokens) 200 | torch.cuda.synchronize(device) 201 | 202 | result = ScenarioResult( 203 | label=f"attn_implementation={attn_implementation}", 204 | batch_size=batch_size 205 | ) 206 | 207 | for toks in TOKEN_BUDGETS: 208 | try: 209 | # Clear memory before each token budget test 210 | clear_cuda() 211 | 212 | # Timing via torch.benchmark 213 | latency_s = measure_latency_with_torch_benchmark(model, inputs, toks) 214 | 215 | # Actual generated token count (sanity; should equal toks per item in batch) 216 | out = generate_once(model, inputs, toks) 217 | actual_tokens_per_item = (out.sequences.shape[1] - inputs["input_ids"].shape[1]) 218 | total_tokens = actual_tokens_per_item * batch_size 219 | del out 220 | torch.cuda.synchronize(device) 221 | 222 | # Memory: take median of multiple probes 223 | allocs, reserveds = [], [] 224 | for _ in range(MEM_REPEATS): 225 | pa, pr = run_memory_probe(model, inputs, device, toks) 226 | allocs.append(pa) 227 | reserveds.append(pr) 228 | 229 | med_alloc = statistics.median(allocs) 230 | med_reserved = statistics.median(reserveds) 231 | 232 | result.by_tokens[toks] = Point( 233 | latency_s=latency_s, 234 | tokens=total_tokens, 235 | toks_per_s=(total_tokens / latency_s) if latency_s > 0 else float("nan"), 236 | peak_alloc_gib=med_alloc, 237 | peak_reserved_gib=med_reserved, 238 | batch_size=batch_size, 239 | ) 240 | except torch.cuda.OutOfMemoryError as e: 241 | print(f"OOM for batch_size={batch_size}, tokens={toks}: {e}") 242 | # Clear memory and continue with next token budget 243 | clear_cuda() 244 | continue 245 | except Exception as e: 246 | print(f"Error for batch_size={batch_size}, tokens={toks}: {e}") 247 | clear_cuda() 248 | continue 249 | 250 | # Cleanup 251 | del tokenizer, model, inputs 252 | clear_cuda() 253 | return result 254 | 255 | 256 | def print_comparison(results): 257 | # Pretty table per token budget 258 | print("\n=== Throughput & Memory Comparison (Single H100) ===") 259 | print(f"Model: {MODEL_ID}") 260 | print(f"Token budgets: {TOKEN_BUDGETS}") 261 | print(f"Batch sizes: {BATCH_SIZES}") 262 | print(f"torch.utils.benchmark min_run_time={TIMER_MIN_RUNTIME_S}s | mem_repeats={MEM_REPEATS}") 263 | print("-" * 140) 264 | 265 | header = ( 266 | f"{'Attention Implementation':<40} {'Batch':>6} {'Tokens':>8} " 267 | f"{'Median Latency (s)':>20} {'Tokens/s':>12} " 268 | f"{'Peak Alloc (GiB)':>18} {'Peak Reserved (GiB)':>20}" 269 | ) 270 | print(header) 271 | print("-" * 140) 272 | 273 | for toks in TOKEN_BUDGETS: 274 | for res in results: 275 | if toks in res.by_tokens: 276 | p = res.by_tokens[toks] 277 | print( 278 | f"{res.label:<40} {res.batch_size:>6d} {toks:>8d} " 279 | f"{p.latency_s:>20.3f} {p.toks_per_s:>12.1f} " 280 | f"{p.peak_alloc_gib:>18.2f} {p.peak_reserved_gib:>20.2f}" 281 | ) 282 | else: 283 | print( 284 | f"{res.label:<40} {res.batch_size:>6d} {toks:>8d} " 285 | f"{'OOM/FAILED':>20} {'N/A':>12} " 286 | f"{'N/A':>18} {'N/A':>20}" 287 | ) 288 | print("-" * 140) 289 | 290 | 291 | # ---------------------------- 292 | # Main 293 | # ---------------------------- 294 | if __name__ == "__main__": 295 | # Optional: ensure only one visible GPU (uncomment if needed) 296 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 297 | 298 | assert torch.cuda.is_available(), "CUDA is required" 299 | assert torch.device(DEVICE).type == "cuda" 300 | torch.set_num_threads(NUM_THREADS) 301 | 302 | results = [] 303 | for attn_impl in ATTN_IMPLEMENTATIONS: 304 | for batch_size in BATCH_SIZES: 305 | print(f"\n>>> Benchmarking with attn_implementation={attn_impl}, batch_size={batch_size}") 306 | try: 307 | res = benchmark_scenario(attn_implementation=attn_impl, batch_size=batch_size) 308 | results.append(res) 309 | except Exception as e: 310 | print(f"Failed to benchmark {attn_impl} with batch_size={batch_size}: {e}") 311 | continue 312 | 313 | if results: 314 | print_comparison(results) 315 | else: 316 | print("No successful benchmarks completed.") --------------------------------------------------------------------------------