├── .gitignore ├── main.py ├── readme.md ├── requirements.txt └── vram_usage.py /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE specific files 2 | .idea/ 3 | .vscode/ 4 | *.sublime-project 5 | *.sublime-workspace 6 | 7 | # Python specific 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | .env 12 | venv/ 13 | .venv/ 14 | env/ 15 | 16 | # OS specific 17 | .DS_Store 18 | Thumbs.db -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import statistics 4 | import logging 5 | from datetime import datetime 6 | from typing import Tuple 7 | import ollama 8 | from ollama import GenerateResponse 9 | from functools import wraps 10 | import timeout_decorator 11 | from vram_usage import get_vram_info 12 | 13 | import os 14 | 15 | 16 | def setup_logging(model_name: str) -> str: 17 | """Setup logging configuration and return the log filename.""" 18 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 19 | log_filename = f"context_test_{model_name}_{timestamp}.log" 20 | 21 | # Ensure the logs directory exists 22 | log_dir = "logs" 23 | os.makedirs(log_dir, exist_ok=True) 24 | log_path = os.path.join(log_dir, log_filename) 25 | 26 | logging.basicConfig( 27 | level=logging.INFO, 28 | format='%(asctime)s - %(levelname)s - %(message)s', 29 | handlers=[ 30 | logging.FileHandler(log_path), 31 | logging.StreamHandler() 32 | ] 33 | ) 34 | 35 | return log_path 36 | 37 | 38 | def timeout_handler(signum, frame): 39 | raise TimeoutError("Query timed out") 40 | 41 | 42 | def retry_on_timeout(max_retries=3, timeout_seconds=60): 43 | def decorator(func): 44 | @wraps(func) 45 | def wrapper(*args, **kwargs): 46 | for attempt in range(max_retries): 47 | try: 48 | # Set the timeout 49 | @timeout_decorator.timeout(timeout_seconds) 50 | def run_with_timeout(): 51 | return func(*args, **kwargs) 52 | 53 | return run_with_timeout() 54 | 55 | except (TimeoutError, timeout_decorator.TimeoutError) as e: 56 | logging.warning(f"Attempt {attempt + 1}/{max_retries} timed out after {timeout_seconds} seconds") 57 | if attempt == max_retries - 1: 58 | raise TimeoutError(f"All {max_retries} attempts timed out") 59 | logging.info("Retrying...") 60 | time.sleep(1) # Brief pause before retry 61 | 62 | return None # Should never reach here due to raise in last attempt 63 | 64 | return wrapper 65 | 66 | return decorator 67 | 68 | 69 | def analyze_test_sentence() -> Tuple[str, int]: 70 | """Return the test sentence and its actual token count.""" 71 | test_sentence = "This is a test sentence to measure context performance. " 72 | 73 | # Actual tokens (approximately): 74 | # "This" = 1 75 | # "is" = 1 76 | # "a" = 1 77 | # "test" = 1 78 | # "sentence" = 1 79 | # "to" = 1 80 | # "measure" = 1 81 | # "context" = 1 82 | # "performance" = 1 83 | # "." = 1 84 | # " " = several tokens, roughly 1-2 additional tokens 85 | 86 | actual_tokens = 11 # More accurate token count 87 | 88 | return test_sentence, actual_tokens 89 | 90 | 91 | def generate_test_prompt(context_size: int) -> Tuple[str, int, int]: 92 | """Generate prompt and return prompt, its token count, and repetitions.""" 93 | base_prompt = "Count the number of characters in the following text and explain your counting process. Here's the text:\n\n" 94 | # Base prompt tokens: 95 | # Approximately 15-17 tokens for the base prompt 96 | base_prompt_tokens = 16 97 | 98 | test_sentence, tokens_per_sentence = analyze_test_sentence() 99 | 100 | # Calculate how many repetitions we can fit 101 | available_tokens = context_size - base_prompt_tokens 102 | repetitions = max(1, available_tokens // tokens_per_sentence) 103 | 104 | repeated_text = test_sentence * repetitions 105 | full_prompt = base_prompt + repeated_text 106 | 107 | total_tokens = base_prompt_tokens + (repetitions * tokens_per_sentence) 108 | 109 | return full_prompt, total_tokens, repetitions 110 | 111 | 112 | @retry_on_timeout(max_retries=3, timeout_seconds=60) 113 | def run_ollama_query(model: str, context_size: int) -> Tuple[GenerateResponse, str, int]: 114 | """Run a query to Ollama with a specific context size and return the response metrics.""" 115 | try: 116 | full_prompt, estimated_tokens, repetitions = generate_test_prompt(context_size) 117 | 118 | logging.debug(f"Estimated tokens in prompt: {estimated_tokens}") 119 | logging.debug(f"Number of repetitions: {repetitions}") 120 | 121 | client = ollama.Client(host='http://localhost:11434') # Explicit host 122 | response = client.generate( 123 | model=model, 124 | prompt=full_prompt, 125 | options={ 126 | "num_ctx": context_size 127 | } 128 | ) 129 | return response, full_prompt, estimated_tokens 130 | except ConnectionError as e: 131 | logging.error(f"Failed to connect to Ollama server: {str(e)}") 132 | logging.info("Make sure Ollama is running and accessible at http://localhost:11434") 133 | raise 134 | 135 | 136 | def calculate_tokens_per_second(response: GenerateResponse) -> float: 137 | """Calculate the tokens per second rate from the Ollama response.""" 138 | eval_count = response.eval_count if hasattr(response, 'eval_count') else 0 139 | eval_duration = response.eval_duration if hasattr(response, 'eval_duration') else 1 140 | tokens_per_second = eval_count / (eval_duration * 1e-9) 141 | return tokens_per_second 142 | 143 | def test_context_size(model: str, context_size: int, num_tests: int = 3) -> tuple[float, float, float]: 144 | """Run multiple tests at a specific context size and return the average tokens/sec and VRAM info.""" 145 | tokens_per_second_list = [] 146 | 147 | logging.info(f"\nContext Size: {context_size}") 148 | logging.info("-" * 50) 149 | 150 | # Get initial VRAM reading 151 | current_vram, max_vram = get_vram_info() 152 | logging.info(f"Initial VRAM Usage: {current_vram:.0f}M / {max_vram:.0f}M ({(current_vram/max_vram*100):.1f}%)") 153 | 154 | for i in range(num_tests): 155 | try: 156 | response, prompt, estimated_tokens = run_ollama_query(model, context_size) 157 | tokens_per_second = calculate_tokens_per_second(response) 158 | tokens_per_second_list.append(tokens_per_second) 159 | 160 | # Get VRAM reading after test 161 | current_vram, max_vram = get_vram_info() 162 | vram_percent = (current_vram / max_vram * 100) if max_vram > 0 else 0 163 | 164 | # Log detailed test information 165 | test_info = { 166 | "Test Number": i + 1, 167 | "Prompt Length (chars)": len(prompt), 168 | "Prompt Tokens": estimated_tokens, 169 | "Response Length (chars)": len(response.response), 170 | "Response Words": len(response.response.split()), 171 | "Response Estimated Tokens": int(len(response.response.split()) * 1.3), 172 | "Total Tokens Processed": response.eval_count, 173 | "Tokens/sec": f"{tokens_per_second:.2f}", 174 | "Eval Duration": f"{response.eval_duration * 1e-9:.2f}s", 175 | "VRAM Usage": f"{current_vram:.0f}M / {max_vram:.0f}M ({vram_percent:.1f}%)" 176 | } 177 | 178 | logging.info(f"Test {i + 1} Details:") 179 | for key, value in test_info.items(): 180 | logging.info(f" {key}: {value}") 181 | 182 | logging.info("Prompt Preview (first 200 chars):") 183 | logging.info(f" {prompt[:200]}...") 184 | logging.info("Response Preview (first 200 chars):") 185 | logging.info(f" {response.response[:200]}...\n") 186 | 187 | except (TimeoutError, ollama.ResponseError) as e: 188 | logging.error(f"Error in test {i + 1}: {str(e)}") 189 | continue 190 | 191 | if tokens_per_second_list: 192 | avg_tokens_per_second = statistics.mean(tokens_per_second_list) 193 | logging.info(f"Average tokens/sec for context size {context_size}: {avg_tokens_per_second:.2f}") 194 | return avg_tokens_per_second, current_vram, max_vram 195 | else: 196 | logging.warning(f"All tests failed for context size {context_size}") 197 | return 0.0, current_vram, max_vram 198 | 199 | 200 | def find_max_context(model: str, start_size: int = 1024, step_size: int = 1024, 201 | minimum_token_rate: int = 10, num_tests: int = 3) -> Tuple[int, float]: 202 | """Find the maximum context size for a model that maintains acceptable performance.""" 203 | context_size = start_size 204 | previous_context_size = start_size 205 | previous_tokens_per_second = float('inf') 206 | 207 | logging.info(f"Starting maximum context size test for model: {model}") 208 | logging.info(f"Parameters:") 209 | logging.info(f" Minimum acceptable token rate: {minimum_token_rate} tokens/sec") 210 | logging.info(f" Starting context size: {start_size}") 211 | logging.info(f" Step size: {step_size}") 212 | logging.info(f" Tests per context size: {num_tests}") 213 | logging.info("=" * 50) 214 | 215 | while True: 216 | try: 217 | logging.info(f"\nTesting context size: {context_size}") 218 | avg_tokens_per_second, current_vram, max_vram = test_context_size(model, context_size, num_tests) 219 | vram_percent = (current_vram / max_vram * 100) if max_vram > 0 else 0 220 | 221 | # Check both token rate and VRAM usage 222 | is_good = avg_tokens_per_second >= minimum_token_rate and vram_percent < 99 223 | status = "GOOD" if is_good else "SLOW/HIGH VRAM" 224 | 225 | logging.info(f"Status: {status}") 226 | logging.info(f"Token Rate: {avg_tokens_per_second:.2f} tokens/sec") 227 | logging.info(f"VRAM Usage: {current_vram:.0f}M / {max_vram:.0f}M ({vram_percent:.1f}%)") 228 | 229 | if not is_good: 230 | reason = [] 231 | if avg_tokens_per_second < minimum_token_rate: 232 | reason.append(f"Token rate below minimum threshold of {minimum_token_rate}") 233 | if vram_percent >= 99: 234 | reason.append("VRAM usage too high") 235 | logging.info(f"Stopping due to: {', '.join(reason)}") 236 | return previous_context_size, previous_tokens_per_second 237 | 238 | previous_context_size = context_size 239 | previous_tokens_per_second = avg_tokens_per_second 240 | context_size += step_size 241 | 242 | except ollama.ResponseError as e: 243 | logging.error(f"Error occurred at context size {context_size}: {str(e)}") 244 | return previous_context_size, previous_tokens_per_second 245 | 246 | 247 | def run_context_test(model: str, min_token_rate: int = 10, start: int = 1024, 248 | step: int = 1024, tests: int = 3) -> None: 249 | """ 250 | Run the context size test with the given parameters and log results. 251 | 252 | Args: 253 | model: The Ollama model name 254 | min_token_rate: Minimum acceptable tokens per second 255 | start: Starting context size 256 | step: Step size for context increments 257 | tests: Number of tests per context size 258 | """ 259 | # Setup logging 260 | log_filename = setup_logging(model) 261 | logging.info(f"Log file created: {log_filename}") 262 | 263 | max_context, final_tokens_per_second = find_max_context( 264 | model, start, step, min_token_rate, tests 265 | ) 266 | 267 | # Log final results 268 | logging.info("\n" + "=" * 60) 269 | logging.info("FINAL RESULTS:") 270 | logging.info(f"Maximum recommended context size: {max_context}") 271 | logging.info(f"Average tokens per second at max context: {final_tokens_per_second:.2f}") 272 | logging.info(f"Minimum token rate threshold: {min_token_rate}") 273 | logging.info("=" * 60) 274 | 275 | # Also print to console 276 | print(f"\nResults have been saved to: {log_filename}") 277 | 278 | 279 | if __name__ == "__main__": 280 | parser = argparse.ArgumentParser(description="Find the maximum usable context size for an Ollama model") 281 | parser.add_argument("model", help="The Ollama model name (e.g., 'codestral:latest')") 282 | parser.add_argument("--min_token_rate", type=int, default=10, 283 | help="The minimum acceptable tokens per second rate (default: 10)") 284 | parser.add_argument("--start", type=int, default=1024, help="Starting context size") 285 | parser.add_argument("--step", type=int, default=1024, help="Step size for context increments") 286 | parser.add_argument("--tests", type=int, default=3, help="Number of tests per context size") 287 | 288 | args = parser.parse_args() 289 | run_context_test(**vars(args)) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Ollama Context Size Tester 2 | 3 | A tool to determine the maximum usable context size for Ollama models while monitoring performance and VRAM usage. This tool helps you find the optimal balance between context size and performance for your specific hardware setup. 4 | 5 | ## Overview 6 | 7 | This tool tests increasing context sizes with your chosen Ollama model to find the maximum size that maintains acceptable performance. It monitors: 8 | - Token processing speed (tokens per second) 9 | - VRAM usage 10 | - Response times 11 | - Model behavior at different context lengths 12 | 13 | ## Prerequisites 14 | 15 | ### Windows 16 | - Windows 10/11 17 | - Python 3.8+ 18 | - Ollama (https://ollama.com/) installed and running 19 | - For VRAM monitoring: 20 | - NVIDIA GPU: No additional setup needed (uses nvidia-smi, included with drivers) 21 | - AMD GPU: ROCm for Windows (if available for your GPU) 22 | 23 | ### Linux 24 | - Python 3.8+ 25 | - Ollama installed and running 26 | - For VRAM monitoring: 27 | - NVIDIA GPU: No additional setup needed (uses nvidia-smi) 28 | - AMD GPU: Either ROCm or radeontop (`sudo apt install radeontop` on Ubuntu/Debian) 29 | 30 | ## Installation 31 | 32 | 1. Clone this repository: 33 | ```bash 34 | git clone https://github.com/scionero/maxcontextfinder 35 | cd maxcontextfinder 36 | ``` 37 | 38 | 2. Install required Python packages: 39 | ```bash 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | ## Usage 44 | 45 | Basic usage: 46 | ```bash 47 | python main.py MODEL_NAME 48 | ``` 49 | 50 | Example: 51 | ```bash 52 | python main.py codellama:latest 53 | ``` 54 | 55 | ### Command Line Options 56 | 57 | - `model`: (Required) The Ollama model name (e.g., 'codellama:latest', 'llama2:13b') 58 | - `--min_token_rate`: Minimum acceptable tokens per second (default: 10) 59 | - `--start`: Starting context size (default: 1024) 60 | - `--step`: Step size for context increments (default: 1024) 61 | - `--tests`: Number of tests per context size (default: 3) 62 | 63 | Example with all options: 64 | ```bash 65 | python main.py mistral:7b --min_token_rate 15 --start 2048 --step 2048 --tests 5 66 | ``` 67 | 68 | ### Output 69 | 70 | The tool generates detailed logs including: 71 | - Test parameters and configuration 72 | - Performance metrics for each test 73 | - VRAM usage statistics (when available) 74 | - Token processing speeds 75 | - Final recommended context size 76 | 77 | Logs are saved to the `logs` directory with names: `context_test_MODEL_TIMESTAMP.log` 78 | 79 | ## VRAM Monitoring Support 80 | 81 | ### Windows 82 | - NVIDIA GPUs: Fully supported through nvidia-smi 83 | - AMD GPUs: Supported through ROCm when available 84 | - Intel GPUs: Not currently supported 85 | 86 | ### Linux 87 | - NVIDIA GPUs: Fully supported through nvidia-smi 88 | - AMD GPUs: Supported through either: 89 | - ROCm (preferred when available) 90 | - radeontop (fallback option) 91 | - Intel GPUs: Not currently supported 92 | 93 | ## Important Notes 94 | 95 | 1. **GPU Support**: 96 | - NVIDIA GPUs are fully supported on both Windows and Linux 97 | - AMD GPU support varies by platform and available tools 98 | - Systems without supported GPUs will run without VRAM monitoring 99 | 100 | 2. **Framework Specificity**: Results are specific to Ollama and may differ from other frameworks like: 101 | - Pure llama.cpp 102 | - vLLM 103 | - Different quantization methods 104 | - Other serving frameworks 105 | 106 | 3. **Hardware Dependence**: Results depend on your hardware: 107 | - GPU memory and performance 108 | - CPU capabilities 109 | - System memory 110 | - Storage speed 111 | 112 | ## Understanding Results 113 | 114 | The tool stops testing larger context sizes when: 115 | - Token processing speed drops below the minimum threshold 116 | - VRAM usage approaches 100% 117 | - Model encounters errors or timeouts 118 | 119 | The "maximum recommended context size" is the largest size that maintained acceptable performance across all metrics. 120 | 121 | ## Contributing 122 | 123 | Contributions are welcome! Areas for improvement: 124 | - Additional GPU support 125 | - More performance metrics 126 | - Support for other frameworks 127 | - Better token counting accuracy 128 | - Alternative testing methodologies 129 | 130 | Please feel free to: 131 | - Open issues for bugs or feature requests 132 | - Submit pull requests with improvements 133 | - Share your testing results 134 | - Suggest better testing methodologies 135 | 136 | ## Disclaimer 137 | 138 | Results should be considered approximate. Real-world performance may vary based on: 139 | - Specific prompt content 140 | - Model implementation details 141 | - System load and conditions 142 | - Hardware configuration 143 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ollama==0.4.7 2 | timeout-decorator==0.5.0 3 | -------------------------------------------------------------------------------- /vram_usage.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import re 3 | import logging 4 | import shutil 5 | from typing import Tuple 6 | 7 | def is_rocm_smi_available() -> bool: 8 | """Check if rocm-smi is available.""" 9 | try: 10 | paths = [ 11 | shutil.which('rocm-smi'), 12 | '/opt/rocm/bin/rocm-smi', 13 | 'C:\\Program Files\\AMD\\ROCm\\bin\\rocm-smi.exe' 14 | ] 15 | for path in paths: 16 | if path: 17 | subprocess.run([path, '--showmeminfo', 'vram'], 18 | stdout=subprocess.PIPE, 19 | stderr=subprocess.PIPE, 20 | timeout=2) 21 | return True 22 | except (subprocess.SubprocessError, subprocess.TimeoutExpired): 23 | pass 24 | return False 25 | 26 | def get_gpu_type(): 27 | """Detect whether system has NVIDIA or AMD GPU with monitoring tools.""" 28 | if shutil.which('nvidia-smi'): 29 | try: 30 | subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, timeout=2) 31 | return "nvidia" 32 | except (subprocess.SubprocessError, subprocess.TimeoutExpired): 33 | pass 34 | 35 | # Check for ROCm SMI first (official AMD tool) 36 | if is_rocm_smi_available(): 37 | return "amd_rocm" 38 | 39 | # Fallback to radeontop on Linux 40 | if shutil.which('radeontop'): 41 | try: 42 | subprocess.run(['radeontop', '-d', '-', '-l', '1'], stdout=subprocess.PIPE, timeout=2) 43 | return "amd_radeontop" 44 | except (subprocess.SubprocessError, subprocess.TimeoutExpired): 45 | pass 46 | 47 | return None 48 | 49 | def get_nvidia_vram(): 50 | """Get VRAM info using nvidia-smi.""" 51 | try: 52 | cmd = ["nvidia-smi", "--query-gpu=memory.used,memory.total", "--format=csv,noheader,nounits"] 53 | result = subprocess.run(cmd, capture_output=True, text=True, timeout=2) 54 | if result.returncode == 0: 55 | used, total = map(float, result.stdout.strip().split(',')) 56 | return used, total 57 | except (subprocess.SubprocessError, ValueError) as e: 58 | logging.warning(f"Error getting NVIDIA VRAM info: {str(e)}") 59 | return 0.0, 0.0 60 | 61 | def get_amd_rocm_vram() -> Tuple[float, float]: 62 | """Get VRAM info using rocm-smi.""" 63 | try: 64 | cmd = shutil.which('rocm-smi') or '/opt/rocm/bin/rocm-smi' 65 | if not cmd: 66 | cmd = 'C:\\Program Files\\AMD\\ROCm\\bin\\rocm-smi.exe' 67 | 68 | result = subprocess.run( 69 | [cmd, '--showmeminfo', 'vram', '--json'], 70 | stdout=subprocess.PIPE, 71 | stderr=subprocess.PIPE, 72 | text=True, 73 | timeout=2 74 | ) 75 | 76 | if result.returncode == 0: 77 | import json 78 | data = json.loads(result.stdout) 79 | for card in data.values(): 80 | vram_info = card.get('VRAM Memory') 81 | if vram_info: 82 | used_mb = float(vram_info['used']) / 1024 / 1024 83 | total_mb = float(vram_info['total']) / 1024 / 1024 84 | return used_mb, total_mb 85 | 86 | except Exception as e: 87 | logging.warning(f"Error getting VRAM info via ROCm SMI: {str(e)}") 88 | return 0.0, 0.0 89 | 90 | def get_amd_radeontop_vram(): 91 | """Get VRAM info using radeontop (Linux fallback).""" 92 | try: 93 | cmd = ["radeontop", "-d", "-", "-l", "1"] 94 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True) 95 | 96 | for line in process.stdout: 97 | data = line.strip().split() 98 | if data[0] == "Dumping": 99 | continue 100 | 101 | line_str = ' '.join(data) 102 | vram_match = re.search(r'vram \d+\.\d+% (\d+\.\d+)mb', line_str) 103 | if vram_match: 104 | current_mb = float(vram_match.group(1)) 105 | vram_percent = float(re.search(r'vram (\d+\.\d+)%', line_str).group(1)) 106 | max_mb = current_mb / (vram_percent / 100) if vram_percent > 0 else 0 107 | process.terminate() 108 | return current_mb, max_mb 109 | 110 | process.terminate() 111 | except (subprocess.SubprocessError, Exception) as e: 112 | logging.warning(f"Error getting AMD VRAM info via radeontop: {str(e)}") 113 | return 0.0, 0.0 114 | 115 | def get_vram_info(): 116 | """Get VRAM usage information. Returns (current_mb, max_mb).""" 117 | if not hasattr(get_vram_info, '_gpu_type'): 118 | get_vram_info._gpu_type = get_gpu_type() 119 | if get_vram_info._gpu_type: 120 | logging.info(f"Detected {get_vram_info._gpu_type.upper()} GPU") 121 | else: 122 | logging.warning("No supported GPU monitoring tools found") 123 | logging.info("Continuing without VRAM monitoring") 124 | 125 | if get_vram_info._gpu_type == "nvidia": 126 | return get_nvidia_vram() 127 | elif get_vram_info._gpu_type == "amd_rocm": 128 | return get_amd_rocm_vram() 129 | elif get_vram_info._gpu_type == "amd_radeontop": 130 | return get_amd_radeontop_vram() 131 | return 0.0, 0.0 --------------------------------------------------------------------------------