├── requirements.txt
├── README.md
└── main.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | openai==1.99.1
2 | pydantic==2.11.7
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # K2-Think-Inference
2 |
3 | ## Dependencies
4 |
5 | You will need to install the following python packages:
6 | ```bash
7 | pip install -r requirements.txt
8 | ```
9 |
10 | ## Configuration
11 | You need to set the following arguments in `main.py` under `class Env`.
12 | ```python
13 | # Endpoint configuration of planner llm
14 | PLANNER_LLM_API_KEY: str = ''
15 | PLANNER_LLM_BASE_URL: str = ''
16 | PLANNER_LLM_MODEL: str = ''
17 |
18 | # Endpoint configuration of solver llm
19 | SOLVER_LLM_API_KEY: str = ''
20 | SOLVER_LLM_BASE_URL: str = ''
21 | SOLVER_LLM_MODEL: str = "K2-Think"
22 | ```
23 | The planner llm will be responsible for extracting topics, generating plans, and comparing answer pairs. You may choose any OpenAI compatible endpoint you want to act as the planner. For example, you can start a [vllm](https://docs.vllm.ai/en/stable/) localhost endpoint of any huggingface model. The following script will start an endpoint serving [Qwen/Qwen3-235B-A22B](https://huggingface.co/Qwen/Qwen3-235B-A22B):
24 | ```bash
25 | vllm serve Qwen/Qwen3-235B-A22B \
26 | --tensor_parallel_size 8 \
27 | --served-model-name Qwen/Qwen3-235B-A22B \
28 | --port 8080
29 | ```
30 | After the endpoint is up and running on localhost, you can set the following arguments:
31 | ```python
32 | # Endpoint configuration of planner llm
33 | PLANNER_LLM_API_KEY: str = ''
34 | PLANNER_LLM_BASE_URL: str = 'http://localhost:8080/v1'
35 | PLANNER_LLM_MODEL: str = 'Qwen/Qwen3-235B-A22B'
36 | ```
37 | Similarly, you can choose your favorite reasoning model for the solver llm, which is responsible for solving the input problems.
38 |
39 | ## Test the Script
40 | ```bash
41 | python main.py
42 | ```
43 | You can change `query` in line `257` from `main.py` to test any problems you like.
44 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import logging
4 | import time
5 | import uuid
6 |
7 | from openai import AsyncOpenAI
8 | from openai.types.chat.chat_completion import ChatCompletion
9 | from pydantic import BaseModel
10 | from typing import List, Callable, Awaitable, Any
11 |
12 |
13 | class Env:
14 |
15 | # Endpoint configuration of planner llm
16 | PLANNER_LLM_API_KEY: str = ''
17 | PLANNER_LLM_BASE_URL: str = ''
18 | PLANNER_LLM_MODEL: str = ''
19 |
20 | # Endpoint configuration of solver llm
21 | SOLVER_LLM_API_KEY: str = ''
22 | SOLVER_LLM_BASE_URL: str = ''
23 | SOLVER_LLM_MODEL: str = "K2-Think"
24 |
25 | # N in Best-of-N sampling
26 | N: int = 3
27 |
28 | # Adapted from AM-Thinking-v1: Advancing the Frontier of Reasoning at 32B Scale (Yunjie Ji et al.) https://arxiv.org/pdf/2505.08311
29 | SOLVER_PROMPT: str = "You are K2-Think, a helpful assistant trained by MBZUAI. To answer the user's question, you first think about the reasoning process and then provide the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."
30 | SOLVER_TEMPERATURE: float = 1.0
31 |
32 |
33 | logging.basicConfig(level=logging.INFO)
34 | log = logging.getLogger(__name__)
35 | env = Env()
36 |
37 |
38 | # Schema for structured output
39 | class BoNIndex(BaseModel):
40 | index: int # must be 0 or 1
41 | explanation: str
42 |
43 |
44 | class SearchList(BaseModel):
45 | is_hard_problem: bool
46 | plan: str
47 | search_list: list[str]
48 |
49 |
50 | class K2ThinkPipeline:
51 | def __init__(self):
52 | self.solver_llm = AsyncOpenAI(
53 | api_key=env.SOLVER_LLM_API_KEY,
54 | base_url=env.SOLVER_LLM_BASE_URL,
55 | timeout=None
56 | )
57 | self.planner_llm = AsyncOpenAI(
58 | api_key=env.PLANNER_LLM_API_KEY,
59 | base_url=env.PLANNER_LLM_BASE_URL
60 | )
61 | self.bon_responses = {}
62 |
63 | async def run(self, question: str) -> ChatCompletion:
64 | return await self.best_of_n_sampling(question=question, n=env.N, timeout=1200)
65 |
66 | # We do not want to wait too long for alternate responses for Best-of-N.
67 | # Once `soft_timeout` seconds have passed, we will collect all completed responses so far for Best-of-N selection.
68 | # If none have completed thusfar, we will wait and return whatever responses have been completed first.
69 | # At `hard_timeout` seconds, we throw an error.
70 | async def run_at_least_one(
71 | self,
72 | fn: Callable[[], Awaitable[Any]],
73 | args_list: List[Any] = [],
74 | soft_timeout: float = 9*60,
75 | hard_timeout: float = 120*60,
76 | poll_interval: float = 10
77 | ) -> List[Any]:
78 | start_time = time.monotonic()
79 | futures = [asyncio.ensure_future(fn(*args)) for args in args_list]
80 | pending = set(futures)
81 |
82 | is_first_iteration = True
83 |
84 | try:
85 | while pending:
86 | # Adjust soft timeout based on how much time has passed
87 | elapsed = time.monotonic() - start_time
88 | if elapsed >= hard_timeout:
89 | break
90 | if is_first_iteration:
91 | is_first_iteration = False
92 | timeout = soft_timeout
93 | return_when = asyncio.ALL_COMPLETED
94 | else:
95 | timeout = min(poll_interval, hard_timeout - elapsed)
96 | return_when = asyncio.FIRST_COMPLETED
97 | done, pending = await asyncio.wait(
98 | pending,
99 | timeout=timeout,
100 | return_when=return_when
101 | )
102 | if len(done) == 0:
103 | continue
104 | results = []
105 |
106 | for fut in done:
107 | try:
108 | result = fut.result()
109 | results.append(result)
110 | except Exception as e:
111 | log.error(f"Error in getting result: {e}")
112 | continue # Ignore failed tasks
113 | if len(results) > 0:
114 | # Cancel the rest
115 | try:
116 | for p in pending:
117 | p.cancel()
118 | except Exception as e:
119 | log.error(f"Error in canceling tasks: {e}")
120 | return results
121 | finally:
122 | for fut in pending:
123 | fut.cancel()
124 |
125 | raise asyncio.TimeoutError(f"No task succeeded within hard timeout of {hard_timeout} seconds")
126 |
127 | async def select_best(self, question, completions):
128 | # We use a linear scan here for simplicity, but a tree-based tournament would be more performant.
129 | answers = []
130 | for completion in completions:
131 | if completion is None:
132 | answers.append("No answer.")
133 | continue
134 | content = completion.choices[0].message.content
135 | if "" in content:
136 | answers.append(content.split("")[1])
137 | else:
138 | answers.append(f"No answer was found, but here was the tail end of the problem solving: {content[-2000:]}")
139 |
140 | best_index, best_completion, best_answer = 0, completions[0], answers[0]
141 | for index, completion in enumerate(completions[1:], start=1):
142 | response = await self.planner_llm.chat.completions.create(
143 | model=env.PLANNER_LLM_MODEL,
144 | messages=[
145 | {
146 | "role": "system",
147 | "content": (
148 | "You are a strict evaluator. Given a question and two responses, "
149 | "return a JSON object with 'better_index' as 0 or 1 for the response "
150 | "that best answers the question."
151 | ),
152 | },
153 | {
154 | "role": "user",
155 | "content": f"Question: {question}\nResponse 0: {best_answer}\nResponse 1: {answers[index]}"
156 | }
157 | ],
158 | extra_body={"guided_json": BoNIndex.model_json_schema()},
159 | )
160 | winner = json.loads(response.choices[0].message.content)["index"]
161 | if winner == 1:
162 | best_index = index
163 | best_completion = completion
164 | best_answer = answers[index]
165 | log.info(f"{best_index=}")
166 | return best_completion
167 |
168 | async def best_of_n_sampling(self, question: str, n: int = 3, timeout: float = 540) -> ChatCompletion | None:
169 | request_id = uuid.uuid4()
170 | self.bon_responses[request_id] = {
171 | "completions": [None] * n
172 | }
173 | args_list = [
174 | (request_id, bon_id, question)
175 | for bon_id in range(n)
176 | ]
177 | log.info(f"running {len(args_list)} tasks , {self.single_sampling.__name__}, {args_list}, {timeout}")
178 | try:
179 | results = await self.run_at_least_one(self.single_sampling, args_list, timeout)
180 | log.info(f"{results=}")
181 | except Exception as e:
182 | log.error(f"Error in best_of_n_sampling: {e}")
183 | return None
184 |
185 | completions = self.bon_responses[request_id]["completions"]
186 | best_completion = await self.select_best(question, completions)
187 | return best_completion
188 |
189 | async def single_sampling(self, request_id: uuid.UUID, bon_id: int, question: str):
190 | # Get a single completion
191 | response = await self.sampling_with_planning(question)
192 | self.bon_responses[request_id]["completions"][bon_id] = response
193 | log.info(f"Finish sampling {bon_id}.\nQuestion: {question}\nAnswer{bon_id}: {response.choices[0].message.content}")
194 |
195 | async def sampling_with_planning(self, question: str):
196 |
197 | topics_list = await self.create_topics_list(question)
198 |
199 | ideas = None
200 | if topics_list is not None:
201 |
202 | prompt_planning_topics: str = f'''
203 | You are given a question and some useful topics:
204 | {question}
205 | {topics_list}
206 | You need to generate a plan of solving the question based on the topics above WITHOUT disclosing the final or potential answer. DO NOT mention or give any hints of the final or potential answer in your plan. Wrap your plan inside .
207 | '''.strip('\n')
208 |
209 | response = await self.planner_llm.chat.completions.create(
210 | model=env.PLANNER_LLM_MODEL,
211 | messages=[{"role": "user", "content": prompt_planning_topics}],
212 | extra_body={"chat_template_kwargs": {"enable_thinking": False}}
213 | )
214 | ideas = response.choices[0].message.content
215 | if "" in ideas:
216 | ideas = ideas.split("")[-2]
217 | if "" in ideas:
218 | ideas = ideas.split("")[-1]
219 |
220 | prompt: str = f"{question}" + f'''
221 | Below are some helpful insights or ideas:
222 | {ideas}
223 | The ideas above may provide some insights in solving the challenge. Now please answer the original question.
224 | '''.strip('\n')
225 | else:
226 | prompt = question
227 | response = await self.solver_llm.chat.completions.create(
228 | model=env.SOLVER_LLM_MODEL,
229 | messages=[
230 | {"role": "system", "content": env.SOLVER_PROMPT},
231 | {"role": "user", "content": prompt}
232 | ],
233 | temperature=env.SOLVER_TEMPERATURE
234 | )
235 | return response
236 |
237 | async def create_topics_list(self, question: str):
238 |
239 | json_schema = SearchList.model_json_schema()
240 |
241 | completion = await self.planner_llm.chat.completions.create(
242 | model=env.PLANNER_LLM_MODEL,
243 | messages=[
244 | {"role": "user", "content": (
245 | "First determine if the user is asking a hard math, stem, or coding problem, or any question where you would need more information from the internet. If so, construct a plan and then a JSON list of less than five things you would want to search for to help solve this hard problem: "
246 | f'{question} '
247 | 'An example of a good thing to search is "Prime vs. composite power sum convergence conditions" or "Using inclusion/exclusion to prove combinatorics problems"'
248 | )}
249 | ],
250 | extra_body={"guided_json": json_schema},
251 | )
252 | body = json.loads(completion.choices[0].message.content)
253 | if body["is_hard_problem"]:
254 | topics_list = body["search_list"]
255 | else:
256 | topics_list = None
257 | return topics_list
258 |
259 |
260 | async def main(query: str):
261 | pipeline = K2ThinkPipeline()
262 | return await pipeline.run(query)
263 |
264 |
265 | if __name__ == "__main__":
266 | query: str = "Determine the least real number $M$ such that the inequality \\[|ab(a^{2}-b^{2})+bc(b^{2}-c^{2})+ca(c^{2}-a^{2})| \\leq M(a^{2}+b^{2}+c^{2})^{2}\\] holds for all real numbers $a$, $b$ and $c$."
267 | response = asyncio.run(main(query))
268 | log.info(f"Final answer generated.\nQuestion: {query}\nAnswer: {response.choices[0].message.content}")
269 |
--------------------------------------------------------------------------------