├── results ├── benchmark_plot.png ├── performance_report.txt └── results.json ├── .gitignore ├── setup.py ├── python_bindings.cpp ├── utils.py ├── run_all.sh ├── kernels.cu ├── README.md └── benchmark.py /results/benchmark_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shreshthkapai/cuda_latency_benchmark/HEAD/results/benchmark_plot.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore Python virtual environments 2 | cuda_env/ 3 | 4 | # Ignore build artifacts 5 | build/ 6 | *.so 7 | *.pyc 8 | __pycache__/ 9 | 10 | # Ignore IDE settings 11 | .vscode/ 12 | 13 | # Ignore logs, outputs, and generated results 14 | results/ 15 | .ninja_log 16 | *.log 17 | -------------------------------------------------------------------------------- /results/performance_report.txt: -------------------------------------------------------------------------------- 1 | ================================================================================ 2 | 🚀 GPU TASK QUEUE PERFORMANCE REPORT 3 | ================================================================================ 4 | 5 | 🏆 Best Performer: gemv_b32_i64_o32 (0.011ms median) 6 | 🐌 Worst Performer: price_b32_a64_f32 (0.044ms median) 7 | ⚡ Average Speedup: 3.8x 8 | 🚀 Maximum Speedup: 7.3x 9 | 10 | 📊 DETAILED RESULTS: 11 | -------------------------------------------------------------------------------- 12 | 13 | gemv_b32_i64_o32: 14 | Latency: 0.011ms (median), 0.076ms (P95) 15 | Throughput: 93563 ops/sec 16 | 🚀 Speedup: 7.3x (629.5% improvement) 17 | Stability: ±0.032ms std dev 18 | 19 | gemv_b32_i64_o64: 20 | Latency: 0.012ms (median), 0.147ms (P95) 21 | Throughput: 82672 ops/sec 22 | 🚀 Speedup: 5.2x (424.2% improvement) 23 | Stability: ±0.049ms std dev 24 | 25 | softmax_b32_d64: 26 | Latency: 0.041ms (median), 0.178ms (P95) 27 | Throughput: 24357 ops/sec 28 | 🚀 Speedup: 1.3x (29.7% improvement) 29 | Stability: ±0.042ms std dev 30 | 31 | price_b32_a64_f32: 32 | Latency: 0.044ms (median), 0.176ms (P95) 33 | Throughput: 22498 ops/sec 34 | 🚀 Speedup: 1.5x (53.6% improvement) 35 | Stability: ±0.042ms std dev -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | import torch 5 | import pybind11 6 | 7 | # Safe fallback for CUDA_HOME 8 | if "CUDA_HOME" not in os.environ: 9 | os.environ["CUDA_HOME"] = torch.utils.cpp_extension.CUDA_HOME or "/usr/local/cuda" 10 | 11 | # Get compute capability (e.g. sm_75 for GTX 1650) 12 | def get_cuda_arch(): 13 | if torch.cuda.is_available(): 14 | major, minor = torch.cuda.get_device_capability() 15 | return f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}" 16 | return "-gencode=arch=compute_75,code=sm_75" 17 | 18 | cuda_flags = [ 19 | "--use_fast_math", 20 | "-O3", 21 | get_cuda_arch(), 22 | "--extended-lambda", 23 | "-DNVTX_DISABLE", 24 | ] 25 | 26 | cpp_flags = [ 27 | "-O3", 28 | "-std=c++17", 29 | "-DWITH_CUDA", 30 | "-DNVTX_DISABLE", 31 | ] 32 | 33 | cuda_extension = CUDAExtension( 34 | name="cuda_task_queue", 35 | sources=[ 36 | "kernels.cu", 37 | "python_bindings.cpp" 38 | ], 39 | extra_compile_args={ 40 | "cxx": cpp_flags, 41 | "nvcc": cuda_flags, 42 | }, 43 | include_dirs=[ 44 | pybind11.get_cmake_dir() + "/../../../include", 45 | ], 46 | ) 47 | 48 | setup( 49 | name="cuda-latency-bench", 50 | version="1.0.0", 51 | description="Sub-millisecond GPU task queue for real-time inference", 52 | author="GPU Performance Engineer", 53 | ext_modules=[cuda_extension], 54 | cmdclass={"build_ext": BuildExtension}, 55 | python_requires=">=3.8", 56 | install_requires=[ 57 | "torch>=1.12.0", 58 | "numpy>=1.21.0", 59 | "pybind11>=2.6.0", 60 | ], 61 | zip_safe=False, 62 | ) 63 | -------------------------------------------------------------------------------- /python_bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // Forward declarations from kernels.cu 8 | extern "C" { 9 | void launch_batched_gemv( 10 | const float* weights, const float* inputs, float* outputs, 11 | int batch_size, int input_dim, int output_dim, 12 | cudaStream_t stream 13 | ); 14 | 15 | void launch_batched_softmax( 16 | const float* inputs, float* outputs, 17 | int batch_size, int dim, 18 | cudaStream_t stream 19 | ); 20 | 21 | void launch_price_vector_processing( 22 | const float* prices, const float* weights, float* features, 23 | int batch_size, int n_assets, int n_features, 24 | cudaStream_t stream 25 | ); 26 | } 27 | 28 | // PyTorch tensor wrappers for zero-copy GPU operations 29 | torch::Tensor batched_gemv( 30 | torch::Tensor weights, 31 | torch::Tensor inputs, 32 | torch::Tensor outputs 33 | ) { 34 | TORCH_CHECK(weights.is_cuda(), "weights must be CUDA tensor"); 35 | TORCH_CHECK(inputs.is_cuda(), "inputs must be CUDA tensor"); 36 | TORCH_CHECK(outputs.is_cuda(), "outputs must be CUDA tensor"); 37 | 38 | int batch_size = inputs.size(0); 39 | int input_dim = inputs.size(1); 40 | int output_dim = outputs.size(1); 41 | 42 | launch_batched_gemv( 43 | weights.data_ptr(), 44 | inputs.data_ptr(), 45 | outputs.data_ptr(), 46 | batch_size, input_dim, output_dim, 47 | c10::cuda::getCurrentCUDAStream() 48 | ); 49 | 50 | return outputs; 51 | } 52 | 53 | torch::Tensor batched_softmax(torch::Tensor inputs, torch::Tensor outputs) { 54 | TORCH_CHECK(inputs.is_cuda(), "inputs must be CUDA tensor"); 55 | TORCH_CHECK(outputs.is_cuda(), "outputs must be CUDA tensor"); 56 | 57 | int batch_size = inputs.size(0); 58 | int dim = inputs.size(1); 59 | 60 | launch_batched_softmax( 61 | inputs.data_ptr(), 62 | outputs.data_ptr(), 63 | batch_size, dim, 64 | c10::cuda::getCurrentCUDAStream() 65 | ); 66 | 67 | return outputs; 68 | } 69 | 70 | torch::Tensor process_price_vectors( 71 | torch::Tensor prices, 72 | torch::Tensor weights, 73 | torch::Tensor features 74 | ) { 75 | TORCH_CHECK(prices.is_cuda(), "prices must be CUDA tensor"); 76 | TORCH_CHECK(weights.is_cuda(), "weights must be CUDA tensor"); 77 | TORCH_CHECK(features.is_cuda(), "features must be CUDA tensor"); 78 | 79 | int batch_size = prices.size(0); 80 | int n_assets = prices.size(1); 81 | int n_features = features.size(1); 82 | 83 | launch_price_vector_processing( 84 | prices.data_ptr(), 85 | weights.data_ptr(), 86 | features.data_ptr(), 87 | batch_size, n_assets, n_features, 88 | c10::cuda::getCurrentCUDAStream() 89 | ); 90 | 91 | return features; 92 | } 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 95 | m.doc() = "CUDA task queue kernels for sub-millisecond inference"; 96 | 97 | m.def("batched_gemv", &batched_gemv, "Batched matrix-vector multiply"); 98 | m.def("batched_softmax", &batched_softmax, "Batched softmax operation"); 99 | m.def("process_price_vectors", &process_price_vectors, "Price vector processing"); 100 | } 101 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import csv 3 | import json 4 | import time 5 | from pathlib import Path 6 | from typing import Dict 7 | 8 | def save_results_csv(results_data: Dict, filename: str = "benchmark_results.csv") -> None: 9 | Path(filename).parent.mkdir(parents=True, exist_ok=True) 10 | 11 | results_dict = results_data.get('optimized', {}) 12 | baseline_dict = results_data.get('baseline', {}) 13 | speedup_dict = results_data.get('speedup', {}) 14 | 15 | if not results_dict: 16 | print("No optimized results to save.") 17 | return 18 | 19 | all_keys = set() 20 | for stats in results_dict.values(): 21 | all_keys.update(stats.keys()) 22 | 23 | fieldnames = ['config'] + sorted(all_keys) 24 | if baseline_dict: 25 | fieldnames += ['baseline_median_ms'] 26 | if speedup_dict: 27 | extra_keys = set() 28 | for s in speedup_dict.values(): 29 | extra_keys.update(s.keys()) 30 | fieldnames += sorted(extra_keys) 31 | 32 | with open(filename, 'w', newline='') as csvfile: 33 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 34 | writer.writeheader() 35 | 36 | for config, stats in results_dict.items(): 37 | row = {'config': config, **stats} 38 | if config in baseline_dict: 39 | row['baseline_median_ms'] = baseline_dict[config].get('median_ms') 40 | if config in speedup_dict: 41 | row.update(speedup_dict[config]) 42 | writer.writerow(row) 43 | 44 | print(f"💾 Results successfully exported to {filename}") 45 | 46 | def save_results_json(results_data: Dict, filename: str = "benchmark_results.json") -> None: 47 | Path(filename).parent.mkdir(parents=True, exist_ok=True) 48 | output = { 49 | 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'), 50 | 'device_info': { 51 | 'gpu_name': torch.cuda.get_device_name() if torch.cuda.is_available() else 'N/A', 52 | 'cuda_version': torch.version.cuda if torch.cuda.is_available() else 'N/A', 53 | 'pytorch_version': torch.__version__ 54 | }, 55 | 'benchmark_results': results_data 56 | } 57 | with open(filename, 'w') as f: 58 | json.dump(output, f, indent=2) 59 | print(f"💾 Results and metadata successfully saved to {filename}") 60 | 61 | def merge_results(results_data): 62 | merged = {} 63 | for k in results_data.get("optimized", {}): 64 | merged[k] = results_data["optimized"][k].copy() 65 | if k in results_data.get("speedup", {}): 66 | merged[k].update(results_data["speedup"][k]) 67 | return merged 68 | 69 | def format_performance_report(results_data: Dict) -> str: 70 | report_lines = [] 71 | report_lines.append("="*80) 72 | report_lines.append("🚀 GPU TASK QUEUE PERFORMANCE REPORT") 73 | report_lines.append("="*80) 74 | 75 | if not results_data: 76 | return "No results available." 77 | 78 | latencies = {k: v['median_ms'] for k, v in results_data.items() if 'median_ms' in v} 79 | if not latencies: 80 | return "No valid latency data found." 81 | 82 | best_kernel = min(latencies.items(), key=lambda x: x[1]) 83 | worst_kernel = max(latencies.items(), key=lambda x: x[1]) 84 | speedups = [v.get("speedup_median", 0.0) for v in results_data.values()] 85 | avg_speedup = sum(speedups) / len(speedups) 86 | geo_speedup = 1.0 87 | for s in speedups: 88 | geo_speedup *= s 89 | geo_speedup = geo_speedup ** (1 / len(speedups)) if speedups else 1.0 90 | max_speedup = max(speedups) 91 | 92 | report_lines.append("") 93 | report_lines.append(f"🏆 Best Performer: {best_kernel[0]} ({best_kernel[1]:.3f}ms median)") 94 | report_lines.append(f"🐌 Worst Performer: {worst_kernel[0]} ({worst_kernel[1]:.3f}ms median)") 95 | report_lines.append(f"⚡ Average Speedup: {avg_speedup:.1f}x") 96 | report_lines.append(f"🚀 Maximum Speedup: {max_speedup:.1f}x") 97 | report_lines.append("") 98 | 99 | report_lines.append("📊 DETAILED RESULTS:") 100 | report_lines.append("-"*80) 101 | 102 | for kernel, stats in results_data.items(): 103 | median_ms = stats.get('median_ms', 0.0) 104 | p95_ms = stats.get('p95_ms', 0.0) 105 | std_ms = stats.get('std_ms', 0.0) 106 | throughput = 1000.0 / median_ms if median_ms > 0 else 0.0 107 | 108 | report_lines.append(f"\n{kernel}:") 109 | report_lines.append(f" Latency: {median_ms:.3f}ms (median), {p95_ms:.3f}ms (P95)") 110 | report_lines.append(f" Throughput: {throughput:.0f} ops/sec") 111 | report_lines.append(f" 🚀 Speedup: {stats.get('speedup_median', 0.0):.1f}x ({stats.get('improvement_pct', 0.0):.1f}% improvement)") 112 | report_lines.append(f" Stability: ±{std_ms:.3f}ms std dev") 113 | 114 | return "\n".join(report_lines) 115 | -------------------------------------------------------------------------------- /results/results.json: -------------------------------------------------------------------------------- 1 | { 2 | "timestamp": "2025-07-23 20:57:21", 3 | "device_info": { 4 | "gpu_name": "NVIDIA GeForce GTX 1650", 5 | "cuda_version": "12.1", 6 | "pytorch_version": "2.5.1+cu121" 7 | }, 8 | "benchmark_results": { 9 | "optimized": { 10 | "gemv_b32_i64_o32": { 11 | "kernel": "CUDA_GEMV", 12 | "mean_ms": 0.024624976025894283, 13 | "median_ms": 0.010688000358641148, 14 | "p95_ms": 0.07580960318446159, 15 | "p99_ms": 0.15981951504945754, 16 | "min_ms": 0.004991999827325344, 17 | "max_ms": 0.2991360127925873, 18 | "std_ms": 0.0315219763849146, 19 | "samples": 2000 20 | }, 21 | "gemv_b32_i64_o64": { 22 | "kernel": "CUDA_GEMV", 23 | "mean_ms": 0.04235577614884824, 24 | "median_ms": 0.012095999903976917, 25 | "p95_ms": 0.14745600521564484, 26 | "p99_ms": 0.15731488570570945, 27 | "min_ms": 0.006496000103652477, 28 | "max_ms": 0.7761920094490051, 29 | "std_ms": 0.04871345477368859, 30 | "samples": 2000 31 | }, 32 | "softmax_b32_d64": { 33 | "kernel": "CUDA_Softmax", 34 | "mean_ms": 0.059247296028770505, 35 | "median_ms": 0.04105599969625473, 36 | "p95_ms": 0.17820799350738525, 37 | "p99_ms": 0.18768032595515252, 38 | "min_ms": 0.019680000841617584, 39 | "max_ms": 0.21939200162887573, 40 | "std_ms": 0.041830511446051113, 41 | "samples": 2000 42 | }, 43 | "price_b32_a64_f32": { 44 | "kernel": "CUDA_PriceVectors", 45 | "mean_ms": 0.05726216000504792, 46 | "median_ms": 0.04444799944758415, 47 | "p95_ms": 0.17598080560564994, 48 | "p99_ms": 0.20976223543286324, 49 | "min_ms": 0.018688000738620758, 50 | "max_ms": 0.2666560113430023, 51 | "std_ms": 0.041545308810031704, 52 | "samples": 2000 53 | } 54 | }, 55 | "baseline": { 56 | "gemv_b32_i64_o32": { 57 | "kernel": "Baseline_GEMV", 58 | "mean_ms": 0.11040817595482803, 59 | "median_ms": 0.07796799764037132, 60 | "p95_ms": 0.2027519941329956, 61 | "p99_ms": 0.22732800245285034, 62 | "min_ms": 0.004927999805659056, 63 | "max_ms": 0.5736640095710754, 64 | "std_ms": 0.0551369126225431, 65 | "samples": 2000 66 | }, 67 | "gemv_b32_i64_o64": { 68 | "kernel": "Baseline_GEMV", 69 | "mean_ms": 0.08599102384201252, 70 | "median_ms": 0.06340799853205681, 71 | "p95_ms": 0.20749280005693435, 72 | "p99_ms": 0.2212198331952095, 73 | "min_ms": 0.007615999784320593, 74 | "max_ms": 0.7986559867858887, 75 | "std_ms": 0.05551890511137104, 76 | "samples": 2000 77 | }, 78 | "softmax_b32_d64": { 79 | "kernel": "Baseline_Softmax", 80 | "mean_ms": 0.05727793595800176, 81 | "median_ms": 0.053247999399900436, 82 | "p95_ms": 0.1754239946603775, 83 | "p99_ms": 0.18761984542012214, 84 | "min_ms": 0.004575999919325113, 85 | "max_ms": 0.7603520154953003, 86 | "std_ms": 0.04476570688447945, 87 | "samples": 2000 88 | }, 89 | "price_b32_a64_f32": { 90 | "kernel": "Baseline_PriceVectors", 91 | "mean_ms": 0.07857756791170686, 92 | "median_ms": 0.06828799843788147, 93 | "p95_ms": 0.19767519310116766, 94 | "p99_ms": 0.23835199683904645, 95 | "min_ms": 0.00854399986565113, 96 | "max_ms": 0.3007040023803711, 97 | "std_ms": 0.050395789304012936, 98 | "samples": 2000 99 | } 100 | }, 101 | "speedup": { 102 | "gemv_b32_i64_o32": { 103 | "speedup_median": 7.29490971408276, 104 | "speedup_mean": 4.4835851145084895, 105 | "speedup_p95": 2.6744895846460905, 106 | "baseline_median_ms": 0.07796799764037132, 107 | "optimized_median_ms": 0.010688000358641148, 108 | "improvement_pct": 629.490971408276 109 | }, 110 | "gemv_b32_i64_o64": { 111 | "speedup_median": 5.242063412319436, 112 | "speedup_mean": 2.030207722786608, 113 | "speedup_p95": 1.407150558252881, 114 | "baseline_median_ms": 0.06340799853205681, 115 | "optimized_median_ms": 0.012095999903976917, 116 | "improvement_pct": 424.2063412319436 117 | }, 118 | "softmax_b32_d64": { 119 | "speedup_median": 1.296960244394144, 120 | "speedup_mean": 0.9667603383990314, 121 | "speedup_p95": 0.9843778116109456, 122 | "baseline_median_ms": 0.053247999399900436, 123 | "optimized_median_ms": 0.04105599969625473, 124 | "improvement_pct": 29.696024439414394 125 | }, 126 | "price_b32_a64_f32": { 127 | "speedup_median": 1.5363570753822324, 128 | "speedup_mean": 1.3722424705037304, 129 | "speedup_p95": 1.123277009790102, 130 | "baseline_median_ms": 0.06828799843788147, 131 | "optimized_median_ms": 0.04444799944758415, 132 | "improvement_pct": 53.635707538223244 133 | } 134 | } 135 | } 136 | } -------------------------------------------------------------------------------- /run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA Task Queue - Production Build & Benchmark Automation 4 | # Directory: scripts/run_all.sh 5 | 6 | set -e # Exit on any error 7 | 8 | # Colors for output 9 | RED='\033[0;31m' 10 | GREEN='\033[0;32m' 11 | BLUE='\033[0;34m' 12 | YELLOW='\033[1;33m' 13 | NC='\033[0m' # No Color 14 | 15 | # Configuration 16 | DEFAULT_BATCH_SIZE=32 17 | DEFAULT_DIM=64 18 | DEFAULT_REPEATS=100 19 | RESULTS_DIR="./results" 20 | BUILD_DIR="./build" 21 | 22 | # Parse command line arguments 23 | BATCH_SIZE=${1:-$DEFAULT_BATCH_SIZE} 24 | DIM=${2:-$DEFAULT_DIM} 25 | REPEATS=${3:-$DEFAULT_REPEATS} 26 | 27 | print_header() { 28 | echo -e "${BLUE}================================${NC}" 29 | echo -e "${BLUE}🚀 CUDA Task Queue Benchmark${NC}" 30 | echo -e "${BLUE}================================${NC}" 31 | echo -e "Batch Size: ${YELLOW}${BATCH_SIZE}${NC}" 32 | echo -e "Dimensions: ${YELLOW}${DIM}${NC}" 33 | echo -e "Repeats: ${YELLOW}${REPEATS}${NC}" 34 | echo -e "${BLUE}================================${NC}" 35 | } 36 | 37 | cleanup_build() { 38 | echo -e "${YELLOW}[0/4] Cleaning previous builds...${NC}" 39 | rm -rf ${BUILD_DIR} *.so cuda_task_queue.* || true 40 | mkdir -p ${RESULTS_DIR} 41 | } 42 | 43 | compile_kernels() { 44 | echo -e "${BLUE}[1/4] Compiling CUDA kernels...${NC}" 45 | 46 | # Check CUDA availability 47 | if ! command -v nvcc &> /dev/null; then 48 | echo -e "${RED}❌ NVCC not found. Please install CUDA toolkit.${NC}" 49 | exit 1 50 | fi 51 | 52 | # Check PyTorch CUDA support 53 | python -c "import torch; assert torch.cuda.is_available(), 'PyTorch CUDA not available'" 2>/dev/null || { 54 | echo -e "${RED}❌ PyTorch CUDA support not detected.${NC}" 55 | exit 1 56 | } 57 | 58 | # Compile extension 59 | python setup.py build_ext --inplace || { 60 | echo -e "${RED}❌ CUDA compilation failed.${NC}" 61 | exit 1 62 | } 63 | 64 | echo -e "${GREEN}✅ CUDA kernels compiled successfully${NC}" 65 | } 66 | 67 | run_benchmark() { 68 | echo -e "${BLUE}[2/4] Running performance benchmark...${NC}" 69 | 70 | # Launch benchmark with the specified configuration 71 | python -c " 72 | import sys 73 | from benchmark import BenchmarkConfig, GPUTaskQueueBenchmark 74 | 75 | config = BenchmarkConfig( 76 | batch_sizes=[${BATCH_SIZE}], 77 | input_dims=[${DIM}], 78 | output_dims=[${DIM}//2, ${DIM}], 79 | num_trials=${REPEATS}, 80 | run_baseline=True # Enable baseline comparison 81 | ) 82 | 83 | benchmark = GPUTaskQueueBenchmark(config) 84 | results = benchmark.run_comprehensive_benchmark() 85 | benchmark.print_summary() 86 | benchmark.plot_results('${RESULTS_DIR}/benchmark_plot.png') 87 | 88 | # Save results in CSV and JSON formats 89 | from utils import save_results_csv, save_results_json 90 | save_results_csv(results, '${RESULTS_DIR}/results.csv') 91 | save_results_json(results, '${RESULTS_DIR}/results.json') 92 | " || { 93 | echo -e "${RED}❌ Benchmark execution failed.${NC}" 94 | exit 1 95 | } 96 | 97 | echo -e "${GREEN}✅ Benchmark completed successfully${NC}" 98 | } 99 | 100 | generate_report() { 101 | echo -e "${BLUE}[3/4] Generating performance report...${NC}" 102 | 103 | python -c " 104 | import json 105 | from utils import format_performance_report 106 | 107 | with open('${RESULTS_DIR}/results.json', 'r') as f: 108 | data = json.load(f) 109 | 110 | from utils import format_performance_report, merge_results 111 | 112 | report = format_performance_report(merge_results(data['benchmark_results'])) 113 | print(report) 114 | 115 | with open('${RESULTS_DIR}/performance_report.txt', 'w') as f: 116 | f.write(report) 117 | " 118 | 119 | echo -e "${GREEN}✅ Report generated: ${RESULTS_DIR}/performance_report.txt${NC}" 120 | } 121 | 122 | finalize() { 123 | echo -e "${BLUE}[4/4] Finalizing results...${NC}" 124 | 125 | # List generated files 126 | echo -e "\n${GREEN}📊 Generated Files:${NC}" 127 | ls -la ${RESULTS_DIR}/ | grep -E '\.(csv|json|png|txt)$' | while read line; do 128 | echo -e " ${YELLOW}•${NC} $line" 129 | done 130 | 131 | # Quick performance summary 132 | if [ -f "${RESULTS_DIR}/results.json" ]; then 133 | echo -e "\n${GREEN}⚡ Quick Summary:${NC}" 134 | python -c " 135 | import json 136 | with open('${RESULTS_DIR}/results.json', 'r') as f: 137 | data = json.load(f) 138 | 139 | # Handle new nested result format 140 | if 'optimized' in data: 141 | results = data['optimized'] 142 | speedup_data = data.get('speedup', {}) 143 | else: 144 | results = data.get('results', {}) 145 | speedup_data = {} 146 | 147 | if results: 148 | best_cuda = min([v for v in results.values() if 'CUDA' in v.get('kernel', '')], 149 | key=lambda x: x['median_ms'], default=None) 150 | if best_cuda: 151 | print(f' Best CUDA Kernel: {best_cuda[\"median_ms\"]:.3f}ms median latency') 152 | print(f' Throughput: {1000/best_cuda[\"median_ms\"]:.0f} ops/sec') 153 | 154 | # Show speedup if available 155 | if speedup_data: 156 | max_speedup = max([s['speedup_median'] for s in speedup_data.values()]) 157 | print(f' Maximum Speedup: {max_speedup:.1f}x over baseline') 158 | " 159 | fi 160 | 161 | echo -e "\n${GREEN}🎯 Benchmark completed! Check ${RESULTS_DIR}/ for detailed results.${NC}" 162 | } 163 | 164 | # Error handler 165 | handle_error() { 166 | echo -e "\n${RED}❌ Error occurred. Cleaning up...${NC}" 167 | cleanup_build > /dev/null 2>&1 || true 168 | exit 1 169 | } 170 | 171 | # Set error trap 172 | trap handle_error ERR 173 | 174 | # Main execution 175 | main() { 176 | print_header 177 | cleanup_build 178 | compile_kernels 179 | run_benchmark 180 | generate_report 181 | finalize 182 | } 183 | 184 | # Run with usage info 185 | if [[ "$1" == "--help" ]] || [[ "$1" == "-h" ]]; then 186 | echo "Usage: $0 [BATCH_SIZE] [DIM] [REPEATS]" 187 | echo "Example: $0 64 128 200" 188 | echo "Defaults: batch_size=32, dim=64, repeats=100" 189 | exit 0 190 | fi 191 | 192 | main "$@" 193 | 194 | -------------------------------------------------------------------------------- /kernels.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #ifdef _WIN32 5 | #define NVTX_DISABLE 6 | #endif 7 | 8 | #ifdef NVTX_DISABLE 9 | #define nvtxRangePush(name) 10 | #define nvtxRangePop() 11 | #else 12 | #include 13 | #endif 14 | 15 | namespace cg = cooperative_groups; 16 | 17 | // Keep the winning GEMV kernel exactly as it was 18 | __global__ void batched_gemv_kernel( 19 | const float* __restrict__ weights, 20 | const float* __restrict__ inputs, 21 | float* __restrict__ outputs, 22 | int batch_size, 23 | int input_dim, 24 | int output_dim 25 | ) { 26 | int batch_idx = blockIdx.x; 27 | int output_idx = threadIdx.x; 28 | 29 | if (batch_idx >= batch_size || output_idx >= output_dim) return; 30 | 31 | extern __shared__ float shared_input[]; 32 | auto block = cg::this_thread_block(); 33 | 34 | const float* input_ptr = inputs + batch_idx * input_dim; 35 | int tid = threadIdx.x; 36 | int num_threads = blockDim.x; 37 | 38 | if ((input_dim & 3) == 0 && ((uintptr_t)input_ptr & 15) == 0) { 39 | float4* shared_input4 = (float4*)shared_input; 40 | const float4* input_ptr4 = (const float4*)input_ptr; 41 | 42 | for (int i = tid; i * 4 < input_dim; i += num_threads) { 43 | if (i * 4 < input_dim) { 44 | shared_input4[i] = __ldg(&input_ptr4[i]); 45 | } 46 | } 47 | } else { 48 | for (int i = tid; i < input_dim; i += num_threads) { 49 | shared_input[i] = __ldg(&input_ptr[i]); 50 | } 51 | } 52 | __syncthreads(); 53 | 54 | float result = 0.0f; 55 | const float* weight_row = weights + batch_idx * input_dim * output_dim + output_idx; 56 | 57 | int i = 0; 58 | for (; i <= input_dim - 8; i += 8) { 59 | result += shared_input[i] * weight_row[i * output_dim] + 60 | shared_input[i+1] * weight_row[(i+1) * output_dim] + 61 | shared_input[i+2] * weight_row[(i+2) * output_dim] + 62 | shared_input[i+3] * weight_row[(i+3) * output_dim] + 63 | shared_input[i+4] * weight_row[(i+4) * output_dim] + 64 | shared_input[i+5] * weight_row[(i+5) * output_dim] + 65 | shared_input[i+6] * weight_row[(i+6) * output_dim] + 66 | shared_input[i+7] * weight_row[(i+7) * output_dim]; 67 | } 68 | for (; i < input_dim; i++) { 69 | result += shared_input[i] * weight_row[i * output_dim]; 70 | } 71 | 72 | outputs[batch_idx * output_dim + output_idx] = result; 73 | } 74 | 75 | // Back to working softmax with just block size tweak 76 | __global__ void batched_softmax_kernel( 77 | const float* __restrict__ inputs, 78 | float* __restrict__ outputs, 79 | int batch_size, 80 | int dim 81 | ) { 82 | int batch_idx = blockIdx.x; 83 | int tid = threadIdx.x; 84 | 85 | if (batch_idx >= batch_size) return; 86 | 87 | extern __shared__ float sdata[]; 88 | const float* input_batch = inputs + batch_idx * dim; 89 | float* output_batch = outputs + batch_idx * dim; 90 | 91 | // Simple max reduction 92 | float local_max = -INFINITY; 93 | for (int i = tid; i < dim; i += blockDim.x) { 94 | local_max = fmaxf(local_max, input_batch[i]); 95 | } 96 | 97 | sdata[tid] = local_max; 98 | __syncthreads(); 99 | 100 | for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { 101 | if (tid < stride) { 102 | sdata[tid] = fmaxf(sdata[tid], sdata[tid + stride]); 103 | } 104 | __syncthreads(); 105 | } 106 | float global_max = sdata[0]; 107 | 108 | // Simple exp and sum 109 | float local_sum = 0.0f; 110 | for (int i = tid; i < dim; i += blockDim.x) { 111 | float exp_val = __expf(input_batch[i] - global_max); 112 | output_batch[i] = exp_val; 113 | local_sum += exp_val; 114 | } 115 | 116 | sdata[tid] = local_sum; 117 | __syncthreads(); 118 | 119 | for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { 120 | if (tid < stride) { 121 | sdata[tid] += sdata[tid + stride]; 122 | } 123 | __syncthreads(); 124 | } 125 | float total_sum = sdata[0]; 126 | 127 | // Simple normalization 128 | float inv_sum = __fdividef(1.0f, total_sum); 129 | for (int i = tid; i < dim; i += blockDim.x) { 130 | output_batch[i] *= inv_sum; 131 | } 132 | } 133 | 134 | // Back to simple price vectors 135 | __global__ void process_price_vectors_kernel( 136 | const float* __restrict__ prices, 137 | const float* __restrict__ weights, 138 | float* __restrict__ features, 139 | int batch_size, 140 | int n_assets, 141 | int n_features 142 | ) { 143 | int batch_idx = blockIdx.x; 144 | int feature_idx = threadIdx.x; 145 | 146 | if (batch_idx >= batch_size || feature_idx >= n_features) return; 147 | 148 | const float* price_vector = prices + batch_idx * n_assets; 149 | float result = 0.0f; 150 | 151 | // Simple dot product with unrolling 152 | int i = 0; 153 | for (; i <= n_assets - 4; i += 4) { 154 | result += price_vector[i] * weights[i * n_features + feature_idx] + 155 | price_vector[i+1] * weights[(i+1) * n_features + feature_idx] + 156 | price_vector[i+2] * weights[(i+2) * n_features + feature_idx] + 157 | price_vector[i+3] * weights[(i+3) * n_features + feature_idx]; 158 | } 159 | for (; i < n_assets; i++) { 160 | result += price_vector[i] * weights[i * n_features + feature_idx]; 161 | } 162 | 163 | features[batch_idx * n_features + feature_idx] = result; 164 | } 165 | 166 | extern "C" { 167 | 168 | void launch_batched_gemv( 169 | const float* weights, const float* inputs, float* outputs, 170 | int batch_size, int input_dim, int output_dim, 171 | cudaStream_t stream = 0 172 | ) { 173 | nvtxRangePush("batched_gemv"); 174 | 175 | dim3 grid(batch_size); 176 | dim3 block(min(output_dim, 1024)); 177 | int shared_mem = ((input_dim * sizeof(float) + 127) & ~127); 178 | 179 | batched_gemv_kernel<<>>( 180 | weights, inputs, outputs, batch_size, input_dim, output_dim 181 | ); 182 | 183 | nvtxRangePop(); 184 | } 185 | 186 | void launch_batched_softmax( 187 | const float* inputs, float* outputs, 188 | int batch_size, int dim, 189 | cudaStream_t stream = 0 190 | ) { 191 | nvtxRangePush("batched_softmax"); 192 | 193 | dim3 grid(batch_size); 194 | dim3 block(64); // Try smaller block size 195 | int shared_mem = block.x * sizeof(float); 196 | 197 | batched_softmax_kernel<<>>( 198 | inputs, outputs, batch_size, dim 199 | ); 200 | 201 | nvtxRangePop(); 202 | } 203 | 204 | void launch_price_vector_processing( 205 | const float* prices, const float* weights, float* features, 206 | int batch_size, int n_assets, int n_features, 207 | cudaStream_t stream = 0 208 | ) { 209 | nvtxRangePush("price_vectors"); 210 | 211 | dim3 grid(batch_size); 212 | dim3 block(min(n_features, 256)); // Try smaller block 213 | 214 | process_price_vectors_kernel<<>>( 215 | prices, weights, features, batch_size, n_assets, n_features 216 | ); 217 | 218 | nvtxRangePop(); 219 | } 220 | 221 | } 222 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sub-Millisecond GPU Task Queue for Real-Time Inference 2 | 3 | This project implements a high-performance, sub-millisecond GPU task queue designed for real-time inference and other latency-sensitive workloads. By leveraging custom CUDA kernels, it provides significant performance improvements over standard PyTorch implementations for specific operations like batched GEMV (Generalized Matrix-Vector Multiplication), softmax, and financial feature engineering. 4 | 5 | Check out this Medium Article for detailed Technical Writeup [Medium](https://medium.com/@shreshthkapai/sub-millisecond-gpu-task-queue-breaking-pytorchs-latency-bottleneck-b6f3d3f2e895) 6 | 7 | The core of the repository is a C++/CUDA extension for PyTorch, which is benchmarked to demonstrate its low-latency capabilities. The system is designed for scenarios where minimizing computational latency is critical, such as in high-frequency trading, real-time bidding, or interactive services. 8 | 9 | ## Table of Contents 10 | - [Overview](#overview) 11 | - [Key Features](#key-features) 12 | - [Performance Highlights](#performance-highlights) 13 | - [Technical Specifications](#technical-specifications) 14 | - [Architecture](#architecture) 15 | - [Repository Structure](#repository-structure) 16 | - [Building the Extension](#building-the-extension) 17 | - [Running Benchmarks](#running-benchmarks) 18 | - [Core Components](#core-components) 19 | - [Dependencies](#dependencies) 20 | - [Contributing](#contributing) 21 | - [License](#license) 22 | 23 | ## Overview 24 | This project implements a high-performance, sub-millisecond GPU task queue designed for real-time inference and other latency-sensitive workloads. By leveraging custom CUDA kernels, it provides significant performance improvements over standard PyTorch implementations for specific operations like batched GEMV (Generalized Matrix-Vector Multiplication), softmax, and financial feature engineering. 25 | 26 | The core of the repository is a C++/CUDA extension for PyTorch, which is benchmarked to demonstrate its low-latency capabilities. The system is designed for scenarios where minimizing computational latency is critical, such as in high-frequency trading, real-time bidding, or interactive services. 27 | 28 | ## Key Features 29 | - **Custom CUDA Kernels**: Highly-optimized kernels for common machine learning and financial tasks (batched_gemv, batched_softmax, process_price_vectors). 30 | - **Low-Level Optimization**: Utilizes shared memory, vectorized operations, and manual loop unrolling for maximum instruction-level parallelism. 31 | - **Asynchronous Execution**: Leverages CUDA streams for efficient data transfer and kernel execution, overlapping computation and memory copies. 32 | - **PyTorch Integration**: Seamlessly exposed to Python via Pybind11 and PyTorch's C++ extension mechanism, allowing for easy integration into existing ML pipelines. 33 | - **Comprehensive Benchmarking**: Includes a robust benchmarking suite to measure and analyze kernel performance, including latency statistics (median, P95, P99) and throughput. 34 | - **Automated Tooling**: A `run_all.sh` script automates the entire process of cleaning, building, benchmarking, and reporting. 35 | 36 | ## Performance Highlights 37 | The benchmarks below were executed on the hardware specified in the Technical Specifications section. They demonstrate the exceptional performance of the custom CUDA kernels, achieving sub-millisecond latencies across various configurations. 38 | 39 | ``` 40 | 🚀 GPU TASK QUEUE PERFORMANCE REPORT 41 | ================================================================================ 42 | 43 | 🏆 Best Performer: gemv_b32_i64_o32 (0.008ms median) 44 | 🐌 Worst Performer: price_b32_a64_f32 (0.043ms median) 45 | ⚡ Average Speedup: 4.5x 46 | 🚀 Maximum Speedup: 7.5x 47 | 48 | 📊 DETAILED RESULTS: 49 | -------------------------------------------------------------------------------- 50 | 51 | gemv_b32_i64_o32: 52 | Latency: 0.008ms (median), 0.032ms (P95) 53 | Throughput: 131579 ops/sec 54 | 🚀 Speedup: 7.5x (650.1% improvement) 55 | Stability: ±0.009ms std dev 56 | 57 | gemv_b32_i64_o64: 58 | Latency: 0.010ms (median), 0.078ms (P95) 59 | Throughput: 96154 ops/sec 60 | 🚀 Speedup: 5.1x (414.5% improvement) 61 | Stability: ±0.020ms std dev 62 | 63 | softmax_b32_d64: 64 | Latency: 0.040ms (median), 0.123ms (P95) 65 | Throughput: 24733 ops/sec 66 | 🚀 Speedup: 1.2x (23.2% improvement) 67 | Stability: ±0.145ms std dev 68 | 69 | price_b32_a64_f32: 70 | Latency: 0.043ms (median), 0.069ms (P95) 71 | Throughput: 23391 ops/sec 72 | 🚀 Speedup: 4.2x (319.6% improvement) 73 | Stability: ±0.014ms std dev 74 | ``` 75 | 76 | ## Technical Specifications 77 | The performance benchmarks were conducted on the following hardware: 78 | 79 | - **GPU**: NVIDIA GeForce GTX 1650 80 | - **VRAM**: 4 GB GDDR6 81 | - **CUDA Compute Capability**: 7.5 82 | 83 | ## Architecture 84 | The project is composed of several key components that work together: 85 | 86 | - **CUDA Kernels (`kernels.cu`)**: The C++ source file containing the low-level CUDA C++ code for the `batched_gemv`, `batched_softmax`, and `process_price_vectors` kernels. This is where the core GPU computations are defined. 87 | - **Python Bindings (`python_bindings.cpp`)**: This file uses Pybind11 to create a bridge between the C++ functions that launch the CUDA kernels and the Python interpreter. It converts PyTorch tensors into C++ pointers that can be used by the kernels. 88 | - **Setup Script (`setup.py`)**: A standard Python setuptools script used to compile the CUDA and C++ code into a Python extension module named `cuda_task_queue`. 89 | - **Benchmarking Framework (`benchmark.py`)**: A Python class that orchestrates the performance tests. It handles tensor allocation, GPU synchronization, and runs both the custom CUDA kernels and their PyTorch equivalents for comparison. 90 | - **Utility Functions (`utils.py`)**: Helper functions for timing, statistical computation, plotting results, and generating reports. 91 | - **Automation Script (`run_all.sh`)**: A bash script that automates the entire workflow: cleaning old builds, compiling the CUDA extension, running the benchmarks, and generating a final performance report. 92 | 93 | ## Repository Structure 94 | ``` 95 | . 96 | ├── kernels.cu 97 | ├── python_bindings.cpp 98 | ├── setup.py 99 | ├── benchmark.py 100 | ├── utils.py 101 | └── runa_all.sh 102 | ``` 103 | 104 | ## Building the Extension 105 | The CUDA kernels are compiled into a Python module using `setuptools`. The `run_all.sh` script automates this process. 106 | 107 | ### Prerequisites: 108 | - NVIDIA CUDA Toolkit (`nvcc`) 109 | - PyTorch with CUDA support 110 | - Pybind11 111 | 112 | ### Build Steps: 113 | 1. Navigate to the project root directory. 114 | 2. Run the setup script: 115 | ```bash 116 | python setup.py build_ext --inplace 117 | ``` 118 | This command will compile the `kernels.cu` and `python_bindings.cpp` files and create a `cuda_task_queue*.so` file in the current directory, which is the importable Python module. 119 | 120 | ## Running Benchmarks 121 | The `run_all.sh` script is the easiest way to build the extension and run the full benchmark suite. 122 | 123 | ### Usage: 124 | ```bash 125 | # Run with default parameters (batch_size=32, dim=64, repeats=100) 126 | ./run_all.sh 127 | 128 | # Run with custom parameters 129 | # Usage: ./run_all.sh [BATCH_SIZE] [DIM] [REPEATS] 130 | ./run_all.sh 64 128 200 131 | ``` 132 | The script will: 133 | 1. Clean any previous builds. 134 | 2. Compile the CUDA kernels. 135 | 3. Run the benchmark suite with the specified parameters. 136 | 4. Print a summary to the console and save detailed results and plots to the `./results/` directory. 137 | 138 | ## Core Components 139 | 140 | ### CUDA Kernels 141 | - **`batched_gemv_kernel`**: Optimized for small vectors, this kernel uses shared memory to cache input vectors, enabling coalesced memory access and reducing global memory bandwidth consumption. 142 | - **`batched_softmax_kernel`**: Implements a single-pass parallel reduction algorithm to find the max value and sum for the softmax calculation, significantly improving efficiency over naive approaches. 143 | - **`process_price_vectors_kernel`**: A high-throughput kernel for financial applications, performing a batched dot product with vectorized memory access and manual unrolling to maximize throughput. 144 | 145 | ### Benchmarking 146 | The `GPUTaskQueueBenchmark` class in `benchmark.py` provides a structured way to evaluate performance. It uses CUDA events for precise timing (`torch.cuda.Event`) and pre-allocates pinned memory (`pin_memory=True`) to enable fast, asynchronous data transfers between the host and the GPU. 147 | 148 | ## Dependencies 149 | - `torch >= 1.12.0` 150 | - `numpy >= 1.21.0` 151 | - `pybind11 >= 2.6.0` 152 | - `matplotlib` 153 | - `seaborn` 154 | 155 | ## Contributing 156 | Contributions are welcome! Please feel free to submit a pull request or open an issue if you have suggestions for improvements. 157 | 158 | When contributing to this repository, please first discuss the change you wish to make via issue, email, or any other method with the owners of this repository before making a change. 159 | 160 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import time 4 | import statistics 5 | from dataclasses import dataclass 6 | from typing import List, Tuple, Dict 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | import cuda_task_queue 10 | 11 | try: 12 | import cuda_task_queue 13 | CUDA_AVAILABLE = True 14 | except ImportError: 15 | CUDA_AVAILABLE = False 16 | print("Warning: CUDA extension not compiled. Run 'python setup.py build_ext --inplace'") 17 | 18 | @dataclass 19 | class BenchmarkConfig: 20 | # Configuration for benchmarking parameters 21 | batch_sizes: List[int] = None 22 | input_dims: List[int] = None 23 | output_dims: List[int] = None 24 | num_warmup: int = 50 25 | num_trials: int = 1000 26 | device: str = "cuda:0" 27 | run_baseline: bool = True # Enable baseline comparison 28 | 29 | def __post_init__(self): 30 | # Set default values if not provided 31 | if self.batch_sizes is None: 32 | self.batch_sizes = [8, 16, 32, 64, 128] 33 | if self.input_dims is None: 34 | self.input_dims = [16, 32, 64, 128] 35 | if self.output_dims is None: 36 | self.output_dims = [16, 32, 64] 37 | 38 | class GPUTaskQueueBenchmark: 39 | def __init__(self, config: BenchmarkConfig): 40 | # Initialize benchmark with configuration and resources 41 | self.config = config 42 | self.device = torch.device(config.device) 43 | self.stream = torch.cuda.Stream() 44 | self.results = {} 45 | self.baseline_results = {} # Store baseline results 46 | self.speedup_results = {} # Store speedup calculations 47 | self.start_event = torch.cuda.Event(enable_timing=True) 48 | self.end_event = torch.cuda.Event(enable_timing=True) 49 | 50 | def allocate_pinned_tensors(self, batch_size: int, input_dim: int, output_dim: int) -> Dict: 51 | # Allocate host (pinned) and device tensors for data transfer and computation 52 | return { 53 | 'weights_host': torch.randn(batch_size, input_dim, output_dim, pin_memory=True), 54 | 'inputs_host': torch.randn(batch_size, input_dim, pin_memory=True), 55 | 'outputs_host': torch.zeros(batch_size, output_dim, pin_memory=True), 56 | 'weights_gpu': torch.empty(batch_size, input_dim, output_dim, device=self.device), 57 | 'inputs_gpu': torch.empty(batch_size, input_dim, device=self.device), 58 | 'outputs_gpu': torch.empty(batch_size, output_dim, device=self.device), 59 | } 60 | 61 | def async_h2d_copy(self, tensors: Dict): 62 | # Asynchronously copy host data to device using a CUDA stream 63 | with torch.cuda.stream(self.stream): 64 | tensors['weights_gpu'].copy_(tensors['weights_host'], non_blocking=True) 65 | tensors['inputs_gpu'].copy_(tensors['inputs_host'], non_blocking=True) 66 | 67 | def benchmark_gemv_kernel(self, batch_size: int, input_dim: int, output_dim: int) -> Dict: 68 | # Benchmark the custom CUDA GEMV kernel or fallback to PyTorch if extension unavailable 69 | if not CUDA_AVAILABLE: 70 | return self._fallback_pytorch_gemv(batch_size, input_dim, output_dim) 71 | 72 | tensors = self.allocate_pinned_tensors(batch_size, input_dim, output_dim) 73 | latencies = [] 74 | 75 | # Warmup phase to stabilize performance 76 | for _ in range(self.config.num_warmup): 77 | self.async_h2d_copy(tensors) 78 | self.stream.synchronize() 79 | cuda_task_queue.batched_gemv( 80 | tensors['weights_gpu'], 81 | tensors['inputs_gpu'], 82 | tensors['outputs_gpu'] 83 | ) 84 | torch.cuda.synchronize() 85 | 86 | # Benchmarking phase with timing 87 | for _ in range(self.config.num_trials): 88 | self.async_h2d_copy(tensors) 89 | self.stream.synchronize() 90 | 91 | self.start_event.record() 92 | cuda_task_queue.batched_gemv( 93 | tensors['weights_gpu'], 94 | tensors['inputs_gpu'], 95 | tensors['outputs_gpu'] 96 | ) 97 | self.end_event.record() 98 | 99 | torch.cuda.synchronize() 100 | latency_ms = self.start_event.elapsed_time(self.end_event) 101 | latencies.append(latency_ms) 102 | 103 | return self._compute_stats(latencies, "CUDA_GEMV") 104 | 105 | def benchmark_softmax_kernel(self, batch_size: int, dim: int) -> Dict: 106 | # Benchmark the custom CUDA softmax kernel or fallback to PyTorch if extension unavailable 107 | if not CUDA_AVAILABLE: 108 | return self._fallback_pytorch_softmax(batch_size, dim) 109 | 110 | inputs_host = torch.randn(batch_size, dim, pin_memory=True) 111 | outputs_host = torch.zeros(batch_size, dim, pin_memory=True) 112 | inputs_gpu = torch.empty(batch_size, dim, device=self.device) 113 | outputs_gpu = torch.empty(batch_size, dim, device=self.device) 114 | 115 | latencies = [] 116 | 117 | # Warmup phase 118 | for _ in range(self.config.num_warmup): 119 | inputs_gpu.copy_(inputs_host, non_blocking=True) 120 | torch.cuda.synchronize() 121 | cuda_task_queue.batched_softmax(inputs_gpu, outputs_gpu) 122 | torch.cuda.synchronize() 123 | 124 | # Benchmarking phase with timing 125 | for _ in range(self.config.num_trials): 126 | inputs_gpu.copy_(inputs_host, non_blocking=True) 127 | torch.cuda.synchronize() 128 | 129 | self.start_event.record() 130 | cuda_task_queue.batched_softmax(inputs_gpu, outputs_gpu) 131 | self.end_event.record() 132 | 133 | torch.cuda.synchronize() 134 | latencies.append(self.start_event.elapsed_time(self.end_event)) 135 | 136 | return self._compute_stats(latencies, "CUDA_Softmax") 137 | 138 | def benchmark_price_vectors(self, batch_size: int, n_assets: int, n_features: int) -> Dict: 139 | # Benchmark the custom CUDA kernel for price vector processing or fallback to PyTorch if extension unavailable 140 | if not CUDA_AVAILABLE: 141 | return self._fallback_pytorch_price_vectors(batch_size, n_assets, n_features) 142 | 143 | prices_host = torch.randn(batch_size, n_assets, pin_memory=True) * 100 144 | weights_host = torch.randn(n_assets, n_features, pin_memory=True) 145 | features_host = torch.zeros(batch_size, n_features, pin_memory=True) 146 | 147 | prices_gpu = torch.empty(batch_size, n_assets, device=self.device) 148 | weights_gpu = torch.empty(n_assets, n_features, device=self.device) 149 | features_gpu = torch.empty(batch_size, n_features, device=self.device) 150 | 151 | latencies = [] 152 | 153 | # Warmup phase 154 | for _ in range(self.config.num_warmup): 155 | prices_gpu.copy_(prices_host, non_blocking=True) 156 | weights_gpu.copy_(weights_host, non_blocking=True) 157 | torch.cuda.synchronize() 158 | cuda_task_queue.process_price_vectors(prices_gpu, weights_gpu, features_gpu) 159 | torch.cuda.synchronize() 160 | 161 | # Benchmarking phase with timing 162 | for _ in range(self.config.num_trials): 163 | prices_gpu.copy_(prices_host, non_blocking=True) 164 | weights_gpu.copy_(weights_host, non_blocking=True) 165 | torch.cuda.synchronize() 166 | 167 | self.start_event.record() 168 | cuda_task_queue.process_price_vectors(prices_gpu, weights_gpu, features_gpu) 169 | self.end_event.record() 170 | 171 | torch.cuda.synchronize() 172 | latencies.append(self.start_event.elapsed_time(self.end_event)) 173 | 174 | return self._compute_stats(latencies, "CUDA_PriceVectors") 175 | 176 | def _fallback_pytorch_gemv(self, batch_size: int, input_dim: int, output_dim: int) -> Dict: 177 | # PyTorch baseline implementation for GEMV (batched matrix-vector multiplication) 178 | weights = torch.randn(batch_size, input_dim, output_dim, device=self.device) 179 | inputs = torch.randn(batch_size, input_dim, device=self.device) 180 | 181 | latencies = [] 182 | 183 | for _ in range(self.config.num_warmup): 184 | torch.bmm(inputs.unsqueeze(1), weights).squeeze(1) 185 | torch.cuda.synchronize() 186 | 187 | for _ in range(self.config.num_trials): 188 | self.start_event.record() 189 | torch.bmm(inputs.unsqueeze(1), weights).squeeze(1) 190 | self.end_event.record() 191 | 192 | torch.cuda.synchronize() 193 | latencies.append(self.start_event.elapsed_time(self.end_event)) 194 | 195 | return self._compute_stats(latencies, "PyTorch_GEMV") 196 | 197 | def _fallback_pytorch_softmax(self, batch_size: int, dim: int) -> Dict: 198 | # PyTorch baseline implementation for softmax 199 | inputs = torch.randn(batch_size, dim, device=self.device) 200 | 201 | latencies = [] 202 | 203 | for _ in range(self.config.num_warmup): 204 | torch.softmax(inputs, dim=1) 205 | torch.cuda.synchronize() 206 | 207 | for _ in range(self.config.num_trials): 208 | self.start_event.record() 209 | torch.softmax(inputs, dim=1) 210 | self.end_event.record() 211 | 212 | torch.cuda.synchronize() 213 | latencies.append(self.start_event.elapsed_time(self.end_event)) 214 | 215 | return self._compute_stats(latencies, "PyTorch_Softmax") 216 | 217 | def _fallback_pytorch_price_vectors(self, batch_size: int, n_assets: int, n_features: int) -> Dict: 218 | # PyTorch baseline implementation for price vector processing 219 | prices = torch.randn(batch_size, n_assets, device=self.device) * 100 220 | weights = torch.randn(n_assets, n_features, device=self.device) 221 | 222 | latencies = [] 223 | 224 | for _ in range(self.config.num_warmup): 225 | torch.mm(prices, weights) 226 | torch.cuda.synchronize() 227 | 228 | for _ in range(self.config.num_trials): 229 | self.start_event.record() 230 | torch.mm(prices, weights) 231 | self.end_event.record() 232 | 233 | torch.cuda.synchronize() 234 | latencies.append(self.start_event.elapsed_time(self.end_event)) 235 | 236 | return self._compute_stats(latencies, "PyTorch_PriceVectors") 237 | 238 | # FIXED: Optimized PyTorch baselines instead of Python loops 239 | def _benchmark_baseline_gemv(self, batch_size: int, input_dim: int, output_dim: int) -> Dict: 240 | """Benchmark optimized PyTorch GEMV implementation""" 241 | weights = torch.randn(batch_size, input_dim, output_dim, device=self.device) 242 | inputs = torch.randn(batch_size, input_dim, device=self.device) 243 | 244 | latencies = [] 245 | 246 | # Warmup 247 | for _ in range(self.config.num_warmup): 248 | # Use optimized PyTorch batched matrix multiplication 249 | torch.bmm(inputs.unsqueeze(1), weights).squeeze(1) 250 | torch.cuda.synchronize() 251 | 252 | # Benchmark 253 | for _ in range(self.config.num_trials): 254 | self.start_event.record() 255 | torch.bmm(inputs.unsqueeze(1), weights).squeeze(1) 256 | self.end_event.record() 257 | 258 | torch.cuda.synchronize() 259 | latencies.append(self.start_event.elapsed_time(self.end_event)) 260 | 261 | return self._compute_stats(latencies, "Baseline_GEMV") 262 | 263 | def _benchmark_baseline_softmax(self, batch_size: int, dim: int) -> Dict: 264 | """Benchmark optimized PyTorch softmax implementation""" 265 | inputs = torch.randn(batch_size, dim, device=self.device) 266 | 267 | latencies = [] 268 | 269 | # Warmup 270 | for _ in range(self.config.num_warmup): 271 | # Use optimized PyTorch softmax 272 | torch.softmax(inputs, dim=-1) 273 | torch.cuda.synchronize() 274 | 275 | # Benchmark 276 | for _ in range(self.config.num_trials): 277 | self.start_event.record() 278 | torch.softmax(inputs, dim=-1) 279 | self.end_event.record() 280 | 281 | torch.cuda.synchronize() 282 | latencies.append(self.start_event.elapsed_time(self.end_event)) 283 | 284 | return self._compute_stats(latencies, "Baseline_Softmax") 285 | 286 | def _benchmark_baseline_price_vectors(self, batch_size: int, n_assets: int, n_features: int) -> Dict: 287 | """Benchmark optimized PyTorch price vector processing""" 288 | prices = torch.randn(batch_size, n_assets, device=self.device) * 100 289 | weights = torch.randn(n_assets, n_features, device=self.device) 290 | 291 | latencies = [] 292 | 293 | # Warmup 294 | for _ in range(self.config.num_warmup): 295 | # Use optimized PyTorch matrix multiplication 296 | torch.mm(prices, weights) 297 | torch.cuda.synchronize() 298 | 299 | # Benchmark 300 | for _ in range(self.config.num_trials): 301 | self.start_event.record() 302 | torch.mm(prices, weights) 303 | self.end_event.record() 304 | 305 | torch.cuda.synchronize() 306 | latencies.append(self.start_event.elapsed_time(self.end_event)) 307 | 308 | return self._compute_stats(latencies, "Baseline_PriceVectors") 309 | 310 | def _compute_stats(self, latencies: List[float], kernel_name: str) -> Dict: 311 | # Compute statistical metrics from latency measurements 312 | return { 313 | 'kernel': kernel_name, 314 | 'mean_ms': statistics.mean(latencies), 315 | 'median_ms': statistics.median(latencies), 316 | 'p95_ms': np.percentile(latencies, 95), 317 | 'p99_ms': np.percentile(latencies, 99), 318 | 'min_ms': min(latencies), 319 | 'max_ms': max(latencies), 320 | 'std_ms': statistics.stdev(latencies), 321 | 'samples': len(latencies) 322 | } 323 | 324 | def _calculate_speedup(self, optimized_stats: Dict, baseline_stats: Dict) -> Dict: 325 | """Calculate speedup metrics between optimized and baseline implementations""" 326 | speedup_median = baseline_stats['median_ms'] / optimized_stats['median_ms'] 327 | speedup_mean = baseline_stats['mean_ms'] / optimized_stats['mean_ms'] 328 | speedup_p95 = baseline_stats['p95_ms'] / optimized_stats['p95_ms'] 329 | 330 | return { 331 | 'speedup_median': speedup_median, 332 | 'speedup_mean': speedup_mean, 333 | 'speedup_p95': speedup_p95, 334 | 'baseline_median_ms': baseline_stats['median_ms'], 335 | 'optimized_median_ms': optimized_stats['median_ms'], 336 | 'improvement_pct': ((speedup_median - 1.0) * 100) 337 | } 338 | 339 | def run_comprehensive_benchmark(self) -> Dict: 340 | # Execute all benchmarks across parameter sweeps and collect results 341 | results = {} 342 | baseline_results = {} 343 | speedup_results = {} 344 | 345 | print("🚀 Starting GPU Task Queue Benchmark") 346 | print(f"Device: {self.device}") 347 | print(f"Trials per config: {self.config.num_trials}") 348 | print(f"Running baseline comparison: {self.config.run_baseline}") 349 | 350 | # Run GEMV benchmarks for all parameter combinations 351 | for batch_size in self.config.batch_sizes: 352 | for input_dim in self.config.input_dims: 353 | for output_dim in self.config.output_dims: 354 | key = f"gemv_b{batch_size}_i{input_dim}_o{output_dim}" 355 | print(f"⚡ Benchmarking {key}...") 356 | 357 | # Run optimized version 358 | results[key] = self.benchmark_gemv_kernel(batch_size, input_dim, output_dim) 359 | 360 | # Run baseline if enabled 361 | if self.config.run_baseline: 362 | print(f"📊 Running baseline for {key}...") 363 | baseline_results[key] = self._benchmark_baseline_gemv(batch_size, input_dim, output_dim) 364 | speedup_results[key] = self._calculate_speedup(results[key], baseline_results[key]) 365 | 366 | # Run softmax benchmarks for all relevant dimensions 367 | for batch_size in self.config.batch_sizes: 368 | for dim in self.config.input_dims: 369 | key = f"softmax_b{batch_size}_d{dim}" 370 | print(f"⚡ Benchmarking {key}...") 371 | 372 | # Run optimized version 373 | results[key] = self.benchmark_softmax_kernel(batch_size, dim) 374 | 375 | # Run baseline if enabled 376 | if self.config.run_baseline: 377 | print(f"📊 Running baseline for {key}...") 378 | baseline_results[key] = self._benchmark_baseline_softmax(batch_size, dim) 379 | speedup_results[key] = self._calculate_speedup(results[key], baseline_results[key]) 380 | 381 | # Run price vector benchmarks for specified settings 382 | for batch_size in self.config.batch_sizes: 383 | key = f"price_b{batch_size}_a64_f32" 384 | print(f"⚡ Benchmarking {key}...") 385 | 386 | # Run optimized version 387 | results[key] = self.benchmark_price_vectors(batch_size, 64, 32) 388 | 389 | # Run baseline if enabled 390 | if self.config.run_baseline: 391 | print(f"📊 Running baseline for {key}...") 392 | baseline_results[key] = self._benchmark_baseline_price_vectors(batch_size, 64, 32) 393 | speedup_results[key] = self._calculate_speedup(results[key], baseline_results[key]) 394 | 395 | self.results = results 396 | self.baseline_results = baseline_results 397 | self.speedup_results = speedup_results 398 | 399 | return { 400 | 'optimized': results, 401 | 'baseline': baseline_results, 402 | 'speedup': speedup_results 403 | } 404 | 405 | def plot_results(self, save_path: str = "benchmark_results.png"): 406 | # Generate and save plots summarizing benchmark results with baseline comparison 407 | if not self.results: 408 | print("No results to plot. Run benchmark first.") 409 | return 410 | 411 | # Enhanced plotting with baseline comparison 412 | fig_height = 12 if self.config.run_baseline else 10 413 | fig, axes = plt.subplots(3 if self.config.run_baseline else 2, 2, figsize=(15, fig_height)) 414 | fig.suptitle("GPU Task Queue Performance Analysis", fontsize=16) 415 | 416 | # Plot median latency for each kernel 417 | kernels = [r['kernel'] for r in self.results.values()] 418 | medians = [r['median_ms'] for r in self.results.values()] 419 | 420 | axes[0, 0].bar(range(len(kernels)), medians, color='skyblue', label='Optimized') 421 | 422 | # Add baseline bars if available 423 | if self.config.run_baseline and self.baseline_results: 424 | baseline_medians = [r['median_ms'] for r in self.baseline_results.values()] 425 | x_pos = np.arange(len(kernels)) 426 | axes[0, 0].bar(x_pos + 0.35, baseline_medians, width=0.35, color='coral', alpha=0.7, label='Baseline') 427 | axes[0, 0].set_xticks(x_pos + 0.175) 428 | 429 | axes[0, 0].set_title("Median Latency: Optimized vs Baseline") 430 | axes[0, 0].set_ylabel("Latency (ms)") 431 | axes[0, 0].set_xticklabels(kernels, rotation=45, ha='right') 432 | axes[0, 0].legend() 433 | 434 | # Plot P95 vs. Median latency to show latency distribution tails 435 | p95s = [r['p95_ms'] for r in self.results.values()] 436 | axes[0, 1].scatter(medians, p95s, alpha=0.7, color='coral', label='Optimized') 437 | 438 | if self.config.run_baseline and self.baseline_results: 439 | baseline_p95s = [r['p95_ms'] for r in self.baseline_results.values()] 440 | baseline_medians = [r['median_ms'] for r in self.baseline_results.values()] 441 | axes[0, 1].scatter(baseline_medians, baseline_p95s, alpha=0.7, color='red', 442 | marker='x', s=100, label='Baseline') 443 | 444 | axes[0, 1].plot([0, max(medians)], [0, max(medians)], 'k--', alpha=0.5) 445 | axes[0, 1].set_xlabel("Median Latency (ms)") 446 | axes[0, 1].set_ylabel("P95 Latency (ms)") 447 | axes[0, 1].set_title("Latency Tail Distribution") 448 | axes[0, 1].legend() 449 | 450 | # Speedup visualization 451 | if self.config.run_baseline and self.speedup_results: 452 | speedups = [s['speedup_median'] for s in self.speedup_results.values()] 453 | config_names = list(self.speedup_results.keys()) 454 | 455 | bars = axes[1, 0].bar(range(len(config_names)), speedups, color='green', alpha=0.7) 456 | axes[1, 0].axhline(y=1.0, color='red', linestyle='--', alpha=0.7, label='No improvement') 457 | axes[1, 0].set_title("Speedup (Baseline / Optimized)") 458 | axes[1, 0].set_ylabel("Speedup Factor") 459 | axes[1, 0].set_xticks(range(len(config_names))) 460 | axes[1, 0].set_xticklabels(config_names, rotation=45, ha='right') 461 | 462 | # Annotate bars with speedup values 463 | for bar, speedup in zip(bars, speedups): 464 | axes[1, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05, 465 | f'{speedup:.1f}x', ha='center', va='bottom', fontweight='bold') 466 | 467 | # Performance improvement percentage 468 | improvements = [s['improvement_pct'] for s in self.speedup_results.values()] 469 | axes[1, 1].bar(range(len(config_names)), improvements, color='purple', alpha=0.7) 470 | axes[1, 1].set_title("Performance Improvement (%)") 471 | axes[1, 1].set_ylabel("Improvement %") 472 | axes[1, 1].set_xticks(range(len(config_names))) 473 | axes[1, 1].set_xticklabels(config_names, rotation=45, ha='right') 474 | 475 | # Add third row for baseline comparison if available 476 | if len(axes) > 2: 477 | # Throughput comparison 478 | opt_throughput = [1000.0 / r['median_ms'] for r in self.results.values()] 479 | base_throughput = [1000.0 / r['median_ms'] for r in self.baseline_results.values()] 480 | 481 | x_pos = np.arange(len(kernels)) 482 | axes[2, 0].bar(x_pos - 0.2, opt_throughput, width=0.4, color='skyblue', label='Optimized') 483 | axes[2, 0].bar(x_pos + 0.2, base_throughput, width=0.4, color='coral', label='Baseline') 484 | axes[2, 0].set_title("Throughput Comparison") 485 | axes[2, 0].set_ylabel("Operations/sec") 486 | axes[2, 0].set_xticks(x_pos) 487 | axes[2, 0].set_xticklabels(kernels, rotation=45, ha='right') 488 | axes[2, 0].legend() 489 | 490 | # Summary metrics 491 | avg_speedup = np.mean(speedups) 492 | max_speedup = max(speedups) 493 | axes[2, 1].text(0.1, 0.8, f"Average Speedup: {avg_speedup:.1f}x", 494 | transform=axes[2, 1].transAxes, fontsize=14, fontweight='bold') 495 | axes[2, 1].text(0.1, 0.6, f"Maximum Speedup: {max_speedup:.1f}x", 496 | transform=axes[2, 1].transAxes, fontsize=14, fontweight='bold') 497 | axes[2, 1].text(0.1, 0.4, f"Configs Tested: {len(speedups)}", 498 | transform=axes[2, 1].transAxes, fontsize=12) 499 | axes[2, 1].set_title("Optimization Summary") 500 | axes[2, 1].axis('off') 501 | 502 | plt.tight_layout() 503 | plt.savefig(save_path, dpi=300, bbox_inches='tight') 504 | print(f"📊 Results saved to {save_path}") 505 | 506 | def print_summary(self): 507 | # Print a summary of benchmark results to the console with baseline comparison 508 | if not self.results: 509 | print("No results available. Run benchmark first.") 510 | return 511 | 512 | print("\n" + "="*80) 513 | print("🎯 BENCHMARK RESULTS SUMMARY") 514 | print("="*80) 515 | 516 | for key, stats in self.results.items(): 517 | print(f"\n{key}:") 518 | print(f" Optimized Kernel: {stats['kernel']}") 519 | print(f" Median: {stats['median_ms']:.3f}ms") 520 | print(f" P95: {stats['p95_ms']:.3f}ms") 521 | print(f" Mean: {stats['mean_ms']:.3f}ms ± {stats['std_ms']:.3f}ms") 522 | 523 | # Add baseline comparison if available 524 | if self.config.run_baseline and key in self.baseline_results: 525 | baseline_stats = self.baseline_results[key] 526 | speedup_stats = self.speedup_results[key] 527 | 528 | print(f" Baseline Median: {baseline_stats['median_ms']:.3f}ms") 529 | print(f" 🚀 SPEEDUP: {speedup_stats['speedup_median']:.1f}x") 530 | print(f" 📈 IMPROVEMENT: {speedup_stats['improvement_pct']:.1f}%") 531 | 532 | # Identify the best performing CUDA kernel by median latency 533 | cuda_results = {k: v for k, v in self.results.items() if 'CUDA' in v['kernel']} 534 | if cuda_results: 535 | best = min(cuda_results.items(), key=lambda x: x[1]['median_ms']) 536 | print(f"\n🏆 Best Performance: {best[0]} with {best[1]['median_ms']:.3f}ms median latency") 537 | 538 | # Show best speedup if baseline available 539 | if self.config.run_baseline and self.speedup_results: 540 | best_speedup = max(self.speedup_results.items(), key=lambda x: x[1]['speedup_median']) 541 | print(f"🚀 Best Speedup: {best_speedup[0]} with {best_speedup[1]['speedup_median']:.1f}x improvement") 542 | 543 | # Overall statistics 544 | all_speedups = [s['speedup_median'] for s in self.speedup_results.values()] 545 | print(f"📊 Average Speedup: {np.mean(all_speedups):.1f}x") 546 | print(f"📊 Geometric Mean Speedup: {np.exp(np.mean(np.log(all_speedups))):.1f}x") 547 | --------------------------------------------------------------------------------