├── .gitignore ├── LICENSE ├── README.md ├── media ├── default.gif ├── default.mp4 ├── n8loom.gif └── n8loom.mp4 ├── pyproject.toml ├── requirements.txt └── src └── n8loom ├── __init__.py ├── cache_utils.py ├── examples ├── __init__.py ├── majority-benchmark.py ├── server.py ├── static │ ├── index.html │ └── main.js └── token-benchmark.py ├── loom.py ├── models └── llama.py ├── sample_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .aider* 2 | .env 3 | Llama-3.2-3B-8bit 4 | Llama-3.2-3B-Instruct-4bit 5 | Mistral-Small-24B-Instruct-4bit 6 | /dist 7 | mistralai 8 | src/n8loom.egg-info 9 | __pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # N8Loom: For Fast Tree-of-Thought Inference 2 | 3 | N8Loom is a Python library built on top of [mlx_lm](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm) and [Transformers](https://huggingface.co/transformers/) that enables structured, tree-based interactions with language models. It's main selling point is its KV cache tree: it stores individual 'fragments' of the KV cache at each node in the tree, which it then concatenates to form the full cache when generating text from a node in the tree. This allows maintaining the cache of many different branches of the tree in parallel, and then merging them together when needed. This gives the inference improvements of caching without the overhead of storing the entire prefix cache at each node. 4 | 5 | Below is a visualization of the critical difference when generating from a single node in the tree - a standard prompt cache must recompute the cache for parent nodes each time, while the KV cache tree can simply concatenate the cache fragments stored at each node to form the full cache. 6 | 7 | | Standard Prompt Cache | Loom Cache | 8 | |------------------------|------------| 9 | | ![](media/default.gif) | ![](media/n8loom.gif) | 10 | 11 | 12 | It additionally provides a set of utilities to manage internal model caches, generate text in parallel and stream mode, and build reasoning trees where each node represents a model “thought” (called a *Heddle*) that can branch off into multiple potential continuations. The library also includes a FastAPI server example for deploying a web service. 13 | 14 | ## Table of Contents 15 | 16 | - [N8Loom: For Fast Tree-of-Thought Inference](#n8loom-for-fast-tree-of-thought-inference) 17 | - [Table of Contents](#table-of-contents) 18 | - [Overview](#overview) 19 | - [Installation](#installation) 20 | - [Usage](#usage) 21 | - [Basic Script Example](#basic-script-example) 22 | - [Running the FastAPI Server](#running-the-fastapi-server) 23 | - [To run the server locally:](#to-run-the-server-locally) 24 | - [API Documentation](#api-documentation) 25 | - [Core Classes (`loom.py`)](#core-classes-loompy) 26 | - [`class Heddle`](#class-heddle) 27 | - [`class Loom` *(subclass of Heddle)*](#class-loom-subclass-of-heddle) 28 | - [Generation Utilities (`utils.py`)](#generation-utilities-utilspy) 29 | - [Cache Utilities (`cache_utils.py`)](#cache-utilities-cache_utilspy) 30 | - [Contributing](#contributing) 31 | - [License](#license) 32 | - [Acknowledgements](#acknowledgements) 33 | 34 | ## Overview 35 | 36 | N8Loom makes it easy to interact with language models by allowing you to: 37 | - **Cache and manipulate intermediate model states.** 38 | Utilities in `cache_utils.py` extract, clip, and fuse key-value caches (KV caches) for each model layer. 39 | - **Create and manage reasoning trees.** 40 | The core abstractions are the `Heddle` and `Loom` classes (in `loom.py`), which represent individual reasoning nodes and the overall prompt tree respectively. 41 | - **Generate responses in batches and streams.** 42 | Use the functions in `utils.py` to prefill caches, sample model outputs in parallel, or yield token-by-token updates. 43 | 44 | ## Installation 45 | 46 | Ensure you have Python 3.7+ installed. Then, install the required dependencies: 47 | 48 | ```bash 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | Or install: 53 | 54 | ```bash 55 | pip install n8loom 56 | ``` 57 | 58 | ## Usage 59 | 60 | ### Basic Script Example 61 | 62 | Note that n8loom only works w/ the llama architecture. 63 | 64 | Below is an example (from `examples/reflection.py`) demonstrating how to load a model, create a reasoning tree (a *Loom*), and expand it with multiple potential answers: 65 | 66 | ```python 67 | from mlx_lm import load 68 | from n8loom import Loom, load_for_loom 69 | 70 | # Load the model and tokenizer 71 | model, tokenizer = load_for_loom("Llama-3.2-3B-Instruct-4bit") 72 | 73 | # Define a problem prompt 74 | prompt = ( 75 | "Tobias is buying a new pair of shoes that costs $95. He has been saving up his money each month " 76 | "for the past three months. He gets a $5 allowance a month. He also mows lawns and shovels driveways. " 77 | "He charges $15 to mow a lawn and $7 to shovel. After buying the shoes, he has $15 in change. " 78 | "If he mows 4 lawns, how many driveways did he shovel?" 79 | ) 80 | 81 | # Create a Loom (root of the reasoning tree) 82 | root = Loom(model, tokenizer, prompt) 83 | 84 | # Add an initial text child to guide the model's reasoning 85 | assistant_start = root.add_text_child("I will solve this problem step by step and be mindful of mistakes.") 86 | 87 | # Expand the reasoning tree by generating 8 potential response branches 88 | assistant_start.ramify(n=8, temp=0.6, max_tokens=512, min_p=0.05) 89 | 90 | # Apply further reasoning to all leaf nodes, incorporating reflection 91 | answers = assistant_start.apply_at_leaves( 92 | lambda x: x.ramify("\n...Wait. I need to look at the problem again. Let's think about what I could've gotten wrong. I could've") 93 | if x.terminal else None, 94 | lambda x: x.ramify(n=2, temp=0.6, max_tokens=512, min_p=0.05), 95 | lambda x: x.crown() 96 | ) 97 | 98 | # Print the generated answers 99 | for i, answer in enumerate(answers): 100 | print(f"Answer {i+1}:\n{answer}\n") 101 | ``` 102 | 103 | ### Running the FastAPI Server 104 | 105 | The library also comes with an example FastAPI server (see `examples/server.py`) that exposes endpoints to manage models, create looms, expand nodes, and export/import reasoning trees. 106 | 107 | Make sure you have an mlx-lm model in the root directory (the parent directory of src). You can quickly do this w/: 108 | 109 | ```bash 110 | pip install huggingface_hub hf_transfer 111 | 112 | huggingface-cli download --local-dir Llama-3.2-3B-Instruct-4bit mlx-community/Llama-3.2-3B-Instruct-4bit 113 | ``` 114 | 115 | #### To run the server locally: 116 | 117 | ```bash 118 | python src/n8loom/examples/server.py 119 | ``` 120 | 121 | ## API Documentation 122 | 123 | ### Core Classes (`loom.py`) 124 | #### `class Heddle` 125 | 126 | A *Heddle* represents a node in a reasoning tree. Each node contains a segment of text, its tokenized form, cache fragments from the model, and potential child nodes. This structure enables branching reasoning and interactive exploration of model-generated responses. 127 | 128 | - **Attributes:** 129 | - `model`: The language model (an instance of `nn.Module`) used for generating responses and cache fragments. 130 | - `tokenizer`: The tokenizer (a `PreTrainedTokenizer` or `TokenizerWrapper`) used to encode text into tokens and decode tokens back to text. 131 | - `text`: The text content of this node. 132 | - `tokens`: The tokenized representation (a list of token IDs) for the node’s text. 133 | - `frag`: A list of cache fragments (`KVFrag`) that store model cache information corresponding to the tokens. 134 | - `children`: A list of child Heddle nodes representing subsequent branches in the reasoning tree. 135 | - `parent`: A reference to the parent Heddle node (or `None` if this node is the root). 136 | - `terminal`: A Boolean flag indicating whether further expansion (generation) is disallowed. 137 | 138 | - **Constructor:** 139 | - `__init__(model, tokenizer, text, frags, children, parent=None, trim_toks=1)` 140 | - **Purpose:** Initializes a new Heddle node. 141 | - **Parameters:** 142 | - `model`: The language model to use. 143 | - `tokenizer`: The tokenizer to encode/decode text. 144 | - `text`: The text prompt for the node. 145 | - `frags`: An optional list of pre-computed cache fragments. If `None`, the fragments are generated based on the text. 146 | - `children`: An optional list of child nodes (defaults to an empty list if not provided). 147 | - `parent`: The parent node (defaults to `None` for the root). 148 | - `trim_toks`: The number of initial tokens to trim from the token list (default is 1). 149 | 150 | - **Key Methods:** 151 | - `clip(token_limit: int)` 152 | - **Purpose:** Clips the node’s tokens, text, and cache fragments to a specified token limit. 153 | - **Details:** 154 | - If `token_limit` is negative, it retains `len(tokens) + token_limit` tokens. 155 | - If the number of tokens exceeds the limit, the node’s tokens are truncated, the text is updated via decoding, the cache fragments are clipped accordingly, and all children are removed. 156 | - **Returns:** The current Heddle instance. 157 | 158 | - `trim(token_trim: int)` 159 | - **Purpose:** Removes the last `token_trim` tokens from the node. 160 | - **Details:** Internally calls `clip` with a negative token limit. 161 | - **Returns:** The current Heddle instance. 162 | 163 | - `to_leaf()` 164 | - **Purpose:** Converts the current node into a leaf node by removing all its children. 165 | - **Returns:** The current Heddle instance. 166 | 167 | - `add_child(child: Heddle)` 168 | - **Purpose:** Adds an existing Heddle node as a child. 169 | - **Details:** Also sets the added child’s `parent` attribute to this node. 170 | - **Returns:** The added child node. 171 | 172 | - `add_text_child(text: str)` 173 | - **Purpose:** Creates a new child node from a text prompt and adds it as a child. 174 | - **Returns:** The newly created child node. 175 | 176 | - `remove_child(child: Heddle)` 177 | - **Purpose:** Removes a specified child node from the current node. 178 | - **Returns:** The removed child node. 179 | 180 | - `get_prefix_cache() -> List[KVCache]` 181 | - **Purpose:** Retrieves the cumulative cache from the root node up to the current node. 182 | - **Details:** Collects and fuses cache fragments from all ancestor nodes to form a complete context cache. 183 | - **Returns:** A list of fused `KVCache` objects. 184 | 185 | - `make_children(n: int = 4, temp: float = 0.8, max_tokens: int = 8, min_p: float = 0.05, **kwargs)` 186 | - **Purpose:** Generates multiple child nodes using batched model generation. 187 | - **Details:** 188 | - Uses the current node’s cumulative cache as context. 189 | - Calls a batched generation routine to generate new text completions. 190 | - For each generated text, a new child is created. 191 | - If generation signals termination (via an `ended` flag), the child is marked as terminal. 192 | - Clears the model cache after generation. 193 | - **Parameters:** 194 | - `n`: Number of children to generate. 195 | - `temp`: Sampling temperature. 196 | - `max_tokens`: Maximum number of tokens to generate for each child. 197 | - `min_p`: Minimum probability threshold for generation. 198 | - **Returns:** A list of newly created child nodes. 199 | 200 | - `ramify(arg: Optional[Union[str, List[str]]] = None, **kwargs)` 201 | - **Purpose:** Expands the node by either adding text children or by generating new responses. 202 | - **Details:** 203 | - If `arg` is a string, creates a single child using that text. 204 | - If `arg` is a list of strings, creates a child for each string. 205 | - If `arg` is not provided, uses model generation: 206 | - If `stream=True` is provided in `kwargs`, streaming generation is used via `make_child_stream`. 207 | - Otherwise, batched generation is performed via `make_children`. 208 | - **Returns:** A single child, a list of children, or a streaming generator, depending on the input. 209 | 210 | - `make_child_stream(n: int = 4, temp: float = 0.8, max_tokens: int = 8, min_p: float = 0.05, **kwargs)` 211 | - **Purpose:** Generates child nodes using a streaming generation process. 212 | - **Details:** 213 | - Yields incremental updates (as dictionaries) from the generation process. 214 | - Upon receiving a final update (indicated by `"type": "final"`), creates child nodes from the generated texts. 215 | - Clears the model cache after finalization. 216 | - **Parameters:** 217 | - `n`: Number of children to generate. 218 | - `temp`: Sampling temperature. 219 | - `max_tokens`: Maximum number of tokens to generate for each child. 220 | - `min_p`: Minimum probability threshold for generation. 221 | - **Yields:** Updates (as dictionaries) during the generation stream. 222 | - **Returns:** A list of newly created child nodes after the final update. 223 | 224 | - `get_prefix_text(exclude: int = 0) -> str` 225 | - **Purpose:** Retrieves concatenated text from all ancestor nodes (including the current node). 226 | - **Parameters:** 227 | - `exclude`: Number of initial nodes to exclude from the prefix (default is 0). 228 | - **Returns:** A single string of the concatenated prefix text. 229 | 230 | - `get_display_text(exclude: int = 0) -> str` 231 | - **Purpose:** Similar to `get_prefix_text` but uses each node's `display_text()` method. 232 | - **Parameters:** 233 | - `exclude`: Number of initial nodes to exclude (default is 0). 234 | - **Returns:** A concatenated string suitable for display. 235 | 236 | - `crown() -> str` 237 | - **Purpose:** Returns the cumulative text from the root node up to this node, excluding the root’s text. 238 | 239 | - `display_text() -> str` 240 | - **Purpose:** Returns the text content of the current node. 241 | - **Details:** This method may be overridden in subclasses to provide formatted or additional context. 242 | 243 | - `get_prefix_tokens(exclude: int = 0) -> List[int]` 244 | - **Purpose:** Retrieves a concatenated list of token IDs from all ancestor nodes (including the current node). 245 | - **Parameters:** 246 | - `exclude`: Number of initial nodes to exclude (default is 0). 247 | - **Returns:** A list of token IDs. 248 | 249 | - `apply_all_children(func: Callable[[Heddle], Any], apply_self: bool = False, leaves_only: bool = False) -> List[Any]` 250 | - **Purpose:** Applies a given function to all descendant nodes. 251 | - **Parameters:** 252 | - `func`: A function that takes a Heddle node as input. 253 | - `apply_self`: Whether to apply the function to the current node as well. 254 | - `leaves_only`: If True, applies the function only to leaf nodes. 255 | - **Returns:** A list of results from applying the function to the nodes. 256 | 257 | - `at_all_leaves(func: Callable[[Heddle], Any]) -> List[Any]` 258 | - **Purpose:** Convenience method to apply a function only to all leaf nodes. 259 | - **Returns:** A list of results from applying the function to each leaf. 260 | 261 | - `apply_at_leaves(*funcs: Callable[[Heddle], Any])` 262 | - **Purpose:** Sequentially applies multiple functions to all leaf nodes. 263 | - **Details:** All functions except the last are applied for their side effects; the final function’s results are returned. 264 | - **Returns:** A list of results from the final function applied to each leaf. 265 | 266 | - `get_all_children(depth: int = 0) -> List[Heddle]` 267 | - **Purpose:** Recursively retrieves all descendant nodes in the subtree. 268 | - **Parameters:** 269 | - `depth`: Used internally to decide whether to include the current node (included if `depth > 0`). 270 | - **Returns:** A flat list of all descendant Heddle nodes. 271 | 272 | - `get_all_leaves() -> List[Heddle]` 273 | - **Purpose:** Retrieves all leaf nodes (nodes with no children) in the subtree. 274 | - **Returns:** A list of leaf nodes. 275 | 276 | - `count_children()` 277 | - **Purpose:** Counts the total number of nodes in the subtree rooted at this node (including itself). 278 | - **Returns:** An integer count of the nodes. 279 | 280 | - `count_leaves()` 281 | - **Purpose:** Counts the total number of leaf nodes in the subtree. 282 | - **Details:** If there are no children, returns 1 (the current node itself). 283 | - **Returns:** An integer count of the leaf nodes. 284 | 285 | - `__repr__()` 286 | - **Purpose:** Returns a string representation of the Heddle node. 287 | - **Returns:** A string displaying the node’s text and a summary of its children. 288 | #### `class Loom` *(subclass of Heddle)* 289 | 290 | *Loom* is a specialized subclass of Heddle used as the root node in chat-based or conversational settings. 291 | 292 | - **Additional Attributes:** 293 | - `user_prompt`: The original prompt provided by the user. 294 | - `chat_template_used`: A flag indicating whether a chat template was applied to the prompt. 295 | 296 | - **Constructor:** 297 | - `__init__(model, tokenizer, prompt)` 298 | - **Purpose:** Initializes a Loom instance. 299 | - **Details:** 300 | - Stores the original user prompt. 301 | - Attempts to apply a chat template via `tokenizer.apply_chat_template` (if available). 302 | - Calls the Heddle constructor with appropriate parameters. 303 | 304 | - **Overridden Methods:** 305 | - `display_text() -> str` 306 | - **Purpose:** Returns formatted text for display. 307 | - **Details:** If a chat template was used, it prefixes the output with "Prompt:" followed by the original user prompt and then a "Response:" section. Otherwise, it returns the plain text as defined in Heddle. 308 | 309 | ### Generation Utilities (`utils.py`) 310 | 311 | - **`prompt_to_cache(model, tokenizer, prompt_ids, c=None, prefill_step_size=512)`** 312 | Processes a prompt through the model in steps, filling the key-value cache. 313 | 314 | - **`generate_batched(...)`** 315 | Generates multiple responses in parallel from a prompt. Returns generated texts, updated caches, token counts, and flags indicating termination. 316 | 317 | - **`generate_batched_stream(...)`** 318 | Similar to `generate_batched`, but yields incremental generation updates in a streaming manner. 319 | 320 | 321 | ### Cache Utilities (`cache_utils.py`) 322 | 323 | - **`frag_cache(cache: List[KVCache], start_idx: int = 0, end_idx: int = -1) -> List[KVFrag]`** 324 | Extracts key-value cache fragments from each layer between the specified indices. 325 | 326 | - **`clip_frag(frags: List[KVFrag], token_limit: int) -> List[KVFrag]`** 327 | Clips the keys and values in cache fragments to the given token limit. 328 | 329 | - **`frag_batch_gen(cache: List[KVCache], total_prompt_len: int, generated_lengths: List[int]) -> List[List[KVFrag]]`** 330 | Creates cache fragments for each batch instance based on prompt and generated token lengths. 331 | 332 | - **`fuse_cache_frags(frags: List[List[KVFrag]]) -> List[KVCache]`** 333 | Merges a list of cache fragments into full KVCache objects, concatenating along the sequence dimension. 334 | 335 | ## Contributing 336 | 337 | Contributions to enhance N8Loom (e.g., new features, bug fixes, or improved documentation) are very welcome. Please file issues or submit pull requests on the project's repository. 338 | 339 | ## License 340 | 341 | This project is licensed under the CC0 License. 342 | 343 | ## Acknowledgements 344 | 345 | - **mlx_lm:** The library builds upon the efficient language model framework provided by mlx_lm. 346 | - **Transformers:** For model and tokenizer support. 347 | - **FastAPI & Uvicorn:** For providing a lightweight web server example. 348 | 349 | --- 350 | 351 | This documentation and the included examples should help you get started with building interactive, tree-based language model applications using N8Loom. Happy coding! 352 | -------------------------------------------------------------------------------- /media/default.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/N8python/n8loom/44b9cc670e47ecf0f0cfc73b191cb2d222ef4fe6/media/default.gif -------------------------------------------------------------------------------- /media/default.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/N8python/n8loom/44b9cc670e47ecf0f0cfc73b191cb2d222ef4fe6/media/default.mp4 -------------------------------------------------------------------------------- /media/n8loom.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/N8python/n8loom/44b9cc670e47ecf0f0cfc73b191cb2d222ef4fe6/media/n8loom.gif -------------------------------------------------------------------------------- /media/n8loom.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/N8python/n8loom/44b9cc670e47ecf0f0cfc73b191cb2d222ef4fe6/media/n8loom.mp4 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "n8loom" 7 | version = "0.2.0" 8 | dependencies = [ 9 | "annotated-types==0.7.0", 10 | "anyio==4.8.0", 11 | "certifi==2025.1.31", 12 | "charset-normalizer==3.4.1", 13 | "click==8.1.8", 14 | "exceptiongroup==1.2.2", 15 | "fastapi==0.115.8", 16 | "filelock==3.17.0", 17 | "fsspec==2025.2.0", 18 | "h11==0.14.0", 19 | "huggingface-hub==0.28.1", 20 | "idna==3.10", 21 | "Jinja2==3.1.5", 22 | "MarkupSafe==3.0.2", 23 | "mlx==0.22.0", 24 | "mlx-lm==0.21.1", 25 | "numpy==2.2.2", 26 | "packaging==24.2", 27 | "protobuf==5.29.3", 28 | "pydantic==2.10.6", 29 | "pydantic_core==2.27.2", 30 | "PyYAML==6.0.2", 31 | "regex==2024.11.6", 32 | "requests==2.32.3", 33 | "safetensors==0.5.2", 34 | "sentencepiece==0.2.0", 35 | "setuptools==75.8.0", 36 | "sniffio==1.3.1", 37 | "starlette==0.45.3", 38 | "tokenizers==0.21.0", 39 | "tqdm==4.67.1", 40 | "transformers==4.48.2", 41 | "typing_extensions==4.12.2", 42 | "urllib3==2.3.0", 43 | "uvicorn==0.34.0" 44 | ] 45 | 46 | [tool.setuptools] 47 | package-dir = {"" = "src"} 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | annotated-types==0.7.0 2 | anyio==4.8.0 3 | certifi==2025.1.31 4 | charset-normalizer==3.4.1 5 | click==8.1.8 6 | exceptiongroup==1.2.2 7 | fastapi==0.115.8 8 | filelock==3.17.0 9 | fsspec==2025.2.0 10 | h11==0.14.0 11 | huggingface-hub==0.28.1 12 | idna==3.10 13 | Jinja2==3.1.5 14 | MarkupSafe==3.0.2 15 | mlx==0.22.0 16 | mlx-lm==0.21.1 17 | numpy==2.2.2 18 | packaging==24.2 19 | pip==25.0 20 | protobuf==5.29.3 21 | pydantic==2.10.6 22 | pydantic_core==2.27.2 23 | PyYAML==6.0.2 24 | regex==2024.11.6 25 | requests==2.32.3 26 | safetensors==0.5.2 27 | sentencepiece==0.2.0 28 | setuptools==75.8.0 29 | sniffio==1.3.1 30 | starlette==0.45.3 31 | tokenizers==0.21.0 32 | tqdm==4.67.1 33 | transformers==4.48.2 34 | typing_extensions==4.12.2 35 | urllib3==2.3.0 36 | uvicorn==0.34.0 37 | wheel==0.45.1 38 | -------------------------------------------------------------------------------- /src/n8loom/__init__.py: -------------------------------------------------------------------------------- 1 | from .loom import Loom, Heddle 2 | from .utils import load_for_loom -------------------------------------------------------------------------------- /src/n8loom/cache_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import namedtuple 3 | from mlx_lm.models.cache import KVCache 4 | import mlx.core as mx 5 | from typing import Any, Dict, List, Optional, Union, Callable 6 | import numpy as np 7 | KVFrag = namedtuple("KVFrag", ["keys", "values"]) 8 | def mx_copy(x: mx.array) -> mx.array: 9 | return mx.array(np.array(x)) 10 | def frag_cache(cache: List[KVCache], start_idx: int = 0, end_idx: int = -1) -> List[KVFrag]: 11 | """Extracts and converts a slice of key-value pairs from model layer caches into fragments. 12 | 13 | Args: 14 | cache: List of KVCache objects, one per model layer, containing cached key-value pairs 15 | start_idx: Starting index for extraction (default: 0) 16 | end_idx: Ending index for extraction (default: -1) 17 | 18 | Returns: 19 | List of KVFrag objects, one per layer, each containing the extracted keys and values 20 | arrays from the specified index range 21 | 22 | Example: 23 | >>> layer_caches = [KVCache(...), KVCache(...)] # List of caches for each layer 24 | >>> fragments = frag_cache(layer_caches, 0, 10) # Get first 10 positions 25 | """ 26 | frags = [] 27 | for layer_cache in cache: 28 | keys = mx_copy(layer_cache.keys[:, :, start_idx:end_idx]) 29 | values = mx_copy(layer_cache.values[:, :, start_idx:end_idx]) 30 | frags.append(KVFrag(keys, values)) 31 | return frags 32 | def clip_frag(frags: List[KVFrag], token_limit: int) -> List[KVFrag]: 33 | """Clips a list of key-value fragments to a specified token limit. 34 | 35 | Args: 36 | frags: List of KVFrag objects - one per layer - to clip 37 | token_limit: Maximum number of tokens to retain in each fragment 38 | 39 | Returns: 40 | List of KVFrag objects, each containing the clipped keys and values arrays 41 | """ 42 | clipped_frags = [] 43 | for frag in frags: 44 | keys = mx_copy(frag.keys[:, :, :token_limit]) 45 | values = mx_copy(frag.values[:, :, :token_limit]) 46 | clipped_frags.append(KVFrag(keys, values)) 47 | return clipped_frags 48 | def frag_batch_gen(cache: List[KVCache], total_prompt_len: int, generated_lengths: List[int]) -> List[List[KVFrag]]: 49 | frags = [] 50 | B = cache[0].keys.shape[0] 51 | for i in range(B): 52 | batch_frags = [] 53 | for layer_cache in cache: 54 | keys = mx_copy(layer_cache.keys[i:i+1, :, total_prompt_len:total_prompt_len + generated_lengths[i]]) 55 | values = mx_copy(layer_cache.values[i:i+1, :, total_prompt_len:total_prompt_len + generated_lengths[i]]) 56 | batch_frags.append(KVFrag(keys, values)) 57 | frags.append(batch_frags) 58 | return frags 59 | 60 | def fuse_cache_frags(frags: List[List[KVFrag]], offset: int) -> List[KVCache]: 61 | """Fuses a list of key-value fragments into a list of model layer caches. 62 | 63 | Args: 64 | frags: List of lists of KVFrag objects - first dimension is the layer index, second is the list of fragments to merge 65 | 66 | Returns: 67 | List of KVCache objects, one per model layer, containing the fused key-value pairs from the fragments, concatenated along the sequence dimension. 68 | 69 | Example: 70 | >>> fragments = [[KVFrag(...), KVFrag(...)], [KVFrag(...), KVFrag(...)]] 71 | >>> layer_caches = fuse_cache_frags(fragments) 72 | """ 73 | caches = [] 74 | for layer_frags in frags: 75 | keys = mx.concat([ 76 | frag.keys 77 | for i, frag in enumerate(layer_frags) 78 | ], axis=2) 79 | 80 | values = mx.concat([ 81 | frag.values 82 | for i, frag in enumerate(layer_frags) 83 | ], axis=2) 84 | cache = KVCache() 85 | cache.keys = keys 86 | cache.values = values 87 | cache.offset = keys.shape[2] 88 | caches.append(cache) 89 | return caches 90 | -------------------------------------------------------------------------------- /src/n8loom/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/N8python/n8loom/44b9cc670e47ecf0f0cfc73b191cb2d222ef4fe6/src/n8loom/examples/__init__.py -------------------------------------------------------------------------------- /src/n8loom/examples/majority-benchmark.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | import time 4 | import csv 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from mlx_lm import load 9 | from n8loom import Loom, load_for_loom 10 | import mlx.core as mx 11 | 12 | # ------------------------------------------------------------------- 13 | # Load the model and tokenizer once (this may take some time) 14 | # ------------------------------------------------------------------- 15 | print("Loading model ...") 16 | model, tokenizer = load_for_loom("Llama-3.2-3B-Instruct-4bit") 17 | print("Model loaded.\n") 18 | 19 | # ------------------------------------------------------------------- 20 | # Function to run a single trial with a given answer_count. 21 | # It returns the execution time and the peak memory usage (in GB). 22 | # ------------------------------------------------------------------- 23 | def run_trial(answer_count): 24 | start_time = time.perf_counter() 25 | 26 | prompt = ( 27 | "Tobias is buying a new pair of shoes that costs $95. He has been saving up his money each month " 28 | "for the past three months. He gets a $5 allowance a month. He also mows lawns and shovels driveways. " 29 | "He charges $15 to mow a lawn and $7 to shovel. After buying the shoes, he has $15 in change. " 30 | "If he mows 4 lawns, how many driveways did he shovel?" 31 | ) 32 | 33 | # Build the chain using Loom 34 | root = Loom(model, tokenizer, prompt) 35 | assistant_start = root.add_text_child("I will solve this problem step by step and be mindful of mistakes.") 36 | assistant_start.ramify(n=answer_count, temp=0.6, max_tokens=512, min_p=0.05) 37 | 38 | answers = assistant_start.apply_at_leaves( 39 | lambda x: x.ramify("\n...Alright, I'll put my final answer between XML tags. My answer is ") if x.terminal else None, 40 | lambda x: x.ramify(n=1, temp=0.0, max_tokens=32, min_p=0.05), 41 | lambda x: x.crown() 42 | ) 43 | 44 | # Use a regex to capture everything between and 45 | pattern_content = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) 46 | parsed_final_answers = [] 47 | for answer in answers: 48 | match = pattern_content.search(answer) 49 | if match: 50 | content = match.group(1) 51 | number_match = re.search(r'\d+', content) 52 | if number_match: 53 | try: 54 | number = int(number_match.group(0)) 55 | parsed_final_answers.append(number) 56 | except ValueError: 57 | parsed_final_answers.append(float('nan')) 58 | else: 59 | parsed_final_answers.append(float('nan')) 60 | else: 61 | parsed_final_answers.append(float('nan')) 62 | 63 | # (Optional) Aggregate answers for majority voting. 64 | answer_counts = {} 65 | for answer in parsed_final_answers: 66 | if math.isnan(answer): 67 | continue 68 | answer_counts[answer] = answer_counts.get(answer, 0) + 1 69 | sorted_answers = sorted(answer_counts.items(), key=lambda x: x[1], reverse=True) 70 | 71 | # Get peak memory usage (in GB) 72 | peak_memory = mx.metal.get_peak_memory() / 1e9 73 | 74 | end_time = time.perf_counter() 75 | execution_time = end_time - start_time 76 | 77 | return execution_time, sorted_answers, peak_memory 78 | 79 | # ------------------------------------------------------------------- 80 | # Benchmark function that runs trials for answer_count values in powers 81 | # of 2 from 1 to 128. For each answer_count, it repeats trials until either 82 | # the standard error of execution time falls below 2 seconds or 10 trials are done. 83 | # 84 | # After each trial, the script waits 10 seconds to help avoid thermal throttling. 85 | # 86 | # Finally, the function saves the benchmark results to a CSV file and 87 | # creates a dual-axis plot showing both mean response time and mean peak memory. 88 | # ------------------------------------------------------------------- 89 | def benchmark(): 90 | # Answer counts: 1, 2, 4, ..., 128 91 | answer_counts_list = [2 ** i for i in range(0, 7)] 92 | results = [] # Each entry: (answer_count, mean_time, time_std_err, trials, mean_peak_memory, memory_std_err) 93 | 94 | for count in answer_counts_list: 95 | trial_times = [] 96 | trial_memories = [] 97 | trial = 0 98 | print(f"Starting trials for answer_count = {count}") 99 | while True: 100 | trial += 1 101 | print(f" Running trial {trial}...") 102 | exec_time, sorted_answers, peak_memory = run_trial(count) 103 | trial_times.append(exec_time) 104 | trial_memories.append(peak_memory) 105 | print(f" Execution time: {exec_time:.2f} seconds, Peak memory: {peak_memory:.2f} GB") 106 | 107 | # Wait 60 seconds after each trial to avoid thermal throttling 108 | time.sleep(20 * (trial + 1)) 109 | 110 | # Stop if 10 trials have been run. 111 | if trial >= 10: 112 | break 113 | # If more than one trial has been done, check if the standard error is below 2 seconds. 114 | if len(trial_times) > 1: 115 | std_err = np.std(trial_times, ddof=1) / np.sqrt(len(trial_times)) 116 | if std_err < 2.0: 117 | break 118 | 119 | mean_time = np.mean(trial_times) 120 | time_std_err = np.std(trial_times, ddof=1) / np.sqrt(len(trial_times)) if len(trial_times) > 1 else 0.0 121 | mean_memory = np.mean(trial_memories) 122 | memory_std_err = np.std(trial_memories, ddof=1) / np.sqrt(len(trial_memories)) if len(trial_memories) > 1 else 0.0 123 | 124 | results.append((count, mean_time, time_std_err, trial, mean_memory, memory_std_err)) 125 | print(f"Results for answer_count = {count}: Mean Time = {mean_time:.2f}s, Std Error = {time_std_err:.2f}s, Trials = {trial}, Mean Peak Memory = {mean_memory:.2f} GB\n") 126 | 127 | # Save results to a CSV file. 128 | csv_filename = "benchmark_results.csv" 129 | with open(csv_filename, "w", newline="") as csvfile: 130 | writer = csv.writer(csvfile) 131 | writer.writerow(["answer_count", "mean_time", "std_error", "trials", "mean_peak_memory", "memory_std_error"]) 132 | for row in results: 133 | writer.writerow(row) 134 | print(f"Saved benchmark results to {csv_filename}") 135 | 136 | # Prepare data for plotting. 137 | counts = [r[0] for r in results] 138 | mean_times = [r[1] for r in results] 139 | time_errors = [r[2] for r in results] 140 | mean_memories = [r[4] for r in results] 141 | memory_errors = [r[5] for r in results] 142 | 143 | # Create a dual-axis plot. 144 | fig, ax1 = plt.subplots(figsize=(10, 6)) 145 | 146 | # Plot mean response time (left y-axis) 147 | ax1.errorbar(counts, mean_times, yerr=time_errors, fmt='o-', capsize=5, color='tab:blue', label='Mean Time (s)') 148 | ax1.set_xlabel("Answer Count") 149 | ax1.set_ylabel("Mean Response Time (seconds)", color='tab:blue') 150 | ax1.tick_params(axis='y', labelcolor='tab:blue') 151 | ax1.set_xscale('log', base=2) 152 | 153 | # Create a second y-axis for peak memory. 154 | ax2 = ax1.twinx() 155 | ax2.errorbar(counts, mean_memories, yerr=memory_errors, fmt='s--', capsize=5, color='tab:red', label='Mean Peak Memory (GB)') 156 | ax2.set_ylabel("Mean Peak Memory (GB)", color='tab:red') 157 | ax2.tick_params(axis='y', labelcolor='tab:red') 158 | 159 | # Title and grid. 160 | fig.suptitle("Benchmark: Answer Count vs Mean Response Time and Peak Memory") 161 | ax1.grid(True, which="both", ls="--") 162 | 163 | # Add a legend combining both axes. 164 | lines_1, labels_1 = ax1.get_legend_handles_labels() 165 | lines_2, labels_2 = ax2.get_legend_handles_labels() 166 | ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc="upper left", bbox_to_anchor=(0.1, 0.9)) 167 | 168 | # Save the plot. 169 | plot_filename = "benchmark_plot.png" 170 | plt.savefig(plot_filename) 171 | print(f"Saved benchmark plot to {plot_filename}") 172 | plt.show() 173 | 174 | # ------------------------------------------------------------------- 175 | # Run the benchmark if this script is executed as the main module. 176 | # ------------------------------------------------------------------- 177 | if __name__ == "__main__": 178 | benchmark() 179 | -------------------------------------------------------------------------------- /src/n8loom/examples/server.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Body, HTTPException, Response 2 | from fastapi.middleware.cors import CORSMiddleware 3 | from fastapi.staticfiles import StaticFiles 4 | from fastapi.responses import JSONResponse 5 | from pydantic import BaseModel 6 | from typing import List, Optional, Dict, Union 7 | import uvicorn 8 | import os 9 | from n8loom import Heddle, Loom, load_for_loom 10 | from mlx_lm import load 11 | 12 | app = FastAPI() 13 | 14 | # For storing loaded models and created nodes. 15 | # In a real application, consider a proper database or other persistent store. 16 | model_store: Dict[str, Dict] = {} # Map of model_path -> {"model": model, "tokenizer": tokenizer} 17 | loom_store: Dict[str, Loom] = {} # Map of loom_id -> Loom (the root heddle) 18 | heddle_store: Dict[str, Heddle] = {} # All nodes (including loom roots), keyed by a unique ID 19 | 20 | # Simple integer counters for unique IDs 21 | COUNTERS = { 22 | "loom_id": 0, 23 | "heddle_id": 0 24 | } 25 | 26 | def get_next_id(prefix: str) -> str: 27 | COUNTERS[prefix] += 1 28 | return f"{prefix}-{COUNTERS[prefix]}" 29 | 30 | # ------------------------------ 31 | # Pydantic models for request/response schemas 32 | # ------------------------------ 33 | class LoadModelRequest(BaseModel): 34 | model_path: str 35 | 36 | class LoadModelResponse(BaseModel): 37 | model_id: str 38 | 39 | class CreateLoomRequest(BaseModel): 40 | model_id: str 41 | prompt: str 42 | 43 | class CreateLoomResponse(BaseModel): 44 | loom_id: str 45 | heddle_id: str # same as the loom's root node ID 46 | 47 | class NodeInfo(BaseModel): 48 | node_id: str 49 | text: str 50 | display_text: str 51 | children_ids: List[str] 52 | terminal: bool 53 | 54 | class RamifyRequest(BaseModel): 55 | node_id: str 56 | # Provide either "text" or generation parameters 57 | text: Optional[Union[str, List[str]]] = None 58 | 59 | # generation parameters if we are sampling from the model 60 | n: Optional[int] = 4 61 | temp: Optional[float] = 0.8 62 | max_tokens: Optional[int] = 8 63 | stream: Optional[bool] = False 64 | 65 | class RamifyResponse(BaseModel): 66 | node_id: str 67 | created_children: List[str] 68 | 69 | class ClipRequest(BaseModel): 70 | node_id: str 71 | token_limit: int 72 | 73 | class TrimRequest(BaseModel): 74 | node_id: str 75 | token_trim: int 76 | 77 | # New models for the loom management endpoints 78 | class LoomInfo(BaseModel): 79 | loom_id: str 80 | root_heddle_id: str 81 | prompt: str 82 | 83 | class ImportLoomRequest(BaseModel): 84 | model_id: str 85 | loom_data: Dict 86 | 87 | class ImportLoomResponse(BaseModel): 88 | loom_id: str 89 | heddle_id: str 90 | 91 | # ------------------------------ 92 | # Helper functions 93 | # ------------------------------ 94 | def serialize_heddle(node: Heddle, node_id: str) -> NodeInfo: 95 | # Return basic information about a node 96 | return NodeInfo( 97 | node_id=node_id, 98 | text=node.text, 99 | display_text=node.display_text(), 100 | children_ids=[ 101 | _id for _id, h in heddle_store.items() if h.parent is node 102 | ], 103 | terminal=node.terminal 104 | ) 105 | 106 | def build_subtree_dict(node: Heddle, node_id: str) -> Dict: 107 | """Recursively build a JSON-serializable dict describing the subtree.""" 108 | return { 109 | "node_id": node_id, 110 | "text": node.text, 111 | "display_text": node.display_text(), 112 | "terminal": node.terminal, 113 | "children": [ 114 | build_subtree_dict(child, _id) 115 | for _id, child in heddle_store.items() if child.parent is node 116 | ] 117 | } 118 | 119 | # ------------------------------ 120 | # API Endpoints 121 | # ------------------------------ 122 | 123 | @app.post("/load_model", response_model=LoadModelResponse) 124 | def load_model(req: LoadModelRequest): 125 | """ 126 | Load a model + tokenizer using `mlx_lm.load` and store them under the model_path. 127 | If the model is already loaded, return the existing path. 128 | """ 129 | if req.model_path not in model_store: 130 | # load model only if not already loaded 131 | model, tokenizer = load_for_loom(req.model_path) 132 | model_store[req.model_path] = { 133 | "model": model, 134 | "tokenizer": tokenizer 135 | } 136 | return LoadModelResponse(model_id=req.model_path) 137 | 138 | 139 | @app.post("/create_loom", response_model=CreateLoomResponse) 140 | def create_loom(req: CreateLoomRequest): 141 | """ 142 | Create a new Loom with the given model_path and user prompt. 143 | """ 144 | if req.model_id not in model_store: 145 | raise HTTPException(status_code=400, detail="Model path not found") 146 | 147 | model_data = model_store[req.model_id] 148 | model = model_data["model"] 149 | tokenizer = model_data["tokenizer"] 150 | 151 | loom_id = get_next_id("loom_id") 152 | root_loom = Loom(model, tokenizer, req.prompt) 153 | 154 | # Store the loom in memory 155 | loom_store[loom_id] = root_loom 156 | 157 | # Also store it in the heddle store 158 | heddle_id = get_next_id("heddle_id") 159 | heddle_store[heddle_id] = root_loom 160 | 161 | return CreateLoomResponse(loom_id=loom_id, heddle_id=heddle_id) 162 | 163 | 164 | @app.get("/loom/{loom_id}") 165 | def get_loom_info(loom_id: str): 166 | """ 167 | Returns a JSON subtree of the entire Loom structure. 168 | """ 169 | if loom_id not in loom_store: 170 | raise HTTPException(status_code=404, detail="Loom not found") 171 | 172 | loom_root = loom_store[loom_id] 173 | # We need to find which heddle_id references loom_root 174 | root_heddle_id = None 175 | for hid, node in heddle_store.items(): 176 | if node is loom_root: 177 | root_heddle_id = hid 178 | break 179 | 180 | if root_heddle_id is None: 181 | raise HTTPException(status_code=500, detail="Root node not found in heddle store.") 182 | 183 | return build_subtree_dict(loom_root, root_heddle_id) 184 | 185 | # Add this new model near the others: 186 | class RenameLoomRequest(BaseModel): 187 | new_id: str 188 | 189 | # New endpoint to rename a loom. 190 | @app.post("/looms/{loom_id}/rename") 191 | def rename_loom(loom_id: str, req: RenameLoomRequest): 192 | """ 193 | Rename an existing loom to a new id. 194 | The new id must not already exist. 195 | """ 196 | if loom_id not in loom_store: 197 | raise HTTPException(status_code=404, detail="Loom not found") 198 | if req.new_id in loom_store: 199 | raise HTTPException(status_code=400, detail="New loom id already exists") 200 | # Remove the old entry and reassign the loom under the new id. 201 | loom = loom_store.pop(loom_id) 202 | loom_store[req.new_id] = loom 203 | return {"old_loom_id": loom_id, "new_loom_id": req.new_id} 204 | @app.delete("/looms/{loom_id}") 205 | def delete_loom(loom_id: str): 206 | """ 207 | Delete a loom from the store, and remove its root node from the heddle store. 208 | """ 209 | if loom_id not in loom_store: 210 | raise HTTPException(status_code=404, detail="Loom not found") 211 | # Remove the loom 212 | loom = loom_store.pop(loom_id) 213 | # Also remove its corresponding root node from heddle_store. 214 | root_heddle_id = None 215 | for hid, node in list(heddle_store.items()): 216 | if node is loom: 217 | root_heddle_id = hid 218 | del heddle_store[hid] 219 | break 220 | return {"deleted_loom_id": loom_id, "deleted_root_heddle_id": root_heddle_id} 221 | 222 | @app.get("/node/{node_id}", response_model=NodeInfo) 223 | def get_node_info(node_id: str): 224 | """ 225 | Get basic info about a node: text, child IDs, terminal status. 226 | """ 227 | if node_id not in heddle_store: 228 | raise HTTPException(status_code=404, detail="Node not found") 229 | 230 | node = heddle_store[node_id] 231 | return serialize_heddle(node, node_id) 232 | 233 | 234 | @app.post("/node/ramify") 235 | def ramify_node(req: RamifyRequest): 236 | """ 237 | - If `text` is given as a string, create one text child. 238 | - If `text` is given as a list of strings, create multiple text children. 239 | - Otherwise, create children by sampling from the model (n, temp, max_tokens). 240 | If `stream` is True, stream generation updates. 241 | """ 242 | if req.node_id not in heddle_store: 243 | raise HTTPException(status_code=404, detail="Node not found") 244 | node = heddle_store[req.node_id] 245 | result = node.ramify(req.text, n=req.n, temp=req.temp, max_tokens=req.max_tokens, stream=req.stream) 246 | if req.stream: 247 | import json 248 | from fastapi.responses import StreamingResponse 249 | def event_generator(): 250 | for update in result: 251 | if 'children' in update: 252 | for child in update['children']: 253 | child_id = get_next_id("heddle_id") 254 | heddle_store[child_id] = child 255 | update['children'] = [child_id] 256 | update['children'] = len(update['children']) 257 | yield json.dumps(update) + "\n" 258 | return StreamingResponse(event_generator(), media_type="application/json") 259 | 260 | created_children_ids = [] 261 | 262 | if isinstance(result, Heddle): 263 | # Single child created 264 | child_id = get_next_id("heddle_id") 265 | heddle_store[child_id] = result 266 | created_children_ids.append(child_id) 267 | elif isinstance(result, list) and all(isinstance(r, Heddle) for r in result): 268 | # Multiple children created 269 | for child in result: 270 | child_id = get_next_id("heddle_id") 271 | heddle_store[child_id] = child 272 | created_children_ids.append(child_id) 273 | else: 274 | # Possibly no children created (if node was terminal) 275 | pass 276 | 277 | return RamifyResponse(node_id=req.node_id, created_children=created_children_ids) 278 | 279 | 280 | @app.post("/node/clip") 281 | def clip_node(req: ClipRequest): 282 | """ 283 | Clip the node (and remove its children) to `token_limit`. 284 | """ 285 | if req.node_id not in heddle_store: 286 | raise HTTPException(status_code=404, detail="Node not found") 287 | node = heddle_store[req.node_id] 288 | node.clip(req.token_limit) 289 | return {"node_id": req.node_id, "clipped_to": req.token_limit} 290 | 291 | 292 | @app.post("/node/trim") 293 | def trim_node(req: TrimRequest): 294 | """ 295 | Trim the last N tokens from the node (and remove its children). 296 | """ 297 | if req.node_id not in heddle_store: 298 | raise HTTPException(status_code=404, detail="Node not found") 299 | node = heddle_store[req.node_id] 300 | node.trim(req.token_trim) 301 | return {"node_id": req.node_id, "trimmed_tokens": req.token_trim} 302 | 303 | 304 | @app.delete("/node/{node_id}") 305 | def delete_node(node_id: str): 306 | """ 307 | Delete a node and its children from the store. 308 | Cannot delete root nodes of looms. 309 | """ 310 | if node_id not in heddle_store: 311 | raise HTTPException(status_code=404, detail="Node not found") 312 | 313 | # Check if this is a root node 314 | node = heddle_store[node_id] 315 | for loom in loom_store.values(): 316 | if node is loom: 317 | raise HTTPException(status_code=400, detail="Cannot delete root node") 318 | 319 | # Recursively collect all child node IDs 320 | def get_child_ids(node): 321 | children = [] 322 | for child_id, child in heddle_store.items(): 323 | if child.parent is node: 324 | children.append(child_id) 325 | children.extend(get_child_ids(child)) 326 | return children 327 | 328 | # Delete all children first 329 | child_ids = get_child_ids(node) 330 | for child_id in child_ids: 331 | del heddle_store[child_id] 332 | 333 | # Delete the node itself 334 | del heddle_store[node_id] 335 | 336 | return {"node_id": node_id, "deleted_children": child_ids} 337 | 338 | 339 | @app.get("/node/{node_id}/subtree") 340 | def get_subtree(node_id: str): 341 | """ 342 | Returns the entire subtree (recursively) from the given node as JSON. 343 | """ 344 | if node_id not in heddle_store: 345 | raise HTTPException(status_code=404, detail="Node not found") 346 | node = heddle_store[node_id] 347 | return build_subtree_dict(node, node_id) 348 | 349 | 350 | # ------------------------------ 351 | # New endpoints for loom management 352 | # ------------------------------ 353 | 354 | @app.get("/looms", response_model=List[LoomInfo]) 355 | def list_looms(): 356 | """ 357 | Returns a list of all looms currently in memory. 358 | """ 359 | looms = [] 360 | for loom_id, loom in loom_store.items(): 361 | root_heddle_id = None 362 | for hid, node in heddle_store.items(): 363 | if node is loom: 364 | root_heddle_id = hid 365 | break 366 | if root_heddle_id: 367 | looms.append(LoomInfo(loom_id=loom_id, root_heddle_id=root_heddle_id, prompt=loom.text)) 368 | return looms 369 | 370 | @app.get("/loom/{loom_id}/export") 371 | def export_loom(loom_id: str): 372 | """ 373 | Export a given loom (its full tree) as JSON. 374 | """ 375 | if loom_id not in loom_store: 376 | raise HTTPException(status_code=404, detail="Loom not found") 377 | loom_root = loom_store[loom_id] 378 | root_heddle_id = None 379 | for hid, node in heddle_store.items(): 380 | if node is loom_root: 381 | root_heddle_id = hid 382 | break 383 | if root_heddle_id is None: 384 | raise HTTPException(status_code=500, detail="Root node not found in heddle store.") 385 | exported_tree = build_subtree_dict(loom_root, root_heddle_id) 386 | return JSONResponse(content=exported_tree) 387 | 388 | @app.post("/looms/import", response_model=ImportLoomResponse) 389 | def import_loom(req: ImportLoomRequest): 390 | """ 391 | Import a loom from an exported JSON structure. 392 | The client must provide the model_id (which must already be loaded) 393 | and the loom_data (the exported JSON). The loom will be re-instantiated 394 | using the provided model/tokenizer and the tree structure rebuilt. 395 | """ 396 | if req.model_id not in model_store: 397 | raise HTTPException(status_code=400, detail="Model id not found") 398 | model_data = model_store[req.model_id] 399 | model = model_data["model"] 400 | tokenizer = model_data["tokenizer"] 401 | 402 | loom_json = req.loom_data 403 | prompt = loom_json.get("text", "Imported Loom") 404 | new_loom = Loom(model, tokenizer, prompt) 405 | new_loom_id = get_next_id("loom_id") 406 | loom_store[new_loom_id] = new_loom 407 | new_heddle_id = get_next_id("heddle_id") 408 | heddle_store[new_heddle_id] = new_loom 409 | 410 | def import_subtree(parent, children_list): 411 | for child in children_list: 412 | child_text = child.get("text", "") 413 | # Use the parent's ramify method to add a child with the provided text. 414 | result = parent.ramify(child_text) 415 | if result is None: 416 | continue 417 | # If multiple children are returned, we take the first. 418 | if isinstance(result, list): 419 | result = result[0] 420 | child_id = get_next_id("heddle_id") 421 | heddle_store[child_id] = result 422 | if child.get("children"): 423 | import_subtree(result, child["children"]) 424 | import_subtree(new_loom, loom_json.get("children", [])) 425 | return ImportLoomResponse(loom_id=new_loom_id, heddle_id=new_heddle_id) 426 | 427 | # ------------------------------ 428 | # CORS, static files, and root endpoint 429 | # ------------------------------ 430 | app.add_middleware( 431 | CORSMiddleware, 432 | allow_origins=["*"], # Allows all origins 433 | allow_credentials=True, 434 | allow_methods=["GET", "POST", "DELETE"], # Explicitly list allowed methods 435 | allow_headers=["*"], # Allows all headers 436 | ) 437 | 438 | current_dir = os.path.dirname(os.path.abspath(__file__)) 439 | # Join it with 'static' to get the full path 440 | static_dir = os.path.join(current_dir, "static") 441 | 442 | app.mount("/static", StaticFiles(directory=static_dir), name="static") 443 | 444 | @app.get("/") 445 | def read_root(): 446 | return Response( 447 | content=""" 448 | 449 | 450 | 451 | 452 | 453 |

Redirecting to the client...

454 | 455 | 456 | """, 457 | media_type="text/html", 458 | ) 459 | 460 | # ------------------------------ 461 | # Run the server (for local testing) 462 | # ------------------------------ 463 | if __name__ == "__main__": 464 | uvicorn.run(app, host="localhost", port=8000) 465 | -------------------------------------------------------------------------------- /src/n8loom/examples/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Node Tree Visualizer 8 | 381 | 382 | 383 | 384 |
385 |
386 |

Control Panel

387 | 388 |
389 | 390 | 391 |
392 | 393 | 394 | 395 |
396 | 397 | 398 |
399 | 400 | 401 | 402 |
403 | 404 | 405 |
406 | 407 |
408 | 409 | 410 |
411 | 412 |
413 | 414 | 415 |
416 | 417 |

Existing Looms:

418 |
419 | 420 |
421 | 422 |
423 |
424 |
425 | 426 |
427 |

Loom

428 |
429 |
430 |
431 | 432 | 433 | 434 | 435 | -------------------------------------------------------------------------------- /src/n8loom/examples/static/main.js: -------------------------------------------------------------------------------- 1 | const BASE_URL = 'http://localhost:8000'; 2 | let selectedNode = null; 3 | let currentModelId = null; 4 | let currentLoomId = null; 5 | let loomData = null; // Will store the entire tree structure once fetched 6 | 7 | // Helper functions for API calls 8 | async function postJSON(url, data) { 9 | const response = await fetch(url, { 10 | method: 'POST', 11 | headers: { 12 | 'Content-Type': 'application/json' 13 | }, 14 | body: JSON.stringify(data) 15 | }); 16 | 17 | if (!response.ok) { 18 | const errorText = await response.text(); 19 | throw new Error(`Error ${response.status}: ${errorText}`); 20 | } 21 | 22 | return await response.json(); 23 | } 24 | 25 | async function getJSON(url) { 26 | const response = await fetch(url); 27 | if (!response.ok) { 28 | const errorText = await response.text(); 29 | throw new Error(`Error ${response.status}: ${errorText}`); 30 | } 31 | return await response.json(); 32 | } 33 | 34 | function showStatus(message, isError = false) { 35 | const status = document.getElementById('status'); 36 | status.textContent = message; 37 | status.className = 'status ' + (isError ? 'error' : 'success'); 38 | } 39 | 40 | // ------------------ Tree Rendering Logic ------------------ 41 | 42 | // Recursively search for node + path from root to nodeId 43 | // Calculate the maximum depth below a node 44 | function getMaxDepth(node) { 45 | if (!node.children || node.children.length === 0) { 46 | return 0; 47 | } 48 | return 1 + Math.max(...node.children.map(child => getMaxDepth(child))); 49 | } 50 | 51 | function countNonzeroDepths(node) { 52 | if (!node.children || node.children.length === 0) { 53 | return 0; 54 | } 55 | let nonZeroDepths = 0; 56 | for (let child of node.children) { 57 | if (getMaxDepth(child) > 0) { 58 | nonZeroDepths++; 59 | } 60 | } 61 | return nonZeroDepths; 62 | } 63 | 64 | function findNodeAndPath(root, targetId, path = []) { 65 | // If this is the node 66 | if (root.node_id === targetId) { 67 | return [...path, root]; 68 | } 69 | 70 | // If children exist, search them 71 | if (root.children) { 72 | for (let child of root.children) { 73 | const result = findNodeAndPath(child, targetId, [...path, root]); 74 | if (result) { 75 | return result; 76 | } 77 | } 78 | } 79 | return null; // Not found in this branch 80 | } 81 | 82 | /** 83 | * Renders the entire "relative" view inside #tree: 84 | * 1) All ancestor texts (including selected) concatenated at the top 85 | * - Each ancestor's text is in its own clickable . 86 | * 2) The selected node in the center 87 | * 3) The immediate children below 88 | */ 89 | function renderRelativeView(selectedId) { 90 | const treeContainer = document.getElementById('tree'); 91 | treeContainer.innerHTML = ''; 92 | 93 | // If no data or no selected node, just return 94 | if (!loomData || !selectedId) { 95 | return; 96 | } 97 | 98 | const path = findNodeAndPath(loomData, selectedId); 99 | if (!path) { 100 | treeContainer.textContent = 'Selected node not found.'; 101 | return; 102 | } 103 | 104 | // 1) Build a parent text container 105 | const parentsDiv = document.createElement('div'); 106 | parentsDiv.className = 'parents-text'; 107 | 108 | // For each ancestor node, create a span that: 109 | // - shows node.text 110 | // - highlights on hover 111 | // - on click, re-selects that ancestor 112 | path.forEach((node, index) => { 113 | const branches = countNonzeroDepths(node); 114 | const span = document.createElement('pre'); 115 | if (index === path.length - 1) { 116 | span.classList.add('selected-chunk'); 117 | } else { 118 | span.classList.add('ancestor-chunk'); 119 | } 120 | if (branches > 1) { 121 | span.style.backgroundColor = 'blue'; 122 | } 123 | // Add a space after each node's text except maybe the last 124 | span.textContent = (node.display_text || 'Empty node'); 125 | 126 | // Clicking on this chunk re-selects that node 127 | span.onclick = (e) => { 128 | e.stopPropagation(); 129 | selectNode(node.node_id); 130 | }; 131 | parentsDiv.appendChild(span); 132 | }); 133 | 134 | // Add the parentsDiv to the DOM 135 | treeContainer.appendChild(parentsDiv); 136 | 137 | // 2) The selected node (last in path) 138 | const selectedNodeData = path[path.length - 1]; 139 | 140 | const selectedDiv = document.createElement('div'); 141 | selectedDiv.className = 'selected-node'; 142 | 143 | const buttonGroup = document.createElement('div'); 144 | buttonGroup.className = 'button-group'; 145 | 146 | const ramifyBtn = document.createElement('button'); 147 | ramifyBtn.classList.add('ramify-btn'); 148 | ramifyBtn.classList.add('icon-btn'); 149 | ramifyBtn.textContent = '🌱'; 150 | ramifyBtn.onclick = async(e) => { 151 | e.stopPropagation(); 152 | await ramifySelected(); 153 | } 154 | if (!selectedNodeData.terminal) { 155 | buttonGroup.appendChild(ramifyBtn); 156 | } 157 | 158 | const extendBtn = document.createElement('button'); 159 | extendBtn.classList.add('extend-btn'); 160 | extendBtn.classList.add('icon-btn'); 161 | extendBtn.textContent = '📝'; 162 | extendBtn.onclick = (e) => { 163 | e.stopPropagation(); 164 | 165 | // Create textarea element 166 | const textarea = document.createElement('textarea'); 167 | textarea.className = 'extend-textarea'; 168 | textarea.value = ''; 169 | 170 | // Insert after the last pre element in parentsDiv 171 | const lastPre = parentsDiv.querySelector('pre:last-of-type'); 172 | lastPre.after(textarea); 173 | textarea.focus(); 174 | 175 | // Auto-resize function 176 | const adjustHeight = () => { 177 | textarea.style.height = 'auto'; 178 | textarea.style.height = textarea.scrollHeight + 'px'; 179 | }; 180 | 181 | textarea.addEventListener('input', adjustHeight); 182 | 183 | textarea.onkeydown = async(e) => { 184 | if (e.key === 'Enter' && e.shiftKey) { 185 | e.preventDefault(); 186 | const text = textarea.value; 187 | if (text) { 188 | const { created_children } = await postJSON(`${BASE_URL}/node/ramify`, { 189 | node_id: selectedNodeData.node_id, 190 | text: text 191 | }); 192 | selectedNode = created_children[0]; 193 | 194 | // Refresh the tree 195 | const treeData = await getJSON(`${BASE_URL}/loom/${currentLoomId}`); 196 | updateTree(treeData); 197 | } 198 | textarea.remove(); 199 | } else if (e.key === 'Escape') { 200 | textarea.remove(); 201 | } 202 | }; 203 | } 204 | if (!selectedNodeData.terminal) { 205 | buttonGroup.appendChild(extendBtn); 206 | } 207 | 208 | // Create a delete button 209 | const deleteBtn = document.createElement('button'); 210 | deleteBtn.classList.add('delete-btn'); 211 | deleteBtn.classList.add('icon-btn'); 212 | deleteBtn.textContent = '🗑️'; 213 | deleteBtn.onclick = async(e) => { 214 | e.stopPropagation(); 215 | selectNode(path[path.length - 2].node_id); 216 | await deleteNode(selectedNodeData.node_id); 217 | }; 218 | if (path.length > 1) { 219 | buttonGroup.appendChild(deleteBtn); 220 | } 221 | 222 | selectedDiv.appendChild(buttonGroup); 223 | 224 | // 3) Children of this node 225 | const childrenDiv = document.createElement('div'); 226 | childrenDiv.className = 'children-container'; 227 | 228 | if (selectedNodeData.children && selectedNodeData.children.length > 0) { 229 | selectedNodeData.children.forEach(child => { 230 | const childDiv = document.createElement('div'); 231 | childDiv.className = 'child-node'; 232 | const depth = getMaxDepth(child); 233 | childDiv.innerHTML = "
" + (child.display_text || 'Empty child') + "
" + 234 | (depth > 0 ? `(...${depth} more levels)` : ''); 235 | 236 | // Create a container for the buttons that will be inline 237 | const buttonContainer = document.createElement('div'); 238 | buttonContainer.style.display = 'inline-flex'; 239 | buttonContainer.style.alignItems = 'center'; 240 | buttonContainer.style.marginLeft = '10px'; 241 | childDiv.appendChild(buttonContainer); 242 | 243 | // Clicking a child sets it as selected 244 | childDiv.onclick = (e) => { 245 | e.stopPropagation(); 246 | selectNode(child.node_id); 247 | }; 248 | 249 | // Add a delete button for this child 250 | const childDeleteBtn = document.createElement('button'); 251 | childDeleteBtn.classList.add('delete-btn'); 252 | childDeleteBtn.classList.add('icon-btn'); 253 | childDeleteBtn.textContent = '🗑️'; 254 | childDeleteBtn.style.width = 'fit-content'; 255 | childDeleteBtn.onclick = async(e) => { 256 | e.stopPropagation(); 257 | await deleteNode(child.node_id); 258 | }; 259 | 260 | buttonContainer.appendChild(childDeleteBtn); 261 | 262 | childrenDiv.appendChild(childDiv); 263 | }); 264 | } 265 | 266 | parentsDiv.appendChild(selectedDiv); 267 | treeContainer.appendChild(childrenDiv); 268 | } 269 | 270 | /** 271 | * Updates the global loomData and triggers rendering of the relative view. 272 | * If there is no selected node, we set it to the root node (loomData.node_id). 273 | */ 274 | function updateTree(treeData) { 275 | loomData = treeData; 276 | 277 | // If there is no current selection, default to the root node 278 | if (!selectedNode) { 279 | selectedNode = loomData.node_id; 280 | } 281 | 282 | renderRelativeView(selectedNode); 283 | } 284 | 285 | /** 286 | * Sets the selected node and re-renders the relative view. 287 | */ 288 | function selectNode(nodeId) { 289 | selectedNode = nodeId; 290 | renderRelativeView(selectedNode); 291 | } 292 | 293 | 294 | // ------------------ API Wrappers ------------------ 295 | 296 | async function loadModel() { 297 | try { 298 | const modelPath = document.getElementById('modelPath').value; 299 | const response = await postJSON(`${BASE_URL}/load_model`, { 300 | model_path: modelPath 301 | }); 302 | currentModelId = response.model_id; 303 | showStatus(`Model loaded successfully: ${currentModelId}`); 304 | } catch (error) { 305 | showStatus(error.message, true); 306 | } 307 | } 308 | 309 | async function createLoom() { 310 | if (!currentModelId) { 311 | showStatus('Please load a model first', true); 312 | return; 313 | } 314 | 315 | try { 316 | const prompt = document.getElementById('prompt').value; 317 | const response = await postJSON(`${BASE_URL}/create_loom`, { 318 | model_id: currentModelId, 319 | prompt: prompt 320 | }); 321 | selectedNode = null; 322 | currentLoomId = response.loom_id; 323 | 324 | // Fetch and display the initial tree 325 | const treeData = await getJSON(`${BASE_URL}/loom/${currentLoomId}`); 326 | updateTree(treeData); 327 | 328 | showStatus(`Loom created successfully: ${currentLoomId}`); 329 | } catch (error) { 330 | showStatus(error.message, true); 331 | } 332 | refreshLoomList(); 333 | } 334 | 335 | async function deleteNode(nodeId) { 336 | try { 337 | await fetch(`${BASE_URL}/node/${nodeId}`, { 338 | method: 'DELETE' 339 | }); 340 | 341 | // Refresh the tree 342 | const treeData = await getJSON(`${BASE_URL}/loom/${currentLoomId}`); 343 | updateTree(treeData); 344 | 345 | // If we deleted the currently selected node, reset selection to root 346 | if (selectedNode === nodeId) { 347 | selectedNode = loomData.node_id; 348 | renderRelativeView(selectedNode); 349 | } 350 | 351 | showStatus('Node deleted successfully'); 352 | } catch (error) { 353 | showStatus(error.message, true); 354 | } 355 | } 356 | 357 | async function ramifySelected() { 358 | if (!selectedNode) { 359 | showStatus('Please select a node first', true); 360 | return; 361 | } 362 | try { 363 | const body = { 364 | node_id: selectedNode, 365 | stream: true, 366 | n: parseInt(document.getElementById('numSamples').value, 10), 367 | temp: parseFloat(document.getElementById('temperature').value), 368 | max_tokens: parseInt(document.getElementById('maxTokens').value, 10) 369 | }; 370 | const response = await fetch(`${BASE_URL}/node/ramify`, { 371 | method: 'POST', 372 | headers: { 'Content-Type': 'application/json' }, 373 | body: JSON.stringify(body) 374 | }); 375 | 376 | if (!response.ok) { 377 | const errorText = await response.text(); 378 | throw new Error(`Error ${response.status}: ${errorText}`); 379 | } 380 | const childrenDiv = document.querySelector('.children-container'); 381 | // Create side-by-side streaming child nodes for each sample 382 | const streamChildNodes = []; 383 | const batchSize = parseInt(document.getElementById('numSamples').value, 10); 384 | for (let i = 0; i < batchSize; i++) { 385 | const streamChild = document.createElement('div'); 386 | streamChild.className = 'child-node streaming'; 387 | streamChild.style.pointerEvents = 'none'; 388 | streamChild.innerHTML = "
Generating...
"; 389 | childrenDiv.appendChild(streamChild); 390 | streamChildNodes.push(streamChild); 391 | } 392 | const reader = response.body.getReader(); 393 | const decoder = new TextDecoder(); 394 | let resultText = ""; 395 | while (true) { 396 | const { value, done } = await reader.read(); 397 | if (done) break; 398 | resultText += decoder.decode(value, { stream: true }); 399 | const lines = resultText.split("\n"); 400 | for (let i = 0; i < lines.length - 1; i++) { 401 | const line = lines[i].trim(); 402 | if (line) { 403 | const data = JSON.parse(line); 404 | if (data.type === "update") { 405 | for (let i = 0; i < data.decoded_texts.length; i++) { 406 | streamChildNodes[i].querySelector('pre').textContent = data.decoded_texts[i]; 407 | } 408 | } else if (data.type === "final") { 409 | for (let i = 0; i < data.decoded_texts.length; i++) { 410 | streamChildNodes[i].querySelector('pre').textContent = data.decoded_texts[i]; 411 | streamChildNodes[i].style.pointerEvents = 'auto'; 412 | streamChildNodes[i].classList.remove('streaming'); 413 | streamChildNodes[i].onclick = () => { 414 | selectNode(selectedNode); 415 | }; 416 | } 417 | } 418 | } 419 | } 420 | resultText = lines[lines.length - 1]; 421 | } 422 | const treeData = await getJSON(`${BASE_URL}/loom/${currentLoomId}`); 423 | updateTree(treeData); 424 | showStatus('Node ramified successfully (stream)'); 425 | } catch (error) { 426 | showStatus(error.message, true); 427 | } 428 | } 429 | 430 | 431 | // ------------------ New Functions for Loom Management ------------------ 432 | async function refreshLoomList() { 433 | try { 434 | const looms = await getJSON(`${BASE_URL}/looms`); 435 | const loomListDiv = document.getElementById('loomList'); 436 | loomListDiv.innerHTML = ''; 437 | looms.forEach(loom => { 438 | const loomDiv = document.createElement('div'); 439 | loomDiv.style.marginBottom = '10px'; 440 | 441 | // Create a span to display only the loom id. 442 | const loomIdSpan = document.createElement('span'); 443 | loomIdSpan.textContent = loom.loom_id; 444 | loomIdSpan.style.cursor = 'pointer'; 445 | loomIdSpan.style.marginRight = '10px'; 446 | // Single click loads the loom. 447 | loomIdSpan.onclick = () => { 448 | loadLoomById(loom.loom_id); 449 | }; 450 | // Double-click to edit (inline renaming) 451 | loomIdSpan.ondblclick = () => { 452 | const input = document.createElement('input'); 453 | input.type = 'text'; 454 | input.value = loom.loom_id; 455 | input.style.width = '100px'; 456 | input.onblur = async() => { 457 | if (input.value && input.value !== loom.loom_id) { 458 | try { 459 | await postJSON(`${BASE_URL}/looms/${loom.loom_id}/rename`, { new_id: input.value }); 460 | showStatus(`Loom renamed to ${input.value}`); 461 | refreshLoomList(); 462 | } catch (error) { 463 | showStatus(error.message, true); 464 | } 465 | } 466 | loomIdSpan.textContent = input.value; 467 | loomDiv.replaceChild(loomIdSpan, input); 468 | }; 469 | input.onkeydown = (e) => { 470 | if (e.key === 'Enter') { 471 | input.blur(); 472 | } 473 | }; 474 | loomDiv.replaceChild(input, loomIdSpan); 475 | input.focus(); 476 | }; 477 | loomDiv.appendChild(loomIdSpan); 478 | 479 | // Create a small export button rendered as an up arrow. 480 | const exportBtn = document.createElement('button'); 481 | exportBtn.className = 'icon-btn'; 482 | exportBtn.style.padding = '5px 8px'; 483 | exportBtn.textContent = '↑'; 484 | exportBtn.onclick = () => exportLoom(loom.loom_id); 485 | loomDiv.appendChild(exportBtn); 486 | 487 | // NEW: Create a delete button rendered as a trash icon. 488 | // NEW: Create a delete button rendered as a trash icon. 489 | const deleteBtn = document.createElement('button'); 490 | deleteBtn.className = 'icon-btn'; 491 | deleteBtn.style.padding = '5px 8px'; 492 | deleteBtn.style.marginLeft = '5px'; 493 | deleteBtn.textContent = '🗑️'; 494 | deleteBtn.onclick = async() => { 495 | // Use SweetAlert instead of confirm() 496 | swal({ 497 | title: "Are you sure?", 498 | text: `Are you sure you want to delete loom ${loom.loom_id}?`, 499 | icon: "warning", 500 | buttons: true, 501 | dangerMode: true, 502 | }).then(async(willDelete) => { 503 | if (willDelete) { 504 | try { 505 | await fetch(`${BASE_URL}/looms/${loom.loom_id}`, { method: 'DELETE' }); 506 | showStatus(`Loom ${loom.loom_id} deleted successfully`); 507 | refreshLoomList(); 508 | } catch (error) { 509 | showStatus(error.message, true); 510 | } 511 | } 512 | }); 513 | }; 514 | loomDiv.appendChild(deleteBtn); 515 | 516 | 517 | loomListDiv.appendChild(loomDiv); 518 | }); 519 | } catch (error) { 520 | showStatus(error.message, true); 521 | } 522 | } 523 | 524 | 525 | function findDeepestNode(node, depth = 0) { 526 | let maxDepth = depth; 527 | let deepestNode = node; 528 | 529 | if (node.children && node.children.length > 0) { 530 | for (const child of node.children) { 531 | const [childNode, childDepth] = findDeepestNode(child, depth + 1); 532 | if (childDepth > maxDepth) { 533 | maxDepth = childDepth; 534 | deepestNode = childNode; 535 | } 536 | } 537 | } 538 | 539 | return [deepestNode, maxDepth]; 540 | } 541 | 542 | async function loadLoomById(loomId) { 543 | try { 544 | const treeData = await getJSON(`${BASE_URL}/loom/${loomId}`); 545 | currentLoomId = loomId; 546 | 547 | // Find deepest node 548 | const [deepestNode, _] = findDeepestNode(treeData); 549 | selectedNode = deepestNode.node_id; 550 | 551 | updateTree(treeData); 552 | showStatus(`Loaded loom ${loomId}`); 553 | } catch (error) { 554 | showStatus(error.message, true); 555 | } 556 | } 557 | 558 | function loadLoomFromInput() { 559 | const loomId = document.getElementById('loadLoomId').value; 560 | if (loomId) { 561 | loadLoomById(loomId); 562 | } 563 | } 564 | 565 | async function exportLoom(loomId) { 566 | try { 567 | const exportData = await getJSON(`${BASE_URL}/loom/${loomId}/export`); 568 | const dataStr = "data:text/json;charset=utf-8," + encodeURIComponent(JSON.stringify(exportData, null, 2)); 569 | const dlAnchorElem = document.createElement('a'); 570 | dlAnchorElem.setAttribute("href", dataStr); 571 | dlAnchorElem.setAttribute("download", `loom_${loomId}.json`); 572 | dlAnchorElem.click(); 573 | showStatus(`Loom ${loomId} exported successfully`); 574 | } catch (error) { 575 | showStatus(error.message, true); 576 | } 577 | } 578 | 579 | async function importLoomFile(event) { 580 | const file = event.target.files[0]; 581 | if (!file) return; 582 | const reader = new FileReader(); 583 | reader.onload = async(e) => { 584 | try { 585 | const loomData = JSON.parse(e.target.result); 586 | if (!currentModelId) { 587 | showStatus('Please load a model first before importing', true); 588 | return; 589 | } 590 | const importResponse = await postJSON(`${BASE_URL}/looms/import`, { 591 | model_id: currentModelId, 592 | loom_data: loomData 593 | }); 594 | showStatus(`Loom imported successfully: ${importResponse.loom_id}`); 595 | // Optionally refresh the loom list 596 | refreshLoomList(); 597 | } catch (error) { 598 | showStatus(error.message, true); 599 | } 600 | }; 601 | reader.readAsText(file); 602 | refreshLoomList(); 603 | } 604 | 605 | refreshLoomList(); 606 | 607 | // Add 'copy parent texts' utility button in the top right of the visualization div. 608 | function copyParentTexts() { 609 | const parentsDiv = document.querySelector('.parents-text'); 610 | if (!parentsDiv) { 611 | showStatus('No parent texts to copy', true); 612 | return; 613 | } 614 | const pres = parentsDiv.querySelectorAll('pre'); 615 | let combinedText = ''; 616 | pres.forEach(pre => { 617 | combinedText += pre.textContent + '\n'; 618 | }); 619 | navigator.clipboard.writeText(combinedText) 620 | .then(() => { 621 | showStatus('Copied parent texts to clipboard'); 622 | }) 623 | .catch(err => { 624 | showStatus('Failed to copy: ' + err, true); 625 | }); 626 | } 627 | 628 | window.addEventListener('load', () => { 629 | const visualizationDiv = document.querySelector('.visualization'); 630 | if (visualizationDiv) { 631 | const copyBtn = document.createElement('button'); 632 | copyBtn.className = 'icon-btn'; 633 | copyBtn.style.position = 'absolute'; 634 | copyBtn.style.top = '10px'; 635 | copyBtn.style.right = '10px'; 636 | copyBtn.textContent = '📄'; 637 | copyBtn.title = 'Copy parent texts'; 638 | copyBtn.onclick = copyParentTexts; 639 | visualizationDiv.appendChild(copyBtn); 640 | } 641 | }); 642 | -------------------------------------------------------------------------------- /src/n8loom/examples/token-benchmark.py: -------------------------------------------------------------------------------- 1 | import time 2 | import matplotlib.pyplot as plt 3 | 4 | from mlx_lm import load 5 | import mlx.core as mx 6 | from n8loom import Loom, load_for_loom 7 | 8 | # Load the model and tokenizer (adjust model name if needed) 9 | model, tokenizer = load_for_loom("Llama-3.2-3B-Instruct-4bit") 10 | 11 | # Define the prompt (using your long story) 12 | prompt = """ 13 | The Epic of Gilgamesh (/ˈɡɪlɡəmɛʃ/)[2] is an epic from ancient Mesopotamia. The literary history of Gilgamesh begins with five Sumerian poems about Gilgamesh (formerly read as Sumerian "Bilgames"[3]), king of Uruk, some of which may date back to the Third Dynasty of Ur (c. 2100 BCE).[1] These independent stories were later used as source material for a combined epic in Akkadian. The first surviving version of this combined epic, known as the "Old Babylonian" version, dates back to the 18th century BCE and is titled after its incipit, Shūtur eli sharrī ("Surpassing All Other Kings"). Only a few tablets of it have survived. The later Standard Babylonian version compiled by Sîn-lēqi-unninni dates to somewhere between the 13th to the 10th centuries BCE and bears the incipit Sha naqba īmuru[note 1] ("He who Saw the Deep(s)", lit. '"He who Sees the Unknown"'). Approximately two-thirds of this longer, twelve-tablet version have been recovered. Some of the best copies were discovered in the library ruins of the 7th-century BCE Assyrian king Ashurbanipal. 14 | 15 | The first half of the story discusses Gilgamesh (who was king of Uruk) and Enkidu, a wild man created by the gods to stop Gilgamesh from oppressing the people of Uruk. After Enkidu becomes civilized through sexual initiation with Shamhat, he travels to Uruk, where he challenges Gilgamesh to a test of strength. Gilgamesh wins the contest; nonetheless, the two become friends. Together, they make a six-day journey to the legendary Cedar Forest, where they ultimately slay its Guardian, Humbaba, and cut down the sacred Cedar.[5] The goddess Ishtar sends the Bull of Heaven to punish Gilgamesh for spurning her advances. Gilgamesh and Enkidu kill the Bull of Heaven, insulting Ishtar in the process, after which the gods decide to sentence Enkidu to death and kill him by giving him a fatal illness. 16 | 17 | In the second half of the epic, distress over Enkidu's death causes Gilgamesh to undertake a long and perilous journey to discover the secret of eternal life. Finally, he meets Utnapishtim, who with his wife were the only humans to survive the Flood triggered by the gods (cf. Athra-Hasis). Gilgamesh learns from him that "Life, which you look for, you will never find. For when the gods created man, they let death be his share, and life withheld in their own hands".[6][7] 18 | 19 | The epic is regarded as a foundational work in religion and the tradition of heroic sagas, with Gilgamesh forming the prototype for later heroes like Heracles (Hercules) and the epic itself serving as an influence for Homeric epics.[8] It has been translated into many languages and is featured in several works of popular fiction. 20 | 21 | Analyze the above summary of Gilgamesh and comment on what it shows about humanity. Do so in at least three paragraphs. 22 | """ 23 | 24 | print("Prompt length:", len(tokenizer.encode(prompt))) 25 | 26 | # Define the batch sizes (n values) to test 27 | n_values = [1, 2, 4, 8, 16, 32, 64, 128] 28 | tokens_per_sec_list = [] 29 | runtimes = [] 30 | memory = [] 31 | # For each n, we generate a batch of children using a fixed max_tokens per child. 32 | # (We assume that the generated text in each child is roughly the new tokens produced.) 33 | for n in n_values: 34 | # (Optional) Reset any peak memory counters, if desired 35 | mx.metal.reset_peak_memory() 36 | # Create a new Loom instance for the current iteration 37 | root = Loom(model, tokenizer, prompt) 38 | 39 | # Use a fixed number of tokens to generate per child. 40 | max_tokens = 128 41 | 42 | start_time = time.time() 43 | children = root.ramify(n=n, temp=0.6, max_tokens=max_tokens, min_p=0.05) 44 | elapsed_time = time.time() - start_time 45 | 46 | # Count the total number of tokens generated by all children. 47 | # Note: Each child is initialized with its generated text. 48 | total_generated_tokens = sum(len(child.tokens) for child in children) 49 | 50 | # Compute tokens per second (protect against division by zero) 51 | tokens_sec = total_generated_tokens / elapsed_time if elapsed_time > 0 else 0 52 | mem_usage_gb = mx.metal.get_peak_memory() / 1e9 # convert bytes to GB 53 | print(f"n={n}, Total Generated Tokens: {total_generated_tokens}, " 54 | f"Tokens/sec: {tokens_sec:.2f}, Runtime: {elapsed_time:.2f} seconds, Peak Memory: {mem_usage_gb:.2f} GB") 55 | 56 | tokens_per_sec_list.append(tokens_sec) 57 | runtimes.append(elapsed_time) 58 | memory.append(mem_usage_gb) 59 | 60 | # Clear caches and delete the root to free memory for the next iteration. 61 | mx.metal.clear_cache() 62 | del root 63 | 64 | # Create figure with proper subplot layout 65 | plt.figure(figsize=(15, 5)) 66 | 67 | # Plot Tokens per Second vs n 68 | plt.subplot(1, 3, 1) 69 | plt.plot(n_values, tokens_per_sec_list, marker='o') 70 | plt.xlabel('n') 71 | plt.ylabel('Tokens per Second') 72 | plt.title('Generation Throughput\n(Tokens/sec) vs n') 73 | 74 | # Plot Runtime vs n 75 | plt.subplot(1, 3, 2) 76 | plt.plot(n_values, runtimes, marker='o', color='red') 77 | plt.xlabel('n') 78 | plt.ylabel('Runtime (seconds)') 79 | plt.title('Runtime vs n') 80 | 81 | # Plot Peak Memory Usage vs n 82 | plt.subplot(1, 3, 3) 83 | plt.plot(n_values, memory, marker='o', color='green') 84 | plt.xlabel('n') 85 | plt.ylabel('Peak Memory Usage (GB)') 86 | plt.title('Peak Memory Usage vs n') 87 | 88 | # Adjust layout to prevent overlap 89 | plt.tight_layout() 90 | plt.show() -------------------------------------------------------------------------------- /src/n8loom/loom.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations # This allows using Heddle directly in annotations 2 | from typing import Any, Dict, List, Optional, Union, Callable 3 | from transformers import PreTrainedTokenizer 4 | from mlx_lm.tokenizer_utils import TokenizerWrapper 5 | 6 | import mlx.core as mx 7 | import mlx.nn as nn 8 | from mlx.utils import tree_flatten, tree_map, tree_unflatten 9 | from mlx_lm.models.cache import KVCache 10 | from mlx_lm import load, generate 11 | from .utils import generate_batched, generate_batched_stream, prompt_to_cache 12 | import numpy as np 13 | import copy 14 | from collections import namedtuple 15 | from .cache_utils import KVFrag, frag_cache, frag_batch_gen, fuse_cache_frags, clip_frag 16 | 17 | 18 | class Heddle: 19 | """ 20 | Rewritten to use tokens as the internal representation while maintaining 21 | the same user-facing (text-based) API. 22 | """ 23 | model: nn.Module 24 | tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper] 25 | parent: Optional[Heddle] 26 | children: List[Heddle] 27 | frag: List[KVFrag] 28 | terminal: bool 29 | 30 | def __init__( 31 | self, 32 | model: nn.Module, 33 | tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], 34 | text: str, 35 | frags: Optional[List[KVFrag]], 36 | children: Optional[List[Heddle]], 37 | parent: Optional[Heddle] = None, 38 | trim_toks: int = 1 39 | ): 40 | self.model = model 41 | self.tokenizer = tokenizer 42 | self.parent = parent 43 | # Underlying representation is tokens only: 44 | if type(text) == list: 45 | self.tokens = text 46 | else: 47 | self.tokens = tokenizer.encode(text)[trim_toks:] 48 | # Store trim_toks for potential future resets: 49 | self._trim_toks = trim_toks 50 | 51 | # If no fragment was passed in, generate one: 52 | if frags is None: 53 | c = None 54 | if self.parent is None: 55 | c, l = prompt_to_cache(model, tokenizer, self.tokens, offset=0) 56 | c = frag_cache(c, 0, l - 1) 57 | else: 58 | parent_cache = self.parent.get_prefix_cache() 59 | p_len = parent_cache[0].offset 60 | c, l = prompt_to_cache( 61 | model, 62 | tokenizer, 63 | self.tokens, 64 | c=parent_cache, 65 | offset=0 66 | ) 67 | c = frag_cache(c, p_len, p_len + l - 1) 68 | frags = c 69 | 70 | self.frag = frags 71 | 72 | # Children 73 | if children is None: 74 | children = [] 75 | self.children = children 76 | for child in self.children: 77 | child.parent = self 78 | 79 | self.terminal = False 80 | 81 | @property 82 | def text(self) -> str: 83 | """ 84 | Returns the decoded text from the underlying tokens. 85 | """ 86 | return self.tokenizer.decode(self.tokens) 87 | 88 | @text.setter 89 | def text(self, new_text: str): 90 | """ 91 | Sets the underlying tokens based on new_text, respecting the original trim. 92 | """ 93 | self.tokens = self.tokenizer.encode(new_text)[self._trim_toks:] 94 | 95 | def clip(self, token_limit: int) -> Heddle: 96 | if token_limit < 0: 97 | token_limit = max(len(self.tokens) + token_limit, 0) 98 | if len(self.tokens) > token_limit: 99 | self.tokens = self.tokens[:token_limit] 100 | self.frag = clip_frag(self.frag, token_limit) 101 | self.children = [] 102 | return self 103 | 104 | def trim(self, token_trim: int) -> Heddle: 105 | return self.clip(-token_trim) 106 | 107 | def to_leaf(self) -> Heddle: 108 | if self.children: 109 | self.children = [] 110 | return self 111 | 112 | def add_child(self, child: Heddle) -> Heddle: 113 | self.children.append(child) 114 | child.parent = self 115 | return child 116 | 117 | def add_text_child(self, text: str) -> Heddle: 118 | child = Heddle(self.model, self.tokenizer, text, None, [], self) 119 | self.add_child(child) 120 | return child 121 | 122 | def remove_child(self, child: Heddle) -> Heddle: 123 | self.children.remove(child) 124 | child.parent = None 125 | return child 126 | 127 | def get_prefix_cache(self) -> List[KVCache]: 128 | parents = [self] 129 | parent = self.parent 130 | while parent is not None: 131 | parents.append(parent) 132 | parent = parent.parent 133 | parents.reverse() 134 | cache = [[] for _ in range(len(self.frag))] 135 | for parent in parents: 136 | for i, frag in enumerate(parent.frag): 137 | cache[i].append(frag) 138 | fused = fuse_cache_frags(cache, offset=1) 139 | return fused 140 | 141 | def make_children( 142 | self, 143 | n: int = 4, 144 | temp: float = 0.8, 145 | max_tokens: int = 8, 146 | min_p: float = 0.05, 147 | stop_strings: List[str] = [], 148 | **kwargs 149 | ) -> Optional[List[Heddle]]: 150 | if self.terminal: 151 | return None 152 | 153 | c = self.get_prefix_cache() 154 | decoded_texts, prompt_cache, total_prompt_len, generated_lengths, ended = generate_batched( 155 | self.model, 156 | self.tokenizer, 157 | prompt=self.tokens, 158 | batch_size=n, 159 | min_p=min_p, 160 | prompt_cache=c, 161 | verbose=False, 162 | temp=temp, 163 | max_tokens=max_tokens, 164 | stop_strings=stop_strings 165 | ) 166 | fragments = frag_batch_gen(prompt_cache, 0, generated_lengths) 167 | made_kids = [] 168 | for i in range(len(fragments)): 169 | child = Heddle( 170 | self.model, 171 | self.tokenizer, 172 | decoded_texts[i], 173 | fragments[i], 174 | [] 175 | ) 176 | if ended[i]: 177 | child.terminal = True 178 | self.add_child(child) 179 | 180 | made_kids.append(child) 181 | mx.metal.clear_cache() 182 | return made_kids 183 | 184 | def ramify(self, arg=None, **kwargs): 185 | """ 186 | Convenience method to expand children. Takes either: 187 | - a single string -> returns a single child 188 | - a list of strings -> returns multiple children 189 | - else uses model generation (batch or stream) 190 | """ 191 | if isinstance(arg, str): 192 | return self.add_text_child(arg) 193 | elif isinstance(arg, list) and all(isinstance(x, str) for x in arg): 194 | children = [self.add_text_child(text) for text in arg] 195 | return children 196 | else: 197 | if kwargs.get('stream', False): 198 | return self.make_child_stream(**kwargs) 199 | else: 200 | return self.make_children(**kwargs) 201 | 202 | def make_child_stream( 203 | self, 204 | n: int = 4, 205 | temp: float = 0.8, 206 | max_tokens: int = 8, 207 | min_p: float = 0.05, 208 | stop_strings: List[str] = [], 209 | **kwargs 210 | ): 211 | c = self.get_prefix_cache() 212 | stream = generate_batched_stream( 213 | self.model, 214 | self.tokenizer, 215 | self.tokens, 216 | batch_size=n, 217 | prompt_cache=c, 218 | verbose=False, 219 | temp=temp, 220 | min_p=min_p, 221 | max_tokens=max_tokens, 222 | stop_strings=stop_strings 223 | ) 224 | made_kids = [] 225 | for update in stream: 226 | if update.get("type") == "final": 227 | final_texts = update.get("decoded_texts", []) 228 | generated_lengths = update.get("generated_lengths", []) 229 | total_prompt_len = update.get("total_prompt_len", 0) 230 | prompt_cache = update.get("prompt_cache", []) 231 | ended = update.get("ended", []) 232 | # free to avoid large intermediate state 233 | update["prompt_cache"] = None 234 | 235 | fragments = frag_batch_gen(prompt_cache, 0, generated_lengths) 236 | made_kids = [] 237 | for i, text_str in enumerate(final_texts): 238 | child = Heddle( 239 | self.model, 240 | self.tokenizer, 241 | text_str, 242 | fragments[i], 243 | [] 244 | ) 245 | if ended[i]: 246 | child.terminal = True 247 | self.add_child(child) 248 | made_kids.append(child) 249 | update["children"] = made_kids 250 | 251 | yield update 252 | return made_kids 253 | 254 | def get_prefix_text(self, exclude: int = 0) -> str: 255 | parents = [self] 256 | parent = self.parent 257 | while parent is not None: 258 | parents.append(parent) 259 | parent = parent.parent 260 | parents.reverse() 261 | return "".join([p.text for p in parents[exclude:]]) 262 | 263 | def get_display_text(self, exclude: int = 0) -> str: 264 | parents = [self] 265 | parent = self.parent 266 | while parent is not None: 267 | parents.append(parent) 268 | parent = parent.parent 269 | parents.reverse() 270 | return "".join([p.display_text() for p in parents[exclude:]]) 271 | 272 | def crown(self) -> str: 273 | return self.get_prefix_text(1) 274 | 275 | def display_text(self) -> str: 276 | return self.text 277 | 278 | def get_prefix_tokens(self, exclude: int = 0) -> List[int]: 279 | parents = [self] 280 | parent = self.parent 281 | while parent is not None: 282 | parents.append(parent) 283 | parent = parent.parent 284 | parents.reverse() 285 | return [token for p in parents[exclude:] for token in p.tokens] 286 | 287 | def apply_all_children( 288 | self, 289 | func: Callable[[Heddle], Any], 290 | apply_self: bool = False, 291 | leaves_only: bool = False 292 | ) -> List[Any]: 293 | """ 294 | Applies a function to all children (and optionally self) in the subtree. 295 | 296 | Args: 297 | func: A function that takes a Heddle and returns any value. 298 | apply_self: Whether to apply the function to the root node itself. 299 | leaves_only: If True, only apply to leaf nodes (i.e., nodes with no children). 300 | 301 | Returns: 302 | A list of results from applying the function to each matching node. 303 | """ 304 | results = [] 305 | if apply_self: 306 | results.append(func(self)) 307 | pre_children = self.get_all_children() 308 | for child in pre_children: 309 | if leaves_only and child.children: 310 | continue 311 | results.append(func(child)) 312 | return results 313 | 314 | def at_all_leaves(self, func: Callable[[Heddle], Any]) -> List[Any]: 315 | return self.apply_all_children(func, apply_self=False, leaves_only=True) 316 | 317 | def apply_at_leaves(self, *funcs: Callable[[Heddle], Any]): 318 | for func in funcs[:-1]: 319 | self.at_all_leaves(func) 320 | return self.at_all_leaves(funcs[-1]) 321 | 322 | def get_all_children(self, depth: int = 0) -> List[Heddle]: 323 | """ 324 | Returns all descendants (the entire subtree). 325 | `depth` is just for internal recursion tracking. 326 | """ 327 | nodes = [] 328 | if depth > 0: 329 | nodes.append(self) 330 | for child in self.children: 331 | nodes.extend(child.get_all_children(depth + 1)) 332 | return nodes 333 | 334 | def get_all_leaves(self) -> List[Heddle]: 335 | return [child for child in self.get_all_children() if not child.children] 336 | 337 | def count_children(self) -> int: 338 | return 1 + sum([child.count_children() for child in self.children]) 339 | 340 | def count_leaves(self) -> int: 341 | if not self.children: 342 | return 1 343 | return sum([child.count_leaves() for child in self.children]) 344 | 345 | def __repr__(self): 346 | return f"Heddle({self.text}, {self.children})" 347 | 348 | 349 | class Loom(Heddle): 350 | def __init__(self, model, tokenizer, prompt): 351 | self.user_prompt = prompt 352 | self.chat_template_used = False 353 | 354 | messages = [{"role": "user", "content": prompt}] 355 | try: 356 | # Attempt to apply chat template, if available 357 | prompt = tokenizer.apply_chat_template( 358 | messages, 359 | add_generation_prompt=True, 360 | tokenize=False 361 | ) 362 | self.chat_template_used = True 363 | except: 364 | pass 365 | 366 | # Initialize base Heddle without trimming 367 | super().__init__(model, tokenizer, prompt, None, [], None, trim_toks=0) 368 | 369 | def display_text(self): 370 | if self.chat_template_used: 371 | return "Prompt: " + self.user_prompt + "\n\nResponse:\n\n" 372 | return super().display_text() 373 | -------------------------------------------------------------------------------- /src/n8loom/models/llama.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Optional, Union, List 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | from mlx_lm import load 8 | import numpy as np 9 | from mlx_lm.models.base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention, create_causal_mask 10 | from mlx_lm.models.rope_utils import initialize_rope 11 | from mlx_lm.models.cache import KVCache 12 | from mlx_lm.utils import wired_limit, cache, maybe_quantize_kv_cache, GenerationResponse 13 | 14 | @dataclass 15 | class ModelArgs(BaseModelArgs): 16 | model_type: str 17 | hidden_size: int 18 | num_hidden_layers: int 19 | intermediate_size: int 20 | num_attention_heads: int 21 | rms_norm_eps: float 22 | vocab_size: int 23 | head_dim: Optional[int] = None 24 | max_position_embeddings: Optional[int] = None 25 | num_key_value_heads: Optional[int] = None 26 | attention_bias: bool = False 27 | mlp_bias: bool = False 28 | rope_theta: float = 10000 29 | rope_traditional: bool = False 30 | rope_scaling: Optional[Dict[str, Union[float, str]]] = None 31 | tie_word_embeddings: bool = True 32 | 33 | def __post_init__(self): 34 | if self.num_key_value_heads is None: 35 | self.num_key_value_heads = self.num_attention_heads 36 | 37 | #@mx.compile 38 | def double_attn(queries, keys, prefix_keys, values, prefix_values, scale, expand, B, L): 39 | prefix_keys = mx.repeat(prefix_keys, repeats=B, axis=0) 40 | prefix_values = mx.repeat(prefix_values, repeats=B, axis=0) 41 | keys = mx.concat([prefix_keys, keys], axis=-2) 42 | values = mx.concat([prefix_values, values], axis=-2) 43 | output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale) 44 | return output.transpose(0, 2, 1, 3).reshape(B, L, -1) 45 | 46 | class Attention(nn.Module): 47 | def __init__(self, args: ModelArgs): 48 | super().__init__() 49 | 50 | dim = args.hidden_size 51 | self.n_heads = n_heads = args.num_attention_heads 52 | self.n_kv_heads = n_kv_heads = args.num_key_value_heads 53 | 54 | self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads 55 | 56 | self.scale = head_dim**-0.5 57 | if hasattr(args, "attention_bias"): 58 | attention_bias = args.attention_bias 59 | else: 60 | attention_bias = False 61 | 62 | self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) 63 | self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) 64 | self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) 65 | self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) 66 | self.rope = initialize_rope( 67 | self.head_dim, 68 | args.rope_theta, 69 | args.rope_traditional, 70 | args.rope_scaling, 71 | args.max_position_embeddings, 72 | ) 73 | 74 | def __call__( 75 | self, 76 | x: mx.array, 77 | mask: Optional[mx.array] = None, 78 | cache: Optional[Any] = None, 79 | cache_gen: Optional[Any] = None, 80 | ) -> mx.array: 81 | B, L, D = x.shape 82 | 83 | queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) 84 | 85 | # Prepare the queries, keys and values for the attention computation 86 | queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) 87 | keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 88 | values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) 89 | if not cache_gen: 90 | if cache is not None: 91 | queries = self.rope(queries, offset=cache.offset) 92 | keys = self.rope(keys, offset=cache.offset) 93 | keys, values = cache.update_and_fetch(keys, values) 94 | else: 95 | queries = self.rope(queries) 96 | keys = self.rope(keys) 97 | 98 | output = scaled_dot_product_attention( 99 | queries, keys, values, cache=cache, scale=self.scale, mask=mask 100 | ) 101 | 102 | output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) 103 | return self.o_proj(output) 104 | elif cache_gen and cache: 105 | # Apply RoPE with the appropriate offset 106 | queries = self.rope(queries, offset=cache.offset + cache_gen.offset) 107 | keys = self.rope(keys, offset=cache.offset + cache_gen.offset) 108 | 109 | # prefix_* are the "old" cached KV tensors (batch=1) from previous prompt tokens 110 | prefix_keys = cache.keys[..., : cache.offset, :] # shape: (1, n_kv_heads, prefix_len, head_dim) 111 | prefix_values = cache.values[..., : cache.offset, :] # shape: (1, n_kv_heads, prefix_len, head_dim) 112 | 113 | # Then fetch updated keys/values for the newly generated token(s) 114 | keys, values = cache_gen.update_and_fetch(keys, values) 115 | 116 | output = double_attn(queries, keys, prefix_keys, values, prefix_values, self.scale, self.n_heads // self.n_kv_heads, B, L) 117 | return self.o_proj(output) 118 | else: 119 | raise ValueError("cache_gen is True but cache is None") 120 | 121 | 122 | class MLP(nn.Module): 123 | def __init__(self, args: ModelArgs): 124 | super().__init__() 125 | 126 | dim = args.hidden_size 127 | hidden_dim = args.intermediate_size 128 | if hasattr(args, "mlp_bias"): 129 | mlp_bias = args.mlp_bias 130 | else: 131 | mlp_bias = False 132 | 133 | self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) 134 | self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) 135 | self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) 136 | 137 | def __call__(self, x) -> mx.array: 138 | return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) 139 | 140 | 141 | class TransformerBlock(nn.Module): 142 | def __init__(self, args: ModelArgs): 143 | super().__init__() 144 | self.num_attention_heads = args.num_attention_heads 145 | self.hidden_size = args.hidden_size 146 | self.self_attn = Attention(args) 147 | self.mlp = MLP(args) 148 | self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 149 | self.post_attention_layernorm = nn.RMSNorm( 150 | args.hidden_size, eps=args.rms_norm_eps 151 | ) 152 | self.args = args 153 | 154 | def __call__( 155 | self, 156 | x: mx.array, 157 | mask: Optional[mx.array] = None, 158 | cache: Optional[Any] = None, 159 | cache_gen: Optional[Any] = None, 160 | ) -> mx.array: 161 | r = self.self_attn(self.input_layernorm(x), mask, cache, cache_gen) 162 | h = x + r 163 | r = self.mlp(self.post_attention_layernorm(h)) 164 | out = h + r 165 | return out 166 | 167 | 168 | class LlamaModel(nn.Module): 169 | def __init__(self, args: ModelArgs): 170 | super().__init__() 171 | self.args = args 172 | self.vocab_size = args.vocab_size 173 | self.num_hidden_layers = args.num_hidden_layers 174 | assert self.vocab_size > 0 175 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 176 | self.layers = [ 177 | TransformerBlock(args=args) for _ in range(args.num_hidden_layers) 178 | ] 179 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 180 | 181 | def __call__( 182 | self, 183 | inputs: mx.array, 184 | mask: mx.array = None, 185 | cache=None, 186 | cache_gen=None, 187 | ): 188 | h = self.embed_tokens(inputs) 189 | if cache_gen is None: 190 | cache_gen = [None] * len(self.layers) 191 | 192 | if mask is None: 193 | mask = create_attention_mask(h, cache) 194 | 195 | if cache is None: 196 | cache = [None] * len(self.layers) 197 | 198 | for layer, c, cg in zip(self.layers, cache, cache_gen): 199 | #print(c, cg) 200 | h = layer(h, mask, cache=c, cache_gen=cg) 201 | 202 | return self.norm(h) 203 | 204 | 205 | class Model(nn.Module): 206 | def __init__(self, args: ModelArgs): 207 | super().__init__() 208 | self.args = args 209 | self.model_type = args.model_type 210 | self.model = LlamaModel(args) 211 | if not args.tie_word_embeddings: 212 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 213 | 214 | def __call__( 215 | self, 216 | inputs: mx.array, 217 | mask: mx.array = None, 218 | cache=None, 219 | cache_gen=None, 220 | ): 221 | out = self.model(inputs, mask, cache, cache_gen) 222 | if self.args.tie_word_embeddings: 223 | out = self.model.embed_tokens.as_linear(out) 224 | else: 225 | out = self.lm_head(out) 226 | return out 227 | 228 | def sanitize(self, weights): 229 | # Remove unused precomputed rotary freqs 230 | return { 231 | k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k 232 | } 233 | 234 | @property 235 | def layers(self): 236 | return self.model.layers -------------------------------------------------------------------------------- /src/n8loom/sample_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023-2024 Apple Inc. 2 | 3 | import math 4 | from functools import partial 5 | from typing import Callable, Dict, Optional 6 | 7 | import mlx.core as mx 8 | 9 | 10 | def make_sampler( 11 | temp: float = 0.0, 12 | top_p: float = 0.0, 13 | min_p: float = 0.0, 14 | min_tokens_to_keep: int = 1, 15 | top_k: int = -1, 16 | ) -> Callable[mx.array, mx.array]: 17 | """ 18 | Make a sampler function for use with ``generate_step``. 19 | 20 | Args: 21 | temp (float): The temperature for sampling, if 0 the argmax is used. 22 | Default: ``0``. 23 | top_p (float, optional): Nulceus sampling, higher means model considers 24 | more less likely words. 25 | min_p (float, optional): The minimum value (scaled by the top token's 26 | probability) that a token probability must have to be considered. 27 | min_tokens_to_keep (int, optional): Minimum number of tokens that cannot 28 | be filtered by min_p sampling. 29 | top_k (int, optional): The top k tokens ranked by probability to constrain 30 | the sampling to. 31 | 32 | Returns: 33 | Callable[mx.array, mx.array]: 34 | A sampler which takes log-probabilities and returns tokens. 35 | """ 36 | if temp == 0: 37 | return lambda x: mx.argmax(x, axis=-1) 38 | elif top_p > 0 and top_p < 1.0: 39 | return lambda x: top_p_sampling(x, top_p, temp) 40 | elif min_p != 0.0: 41 | return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) 42 | elif top_k > 0: 43 | return lambda x: top_k_sampling(x, top_k, temp) 44 | else: 45 | return lambda x: categorical_sampling(x, temp) 46 | 47 | 48 | def make_logits_processors( 49 | logit_bias: Optional[Dict[int, float]] = None, 50 | repetition_penalty: Optional[float] = None, 51 | repetition_context_size: Optional[int] = 20, 52 | ): 53 | """ 54 | Make logits processors for use with ``generate_step``. 55 | 56 | Args: 57 | repetition_penalty (float, optional): The penalty factor for repeating 58 | tokens. 59 | repetition_context_size (int, optional): The number of tokens to 60 | consider for repetition penalty. Default: ``20``. 61 | logit_bias (dictionary, optional): Additive logit bias. 62 | 63 | Returns: 64 | List[Callable[[mx.array, mx.array], mx.array]]: 65 | A list of logits processors. Each processor in the list is a 66 | callable which takes an array of tokens and an array of logits 67 | and returns the updated logits. 68 | """ 69 | logits_processors = [] 70 | if logit_bias: 71 | indices = mx.array(list(logit_bias.keys())) 72 | values = mx.array(list(logit_bias.values())) 73 | 74 | def logit_bias_processor(_, logits): 75 | logits[:, indices] += values 76 | return logits 77 | 78 | logits_processors.append(logit_bias_processor) 79 | 80 | if repetition_penalty and repetition_penalty != 0.0: 81 | logits_processors.append( 82 | make_repetition_penalty(repetition_penalty, repetition_context_size) 83 | ) 84 | return logits_processors 85 | 86 | 87 | @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) 88 | def top_k_sampling( 89 | logprobs: mx.array, 90 | top_k: int, 91 | temperature=1.0, 92 | ) -> mx.array: 93 | """ 94 | Sample from only the top K tokens ranked by probability. 95 | 96 | Args: 97 | logprobs: A vector of log probabilities. 98 | top_k (int): Top k tokens to sample from. 99 | """ 100 | vocab_size = logprobs.shape[-1] 101 | if not isinstance(top_k, int) or not (0 < top_k < vocab_size): 102 | raise ValueError( 103 | f"`top_k` has to be an integer in the (0, {vocab_size}] interval," 104 | f" but is {top_k}." 105 | ) 106 | logprobs = logprobs * (1 / temperature) 107 | mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:] 108 | masked_logprobs = mx.put_along_axis( 109 | logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1 110 | ) 111 | return mx.random.categorical(masked_logprobs, axis=-1) 112 | 113 | 114 | @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) 115 | def min_p_sampling( 116 | logprobs: mx.array, 117 | min_p: float, 118 | min_tokens_to_keep: int = 1, 119 | temperature=1.0, 120 | ) -> mx.array: 121 | """ 122 | Apply min-p sampling to the logprobs. 123 | 124 | Min-p keeps all tokens that are above a minimum probability, scaled by the 125 | probability of the most likely token. As a result, the filter is more 126 | aggressive given a very high-probability token. 127 | 128 | Args: 129 | logprobs: A vector of log probabilities. 130 | min_p (float): Minimum token probability. Typical values are in the 131 | 0.01-0.2 range, comparably selective as setting `top_p` in the 132 | 0.99-0.8 range. 133 | min_tokens_to_keep (int, optional): Minimum number of tokens that cannot 134 | be filtered. Default: ``1``. 135 | 136 | """ 137 | if not (0 <= min_p <= 1.0): 138 | raise ValueError( 139 | f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}" 140 | ) 141 | if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): 142 | raise ValueError( 143 | f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}" 144 | ) 145 | # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 146 | 147 | logprobs = logprobs * (1 / temperature) 148 | 149 | # Indices sorted in decreasing order 150 | sorted_indices = mx.argsort(-logprobs, axis=-1) 151 | sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1) 152 | 153 | # Top probability 154 | top_logprobs = sorted_logprobs[:, 0:1] 155 | 156 | # Calculate the min_p threshold 157 | scaled_min_p = top_logprobs + math.log(min_p) 158 | 159 | # Mask tokens that have a probability less than the scaled min_p 160 | tokens_to_remove = sorted_logprobs < scaled_min_p 161 | tokens_to_remove[..., :min_tokens_to_keep] = False 162 | 163 | # Create pool of tokens with probability less than scaled min_p 164 | selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) 165 | 166 | # Return sampled tokens 167 | sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None] 168 | return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1) 169 | 170 | 171 | @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) 172 | def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: 173 | """ 174 | Apply top-p (nucleus) sampling to logits. 175 | 176 | Args: 177 | logits: The logits from the model's output. 178 | top_p: The cumulative probability threshold for top-p filtering. 179 | temperature: Temperature parameter for softmax distribution reshaping. 180 | Returns: 181 | token selected based on the top-p criterion. 182 | """ 183 | # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 184 | probs = mx.softmax(logits * (1 / temperature), axis=-1) 185 | 186 | # sort probs in ascending order 187 | sorted_indices = mx.argsort(probs, axis=-1) 188 | sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) 189 | 190 | cumulative_probs = mx.cumsum(sorted_probs, axis=-1) 191 | 192 | # select tokens with cumulative probs below threshold 193 | top_probs = mx.where( 194 | cumulative_probs > 1 - top_p, 195 | sorted_probs, 196 | 0, 197 | ) 198 | 199 | sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None] 200 | return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1) 201 | 202 | 203 | @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) 204 | def categorical_sampling(logits, temp): 205 | return mx.random.categorical(logits * (1 / temp)) 206 | 207 | 208 | def make_repetition_penalty(penalty: float, context_size: int = 20): 209 | """ 210 | Make repetition penalty processor. 211 | 212 | Paper: https://arxiv.org/abs/1909.05858 213 | 214 | Args: 215 | penalty (float): The repetition penalty factor to be applied. 216 | context_size (int): The number of previous tokens to use. 217 | Default: ``20``. 218 | 219 | Returns: 220 | Callable[[mx.array, List[int]], mx.array]: 221 | The repetition penalty processor. 222 | """ 223 | if penalty < 0 or not isinstance(penalty, (int, float)): 224 | raise ValueError(f"penalty must be a non-negative float, got {penalty}") 225 | 226 | def repetition_penalty_processor(tokens, logits): 227 | if len(tokens) > 0: 228 | tokens = tokens[-context_size:] 229 | selected_logits = logits[:, tokens] 230 | selected_logits = mx.where( 231 | selected_logits < 0, 232 | selected_logits * penalty, 233 | selected_logits / penalty, 234 | ) 235 | logits[:, tokens] = selected_logits 236 | return logits 237 | 238 | return repetition_penalty_processor -------------------------------------------------------------------------------- /src/n8loom/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | from typing import List, Optional, Callable, Union, Tuple 4 | 5 | import mlx.core as mx 6 | import mlx.nn as nn 7 | from mlx.utils import tree_flatten, tree_map, tree_unflatten 8 | from transformers import PreTrainedTokenizer 9 | from mlx_lm.tokenizer_utils import TokenizerWrapper 10 | from .sample_utils import make_sampler 11 | from .models.llama import Model 12 | from mlx_lm.utils import wired_limit, cache, maybe_quantize_kv_cache, GenerationResponse, get_model_path, load_model, load_adapters, load_tokenizer 13 | from mlx_lm import load, generate, stream_generate 14 | import importlib 15 | import math 16 | MODEL_REMAPPING = { 17 | "mistral": "llama", # mistral is compatible with llama 18 | "phi-msft": "phixtral", 19 | "falcon_mamba": "mamba", 20 | } 21 | 22 | def _get_classes(config: dict): 23 | """ 24 | Retrieve the model and model args classes based on the configuration. 25 | 26 | Args: 27 | config (dict): The model configuration. 28 | 29 | Returns: 30 | A tuple containing the Model class and the ModelArgs class. 31 | """ 32 | model_type = config["model_type"] 33 | model_type = MODEL_REMAPPING.get(model_type, model_type) 34 | try: 35 | arch = importlib.import_module(f"n8loom.models.{model_type}") 36 | except ImportError: 37 | msg = f"Model type {model_type} not supported." 38 | raise ValueError(msg) 39 | 40 | return arch.Model, arch.ModelArgs 41 | 42 | def load_for_loom( 43 | path_or_hf_repo: str, 44 | tokenizer_config={}, 45 | model_config={}, 46 | adapter_path: Optional[str] = None, 47 | lazy: bool = False, 48 | ) -> Tuple[nn.Module, TokenizerWrapper]: 49 | """ 50 | Load the model and tokenizer from a given path or a huggingface repository. 51 | 52 | Args: 53 | path_or_hf_repo (Path): The path or the huggingface repository to load the model from. 54 | tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. 55 | Defaults to an empty dictionary. 56 | model_config(dict, optional): Configuration parameters specifically for the model. 57 | Defaults to an empty dictionary. 58 | adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers 59 | to the model. Default: ``None``. 60 | lazy (bool): If ``False`` eval the model parameters to make sure they are 61 | loaded in memory before returning, otherwise they will be loaded 62 | when needed. Default: ``False`` 63 | Returns: 64 | Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer. 65 | 66 | Raises: 67 | FileNotFoundError: If config file or safetensors are not found. 68 | ValueError: If model class or args class are not found. 69 | """ 70 | model_path = get_model_path(path_or_hf_repo) 71 | 72 | model, config = load_model(model_path, lazy, get_model_classes=_get_classes) 73 | if adapter_path is not None: 74 | model = load_adapters(model, adapter_path) 75 | model.eval() 76 | tokenizer = load_tokenizer( 77 | model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None) 78 | ) 79 | 80 | return model, tokenizer 81 | 82 | def prompt_to_cache( 83 | model: nn.Module, 84 | tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], 85 | prompt_ids: List[int], 86 | c: Optional[List[cache.KVCache]] = None, 87 | prefill_step_size: int = 512, 88 | offset: int = 0, 89 | ) -> tuple[List[cache.KVCache], int]: 90 | """ 91 | Process a prompt and fill the KV cache. 92 | 93 | Args: 94 | model (nn.Module): The language model. 95 | tokenizer (PreTrainedTokenizer or TokenizerWrapper): The tokenizer. 96 | prompt_ids (List[int]): List of token IDs for the prompt. 97 | c (Optional[List[cache.KVCache]]): The KV cache to fill. If None, a new cache is created. 98 | prefill_step_size (int): Step size used when processing the prompt. 99 | 100 | Returns: 101 | Tuple[List[cache.KVCache], int]: The filled KV cache and total prompt length. 102 | """ 103 | prompt_ids = mx.array(prompt_ids) 104 | if c is None: 105 | c = cache.make_prompt_cache(model) 106 | if not isinstance(tokenizer, TokenizerWrapper): 107 | tokenizer = TokenizerWrapper(tokenizer) 108 | 109 | total_prompt_len = prompt_ids.shape[0] - offset 110 | processed = 0 111 | while processed < total_prompt_len: 112 | chunk_end = min(processed + prefill_step_size, total_prompt_len) 113 | inputs_chunk = prompt_ids[processed:chunk_end] 114 | _ = model(inputs_chunk[None], cache=c) 115 | processed = chunk_end 116 | 117 | return c, total_prompt_len 118 | 119 | 120 | def _prefill_cache( 121 | model: nn.Module, 122 | prompt_ids: mx.array, 123 | prompt_cache: List[cache.KVCache], 124 | generation_stream: mx.Stream, 125 | prefill_step_size: int, 126 | ) -> tuple[int, float]: 127 | """ 128 | Prefill the prompt cache by running the prompt through the model in chunks. 129 | 130 | Returns: 131 | total_prompt_len: number of tokens in the prompt. 132 | prompt_tps: prompt tokens per second (for logging purposes). 133 | """ 134 | total_prompt_len = prompt_ids.shape[0] 135 | processed = 0 136 | with wired_limit(model, [generation_stream]): 137 | tic = time.perf_counter() 138 | while processed < total_prompt_len: 139 | chunk_end = min(processed + prefill_step_size, total_prompt_len) 140 | inputs_chunk = prompt_ids[processed:chunk_end] 141 | with mx.stream(generation_stream): 142 | _ = model(inputs_chunk[None], cache=prompt_cache) 143 | mx.eval([c.state for c in prompt_cache]) 144 | processed = chunk_end 145 | mx.metal.clear_cache() 146 | prompt_time = time.perf_counter() - tic 147 | prompt_tps = (total_prompt_len * 1) / prompt_time if prompt_time > 0 else 0.0 # Only one prompt sequence. 148 | return total_prompt_len, prompt_tps 149 | 150 | 151 | def generate_batched( 152 | model: nn.Module, 153 | tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], 154 | prompt: List[int], 155 | batch_size: int, 156 | *, 157 | prompt_cache: Optional[List[cache.KVCache]] = None, 158 | verbose: bool = False, 159 | max_tokens: int = 256, 160 | temp: float = 0.0, 161 | top_p: float = 0.0, 162 | min_p: float = 0.0, 163 | min_tokens_to_keep: int = 1, 164 | repetition_penalty: float = 1.0, 165 | repetition_context_size: int = 20, 166 | prefill_step_size: int = 512, 167 | stop_strings: List[str] = [], 168 | **kwargs, 169 | ) -> Tuple[ 170 | List[List[int]], # Generated token sequences 171 | List[cache.KVCache], # Generation cache 172 | int, # total_prompt_len 173 | List[int], # lengths of the generated token sequences 174 | List[bool] # ended flags 175 | ]: 176 | """ 177 | Generate multiple responses in parallel from the same prompt, returning token sequences. 178 | 179 | Returns: 180 | - generated_token_seqs: list of token sequences (one per batch element). 181 | - cache_gen: the per-token generation cache (list of KVCache). 182 | - total_prompt_len: number of prompt tokens used. 183 | - generated_lengths: lengths of the token sequences generated for each batch element. 184 | - ended: boolean flags indicating whether each sequence ended via an EOS token. 185 | """ 186 | if not isinstance(tokenizer, TokenizerWrapper): 187 | tokenizer = TokenizerWrapper(tokenizer) 188 | 189 | # Prompt to feed into the model is the last token (non-traditional attention pattern). 190 | prompt_ids = mx.array([prompt[-1]]) 191 | 192 | ended = [False] * batch_size 193 | ended_due_to_eos = [False] * batch_size 194 | sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) 195 | 196 | generation_stream = mx.new_stream(mx.default_device()) 197 | if prompt_cache is None: 198 | prompt_cache = cache.make_prompt_cache(model) 199 | with wired_limit(model, [generation_stream]): 200 | total_prompt_len, prompt_tps = _prefill_cache( 201 | model, prompt_ids, prompt_cache, generation_stream, prefill_step_size 202 | ) 203 | else: 204 | total_prompt_len = prompt_cache[0].offset 205 | 206 | # Create a separate generation cache for newly generated tokens. 207 | cache_gen = cache.make_prompt_cache(model) 208 | cache_step = max(2 ** math.floor(math.log2( 256 / batch_size)), 4) 209 | for c in cache_gen: 210 | c.step = cache_step 211 | 212 | # The initial (last) token is repeated `batch_size` times 213 | y = mx.repeat(prompt_ids[-1:][None, :], repeats=batch_size, axis=0) 214 | 215 | # We accumulate generated tokens for each batch element 216 | tokens_so_far = [[] for _ in range(batch_size)] 217 | stop_string_tok_lengths = [len(tokenizer.encode(s)) for s in stop_strings] 218 | tic = time.perf_counter() 219 | n = 0 220 | has_stop_strings = len(stop_strings) > 0 221 | with wired_limit(model, [generation_stream]): 222 | while n < max_tokens: 223 | logits = model(y, cache=prompt_cache, cache_gen=cache_gen) 224 | logits = logits[:, -1, :] # only the logits for the new position 225 | mx.async_eval(logits) 226 | 227 | # Compute probabilities and sample new tokens 228 | logprobs = logits - mx.logsumexp(logits, keepdims=True) 229 | sampled_tokens = sampler(logprobs).tolist() 230 | 231 | next_tokens_list = [] 232 | for i in range(batch_size): 233 | if ended[i]: 234 | # If sequence already ended, repeat the previous token 235 | next_tokens_list.append(y[i, 0]) 236 | continue 237 | 238 | token = sampled_tokens[i] 239 | if token in tokenizer.eos_token_ids: 240 | ended[i] = True 241 | ended_due_to_eos[i] = True 242 | # Check for stop strings 243 | if has_stop_strings: 244 | for j, stop_string in enumerate(stop_strings): 245 | max_length_to_check = 2 * stop_string_tok_lengths[j] 246 | last_toks = tokens_so_far[i][-max_length_to_check:] 247 | last_toks_str = tokenizer.decode(last_toks) 248 | if stop_string in last_toks_str: 249 | ended[i] = True 250 | next_tokens_list.append(token) 251 | if not ended[i]: 252 | tokens_so_far[i].append(token) 253 | 254 | y = mx.array(next_tokens_list).reshape(batch_size, 1) 255 | n += 1 256 | if n % cache_step == 0: 257 | mx.metal.clear_cache() 258 | if all(ended): 259 | break 260 | 261 | 262 | generation_time = time.perf_counter() - tic 263 | total_generated_tokens = sum(len(seq) for seq in tokens_so_far) 264 | 265 | if verbose: 266 | for i, seq in enumerate(tokens_so_far): 267 | print("=" * 10) 268 | print(f"Batch {i} tokens:", seq) 269 | print("Decoded text:", tokenizer.decode(seq)) 270 | print("=" * 10) 271 | if not tokens_so_far: 272 | print("No tokens generated for this prompt.") 273 | else: 274 | print( 275 | f"Prompt tokens (per sequence): {total_prompt_len}" 276 | ) 277 | print( 278 | f"Generation tokens (max per sequence): {n}, " 279 | f"Overall generation TPS: " 280 | f"{(total_generated_tokens)/(generation_time+1e-9):.3f}" 281 | ) 282 | peak_mem = mx.metal.get_peak_memory() / 1e9 283 | print(f"Peak memory: {peak_mem:.3f} GB") 284 | 285 | mx.metal.clear_cache() 286 | 287 | # Return the token sequences instead of decoded strings 288 | return ( 289 | tokens_so_far, 290 | cache_gen, 291 | total_prompt_len, 292 | [len(seq) for seq in tokens_so_far], 293 | ended_due_to_eos, 294 | ) 295 | def generate_batched_stream( 296 | model: nn.Module, 297 | tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], 298 | prompt: List[int], 299 | batch_size: int, 300 | *, 301 | prompt_cache: Optional[List[cache.KVCache]] = None, 302 | verbose: bool = False, 303 | max_tokens: int = 256, 304 | temp: float = 0.0, 305 | top_p: float = 0.0, 306 | min_p: float = 0.0, 307 | min_tokens_to_keep: int = 1, 308 | repetition_penalty: float = 1.0, 309 | repetition_context_size: int = 20, 310 | prefill_step_size: int = 512, 311 | stop_strings: List[str] = [], # <-- Added stop_strings parameter. 312 | **kwargs, 313 | ): 314 | """ 315 | Generate multiple responses in parallel from the same prompt, yielding updates as tokens are generated. 316 | Returns tokens (rather than decoded strings) in each update, but also includes a decoded text if desired. 317 | 318 | Now supports stop strings: if any decoded text contains one of the stop strings (within a window of 319 | twice the stop string token length), that sequence is ended. 320 | """ 321 | if not isinstance(tokenizer, TokenizerWrapper): 322 | tokenizer = TokenizerWrapper(tokenizer) 323 | 324 | # Use the last token of the prompt as the starting token. 325 | prompt_ids = mx.array([prompt[-1]]) 326 | ended = [False] * batch_size 327 | ended_due_to_eos = [False] * batch_size 328 | sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) 329 | 330 | generation_stream = mx.new_stream(mx.default_device()) 331 | if prompt_cache is None: 332 | prompt_cache = cache.make_prompt_cache(model) 333 | with wired_limit(model, [generation_stream]): 334 | total_prompt_len, _ = _prefill_cache( 335 | model, prompt_ids, prompt_cache, generation_stream, prefill_step_size 336 | ) 337 | else: 338 | total_prompt_len = prompt_cache[0].offset 339 | 340 | # Create a separate generation cache for new tokens. 341 | cache_gen = cache.make_prompt_cache(model) 342 | cache_step = max(2 ** math.floor(math.log2(256 / batch_size)), 4) 343 | for c in cache_gen: 344 | c.step = cache_step 345 | 346 | # Repeat the starting token for each batch element. 347 | y = mx.repeat(prompt_ids[-1:][None, :], repeats=batch_size, axis=0) 348 | tokens_so_far = [[] for _ in range(batch_size)] 349 | 350 | # Pre-compute token lengths for each stop string. 351 | stop_string_tok_lengths = [len(tokenizer.encode(s)) for s in stop_strings] 352 | has_stop_strings = len(stop_strings) > 0 353 | 354 | n = 0 355 | with wired_limit(model, [generation_stream]): 356 | while n < max_tokens: 357 | logits = model(y, cache=prompt_cache, cache_gen=cache_gen) 358 | logits = logits[:, -1, :] # Only the logits for the new position. 359 | mx.async_eval(logits) 360 | sampled_tokens = sampler(logits).tolist() 361 | 362 | new_tokens = [] 363 | for i in range(batch_size): 364 | if ended[i]: 365 | # If this sequence has already ended, repeat its last token. 366 | new_tokens.append(y[i, 0]) 367 | continue 368 | 369 | token = sampled_tokens[i] 370 | if token in tokenizer.eos_token_ids: 371 | ended[i] = True 372 | ended_due_to_eos[i] = True 373 | 374 | # Check if any stop string is found in a window of the generated tokens. 375 | if has_stop_strings: 376 | for j, stop_string in enumerate(stop_strings): 377 | max_length_to_check = 2 * stop_string_tok_lengths[j] 378 | # Consider the last tokens generated so far. 379 | last_toks = tokens_so_far[i][-max_length_to_check:] 380 | last_toks_str = tokenizer.decode(last_toks) 381 | if stop_string in last_toks_str: 382 | ended[i] = True 383 | new_tokens.append(token) 384 | if not ended[i]: 385 | tokens_so_far[i].append(token) 386 | 387 | y = mx.array(new_tokens).reshape(batch_size, 1) 388 | n += 1 389 | if n % cache_step == 0: 390 | mx.metal.clear_cache() 391 | # Yield partial update with tokens and decoded text. 392 | yield { 393 | "type": "update", 394 | "tokens": [ts[:] for ts in tokens_so_far], # copy current tokens 395 | "decoded_texts": [tokenizer.decode(ts) for ts in tokens_so_far], 396 | "ended": ended[:], 397 | } 398 | if all(ended): 399 | break 400 | 401 | mx.metal.clear_cache() 402 | # Final yield with all tokens and additional generation metadata. 403 | yield { 404 | "type": "final", 405 | "tokens": tokens_so_far, 406 | "decoded_texts": [tokenizer.decode(seq) for seq in tokens_so_far], 407 | "generated_lengths": [len(seq) for seq in tokens_so_far], 408 | "total_prompt_len": total_prompt_len, 409 | "ended": ended_due_to_eos, 410 | "prompt_cache": cache_gen, 411 | } 412 | --------------------------------------------------------------------------------