├── README.md ├── examples ├── get_24.py └── mayonnaise.py └── tree_of_thoughts.py /README.md: -------------------------------------------------------------------------------- 1 | # Tree of thoughts 2 | A LMQL implementation of something like tree of thoughts. Applies a natural selection process to steer reasoning and constrain the results. 3 | 4 | Many many improvements to be made. 5 | 6 | ## Features 7 | I meant this to be as "engineerable" as possible. Each tree instance is configured to solve a specific problem, and be used as a function. It can apply a callback function to the result so it doesn't necessarily have to return text. 8 | 9 | Some of the main features: 10 | - Asynchronous 11 | - Configurable 12 | - Prompt-based and programmatic result validation 13 | 14 | Some planned features: 15 | - Multiple arguments and argument types 16 | - Feature weighting: option to assign relative importance to selection criteria 17 | - Dynamic width: method for determining how many branches should stem from each thought 18 | 19 | ## How it works 20 | Each iteration consists of a review phase, a generation phase, an evaluation phase. 21 | - **Selection:** The top-k scoring lines of thought are selected 22 | - **Review:** Selected lines of thought are checked to see if they contain an answer. 23 | - **Generation:** A fixed number of branching thoughts are generate from selected leaf thoughts. If a selected leaf contains an answer, a conclusion is generated instead. 24 | - **Evaluation:** New thoughts are scored against defined criteria to determine the relative strength of the threads. If any conclusions were generated, they are validated and returned if they pass. 25 | 26 | ## Usage 27 | For now see the `examples` folder to get a sense of it. In a nutshell there's three configurations: one for the initial prompt, one that governs the reasoning dynamics (evaluation, answer recognition), and one that describes how answer attempts are handled (conclusion generation, callbacks, validation). 28 | 29 | -------------------------------------------------------------------------------- /examples/get_24.py: -------------------------------------------------------------------------------- 1 | from tree_of_thoughts import TreeOfThoughts 2 | 3 | tree_config = { 4 | "initial": { # sandwiches the argument passed to self.reason 5 | "prefix": "Question: use 4 numbers and basic arithmetic operations (+-*/) to obtain ", 6 | "suffix": ". Only choose one number each step.\nAnswer: Let's think step by step.", 7 | }, 8 | "reasoning": { 9 | "graded": { 10 | "prefix": "Please assess the following reasoning, and choose an option for each point:\n```\n", 11 | "suffix": "\n```\n\n", 12 | "items": [ 13 | "The reasoning is reliable and repeatable: ", 14 | "We are getting closer to the answer: " 15 | "It's impossible that there is a mistake: ", 16 | "The reasoning is clear and easy to follow: ", 17 | ], 18 | }, 19 | # both vital and fatal are applied to the reasoning after each new thought is generated 20 | "vital": { 21 | "prefix": "Please assess the following reasoning, and choose an option for each point:\n```\n", 22 | "suffix": "\n```\n", 23 | "items": [ # if any of these questions are answered "no" the new leaf dies 24 | "There is not a single math mistake in the reasoning: ", 25 | ] 26 | }, 27 | "fatal": { # if any of these questions are answered "yes" the new leaf dies 28 | "prefix": "Please assess the following reasoning, and choose an option for each point.\n```\n", 29 | "suffix": "\n```\n", 30 | "items": [ 31 | "There is a math mistake in the reasoning: ", 32 | ] 33 | }, 34 | "stopping": { # Applied at the start of each iteration to flag potential answers 35 | "prefix": "Has the following reasoning achieved a correct and satisfying answer to the initial question?\n```\n", 36 | "suffix": "\n```\n\nAnswer: ", 37 | }, 38 | }, 39 | "answer": { 40 | "callback_prompt": { # Applied to the reasoning leading up to an answer 41 | "suffix": "In conclusion, using (+,-,x,/) and obey PEDMAS, in one expression it is written as: ", 42 | }, 43 | "callback_fn": lambda x: x.replace(".", "").strip(), # Applied to the result of whatever follows from the callback prompt 44 | "validation" : { # validations can be a yes no question and expected answer, or any bool returning function 45 | "prefix": "Please answer the following questions about the expression `", 46 | "suffix": "`. ", 47 | "items": [ # TODO: allow argument usage in all prompts and not just validation 48 | ("Are four numbers used to obtain $arg?", True), # $arg substitutes the reasoning argument into the validation prompt 49 | ("Does the expression really equal $arg?", True) 50 | ] 51 | }, 52 | }, 53 | } 54 | 55 | number_maker = TreeOfThoughts(**tree_config) 56 | 57 | # The top n_active_leafs scoring leafs are selected on each iteration 58 | # n_branches new thoughts breanch from them if they are not answers 59 | answers = number_maker.reason("24", n_active_leaves=2, n_branches=3, verbose=True) 60 | 61 | print() 62 | print("FINAL ANSWERS") 63 | print(answers) 64 | -------------------------------------------------------------------------------- /examples/mayonnaise.py: -------------------------------------------------------------------------------- 1 | from tree_of_thoughts import TreeOfThoughts 2 | 3 | tree_config = { 4 | "initial": { 5 | "prefix": "Iterate through the word '", 6 | "suffix": "' letter by letter, and note the index of each letter.", 7 | }, 8 | "reasoning": { 9 | "graded": { 10 | "items": [ 11 | "On a scale of 1-9 the instructions are being respected so far: " 12 | ], 13 | }, 14 | }, 15 | "answer": { 16 | "callback_prompt": { 17 | "suffix": "\nThe total number of of 'n' occurences is: ", 18 | }, 19 | "callback_fn": lambda x: x.replace(".", "").strip(), 20 | "validation" : { 21 | "items": [ 22 | lambda s: s.isdigit() 23 | ] 24 | }, 25 | }, 26 | } 27 | 28 | letter_n_counter = TreeOfThoughts(**tree_config, max_iterations=15) 29 | 30 | answers = letter_n_counter.reason("mayonnaise", n_active_leaves=1, n_branches=1, verbose=True) 31 | 32 | print() 33 | print("FINAL ANSWERS") 34 | print(answers) 35 | -------------------------------------------------------------------------------- /tree_of_thoughts.py: -------------------------------------------------------------------------------- 1 | import lmql 2 | import asyncio 3 | from collections import namedtuple 4 | 5 | color= { 6 | "black": lambda text: f"\033[30m{text}\033[0m", 7 | "red": lambda text: f"\033[31m{text}\033[0m", 8 | "green": lambda text: f"\033[32m{text}\033[0m", 9 | "yellow": lambda text: f"\033[33m{text}\033[0m", 10 | "blue": lambda text: f"\033[34m{text}\033[0m", 11 | "magenta": lambda text: f"\033[35m{text}\033[0m", 12 | "cyan": lambda text: f"\033[36m{text}\033[0m", 13 | "white": lambda text: f"\033[37m{text}\033[0m", 14 | } 15 | 16 | # TODO: debug output writer 17 | 18 | PromptSandwich = namedtuple("PromptSandwich", ["prefix", "suffix", "items"]) 19 | ReasoningPrompt = namedtuple("ReasoningPrompt", ["graded", "vital", "fatal", "stopping"]) 20 | AnswerPrompt = namedtuple("AnswerPrompt", ["callback_prompt", "callback_fn", "validation"]) 21 | 22 | def create_prompt_sandwich(data): 23 | prefix = data.get("prefix", "") 24 | suffix = data.get("suffix", "") 25 | items = data.get("items", []) 26 | return PromptSandwich(prefix=prefix, suffix=suffix, items=items) 27 | 28 | def create_prompt_reasoning(data): 29 | graded = create_prompt_sandwich(data.get("graded", {})) 30 | vital = create_prompt_sandwich(data.get("vital", {})) 31 | fatal = create_prompt_sandwich(data.get("fatal", {})) 32 | stopping = create_prompt_sandwich(data.get("stopping", {})) 33 | return ReasoningPrompt(graded=graded, vital=vital, fatal=fatal, stopping=stopping) 34 | 35 | def create_prompt_answer(data): 36 | callback_prompt = create_prompt_sandwich(data.get("callback_prompt", {})) 37 | callback_fn = data.get("callback_fn", None) 38 | validation = create_prompt_sandwich(data.get("validation", {})) 39 | return AnswerPrompt(callback_prompt=callback_prompt, callback_fn=callback_fn, validation=validation) 40 | 41 | class Node: 42 | def __init__(self, id: int, value: int | float, parent_id: int | None): 43 | self.id = id 44 | self.value = value 45 | self.parent_id = parent_id 46 | 47 | class Tree: 48 | def __init__(self): 49 | self.nodes = {} 50 | self.stack = {} 51 | self.answers = [] 52 | self.id_counter = 0 53 | 54 | def push(self, value: str, score: int | float, parent: Node): 55 | # nodes have unique names, still determined by counter 56 | # viable_leaf_ids is instead "data" and has keys for each of the nodes 57 | self.id_counter += 1 58 | 59 | if parent.id not in self.nodes: 60 | raise ValueError(f"Parent node {parent.value} ({parent.id}) not in tree") 61 | 62 | self.nodes[self.id_counter] = Node(self.id_counter, value, parent.id) 63 | 64 | self.stack[self.id_counter] = score 65 | 66 | def add_root(self, value: str) -> Node: 67 | self.id_counter += 1 68 | root_node = Node(self.id_counter, value, None) 69 | self.nodes[self.id_counter] = root_node 70 | 71 | return root_node 72 | 73 | def mark_as_answer(self, id: int, root_id: int): 74 | self.answers.append((id, root_id)) 75 | 76 | def leaves_pop_top(self, n: int) -> list[int]: 77 | # selected_leaf_ids = sorted(self.viable_leaf_ids, key=self.viable_leaf_ids.get, reverse=True)[:n] 78 | # selected_leaf_ids = sorted(self.viable_leaf_ids, reverse=True)[:n] 79 | selected_leaf_ids = [] 80 | i = self.id_counter 81 | while len(selected_leaf_ids) < n and i > 0: 82 | if i in self.stack: 83 | selected_leaf_ids.append(i) 84 | i -= 1 85 | 86 | for leaf_id in selected_leaf_ids: 87 | if self.nodes[leaf_id].parent_id: 88 | del self.stack[leaf_id] 89 | 90 | return selected_leaf_ids 91 | 92 | def get_path(self, id: int) -> tuple[Node, str, dict]: 93 | if id not in self.nodes: 94 | raise ValueError(f"Node {id} not in tree") 95 | 96 | node_values = [] 97 | leaf_node = self.nodes[id] 98 | current_node = self.nodes[id] 99 | 100 | while current_node is not None: 101 | node_values.append(current_node.value) 102 | if current_node.parent_id is not None: 103 | current_node = self.nodes[current_node.parent_id] 104 | else: 105 | current_node = None 106 | 107 | reasoning_path = node_values[1:] 108 | reasoning_path = "\n".join(reversed(reasoning_path)) 109 | 110 | return leaf_node, reasoning_path, {} 111 | 112 | def paths_pop_top(self, n) -> list[tuple[Node, str, dict]]: 113 | selected_leaf_ids = self.leaves_pop_top(n) 114 | return [self.get_path(id) for id in selected_leaf_ids] 115 | 116 | class TreeOfThoughts: 117 | initial: PromptSandwich 118 | reasoning: ReasoningPrompt 119 | answer: AnswerPrompt 120 | 121 | def __init__(self, initial, reasoning, answer, max_iterations=10): 122 | 123 | self.initial = create_prompt_sandwich(initial) 124 | self.reasoning = create_prompt_reasoning(reasoning) 125 | self.answer = create_prompt_answer(answer) 126 | 127 | self.max_iterations = max_iterations 128 | 129 | self.penalties = [] # TODO: investigate if these are useful, would add into evaluations 130 | self.bonuses = [] 131 | 132 | # self.params = {criteria: (1, 0) for criteria in self.graded_criteria} # TODO: use these in self.process_rating 133 | 134 | self.tree = Tree() 135 | 136 | # TODO: memory, error propagation 137 | 138 | self.verbose_buffer = "" 139 | 140 | def reason(self, argument, n_active_leaves, n_branches, verbose=False): 141 | return asyncio.run(self.async_reason(argument, n_active_leaves, n_branches, verbose)) 142 | 143 | def print_verbose(self): 144 | # clear screen 145 | print("\033c", end="") 146 | print(self.verbose_buffer) 147 | 148 | async def async_reason(self, argument, n_active_leaves, n_branches, verbose=False): 149 | self.verbose_buffer = "" 150 | self.argument = argument 151 | 152 | root_value = self.initial.prefix + argument + self.initial.suffix 153 | root = self.tree.add_root(root_value) 154 | 155 | if verbose: 156 | self.verbose_buffer += color['cyan']( "ROOT ------------------------------------------------------\n") 157 | self.verbose_buffer += root_value + "\n\n" 158 | self.print_verbose() 159 | 160 | current = 1 161 | 162 | while current <= self.max_iterations: 163 | 164 | if verbose: 165 | self.verbose_buffer += color['green'](f"ITERATION {current}\n") 166 | self.verbose_buffer += color['cyan']( "CHECKING FOR ANSWERABLE THOUGHTS --------------------------\n") 167 | self.print_verbose() 168 | 169 | # get (node, path_string) pairs, and default to the root if all leaves die 170 | selected_leaves = self.tree.paths_pop_top(n_active_leaves) 171 | 172 | if selected_leaves: 173 | can_answer = await asyncio.gather(*[self.is_finished(path + "\n" + thought.value) for thought, path, attrs in selected_leaves]) 174 | can_answer = [x[0] for x in can_answer] 175 | for i, is_answerable in enumerate(can_answer): 176 | selected_leaves[i][2]["preceeds_answer"] = is_answerable 177 | else: 178 | selected_leaves = [(root, "", {"preceeds_answer": False})] 179 | 180 | if verbose: 181 | tally = sum(1 for _, _, meta in selected_leaves if meta.get('preceeds_answer', False)) 182 | self.verbose_buffer += f" {tally} selected leaves are potential answers\n\n" 183 | self.verbose_buffer += color['cyan']( "GENERATING NEXT THOUGHTS ----------------------------------\n") 184 | self.print_verbose() 185 | 186 | if verbose: 187 | self.verbose_buffer += color['cyan']("\n------------------------------\n").join([reasoning_path + "\n" + color['blue'](leaf_node.value) + "\n" for leaf_node, reasoning_path, attrs in selected_leaves]) + "\n" 188 | self.print_verbose() 189 | 190 | next_thoughts_list = [] 191 | for leaf_thought, reasoning_path, attrs in selected_leaves: 192 | if attrs["preceeds_answer"]: 193 | next_thoughts_list.append(self.final_result(reasoning_path + "\n" + leaf_thought.value)) 194 | else: 195 | next_thoughts_list.append(self.get_next_thoughts(n_branches, reasoning_path + "\n" + leaf_thought.value)) 196 | 197 | next_thoughts_list = await asyncio.gather(*next_thoughts_list) 198 | next_thoughts_list = [x if isinstance(x[0], str) else [y[0] for y in x] for x in next_thoughts_list] 199 | 200 | if verbose: 201 | tally = sum(len(x) for x in next_thoughts_list) 202 | self.verbose_buffer += f" {tally} new thoughts from here\n\n" 203 | self.verbose_buffer += color['cyan']( "ASSESSING THOUGHT PATHS -----------------------------------\n") 204 | self.print_verbose() 205 | 206 | thought_scores_list = [] 207 | for leaf_thought, next_thoughts in zip(selected_leaves, next_thoughts_list): 208 | leaf, reasoning_path, attrs = leaf_thought 209 | if attrs["preceeds_answer"]: 210 | thought_scores_list.append(*[self.validate_result(next_thoughts[0])]) # attempted answers only have one branch 211 | else: 212 | thought_scores_list.append(asyncio.gather(*[self.evaluate_reasoning(reasoning_path + "\n" + leaf.value + "\n" + next_thought) for next_thought in next_thoughts])) 213 | 214 | thought_scores_list = await asyncio.gather(*thought_scores_list) 215 | thought_scores_list = [x if isinstance(x, list) else [x] for x in thought_scores_list] 216 | 217 | if verbose: 218 | n_thoughts = sum(len(x) for x in thought_scores_list) 219 | n_true = sum([1 for x in thought_scores_list for num in x if num > 0]) 220 | self.verbose_buffer += f" {n_true}/{n_thoughts} of the new thoughts are viable\n" 221 | self.print_verbose() 222 | 223 | answers = [] 224 | for leaf_thought, next_thoughts, next_thought_ratings in zip(selected_leaves, next_thoughts_list, thought_scores_list): 225 | leaf, reasoning_path, attrs = leaf_thought 226 | for next_thought, rating in sorted(zip(next_thoughts, next_thought_ratings), key=lambda x: x[1], reverse=True): 227 | if rating > 0: 228 | self.tree.push(next_thought, score=rating, parent=leaf) 229 | 230 | # for leaf_thought, next_thought_ratings in zip(selected_leaves, thought_scores_list): 231 | if leaf_thought[2]["preceeds_answer"] and next_thought_ratings[0] > 0: 232 | answers.append(next_thoughts[0]) 233 | self.tree.mark_as_answer(leaf.id, root.id) 234 | 235 | if verbose: 236 | self.verbose_buffer += f" {len(answers)} answers passing validation\n\n" 237 | self.print_verbose() 238 | 239 | current += 1 240 | 241 | if answers: 242 | return answers 243 | 244 | if verbose: 245 | self.verbose_buffer += color['cyan']( "NO ANSWERS FOUND IN MAX STEPS -----------------------------\n\n") 246 | self.print_verbose() 247 | 248 | return [] 249 | 250 | @lmql.query 251 | async def final_result(self, reasoning): 252 | '''lmql 253 | sample() 254 | "{self.answer.callback_prompt.prefix}" 255 | "{reasoning}" 256 | "{self.answer.callback_prompt.suffix}" 257 | "[result]" 258 | if self.answer.callback_fn: 259 | return self.answer.callback_fn(result) 260 | return result 261 | from 262 | "openai/gpt-3.5-turbo" 263 | ''' 264 | 265 | async def validate_result(self, result): 266 | if self.answer.validation.items: 267 | loop = asyncio.get_event_loop() 268 | answer_validations = [] 269 | for validation in self.answer.validation.items: 270 | if isinstance(validation, tuple): 271 | answer_validations.append(self.prompt_validate(result, validation[0], validation[1])) 272 | else: 273 | answer_validations.append(loop.run_in_executor(None, validation, result)) 274 | # answer_validations.append(validation(result)) 275 | 276 | answer_validations = await asyncio.gather(*answer_validations) 277 | answer_validations = [x[0] if isinstance(x, list) else x for x in answer_validations] 278 | 279 | if not all(answer_validations): 280 | return 0 # below survival threshold 281 | 282 | return 1 # above survival threshold 283 | 284 | @lmql.query 285 | async def prompt_validate(self, result, validation, should_be): 286 | """lmql 287 | argmax 288 | "( yes/no )\n" 289 | "{self.answer.validation.prefix}" 290 | "{result}" 291 | "{self.answer.validation.suffix}" 292 | parsed_validation = validation.replace('$arg', self.argument) 293 | "{parsed_validation}" 294 | "[yn]" 295 | if yn.split()[-1] in ["yes", "Yes"]: 296 | answer = True 297 | else: 298 | answer = False 299 | 300 | return answer == should_be 301 | from 302 | "openai/gpt-3.5-turbo" 303 | where 304 | STOPS_AT(yn, "yes") and 305 | STOPS_AT(yn, "no") and 306 | STOPS_AT(yn, "Yes") and 307 | STOPS_AT(yn, "No") and 308 | len(TOKENS(yn)) < 20 309 | """ 310 | 311 | async def get_next_thoughts(self, n, reasoning): 312 | thoughts = [self.get_next_thought(reasoning) for _ in range(n)] 313 | return await asyncio.gather(*thoughts) 314 | 315 | # TODO: add continuation prompt (e.g. This next step is very important, so I am paying very close attention...) 316 | @lmql.query 317 | async def get_next_thought(self, reasoning): 318 | '''lmql 319 | sample() 320 | "{reasoning}\n" 321 | "[thought]" 322 | return thought 323 | from 324 | "openai/gpt-3.5-turbo" 325 | where 326 | STOPS_BEFORE(thought, "\\n") and 327 | STOPS_BEFORE(thought, "\n") 328 | ''' 329 | 330 | @lmql.query 331 | async def is_finished(self, reasoning): 332 | '''lmql 333 | argmax 334 | "(yes/no)\n" 335 | "{self.reasoning.stopping.prefix}" 336 | "{reasoning}" 337 | "{self.reasoning.stopping.suffix}" 338 | "[yn]" 339 | if yn.split()[-1] in ["yes", "Yes"]: 340 | return True 341 | return False 342 | from 343 | "openai/gpt-3.5-turbo" 344 | where 345 | STOPS_AT(yn, "yes") and 346 | STOPS_AT(yn, "no") and 347 | STOPS_AT(yn, "Yes") and 348 | STOPS_AT(yn, "No") and 349 | len(TOKENS(yn)) < 20 350 | ''' 351 | 352 | # TODO: programmatic constraints and evaluations 353 | # TODO: explore metaprompting for rating criteria 354 | async def evaluate_reasoning(self, reasoning): 355 | thought_validations = [self.validate_thought(self.reasoning.fatal.prefix, self.reasoning.fatal.suffix, statement, reasoning, should_be=False) for statement in self.reasoning.fatal.items] 356 | thought_validations += [self.validate_thought(self.reasoning.vital.prefix, self.reasoning.vital.suffix, statement, reasoning, should_be=True) for statement in self.reasoning.vital.items] 357 | thought_validations = await asyncio.gather(*thought_validations) 358 | thought_validations = [x[0] for x in thought_validations] 359 | if not all(thought_validations): 360 | return 0 361 | 362 | evaluations = [self.grade(statement, reasoning) for statement in self.reasoning.graded.items] 363 | evaluations = await asyncio.gather(*evaluations) 364 | evaluations = [x[0] for x in evaluations] 365 | return sum(evaluations) 366 | 367 | @lmql.query 368 | async def validate_thought(self, prefix, suffix, statement, reasoning, should_be=True): 369 | '''lmql 370 | argmax 371 | default = "yes" if should_be else "no" 372 | "( Answer yes/no. If not applicable, default to {default}. )\n" 373 | "{prefix}" 374 | "{reasoning}" 375 | "{suffix}" 376 | "{statement}: [yn]" 377 | if yn.split()[-1] in ["yes", "Yes"]: 378 | answer = True 379 | else: 380 | answer = False 381 | 382 | return answer == should_be 383 | from 384 | "openai/gpt-3.5-turbo" 385 | where 386 | STOPS_AT(yn, "yes") and 387 | STOPS_AT(yn, "no") and 388 | STOPS_AT(yn, "Yes") and 389 | STOPS_AT(yn, "No") and 390 | len(TOKENS(yn)) < 10 391 | ''' 392 | 393 | # TODO: replace ridiculous list of stops_at constraints if "in" constraints are supported for chat 394 | @lmql.query 395 | async def grade(self, statement, reasoning): 396 | '''lmql 397 | argmax 398 | "( rate each point from 1 - 9 where 5 is neutral. If N/A choose 5. )\n" 399 | "{self.reasoning.graded.prefix}" 400 | "{reasoning}" 401 | "{self.reasoning.graded.suffix}" 402 | "{statement}: [rating]" 403 | if rating[-1] in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: 404 | rating = int(rating[-1]) 405 | rating = rating - 5 406 | else: 407 | rating = 0 # no information if improperly answered 408 | 409 | return rating 410 | from 411 | "openai/gpt-3.5-turbo" 412 | where 413 | STOPS_AT(rating, "1") and 414 | STOPS_AT(rating, "2") and 415 | STOPS_AT(rating, "3") and 416 | STOPS_AT(rating, "4") and 417 | STOPS_AT(rating, "5") and 418 | STOPS_AT(rating, "6") and 419 | STOPS_AT(rating, "7") and 420 | STOPS_AT(rating, "8") and 421 | STOPS_AT(rating, "9") and 422 | len(TOKENS(rating)) < 10 423 | ''' 424 | --------------------------------------------------------------------------------