├── requirements.txt ├── README.md └── main.py /requirements.txt: -------------------------------------------------------------------------------- 1 | openai==1.99.1 2 | pydantic==2.11.7 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # K2-Think-Inference 2 | 3 | ## Dependencies 4 | 5 | You will need to install the following python packages: 6 | ```bash 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ## Configuration 11 | You need to set the following arguments in `main.py` under `class Env`. 12 | ```python 13 | # Endpoint configuration of planner llm 14 | PLANNER_LLM_API_KEY: str = '' 15 | PLANNER_LLM_BASE_URL: str = '' 16 | PLANNER_LLM_MODEL: str = '' 17 | 18 | # Endpoint configuration of solver llm 19 | SOLVER_LLM_API_KEY: str = '' 20 | SOLVER_LLM_BASE_URL: str = '' 21 | SOLVER_LLM_MODEL: str = "K2-Think" 22 | ``` 23 | The planner llm will be responsible for extracting topics, generating plans, and comparing answer pairs. You may choose any OpenAI compatible endpoint you want to act as the planner. For example, you can start a [vllm](https://docs.vllm.ai/en/stable/) localhost endpoint of any huggingface model. The following script will start an endpoint serving [Qwen/Qwen3-235B-A22B](https://huggingface.co/Qwen/Qwen3-235B-A22B): 24 | ```bash 25 | vllm serve Qwen/Qwen3-235B-A22B \ 26 | --tensor_parallel_size 8 \ 27 | --served-model-name Qwen/Qwen3-235B-A22B \ 28 | --port 8080 29 | ``` 30 | After the endpoint is up and running on localhost, you can set the following arguments: 31 | ```python 32 | # Endpoint configuration of planner llm 33 | PLANNER_LLM_API_KEY: str = '' 34 | PLANNER_LLM_BASE_URL: str = 'http://localhost:8080/v1' 35 | PLANNER_LLM_MODEL: str = 'Qwen/Qwen3-235B-A22B' 36 | ``` 37 | Similarly, you can choose your favorite reasoning model for the solver llm, which is responsible for solving the input problems. 38 | 39 | ## Test the Script 40 | ```bash 41 | python main.py 42 | ``` 43 | You can change `query` in line `257` from `main.py` to test any problems you like. 44 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import time 5 | import uuid 6 | 7 | from openai import AsyncOpenAI 8 | from openai.types.chat.chat_completion import ChatCompletion 9 | from pydantic import BaseModel 10 | from typing import List, Callable, Awaitable, Any 11 | 12 | 13 | class Env: 14 | 15 | # Endpoint configuration of planner llm 16 | PLANNER_LLM_API_KEY: str = '' 17 | PLANNER_LLM_BASE_URL: str = '' 18 | PLANNER_LLM_MODEL: str = '' 19 | 20 | # Endpoint configuration of solver llm 21 | SOLVER_LLM_API_KEY: str = '' 22 | SOLVER_LLM_BASE_URL: str = '' 23 | SOLVER_LLM_MODEL: str = "K2-Think" 24 | 25 | # N in Best-of-N sampling 26 | N: int = 3 27 | 28 | # Adapted from AM-Thinking-v1: Advancing the Frontier of Reasoning at 32B Scale (Yunjie Ji et al.) https://arxiv.org/pdf/2505.08311 29 | SOLVER_PROMPT: str = "You are K2-Think, a helpful assistant trained by MBZUAI. To answer the user's question, you first think about the reasoning process and then provide the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ." 30 | SOLVER_TEMPERATURE: float = 1.0 31 | 32 | 33 | logging.basicConfig(level=logging.INFO) 34 | log = logging.getLogger(__name__) 35 | env = Env() 36 | 37 | 38 | # Schema for structured output 39 | class BoNIndex(BaseModel): 40 | index: int # must be 0 or 1 41 | explanation: str 42 | 43 | 44 | class SearchList(BaseModel): 45 | is_hard_problem: bool 46 | plan: str 47 | search_list: list[str] 48 | 49 | 50 | class K2ThinkPipeline: 51 | def __init__(self): 52 | self.solver_llm = AsyncOpenAI( 53 | api_key=env.SOLVER_LLM_API_KEY, 54 | base_url=env.SOLVER_LLM_BASE_URL, 55 | timeout=None 56 | ) 57 | self.planner_llm = AsyncOpenAI( 58 | api_key=env.PLANNER_LLM_API_KEY, 59 | base_url=env.PLANNER_LLM_BASE_URL 60 | ) 61 | self.bon_responses = {} 62 | 63 | async def run(self, question: str) -> ChatCompletion: 64 | return await self.best_of_n_sampling(question=question, n=env.N, timeout=1200) 65 | 66 | # We do not want to wait too long for alternate responses for Best-of-N. 67 | # Once `soft_timeout` seconds have passed, we will collect all completed responses so far for Best-of-N selection. 68 | # If none have completed thusfar, we will wait and return whatever responses have been completed first. 69 | # At `hard_timeout` seconds, we throw an error. 70 | async def run_at_least_one( 71 | self, 72 | fn: Callable[[], Awaitable[Any]], 73 | args_list: List[Any] = [], 74 | soft_timeout: float = 9*60, 75 | hard_timeout: float = 120*60, 76 | poll_interval: float = 10 77 | ) -> List[Any]: 78 | start_time = time.monotonic() 79 | futures = [asyncio.ensure_future(fn(*args)) for args in args_list] 80 | pending = set(futures) 81 | 82 | is_first_iteration = True 83 | 84 | try: 85 | while pending: 86 | # Adjust soft timeout based on how much time has passed 87 | elapsed = time.monotonic() - start_time 88 | if elapsed >= hard_timeout: 89 | break 90 | if is_first_iteration: 91 | is_first_iteration = False 92 | timeout = soft_timeout 93 | return_when = asyncio.ALL_COMPLETED 94 | else: 95 | timeout = min(poll_interval, hard_timeout - elapsed) 96 | return_when = asyncio.FIRST_COMPLETED 97 | done, pending = await asyncio.wait( 98 | pending, 99 | timeout=timeout, 100 | return_when=return_when 101 | ) 102 | if len(done) == 0: 103 | continue 104 | results = [] 105 | 106 | for fut in done: 107 | try: 108 | result = fut.result() 109 | results.append(result) 110 | except Exception as e: 111 | log.error(f"Error in getting result: {e}") 112 | continue # Ignore failed tasks 113 | if len(results) > 0: 114 | # Cancel the rest 115 | try: 116 | for p in pending: 117 | p.cancel() 118 | except Exception as e: 119 | log.error(f"Error in canceling tasks: {e}") 120 | return results 121 | finally: 122 | for fut in pending: 123 | fut.cancel() 124 | 125 | raise asyncio.TimeoutError(f"No task succeeded within hard timeout of {hard_timeout} seconds") 126 | 127 | async def select_best(self, question, completions): 128 | # We use a linear scan here for simplicity, but a tree-based tournament would be more performant. 129 | answers = [] 130 | for completion in completions: 131 | if completion is None: 132 | answers.append("No answer.") 133 | continue 134 | content = completion.choices[0].message.content 135 | if "" in content: 136 | answers.append(content.split("")[1]) 137 | else: 138 | answers.append(f"No answer was found, but here was the tail end of the problem solving: {content[-2000:]}") 139 | 140 | best_index, best_completion, best_answer = 0, completions[0], answers[0] 141 | for index, completion in enumerate(completions[1:], start=1): 142 | response = await self.planner_llm.chat.completions.create( 143 | model=env.PLANNER_LLM_MODEL, 144 | messages=[ 145 | { 146 | "role": "system", 147 | "content": ( 148 | "You are a strict evaluator. Given a question and two responses, " 149 | "return a JSON object with 'better_index' as 0 or 1 for the response " 150 | "that best answers the question." 151 | ), 152 | }, 153 | { 154 | "role": "user", 155 | "content": f"Question: {question}\nResponse 0: {best_answer}\nResponse 1: {answers[index]}" 156 | } 157 | ], 158 | extra_body={"guided_json": BoNIndex.model_json_schema()}, 159 | ) 160 | winner = json.loads(response.choices[0].message.content)["index"] 161 | if winner == 1: 162 | best_index = index 163 | best_completion = completion 164 | best_answer = answers[index] 165 | log.info(f"{best_index=}") 166 | return best_completion 167 | 168 | async def best_of_n_sampling(self, question: str, n: int = 3, timeout: float = 540) -> ChatCompletion | None: 169 | request_id = uuid.uuid4() 170 | self.bon_responses[request_id] = { 171 | "completions": [None] * n 172 | } 173 | args_list = [ 174 | (request_id, bon_id, question) 175 | for bon_id in range(n) 176 | ] 177 | log.info(f"running {len(args_list)} tasks , {self.single_sampling.__name__}, {args_list}, {timeout}") 178 | try: 179 | results = await self.run_at_least_one(self.single_sampling, args_list, timeout) 180 | log.info(f"{results=}") 181 | except Exception as e: 182 | log.error(f"Error in best_of_n_sampling: {e}") 183 | return None 184 | 185 | completions = self.bon_responses[request_id]["completions"] 186 | best_completion = await self.select_best(question, completions) 187 | return best_completion 188 | 189 | async def single_sampling(self, request_id: uuid.UUID, bon_id: int, question: str): 190 | # Get a single completion 191 | response = await self.sampling_with_planning(question) 192 | self.bon_responses[request_id]["completions"][bon_id] = response 193 | log.info(f"Finish sampling {bon_id}.\nQuestion: {question}\nAnswer{bon_id}: {response.choices[0].message.content}") 194 | 195 | async def sampling_with_planning(self, question: str): 196 | 197 | topics_list = await self.create_topics_list(question) 198 | 199 | ideas = None 200 | if topics_list is not None: 201 | 202 | prompt_planning_topics: str = f''' 203 | You are given a question and some useful topics: 204 | {question} 205 | {topics_list} 206 | You need to generate a plan of solving the question based on the topics above WITHOUT disclosing the final or potential answer. DO NOT mention or give any hints of the final or potential answer in your plan. Wrap your plan inside . 207 | '''.strip('\n') 208 | 209 | response = await self.planner_llm.chat.completions.create( 210 | model=env.PLANNER_LLM_MODEL, 211 | messages=[{"role": "user", "content": prompt_planning_topics}], 212 | extra_body={"chat_template_kwargs": {"enable_thinking": False}} 213 | ) 214 | ideas = response.choices[0].message.content 215 | if "" in ideas: 216 | ideas = ideas.split("")[-2] 217 | if "" in ideas: 218 | ideas = ideas.split("")[-1] 219 | 220 | prompt: str = f"{question}" + f''' 221 | Below are some helpful insights or ideas: 222 | {ideas} 223 | The ideas above may provide some insights in solving the challenge. Now please answer the original question. 224 | '''.strip('\n') 225 | else: 226 | prompt = question 227 | response = await self.solver_llm.chat.completions.create( 228 | model=env.SOLVER_LLM_MODEL, 229 | messages=[ 230 | {"role": "system", "content": env.SOLVER_PROMPT}, 231 | {"role": "user", "content": prompt} 232 | ], 233 | temperature=env.SOLVER_TEMPERATURE 234 | ) 235 | return response 236 | 237 | async def create_topics_list(self, question: str): 238 | 239 | json_schema = SearchList.model_json_schema() 240 | 241 | completion = await self.planner_llm.chat.completions.create( 242 | model=env.PLANNER_LLM_MODEL, 243 | messages=[ 244 | {"role": "user", "content": ( 245 | "First determine if the user is asking a hard math, stem, or coding problem, or any question where you would need more information from the internet. If so, construct a plan and then a JSON list of less than five things you would want to search for to help solve this hard problem: " 246 | f'{question} ' 247 | 'An example of a good thing to search is "Prime vs. composite power sum convergence conditions" or "Using inclusion/exclusion to prove combinatorics problems"' 248 | )} 249 | ], 250 | extra_body={"guided_json": json_schema}, 251 | ) 252 | body = json.loads(completion.choices[0].message.content) 253 | if body["is_hard_problem"]: 254 | topics_list = body["search_list"] 255 | else: 256 | topics_list = None 257 | return topics_list 258 | 259 | 260 | async def main(query: str): 261 | pipeline = K2ThinkPipeline() 262 | return await pipeline.run(query) 263 | 264 | 265 | if __name__ == "__main__": 266 | query: str = "Determine the least real number $M$ such that the inequality \\[|ab(a^{2}-b^{2})+bc(b^{2}-c^{2})+ca(c^{2}-a^{2})| \\leq M(a^{2}+b^{2}+c^{2})^{2}\\] holds for all real numbers $a$, $b$ and $c$." 267 | response = asyncio.run(main(query)) 268 | log.info(f"Final answer generated.\nQuestion: {query}\nAnswer: {response.choices[0].message.content}") 269 | --------------------------------------------------------------------------------