├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── actions.py ├── benchmark_runner.py ├── example-benchmark.yml ├── example-prompt.txt ├── llm_openai_sql_queries.py ├── llm_sql_queries.py ├── metrics.py ├── reason-act-llm.png ├── requirements.txt ├── rescore.py ├── run_interface.py └── states.yml /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | *.pyc 4 | 5 | *.db* 6 | *.json 7 | *.log 8 | *.gguf 9 | 10 | traces/ 11 | traces.*/ 12 | experiments/ 13 | data/ 14 | results/ 15 | 16 | *-benchmark.yml 17 | 18 | .envrc 19 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11.6 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Brandon Roberts 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reason-Act SQLite Demo 2 | 3 | ![Reason and Act Flow Chart](https://github.com/brandonrobertz/reason-act-sqlite-py/blob/main/reason-act-llm.png?raw=true) 4 | 5 | This is a demonstration of how to use [reason and act][react-paper] with [llama.cpp][llama-git] and a [LLM][xwin] to pose plain english queries to a SQLite database using one of two strategies: 6 | 7 | 1. Actions that mimic interaction with a frontend like [Datasette][datasette-git]. Actions: list tables, list table columns, facet, filter 8 | 2. Let the LLM use SQLite queries directly. Actions: list tables, list table schema, execute sql 9 | 10 | The things you'll need to do are: 11 | 12 | 1. Provide a SQLite database (named `example.db` or you need to change the name in the Python files) 13 | 2. Change the prompts in both Python scripts (the `prompt` string inside the `execute` functions) to be specific to your data and problems. You'll also want to date the `DATA_HELP` table and column descriptions in `run-sql-queries.py`. 14 | 3. Download a GGUF model for use. The default is to look for [dolphin-2.2.1-mistral-7b.Q5_K_M.gguf][dolphin-2.2.1-mistral-7b] in the current dir. If you want to use a different model, edit the script you're running. 15 | 16 | There are some dependencies for this project that you need, first. You can install with using pip: 17 | 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | Once you have everything installed and configured, you can kick off a session by coming up with a question and asking it on the command line: 23 | 24 | ``` 25 | python run_interface.py "What kind of data do I have available?" 26 | python llm_sql_queries.py "What are some interesting records in the database?" 27 | ``` 28 | 29 | The model output will be printed to stdout. 30 | 31 | [react-paper]: https://blog.research.google/2022/11/react-synergizing-reasoning-and-acting.html?m=1 32 | "ReAct: Synergizing Reasoning and Acting in Language Models" 33 | 34 | [llama-git]: https://github.com/ggerganov/llama.cpp 35 | "Port of Facebook's LLaMA model in C/C++" 36 | 37 | [xwin]: https://huggingface.co/TheBloke/Xwin-LM-13B-V0.1-GGUF 38 | "Xwin-LM-13B-V0.1-GGUF on Huggingface courtesy of TheBloke" 39 | 40 | [datasette-git]: https://github.com/simonw/datasette 41 | "An open source multi-tool for exploring and publishing data" 42 | 43 | [llama-cpp-py]: https://github.com/abetlen/llama-cpp-python 44 | "Python bindings for llama.cpp" 45 | 46 | [sqlite-utils]: https://github.com/simonw/sqlite-utils 47 | "Python CLI utility and library for manipulating SQLite databases" 48 | 49 | [dolphin-2.2.1-mistral-7b]: https://huggingface.co/TheBloke/dolphin-2.2.1-mistral-7B-GGUF/tree/main 50 | "Dolphin 2.2.1 Mistral 7B GGUF on HuggingFace" 51 | -------------------------------------------------------------------------------- /actions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | import sqlite3 6 | 7 | import sqlite_utils 8 | 9 | 10 | DB_PATH = "example.db" 11 | 12 | # columns: useful for getting table columns. input 1: table name. 13 | ACTIONS = """ 14 | tables: useful for getting the names of tables available. no input. 15 | schema: useful for looking at the schema of a database. input 1: table name. 16 | help: useful for getting helpful context about how to use tables and their columns. input 1: table name. (optional) input 2: column name. 17 | sql-query: useful for analyzing data and getting the top 5 results of a query. input 1: a valid sqlite sql query. 18 | """ 19 | 20 | DATA_HELP = { 21 | "users": { 22 | None: "profiles of individuals (sometimes called creators) who are seeking work, have worked on projects, or are looking to hire other people.", 23 | "creatorUserId": "this is the primary key for a user. the experiences table references it on the creatorUserId field", 24 | "createdUtc": "a ISO8601 datetime string of the user creation date", 25 | "updatedUtc": "a ISO8601 datetime string of the user's last updated date", 26 | "isPublic": "a boolean describing if the user's profile is public. all of these values will be true", 27 | "isContactAllowed": "a boolean describing whether or not the user allows people to contact them", 28 | "creatorDescription": "a free-text field the user has supplied describing themselves, their interests, work preferences and occasionally age/location. details like this are sometimes present: 'I am 23 years old' or 'been building games for 8 years'.", 29 | "isOpenToWork": "whether or not the user is actively looking for work", 30 | "interestDescription": "a text field describing the users interests", 31 | "linkTypes": "an array of platforms/methods the users can be contacted on", 32 | "preferredContactLinkType": "the user's preferred platform/method of contact", 33 | "socialLinks": "an array of JSON data describing the user's social media accounts", 34 | "jobTypes": "the type of jobs and work the user is seeking", 35 | "skillTypes": "an array containing skills the user has", 36 | "requiresAction": "always set to \"noAction\"", 37 | }, 38 | "experiences": { 39 | # table description 40 | None: "the experiences table describes previous jobs that users have worked on.", 41 | "experienceId": "the primary key for this experience", 42 | "creatorUserId": "the primary key from the users table of the user who put this experience on this profile", 43 | "createdUtc": "an ISO8601 experience create date", 44 | "updatedUtc": "an ISO8601 experience update date", 45 | "projectName": "the name of the project the user worked on. this matches the name column of the games table.", 46 | "experienceDescription": "a free-text field describing the work the user performed", 47 | "jobRole": "a free-text field describing the user's job title on the project", 48 | "experienceMedia": "an array of JSON objects describing media displayed on the profile. the names of the media may have descriptive meaning", 49 | "experienceLinks": "an array of text containing Markdown links to the game(s) for this experience. the ID in the games URLs may be present in the games table", 50 | "teamName": "the name of the team the user worked with, if applicable", 51 | "teamId": "the team ID, which may be different from the teamName across experiences", 52 | "startedUtc": "an ISO8601 date of when the user began work on the project", 53 | "endedUtc": "an ISO8601 date of when the user ended work on the project", 54 | "isCurrent": "wheter or not the project is ongoing", 55 | }, 56 | "games": { 57 | None: "these are games that are available to play. users mention them in their experience descriptions.", 58 | "placeId": "this is the game's primary key. this could show up as an ID in a games URL elsewhere", 59 | "name": "the name of the game", 60 | "description": "a free-form text description of the game", 61 | "sourceName": "the name of the game", 62 | "sourceDescription": "a free-form text description of the game", 63 | "url": "the URL to the game's page. can be found in experienceLinks or user descriptions. the placeId appears in the URL", 64 | "builder": "the name of the user who built the game", 65 | "builderId": "the ID of the user who build the game", 66 | "hasVerifiedBadge": "whether or not the game has gone through the verification process", 67 | "isPlayable": "whether or not the game is playable", 68 | "reasonProhibited": "this field, if it's not \"None\", provides the reason why the game isn't currently playable", 69 | "universeId": "I'm not sure what this field is. this ID may appear in user or experience descriptions", 70 | "universeRootPlaceId": "I'm not sure what this field is. this ID may appear in user or experience descriptions", 71 | "price": "the price of this game", 72 | "imageToken": "an ID representing the game's image", 73 | }, 74 | "game_stats": { 75 | None: "game stats record popularity metrics and game categories. the numbers here are all intended to be short and human-readable, not sortable", 76 | "Active": "how many users are active on the game", 77 | "Favorites": "how many users have favorited the game", 78 | "Visits": "how many users have visited or played the game", 79 | "Created": "when the game was created", 80 | "Updated": "when the last time the game was updates", 81 | "Server Size": "how large the server is", 82 | "Genre": "a category the game falls into", 83 | "Allowed Gear": "this field is always blank", 84 | "placeId": "a primary key of the game. this matches a row in the games table", 85 | }, 86 | "game_passes": { 87 | None: "game passes are add-ons that players can purchase that grant additional capabilities in games", 88 | "name": "the title of the add-on. this is not the game's name. that must be found via joining the games table on placeId=game_id.", 89 | "price": "how much this add-on costs", 90 | "seller_name": "the name of the game this add-on applies to", 91 | "description": "a free-text description of this add-on and what it provides", 92 | "down": "how many people purchased this add-on and liked it", 93 | "up": "how many people purchased this add-on and disliked it", 94 | }, 95 | "jobs": { 96 | None: "these are listings for jobs to work on games. users can apply to them privately. jobs list some details about the work arrangement and requirements", 97 | "id": "the primary key of the job listing", 98 | "jobPosterId": "the primary key of the user who created the job listing", 99 | "title": "the title of the job listing", 100 | "description": "a free-text field explaining details about the job and its requirements", 101 | "jobType": "the type of job, can be one of: \"Commission\", \"FullTime\", \"PartTime\"", 102 | "paymentTypes": "this described the the payment scheme and can be one of: \"Currency\", \"RevenuePercent\", \"Robux\"", 103 | "skillTypes": "a JSON array listing the types of skills required for this job", 104 | "publishedUtc": "an iso8601 job listing publish date", 105 | "expiresUtc": "an iso8601 job listing expiration date", 106 | "minAgeRequirement": "the job's minimum age requirements", 107 | "isVerifiedRequirement": "this field is always \"false\"", 108 | "isVerified": "whether or not the job has been verified, can be \"true\" or \"false\"", 109 | "paymentAmount": "a float describing how much the job pays", 110 | "paymentAmountType": "a category the paymentAmount falls into", 111 | } 112 | } 113 | 114 | IGNORED_TABLES = [ 115 | "ar_internal_metadata", 116 | "schema_migrations", 117 | "filing_folders", 118 | ] 119 | IGNORED_COLUMNS = [] 120 | 121 | 122 | def load_db(path): 123 | assert os.path.exists(path), f"Database doesn't exist: {path}" 124 | db = sqlite_utils.Database(path) 125 | return db 126 | 127 | 128 | def clean_truncate(results, n=3): 129 | return [ 130 | {k: v for k, v in r.items()} 131 | for r in results[:n] 132 | ] 133 | 134 | 135 | ## ACTIONS 136 | def tables(db): 137 | return [ 138 | name 139 | for name in db.table_names() 140 | # game stats confuses the model 141 | if ( 142 | "_fts" not in name 143 | and name not in IGNORED_TABLES 144 | and not name.endswith("_history") 145 | ) 146 | ] 147 | 148 | 149 | def schema(db, table_name): 150 | table_names = tables(db) 151 | if table_name not in table_names: 152 | return f"Error: Invalid table. Valid tables are: {table_names}" 153 | return re.sub('\s+', ' ', db[table_name].schema) 154 | 155 | 156 | def columns(db, table_name): 157 | table_names = tables(db) 158 | if table_name not in table_names: 159 | return f"Error: Invalid table. Valid tables are: {table_names}" 160 | return [ 161 | c.name 162 | for c in db[table_name].columns 163 | if c.name not in IGNORED_COLUMNS 164 | ] 165 | 166 | 167 | def help(db, *args): 168 | if not args: 169 | return "Error: The help action requires at least one argument" 170 | table_name = args[0] 171 | column = None 172 | if len(args) == 2: 173 | column = args[1] 174 | if table_name not in DATA_HELP: 175 | available_tables = tables(db) 176 | return f"Error: The table {table_name} doesn't exist. Valid tables: {available_tables}" 177 | if column not in DATA_HELP[table_name]: 178 | available_columns = [ 179 | c.name 180 | for c in db[table_name].columns 181 | if c.name not in IGNORED_COLUMNS 182 | ] 183 | return f"Error: The column {column} isn't in the {table_name} table. Valid columns: {available_columns}" 184 | help_text = DATA_HELP[table_name][column] 185 | # table help requested 186 | if column is None: 187 | return help_text 188 | # column help requested, add common values 189 | analysis = db[table_name].analyze_column(column, common_limit=2) 190 | common_values = ", ".join([f"{value}" for value, count in analysis.most_common]) 191 | return f"{help_text} the top two values are: {common_values}" 192 | 193 | 194 | def sql_query(db, query): 195 | if query.lower().startswith("select *"): 196 | return "Error: Select some specific columns, not *" 197 | try: 198 | results = list(db.query(query)) 199 | except sqlite3.OperationalError as e: 200 | return f"Your query has an error: {e}" 201 | return clean_truncate(results, n=5) 202 | -------------------------------------------------------------------------------- /benchmark_runner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from datetime import datetime 3 | import copy 4 | import multiprocessing 5 | import json 6 | import os 7 | import re 8 | import sys 9 | import time 10 | 11 | import numpy as np 12 | import pymeteor.pymeteor as pymeteor 13 | from nltk.corpus import stopwords 14 | from nltk import download 15 | import spacy 16 | from yaml import load, dump 17 | try: 18 | from yaml import CLoader as Loader, CDumper as Dumper 19 | except ImportError: 20 | from yaml import Loader, Dumper 21 | 22 | from metrics import get_keyword_matches 23 | from llm_sql_queries import execute 24 | from llm_openai_sql_queries import execute as execute_openai 25 | 26 | 27 | USE_EXAMPLE_INJECTION = True 28 | # HACK: globals 29 | nlp = None 30 | stop_words = None 31 | 32 | 33 | def load_yml_file(filename): 34 | with open(filename, "r") as f: 35 | return load(f, Loader=Loader) 36 | 37 | 38 | def preprocess(sentence): 39 | return [w for w in sentence.lower().split() if w not in stop_words] 40 | 41 | 42 | def best_matching_injectable(question, injectables): 43 | best = [0.0, injectables[0]["prompt"]] 44 | q1 = " ".join(preprocess(question)) 45 | question_vec = nlp(q1) 46 | for injectable in injectables: 47 | q2 = " ".join(preprocess(injectable["question"])) 48 | inj_question_vec = nlp(q2) 49 | sim = inj_question_vec.similarity(question_vec) 50 | print(sim, "Q:", q1, "Q2:", q2) 51 | if sim > best[0]: 52 | best = [sim, injectable["prompt"]] 53 | return best[1] 54 | 55 | 56 | def maybe_inject_prompts(prompt_data, question, injectables=None): 57 | new_prompt_data = copy.deepcopy(prompt_data) 58 | if not USE_EXAMPLE_INJECTION: 59 | return new_prompt_data 60 | 61 | if not injectables: 62 | return new_prompt_data 63 | 64 | if not nlp: 65 | return new_prompt_data 66 | 67 | similar_injectable = best_matching_injectable(question, injectables) 68 | 69 | # first: truncate the examples by looking for the inject_before: True 70 | # on the prompt items 71 | truncate_at = None 72 | for i, item in enumerate(new_prompt_data): 73 | if item.get("inject_before"): 74 | truncate_at = i 75 | break 76 | 77 | if truncate_at is None: 78 | return new_prompt_data 79 | 80 | # This also cuts off the final part, we need to fix that 81 | truncated_prompt_data = new_prompt_data[:i] + similar_injectable 82 | # append the question now 83 | truncated_prompt_data.append(new_prompt_data[-1]) 84 | return truncated_prompt_data 85 | 86 | 87 | def prompt_data_to_openai(prompt_data, question, injectables=None): 88 | prompt_completed = maybe_inject_prompts( 89 | prompt_data, question, injectables=injectables 90 | ) 91 | prompt_completed[-1]["content"] = prompt_completed[-1]["content"].format( 92 | question=question 93 | ) 94 | print("Final instruction in prepared prompt:", prompt_completed[-1]) 95 | # clean up the prompt because openAI explodes if any unexpected keys 96 | # are supplied 97 | openai_allowed_keys = ["role", "content"] 98 | finalized_prompt = [ 99 | {k: v for k, v in item.items() if k in openai_allowed_keys} 100 | for item in prompt_completed 101 | ] 102 | return finalized_prompt 103 | 104 | 105 | def prompt_data_to_raw(prompt_data, question, injectables=None): 106 | prompt_completed = maybe_inject_prompts(prompt_data, question, injectables=injectables) 107 | prompt_raw = "" 108 | for item in prompt_completed: 109 | line = item["content"].format(question=question) 110 | prompt_raw += line 111 | prompt_raw += "\n" 112 | if "Final Answer:" in line: 113 | prompt_raw += "\n" 114 | return prompt_raw.strip() 115 | 116 | 117 | def prompt_data_to_chatml(prompt_data, question, injectables=None): 118 | prompt_completed = maybe_inject_prompts(prompt_data, question, injectables=injectables) 119 | prompt_raw = "" 120 | last_item = len(prompt_completed) - 1 121 | for i, item in enumerate(prompt_completed): 122 | line = item["content"].format(question=question).strip() 123 | if item["role"] == "system": 124 | prompt_raw += "<|im_start|>system\n" 125 | prompt_raw += f"{line}\n<|im_end|>\n" 126 | 127 | if item["role"] == "assistant": 128 | prompt_raw += "<|im_start|>system name=example_assistant\n" 129 | prompt_raw += f"{line}\n<|im_end|>\n" 130 | if "Final Answer: " in line: 131 | prompt_raw += "\n" 132 | 133 | if item["role"] == "user" and i != (last_item): 134 | prompt_raw += "<|im_start|>system name=example_user\n" 135 | prompt_raw += f"{line}\n<|im_end|>\n" 136 | 137 | # the final one is the question with the lead out for completion 138 | if item["role"] == "user" and i == (last_item): 139 | prompt_raw += "<|im_start|>user\n" 140 | prompt_raw += f"{line}\n<|im_end|>\n" 141 | prompt_raw += "<|im_start|>assistant\n" 142 | prompt_raw += "Thought: " 143 | 144 | return prompt_raw.strip() 145 | 146 | 147 | def get_model_name(model_file): 148 | model_name=re.sub('[^A-Za-z0-9\-_]+', "_", os.path.basename(model_file)) 149 | return model_name 150 | 151 | 152 | def get_tracefile(model_file): 153 | model_name = get_model_name(model_file) 154 | now=datetime.now().strftime("%Y-%m-%d_%H:%M:%S.%f") 155 | tracefile = f"./traces/experiment_{model_name}_{now}.log" 156 | return tracefile 157 | 158 | 159 | def run_llm(*args, timeout=30*60, **kwargs): 160 | # shared dict for transferring results back from the proc 161 | manager = multiprocessing.Manager() 162 | return_dict = manager.dict() 163 | kwargs["return_dict"] = return_dict 164 | 165 | execute_fn = execute 166 | if args[0].startswith("openai:"): 167 | execute_fn = execute_openai 168 | 169 | p = multiprocessing.Process( 170 | target=execute_fn, name="LLM", 171 | args=args, kwargs=kwargs 172 | ) 173 | p.start() 174 | p.join(timeout) 175 | if p.is_alive(): 176 | p.terminate() 177 | p.join() 178 | raise Exception(f"Timed out after {timeout}s") 179 | 180 | if not return_dict: 181 | print("Blank return_dict. Likely an error!") 182 | 183 | return return_dict.get("final_answer"), return_dict.get("trace") 184 | 185 | 186 | def save_experiment_data(experiment_output, experiment_data): 187 | print("Writing experiment data to", experiment_output) 188 | with open(experiment_output, "w") as f: 189 | f.write(json.dumps(experiment_data, indent=2)) 190 | 191 | 192 | def run_experiment( 193 | model_path, prompt_data, qa, experiment_output, 194 | cooldown=None, n_tries=10, n_gpu_layers=0, 195 | temp=None, top_p=None, 196 | injectables=None, timeout=30*60 197 | ): 198 | experiment_data = { 199 | "question_results": [], 200 | "model_name": get_model_name(model_path), 201 | "model_path": model_path, 202 | "prompt": prompt_data, 203 | } 204 | for q_data in qa: 205 | q_result = copy.deepcopy(q_data) 206 | print() 207 | print("="*72) 208 | print("Beginning with question:", q_result["question"]) 209 | q_result["scores"] = [] 210 | q_result["tracefiles"] = [] 211 | q_result["errors"] = [] 212 | q_result["keyword_matches"] = [] 213 | q_result["answers"] = [] 214 | for i in range(n_tries): 215 | print("-" * 72) 216 | print(f"Attempt: {i}") 217 | 218 | question = q_result["question"] 219 | print("Preparing prompt with Question:", question) 220 | prompt = None 221 | if experiment_prompt == "raw": 222 | prompt = prompt_data_to_raw(prompt_data, question, injectables=injectables) 223 | elif experiment_prompt == "chatml": 224 | prompt = prompt_data_to_chatml(prompt_data, question, injectables=injectables) 225 | elif experiment_prompt == "openai": 226 | prompt = prompt_data_to_openai(prompt_data, question, injectables=injectables) 227 | 228 | tracefile = get_tracefile(model_path) 229 | correct_keywords = q_result["correct_keywords"] 230 | 231 | print("Writing to:", tracefile) 232 | answer = None 233 | error = None 234 | try: 235 | answer, trace = run_llm( 236 | model_path, outfile=tracefile, 237 | debug=False, prompt=prompt, 238 | n_gpu_layers=n_gpu_layers, 239 | timeout=timeout, temp=temp, 240 | top_p=top_p 241 | ) 242 | except Exception as e: 243 | print(f"ERROR: {e}") 244 | error = f"{e}" 245 | 246 | print("Answer:", answer) 247 | reference = q_result["correct_answer"] 248 | candidate = answer or "" 249 | 250 | meteor_score = pymeteor.meteor(reference, candidate, print_details=True) 251 | print("Score:", meteor_score) 252 | keyword_matches = get_keyword_matches(candidate, correct_keywords) 253 | print("Keyword Matches:", keyword_matches) 254 | 255 | q_result["scores"].append(meteor_score) 256 | q_result["tracefiles"].append(tracefile) 257 | q_result["errors"].append(error) 258 | q_result["keyword_matches"].append(keyword_matches) 259 | q_result["answers"].append(answer) 260 | 261 | save_experiment_data(experiment_output, experiment_data) 262 | 263 | if cooldown: 264 | print(f"Cooling down for {cooldown}s...") 265 | time.sleep(cooldown) 266 | 267 | print("Appending experiment") 268 | experiment_data["question_results"].append(q_result) 269 | print(len(experiment_data["question_results"]), 270 | "experiments have been completed") 271 | 272 | save_experiment_data(experiment_output, experiment_data) 273 | 274 | return experiment_data 275 | 276 | 277 | if __name__ == "__main__": 278 | try: 279 | experiment_plan_file = sys.argv[1] 280 | except IndexError: 281 | print("USAGE: benchmark.py experiment_plan_file") 282 | print("Where experiment_plan_file points to a yaml file describing the experiments to be performed.") 283 | sys.exit(1) 284 | 285 | print("Loading experiment plan file", experiment_plan_file) 286 | experiment_plan = load_yml_file(experiment_plan_file) 287 | 288 | today=datetime.now().strftime("%Y-%m-%d") 289 | exp_name = experiment_plan["EXPERIMENT_NAME"] 290 | timeout = experiment_plan.get("TIMEOUT") 291 | 292 | prompt_data = experiment_plan["PROMPT_DATA"] 293 | injectables = experiment_plan.get("AVAILABLE_INJECT_PROMPTS") 294 | 295 | if USE_EXAMPLE_INJECTION: 296 | print("Loading NLP models") 297 | nlp = spacy.load("en_core_web_lg") 298 | print("Loading stopwords") 299 | download('stopwords') # Download stopwords list. 300 | stop_words = stopwords.words('english') 301 | 302 | for model_data in experiment_plan["MODELS"]: 303 | experiment_model = model_data["path"] 304 | experiment_prompt = model_data["prompt_type"] 305 | 306 | model_name = get_model_name(experiment_model) 307 | experiment_output = f"./experiments/{exp_name}_{today}_{model_name}.json" 308 | 309 | experiment_data = run_experiment( 310 | experiment_model, 311 | prompt_data, 312 | experiment_plan["QA"], 313 | experiment_output, 314 | cooldown=model_data.get("cooldown") or experiment_plan.get("COOLDOWN"), 315 | n_tries=experiment_plan["N_TRIES"], 316 | n_gpu_layers=model_data.get("n_gpu_layers", 0), 317 | temp=experiment_plan.get("temp"), 318 | top_p=experiment_plan.get("top_p"), 319 | injectables=injectables, 320 | timeout=model_data.get("timeout", timeout) 321 | ) 322 | save_experiment_data(experiment_output, experiment_data) 323 | -------------------------------------------------------------------------------- /example-benchmark.yml: -------------------------------------------------------------------------------- 1 | EXPERIMENT_NAME: "EXPERIMENT_SQL" 2 | # which models to try and their prompt format 3 | MODELS: 4 | - path: ../models/dolphin-2.1-mistral-7b.Q5_K_S.gguf 5 | prompt_type: chatml 6 | n_gpu_layers: -1 7 | - path: ../models/mistral-7b-v0.1.Q5_K_M.gguf 8 | prompt_type: raw 9 | n_gpu_layers: -1 10 | - path: ../models/llama-2-7b.Q5_K_M.gguf 11 | prompt_type: raw 12 | n_gpu_layers: -1 13 | - path: ../models/llama-2-13b.Q5_K_M.gguf 14 | prompt_type: raw 15 | - path: ../models/xwin-lm-7b-v0.1.Q5_K_M.gguf 16 | prompt_type: raw 17 | n_gpu_layers: -1 18 | - path: ../models/xwin-lm-13b-v0.2.Q5_K_M.gguf 19 | prompt_type: raw 20 | - path: ../models/ultralm-13b-v2.0.Q5_K_M.gguf 21 | prompt_type: raw 22 | - path: ../models/sqlcoder.Q5_K_M.gguf 23 | prompt_type: raw 24 | - path: ../models/llama-2-70b-orca-200k.Q5_K_S.gguf 25 | prompt_type: raw 26 | timeout: 14400 27 | # # model temperature setting 28 | # temp: 0.0 29 | # # model Top-P nucleus sampling threshold not 30 | # # used if temp is 0 31 | # top_p: 0.1 32 | # wait this long between runs, let the GPU cool down 33 | COOLDOWN: 30 34 | # 120 mins max runtime (should be PLENTY) 35 | TIMEOUT: 7200 36 | # On a local 70B model we might want more time, though 37 | # "TIMEOUT": 14400 38 | # how many times to try each question 39 | N_TRIES: 10 40 | QA: [{ 41 | "question": "first question", 42 | "correct_answer": "The correct answer (approximately)", 43 | "correct_keywords": [ 44 | "A Full phrase to match", 45 | # this will match with and without the dollar signs and commas 46 | "$10,000,000", 47 | ] 48 | }] 49 | # some prompts that should be matched based on similarity to the input question and injected 50 | AVAILABLE_INJECT_PROMPTS: 51 | - question: An example question to match 52 | prompt: 53 | - role: user 54 | content: |- 55 | Question: An example question to match 56 | 57 | - role: assistant 58 | content: |- 59 | Thought: I should begin with ... 60 | Action: tables 61 | 62 | - role: user 63 | content: |- 64 | Observation: etc etc continue the trace yourself 65 | 66 | PROMPT_DATA: 67 | - role: system 68 | content: |- 69 | Answer the following questions as best you can. You have access to the following tools: 70 | 71 | action1: description. input 1: arg desc. etc ... 72 | 73 | Use the following format: 74 | 75 | Question: the input question you must answer 76 | Thought: you should always think about what to do 77 | Action: the action to take, should be one of: tables, schema, help, sql-query 78 | Action Input 1: the first input to the action. 79 | Observation: the result of the action 80 | ... (this Thought/Action/Action Input/Observation can repeat N times) 81 | Thought: I now know the final answer 82 | Final Answer: the final answer to the original input question 83 | 84 | - role: user 85 | content: |- 86 | Question: What information do I have about users? 87 | 88 | - role: assistant 89 | content: |- 90 | Thought: continue the trace ... 91 | Action: ... 92 | 93 | - role: user 94 | content: |- 95 | Observation: ```...``` 96 | 97 | - role: assistant 98 | content: |- 99 | Thought: ... 100 | Final Answer: ... example final answer ... 101 | 102 | - role: user 103 | # this marks the point where if we're going to inject examples we'll do it before here 104 | # and not use the prompt from here until the second before the end 105 | inject_before: True 106 | content: |- 107 | Question: ... 108 | 109 | - role: assistant 110 | content: |- 111 | Thought: ... 112 | Final Answer: ... 113 | 114 | - role: user 115 | content: |- 116 | Question: {question} 117 | -------------------------------------------------------------------------------- /example-prompt.txt: -------------------------------------------------------------------------------- 1 | Answer the following questions as best you can. You have access to the following tools: 2 | 3 | tables: Useful for getting the names of tables available. No input. 4 | schema: Useful for looking at the schema (columns and data types) of a table. Input 1: table name. 5 | help: Returns helpful information describing a table or a table's column. Useful for understanding things like the relationship between tables and how to interpret columns. Input 1: table name. (optional) Input 2: column name. 6 | sql-query: Useful for analyzing data and getting the top 5 results of a query. Input 1: a valid SQLite3 SQL query. 7 | 8 | Use the following format: 9 | 10 | Question: the input question you must answer 11 | Thought: you should always think about what to do 12 | Action: the action to take, should be one of: tables, schema, help, sql-query 13 | Action Input 1: the first input to the action. 14 | Observation: the result of the action 15 | ... (this Thought/Action/Action Input/Observation can repeat N times) 16 | Thought: I now know the final answer 17 | Final Answer: the final answer to the original input question 18 | 19 | # Here's an example: 20 | 21 | Question: What information do I have about users? 22 | Thought: I should check to see if I have any users tables. 23 | Action: tables 24 | Observation: ```["jobs", "users", "games", "game_passes"]``` 25 | Thought: I should inspect the columns on the users table. 26 | Action: schema 27 | Action Input 1: ```users``` 28 | Observation: ```CREATE TABLE [users] ( [creatorUserId] INTEGER PRIMARY KEY, [isContactAllowed] INTEGER, [creatorDescription] TEXT, [isOpenToWork] INTEGER, [interestDescription] TEXT, [linkTypes] TEXT, [preferredContactLinkType] TEXT, [socialLinks] TEXT, [jobTypes] TEXT, [skillTypes] TEXT, [requiresAction] TEXT )``` 29 | Thought: I should see what ways the users table can be helpful by checking the help. 30 | Action: help 31 | Action Input 1: ```users``` 32 | Observation: users are individuals who are seeking work, have worked or are looking to hire people to work on games. sometimes they mention personal details about themselves in their description. 33 | Thought: I have all of the information I need. 34 | Final Answer: I have information about user skills, job types, and whether or not they're actively open to work for users. For some users, we have personal information in their profiles. 35 | 36 | # Here's another example: 37 | 38 | Question: How many jobs have been offered? 39 | Thought: I should look at which tables I have available. 40 | Action: tables 41 | Observation: ```["jobs", "users", "games", "game_passes"]``` 42 | Thought: I should count the rows of the jobs table using SQLite SQL. 43 | Action: sql-query 44 | Action Input 1: ```select count(id) from jobs;``` 45 | Observation: ```[{{'count(id)': 6753}}]``` 46 | Thought: This query has given me the count of jobs in the table. I have a final answer. 47 | Final Answer: There have been 6,753 total jobs offered according to the database. 48 | 49 | # You can run any SQLite query to answer aggregate questions: 50 | 51 | Question: What are the most common payment amounts that have been offered for jobs? 52 | Thought: I should look at which tables I have available. 53 | Action: tables 54 | Observation: ```["jobs", "users", "games", "game_passes"]``` 55 | Thought: I should use a SQL group by query to see the top jobs.paymentAmount values. 56 | Action: sql-query 57 | Action Input 1: ```select paymentAmount, count(paymentAmount) as n from jobs group by paymentAmount limit 3;``` 58 | Observation: ```[{{"paymentAmount": 0.0, "n": 713}}, {{"paymentAmount": 1.0, "n": 157}}, {{"paymentAmount": 2.0, "n": 22}}]``` 59 | Thought: This query has given me the count of the top payment amounts in the 60 | jobs table. I have a final answer. 61 | Final Answer: The top three payment amounts offered for jobs are 0.0 (713 jobs), 1.0 (157), and 2.0 (22 jobs). 62 | 63 | # You can get helpful information about table and columns by using the help action: 64 | 65 | Question: What kind of information shows up in the user description? 66 | Thought: I should read the help for the description column in the users table. 67 | Action: help 68 | Action Input 1: ```users``` 69 | Action Input 2: ```creatorDescription``` 70 | Observation: The users table's creatorDescription column is a free-text field the user has supplied, describing themselves, their interests and work preferences. 71 | Thought: I have some information about what appears in user descriptions. 72 | Final Answer: Users sometimes put their interests, work preferences and demographic information in their description. 73 | 74 | Question: {question} 75 | Thought: 76 | -------------------------------------------------------------------------------- /llm_openai_sql_queries.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | import sqlite3 6 | import time 7 | 8 | import openai 9 | import sqlite_utils 10 | 11 | from llm_sql_queries import ( 12 | DB_PATH, load_db, 13 | tables, schema, help, sql_query 14 | ) 15 | 16 | 17 | # Larger context sizes will reduce quality, but some models 18 | # support large contexts better than others. 19 | #CONTEXT_SIZE=2048 20 | CONTEXT_SIZE=2048*2 21 | # how many tokens to allow the model to output in a sigle go w/o stopping 22 | MAX_TOKENS=400 23 | 24 | 25 | def execute(model_path, outfile=None, debug=True, return_dict=None, 26 | prompt=None, n_gpu_layers=0, temp=None, top_p=None): 27 | assert prompt, "You didn't supply a prompt" 28 | db = load_db(DB_PATH) 29 | openai.organization = os.environ["OPENAI_ORG_ID"] 30 | openai.api_key = os.environ["OPENAI_API_KEY"] 31 | assert openai.organization and openai.api_key, "No OpenAI credentials" 32 | action_fns = { 33 | "tables": tables, 34 | # "columns": columns, 35 | "schema": schema, 36 | "help": help, 37 | "sql-query": sql_query, 38 | } 39 | 40 | if debug: 41 | print(json.dumps(prompt, indent=2)) 42 | 43 | total_tokens = 0 44 | done = False 45 | while not done: 46 | model_name = model_path.split(":", 1)[1] 47 | print("Running OpenAI model:", model_name) 48 | print("Last prompt line:", json.dumps(prompt[-1], indent=2)) 49 | model_kwargs = dict( 50 | # model="gpt-4", 51 | model=model_name, 52 | # string / array / null 53 | # Up to 4 sequences where the API will stop generating 54 | # further tokens. The returned text will not contain the 55 | # stop sequence. 56 | stop=[ 57 | "Question:", "Observation:", 58 | "<|im_end|>", "<|im_start|>user", 59 | ], 60 | stream=True, 61 | messages=prompt, 62 | ) 63 | 64 | # Open AI recommends not using BOTH temperature and top-p 65 | if temp is not None: 66 | model_kwargs["temperature"] = temp 67 | elif top_p is not None: 68 | model_kwargs["top_p"] = top_p 69 | 70 | try: 71 | stream = openai.ChatCompletion.create( 72 | **model_kwargs 73 | ) 74 | except openai.error.RateLimitError: 75 | print("Cooling down...") 76 | time.sleep(30) 77 | 78 | with open("debug-openai.log", "a") as f: 79 | f.write(json.dumps(prompt)) 80 | f.write('\n') 81 | 82 | response = "" 83 | for i, item in enumerate(stream): 84 | # { 85 | # "choices": [ 86 | # { 87 | # "delta": { 88 | # "role": "assistant" 89 | # # OR, once started a role 90 | # "content": "\n\n" 91 | # }, 92 | # "finish_reason": null | "stop", 93 | # "index": 0 94 | # } 95 | # ], 96 | # "created": 1677825464, 97 | # "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD", 98 | # "model": "gpt-3.5-turbo-0301", 99 | # "object": "chat.completion.chunk" 100 | # } 101 | if i > MAX_TOKENS: 102 | break 103 | choice = item['choices'][0] 104 | print(i, json.dumps(choice), end=" \r") 105 | 106 | # if it gives a non-assistant role, end 107 | role = choice["delta"].get("role") 108 | if role and role != "assistant": 109 | break 110 | # if it wants to stop (or hits a stopword) let it 111 | if choice.get("finish_reason") == "stop": 112 | break 113 | 114 | total_tokens += 1 115 | if total_tokens > CONTEXT_SIZE: 116 | done = True 117 | break 118 | 119 | # otherwise assume we have another token 120 | token = choice["delta"]["content"] 121 | response += token 122 | 123 | with open("debug-openai.log", "a") as f: 124 | f.write(json.dumps(item)) 125 | f.write('\n') 126 | 127 | # Update the prompt 128 | prompt.append({"role": "assistant", "content": response}) 129 | 130 | if debug: 131 | print(response) 132 | 133 | with open("debug-openai.log", "a") as f: 134 | f.write(json.dumps(prompt)) 135 | f.write('\n') 136 | 137 | if outfile: 138 | print("Writing to tracefile", outfile) 139 | with open(outfile, "w") as f: 140 | f.write(json.dumps(prompt, indent=2)) 141 | 142 | if done: 143 | break 144 | 145 | try: 146 | action = re.findall(r"Action: (.*)", response, re.M)[0] 147 | except IndexError: 148 | action = None 149 | 150 | try: 151 | final_answer = re.findall(r'Final Answer: (.*)', response, re.M|re.S)[0] 152 | except IndexError: 153 | final_answer = None 154 | 155 | if action and action not in action_fns: 156 | action_names = ", ".join(list(action_fns.keys())) 157 | prompt.append({ 158 | "role": "user", 159 | "content": f"Observation: That's an invalid action. Valid actions: {action_names}" 160 | }) 161 | 162 | elif action: 163 | print("Action in response", response) 164 | # NOTE: we could change 1 for the number of args of selected action 165 | actionInputs = re.findall( 166 | r'Action Input (\d): ```([^`]+)```', response, re.M|re.S 167 | ) 168 | # try and recover actions without backticks 169 | if not actionInputs: 170 | actionInputs = re.findall( 171 | r'Action Input (\d): ([^`]+)', response, re.M|re.S 172 | ) 173 | print("actionInputs", actionInputs) 174 | 175 | args = [ 176 | inp[1] 177 | for inp in actionInputs 178 | ] 179 | action_fn = action_fns[action] 180 | observation_text = "" 181 | try: 182 | print("Running action", action_fn, end="... \t") 183 | result = action_fn(db, *args) 184 | print("Done!", end="\r") 185 | result_text = json.dumps(result) 186 | observation_text = f"```{result_text}```" 187 | except TypeError as e: 188 | if "positional argument" not in str(e): 189 | raise e 190 | # trim off the name of of the action from msg like: 191 | # hi() takes 1 positional argument but 2 were given 192 | # and turn it into: 193 | # The action hi takes 1 Action Input but 2 were given 194 | args_err_msg = str(e).split(" ", 1)[1].replace( 195 | "positional argument", "Action Input" 196 | ).replace( 197 | "positional arguments", "Action Inputs" 198 | ).split(": '", 1)[0] 199 | observation_text = f"The action {action} {args_err_msg}" 200 | prompt.append({ 201 | "role": "user", 202 | "content": f"Observation: {observation_text}" 203 | }) 204 | 205 | elif final_answer: 206 | if return_dict is not None: 207 | return_dict["final_answer"] = final_answer 208 | return_dict["trace"] = prompt 209 | return final_answer, prompt 210 | 211 | # TODO: truncate the prompt if its grown too long 212 | # using tiktoken and some keep_n value of context 213 | 214 | if return_dict is not None: 215 | return_dict["final_answer"] = None 216 | return_dict["trace"] = prompt 217 | 218 | return None, prompt 219 | -------------------------------------------------------------------------------- /llm_sql_queries.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | import sqlite3 6 | 7 | try: 8 | from llama_cpp import Llama 9 | except ModuleNotFoundError: 10 | print("llama_cpp not installed, continuing without") 11 | 12 | from actions import ( 13 | DB_PATH, load_db, 14 | tables, schema, help, sql_query 15 | ) 16 | 17 | 18 | # Larger context sizes will reduce quality, but some models 19 | # support large contexts better than others. 20 | #CONTEXT_SIZE=2048 21 | CONTEXT_SIZE=2048*2 22 | # how many tokens to allow the model to output in a sigle go w/o stopping 23 | MAX_TOKENS=400 24 | 25 | 26 | # Utils n stuff 27 | def load_model(model_path, n_gpu_layers=0, n_threads=os.cpu_count() - 1, 28 | n_ctx=CONTEXT_SIZE, temp=None, top_p=None): 29 | # for LLaMA2 70B models add kwarg: n_gqa=8 (NOTE: not required for GGUF models) 30 | print("Loading model", model_path) 31 | print("CTX:", n_ctx, "GPU layers:", n_gpu_layers, "CPU threads:", n_threads) 32 | print("Temperature:", temp, "Top-p Sampling:", top_p) 33 | kwargs = dict( 34 | model_path=model_path, 35 | n_ctx=n_ctx, 36 | n_gpu_layers=n_gpu_layers, 37 | n_threads=n_threads, 38 | verbose=False 39 | ) 40 | if temp is not None: 41 | kwargs["temp"] = temp 42 | if top_p is not None: 43 | kwargs["top_p"] = top_p 44 | llm = Llama(**kwargs) 45 | return llm 46 | 47 | 48 | def execute(model_path, outfile=None, debug=True, return_dict=None, 49 | prompt=None, n_gpu_layers=0, temp=None, top_p=None): 50 | llm = load_model(model_path, n_gpu_layers=n_gpu_layers, temp=temp, 51 | top_p=top_p) 52 | db = load_db(DB_PATH) 53 | action_fns = { 54 | "tables": tables, 55 | "schema": schema, 56 | "help": help, 57 | "sql-query": sql_query, 58 | } 59 | action_names_text = ", ".join(list(action_fns.keys())) 60 | prompt_is_chatml = "<|im_start|>" in prompt 61 | if debug: 62 | print(prompt) 63 | 64 | n_sequential_whitespace = 0 65 | n_thoughts_seen = 0 66 | done = False 67 | while not done: 68 | stream = llm( 69 | prompt, 70 | max_tokens=MAX_TOKENS, 71 | stop=["Question:", "Observation:", "<|im_end|>", "<|im_start|>user"], 72 | stream=True, 73 | echo=True 74 | ) 75 | response = "" 76 | for i, token in enumerate(stream): 77 | choice = token['choices'][0] 78 | print(i, choice, end="\t\t\t\t\t\r") 79 | token = choice["text"] 80 | response += token 81 | if token in ["", "\n"]: 82 | n_sequential_whitespace += 1 83 | else: 84 | n_sequential_whitespace = 0 85 | # detect repeating loop 86 | if response.count("Thought: ") > 4: 87 | done = True 88 | break 89 | if n_sequential_whitespace > 20: 90 | done = True 91 | break 92 | 93 | with open("debug.log", "a") as f: 94 | f.write(json.dumps(token)) 95 | f.write('\n') 96 | 97 | if prompt_is_chatml and not response.strip().endswith("<|im_end|>"): 98 | response = f"{response.strip()}\n<|im_end|>\n" 99 | 100 | # Update the prompt 101 | prompt = f"{prompt}{response}".strip() 102 | 103 | if debug: 104 | print(response) 105 | 106 | if outfile: 107 | print("Writing to tracefile", outfile) 108 | with open(outfile, "w") as f: 109 | f.write(prompt) 110 | 111 | if done: 112 | break 113 | 114 | try: 115 | action = re.findall(r"Action: (.*)", response, re.M)[0] 116 | except IndexError: 117 | action = None 118 | 119 | try: 120 | final_answer = re.findall(r'Final Answer: (.*)', response, re.M|re.S)[0] 121 | except IndexError: 122 | final_answer = None 123 | 124 | if action and action not in action_fns: 125 | action_names = ", ".join(list(action_fns.keys())) 126 | if prompt_is_chatml: 127 | prompt += f""" 128 | <|im_start|>user 129 | Observation: That's an invalid action. Valid actions: {action_names} 130 | <|im_end|> 131 | <|im_start|>assistant 132 | Thought: """ 133 | else: 134 | prompt += f"""Observation: That's an invalid action. Valid actions: {action_names} 135 | Thought: """ 136 | 137 | elif action: 138 | # NOTE: we could change 1 for the number of args of selected action 139 | actionInputs = re.findall( 140 | r'Action Input (\d): ```([^`]+)```', response, re.M|re.S 141 | ) 142 | args = [ 143 | inp[1] 144 | for inp in actionInputs 145 | ] 146 | action_fn = action_fns[action] 147 | observation_text = "" 148 | try: 149 | print("Running action", action_fn, end="... \t") 150 | result = action_fn(db, *args) 151 | print("Done!", end="\r") 152 | result_text = json.dumps(result) 153 | observation_text = f"```{result_text}```" 154 | except TypeError as e: 155 | if "positional argument" not in str(e): 156 | raise e 157 | # trim off the name of of the action from msg like: 158 | # hi() takes 1 positional argument but 2 were given 159 | # and turn it into: 160 | # The action hi takes 1 Action Input but 2 were given 161 | args_err_msg = str(e).split(" ", 1)[1].replace( 162 | "positional argument", "Action Input" 163 | ).replace( 164 | "positional arguments", "Action Inputs" 165 | ).split(": '", 1)[0] 166 | observation_text = f"The action {action} {args_err_msg}" 167 | if prompt_is_chatml: 168 | prompt += f""" 169 | <|im_start|>user 170 | Observation: {observation_text} 171 | <|im_end|> 172 | <|im_start|>assistant 173 | Thought: """ 174 | else: 175 | prompt += f""" 176 | Observation: {observation_text} 177 | Thought: """ 178 | 179 | elif final_answer: 180 | if return_dict is not None: 181 | return_dict["final_answer"] = final_answer.replace( 182 | "<|im_end|>", "" 183 | ).strip() 184 | return_dict["trace"] = prompt 185 | return final_answer, prompt 186 | 187 | # TODO: truncate the prompt if its grown too long 188 | # using tiktoken and some keep_n value of context 189 | 190 | if return_dict is not None: 191 | return_dict["final_answer"] = None 192 | return_dict["trace"] = prompt 193 | return None, prompt 194 | 195 | 196 | if __name__ == "__main__": 197 | question = sys.argv[1] 198 | model_path = "dolphin-2.2.1-mistral-7b.Q5_K_M.gguf" 199 | with open("example-prompt.txt", "r") as f: 200 | prompt = f.read().format(question=question.strip()) 201 | answer, trace = execute( 202 | model_path, outfile=None, 203 | debug=False, prompt=prompt, 204 | n_gpu_layers=0, 205 | temp=0, 206 | top_p=None 207 | ) 208 | 209 | print("Trace", trace) 210 | print("Final Answer:", answer) 211 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | STATES = { 5 | 'AK':'Alaska', 6 | 'AL':'Alabama', 7 | 'AR':'Arkansas', 8 | 'AZ':'Arizona', 9 | 'CA':'California', 10 | 'CO':'Colorado', 11 | 'CT':'Connecticut', 12 | 'DC':'District of Columbia', 13 | 'DE':'Delaware', 14 | 'FL':'Florida', 15 | 'GA':'Georgia', 16 | 'HI':'Hawaii', 17 | 'IA':'Iowa', 18 | 'ID':'Idaho', 19 | 'IL':'Illinois', 20 | 'IN':'Indiana', 21 | 'KS':'Kansas', 22 | 'KY':'Kentucky', 23 | 'LA':'Louisiana', 24 | 'MA':'Massachusetts', 25 | 'MD':'Maryland', 26 | 'ME':'Maine', 27 | 'MI':'Michigan', 28 | 'MN':'Minnesota', 29 | 'MO':'Missouri', 30 | 'MS':'Mississippi', 31 | 'MT':'Montana', 32 | 'NC':'North Carolina', 33 | 'ND':'North Dakota', 34 | 'NE':'Nebraska', 35 | 'NH':'New Hampshire', 36 | 'NJ':'New Jersey', 37 | 'NM':'New Mexico', 38 | 'NV':'Nevada', 39 | 'NY':'New York', 40 | 'OH':'Ohio', 41 | 'OK':'Oklahoma', 42 | 'OR':'Oregon', 43 | 'PA':'Pennsylvania', 44 | 'RI':'Rhode Island', 45 | 'SC':'South Carolina', 46 | 'SD':'South Dakota', 47 | 'TN':'Tennessee', 48 | 'TX':'Texas', 49 | 'UT':'Utah', 50 | 'VA':'Virginia', 51 | 'VT':'Vermont', 52 | 'WA':'Washington', 53 | 'WI':'Wisconsin', 54 | 'WV':'West Virginia', 55 | 'WY':'Wyoming' 56 | } 57 | 58 | 59 | def get_keyword_matches(result, correct_keywords, return_texts=False): 60 | match_texts = [] 61 | matches = 0 62 | if not result: 63 | if return_texts: 64 | return matches, match_texts 65 | return matches 66 | for keyword in correct_keywords: 67 | keyword_nocomma = re.sub(r"[$,]+", "", str(keyword)) 68 | keyword_re = rf"[(\b\s]({keyword_nocomma})(?:[).,\s\b]|$)" 69 | # dollar amounts look for the full int sans symbols 70 | if isinstance(keyword, (int, float)) or str(keyword).startswith("$"): 71 | res_nocomma = re.sub(r"[$,]+", "", result) 72 | found = re.findall(keyword_re, res_nocomma, re.I) 73 | if len(found) > 0: 74 | matches += 1 75 | match_texts.append(found) 76 | # # if we have a state, check for case-sensitive abbrev + full name 77 | # elif keyword in STATES: 78 | # if f" {keyword}" in result: 79 | # matches += 1 80 | # elif STATES[keyword] in result: 81 | # matches += 1 82 | # case insensitive match on phrases 83 | else: 84 | found = re.findall(keyword_re, result, re.I) 85 | if len(found) > 0: 86 | matches += 1 87 | match_texts.append(found) 88 | if return_texts: 89 | return matches, match_texts 90 | return matches 91 | -------------------------------------------------------------------------------- /reason-act-llm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brandonrobertz/reason-act-sqlite-py/da60bb8525fc576d443679c83099225c93a769e7/reason-act-llm.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # python 3.11.6 2 | llama_cpp_python>=0.2.7 3 | # pip install --index-url https://test.pypi.org/simple/ pymeteor 4 | pymeteor @ https://test-files.pythonhosted.org/packages/e9/8a/c72ff9c96ccc49340a8f5cdb331e25ab1b9f105e4bb2fc8b230509dcad24/pymeteor-0.0.1-py3-none-any.whl#sha256=5784be3be1a247a31cea8ad7c296323cbfb35f2873b3f002264ec21a8431fa62 5 | en-core-web-lg @ https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.7.0/en_core_web_lg-3.7.0-py3-none-any.whl#sha256=708da1110fbe1163d059de34a2cbedb1db65c26e1e624ca925897a2711cb7d77 6 | nltk>=3.8.1,<4.0 7 | numpy>=1.26.2,<2.0 8 | openai>=0.28.0,<1.0 9 | PyYAML>=6.0.1,<7.0 10 | scipy>=1.11.3,<2.0 11 | spacy>=3.7.2,<4.0 12 | sqlite-utils>=3.35.2,<4.0 13 | -------------------------------------------------------------------------------- /rescore.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import json 3 | import os 4 | import re 5 | import sys 6 | 7 | import pandas as pd 8 | import pymeteor.pymeteor as pymeteor 9 | from yaml import load, dump 10 | try: 11 | from yaml import CLoader as Loader, CDumper as Dumper 12 | except ImportError: 13 | from yaml import Loader, Dumper 14 | 15 | from metrics import get_keyword_matches 16 | 17 | 18 | def load_yml_file(filename): 19 | with open(filename, "r") as f: 20 | return load(f, Loader=Loader) 21 | 22 | 23 | def final_answer_from_trace_or_result(tracepath, result=None): 24 | final_answer = "" 25 | if not os.path.exists(tracepath): 26 | return result or "" 27 | with open(tracepath, "r") as f: 28 | lines = f.readlines() 29 | i = 0 30 | while "Final Answer: " in "\n".join(lines): 31 | try: 32 | line = lines.pop() 33 | except IndexError: 34 | return final_answer 35 | if not line.startswith("Final Answer:"): 36 | continue 37 | final_answer += line 38 | if line.startswith("<|im_end|>"): 39 | break 40 | if line.startswith("Thought: "): 41 | break 42 | if line.startswith("Question: "): 43 | break 44 | i += 1 45 | return final_answer.replace("<|im_end|>", "") 46 | 47 | 48 | if __name__ == "__main__": 49 | benchmark_plan_file = None 50 | results_outfile = None 51 | try: 52 | benchmark_plan_file = sys.argv[1] 53 | results_outfile = sys.argv[2] 54 | except IndexError: 55 | print("USAGE: rescore.py benchmark-file.yml results_output.csv") 56 | sys.exit(1) 57 | 58 | plan = load_yml_file(benchmark_plan_file) 59 | question_keywords = { 60 | qa["question"]: qa["correct_keywords"] 61 | for qa in plan["QA"] 62 | } 63 | question_answers = { 64 | qa["question"]: qa["correct_answer"] 65 | for qa in plan["QA"] 66 | } 67 | 68 | results = [ 69 | ["Experiment", "Model", "Task", "Keyword(s)", "METEOR", "Match Texts"], 70 | ] 71 | for basedir, subdirs, filenames in os.walk("./experiments/"): 72 | for filename in filenames: 73 | print("Loading experiment file", filename) 74 | with open(os.path.join(basedir, filename), "r") as f: 75 | try: 76 | experiment = json.load(f) 77 | except json.decoder.JSONDecodeError: 78 | continue 79 | try: 80 | experiment_name = re.findall(r"^(.+)_2023-\d\d-\d\d_.*", filename)[0] 81 | except Exception: 82 | continue 83 | model_name = experiment["model_name"] 84 | print("=" * 72) 85 | print("Experiment:", experiment_name, "Model name:", model_name) 86 | 87 | # TODO: replace these with cli arg filters 88 | if "ROBLOX" not in experiment_name or "ROBLOX" not in filename: 89 | continue 90 | # if "OPENAI" not in experiment_name: 91 | # continue 92 | 93 | for q_n, result in enumerate(experiment["question_results"]): 94 | question = result["question"] 95 | correct_keywords = question_keywords[question] 96 | correct_answer = question_answers[question] 97 | 98 | keyword_matches_all = [] 99 | meteor_score_all = [] 100 | match_texts_all = [] 101 | for index in range(len(result["scores"])): 102 | print("-" * 72) 103 | print(f"Q_{q_n}_{index}") 104 | 105 | error = result["errors"][index] 106 | 107 | tracefile = result["tracefiles"][index] 108 | print("tracefile", tracefile) 109 | if not os.path.exists(tracefile): 110 | tracefile = "missing" 111 | 112 | final_answer = result["answers"][index] or "" 113 | # final_answer = final_answer_from_trace_or_result( 114 | # tracefile, result=result["answers"][index] 115 | # ) 116 | print("Final answer:", final_answer) 117 | 118 | keyword_score, match_texts = get_keyword_matches( 119 | final_answer, correct_keywords, return_texts=True 120 | ) 121 | match_texts_all.append(match_texts) 122 | print("Calculating METEOR on:", final_answer) 123 | meteor_score = pymeteor.meteor( 124 | correct_answer, final_answer 125 | ) 126 | exp_result = [ 127 | experiment_name, 128 | model_name, 129 | f"Q_{q_n}_{index}", 130 | keyword_score, 131 | meteor_score, 132 | match_texts, 133 | ] 134 | print(exp_result) 135 | results.append(exp_result) 136 | meteor_score_all.append(meteor_score) 137 | print("METEOR score:", meteor_score) 138 | keyword_matches_all.append(keyword_score) 139 | print("Keyword score:", keyword_score) 140 | 141 | print("-" * 72) 142 | print(f"Q_{q_n} Aggregates") 143 | exp_result = [ 144 | experiment_name, 145 | model_name, 146 | f"Q_{q_n}_Avg", 147 | sum(keyword_matches_all) / len(keyword_matches_all), 148 | sum(meteor_score_all) / len(meteor_score_all), 149 | match_texts_all, 150 | ] 151 | results.append(exp_result) 152 | print(exp_result) 153 | 154 | exp_result = [ 155 | experiment_name, 156 | model_name, 157 | f"Q_{q_n}_Max", 158 | max(keyword_matches_all), 159 | max(meteor_score_all), 160 | match_texts_all, 161 | ] 162 | results.append(exp_result) 163 | print(exp_result) 164 | 165 | results_df = pd.DataFrame(results[1:], columns=results[0]) 166 | with open(results_outfile, "w") as f: 167 | f.write(results_df.to_csv()) 168 | 169 | # import IPython 170 | # IPython.embed() 171 | # import time 172 | # time.sleep(2) 173 | -------------------------------------------------------------------------------- /run_interface.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | import sqlite3 6 | 7 | from llama_cpp import Llama 8 | import sqlite_utils 9 | 10 | 11 | DB_PATH = "example.db" 12 | MODEL_PATH = "dolphin-2.2.1-mistral-7b.Q5_K_M.gguf" 13 | # columns to not ever use or show 14 | IGNORED_COLUMNS = ["rowid", "created_at", "_meta_score"] 15 | TABLE_FTS = { 16 | "users": ["creatorDescription"], 17 | "jobs": ["title", "description"], 18 | "experiences": ["projectName", "experienceDescription", "jobRole"], 19 | "games": ["name", "description"], 20 | "game_passes": ["name", "description"], 21 | } 22 | # search-tables: useful for getting a list of tables that support full-text search. input 1: table name. 23 | # search: a full-text search engine. useful to find records with descriptions containing some text. input 1: table name, input 2: a search query. 24 | ACTIONS = """ 25 | tables: useful for getting the names of tables available. no input. 26 | columns: useful for looking all of the columns for a given table. input 1: table name. 27 | facets: useful for looking at the unique values and counts for a given column. input 1: table name, input 2: column name. 28 | filter: useful for getting the first row where the column matches a given value. input 1: table name, input 2: column name, input 3: a value to filter on. 29 | """ 30 | 31 | 32 | def load_db(path): 33 | assert os.path.exists(path), f"Database doesn't exist: {path}" 34 | db = sqlite_utils.Database(path) 35 | for table, fields in TABLE_FTS.items(): 36 | try: 37 | db[table].enable_fts(fields, create_triggers=True) 38 | except sqlite3.OperationalError: 39 | pass 40 | return db 41 | 42 | 43 | def is_array_field(db, table_name, column): 44 | # test if it's an array type, first result will always start with '[' 45 | rows = db.query(f""" 46 | select {column} as value 47 | from {table_name} 48 | where {column} is not null and {column} != "" 49 | limit 1 50 | """) 51 | for row in rows: 52 | if isinstance(row["value"], int): 53 | return False 54 | return row["value"].startswith("[") 55 | 56 | 57 | def clean_truncate(results, n=3): 58 | return [ 59 | {k: v for k, v in r.items() if k not in IGNORED_COLUMNS} 60 | for r in results[:n] 61 | ] 62 | 63 | 64 | ## ACTIONS 65 | def tables(db): 66 | return [ 67 | name 68 | for name in db.table_names() 69 | if "_fts" not in name 70 | ] 71 | 72 | 73 | def columns(db, table_name): 74 | table_names = tables(db) 75 | if table_name not in table_names: 76 | return f"Invalid table. Valid tables are: {table_names}" 77 | return [ 78 | c.name 79 | for c in db[table_name].columns 80 | if c.name not in IGNORED_COLUMNS 81 | ] 82 | 83 | 84 | def facets(db, table_name, column): 85 | table_names = tables(db) 86 | if table_name not in table_names: 87 | return f"Invalid table. Valid tables are: {table_names}" 88 | column_names = columns(db, table_name) 89 | if column not in column_names: 90 | return f"Invalid column. Valid columns are: {column_names}" 91 | if is_array_field(db, table_name, column): 92 | results = db.query(f""" 93 | SELECT value, count(*) AS count 94 | FROM (SELECT j.value AS value 95 | FROM {table_name} 96 | CROSS JOIN json_each({table_name}.{column}) AS j) 97 | GROUP BY value 98 | ORDER BY count DESC 99 | LIMIT 5; 100 | """) 101 | else: 102 | # if it's not an array type 103 | results = db.query(f""" 104 | SELECT {column} AS value, count({column}) AS count 105 | FROM {table_name} 106 | GROUP BY {column} 107 | ORDER BY count DESC 108 | LIMIT 5 109 | """) 110 | return [ 111 | [r["value"], r["count"]] 112 | for r in results 113 | ] 114 | 115 | 116 | def filter(db, table_name, column, value): 117 | table_names = tables(db) 118 | if table_name not in table_names: 119 | return f"Invalid table. Valid tables are: {table_names}" 120 | column_names = columns(db, table_name) 121 | if column not in column_names: 122 | return f"Invalid column. Valid columns are: {column_names}" 123 | if is_array_field(db, table_name, column): 124 | results = list(db.query(f""" 125 | SELECT * 126 | FROM {table_name} 127 | WHERE EXISTS (SELECT 1 FROM json_each({column}) WHERE value = '{value}') 128 | """)) 129 | else: 130 | results = list(db[table_name].rows_where(f"{column} = ?", [value])) 131 | return clean_truncate(results, n=1) 132 | 133 | 134 | def search(db, table_name, query): 135 | results = list(db[table_name].search(query)) 136 | return clean_truncate(results) 137 | 138 | 139 | # Utils n stuff 140 | def load_model(model_path): 141 | return Llama(model_path=model_path, n_ctx=2048) 142 | 143 | 144 | def execute(llm, question): 145 | action_fns = { 146 | "tables": tables, 147 | "columns": columns, 148 | "facets": facets, 149 | "filter": filter, 150 | } 151 | action_names_text = ", ".join(list(action_fns.keys())) 152 | prompt = f""" 153 | Answer the following questions as best you can. You have access to the following tools: 154 | 155 | {ACTIONS.strip()} 156 | 157 | Use the following format: 158 | 159 | Question: the input question you must answer 160 | Thought: you should always think about what to do 161 | Action: the action to take, should be one of [{action_names_text}] 162 | Action Input 1: the first input to the action. 163 | Action Input 2: the second input to the action, if more than one. 164 | Observation: the result of the action 165 | ... (this Thought/Action/Action Input/Observation can repeat N times) 166 | Thought: I now know the final answer 167 | Final Answer: the final answer to the original input question 168 | 169 | Here's an example: 170 | 171 | Question: What information do I have about users? 172 | Thought: I should check to see if I have any users tables. 173 | Action: tables 174 | Observation: ["jobs", "users", "games", "game_stats", "game_passes"] 175 | Thought: I should inspect the columns on the users table. 176 | Action: columns 177 | Action Input 1: "users" 178 | Observation: ["creatorUserId", "isPublic", "isContactAllowed", "creatorDescription", "isOpenToWork", "interestDescription", "jobTypes", "skillTypes"] 179 | Thought: I have all of the information I need. 180 | Final Answer: I have the following fields about users: ["creatorUserId", "isPublic", "isContactAllowed", "creatorDescription", "isOpenToWork", "interestDescription", "jobTypes", "skillTypes"] 181 | 182 | Here's another example: 183 | 184 | Question: What job types are being offered? 185 | Thought: I should look at which tables I have available. 186 | Action: tables 187 | Observation: ["jobs", "users", "games", "game_stats", "game_passes"] 188 | Thought: I should see which columns are available for jobs. 189 | Action: columns 190 | Action Input 1: "jobs" 191 | Observation: ["jobType", "paymentTypes", "isVerified", "paymentAmountType", "paymentTypes", "skillTypes"] 192 | Thought: I should look at the unique values the jobType column has on the jobs table. 193 | Action: facets 194 | Action Input 1: "jobs" 195 | Action Input 2: "jobType" 196 | Observation: [["PartTime", 6753]] 197 | Thought: It looks like there is only one job type so I have the final answer. 198 | Final Answer: There is only one type of job being offered: PartTime (6,754 jobs). 199 | 200 | Begin! 201 | 202 | Question: {question.strip()} 203 | Thought:""".strip() 204 | print(prompt) 205 | 206 | # allow the LLM to try 15 goes to get to an answer 207 | attempts = 0 208 | while attempts < 15: 209 | attempts += 1 210 | # print(f"** running attempt: {attempts}") 211 | 212 | output = llm( 213 | prompt, 214 | max_tokens=256, 215 | stop=["Question:", "Observation:"], 216 | echo=True 217 | ) 218 | # print("** output:", output) 219 | 220 | # get the response minus our prompt (just the new stuff) 221 | response = output["choices"][0]["text"].replace(prompt, "").strip() 222 | #print(f"** response:\n{response}") 223 | print(f"\n{response}") 224 | 225 | try: 226 | action = re.findall(r"Action: (.*)", response, re.M)[0] 227 | except IndexError: 228 | action = None 229 | # print("** action:", action) 230 | 231 | try: 232 | final_answer = re.findall(r'Final Answer: (.*)', response, re.M)[0] 233 | except IndexError: 234 | final_answer = None 235 | # print("** final answer:", final_answer) 236 | 237 | if action and action not in action_fns: 238 | prompt = output["choices"][0]["text"].strip() 239 | action_names = ", ".join(list(action_fns.keys())) 240 | prompt += f"\nObservation: That's an invalid action. Valid actions: {action_names}\n" 241 | 242 | elif action: 243 | # force to have quotes 244 | actionInputs = re.findall( 245 | r'Action Input (\d): "(.*)"', response, re.M 246 | ) 247 | args = [ 248 | inp[1] 249 | for inp in actionInputs 250 | ] 251 | action_fn = action_fns[action] 252 | # print("** action_fn:", action_fn, "args", args) 253 | try: 254 | observation = action_fn(db, *args) 255 | except TypeError as e: 256 | if "positional argument" not in str(e): 257 | raise e 258 | # trim off the name of of the action from msg like: 259 | # hi() takes 1 positional argument but 2 were given 260 | # and turn it into: 261 | # The action hi takes 1 Action Input but 2 were given 262 | args_err_msg = str(e).split(" ", 1)[1].replace( 263 | "positional argument", "Action Input" 264 | ).replace( 265 | "positional arguments", "Action Inputs" 266 | ).split(": '", 1)[0] 267 | observation = f"The action {action} {args_err_msg}" 268 | # print("** observation:", observation) 269 | observation_text = json.dumps(observation) 270 | print(f"Observation: {observation_text}\n") 271 | prompt = output["choices"][0]["text"].strip() 272 | prompt += f"\nObservation: {observation_text}\n" 273 | 274 | elif final_answer: 275 | trace = output["choices"][0]["text"] 276 | return final_answer, trace 277 | 278 | # TODO: truncate the prompt if its grown too long 279 | # using tiktoken and some keep_n value of context 280 | 281 | return None, output["choices"][0]["text"] 282 | 283 | 284 | if __name__ == "__main__": 285 | question = sys.argv[1] 286 | db = load_db(DB_PATH) 287 | llm = load_model(MODEL_PATH) 288 | answer, trace = execute(llm, question) 289 | -------------------------------------------------------------------------------- /states.yml: -------------------------------------------------------------------------------- 1 | { 2 | 'AK':'Alaska', 3 | 'AL':'Alabama', 4 | 'AR':'Arkansas', 5 | 'AZ':'Arizona', 6 | 'CA':'California', 7 | 'CO':'Colorado', 8 | 'CT':'Connecticut', 9 | 'DC':'District of Columbia', 10 | 'DE':'Delaware', 11 | 'FL':'Florida', 12 | 'GA':'Georgia', 13 | 'HI':'Hawaii', 14 | 'IA':'Iowa', 15 | 'ID':'Idaho', 16 | 'IL':'Illinois', 17 | 'IN':'Indiana', 18 | 'KS':'Kansas', 19 | 'KY':'Kentucky', 20 | 'LA':'Louisiana', 21 | 'MA':'Massachusetts', 22 | 'MD':'Maryland', 23 | 'ME':'Maine', 24 | 'MI':'Michigan', 25 | 'MN':'Minnesota', 26 | 'MO':'Missouri', 27 | 'MS':'Mississippi', 28 | 'MT':'Montana', 29 | 'NC':'North Carolina', 30 | 'ND':'North Dakota', 31 | 'NE':'Nebraska', 32 | 'NH':'New Hampshire', 33 | 'NJ':'New Jersey', 34 | 'NM':'New Mexico', 35 | 'NV':'Nevada', 36 | 'NY':'New York', 37 | 'OH':'Ohio', 38 | 'OK':'Oklahoma', 39 | 'OR':'Oregon', 40 | 'PA':'Pennsylvania', 41 | 'RI':'Rhode Island', 42 | 'SC':'South Carolina', 43 | 'SD':'South Dakota', 44 | 'TN':'Tennessee', 45 | 'TX':'Texas', 46 | 'UT':'Utah', 47 | 'VA':'Virginia', 48 | 'VT':'Vermont', 49 | 'WA':'Washington', 50 | 'WI':'Wisconsin', 51 | 'WV':'West Virginia', 52 | 'WY':'Wyoming' 53 | } 54 | --------------------------------------------------------------------------------