├── prompts ├── two-plus-two.txt ├── color-sky.txt ├── capital-france.txt └── largest-planet.txt ├── example1.gif ├── llmwalk ├── __init__.py ├── __main__.py ├── llm.py └── cli.py ├── tests ├── __snapshots__ │ └── test_prompts │ │ ├── test_prompt_snapshot[default-color-sky].json │ │ ├── test_prompt_snapshot[mistral-color-sky].json │ │ ├── test_prompt_snapshot[default-capital-france].json │ │ ├── test_prompt_snapshot[default-two-plus-two].json │ │ ├── test_prompt_snapshot[mistral-two-plus-two].json │ │ ├── test_prompt_snapshot[mistral-largest-planet].json │ │ ├── test_prompt_snapshot[default-largest-planet].json │ │ └── test_prompt_snapshot[mistral-capital-france].json └── test_prompts.py ├── LICENSE ├── pyproject.toml ├── AGENTS.md ├── README.md └── .gitignore /prompts/two-plus-two.txt: -------------------------------------------------------------------------------- 1 | Answering with just one word, what is 2+2? 2 | -------------------------------------------------------------------------------- /example1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samwho/llmwalk/HEAD/example1.gif -------------------------------------------------------------------------------- /prompts/color-sky.txt: -------------------------------------------------------------------------------- 1 | Answering with just 1 word, what color is the sky? 2 | -------------------------------------------------------------------------------- /llmwalk/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["__version__"] 2 | 3 | __version__ = "0.1.0" 4 | -------------------------------------------------------------------------------- /prompts/capital-france.txt: -------------------------------------------------------------------------------- 1 | Answering with 1 word, what is the capital of France? 2 | -------------------------------------------------------------------------------- /prompts/largest-planet.txt: -------------------------------------------------------------------------------- 1 | Answering with just one word, what is the largest planet in our solar system? 2 | -------------------------------------------------------------------------------- /llmwalk/__main__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .cli import main 4 | 5 | if __name__ == "__main__": 6 | main() 7 | -------------------------------------------------------------------------------- /tests/__snapshots__/test_prompts/test_prompt_snapshot[default-color-sky].json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "answer": "Blue", 4 | "finish_reason": "eos_token", 5 | "probability": 0.635196369415354, 6 | "tokens": [ 7 | { 8 | "probability": 0.9809126853942871, 9 | "token": "Blue" 10 | }, 11 | { 12 | "probability": 0.647556483745575, 13 | "token": "" 14 | } 15 | ] 16 | } 17 | ] -------------------------------------------------------------------------------- /tests/__snapshots__/test_prompts/test_prompt_snapshot[mistral-color-sky].json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "answer": "Blue", 4 | "finish_reason": "eos_token", 5 | "probability": 0.706212235187877, 6 | "tokens": [ 7 | { 8 | "probability": 0.9996299743652344, 9 | "token": "Blue" 10 | }, 11 | { 12 | "probability": 0.7064736485481262, 13 | "token": "" 14 | } 15 | ] 16 | } 17 | ] -------------------------------------------------------------------------------- /tests/__snapshots__/test_prompts/test_prompt_snapshot[default-capital-france].json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "answer": "Paris", 4 | "finish_reason": "eos_token", 5 | "probability": 0.5851167404268445, 6 | "tokens": [ 7 | { 8 | "probability": 0.9878690838813782, 9 | "token": "Paris" 10 | }, 11 | { 12 | "probability": 0.5923019051551819, 13 | "token": "" 14 | } 15 | ] 16 | } 17 | ] -------------------------------------------------------------------------------- /tests/__snapshots__/test_prompts/test_prompt_snapshot[default-two-plus-two].json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "answer": "Four.", 4 | "finish_reason": "eos_token", 5 | "probability": 0.4089153564667007, 6 | "tokens": [ 7 | { 8 | "probability": 0.585510790348053, 9 | "token": "Four" 10 | }, 11 | { 12 | "probability": 0.6984281539916992, 13 | "token": "." 14 | }, 15 | { 16 | "probability": 0.9999465346336365, 17 | "token": "" 18 | } 19 | ] 20 | } 21 | ] -------------------------------------------------------------------------------- /tests/__snapshots__/test_prompts/test_prompt_snapshot[mistral-two-plus-two].json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "answer": "Four.", 4 | "finish_reason": "eos_token", 5 | "probability": 0.6346089192323131, 6 | "tokens": [ 7 | { 8 | "probability": 0.9988828897476196, 9 | "token": "Four" 10 | }, 11 | { 12 | "probability": 0.6585058569908142, 13 | "token": "." 14 | }, 15 | { 16 | "probability": 0.9647881388664246, 17 | "token": "" 18 | } 19 | ] 20 | } 21 | ] -------------------------------------------------------------------------------- /tests/__snapshots__/test_prompts/test_prompt_snapshot[mistral-largest-planet].json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "answer": "Jupiter", 4 | "finish_reason": "eos_token", 5 | "probability": 0.9310452838409994, 6 | "tokens": [ 7 | { 8 | "probability": 0.9999580383300781, 9 | "token": "J" 10 | }, 11 | { 12 | "probability": 0.9997254014015198, 13 | "token": "upiter" 14 | }, 15 | { 16 | "probability": 0.9313400983810425, 17 | "token": "" 18 | } 19 | ] 20 | } 21 | ] -------------------------------------------------------------------------------- /tests/__snapshots__/test_prompts/test_prompt_snapshot[default-largest-planet].json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "answer": "Jupiter.", 4 | "finish_reason": "eos_token", 5 | "probability": 0.7771506816745445, 6 | "tokens": [ 7 | { 8 | "probability": 0.8974856734275818, 9 | "token": "J" 10 | }, 11 | { 12 | "probability": 0.9881630539894104, 13 | "token": "upiter" 14 | }, 15 | { 16 | "probability": 0.8763476610183716, 17 | "token": "." 18 | }, 19 | { 20 | "probability": 0.9999370574951172, 21 | "token": "" 22 | } 23 | ] 24 | } 25 | ] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2025 Sam Rose 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling>=1.22.0"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "llmwalk" 7 | version = "0.1.3" 8 | description = "Explore the answer-space for any prompt and any MLX-supported model." 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | dependencies = [ 12 | "mlx-lm==0.29.1", 13 | "rich==14.2.0", 14 | "sortedcontainers==2.4.0", 15 | "transformers===4.57.3", 16 | ] 17 | 18 | [project.optional-dependencies] 19 | dev = [ 20 | "pytest>=8.0.0", 21 | "syrupy>=4.0.0", 22 | ] 23 | 24 | [project.scripts] 25 | llmwalk = "llmwalk.cli:main" 26 | 27 | [tool.hatch.build] 28 | exclude = [ 29 | "example1.gif", 30 | ] 31 | 32 | [tool.hatch.build.targets.wheel] 33 | packages = ["llmwalk"] 34 | 35 | [tool.pytest.ini_options] 36 | filterwarnings = [ 37 | # Noise from the sentencepiece SWIG extension when running with Mistral tokenizers. 38 | "ignore:.*SwigPyPacked.*:DeprecationWarning", 39 | "ignore:.*SwigPyObject.*:DeprecationWarning", 40 | "ignore:.*swigvarlink.*:DeprecationWarning", 41 | ] 42 | -------------------------------------------------------------------------------- /AGENTS.md: -------------------------------------------------------------------------------- 1 | # Repository Guidelines 2 | 3 | ## Project Structure & Modules 4 | - Core package lives in `llmwalk/` (`cli.py` for CLI wiring and output rendering, `llm.py` for search logic, `__main__.py` for entrypoint). 5 | - Prompt examples sit in `prompts/` and feed snapshot tests; update or add `.txt` files there to cover new cases. 6 | - Tests and snapshots are under `tests/` with Syrupy fixtures in `tests/__snapshots__/`. 7 | - Top-level assets: `README.md` for user-facing usage, `example1.gif` for demo, `pyproject.toml` for deps/metadata. 8 | 9 | ## Build, Run, and Dev Commands 10 | - Run the tool with dependencies resolved via uv: `uv run llmwalk -p "Your prompt"`. 11 | - Export structured results: append `--format json` or `--format csv`. 12 | 13 | ## Coding Style & Naming 14 | - Python 3.10+, 4-space indentation, type hints throughout; prefer dataclasses for config/data carriers (`SearchConfig`, `Branch`). 15 | - Keep CLI arguments descriptive and validated early (see `parse_args`); guard user-facing output with clear defaults. 16 | - Follow existing naming: `PromptTreeSearch`, `render_*` for display helpers 17 | - Rich output colors are hex strings; keep ASCII for code/doc unless the runtime already emits symbols. 18 | 19 | ## Testing Guidelines 20 | - Primary suite uses `pytest` plus Syrupy snapshots: `uv run pytest`. 21 | - Update snapshots intentionally with `uv run pytest --snapshot-update`; ensure prompt fixtures in `prompts/` match expected outputs. 22 | - Tests currently focus on deterministic JSON output; add new prompts or behaviors alongside snapshot coverage where possible. 23 | 24 | ## Commit & PR Practices 25 | - Commit messages in history are short, imperative, and scoped (e.g., “Allow -p to be a file path.”); mirror that style and keep changesets focused. 26 | - For PRs, describe model/runtime expectations (model id, `--top-k`, `--top-p`, `--temperature`) and note any snapshot updates or offline-mode considerations. 27 | - Include reproducible commands (`uvx`/`uv run pytest`) and mention any HuggingFace downloads or cache needs.*** 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llmwalk 2 | 3 | Explore the answer-space for any prompt and any MLX-supported model. See 4 | for supported models. 5 | 6 | ![Usage example gif](example1.gif) 7 | 8 | Instead of sampling from the possible tokens each step, llmwalk branches out 9 | and completes all of the branches the sampler would consider based on 10 | `--top-k`, `--top-p` and `--temperature`, ranking the results by probability 11 | as it goes. 12 | 13 | The tree is walked prioritising the most likely branches, until it finds `-n` 14 | branches and then it stops. It doesn't enumerate all possibilities, just enough 15 | to know for sure it has found the `-n` most likely branches. 16 | 17 | ## Usage 18 | 19 | - `uvx llmwalk -p "In what year was Barack Obama born?"` 20 | - `uvx llmwalk -p "Write a haiku about compilers" -n 5` 21 | - `uvx llmwalk -p "Give me one word: " --top-k 200 --temperature 0.7` 22 | 23 | ## Options 24 | 25 | - `-p, --prompt TEXT`: Prompt to score (wrapped with the model’s chat template). 26 | - `-m, --model MODEL`: MLX-LM model identifier or path (default: `mlx-community/Llama-3.2-1B-Instruct-4bit`), supported models can be found at 27 | - `-n N`: Number of answers to show. The search stops once it has `N` finished answers and no unfinished branch can beat the worst of those `N`. 28 | - `--min-probability FLOAT`: Any branch whose cumulative probability falls below this is marked finished (`low_probability`) and not expanded further. 29 | - `--top-k INT`: At each step, expand at most `k` next tokens (highest probability). 30 | - `--top-p FLOAT`: Nucleus cutoff applied *within the top-k tokens* at each step (keep adding tokens until cumulative probability ≥ `p`). 31 | - `--temperature FLOAT`: Softmax temperature applied when computing per-step probabilities (`1.0` is the model distribution; must be `> 0`). 32 | - `--stats-interval SECONDS`: How often to refresh the live view (`<= 0` disables periodic refresh; still renders at start/end). 33 | - `--format {csv,json}`: Output format for machine-readable output. When specified, disables the interactive display and prints results to stdout when the job completes. 34 | 35 | ## Machine-readable output 36 | 37 | Use `--format` to get structured output for scripting or further processing: 38 | 39 | ```bash 40 | # JSON output 41 | uvx llmwalk -p "What is 2+2?" --format json 42 | 43 | # CSV output 44 | uvx llmwalk -p "What is 2+2?" --format csv 45 | ``` 46 | 47 | JSON output includes detailed token-level information: 48 | 49 | ```json 50 | [ 51 | { 52 | "answer": "4", 53 | "probability": 0.95, 54 | "finish_reason": "eos_token", 55 | "tokens": [ 56 | {"token": "4", "probability": 0.95} 57 | ] 58 | } 59 | ] 60 | ``` 61 | 62 | CSV output provides a simpler tabular format with columns: `answer`, `probability`, `finish_reason`. 63 | -------------------------------------------------------------------------------- /tests/test_prompts.py: -------------------------------------------------------------------------------- 1 | """Snapshot tests for llmwalk prompts.""" 2 | 3 | from __future__ import annotations 4 | 5 | import gc 6 | import json 7 | from pathlib import Path 8 | 9 | import pytest 10 | from syrupy.assertion import SnapshotAssertion 11 | from syrupy.extensions.json import JSONSnapshotExtension 12 | 13 | from llmwalk.cli import main 14 | 15 | PROMPTS_DIR = Path(__file__).parent.parent / "prompts" 16 | DEFAULT_MODEL = "mlx-community/Llama-3.2-1B-Instruct-4bit" 17 | MISTRAL_MODEL = "mlx-community/Mistral-7B-Instruct-v0.3-4bit" 18 | 19 | try: 20 | from huggingface_hub.errors import ( 21 | LocalEntryNotFoundError as HFLocalEntryNotFoundError, 22 | ) 23 | except Exception: # pragma: no cover - optional dependency details 24 | HFLocalEntryNotFoundError = None # type: ignore[assignment] 25 | 26 | 27 | class JSONSnapshotExtensionPretty(JSONSnapshotExtension): 28 | """JSON snapshot extension with pretty printing.""" 29 | 30 | def serialize(self, data, **kwargs): 31 | return json.dumps(data, indent=2, sort_keys=True) 32 | 33 | 34 | @pytest.fixture 35 | def snapshot_json(snapshot: SnapshotAssertion) -> SnapshotAssertion: 36 | return snapshot.use_extension(JSONSnapshotExtensionPretty) 37 | 38 | 39 | def get_prompt_files() -> list[Path]: 40 | """Get all prompt files from the prompts directory.""" 41 | return sorted(PROMPTS_DIR.glob("*.txt")) 42 | 43 | 44 | def get_models() -> list[str]: 45 | return [DEFAULT_MODEL, MISTRAL_MODEL] 46 | 47 | 48 | def _should_skip_missing_model_error(exc: Exception) -> bool: 49 | if isinstance(exc, FileNotFoundError): 50 | return True 51 | if HFLocalEntryNotFoundError is not None and isinstance( 52 | exc, HFLocalEntryNotFoundError 53 | ): 54 | return True 55 | return False 56 | 57 | 58 | @pytest.fixture(autouse=True) 59 | def purge_model_memory(): 60 | yield 61 | gc.collect() 62 | try: 63 | import mlx.core as mx 64 | 65 | if hasattr(mx, "metal") and hasattr(mx.metal, "clear_cache"): 66 | mx.metal.clear_cache() 67 | except Exception: 68 | pass 69 | 70 | 71 | @pytest.mark.parametrize( 72 | "prompt_file", 73 | get_prompt_files(), 74 | ids=[p.stem for p in get_prompt_files()], 75 | ) 76 | @pytest.mark.parametrize( 77 | "model", 78 | get_models(), 79 | ids=["default", "mistral"], 80 | ) 81 | def test_prompt_snapshot( 82 | prompt_file: Path, 83 | model: str, 84 | snapshot_json: SnapshotAssertion, 85 | capsys, 86 | ) -> None: 87 | """Test that running llmwalk on each prompt produces consistent output.""" 88 | # Run with JSON format for deterministic output 89 | try: 90 | main( 91 | [ 92 | "-p", 93 | str(prompt_file), 94 | "--format", 95 | "json", 96 | "--model", 97 | model, 98 | "--offline", 99 | "-n", 100 | "3", 101 | "--top-p", 102 | "0.5", 103 | "--top-k", 104 | "10", 105 | ] 106 | ) 107 | except Exception as e: 108 | if _should_skip_missing_model_error(e): 109 | pytest.skip(f"Model not available locally: {model}") 110 | raise 111 | 112 | captured = capsys.readouterr() 113 | output = json.loads(captured.out) 114 | 115 | # Snapshot the structured output 116 | assert output == snapshot_json 117 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | ### Python Patch ### 167 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 168 | poetry.toml 169 | 170 | # ruff 171 | .ruff_cache/ 172 | 173 | # LSP config files 174 | pyrightconfig.json 175 | 176 | # End of https://www.toptal.com/developers/gitignore/api/python 177 | -------------------------------------------------------------------------------- /tests/__snapshots__/test_prompts/test_prompt_snapshot[mistral-capital-france].json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "answer": "Paris (not a typo, Paris is the capital city of France)", 4 | "finish_reason": "eos_token", 5 | "probability": 0.002477645966623216, 6 | "tokens": [ 7 | { 8 | "probability": 0.9997292160987854, 9 | "token": "Paris" 10 | }, 11 | { 12 | "probability": 0.7607424259185791, 13 | "token": "(" 14 | }, 15 | { 16 | "probability": 0.19734331965446472, 17 | "token": "not" 18 | }, 19 | { 20 | "probability": 0.8402746319770813, 21 | "token": "a" 22 | }, 23 | { 24 | "probability": 0.5315064191818237, 25 | "token": "typ" 26 | }, 27 | { 28 | "probability": 0.9996547698974609, 29 | "token": "o" 30 | }, 31 | { 32 | "probability": 0.9821969866752625, 33 | "token": "," 34 | }, 35 | { 36 | "probability": 0.7557774782180786, 37 | "token": "Paris" 38 | }, 39 | { 40 | "probability": 0.9985285401344299, 41 | "token": "is" 42 | }, 43 | { 44 | "probability": 0.7320148348808289, 45 | "token": "the" 46 | }, 47 | { 48 | "probability": 0.670669436454773, 49 | "token": "capital" 50 | }, 51 | { 52 | "probability": 0.5634616017341614, 53 | "token": "city" 54 | }, 55 | { 56 | "probability": 0.46561211347579956, 57 | "token": "of" 58 | }, 59 | { 60 | "probability": 0.9999542236328125, 61 | "token": "France" 62 | }, 63 | { 64 | "probability": 0.3925216794013977, 65 | "token": ")" 66 | }, 67 | { 68 | "probability": 0.9867410659790039, 69 | "token": "" 70 | } 71 | ] 72 | }, 73 | { 74 | "answer": "Paris (not a typo, Paris is the capital city, not the capital state)", 75 | "finish_reason": "eos_token", 76 | "probability": 0.000701938764540693, 77 | "tokens": [ 78 | { 79 | "probability": 0.9997292160987854, 80 | "token": "Paris" 81 | }, 82 | { 83 | "probability": 0.7607424259185791, 84 | "token": "(" 85 | }, 86 | { 87 | "probability": 0.19734331965446472, 88 | "token": "not" 89 | }, 90 | { 91 | "probability": 0.8402746319770813, 92 | "token": "a" 93 | }, 94 | { 95 | "probability": 0.5315064191818237, 96 | "token": "typ" 97 | }, 98 | { 99 | "probability": 0.9996547698974609, 100 | "token": "o" 101 | }, 102 | { 103 | "probability": 0.9821969866752625, 104 | "token": "," 105 | }, 106 | { 107 | "probability": 0.7557774782180786, 108 | "token": "Paris" 109 | }, 110 | { 111 | "probability": 0.9985285401344299, 112 | "token": "is" 113 | }, 114 | { 115 | "probability": 0.7320148348808289, 116 | "token": "the" 117 | }, 118 | { 119 | "probability": 0.670669436454773, 120 | "token": "capital" 121 | }, 122 | { 123 | "probability": 0.5634616017341614, 124 | "token": "city" 125 | }, 126 | { 127 | "probability": 0.48795729875564575, 128 | "token": "," 129 | }, 130 | { 131 | "probability": 0.5095872282981873, 132 | "token": "not" 133 | }, 134 | { 135 | "probability": 0.9860091805458069, 136 | "token": "the" 137 | }, 138 | { 139 | "probability": 0.8910145163536072, 140 | "token": "capital" 141 | }, 142 | { 143 | "probability": 0.486347496509552, 144 | "token": "state" 145 | }, 146 | { 147 | "probability": 0.4830770790576935, 148 | "token": ")" 149 | }, 150 | { 151 | "probability": 0.9954100251197815, 152 | "token": "" 153 | } 154 | ] 155 | }, 156 | { 157 | "answer": "Paris (not a typo, Paris is the capital city, not the capital state or province)", 158 | "finish_reason": "eos_token", 159 | "probability": 0.00013999200194631084, 160 | "tokens": [ 161 | { 162 | "probability": 0.9997292160987854, 163 | "token": "Paris" 164 | }, 165 | { 166 | "probability": 0.7607424259185791, 167 | "token": "(" 168 | }, 169 | { 170 | "probability": 0.19734331965446472, 171 | "token": "not" 172 | }, 173 | { 174 | "probability": 0.8402746319770813, 175 | "token": "a" 176 | }, 177 | { 178 | "probability": 0.5315064191818237, 179 | "token": "typ" 180 | }, 181 | { 182 | "probability": 0.9996547698974609, 183 | "token": "o" 184 | }, 185 | { 186 | "probability": 0.9821969866752625, 187 | "token": "," 188 | }, 189 | { 190 | "probability": 0.7557774782180786, 191 | "token": "Paris" 192 | }, 193 | { 194 | "probability": 0.9985285401344299, 195 | "token": "is" 196 | }, 197 | { 198 | "probability": 0.7320148348808289, 199 | "token": "the" 200 | }, 201 | { 202 | "probability": 0.670669436454773, 203 | "token": "capital" 204 | }, 205 | { 206 | "probability": 0.5634616017341614, 207 | "token": "city" 208 | }, 209 | { 210 | "probability": 0.48795729875564575, 211 | "token": "," 212 | }, 213 | { 214 | "probability": 0.5095872282981873, 215 | "token": "not" 216 | }, 217 | { 218 | "probability": 0.9860091805458069, 219 | "token": "the" 220 | }, 221 | { 222 | "probability": 0.8910145163536072, 223 | "token": "capital" 224 | }, 225 | { 226 | "probability": 0.486347496509552, 227 | "token": "state" 228 | }, 229 | { 230 | "probability": 0.19215478003025055, 231 | "token": "or" 232 | }, 233 | { 234 | "probability": 0.7145478129386902, 235 | "token": "province" 236 | }, 237 | { 238 | "probability": 0.7032752633094788, 239 | "token": ")" 240 | }, 241 | { 242 | "probability": 0.99314945936203, 243 | "token": "" 244 | } 245 | ] 246 | } 247 | ] -------------------------------------------------------------------------------- /llmwalk/llm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import heapq 4 | from collections.abc import Iterable 5 | from dataclasses import dataclass 6 | from datetime import datetime 7 | from functools import cache 8 | 9 | import mlx.core as mx 10 | from mlx.nn import Module 11 | from mlx_lm.models.cache import KVCache 12 | from mlx_lm.tokenizer_utils import SPMStreamingDetokenizer, TokenizerWrapper 13 | from sortedcontainers import SortedList 14 | 15 | 16 | @dataclass 17 | class SearchConfig: 18 | """Configuration for the tree search algorithm.""" 19 | 20 | n: int = 10 21 | top_k: int = 50 22 | top_p: float = 1.0 23 | temperature: float = 1.0 24 | min_probability: float = 0.0001 25 | 26 | 27 | @dataclass 28 | class OutputToken: 29 | token: int 30 | prob: float 31 | 32 | 33 | @dataclass(eq=False) 34 | class Branch: 35 | parent: Branch | None 36 | token: OutputToken | None 37 | probability: float = 1.0 38 | finish_reason: str | None = None 39 | cache: list[KVCache] | None = None 40 | 41 | def answer_tokens(self) -> list[OutputToken]: 42 | toks: list[OutputToken] = [] 43 | cur: Branch | None = self 44 | while cur is not None and cur.token is not None: 45 | toks.append(cur.token) 46 | cur = cur.parent 47 | toks.reverse() 48 | return toks 49 | 50 | 51 | def _clone_kv_cache(c: KVCache) -> KVCache: 52 | cloned = KVCache() 53 | cloned.offset = c.offset 54 | cloned.keys = mx.array(c.keys) if c.keys is not None else None 55 | cloned.values = mx.array(c.values) if c.values is not None else None 56 | return cloned 57 | 58 | 59 | def _clone_prompt_cache(cache: list[KVCache]) -> list[KVCache]: 60 | return [_clone_kv_cache(c) for c in cache] 61 | 62 | 63 | def _infer_num_layers(model: Module) -> int | None: 64 | for obj in (model, getattr(model, "model", None)): 65 | if obj is None: 66 | continue 67 | 68 | n = getattr(obj, "num_hidden_layers", None) 69 | if isinstance(n, int) and n > 0: 70 | return n 71 | 72 | layers = getattr(obj, "layers", None) 73 | if layers is None: 74 | continue 75 | try: 76 | n_layers = len(layers) # type: ignore[arg-type] 77 | except TypeError: 78 | n_layers = None 79 | if isinstance(n_layers, int) and n_layers > 0: 80 | return n_layers 81 | 82 | return None 83 | 84 | 85 | def _make_kv_cache(model: Module) -> list[KVCache] | None: 86 | make_cache = getattr(model, "make_cache", None) 87 | if callable(make_cache): 88 | cache = make_cache() 89 | if cache is None: 90 | return None 91 | if isinstance(cache, list): 92 | return cache 93 | if isinstance(cache, Iterable): 94 | return list(cache) 95 | return None 96 | 97 | n_layers = _infer_num_layers(model) 98 | if n_layers is None: 99 | return None 100 | return [KVCache() for _ in range(n_layers)] 101 | 102 | 103 | def _top_tokens_from_logprobs( 104 | logprobs: mx.array, config: SearchConfig 105 | ) -> list[OutputToken]: 106 | vocab = int(logprobs.shape[0]) 107 | k = min(config.top_k, vocab) 108 | part = mx.argpartition(logprobs, vocab - k) 109 | top_idx = part[vocab - k :] 110 | top_lp = mx.take(logprobs, top_idx) 111 | order = mx.argsort(top_lp)[::-1] 112 | sorted_indices = mx.take(top_idx, order) 113 | 114 | if config.temperature == 1.0: 115 | probs = mx.exp(mx.take(logprobs, sorted_indices)) 116 | else: 117 | lse = mx.logsumexp(logprobs / config.temperature, axis=-1) 118 | probs = mx.exp(mx.take(logprobs, sorted_indices) / config.temperature - lse) 119 | 120 | mx.eval(sorted_indices, probs) 121 | token_ids = list(sorted_indices.astype(mx.int64).tolist()) # type: ignore[arg-type] 122 | token_probs = list(mx.reshape(probs, (-1,)).tolist()) # type: ignore[arg-type] 123 | 124 | output_tokens: list[OutputToken] = [] 125 | cum_prob = 0.0 126 | for token_id, prob in zip(token_ids, token_probs): # type: ignore[call-arg] 127 | if output_tokens and cum_prob >= config.top_p: 128 | break 129 | output_tokens.append(OutputToken(token=token_id, prob=float(prob))) 130 | cum_prob += float(prob) 131 | 132 | return output_tokens 133 | 134 | 135 | class PromptTreeSearch: 136 | model: Module 137 | tokenizer: TokenizerWrapper 138 | prompt: list[int] 139 | config: SearchConfig 140 | _frontier: list[tuple[float, int, Branch]] 141 | _finished_eos: SortedList # SortedList[Branch] 142 | _heap_counter: int = 0 143 | _stopped: bool = False 144 | 145 | tokens: int = 0 146 | pruned: int = 0 147 | 148 | _low_watermark: float | None = None 149 | _start: datetime | None = None 150 | _end: datetime | None = None 151 | 152 | def __init__( 153 | self, 154 | model: Module, 155 | tokenizer: TokenizerWrapper, 156 | prompt: list[int], 157 | config: SearchConfig | None = None, 158 | ) -> None: 159 | self.model = model 160 | self.tokenizer = tokenizer 161 | self.prompt = prompt 162 | self.config = config or SearchConfig() 163 | self._frontier = [] 164 | self._finished_eos = SortedList(key=lambda b: -b.probability) 165 | 166 | root = Branch(parent=None, token=None) 167 | self.branches = SortedList(key=lambda b: -b.probability) 168 | self.branches.add(root) 169 | self._push_frontier(root) 170 | 171 | def _run_model(self, cache: list[KVCache] | None, input_ids: list[int]) -> mx.array: 172 | self.tokens += 1 173 | inputs = mx.array([input_ids], mx.int32) 174 | logits = self.model(inputs, cache=cache)[:, -1, :] 175 | logits = logits.astype(mx.float32) 176 | logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) 177 | return mx.reshape(logprobs, (-1,)) 178 | 179 | @cache 180 | def is_sentencepiece_model(self) -> bool: 181 | return isinstance(self.tokenizer.detokenizer, SPMStreamingDetokenizer) 182 | 183 | @property 184 | def active(self) -> int: 185 | return len(self._frontier) 186 | 187 | def top_branches(self, n: int) -> list[Branch]: 188 | return list(self.branches[:n]) 189 | 190 | def _push_frontier(self, branch: Branch) -> None: 191 | self._heap_counter += 1 192 | heapq.heappush( 193 | self._frontier, (-branch.probability, self._heap_counter, branch) 194 | ) 195 | 196 | def _update_low_watermark(self) -> None: 197 | if len(self._finished_eos) < self.config.n: 198 | self._low_watermark = None 199 | return 200 | self._low_watermark = self._finished_eos[self.config.n - 1].probability 201 | 202 | def stop(self) -> None: 203 | self._stopped = True 204 | 205 | def should_stop(self) -> bool: 206 | if self._stopped: 207 | return True 208 | if not self._frontier: 209 | return True 210 | if self._low_watermark is None: 211 | return False 212 | best_prob = -self._frontier[0][0] 213 | return best_prob < self._low_watermark 214 | 215 | def step(self) -> None: 216 | if self._start is None: 217 | self._start = datetime.now() 218 | 219 | if self.should_stop(): 220 | return 221 | 222 | _, _, branch = heapq.heappop(self._frontier) 223 | 224 | if self._low_watermark is not None and branch.probability < self._low_watermark: 225 | self.pruned += 1 226 | branch.finish_reason = "pruned" 227 | branch.cache = None 228 | return 229 | 230 | if branch.token is None: # root branch 231 | cache_after = _make_kv_cache(self.model) 232 | logprobs = self._run_model(cache_after, self.prompt) 233 | else: 234 | if branch.cache is None: 235 | input_ids = self.prompt + [t.token for t in branch.answer_tokens()] 236 | cache_after = None 237 | logprobs = self._run_model(cache_after, input_ids) 238 | else: 239 | cache_after = _clone_prompt_cache(branch.cache) 240 | logprobs = self._run_model(cache_after, [branch.token.token]) 241 | 242 | self.branches.remove(branch) 243 | 244 | new_branches: list[Branch] = [] 245 | frontier_add: list[Branch] = [] 246 | eos_add: list[Branch] = [] 247 | for tok in _top_tokens_from_logprobs(logprobs, self.config): 248 | new_prob = branch.probability * tok.prob 249 | 250 | if new_prob < self.config.min_probability: 251 | self.pruned += 1 252 | new_branch = Branch( 253 | parent=branch, 254 | token=tok, 255 | probability=new_prob, 256 | finish_reason="low_probability", 257 | ) 258 | new_branches.append(new_branch) 259 | continue 260 | 261 | if tok.token in self.tokenizer.eos_token_ids: 262 | new_branch = Branch( 263 | parent=branch, 264 | token=tok, 265 | probability=new_prob, 266 | finish_reason="eos_token", 267 | ) 268 | new_branches.append(new_branch) 269 | eos_add.append(new_branch) 270 | continue 271 | 272 | new_branch = Branch( 273 | parent=branch, 274 | token=tok, 275 | probability=new_prob, 276 | cache=cache_after, 277 | ) 278 | new_branches.append(new_branch) 279 | frontier_add.append(new_branch) 280 | 281 | for b in new_branches: 282 | self.branches.add(b) 283 | for b in frontier_add: 284 | self._push_frontier(b) 285 | for b in eos_add: 286 | self._finished_eos.add(b) 287 | if eos_add: 288 | self._update_low_watermark() 289 | 290 | branch.cache = None 291 | -------------------------------------------------------------------------------- /llmwalk/cli.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | # Check for --offline flag before any HuggingFace imports 4 | import os 5 | import sys 6 | 7 | from mlx_lm.tokenizer_utils import SPMStreamingDetokenizer 8 | 9 | if "--offline" in sys.argv: 10 | os.environ["HF_HUB_OFFLINE"] = "1" 11 | 12 | # Metal GPU trace capture requires this to be set at process start. 13 | # See https://ml-explore.github.io/mlx/build/html/dev/metal_debugger.html 14 | if any( 15 | arg == "--metal-capture" or arg.startswith("--metal-capture=") for arg in sys.argv 16 | ): 17 | os.environ.setdefault("MTL_CAPTURE_ENABLED", "1") 18 | 19 | import argparse 20 | import cProfile 21 | import csv 22 | import io 23 | import json 24 | import pstats 25 | import time 26 | from datetime import datetime 27 | from importlib.metadata import version as pkg_version 28 | 29 | import mlx.core as mx 30 | from mlx_lm import load 31 | from rich.console import Console, Group 32 | from rich.live import Live 33 | from rich.style import Style 34 | from rich.table import Table 35 | from rich.text import Text 36 | 37 | from .llm import PromptTreeSearch, SearchConfig 38 | 39 | args: argparse.Namespace 40 | 41 | _BAND_COLORS = [ 42 | "#7f7f7f", # 0-10%: grey 43 | "#ff3b30", # 10-20%: red 44 | "#ff6a00", # 20-30%: orange 45 | "#ff8c00", # 30-40%: dark orange 46 | "#ffb000", # 40-50%: amber 47 | "#ffd000", # 50-60%: yellow 48 | "#d7e500", # 60-70%: yellow-green 49 | "#a8e600", # 70-80%: greenish 50 | "#4cd964", # 80-90%: green 51 | "#00c853", # 90-100%: bright green 52 | ] 53 | _BAND_STYLES = [Style(color=c) for c in _BAND_COLORS] 54 | 55 | 56 | def style_for_token_probability(prob: float) -> Style: 57 | if prob != prob: # NaN 58 | prob = 0.0 59 | elif prob < 0.0: 60 | prob = 0.0 61 | elif prob > 1.0: 62 | prob = 1.0 63 | 64 | band = min(int(prob * 10), 9) # 0..9 65 | return _BAND_STYLES[band] 66 | 67 | 68 | def render_probability_legend() -> Text: 69 | legend = Text("Legend: ", style="bold", no_wrap=True, overflow="ellipsis") 70 | for i in range(9, -1, -1): 71 | style = style_for_token_probability((i + 0.5) / 10) 72 | if i == 9: 73 | label = "90%+" 74 | elif i == 0: 75 | label = "0–10%" 76 | else: 77 | label = f"{i * 10}%+" 78 | 79 | legend.append("■", style=style) 80 | legend.append(f" {label}") 81 | if i != 0: 82 | legend.append(" ") 83 | 84 | return legend 85 | 86 | 87 | _PROBABILITY_LEGEND = render_probability_legend() 88 | 89 | 90 | def render_branches(walker: PromptTreeSearch) -> Table: 91 | table = Table(expand=True, show_header=False, show_edge=False) 92 | table.add_column("Prob.", justify="right", no_wrap=True, width=8) 93 | table.add_column("Answer", ratio=1) 94 | 95 | branches = walker.top_branches(args.n) 96 | for i in range(args.n): 97 | if i >= len(branches): 98 | table.add_row("", "", "") 99 | continue 100 | 101 | branch = branches[i] 102 | answer_text = Text() 103 | 104 | if walker.is_sentencepiece_model(): 105 | answer_text = Text( 106 | walker.tokenizer.decode([tok.token for tok in branch.answer_tokens()]) # type: ignore[call-arg] 107 | ) 108 | else: 109 | for tok in branch.answer_tokens(): 110 | piece = walker.tokenizer.decode( # type: ignore[call-arg] 111 | [tok.token], 112 | skip_special_tokens=True, 113 | ) 114 | if not piece: 115 | continue 116 | piece = piece.replace("\n", "\\n") 117 | answer_text.append(piece, style=style_for_token_probability(tok.prob)) 118 | 119 | status: Text 120 | if branch.finish_reason == "eos_token": 121 | status = Text("✓", style="green") 122 | elif branch.finish_reason == "low_probability": 123 | status = Text("?", style="yellow") 124 | elif branch.finish_reason == "pruned": 125 | status = Text("-", style="dim") 126 | else: 127 | status = Text(" ") 128 | 129 | probability_text = Text.assemble(status, f"{branch.probability * 100:6.2f}%") 130 | 131 | table.add_row(probability_text, answer_text) 132 | 133 | return table 134 | 135 | 136 | def render_stats_bar(walker: PromptTreeSearch) -> Table: 137 | elapsed = (datetime.now() - walker._start).total_seconds() if walker._start else 0.0 138 | tps = walker.tokens / elapsed if elapsed > 0 else 0.0 139 | left = f"frontier {walker.active} pruned {walker.pruned} tps {tps:0.1f}" 140 | grid = Table.grid(expand=True) 141 | grid.add_column(ratio=1) 142 | grid.add_column(justify="right", no_wrap=True) 143 | grid.add_row( 144 | Text(left, overflow="ellipsis", no_wrap=True), 145 | Text( 146 | f"top_k={args.top_k} top_p={args.top_p} temp={args.temperature}", 147 | no_wrap=True, 148 | ), 149 | ) 150 | return grid 151 | 152 | 153 | def render_view(walker: PromptTreeSearch) -> Group: 154 | if args.minimal: 155 | return Group(render_branches(walker)) 156 | return Group( 157 | _PROBABILITY_LEGEND, 158 | render_branches(walker), 159 | render_stats_bar(walker), 160 | ) 161 | 162 | 163 | def format_results_json(walker: PromptTreeSearch) -> str: 164 | branches = walker.top_branches(args.n) 165 | results = [] 166 | for branch in branches: 167 | tokens = [] 168 | for tok in branch.answer_tokens(): 169 | tokens.append( 170 | { 171 | "token": walker.tokenizer.decode([tok.token]), # type: ignore[call-arg] 172 | "probability": tok.prob, 173 | } 174 | ) 175 | answer_text = walker.tokenizer.decode( # type: ignore[attr-defined] 176 | [tok.token for tok in branch.answer_tokens()] 177 | ) 178 | results.append( 179 | { 180 | "answer": answer_text, 181 | "probability": branch.probability, 182 | "finish_reason": branch.finish_reason, 183 | "tokens": tokens, 184 | } 185 | ) 186 | return json.dumps(results, indent=2) 187 | 188 | 189 | def format_results_csv(walker: PromptTreeSearch) -> str: 190 | branches = walker.top_branches(args.n) 191 | output = io.StringIO() 192 | writer = csv.writer(output) 193 | writer.writerow(["answer", "probability", "finish_reason"]) 194 | for branch in branches: 195 | answer_text = walker.tokenizer.decode( # type: ignore[attr-defined] 196 | [tok.token for tok in branch.answer_tokens()] 197 | ) 198 | writer.writerow([answer_text, branch.probability, branch.finish_reason or ""]) 199 | return output.getvalue() 200 | 201 | 202 | def run() -> None: 203 | load_resp = load(args.model) 204 | model = load_resp[0] 205 | tokenizer = load_resp[1] 206 | 207 | model.eval() 208 | 209 | prompt = tokenizer.apply_chat_template( # type: ignore[call-arg] 210 | [{"role": "user", "content": args.prompt}], 211 | add_generation_prompt=True, 212 | ) 213 | 214 | config = SearchConfig( 215 | n=args.n, 216 | top_k=args.top_k, 217 | top_p=args.top_p, 218 | temperature=args.temperature, 219 | min_probability=args.min_probability, 220 | ) 221 | 222 | walker = PromptTreeSearch(model, tokenizer, prompt, config) 223 | 224 | profiler: cProfile.Profile | None = None 225 | if args.cprofile is not None: 226 | profiler = cProfile.Profile() 227 | 228 | capture_enabled = False 229 | if args.metal_capture is not None and hasattr(mx, "metal"): 230 | start_capture = getattr(mx.metal, "start_capture", None) 231 | if callable(start_capture): 232 | try: 233 | start_capture(args.metal_capture) 234 | capture_enabled = True 235 | except Exception as exc: 236 | print( 237 | f"Warning: failed to start Metal capture: {exc}", 238 | file=sys.stderr, 239 | ) 240 | 241 | steps = 0 242 | 243 | def do_step_loop() -> None: 244 | nonlocal steps 245 | try: 246 | while not walker.should_stop(): 247 | walker.step() 248 | steps += 1 249 | if args.max_steps > 0 and steps >= args.max_steps: 250 | walker.stop() 251 | break 252 | except KeyboardInterrupt: 253 | walker.stop() 254 | 255 | try: 256 | if args.format: 257 | # Machine-readable output: no interactive display 258 | try: 259 | if profiler is not None: 260 | profiler.enable() 261 | do_step_loop() 262 | finally: 263 | if profiler is not None: 264 | profiler.disable() 265 | 266 | if args.format == "json": 267 | print(format_results_json(walker)) 268 | elif args.format == "csv": 269 | print(format_results_csv(walker), end="") 270 | else: 271 | # Interactive display 272 | console = Console() 273 | try: 274 | with Live(console=console, transient=False) as live: 275 | interval = max(0.1, args.stats_interval) 276 | next_render = time.monotonic() 277 | live.update(render_view(walker)) 278 | try: 279 | if profiler is not None: 280 | profiler.enable() 281 | while not walker.should_stop(): 282 | walker.step() 283 | steps += 1 284 | if args.max_steps > 0 and steps >= args.max_steps: 285 | walker.stop() 286 | break 287 | if ( 288 | args.stats_interval > 0 289 | and time.monotonic() >= next_render 290 | ): 291 | live.update(render_view(walker)) 292 | next_render = time.monotonic() + interval 293 | finally: 294 | if profiler is not None: 295 | profiler.disable() 296 | live.update(render_view(walker)) 297 | except KeyboardInterrupt: 298 | walker.stop() 299 | finally: 300 | if capture_enabled and hasattr(mx, "metal"): 301 | stop_capture = getattr(mx.metal, "stop_capture", None) 302 | if callable(stop_capture): 303 | try: 304 | stop_capture() 305 | except Exception as exc: 306 | print( 307 | f"Warning: failed to stop Metal capture: {exc}", 308 | file=sys.stderr, 309 | ) 310 | 311 | if profiler is not None: 312 | if args.cprofile == "-": 313 | stream = io.StringIO() 314 | stats = pstats.Stats(profiler, stream=stream) 315 | stats.strip_dirs().sort_stats("cumtime").print_stats(50) 316 | print(stream.getvalue(), file=sys.stderr, end="") 317 | else: 318 | profiler.dump_stats(args.cprofile) 319 | 320 | 321 | def parse_args(argv: list[str] | None) -> argparse.Namespace: 322 | parser = argparse.ArgumentParser() 323 | parser.add_argument( 324 | "-p", 325 | "--prompt", 326 | default="What is 2+2?", 327 | help="The prompt to walk. Can be a file path, in which case the file contents will be used.", 328 | ) 329 | parser.add_argument( 330 | "-m", 331 | "--model", 332 | default="mlx-community/Llama-3.2-1B-Instruct-4bit", 333 | help="Which model to use. Must be an mlx-community/ model from HuggingFace.", 334 | ) 335 | parser.add_argument( 336 | "-n", 337 | default=10, 338 | type=int, 339 | help="The top N answers to track. Search will stop after the top N answers have been found, so increasing this can increase runtime.", 340 | ) 341 | parser.add_argument( 342 | "--min-probability", 343 | type=float, 344 | default=0.0001, 345 | help="A minimum probability threshold for branches. If a branch becomes less likely than this, we stop walking it. Lowering this can increase runtime.", 346 | ) 347 | parser.add_argument( 348 | "--top-k", 349 | dest="top_k", 350 | default=50, 351 | type=int, 352 | help="How many tokens to branch on at each step. Increasing this will increase runtime.", 353 | ) 354 | parser.add_argument( 355 | "--top-p", 356 | dest="top_p", 357 | default=1.0, 358 | type=float, 359 | help="Like --top-k, this will limit the tokens branched on at each step, but by cumulative probability instead of a static number. Decreasing this can reduce runtime.", 360 | ) 361 | parser.add_argument( 362 | "--temperature", 363 | type=float, 364 | default=1.0, 365 | help='Sampling temperature, decreasing this will "sharpen" the token probabilities and make high-probability tokens more likely. Increasing it will make the distribution more uniform, making less likely tokens more likely.', 366 | ) 367 | parser.add_argument( 368 | "--stats-interval", 369 | type=float, 370 | default=0.1, 371 | help="In interactive mode, i.e. no --format, this will control how often the table is updated.", 372 | ) 373 | parser.add_argument( 374 | "--format", 375 | choices=["csv", "json"], 376 | default=None, 377 | help="Output format for machine-readable output (disables interactive display)", 378 | ) 379 | parser.add_argument( 380 | "--max-steps", 381 | type=int, 382 | default=0, 383 | help="Stop after this many search steps (0 = no limit). Useful for profiling and quick runs.", 384 | ) 385 | parser.add_argument( 386 | "--cprofile", 387 | default=None, 388 | help="Write a Python cProfile to this path (use '-' to print top stats to stderr).", 389 | ) 390 | parser.add_argument( 391 | "--metal-capture", 392 | default=None, 393 | help="Write a Metal capture trace to this path (macOS only).", 394 | ) 395 | parser.add_argument( 396 | "--offline", 397 | action="store_true", 398 | help="Run in offline mode (skip HuggingFace Hub network requests)", 399 | ) 400 | parser.add_argument( 401 | "--minimal", 402 | action="store_true", 403 | help="Hide the legend and stats bar for a cleaner display", 404 | ) 405 | parser.add_argument( 406 | "--version", 407 | action="version", 408 | version=f"%(prog)s {pkg_version('llmwalk')}", 409 | ) 410 | 411 | raw = list(sys.argv[1:] if argv is None else argv) 412 | filtered = [a for a in raw if a != "--"] 413 | parsed = parser.parse_args(filtered) 414 | 415 | if parsed.temperature <= 0: 416 | parser.error("--temperature must be > 0") 417 | if not (0 < parsed.top_p <= 1): 418 | parser.error("--top-p must be in the range (0, 1]") 419 | if parsed.max_steps < 0: 420 | parser.error("--max-steps must be >= 0") 421 | 422 | # If prompt is a file path, read the file contents 423 | from pathlib import Path 424 | 425 | prompt_path = Path(parsed.prompt) 426 | if prompt_path.is_file(): 427 | parsed.prompt = prompt_path.read_text() 428 | 429 | return parsed 430 | 431 | 432 | def main(argv: list[str] | None = None) -> None: 433 | global args 434 | args = parse_args(argv) 435 | run() 436 | --------------------------------------------------------------------------------