├── .env.template ├── .gitignore ├── README.md ├── __pycache__ └── opentom_evaluator.cpython-311.pyc ├── cot.py ├── cot_modules.pkl ├── cot_with_thought.py ├── datasets.pkl ├── dspy.ipynb ├── get_data.py ├── main.py ├── opentom_evaluator.py ├── poetry.lock └── pyproject.toml /.env.template: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY= 2 | NEPTUNE_API_TOKEN= -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | 3 | .venv 4 | 5 | .DS_Store 6 | 7 | *.log 8 | 9 | __pycache__ 10 | 11 | .neptune 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DSPy OpenTOM 2 | 3 | This repo contains scripts for optimizing DSPy modules for the OpenTOM Benchmark. We support Chain of Thought and a method we thought might work where we generate a "thought" about the context to aid in answering the question (spoiler -- it didn't work better than just `BootstrapFewShotWithRandomSearch`). 4 | 5 | CLI Usage: 6 | ``` 7 | usage: main.py [-h] [--student STUDENT] [--teacher TEACHER] [--train_size TRAIN_SIZE] [--download_dataset DOWNLOAD_DATASET] 8 | [--question_types [QUESTION_TYPES ...]] 9 | experiment_title dspy_method dspy_optimizer 10 | 11 | Run DSPY method. 12 | 13 | positional arguments: 14 | experiment_title Title of new experiment 15 | dspy_method The DSPY method to run 16 | dspy_optimizer The DSPY optimizer to use 17 | 18 | options: 19 | -h, --help show this help message and exit 20 | --student STUDENT The LLM to optimize prompts for 21 | --teacher TEACHER Teacher LLM for optimizing prompts. Defaults to Student LLM 22 | --train_size TRAIN_SIZE 23 | Number of training examples to use for optimization 24 | --download_dataset DOWNLOAD_DATASET 25 | Download dataset 26 | --question_types [QUESTION_TYPES ...] 27 | Question types. Defaults to all 28 | ``` 29 | 30 | Come chat with us in our [discord](https://discorg.gg/plasticlabs) or in the [DSPy thread](https://discord.com/channels/1161519468141355160/1214629969318252574) 31 | -------------------------------------------------------------------------------- /__pycache__/opentom_evaluator.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plastic-labs/dspy-opentom/58a3715d3245690740163ad27256971f7a0a5df8/__pycache__/opentom_evaluator.cpython-311.pyc -------------------------------------------------------------------------------- /cot.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | 3 | 4 | # DSPy code 5 | class GenerateAnswer(dspy.Signature): 6 | """Generate answers to the questions""" 7 | 8 | context = dspy.InputField(desc="may contain relevant facts and psychological insights") 9 | question = dspy.InputField() 10 | answer_choices = dspy.InputField() 11 | answer = dspy.OutputField(desc="often between 1 and 5 words") 12 | 13 | 14 | class CoTSimplifiedBaleen(dspy.Module): 15 | def __init__(self): 16 | super().__init__() 17 | self.generate_answer = dspy.ChainOfThought(GenerateAnswer) 18 | 19 | def forward(self, question, context, answer_choices): 20 | pred = self.generate_answer(context=context, question=question, answer_choices=answer_choices) 21 | return dspy.Prediction(context=context, answer=pred.answer) 22 | -------------------------------------------------------------------------------- /cot_modules.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plastic-labs/dspy-opentom/58a3715d3245690740163ad27256971f7a0a5df8/cot_modules.pkl -------------------------------------------------------------------------------- /cot_with_thought.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | 3 | 4 | # DSPy code 5 | class GenerateAnswer(dspy.Signature): 6 | """Generate answers to the questions""" 7 | 8 | context = dspy.InputField(desc="may contain relevant facts and psychological insights") 9 | question = dspy.InputField() 10 | thought = dspy.InputField(desc="a thought that might help answer the question") 11 | answer_choices = dspy.InputField() 12 | answer = dspy.OutputField(desc="often between 1 and 5 words") 13 | 14 | 15 | class GenerateThought(dspy.Signature): 16 | """Generate thoughts about questions""" 17 | 18 | context = dspy.InputField(desc="may contain relevant facts and psychological insights") 19 | question = dspy.InputField() 20 | thought = dspy.OutputField(desc="a thought that might help answer the question") 21 | 22 | 23 | class CoTWithThoughtSimplifiedBaleen(dspy.Module): 24 | def __init__(self): 25 | super().__init__() 26 | self.generate_thought = dspy.ChainOfThought(GenerateThought) 27 | self.generate_answer = dspy.ChainOfThought(GenerateAnswer) 28 | 29 | def forward(self, question, context, answer_choices): 30 | pred_thought = self.generate_thought(context=context, question=question) 31 | pred = self.generate_answer( 32 | context=context, question=question, thought=pred_thought.thought, answer_choices=answer_choices 33 | ) 34 | return dspy.Prediction(context=context, answer=pred.answer) 35 | -------------------------------------------------------------------------------- /datasets.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plastic-labs/dspy-opentom/58a3715d3245690740163ad27256971f7a0a5df8/datasets.pkl -------------------------------------------------------------------------------- /dspy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# DSPy + OpenTom\n", 8 | "\n", 9 | "Goal of this notebook is to explore the OpenToM dataset and see if we can write some DSPy code to optimize prompts for answering the questions.\n", 10 | "\n", 11 | "They've evaluated the performance of CoT and SimToM on their dataset, I now wonder how much extra performance we can get from using a framework like DSPy." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# SETUP: run poetry install + shell in the terminal, then i just say `cursor .` to open my editor and it runs this nb in the venv\n", 21 | "# GETTING STARTED: let's import the packages and get the data\n", 22 | "import dspy\n", 23 | "import requests\n", 24 | "import random\n", 25 | "import pandas as pd\n", 26 | "from dotenv import load_dotenv\n", 27 | "\n", 28 | "load_dotenv() # need ur api keys set beforehand\n", 29 | "\n", 30 | "turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=200)\n", 31 | "dspy.settings.configure(lm=turbo)\n", 32 | "\n", 33 | "# dataset isn't able to be loaded using hf datasets package so let's read it from github raw\n", 34 | "# also let's keep it simple and just go for the opentom_long.json\n", 35 | "# this is the one that they sampled 100 existing OpenToM plots to produce \"extra long\" narratives\n", 36 | "# url = \"https://raw.githubusercontent.com/SeacowX/OpenToM/main/data/opentom_long.json\"\n", 37 | "url = \"https://raw.githubusercontent.com/SeacowX/OpenToM/main/data/opentom.json\"\n", 38 | "response = requests.get(url).json()\n", 39 | "\n", 40 | "df = pd.DataFrame(response)\n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "df.head()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "df.loc[0]['plot_info']" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "type_counts = df['question'].apply(lambda x: x['type']).value_counts()\n", 68 | "type_counts # fo means first-order, so means second-order\n", 69 | "\n", 70 | "# first order questions directly ask about a character’s perception of the world, while\n", 71 | "# second order questions ask about a character’s belief of another character's mental state" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# Assuming 'df' is your DataFrame and it contains a 'question' column with dictionaries having 'type' and 'answer' keys\n", 81 | "\n", 82 | "# Extract 'type' and 'answer' into separate columns\n", 83 | "df['type'] = df['question'].apply(lambda x: x['type'])\n", 84 | "df['answer'] = df['question'].apply(lambda x: x['answer'])\n", 85 | "\n", 86 | "# Group by 'type' and get unique 'answer' values for each 'type'\n", 87 | "unique_answers_by_type = df.groupby('type')['answer'].unique()\n", 88 | "\n", 89 | "print(unique_answers_by_type)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "import json\n", 99 | "\n", 100 | "# convert the dataset to what DSPy expects (list of Example objects)\n", 101 | "dataset = []\n", 102 | "\n", 103 | "for index, row in df.iterrows():\n", 104 | " context = row['narrative']\n", 105 | " question = row['question']['question']\n", 106 | " answer = row['question']['answer']\n", 107 | " type = row['question']['type']\n", 108 | " plot_info = json.dumps(row['plot_info']) # Keeping each example field as a string might be a good idea\n", 109 | "\n", 110 | " if \"location\" in type and (answer.lower().strip() != \"yes\" and answer.lower().strip() != \"no\"): # don't provide answer choices for fine grained location questions\n", 111 | " answer_choices = \"n/a, list a specific location\"\n", 112 | " elif \"location\" in type:\n", 113 | " answer_choices = \"No, Yes\"\n", 114 | " else:\n", 115 | " answer_choices = \", \".join(unique_answers_by_type[type])\n", 116 | "\n", 117 | " dataset.append(dspy.Example(context=context, question=question, answer=answer, type=type, plot_info=plot_info, answer_choices=answer_choices).with_inputs(\"context\", \"question\", \"answer_choices\"))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "# split datasets by question types \n", 127 | "from collections import defaultdict\n", 128 | "\n", 129 | "datasets = defaultdict(lambda: [])\n", 130 | "\n", 131 | "for example in dataset:\n", 132 | " datasets[example.type].append(example)\n", 133 | "\n", 134 | "datasets.keys()\n", 135 | "[len(dataset) for dataset in datasets.values()]" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# create train test split\n", 145 | "for question_type, dataset in datasets.items():\n", 146 | " random.shuffle(dataset)\n", 147 | "\n", 148 | " datasets[question_type] = {\n", 149 | " \"train\": dataset[:int(len(dataset) * 0.8)],\n", 150 | " \"test\": dataset[int(len(dataset) * 0.8):],\n", 151 | " }\n", 152 | "\n", 153 | " print(f\"Now Train {question_type}: {len(datasets[question_type]['train'])}\")\n", 154 | " print(f\"Now Test {question_type}: {len(datasets[question_type]['test'])}\")" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "# Define the Signatures\n", 162 | "\n", 163 | "Using a \"Baleen\" pipeline [(Khattab et al., 2021)](https://arxiv.org/abs/2101.00436)\n" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "# answer the question\n", 173 | "class GenerateAnswer(dspy.Signature):\n", 174 | " \"\"\"Generate answers to the questions\"\"\"\n", 175 | "\n", 176 | " context = dspy.InputField(desc=\"may contain relevant facts and psychological insights\")\n", 177 | " question = dspy.InputField()\n", 178 | " answer_choices = dspy.InputField()\n", 179 | " answer = dspy.OutputField(desc=\"often between 1 and 5 words\")\n", 180 | "\n", 181 | "# generate a question to help you better answer the question\n", 182 | "# class GenerateSearchQuery(dspy.Signature):\n", 183 | "# \"\"\"Write a simple search query that will help answer a complex question.\"\"\"\n", 184 | "\n", 185 | "# context = dspy.InputField(desc=\"may contain relevant facts and psychological insights\")\n", 186 | "# question = dspy.InputField()\n", 187 | "# query = dspy.OutputField(desc=\"a thought that might help answer the question\") \n", 188 | "\n", 189 | "# class GenerateSearchAnswer(dspy.Signature):\n", 190 | "# \"\"\"Generate a long form answer to the question given the context\"\"\"\n", 191 | "\n", 192 | "# context = dspy.InputField(desc=\"may contain relevant facts and psychological insights\")\n", 193 | "# question = dspy.InputField()\n", 194 | "# answer = dspy.OutputField(desc=\"a thought about what the answer to the question may be\")\n", 195 | "\n" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "from dsp.utils import deduplicate\n", 205 | "\n", 206 | "class SimplifiedBaleen(dspy.Module):\n", 207 | " # def __init__(self, max_hops=2):\n", 208 | " # super().__init__()\n", 209 | " def __init__(self):\n", 210 | " super().__init__()\n", 211 | "\n", 212 | " # self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]\n", 213 | " # self.generate_search_answer = dspy.ChainOfThought(GenerateSearchAnswer)\n", 214 | " self.generate_answer = dspy.ChainOfThought(GenerateAnswer)\n", 215 | " # self.max_hops = max_hops\n", 216 | " \n", 217 | " def forward(self, question, context, answer_choices):\n", 218 | " # final_context = []\n", 219 | " \n", 220 | " # for hop in range(self.max_hops):\n", 221 | " # query = self.generate_query[hop](context=context, question=question).query\n", 222 | " # filtered_context = self.generate_search_answer(context=context, question=query).answer\n", 223 | " # final_context = (context + filtered_context)\n", 224 | "\n", 225 | " pred = self.generate_answer(context=context, question=question, answer_choices=answer_choices)\n", 226 | " return dspy.Prediction(context=context, answer=pred.answer)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "Here we're defining a simple signature just to generate the answer given the context, question, and answer choices.\n", 234 | "\n", 235 | "# Executing the Pipeline\n", 236 | "\n", 237 | "Let's see how this works in a zero-shot setting" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "my_question = datasets[\"attitude\"][\"test\"][0].question\n", 247 | "my_context = datasets[\"attitude\"][\"test\"][0].context\n", 248 | "my_answer_choices = datasets[\"attitude\"][\"test\"][0].answer_choices\n", 249 | "\n", 250 | "# Get the prediction. This contains `pred.context` and `pred.answer`.\n", 251 | "uncompiled_baleen = SimplifiedBaleen() # uncompiled (i.e., zero-shot) program\n", 252 | "pred = uncompiled_baleen(question=my_question, context=my_context, answer_choices=my_answer_choices)\n", 253 | "\n", 254 | "# Print the contexts and the answer.\n", 255 | "print(f\"Question: {my_question}\")\n", 256 | "print(f\"True Answer: {datasets['attitude']['test'][0].answer}\")\n", 257 | "print(f\"Predicted Answer: {pred.answer}\")\n", 258 | "print(f\"Answer Choices: {my_answer_choices}\")\n" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "from opentom_evaluator import OpenToMEvaluatorDspy\n", 268 | "eval = OpenToMEvaluatorDspy()\n", 269 | "eval.dspy_metric(datasets[\"attitude\"][\"test\"][0], pred)\n", 270 | "\n" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "We can inspect the last three calls to the LM (i.e., generating the first hop's query, generating the second hop's query, and generating the answer) using:" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "turbo.inspect_history(n=3)" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "# Optimizing the Pipeline\n", 294 | "\n", 295 | "However, a zero-shot approach quickly falls short for more specialized tasks, novel domains/settings, and more efficient (or open) models.\n", 296 | "\n", 297 | "To address this, DSPy offers compilation. Let's compile our multi-hop (SimplifiedBaleen) program." 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "from opentom_evaluator import OpenToMEvaluatorDspy\n", 307 | "from dspy.teleprompt import BootstrapFewShotWithRandomSearch\n", 308 | "import time\n", 309 | "\n", 310 | "eval_question_types = [\"attitude\", \"multihop-fo\", \"multihop-so\", \"location-fo\", \"location-so\"] # question types to optimize a module for\n", 311 | "modules = {}\n", 312 | "\n", 313 | "# define modules for each question type\n", 314 | "for question_type in eval_question_types:\n", 315 | " print(f\"TYPE: {question_type}\")\n", 316 | " evaluator = OpenToMEvaluatorDspy(model_name=\"(training set) complied baleen\")\n", 317 | " optimizer = BootstrapFewShotWithRandomSearch(metric=evaluator.dspy_metric, num_threads=1)\n", 318 | " compiled_baleen = optimizer.compile(SimplifiedBaleen(), trainset=datasets[question_type][\"train\"][:25])\n", 319 | "\n", 320 | " modules[question_type] = compiled_baleen\n", 321 | " time.sleep(60)\n" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "from dspy.evaluate.evaluate import Evaluate\n", 331 | "\n", 332 | "print(\"Macro Averaged F1 Scores\")\n", 333 | "for question_type in eval_question_types:\n", 334 | " test = datasets[question_type][\"test\"]\n", 335 | " compiled_baleen = modules[question_type]\n", 336 | "\n", 337 | " # Set up the `evaluate_on_hotpotqa` function.\n", 338 | " evaluate_on_opentom = Evaluate(devset=test[:10], num_threads=1, display_progress=True, display_table=10)\n", 339 | "\n", 340 | " uncompiled_baleen_evaluator = OpenToMEvaluatorDspy(model_name='uncompiled_baleen')\n", 341 | " uncompiled_baleen_retrieval_score = evaluate_on_opentom(uncompiled_baleen, metric=uncompiled_baleen_evaluator.dspy_metric, display=False)\n", 342 | " uncompiled_baleen_evaluator.print_f1_results()\n", 343 | "\n", 344 | " compiled_baleen_evaluator = OpenToMEvaluatorDspy(model_name='compiled_baleen')\n", 345 | " compiled_baleen_retrieval_score = evaluate_on_opentom(compiled_baleen, metric=compiled_baleen_evaluator.dspy_metric, display=False)\n", 346 | " compiled_baleen_evaluator.print_f1_results()" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "uncompiled_baleen.dump_state()" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "compiled_baleen.dump_state()" 365 | ] 366 | } 367 | ], 368 | "metadata": { 369 | "kernelspec": { 370 | "display_name": ".venv", 371 | "language": "python", 372 | "name": "python3" 373 | }, 374 | "language_info": { 375 | "codemirror_mode": { 376 | "name": "ipython", 377 | "version": 3 378 | }, 379 | "file_extension": ".py", 380 | "mimetype": "text/x-python", 381 | "name": "python", 382 | "nbconvert_exporter": "python", 383 | "pygments_lexer": "ipython3", 384 | "version": "3.11.4" 385 | } 386 | }, 387 | "nbformat": 4, 388 | "nbformat_minor": 2 389 | } 390 | -------------------------------------------------------------------------------- /get_data.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | import requests 3 | import pickle 4 | import json 5 | import random 6 | from collections import defaultdict 7 | import pandas as pd 8 | 9 | 10 | # this is the one that they sampled 100 existing OpenToM plots to produce "extra long" narratives 11 | # URL = "https://raw.githubusercontent.com/SeacowX/OpenToM/main/data/opentom_long.json" 12 | URL = "https://raw.githubusercontent.com/SeacowX/OpenToM/main/data/opentom.json" 13 | 14 | 15 | def default_factory(): 16 | return [] 17 | 18 | 19 | def load_dataset(): 20 | response = requests.get(URL).json() 21 | 22 | df = pd.DataFrame(response) 23 | 24 | # Extract 'type' and 'answer' into separate columns 25 | df["type"] = df["question"].apply(lambda x: x["type"]) 26 | df["answer"] = df["question"].apply(lambda x: x["answer"]) 27 | 28 | unique_answers_by_type = df.groupby("type")["answer"].unique() 29 | 30 | # convert the dataset to what DSPy expects (list of Example objects) 31 | dataset = [] 32 | 33 | for index, row in df.iterrows(): 34 | context = row["narrative"] 35 | question = row["question"]["question"] 36 | answer = row["question"]["answer"] 37 | type = row["question"]["type"] 38 | plot_info = json.dumps(row["plot_info"]) # Keeping each example field as a string might be a good idea 39 | 40 | # update the type value if location is coarse or fine 41 | if "location" in type: 42 | location_granularity = "fine" if answer.lower().strip() != "yes" and answer.lower().strip() != "no" else "coarse" 43 | type = f"{type}-{location_granularity}" 44 | 45 | # Answer choices 46 | if "location" in type and ( 47 | answer.lower().strip() != "yes" and answer.lower().strip() != "no" 48 | ): # don't provide answer choices for fine grained location questions 49 | answer_choices = "n/a, list a specific location" 50 | elif "location" in type: 51 | answer_choices = "No, Yes" 52 | else: 53 | answer_choices = ", ".join(unique_answers_by_type[type]) 54 | 55 | dataset.append( 56 | dspy.Example( 57 | context=context, question=question, answer=answer, type=type, plot_info=plot_info, answer_choices=answer_choices 58 | ).with_inputs("context", "question", "answer_choices") 59 | ) 60 | 61 | # split datasets by question types 62 | datasets = defaultdict(default_factory) 63 | 64 | for example in dataset: 65 | datasets[example.type].append(example) 66 | 67 | datasets.keys() 68 | [len(dataset) for dataset in datasets.values()] 69 | 70 | # create train test split 71 | for question_type, dataset in datasets.items(): 72 | random.shuffle(dataset) 73 | 74 | datasets[question_type] = { 75 | "train": dataset[int(len(dataset) * 0.8) :], # 80% test, 20% train 76 | "test": dataset[: int(len(dataset) * 0.8)], 77 | } 78 | 79 | print(f"Train {question_type}: {len(datasets[question_type]['train'])}") 80 | print(f"Test {question_type}: {len(datasets[question_type]['test'])}") 81 | 82 | # Serialize and save the datasets object to a file 83 | with open("datasets.pkl", "wb") as file: 84 | pickle.dump(datasets, file) 85 | 86 | print("🫡 Datasets object has been saved to 'datasets.pkl' 🫡") 87 | 88 | 89 | if __name__ == "__main__": 90 | load_dataset() 91 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # run with python main.py cot 2 | 3 | import pickle 4 | import time 5 | import argparse 6 | from typing import Optional 7 | from opentom_evaluator import OpenToMEvaluatorDspy 8 | import dspy 9 | from dspy.teleprompt import BootstrapFewShotWithRandomSearch, SignatureOptimizer 10 | from dspy.evaluate.evaluate import Evaluate 11 | from cot import CoTSimplifiedBaleen 12 | from cot_with_thought import CoTWithThoughtSimplifiedBaleen 13 | from get_data import default_factory, load_dataset 14 | from collections import defaultdict 15 | from dotenv import load_dotenv 16 | import neptune 17 | import numpy as np 18 | 19 | load_dotenv() 20 | 21 | # initialize neptune 22 | run = neptune.init_run( 23 | project="dspy-opentom/dspy-evaluation", 24 | capture_hardware_metrics=False, 25 | capture_stderr=True, 26 | capture_stdout=True, 27 | capture_traceback=True, 28 | ) 29 | 30 | EVAL_QUESTION_TYPES = [ 31 | "attitude", 32 | "multihop-fo", 33 | "multihop-so", 34 | "location-fo-coarse", 35 | "location-fo-fine", 36 | "location-so-coarse", 37 | "location-so-fine", 38 | ] 39 | 40 | 41 | def dump_state(data, filename): 42 | with open(filename, "wb") as file: 43 | pickle.dump(data, file) 44 | 45 | 46 | def main(dspy_method, dspy_optimizer, download_dataset, question_types, teacher_lm, train_size): 47 | # load dataset 48 | if download_dataset: 49 | load_dataset() 50 | 51 | # read in the datasets pickle object 52 | with open("datasets.pkl", "rb") as file: 53 | datasets = pickle.load(file) 54 | 55 | if dspy_method == "cot": 56 | module_type = CoTSimplifiedBaleen 57 | elif dspy_method == "cot_with_thought": 58 | module_type = CoTWithThoughtSimplifiedBaleen 59 | else: 60 | raise Exception(f"Dspy method '{dspy_method}' is not valid") 61 | 62 | modules = {} 63 | # define modules for each question type 64 | for question_type in question_types: 65 | print(f"TYPE: {question_type}") 66 | evaluator = OpenToMEvaluatorDspy(model_name="(training set) complied baleen") 67 | 68 | if dspy_optimizer == "bootstrap_fewshot_with_random_search": 69 | optimizer = BootstrapFewShotWithRandomSearch( 70 | metric=evaluator.dspy_metric, 71 | num_candidate_programs=25, 72 | num_threads=1, 73 | teacher_settings=dict(lm=teacher_lm), 74 | ) 75 | compiled_baleen = optimizer.compile(module_type(), trainset=datasets[question_type]["train"][:train_size]) 76 | elif dspy_optimizer == "signature_optimizer": 77 | optimizer = SignatureOptimizer( 78 | metric=evaluator.dspy_metric, 79 | breadth=10, 80 | depth=3, 81 | init_temperature=1.4, 82 | verbose=True, 83 | track_stats=True, 84 | prompt_model=teacher_lm, 85 | ) 86 | eval_kwargs = dict(num_threads=1, display_progress=True, display_table=0) 87 | compiled_baleen = optimizer.compile( 88 | module_type(), 89 | devset=datasets[question_type]["train"][:train_size], 90 | eval_kwargs=eval_kwargs, 91 | ) 92 | else: 93 | raise Exception(f"Invalid dspy optimizer type: {dspy_optimizer}") 94 | 95 | modules[question_type] = compiled_baleen 96 | time.sleep(10) 97 | 98 | uncompiled_baleen = CoTSimplifiedBaleen() # regular cot is always the uncompiled baseline 99 | 100 | print("Beginning Evaluation") 101 | for question_type in question_types: 102 | compiled_baleen = modules[question_type] 103 | 104 | # Evaluation Procedure: Calculate the F1 Score for a randomly drawn batch of 50 questions 5 times and average the F1 Scores 105 | batch_size = 50 106 | num_batches = 5 107 | 108 | assert len(datasets[question_type]["test"]) >= batch_size * num_batches 109 | test = datasets[question_type]["test"][: batch_size * num_batches] 110 | test_sets = [test[i : i + batch_size] for i in range(num_batches)] 111 | 112 | uncompiled_f1_scores = [] 113 | compiled_f1_scores = [] 114 | 115 | for test in test_sets: 116 | # Set up the `evaluate_on_hotpotqa` function. 117 | evaluate_on_opentom = Evaluate(devset=test, num_threads=1, display_progress=True, display_table=0) 118 | 119 | uncompiled_baleen_evaluator = OpenToMEvaluatorDspy(model_name="uncompiled_baleen") 120 | evaluate_on_opentom(uncompiled_baleen, metric=uncompiled_baleen_evaluator.dspy_metric, display=True) 121 | uncompiled_f1_scores.append(uncompiled_baleen_evaluator.f1_score()[question_type]["macro_averaged"]) 122 | 123 | compiled_baleen_evaluator = OpenToMEvaluatorDspy(model_name="compiled_baleen") 124 | evaluate_on_opentom(compiled_baleen, metric=compiled_baleen_evaluator.dspy_metric, display=True) 125 | compiled_f1_scores.append(compiled_baleen_evaluator.f1_score()[question_type]["macro_averaged"]) 126 | 127 | # overall f1 scores 128 | uncompiled_mean_f1 = np.mean(uncompiled_f1_scores) 129 | uncompiled_std_f1 = np.std(uncompiled_f1_scores) 130 | 131 | compiled_mean_f1 = np.mean(compiled_f1_scores) 132 | compiled_std_f1 = np.std(compiled_f1_scores) 133 | 134 | run[f"evaluation/{question_type}/uncompiled/mean_macro_averaged_f1"] = uncompiled_mean_f1 135 | run[f"evaluation/{question_type}/uncompiled/mean_macro_averaged_f1"] = uncompiled_std_f1 136 | run[f"evaluation/{question_type}/compiled/mean_macro_averaged_f1"] = compiled_mean_f1 137 | run[f"evaluation/{question_type}/compiled/mean_macro_averaged_f1"] = compiled_std_f1 138 | 139 | print( 140 | f"Mean Macro Averaged F1 Scores (± std dev.) - {question_type} - Aggregated from {num_batches} batches of {batch_size} questions" 141 | ) 142 | print(f"uncompiled: {uncompiled_mean_f1:.3f} ± {uncompiled_std_f1:.3}") 143 | print(f"compiled: {compiled_mean_f1:.3} ± {compiled_std_f1:.3}") 144 | 145 | dump_state(modules, "cot_modules.pkl") 146 | run["cot_modules"].upload("cot_modules.pkl") 147 | 148 | 149 | if __name__ == "__main__": 150 | parser = argparse.ArgumentParser(description="Run DSPY method.") 151 | 152 | # dspy arguments 153 | parser.add_argument("experiment_title", type=str, help="Title of new experiment") 154 | parser.add_argument("dspy_method", type=str, help="The DSPY method to run") 155 | parser.add_argument("dspy_optimizer", type=str, help="The DSPY optimizer to use") 156 | parser.add_argument("--student", default="gpt-3.5-turbo", type=str, help="The LLM to optimize prompts for") 157 | parser.add_argument("--teacher", default=None, type=str, help="Teacher LLM for optimizing prompts. Defaults to Student LLM") 158 | parser.add_argument("--train_size", default=50, type=int, help="Number of training examples to use for optimization") 159 | parser.add_argument("--download_dataset", default=True, type=bool, help="Download dataset") 160 | parser.add_argument("--question_types", default=EVAL_QUESTION_TYPES, nargs="*", help="Question types. Defaults to all") 161 | 162 | args = parser.parse_args() 163 | 164 | # setup LLMs 165 | student_lm = dspy.OpenAI(model=args.student, max_tokens=1000) 166 | args.teacher = args.student if args.teacher is None else args.teacher 167 | teacher_lm = dspy.OpenAI(model=args.teacher, max_tokens=1000) 168 | dspy.settings.configure(lm=student_lm) 169 | 170 | # validate question types 171 | question_types = args.question_types 172 | assert all([question_type in EVAL_QUESTION_TYPES for question_type in question_types]) 173 | args.question_types = ", ".join(question_types) # turn list into string for neptune logging 174 | 175 | # log run parameters 176 | run["parameters"] = args 177 | run["sys/name"] = args.experiment_title 178 | 179 | main(args.dspy_method, args.dspy_optimizer, args.download_dataset, question_types, teacher_lm, args.train_size) 180 | -------------------------------------------------------------------------------- /opentom_evaluator.py: -------------------------------------------------------------------------------- 1 | # taken from https://github.com/seacowx/OpenToM/blob/main/src/evaluate/opentom_evaluator.py 2 | # modified for usability 3 | 4 | from collections import defaultdict 5 | import json 6 | import traceback 7 | 8 | 9 | class OpenToMEvaluatorDspy: 10 | 11 | def __init__(self, model_name="") -> None: 12 | self.true_positives = defaultdict(lambda: 0) 13 | self.false_positives = defaultdict(lambda: 0) 14 | self.false_negatives = defaultdict(lambda: 0) 15 | self.model_name = model_name 16 | 17 | def dspy_metric(self, example, pred_answer, trace=None): 18 | type = example.type 19 | 20 | eval_result = self.check_answer(example, pred_answer.answer) 21 | if eval_result == None: # Hm what is the correct value to return as a dspy metric when there's an invalid example? 22 | return None 23 | gt, pred = eval_result # ground truth answer class, predicted answer class 24 | 25 | # store positive/negative results by class so we can calculate the f1 scores later 26 | if gt == pred: 27 | self.true_positives[f"{type}_{pred}"] += 1 28 | else: 29 | self.false_positives[f"{type}_{pred}"] += 1 30 | self.false_negatives[f"{type}_{gt}"] += 1 31 | 32 | # print("done", example.type, gt, pred, example.answer, pred_answer.answer) 33 | 34 | return gt == pred 35 | 36 | # this method was added to make dspy evaluation easier 37 | def check_answer( 38 | self, 39 | example, 40 | pred_answer, 41 | cot_flag=False, 42 | perspective="all", 43 | ): 44 | mover, affected_char, eoi, original_place, move_to_place = json.loads(example.plot_info).values() 45 | 46 | cur_question_type = example.type 47 | question_content = example.question 48 | 49 | gt_answer = example.answer.strip() 50 | pred_answer = pred_answer.strip() 51 | 52 | # NOTE: evaluate based on the character 53 | if perspective == "observer": 54 | if mover in question_content and affected_char not in question_content: 55 | return None 56 | 57 | if mover in question_content and affected_char in question_content: 58 | question_tokens = question_content.replace("'s", "").replace(",", "").split() 59 | 60 | mover_idx = question_tokens.index(mover) 61 | affected_char_idx = question_tokens.index(affected_char) 62 | 63 | if mover_idx < affected_char_idx: 64 | return None 65 | 66 | elif perspective == "mover": 67 | if mover not in question_content and affected_char in question_content: 68 | return None 69 | 70 | if mover in question_content and affected_char in question_content: 71 | question_tokens = question_content.replace("'s", "").replace(",", "").split() 72 | 73 | mover_idx = question_tokens.index(mover) 74 | affected_char_idx = question_tokens.index(affected_char) 75 | 76 | if mover_idx > affected_char_idx: 77 | return None 78 | 79 | if cot_flag: 80 | pred_answer = self.parse_cot_answer(pred_answer) 81 | 82 | if cur_question_type == "location-fo-coarse": 83 | gt, pred = self.check_answer_for_cg_location(pred_answer, gt_answer) 84 | return gt, pred 85 | 86 | elif cur_question_type == "location-fo-fine": 87 | gt, pred = self.check_answer_for_fg_location(pred_answer, gt_answer, original_place, move_to_place) 88 | return gt, pred 89 | 90 | elif cur_question_type == "location-so-coarse": 91 | gt, pred = self.check_answer_for_cg_location(pred_answer, gt_answer) 92 | return gt, pred 93 | 94 | elif cur_question_type == "location-so-fine": 95 | gt, pred = self.check_answer_for_fg_location(pred_answer, gt_answer, original_place, move_to_place) 96 | return gt, pred 97 | 98 | elif cur_question_type == "multihop-fo": 99 | if "fullness" in question_content: 100 | gt, pred = self.check_fullness_answer(pred_answer, gt_answer) 101 | return gt, pred 102 | 103 | elif "accessibility" in question_content: 104 | if "|" in gt_answer: 105 | gt_answer = "equally accessible" 106 | 107 | if isinstance(gt_answer, list): 108 | gt_answer = [ele for ele in gt_answer if ele != "corrupted"] 109 | assert len(gt_answer) == 1 110 | gt_answer = gt_answer[0] 111 | 112 | gt, pred = self.check_accessibility_answer(pred_answer, gt_answer) 113 | return gt, pred 114 | 115 | elif cur_question_type == "multihop-so": 116 | if "fullness" in question_content: 117 | gt, pred = self.check_fullness_answer(pred_answer, gt_answer) 118 | return gt, pred 119 | 120 | elif "accessibility" in question_content: 121 | if "|" in gt_answer: 122 | gt_answer = "equally accessible" 123 | 124 | if isinstance(gt_answer, list): 125 | gt_answer = [ele for ele in gt_answer if ele != "corrupted"] 126 | assert len(gt_answer) == 1 127 | gt_answer = gt_answer[0] 128 | 129 | gt, pred = self.check_accessibility_answer(pred_answer, gt_answer) 130 | return gt, pred 131 | 132 | elif cur_question_type == "attitude": 133 | gt, pred = self.check_attitude_answer(pred_answer, gt_answer) 134 | return gt, pred 135 | 136 | def f1_score(self): 137 | true_positives = self.true_positives 138 | false_positives = self.false_positives 139 | false_negatives = self.false_negatives 140 | f1_scores = defaultdict(lambda: {"by_class": {}}) 141 | 142 | for _class in true_positives.keys() | false_positives.keys() | false_negatives.keys(): 143 | question_type, _ = _class.split("_") 144 | class_true_positives = true_positives[_class] 145 | class_false_positives = false_positives[_class] 146 | class_false_negatives = false_negatives[_class] 147 | class_precision = ( 148 | class_true_positives / (class_true_positives + class_false_positives) if class_true_positives > 0.0 else 0.0 149 | ) # avoid dividing by zero 150 | class_recall = ( 151 | class_true_positives / (class_true_positives + class_false_negatives) if class_true_positives > 0.0 else 0.0 152 | ) 153 | class_f1_score = ( 154 | (2 * class_precision * class_recall) / (class_precision + class_recall) 155 | if class_precision > 0.0 or class_recall > 0.0 156 | else 0.0 157 | ) 158 | f1_scores[question_type]["by_class"][_class] = class_f1_score 159 | 160 | for question_type, type_f1_scores in f1_scores.items(): 161 | type_f1_scores = type_f1_scores["by_class"] 162 | macro_averaged_f1_score = sum(list(type_f1_scores.values())) / len(type_f1_scores) 163 | f1_scores[question_type]["macro_averaged"] = macro_averaged_f1_score 164 | 165 | return f1_scores 166 | 167 | # pretty print macro averaged f1 scores for each question type 168 | def print_f1_results(self, round_decimal=2, print_header=False): 169 | f1_scores = self.f1_score() 170 | if print_header: 171 | print("Macro Averaged F1 Scores by question type") 172 | 173 | print(self.model_name, end=" - ") 174 | for question_type, type_f1_scores in f1_scores.items(): 175 | print( 176 | f'{question_type}: {round(type_f1_scores["macro_averaged"], ndigits=round_decimal+2) * 100}', 177 | end="\t", 178 | ) 179 | print() 180 | 181 | @staticmethod 182 | def remove_determinant(word: str) -> str: 183 | determinants = ["a", "an", "the"] 184 | for det in determinants: 185 | if word.startswith(det): 186 | return word[len(det) :].strip() 187 | return word 188 | 189 | @staticmethod 190 | def compute_lexical_overlap(pred: str, location: str) -> float: 191 | pred = pred.lower().replace("_", " ").replace("'s", "") 192 | location = location.lower().replace("_", " ").replace("'s", "") 193 | score = 0 194 | pred = pred.replace(".", "").split() 195 | location = location.split() 196 | visited_word = [] 197 | 198 | for word in pred: 199 | if word in location and word not in visited_word: 200 | score += 1 201 | visited_word.append(word) 202 | 203 | return score / len(location) 204 | 205 | def parse_cot_answer(self, answer: str) -> str: 206 | # cot typically generate answer in the last sentence or paragraph 207 | if "\n" in answer: 208 | answer = answer.split("\n")[-1] 209 | else: 210 | answer = answer.split("Therefore")[-1] 211 | return answer 212 | 213 | def check_answer_for_fg_location(self, prediction: str, answer: str, original_place: str, move_to_place: str) -> list: 214 | 215 | # truncate prediction as some of them contain explanations 216 | answer = self.remove_determinant(answer).lower() 217 | original_place = self.remove_determinant(original_place).lower() 218 | move_to_place = self.remove_determinant(move_to_place).lower() 219 | gt_label, pred_label = None, None 220 | original_place_score = self.compute_lexical_overlap(prediction, original_place) 221 | move_to_place_score = self.compute_lexical_overlap(prediction, move_to_place) 222 | 223 | if original_place_score == move_to_place_score: 224 | pred_label = 3 225 | if original_place_score > move_to_place_score: 226 | pred_label = 1 227 | elif original_place_score < move_to_place_score: 228 | pred_label = 2 229 | 230 | if original_place == answer: 231 | gt_label = 1 232 | elif move_to_place == answer: 233 | gt_label = 2 234 | 235 | return [gt_label, pred_label] 236 | 237 | def check_answer_for_cg_location(self, prediction: str, answer: str) -> list: 238 | prediction = prediction.lower() 239 | answer = answer.lower() 240 | 241 | if "no" in prediction and "yes" not in prediction: 242 | pred_label = 0 243 | elif "yes" in prediction and "no" not in prediction: 244 | pred_label = 1 245 | else: 246 | pred_label = -1 247 | 248 | if "no" in answer: 249 | gt_label = 0 250 | elif "yes" in answer: 251 | gt_label = 1 252 | 253 | return [gt_label, pred_label] 254 | 255 | def check_fullness_answer(self, prediction: str, answer: str) -> list: 256 | prediction = prediction.replace(".", "").lower() 257 | less_full_answer_list = ["less full", "emptier", "more empty"] 258 | more_full_answer_list = ["more full", "fuller"] 259 | pred_label, gt_label = None, None 260 | for less_full_ans in less_full_answer_list: 261 | if less_full_ans in prediction: 262 | pred_label = 1 263 | 264 | if not pred_label: 265 | for more_full_ans in more_full_answer_list: 266 | if more_full_ans in prediction: 267 | pred_label = 2 268 | 269 | if not pred_label: 270 | if "equally full" in prediction: 271 | pred_label = 3 272 | 273 | if not pred_label: 274 | pred_label = -1 # corrupted 275 | 276 | if answer == "less full": 277 | gt_label = 1 278 | elif answer == "more full": 279 | gt_label = 2 280 | elif answer == "equally full": 281 | gt_label = 3 282 | 283 | return [gt_label, pred_label] 284 | 285 | def check_accessibility_answer(self, prediction: str, answer: str) -> list: 286 | prediction = prediction.replace(".", "").lower() 287 | pred_label, gt_label = None, None 288 | if "more accessible" in prediction: 289 | pred_label = 1 290 | elif "less accessible" in prediction: 291 | pred_label = 2 292 | elif "equally accessible" in prediction: 293 | pred_label = 3 294 | else: 295 | pred_label = -1 # corrupted 296 | 297 | if answer == "more accessible": 298 | gt_label = 1 299 | elif answer == "less accessible": 300 | gt_label = 2 301 | else: 302 | gt_label = 3 303 | 304 | return [gt_label, pred_label] 305 | 306 | def check_attitude_answer(self, prediction: str, answer: str) -> list: 307 | prediction = prediction.lower() 308 | answer = answer.lower() 309 | answer_map = {"a": "positive", "b": "neutral", "c": "negative"} 310 | prediction_token = prediction.split("\n\n")[-1].split(":")[-1].split(".")[0].strip().lower() 311 | gt_label, pred_label = None, None 312 | 313 | if answer == "positive": 314 | gt_label = 1 315 | elif answer == "negative": 316 | gt_label = 2 317 | else: 318 | gt_label = 3 319 | 320 | try: 321 | prediction = answer_map[prediction_token] 322 | if prediction == "positive": 323 | pred_label = 1 324 | elif prediction == "negative": 325 | pred_label = 2 326 | else: 327 | pred_label = 3 328 | 329 | except: 330 | if "positive" in prediction_token and "negative" in prediction_token: 331 | pred_label = -1 332 | elif "positive" in prediction_token and "neutral" in prediction_token: 333 | pred_label = -1 334 | elif "neutral" in prediction_token and "negative" in prediction_token: 335 | pred_label = -1 336 | elif "positive" in prediction_token: 337 | pred_label = 1 338 | elif "negative" in prediction_token: 339 | pred_label = 2 340 | elif "neutral" in prediction_token: 341 | pred_label = 3 342 | else: 343 | pred_label = -1 344 | 345 | return [gt_label, pred_label] 346 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "opentom-dspy" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["vintro "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.11" 10 | jupyterlab = "^4.1.2" 11 | python-dotenv = "^1.0.1" 12 | dspy-ai = "^2.1.10" 13 | neptune = "^1.9.1" 14 | 15 | 16 | [build-system] 17 | requires = ["poetry-core"] 18 | build-backend = "poetry.core.masonry.api" 19 | --------------------------------------------------------------------------------