├── .env-copy ├── DSPy-Text2SQL.ipynb ├── Makefile ├── README.md ├── assets └── dspy.png ├── model_serve ├── chat_template.jinja └── serve_model.py ├── poetry.lock ├── pyproject.toml └── src ├── __init__.py └── starling.py /.env-copy: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY= -------------------------------------------------------------------------------- /DSPy-Text2SQL.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "#### Libraries" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import dspy\n", 17 | "import random\n", 18 | "from dotenv import load_dotenv\n", 19 | "from dspy.datasets import DataLoader\n", 20 | "from dspy.evaluate import Evaluate\n", 21 | "from dspy.teleprompt import BootstrapFewShotWithRandomSearch, LabeledFewShot\n", 22 | "\n", 23 | "from src.starling import StarlingLM # <- Custom Local Model Client for Starling7B\n", 24 | "\n", 25 | "_ = load_dotenv()" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "#### LLM" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "The model that will be used in this notebook is [Starling7B](https://huggingface.co/Nexusflow/Starling-LM-7B-beta), and the evaluation model will be [GPT-4 Turbo](https://openai.com/gpt-4)." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# Share generation args between models\n", 49 | "generation_args = {\n", 50 | " \"temperature\":0,\n", 51 | " \"max_tokens\":500,\n", 52 | " \"stop\":\"\\n\\n\",\n", 53 | " \"model_type\":\"chat\",\n", 54 | " \"n\": 1\n", 55 | "}\n", 56 | "# Model specific args\n", 57 | "model_info = {\n", 58 | " \"gpt-4\": {\"model\": \"gpt-4-0125-preview\", \"api_base\": \"https://api.openai.com/v1/\"},\n", 59 | " \"starling\": {\"model\": \"Nexusflow/Starling-LM-7B-beta\"}\n", 60 | "}" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# Set up the models\n", 70 | "lm = StarlingLM(**model_info[\"starling\"], **generation_args)\n", 71 | "evaluator_lm = dspy.OpenAI(**model_info[\"gpt-4\"], **generation_args)\n", 72 | "\n", 73 | "dspy.configure(lm=lm)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": [ 84 | "[' The capital of Colombia is Bogotá. It is the largest city in the country and serves as the political, economic, and cultural center of Colombia. Located in the Andean region of the country, Bogotá has a rich history and is known for its vibrant arts scene, diverse architecture, and numerous museums and cultural institutions.']" 85 | ] 86 | }, 87 | "execution_count": 4, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "# Testing inference of Starling\n", 94 | "lm(\"What is the capital of Colombia?\")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "#### Load dataset" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "The dataset that will be used in this notebook is [gretelai/synthetic_text_to_sql](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 5, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "# Define random seed\n", 118 | "random.seed(1399)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 6, 124 | "metadata": {}, 125 | "outputs": [ 126 | { 127 | "data": { 128 | "text/plain": [ 129 | "(75, 25, 75)" 130 | ] 131 | }, 132 | "execution_count": 6, 133 | "metadata": {}, 134 | "output_type": "execute_result" 135 | } 136 | ], 137 | "source": [ 138 | "# Load dataset\n", 139 | "dl = DataLoader()\n", 140 | "trainset = dl.from_huggingface(\n", 141 | " dataset_name=\"gretelai/synthetic_text_to_sql\", # Dataset name from Huggingface\n", 142 | " fields=(\"sql_prompt\", \"sql_context\", \"sql\"), # Fields needed\n", 143 | " input_keys=(\"sql_prompt\", \"sql_context\"), # What our model expects to recieve to generate an output\n", 144 | " split=\"train\"\n", 145 | ")\n", 146 | "\n", 147 | "testset = dl.from_huggingface(\n", 148 | " dataset_name=\"gretelai/synthetic_text_to_sql\", # Dataset name from Huggingface\n", 149 | " fields=(\"sql_prompt\", \"sql_context\", \"sql\"), # Fields needed\n", 150 | " input_keys=(\"sql_prompt\", \"sql_context\"), # What our model expects to recieve to generate an output\n", 151 | " split=\"test\"\n", 152 | ")\n", 153 | "\n", 154 | "trainset = dl.sample(dataset=trainset, n=100)\n", 155 | "testset = dl.sample(dataset=testset, n=75)\n", 156 | "\n", 157 | "_trainval = dl.train_test_split(dataset=trainset, test_size=0.25, random_state=1399) # 25% of training data for validation\n", 158 | "trainset, valset = _trainval[\"train\"], _trainval[\"test\"]\n", 159 | "\n", 160 | "len(trainset), len(valset), len(testset)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 7, 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "\n", 173 | "SQL_PROMPT:\n", 174 | "\n", 175 | "List the top 5 countries with the highest number of satellites in orbit as of 2022-01-01, ordered by the number of satellites in descending order.\n", 176 | "\n", 177 | "SQL_CONTEXT:\n", 178 | "\n", 179 | "CREATE TABLE countries(id INT, name VARCHAR(255), population INT, satellites_in_orbit INT, last_census_date DATE);\n", 180 | "\n", 181 | "SQL:\n", 182 | "\n", 183 | "SELECT name, satellites_in_orbit FROM countries WHERE last_census_date <= '2022-01-01' GROUP BY name ORDER BY satellites_in_orbit DESC LIMIT 5;\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "# Verify an example of the dataset\n", 189 | "sample = dl.sample(dataset=trainset, n=1)[0]\n", 190 | "for k, v in sample.items():\n", 191 | " print(f\"\\n{k.upper()}:\\n\")\n", 192 | " print(v)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "#### Signature (Input/Output)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 8, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "class TextToSql(dspy.Signature):\n", 209 | " \"\"\"Transform a natural language query into a SQL query.\"\"\"\n", 210 | "\n", 211 | " sql_prompt = dspy.InputField(desc=\"Natural language query\")\n", 212 | " sql_context = dspy.InputField(desc=\"Context for the query\")\n", 213 | " sql = dspy.OutputField(desc=\"SQL query\")" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "### Inference" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "#### Baseline Inference" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 9, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "name": "stdout", 237 | "output_type": "stream", 238 | "text": [ 239 | "\n", 240 | "SQL:\n", 241 | "\n", 242 | "SELECT name, satellites_in_orbit\n", 243 | "FROM countries\n", 244 | "WHERE last_census_date = '2022-01-01'\n", 245 | "ORDER BY satellites_in_orbit DESC\n", 246 | "LIMIT 5;\n" 247 | ] 248 | } 249 | ], 250 | "source": [ 251 | "generate_sql_query = dspy.Predict(signature=TextToSql)\n", 252 | "\n", 253 | "result = generate_sql_query(\n", 254 | " sql_prompt=sample[\"sql_prompt\"],\n", 255 | " sql_context=sample[\"sql_context\"]\n", 256 | ")\n", 257 | "\n", 258 | "for k, v in result.items():\n", 259 | " print(f\"\\n{k.upper()}:\\n\")\n", 260 | " print(v)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "#### ChainOfThought Inference" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 10, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "name": "stdout", 277 | "output_type": "stream", 278 | "text": [ 279 | "\n", 280 | "RATIONALE:\n", 281 | "\n", 282 | "produce the SQL query. We need to:\n", 283 | "\n", 284 | "SQL:\n", 285 | "\n", 286 | "1. Select the relevant columns from the table: In this case, we need to select the country name and the number of satellites in orbit.\n", 287 | "2. Filter the data based on the date: We need to filter the data to only include records with a last_census_date of 2022-01-01.\n", 288 | "3. Order the results: We need to order the results by the number of satellites in orbit in descending order.\n" 289 | ] 290 | } 291 | ], 292 | "source": [ 293 | "generate_sql_query = dspy.ChainOfThought(signature=TextToSql)\n", 294 | "\n", 295 | "result = generate_sql_query(\n", 296 | " sql_prompt=sample[\"sql_prompt\"],\n", 297 | " sql_context=sample[\"sql_context\"]\n", 298 | ")\n", 299 | "\n", 300 | "for k, v in result.items():\n", 301 | " print(f\"\\n{k.upper()}:\\n\")\n", 302 | " print(v)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": {}, 308 | "source": [ 309 | "### Metric of evaluation" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "#### Metric definition" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 11, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "class Correctness(dspy.Signature):\n", 326 | " \"\"\"Assess if the SQL query accurately answers the given natural language query based on the provided context.\"\"\"\n", 327 | "\n", 328 | " sql_prompt = dspy.InputField(desc=\"Natural language query \")\n", 329 | " sql_context = dspy.InputField(desc=\"Context for the query\")\n", 330 | " sql = dspy.InputField(desc=\"SQL query\")\n", 331 | " correct = dspy.OutputField(desc=\"Indicate whether the SQL query correctly answers the natural language query based on the given context\", prefix=\"Yes/No:\")" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 12, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "def correctness_metric(example, pred, trace=None):\n", 341 | " sql_prompt, sql_context, sql = example.sql_prompt, example.sql_context, pred.sql\n", 342 | "\n", 343 | " correctness = dspy.Predict(Correctness)\n", 344 | "\n", 345 | " with dspy.context(lm=evaluator_lm): \n", 346 | " correct = correctness(\n", 347 | " sql_prompt=sql_prompt,\n", 348 | " sql_context=sql_context,\n", 349 | " sql=sql,\n", 350 | " )\n", 351 | " \n", 352 | " score = int(correct.correct==\"Yes\")\n", 353 | "\n", 354 | " if trace is not None:\n", 355 | " return score == 1\n", 356 | "\n", 357 | " return score" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": {}, 363 | "source": [ 364 | "#### Evaluate single data point" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 13, 370 | "metadata": {}, 371 | "outputs": [ 372 | { 373 | "name": "stdout", 374 | "output_type": "stream", 375 | "text": [ 376 | "Correct SQL query: No\n" 377 | ] 378 | } 379 | ], 380 | "source": [ 381 | "_correctness = correctness_metric(\n", 382 | " example=sample,\n", 383 | " pred=result\n", 384 | ")\n", 385 | "print(f\"Correct SQL query: {'Yes' if _correctness else 'No'}\")" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 14, 391 | "metadata": {}, 392 | "outputs": [ 393 | { 394 | "name": "stdout", 395 | "output_type": "stream", 396 | "text": [ 397 | "\n", 398 | "\n", 399 | "\n", 400 | "\n", 401 | "Assess if the SQL query accurately answers the given natural language query based on the provided context.\n", 402 | "\n", 403 | "---\n", 404 | "\n", 405 | "Follow the following format.\n", 406 | "\n", 407 | "Sql Prompt: Natural language query\n", 408 | "\n", 409 | "Sql Context: Context for the query\n", 410 | "\n", 411 | "Sql: SQL query\n", 412 | "\n", 413 | "Yes/No: Indicate whether the SQL query correctly answers the natural language query based on the given context\n", 414 | "\n", 415 | "---\n", 416 | "\n", 417 | "Sql Prompt: List the top 5 countries with the highest number of satellites in orbit as of 2022-01-01, ordered by the number of satellites in descending order.\n", 418 | "\n", 419 | "Sql Context: CREATE TABLE countries(id INT, name VARCHAR(255), population INT, satellites_in_orbit INT, last_census_date DATE);\n", 420 | "\n", 421 | "Sql: 1. Select the relevant columns from the table: In this case, we need to select the country name and the number of satellites in orbit. 2. Filter the data based on the date: We need to filter the data to only include records with a last_census_date of 2022-01-01. 3. Order the results: We need to order the results by the number of satellites in orbit in descending order.\n", 422 | "\n", 423 | "Yes/No:\u001b[32m No\u001b[0m\n", 424 | "\n", 425 | "\n", 426 | "\n" 427 | ] 428 | } 429 | ], 430 | "source": [ 431 | "evaluator_lm.inspect_history(n=1)" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": {}, 437 | "source": [ 438 | "#### Evaluate entire dataset - GPT 3.5" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": {}, 444 | "source": [ 445 | "

📊 Baseline Evaluation

Without any optimization, Starling7B achieves an 80% correctness in validation (25 samples) and 70.07% correctness in test (75 samples).

" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 16, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "name": "stdout", 455 | "output_type": "stream", 456 | "text": [ 457 | "GPT 3.5 Turbo - Validation Score: \n", 458 | "\n" 459 | ] 460 | }, 461 | { 462 | "name": "stderr", 463 | "output_type": "stream", 464 | "text": [ 465 | " 0%| | 0/25 [00:00

⚠️ Evaluation Stage 1

Without any optimization, Starling7B achieves an 72% correctness in validation (25 samples) and 50.67% correctness in test (75 samples).

" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 19, 560 | "metadata": {}, 561 | "outputs": [ 562 | { 563 | "name": "stdout", 564 | "output_type": "stream", 565 | "text": [ 566 | "Starling7b - Validation Score: \n", 567 | "\n" 568 | ] 569 | }, 570 | { 571 | "name": "stderr", 572 | "output_type": "stream", 573 | "text": [ 574 | "Average Metric: 18 / 25 (72.0): 100%|██████████| 25/25 [00:10<00:00, 2.44it/s]" 575 | ] 576 | }, 577 | { 578 | "name": "stdout", 579 | "output_type": "stream", 580 | "text": [ 581 | "Average Metric: 18 / 25 (72.0%)\n" 582 | ] 583 | }, 584 | { 585 | "name": "stderr", 586 | "output_type": "stream", 587 | "text": [ 588 | "\n", 589 | "/home/jjmov99/dspy-testing/.venv/lib/python3.11/site-packages/dspy/evaluate/evaluate.py:187: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", 590 | " df = df.applymap(truncate_cell)\n" 591 | ] 592 | }, 593 | { 594 | "data": { 595 | "text/plain": [ 596 | "72.0" 597 | ] 598 | }, 599 | "execution_count": 19, 600 | "metadata": {}, 601 | "output_type": "execute_result" 602 | } 603 | ], 604 | "source": [ 605 | "print(\"Starling7b - Validation Score: \\n\")\n", 606 | "evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)\n", 607 | "evaluate(generate_sql_query)" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": 20, 613 | "metadata": {}, 614 | "outputs": [ 615 | { 616 | "name": "stdout", 617 | "output_type": "stream", 618 | "text": [ 619 | "Starling7b - Test Score: \n", 620 | "\n" 621 | ] 622 | }, 623 | { 624 | "name": "stderr", 625 | "output_type": "stream", 626 | "text": [ 627 | "Average Metric: 38 / 75 (50.7): 100%|██████████| 75/75 [00:31<00:00, 2.37it/s]" 628 | ] 629 | }, 630 | { 631 | "name": "stdout", 632 | "output_type": "stream", 633 | "text": [ 634 | "Average Metric: 38 / 75 (50.7%)\n" 635 | ] 636 | }, 637 | { 638 | "name": "stderr", 639 | "output_type": "stream", 640 | "text": [ 641 | "\n", 642 | "/home/jjmov99/dspy-testing/.venv/lib/python3.11/site-packages/dspy/evaluate/evaluate.py:187: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", 643 | " df = df.applymap(truncate_cell)\n" 644 | ] 645 | }, 646 | { 647 | "data": { 648 | "text/plain": [ 649 | "50.67" 650 | ] 651 | }, 652 | "execution_count": 20, 653 | "metadata": {}, 654 | "output_type": "execute_result" 655 | } 656 | ], 657 | "source": [ 658 | "print(\"Starling7b - Test Score: \\n\")\n", 659 | "evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)\n", 660 | "evaluate(generate_sql_query)" 661 | ] 662 | }, 663 | { 664 | "cell_type": "markdown", 665 | "metadata": {}, 666 | "source": [ 667 | "### Optimize for Text2SQL" 668 | ] 669 | }, 670 | { 671 | "cell_type": "markdown", 672 | "metadata": {}, 673 | "source": [ 674 | "#### Create program" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": 21, 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [ 683 | "# Define the program ~ You can think of this a Pytorch model.\n", 684 | "class TextToSqlProgram(dspy.Module):\n", 685 | " def __init__(self):\n", 686 | " super().__init__()\n", 687 | " self.program = dspy.ChainOfThought(signature=TextToSql)\n", 688 | " \n", 689 | " def forward(self, sql_prompt, sql_context):\n", 690 | " return self.program(\n", 691 | " sql_prompt=sql_prompt,\n", 692 | " sql_context=sql_context\n", 693 | " )" 694 | ] 695 | }, 696 | { 697 | "cell_type": "markdown", 698 | "metadata": {}, 699 | "source": [ 700 | "### FewShot" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": 22, 706 | "metadata": {}, 707 | "outputs": [], 708 | "source": [ 709 | "# Execute the optimizer -> this only adds few shots to the prompt\n", 710 | "optimizer = LabeledFewShot(k=4)\n", 711 | "optmized_program = optimizer.compile(student=TextToSqlProgram(), trainset=trainset)" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": 23, 717 | "metadata": {}, 718 | "outputs": [ 719 | { 720 | "data": { 721 | "text/plain": [ 722 | "Prediction(\n", 723 | " rationale='produce the SQL query. We need to filter the countries based on the date and then order them by the number of satellites in orbit.',\n", 724 | " sql=\"SELECT name, satellites_in_orbit\\nFROM countries\\nWHERE last_census_date <= '2022-01-01'\\nORDER BY satellites_in_orbit DESC\\nLIMIT 5;\"\n", 725 | ")" 726 | ] 727 | }, 728 | "execution_count": 23, 729 | "metadata": {}, 730 | "output_type": "execute_result" 731 | } 732 | ], 733 | "source": [ 734 | "optmized_program(sql_context=sample[\"sql_context\"], sql_prompt=sample[\"sql_prompt\"])" 735 | ] 736 | }, 737 | { 738 | "cell_type": "markdown", 739 | "metadata": {}, 740 | "source": [ 741 | "#### What is happening inside?" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": 24, 747 | "metadata": {}, 748 | "outputs": [ 749 | { 750 | "name": "stdout", 751 | "output_type": "stream", 752 | "text": [ 753 | "\n", 754 | "\n", 755 | "\n", 756 | "\n", 757 | "Transform a natural language query into a SQL query.\n", 758 | "\n", 759 | "---\n", 760 | "\n", 761 | "Sql Prompt: What is the total number of electric vehicles sold by manufacturer 'XYZ'?\n", 762 | "Sql Context: CREATE TABLE sales_data (manufacturer VARCHAR(10), vehicle_type VARCHAR(10), quantity INT);\n", 763 | "Sql: SELECT manufacturer, SUM(quantity) FROM sales_data WHERE vehicle_type = 'Electric' AND manufacturer = 'XYZ' GROUP BY manufacturer;\n", 764 | "\n", 765 | "Sql Prompt: What is the total number of amphibians in the 'animals' table with a population size greater than 1000?\n", 766 | "Sql Context: CREATE TABLE animals (id INT, name VARCHAR(50), species VARCHAR(50), population_size INT); INSERT INTO animals (id, name, species, population_size) VALUES (1, 'Frog', 'Anura', 1200);\n", 767 | "Sql: SELECT COUNT(*) FROM animals WHERE species = 'Anura' AND population_size > 1000;\n", 768 | "\n", 769 | "Sql Prompt: What is the minimum depth a marine species can live at?\n", 770 | "Sql Context: CREATE TABLE species (id INT, name VARCHAR(255), habitat VARCHAR(255), depth FLOAT); INSERT INTO species (id, name, habitat, depth) VALUES (1, 'Clownfish', 'Coral Reef', 20.0); INSERT INTO species (id, name, habitat, depth) VALUES (2, 'Blue Whale', 'Open Ocean', 2000.0); INSERT INTO species (id, name, habitat, depth) VALUES (3, 'Sea Otter', 'Kelp Forest', 50.0);\n", 771 | "Sql: SELECT MIN(depth) FROM species;\n", 772 | "\n", 773 | "Sql Prompt: Delete all investments with ESG scores less than 70.\n", 774 | "Sql Context: CREATE TABLE investments (id INT, sector VARCHAR(20), esg_score FLOAT); INSERT INTO investments (id, sector, esg_score) VALUES (1, 'Education', 75.00), (2, 'Healthcare', 70.00), (3, 'Renewable Energy', 65.00);\n", 775 | "Sql: DELETE FROM investments WHERE esg_score < 70;\n", 776 | "\n", 777 | "---\n", 778 | "\n", 779 | "Follow the following format.\n", 780 | "\n", 781 | "Sql Prompt: Natural language query\n", 782 | "\n", 783 | "Sql Context: Context for the query\n", 784 | "\n", 785 | "Reasoning: Let's think step by step in order to ${produce the sql}. We ...\n", 786 | "\n", 787 | "Sql: SQL query\n", 788 | "\n", 789 | "---\n", 790 | "\n", 791 | "Sql Prompt: List the top 5 countries with the highest number of satellites in orbit as of 2022-01-01, ordered by the number of satellites in descending order.\n", 792 | "\n", 793 | "Sql Context: CREATE TABLE countries(id INT, name VARCHAR(255), population INT, satellites_in_orbit INT, last_census_date DATE);\n", 794 | "\n", 795 | "Reasoning: Let's think step by step in order to produce the SQL query. We need to filter the countries based on the date and then order them by the number of satellites in orbit.\n", 796 | "\n", 797 | "Sql:\u001b[32m SELECT name, satellites_in_orbit\n", 798 | "FROM countries\n", 799 | "WHERE last_census_date <= '2022-01-01'\n", 800 | "ORDER BY satellites_in_orbit DESC\n", 801 | "LIMIT 5;\u001b[0m\n", 802 | "\n", 803 | "\n", 804 | "\n" 805 | ] 806 | } 807 | ], 808 | "source": [ 809 | "lm.inspect_history(n=1)" 810 | ] 811 | }, 812 | { 813 | "cell_type": "markdown", 814 | "metadata": {}, 815 | "source": [ 816 | "#### Evaluate the optimized program" 817 | ] 818 | }, 819 | { 820 | "cell_type": "markdown", 821 | "metadata": {}, 822 | "source": [ 823 | "\n", 824 | "

🌟 Evaluation Stage 2

With Few Shot optimization, Starling7B achieves an 64% correctness in validation (25 samples) and 60% correctness in test (75 samples).

" 825 | ] 826 | }, 827 | { 828 | "cell_type": "code", 829 | "execution_count": 25, 830 | "metadata": {}, 831 | "outputs": [ 832 | { 833 | "name": "stdout", 834 | "output_type": "stream", 835 | "text": [ 836 | "Starling7b + FewShotOptimizer - Validation Score: \n", 837 | "\n" 838 | ] 839 | }, 840 | { 841 | "name": "stderr", 842 | "output_type": "stream", 843 | "text": [ 844 | "Average Metric: 16 / 25 (64.0): 100%|██████████| 25/25 [00:15<00:00, 1.59it/s]" 845 | ] 846 | }, 847 | { 848 | "name": "stdout", 849 | "output_type": "stream", 850 | "text": [ 851 | "Average Metric: 16 / 25 (64.0%)\n" 852 | ] 853 | }, 854 | { 855 | "name": "stderr", 856 | "output_type": "stream", 857 | "text": [ 858 | "\n", 859 | "/home/jjmov99/dspy-testing/.venv/lib/python3.11/site-packages/dspy/evaluate/evaluate.py:187: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", 860 | " df = df.applymap(truncate_cell)\n" 861 | ] 862 | }, 863 | { 864 | "data": { 865 | "text/plain": [ 866 | "64.0" 867 | ] 868 | }, 869 | "execution_count": 25, 870 | "metadata": {}, 871 | "output_type": "execute_result" 872 | } 873 | ], 874 | "source": [ 875 | "print(\"Starling7b + FewShotOptimizer - Validation Score: \\n\")\n", 876 | "evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)\n", 877 | "evaluate(optmized_program)" 878 | ] 879 | }, 880 | { 881 | "cell_type": "code", 882 | "execution_count": 26, 883 | "metadata": {}, 884 | "outputs": [ 885 | { 886 | "name": "stdout", 887 | "output_type": "stream", 888 | "text": [ 889 | "Starling7b + FewShotOptimizer - Test Score: \n", 890 | "\n" 891 | ] 892 | }, 893 | { 894 | "name": "stderr", 895 | "output_type": "stream", 896 | "text": [ 897 | "Average Metric: 45 / 75 (60.0): 100%|██████████| 75/75 [00:31<00:00, 2.35it/s]" 898 | ] 899 | }, 900 | { 901 | "name": "stdout", 902 | "output_type": "stream", 903 | "text": [ 904 | "Average Metric: 45 / 75 (60.0%)\n" 905 | ] 906 | }, 907 | { 908 | "name": "stderr", 909 | "output_type": "stream", 910 | "text": [ 911 | "\n" 912 | ] 913 | }, 914 | { 915 | "data": { 916 | "text/plain": [ 917 | "60.0" 918 | ] 919 | }, 920 | "execution_count": 26, 921 | "metadata": {}, 922 | "output_type": "execute_result" 923 | } 924 | ], 925 | "source": [ 926 | "print(\"Starling7b + FewShotOptimizer - Test Score: \\n\")\n", 927 | "evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)\n", 928 | "evaluate(optmized_program)" 929 | ] 930 | }, 931 | { 932 | "cell_type": "markdown", 933 | "metadata": {}, 934 | "source": [ 935 | "### BootstrapFewShotWithRandomSearch" 936 | ] 937 | }, 938 | { 939 | "cell_type": "markdown", 940 | "metadata": {}, 941 | "source": [ 942 | "[DSPy docs](https://dspy-docs.vercel.app/docs/building-blocks/optimizers) recommend that in a setup like the one with have at hand, with ~50 samples, the best option is to use `BootstrapFewShotWithRandomSearch`:\n", 943 | "\n", 944 | "![image](assets/dspy.png)" 945 | ] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "execution_count": 27, 950 | "metadata": {}, 951 | "outputs": [ 952 | { 953 | "name": "stdout", 954 | "output_type": "stream", 955 | "text": [ 956 | "Going to sample between 1 and 2 traces per predictor.\n", 957 | "Will attempt to train 8 candidate sets.\n" 958 | ] 959 | }, 960 | { 961 | "name": "stderr", 962 | "output_type": "stream", 963 | "text": [ 964 | "Average Metric: 18 / 25 (72.0): 100%|██████████| 25/25 [00:15<00:00, 1.64it/s]\n", 965 | "/home/jjmov99/dspy-testing/.venv/lib/python3.11/site-packages/dspy/evaluate/evaluate.py:187: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", 966 | " df = df.applymap(truncate_cell)\n" 967 | ] 968 | }, 969 | { 970 | "name": "stdout", 971 | "output_type": "stream", 972 | "text": [ 973 | "Average Metric: 18 / 25 (72.0%)\n", 974 | "Score: 72.0 for set: [0]\n", 975 | "New best score: 72.0 for seed -3\n", 976 | "Scores so far: [72.0]\n", 977 | "Best score: 72.0\n" 978 | ] 979 | }, 980 | { 981 | "name": "stderr", 982 | "output_type": "stream", 983 | "text": [ 984 | "Average Metric: 19 / 25 (76.0): 100%|██████████| 25/25 [00:24<00:00, 1.00it/s]\n" 985 | ] 986 | }, 987 | { 988 | "name": "stdout", 989 | "output_type": "stream", 990 | "text": [ 991 | "Average Metric: 19 / 25 (76.0%)\n", 992 | "Score: 76.0 for set: [16]\n", 993 | "New best score: 76.0 for seed -2\n", 994 | "Scores so far: [72.0, 76.0]\n", 995 | "Best score: 76.0\n" 996 | ] 997 | }, 998 | { 999 | "name": "stderr", 1000 | "output_type": "stream", 1001 | "text": [ 1002 | " 4%|▍ | 3/75 [00:11<04:44, 3.95s/it]\n" 1003 | ] 1004 | }, 1005 | { 1006 | "name": "stdout", 1007 | "output_type": "stream", 1008 | "text": [ 1009 | "Bootstrapped 2 full traces after 4 examples in round 0.\n" 1010 | ] 1011 | }, 1012 | { 1013 | "name": "stderr", 1014 | "output_type": "stream", 1015 | "text": [ 1016 | "Average Metric: 18 / 25 (72.0): 100%|██████████| 25/25 [00:25<00:00, 1.02s/it]\n" 1017 | ] 1018 | }, 1019 | { 1020 | "name": "stdout", 1021 | "output_type": "stream", 1022 | "text": [ 1023 | "Average Metric: 18 / 25 (72.0%)\n", 1024 | "Score: 72.0 for set: [16]\n", 1025 | "Scores so far: [72.0, 76.0, 72.0]\n", 1026 | "Best score: 76.0\n", 1027 | "Average of max per entry across top 1 scores: 0.76\n", 1028 | "Average of max per entry across top 2 scores: 0.84\n", 1029 | "Average of max per entry across top 3 scores: 0.84\n", 1030 | "Average of max per entry across top 5 scores: 0.84\n", 1031 | "Average of max per entry across top 8 scores: 0.84\n", 1032 | "Average of max per entry across top 9999 scores: 0.84\n" 1033 | ] 1034 | }, 1035 | { 1036 | "name": "stderr", 1037 | "output_type": "stream", 1038 | "text": [ 1039 | " 4%|▍ | 3/75 [00:10<04:08, 3.45s/it]\n" 1040 | ] 1041 | }, 1042 | { 1043 | "name": "stdout", 1044 | "output_type": "stream", 1045 | "text": [ 1046 | "Bootstrapped 2 full traces after 4 examples in round 0.\n" 1047 | ] 1048 | }, 1049 | { 1050 | "name": "stderr", 1051 | "output_type": "stream", 1052 | "text": [ 1053 | "Average Metric: 20 / 25 (80.0): 100%|██████████| 25/25 [00:26<00:00, 1.07s/it]\n" 1054 | ] 1055 | }, 1056 | { 1057 | "name": "stdout", 1058 | "output_type": "stream", 1059 | "text": [ 1060 | "Average Metric: 20 / 25 (80.0%)\n", 1061 | "Score: 80.0 for set: [16]\n", 1062 | "New best score: 80.0 for seed 0\n", 1063 | "Scores so far: [72.0, 76.0, 72.0, 80.0]\n", 1064 | "Best score: 80.0\n", 1065 | "Average of max per entry across top 1 scores: 0.8\n", 1066 | "Average of max per entry across top 2 scores: 0.84\n", 1067 | "Average of max per entry across top 3 scores: 0.88\n", 1068 | "Average of max per entry across top 5 scores: 0.88\n", 1069 | "Average of max per entry across top 8 scores: 0.88\n", 1070 | "Average of max per entry across top 9999 scores: 0.88\n" 1071 | ] 1072 | }, 1073 | { 1074 | "name": "stderr", 1075 | "output_type": "stream", 1076 | "text": [ 1077 | " 1%|▏ | 1/75 [00:02<03:34, 2.90s/it]\n" 1078 | ] 1079 | }, 1080 | { 1081 | "name": "stdout", 1082 | "output_type": "stream", 1083 | "text": [ 1084 | "Bootstrapped 1 full traces after 2 examples in round 0.\n" 1085 | ] 1086 | }, 1087 | { 1088 | "name": "stderr", 1089 | "output_type": "stream", 1090 | "text": [ 1091 | "Average Metric: 20 / 25 (80.0): 100%|██████████| 25/25 [00:25<00:00, 1.01s/it]\n" 1092 | ] 1093 | }, 1094 | { 1095 | "name": "stdout", 1096 | "output_type": "stream", 1097 | "text": [ 1098 | "Average Metric: 20 / 25 (80.0%)\n", 1099 | "Score: 80.0 for set: [16]\n", 1100 | "Scores so far: [72.0, 76.0, 72.0, 80.0, 80.0]\n", 1101 | "Best score: 80.0\n", 1102 | "Average of max per entry across top 1 scores: 0.8\n", 1103 | "Average of max per entry across top 2 scores: 0.88\n", 1104 | "Average of max per entry across top 3 scores: 0.88\n", 1105 | "Average of max per entry across top 5 scores: 0.88\n", 1106 | "Average of max per entry across top 8 scores: 0.88\n", 1107 | "Average of max per entry across top 9999 scores: 0.88\n" 1108 | ] 1109 | }, 1110 | { 1111 | "name": "stderr", 1112 | "output_type": "stream", 1113 | "text": [ 1114 | " 1%|▏ | 1/75 [00:03<04:09, 3.37s/it]\n" 1115 | ] 1116 | }, 1117 | { 1118 | "name": "stdout", 1119 | "output_type": "stream", 1120 | "text": [ 1121 | "Bootstrapped 1 full traces after 2 examples in round 0.\n" 1122 | ] 1123 | }, 1124 | { 1125 | "name": "stderr", 1126 | "output_type": "stream", 1127 | "text": [ 1128 | "Average Metric: 20 / 25 (80.0): 100%|██████████| 25/25 [00:26<00:00, 1.08s/it]\n" 1129 | ] 1130 | }, 1131 | { 1132 | "name": "stdout", 1133 | "output_type": "stream", 1134 | "text": [ 1135 | "Average Metric: 20 / 25 (80.0%)\n", 1136 | "Score: 80.0 for set: [16]\n", 1137 | "Scores so far: [72.0, 76.0, 72.0, 80.0, 80.0, 80.0]\n", 1138 | "Best score: 80.0\n", 1139 | "Average of max per entry across top 1 scores: 0.8\n", 1140 | "Average of max per entry across top 2 scores: 0.88\n", 1141 | "Average of max per entry across top 3 scores: 0.96\n", 1142 | "Average of max per entry across top 5 scores: 0.96\n", 1143 | "Average of max per entry across top 8 scores: 0.96\n", 1144 | "Average of max per entry across top 9999 scores: 0.96\n" 1145 | ] 1146 | }, 1147 | { 1148 | "name": "stderr", 1149 | "output_type": "stream", 1150 | "text": [ 1151 | " 1%|▏ | 1/75 [00:03<04:00, 3.25s/it]\n" 1152 | ] 1153 | }, 1154 | { 1155 | "name": "stdout", 1156 | "output_type": "stream", 1157 | "text": [ 1158 | "Bootstrapped 1 full traces after 2 examples in round 0.\n" 1159 | ] 1160 | }, 1161 | { 1162 | "name": "stderr", 1163 | "output_type": "stream", 1164 | "text": [ 1165 | "Average Metric: 19 / 25 (76.0): 100%|██████████| 25/25 [00:23<00:00, 1.05it/s]\n" 1166 | ] 1167 | }, 1168 | { 1169 | "name": "stdout", 1170 | "output_type": "stream", 1171 | "text": [ 1172 | "Average Metric: 19 / 25 (76.0%)\n", 1173 | "Score: 76.0 for set: [16]\n", 1174 | "Scores so far: [72.0, 76.0, 72.0, 80.0, 80.0, 80.0, 76.0]\n", 1175 | "Best score: 80.0\n", 1176 | "Average of max per entry across top 1 scores: 0.8\n", 1177 | "Average of max per entry across top 2 scores: 0.88\n", 1178 | "Average of max per entry across top 3 scores: 0.96\n", 1179 | "Average of max per entry across top 5 scores: 0.96\n", 1180 | "Average of max per entry across top 8 scores: 0.96\n", 1181 | "Average of max per entry across top 9999 scores: 0.96\n" 1182 | ] 1183 | }, 1184 | { 1185 | "name": "stderr", 1186 | "output_type": "stream", 1187 | "text": [ 1188 | " 1%|▏ | 1/75 [00:03<04:18, 3.49s/it]\n" 1189 | ] 1190 | }, 1191 | { 1192 | "name": "stdout", 1193 | "output_type": "stream", 1194 | "text": [ 1195 | "Bootstrapped 1 full traces after 2 examples in round 0.\n" 1196 | ] 1197 | }, 1198 | { 1199 | "name": "stderr", 1200 | "output_type": "stream", 1201 | "text": [ 1202 | "Average Metric: 20 / 25 (80.0): 100%|██████████| 25/25 [00:25<00:00, 1.04s/it]\n" 1203 | ] 1204 | }, 1205 | { 1206 | "name": "stdout", 1207 | "output_type": "stream", 1208 | "text": [ 1209 | "Average Metric: 20 / 25 (80.0%)\n", 1210 | "Score: 80.0 for set: [16]\n", 1211 | "Scores so far: [72.0, 76.0, 72.0, 80.0, 80.0, 80.0, 76.0, 80.0]\n", 1212 | "Best score: 80.0\n", 1213 | "Average of max per entry across top 1 scores: 0.8\n", 1214 | "Average of max per entry across top 2 scores: 0.88\n", 1215 | "Average of max per entry across top 3 scores: 0.96\n", 1216 | "Average of max per entry across top 5 scores: 1.0\n", 1217 | "Average of max per entry across top 8 scores: 1.0\n", 1218 | "Average of max per entry across top 9999 scores: 1.0\n" 1219 | ] 1220 | }, 1221 | { 1222 | "name": "stderr", 1223 | "output_type": "stream", 1224 | "text": [ 1225 | " 5%|▌ | 4/75 [00:13<03:57, 3.34s/it]\n" 1226 | ] 1227 | }, 1228 | { 1229 | "name": "stdout", 1230 | "output_type": "stream", 1231 | "text": [ 1232 | "Bootstrapped 2 full traces after 5 examples in round 0.\n" 1233 | ] 1234 | }, 1235 | { 1236 | "name": "stderr", 1237 | "output_type": "stream", 1238 | "text": [ 1239 | "Average Metric: 20 / 25 (80.0): 100%|██████████| 25/25 [00:25<00:00, 1.02s/it]\n" 1240 | ] 1241 | }, 1242 | { 1243 | "name": "stdout", 1244 | "output_type": "stream", 1245 | "text": [ 1246 | "Average Metric: 20 / 25 (80.0%)\n", 1247 | "Score: 80.0 for set: [16]\n", 1248 | "Scores so far: [72.0, 76.0, 72.0, 80.0, 80.0, 80.0, 76.0, 80.0, 80.0]\n", 1249 | "Best score: 80.0\n", 1250 | "Average of max per entry across top 1 scores: 0.8\n", 1251 | "Average of max per entry across top 2 scores: 0.88\n", 1252 | "Average of max per entry across top 3 scores: 0.96\n", 1253 | "Average of max per entry across top 5 scores: 1.0\n", 1254 | "Average of max per entry across top 8 scores: 1.0\n", 1255 | "Average of max per entry across top 9999 scores: 1.0\n" 1256 | ] 1257 | }, 1258 | { 1259 | "name": "stderr", 1260 | "output_type": "stream", 1261 | "text": [ 1262 | " 1%|▏ | 1/75 [00:03<04:52, 3.95s/it]\n" 1263 | ] 1264 | }, 1265 | { 1266 | "name": "stdout", 1267 | "output_type": "stream", 1268 | "text": [ 1269 | "Bootstrapped 1 full traces after 2 examples in round 0.\n" 1270 | ] 1271 | }, 1272 | { 1273 | "name": "stderr", 1274 | "output_type": "stream", 1275 | "text": [ 1276 | "Average Metric: 19 / 25 (76.0): 100%|██████████| 25/25 [00:27<00:00, 1.10s/it]\n" 1277 | ] 1278 | }, 1279 | { 1280 | "name": "stdout", 1281 | "output_type": "stream", 1282 | "text": [ 1283 | "Average Metric: 19 / 25 (76.0%)\n", 1284 | "Score: 76.0 for set: [16]\n", 1285 | "Scores so far: [72.0, 76.0, 72.0, 80.0, 80.0, 80.0, 76.0, 80.0, 80.0, 76.0]\n", 1286 | "Best score: 80.0\n", 1287 | "Average of max per entry across top 1 scores: 0.8\n", 1288 | "Average of max per entry across top 2 scores: 0.88\n", 1289 | "Average of max per entry across top 3 scores: 0.96\n", 1290 | "Average of max per entry across top 5 scores: 1.0\n", 1291 | "Average of max per entry across top 8 scores: 1.0\n", 1292 | "Average of max per entry across top 9999 scores: 1.0\n" 1293 | ] 1294 | }, 1295 | { 1296 | "name": "stderr", 1297 | "output_type": "stream", 1298 | "text": [ 1299 | " 4%|▍ | 3/75 [00:11<04:41, 3.90s/it]\n" 1300 | ] 1301 | }, 1302 | { 1303 | "name": "stdout", 1304 | "output_type": "stream", 1305 | "text": [ 1306 | "Bootstrapped 2 full traces after 4 examples in round 0.\n" 1307 | ] 1308 | }, 1309 | { 1310 | "name": "stderr", 1311 | "output_type": "stream", 1312 | "text": [ 1313 | "Average Metric: 20 / 25 (80.0): 100%|██████████| 25/25 [00:27<00:00, 1.09s/it]" 1314 | ] 1315 | }, 1316 | { 1317 | "name": "stdout", 1318 | "output_type": "stream", 1319 | "text": [ 1320 | "Average Metric: 20 / 25 (80.0%)\n", 1321 | "Score: 80.0 for set: [16]\n", 1322 | "Scores so far: [72.0, 76.0, 72.0, 80.0, 80.0, 80.0, 76.0, 80.0, 80.0, 76.0, 80.0]\n", 1323 | "Best score: 80.0\n", 1324 | "Average of max per entry across top 1 scores: 0.8\n", 1325 | "Average of max per entry across top 2 scores: 0.88\n", 1326 | "Average of max per entry across top 3 scores: 0.96\n", 1327 | "Average of max per entry across top 5 scores: 1.0\n", 1328 | "Average of max per entry across top 8 scores: 1.0\n", 1329 | "Average of max per entry across top 9999 scores: 1.0\n", 1330 | "11 candidate programs found.\n" 1331 | ] 1332 | }, 1333 | { 1334 | "name": "stderr", 1335 | "output_type": "stream", 1336 | "text": [ 1337 | "\n" 1338 | ] 1339 | } 1340 | ], 1341 | "source": [ 1342 | "optimizer2 = BootstrapFewShotWithRandomSearch(metric=correctness_metric, max_bootstrapped_demos=2, num_candidate_programs=8, num_threads=5)\n", 1343 | "optmized_program_2 = optimizer2.compile(student = TextToSqlProgram(), trainset=trainset, valset=valset)" 1344 | ] 1345 | }, 1346 | { 1347 | "cell_type": "code", 1348 | "execution_count": 28, 1349 | "metadata": {}, 1350 | "outputs": [ 1351 | { 1352 | "data": { 1353 | "text/plain": [ 1354 | "Prediction(\n", 1355 | " rationale='produce the SQL query. We need to find the top 5 countries with the highest number of satellites in orbit as of 2022-01-01, ordered by the number of satellites in descending order.',\n", 1356 | " sql=\"SELECT name, satellites_in_orbit\\nFROM countries\\nWHERE last_census_date <= '2022-01-01'\\nORDER BY satellites_in_orbit DESC\\nLIMIT 5;\"\n", 1357 | ")" 1358 | ] 1359 | }, 1360 | "execution_count": 28, 1361 | "metadata": {}, 1362 | "output_type": "execute_result" 1363 | } 1364 | ], 1365 | "source": [ 1366 | "optmized_program_2(sql_context=sample[\"sql_context\"], sql_prompt=sample[\"sql_prompt\"])" 1367 | ] 1368 | }, 1369 | { 1370 | "cell_type": "code", 1371 | "execution_count": 29, 1372 | "metadata": {}, 1373 | "outputs": [ 1374 | { 1375 | "name": "stdout", 1376 | "output_type": "stream", 1377 | "text": [ 1378 | "\n", 1379 | "\n", 1380 | "\n", 1381 | "\n", 1382 | "Transform a natural language query into a SQL query.\n", 1383 | "\n", 1384 | "---\n", 1385 | "\n", 1386 | "Sql Prompt: What is the total number of open pedagogy courses offered by each country?\n", 1387 | "Sql Context: CREATE TABLE country (country_id INT, country_name VARCHAR(255)); CREATE TABLE open_pedagogy_courses (country_id INT, course_id INT); INSERT INTO country (country_id, country_name) VALUES (6001, 'Country X'), (6002, 'Country Y'), (6003, 'Country Z'); INSERT INTO open_pedagogy_courses (country_id, course_id) VALUES (6001, 7001), (6001, 7002), (6002, 7003);\n", 1388 | "Sql: SELECT country_name, COUNT(course_id) as total_courses FROM country JOIN open_pedagogy_courses ON country.country_id = open_pedagogy_courses.country_id GROUP BY country_name;\n", 1389 | "\n", 1390 | "Sql Prompt: Top 3 most popular songs of 2021 in 'concert_ticket_sales' table?\n", 1391 | "Sql Context: CREATE TABLE concert_ticket_sales (ticket_id INT, song_id INT, quantity INT, price FLOAT, sale_date DATE);\n", 1392 | "Sql: SELECT song_id, SUM(quantity) as total_quantity FROM concert_ticket_sales WHERE sale_date >= '2021-01-01' GROUP BY song_id ORDER BY total_quantity DESC LIMIT 3;\n", 1393 | "\n", 1394 | "Sql Prompt: What is the total amount donated by each donor in the 'Donors' table, ordered by the total amount donated in descending order?\n", 1395 | "Sql Context: CREATE TABLE Donors (DonorID INT, Name VARCHAR(50), TotalDonated DECIMAL(10, 2));\n", 1396 | "Sql: SELECT DonorID, Name, SUM(TotalDonated) AS TotalDonatedSum FROM Donors GROUP BY DonorID ORDER BY TotalDonatedSum DESC;\n", 1397 | "\n", 1398 | "Sql Prompt: What was the total sales amount for each product category in Africa in 2021?\n", 1399 | "Sql Context: CREATE TABLE sales_2021 AS SELECT * FROM sales WHERE sale_date BETWEEN '2021-01-01' AND '2021-12-31'; ALTER TABLE sales_2021 ADD COLUMN sale_country VARCHAR(50); UPDATE sales_2021 SET sale_country = CASE WHEN sale_city IN ('Accra', 'Lagos', 'Cairo') THEN 'Africa' ELSE sale_country END; ALTER TABLE sales_2021 ADD COLUMN product_category VARCHAR(50); UPDATE sales_2021 SET product_category = CASE WHEN product_id = 1 THEN 'Tops' WHEN product_id = 2 THEN 'Bottoms' WHEN product_id = 3 THEN 'Outerwear' WHEN product_id = 4 THEN 'Accessories' END;\n", 1400 | "Sql: SELECT sale_country, product_category, SUM(sale_amount) FROM sales_2021 WHERE sale_country = 'Africa' GROUP BY sale_country, product_category;\n", 1401 | "\n", 1402 | "Sql Prompt: What is the total transaction value for each customer in the past week, split by currency, for customers in the United States?\n", 1403 | "Sql Context: CREATE TABLE customers (customer_id INT, age INT, name VARCHAR(255), country VARCHAR(50)); CREATE TABLE transactions (transaction_id INT, customer_id INT, product_id INT, category_id INT, transaction_date DATE, amount DECIMAL(10,2), currency VARCHAR(10));\n", 1404 | "Sql: SELECT c.country, c.name, t.currency, SUM(t.amount) as total_transaction_value FROM customers c INNER JOIN transactions t ON c.customer_id = t.customer_id WHERE c.country = 'United States' AND t.transaction_date >= DATE_SUB(CURRENT_DATE, INTERVAL 1 WEEK) GROUP BY c.country, c.name, t.currency;\n", 1405 | "\n", 1406 | "Sql Prompt: Who are the top 5 customers by transaction count?\n", 1407 | "Sql Context: CREATE TABLE customer_transactions (customer_id INT, customer_name VARCHAR(20), transaction_id INT); INSERT INTO customer_transactions (customer_id, customer_name, transaction_id) VALUES (1, 'Juan Pérez', 1), (2, 'María Rodríguez', 2), (3, 'Carlos López', 3), (4, 'Laura González', 4), (5, 'José Hernández', 5), (6, 'Ana Sánchez', 6), (7, 'Pedro Martínez', 7);\n", 1408 | "Sql: SELECT c.customer_name, COUNT(ct.transaction_id) as transaction_count FROM customers c JOIN customer_transactions ct ON c.customer_id = ct.customer_id GROUP BY c.customer_name ORDER BY transaction_count DESC LIMIT 5;\n", 1409 | "\n", 1410 | "Sql Prompt: How many accounts are associated with each risk category and what is the total assets value for each category?\n", 1411 | "Sql Context: CREATE TABLE account_risk (id INT, account_id INT, risk_category VARCHAR(255)); INSERT INTO account_risk (id, account_id, risk_category) VALUES (1, 1, 'High'), (2, 2, 'Medium'), (3, 3, 'Low'), (4, 4, 'High'), (5, 5, 'Medium'); CREATE TABLE accounts (id INT, customer_id INT, total_assets DECIMAL(10, 2)); INSERT INTO accounts (id, customer_id, total_assets) VALUES (1, 1, 100000), (2, 2, 150000), (3, 3, 80000), (4, 4, 120000), (5, 5, 90000);\n", 1412 | "Sql: SELECT r.risk_category, COUNT(r.account_id) AS num_accounts, SUM(a.total_assets) AS total_assets FROM account_risk r INNER JOIN accounts a ON r.account_id = a.id GROUP BY r.risk_category;\n", 1413 | "\n", 1414 | "Sql Prompt: What is the average production per acre for corn and soy in each region?\n", 1415 | "Sql Context: CREATE TABLE crop_production (region TEXT, crop_type TEXT, acres FLOAT, production INT); INSERT INTO crop_production (region, crop_type, acres, production) VALUES ('North', 'Corn', 150, 270), ('North', 'Soy', 250, 120), ('South', 'Corn', 220, 900), ('East', 'Soy', 150, 300);\n", 1416 | "Sql: SELECT region, crop_type, AVG(production/acres) as avg_production_per_acre FROM crop_production WHERE crop_type IN ('Corn', 'Soy') GROUP BY region, crop_type;\n", 1417 | "\n", 1418 | "Sql Prompt: What is the minimum policy issuance year for policy number 1001?\n", 1419 | "Sql Context: CREATE TABLE policies (policy_id INT, policy_issue_year INT); INSERT INTO policies (policy_id, policy_issue_year) VALUES (1001, 2018), (1002, 2017), (1003, 2016), (1004, 2019);\n", 1420 | "Sql: SELECT MIN(policy_issue_year) FROM policies WHERE policy_id = 1001;\n", 1421 | "\n", 1422 | "Sql Prompt: What is the total number of hospital beds in rural areas of each state?\n", 1423 | "Sql Context: CREATE TABLE beds (bed_id INT, hospital_id INT, location VARCHAR(20));\n", 1424 | "Sql: SELECT hospital_id, COUNT(*) FROM beds WHERE location = 'Rural' GROUP BY hospital_id;\n", 1425 | "\n", 1426 | "Sql Prompt: What is the average age of employees with the title 'Supervisor' in the 'employees' table?\n", 1427 | "Sql Context: CREATE TABLE employees(id INT, name VARCHAR(255), title VARCHAR(255), age INT); INSERT INTO employees(id, name, title, age) VALUES ('1', 'Jane Smith', 'Mining Supervisor', '55');\n", 1428 | "Sql: SELECT AVG(age) FROM employees WHERE title LIKE '%Supervisor%';\n", 1429 | "\n", 1430 | "Sql Prompt: What was the total number of citizens in Asia in 2017?\n", 1431 | "Sql Context: CREATE TABLE asia_population (id INT PRIMARY KEY, year INT, num_citizens INT); INSERT INTO asia_population (id, year, num_citizens) VALUES (1, 2017, 4000000000);\n", 1432 | "Sql: SELECT num_citizens FROM asia_population WHERE year = 2017;\n", 1433 | "\n", 1434 | "Sql Prompt: What is the average number of citations for algorithmic fairness papers published in 2018 and 2019?\n", 1435 | "Sql Context: CREATE TABLE algorithmic_fairness_papers (year INT, paper_title VARCHAR(255), author_name VARCHAR(255), num_citations INT); INSERT INTO algorithmic_fairness_papers (year, paper_title, author_name, num_citations) VALUES ('2018', 'Algorithmic Fairness: A Review', 'Alice Johnson', '50');\n", 1436 | "Sql: SELECT AVG(num_citations) as avg_citations FROM algorithmic_fairness_papers WHERE year IN (2018, 2019);\n", 1437 | "\n", 1438 | "Sql Prompt: What is the average water temperature in salmon farms in Norway?\n", 1439 | "Sql Context: CREATE TABLE salmon_farms (id INT, name TEXT, country TEXT); CREATE TABLE temperature_readings (id INT, farm_id INT, temperature FLOAT); INSERT INTO salmon_farms (id, name, country) VALUES (1, 'Farm X', 'Norway'), (2, 'Farm Y', 'Norway'), (3, 'Farm Z', 'Canada'); INSERT INTO temperature_readings (id, farm_id, temperature) VALUES (1, 1, 12.5), (2, 1, 13.0), (3, 2, 11.0), (4, 2, 11.5), (5, 3, 7.0);\n", 1440 | "Sql: SELECT AVG(temperature) FROM temperature_readings TR JOIN salmon_farms SF ON TR.farm_id = SF.id WHERE SF.country = 'Norway';\n", 1441 | "\n", 1442 | "---\n", 1443 | "\n", 1444 | "Follow the following format.\n", 1445 | "\n", 1446 | "Sql Prompt: Natural language query\n", 1447 | "\n", 1448 | "Sql Context: Context for the query\n", 1449 | "\n", 1450 | "Reasoning: Let's think step by step in order to ${produce the sql}. We ...\n", 1451 | "\n", 1452 | "Sql: SQL query\n", 1453 | "\n", 1454 | "---\n", 1455 | "\n", 1456 | "Sql Prompt: Insert a new contractor with \"contractor_id\" 1001, \"name\" \"ABC Construction\", \"location\" \"New York, NY\", and \"license_number\" \"1234567890\" into the \"Contractors\" table.\n", 1457 | "\n", 1458 | "Sql Context: CREATE TABLE Contractors (contractor_id INT, name VARCHAR(255), location VARCHAR(255), license_number VARCHAR(50));\n", 1459 | "\n", 1460 | "Reasoning: Let's think step by step in order to ${insert the new contractor into the Contractors table}. We will first identify the necessary columns and their corresponding values for the new contractor, and then use the INSERT INTO statement to add the new record to the table.\n", 1461 | "\n", 1462 | "Sql: INSERT INTO Contractors (contractor_id, name, location, license_number) VALUES (1001, 'ABC Construction', 'New York, NY', '1234567890');\n", 1463 | "\n", 1464 | "---\n", 1465 | "\n", 1466 | "Sql Prompt: Show the number of vegan hair care products launched in Canada each year.\n", 1467 | "\n", 1468 | "Sql Context: CREATE TABLE hair_care_launches (id INT, product VARCHAR(50), launch_date DATE, country VARCHAR(50), vegan BOOLEAN); INSERT INTO hair_care_launches (id, product, launch_date, country, vegan) VALUES (1, 'Vegan Shampoo', '2021-02-15', 'Canada', TRUE);\n", 1469 | "\n", 1470 | "Reasoning: Let's think step by step in order to produce the SQL query. We need to find the number of vegan hair care products launched in Canada each year.\n", 1471 | "\n", 1472 | "Sql: SELECT YEAR(launch_date) AS launch_year, COUNT(*) AS vegan_products_launched FROM hair_care_launches WHERE country = 'Canada' AND vegan = TRUE GROUP BY launch_year ORDER BY launch_year;\n", 1473 | "\n", 1474 | "---\n", 1475 | "\n", 1476 | "Sql Prompt: List the top 5 countries with the highest number of satellites in orbit as of 2022-01-01, ordered by the number of satellites in descending order.\n", 1477 | "\n", 1478 | "Sql Context: CREATE TABLE countries(id INT, name VARCHAR(255), population INT, satellites_in_orbit INT, last_census_date DATE);\n", 1479 | "\n", 1480 | "Reasoning: Let's think step by step in order to produce the SQL query. We need to find the top 5 countries with the highest number of satellites in orbit as of 2022-01-01, ordered by the number of satellites in descending order.\n", 1481 | "\n", 1482 | "Sql:\u001b[32m SELECT name, satellites_in_orbit\n", 1483 | "FROM countries\n", 1484 | "WHERE last_census_date <= '2022-01-01'\n", 1485 | "ORDER BY satellites_in_orbit DESC\n", 1486 | "LIMIT 5;\u001b[0m\n", 1487 | "\n", 1488 | "\n", 1489 | "\n" 1490 | ] 1491 | } 1492 | ], 1493 | "source": [ 1494 | "lm.inspect_history(n=1)" 1495 | ] 1496 | }, 1497 | { 1498 | "cell_type": "markdown", 1499 | "metadata": {}, 1500 | "source": [ 1501 | "#### Evaluate the optimized program" 1502 | ] 1503 | }, 1504 | { 1505 | "cell_type": "markdown", 1506 | "metadata": {}, 1507 | "source": [ 1508 | "

✅ Evaluation Stage 3

With BootstrapFewShotWithRandomSearch optimization, Starling7B achieves an 80% correctness in validation (25 samples) and 68% correctness in test (75 samples).

" 1509 | ] 1510 | }, 1511 | { 1512 | "cell_type": "code", 1513 | "execution_count": 30, 1514 | "metadata": {}, 1515 | "outputs": [ 1516 | { 1517 | "name": "stdout", 1518 | "output_type": "stream", 1519 | "text": [ 1520 | "Starling7b + BootstrapFewShotWithRandomSearch - Validation Score: \n", 1521 | "\n" 1522 | ] 1523 | }, 1524 | { 1525 | "name": "stderr", 1526 | "output_type": "stream", 1527 | "text": [] 1528 | }, 1529 | { 1530 | "name": "stderr", 1531 | "output_type": "stream", 1532 | "text": [ 1533 | "Average Metric: 20 / 25 (80.0): 100%|██████████| 25/25 [00:20<00:00, 1.24it/s]" 1534 | ] 1535 | }, 1536 | { 1537 | "name": "stdout", 1538 | "output_type": "stream", 1539 | "text": [ 1540 | "Average Metric: 20 / 25 (80.0%)\n" 1541 | ] 1542 | }, 1543 | { 1544 | "name": "stderr", 1545 | "output_type": "stream", 1546 | "text": [ 1547 | "\n", 1548 | "/home/jjmov99/dspy-testing/.venv/lib/python3.11/site-packages/dspy/evaluate/evaluate.py:187: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", 1549 | " df = df.applymap(truncate_cell)\n" 1550 | ] 1551 | }, 1552 | { 1553 | "data": { 1554 | "text/plain": [ 1555 | "80.0" 1556 | ] 1557 | }, 1558 | "execution_count": 30, 1559 | "metadata": {}, 1560 | "output_type": "execute_result" 1561 | } 1562 | ], 1563 | "source": [ 1564 | "print(\"Starling7b + BootstrapFewShotWithRandomSearch - Validation Score: \\n\")\n", 1565 | "evaluate = Evaluate(devset=valset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)\n", 1566 | "evaluate(optmized_program_2)" 1567 | ] 1568 | }, 1569 | { 1570 | "cell_type": "code", 1571 | "execution_count": 31, 1572 | "metadata": {}, 1573 | "outputs": [ 1574 | { 1575 | "name": "stdout", 1576 | "output_type": "stream", 1577 | "text": [ 1578 | "Starling7b + BootstrapFewShotWithRandomSearch - Test Score: \n", 1579 | "\n" 1580 | ] 1581 | }, 1582 | { 1583 | "name": "stderr", 1584 | "output_type": "stream", 1585 | "text": [ 1586 | "Average Metric: 51 / 75 (68.0): 100%|██████████| 75/75 [01:01<00:00, 1.22it/s]" 1587 | ] 1588 | }, 1589 | { 1590 | "name": "stdout", 1591 | "output_type": "stream", 1592 | "text": [ 1593 | "Average Metric: 51 / 75 (68.0%)\n" 1594 | ] 1595 | }, 1596 | { 1597 | "name": "stderr", 1598 | "output_type": "stream", 1599 | "text": [ 1600 | "\n", 1601 | "/home/jjmov99/dspy-testing/.venv/lib/python3.11/site-packages/dspy/evaluate/evaluate.py:187: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", 1602 | " df = df.applymap(truncate_cell)\n" 1603 | ] 1604 | }, 1605 | { 1606 | "data": { 1607 | "text/plain": [ 1608 | "68.0" 1609 | ] 1610 | }, 1611 | "execution_count": 31, 1612 | "metadata": {}, 1613 | "output_type": "execute_result" 1614 | } 1615 | ], 1616 | "source": [ 1617 | "print(\"Starling7b + BootstrapFewShotWithRandomSearch - Test Score: \\n\")\n", 1618 | "evaluate = Evaluate(devset=testset, metric=correctness_metric, num_threads=10, display_progress=True, display_table=0)\n", 1619 | "evaluate(optmized_program_2)" 1620 | ] 1621 | } 1622 | ], 1623 | "metadata": { 1624 | "kernelspec": { 1625 | "display_name": ".venv", 1626 | "language": "python", 1627 | "name": "python3" 1628 | }, 1629 | "language_info": { 1630 | "codemirror_mode": { 1631 | "name": "ipython", 1632 | "version": 3 1633 | }, 1634 | "file_extension": ".py", 1635 | "mimetype": "text/x-python", 1636 | "name": "python", 1637 | "nbconvert_exporter": "python", 1638 | "pygments_lexer": "ipython3", 1639 | "version": "3.11.6" 1640 | } 1641 | }, 1642 | "nbformat": 4, 1643 | "nbformat_minor": 2 1644 | } 1645 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | pycache: 2 | find ./ -type d -name '__pycache__' -exec rm -rf {} + -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimizing LM for Text2SQL using DSPy 2 | 3 | ## What is in this project? 4 | In this project, I decided to test DSPy for the task of converting natural language to SQL syntax. For this, I used a 7 billion parameter open-source model ([Nexusflow/Starling-LM-7B-beta](https://huggingface.co/Nexusflow/Starling-LM-7B-beta)) to see if I can reach GPT-3.5-Turbo performance levels, or even surpass them. The cool thing about this is that with DSPy, you can actually focus on programming rather than creating overcrafted prompts that are not failure proof. I invite you to see the main notebook that contains the entire process. I think you will find it really useful. These were the main results: 5 | 6 | | Model | Optimization | Validation Correctness (n = 25) | Test Correctness (n = 75) | 7 | | ----- | ------------ | ------------------------------- | ------------------------- | 8 | | GPT 3.5 Turbo | N/A | 80.0% | 70.07% | 9 | | Starling7B | N/A | 72.0% | 50.67% | 10 | | Starling7B | LabeledFewShot | 64.0% | 60.0% | 11 | | Starling7B | BootstrapFewShotWithRandomSearch | 80.0% | 68.0% | 12 | 13 | ## Repo Structure 14 | | Stage | Notebook/Script | Tech Stack | 15 | |------------------------------------------|------------------------------------------------------------------------------------------------------------------|--------------------------------| 16 | | Serverless Deployment of Starling 7B | [model_serve/](https://github.com/jjovalle99/DSPy-Text2SQL/tree/23a0a347db2d7515c5a28c305dacaea00d09dddc/model_serve) | vLLM, Modal, HuggingFace | 17 | | DSPy Optimization for Text2SQL | [DSPy-Text2SQL.ipynb](https://github.com/jjovalle99/DSPy-Text2SQL/blob/23a0a347db2d7515c5a28c305dacaea00d09dddc/DSPy-Text2SQL.ipynb) | DSPy, HuggingFace | 18 | 19 | ## What is DSPy? 20 | #### [Docs](https://dspy-docs.vercel.app/) 21 | 22 | _DSPy is a framework for algorithmically optimizing LM prompts and weights, especially when LMs are used one or more times within a pipeline. To use LMs to build a complex system without DSPy, you generally have to: (1) break the problem down into steps, (2) prompt your LM well until each step works well in isolation, (3) tweak the steps to work well together, (4) generate synthetic examples to tune each step, and (5) use these examples to fine-tune smaller LMs to cut costs. Currently, this is hard and messy: every time you change your pipeline, your LM, or your data, all prompts (or fine-tuning steps) may need to change._ 23 | 24 | _To make this more systematic and much more powerful, DSPy does two things. First, it separates the flow of your program (modules) from the parameters (LM prompts and weights) of each step. Second, DSPy introduces new optimizers, which are LM-driven algorithms that can tune the prompts and/or the weights of your LM calls, given a metric you want to maximize._ -------------------------------------------------------------------------------- /assets/dspy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jjovalle99/DSPy-Text2SQL/988895266bfe1d22a51b8cc0f162bc3efa2b26ed/assets/dspy.png -------------------------------------------------------------------------------- /model_serve/chat_template.jinja: -------------------------------------------------------------------------------- 1 | {{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %} -------------------------------------------------------------------------------- /model_serve/serve_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shlex 3 | import subprocess 4 | from pathlib import Path 5 | 6 | from modal import Image, Mount, Secret, Stub, gpu, web_server 7 | 8 | MODEL_DIR = "/model" 9 | BASE_MODEL = "Nexusflow/Starling-LM-7B-beta" 10 | GPU_CONFIG = gpu.A100(count=1) 11 | 12 | 13 | # Download the model 14 | def download_model_to_folder(): 15 | from huggingface_hub import snapshot_download 16 | from transformers.utils import move_cache 17 | 18 | os.makedirs(MODEL_DIR, exist_ok=True) 19 | 20 | snapshot_download( 21 | BASE_MODEL, 22 | local_dir=MODEL_DIR, 23 | ignore_patterns=["*.pt", "*.gguf", "*.bin"], 24 | ) 25 | move_cache() 26 | 27 | 28 | # Image definition 29 | image = ( 30 | Image.from_registry("nvidia/cuda:12.1.1-devel-ubuntu22.04", add_python="3.11") 31 | .pip_install( 32 | "vllm==0.3.2", 33 | "huggingface_hub==0.19.4", 34 | "hf-transfer==0.1.4", 35 | "torch==2.1.2", 36 | ) 37 | .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) 38 | .run_function( 39 | download_model_to_folder, 40 | timeout=60 * 20, 41 | ) 42 | ) 43 | 44 | stub = Stub("starling7b-ft", image=image) 45 | mounts_map = { 46 | "chat_template": { 47 | "local_path": Path(__file__).parent / "chat_template.jinja", 48 | "remote_path": "/model/chat_template.jinja", 49 | }, 50 | } 51 | 52 | 53 | @stub.function( 54 | gpu=GPU_CONFIG, 55 | mounts=[Mount.from_local_file(**mounts_map["chat_template"])], 56 | allow_concurrent_inputs=50, 57 | timeout=60 * 60, 58 | keep_warm=1, 59 | ) 60 | @web_server(port=8000, startup_timeout=60) 61 | def serve_model(): 62 | base_model = shlex.quote(str(BASE_MODEL)) 63 | cmd = ( 64 | f"python -m vllm.entrypoints.openai.api_server --model {base_model} " 65 | "--chat-template /model/chat_template.jinja " 66 | ) 67 | subprocess.Popen(cmd, shell=True) 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "dspy-testing" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["nn"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "~3.11" 10 | dspy-ai = "^2.4.0" 11 | modal = "^0.62.43" 12 | jinja2 = "^3.1.3" 13 | 14 | 15 | [tool.poetry.group.dev.dependencies] 16 | ipykernel = "^6.29.4" 17 | ipywidgets = "^8.1.2" 18 | python-dotenv = "^1.0.1" 19 | 20 | 21 | [tool.poetry.group.format.dependencies] 22 | ruff = "^0.3.5" 23 | 24 | [build-system] 25 | requires = ["poetry-core"] 26 | build-backend = "poetry.core.masonry.api" 27 | 28 | [tool.ruff] 29 | line-length = 120 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jjovalle99/DSPy-Text2SQL/988895266bfe1d22a51b8cc0f162bc3efa2b26ed/src/__init__.py -------------------------------------------------------------------------------- /src/starling.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | 3 | import requests 4 | from dsp import LM 5 | 6 | 7 | class StarlingLM(LM): 8 | def __init__(self, model, model_type, **kwargs): 9 | self.provider = "openai" 10 | self.model = model 11 | self.base_url = "https://jjovalle99--starling7b-ft-serve-model.modal.run/v1/chat/completions" 12 | self.history = [] 13 | self.kwargs = kwargs 14 | self.model_type = model_type 15 | 16 | def basic_request(self, prompt, **kwargs): 17 | headers = {"content-type": "application/json"} 18 | kwargs = {**self.kwargs, **kwargs} 19 | data = {**kwargs, "model": self.model, "messages": [{"role": "user", "content": prompt}]} 20 | 21 | response = requests.post(self.base_url, headers=headers, json=data) 22 | response = response.json() 23 | 24 | self.history.append( 25 | { 26 | "prompt": prompt, 27 | "response": response, 28 | "kwargs": kwargs, 29 | } 30 | ) 31 | return response 32 | 33 | def _get_choice_text(self, choice: dict[str, Any]) -> str: 34 | if self.model_type == "chat": 35 | return choice["message"]["content"] 36 | return choice["text"] 37 | 38 | def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): 39 | response = self.request(prompt, **kwargs) 40 | 41 | completions = [response["message"]["content"] for response in response["choices"]] 42 | 43 | return completions 44 | --------------------------------------------------------------------------------