├── .gitignore
├── .prettierrc
├── LICENSE
├── README.md
├── api
├── execute.py
└── requirements.txt
├── docs
└── source
│ └── _static
│ └── logo.png
├── index.html
├── package-lock.json
├── package.json
├── public
├── meta-full.png
└── meta.jpg
├── src
├── components
│ ├── App.tsx
│ ├── NodeInfo.tsx
│ ├── Prompt.tsx
│ ├── modals
│ │ ├── APIKeyModal.tsx
│ │ └── SettingsModal.tsx
│ ├── nodes
│ │ ├── CustomNode.tsx
│ │ ├── LabelUpdaterNode.tsx
│ │ ├── useAnimatedNodes.tsx
│ │ └── useExpandCollapse.tsx
│ ├── tree.ts
│ └── utils
│ │ ├── APIKeyInput.tsx
│ │ ├── BigButton.tsx
│ │ ├── LabeledInputs.tsx
│ │ ├── Markdown.tsx
│ │ ├── NavigationBar.tsx
│ │ └── Whisper.tsx
├── index.css
├── main.tsx
├── types
│ └── highlightjs-solidity.d.ts
├── utils
│ ├── apikey.ts
│ ├── branchesEdge.ts
│ ├── branchesNode.ts
│ ├── chakra.tsx
│ ├── clipboard.ts
│ ├── color.ts
│ ├── constants.ts
│ ├── debounce.ts
│ ├── humanEval.ts
│ ├── human_eval_problems.json
│ ├── llm.ts
│ ├── lstore.ts
│ ├── mod.ts
│ ├── models.ts
│ ├── nodeId.ts
│ ├── platform.ts
│ ├── prompt.ts
│ ├── qparams.ts
│ ├── rand.ts
│ ├── resize.ts
│ ├── tot.ts
│ └── types.ts
└── vite-env.d.ts
├── tsconfig.json
├── tsconfig.node.json
├── vercel.json
└── vite.config.ts
/.gitignore:
--------------------------------------------------------------------------------
1 | # Logs
2 | logs
3 | *.log
4 | npm-debug.log*
5 | yarn-debug.log*
6 | yarn-error.log*
7 | pnpm-debug.log*
8 | lerna-debug.log*
9 |
10 | node_modules
11 | dist
12 | dist-ssr
13 | *.local
14 |
15 | # Editor directories and files
16 | .vscode/*
17 | !.vscode/extensions.json
18 | .idea
19 | .DS_Store
20 | *.suo
21 | *.ntvs*
22 | *.njsproj
23 | *.sln
24 | *.sw?
25 | .env
26 | .vercel
27 |
--------------------------------------------------------------------------------
/.prettierrc:
--------------------------------------------------------------------------------
1 | {
2 | "printWidth": 90
3 | }
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 t11s
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 | # Branches
6 |
7 | Prototype advanced LLM algorithms for reasoning and planning.
8 |
9 | [Try Online](http://code-gen-tree.vercel.app) •
10 | [Report a Bug](https://github.com/normal-computing/branches/issues) •
11 | [Stay tuned](#stay-tuned-for)
12 |
13 |
14 |
15 | 
16 |
17 | ***Tree-search visualization during code generation.** We visualize a reasoning algorithm which learns from feedback, automatically correcting itself by analyzing error tracebacks to refine its solutions. In this case, we benchmark Python programming problems from the HumanEval dataset.*
18 |
19 | ## About
20 |
21 | Branches is an AI tool for graph-based prototyping of advanced algorithms for LLM reasoning and planning -- like Tree of Thoughts and Reflexion. Branches is adapted from [Flux](https://github.com/paradigmxyz/flux).
22 |
23 | Designed for researchers and developers, it allows users to directly interact with AI reasoning processes, streamlining the exploration of complex coding challenges and strategic problem-solving.
24 |
25 | ### Code Generation (HumanEval)
26 |
27 | Branches automatically expands decision trees to solve programming problems from the [HumanEval dataset](https://huggingface.co/datasets/openai_humaneval), visualizing reasoning chains and facilitating self-correction through error tracebacks. This is found on the `main` branch and is currently hosted.
28 |
29 | ### Game of 24
30 | Branches includes a specialized evaluation mechanism for the [Game of 24 puzzle](https://en.wikipedia.org/wiki/24_(puzzle)), leveraging a scoring system to enhance breadth-first search (BFS) by prioritizing promising paths. This is found on the `game-of-24` branch.
31 |
32 | ## Features
33 |
34 | - [x] 🌳 **Automated Tree Expansion**: Leveraging Tree of Thoughts for dynamic expansion in problem-solving.
35 | - [x] 🧠 **Pre-loaded Prompts**: Curated for search-based reasoning to solve specific problems.
36 | - [x] 💻 **Code Interpretation**: Instant execution and error analysis for self-correcting AI-generated code.
37 | - [x] 🔍 **Scoring Mechanism**: Advanced BFS for the Game of 24 with node evaluation for search optimization.
38 | - [x] 📊 **Interactive Visualization**: Graphical representation of tree searches for easy analysis and education. Largely adapted from [Flux](https://github.com/paradigmxyz/flux).
39 |
40 | ## Usage
41 |
42 | To get started with Branches, you can either visit [code-gen-tree.vercel.app](https://code-gen-tree.vercel.app) for the hosted version or run it locally by following the instructions below.
43 |
44 | ## Deploy to Vercel
45 | ```sh
46 | npm i -g vercel
47 | vercel
48 | ```
49 |
50 | ## Stay Tuned For
51 |
52 | Our commitment to enhancing Branches continues, with exciting new developments on the way:
53 |
54 | - More reasoning and planning algorithms beyond the defaults ([#10](https://github.com/normal-computing/branches/issues/10))
55 | - Node Value Editing and Regenerate Subtree Functionality ([#5](https://github.com/normal-computing/branches/issues/5))
56 | - UI Color Fixes and Customization Features ([#6](https://github.com/normal-computing/branches/issues/6))
57 | - Address Model/UI Timeout Issues ([#7](https://github.com/normal-computing/branches/issues/7))
58 | - Enhance Game of 24 Logic, Model Cost Tracking, and Prompt Engineering ([#8](https://github.com/normal-computing/branches/issues/8))
59 |
60 | ## Contributing
61 |
62 | Your contributions make Branches better. Whether it’s bug reports, new features, or feedback, we welcome it all! Report bugs or request features by creating an issue [here](https://github.com/normal-computing/Branches/issues).
63 |
64 | ## License
65 |
66 | Branches is open-source and continues to uphold the [MIT license](LICENSE).
67 |
--------------------------------------------------------------------------------
/api/execute.py:
--------------------------------------------------------------------------------
1 | from http import HTTPStatus
2 | import json
3 | from concurrent.futures import ThreadPoolExecutor
4 | from human_eval.execution import check_correctness
5 | from flask import Flask, request, jsonify
6 |
7 |
8 | app = Flask(__name__)
9 |
10 | executor = ThreadPoolExecutor(max_workers=5)
11 |
12 |
13 | @app.route("/execute", methods=["POST"])
14 | def execute():
15 | data = request.json
16 |
17 | problem = data.get("problem", "")
18 | completion = data.get("completion", "")
19 | timeout = data.get("timeout", 5.0)
20 | args = (problem, completion, timeout)
21 |
22 | if not completion:
23 | response = jsonify({"error": "No completion provided"})
24 | response.status_code = HTTPStatus.BAD_REQUEST
25 | return response
26 |
27 | try:
28 | future = executor.submit(check_correctness, problem, completion, timeout)
29 | result = future.result()
30 | return jsonify({"result": result})
31 | except Exception as e:
32 | response = jsonify({"error": str(e)})
33 | response.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
34 | return response
35 |
36 |
37 | # check if a 500 error code is thrown
38 | @app.errorhandler(500)
39 | def internal_error(error):
40 | return "500 error: {}".format(str(error)), 500
41 |
--------------------------------------------------------------------------------
/api/requirements.txt:
--------------------------------------------------------------------------------
1 | openai
2 | git+https://github.com/arunpatro/human-eval.git@pipgit
3 | Flask
--------------------------------------------------------------------------------
/docs/source/_static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/normal-computing/branches/eb7111da6edb7a762bfbe4bdad79b6978890657c/docs/source/_static/logo.png
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | Branches
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "branches",
3 | "private": true,
4 | "version": "0.0.0",
5 | "scripts": {
6 | "dev": "vite",
7 | "build": "tsc && vite build",
8 | "preview": "vite preview"
9 | },
10 | "dependencies": {
11 | "@chakra-ui/icons": "^2.0.17",
12 | "@chakra-ui/react": "^2.5.5",
13 | "@emotion/react": "^11.10.6",
14 | "@emotion/styled": "^11.10.6",
15 | "d3-hierarchy": "^3.1.2",
16 | "d3-timer": "^3.0.1",
17 | "framer-motion": "^9.0.4",
18 | "highlightjs-solidity": "^2.0.6",
19 | "js-tiktoken": "^1.0.7",
20 | "mathjs": "^11.11.0",
21 | "mixpanel-browser": "^2.46.0",
22 | "nunjucks": "^3.2.4",
23 | "openai": "^4.5.0",
24 | "openai-streams": "^4.2.0",
25 | "re-resizable": "^6.9.9",
26 | "react": "^18.2.0",
27 | "react-beforeunload": "^2.5.3",
28 | "react-dom": "^18.2.0",
29 | "react-hotkeys-hook": "^4.3.7",
30 | "react-icons": "^4.11.0",
31 | "react-markdown": "^8.0.6",
32 | "react-textarea-autosize": "^8.4.0",
33 | "reactflow": "^11.9.4",
34 | "rehype-highlight": "^6.0.0",
35 | "yield-stream": "^2.3.0"
36 | },
37 | "devDependencies": {
38 | "@types/mixpanel-browser": "^2.38.1",
39 | "@types/node": "^18.14.2",
40 | "@types/nunjucks": "^3.2.5",
41 | "@types/react": "^18.0.27",
42 | "@types/react-beforeunload": "^2.1.1",
43 | "@types/react-dom": "^18.0.10",
44 | "@vitejs/plugin-react-swc": "^3.0.0",
45 | "typescript": "^5.1.6",
46 | "vite": "^5.3.3"
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/public/meta-full.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/normal-computing/branches/eb7111da6edb7a762bfbe4bdad79b6978890657c/public/meta-full.png
--------------------------------------------------------------------------------
/public/meta.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/normal-computing/branches/eb7111da6edb7a762bfbe4bdad79b6978890657c/public/meta.jpg
--------------------------------------------------------------------------------
/src/components/App.tsx:
--------------------------------------------------------------------------------
1 | import { MIXPANEL_TOKEN } from "../main";
2 | import { isValidAPIKey } from "../utils/apikey";
3 | import { Column, Row } from "../utils/chakra";
4 | import {
5 | API_KEY_LOCAL_STORAGE_KEY,
6 | DEFAULT_SETTINGS,
7 | FIT_VIEW_SETTINGS,
8 | MODEL_SETTINGS_LOCAL_STORAGE_KEY,
9 | REACT_FLOW_NODE_TYPES,
10 | REACT_FLOW_LOCAL_STORAGE_KEY,
11 | TOAST_CONFIG,
12 | SAVED_CHAT_SIZE_LOCAL_STORAGE_KEY,
13 | } from "../utils/constants";
14 | import { useDebouncedEffect } from "../utils/debounce";
15 | import { newBranchesEdge } from "../utils/branchesEdge";
16 | import {
17 | getBranchesNode,
18 | newBranchesNode,
19 | appendTextToBranchesNodeAsGPT,
20 | getBranchesNodeLineage,
21 | modifyBranchesNodeText,
22 | markOnlyNodeAsSelected,
23 | getConnectionAllowed,
24 | } from "../utils/branchesNode";
25 | import { useLocalStorage } from "../utils/lstore";
26 | import { getAvailableChatModels } from "../utils/models";
27 | import { generateNodeId, generateStreamId } from "../utils/nodeId";
28 | import {
29 | explanationMessage,
30 | humanEvalMessageFromNode,
31 | regenMessage,
32 | } from "../utils/prompt";
33 | import { resetURL } from "../utils/qparams";
34 | import { useDebouncedWindowResize } from "../utils/resize";
35 | import {
36 | ToTNodeData,
37 | BranchesNodeType,
38 | Settings,
39 | HumanEvalProblemsType,
40 | } from "../utils/types";
41 | import { NodeInfo } from "./NodeInfo";
42 | import { APIKeyModal } from "./modals/APIKeyModal";
43 | import { SettingsModal } from "./modals/SettingsModal";
44 | import { NavigationBar } from "./utils/NavigationBar";
45 | import { CheckCircleIcon } from "@chakra-ui/icons";
46 | import { Box, useDisclosure, Spinner, useToast } from "@chakra-ui/react";
47 | import mixpanel from "mixpanel-browser";
48 | import { OpenAI } from "openai-streams";
49 | import { Resizable } from "re-resizable";
50 | import { useEffect, useState, useCallback, useRef } from "react";
51 | import { useBeforeunload } from "react-beforeunload";
52 | import rawHumanEvalProblems from "../utils/human_eval_problems.json";
53 | import useExpandCollapse from "./nodes/useExpandCollapse";
54 | import useAnimatedNodes from "./nodes/useAnimatedNodes";
55 |
56 | import ReactFlow, {
57 | addEdge,
58 | Background,
59 | Connection,
60 | Node,
61 | Edge,
62 | NodeMouseHandler,
63 | useEdgesState,
64 | useNodesState,
65 | SelectionMode,
66 | ReactFlowInstance,
67 | ReactFlowJsonObject,
68 | useReactFlow,
69 | updateEdge,
70 | } from "reactflow";
71 | import "reactflow/dist/style.css";
72 | import { yieldStream } from "yield-stream";
73 | import { treeDemo } from "./tree";
74 | import { getBranchesNodeColor } from "../utils/color";
75 | import { getEncoding, encodingForModel } from "js-tiktoken";
76 |
77 | const HUMAN_EVAL_PROBLEMS = rawHumanEvalProblems as HumanEvalProblemsType;
78 |
79 | function App() {
80 | const toast = useToast();
81 |
82 | /*//////////////////////////////////////////////////////////////
83 | CORE REACT FLOW LOGIC
84 | //////////////////////////////////////////////////////////////*/
85 |
86 | type NodeWithText = {
87 | node: Node;
88 | text: string;
89 | };
90 |
91 | const { setViewport, fitView } = useReactFlow();
92 |
93 | const [reactFlow, setReactFlow] = useState(null);
94 |
95 | const [nodes, setNodes, onNodesChange] = useNodesState([]);
96 | const [edges, setEdges, onEdgesChange] = useEdgesState([]);
97 |
98 | const treeWidth: number = 220;
99 | const treeHeight: number = 150;
100 | const animationDuration: number = 200;
101 |
102 | const { nodes: visibleNodes, edges: visibleEdges } = useExpandCollapse(nodes, edges, {
103 | treeWidth,
104 | treeHeight,
105 | });
106 | const { nodes: animatedNodes } = useAnimatedNodes(visibleNodes, { animationDuration });
107 |
108 | const [filteredNodes, setFilteredNodes] = useState([]);
109 | const [showAnswerPathOnly, setShowAnswerPathOnly] = useState(false);
110 |
111 | const [inputTokenCount, setInputTokenCount] = useState(0);
112 | const [outputTokenCount, setOutputTokenCount] = useState(0);
113 |
114 | const edgeUpdateSuccessful = useRef(true);
115 |
116 | const onEdgeUpdateStart = useCallback(() => {
117 | edgeUpdateSuccessful.current = false;
118 | }, []);
119 |
120 | const onEdgeUpdate = (oldEdge: Edge, newConnection: Connection) => {
121 | if (
122 | !getConnectionAllowed(nodes, edges, {
123 | source: newConnection.source!,
124 | target: newConnection.target!,
125 | })
126 | )
127 | return;
128 |
129 | edgeUpdateSuccessful.current = true;
130 |
131 | setEdges((edges) => updateEdge(oldEdge, newConnection, edges));
132 | };
133 |
134 | const onEdgeUpdateEnd = (_: unknown, edge: Edge) => {
135 | if (!edgeUpdateSuccessful.current) {
136 | setEdges((edges) => edges.filter((e) => e.id !== edge.id));
137 | }
138 |
139 | edgeUpdateSuccessful.current = true;
140 | };
141 |
142 | const onConnect = (connection: Edge | Connection) => {
143 | if (
144 | !getConnectionAllowed(nodes, edges, {
145 | source: connection.source!,
146 | target: connection.target!,
147 | })
148 | )
149 | return;
150 |
151 | setEdges((eds) => addEdge({ ...connection }, eds));
152 | };
153 |
154 | const autoZoom = () => setTimeout(() => fitView(FIT_VIEW_SETTINGS), 50);
155 |
156 | const autoZoomIfNecessary = () => {
157 | if (settings.autoZoom) autoZoom();
158 | };
159 |
160 | const save = () => {
161 | if (reactFlow) {
162 | localStorage.setItem(
163 | REACT_FLOW_LOCAL_STORAGE_KEY,
164 | JSON.stringify(reactFlow.toObject())
165 | );
166 | }
167 | };
168 |
169 | // Auto save.
170 | const isSavingReactFlow = useDebouncedEffect(
171 | save,
172 | 1000, // 1 second.
173 | [reactFlow, nodes, edges]
174 | );
175 |
176 | // Auto restore on load.
177 | useEffect(() => {
178 | if (reactFlow) {
179 | // const rawFlow = undefined;
180 |
181 | // const flow: ReactFlowJsonObject = rawFlow ? JSON.parse(rawFlow) : null;
182 | const flow: ReactFlowJsonObject = treeDemo;
183 |
184 | if (flow !== null) {
185 | setEdges(flow.edges || []);
186 | setViewport(flow.viewport);
187 |
188 | const nodes = flow.nodes; // For brevity.
189 |
190 | if (nodes.length > 0) {
191 | // Either the first selected node we find, or the first node in the array.
192 | const toSelect = nodes.find((node) => node.selected)?.id ?? nodes[0].id;
193 |
194 | // Add the nodes to the React Flow array and select the node.
195 | selectNode(toSelect, () => nodes);
196 |
197 | // If there was a newTreeWith query param, create a new tree with that content.
198 | // We pass false for forceAutoZoom because we'll do it 500ms later to avoid lag.
199 | }
200 | }
201 |
202 | setTimeout(() => {
203 | // Do this with a more generous timeout to make sure
204 | // the nodes are rendered and the settings have loaded in.
205 | if (settings.autoZoom) fitView(FIT_VIEW_SETTINGS);
206 | }, 500);
207 |
208 | resetURL(); // Get rid of the query params.
209 | }
210 | }, [reactFlow]);
211 |
212 | /*//////////////////////////////////////////////////////////////
213 | AI PROMPT CALLBACKS
214 | //////////////////////////////////////////////////////////////*/
215 |
216 | // Takes a prompt, submits it to the GPT API with n responses,
217 | // then creates a child node for each response under the selected node.
218 | const submitPrompt = async () => {
219 | const temp = settings.temp;
220 | const model = settings.model;
221 | const parentNode = selectedNodeLineage[0];
222 | const submittedNode = getBranchesNode(nodes, parentNode.id)!;
223 |
224 | console.log("current node", submittedNode);
225 |
226 | type SetNodes = React.Dispatch>;
227 | type SetEdges = React.Dispatch>;
228 | type BranchesNodeInput = {
229 | id?: string;
230 | x: number;
231 | y: number;
232 | branchesNodeType: BranchesNodeType;
233 | input: string;
234 | text: string;
235 | streamId?: string;
236 | steps: string[];
237 | solutions: any[];
238 | errors: any[];
239 | style: any;
240 | explanations: any[];
241 | };
242 | type BranchesEdgeInput = {
243 | source: string;
244 | target: string;
245 | animated: boolean;
246 | }
247 |
248 | const createNewNodeAndEdge = (
249 | currentNode: Node,
250 | newBranchesNode: (node: BranchesNodeInput) => Node,
251 | newBranchesEdge: (node: BranchesEdgeInput) => Edge,
252 | setNodes: SetNodes,
253 | setEdges: SetEdges,
254 | streamId: string,
255 | isSolutionNode: boolean,
256 | callback: (newNode: Node) => void
257 | ) => {
258 | const currentChildNodeId = generateNodeId();
259 |
260 | setNodes((prevNodes: Node[]) => {
261 | const matchingNode = prevNodes.find((n) => n.id === currentNode.id);
262 | if (!matchingNode) {
263 | throw new Error("Node not found");
264 | }
265 |
266 | // Create a new node using the currentErrors
267 | const newNode = newBranchesNode({
268 | id: currentChildNodeId,
269 | x: matchingNode.position.x + 10,
270 | y: matchingNode.position.y + 100,
271 | branchesNodeType: BranchesNodeType.GPT,
272 | input: matchingNode.data.input,
273 | text: "",
274 | streamId,
275 | steps: [...matchingNode.data.steps, ""],
276 | solutions: isSolutionNode
277 | ? [...matchingNode.data.solutions, ""]
278 | : [...matchingNode.data.solutions],
279 | style: { background: getBranchesNodeColor(!isSolutionNode, true, false, true, 0) },
280 | errors: [...matchingNode.data.errors],
281 | explanations: isSolutionNode
282 | ? [...matchingNode.data.explanations]
283 | : [...matchingNode.data.explanations, ""],
284 | });
285 |
286 | callback(newNode);
287 |
288 | return [...prevNodes, newNode];
289 | });
290 |
291 | setEdges((prevEdges) => [
292 | ...prevEdges,
293 | newBranchesEdge({
294 | source: currentNode.id,
295 | target: currentChildNodeId,
296 | animated: true,
297 | }),
298 | ]);
299 |
300 | setTimeout(autoZoomIfNecessary, 500);
301 | };
302 |
303 | const updateNodeColor = (
304 | nodeId: string,
305 | setNodes: SetNodes,
306 | isExplanation?: boolean
307 | ) => {
308 | setNodes((prevNodes: Node[]) => {
309 | const newNodes = prevNodes.map((node) => {
310 | if (node.id === nodeId) {
311 | console.log(node.data.score)
312 | return {
313 | ...node,
314 | style: {
315 | background: getBranchesNodeColor(
316 | isExplanation || false,
317 | false,
318 | node.data.isTerminal || false,
319 | !node.data.errors || node.data.errors.length == 0,
320 | node.data.score || 0
321 | ),
322 | },
323 | };
324 | }
325 | return node;
326 | });
327 | return newNodes;
328 | });
329 | };
330 |
331 | const updatePreviousEdge = (currentChildNodeId: string, setEdges: SetEdges) => {
332 | setEdges((prevEdges: Edge[]) => {
333 | return prevEdges.map((edge) => {
334 | if (edge.target === currentChildNodeId) {
335 | return { ...edge, animated: false };
336 | }
337 | return edge;
338 | });
339 | });
340 | };
341 |
342 | // need to modify to use model name, currenlty defining enc in function is very slow
343 | const enc = encodingForModel("gpt-3.5-turbo");
344 | function countTokens(text: string): number {
345 | const tokens = enc.encode(text);
346 | return tokens.length;
347 | }
348 |
349 | const addError = (nodeId: string, error: string, setNodes: SetNodes) => {
350 | setNodes((prevNodes: Node[]) => {
351 | const newNodes = prevNodes.map((node) => {
352 | if (node.id === nodeId) {
353 | const existingErrors = node.data.errors || []; // Initialize to empty array if it doesn't exist
354 | return {
355 | ...node,
356 | data: {
357 | ...node.data,
358 | errors: [...existingErrors, error], // Append the new error to the existing array
359 | },
360 | };
361 | }
362 | return node;
363 | });
364 | return newNodes;
365 | });
366 | };
367 |
368 | async function executeInterpreter(
369 | node: Node,
370 | solutionText: string,
371 | finalNode: boolean
372 | ): Promise {
373 | let data = {
374 | problem: HUMAN_EVAL_PROBLEMS[node.data.input],
375 | completion: solutionText,
376 | };
377 |
378 | console.log("node solution text", solutionText);
379 | console.log("data", JSON.stringify(data));
380 |
381 | let url = "/execute";
382 | let response = await fetch(url, {
383 | method: "POST",
384 | headers: {
385 | "Content-Type": "application/json",
386 | },
387 | body: JSON.stringify(data),
388 | });
389 |
390 | // Parse JSON response
391 | let jsonResponse = await response.json();
392 | console.log("json response", jsonResponse);
393 |
394 | const passed = jsonResponse["result"]["passed"];
395 | console.log("passed", passed);
396 |
397 | if (passed) {
398 | handleFinishedNode(node, true, false);
399 | return null;
400 | } else {
401 | const error = jsonResponse["result"]["result"];
402 | addError(node.id, error, setNodes);
403 | updateNodeColor(node.id, setNodes);
404 |
405 | if (!finalNode) {
406 | const explanationPromises = Array(settings.N_EXPLANATION_FANOUT)
407 | .fill(null)
408 | .map(async () => {
409 | return await generateChild(node, "explanation", error, false);
410 | });
411 |
412 | const explanationChildrenWithText: NodeWithText[] = await Promise.all(
413 | explanationPromises
414 | );
415 |
416 | const regenChildrenPromises: Promise[] =
417 | explanationChildrenWithText.map(async (explanationChildWithText) => {
418 | // Create N_ANSWER_FANOUT number of promises for each explanation child
419 | const regenPromises = Array(settings.N_ANSWER_FANOUT)
420 | .fill(null)
421 | .map(async () => {
422 | // Assuming that `error` is available in the current scope
423 | return await generateChild(
424 | explanationChildWithText.node,
425 | "regen",
426 | error,
427 | true
428 | );
429 | });
430 |
431 | // Await all regenPromises for the current explanation child
432 | return await Promise.all(regenPromises);
433 | });
434 |
435 | // Await all regenChildrenPromises for all explanation children
436 | const regenChildrenArrays: NodeWithText[][] = await Promise.all(
437 | regenChildrenPromises
438 | );
439 |
440 | // Flatten the array of arrays into a single array
441 | const regenChildrenWithText: NodeWithText[] = ([] as NodeWithText[]).concat(...regenChildrenArrays);
442 |
443 | return regenChildrenWithText;
444 | }
445 | return null;
446 | }
447 | }
448 |
449 | const markAsAnswerPath = (
450 | targetNodeId: string,
451 | setNodes: SetNodes,
452 | setEdges: SetEdges
453 | ) => {
454 | setEdges((prevEdges) => {
455 | const edges = [...prevEdges]; // Make a shallow copy for reference
456 | console.log("Edges are:", edges);
457 |
458 | setNodes((prevNodes) => {
459 | const markNodeAndAncestors = (nodeId: string, nodes: Node[]) => {
460 | let updatedNodes: Node[] = [];
461 |
462 | const nodeToUpdate = nodes.find((node) => node.id === nodeId);
463 | if (nodeToUpdate) {
464 | const updatedNode = {
465 | ...nodeToUpdate,
466 | data: { ...nodeToUpdate.data, isInAnswerPath: true },
467 | };
468 | updatedNodes.push(updatedNode);
469 | }
470 |
471 | edges.forEach((edge) => {
472 | if (edge.target === nodeId) {
473 | updatedNodes = [
474 | ...updatedNodes,
475 | ...markNodeAndAncestors(edge.source, nodes),
476 | ];
477 | }
478 | });
479 |
480 | return updatedNodes;
481 | };
482 |
483 | const nodesToUpdate = markNodeAndAncestors(targetNodeId, prevNodes);
484 | return prevNodes.map((node) => {
485 | const nodeToUpdate = nodesToUpdate.find((n) => n.id === node.id);
486 | return nodeToUpdate || node;
487 | });
488 | });
489 |
490 | return edges; // return the edges as-is since we're not modifying them
491 | });
492 | };
493 |
494 | async function handleFinishedNode(
495 | finishedNode: Node,
496 | isTerminal: boolean,
497 | isExplanation: boolean
498 | ): Promise> {
499 | console.log("handling node", finishedNode);
500 | console.log("is explanation?", isExplanation);
501 | let modifiedNode = { ...finishedNode };
502 | if (isTerminal) {
503 | console.log("found terminal node");
504 | markAsAnswerPath(finishedNode.id, setNodes, setEdges);
505 | setNodes((prevNodes: Node[]) => {
506 | const newNodes = prevNodes.map((node) => {
507 | if (node.id === finishedNode?.id) {
508 | modifiedNode = {
509 | ...node,
510 | style: {
511 | background: getBranchesNodeColor(
512 | isExplanation,
513 | false,
514 | isTerminal,
515 | true,
516 | finishedNode.data.score || 0
517 | ),
518 | },
519 | data: {
520 | ...node.data,
521 | isTerminal: true,
522 | },
523 | };
524 | return modifiedNode;
525 | }
526 | return node;
527 | });
528 | return newNodes;
529 | });
530 | }
531 |
532 | updateNodeColor(finishedNode?.id!, setNodes, isExplanation);
533 | updatePreviousEdge(finishedNode?.id!, setEdges);
534 |
535 | return modifiedNode;
536 | }
537 |
538 | async function generateChild(
539 | node: Node,
540 | nodeType: string,
541 | error: string,
542 | isSolutionNode: boolean
543 | ): Promise {
544 | console.log("generating from this node", node);
545 | const DECODER = new TextDecoder();
546 |
547 | const abortController = new AbortController();
548 |
549 | const streamId = generateStreamId();
550 | console.log("new stream id", streamId);
551 | let isNewNode = true;
552 |
553 | const question = HUMAN_EVAL_PROBLEMS[node.data.input]["prompt"];
554 | let answer = node.data.steps[0];
555 | console.log("this is the node we're generating from", node);
556 | let explanation = "";
557 | if (nodeType == "regen") {
558 | explanation = node.data.steps[1];
559 | }
560 |
561 | const messages =
562 | nodeType == "explanation"
563 | ? explanationMessage(question, answer, error)
564 | : nodeType == "regen"
565 | ? regenMessage(question, answer, error, explanation)
566 | : humanEvalMessageFromNode(node);
567 | const newInputTokens = countTokens(messages[0]["content"]);
568 | setInputTokenCount((prevCount) => prevCount + newInputTokens);
569 |
570 | const stream = await OpenAI(
571 | "chat",
572 | {
573 | model,
574 | temperature: temp,
575 | messages,
576 | },
577 | { apiKey: apiKey!, mode: "raw" }
578 | );
579 | let currentText: string = "";
580 | let currentChildNode: Node | null = null;
581 |
582 | for await (const chunk of yieldStream(stream, abortController)) {
583 | if (abortController.signal.aborted) break;
584 |
585 | try {
586 | const decoded = JSON.parse(DECODER.decode(chunk));
587 | const choice = decoded.choices[0];
588 |
589 | if (choice.delta?.content) {
590 | const chars = choice.delta.content;
591 | const newTokens = countTokens(chars);
592 | setOutputTokenCount((prevCount) => prevCount + newTokens);
593 |
594 | // new node
595 | if (isNewNode) {
596 | createNewNodeAndEdge(
597 | node,
598 | newBranchesNode,
599 | newBranchesEdge,
600 | setNodes,
601 | setEdges,
602 | streamId,
603 | isSolutionNode,
604 | (newNode) => {
605 | currentChildNode = newNode;
606 | }
607 | );
608 | isNewNode = false;
609 | }
610 | currentText += chars;
611 |
612 | setNodes((prevNodes: Node[]) => {
613 | return appendTextToBranchesNodeAsGPT(
614 | prevNodes,
615 | {
616 | id: currentChildNode?.id!,
617 | text: currentText,
618 | streamId,
619 | },
620 | isSolutionNode
621 | );
622 | });
623 |
624 | // We cannot return within the loop, and we do
625 | // not want to execute the code below, so we break.
626 | if (abortController.signal.aborted) break;
627 | }
628 | } catch (err) {
629 | console.error(err);
630 | }
631 | }
632 |
633 | const finalChild: Node = await handleFinishedNode(
634 | currentChildNode!,
635 | false,
636 | nodeType == "explanation"
637 | );
638 | return { node: finalChild, text: currentText };
639 | }
640 |
641 | const promises = Array(settings.N_ANSWER_FANOUT)
642 | .fill(null)
643 | .map(async () => {
644 | return await generateChild(submittedNode, "normal", "", true);
645 | });
646 |
647 | const childrenWithText = await Promise.all(promises);
648 |
649 | autoZoomIfNecessary();
650 |
651 | const interpretChildrenPromises = childrenWithText.map(async (childWithText) => {
652 | return await executeInterpreter(childWithText.node, childWithText.text, false);
653 | });
654 |
655 | const regenChildren = await Promise.all(interpretChildrenPromises);
656 |
657 | autoZoomIfNecessary();
658 |
659 | if (regenChildren) {
660 | const combinedRegenChildren = regenChildren.flatMap((childArray) =>
661 | childArray !== null ? childArray : []
662 | );
663 |
664 | combinedRegenChildren.map(async (regenChild: NodeWithText) => {
665 | return await executeInterpreter(regenChild.node, regenChild.text, true);
666 | });
667 | }
668 |
669 | autoZoomIfNecessary();
670 |
671 | if (MIXPANEL_TOKEN) mixpanel.track("Submitted Prompt"); // KPI
672 | };
673 |
674 | /*//////////////////////////////////////////////////////////////
675 | SELECTED NODE LOGIC
676 | //////////////////////////////////////////////////////////////*/
677 |
678 | const [selectedNodeId, setSelectedNodeId] = useState(null);
679 |
680 | const selectedNodeLineage =
681 | selectedNodeId !== null ? getBranchesNodeLineage(nodes, edges, selectedNodeId) : [];
682 |
683 | /*//////////////////////////////////////////////////////////////
684 | NODE MUTATION CALLBACKS
685 | //////////////////////////////////////////////////////////////*/
686 |
687 | /*//////////////////////////////////////////////////////////////
688 | NODE SELECTION CALLBACKS
689 | //////////////////////////////////////////////////////////////*/
690 |
691 | const selectNode = (
692 | id: string,
693 | computeNewNodes?: (currNodes: Node[]) => Node[]
694 | ) => {
695 | setSelectedNodeId(id);
696 | setNodes((currNodes) =>
697 | // If we were passed a computeNewNodes function, use it, otherwise just use the current nodes.
698 | markOnlyNodeAsSelected(computeNewNodes ? computeNewNodes(currNodes) : currNodes, id)
699 | );
700 | };
701 |
702 | /*//////////////////////////////////////////////////////////////
703 | SETTINGS MODAL LOGIC
704 | //////////////////////////////////////////////////////////////*/
705 |
706 | const {
707 | isOpen: isSettingsModalOpen,
708 | onOpen: onOpenSettingsModal,
709 | onClose: onCloseSettingsModal,
710 | } = useDisclosure();
711 |
712 | const [settings, setSettings] = useState(() => {
713 | const rawSettings = localStorage.getItem(MODEL_SETTINGS_LOCAL_STORAGE_KEY);
714 |
715 | if (rawSettings !== null) {
716 | return JSON.parse(rawSettings) as Settings;
717 | } else {
718 | return DEFAULT_SETTINGS;
719 | }
720 | });
721 |
722 | const isGPT4 = settings.model.includes("gpt-4");
723 |
724 | // Auto save.
725 | const isSavingSettings = useDebouncedEffect(
726 | () => {
727 | localStorage.setItem(MODEL_SETTINGS_LOCAL_STORAGE_KEY, JSON.stringify(settings));
728 | },
729 | 1000, // 1 second.
730 | [settings]
731 | );
732 |
733 | /*//////////////////////////////////////////////////////////////
734 | API KEY LOGIC
735 | //////////////////////////////////////////////////////////////*/
736 |
737 | const [apiKey, setApiKey] = useLocalStorage(API_KEY_LOCAL_STORAGE_KEY);
738 |
739 | const [availableModels, setAvailableModels] = useState(null);
740 |
741 | // modelsLoadCounter lets us discard the results of the requests if a concurrent newer one was made.
742 | const modelsLoadCounter = useRef(0);
743 | useEffect(() => {
744 | if (isValidAPIKey(apiKey)) {
745 | const modelsLoadIndex = modelsLoadCounter.current + 1;
746 | modelsLoadCounter.current = modelsLoadIndex;
747 |
748 | setAvailableModels(null);
749 |
750 | (async () => {
751 | let modelList: string[] = [];
752 | try {
753 | modelList = await getAvailableChatModels(apiKey!);
754 | } catch (e) {
755 | toast({
756 | title: "Failed to load model list!",
757 | status: "error",
758 | ...TOAST_CONFIG,
759 | });
760 | }
761 | if (modelsLoadIndex !== modelsLoadCounter.current) return;
762 |
763 | if (modelList.length === 0) modelList.push(settings.model);
764 |
765 | setAvailableModels(modelList);
766 |
767 | if (!modelList.includes(settings.model)) {
768 | const oldModel = settings.model;
769 | const newModel = modelList.includes(DEFAULT_SETTINGS.model)
770 | ? DEFAULT_SETTINGS.model
771 | : modelList[0];
772 |
773 | setSettings((settings) => ({ ...settings, model: newModel }));
774 |
775 | toast({
776 | title: `Model "${oldModel}" no longer available!`,
777 | description: `Switched to "${newModel}"`,
778 | status: "warning",
779 | ...TOAST_CONFIG,
780 | });
781 | }
782 | })();
783 | }
784 | }, [apiKey]);
785 |
786 | useEffect(() => {
787 | const updatedNodes: Node[] = animatedNodes.filter(
788 | (node) => node.data?.isInAnswerPath
789 | );
790 | setFilteredNodes(updatedNodes);
791 | }, [nodes]);
792 |
793 | const isAnythingSaving = isSavingReactFlow || isSavingSettings;
794 | const isAnythingLoading = isAnythingSaving || availableModels === null;
795 |
796 | useBeforeunload((event: BeforeUnloadEvent) => {
797 | // Prevent leaving the page before saving.
798 | if (isAnythingSaving) event.preventDefault();
799 | });
800 |
801 | /*//////////////////////////////////////////////////////////////
802 | WINDOW RESIZE LOGIC
803 | //////////////////////////////////////////////////////////////*/
804 |
805 | useDebouncedWindowResize(autoZoomIfNecessary, 100);
806 |
807 | /*//////////////////////////////////////////////////////////////
808 | CHAT RESIZE LOGIC
809 | //////////////////////////////////////////////////////////////*/
810 |
811 | const [savedChatSize, setSavedChatSize] = useLocalStorage(
812 | SAVED_CHAT_SIZE_LOCAL_STORAGE_KEY
813 | );
814 |
815 | /*//////////////////////////////////////////////////////////////
816 | APP
817 | //////////////////////////////////////////////////////////////*/
818 |
819 | const onNodeClick: NodeMouseHandler = useCallback(
820 | (_, node) => {
821 | setSelectedNodeId(node.id);
822 | setNodes((nds) =>
823 | nds.map((n) => {
824 | if (n.id === node.id) {
825 | return {
826 | ...n,
827 | data: { ...n.data, expanded: !n.data.expanded },
828 | };
829 | }
830 |
831 | return n;
832 | })
833 | );
834 | },
835 | [setNodes]
836 | );
837 |
838 | return (
839 | <>
840 | {!isValidAPIKey(apiKey) && }
841 |
842 |
851 |
857 |
858 | {
877 | setSavedChatSize(ref.style.width);
878 | autoZoomIfNecessary();
879 |
880 | if (MIXPANEL_TOKEN) mixpanel.track("Resized chat window");
881 | }}
882 | >
883 |
890 |
899 | {
901 | onOpenSettingsModal();
902 |
903 | if (MIXPANEL_TOKEN) mixpanel.track("Opened Settings Modal"); // KPI
904 | }}
905 | onToggleAnswerFilter={() => {
906 | setShowAnswerPathOnly(!showAnswerPathOnly);
907 | }}
908 | showAnswerPathOnly={showAnswerPathOnly}
909 | />
910 |
911 |
912 | Input Token Count: {inputTokenCount}
913 |
914 |
915 |
916 | Output Token Count: {outputTokenCount}
917 |
918 |
919 | {/*
920 | Total Cost (GPT-4): ${((inputTokenCount * 0.03 / 1000) + (outputTokenCount * 0.06 / 1000)).toFixed(2)}
921 | */}
922 |
923 |
924 | {isAnythingLoading ? (
925 |
926 | ) : (
927 |
928 | )}
929 |
930 |
931 |
932 |
961 |
962 |
963 |
964 |
965 |
966 |
967 | {
972 | setNodes((nodes) =>
973 | modifyBranchesNodeText(nodes, {
974 | asHuman: true,
975 | id: selectedNodeId!,
976 | text,
977 | isRunning: false,
978 | })
979 | );
980 | }}
981 | apiKey={apiKey}
982 | nodes={nodes}
983 | edges={edges}
984 | />
985 |
986 |
987 |
988 | >
989 | );
990 | }
991 |
992 | export default App;
993 |
--------------------------------------------------------------------------------
/src/components/NodeInfo.tsx:
--------------------------------------------------------------------------------
1 | import { Node, Edge } from "reactflow";
2 | import rawHumanEvalProblems from "../utils/human_eval_problems.json";
3 | import {
4 | Box,
5 | Text,
6 | Tag,
7 | TagLeftIcon,
8 | TagLabel,
9 | Textarea,
10 | List,
11 | ListItem,
12 | ListIcon,
13 | Input,
14 | Heading,
15 | Flex,
16 | } from "@chakra-ui/react";
17 | import { CheckIcon } from "@chakra-ui/icons";
18 | import { MdCheck, MdQuestionMark, MdClose, MdThumbUpOffAlt } from "react-icons/md";
19 | import { Settings, ToTNodeData, HumanEvalProblemsType } from "../utils/types";
20 | import { Prompt } from "./Prompt";
21 | import { getBranchesNodeParent } from "../utils/branchesNode";
22 | import { useEffect, useState } from "react";
23 | import { Markdown } from "./utils/Markdown";
24 | import { Row } from "../utils/chakra";
25 |
26 | const HUMAN_EVAL_PROBLEMS = rawHumanEvalProblems as HumanEvalProblemsType;
27 |
28 | function EvalListItem({ item }: { item: string }) {
29 | if (item) {
30 | const lines = item.split("\n");
31 | const lastLine = lines[lines.length - 1];
32 | let icon = null;
33 | if (lastLine === "sure") {
34 | icon = ;
35 | } else if (lastLine === "impossible") {
36 | icon = ;
37 | } else if (lastLine === "likely") {
38 | icon = ;
39 | }
40 |
41 | return (
42 |
43 | {icon}
44 | {lines.map((line, i) => {
45 | return (
46 |
47 | {line}
48 |
49 |
50 | );
51 | })}
52 |
53 | );
54 | }
55 | return ;
56 | }
57 |
58 | export function NodeInfo({
59 | lineage,
60 | selectNode,
61 | submitPrompt,
62 | apiKey,
63 | onPromptType,
64 | nodes,
65 | edges,
66 | }: {
67 | lineage: Node[] | null;
68 | settings?: Settings;
69 | setSettings?: (settings: Settings) => void;
70 | isGPT4?: boolean;
71 | submitPrompt: () => Promise;
72 | selectNode: (id: string) => void;
73 | apiKey: string | null;
74 | onPromptType: (text: string) => void;
75 | nodes: Node[];
76 | edges: Edge[];
77 | }) {
78 | const selectedNode =
79 | lineage &&
80 | (lineage.find((n) => n.selected === true) as Node | undefined);
81 | const selectedNodeId = selectedNode?.id ?? null;
82 | const rootNode = lineage ? lineage[lineage.length - 1] : undefined;
83 |
84 | const [selectedNodeParent, setSelectedNodeParent] = useState<
85 | Node | null | undefined
86 | >(null);
87 |
88 | useEffect(() => {
89 | const newSelectedNodeParent =
90 | selectedNodeId !== null ? getBranchesNodeParent(nodes, edges, selectedNodeId) : null;
91 | setSelectedNodeParent(newSelectedNodeParent);
92 | }, [selectedNodeId, nodes, edges]);
93 |
94 | return (
95 |
96 | {selectedNode?.data.isTerminal ? (
97 |
103 |
104 | Terminal
105 |
106 | ) : null}
107 | {selectedNode?.data.isValid ? (
108 |
109 |
110 | Valid
111 |
112 | ) : null}
113 |
114 | {/*
115 | Input
116 |
117 | {selectedNodeParent || selectedNodeId == null ? (
118 |
{selectedNode?.data.input ?? ""}
119 | ) : (
120 |
222 | );
223 | }
224 |
--------------------------------------------------------------------------------
/src/components/Prompt.tsx:
--------------------------------------------------------------------------------
1 | import { MIXPANEL_TOKEN } from "../main";
2 | import { Row, Center, Column } from "../utils/chakra";
3 | import { getBranchesNodeColor, getBranchesNodeTypeDarkColor } from "../utils/color";
4 | import { setBranchesNodeStreamId } from "../utils/branchesNode";
5 | import { ToTNodeData, BranchesNodeType, Settings } from "../utils/types";
6 | import { BigButton } from "./utils/BigButton";
7 | import { Markdown } from "./utils/Markdown";
8 | import { NotAllowedIcon } from "@chakra-ui/icons";
9 | import { Spinner, Button, Heading } from "@chakra-ui/react";
10 | import mixpanel from "mixpanel-browser";
11 | import { useState, useEffect, useRef } from "react";
12 | import { Node, useReactFlow } from "reactflow";
13 |
14 | export function Prompt({
15 | submitPrompt,
16 | selectedNode,
17 | lineage,
18 | selectNode,
19 | }: {
20 | submitPrompt: () => Promise;
21 | selectedNode?: Node | null;
22 | lineage: Node[] | null;
23 | selectNode: (id: string) => void;
24 | }) {
25 | const { setNodes } = useReactFlow();
26 |
27 | // const promptNode = lineage[0];
28 |
29 | // const promptNodeType = selectedNode.data.branchesNodeType;
30 |
31 | const onMainButtonClick = () => {
32 | submitPrompt();
33 | };
34 |
35 | const stopGenerating = () => {
36 | // Reset the stream id.
37 | setNodes((nodes) =>
38 | setBranchesNodeStreamId(nodes, { id: selectedNode?.id ?? '', streamId: undefined })
39 | );
40 |
41 | if (MIXPANEL_TOKEN) mixpanel.track("Stopped generating response");
42 | };
43 |
44 | /*//////////////////////////////////////////////////////////////
45 | STATE
46 | //////////////////////////////////////////////////////////////*/
47 |
48 | const [hoveredNodeId, setHoveredNodeId] = useState(null);
49 |
50 | /*//////////////////////////////////////////////////////////////
51 | EFFECTS
52 | //////////////////////////////////////////////////////////////*/
53 |
54 | const textOffsetRef = useRef(-1);
55 |
56 | // Scroll to the prompt buttons
57 | // when the bottom node is swapped.
58 | // useEffect(() => {
59 | // window.document
60 | // .getElementById("promptButtons")
61 | // ?.scrollIntoView(/* { behavior: "smooth" } */);
62 | // }, [selectedNode.id]);
63 |
64 | /*//////////////////////////////////////////////////////////////
65 | APP
66 | //////////////////////////////////////////////////////////////*/
67 |
68 | return (
69 | <>
70 | {selectedNode &&
71 | selectedNode?.data &&
72 | selectedNode.data.solutions.reverse().map((solution, i) => {
73 | const data = selectedNode.data;
74 | const errors = data.errors || [];
75 | const explanations = data.explanations || [];
76 | const currNode: Node | undefined = lineage
77 | ?.slice(0, lineage.length - 1)
78 | .reverse()[i * 2];
79 | const currExplanationNode =
80 | lineage?.length || 0 > 2
81 | ? lineage?.slice(0, lineage.length - 1).reverse()[1]
82 | : null;
83 |
84 | return (
85 | <>
86 | setHoveredNodeId(node.id)}
99 | ///onMouseLeave={() => setHoveredNodeId(null)}
100 | bg={currNode?.style?.background as string || "#FFFFFF"}
101 | key={currNode?.id}
102 | onClick={() => {
103 | const selection = window.getSelection();
104 |
105 | if (selection?.isCollapsed) {
106 | selectNode(currNode!.id);
107 | }
108 | }}
109 | cursor="pointer"
110 | >
111 | {data.streamId && data.text === "" ? (
112 |
113 |
114 |
115 | ) : (
116 | <>
117 |
133 |
145 | {solution && (
146 |
152 | Solution
153 |
154 |
155 | )}
156 | {errors[i] && (
157 |
164 | Error
165 |
166 |
167 | )}
168 |
169 | >
170 | )}
171 |
172 | {explanations[i] && (
173 | {
188 | const selection = window.getSelection();
189 |
190 | if (selection?.isCollapsed) {
191 | selectNode(currExplanationNode!.id);
192 | }
193 | }}
194 | cursor="pointer"
195 | >
196 |
203 | Explanation
204 |
205 |
206 |
207 | )}
208 | >
209 | );
210 | })}
211 |
212 | {
213 |
220 |
228 | Generate children nodes
229 |
230 |
231 | }
232 | >
233 | );
234 | }
235 |
--------------------------------------------------------------------------------
/src/components/modals/APIKeyModal.tsx:
--------------------------------------------------------------------------------
1 | import mixpanel from "mixpanel-browser";
2 |
3 | import { Modal, ModalOverlay, ModalContent, Link, Text } from "@chakra-ui/react";
4 |
5 | import { MIXPANEL_TOKEN } from "../../main";
6 |
7 | import { Column } from "../../utils/chakra";
8 | import { isValidAPIKey } from "../../utils/apikey";
9 | import { APIKeyInput } from "../utils/APIKeyInput";
10 |
11 | export function APIKeyModal({
12 | apiKey,
13 | setApiKey,
14 | }: {
15 | apiKey: string | null;
16 | setApiKey: (apiKey: string) => void;
17 | }) {
18 | const setApiKeyTracked = (apiKey: string) => {
19 | setApiKey(apiKey);
20 |
21 | if (isValidAPIKey(apiKey)) {
22 | if (MIXPANEL_TOKEN) mixpanel.track("Entered API Key"); // KPI
23 |
24 | // Hacky way to get the prompt box to focus after the
25 | // modal closes. Long term should probably use a ref.
26 | setTimeout(() => window.document.getElementById("promptBox")?.focus(), 50);
27 | }
28 | };
29 |
30 | return (
31 | {}}
34 | size="3xl"
35 | isCentered={true}
36 | motionPreset="none"
37 | >
38 |
39 |
40 |
41 |
42 |
43 | We will never upload, log, or store your API key outside of your
44 | browser's local storage. Verify for yourself{" "}
45 |
46 | here
47 |
48 | .
49 |
50 |
51 |
52 |
53 | );
54 | }
55 |
--------------------------------------------------------------------------------
/src/components/modals/SettingsModal.tsx:
--------------------------------------------------------------------------------
1 | import { MIXPANEL_TOKEN } from "../../main";
2 | import { getBranchesNodeTypeDarkColor } from "../../utils/color";
3 | import { DEFAULT_SETTINGS } from "../../utils/constants";
4 | import { Settings, BranchesNodeType } from "../../utils/types";
5 | import { APIKeyInput } from "../utils/APIKeyInput";
6 | import { LabeledSelect, LabeledSlider } from "../utils/LabeledInputs";
7 |
8 | import {
9 | Button,
10 | Modal,
11 | ModalBody,
12 | ModalCloseButton,
13 | ModalContent,
14 | ModalFooter,
15 | ModalHeader,
16 | ModalOverlay,
17 | Checkbox,
18 | } from "@chakra-ui/react";
19 | import mixpanel from "mixpanel-browser";
20 | import { ChangeEvent, memo } from "react";
21 |
22 | export const SettingsModal = memo(function SettingsModal({
23 | isOpen,
24 | onClose,
25 | settings,
26 | setSettings,
27 | apiKey,
28 | setApiKey,
29 | availableModels,
30 | }: {
31 | isOpen: boolean;
32 | onClose: () => void;
33 | settings: Settings;
34 | setSettings: (settings: Settings) => void;
35 | apiKey: string | null;
36 | setApiKey: (apiKey: string) => void;
37 | availableModels: string[] | null;
38 | }) {
39 | const reset = () => {
40 | if (
41 | confirm(
42 | "Are you sure you want to reset your settings to default? This cannot be undone!"
43 | )
44 | ) {
45 | setSettings(DEFAULT_SETTINGS);
46 |
47 | if (MIXPANEL_TOKEN) mixpanel.track("Restored defaults");
48 | }
49 | };
50 |
51 | const hardReset = () => {
52 | if (
53 | confirm(
54 | "Are you sure you want to delete ALL data (including your saved API key, conversations, etc?) This cannot be undone!"
55 | ) &&
56 | confirm(
57 | "Are you 100% sure? Reminder this cannot be undone and you will lose EVERYTHING!"
58 | )
59 | ) {
60 | // Clear local storage.
61 | localStorage.clear();
62 |
63 | // Ensure that the page is reloaded even if there are unsaved changes.
64 | window.onbeforeunload = null;
65 |
66 | // Reload the window.
67 | window.location.reload();
68 |
69 | if (MIXPANEL_TOKEN) mixpanel.track("Performed hard reset");
70 | }
71 | };
72 |
73 | return (
74 |
75 |
76 |
77 | Settings
78 |
79 |
80 | {
85 | setSettings({ ...settings, model: v });
86 |
87 | if (MIXPANEL_TOKEN) mixpanel.track("Changed model");
88 | }}
89 | />
90 |
91 |
92 |
93 | {
98 | setSettings({ ...settings, temp: v });
99 |
100 | if (MIXPANEL_TOKEN) mixpanel.track("Changed temperature");
101 | }}
102 | color={getBranchesNodeTypeDarkColor(BranchesNodeType.User)}
103 | max={1.25}
104 | min={0}
105 | step={0.01}
106 | />
107 |
108 | {
113 | setSettings({ ...settings, N_ANSWER_FANOUT: v });
114 |
115 | if (MIXPANEL_TOKEN) mixpanel.track("Changed answer fanout");
116 | }}
117 | color={getBranchesNodeTypeDarkColor(BranchesNodeType.User)}
118 | max={10}
119 | min={1}
120 | step={1}
121 | />
122 |
123 | {
128 | setSettings({ ...settings, N_EXPLANATION_FANOUT: v });
129 |
130 | if (MIXPANEL_TOKEN) mixpanel.track("Changed explanation fanout");
131 | }}
132 | color={getBranchesNodeTypeDarkColor(BranchesNodeType.User)}
133 | max={10}
134 | min={1}
135 | step={1}
136 | />
137 |
138 | ) => {
144 | setSettings({ ...settings, autoZoom: event.target.checked });
145 |
146 | if (MIXPANEL_TOKEN) mixpanel.track("Changed auto zoom");
147 | }}
148 | >
149 | Auto Zoom
150 |
151 |
152 |
153 |
154 |
157 |
158 |
161 |
162 |
163 |
164 | );
165 | });
166 |
--------------------------------------------------------------------------------
/src/components/nodes/CustomNode.tsx:
--------------------------------------------------------------------------------
1 | import { NodeProps, Position, Handle } from "reactflow";
2 | import { FaChevronDown, FaChevronRight } from "react-icons/fa";
3 |
4 | export default function CustomNode({ data }: NodeProps) {
5 | return (
6 |
7 |
8 |
21 |
29 |
41 | {data.label}
42 |
43 |
44 | {data.expandable && (
45 |
46 | {" "}
47 | {/* Adds padding to the right of the icon */}
48 | {data.expanded ? : }
49 |
50 | )}
51 |
52 |
53 |
54 | );
55 | }
56 |
--------------------------------------------------------------------------------
/src/components/nodes/LabelUpdaterNode.tsx:
--------------------------------------------------------------------------------
1 | import { MIXPANEL_TOKEN } from "../../main";
2 | import { Row } from "../../utils/chakra";
3 | import { modifyBranchesNodeLabel, modifyReactFlowNodeProperties } from "../../utils/branchesNode";
4 | import { ToTNodeData } from "../../utils/types";
5 | import { Box, Input, Tooltip } from "@chakra-ui/react";
6 | import mixpanel from "mixpanel-browser";
7 | import { useEffect, useState } from "react";
8 | import { Handle, Position, useReactFlow } from "reactflow";
9 | import { FaChevronDown, FaChevronRight } from "react-icons/fa"; // Import the icons
10 |
11 | export function LabelUpdaterNode({
12 | id,
13 | data,
14 | isConnectable,
15 | }: {
16 | id: string;
17 | data: ToTNodeData;
18 | isConnectable: boolean;
19 | }) {
20 | const { setNodes } = useReactFlow();
21 |
22 | const [renameLabel, setRenameLabel] = useState(data.label);
23 |
24 | const inputId = `renameInput-${id}`;
25 |
26 | // Select the input element on mount.
27 | useEffect(() => {
28 | const input = document.getElementById(inputId) as HTMLInputElement | null;
29 |
30 | // Have to do this with a bit of a delay to
31 | // ensure it works when triggered via navbar.
32 | setTimeout(() => input?.select(), 50);
33 | }, []);
34 |
35 | const cancel = () => {
36 | setNodes((nodes) =>
37 | // Reset the node type to the default
38 | // type and make it draggable again.
39 | modifyReactFlowNodeProperties(nodes, {
40 | id,
41 | type: undefined,
42 | draggable: true,
43 | })
44 | );
45 |
46 | if (MIXPANEL_TOKEN) mixpanel.track("Canceled renaming");
47 | };
48 |
49 | const submit = () => {
50 | setNodes((nodes) =>
51 | modifyBranchesNodeLabel(nodes, {
52 | id,
53 | label: renameLabel,
54 | })
55 | );
56 |
57 | if (MIXPANEL_TOKEN) mixpanel.track("Node renamed");
58 | };
59 |
60 | return (
61 |
62 |
63 |
64 |
65 |
66 | {data.expanded ? : }{" "}
67 | setRenameLabel(e.target.value)}
72 | onKeyDown={(e) =>
73 | e.key === "Enter" ? submit() : e.key === "Escape" && cancel()
74 | }
75 | className="nodrag" // https://reactflow.dev/docs/api/nodes/custom-nodes/#prevent-dragging--selecting
76 | textAlign="center"
77 | size="xs"
78 | // px={6}
79 | />
80 |
81 |
82 |
83 |
84 |
85 | );
86 | }
87 |
--------------------------------------------------------------------------------
/src/components/nodes/useAnimatedNodes.tsx:
--------------------------------------------------------------------------------
1 | import { useEffect, useState } from "react";
2 | import { Node, useReactFlow } from "reactflow";
3 | import { timer } from "d3-timer";
4 |
5 | export type UseAnimatedNodeOptions = {
6 | animationDuration?: number;
7 | };
8 |
9 | function useAnimatedNodes(
10 | nodes: Node[],
11 | { animationDuration = 300 }: UseAnimatedNodeOptions = {}
12 | ) {
13 | const [tmpNodes, setTmpNodes] = useState(nodes);
14 | const { getNode } = useReactFlow();
15 |
16 | useEffect(() => {
17 | const transitions = nodes.map((node) => ({
18 | id: node.id,
19 | from: getNode(node.id)?.position ?? node.position,
20 | to: node.position,
21 | node,
22 | }));
23 |
24 | const t = timer((elapsed) => {
25 | const s = elapsed / animationDuration;
26 |
27 | const currNodes = transitions.map(({ node, from, to }) => {
28 | return {
29 | ...node,
30 | position: { x: from.x + (to.x - from.x) * s, y: from.y + (to.y - from.y) * s },
31 | };
32 | });
33 |
34 | setTmpNodes(currNodes);
35 |
36 | if (elapsed > animationDuration) {
37 | // it's important to set the final nodes here to avoid glitches
38 | setTmpNodes(nodes);
39 | t.stop();
40 | }
41 | });
42 |
43 | return () => t.stop();
44 | }, [nodes, getNode, animationDuration]);
45 |
46 | return { nodes: tmpNodes };
47 | }
48 |
49 | export default useAnimatedNodes;
50 |
--------------------------------------------------------------------------------
/src/components/nodes/useExpandCollapse.tsx:
--------------------------------------------------------------------------------
1 | import { useMemo } from "react";
2 | import { Node, Edge, XYPosition } from "reactflow";
3 | import { HierarchyNode, HierarchyPointNode, stratify, tree } from "d3-hierarchy";
4 | import { ToTNodeData } from "../../utils/types";
5 |
6 | type ExpandCollapseNode = Node;
7 |
8 | export type UseExpandCollapseOptions = {
9 | layoutNodes?: boolean;
10 | treeWidth?: number;
11 | treeHeight?: number;
12 | };
13 |
14 | function isHierarchyPointNode(
15 | pointNode: HierarchyNode | HierarchyPointNode
16 | ): pointNode is HierarchyPointNode {
17 | return (
18 | typeof (pointNode as HierarchyPointNode).x === "number" &&
19 | typeof (pointNode as HierarchyPointNode).y === "number"
20 | );
21 | }
22 |
23 | function useExpandCollapse(
24 | nodes: Node[],
25 | edges: Edge[],
26 | { layoutNodes = true, treeWidth = 220, treeHeight = 100 }: UseExpandCollapseOptions = {} // TODO: make layout true
27 | ): { nodes: Node[]; edges: Edge[] } {
28 | return useMemo(() => {
29 | if (nodes.length === 0) {
30 | return { nodes: [], edges: [] };
31 | }
32 |
33 | const hierarchy = stratify()
34 | .id((d) => d.id)
35 | .parentId((d: Node) => {
36 | const parent_id = edges.find((e: Edge) => e.target === d.id)?.source;
37 | return parent_id || null; // returns null if parent_id is an empty string or undefined
38 | })(nodes);
39 |
40 | hierarchy.descendants().forEach((d) => {
41 | d.data.data.expandable = !!d.children?.length;
42 | d.children = d.data.data.expanded ? d.children : undefined;
43 | });
44 |
45 | const layout = tree()
46 | .nodeSize([treeWidth, treeHeight])
47 | .separation(() => 1);
48 |
49 | const root = layoutNodes ? layout(hierarchy) : hierarchy;
50 |
51 | return {
52 | nodes: root.descendants().map((d) => ({
53 | ...d.data,
54 | // This bit is super important! We *mutated* the object in the `forEach`
55 | // above so the reference is the same. React needs to see a new reference
56 | // to trigger a re-render of the node.
57 | data: { ...d.data.data },
58 | type: "custom",
59 | position: isHierarchyPointNode(d) ? { x: d.x, y: d.y } : d.data.position,
60 | })),
61 | edges: edges.filter(
62 | (edge) =>
63 | root.find((h) => h.id === edge.source) && root.find((h) => h.id === edge.target)
64 | ),
65 | };
66 | }, [nodes, edges, layoutNodes, treeWidth, treeHeight]);
67 | }
68 |
69 | export default useExpandCollapse;
70 |
--------------------------------------------------------------------------------
/src/components/tree.ts:
--------------------------------------------------------------------------------
1 | import { Node } from "reactflow";
2 | import { ToTNodeData, BranchesNodeType } from "../utils/types";
3 |
4 | export const treeDemoNodes: Node[] = [
5 | {
6 | id: "0",
7 | position: { x: 0, y: 0 },
8 | data: {
9 | id: "0",
10 | parent_id: "",
11 | input: "HumanEval/4",
12 | steps: [],
13 | solutions: [],
14 | output: "",
15 | label: "HumanEval/4",
16 | score: 60,
17 | evals: "",
18 | isValid: true,
19 | isTerminal: false,
20 | errors: [],
21 | explanations: [],
22 | expandable: true,
23 | expanded: true,
24 | },
25 | },
26 | ];
27 |
28 | export const treeDemoEdges = [];
29 |
30 | function processNode(node: Node): Node {
31 | const data = { ...node.data } as ToTNodeData;
32 | if (data.branchesNodeType == null) {
33 | data.branchesNodeType = BranchesNodeType.GPT;
34 | }
35 | // if (data.text == null) {
36 | // data.text = '';
37 | // }
38 | if (!Array.isArray(data.evals)) {
39 | data.evals = [];
40 | }
41 | data.text = data.steps.join("\n") + "\n" + data.evals.join("\n") + "\n" + data.label;
42 | let color = "#f7d0a1";
43 | if (data.isValid) {
44 | color = "#d9f3d6";
45 | }
46 | if (data.isTerminal) {
47 | color = "rgb(233, 216, 253)";
48 | }
49 |
50 | return {
51 | ...node,
52 | height: 38,
53 | width: 150,
54 | selected: false,
55 | style: {
56 | outline: data.isTerminal ? "1px dashed #f7d0a1" : "",
57 | background: color,
58 | },
59 | data,
60 | };
61 | }
62 |
63 | export const treeDemo = {
64 | nodes: treeDemoNodes.map(processNode),
65 | edges: treeDemoEdges,
66 | viewport: {
67 | x: 0,
68 | y: 0,
69 | zoom: 1,
70 | },
71 | };
72 |
--------------------------------------------------------------------------------
/src/components/utils/APIKeyInput.tsx:
--------------------------------------------------------------------------------
1 | import { BoxProps } from "@chakra-ui/react";
2 | import { LabeledPasswordInputWithLink } from "./LabeledInputs";
3 |
4 | export function APIKeyInput({
5 | apiKey,
6 | setApiKey,
7 | ...others
8 | }: {
9 | apiKey: string | null;
10 | setApiKey: (apiKey: string) => void;
11 | } & BoxProps) {
12 | return (
13 |
23 | );
24 | }
25 |
--------------------------------------------------------------------------------
/src/components/utils/BigButton.tsx:
--------------------------------------------------------------------------------
1 | import { ButtonProps, Button, Tooltip } from "@chakra-ui/react";
2 |
3 | import { adjustColor } from "../../utils/color";
4 |
5 | export function BigButton({
6 | color,
7 | tooltip,
8 | ...others
9 | }: { color: string; tooltip: string } & ButtonProps) {
10 | return (
11 |
12 |
20 |
21 | );
22 | }
23 |
--------------------------------------------------------------------------------
/src/components/utils/LabeledInputs.tsx:
--------------------------------------------------------------------------------
1 | import { Row } from "../../utils/chakra";
2 | import { ExternalLinkIcon } from "@chakra-ui/icons";
3 | import {
4 | Box,
5 | BoxProps,
6 | Button,
7 | Input,
8 | InputGroup,
9 | InputRightElement,
10 | Link,
11 | Select,
12 | Slider,
13 | SliderFilledTrack,
14 | SliderThumb,
15 | SliderTrack,
16 | Textarea,
17 | } from "@chakra-ui/react";
18 | import { useEffect, useState } from "react";
19 |
20 | export function LabeledSlider({
21 | label,
22 | value,
23 | setValue,
24 | color,
25 | max,
26 | min,
27 | step,
28 | ...others
29 | }: {
30 | label: string;
31 | value: number;
32 | setValue: (value: number) => void;
33 | color: string;
34 | max: number;
35 | min: number;
36 | step: number;
37 | } & BoxProps) {
38 | return (
39 |
40 | {label}: {value}
41 | setValue(v)}
46 | max={max}
47 | min={min}
48 | step={step}
49 | >
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 | );
58 | }
59 |
60 | export function LabeledInput({
61 | label,
62 | value,
63 | setValue,
64 | ...others
65 | }: {
66 | label: string;
67 | value: string;
68 | setValue: (value: string) => void;
69 | } & BoxProps) {
70 | return (
71 |
72 | {label}:
73 | setValue(e.target.value)} value={value} />
74 |
75 | );
76 | }
77 |
78 | export function LabeledPasswordInputWithLink({
79 | label,
80 | linkLabel,
81 | placeholder,
82 | link,
83 | value,
84 | setValue,
85 | ...others
86 | }: {
87 | label: string;
88 | linkLabel: string;
89 | placeholder?: string;
90 | link: string;
91 | value: string;
92 | setValue: (value: string) => void;
93 | } & BoxProps) {
94 | const [show, setShow] = useState(false);
95 |
96 | return (
97 |
98 |
99 | {label}:
100 |
107 | {linkLabel}
108 |
109 |
110 |
111 |
112 | setValue(e.target.value)}
117 | />
118 |
119 |
129 |
130 |
131 |
132 | );
133 | }
134 |
135 | export function LabeledTextArea({
136 | label,
137 | value,
138 | setValue,
139 | textAreaId,
140 | ...others
141 | }: {
142 | label: string;
143 | value: string;
144 | textAreaId: string | undefined;
145 | setValue: (value: string) => void;
146 | } & BoxProps) {
147 | return (
148 |
149 | {label}:
150 |
159 | );
160 | }
161 |
162 | export function SelfSelectingLabeledTextArea({
163 | label,
164 | value,
165 | setValue,
166 | id,
167 | ...others
168 | }: {
169 | label: string;
170 | value: string;
171 | id: string;
172 | setValue: (value: string) => void;
173 | } & BoxProps) {
174 | useEffect(
175 | () => (window.document.getElementById(id) as HTMLTextAreaElement | null)?.select(),
176 | []
177 | );
178 |
179 | return (
180 |
187 | );
188 | }
189 |
190 | export function LabeledSelect({
191 | label,
192 | value,
193 | setValue,
194 | options,
195 | ...others
196 | }: {
197 | label: string;
198 | value: string;
199 | options: string[];
200 | setValue: (value: string) => void;
201 | } & BoxProps) {
202 | return (
203 |
204 | {label}:
205 |
210 |
211 | );
212 | }
213 |
--------------------------------------------------------------------------------
/src/components/utils/Markdown.tsx:
--------------------------------------------------------------------------------
1 | import React, { useState, useEffect, useRef, ReactNode, RefObject } from "react";
2 | import ReactMarkdown, { Components } from "react-markdown";
3 | import "highlight.js/styles/atom-one-light.css";
4 | import rehypeHighlight from "rehype-highlight";
5 | import { Button, Box, Code, Text, useTheme, List, ListItem } from "@chakra-ui/react";
6 | import { CopyIcon } from "@chakra-ui/icons";
7 | import { Row, Column } from "../../utils/chakra";
8 | import { copySnippetToClipboard } from "../../utils/clipboard";
9 | import { solidity, yul } from "highlightjs-solidity";
10 | import { PluggableList } from "unified";
11 |
12 | const CodeblockTitleBar = ({
13 | language,
14 | codeRef,
15 | }: {
16 | language?: string;
17 | codeRef: RefObject;
18 | }) => {
19 | // Grabbing the default font family from Chakra via
20 | // useTheme to override the markdown code font family.
21 | const theme = useTheme();
22 |
23 | return (
24 |
37 | {language || "plaintext"}
38 |
39 |
40 | );
41 | };
42 |
43 | const CopyCodeButton = ({ codeRef }: { codeRef: RefObject }) => {
44 | const [copied, setCopied] = useState(false);
45 |
46 | const handleCopyButtonClick = async (e: React.MouseEvent) => {
47 | e.stopPropagation(); // Prevent this from triggering edit mode in the parent.
48 |
49 | if (await copySnippetToClipboard(stringifyChildren(codeRef.current ?? [])))
50 | setCopied(true);
51 | };
52 |
53 | useEffect(() => {
54 | if (copied) {
55 | const timer = setTimeout(() => setCopied(false), 2000);
56 | return () => clearTimeout(timer);
57 | }
58 | }, [copied]);
59 |
60 | return (
61 |
70 | );
71 | };
72 |
73 | const Codeblock = ({
74 | className,
75 | inline,
76 | children,
77 | ...props
78 | }: {
79 | className?: string;
80 | inline?: boolean;
81 | children: ReactNode[];
82 | }) => {
83 | const match = /language-(\w+)/.exec(className || "");
84 | const codeRef = useRef([]);
85 |
86 | useEffect(() => {
87 | codeRef.current = children;
88 | }, [children]);
89 |
90 | return !inline ? (
91 |
97 |
98 |
106 | {children}
107 |
108 |
109 | ) : (
110 |
116 | {children}
117 |
118 | );
119 | };
120 |
121 | export const Markdown = ({ text }: { text: string }) => {
122 | return (
123 |
124 |
125 | {text}
126 |
127 |
128 | );
129 | };
130 |
131 | const rehypePlugins: PluggableList = [
132 | [rehypeHighlight, { ignoreMissing: true, languages: { solidity, yul } }],
133 | ];
134 |
135 | const components: Components = {
136 | ul({ children }) {
137 | return (
138 |
139 | {children}
140 |
141 | );
142 | },
143 | ol({ children }) {
144 | return (
145 |
146 | {children}
147 |
148 | );
149 | },
150 | li({ children }) {
151 | return (
152 |
153 | {children?.filter(
154 | (child: ReactNode) => !(typeof child === "string" && child.trim() === "")
155 | )}
156 |
157 | );
158 | },
159 | blockquote({ children }) {
160 | return (
161 |
162 | {children?.filter(
163 | (child: ReactNode) => !(typeof child === "string" && child.trim() === "")
164 | )}
165 |
166 | );
167 | },
168 | code(props) {
169 | return ;
170 | },
171 | };
172 |
173 | // Recursively extract text value from the children prop of a ReactMarkdown component.
174 | // This function is necessary because some children can contain inline elements,
175 | // and simple concatenation is not sufficient for extracting text data.
176 | // It navigates deeply within nested structures to acquire the intended text.
177 | const stringifyChildren = (children: ReactNode[]): string => {
178 | return (
179 | children
180 | .reduce((concatenatedText: string, currentNode: ReactNode) => {
181 | if (React.isValidElement(currentNode) && currentNode.props.children) {
182 | return (
183 | concatenatedText +
184 | stringifyChildren(
185 | Array.isArray(currentNode.props.children)
186 | ? currentNode.props.children
187 | : [currentNode.props.children]
188 | )
189 | );
190 | }
191 |
192 | // Ignore non-text ReactNodes, fixing [object Object] error.
193 | if (typeof currentNode === "object") {
194 | return concatenatedText;
195 | }
196 |
197 | return concatenatedText + String(currentNode || "");
198 | }, "")
199 | // react-markdown sometimes includes a newline at the end of the children array.
200 | // We remove it if needed here to avoid a newline at the end of the copied text.
201 | .replace(/\n$/, "")
202 | );
203 | };
204 |
--------------------------------------------------------------------------------
/src/components/utils/NavigationBar.tsx:
--------------------------------------------------------------------------------
1 | import { Button, Box, Text } from "@chakra-ui/react";
2 |
3 | import { Row } from "../../utils/chakra";
4 |
5 | export function NavigationBar({
6 | onOpenSettingsModal,
7 | onToggleAnswerFilter,
8 | showAnswerPathOnly,
9 | }: {
10 | onOpenSettingsModal: () => void;
11 | onToggleAnswerFilter: () => void;
12 | showAnswerPathOnly: boolean;
13 | }) {
14 | return (
15 |
21 |
22 | Branches
23 |
24 |
25 |
26 |
27 |
36 |
45 |
46 | );
47 | }
48 |
--------------------------------------------------------------------------------
/src/components/utils/Whisper.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useEffect } from "react";
2 | import { Button, Box, Spinner } from "@chakra-ui/react";
3 | import { FaMicrophone, FaMicrophoneSlash } from "react-icons/fa";
4 | import mixpanel from "mixpanel-browser";
5 | import { MIXPANEL_TOKEN } from "../../main";
6 | import { useHotkeys } from "react-hotkeys-hook";
7 | import { HOTKEY_CONFIG } from "../../utils/constants";
8 | import { getPlatformModifierKey } from "../../utils/platform";
9 |
10 | export const Whisper = ({
11 | onConvertedText,
12 | apiKey,
13 | }: {
14 | onConvertedText: (text: string) => void;
15 | apiKey: string | null;
16 | }) => {
17 | const [isRecording, setIsRecording] = useState(false);
18 | const [isTranscribing, setIsTranscribing] = useState(false);
19 | const [mediaRecorder, setMediaRecorder] = useState(null);
20 | const [hasRecordingSupport, setHasRecordingSupport] = useState(false);
21 | const [isDesktopDevice, setIsDesktopDevice] = useState(false);
22 |
23 | useEffect(() => {
24 | // Not inlined because of some TypeScript nonsense.
25 | if (navigator.mediaDevices && MediaRecorder) {
26 | setHasRecordingSupport(true);
27 | } else setHasRecordingSupport(false);
28 |
29 | setIsDesktopDevice(
30 | // https://stackoverflow.com/questions/11381673/detecting-a-mobile-browser
31 | !(
32 | window.navigator.userAgent?.toLowerCase()?.includes("mobi") ??
33 | window.innerWidth < 1024
34 | )
35 | );
36 | }, []);
37 |
38 | const onDataAvailable = (e: BlobEvent) => {
39 | const formData = new FormData();
40 | formData.append("file", e.data, "recording.webm");
41 | formData.append("model", "whisper-1");
42 |
43 | setIsTranscribing(true);
44 |
45 | fetch("https://api.openai.com/v1/audio/transcriptions", {
46 | method: "POST",
47 | headers: {
48 | Authorization: `Bearer ${apiKey}`,
49 | },
50 | body: formData,
51 | })
52 | .then((response) => response.json())
53 | .then((data) => onConvertedText(data.text))
54 | .catch((err) => console.error("Error transcribing: ", err))
55 | .finally(() => setIsTranscribing(false));
56 | };
57 |
58 | const startRecording = async () => {
59 | setIsRecording(true);
60 |
61 | try {
62 | const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
63 |
64 | const recorder = new MediaRecorder(stream);
65 |
66 | recorder.onstop = () => {
67 | stream.getTracks().forEach((track) => {
68 | track.stop();
69 | });
70 |
71 | if (stream.active) {
72 | stream.getTracks().forEach((track) => {
73 | stream.removeTrack(track);
74 | });
75 | }
76 | };
77 |
78 | recorder.addEventListener("dataavailable", onDataAvailable);
79 | recorder.start();
80 |
81 | setMediaRecorder(recorder);
82 |
83 | if (MIXPANEL_TOKEN) mixpanel.track("Started recording");
84 | } catch (error) {
85 | console.error("Error starting recorder: ", error);
86 | setIsRecording(false);
87 | }
88 | };
89 |
90 | const stopRecording = () => {
91 | if (mediaRecorder) mediaRecorder.stop();
92 |
93 | setIsRecording(false);
94 |
95 | if (MIXPANEL_TOKEN) mixpanel.track("Stopped recording");
96 | };
97 |
98 | const modifierKey = getPlatformModifierKey();
99 |
100 | useHotkeys(
101 | `${modifierKey}+L`,
102 | () => (isRecording ? stopRecording() : startRecording()),
103 | HOTKEY_CONFIG
104 | );
105 |
106 | return (
107 | <>
108 | {hasRecordingSupport && isDesktopDevice && (
109 |
110 |
130 |
131 | )}
132 | >
133 | );
134 | };
135 |
--------------------------------------------------------------------------------
/src/index.css:
--------------------------------------------------------------------------------
1 | .selected {
2 | box-shadow: 0px 0px 0px 2px #e73324, 0px 0px 20px 2px #e73324 !important;
3 | }
4 |
5 | .react-flow__node {
6 | border-radius: 6px;
7 | border-width: 0px;
8 | }
9 |
10 | .react-flow__node:not(.selected):hover {
11 | box-shadow: 0 0 0 0.5px #1a192b !important;
12 | }
13 |
14 | /* Enable if you
15 | want bigger handles:
16 | .react-flow__handle {
17 | width: 20px;
18 | height: 7px;
19 | border-radius: 5px;
20 | } */
21 |
22 | .markdown-wrapper h1,
23 | .markdown-wrapper h2,
24 | .markdown-wrapper h3,
25 | .markdown-wrapper h4,
26 | .markdown-wrapper h5,
27 | .markdown-wrapper h6 {
28 | font-size: inherit;
29 | font-weight: 500;
30 | }
31 |
32 | .markdown-wrapper blockquote {
33 | margin: revert;
34 | }
35 |
36 | .markdown-wrapper h1 {
37 | font-size: 2em;
38 | }
39 |
40 | .markdown-wrapper h2 {
41 | font-size: 1.5em;
42 | }
43 |
44 | .markdown-wrapper h3 {
45 | font-size: 1.17em;
46 | }
47 |
48 | .markdown-wrapper h4 {
49 | font-size: 1em;
50 | }
51 |
52 | .markdown-wrapper h5 {
53 | font-size: 0.83em;
54 | }
55 |
56 | .markdown-wrapper h6 {
57 | font-size: 0.67em;
58 | }
59 |
60 | .markdown-wrapper hr {
61 | background-color: currentColor;
62 | height: 2px;
63 | border: 0px;
64 | border-radius: 6px;
65 | }
66 |
67 | .node-info h4 {
68 | margin-bottom: 1rem;
69 | margin-top: 1rem;
70 | }
71 |
72 | .tab-panel-full-width {
73 | padding-left: 0 !important;
74 | padding-right: 0 !important;
75 | }
76 |
77 | .eval-list li span {
78 | margin-left: 1.5rem;
79 | }
80 | .eval-list li span:first-of-type {
81 | margin-left: 0;
82 | }
--------------------------------------------------------------------------------
/src/main.tsx:
--------------------------------------------------------------------------------
1 | import React from "react";
2 |
3 | import ReactDOM from "react-dom/client";
4 |
5 | import { ReactFlowProvider } from "reactflow";
6 |
7 | import { ChakraProvider } from "@chakra-ui/react";
8 |
9 | import mixpanel from "mixpanel-browser";
10 |
11 | import App from "./components/App";
12 |
13 | import "./index.css";
14 |
15 | export const MIXPANEL_TOKEN = import.meta.env.VITE_MIXPANEL_TOKEN;
16 |
17 | if (MIXPANEL_TOKEN) mixpanel.init(MIXPANEL_TOKEN);
18 |
19 | ReactDOM.createRoot(document.getElementById("root") as HTMLElement).render(
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | );
28 |
--------------------------------------------------------------------------------
/src/types/highlightjs-solidity.d.ts:
--------------------------------------------------------------------------------
1 | declare module "highlightjs-solidity" {
2 | import { HLJSApi, LanguageFn } from "highlight.js";
3 |
4 | export const solidity: LanguageFn;
5 | export const yul: LanguageFn;
6 |
7 | function hljsDefineSolidity(hljs: HLJSApi): void;
8 |
9 | export default hljsDefineSolidity;
10 | }
11 |
--------------------------------------------------------------------------------
/src/utils/apikey.ts:
--------------------------------------------------------------------------------
1 | export function isValidAPIKey(apiKey: string | null) {
2 | return apiKey?.length == 51 && apiKey?.startsWith("sk-");
3 | }
4 |
--------------------------------------------------------------------------------
/src/utils/branchesEdge.ts:
--------------------------------------------------------------------------------
1 | import { Edge } from "reactflow";
2 |
3 | /*//////////////////////////////////////////////////////////////
4 | CONSTRUCTORS
5 | //////////////////////////////////////////////////////////////*/
6 |
7 | export function newBranchesEdge({
8 | source,
9 | target,
10 | animated,
11 | }: {
12 | source: string;
13 | target: string;
14 | animated: boolean;
15 | }): Edge {
16 | return {
17 | id: `${source}-${target}`,
18 | source,
19 | target,
20 | animated,
21 | };
22 | }
23 |
24 | /*//////////////////////////////////////////////////////////////
25 | TRANSFORMERS
26 | //////////////////////////////////////////////////////////////*/
27 |
28 | export function addBranchesEdge(
29 | existingEdges: Edge[],
30 | { source, target, animated }: { source: string; target: string; animated: boolean }
31 | ): Edge[] {
32 | const newEdge = newBranchesEdge({ source, target, animated });
33 |
34 | return [...existingEdges, newEdge];
35 | }
36 |
37 | export function modifyBranchesEdge(
38 | existingEdges: Edge[],
39 | { source, target, animated }: { source: string; target: string; animated: boolean }
40 | ): Edge[] {
41 | return existingEdges.map((edge) => {
42 | if (edge.id !== `${source}-${target}`) return edge;
43 |
44 | const copy = { ...edge };
45 |
46 | copy.animated = animated;
47 |
48 | return copy;
49 | });
50 | }
51 |
--------------------------------------------------------------------------------
/src/utils/branchesNode.ts:
--------------------------------------------------------------------------------
1 | import { Node, Edge } from "reactflow";
2 |
3 | import { STALE_STREAM_ERROR_MESSAGE, STREAM_CANCELED_ERROR_MESSAGE } from "./constants";
4 | import { BranchesNodeType, ToTNodeData, ReactFlowNodeTypes } from "./types";
5 | import { getBranchesNodeColor } from "./color";
6 | import { generateNodeId } from "./nodeId";
7 | import { formatAutoLabel, getCurrentNumbers } from "./prompt";
8 |
9 | /*//////////////////////////////////////////////////////////////
10 | CONSTRUCTORS
11 | //////////////////////////////////////////////////////////////*/
12 |
13 | export function newBranchesNode({
14 | id,
15 | x,
16 | y,
17 | branchesNodeType,
18 | input,
19 | text,
20 | streamId,
21 | steps,
22 | solutions,
23 | style,
24 | errors,
25 | explanations,
26 | }: {
27 | id?: string;
28 | x: number;
29 | y: number;
30 | branchesNodeType: BranchesNodeType;
31 | input: string;
32 | text: string;
33 | streamId?: string;
34 | steps: string[];
35 | solutions?: string[];
36 | style: any;
37 | errors: string[];
38 | explanations: string[];
39 | }): Node {
40 | return {
41 | id: id ?? generateNodeId(),
42 | position: { x, y },
43 | style: {
44 | background: style.background,
45 | },
46 | data: {
47 | expanded: true,
48 | expandable: true,
49 | label: text,
50 | branchesNodeType,
51 | errors,
52 | input,
53 | steps,
54 | solutions: solutions ?? [],
55 | explanations: explanations ?? [],
56 | streamId,
57 | text,
58 | },
59 | };
60 | }
61 |
62 | /*//////////////////////////////////////////////////////////////
63 | TRANSFORMERS
64 | //////////////////////////////////////////////////////////////*/
65 |
66 | export function modifyReactFlowNodeProperties(
67 | existingNodes: Node[],
68 | {
69 | id,
70 | type,
71 | draggable,
72 | }: { id: string; type: ReactFlowNodeTypes | undefined; draggable: boolean }
73 | ): Node[] {
74 | return existingNodes.map((node) => {
75 | if (node.id !== id) return node;
76 |
77 | const copy = { ...node, data: { ...node.data }, type, draggable };
78 |
79 | return copy;
80 | });
81 | }
82 |
83 | export function modifyBranchesNodeText(
84 | existingNodes: Node[],
85 | {
86 | asHuman,
87 | id,
88 | text,
89 | isRunning,
90 | }: { asHuman: boolean; id: string; text: string; isRunning: boolean }
91 | ): Node[] {
92 | return existingNodes.map((node) => {
93 | if (node.id !== id) return node;
94 |
95 | const copy = { ...node, data: { ...node.data } };
96 |
97 | copy.data.text = text;
98 | copy.data.input = text;
99 | copy.data.label = text;
100 |
101 | // If the node's branchesNodeType is GPT and we're changing
102 | // it as a human then its type becomes GPT + Human.
103 | if (asHuman && copy.data.branchesNodeType === BranchesNodeType.GPT) {
104 | copy.style = {
105 | ...copy.style,
106 | background: getBranchesNodeColor(
107 | false,
108 | isRunning,
109 | copy.data.isTerminal || false,
110 | true,
111 | copy.data.score || 0
112 | ),
113 | };
114 |
115 | copy.data.branchesNodeType = BranchesNodeType.TweakedGPT;
116 | }
117 |
118 | // Generate auto label based on prompt text, and preserve custom label
119 | if (!copy.data.hasCustomlabel) {
120 | copy.data.label = copy.data.text
121 | ? formatAutoLabel(copy.data.text)
122 | : displayNameFromBranchesNodeType(copy.data.branchesNodeType);
123 | }
124 |
125 | return copy;
126 | });
127 | }
128 |
129 | export function modifyBranchesNodeLabel(
130 | existingNodes: Node[],
131 | { id, type, label }: { id: string; type?: BranchesNodeType; label: string }
132 | ): Node[] {
133 | return existingNodes.map((node) => {
134 | if (node.id !== id) return node;
135 |
136 | const copy = {
137 | ...node,
138 | data: { ...node.data, label, hasCustomlabel: true },
139 | type,
140 | draggable: undefined,
141 | };
142 |
143 | return copy;
144 | });
145 | }
146 |
147 | export function setBranchesNodeStreamId(
148 | existingNodes: Node[],
149 | { id, streamId }: { id: string; streamId: string | undefined }
150 | ) {
151 | return existingNodes.map((node) => {
152 | if (node.id !== id) return node;
153 |
154 | return { ...node, data: { ...node.data, streamId } };
155 | });
156 | }
157 |
158 | export function checkIfTerminal(text: string): boolean {
159 | const currentNumbers = getCurrentNumbers(text);
160 | return currentNumbers === "24";
161 | }
162 |
163 | export function appendTextToBranchesNodeAsGPT(
164 | existingNodes: Node[],
165 | { id, text, streamId }: { id: string; text: string; streamId: string },
166 | isSolutionNode: boolean // Add this argument
167 | ): Node[] {
168 | return existingNodes.map((node) => {
169 | if (node.id !== id) return node;
170 |
171 | if (node.data.streamId === undefined) throw new Error(STREAM_CANCELED_ERROR_MESSAGE);
172 | if (node.data.streamId !== streamId) throw new Error(STALE_STREAM_ERROR_MESSAGE);
173 |
174 | const copy = { ...node, data: { ...node.data } };
175 | const isFirstToken = copy.data.text.length === 0;
176 |
177 | copy.data.text = text;
178 | copy.data.label = text;
179 | copy.data.steps[copy.data.steps.length - 1] = text;
180 |
181 | // Update the last element in the solutions array if isSolutionNode is true
182 | if (isSolutionNode) {
183 | copy.data.solutions[copy.data.solutions.length - 1] = text;
184 | } else {
185 | copy.data.explanations[copy.data.explanations.length - 1] = text;
186 | }
187 |
188 | if (copy.data.hasCustomlabel) return copy;
189 |
190 | if (!copy.data.label.endsWith(" ...") || isFirstToken) {
191 | copy.data.label = formatAutoLabel(copy.data.text);
192 | }
193 |
194 | return copy;
195 | });
196 | }
197 |
198 | export function markOnlyNodeAsSelected(
199 | existingNodes: Node[],
200 | id: string
201 | ): Node[] {
202 | return existingNodes.map((node) => {
203 | return { ...node, selected: node.id === id };
204 | });
205 | }
206 |
207 | /*//////////////////////////////////////////////////////////////
208 | GETTERS
209 | //////////////////////////////////////////////////////////////*/
210 |
211 | export function getBranchesNode(
212 | nodes: Node[],
213 | id: string
214 | ): Node | undefined {
215 | return nodes.find((node) => node.id === id);
216 | }
217 |
218 | export function getBranchesNodeChildren(
219 | existingNodes: Node[],
220 | existingEdges: Edge[],
221 | id: string
222 | ) {
223 | return existingNodes.filter(
224 | (node) => getBranchesNodeParent(existingNodes, existingEdges, node.id)?.id === id
225 | );
226 | }
227 |
228 | export function getBranchesNodeSiblings(
229 | existingNodes: Node[],
230 | existingEdges: Edge[],
231 | parentId: string,
232 | nodeId: string
233 | ): Node[] {
234 | // Fetch all children of the parent node
235 | const siblings = getBranchesNodeChildren(existingNodes, existingEdges, parentId);
236 |
237 | // Filter out the node itself to get its siblings
238 | return siblings.filter((node) => node.id !== nodeId);
239 | }
240 |
241 | export function getBranchesNodeParent(
242 | existingNodes: Node[],
243 | existingEdges: Edge[],
244 | id: string
245 | ): Node | undefined {
246 | let edge: Edge | undefined;
247 |
248 | // We iterate in reverse to ensure we don't try to route
249 | // through a stale (now hidden) edge to find the parent.
250 | for (let i = existingEdges.length - 1; i >= 0; i--) {
251 | const e = existingEdges[i];
252 |
253 | if (e.target === id) {
254 | edge = e;
255 | break;
256 | }
257 | }
258 |
259 | if (!edge) return;
260 |
261 | return existingNodes.find((node) => node.id === edge!.source);
262 | }
263 |
264 | // Get the lineage of the node,
265 | // where index 0 is the node,
266 | // index 1 is the node's parent,
267 | // index 2 is the node's grandparent, etc.
268 | // TODO: Eventually would be nice to have
269 | // support for connecting multiple parents!
270 | export function getBranchesNodeLineage(
271 | existingNodes: Node[],
272 | existingEdges: Edge[],
273 | id: string
274 | ): Node[] {
275 | const lineage: Node[] = [];
276 |
277 | let currentNode = getBranchesNode(existingNodes, id);
278 |
279 | while (currentNode) {
280 | lineage.push(currentNode);
281 |
282 | currentNode = getBranchesNodeParent(existingNodes, existingEdges, currentNode.id);
283 | }
284 |
285 | return lineage;
286 | }
287 |
288 | export function isBranchesNodeInLineage(
289 | existingNodes: Node[],
290 | existingEdges: Edge[],
291 | { nodeToCheck, nodeToGetLineageOf }: { nodeToCheck: string; nodeToGetLineageOf: string }
292 | ): boolean {
293 | const lineage = getBranchesNodeLineage(existingNodes, existingEdges, nodeToGetLineageOf);
294 |
295 | return lineage.some((node) => node.id === nodeToCheck);
296 | }
297 |
298 | export function getConnectionAllowed(
299 | existingNodes: Node[],
300 | existingEdges: Edge[],
301 | { source, target }: { source: string; target: string }
302 | ): boolean {
303 | return (
304 | // Check the lineage of the source node to make
305 | // sure we aren't creating a recursive connection.
306 | !isBranchesNodeInLineage(existingNodes, existingEdges, {
307 | nodeToCheck: target,
308 | nodeToGetLineageOf: source,
309 | // Check if the target node already has a parent.
310 | }) && getBranchesNodeParent(existingNodes, existingEdges, target) === undefined
311 | );
312 | }
313 |
314 | /*//////////////////////////////////////////////////////////////
315 | RENDERERS
316 | //////////////////////////////////////////////////////////////*/
317 |
318 | export function displayNameFromBranchesNodeType(
319 | branchesNodeType: BranchesNodeType,
320 | isGPT4?: boolean
321 | ): string {
322 | switch (branchesNodeType) {
323 | case BranchesNodeType.User:
324 | return "User";
325 | case BranchesNodeType.GPT:
326 | return isGPT4 === undefined ? "GPT" : isGPT4 ? "GPT-4" : "GPT-3.5";
327 | case BranchesNodeType.TweakedGPT:
328 | return displayNameFromBranchesNodeType(BranchesNodeType.GPT, isGPT4) + " (edited)";
329 | case BranchesNodeType.System:
330 | return "System";
331 | }
332 | }
333 |
--------------------------------------------------------------------------------
/src/utils/chakra.tsx:
--------------------------------------------------------------------------------
1 | import { useState, useEffect } from "react";
2 | import { Flex, FlexProps } from "@chakra-ui/react";
3 |
4 | /* Typings */
5 | export type MainAxisAlignmentStrings =
6 | | "space-between"
7 | | "space-around"
8 | | "flex-start"
9 | | "center"
10 | | "flex-end";
11 |
12 | export type MainAxisAlignment =
13 | | MainAxisAlignmentStrings
14 | | { md: MainAxisAlignmentStrings; base: MainAxisAlignmentStrings };
15 |
16 | export type CrossAxisAlignmentStrings = "flex-start" | "center" | "flex-end" | "stretch";
17 |
18 | export type CrossAxisAlignment =
19 | | CrossAxisAlignmentStrings
20 | | {
21 | md: CrossAxisAlignmentStrings;
22 | base: CrossAxisAlignmentStrings;
23 | };
24 |
25 | export class PixelMeasurement {
26 | size: number;
27 |
28 | constructor(num: number) {
29 | this.size = num;
30 | }
31 |
32 | asPxString(): string {
33 | return this.size + "px";
34 | }
35 |
36 | toString(): string {
37 | return this.asPxString();
38 | }
39 |
40 | asNumber(): number {
41 | return this.size;
42 | }
43 | }
44 |
45 | export class PercentageSize {
46 | percent: number;
47 |
48 | constructor(num: number) {
49 | if (num > 1) {
50 | throw new Error("Cannot have a percentage higher than 1!");
51 | }
52 |
53 | this.percent = num;
54 | }
55 | }
56 |
57 | export class PercentOnDesktopPixelOnMobileSize {
58 | percent: number;
59 | pixel: number;
60 |
61 | constructor({
62 | percentageSize,
63 | pixelSize,
64 | }: {
65 | percentageSize: number;
66 | pixelSize: number;
67 | }) {
68 | if (percentageSize > 1) {
69 | throw new Error("Cannot have a percentage higher than 1!");
70 | }
71 |
72 | this.percent = percentageSize;
73 | this.pixel = pixelSize;
74 | }
75 | }
76 |
77 | export class PixelSize {
78 | pixel: number;
79 |
80 | constructor(num: number) {
81 | this.pixel = num;
82 | }
83 | }
84 |
85 | export class ResponsivePixelSize {
86 | desktop: number;
87 | mobile: number;
88 |
89 | constructor({ desktop, mobile }: { desktop: number; mobile: number }) {
90 | this.mobile = mobile;
91 | this.desktop = desktop;
92 | }
93 | }
94 |
95 | /**************************************
96 | *
97 | *
98 | * Components
99 | * - Center.tsx
100 | * - Column.tsx
101 | * - Row.tsx
102 | * - RowOnDesktopColumnOnMobile.tsx
103 | * - RowOrColumn.tsx
104 | *
105 | ***************************************
106 | */
107 |
108 | /**
109 | * Center.tsx
110 | *
111 | * Creates a Flex where `justifyContent === 'center'` and `alignItems === 'center'`
112 | * If `expand === true` it will set the height and width of the Flex to 100%.
113 | * Passes all extra props to the Flex.
114 | */
115 |
116 | export type CenterProps = {
117 | children: React.ReactNode;
118 | expand?: boolean;
119 | } & FlexProps;
120 |
121 | export const Center = ({ children, expand, ...others }: CenterProps) => {
122 | if (expand) {
123 | others.height = "100%";
124 | others.width = "100%";
125 | }
126 |
127 | return (
128 |
129 | {children}
130 |
131 | );
132 | };
133 |
134 | /**
135 | * Column.tsx
136 | *
137 | * Creates a Flex with a column direction
138 | * and sets the `justifyContent` to the `mainAxisAlignment`
139 | * and the `alignItems` to the `crossAxisAlignment`.
140 | * If `expand === true` it will set the height and width of the Flex to 100%.
141 | * Passes all extra props to the Flex.
142 | */
143 |
144 | export type ColumnProps = {
145 | mainAxisAlignment: MainAxisAlignment;
146 | crossAxisAlignment: CrossAxisAlignment;
147 | children: React.ReactNode;
148 | expand?: boolean;
149 | } & FlexProps;
150 |
151 | export const Column = ({
152 | mainAxisAlignment,
153 | crossAxisAlignment,
154 | children,
155 | expand,
156 | ...others
157 | }: ColumnProps) => {
158 | if (expand) {
159 | others.height = "100%";
160 | others.width = "100%";
161 | }
162 |
163 | return (
164 |
170 | {children}
171 |
172 | );
173 | };
174 |
175 | /**
176 | * Row.tsx
177 | *
178 | * Creates a Flex with a row direction
179 | * and sets the `justifyContent` to the `mainAxisAlignment`
180 | * and the `alignItems` to the `crossAxisAlignment`.
181 | * If `expand === true` it will set the height and width of the Flex to 100%.
182 | * Passes all extra props to the Flex.
183 | */
184 |
185 | export type RowProps = {
186 | mainAxisAlignment: MainAxisAlignment;
187 | crossAxisAlignment: CrossAxisAlignment;
188 | children: React.ReactNode;
189 | expand?: boolean;
190 | } & FlexProps;
191 |
192 | export const Row = ({
193 | mainAxisAlignment,
194 | crossAxisAlignment,
195 | children,
196 | expand,
197 | ...others
198 | }: RowProps) => {
199 | if (expand) {
200 | others.height = "100%";
201 | others.width = "100%";
202 | }
203 |
204 | return (
205 |
211 | {children}
212 |
213 | );
214 | };
215 |
216 | /**
217 | * RowOnDesktopColumnOnMobile.tsx
218 | *
219 | * Creates a Flex with a row direction on desktop and a column direction on mobile.
220 | * and sets the `justifyContent` to the `mainAxisAlignment`
221 | * and the `alignItems` to the `crossAxisAlignment`.
222 | * If `expand === true` it will set the height and width of the Flex to 100%.
223 | * Passes all extra props to the Flex.
224 | */
225 | export const RowOnDesktopColumnOnMobile = ({
226 | mainAxisAlignment,
227 | crossAxisAlignment,
228 | children,
229 | expand,
230 | ...others
231 | }: RowProps) => {
232 | if (expand) {
233 | others.height = "100%";
234 | others.width = "100%";
235 | }
236 |
237 | return (
238 |
244 | {children}
245 |
246 | );
247 | };
248 |
249 | /**
250 | * RowOrColumn.tsx
251 | *
252 | * Creates a Flex which will be a row if `isRow` is true
253 | * and sets the `justifyContent` to the `mainAxisAlignment`
254 | * and the `alignItems` to the `crossAxisAlignment`.
255 | * If `expand === true` it will set the height and width of the Flex to 100%.
256 | * Passes all extra props to the Flex.
257 | */
258 | export const RowOrColumn = ({
259 | mainAxisAlignment,
260 | crossAxisAlignment,
261 | children,
262 | expand,
263 | isRow,
264 | ...others
265 | }: RowProps & { isRow: boolean }) => {
266 | if (expand) {
267 | others.height = "100%";
268 | others.width = "100%";
269 | }
270 |
271 | return (
272 |
278 | {children}
279 |
280 | );
281 | };
282 |
283 | /**************************************
284 | *
285 | *
286 | * Hooks
287 | * - useWindowSize.ts
288 | * - useLockedViewHeight.ts
289 | * - useIsMobile.ts
290 | * - useSpacedLayout.ts
291 | *
292 | ***************************************
293 | */
294 |
295 | /**
296 | * useWindowSize.ts
297 | *
298 | * Gets the height and width of the current window.
299 | */
300 | export const useWindowSize = () => {
301 | const [windowSize, setWindowSize] = useState({
302 | width: window.innerWidth,
303 | height: window.innerHeight,
304 | });
305 |
306 | useEffect(() => {
307 | // Handler to call on window resize
308 | function handleResize() {
309 | // Set window width/height to state
310 | setWindowSize({
311 | width: window.innerWidth,
312 | height: window.innerHeight,
313 | });
314 | }
315 |
316 | // Add event listener
317 | window.addEventListener("resize", handleResize);
318 |
319 | // Call handler right away so state gets updated with initial window size
320 | handleResize();
321 |
322 | // Remove event listener on cleanup
323 | return () => window.removeEventListener("resize", handleResize);
324 | }, []); // Empty array ensures that effect is only run on mount
325 |
326 | return windowSize;
327 | };
328 |
329 | /**
330 | * useLockedViewHeight.ts
331 | *
332 | * Returns the pixel count of the height of the window,
333 | * but will not return a value lower or higher than the minimum/maximum passed.
334 | */
335 | export function useLockedViewHeight({
336 | min = -1,
337 | max = Number.MAX_SAFE_INTEGER,
338 | }: {
339 | min?: number;
340 | max?: number;
341 | }) {
342 | const { height } = useWindowSize();
343 |
344 | if (height <= min) {
345 | return {
346 | windowHeight: new PixelMeasurement(min),
347 | isLocked: true,
348 | };
349 | } else if (height >= max) {
350 | return {
351 | windowHeight: new PixelMeasurement(max),
352 | isLocked: true,
353 | };
354 | } else {
355 | return {
356 | windowHeight: new PixelMeasurement(height),
357 | isLocked: false,
358 | };
359 | }
360 | }
361 |
362 | /**
363 | * useIsMobile.ts
364 | *
365 | * Returns whether the width of the window makes it likely a mobile device.
366 | * */
367 | export function useIsMobile() {
368 | const { width } = useWindowSize();
369 |
370 | return width < 768;
371 | }
372 |
373 | /**
374 | * useSpacedLayout.ts
375 | *
376 | * Takes the height of the parent, the desired spacing between children,
377 | * and the desired percentage sizes of the children (relative to their parent minus the spacing desired and the size of fixed sized children)
378 | * or the size of the child in pixels
379 | * and returns the pixel size of each child
380 | * that makes that child conform to the desired percentage.
381 | */
382 | export function useSpacedLayout({
383 | parentHeight,
384 | spacing,
385 | childSizes,
386 | }: {
387 | parentHeight: number;
388 | spacing: number;
389 | childSizes: (
390 | | PercentageSize
391 | | PercentOnDesktopPixelOnMobileSize
392 | | PixelSize
393 | | ResponsivePixelSize
394 | )[];
395 | }) {
396 | const isMobile = useIsMobile();
397 |
398 | let parentMinusSpacingAndFixedChildSizes =
399 | parentHeight -
400 | spacing * (childSizes.length - 1) -
401 | childSizes.reduce((past, value) => {
402 | if (
403 | value instanceof PixelSize ||
404 | (value instanceof PercentOnDesktopPixelOnMobileSize && isMobile)
405 | ) {
406 | return past + value.pixel;
407 | } else if (value instanceof ResponsivePixelSize) {
408 | return past + (isMobile ? value.mobile : value.desktop);
409 | } else {
410 | return past;
411 | }
412 | }, 0);
413 |
414 | let spacedChildren: PixelMeasurement[] = [];
415 |
416 | for (const size of childSizes) {
417 | if (
418 | size instanceof PercentageSize ||
419 | (size instanceof PercentOnDesktopPixelOnMobileSize && !isMobile)
420 | ) {
421 | spacedChildren.push(
422 | new PixelMeasurement(size.percent * parentMinusSpacingAndFixedChildSizes)
423 | );
424 | } else if (size instanceof PercentOnDesktopPixelOnMobileSize && isMobile) {
425 | spacedChildren.push(new PixelMeasurement(size.pixel));
426 | } else if (size instanceof ResponsivePixelSize) {
427 | spacedChildren.push(new PixelMeasurement(isMobile ? size.mobile : size.desktop));
428 | } else {
429 | spacedChildren.push(new PixelMeasurement(size.pixel));
430 | }
431 | }
432 |
433 | return {
434 | parentHeight: new PixelMeasurement(parentHeight),
435 | spacing: new PixelMeasurement(spacing),
436 | childSizes: spacedChildren,
437 | };
438 | }
439 |
--------------------------------------------------------------------------------
/src/utils/clipboard.ts:
--------------------------------------------------------------------------------
1 | export const copySnippetToClipboard = async (text: string): Promise => {
2 | try {
3 | await navigator.clipboard.writeText(text);
4 |
5 | return true;
6 | } catch (err) {
7 | console.error("Failed to copy to clipboard", err);
8 |
9 | return false;
10 | }
11 | };
12 |
--------------------------------------------------------------------------------
/src/utils/color.ts:
--------------------------------------------------------------------------------
1 | import { BranchesNodeType } from "./types";
2 |
3 | export function adjustColor(color: string, amount: number) {
4 | return (
5 | "#" +
6 | color
7 | .replace(/^#/, "")
8 | .replace(/../g, (color) =>
9 | (
10 | "0" + Math.min(255, Math.max(0, parseInt(color, 16) + amount)).toString(16)
11 | ).substr(-2)
12 | )
13 | );
14 | }
15 |
16 | export function getBranchesNodeColor(
17 | isExplanation: boolean,
18 | isRunning: boolean,
19 | isTerminal: boolean,
20 | isValid: boolean,
21 | score: number
22 | ) {
23 | if (isRunning) {
24 | return "#c5e2f6";
25 | }
26 | if (isExplanation) {
27 | return "#EEEEEE";
28 | }
29 | if (isTerminal) {
30 | return "#e9d8fd";
31 | }
32 | if (!isValid) {
33 | return "#f7d0a1";
34 | }
35 |
36 | return "#c5e2f6";
37 | }
38 |
39 | export function getBranchesNodeTypeDarkColor(branchesNodeType: BranchesNodeType) {
40 | return "#619F83";
41 | }
42 |
--------------------------------------------------------------------------------
/src/utils/constants.ts:
--------------------------------------------------------------------------------
1 | import { UseToastOptions } from "@chakra-ui/toast";
2 |
3 | import { Options } from "react-hotkeys-hook";
4 |
5 | import { NodeProps } from "reactflow";
6 |
7 | import { ReactFlowNodeTypes, Settings } from "./types";
8 |
9 | import CustomNode from "../components/nodes/CustomNode";
10 |
11 | export const REACT_FLOW_NODE_TYPES = {
12 | custom: CustomNode,
13 | };
14 |
15 | export const DEFAULT_SETTINGS: Settings = {
16 | temp: 1.2,
17 | N_ANSWER_FANOUT: 3,
18 | N_EXPLANATION_FANOUT: 3,
19 | autoZoom: true,
20 | model: "gpt-3.5-turbo",
21 | defaultPreamble: `You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: 2021-09 Current date: ${
22 | new Date().toISOString().split("T")[0]
23 | }`,
24 | };
25 |
26 | export const HOTKEY_CONFIG: Options = {
27 | preventDefault: true,
28 | enableOnFormTags: true,
29 | };
30 |
31 | export const TOAST_CONFIG: UseToastOptions = {
32 | isClosable: true,
33 | variant: "left-accent",
34 | position: "bottom-left",
35 | };
36 |
37 | export const MAX_HISTORY_SIZE = 256;
38 |
39 | export const OVERLAP_RANDOMNESS_MAX = 20;
40 |
41 | export const API_KEY_LOCAL_STORAGE_KEY = "FLUX_OPENAI_API_KEY";
42 | export const REACT_FLOW_LOCAL_STORAGE_KEY = "FLUX_REACT_FLOW_DATA";
43 | export const MODEL_SETTINGS_LOCAL_STORAGE_KEY = "FLUX_MODEL_SETTINGS";
44 | export const SAVED_CHAT_SIZE_LOCAL_STORAGE_KEY = "FLUX_SAVED_CHAT_SIZE";
45 |
46 | export const NEW_TREE_CONTENT_QUERY_PARAM = "newTreeWith";
47 |
48 | export const UNDEFINED_RESPONSE_STRING = "[UNDEFINED RESPONSE]";
49 |
50 | export const FIT_VIEW_SETTINGS = { padding: 0.1, duration: 200 };
51 |
52 | export const NEW_TREE_X_OFFSET = 600;
53 |
54 | export const STREAM_CANCELED_ERROR_MESSAGE = "STREAM_CANCELED";
55 | export const STALE_STREAM_ERROR_MESSAGE = "STALE_STREAM";
56 |
57 | // Magic number to almost always make auto-label text stay in two lines.
58 | export const MAX_AUTOLABEL_CHARS = 32;
59 |
--------------------------------------------------------------------------------
/src/utils/debounce.ts:
--------------------------------------------------------------------------------
1 | import { useEffect, useState } from "react";
2 |
3 | export function useDebouncedEffect(
4 | callback: () => void,
5 | timeout: number,
6 | deps: any[]
7 | ): boolean {
8 | const [isWaiting, setIsWaiting] = useState(false);
9 |
10 | useEffect(() => {
11 | setIsWaiting(true);
12 |
13 | const handler = setTimeout(() => {
14 | callback();
15 |
16 | setIsWaiting(false);
17 | }, timeout);
18 |
19 | // This line is secret sauce: cancel the
20 | // timeout if the effect gets triggered again.
21 | return () => clearTimeout(handler);
22 | }, [timeout, ...deps]);
23 |
24 | return isWaiting; // Useful for showing loading animations, etc.
25 | }
26 |
--------------------------------------------------------------------------------
/src/utils/humanEval.ts:
--------------------------------------------------------------------------------
1 | import HUMAN_EVAL_PROBLEMS from "./human_eval_problems.json";
2 | import { llm } from "./llm"
3 | // Prompt Templates
4 | // const q_template = (question: string): string => {
5 | // return `
6 | // You are a smart and capable agent. Given the header of python function, only output the body of the program.
7 |
8 | // QUESTION:
9 | // ----
10 | // ${question}
11 | // ----
12 | // ANSWER:
13 | // ----
14 | // `;
15 | // }
16 |
17 | export const q_template = (question: string): string => {
18 | return `${question}`;
19 | };
20 |
21 | const error2explanation = (question: string, answer: string, error: string): string => {
22 | return `
23 | You are a smart and capable agent and can learn from your mistakes. You can correctly debug and code a python program.
24 | Only output the explanation of the traceback error so that you can fix the previous answer by rewriting. Do not output code.
25 |
26 | QUESTION:
27 | ----
28 | ${question}
29 | ----
30 | ANSWER:
31 | ----
32 | ${answer}
33 | ----
34 | ERROR TRACEBACK:
35 | ----
36 | ${error}
37 | ----
38 | EXPLANATION:
39 | ----
40 | `;
41 | };
42 |
43 | const explanation2code = (
44 | question: string,
45 | answer: string,
46 | error: string,
47 | explanation: string
48 | ): string => {
49 | return `
50 | You are a smart and capable agent who can learn from mistakes. Given an incorrect code and its error traceback, correct the completion answer by incorporating the explanation.
51 | Only output the body of the completion answer.
52 |
53 | QUESTION:
54 | ----
55 | ${question}
56 | ----
57 | ANSWER:
58 | ----
59 | ${answer}
60 | ----
61 | ERROR TRACEBACK:
62 | ----
63 | ${error}
64 | ----
65 | EXPLANATION:
66 | ----
67 | ${explanation}
68 | ----
69 | ANSWER:
70 | ----
71 | `;
72 | };
73 |
74 | // currently we need to choose amongst predefined humaneval questions called task id - here we choose the 21st one
75 | const t1 = "HumanEval/22";
76 |
77 | const q1 = `from typing import List\n\n\ndef rescale_to_unit(numbers: List[float]) -> List[float]:\n """ Given list of numbers (of at least two elements), apply a linear transform to that list,\n such that the smallest number will become 0 and the largest will become 1\n >>> rescale_to_unit([1.0, 2.0, 3.0, 4.0, 5.0])\n [0.0, 0.25, 0.5, 0.75, 1.0]\n """\n`;
78 |
79 | let prompt = q_template(q1);
80 | let answer = (await llm(prompt)) as string;
81 | console.log(answer);
82 | console.log(typeof answer);
83 |
84 | let data = {
85 | // "task_id": t1,
86 | // "prompt": prompt,
87 | problem: HUMAN_EVAL_PROBLEMS[t1],
88 | completion: answer,
89 | };
90 |
91 | // Sending a POST request to the server
92 | let url = "http://127.0.0.1:5000/execute"; // Replace with your server's URL
93 | let response = await fetch(url, {
94 | method: "POST",
95 | headers: {
96 | "Content-Type": "application/json",
97 | },
98 | body: JSON.stringify(data),
99 | });
100 |
101 | // Parse JSON response
102 | let jsonResponse = await response.json();
103 | console.log(jsonResponse);
104 |
105 | // if code fails, regen
106 | if (jsonResponse.result.passed === false) {
107 | let ex_prompt = error2explanation(q1, answer, jsonResponse.result.result);
108 | let explanation = (await llm(ex_prompt, 1)) as string;
109 | console.log(explanation);
110 |
111 | let ans_prompt = explanation2code(q1, answer, jsonResponse.result.result, explanation);
112 | let re_ans = (await llm(ans_prompt, 1)) as string;
113 | console.log(re_ans);
114 | }
115 |
--------------------------------------------------------------------------------
/src/utils/llm.ts:
--------------------------------------------------------------------------------
1 | import OpenAI from "openai";
2 |
3 | const openai = new OpenAI({
4 | // The apiKey will default to process.env["OPENAI_API_KEY"]
5 | // If you want to hardcode the key (not recommended for production):
6 | // apiKey: 'YOUR_OPENAI_API_KEY'
7 | });
8 |
9 | const MAX_RETRIES = 5; // Maximum number of retry attempts
10 | const RETRY_DELAY_BASE = 2000; // Base delay in milliseconds
11 |
12 | export async function llm(
13 | promptContent: string,
14 | samples: number = 1,
15 | max_tokens: number = 300
16 | ): Promise {
17 | let retries = 0;
18 |
19 | while (retries <= MAX_RETRIES) {
20 | try {
21 | const completion = await openai.chat.completions.create({
22 | messages: [{ role: "user", content: promptContent }],
23 | // model: 'gpt-4-0613',
24 | // model: 'gpt-3.5-turbo-0613',
25 | model: "gpt-4",
26 | max_tokens, // use the parameter value
27 | n: samples,
28 | temperature: 0.7, // adjust if necessary
29 | });
30 |
31 | var response = null;
32 |
33 | if (samples === 1) {
34 | response = completion.choices[0]["message"]["content"] as string;
35 | } else {
36 | response = completion.choices.map(
37 | (choice) => choice["message"]["content"] as string
38 | );
39 | }
40 | return response;
41 | } catch (error) {
42 | console.error("Error calling OpenAI API:", error);
43 |
44 | if (retries < MAX_RETRIES) {
45 | const delay = Math.pow(2, retries) * RETRY_DELAY_BASE;
46 | console.log(`Retrying in ${delay}ms...`);
47 | await new Promise((resolve) => setTimeout(resolve, delay));
48 | retries++;
49 | } else {
50 | console.error("Max retry attempts reached. Giving up.");
51 | throw error;
52 | }
53 | }
54 | }
55 |
56 | // If we reach here, it means the maximum number of retries was reached without success
57 | throw new Error("Exceeded maximum number of retries without successful response.");
58 | }
59 |
--------------------------------------------------------------------------------
/src/utils/lstore.ts:
--------------------------------------------------------------------------------
1 | import { useState, useEffect } from "react";
2 |
3 | export function readLocalStorage(key: string): T | null {
4 | const storedValue = localStorage.getItem(key);
5 | return storedValue ? JSON.parse(storedValue) : null;
6 | }
7 |
8 | export function writeLocalStorage(key: string, value: any) {
9 | localStorage.setItem(key, JSON.stringify(value));
10 | }
11 |
12 | export function useLocalStorage(
13 | key: string
14 | ): [T | null, React.Dispatch>] {
15 | const [value, setValue] = useState(() => readLocalStorage(key));
16 |
17 | useEffect(() => setValue(readLocalStorage(key)), [key]);
18 |
19 | useEffect(() => writeLocalStorage(key, value), [key, value]);
20 |
21 | return [value, setValue];
22 | }
23 |
--------------------------------------------------------------------------------
/src/utils/mod.ts:
--------------------------------------------------------------------------------
1 | export function mod(n: number, m: number) {
2 | return ((n % m) + m) % m;
3 | }
4 |
--------------------------------------------------------------------------------
/src/utils/models.ts:
--------------------------------------------------------------------------------
1 | export function getAvailableModels(apiKey: string): Promise {
2 | return new Promise(async (resolve, reject) => {
3 | try {
4 | const response = await fetch("https://api.openai.com/v1/models", {
5 | method: "GET",
6 | headers: {
7 | Authorization: `Bearer ${apiKey}`,
8 | },
9 | })
10 | const data = await response.json();
11 | resolve(data.data.map((model: any) => model.id).sort());
12 | } catch (err) {
13 | reject(err);
14 | }
15 | });
16 | };
17 |
18 | export function getAvailableChatModels(apiKey: string): Promise {
19 | return new Promise((resolve, reject) => {
20 | getAvailableModels(apiKey)
21 | .then((models) => {
22 | resolve(models.filter((model) => model.startsWith("gpt-")));
23 | })
24 | .catch((err) => {
25 | reject(err);
26 | });
27 | });
28 | };
29 |
--------------------------------------------------------------------------------
/src/utils/nodeId.ts:
--------------------------------------------------------------------------------
1 | export function generateNodeId(): string {
2 | return Math.random().toString().replace("0.", "");
3 | }
4 |
5 | export function generateStreamId(): string {
6 | return Math.random().toString().replace("0.", "");
7 | }
8 |
--------------------------------------------------------------------------------
/src/utils/platform.ts:
--------------------------------------------------------------------------------
1 | export function getPlatformModifierKey() {
2 | return window.navigator.platform === "MacIntel" ? "meta" : "ctrl";
3 | }
4 |
5 | export function getPlatformModifierKeyText() {
6 | return window.navigator.platform === "MacIntel" ? "⌘" : " Ctrl ";
7 | }
8 |
--------------------------------------------------------------------------------
/src/utils/prompt.ts:
--------------------------------------------------------------------------------
1 | import { ToTNodeData, HumanEvalProblemsType } from "./types";
2 | import { ChatCompletionRequestMessage } from "openai-streams";
3 | import { MAX_AUTOLABEL_CHARS } from "./constants";
4 | import { Node } from "reactflow";
5 | import * as nunjucks from "nunjucks";
6 | import rawHumanEvalProblems from "./human_eval_problems.json";
7 | const HUMAN_EVAL_PROBLEMS = rawHumanEvalProblems as HumanEvalProblemsType;
8 |
9 | export function messageFromNode(
10 | currNode: Node
11 | ): ChatCompletionRequestMessage[] {
12 | const messages: ChatCompletionRequestMessage[] = [];
13 |
14 | console.log(currNode.data.input);
15 | console.log(currNode.data.output);
16 | console.log(currNode.data.steps);
17 |
18 | let currNumsStr: string;
19 |
20 | if (currNode.data.steps.length === 0) {
21 | currNumsStr = currNode.data.input;
22 | } else {
23 | currNumsStr = getCurrentNumbers(currNode.data.steps[currNode.data.steps.length - 1]);
24 | // Assuming getCurrentNumbers has been defined in TypeScript as shared before
25 | }
26 | let prompt = proposePrompt(currNumsStr);
27 | console.log("this is the prompt", prompt);
28 |
29 | messages.push({
30 | role: "user",
31 | content: prompt,
32 | });
33 |
34 | console.table(messages);
35 |
36 | return messages;
37 | }
38 |
39 | export function humanEvalMessageFromNode(
40 | currNode: Node
41 | ): ChatCompletionRequestMessage[] {
42 | const messages: ChatCompletionRequestMessage[] = [];
43 |
44 | const prompt: string = HUMAN_EVAL_PROBLEMS[currNode.data.input]["prompt"];
45 | console.log("this is the human eval prompt", prompt);
46 |
47 | messages.push({
48 | role: "user",
49 | content: prompt,
50 | });
51 |
52 | console.table(messages);
53 |
54 | return messages;
55 | }
56 |
57 | export function explanationMessage(
58 | question: string,
59 | answer: string,
60 | error: string
61 | ): ChatCompletionRequestMessage[] {
62 | const messages: ChatCompletionRequestMessage[] = [];
63 | const prompt: string = error2explanation(question, answer, error);
64 |
65 | messages.push({
66 | role: "user",
67 | content: prompt,
68 | });
69 |
70 | console.log("explanation message");
71 | console.table(messages);
72 |
73 | return messages;
74 | }
75 |
76 | export function regenMessage(
77 | question: string,
78 | answer: string,
79 | error: string,
80 | explanation: string
81 | ): ChatCompletionRequestMessage[] {
82 | const messages: ChatCompletionRequestMessage[] = [];
83 | const prompt: string = explanation2code(question, answer, error, explanation);
84 | messages.push({
85 | role: "user",
86 | content: prompt,
87 | });
88 |
89 | console.log("regen message");
90 | console.table(messages);
91 |
92 | return messages;
93 | }
94 |
95 | const explanation2code = (
96 | question: string,
97 | answer: string,
98 | error: string,
99 | explanation: string
100 | ): string => {
101 | return `
102 | You are a smart and capable agent who can learn from mistakes. Given an incorrect code and its error traceback, correct the completion answer by incorporating the explanation.
103 | Only output the body of the completion answer.
104 |
105 | QUESTION:
106 | ----
107 | ${question}
108 | ----
109 | ANSWER:
110 | ----
111 | ${answer}
112 | ----
113 | ERROR TRACEBACK:
114 | ----
115 | ${error}
116 | ----
117 | EXPLANATION:
118 | ----
119 | ${explanation}
120 | ----
121 | ANSWER:
122 | ----
123 | `;
124 | };
125 |
126 | const error2explanation = (question: string, answer: string, error: string): string => {
127 | return `
128 | You are a smart and capable agent and can learn from your mistakes. You can correctly debug and code a python program.
129 | Only output the explanation of the traceback error so that you can fix the previous answer by rewriting. Do not output code.
130 |
131 | QUESTION:
132 | ----
133 | ${question}
134 | ----
135 | ANSWER:
136 | ----
137 | ${answer}
138 | ----
139 | ERROR TRACEBACK:
140 | ----
141 | ${error}
142 | ----
143 | EXPLANATION:
144 | ----
145 | `;
146 | };
147 |
148 | export function getCurrentNumbers(val: string): string {
149 | console.log("val", val);
150 | const lastLine = val.trim().split("\n").pop() || "";
151 | return lastLine.split("left: ").pop()?.split(")")[0] || "";
152 | }
153 |
154 | // PROPOSE PROMPT
155 | const proposePromptTemplate = `{% for example in examples %}
156 | Input: {{ example.input }}
157 | Possible next steps:
158 | {% for next_step in example.next_steps %}{{ next_step }}
159 | {% endfor %}{% endfor %}
160 | Provide only 4 possible next steps.
161 | Input: {{ input }}
162 | Possible next steps:
163 | `;
164 |
165 | const proposeExamples = [
166 | {
167 | input: "3 8 9",
168 | next_steps: [
169 | "9 / 3 = 3 (left: 3 8)",
170 | "3 * 8 = 24 (left: 24 9)",
171 | "9 * 3 = 27 (left: 27 8)",
172 | "9 - 8 = 1 (left: 1 3)",
173 | ],
174 | },
175 | {
176 | input: "3 3 7",
177 | next_steps: [
178 | "3 + 7 = 10 (left: 10 3)",
179 | "3 * 3 = 9 (left: 9 7)",
180 | "7 - 3 = 4 (left: 4 3)",
181 | "3 - 3 = 0 (left: 0 7)",
182 | ],
183 | },
184 | ];
185 |
186 | // Create function to render the propose prompt by parsing the jinja
187 | function textPromptDecorator(fn: Function) {
188 | return function (input: string, examples = proposeExamples) {
189 | const renderedTemplate = nunjucks.renderString(proposePromptTemplate, {
190 | input,
191 | examples,
192 | });
193 | return fn(renderedTemplate);
194 | };
195 | }
196 |
197 | const proposePrompt = textPromptDecorator((renderedTemplate: string) => {
198 | return renderedTemplate;
199 | });
200 |
201 | export function formatAutoLabel(text: string) {
202 | const formattedText = removeInvalidChars(text);
203 |
204 | return formattedText.length > MAX_AUTOLABEL_CHARS
205 | ? formattedText.slice(0, MAX_AUTOLABEL_CHARS).split(" ").slice(0, -1).join(" ") +
206 | " ..."
207 | : formattedText;
208 | }
209 |
210 | // VALUE PROMPT
211 | const valuePromptTemplate = `Evaluate if given numbers can reach 24 (sure/likely/impossible)
212 | {% for example in examples %}
213 | Input: {{ example.input }}
214 | {% for step in example.steps %}
215 | {{ step }}
216 | {% endfor %}
217 | {{ example.output }}
218 | {% endfor %}
219 | Input: {{input}}
220 | `;
221 |
222 | const value_examples = [
223 | { input: "10 14", steps: ["10 + 14 = 24"], output: "sure" },
224 | {
225 | input: "11 12",
226 | steps: ["11 + 12 = 23", "12 - 11 = 1", "11 * 12 = 132", "11 / 12 = 0.91"],
227 | output: "impossible",
228 | },
229 | {
230 | input: "4 4 10",
231 | steps: [
232 | "4 + 4 + 10 = 8 + 10 = 18",
233 | "4 * 10 - 4 = 40 - 4 = 36",
234 | "(10 - 4) * 4 = 6 * 4 = 24",
235 | ],
236 | output: "sure",
237 | },
238 | { input: "4 9 11", steps: ["9 + 11 + 4 = 20 + 4 = 24"], output: "sure" },
239 | {
240 | input: "5 7 8",
241 | steps: [
242 | "5 + 7 + 8 = 12 + 8 = 20",
243 | "(8 - 5) * 7 = 3 * 7 = 21",
244 | "I cannot obtain 24 now, but numbers are within a reasonable range",
245 | ],
246 | output: "likely",
247 | },
248 | {
249 | input: "5 6 6",
250 | steps: [
251 | "5 + 6 + 6 = 17",
252 | "(6 - 5) * 6 = 1 * 6 = 6",
253 | "I cannot obtain 24 now, but numbers are within a reasonable range",
254 | ],
255 | output: "likely",
256 | },
257 | {
258 | input: "10 10 11",
259 | steps: ["10 + 10 + 11 = 31", "(11 - 10) * 10 = 10", "10 10 10 are all too big"],
260 | output: "impossible",
261 | },
262 | {
263 | input: "1 3 3",
264 | steps: ["1 * 3 * 3 = 9", "(1 + 3) * 3 = 12", "1 3 3 are all too small"],
265 | output: "impossible",
266 | },
267 | { input: "24", steps: ["24 = 24 (solved, no steps needed)"], output: "sure" },
268 | ];
269 |
270 | function valuePromptDecorator(fn: Function) {
271 | return function (input: string, examples = value_examples) {
272 | const renderedTemplate = nunjucks.renderString(valuePromptTemplate, {
273 | input,
274 | examples,
275 | });
276 | return fn(renderedTemplate);
277 | };
278 | }
279 |
280 | const valuePrompt = valuePromptDecorator((renderedTemplate: string) => {
281 | return renderedTemplate;
282 | });
283 |
284 | export function evalMessageFromText(text: string): ChatCompletionRequestMessage[] {
285 | const messages: ChatCompletionRequestMessage[] = [];
286 |
287 | // Using cotPrompt to generate the prompt
288 | const currNumsStr = getCurrentNumbers(text);
289 | let prompt = valuePrompt(currNumsStr);
290 |
291 | messages.push({
292 | role: "user",
293 | content: prompt,
294 | });
295 |
296 | console.table(messages);
297 |
298 | return messages;
299 | }
300 |
301 | export function parseAndCompute(valueOutputs: string[]): number {
302 | const valueMap: { [key: string]: number } = {
303 | impossible: 0.001,
304 | likely: 1,
305 | sure: 20,
306 | };
307 |
308 | function computeValue(sample: string): number {
309 | const valueName = sample.split("\n").slice(-1)[0];
310 | return valueMap[valueName] || 0;
311 | }
312 |
313 | return valueOutputs.map(computeValue).reduce((a, b) => a + b, 0);
314 | }
315 |
316 | const valueLastStepPromptTemplate = `Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
317 | {% for example in examples %}
318 | Input: {{ example.input }}
319 | Answer: {{ example.answer }}
320 | Judge: {{ example.judge }}
321 | {% endfor %}
322 | Input: {{input}}
323 | Answer: {{answer}}
324 | Judge:`;
325 |
326 | const value_last_step_examples = [
327 | { input: "3 3 5", answer: "(5 + 3) * 3 = 24", judge: "sure" },
328 | { input: "3 3 5", answer: "(3 - 3) * 5 = 24", judge: "impossible" },
329 | { input: "2 5 7", answer: "(7 + 5) * 2 = 24", judge: "sure" },
330 | { input: "2 5 7", answer: "(7 - 5) * 2 = 24", judge: "impossible" },
331 | { input: "5 8 8", answer: "(8 + 8) - 5 = 24", judge: "impossible" },
332 | { input: "5 8 8", answer: "(8 - 5) / 8 = 24", judge: "impossible" },
333 | ];
334 |
335 | function valueLastStepPromptDecorator(fn: Function) {
336 | return function (input: string, answer: string, examples = value_last_step_examples) {
337 | const renderedTemplate = nunjucks.renderString(valueLastStepPromptTemplate, {
338 | input,
339 | answer,
340 | examples,
341 | });
342 | return fn(renderedTemplate);
343 | };
344 | }
345 |
346 | // COT PROMPT
347 | const cotPromptTemplate = `Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Be sure to use numbers uniquely only once. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
348 | {% for example in examples %}
349 | Input: {{ example.input }}
350 | Steps:
351 | {% for step in example.steps %}
352 | {{ step }}
353 | {% endfor %}
354 | Answer: {{ example.output }}
355 | {% endfor %}
356 | Input: {{input}}
357 | Steps:\n
358 | `;
359 |
360 | const cot_examples = [
361 | {
362 | input: "3 3 5",
363 | steps: ["3 + 5 = 8 (left: 8 3)", "8 * 3 = 24 (left: 24)"],
364 | output: "(3 + 5) * 3 = 24",
365 | },
366 | {
367 | input: "3 8 9",
368 | steps: ["9 / 3 = 3 (left: 3 8)", "3 * 8 = 24 (left: 24)"],
369 | output: "(9 / 3) * 8 = 24",
370 | },
371 | {
372 | input: "5 8 8",
373 | steps: ["8 - 5 = 3 (left: 3 8)", "3 * 8 = 24 (left: 24)"],
374 | output: "(8 - 5) * 3 = 24",
375 | },
376 | {
377 | input: "3 3 9",
378 | steps: ["9 * 3 = 27 (left: 27 3)", "27 - 3 = 24 (left: 24)"],
379 | output: "(9 * 3) - 3 = 24",
380 | },
381 | {
382 | input: "2 5 7",
383 | steps: ["7 + 5 = 12 (left: 12 2)", "12 * 2 = 24 (left: 24)"],
384 | output: "(7 + 5) * 2 = 24",
385 | },
386 | ];
387 |
388 | type RenderFunction = (template: string) => string;
389 | // Create function to render the cot prompt by parsing the jinja
390 | function cotPromptDecorator(fn: RenderFunction): RenderFunction {
391 | type Example = {
392 | input: string;
393 | steps: string[];
394 | output: string;
395 | };
396 |
397 | return function (input: string, examples: Example[] = cot_examples): string {
398 | const renderedTemplate = nunjucks.renderString(cotPromptTemplate, {
399 | input,
400 | examples,
401 | });
402 | return fn(renderedTemplate);
403 | };
404 | }
405 |
406 | const cotPrompt = cotPromptDecorator((renderedTemplate: string) => {
407 | return renderedTemplate;
408 | });
409 |
410 | export function cotMessageFromNode(
411 | currNode: Node,
412 | text: string
413 | ): ChatCompletionRequestMessage[] {
414 | const messages: ChatCompletionRequestMessage[] = [];
415 |
416 | // Using cotPrompt to generate the prompt
417 | let prompt =
418 | cotPrompt(currNode.data.input) +
419 | currNode.data.steps.slice(0, -1).join("\n") +
420 | "\n" +
421 | text;
422 |
423 | console.log("this is text for answer prompt", text);
424 | console.log("this is answer prompt", prompt);
425 |
426 | messages.push({
427 | role: "user",
428 | content: prompt,
429 | });
430 |
431 | console.table(messages);
432 |
433 | return messages;
434 | }
435 |
436 | function removeInvalidChars(text: string) {
437 | // The regular expression pattern:
438 | // ^: not
439 | // a-zA-Z0-9: letters and numbers
440 | // .,?!: common punctuation marks
441 | // \s: whitespace characters (space, tab, newline, etc.)
442 | const regex = /[^a-zA-Z0-9.,'?!-\s+=*\/<>():%_{}[\]&|^~@;#$]+/g;
443 |
444 | // Replace `\n` with spaces and remove invalid characters
445 | const cleanedStr = text.replaceAll("\n", " ").replace(regex, "");
446 |
447 | return cleanedStr;
448 | }
449 |
--------------------------------------------------------------------------------
/src/utils/qparams.ts:
--------------------------------------------------------------------------------
1 | export function getQueryParam(parameterName: string) {
2 | const urlParams = new URLSearchParams(window.location.search);
3 | return urlParams.get(parameterName);
4 | }
5 |
6 | export function resetURL() {
7 | history.replaceState({}, document.title, window.location.pathname);
8 | }
9 |
--------------------------------------------------------------------------------
/src/utils/rand.ts:
--------------------------------------------------------------------------------
1 | export function randomNumber(min: number, max: number) {
2 | return Math.random() * (max - min) + min;
3 | }
4 |
--------------------------------------------------------------------------------
/src/utils/resize.ts:
--------------------------------------------------------------------------------
1 | import { useEffect } from "react";
2 |
3 | export function useDebouncedWindowResize(callback: () => void, timeout: number) {
4 | useEffect(() => {
5 | let resizeTimeout: NodeJS.Timeout;
6 |
7 | const handleResize = () => {
8 | clearTimeout(resizeTimeout);
9 | resizeTimeout = setTimeout(callback, timeout);
10 | };
11 |
12 | window.addEventListener("resize", handleResize);
13 |
14 | return () => {
15 | window.removeEventListener("resize", handleResize);
16 | clearTimeout(resizeTimeout);
17 | };
18 | }, [callback, timeout]);
19 | }
20 |
--------------------------------------------------------------------------------
/src/utils/tot.ts:
--------------------------------------------------------------------------------
1 | import * as nunjucks from "nunjucks";
2 | import { writeFile } from "fs";
3 | import * as fs from "fs";
4 | import { llm } from "./llm";
5 | import * as math from "mathjs";
6 |
7 | function validateLLMOutput(input: string, output: string): boolean {
8 | const cleanedOutput =
9 | output
10 | .trim()
11 | .split("\n")
12 | .pop()
13 | ?.toLowerCase()
14 | .replace("answer: ", "")
15 | .split("=")[0] || "";
16 | const numbers = cleanedOutput.match(/\d+/g) || [];
17 | const problemNumbers = input.match(/\d+/g) || [];
18 |
19 | if (numbers.sort().join("") !== problemNumbers.sort().join("")) {
20 | return false;
21 | }
22 |
23 | try {
24 | return math.evaluate(cleanedOutput) === 24;
25 | } catch (e) {
26 | return false;
27 | }
28 | }
29 |
30 | nunjucks.configure({ autoescape: true });
31 |
32 | class Node {
33 | input: string;
34 | steps: string[] = [];
35 | output?: string | null = null;
36 |
37 | constructor(input: string, steps?: string[], output?: string | null) {
38 | this.input = input;
39 | if (steps) this.steps = steps;
40 | if (output !== undefined) this.output = output;
41 | }
42 |
43 | toRepr(): string {
44 | // You need to define how you want to represent the Node here.
45 | return `Node(input='${this.input}', steps=[${this.steps
46 | .map((s) => `'${s}'`)
47 | .join(", ")}], output='${this.output}')`;
48 | }
49 | }
50 |
51 | function logDictToJson(data: any, filename: string): void {
52 | function convertNodesToRepr(obj: any): any {
53 | if (obj instanceof Node) {
54 | return obj.toRepr();
55 | } else if (obj !== null && typeof obj === "object") {
56 | const newObj: any = Array.isArray(obj) ? [] : {};
57 | for (const key in obj) {
58 | newObj[key] = convertNodesToRepr(obj[key]);
59 | }
60 | return newObj;
61 | } else {
62 | return obj;
63 | }
64 | }
65 |
66 | const jsonString = JSON.stringify(convertNodesToRepr(data), null, 4);
67 | writeFile(`${filename}.json`, jsonString, (err) => {
68 | if (err) {
69 | console.error("Failed to save the JSON file:", err);
70 | }
71 | });
72 | }
73 |
74 | function getCurrentNumbers(val: string): string {
75 | const lastLine = val.trim().split("\n").pop() || "";
76 | return lastLine.split("left: ").pop()?.split(")")[0] || "";
77 | }
78 |
79 | // PROPOSE PROMPT
80 | const proposePromptTemplate = `{% for example in examples %}
81 | Input: {{ example.input }}
82 | Possible next steps:
83 | {% for next_step in example.next_steps %}{{ next_step }}
84 | {% endfor %}{% endfor %}
85 | Input: {{ input }}
86 | Possible next steps:
87 | `;
88 |
89 | const proposeExamples = [
90 | {
91 | input: "3 8 9",
92 | next_steps: [
93 | "9 / 3 = 3 (left: 3 8)",
94 | "3 * 8 = 24 (left: 24 9)",
95 | "9 * 3 = 27 (left: 27 8)",
96 | "9 - 8 = 1 (left: 1 3)",
97 | ],
98 | },
99 | {
100 | input: "3 3 7",
101 | next_steps: [
102 | "3 + 7 = 10 (left: 10 3)",
103 | "3 * 3 = 9 (left: 9 7)",
104 | "7 - 3 = 4 (left: 4 3)",
105 | "3 - 3 = 0 (left: 0 7)",
106 | ],
107 | },
108 | ];
109 |
110 | // Create function to render the propose prompt by parsing the jinja
111 | function textPromptDecorator(fn: Function) {
112 | return function (input: string, examples = proposeExamples) {
113 | const renderedTemplate = nunjucks.renderString(proposePromptTemplate, {
114 | input,
115 | examples,
116 | });
117 | return fn(renderedTemplate);
118 | };
119 | }
120 |
121 | export const proposePrompt = textPromptDecorator((renderedTemplate: string) => {
122 | return renderedTemplate;
123 | });
124 |
125 | type LLMOUTPUT = string[] | string;
126 | function getNextSteps(llmOutput: LLMOUTPUT, maxSteps: number = 5): string[] {
127 | if (typeof llmOutput === "string") {
128 | let nextSteps = llmOutput.trim().split("\n");
129 | return nextSteps.slice(0, maxSteps);
130 | } else if (Array.isArray(llmOutput)) {
131 | // Check if it's an array before attempting to map
132 | return ([] as string[]).concat(
133 | ...llmOutput.map((item) => getNextSteps(item, maxSteps))
134 | );
135 | } else {
136 | // Return an empty array as a fallback
137 | return [];
138 | }
139 | }
140 |
141 | // COT PROMPT
142 | const cotPromptTemplate = `Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Be sure to use numbers uniquely only once. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
143 | {% for example in examples %}
144 | Input: {{ example.input }}
145 | Steps:
146 | {% for step in example.steps %}
147 | {{ step }}
148 | {% endfor %}
149 | Answer: {{ example.output }}
150 | {% endfor %}
151 | Input: {{input}}
152 | Steps:
153 | `;
154 |
155 | const cot_examples = [
156 | {
157 | input: "3 3 5",
158 | steps: ["3 + 5 = 8 (left: 8 3)", "8 * 3 = 24 (left: 24)"],
159 | output: "(3 + 5) * 3 = 24",
160 | },
161 | {
162 | input: "3 8 9",
163 | steps: ["9 / 3 = 3 (left: 3 8)", "3 * 8 = 24 (left: 24)"],
164 | output: "(9 / 3) * 8 = 24",
165 | },
166 | {
167 | input: "5 8 8",
168 | steps: ["8 - 5 = 3 (left: 3 8)", "3 * 8 = 24 (left: 24)"],
169 | output: "(8 - 5) * 3 = 24",
170 | },
171 | {
172 | input: "3 3 9",
173 | steps: ["9 * 3 = 27 (left: 27 3)", "27 - 3 = 24 (left: 24)"],
174 | output: "(9 * 3) - 3 = 24",
175 | },
176 | {
177 | input: "2 5 7",
178 | steps: ["7 + 5 = 12 (left: 12 2)", "12 * 2 = 24 (left: 24)"],
179 | output: "(7 + 5) * 2 = 24",
180 | },
181 | ];
182 |
183 | type RenderFunction = (template: string) => string;
184 | // Create function to render the cot prompt by parsing the jinja
185 | function cotPromptDecorator(fn: RenderFunction): RenderFunction {
186 | type Example = {
187 | input: string;
188 | steps: string[];
189 | output: string;
190 | };
191 |
192 | return function (input: string, examples: Example[] = cot_examples): string {
193 | const renderedTemplate = nunjucks.renderString(cotPromptTemplate, {
194 | input,
195 | examples,
196 | });
197 | return fn(renderedTemplate);
198 | };
199 | }
200 |
201 | export const cotPrompt = cotPromptDecorator((renderedTemplate: string) => {
202 | return renderedTemplate;
203 | });
204 |
205 | async function nodeGenerator(node: Node, fanout: number = 5): Promise {
206 | let currNumsStr: string;
207 |
208 | if (node.steps.length === 0) {
209 | currNumsStr = node.input;
210 | } else {
211 | currNumsStr = getCurrentNumbers(node.steps[node.steps.length - 1]);
212 | // Assuming getCurrentNumbers has been defined in TypeScript as shared before
213 | }
214 |
215 | let prompt = proposePrompt(currNumsStr);
216 | // Assuming proposePrompt function exists and returns a string
217 |
218 | let llmOutput = await llm(prompt);
219 | // Assuming llm function exists and returns the desired output
220 | let nextSteps = getNextSteps(llmOutput, fanout);
221 |
222 | // Check if any nextSteps result in "24"
223 | let newNodes: Node[] = [];
224 | for (const step of nextSteps) {
225 | if (getCurrentNumbers(step) === "24") {
226 | let prompt =
227 | cotPrompt(node.input) + "\nSteps:\n" + node.steps.concat([step]).join("\n");
228 | // Assuming cotPrompt function exists and returns a string
229 |
230 | let answer = (await llm(prompt)) as string;
231 | // Remember to await llm here too, as it's an async operation
232 | let leafNode = new Node(node.input, node.steps, answer);
233 | return [leafNode];
234 | } else {
235 | newNodes.push(new Node(node.input, node.steps.concat([step])));
236 | }
237 | }
238 | return newNodes;
239 | }
240 |
241 | // VALUE PROMPT
242 | const valuePromptTemplate = `Evaluate if given numbers can reach 24 (sure/likely/impossible)
243 | {% for example in examples %}
244 | Input: {{ example.input }}
245 | {% for step in example.steps %}
246 | {{ step }}
247 | {% endfor %}
248 | {{ example.output }}
249 | {% endfor %}
250 | Input: {{input}}
251 | `;
252 |
253 | const value_examples = [
254 | { input: "10 14", steps: ["10 + 14 = 24"], output: "sure" },
255 | {
256 | input: "11 12",
257 | steps: ["11 + 12 = 23", "12 - 11 = 1", "11 * 12 = 132", "11 / 12 = 0.91"],
258 | output: "impossible",
259 | },
260 | {
261 | input: "4 4 10",
262 | steps: [
263 | "4 + 4 + 10 = 8 + 10 = 18",
264 | "4 * 10 - 4 = 40 - 4 = 36",
265 | "(10 - 4) * 4 = 6 * 4 = 24",
266 | ],
267 | output: "sure",
268 | },
269 | { input: "4 9 11", steps: ["9 + 11 + 4 = 20 + 4 = 24"], output: "sure" },
270 | {
271 | input: "5 7 8",
272 | steps: [
273 | "5 + 7 + 8 = 12 + 8 = 20",
274 | "(8 - 5) * 7 = 3 * 7 = 21",
275 | "I cannot obtain 24 now, but numbers are within a reasonable range",
276 | ],
277 | output: "likely",
278 | },
279 | {
280 | input: "5 6 6",
281 | steps: [
282 | "5 + 6 + 6 = 17",
283 | "(6 - 5) * 6 = 1 * 6 = 6",
284 | "I cannot obtain 24 now, but numbers are within a reasonable range",
285 | ],
286 | output: "likely",
287 | },
288 | {
289 | input: "10 10 11",
290 | steps: ["10 + 10 + 11 = 31", "(11 - 10) * 10 = 10", "10 10 10 are all too big"],
291 | output: "impossible",
292 | },
293 | {
294 | input: "1 3 3",
295 | steps: ["1 * 3 * 3 = 9", "(1 + 3) * 3 = 12", "1 3 3 are all too small"],
296 | output: "impossible",
297 | },
298 | { input: "24", steps: ["24 = 24 (solved, no steps needed)"], output: "sure" },
299 | ];
300 |
301 | function valuePromptDecorator(fn: Function) {
302 | return function (input: string, examples = value_examples) {
303 | const renderedTemplate = nunjucks.renderString(valuePromptTemplate, {
304 | input,
305 | examples,
306 | });
307 | return fn(renderedTemplate);
308 | };
309 | }
310 |
311 | export const valuePrompt = valuePromptDecorator((renderedTemplate: string) => {
312 | return renderedTemplate;
313 | });
314 |
315 | const valueLastStepPromptTemplate = `Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
316 | {% for example in examples %}
317 | Input: {{ example.input }}
318 | Answer: {{ example.answer }}
319 | Judge: {{ example.judge }}
320 | {% endfor %}
321 | Input: {{input}}
322 | Answer: {{answer}}
323 | Judge:`;
324 |
325 | const value_last_step_examples = [
326 | { input: "3 3 5", answer: "(5 + 3) * 3 = 24", judge: "sure" },
327 | { input: "3 3 5", answer: "(3 - 3) * 5 = 24", judge: "impossible" },
328 | { input: "2 5 7", answer: "(7 + 5) * 2 = 24", judge: "sure" },
329 | { input: "2 5 7", answer: "(7 - 5) * 2 = 24", judge: "impossible" },
330 | { input: "5 8 8", answer: "(8 + 8) - 5 = 24", judge: "impossible" },
331 | { input: "5 8 8", answer: "(8 - 5) / 8 = 24", judge: "impossible" },
332 | ];
333 |
334 | function valueLastStepPromptDecorator(fn: Function) {
335 | return function (input: string, answer: string, examples = value_last_step_examples) {
336 | const renderedTemplate = nunjucks.renderString(valueLastStepPromptTemplate, {
337 | input,
338 | answer,
339 | examples,
340 | });
341 | return fn(renderedTemplate);
342 | };
343 | }
344 |
345 | export const valueLastStepPrompt = valueLastStepPromptDecorator(
346 | (renderedTemplate: string) => {
347 | return renderedTemplate;
348 | }
349 | );
350 |
351 | type ValueOutputs = string[] | string[][];
352 |
353 | function parseAndCompute(valueOutputs: ValueOutputs): number | number[] {
354 | const valueMap: { [key: string]: number } = {
355 | impossible: 0.001,
356 | likely: 1,
357 | sure: 20,
358 | }; // TODO: ad hoc
359 |
360 | function computeValue(sample: string[]): number {
361 | const valueNames = sample.map((s) => s.split("\n").slice(-1)[0]);
362 | return Object.entries(valueMap).reduce((sum, [name, value]) => {
363 | return sum + value * valueNames.filter((vName) => vName === name).length;
364 | }, 0);
365 | }
366 |
367 | // Determine if valueOutputs is a single sample or multiple samples
368 | if (Array.isArray(valueOutputs[0])) {
369 | // Handling multiple samples
370 | return (valueOutputs as string[][]).map(computeValue);
371 | } else {
372 | // Handling a single sample
373 | return computeValue(valueOutputs as string[]);
374 | }
375 | }
376 |
377 | // TODO: CHECK VALIDITY OF NODE BY GRAMMAR CHECKING
378 | function validNode(node: Node): boolean {
379 | return true;
380 | }
381 |
382 | async function nodeEvaluatorMulti(nodes: Node[]): Promise<[number, string[]][]> {
383 | const N_EVAL = 3;
384 |
385 | const prompts: string[] = [];
386 | const nodeValidity: boolean[] = [];
387 |
388 | for (const node of nodes) {
389 | if (!validNode(node)) {
390 | nodeValidity.push(false);
391 | continue;
392 | }
393 |
394 | let prompt: string;
395 |
396 | if (node.output) {
397 | const ansExpr = node.output.toLowerCase().replace("answer: ", "");
398 | prompt = valueLastStepPrompt(node.input, ansExpr);
399 | } else {
400 | const currNumsStr = getCurrentNumbers(node.steps[node.steps.length - 1]);
401 | prompt = valuePrompt(currNumsStr);
402 | }
403 |
404 | prompts.push(prompt);
405 | nodeValidity.push(true);
406 | }
407 |
408 | let llmOutputs: string[][] = []; // Defaulting to an array of string arrays
409 |
410 | if (prompts.length > 0) {
411 | const results = await Promise.all(prompts.map((prompt) => llm(prompt, N_EVAL)));
412 |
413 | // After all promises are resolved, results will be an array of the resolved values.
414 | if (Array.isArray(results[0])) {
415 | // Using a type guard
416 | llmOutputs = results as string[][];
417 | } else {
418 | throw new Error("Unexpected output format from llm.");
419 | }
420 | }
421 |
422 | const results: [number, string[]][] = [];
423 | let j = 0; // Counter for valid nodes
424 |
425 | for (let i = 0; i < nodes.length; i++) {
426 | if (nodeValidity[i]) {
427 | const valueOutput = parseAndCompute(llmOutputs[j]);
428 | const value = Array.isArray(valueOutput) ? valueOutput[0] : valueOutput;
429 |
430 | results.push([value, llmOutputs[j]]);
431 | j++;
432 | } else {
433 | results.push([-1, ["invalid node, will log reason later"]]);
434 | }
435 | }
436 |
437 | return results;
438 | }
439 |
440 | type Proposal = {
441 | node: Node;
442 | value: number;
443 | propEval: string[];
444 | isTerminal: boolean;
445 | isValid: boolean;
446 | };
447 |
448 | export async function treeOfThoughtsBfs(x: string): Promise {
449 | const N_BEST = 5;
450 | const N_STEPS = 3;
451 |
452 | const terminalData: Proposal[] = [];
453 | const root = new Node(x);
454 | let queue: Node[] = [root];
455 | let foundTerminal = false;
456 |
457 | for (let step = 0; step < N_STEPS; step++) {
458 | const allProposalData: Proposal[] = [];
459 |
460 | for (let node of queue) {
461 | const nextNodes = await nodeGenerator(node);
462 | const nextNodeValuesAndLogs = await nodeEvaluatorMulti(nextNodes);
463 |
464 | nextNodes.forEach((nextNode, idx) => {
465 | const proposal: Proposal = {
466 | node: nextNode,
467 | value: nextNodeValuesAndLogs[idx][0],
468 | propEval: nextNodeValuesAndLogs[idx][1],
469 | isTerminal: nextNode.output !== null,
470 | isValid: nextNodeValuesAndLogs[idx][0] > -1,
471 | };
472 |
473 | allProposalData.push(proposal);
474 |
475 | if (proposal.isTerminal) {
476 | terminalData.push(proposal);
477 | foundTerminal = true;
478 | return;
479 | }
480 | });
481 | if (foundTerminal) break;
482 | }
483 | if (foundTerminal) break;
484 |
485 | console.log(`>> step ${step + 1}: ${allProposalData.length} proposals`);
486 |
487 | if (!loggingDict[root.input]) {
488 | loggingDict[root.input] = { steps: [] };
489 | }
490 | loggingDict[root.input].steps.push({
491 | queue: [...queue],
492 | allProposals: allProposalData,
493 | });
494 |
495 | const validProposalData = allProposalData.filter((p) => p.isValid && !p.isTerminal);
496 | const sortedProposalData = validProposalData.sort((a, b) => b.value - a.value);
497 | queue = sortedProposalData.slice(0, N_BEST).map((p) => p.node);
498 | }
499 |
500 | // Sanity check
501 | if (!terminalData.every((p) => p.node.output !== null)) {
502 | throw new Error("Sanity check failed: Not all terminal nodes have an output.");
503 | }
504 |
505 | // Log terminal nodes
506 | loggingDict[root.input].terminalData = terminalData;
507 |
508 | const validTerminals = terminalData.filter((p) => p.isValid);
509 | const sortedValidTerminalData = validTerminals.sort((a, b) => b.value - a.value);
510 | const answers = sortedValidTerminalData.slice(0, N_BEST).map((p) => p.node.output!);
511 |
512 | return answers;
513 | }
514 |
515 | async function testTotBfs(): Promise {
516 | loggingDict = {}; // Resetting loggingDict
517 |
518 | // Assuming inputs is a global or module-level array
519 | const outputs: string[][] = [];
520 |
521 | for (const x of inputs) {
522 | const answers = await treeOfThoughtsBfs(x); // Assuming treeOfThoughtsBfs is async
523 | outputs.push(answers);
524 |
525 | if (!loggingDict[x]) {
526 | loggingDict[x] = { answers: [] };
527 | }
528 | loggingDict[x]["answers"] = answers;
529 | console.log(`>> x = ${x}, answers = ${answers}`);
530 |
531 | // Assuming logDictToJson is a function available in your TS code
532 | logDictToJson(loggingDict, "logging_dict_tot");
533 | }
534 |
535 | const statusMatrix: boolean[][] = [];
536 | for (let i = 0; i < inputs.length; i++) {
537 | const x = inputs[i];
538 | const ys = outputs[i];
539 |
540 | const status = ys.map((y) => validateLLMOutput(x, y)); // Assuming validateLlmOutput is defined elsewhere
541 | statusMatrix.push(status);
542 | }
543 |
544 | const nAnyCorrect = statusMatrix.filter((row) => row.some(Boolean)).length;
545 | const avgCorrect =
546 | statusMatrix.map((row) => row.filter(Boolean).length).reduce((a, b) => a + b, 0) /
547 | inputs.length;
548 |
549 | console.log("ToT Prompting");
550 | console.log(`n_inputs: ${inputs.length}`);
551 | console.log(`n_inputs (min 1 correct): ${nAnyCorrect}`);
552 | console.log(`avg number of correct answers per input: ${avgCorrect.toFixed(2)}`);
553 | }
554 |
555 | async function main() {
556 | // let root = new Node("3 3 8");
557 | // // console.log(root);
558 | // let nodes = await nodeGenerator(root);
559 | // // let eval_nodes = await Promise.all(nodes.map(nodeEvaluator));
560 | // // console.log(eval_nodes);
561 | // let eval_nodes = await nodeEvaluatorMulti(nodes);
562 | // console.log(eval_nodes);
563 | // logDictToJson(nodes, "nodes");
564 |
565 | // let ans = await llm(valuePrompt("3 3 8"));
566 | // console.log(ans);
567 |
568 | const filePath = "../data/game24_3nums.csv";
569 |
570 | try {
571 | const fileContent = fs.readFileSync(filePath, "utf-8");
572 | const lines = fileContent.trim().split("\n");
573 | // Assuming the header is the first line, skip it
574 | inputs = lines.slice(1);
575 | console.log(inputs);
576 | } catch (error) {
577 | console.error("Error reading the file:", error);
578 | }
579 | // inputs = [
580 | // "1 2 8",
581 | // "1 3 6",
582 | // "3 4 4",
583 | // "8 8 8",
584 | // "4 4 5"
585 | // ]
586 | testTotBfs();
587 | }
588 |
589 | let loggingDict: { [key: string]: { [key: string]: any[] } } = {};
590 | let inputs: string[];
591 | main();
592 |
--------------------------------------------------------------------------------
/src/utils/types.ts:
--------------------------------------------------------------------------------
1 | import { Node, Edge } from "reactflow";
2 |
3 | import { ChatCompletionResponseMessage } from "openai-streams";
4 |
5 | type BranchesNodeData = {
6 | label: string;
7 | branchesNodeType: BranchesNodeType;
8 | text: string;
9 | streamId?: string;
10 | hasCustomlabel?: boolean;
11 | };
12 |
13 | export type ToTNodeData = BranchesNodeData & {
14 | errors: string[];
15 | evals?: string[];
16 | expandable: boolean;
17 | expanded: boolean;
18 | explanations: string[];
19 | input: string;
20 | isInAnswerPath?: boolean;
21 | isTerminal?: boolean;
22 | isValid?: boolean;
23 | output?: string;
24 | score?: number;
25 | solutions: string[];
26 | steps: string[];
27 | };
28 |
29 | export type HumanEvalProblemsType = {
30 | [key: string]: {
31 | task_id: string;
32 | prompt: string;
33 | entry_point: string;
34 | canonical_solution: string;
35 | test: string;
36 | };
37 | };
38 |
39 | export enum BranchesNodeType {
40 | System = "System",
41 | User = "User",
42 | GPT = "GPT",
43 | TweakedGPT = "GPT (tweaked)",
44 | }
45 |
46 | export type Settings = {
47 | defaultPreamble: string;
48 | autoZoom: boolean;
49 | model: string;
50 | temp: number;
51 | N_ANSWER_FANOUT: number;
52 | N_EXPLANATION_FANOUT: number;
53 | };
54 |
55 | export enum ReactFlowNodeTypes {
56 | LabelUpdater = "LabelUpdater",
57 | }
58 |
59 | // The stream response is weird and has a delta instead of message field.
60 | export interface CreateChatCompletionStreamResponseChoicesInner {
61 | index?: number;
62 | delta?: ChatCompletionResponseMessage;
63 | finish_reason?: string;
64 | }
65 |
66 | export type HistoryItem = {
67 | nodes: Node[];
68 | edges: Edge[];
69 | selectedNodeId: string | null;
70 | lastSelectedNodeId: string | null;
71 | };
72 |
--------------------------------------------------------------------------------
/src/vite-env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 |
--------------------------------------------------------------------------------
/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2017",
4 | "useDefineForClassFields": true,
5 | "lib": [
6 | "DOM",
7 | "DOM.Iterable",
8 | "ESNext"
9 | ],
10 | "allowJs": false,
11 | "skipLibCheck": true,
12 | "esModuleInterop": false,
13 | "allowSyntheticDefaultImports": true,
14 | "strict": true,
15 | "forceConsistentCasingInFileNames": true,
16 | "module": "ESNext",
17 | "moduleResolution": "Node",
18 | "resolveJsonModule": true,
19 | "isolatedModules": true,
20 | "noEmit": true,
21 | "jsx": "react-jsx",
22 | "typeRoots": [
23 | "./node_modules/@types",
24 | "./types"
25 | ]
26 | },
27 | "include": [
28 | "src"
29 | ],
30 | "references": [
31 | {
32 | "path": "./tsconfig.node.json"
33 | }
34 | ],
35 | "skipLibCheck": true
36 | }
--------------------------------------------------------------------------------
/tsconfig.node.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "composite": true,
4 | "module": "ESNext",
5 | "moduleResolution": "Node",
6 | "allowSyntheticDefaultImports": true
7 | },
8 | "include": ["vite.config.ts"]
9 | }
10 |
--------------------------------------------------------------------------------
/vercel.json:
--------------------------------------------------------------------------------
1 | {
2 | "routes": [
3 | {
4 | "src": "/execute",
5 | "dest": "/api/execute.py"
6 | }
7 | ]
8 | }
--------------------------------------------------------------------------------
/vite.config.ts:
--------------------------------------------------------------------------------
1 | import { defineConfig } from "vite";
2 | import react from "@vitejs/plugin-react-swc";
3 |
4 | // https://vitejs.dev/config/
5 | export default defineConfig({
6 | plugins: [react()],
7 | });
8 |
--------------------------------------------------------------------------------