├── open_intent_classifier ├── __init__.py ├── utils.py ├── consts.py ├── embedder.py └── model.py ├── .bumpversion.cfg ├── requirements.txt ├── tests ├── test_dspy_classifier.py ├── test_openai_classifier.py ├── test_devices.py ├── test_flan_t5_base.py ├── test_embeddings.py ├── test_flan_t5_small.py └── test_smollm2.py ├── setup.py ├── .gitignore ├── README.md └── notebooks ├── dspy_tutorial.ipynb ├── smollm2.ipynb └── dspy_training.ipynb /open_intent_classifier/__init__.py: -------------------------------------------------------------------------------- 1 | VERSION = "0.0.8" 2 | 3 | -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.0.8 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:open_intent_classifier/__init__.py] 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | torch 3 | sentencepiece 4 | scikit-learn 5 | sentence-transformers 6 | bump2version 7 | jupyter 8 | datasets 9 | pandas 10 | -------------------------------------------------------------------------------- /open_intent_classifier/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def join_labels(labels): 5 | prompt_options = "Options:\n" 6 | for label in labels: 7 | prompt_options += f"# {label} \n" 8 | return prompt_options 9 | 10 | def labels_to_str(labels: List[str]): 11 | return "%".join(labels) -------------------------------------------------------------------------------- /tests/test_dspy_classifier.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from open_intent_classifier.model import DSPyClassifier 4 | 5 | 6 | class TestDSPyClassifier(unittest.TestCase): 7 | 8 | def test_basic(self): 9 | classifier = DSPyClassifier() 10 | labels = ["Cancel subscription", "Refund request"] 11 | text = "I want to cancel my subscription" 12 | result = classifier.predict(text, labels) 13 | self.assertEquals("Cancel subscription".lower(), result.class_name.lower()) 14 | 15 | -------------------------------------------------------------------------------- /tests/test_openai_classifier.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from open_intent_classifier.model import OpenAiIntentClassifier 4 | 5 | 6 | class TestOpenAiIntentClassifier(unittest.TestCase): 7 | 8 | def test_prediction(self): 9 | classifier = OpenAiIntentClassifier(model_name="gpt-4o-mini") 10 | result = classifier.predict(text="I want to cancel subscription", 11 | labels=["cancel subscription", "cancel refund", "permission issue"]) 12 | 13 | self.assertEquals("cancel subscription", result.class_name) 14 | -------------------------------------------------------------------------------- /tests/test_devices.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch.cuda 4 | 5 | from open_intent_classifier.model import IntentClassifier 6 | 7 | 8 | class TestDevices(TestCase): 9 | 10 | def test_default(self): 11 | self.model = IntentClassifier() 12 | 13 | def test_cuda(self): 14 | if torch.cuda.is_available(): 15 | self.model = IntentClassifier(device="cuda") 16 | 17 | def test_cpu(self): 18 | self.model = IntentClassifier(device="cpu") 19 | 20 | def test_wrong_device(self): 21 | with self.assertRaises(RuntimeError): 22 | self.model = IntentClassifier(device="wrong") -------------------------------------------------------------------------------- /tests/test_flan_t5_base.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from open_intent_classifier.model import IntentClassifier 4 | 5 | 6 | class TestFlanT5Base(TestCase): 7 | 8 | def setUp(self) -> None: 9 | self.model = IntentClassifier() 10 | 11 | def test_sanity(self): 12 | output = self.model.predict("I want to cancel subscription", ["Cancel Subscription", "Refund"]) 13 | self.assertTrue("Cancel Subscription", output) 14 | 15 | def test_cuda(self): 16 | output = self.model.predict("I want a refund", ["Refund Request", "Issues", "Bug Report"]) 17 | self.assertTrue("Refund Request", output) 18 | 19 | def test_cpu(self): 20 | output = self.model.predict("I don't remember my password, can you please help me?", 21 | ["Cancel Subscription", "Refund Request", "Transfer Request", "Login Issues"]) 22 | self.assertTrue("Login Issues", output) 23 | -------------------------------------------------------------------------------- /tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from open_intent_classifier.embedder import StaticLabelsEmbeddingClassifier 4 | 5 | 6 | class TestEmbeddings(TestCase): 7 | def setUp(self) -> None: 8 | self.model = StaticLabelsEmbeddingClassifier(["Cancel Subscription", "Refund Requests"]) 9 | 10 | def test_top_1(self): 11 | labels, probabilities = self.model.predict("I want to cancel my subscription", n=1) 12 | self.assertTrue(1, len(labels)) 13 | self.assertTrue(1, len(probabilities)) 14 | self.assertTrue("Cancel Subscription", labels[0]) 15 | 16 | def test_top_2(self): 17 | labels, probabilities = self.model.predict("I want to cancel subscription", n=2) 18 | self.assertTrue(2, len(labels)) 19 | self.assertTrue(2, len(probabilities)) 20 | self.assertTrue("Cancel Subscription", labels[0]) 21 | self.assertTrue("Refund Requests", labels[1]) 22 | 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | from open_intent_classifier import VERSION 4 | from pathlib import Path 5 | this_directory = Path(__file__).parent 6 | long_description = (this_directory / "README.md").read_text() 7 | 8 | setup( 9 | name="open_intent_classifier", 10 | version=VERSION, 11 | packages=["open_intent_classifier"], 12 | author="Serj Smorodinsky", 13 | url="https://github.com/SerjSmor/open-intent-classifier", 14 | author_email="serjsmor@gmail.com", 15 | description="This library has two purposes: 1. allow to easily test semantic classification with open labels (not pre defined) for intent recognition. " 16 | "2. allow to experiment with different n-shot classification components.", 17 | install_requires=["transformers", "torch", "sentencepiece", "scikit-learn", "sentence-transformers", "pandas", 18 | "python-dotenv", "openai", "dspy", "accelerate"], 19 | long_description=long_description, 20 | long_description_content_type='text/markdown' 21 | ) 22 | -------------------------------------------------------------------------------- /tests/test_flan_t5_small.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from open_intent_classifier.consts import INTENT_CLASSIFIER_80M_FLAN_T5_SMALL 4 | from open_intent_classifier.model import IntentClassifier 5 | 6 | 7 | class TestFlanT5Small(TestCase): 8 | 9 | def setUp(self) -> None: 10 | self.model = IntentClassifier(model_name=INTENT_CLASSIFIER_80M_FLAN_T5_SMALL, verbose=True) 11 | 12 | def test_sanity(self): 13 | output = self.model.predict("I want to cancel subscription", ["Cancel Subscription", "Refund"]) 14 | self.assertTrue("Cancel Subscription", output) 15 | 16 | def test_cuda(self): 17 | output = self.model.predict("I want a refund", ["Refund Request", "Issues", "Bug Report"]) 18 | self.assertTrue("Refund Request", output) 19 | 20 | def test_cpu(self): 21 | output = self.model.predict("I don't remember my password, can you please help me?", ["Cancel Subscription", 22 | "Refund Request", "Transfer Request", "Login Issues"]) 23 | self.assertTrue("Login Issues", output) -------------------------------------------------------------------------------- /open_intent_classifier/consts.py: -------------------------------------------------------------------------------- 1 | PROMPT_TEMPLATE = "Topic %% Customer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n {labels} \nClass name: " 2 | 3 | OPENAI_PROMPT_TEMPLATE = ''' 4 | You are an expert in customer service domain. You need to classify a customer message into one of 5 | the following classes: {labels} 6 | Please return json object with the following structure: {'{class_name: '', reasoning: ''}'} class name should not contain a number. 7 | {examples} 8 | Customer message: {text} % 9 | Answer: . 10 | ''' 11 | 12 | SMOLLM2_PROMPT_TEMPLATE = '''Input format: 13 | Text: the message of the customer. 14 | Intent list: all of the possible labels, delimeted by a %. 15 | Answer: the name of the intent from the list that matches the intent from the message. 16 | 17 | Answer format is one of the intent from the list. 18 | Example 19 | --- 20 | {examples} 21 | --- 22 | Text: {text} 23 | Intent List: {labels} 24 | Answer: 25 | ''' 26 | 27 | INTENT_CLASSIFIER_248M_FLAN_T5_BASE = "Serj/intent-classifier" 28 | INTENT_CLASSIFIER_80M_FLAN_T5_SMALL = "Serj/intent-classifier-flan-t5-small" 29 | 30 | DEFAULT_MAX_LENGTH = 512 31 | DEFAULT_TRUNCATION = True 32 | 33 | SMOLLM2_1_7B = "HuggingFaceTB/SmolLM2-1.7B-Instruct" 34 | SMOLLM2_360M = "HuggingFaceTB/SmolLM2-360M-Instruct" -------------------------------------------------------------------------------- /open_intent_classifier/embedder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | from sentence_transformers import SentenceTransformer 5 | from sklearn.metrics.pairwise import cosine_similarity 6 | 7 | BAII_SMALL_V1_5 = "BAAI/bge-small-en-v1.5" 8 | 9 | class StaticLabelsEmbeddingClassifier: 10 | def __init__(self, labels: List[str], model_name: str = BAII_SMALL_V1_5): 11 | self.model = SentenceTransformer(model_name) 12 | self.embedded_labels = self.model.encode(labels, show_progress_bar=True) 13 | self.labels = labels 14 | 15 | def top_n(self, text: str, n: int, embeddings: np.array) -> (np.array, np.array): 16 | text_embedding = self.model.encode([text]) 17 | # for matrix multiplication we need the shape to be NXM MXN vector for example: vectpr 1X384, matrix 384X3 18 | angles = cosine_similarity(text_embedding, embeddings).squeeze() 19 | # https://stackoverflow.com/a/6910672 20 | # ::-1 reverses this list, # -n: top N 21 | sorted_indices = angles.argsort()[-n:][::-1] 22 | return sorted_indices, angles[sorted_indices] 23 | 24 | def predict(self, text: str, n=1) -> Tuple[List[str], np.array]: 25 | top_n_indices, top_n_angles = self.top_n(text, n=n, embeddings=self.embedded_labels) 26 | top_n_labels = [self.labels[i] for i in top_n_indices] 27 | 28 | return top_n_labels, top_n_angles 29 | 30 | -------------------------------------------------------------------------------- /tests/test_smollm2.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from open_intent_classifier.model import SmolLm2Classifier, ClassificationExample 4 | 5 | 6 | class TestSmolLm2(unittest.TestCase): 7 | 8 | def test_basic(self): 9 | labels = ["Cancel subscription", "Refund request"] 10 | text = "I want to cancel my subscription" 11 | 12 | c = SmolLm2Classifier() 13 | result = c.predict(text, labels) 14 | self.assertEquals("Cancel subscription".lower(), result.class_name.lower()) 15 | 16 | def test_few_shot(self): 17 | labels = ["Order flight ticket", "Abbrevations", "Fare code", "Cheap price", "Meals questions", "Seating questions"] 18 | text = "I want to get a flight ticket to NYC asap" 19 | 20 | example_flight = ClassificationExample("Do you have an open ticket to a flight to Chicago?", 21 | "Order flight ticket", labels) 22 | 23 | example_price = ClassificationExample("What is the cheapest route to London?", "Price questions", labels) 24 | 25 | c = SmolLm2Classifier(verbose=True, few_shot_examples=[example_flight, example_price]) 26 | 27 | result = c.predict(text, labels) 28 | self.assertEquals("Order flight ticket".lower(), result.class_name.lower()) 29 | 30 | 31 | def test_hierarchy(self): 32 | labels = ["flight ticket", "price", "meal", "seats"] 33 | text = "I want to get a flight ticket to NYC asap" 34 | 35 | c = SmolLm2Classifier(verbose=True) 36 | example_flight = ClassificationExample("Do you have an open ticket to a flight to Chicago?", "flight ticket", 37 | labels) 38 | 39 | example_price = ClassificationExample("What is the price for a flight to London?", "price", labels) 40 | 41 | result = c.predict(text, labels, few_shot_examples=[example_flight, example_price]) 42 | self.assertEquals("flight ticket".lower(), result.class_name.lower()) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | .idea/ 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open Intent Classification 2 | Closed intent classification uses a set of predefined labels to identify an intent. 3 | In comparison, open intent classification allows you to define as many labels you want, without fine tuning the model. 4 | 5 | This project implements different components that support open intent classification such as an embedder, a T5 based fine tuned model for intent classification and a verbalizer. 6 | If you are interested in finer detailes you can read my [blog post](https://medium.com/80-20-hacking-ai/components-that-optimize-n-shot-text-classification-f574184e0b81). 7 | 8 | The goal of this library is to enable you test your assumptions about your data as fast as possible and to be a one stop shop for everything "classification like", similarly to how [Bertopic](https://maartengr.github.io/BERTopic/index.html) is for clustering. 9 | 10 | ## Why should you use this? 11 | 1. You are researching nlp classification problems and want to test different embeddings, verbalizers and components with plug-and-play feel 12 | 2. You want to detect semantically user intents in text but either don't want to commit to pre-defined classes OR just want to test out the fastest way to classify text other than through an LLM 13 | 14 | 15 | [!IMPORTANT] 16 | > open-intent-classification project is in Alpha stage. 17 | > 1. Expect API changes 18 | > 2. Milage may vary. Quality of classifiers have been tested on Atis and Banking77 19 | 20 | ## Installation 21 | `$ pip install open-intent-classifier` 22 | 23 | ## Usage 24 | A full example is under [Atis Notebook](https://github.com/SerjSmor/open-intent-classifier/blob/main/notebooks/atis_example.ipynb) 25 | 26 | ### T5 Based Intent Classification 27 | ````python 28 | from open_intent_classifier.model import IntentClassifier 29 | model = IntentClassifier() 30 | labels = ["Cancel Subscription", "Refund Requests", "Broken Item", "And More..."] 31 | text = "I don't want to continue this subscription" 32 | predicted_label = model.predict(text, labels) 33 | ```` 34 | 35 | By default, the IntentClassifier is loading a small model with 80M parameters. 36 | 37 | For higher accuracy you can initialize the model with: 38 | ```python 39 | from open_intent_classifier.model import IntentClassifier 40 | from open_intent_classifier.consts import INTENT_CLASSIFIER_248M_FLAN_T5_BASE 41 | model = IntentClassifier(INTENT_CLASSIFIER_248M_FLAN_T5_BASE) 42 | ``` 43 | This will increase model latency as well. 44 | 45 | 46 | ### Embeddings Based Classification 47 | ```python 48 | from open_intent_classifier.embedder import StaticLabelsEmbeddingClassifier 49 | labels = ["Cancel Subscription", "Refund Requests", "Broken Item", "And More..."] 50 | text = "I don't want to continue this subscription" 51 | embeddings_classifier = StaticLabelsEmbeddingClassifier(labels) 52 | predicted_label = embeddings_classifier.predict(text) 53 | 54 | ``` 55 | 56 | ### LLM Based Classification 57 | Using LLM for classification is a viable option that sometimes provides the highest quality. 58 | Currently we have implemented only Open AI based LLMs. 59 | 60 | ```python 61 | from open_intent_classifier.model import OpenAiIntentClassifier 62 | labels = ["Cancel Subscription", "Refund Requests", "Broken Item", "And More..."] 63 | text = "I don't want to continue this subscription" 64 | model_name = "gpt-4o-mini" 65 | classifier = OpenAiIntentClassifier(model_name) 66 | result = classifier.predict(text=text, labels=labels) 67 | ``` 68 | 69 | ### LLM Few Shot Based Classification 70 | ```python 71 | from open_intent_classifier.model import OpenAiIntentClassifier, ClassificationExample 72 | labels = ["Cancel Subscription", "Refund Requests", "Broken Item", "And More..."] 73 | text = "I don't want to continue this subscription" 74 | model_name = "gpt-4o-mini" 75 | example = ClassificationExample(text="I want to abort my account", intent_labels=labels, intent="Cancel Subscription") 76 | classifier = OpenAiIntentClassifier(model_name, few_shot_examples=[example]) 77 | result = classifier.predict(text=text, labels=labels) 78 | ``` 79 | 80 | ### DSPy based classification 81 | Here's an example of using DSPy based classifier for both educational purposes and potential increase in quality. 82 | ```python 83 | from open_intent_classifier.model import DSPyClassifier 84 | classifier = DSPyClassifier() 85 | labels = ["Cancel subscription", "Refund request"] 86 | text = "I want to cancel my subscription" 87 | result = classifier.predict(text, labels) 88 | ``` 89 | 90 | ### SmolLM2 based classification 91 | Here's an example of using SmolLM2 based classifier 92 | ```python 93 | from open_intent_classifier.model import SmolLm2Classifier 94 | classifier = SmolLm2Classifier() 95 | labels = ["Cancel subscription", "Refund request"] 96 | text = "I want to cancel my subscription" 97 | result = classifier.predict(text, labels) 98 | ``` 99 | 100 | ### SmolLM2 Few Shot Based Classification 101 | from open_intent_classifier.model import SmolLm2Classifier 102 | ```python 103 | from open_intent_classifier.model import SmolLm2Classifier, ClassificationExample 104 | labels = ["Cancel subscription", "Refund request"] 105 | example = ClassificationExample(text="I want to abort my account", intent_labels=labels, intent="Cancel Subscription") 106 | classifier = SmolLm2Classifier(few_shot_examples=[example]) 107 | text = "I want to cancel my subscription" 108 | result = classifier.predict(text, labels) 109 | ``` 110 | 111 | 112 | ## Training the T5 base classifier 113 | The details of training of the classifier is in another repository. I have separated training from inference in order to allow each repository to be focused and extended. 114 | 115 | You can read about the training in the training repo: https://github.com/SerjSmor/intent_classification 116 | 117 | # Roadmap 118 | 119 | - [x] Add LLM based classification 120 | - [x] Add DSPy based classification 121 | - [x] Add small language models as classifiers 122 | - [x] Add few shot interface 123 | - [ ] Add Atis and Banking77 comparisons with all classifiers 124 | - [ ] Add embeddings filtering stage for classifiers 125 | - [ ] Add multithreading for LLM based classifiers 126 | - [ ] Add an option to ensemble embeddings and T5 (and additional models) 127 | - [ ] Create a recommender for fine-tuning 128 | -------------------------------------------------------------------------------- /open_intent_classifier/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import List 4 | import logging 5 | 6 | import dspy 7 | from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, AutoModelForCausalLM 8 | from openai import OpenAI 9 | from dotenv import load_dotenv 10 | 11 | from open_intent_classifier.consts import INTENT_CLASSIFIER_80M_FLAN_T5_SMALL, DEFAULT_MAX_LENGTH, DEFAULT_TRUNCATION, \ 12 | PROMPT_TEMPLATE, OPENAI_PROMPT_TEMPLATE, SMOLLM2_1_7B, SMOLLM2_PROMPT_TEMPLATE 13 | 14 | from open_intent_classifier.utils import join_labels, labels_to_str 15 | 16 | CUDA = "cuda" 17 | 18 | GPT_4O_MINI = 'gpt-4o-mini' 19 | 20 | TEXT_PLACEHOLDER = "{text}" 21 | LABELS_PLACEHOLDER = "{labels}" 22 | EXAMPLES_PLACEHOLDER = "{examples}" 23 | 24 | # Create a logger with the filename 25 | logger = logging.getLogger(__name__) 26 | logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s') 27 | 28 | 29 | @dataclass 30 | class ClassificationResult: 31 | class_name: str 32 | reasoning: str 33 | 34 | 35 | @dataclass 36 | class ClassificationExample: 37 | text: str 38 | intent: str 39 | intent_labels: List[str] 40 | 41 | def __str__(self): 42 | return f''' 43 | ----- 44 | Example 45 | Text: {self.text} 46 | Intents list: {labels_to_str(self.intent_labels)} 47 | Answer: {self.intent} 48 | ----- 49 | ''' 50 | 51 | 52 | DEFAULT_EXAMPLE = ClassificationExample( 53 | text="I want to cancel subscription", 54 | intent="Cancel Subscription", 55 | intent_labels=["Refund Request", "Delivery Late", "Cancel Subscription"] 56 | ) 57 | 58 | 59 | class Classifier: 60 | def predict(self, text: str, labels: List[str], **kwargs) -> ClassificationResult: 61 | raise NotImplementedError() 62 | 63 | 64 | class Classification(dspy.Signature): 65 | """Classify the customer message into one of the intent labels. 66 | The output should be only the predicted class as a single intent label.""" 67 | 68 | customer_message = dspy.InputField(desc="Customer message during customer service interaction") 69 | intent_labels = dspy.InputField(desc="Labels that represent customer intent") 70 | intent_class = dspy.OutputField(desc="a label best matching customer's intent ") 71 | 72 | 73 | class DSPyClassifier(dspy.Module, Classifier): 74 | def __init__(self, model_name=GPT_4O_MINI): 75 | super().__init__() 76 | self.generate_answer = dspy.ChainOfThought(Classification) 77 | lm = dspy.OpenAI(model=model_name) 78 | dspy.settings.configure(lm=lm) 79 | 80 | def forward(self, customer_message: str, labels: str): 81 | return self.generate_answer(customer_message=customer_message, labels=labels) 82 | 83 | def predict(self, text: str, labels: List[str], **kwargs) -> ClassificationResult: 84 | labels = join_labels(labels) 85 | pred = self.forward(customer_message=text, labels=labels) 86 | return ClassificationResult(pred.intent_class, pred.rationale) 87 | 88 | 89 | class OpenAiIntentClassifier(Classifier): 90 | def __init__(self, model_name: str, openai_api_key: str = "", few_shot_examples: List[ClassificationExample] = None): 91 | if not openai_api_key: 92 | load_dotenv() 93 | openai_api_key = os.getenv("OPENAI_API_KEY") 94 | self.client = OpenAI(api_key=openai_api_key) 95 | self.model_name = model_name 96 | self.few_shot_examples = [] if few_shot_examples is None else [] 97 | 98 | def predict(self, text: str, labels: List[str], **kwargs) -> ClassificationResult: 99 | joined_labels = join_labels(labels) 100 | 101 | if len(self.few_shot_examples) == 0: 102 | self.few_shot_examples.append(DEFAULT_EXAMPLE) 103 | examples_str = "" 104 | for example in self.few_shot_examples: 105 | examples_str += str(example) 106 | 107 | prompt = OPENAI_PROMPT_TEMPLATE.replace(LABELS_PLACEHOLDER, joined_labels).replace(TEXT_PLACEHOLDER, text).\ 108 | replace(EXAMPLES_PLACEHOLDER, examples_str) 109 | 110 | completion = self.client.chat.completions.create( 111 | model=self.model_name, 112 | messages=[ 113 | {"role": "user", "content": prompt} 114 | ], 115 | response_format={"type": "json_object"} 116 | ) 117 | 118 | response_dict = eval(completion.choices[0].message.content) 119 | return ClassificationResult(response_dict["class_name"], response_dict["reasoning"]) 120 | 121 | 122 | class IntentClassifier: 123 | def __init__(self, model_name: str = INTENT_CLASSIFIER_80M_FLAN_T5_SMALL, device: str = None, 124 | verbose: bool = False): 125 | self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device) 126 | self.tokenizer = T5Tokenizer.from_pretrained(model_name) 127 | self.device = device 128 | if verbose: 129 | logger.setLevel(logging.DEBUG) 130 | 131 | def _build_prompt(self, text: str, labels: List[str], prompt_template: str = PROMPT_TEMPLATE): 132 | prompt_options = join_labels(labels) 133 | 134 | # first replace {labels} this way we know that {text} can't be misused with tokens that can overlap with {text} 135 | prompt = prompt_template.replace(LABELS_PLACEHOLDER, prompt_options) 136 | prompt = prompt.replace(TEXT_PLACEHOLDER, text) 137 | 138 | return prompt 139 | 140 | def predict(self, text: str, labels: List[str], **kwargs) -> ClassificationResult: 141 | prompt = self._build_prompt(text, labels) 142 | max_length = kwargs.get("max_length", DEFAULT_MAX_LENGTH) 143 | truncation = kwargs.get("truncation", DEFAULT_TRUNCATION) 144 | input_ids = self.tokenizer.encode(prompt, return_tensors="pt", max_length=max_length, 145 | truncation=truncation).to(self.device) 146 | # Generate the output 147 | output = self.model.generate(input_ids, **kwargs) 148 | # Decode the output tokens 149 | decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True) 150 | logger.debug(f"Full prompt: {prompt}") 151 | logger.debug(f"Decoded output: {decoded_output}") 152 | return ClassificationResult(class_name=decoded_output, reasoning="") 153 | 154 | 155 | class SmolLm2Classifier(Classifier): 156 | def __init__(self, model_name: str = SMOLLM2_1_7B, device: str = CUDA, verbose: bool = False, 157 | few_shot_examples: List[ClassificationExample] = None): 158 | tokenizer = AutoTokenizer.from_pretrained(model_name) 159 | model = AutoModelForCausalLM.from_pretrained(model_name).to(device) 160 | self.tokenizer = tokenizer 161 | self.model = model 162 | self.device = device 163 | self.verbose = verbose 164 | self.few_shot_examples = [] if few_shot_examples is None else few_shot_examples 165 | 166 | def predict(self, text: str, labels: List[str], **kwargs) -> ClassificationResult: 167 | labels_str = "%".join(labels) 168 | 169 | if len(self.few_shot_examples) == 0: 170 | self.few_shot_examples.append(DEFAULT_EXAMPLE) 171 | examples_str = "" 172 | for example in self.few_shot_examples: 173 | examples_str += str(example) 174 | 175 | prompt = SMOLLM2_PROMPT_TEMPLATE.replace(TEXT_PLACEHOLDER, text).replace(LABELS_PLACEHOLDER, labels_str).replace(EXAMPLES_PLACEHOLDER, examples_str) 176 | if self.verbose: 177 | logger.info(prompt) 178 | 179 | messages = [{"role": "system", 180 | "content": "You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list"}, 181 | {"role": "user", "content": prompt}] 182 | 183 | input_text = self.tokenizer.apply_chat_template(messages, tokenize=False) 184 | 185 | inputs = self.tokenizer.encode(input_text, return_tensors="pt").to(self.device) 186 | outputs = self.model.generate(inputs, max_new_tokens=350) 187 | output = self.tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) 188 | 189 | if self.verbose: 190 | logger.info(output) 191 | 192 | last_answer_occurrence = output.rsplit("Answer: ", 2)[-1] 193 | return ClassificationResult(last_answer_occurrence, "") 194 | 195 | 196 | -------------------------------------------------------------------------------- /notebooks/dspy_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "da5c9e91-65f3-4c06-91f6-d7863f129848", 6 | "metadata": {}, 7 | "source": [ 8 | "## Basic Example" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "5f476789-3bb1-4fc9-9521-8650c0aef2d0", 14 | "metadata": {}, 15 | "source": [ 16 | "### First define the signature" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 21, 22 | "id": "e6d35f00-46b6-4d0e-93d6-28caa10f1b40", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import dspy \n", 27 | "from open_intent_classifier.utils import join_labels\n", 28 | "\n", 29 | "class Classification(dspy.Signature):\n", 30 | " \"\"\"Classify the customer message into one of the intent labels.\n", 31 | " The output should be only the predicted class as a single intent label.\"\"\"\n", 32 | "\n", 33 | " customer_message = dspy.InputField(desc=\"Customer message during customer service interaction\")\n", 34 | " intent_labels = dspy.InputField(desc=\"Labels that represent customer intent\")\n", 35 | " intent_class = dspy.OutputField(desc=\"a label best matching customer's intent \")" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "09b42d77-5ef4-420b-9ae4-cafa3ca81c28", 41 | "metadata": {}, 42 | "source": [ 43 | "### Define the languae model " 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 24, 49 | "id": "43e0debd-85b0-4d6e-9e13-3b6319165e25", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "lm_mini = dspy.OpenAI(model='gpt-4o-mini')\n", 54 | "dspy.settings.configure(lm=lm_mini)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "id": "17c8e00e-42a3-4650-adf6-99b3c096af15", 60 | "metadata": {}, 61 | "source": [ 62 | "### Predict" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 25, 68 | "id": "86e29e2a-f875-4614-9518-8129c3ff8a4e", 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "data": { 73 | "text/plain": [ 74 | "Prediction(\n", 75 | " intent_class='Cancel subscription'\n", 76 | ")" 77 | ] 78 | }, 79 | "execution_count": 25, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "predict = dspy.Predict(Classification)\n", 86 | "labels = join_labels([\"Cancel subscription\", \"Refund request\"])\n", 87 | "labels = \" \".join([\"Cancel subscription\", \"Refund request\"])\n", 88 | "text = \"I want to cancel my subscription\"\n", 89 | "pred = predict(customer_message=text, intent_labels=labels)\n", 90 | "pred" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 26, 96 | "id": "4b8d7f04-911d-4671-89a3-3105649fdd68", 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "\n", 104 | "\n", 105 | "\n", 106 | "Classify the customer message into one of the intent labels.\n", 107 | "The output should be only the predicted class as a single intent label.\n", 108 | "\n", 109 | "---\n", 110 | "\n", 111 | "Follow the following format.\n", 112 | "\n", 113 | "Customer Message: Customer message during customer service interaction\n", 114 | "Intent Labels: Labels that represent customer intent\n", 115 | "Intent Class: a label best matching customer's intent\n", 116 | "\n", 117 | "---\n", 118 | "\n", 119 | "Customer Message: I want to cancel my subscription\n", 120 | "Intent Labels: Cancel subscription Refund request\n", 121 | "Intent Class:\u001b[32m Cancel subscription\u001b[0m\n", 122 | "\n", 123 | "\n", 124 | "\n" 125 | ] 126 | }, 127 | { 128 | "data": { 129 | "text/plain": [ 130 | "\"\\n\\n\\nClassify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.\\n\\n---\\n\\nFollow the following format.\\n\\nCustomer Message: Customer message during customer service interaction\\nIntent Labels: Labels that represent customer intent\\nIntent Class: a label best matching customer's intent\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class:\\x1b[32m Cancel subscription\\x1b[0m\\n\\n\\n\"" 131 | ] 132 | }, 133 | "execution_count": 26, 134 | "metadata": {}, 135 | "output_type": "execute_result" 136 | } 137 | ], 138 | "source": [ 139 | "lm_mini.inspect_history(n=1)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "id": "c7dbd608-7cb3-4d3e-987e-853aa23e188a", 145 | "metadata": {}, 146 | "source": [ 147 | "## Using Chain Of Thought (CoT)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 27, 153 | "id": "61c51226-f2a0-4bd8-810e-95be5c43ff40", 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stderr", 158 | "output_type": "stream", 159 | "text": [ 160 | "httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n" 161 | ] 162 | }, 163 | { 164 | "data": { 165 | "text/plain": [ 166 | "Prediction(\n", 167 | " rationale='produce the intent_class. We first identify that the customer is expressing a desire to terminate an ongoing service, which directly aligns with the intent of canceling a subscription. There is no mention of seeking a refund or any other request, so the primary focus is on the cancellation itself.',\n", 168 | " intent_class='Cancel subscription'\n", 169 | ")" 170 | ] 171 | }, 172 | "execution_count": 27, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "predict_with_chain_of_thought = dspy.ChainOfThought(Classification)\n", 179 | "pred = predict_with_chain_of_thought(customer_message=text, intent_labels=labels)\n", 180 | "pred" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 28, 186 | "id": "52dac266-0db6-4b5c-9c8b-cf930175d9eb", 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "data": { 191 | "text/plain": [ 192 | "'Cancel subscription'" 193 | ] 194 | }, 195 | "execution_count": 28, 196 | "metadata": {}, 197 | "output_type": "execute_result" 198 | } 199 | ], 200 | "source": [ 201 | "pred.intent_class" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 29, 207 | "id": "3534fb75-fb66-4a4b-887d-6fedb07bdf10", 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "data": { 212 | "text/plain": [ 213 | "'produce the intent_class. We first identify that the customer is expressing a desire to terminate an ongoing service, which directly aligns with the intent of canceling a subscription. There is no mention of seeking a refund or any other request, so the primary focus is on the cancellation itself.'" 214 | ] 215 | }, 216 | "execution_count": 29, 217 | "metadata": {}, 218 | "output_type": "execute_result" 219 | } 220 | ], 221 | "source": [ 222 | "pred.rationale" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 30, 228 | "id": "73a919eb-8e23-440a-b26c-b66643d4f7ea", 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "\n", 236 | "\n", 237 | "\n", 238 | "Classify the customer message into one of the intent labels.\n", 239 | "The output should be only the predicted class as a single intent label.\n", 240 | "\n", 241 | "---\n", 242 | "\n", 243 | "Follow the following format.\n", 244 | "\n", 245 | "Customer Message: Customer message during customer service interaction\n", 246 | "\n", 247 | "Intent Labels: Labels that represent customer intent\n", 248 | "\n", 249 | "Reasoning: Let's think step by step in order to ${produce the intent_class}. We ...\n", 250 | "\n", 251 | "Intent Class: a label best matching customer's intent\n", 252 | "\n", 253 | "---\n", 254 | "\n", 255 | "Customer Message: I want to cancel my subscription\n", 256 | "\n", 257 | "Intent Labels: Cancel subscription Refund request\n", 258 | "\n", 259 | "Reasoning: Let's think step by step in order to\u001b[32m produce the intent_class. We first identify that the customer is expressing a desire to terminate an ongoing service, which directly aligns with the intent of canceling a subscription. There is no mention of seeking a refund or any other request, so the primary focus is on the cancellation itself.\n", 260 | "\n", 261 | "Intent Class: Cancel subscription\u001b[0m\n", 262 | "\n", 263 | "\n", 264 | "\n" 265 | ] 266 | }, 267 | { 268 | "data": { 269 | "text/plain": [ 270 | "\"\\n\\n\\nClassify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.\\n\\n---\\n\\nFollow the following format.\\n\\nCustomer Message: Customer message during customer service interaction\\n\\nIntent Labels: Labels that represent customer intent\\n\\nReasoning: Let's think step by step in order to ${produce the intent_class}. We ...\\n\\nIntent Class: a label best matching customer's intent\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\n\\nIntent Labels: Cancel subscription Refund request\\n\\nReasoning: Let's think step by step in order to\\x1b[32m produce the intent_class. We first identify that the customer is expressing a desire to terminate an ongoing service, which directly aligns with the intent of canceling a subscription. There is no mention of seeking a refund or any other request, so the primary focus is on the cancellation itself.\\n\\nIntent Class: Cancel subscription\\x1b[0m\\n\\n\\n\"" 271 | ] 272 | }, 273 | "execution_count": 30, 274 | "metadata": {}, 275 | "output_type": "execute_result" 276 | } 277 | ], 278 | "source": [ 279 | "lm_mini.inspect_history(n=1)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "id": "f387a663-b0df-40f8-9ea7-f88d3f03edf8", 285 | "metadata": {}, 286 | "source": [ 287 | "## Example of a bug - not specifying output properly" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "id": "92ec67e1-294e-4cf8-a0e7-babcb3299b50", 293 | "metadata": {}, 294 | "source": [ 295 | "### My instruction: \n", 296 | "\"\"\"Classify the customer message into one of the intent labels.\"\"\" \n", 297 | "### is not specific enough" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 31, 303 | "id": "1461b8ea-13f1-4c5f-ad4a-e48e33c1678d", 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "import dspy \n", 308 | "from open_intent_classifier.utils import join_labels\n", 309 | "\n", 310 | "class Classification(dspy.Signature):\n", 311 | " \"\"\"Classify the customer message into one of the intent labels.\"\"\"\n", 312 | "\n", 313 | " customer_message = dspy.InputField(desc=\"Customer message during customer service interaction\")\n", 314 | " intent_labels = dspy.InputField(desc=\"Labels that represent customer intent\")\n", 315 | " intent_class = dspy.OutputField(desc=\"a label best matching customer's intent \")" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "id": "55bdc918-e918-48b3-8aaf-22c26f35091f", 321 | "metadata": {}, 322 | "source": [ 323 | "### Intent class is the whole message, instead of being the class" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 17, 329 | "id": "ec7d46bd-d247-4eac-afa9-92b168f34f3a", 330 | "metadata": {}, 331 | "outputs": [ 332 | { 333 | "name": "stderr", 334 | "output_type": "stream", 335 | "text": [ 336 | "httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n" 337 | ] 338 | }, 339 | { 340 | "data": { 341 | "text/plain": [ 342 | "Prediction(\n", 343 | " intent_class='Customer Message: I want to cancel my subscription \\nIntent Labels: Cancel subscription \\nIntent Class: Cancel subscription'\n", 344 | ")" 345 | ] 346 | }, 347 | "execution_count": 17, 348 | "metadata": {}, 349 | "output_type": "execute_result" 350 | } 351 | ], 352 | "source": [ 353 | "predict = dspy.Predict(Classification)\n", 354 | "labels = \" \".join([\"Cancel subscription\", \"Refund request\"])\n", 355 | "text = \"I want to cancel my subscription\"\n", 356 | "pred = predict(customer_message=text, intent_labels=labels)\n", 357 | "pred" 358 | ] 359 | } 360 | ], 361 | "metadata": { 362 | "kernelspec": { 363 | "display_name": "Python 3 (ipykernel)", 364 | "language": "python", 365 | "name": "python3" 366 | }, 367 | "language_info": { 368 | "codemirror_mode": { 369 | "name": "ipython", 370 | "version": 3 371 | }, 372 | "file_extension": ".py", 373 | "mimetype": "text/x-python", 374 | "name": "python", 375 | "nbconvert_exporter": "python", 376 | "pygments_lexer": "ipython3", 377 | "version": "3.10.12" 378 | } 379 | }, 380 | "nbformat": 4, 381 | "nbformat_minor": 5 382 | } 383 | -------------------------------------------------------------------------------- /notebooks/smollm2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5a8142e2-4f4a-48c9-bd1a-82a596c08fbe", 6 | "metadata": {}, 7 | "source": [ 8 | "# 1.7B Instruct" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 8, 14 | "id": "299d5cd3-36e9-49f7-87e7-5802e1c6b75a", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# pip install transformers\n", 19 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n", 20 | "# checkpoint = \"HuggingFaceTB/SmolLM2-360M\"\n", 21 | "# checkpoint = \"HuggingFaceTB/SmolLM2-360M-Instruct\"\n", 22 | "def get_model_tokenizer(checkpoint):\n", 23 | " device = \"cuda\" # for GPU usage or \"cpu\" for CPU usage\n", 24 | " tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", 25 | " # for multiple GPUs install accelerate and do `model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map=\"auto\")`\n", 26 | " model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)\n", 27 | " return model, tokenizer\n", 28 | "\n", 29 | "big_model = \"HuggingFaceTB/SmolLM2-1.7B-Instruct\"\n", 30 | "small_model = \"HuggingFaceTB/SmolLM2-360M-Instruct\"" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "id": "9d8d0e81-7ea0-4629-8fa0-c49cbdae7a65", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "model, tokenizer = get_model_tokenizer(big_model)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 5, 46 | "id": "b837c84f-4d1d-4988-bad6-83bc899e0fc6", 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stderr", 51 | "output_type": "stream", 52 | "text": [ 53 | "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", 54 | "/home/serj/dev/open-intent-classifier/venv/lib/python3.10/site-packages/transformers/generation/utils.py:1258: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n", 55 | " warnings.warn(\n" 56 | ] 57 | }, 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Gravity is a fundamental force of nature that affects all objects with mass, from the smallest subatomic\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "device = \"cuda\"\n", 68 | "inputs = tokenizer.encode(\"Gravity is\", return_tensors=\"pt\").to(device)\n", 69 | "outputs = model.generate(inputs)\n", 70 | "print(tokenizer.decode(outputs[0]))" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 31, 76 | "id": "ba388418-8b1a-48ef-9beb-18f6f83ac135", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# check the chat template\n", 81 | "prompt_template = '''You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list.\n", 82 | " Format, the list of intents is delimeted with %.\n", 83 | " You must choose only one\n", 84 | " Text: {text}\n", 85 | " Intent List: {labels}.\n", 86 | "\n", 87 | " Answer format: {class: }\n", 88 | " Answer: \n", 89 | " '''\n", 90 | "\n", 91 | "# def predict(text, labels, device=\"cuda\"): \n", 92 | "# labels_str = \"%\".join(labels)\n", 93 | "# prompt = prompt_template.replace(\"{text}\", text).replace(\"{labels}\", labels_str)\n", 94 | "# print(prompt)\n", 95 | "# inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", 96 | "# outputs = model.generate(inputs, max_new_tokens=200)\n", 97 | "# return tokenizer.decode(outputs[0])\n", 98 | "\n", 99 | "def predict_roles(text, labels, device=\"cuda\"):\n", 100 | " \n", 101 | " labels_str = \"%\".join(labels)\n", 102 | "\n", 103 | " prompt_format = '''Format, the list of intents is delimeted with %.\n", 104 | " You must choose only one\n", 105 | " Text: {text}\n", 106 | " Intent List: {labels}.\n", 107 | " Answer format is json: {class: }\n", 108 | " You must follow this format'''\n", 109 | " prompt = prompt_format.replace(\"{text}\", text).replace(\"{labels}\", labels_str)\n", 110 | " messages = [{\"role\": \"system\", \"content\": \"You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list\"},\n", 111 | " {\"role\": \"user\", \"content\": prompt}]\n", 112 | " \n", 113 | " input_text=tokenizer.apply_chat_template(messages, tokenize=False)\n", 114 | " inputs = tokenizer.encode(input_text, return_tensors=\"pt\").to(device)\n", 115 | " outputs = model.generate(inputs, max_new_tokens=50)\n", 116 | " return tokenizer.decode(outputs[0])\n", 117 | "\n", 118 | "def predict_small_model(text, labels, device=\"cuda\"):\n", 119 | " labels_str = \"%\".join(labels)\n", 120 | "\n", 121 | " prompt_format = '''Format, the list of intents is delimeted with %.\n", 122 | " Text: is the text of the customer.\n", 123 | " Intent List: is the list of possible intents for the text. You need to choose one.\n", 124 | " You must choose only one\n", 125 | " --\n", 126 | " Your answer must include only the name of the class, without any prefix, suffix or inbetween nicities.\n", 127 | " You must follow this format.\n", 128 | " -- \n", 129 | " Here is an example:\n", 130 | " Text: I have an issues with my iPhone\n", 131 | " Intent List: technical issue % cancel subscription\n", 132 | " Class: technical issue\n", 133 | " --\n", 134 | " Text: {text}\n", 135 | " Intent List: {labels}.\n", 136 | " Class: '''\n", 137 | " prompt = prompt_format.replace(\"{text}\", text).replace(\"{labels}\", labels_str)\n", 138 | " messages = [{\"role\": \"system\", \"content\": \"You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list\"},\n", 139 | " {\"role\": \"user\", \"content\": prompt}]\n", 140 | " \n", 141 | " input_text=tokenizer.apply_chat_template(messages, tokenize=False)\n", 142 | " inputs = tokenizer.encode(input_text, return_tensors=\"pt\").to(device)\n", 143 | " outputs = model.generate(inputs, max_new_tokens=100)\n", 144 | " return tokenizer.decode(outputs[0])\n", 145 | "\n", 146 | "\n", 147 | "def predict_small_model_v2(text, labels, device=\"cuda\"):\n", 148 | " labels_str = \"%\".join(labels)\n", 149 | "\n", 150 | " # prompt_format = '''Format, the list of intents is delimeted with %.\n", 151 | " # Text: is the text of the customer.\n", 152 | " # Intent List: is the list of possible intents for the text. You need to choose one.\n", 153 | " # You must choose only one\n", 154 | " # --\n", 155 | " # Your answer must include only the name of the class, without any prefix, suffix or inbetween nicities.\n", 156 | " # You must follow this format.\n", 157 | " # -- \n", 158 | " # Here is an example:\n", 159 | " # Text: I have an issues with my iPhone\n", 160 | " # Intent List: technical issue % cancel subscription\n", 161 | " # Class: technical issue\n", 162 | " # --\n", 163 | " # Text: {text}\n", 164 | " # Intent List: {labels}.\n", 165 | " # Class: '''\n", 166 | " \n", 167 | " \n", 168 | " # prompt = prompt_format.replace(\"{text}\", text).replace(\"{labels}\", labels_str)\n", 169 | " prompt = f\"Customer Text: {text}. Most matching intent from list: \"\n", 170 | " messages = [{\"role\": \"system\", \"content\": f\"You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list: {labels_str}. Do not return any text other than the text of the most matching intent from the list.\"},\n", 171 | " {\"role\": \"user\", \"content\": prompt}]\n", 172 | " \n", 173 | " input_text=tokenizer.apply_chat_template(messages, tokenize=False)\n", 174 | " inputs = tokenizer.encode(input_text, return_tensors=\"pt\").to(device)\n", 175 | " outputs = model.generate(inputs, max_new_tokens=100)\n", 176 | " return tokenizer.decode(outputs[0])\n" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 10, 182 | "id": "78119cd1-de18-4273-bf2c-33912f36f918", 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list.\n", 190 | " Format, the list of intents is delimeted with %.\n", 191 | " You must choose only one\n", 192 | " Text: Hey, I want to cancel subscription\n", 193 | " Intent List: cancel subscription%refund requests%bug%issue%unknown.\n", 194 | "\n", 195 | " Answer format: {class: }\n", 196 | " Answer: \n", 197 | " \n" 198 | ] 199 | }, 200 | { 201 | "data": { 202 | "text/plain": [ 203 | "'You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list.\\n Format, the list of intents is delimeted with %.\\n You must choose only one\\n Text: Hey, I want to cancel subscription\\n Intent List: cancel subscription%refund requests%bug%issue%unknown.\\n\\n Answer format: {class: }\\n Answer: \\n 1. cancel subscription\\n 2. refund requests\\n 3. bug\\n 4. issue\\n 5. unknown\\n\\nText: I want to cancel my subscription\\nIntent List: cancel subscription%refund requests%bug%issue%unknown.\\n\\n Answer format: {class: }\\n Answer: \\n 1. cancel subscription\\n 2. refund requests\\n 3. bug\\n 4. issue\\n 5. unknown\\n\\nText: I want to cancel my subscription\\nIntent List: cancel subscription%refund requests%bug%issue%unknown.\\n\\n Answer format: {class: }\\n Answer: \\n 1. cancel subscription\\n 2. refund requests\\n 3. bug\\n 4. issue\\n 5. unknown\\n\\nText: I want to cancel my subscription\\nIntent List: cancel subscription%refund requests%bug%issue%unknown.\\n\\n Answer format: {class: }\\n Answer: \\n 1. cancel subscription\\n 2. refund requests\\n 3. bug\\n '" 204 | ] 205 | }, 206 | "execution_count": 10, 207 | "metadata": {}, 208 | "output_type": "execute_result" 209 | } 210 | ], 211 | "source": [ 212 | "text = \"Hey, I want to cancel subscription\"\n", 213 | "labels = [\"cancel subscription\", \"refund requests\", \"bug\", \"issue\", \"unknown\"]\n", 214 | "predict(text, labels)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 14, 220 | "id": "d89e9619-0bbe-40a0-be94-1ac02f1358e1", 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/plain": [ 226 | "'<|im_start|>system\\nYou are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list<|im_end|>\\n<|im_start|>user\\nFormat, the list of intents is delimeted with %.\\n You must choose only one\\n Text: Hey, I want to cancel subscription\\n Intent List: cancel subscription%refund requests%bug%issue%unknown.\\n Answer format is json: {class: }\\n You must follow this format<|im_end|>\\n<|im_start|>assistant\\n{\\n \"class\": \"cancel_subscription\"\\n}<|im_end|>'" 227 | ] 228 | }, 229 | "execution_count": 14, 230 | "metadata": {}, 231 | "output_type": "execute_result" 232 | } 233 | ], 234 | "source": [ 235 | "predict_roles(text, labels)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 15, 241 | "id": "e82ac4a8-a30d-4b86-b24a-f3eafc839b35", 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "'<|im_start|>system\\nYou are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list<|im_end|>\\n<|im_start|>user\\nFormat, the list of intents is delimeted with %.\\n You must choose only one\\n Text: I hate this job and your service, I have to get my money back\\n Intent List: cancel subscription%refund requests%bug%issue%unknown.\\n Answer format is json: {class: }\\n You must follow this format<|im_end|>\\n<|im_start|>assistant\\n{\\n \"class\": \"refund_requests\"\\n}<|im_end|>'" 248 | ] 249 | }, 250 | "execution_count": 15, 251 | "metadata": {}, 252 | "output_type": "execute_result" 253 | } 254 | ], 255 | "source": [ 256 | "predict_roles(\"I hate this job and your service, I have to get my money back\", labels)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "id": "1c66a6e3-2d4d-43ae-90c6-a63badd8fa5d", 262 | "metadata": {}, 263 | "source": [ 264 | "# 360M Instruct" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 2, 270 | "id": "c6d24c4c-bd0b-4632-b116-cfd4b74eab21", 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "model, tokenizer = get_model_tokenizer(small_model)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 32, 280 | "id": "08fbfb4c-c8ee-47a2-b3ec-df3a70fbcc25", 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "data": { 285 | "text/plain": [ 286 | "'<|im_start|>system\\nYou are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list: cancel subscription%refund requests%bug%issue%unknown. Do not return any text other than the text of the most matching intent from the list.<|im_end|>\\n<|im_start|>user\\nCustomer Text: Hey, I want to cancel subscription. Most matching intent from list: <|im_end|>\\n<|im_start|>assistant\\nCancel subscription.<|im_end|>'" 287 | ] 288 | }, 289 | "execution_count": 32, 290 | "metadata": {}, 291 | "output_type": "execute_result" 292 | } 293 | ], 294 | "source": [ 295 | "predict_small_model_v2(text, labels)\n" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 33, 301 | "id": "06a28a7a-51fa-4010-ba06-80d9177e4a5d", 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "data": { 306 | "text/plain": [ 307 | "\"<|im_start|>system\\nYou are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list: cancel subscription%refund requests%bug%issue%unknown. Do not return any text other than the text of the most matching intent from the list.<|im_end|>\\n<|im_start|>user\\nCustomer Text: I hate this job and your service, I have to get my money back. Most matching intent from list: <|im_end|>\\n<|im_start|>assistant\\nI apologize for the inconvenience. I'm sorry to hear that you're having issues with your service. I'm here to help. Can you please provide more details about the issue you're experiencing?<|im_end|>\"" 308 | ] 309 | }, 310 | "execution_count": 33, 311 | "metadata": {}, 312 | "output_type": "execute_result" 313 | } 314 | ], 315 | "source": [ 316 | "predict_small_model_v2(\"I hate this job and your service, I have to get my money back\", labels)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "id": "c34634a1-ebef-464c-b9f1-1292e76f2a31", 322 | "metadata": {}, 323 | "source": [ 324 | "# 1.7B with DSPY" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 16, 330 | "id": "97c9d700-9ea1-4532-ad45-1474cf9f502d", 331 | "metadata": {}, 332 | "outputs": [ 333 | { 334 | "name": "stderr", 335 | "output_type": "stream", 336 | "text": [ 337 | "accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).\n", 338 | "accelerate.big_modeling - WARNING - Some parameters are on the meta device because they were offloaded to the cpu.\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "import dspy\n", 344 | "from open_intent_classifier.model import Classification\n", 345 | "\n", 346 | "llm = dspy.HFModel(model=big_model)\n", 347 | "dspy.settings.configure(lm=llm)\n", 348 | "predict = dspy.Predict(Classification)" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 19, 354 | "id": "6e6ffc8c-754e-447b-b606-455425338c2e", 355 | "metadata": {}, 356 | "outputs": [ 357 | { 358 | "data": { 359 | "text/plain": [ 360 | "['Please answer. what is gravity?']" 361 | ] 362 | }, 363 | "execution_count": 19, 364 | "metadata": {}, 365 | "output_type": "execute_result" 366 | } 367 | ], 368 | "source": [ 369 | "llm(\"Please answer. what is gravity?\")" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 10, 375 | "id": "92626ce7-105c-4d50-b46f-972c8a7186ac", 376 | "metadata": {}, 377 | "outputs": [ 378 | { 379 | "name": "stderr", 380 | "output_type": "stream", 381 | "text": [ 382 | "/home/serj/dev/open-intent-classifier/venv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", 383 | " warnings.warn(\n" 384 | ] 385 | }, 386 | { 387 | "data": { 388 | "text/plain": [ 389 | "Prediction(\n", 390 | " intent_class=\"Classify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.\\n\\n---\\n\\nFollow the following format.\\n\\nCustomer Message: Customer message during customer service interaction\\nIntent Labels: Labels that represent customer intent\\nIntent Class: a label best matching customer's intent\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Cancel subscription\\n\\nCustomer Message: I want to change my subscription plan\\nIntent Labels: Change subscription plan Refund request\\nIntent Class: Change subscription plan\\n\\nCustomer Message: I want to cancel my subscription and get a refund\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Cancel subscription\\n\\nCustomer Message: I want to change my subscription plan and get a refund\\nIntent Labels: Change subscription plan Refund request\\nIntent Class: Change subscription plan\\n\\nCustomer Message: I want to cancel my subscription and get a refund, but I want to keep my account\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Cancel subscription\"\n", 391 | ")" 392 | ] 393 | }, 394 | "execution_count": 10, 395 | "metadata": {}, 396 | "output_type": "execute_result" 397 | } 398 | ], 399 | "source": [ 400 | "predict = dspy.Predict(Classification)\n", 401 | "labels = ([\"Cancel subscription\", \"Refund request\"])\n", 402 | "labels = \" \".join([\"Cancel subscription\", \"Refund request\"])\n", 403 | "text = \"I want to cancel my subscription\"\n", 404 | "pred = predict(customer_message=text, intent_labels=labels)\n", 405 | "pred\n" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "id": "655a933f-dcb8-4d18-88b5-9dc8792fb2cd", 411 | "metadata": {}, 412 | "source": [ 413 | "# 360M with DSPy" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 5, 419 | "id": "64956afb-876b-44cf-b298-579141b9a224", 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "name": "stderr", 424 | "output_type": "stream", 425 | "text": [ 426 | "accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).\n" 427 | ] 428 | } 429 | ], 430 | "source": [ 431 | "import dspy\n", 432 | "from open_intent_classifier.model import Classification\n", 433 | "\n", 434 | "llm = dspy.HFModel(model=small_model)\n", 435 | "dspy.settings.configure(lm=llm)\n", 436 | "predict = dspy.Predict(Classification)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 7, 442 | "id": "1518566f-78f3-4221-9499-c0700cbd70df", 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "name": "stderr", 447 | "output_type": "stream", 448 | "text": [ 449 | "root - WARNING - \t*** In DSPy 2.5, all LM clients except `dspy.LM` are deprecated, underperform, and are about to be deleted. ***\n", 450 | " \t\tYou are using the client HFModel, which will be removed in DSPy 2.6.\n", 451 | " \t\tChanging the client is straightforward and will let you use new features (Adapters) that improve the consistency of LM outputs, especially when using chat LMs. \n", 452 | "\n", 453 | " \t\tLearn more about the changes and how to migrate at\n", 454 | " \t\thttps://github.com/stanfordnlp/dspy/blob/main/examples/migration.ipynb\n", 455 | "/home/serj/dev/open-intent-classifier/venv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", 456 | " warnings.warn(\n" 457 | ] 458 | }, 459 | { 460 | "data": { 461 | "text/plain": [ 462 | "Prediction(\n", 463 | " intent_class=\"Classify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.\\n\\n---\\n\\nFollow the following format.\\n\\nCustomer Message: Customer message during customer service interaction\\nIntent Labels: Labels that represent customer intent\\nIntent Class: a label best matching customer's intent\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Refund request\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Refund request\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Refund request\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Refund request\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Refund request\\n\\n---\\n\\nCustomer Message: I want to\"\n", 464 | ")" 465 | ] 466 | }, 467 | "execution_count": 7, 468 | "metadata": {}, 469 | "output_type": "execute_result" 470 | } 471 | ], 472 | "source": [ 473 | "predict = dspy.Predict(Classification)\n", 474 | "labels = ([\"Cancel subscription\", \"Refund request\"])\n", 475 | "labels = \" \".join([\"Cancel subscription\", \"Refund request\"])\n", 476 | "text = \"I want to cancel my subscription\"\n", 477 | "pred = predict(customer_message=text, intent_labels=labels)\n", 478 | "pred\n" 479 | ] 480 | } 481 | ], 482 | "metadata": { 483 | "kernelspec": { 484 | "display_name": "Python 3 (ipykernel)", 485 | "language": "python", 486 | "name": "python3" 487 | }, 488 | "language_info": { 489 | "codemirror_mode": { 490 | "name": "ipython", 491 | "version": 3 492 | }, 493 | "file_extension": ".py", 494 | "mimetype": "text/x-python", 495 | "name": "python", 496 | "nbconvert_exporter": "python", 497 | "pygments_lexer": "ipython3", 498 | "version": "3.10.12" 499 | } 500 | }, 501 | "nbformat": 4, 502 | "nbformat_minor": 5 503 | } 504 | -------------------------------------------------------------------------------- /notebooks/dspy_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "24b400eb-2039-4e22-812a-c784161f329d", 6 | "metadata": {}, 7 | "source": [ 8 | "# Content\n", 9 | "## Define Classifier\n", 10 | "## Define Datasets\n", 11 | "## Define DSPy.Examples\n", 12 | "## Define Optimizers\n", 13 | "## Define Evaluation\n", 14 | "## Run Evaluation\n", 15 | "## Save Prompts" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "id": "b94e80f3-3447-43ea-9543-ff73df3c4805", 21 | "metadata": {}, 22 | "source": [ 23 | "# Define classifier" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "id": "39778e03-1631-4d90-a7cd-1c82f5b948ac", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "import dspy \n", 34 | "import random\n", 35 | "import pandas as pd\n", 36 | "from datasets import load_dataset\n", 37 | "\n", 38 | "class Classification(dspy.Signature):\n", 39 | " \"\"\"Classify the customer message into one of the intent labels.\n", 40 | " The output should be only the predicted class as a single intent label.\"\"\"\n", 41 | "\n", 42 | " customer_message = dspy.InputField(desc=\"Customer message during customer service interaction\")\n", 43 | " intent_labels = dspy.InputField(desc=\"Labels that represent customer intent\")\n", 44 | " answer = dspy.OutputField(desc=\"a label best matching customer's intent \")\n", 45 | "\n", 46 | "lm_mini = dspy.OpenAI(model='gpt-4o-mini')\n", 47 | "dspy.settings.configure(lm=lm_mini)\n", 48 | "cot_predictor = dspy.ChainOfThought(Classification)\n" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "cf11469a-598f-4962-97c4-4ff2055d8086", 54 | "metadata": {}, 55 | "source": [ 56 | "# Parse Atis Dataset" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "id": "16594105-5855-4c13-801d-bb3055580136", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "\n", 67 | "dataset = load_dataset(\"tuetschek/atis\")\n", 68 | "dataset.set_format(type=\"pandas\")\n", 69 | "\n", 70 | "df_train: pd.DataFrame = dataset[\"train\"][:]\n", 71 | "df_test: pd.DataFrame = dataset[\"test\"][:]\n", 72 | "small_test = df_test.head(100)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "id": "a1cba8a9-00dd-4ba3-9486-925b0bb2ea8d", 78 | "metadata": {}, 79 | "source": [ 80 | "## x column: text, y column: intent" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 3, 86 | "id": "e08e4e74-d202-48a4-8a14-0ab68729a462", 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "id 0\n", 93 | "intent flight\n", 94 | "text i want to fly from boston at 838 am and arrive...\n", 95 | "slots O O O O O B-fromloc.city_name O B-depart_time....\n", 96 | "Name: 0, dtype: object" 97 | ] 98 | }, 99 | "execution_count": 3, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "df_train.iloc[0]" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "id": "78305be5-30a9-4cf5-a2ac-9c66ecca03cc", 111 | "metadata": {}, 112 | "source": [ 113 | "## prepare labels" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 4, 119 | "id": "3632bcc1-c391-43f9-b927-63672ad7d056", 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/plain": [ 125 | "'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'" 126 | ] 127 | }, 128 | "execution_count": 4, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "labels = df_train[\"intent\"].unique().tolist()\n", 135 | "labels_str = \"%\".join(labels)\n", 136 | "labels_str" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 5, 142 | "id": "681e4647-fd7e-4b80-a570-77f44f966485", 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stderr", 147 | "output_type": "stream", 148 | "text": [ 149 | "WARNING:root:\t*** In DSPy 2.5, all LM clients except `dspy.LM` are deprecated, underperform, and are about to be deleted. ***\n", 150 | " \t\tYou are using the client GPT3, which will be removed in DSPy 2.6.\n", 151 | " \t\tChanging the client is straightforward and will let you use new features (Adapters) that improve the consistency of LM outputs, especially when using chat LMs. \n", 152 | "\n", 153 | " \t\tLearn more about the changes and how to migrate at\n", 154 | " \t\thttps://github.com/stanfordnlp/dspy/blob/main/examples/migration.ipynb\n" 155 | ] 156 | }, 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "customer message: i want to fly from boston at 838 am and arrive in denver at 1110 in the morning,real class: flight\n" 162 | ] 163 | }, 164 | { 165 | "data": { 166 | "text/plain": [ 167 | "Prediction(\n", 168 | " rationale='produce the answer. We first identify that the customer is expressing a desire to fly from Boston to Denver, specifying both the departure time (8:38 AM) and the arrival time (11:10 AM). This indicates that the customer is looking for flight information, particularly related to the flight time and possibly the airfare. The intent is primarily focused on booking or inquiring about a flight, which aligns with the \"flight_time\" and \"airfare\" labels. However, since the customer is providing specific times, the most relevant label is \"flight_time\".',\n", 169 | " answer='flight_time'\n", 170 | ")" 171 | ] 172 | }, 173 | "execution_count": 5, 174 | "metadata": {}, 175 | "output_type": "execute_result" 176 | } 177 | ], 178 | "source": [ 179 | "## run prediction\n", 180 | "first_row = df_train.iloc[0]\n", 181 | "print(f\"customer message: {first_row['text']},real class: {first_row['intent']}\")\n", 182 | "cot_predictor(customer_message=first_row[\"text\"], intent_labels=labels_str)\n" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "id": "6faec3e6-114e-4397-996f-e3c0a0e805c4", 188 | "metadata": {}, 189 | "source": [ 190 | "# Define Examples" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 6, 196 | "id": "1e01c3c9-288d-4477-a20b-47694f583c8a", 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "name": "stdout", 201 | "output_type": "stream", 202 | "text": [ 203 | "90 45\n" 204 | ] 205 | } 206 | ], 207 | "source": [ 208 | "# nowhere else there was an example of this thing\n", 209 | "\n", 210 | "# we want k examples per class \n", 211 | "def get_dspy_examples(df, k) -> dspy.example:\n", 212 | " dspy_examples = []\n", 213 | " for label in labels:\n", 214 | " try:\n", 215 | " label_df = df[df[\"intent\"] == label].sample(n=k)\n", 216 | " for index, row in label_df.iterrows():\n", 217 | " dspy_examples.append(\n", 218 | " dspy.Example(customer_message=row[\"text\"], answer=row[\"intent\"], \n", 219 | " intent_labels=labels_str).with_inputs(\"customer_message\", \"intent_labels\")\n", 220 | " )\n", 221 | " except:\n", 222 | " # there are classes that don't have any representatives\n", 223 | " continue\n", 224 | "\n", 225 | " return dspy_examples\n", 226 | "\n", 227 | "\n", 228 | "train_examples = get_dspy_examples(df_train, k=2)\n", 229 | "all_test_examples = get_dspy_examples(df_test, k=10)\n", 230 | "print(len(all_test_examples), len(all_test_examples) // 2)\n", 231 | "dev_examples = random.sample(all_test_examples, len(all_test_examples) // 2)\n", 232 | "test_examples = [example for example in all_test_examples if example not in dev_examples]" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "id": "b061f1af-b8fc-4742-951e-cfeaf1b696a9", 238 | "metadata": {}, 239 | "source": [ 240 | "# Define LabeledFewShot Optimizer\n", 241 | "LabeledFewShot is the simplest optimizer. Its compile method injects samples intro the prompt.\n", 242 | "There is not optimization going on.\n" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 7, 248 | "id": "24b21df7-57c4-454e-bdf6-d63e8f855e70", 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "from dspy.teleprompt import LabeledFewShot\n", 253 | "\n", 254 | "few_shot_demos = random.sample(train_examples, k=10)\n", 255 | "labeled_fewshot_optimizer = LabeledFewShot(k=len(few_shot_demos))\n", 256 | "few_shot_model = labeled_fewshot_optimizer.compile(student=cot_predictor, trainset=few_shot_demos)\n" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "id": "369a899d-4994-4371-9025-ea7f36f86bba", 262 | "metadata": {}, 263 | "source": [ 264 | "### What is happenning under the hood?\n", 265 | "### LabeledFewShot randomly selects labels" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "id": "bf6ac178-fc1b-4e34-b944-ddc781428257", 271 | "metadata": {}, 272 | "source": [ 273 | "### DSPy SOURCE CODE: https://github.com/stanfordnlp/dspy/blob/793530c65a0e1721997dac0d2636f0f70ad649b6/dspy/teleprompt/vanilla.py#L6\n", 274 | "\n", 275 | "class LabeledFewShot(Teleprompter):\n", 276 | " def __init__(self, k=16):\n", 277 | " self.k = k\n", 278 | "\n", 279 | " def compile(self, student, *, trainset, sample=True):\n", 280 | " self.student = student.reset_copy()\n", 281 | " self.trainset = trainset\n", 282 | "\n", 283 | " if len(self.trainset) == 0:\n", 284 | " return self.student\n", 285 | "\n", 286 | " rng = random.Random(0)\n", 287 | "\n", 288 | " for predictor in self.student.predictors():\n", 289 | " if sample:\n", 290 | " predictor.demos = rng.sample(self.trainset, min(self.k, len(self.trainset)))\n", 291 | " else:\n", 292 | " predictor.demos = self.trainset[: min(self.k, len(self.trainset))]\n", 293 | "\n", 294 | " return self.student" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "id": "d2598608-3292-49a8-98f4-2e5f46d1366f", 300 | "metadata": {}, 301 | "source": [ 302 | "### My own summary of the implementation\n", 303 | "DSPy samples randomly a portion of the samples as examples for in-context learning. \n", 304 | "There's no actual optimization process." 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "id": "6bdad12e-aff6-4b2e-bd93-f6c7f594b807", 310 | "metadata": {}, 311 | "source": [ 312 | "### How does the prompt looks like?" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 20, 318 | "id": "12c40bb7-707d-4b34-b5a1-3643d85425ff", 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "\n", 326 | "\n", 327 | "\n", 328 | "Classify the customer message into one of the intent labels.\n", 329 | "The output should be only the predicted class as a single intent label.\n", 330 | "\n", 331 | "---\n", 332 | "\n", 333 | "Follow the following format.\n", 334 | "\n", 335 | "Customer Message: Customer message during customer service interaction\n", 336 | "\n", 337 | "Intent Labels: Labels that represent customer intent\n", 338 | "\n", 339 | "Reasoning: Let's think step by step in order to ${produce the answer}. We ...\n", 340 | "\n", 341 | "Answer: a label best matching customer's intent\n", 342 | "\n", 343 | "---\n", 344 | "\n", 345 | "Customer Message: what are the air restrictions on flights from pittsburgh to atlanta for the airfare of 416 dollars\n", 346 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 347 | "Answer: restriction\n", 348 | "\n", 349 | "---\n", 350 | "\n", 351 | "Customer Message: what is the flight number of the earliest flight between boston and washington dc\n", 352 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 353 | "Answer: flight_no\n", 354 | "\n", 355 | "---\n", 356 | "\n", 357 | "Customer Message: what times does continental depart from boston to san francisco\n", 358 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 359 | "Answer: flight_time\n", 360 | "\n", 361 | "---\n", 362 | "\n", 363 | "Customer Message: what ground transportation is available in san francisco\n", 364 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 365 | "Answer: ground_service\n", 366 | "\n", 367 | "---\n", 368 | "\n", 369 | "Customer Message: do i get a meal on the atlanta to bwi flight eastern 210\n", 370 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 371 | "Answer: meal\n", 372 | "\n", 373 | "---\n", 374 | "\n", 375 | "Customer Message: round trip fares from denver to philadelphia less than 1000 dollars\n", 376 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 377 | "Answer: airfare\n", 378 | "\n", 379 | "---\n", 380 | "\n", 381 | "Customer Message: which airlines fly between boston and pittsburgh\n", 382 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 383 | "Answer: airline\n", 384 | "\n", 385 | "---\n", 386 | "\n", 387 | "Customer Message: houston airports\n", 388 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 389 | "Answer: airport\n", 390 | "\n", 391 | "---\n", 392 | "\n", 393 | "Customer Message: what is the distance between pittsburgh airport and downtown pittsburgh\n", 394 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 395 | "Answer: distance\n", 396 | "\n", 397 | "---\n", 398 | "\n", 399 | "Customer Message: i would like the flight number and the time for the cheapest fare that is the least expensive first class fare from san francisco to pittsburgh leaving after 8 pm monday night\n", 400 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 401 | "Answer: flight_no\n", 402 | "\n", 403 | "---\n", 404 | "\n", 405 | "Customer Message: show flights from minneapolis to kansas city\n", 406 | "\n", 407 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n", 408 | "\n", 409 | "Reasoning: Let's think step by step in order to\u001b[32m produce the answer. We need to identify the main intent of the customer message. The customer is asking to see flights from one city to another, which indicates they are looking for information related to flights. The most appropriate label that matches this intent is \"flight.\"\n", 410 | "\n", 411 | "Answer: flight\u001b[0m\n", 412 | "\n", 413 | "\n", 414 | "\n" 415 | ] 416 | }, 417 | { 418 | "data": { 419 | "text/plain": [ 420 | "'\\n\\n\\nClassify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.\\n\\n---\\n\\nFollow the following format.\\n\\nCustomer Message: Customer message during customer service interaction\\n\\nIntent Labels: Labels that represent customer intent\\n\\nReasoning: Let\\'s think step by step in order to ${produce the answer}. We ...\\n\\nAnswer: a label best matching customer\\'s intent\\n\\n---\\n\\nCustomer Message: what are the air restrictions on flights from pittsburgh to atlanta for the airfare of 416 dollars\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: restriction\\n\\n---\\n\\nCustomer Message: what is the flight number of the earliest flight between boston and washington dc\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: flight_no\\n\\n---\\n\\nCustomer Message: what times does continental depart from boston to san francisco\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: flight_time\\n\\n---\\n\\nCustomer Message: what ground transportation is available in san francisco\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: ground_service\\n\\n---\\n\\nCustomer Message: do i get a meal on the atlanta to bwi flight eastern 210\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: meal\\n\\n---\\n\\nCustomer Message: round trip fares from denver to philadelphia less than 1000 dollars\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: airfare\\n\\n---\\n\\nCustomer Message: which airlines fly between boston and pittsburgh\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: airline\\n\\n---\\n\\nCustomer Message: houston airports\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: airport\\n\\n---\\n\\nCustomer Message: what is the distance between pittsburgh airport and downtown pittsburgh\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: distance\\n\\n---\\n\\nCustomer Message: i would like the flight number and the time for the cheapest fare that is the least expensive first class fare from san francisco to pittsburgh leaving after 8 pm monday night\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: flight_no\\n\\n---\\n\\nCustomer Message: show flights from minneapolis to kansas city\\n\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\n\\nReasoning: Let\\'s think step by step in order to\\x1b[32m produce the answer. We need to identify the main intent of the customer message. The customer is asking to see flights from one city to another, which indicates they are looking for information related to flights. The most appropriate label that matches this intent is \"flight.\"\\n\\nAnswer: flight\\x1b[0m\\n\\n\\n'" 421 | ] 422 | }, 423 | "execution_count": 20, 424 | "metadata": {}, 425 | "output_type": "execute_result" 426 | } 427 | ], 428 | "source": [ 429 | "example = test_examples[0]\n", 430 | "# without inputs(), we won't inject the inputs of the example\n", 431 | "pred = few_shot_model(**example.inputs())\n", 432 | "# Produce a prediction from our `cot` module, using the `example` above as input.\n", 433 | "lm_mini.inspect_history(n=1)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "id": "4be6e2df-2e27-43c3-b40b-f6dd70a35ef3", 439 | "metadata": {}, 440 | "source": [ 441 | "## Define BootstrapFewShot Optimizer\n", 442 | "This family of optimizers is focused on optimizing the few shot examples. Let's take an example of a Sample pipeline and see how we can use this optimizer to optimize it. From: https://dspy.ai/deep-dive/optimizers/bootstrap-fewshot/" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 9, 448 | "id": "9e86aad2-7f4e-463c-9ee0-5db62ffec24b", 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "from dspy.evaluate import answer_exact_match as metric\n", 453 | "from dspy.teleprompt import BootstrapFewShot\n", 454 | "\n", 455 | "optimizer = BootstrapFewShot(\n", 456 | " metric=metric,\n", 457 | " max_bootstrapped_demos=10,\n", 458 | " max_labeled_demos=10,\n", 459 | " max_rounds=10,\n", 460 | ")" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "id": "d7281605-91de-4bd6-8244-df8c8dc6441c", 466 | "metadata": {}, 467 | "source": [ 468 | "### Optimize" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 10, 474 | "id": "45e59037-ced5-44b7-8d31-8337fe9d45ea", 475 | "metadata": {}, 476 | "outputs": [ 477 | { 478 | "name": "stderr", 479 | "output_type": "stream", 480 | "text": [ 481 | " 28%|████████████████████████████████████▍ | 10/36 [00:14<00:38, 1.47s/it]" 482 | ] 483 | }, 484 | { 485 | "name": "stdout", 486 | "output_type": "stream", 487 | "text": [ 488 | "Bootstrapped 10 full traces after 10 examples for up to 10 rounds, amounting to 10 attempts.\n" 489 | ] 490 | }, 491 | { 492 | "name": "stderr", 493 | "output_type": "stream", 494 | "text": [ 495 | "\n" 496 | ] 497 | } 498 | ], 499 | "source": [ 500 | "# documentation is wrong - there is not valset: https://dspy.ai/deep-dive/optimizers/bootstrap-fewshot/\n", 501 | "cot_few_shot_optimized = optimizer.compile(cot_predictor, trainset=train_examples)\n" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "id": "71cb5854-1b23-44f0-b28d-7a77e0458f67", 507 | "metadata": {}, 508 | "source": [ 509 | "## Peek under the hood of DSPy source code for BootStrapFewShot training" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "id": "b7f51b93-76ae-47bf-b644-53f1719af833", 515 | "metadata": {}, 516 | "source": [ 517 | "### DSPy source code for training\n", 518 | "class BootstrapFewShot()\n", 519 | " def _train(self):\n", 520 | " rng = random.Random(0)\n", 521 | " raw_demos = self.validation\n", 522 | "\n", 523 | " for name, predictor in self.student.named_predictors():\n", 524 | " augmented_demos = self.name2traces[name][: self.max_bootstrapped_demos]\n", 525 | "\n", 526 | " sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos))\n", 527 | " sample_size = max(0, sample_size)\n", 528 | "\n", 529 | " raw_demos = rng.sample(raw_demos, sample_size)\n", 530 | "\n", 531 | " if dspy.settings.release >= 20230928:\n", 532 | " predictor.demos = raw_demos + augmented_demos\n", 533 | " else:\n", 534 | " predictor.demos = augmented_demos + raw_demos\n", 535 | "\n", 536 | " return self.student" 537 | ] 538 | }, 539 | { 540 | "cell_type": "markdown", 541 | "id": "ceade59e-3c28-4938-9119-1599aee4ec2c", 542 | "metadata": {}, 543 | "source": [ 544 | "I consulted with ChatGPT about this method. Source code: https://github.com/stanfordnlp/dspy/blob/main/dspy/teleprompt/bootstrap.py\n", 545 | "\n", 546 | "_train() Purpose\n", 547 | "Once _bootstrap() has collected and validated a set of bootstrapped demos, _train() takes over to:\n", 548 | "\n", 549 | "Compile Final Demos for Predictors: _train() assembles the demos (both bootstrapped and labeled) for each predictor within the student model. For each predictor, it selects a mix of bootstrapped demos (from _bootstrap()) and labeled examples (raw demos from the validation set) to create a final demo set.\n", 550 | "Random Sampling: The method performs a random sample from the raw labeled demos, ensuring the demos meet the configuration limits, such as max_labeled_demos.\n", 551 | "Set Demos for Each Predictor: Finally, _train() updates each predictor in the student model with this finalized set of demos, effectively preparing it for use.\n", 552 | "In essence, _bootstrap() is responsible for creating and validating bootstrapped demos, while _train() assembles a balanced set of these demos and labeled examples to finalize the student model’s training." 553 | ] 554 | }, 555 | { 556 | "cell_type": "markdown", 557 | "id": "81a6d61a-b30f-41fd-96a6-4abf5108eeab", 558 | "metadata": {}, 559 | "source": [ 560 | "### My Own Summary\n", 561 | "BootstrapFewShot has two main properties:\n", 562 | "1. Enable you to generate additional examples\n", 563 | "2. DSPy tests which predictions pass the validation and keep only those" 564 | ] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "id": "0aa483be-ed96-4fbe-8ac0-32af8a04534b", 569 | "metadata": {}, 570 | "source": [ 571 | "## Define BootstrapFewShotWithRandomSearch" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": 11, 577 | "id": "86165cf0-fd28-4f74-9803-ebec811c4d6d", 578 | "metadata": {}, 579 | "outputs": [ 580 | { 581 | "name": "stdout", 582 | "output_type": "stream", 583 | "text": [ 584 | "Going to sample between 1 and 10 traces per predictor.\n", 585 | "Will attempt to bootstrap 5 candidate sets.\n" 586 | ] 587 | } 588 | ], 589 | "source": [ 590 | "from dspy.teleprompt import BootstrapFewShotWithRandomSearch\n", 591 | "\n", 592 | "optimizer = BootstrapFewShotWithRandomSearch(\n", 593 | " metric=metric, \n", 594 | " max_bootstrapped_demos=10, \n", 595 | " max_labeled_demos=10,\n", 596 | " num_threads=10,\n", 597 | " num_candidate_programs=5\n", 598 | ")" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": 12, 604 | "id": "852ff802-23c1-49ed-90f4-af0dd9210eb5", 605 | "metadata": {}, 606 | "outputs": [ 607 | { 608 | "name": "stderr", 609 | "output_type": "stream", 610 | "text": [ 611 | "Average Metric: 29 / 36 (80.6): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:05<00:00, 6.66it/s]\n", 612 | "2024/11/18 09:16:35 INFO dspy.evaluate.evaluate: Average Metric: 29 / 36 (80.6%)\n" 613 | ] 614 | }, 615 | { 616 | "name": "stdout", 617 | "output_type": "stream", 618 | "text": [ 619 | "New best score: 80.56 for seed -3\n", 620 | "Scores so far: [80.56]\n", 621 | "Best score so far: 80.56\n" 622 | ] 623 | }, 624 | { 625 | "name": "stderr", 626 | "output_type": "stream", 627 | "text": [ 628 | "Average Metric: 31 / 36 (86.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:04<00:00, 8.78it/s]\n", 629 | "2024/11/18 09:16:39 INFO dspy.evaluate.evaluate: Average Metric: 31 / 36 (86.1%)\n" 630 | ] 631 | }, 632 | { 633 | "name": "stdout", 634 | "output_type": "stream", 635 | "text": [ 636 | "New best score: 86.11 for seed -2\n", 637 | "Scores so far: [80.56, 86.11]\n", 638 | "Best score so far: 86.11\n" 639 | ] 640 | }, 641 | { 642 | "name": "stderr", 643 | "output_type": "stream", 644 | "text": [ 645 | " 28%|███████████████████████████████████▊ | 10/36 [00:00<00:00, 1520.39it/s]\n" 646 | ] 647 | }, 648 | { 649 | "name": "stdout", 650 | "output_type": "stream", 651 | "text": [ 652 | "Bootstrapped 10 full traces after 10 examples for up to 1 rounds, amounting to 10 attempts.\n" 653 | ] 654 | }, 655 | { 656 | "name": "stderr", 657 | "output_type": "stream", 658 | "text": [ 659 | "Average Metric: 30 / 36 (83.3): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:07<00:00, 5.06it/s]\n", 660 | "2024/11/18 09:16:47 INFO dspy.evaluate.evaluate: Average Metric: 30 / 36 (83.3%)\n" 661 | ] 662 | }, 663 | { 664 | "name": "stdout", 665 | "output_type": "stream", 666 | "text": [ 667 | "Scores so far: [80.56, 86.11, 83.33]\n", 668 | "Best score so far: 86.11\n" 669 | ] 670 | }, 671 | { 672 | "name": "stderr", 673 | "output_type": "stream", 674 | "text": [ 675 | " 28%|████████████████████████████████████▍ | 10/36 [00:25<01:05, 2.53s/it]\n" 676 | ] 677 | }, 678 | { 679 | "name": "stdout", 680 | "output_type": "stream", 681 | "text": [ 682 | "Bootstrapped 7 full traces after 10 examples for up to 1 rounds, amounting to 10 attempts.\n" 683 | ] 684 | }, 685 | { 686 | "name": "stderr", 687 | "output_type": "stream", 688 | "text": [ 689 | "Average Metric: 31 / 36 (86.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:06<00:00, 5.38it/s]\n", 690 | "2024/11/18 09:17:19 INFO dspy.evaluate.evaluate: Average Metric: 31 / 36 (86.1%)\n" 691 | ] 692 | }, 693 | { 694 | "name": "stdout", 695 | "output_type": "stream", 696 | "text": [ 697 | "Scores so far: [80.56, 86.11, 83.33, 86.11]\n", 698 | "Best score so far: 86.11\n" 699 | ] 700 | }, 701 | { 702 | "name": "stderr", 703 | "output_type": "stream", 704 | "text": [ 705 | " 11%|██████████████▋ | 4/36 [00:06<00:55, 1.72s/it]\n" 706 | ] 707 | }, 708 | { 709 | "name": "stdout", 710 | "output_type": "stream", 711 | "text": [ 712 | "Bootstrapped 3 full traces after 4 examples for up to 1 rounds, amounting to 4 attempts.\n" 713 | ] 714 | }, 715 | { 716 | "name": "stderr", 717 | "output_type": "stream", 718 | "text": [ 719 | "Average Metric: 31 / 36 (86.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:07<00:00, 4.59it/s]\n", 720 | "2024/11/18 09:17:33 INFO dspy.evaluate.evaluate: Average Metric: 31 / 36 (86.1%)\n" 721 | ] 722 | }, 723 | { 724 | "name": "stdout", 725 | "output_type": "stream", 726 | "text": [ 727 | "Scores so far: [80.56, 86.11, 83.33, 86.11, 86.11]\n", 728 | "Best score so far: 86.11\n" 729 | ] 730 | }, 731 | { 732 | "name": "stderr", 733 | "output_type": "stream", 734 | "text": [ 735 | " 3%|███▋ | 1/36 [00:01<00:40, 1.17s/it]\n" 736 | ] 737 | }, 738 | { 739 | "name": "stdout", 740 | "output_type": "stream", 741 | "text": [ 742 | "Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.\n" 743 | ] 744 | }, 745 | { 746 | "name": "stderr", 747 | "output_type": "stream", 748 | "text": [ 749 | "Average Metric: 32 / 36 (88.9): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:07<00:00, 4.80it/s]\n", 750 | "2024/11/18 09:17:42 INFO dspy.evaluate.evaluate: Average Metric: 32 / 36 (88.9%)\n" 751 | ] 752 | }, 753 | { 754 | "name": "stdout", 755 | "output_type": "stream", 756 | "text": [ 757 | "New best score: 88.89 for seed 2\n", 758 | "Scores so far: [80.56, 86.11, 83.33, 86.11, 86.11, 88.89]\n", 759 | "Best score so far: 88.89\n" 760 | ] 761 | }, 762 | { 763 | "name": "stderr", 764 | "output_type": "stream", 765 | "text": [ 766 | " 14%|██████████████████▎ | 5/36 [00:08<00:53, 1.72s/it]\n" 767 | ] 768 | }, 769 | { 770 | "name": "stdout", 771 | "output_type": "stream", 772 | "text": [ 773 | "Bootstrapped 4 full traces after 5 examples for up to 1 rounds, amounting to 5 attempts.\n" 774 | ] 775 | }, 776 | { 777 | "name": "stderr", 778 | "output_type": "stream", 779 | "text": [ 780 | "Average Metric: 29 / 36 (80.6): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:07<00:00, 5.10it/s]\n", 781 | "2024/11/18 09:17:58 INFO dspy.evaluate.evaluate: Average Metric: 29 / 36 (80.6%)\n" 782 | ] 783 | }, 784 | { 785 | "name": "stdout", 786 | "output_type": "stream", 787 | "text": [ 788 | "Scores so far: [80.56, 86.11, 83.33, 86.11, 86.11, 88.89, 80.56]\n", 789 | "Best score so far: 88.89\n" 790 | ] 791 | }, 792 | { 793 | "name": "stderr", 794 | "output_type": "stream", 795 | "text": [ 796 | " 19%|█████████████████████████▋ | 7/36 [00:11<00:46, 1.60s/it]\n" 797 | ] 798 | }, 799 | { 800 | "name": "stdout", 801 | "output_type": "stream", 802 | "text": [ 803 | "Bootstrapped 4 full traces after 7 examples for up to 1 rounds, amounting to 7 attempts.\n" 804 | ] 805 | }, 806 | { 807 | "name": "stderr", 808 | "output_type": "stream", 809 | "text": [ 810 | "Average Metric: 31 / 36 (86.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:06<00:00, 5.33it/s]\n", 811 | "2024/11/18 09:18:16 INFO dspy.evaluate.evaluate: Average Metric: 31 / 36 (86.1%)\n" 812 | ] 813 | }, 814 | { 815 | "name": "stdout", 816 | "output_type": "stream", 817 | "text": [ 818 | "Scores so far: [80.56, 86.11, 83.33, 86.11, 86.11, 88.89, 80.56, 86.11]\n", 819 | "Best score so far: 88.89\n", 820 | "8 candidate programs found.\n" 821 | ] 822 | } 823 | ], 824 | "source": [ 825 | "cot_few_shot_rs_optimized = optimizer.compile(cot_predictor, trainset=train_examples)" 826 | ] 827 | }, 828 | { 829 | "cell_type": "markdown", 830 | "id": "6c393252-33c3-42d4-82cb-02072a1e757d", 831 | "metadata": {}, 832 | "source": [ 833 | "## Peek under the hood of the source code Implementation" 834 | ] 835 | }, 836 | { 837 | "cell_type": "markdown", 838 | "id": "f463a5df-875c-47b2-b8b1-e857e914714c", 839 | "metadata": {}, 840 | "source": [ 841 | "### Source code\n", 842 | "From: https://github.com/stanfordnlp/dspy/blob/main/dspy/teleprompt/random_search.py\n", 843 | "\n", 844 | " assert seed >= 0, seed\n", 845 | "\n", 846 | " random.Random(seed).shuffle(trainset_copy)\n", 847 | " size = random.Random(seed).randint(self.min_num_samples, self.max_num_samples)\n", 848 | "\n", 849 | " optimizer = BootstrapFewShot(\n", 850 | " metric=self.metric,\n", 851 | " metric_threshold=self.metric_threshold,\n", 852 | " max_bootstrapped_demos=size,\n", 853 | " max_labeled_demos=self.max_labeled_demos,\n", 854 | " teacher_settings=self.teacher_settings,\n", 855 | " max_rounds=self.max_rounds,\n", 856 | " max_errors=self.max_errors,\n", 857 | " )\n", 858 | "\n", 859 | " program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy)\n", 860 | "\n", 861 | " evaluate = Evaluate(\n", 862 | " devset=self.valset,\n", 863 | " metric=self.metric,\n", 864 | " num_threads=self.num_threads,\n", 865 | " max_errors=self.max_errors,\n", 866 | " display_table=False,\n", 867 | " display_progress=True,\n", 868 | " )\n", 869 | "\n", 870 | " score, subscores = evaluate(program, return_all_scores=True)\n", 871 | "\n", 872 | " all_subscores.append(subscores)" 873 | ] 874 | }, 875 | { 876 | "cell_type": "markdown", 877 | "id": "cc4e963c-4c49-469c-9ebb-3486bd40cad8", 878 | "metadata": {}, 879 | "source": [ 880 | "### My own summary\n", 881 | "Given the number of programs we will generate each time a different seed and run BootStrapFewShot with that " 882 | ] 883 | }, 884 | { 885 | "cell_type": "markdown", 886 | "id": "c8f8ed4e-59fe-4657-a760-61e7d212c99f", 887 | "metadata": {}, 888 | "source": [ 889 | "# Evaluation" 890 | ] 891 | }, 892 | { 893 | "cell_type": "markdown", 894 | "id": "97081940-62e1-470a-b227-c5465b1ddc7c", 895 | "metadata": {}, 896 | "source": [ 897 | "## Single Evaluation" 898 | ] 899 | }, 900 | { 901 | "cell_type": "code", 902 | "execution_count": 13, 903 | "id": "16175d77-0aa0-4265-8da0-b8ad3dd7325a", 904 | "metadata": {}, 905 | "outputs": [ 906 | { 907 | "name": "stdout", 908 | "output_type": "stream", 909 | "text": [ 910 | "Example({'customer_message': 'show flights from minneapolis to kansas city', 'answer': 'flight', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'})\n", 911 | "Prediction(\n", 912 | " rationale='identify the customer\\'s intent. The customer is asking to see flights from one city (Minneapolis) to another (Kansas City). This indicates that they are looking for information related to flights, specifically the availability of flights between these two locations. The intent clearly aligns with searching for flights, which is represented by the label \"flight.\"',\n", 913 | " answer='flight'\n", 914 | ")\n", 915 | "Customer message: \t show flights from minneapolis to kansas city\n", 916 | "\n", 917 | "Gold Response: \t flight\n", 918 | "\n", 919 | "Predicted Response: \t flight\n", 920 | "\n", 921 | "Exact match score: 1.00\n" 922 | ] 923 | } 924 | ], 925 | "source": [ 926 | "from dspy.evaluate import answer_exact_match\n", 927 | "\n", 928 | "# Instantiate the metric.\n", 929 | "metric = answer_exact_match\n", 930 | "\n", 931 | "example = test_examples[0]\n", 932 | "# Produce a prediction from our `cot` module, using the `example` above as input.\n", 933 | "print(example)\n", 934 | "pred = cot_predictor(**example.inputs())\n", 935 | "print(pred)\n", 936 | "\n", 937 | "# Compute the metric score for the prediction.\n", 938 | "score = metric(example, pred)\n", 939 | "\n", 940 | "print(f\"Customer message: \\t {example.customer_message}\\n\")\n", 941 | "print(f\"Gold Response: \\t {example.answer}\\n\")\n", 942 | "print(f\"Predicted Response: \\t {pred.answer}\\n\")\n", 943 | "print(f\"Exact match score: {score:.2f}\")" 944 | ] 945 | }, 946 | { 947 | "cell_type": "markdown", 948 | "id": "8b6a33cf-0fdc-458c-a698-6cb2e0fe1732", 949 | "metadata": {}, 950 | "source": [ 951 | "## Setup Evaluation" 952 | ] 953 | }, 954 | { 955 | "cell_type": "code", 956 | "execution_count": 14, 957 | "id": "64bc0f61-727b-4a12-88e2-71d7bea426ba", 958 | "metadata": {}, 959 | "outputs": [ 960 | { 961 | "name": "stdout", 962 | "output_type": "stream", 963 | "text": [ 964 | "36\n" 965 | ] 966 | } 967 | ], 968 | "source": [ 969 | "from dspy.evaluate.evaluate import Evaluate\n", 970 | "# Set up the `evaluate_atis` function. We'll use this many times below.\n", 971 | "print(len(train_examples))\n", 972 | "evaluate_atis = Evaluate(devset=test_examples, num_threads=8, display_progress=True, display_table=5, provide_traceback=True)" 973 | ] 974 | }, 975 | { 976 | "cell_type": "markdown", 977 | "id": "7688f4de-726c-426a-a048-fbcbc3723a0e", 978 | "metadata": {}, 979 | "source": [ 980 | "## Evaluate zero shot CoT " 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": 15, 986 | "id": "06aca234-925e-4e8c-abb7-f9de1d54e4a4", 987 | "metadata": {}, 988 | "outputs": [ 989 | { 990 | "name": "stderr", 991 | "output_type": "stream", 992 | "text": [ 993 | "Average Metric: 37 / 42 (88.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:08<00:00, 5.17it/s]\n", 994 | "2024/11/18 09:19:01 INFO dspy.evaluate.evaluate: Average Metric: 37 / 42 (88.1%)\n" 995 | ] 996 | }, 997 | { 998 | "data": { 999 | "text/html": [ 1000 | "\n", 1014 | "\n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | "
 customer_messageexample_answerintent_labelsrationalepred_answeranswer_exact_match
0show flights from minneapolis to kansas cityflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noidentify the customer's intent. The customer is asking to see flights from one city (Minneapolis) to another (Kansas City). This indicates that they are looking...flight✔️ [True]
1which flights leave on wednesday april thirteenth from indianapolis and arrive in montreal in the morningflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noproduce the answer. We need to identify the main focus of the customer's message. The customer is inquiring about flights, specifically looking for details about...flight_time
2i would like flight information from phoenix to denverflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noidentify the customer's intent. The customer is asking for specific information regarding a flight from Phoenix to Denver. This request involves details about the flight,...flight✔️ [True]
3how much is coach flight from pittsburgh to atlantaairfareflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_nodetermine the customer's intent. The customer is inquiring about the cost of a coach flight from Pittsburgh to Atlanta, which directly relates to airfare. The...airfare✔️ [True]
4list airfares for first class round trip from detroit to st. petersburgairfareflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noidentify the customer's intent. The customer is asking for information about airfares specifically for a first-class round trip flight from Detroit to St. Petersburg. This...airfare✔️ [True]
\n" 1074 | ], 1075 | "text/plain": [ 1076 | "" 1077 | ] 1078 | }, 1079 | "metadata": {}, 1080 | "output_type": "display_data" 1081 | }, 1082 | { 1083 | "data": { 1084 | "text/html": [ 1085 | "\n", 1086 | "
\n", 1092 | " ... 37 more rows not displayed ...\n", 1093 | "
\n", 1094 | " " 1095 | ], 1096 | "text/plain": [ 1097 | "" 1098 | ] 1099 | }, 1100 | "metadata": {}, 1101 | "output_type": "display_data" 1102 | }, 1103 | { 1104 | "data": { 1105 | "text/plain": [ 1106 | "88.1" 1107 | ] 1108 | }, 1109 | "execution_count": 15, 1110 | "metadata": {}, 1111 | "output_type": "execute_result" 1112 | } 1113 | ], 1114 | "source": [ 1115 | "# Evaluate the program with the `answer_exact_match` metric.\n", 1116 | "# Launch evaluation.\n", 1117 | "evaluate_atis(cot_predictor, metric=metric)\n" 1118 | ] 1119 | }, 1120 | { 1121 | "cell_type": "markdown", 1122 | "id": "eefa60d2-dac0-4646-afed-ee34e6675619", 1123 | "metadata": {}, 1124 | "source": [ 1125 | "## Evaluate few shot CoT" 1126 | ] 1127 | }, 1128 | { 1129 | "cell_type": "code", 1130 | "execution_count": 16, 1131 | "id": "142910a7-ef4a-44c0-b48d-b9e999e74553", 1132 | "metadata": {}, 1133 | "outputs": [ 1134 | { 1135 | "name": "stderr", 1136 | "output_type": "stream", 1137 | "text": [ 1138 | "Average Metric: 33 / 42 (78.6): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:14<00:00, 2.98it/s]\n", 1139 | "2024/11/18 09:19:22 INFO dspy.evaluate.evaluate: Average Metric: 33 / 42 (78.6%)\n" 1140 | ] 1141 | }, 1142 | { 1143 | "data": { 1144 | "text/html": [ 1145 | "\n", 1159 | "\n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1206 | " \n", 1207 | " \n", 1208 | " \n", 1209 | " \n", 1210 | " \n", 1211 | " \n", 1212 | " \n", 1213 | " \n", 1214 | " \n", 1215 | " \n", 1216 | " \n", 1217 | " \n", 1218 | "
 customer_messageexample_answerintent_labelsrationalepred_answeranswer_exact_match
0show flights from minneapolis to kansas cityflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noproduce the answer. We need to identify the main intent of the customer message. The customer is asking to see flights from one city to...flight✔️ [True]
1which flights leave on wednesday april thirteenth from indianapolis and arrive in montreal in the morningflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_nodetermine the intent of the customer's message. The customer is asking for information about flights that leave from Indianapolis and arrive in Montreal on a...flight_time
2i would like flight information from phoenix to denverflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noproduce the answer. We need to identify the main focus of the customer's request. The customer is asking for flight information, which typically includes details...flight✔️ [True]
3how much is coach flight from pittsburgh to atlantaairfareflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noproduce the answer. We need to identify the main focus of the customer's inquiry. The customer is asking about the cost of a coach flight...airfare✔️ [True]
4list airfares for first class round trip from detroit to st. petersburgairfareflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noproduce the answer. We need to identify the main focus of the customer's request. The customer is asking for airfares specifically for first-class round trips,...airfare✔️ [True]
\n" 1219 | ], 1220 | "text/plain": [ 1221 | "" 1222 | ] 1223 | }, 1224 | "metadata": {}, 1225 | "output_type": "display_data" 1226 | }, 1227 | { 1228 | "data": { 1229 | "text/html": [ 1230 | "\n", 1231 | "
\n", 1237 | " ... 37 more rows not displayed ...\n", 1238 | "
\n", 1239 | " " 1240 | ], 1241 | "text/plain": [ 1242 | "" 1243 | ] 1244 | }, 1245 | "metadata": {}, 1246 | "output_type": "display_data" 1247 | }, 1248 | { 1249 | "data": { 1250 | "text/plain": [ 1251 | "78.57" 1252 | ] 1253 | }, 1254 | "execution_count": 16, 1255 | "metadata": {}, 1256 | "output_type": "execute_result" 1257 | } 1258 | ], 1259 | "source": [ 1260 | "evaluate_atis(few_shot_model, metric=metric)" 1261 | ] 1262 | }, 1263 | { 1264 | "cell_type": "markdown", 1265 | "id": "4e91d4f9-a39f-4d5a-8d0a-bec4c0ba6501", 1266 | "metadata": {}, 1267 | "source": [ 1268 | "## Evaluate BootstrapedFewShot" 1269 | ] 1270 | }, 1271 | { 1272 | "cell_type": "code", 1273 | "execution_count": 17, 1274 | "id": "5f20ef32-d986-4fda-875a-deaa690651ac", 1275 | "metadata": {}, 1276 | "outputs": [ 1277 | { 1278 | "name": "stderr", 1279 | "output_type": "stream", 1280 | "text": [ 1281 | "Average Metric: 37 / 42 (88.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:10<00:00, 4.06it/s]\n", 1282 | "2024/11/18 09:19:43 INFO dspy.evaluate.evaluate: Average Metric: 37 / 42 (88.1%)\n" 1283 | ] 1284 | }, 1285 | { 1286 | "data": { 1287 | "text/html": [ 1288 | "\n", 1302 | "\n", 1303 | " \n", 1304 | " \n", 1305 | " \n", 1306 | " \n", 1307 | " \n", 1308 | " \n", 1309 | " \n", 1310 | " \n", 1311 | " \n", 1312 | " \n", 1313 | " \n", 1314 | " \n", 1315 | " \n", 1316 | " \n", 1317 | " \n", 1318 | " \n", 1319 | " \n", 1320 | " \n", 1321 | " \n", 1322 | " \n", 1323 | " \n", 1324 | " \n", 1325 | " \n", 1326 | " \n", 1327 | " \n", 1328 | " \n", 1329 | " \n", 1330 | " \n", 1331 | " \n", 1332 | " \n", 1333 | " \n", 1334 | " \n", 1335 | " \n", 1336 | " \n", 1337 | " \n", 1338 | " \n", 1339 | " \n", 1340 | " \n", 1341 | " \n", 1342 | " \n", 1343 | " \n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1352 | " \n", 1353 | " \n", 1354 | " \n", 1355 | " \n", 1356 | " \n", 1357 | " \n", 1358 | " \n", 1359 | " \n", 1360 | " \n", 1361 | "
 customer_messageexample_answerintent_labelsrationalepred_answeranswer_exact_match
0show flights from minneapolis to kansas cityflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noidentify the main focus of the customer's request. The customer is asking to see flights from Minneapolis to Kansas City, which indicates a clear interest...flight✔️ [True]
1which flights leave on wednesday april thirteenth from indianapolis and arrive in montreal in the morningflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noidentify the intent of the customer's message. The customer is inquiring about flights that leave on a specific day (Wednesday, April 13th) from Indianapolis to...flight✔️ [True]
2i would like flight information from phoenix to denverflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noidentify the main focus of the customer's request. The customer is explicitly asking for flight information regarding a route from Phoenix to Denver. This indicates...flight✔️ [True]
3how much is coach flight from pittsburgh to atlantaairfareflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noidentify the main focus of the customer's inquiry. The customer is asking about the cost of a coach flight from Pittsburgh to Atlanta, which indicates...airfare✔️ [True]
4list airfares for first class round trip from detroit to st. petersburgairfareflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noidentify the main focus of the customer's request. The customer is asking for airfares specifically for first-class round trip flights from Detroit to St. Petersburg....airfare✔️ [True]
\n" 1362 | ], 1363 | "text/plain": [ 1364 | "" 1365 | ] 1366 | }, 1367 | "metadata": {}, 1368 | "output_type": "display_data" 1369 | }, 1370 | { 1371 | "data": { 1372 | "text/html": [ 1373 | "\n", 1374 | "
\n", 1380 | " ... 37 more rows not displayed ...\n", 1381 | "
\n", 1382 | " " 1383 | ], 1384 | "text/plain": [ 1385 | "" 1386 | ] 1387 | }, 1388 | "metadata": {}, 1389 | "output_type": "display_data" 1390 | }, 1391 | { 1392 | "data": { 1393 | "text/plain": [ 1394 | "88.1" 1395 | ] 1396 | }, 1397 | "execution_count": 17, 1398 | "metadata": {}, 1399 | "output_type": "execute_result" 1400 | } 1401 | ], 1402 | "source": [ 1403 | "evaluate_atis(cot_few_shot_optimized, metric=metric)" 1404 | ] 1405 | }, 1406 | { 1407 | "cell_type": "markdown", 1408 | "id": "67f56d55-d44f-4a9a-b5ba-3e93670fc8fd", 1409 | "metadata": {}, 1410 | "source": [ 1411 | "## Evaluate Boostraped Random Search" 1412 | ] 1413 | }, 1414 | { 1415 | "cell_type": "code", 1416 | "execution_count": 18, 1417 | "id": "8a0491f8-dd36-4910-8db5-f3036a00960a", 1418 | "metadata": {}, 1419 | "outputs": [ 1420 | { 1421 | "name": "stderr", 1422 | "output_type": "stream", 1423 | "text": [ 1424 | "Average Metric: 40 / 42 (95.2): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:09<00:00, 4.22it/s]\n", 1425 | "2024/11/18 09:19:58 INFO dspy.evaluate.evaluate: Average Metric: 40 / 42 (95.2%)\n" 1426 | ] 1427 | }, 1428 | { 1429 | "data": { 1430 | "text/html": [ 1431 | "\n", 1445 | "\n", 1446 | " \n", 1447 | " \n", 1448 | " \n", 1449 | " \n", 1450 | " \n", 1451 | " \n", 1452 | " \n", 1453 | " \n", 1454 | " \n", 1455 | " \n", 1456 | " \n", 1457 | " \n", 1458 | " \n", 1459 | " \n", 1460 | " \n", 1461 | " \n", 1462 | " \n", 1463 | " \n", 1464 | " \n", 1465 | " \n", 1466 | " \n", 1467 | " \n", 1468 | " \n", 1469 | " \n", 1470 | " \n", 1471 | " \n", 1472 | " \n", 1473 | " \n", 1474 | " \n", 1475 | " \n", 1476 | " \n", 1477 | " \n", 1478 | " \n", 1479 | " \n", 1480 | " \n", 1481 | " \n", 1482 | " \n", 1483 | " \n", 1484 | " \n", 1485 | " \n", 1486 | " \n", 1487 | " \n", 1488 | " \n", 1489 | " \n", 1490 | " \n", 1491 | " \n", 1492 | " \n", 1493 | " \n", 1494 | " \n", 1495 | " \n", 1496 | " \n", 1497 | " \n", 1498 | " \n", 1499 | " \n", 1500 | " \n", 1501 | " \n", 1502 | " \n", 1503 | " \n", 1504 | "
 customer_messageexample_answerintent_labelsrationalepred_answeranswer_exact_match
0show flights from minneapolis to kansas cityflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_nodetermine the intent of the customer's message. The customer is asking to see flights between two specific cities, which indicates they are looking for information...flight✔️ [True]
1which flights leave on wednesday april thirteenth from indianapolis and arrive in montreal in the morningflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_nodetermine the intent of the customer's message. The customer is asking for specific flight information regarding flights that leave from a particular city on a...flight✔️ [True]
2i would like flight information from phoenix to denverflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_nodetermine the intent of the customer's message. The customer is requesting flight information between two specific cities, which indicates they are looking for details related...flight✔️ [True]
3how much is coach flight from pittsburgh to atlantaairfareflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_nodetermine the intent of the customer's message. The customer is inquiring about the cost of a specific type of flight (coach) from one city to...airfare✔️ [True]
4list airfares for first class round trip from detroit to st. petersburgairfareflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_nodetermine the intent of the customer's message. The customer is asking for information about airfares specifically for first class round trip flights between two cities....airfare✔️ [True]
\n" 1505 | ], 1506 | "text/plain": [ 1507 | "" 1508 | ] 1509 | }, 1510 | "metadata": {}, 1511 | "output_type": "display_data" 1512 | }, 1513 | { 1514 | "data": { 1515 | "text/html": [ 1516 | "\n", 1517 | "
\n", 1523 | " ... 37 more rows not displayed ...\n", 1524 | "
\n", 1525 | " " 1526 | ], 1527 | "text/plain": [ 1528 | "" 1529 | ] 1530 | }, 1531 | "metadata": {}, 1532 | "output_type": "display_data" 1533 | }, 1534 | { 1535 | "data": { 1536 | "text/plain": [ 1537 | "95.24" 1538 | ] 1539 | }, 1540 | "execution_count": 18, 1541 | "metadata": {}, 1542 | "output_type": "execute_result" 1543 | } 1544 | ], 1545 | "source": [ 1546 | "evaluate_atis(cot_few_shot_rs_optimized, metric=metric)" 1547 | ] 1548 | }, 1549 | { 1550 | "cell_type": "markdown", 1551 | "id": "bca34c63-f758-48c3-b7d8-c25b54045125", 1552 | "metadata": {}, 1553 | "source": [ 1554 | "# Save / Load models" 1555 | ] 1556 | }, 1557 | { 1558 | "cell_type": "code", 1559 | "execution_count": 21, 1560 | "id": "b2476c94-4ff1-4467-a9ee-acf8e4121cf1", 1561 | "metadata": {}, 1562 | "outputs": [], 1563 | "source": [ 1564 | "cot_predictor.save(\"cot_zero_shot.json\")\n", 1565 | "few_shot_model.save(\"cot_few_shot.json\")\n", 1566 | "cot_few_shot_optimized.save(\"cot_boostraped_few_shot.json\")\n", 1567 | "cot_few_shot_rs_optimized.save(\"cot_bootstraped_rs_few_shot.json\")" 1568 | ] 1569 | }, 1570 | { 1571 | "cell_type": "code", 1572 | "execution_count": 22, 1573 | "id": "ee875833-b02d-4ef8-a218-3b242dba4292", 1574 | "metadata": {}, 1575 | "outputs": [], 1576 | "source": [ 1577 | "cot_predictor.load(\"cot_zero_shot.json\")\n", 1578 | "few_shot_model.load(\"cot_few_shot.json\")\n", 1579 | "cot_few_shot_optimized.load(\"cot_boostraped_few_shot.json\")\n", 1580 | "cot_few_shot_rs_optimized.load(\"cot_bootstraped_rs_few_shot.json\")" 1581 | ] 1582 | }, 1583 | { 1584 | "cell_type": "code", 1585 | "execution_count": null, 1586 | "id": "63c9f057-c9bb-40d7-812b-8f570dc78d05", 1587 | "metadata": {}, 1588 | "outputs": [], 1589 | "source": [] 1590 | }, 1591 | { 1592 | "cell_type": "code", 1593 | "execution_count": null, 1594 | "id": "7f29f4cd-3511-4e4c-9bd3-c3a6adf687e5", 1595 | "metadata": {}, 1596 | "outputs": [], 1597 | "source": [] 1598 | }, 1599 | { 1600 | "cell_type": "code", 1601 | "execution_count": 27, 1602 | "id": "40ca28e6-e75f-4976-9a1c-2237102b7932", 1603 | "metadata": {}, 1604 | "outputs": [ 1605 | { 1606 | "data": { 1607 | "text/plain": [ 1608 | "[Predict(StringSignature(customer_message, intent_labels -> rationale, answer\n", 1609 | " instructions='Classify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.'\n", 1610 | " customer_message = Field(annotation=str required=True json_schema_extra={'desc': 'Customer message during customer service interaction', '__dspy_field_type': 'input', 'prefix': 'Customer Message:'})\n", 1611 | " intent_labels = Field(annotation=str required=True json_schema_extra={'desc': 'Labels that represent customer intent', '__dspy_field_type': 'input', 'prefix': 'Intent Labels:'})\n", 1612 | " rationale = Field(annotation=str required=True json_schema_extra={'prefix': \"Reasoning: Let's think step by step in order to\", 'desc': '${produce the answer}. We ...', '__dspy_field_type': 'output'})\n", 1613 | " answer = Field(annotation=str required=True json_schema_extra={'desc': \"a label best matching customer's intent \", '__dspy_field_type': 'output', 'prefix': 'Answer:'})\n", 1614 | " ))]" 1615 | ] 1616 | }, 1617 | "execution_count": 27, 1618 | "metadata": {}, 1619 | "output_type": "execute_result" 1620 | } 1621 | ], 1622 | "source": [ 1623 | "cot_predictor.parameters()" 1624 | ] 1625 | }, 1626 | { 1627 | "cell_type": "code", 1628 | "execution_count": 30, 1629 | "id": "c176b4f0-ce8c-40a0-b905-be11ff0d4332", 1630 | "metadata": {}, 1631 | "outputs": [ 1632 | { 1633 | "data": { 1634 | "text/plain": [ 1635 | "[Example({'customer_message': 'what kind of plane flies from boston to pittsburgh after noon', 'answer': 'aircraft', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1636 | " Example({'customer_message': 'what are the restrictions on the cheapest one way fare between boston and oakland', 'answer': 'restriction', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1637 | " Example({'customer_message': 'fine can you give me information on ground transportation in washington dc to downtown', 'answer': 'ground_service', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1638 | " Example({'customer_message': 'what is the distance from toronto international airport to toronto', 'answer': 'distance', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1639 | " Example({'customer_message': 'list flights from denver to philadelphia', 'answer': 'flight', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1640 | " Example({'customer_message': 'what is the meaning of restriction ap80', 'answer': 'abbreviation', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1641 | " Example({'customer_message': 'what is restriction ap57', 'answer': 'restriction', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1642 | " Example({'customer_message': 'what time zone is denver in', 'answer': 'city', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1643 | " Example({'customer_message': 'airports', 'answer': 'airport', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1644 | " Example({'customer_message': 'list daily flights from boston to oakland using twa', 'answer': 'flight', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'})]" 1645 | ] 1646 | }, 1647 | "execution_count": 30, 1648 | "metadata": {}, 1649 | "output_type": "execute_result" 1650 | } 1651 | ], 1652 | "source": [ 1653 | "few_shot_model.demos" 1654 | ] 1655 | }, 1656 | { 1657 | "cell_type": "code", 1658 | "execution_count": 31, 1659 | "id": "a1df2ccc-04aa-48e3-a2c7-daf235662e7a", 1660 | "metadata": {}, 1661 | "outputs": [ 1662 | { 1663 | "data": { 1664 | "text/plain": [ 1665 | "[Example({'augmented': True, 'customer_message': 'show me the flights from los angeles to pittsburgh which arrive at pittsburgh on monday', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'produce the answer. We need to identify the main focus of the customer\\'s request. The customer is asking for flights from Los Angeles to Pittsburgh, specifically mentioning the arrival day (Monday). This indicates that the intent is related to finding flights. The most relevant label that matches this request is \"flight\".', 'answer': 'flight'}) (input_keys=None),\n", 1666 | " Example({'augmented': True, 'customer_message': 'what time does the flight leave denver going to san francisco on continental airlines', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'determine the intent of the customer\\'s message. The customer is asking for the departure time of a specific flight from Denver to San Francisco on Continental Airlines. This indicates that the primary focus of the inquiry is related to the timing of the flight. \\n\\nThe intent labels include various categories, but the most relevant one for this question is \"flight_time,\" as it directly pertains to the time of departure for the flight.', 'answer': 'flight_time'}) (input_keys=None),\n", 1667 | " Example({'augmented': True, 'customer_message': 'show me times for coach flights between boston and baltimore on wednesday', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'produce the answer. We need to identify the main focus of the customer\\'s request. The customer is asking for \"times\" related to \"coach flights\" between two cities on a specific day. This indicates that the customer is interested in the schedule or timing of flights. Therefore, the intent label that best matches this request is related to flight times.', 'answer': 'flight_time'}) (input_keys=None),\n", 1668 | " Example({'augmented': True, 'customer_message': 'how much does it cost to fly on american from dallas to baltimore', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'produce the answer. We need to identify the main focus of the customer\\'s inquiry. The customer is asking about the cost of a flight on American Airlines from Dallas to Baltimore. This indicates that the primary intent is related to the fare or price of the flight. Therefore, the best matching intent label for this inquiry is \"airfare.\"', 'answer': 'airfare'}) (input_keys=None),\n", 1669 | " Example({'augmented': True, 'customer_message': 'please show me fares from denver to san francisco', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"produce the answer. We need to identify the main focus of the customer's request. The customer is asking for fares, which indicates they are interested in the cost of flights from Denver to San Francisco. This aligns with the intent label related to airfare.\", 'answer': 'airfare'}) (input_keys=None),\n", 1670 | " Example({'augmented': True, 'customer_message': \"i 'd like to know what type of plane flies from boston to dallas fort worth after 530\", 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"determine the intent of the customer's message. The customer is asking about the type of plane that operates on a specific route (from Boston to Dallas Fort Worth) and at a specific time (after 5:30). This indicates that the customer is interested in information related to the aircraft used for that flight.\", 'answer': 'aircraft'}) (input_keys=None),\n", 1671 | " Example({'augmented': True, 'customer_message': 'what is the smallest aircraft that flies from pittsburgh to baltimore arriving may seventh', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"determine the intent of the customer's message. The customer is asking about the smallest aircraft that operates on a specific route (from Pittsburgh to Baltimore) and on a specific date (arriving May seventh). This indicates that the customer is interested in information related to the type of aircraft used for that flight.\", 'answer': 'aircraft'}) (input_keys=None),\n", 1672 | " Example({'augmented': True, 'customer_message': 'what types of ground transportation are there to san francisco airport', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'identify the intent of the customer message. The customer is asking about the types of ground transportation available to San Francisco airport. This indicates that the focus is on transportation services rather than flights, fares, or other aviation-related queries. Therefore, the most appropriate intent label that matches this inquiry is related to ground services.', 'answer': 'ground_service'}) (input_keys=None),\n", 1673 | " Example({'augmented': True, 'customer_message': 'is there ground transportation from the milwaukee airport to the downtown area', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'identify the intent of the customer message. The customer is inquiring about ground transportation options from the Milwaukee airport to the downtown area. This indicates a need for information related to ground services rather than flights or fares. Therefore, the most appropriate intent label for this message is related to ground services.', 'answer': 'ground_service'}) (input_keys=None),\n", 1674 | " Example({'augmented': True, 'customer_message': 'houston airports', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'identify the intent of the customer message. The customer is asking about \"houston airports,\" which indicates they are looking for information related to airports in Houston. This aligns with the intent label that represents airport-related inquiries.', 'answer': 'airport'}) (input_keys=None)]" 1675 | ] 1676 | }, 1677 | "execution_count": 31, 1678 | "metadata": {}, 1679 | "output_type": "execute_result" 1680 | } 1681 | ], 1682 | "source": [ 1683 | "cot_few_shot_optimized.demos" 1684 | ] 1685 | }, 1686 | { 1687 | "cell_type": "code", 1688 | "execution_count": 78, 1689 | "id": "8956c58a-069d-4da9-92a2-865181e45305", 1690 | "metadata": {}, 1691 | "outputs": [ 1692 | { 1693 | "data": { 1694 | "text/plain": [ 1695 | "[Example({'augmented': True, 'customer_message': 'show me the flights from boston to oakland', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'produce the answer. We need to identify the main request in the customer message. The customer is asking for information about flights from Boston to Oakland, which indicates they are looking for flight details. The most relevant intent label that matches this request is \"flight\".', 'answer': 'flight'}) (input_keys=None),\n", 1696 | " Example({'augmented': True, 'customer_message': 'does delta have an early afternoon flight from boston to san francisco', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"identify the main request in the customer's message. The customer is asking about the availability of a specific flight operated by Delta Airlines from Boston to San Francisco during a particular time frame (early afternoon). This indicates an intent to inquire about a flight schedule rather than asking for flights, fares, or other related information. \\n\\nSince the customer is specifically asking about the existence of a flight, the best matching intent label is 'flight'.\", 'answer': 'flight'}) (input_keys=None),\n", 1697 | " Example({'augmented': True, 'customer_message': 'please give me the flight times the morning on united airlines for september twentieth from philadelphia to san francisco', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"produce the answer. We need to identify the main focus of the customer's request. The customer is asking for flight times specifically for a flight on United Airlines from Philadelphia to San Francisco on a specific date, which indicates they are interested in the timing of the flights. Therefore, the intent label that best matches this request is flight_time.\", 'answer': 'flight_time'}) (input_keys=None),\n", 1698 | " Example({'augmented': True, 'customer_message': 'what times does continental depart from boston to san francisco', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"determine the intent of the customer's message. The customer is asking for specific departure times for a flight operated by Continental Airlines from Boston to San Francisco. This request is focused on the timing of the flight, which falls under the category of flight times.\", 'answer': 'flight_time'}) (input_keys=None),\n", 1699 | " Example({'customer_message': 'can you list costs of denver rental cars', 'answer': 'ground_fare', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1700 | " Example({'customer_message': 'all flights and fares from pittsburgh to dallas round trip after 12 pm less than 1100 dollars', 'answer': 'flight+airfare', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1701 | " Example({'customer_message': 'what is the ap 57 restriction', 'answer': 'abbreviation', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1702 | " Example({'customer_message': 'what is the total seating capacity of all aircraft of american airlines', 'answer': 'capacity', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1703 | " Example({'customer_message': 'names of airports', 'answer': 'airport', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1704 | " Example({'customer_message': 'how much is the ground transportation between atlanta and downtown', 'answer': 'ground_fare', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1705 | " Example({'customer_message': 'what ground transportation is there in atlanta', 'answer': 'ground_service', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1706 | " Example({'customer_message': 'what are the seating capacities of planes between pittsburgh and baltimore', 'answer': 'capacity', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1707 | " Example({'customer_message': 'what are all the available meals', 'answer': 'meal', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1708 | " Example({'customer_message': 'show me all meals on flights from atlanta to washington', 'answer': 'meal', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1709 | " Example({'customer_message': 'airline and flight number from columbus to minneapolis', 'answer': 'airline+flight_no', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n", 1710 | " Example({'customer_message': 'flight numbers from columbus to minneapolis tomorrow', 'answer': 'flight_no', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'})]" 1711 | ] 1712 | }, 1713 | "execution_count": 78, 1714 | "metadata": {}, 1715 | "output_type": "execute_result" 1716 | } 1717 | ], 1718 | "source": [ 1719 | "cot_few_shot_rs_optimized.demos" 1720 | ] 1721 | }, 1722 | { 1723 | "cell_type": "markdown", 1724 | "id": "4a0856d6-bc3b-4c47-8ae3-6934975e0839", 1725 | "metadata": {}, 1726 | "source": [ 1727 | "# Remove bootstrapping" 1728 | ] 1729 | }, 1730 | { 1731 | "cell_type": "code", 1732 | "execution_count": 41, 1733 | "id": "d5f289d7-750a-4d4d-b959-c8ff74fa639a", 1734 | "metadata": {}, 1735 | "outputs": [ 1736 | { 1737 | "name": "stderr", 1738 | "output_type": "stream", 1739 | "text": [ 1740 | " 11%|██████████████▋ | 4/36 [00:01<00:11, 2.85it/s]\n" 1741 | ] 1742 | }, 1743 | { 1744 | "name": "stdout", 1745 | "output_type": "stream", 1746 | "text": [ 1747 | "Bootstrapped 4 full traces after 4 examples for up to 20 rounds, amounting to 8 attempts.\n" 1748 | ] 1749 | }, 1750 | { 1751 | "name": "stderr", 1752 | "output_type": "stream", 1753 | "text": [ 1754 | "Average Metric: 32 / 40 (80.0): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:11<00:00, 3.63it/s]\n", 1755 | "2024/11/17 06:42:28 INFO dspy.evaluate.evaluate: Average Metric: 32 / 40 (80.0%)\n" 1756 | ] 1757 | }, 1758 | { 1759 | "data": { 1760 | "text/html": [ 1761 | "\n", 1775 | "\n", 1776 | " \n", 1777 | " \n", 1778 | " \n", 1779 | " \n", 1780 | " \n", 1781 | " \n", 1782 | " \n", 1783 | " \n", 1784 | " \n", 1785 | " \n", 1786 | " \n", 1787 | " \n", 1788 | " \n", 1789 | " \n", 1790 | " \n", 1791 | " \n", 1792 | " \n", 1793 | " \n", 1794 | " \n", 1795 | " \n", 1796 | " \n", 1797 | " \n", 1798 | " \n", 1799 | " \n", 1800 | " \n", 1801 | " \n", 1802 | " \n", 1803 | " \n", 1804 | " \n", 1805 | " \n", 1806 | " \n", 1807 | " \n", 1808 | " \n", 1809 | " \n", 1810 | " \n", 1811 | " \n", 1812 | " \n", 1813 | " \n", 1814 | " \n", 1815 | " \n", 1816 | " \n", 1817 | " \n", 1818 | " \n", 1819 | " \n", 1820 | " \n", 1821 | " \n", 1822 | " \n", 1823 | " \n", 1824 | " \n", 1825 | " \n", 1826 | " \n", 1827 | " \n", 1828 | " \n", 1829 | " \n", 1830 | " \n", 1831 | " \n", 1832 | " \n", 1833 | " \n", 1834 | "
 customer_messageexample_answerintent_labelsrationalepred_answeranswer_exact_match
0list airports in new yorkairportflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_nodetermine the intent of the customer's message. The customer is asking for a list of airports located in New York. This request is focused on...airport✔️ [True]
1show me flights from montreal to orlandoflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noproduce the answer. The customer is requesting information about flights from Montreal to Orlando. This indicates that the intent is related to finding flights between...flight✔️ [True]
2which flights leave chicago next tuesday and arrive in detroit around 6 pmflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_nodetermine the intent of the customer's message. The customer is asking for information about flights that leave from Chicago next Tuesday and arrive in Detroit...flight✔️ [True]
3list the flights from dallas to baltimore arriving july firstflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_nodetermine the intent of the customer's message. The customer is asking for a list of flights from Dallas to Baltimore that are arriving on a...flight✔️ [True]
4show flights tomorrow evening from milwaukee to st. louisflightflight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_noproduce the answer. The customer is requesting information about flights that are scheduled for tomorrow evening from Milwaukee to St. Louis. This indicates that the...flight✔️ [True]
\n" 1835 | ], 1836 | "text/plain": [ 1837 | "" 1838 | ] 1839 | }, 1840 | "metadata": {}, 1841 | "output_type": "display_data" 1842 | }, 1843 | { 1844 | "data": { 1845 | "text/html": [ 1846 | "\n", 1847 | "
\n", 1853 | " ... 35 more rows not displayed ...\n", 1854 | "
\n", 1855 | " " 1856 | ], 1857 | "text/plain": [ 1858 | "" 1859 | ] 1860 | }, 1861 | "metadata": {}, 1862 | "output_type": "display_data" 1863 | }, 1864 | { 1865 | "data": { 1866 | "text/plain": [ 1867 | "80.0" 1868 | ] 1869 | }, 1870 | "execution_count": 41, 1871 | "metadata": {}, 1872 | "output_type": "execute_result" 1873 | } 1874 | ], 1875 | "source": [ 1876 | "from dspy.evaluate import answer_exact_match as metric\n", 1877 | "from dspy.teleprompt import BootstrapFewShot\n", 1878 | "\n", 1879 | "optimizer = BootstrapFewShot(\n", 1880 | " metric=metric,\n", 1881 | " max_bootstrapped_demos=4, # no need for boostrapped\n", 1882 | " max_labeled_demos=40, # increase examples\n", 1883 | " max_rounds=20,\n", 1884 | ")\n", 1885 | "\n", 1886 | "# documentation is wrong - there is not valset: https://dspy.ai/deep-dive/optimizers/bootstrap-fewshot/\n", 1887 | "cot_few_shot_to_bootstrap = optimizer.compile(cot_predictor, trainset=train_examples)\n", 1888 | "\n", 1889 | "\n", 1890 | "evaluate_atis(cot_few_shot_to_bootstrap, metric=metric)\n", 1891 | "# bootrsapped_demos: generated demos?\n", 1892 | "# max_labeled_demos: input demos" 1893 | ] 1894 | } 1895 | ], 1896 | "metadata": { 1897 | "kernelspec": { 1898 | "display_name": "Python 3 (ipykernel)", 1899 | "language": "python", 1900 | "name": "python3" 1901 | }, 1902 | "language_info": { 1903 | "codemirror_mode": { 1904 | "name": "ipython", 1905 | "version": 3 1906 | }, 1907 | "file_extension": ".py", 1908 | "mimetype": "text/x-python", 1909 | "name": "python", 1910 | "nbconvert_exporter": "python", 1911 | "pygments_lexer": "ipython3", 1912 | "version": "3.10.12" 1913 | } 1914 | }, 1915 | "nbformat": 4, 1916 | "nbformat_minor": 5 1917 | } 1918 | --------------------------------------------------------------------------------