├── open_intent_classifier
├── __init__.py
├── utils.py
├── consts.py
├── embedder.py
└── model.py
├── .bumpversion.cfg
├── requirements.txt
├── tests
├── test_dspy_classifier.py
├── test_openai_classifier.py
├── test_devices.py
├── test_flan_t5_base.py
├── test_embeddings.py
├── test_flan_t5_small.py
└── test_smollm2.py
├── setup.py
├── .gitignore
├── README.md
└── notebooks
├── dspy_tutorial.ipynb
├── smollm2.ipynb
└── dspy_training.ipynb
/open_intent_classifier/__init__.py:
--------------------------------------------------------------------------------
1 | VERSION = "0.0.8"
2 |
3 |
--------------------------------------------------------------------------------
/.bumpversion.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 0.0.8
3 | commit = True
4 | tag = True
5 |
6 | [bumpversion:file:open_intent_classifier/__init__.py]
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers
2 | torch
3 | sentencepiece
4 | scikit-learn
5 | sentence-transformers
6 | bump2version
7 | jupyter
8 | datasets
9 | pandas
10 |
--------------------------------------------------------------------------------
/open_intent_classifier/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | def join_labels(labels):
5 | prompt_options = "Options:\n"
6 | for label in labels:
7 | prompt_options += f"# {label} \n"
8 | return prompt_options
9 |
10 | def labels_to_str(labels: List[str]):
11 | return "%".join(labels)
--------------------------------------------------------------------------------
/tests/test_dspy_classifier.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from open_intent_classifier.model import DSPyClassifier
4 |
5 |
6 | class TestDSPyClassifier(unittest.TestCase):
7 |
8 | def test_basic(self):
9 | classifier = DSPyClassifier()
10 | labels = ["Cancel subscription", "Refund request"]
11 | text = "I want to cancel my subscription"
12 | result = classifier.predict(text, labels)
13 | self.assertEquals("Cancel subscription".lower(), result.class_name.lower())
14 |
15 |
--------------------------------------------------------------------------------
/tests/test_openai_classifier.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from open_intent_classifier.model import OpenAiIntentClassifier
4 |
5 |
6 | class TestOpenAiIntentClassifier(unittest.TestCase):
7 |
8 | def test_prediction(self):
9 | classifier = OpenAiIntentClassifier(model_name="gpt-4o-mini")
10 | result = classifier.predict(text="I want to cancel subscription",
11 | labels=["cancel subscription", "cancel refund", "permission issue"])
12 |
13 | self.assertEquals("cancel subscription", result.class_name)
14 |
--------------------------------------------------------------------------------
/tests/test_devices.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import torch.cuda
4 |
5 | from open_intent_classifier.model import IntentClassifier
6 |
7 |
8 | class TestDevices(TestCase):
9 |
10 | def test_default(self):
11 | self.model = IntentClassifier()
12 |
13 | def test_cuda(self):
14 | if torch.cuda.is_available():
15 | self.model = IntentClassifier(device="cuda")
16 |
17 | def test_cpu(self):
18 | self.model = IntentClassifier(device="cpu")
19 |
20 | def test_wrong_device(self):
21 | with self.assertRaises(RuntimeError):
22 | self.model = IntentClassifier(device="wrong")
--------------------------------------------------------------------------------
/tests/test_flan_t5_base.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from open_intent_classifier.model import IntentClassifier
4 |
5 |
6 | class TestFlanT5Base(TestCase):
7 |
8 | def setUp(self) -> None:
9 | self.model = IntentClassifier()
10 |
11 | def test_sanity(self):
12 | output = self.model.predict("I want to cancel subscription", ["Cancel Subscription", "Refund"])
13 | self.assertTrue("Cancel Subscription", output)
14 |
15 | def test_cuda(self):
16 | output = self.model.predict("I want a refund", ["Refund Request", "Issues", "Bug Report"])
17 | self.assertTrue("Refund Request", output)
18 |
19 | def test_cpu(self):
20 | output = self.model.predict("I don't remember my password, can you please help me?",
21 | ["Cancel Subscription", "Refund Request", "Transfer Request", "Login Issues"])
22 | self.assertTrue("Login Issues", output)
23 |
--------------------------------------------------------------------------------
/tests/test_embeddings.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from open_intent_classifier.embedder import StaticLabelsEmbeddingClassifier
4 |
5 |
6 | class TestEmbeddings(TestCase):
7 | def setUp(self) -> None:
8 | self.model = StaticLabelsEmbeddingClassifier(["Cancel Subscription", "Refund Requests"])
9 |
10 | def test_top_1(self):
11 | labels, probabilities = self.model.predict("I want to cancel my subscription", n=1)
12 | self.assertTrue(1, len(labels))
13 | self.assertTrue(1, len(probabilities))
14 | self.assertTrue("Cancel Subscription", labels[0])
15 |
16 | def test_top_2(self):
17 | labels, probabilities = self.model.predict("I want to cancel subscription", n=2)
18 | self.assertTrue(2, len(labels))
19 | self.assertTrue(2, len(probabilities))
20 | self.assertTrue("Cancel Subscription", labels[0])
21 | self.assertTrue("Refund Requests", labels[1])
22 |
23 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | from open_intent_classifier import VERSION
4 | from pathlib import Path
5 | this_directory = Path(__file__).parent
6 | long_description = (this_directory / "README.md").read_text()
7 |
8 | setup(
9 | name="open_intent_classifier",
10 | version=VERSION,
11 | packages=["open_intent_classifier"],
12 | author="Serj Smorodinsky",
13 | url="https://github.com/SerjSmor/open-intent-classifier",
14 | author_email="serjsmor@gmail.com",
15 | description="This library has two purposes: 1. allow to easily test semantic classification with open labels (not pre defined) for intent recognition. "
16 | "2. allow to experiment with different n-shot classification components.",
17 | install_requires=["transformers", "torch", "sentencepiece", "scikit-learn", "sentence-transformers", "pandas",
18 | "python-dotenv", "openai", "dspy", "accelerate"],
19 | long_description=long_description,
20 | long_description_content_type='text/markdown'
21 | )
22 |
--------------------------------------------------------------------------------
/tests/test_flan_t5_small.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from open_intent_classifier.consts import INTENT_CLASSIFIER_80M_FLAN_T5_SMALL
4 | from open_intent_classifier.model import IntentClassifier
5 |
6 |
7 | class TestFlanT5Small(TestCase):
8 |
9 | def setUp(self) -> None:
10 | self.model = IntentClassifier(model_name=INTENT_CLASSIFIER_80M_FLAN_T5_SMALL, verbose=True)
11 |
12 | def test_sanity(self):
13 | output = self.model.predict("I want to cancel subscription", ["Cancel Subscription", "Refund"])
14 | self.assertTrue("Cancel Subscription", output)
15 |
16 | def test_cuda(self):
17 | output = self.model.predict("I want a refund", ["Refund Request", "Issues", "Bug Report"])
18 | self.assertTrue("Refund Request", output)
19 |
20 | def test_cpu(self):
21 | output = self.model.predict("I don't remember my password, can you please help me?", ["Cancel Subscription",
22 | "Refund Request", "Transfer Request", "Login Issues"])
23 | self.assertTrue("Login Issues", output)
--------------------------------------------------------------------------------
/open_intent_classifier/consts.py:
--------------------------------------------------------------------------------
1 | PROMPT_TEMPLATE = "Topic %% Customer: {text}.\nEND MESSAGE\nChoose one topic that matches customer's issue.\n {labels} \nClass name: "
2 |
3 | OPENAI_PROMPT_TEMPLATE = '''
4 | You are an expert in customer service domain. You need to classify a customer message into one of
5 | the following classes: {labels}
6 | Please return json object with the following structure: {'{class_name: '', reasoning: ''}'} class name should not contain a number.
7 | {examples}
8 | Customer message: {text} %
9 | Answer: .
10 | '''
11 |
12 | SMOLLM2_PROMPT_TEMPLATE = '''Input format:
13 | Text: the message of the customer.
14 | Intent list: all of the possible labels, delimeted by a %.
15 | Answer: the name of the intent from the list that matches the intent from the message.
16 |
17 | Answer format is one of the intent from the list.
18 | Example
19 | ---
20 | {examples}
21 | ---
22 | Text: {text}
23 | Intent List: {labels}
24 | Answer:
25 | '''
26 |
27 | INTENT_CLASSIFIER_248M_FLAN_T5_BASE = "Serj/intent-classifier"
28 | INTENT_CLASSIFIER_80M_FLAN_T5_SMALL = "Serj/intent-classifier-flan-t5-small"
29 |
30 | DEFAULT_MAX_LENGTH = 512
31 | DEFAULT_TRUNCATION = True
32 |
33 | SMOLLM2_1_7B = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
34 | SMOLLM2_360M = "HuggingFaceTB/SmolLM2-360M-Instruct"
--------------------------------------------------------------------------------
/open_intent_classifier/embedder.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import numpy as np
4 | from sentence_transformers import SentenceTransformer
5 | from sklearn.metrics.pairwise import cosine_similarity
6 |
7 | BAII_SMALL_V1_5 = "BAAI/bge-small-en-v1.5"
8 |
9 | class StaticLabelsEmbeddingClassifier:
10 | def __init__(self, labels: List[str], model_name: str = BAII_SMALL_V1_5):
11 | self.model = SentenceTransformer(model_name)
12 | self.embedded_labels = self.model.encode(labels, show_progress_bar=True)
13 | self.labels = labels
14 |
15 | def top_n(self, text: str, n: int, embeddings: np.array) -> (np.array, np.array):
16 | text_embedding = self.model.encode([text])
17 | # for matrix multiplication we need the shape to be NXM MXN vector for example: vectpr 1X384, matrix 384X3
18 | angles = cosine_similarity(text_embedding, embeddings).squeeze()
19 | # https://stackoverflow.com/a/6910672
20 | # ::-1 reverses this list, # -n: top N
21 | sorted_indices = angles.argsort()[-n:][::-1]
22 | return sorted_indices, angles[sorted_indices]
23 |
24 | def predict(self, text: str, n=1) -> Tuple[List[str], np.array]:
25 | top_n_indices, top_n_angles = self.top_n(text, n=n, embeddings=self.embedded_labels)
26 | top_n_labels = [self.labels[i] for i in top_n_indices]
27 |
28 | return top_n_labels, top_n_angles
29 |
30 |
--------------------------------------------------------------------------------
/tests/test_smollm2.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from open_intent_classifier.model import SmolLm2Classifier, ClassificationExample
4 |
5 |
6 | class TestSmolLm2(unittest.TestCase):
7 |
8 | def test_basic(self):
9 | labels = ["Cancel subscription", "Refund request"]
10 | text = "I want to cancel my subscription"
11 |
12 | c = SmolLm2Classifier()
13 | result = c.predict(text, labels)
14 | self.assertEquals("Cancel subscription".lower(), result.class_name.lower())
15 |
16 | def test_few_shot(self):
17 | labels = ["Order flight ticket", "Abbrevations", "Fare code", "Cheap price", "Meals questions", "Seating questions"]
18 | text = "I want to get a flight ticket to NYC asap"
19 |
20 | example_flight = ClassificationExample("Do you have an open ticket to a flight to Chicago?",
21 | "Order flight ticket", labels)
22 |
23 | example_price = ClassificationExample("What is the cheapest route to London?", "Price questions", labels)
24 |
25 | c = SmolLm2Classifier(verbose=True, few_shot_examples=[example_flight, example_price])
26 |
27 | result = c.predict(text, labels)
28 | self.assertEquals("Order flight ticket".lower(), result.class_name.lower())
29 |
30 |
31 | def test_hierarchy(self):
32 | labels = ["flight ticket", "price", "meal", "seats"]
33 | text = "I want to get a flight ticket to NYC asap"
34 |
35 | c = SmolLm2Classifier(verbose=True)
36 | example_flight = ClassificationExample("Do you have an open ticket to a flight to Chicago?", "flight ticket",
37 | labels)
38 |
39 | example_price = ClassificationExample("What is the price for a flight to London?", "price", labels)
40 |
41 | result = c.predict(text, labels, few_shot_examples=[example_flight, example_price])
42 | self.assertEquals("flight ticket".lower(), result.class_name.lower())
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project specific
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/#use-with-ide
112 | .pdm.toml
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | .idea/
163 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Open Intent Classification
2 | Closed intent classification uses a set of predefined labels to identify an intent.
3 | In comparison, open intent classification allows you to define as many labels you want, without fine tuning the model.
4 |
5 | This project implements different components that support open intent classification such as an embedder, a T5 based fine tuned model for intent classification and a verbalizer.
6 | If you are interested in finer detailes you can read my [blog post](https://medium.com/80-20-hacking-ai/components-that-optimize-n-shot-text-classification-f574184e0b81).
7 |
8 | The goal of this library is to enable you test your assumptions about your data as fast as possible and to be a one stop shop for everything "classification like", similarly to how [Bertopic](https://maartengr.github.io/BERTopic/index.html) is for clustering.
9 |
10 | ## Why should you use this?
11 | 1. You are researching nlp classification problems and want to test different embeddings, verbalizers and components with plug-and-play feel
12 | 2. You want to detect semantically user intents in text but either don't want to commit to pre-defined classes OR just want to test out the fastest way to classify text other than through an LLM
13 |
14 |
15 | [!IMPORTANT]
16 | > open-intent-classification project is in Alpha stage.
17 | > 1. Expect API changes
18 | > 2. Milage may vary. Quality of classifiers have been tested on Atis and Banking77
19 |
20 | ## Installation
21 | `$ pip install open-intent-classifier`
22 |
23 | ## Usage
24 | A full example is under [Atis Notebook](https://github.com/SerjSmor/open-intent-classifier/blob/main/notebooks/atis_example.ipynb)
25 |
26 | ### T5 Based Intent Classification
27 | ````python
28 | from open_intent_classifier.model import IntentClassifier
29 | model = IntentClassifier()
30 | labels = ["Cancel Subscription", "Refund Requests", "Broken Item", "And More..."]
31 | text = "I don't want to continue this subscription"
32 | predicted_label = model.predict(text, labels)
33 | ````
34 |
35 | By default, the IntentClassifier is loading a small model with 80M parameters.
36 |
37 | For higher accuracy you can initialize the model with:
38 | ```python
39 | from open_intent_classifier.model import IntentClassifier
40 | from open_intent_classifier.consts import INTENT_CLASSIFIER_248M_FLAN_T5_BASE
41 | model = IntentClassifier(INTENT_CLASSIFIER_248M_FLAN_T5_BASE)
42 | ```
43 | This will increase model latency as well.
44 |
45 |
46 | ### Embeddings Based Classification
47 | ```python
48 | from open_intent_classifier.embedder import StaticLabelsEmbeddingClassifier
49 | labels = ["Cancel Subscription", "Refund Requests", "Broken Item", "And More..."]
50 | text = "I don't want to continue this subscription"
51 | embeddings_classifier = StaticLabelsEmbeddingClassifier(labels)
52 | predicted_label = embeddings_classifier.predict(text)
53 |
54 | ```
55 |
56 | ### LLM Based Classification
57 | Using LLM for classification is a viable option that sometimes provides the highest quality.
58 | Currently we have implemented only Open AI based LLMs.
59 |
60 | ```python
61 | from open_intent_classifier.model import OpenAiIntentClassifier
62 | labels = ["Cancel Subscription", "Refund Requests", "Broken Item", "And More..."]
63 | text = "I don't want to continue this subscription"
64 | model_name = "gpt-4o-mini"
65 | classifier = OpenAiIntentClassifier(model_name)
66 | result = classifier.predict(text=text, labels=labels)
67 | ```
68 |
69 | ### LLM Few Shot Based Classification
70 | ```python
71 | from open_intent_classifier.model import OpenAiIntentClassifier, ClassificationExample
72 | labels = ["Cancel Subscription", "Refund Requests", "Broken Item", "And More..."]
73 | text = "I don't want to continue this subscription"
74 | model_name = "gpt-4o-mini"
75 | example = ClassificationExample(text="I want to abort my account", intent_labels=labels, intent="Cancel Subscription")
76 | classifier = OpenAiIntentClassifier(model_name, few_shot_examples=[example])
77 | result = classifier.predict(text=text, labels=labels)
78 | ```
79 |
80 | ### DSPy based classification
81 | Here's an example of using DSPy based classifier for both educational purposes and potential increase in quality.
82 | ```python
83 | from open_intent_classifier.model import DSPyClassifier
84 | classifier = DSPyClassifier()
85 | labels = ["Cancel subscription", "Refund request"]
86 | text = "I want to cancel my subscription"
87 | result = classifier.predict(text, labels)
88 | ```
89 |
90 | ### SmolLM2 based classification
91 | Here's an example of using SmolLM2 based classifier
92 | ```python
93 | from open_intent_classifier.model import SmolLm2Classifier
94 | classifier = SmolLm2Classifier()
95 | labels = ["Cancel subscription", "Refund request"]
96 | text = "I want to cancel my subscription"
97 | result = classifier.predict(text, labels)
98 | ```
99 |
100 | ### SmolLM2 Few Shot Based Classification
101 | from open_intent_classifier.model import SmolLm2Classifier
102 | ```python
103 | from open_intent_classifier.model import SmolLm2Classifier, ClassificationExample
104 | labels = ["Cancel subscription", "Refund request"]
105 | example = ClassificationExample(text="I want to abort my account", intent_labels=labels, intent="Cancel Subscription")
106 | classifier = SmolLm2Classifier(few_shot_examples=[example])
107 | text = "I want to cancel my subscription"
108 | result = classifier.predict(text, labels)
109 | ```
110 |
111 |
112 | ## Training the T5 base classifier
113 | The details of training of the classifier is in another repository. I have separated training from inference in order to allow each repository to be focused and extended.
114 |
115 | You can read about the training in the training repo: https://github.com/SerjSmor/intent_classification
116 |
117 | # Roadmap
118 |
119 | - [x] Add LLM based classification
120 | - [x] Add DSPy based classification
121 | - [x] Add small language models as classifiers
122 | - [x] Add few shot interface
123 | - [ ] Add Atis and Banking77 comparisons with all classifiers
124 | - [ ] Add embeddings filtering stage for classifiers
125 | - [ ] Add multithreading for LLM based classifiers
126 | - [ ] Add an option to ensemble embeddings and T5 (and additional models)
127 | - [ ] Create a recommender for fine-tuning
128 |
--------------------------------------------------------------------------------
/open_intent_classifier/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 | from typing import List
4 | import logging
5 |
6 | import dspy
7 | from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, AutoModelForCausalLM
8 | from openai import OpenAI
9 | from dotenv import load_dotenv
10 |
11 | from open_intent_classifier.consts import INTENT_CLASSIFIER_80M_FLAN_T5_SMALL, DEFAULT_MAX_LENGTH, DEFAULT_TRUNCATION, \
12 | PROMPT_TEMPLATE, OPENAI_PROMPT_TEMPLATE, SMOLLM2_1_7B, SMOLLM2_PROMPT_TEMPLATE
13 |
14 | from open_intent_classifier.utils import join_labels, labels_to_str
15 |
16 | CUDA = "cuda"
17 |
18 | GPT_4O_MINI = 'gpt-4o-mini'
19 |
20 | TEXT_PLACEHOLDER = "{text}"
21 | LABELS_PLACEHOLDER = "{labels}"
22 | EXAMPLES_PLACEHOLDER = "{examples}"
23 |
24 | # Create a logger with the filename
25 | logger = logging.getLogger(__name__)
26 | logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
27 |
28 |
29 | @dataclass
30 | class ClassificationResult:
31 | class_name: str
32 | reasoning: str
33 |
34 |
35 | @dataclass
36 | class ClassificationExample:
37 | text: str
38 | intent: str
39 | intent_labels: List[str]
40 |
41 | def __str__(self):
42 | return f'''
43 | -----
44 | Example
45 | Text: {self.text}
46 | Intents list: {labels_to_str(self.intent_labels)}
47 | Answer: {self.intent}
48 | -----
49 | '''
50 |
51 |
52 | DEFAULT_EXAMPLE = ClassificationExample(
53 | text="I want to cancel subscription",
54 | intent="Cancel Subscription",
55 | intent_labels=["Refund Request", "Delivery Late", "Cancel Subscription"]
56 | )
57 |
58 |
59 | class Classifier:
60 | def predict(self, text: str, labels: List[str], **kwargs) -> ClassificationResult:
61 | raise NotImplementedError()
62 |
63 |
64 | class Classification(dspy.Signature):
65 | """Classify the customer message into one of the intent labels.
66 | The output should be only the predicted class as a single intent label."""
67 |
68 | customer_message = dspy.InputField(desc="Customer message during customer service interaction")
69 | intent_labels = dspy.InputField(desc="Labels that represent customer intent")
70 | intent_class = dspy.OutputField(desc="a label best matching customer's intent ")
71 |
72 |
73 | class DSPyClassifier(dspy.Module, Classifier):
74 | def __init__(self, model_name=GPT_4O_MINI):
75 | super().__init__()
76 | self.generate_answer = dspy.ChainOfThought(Classification)
77 | lm = dspy.OpenAI(model=model_name)
78 | dspy.settings.configure(lm=lm)
79 |
80 | def forward(self, customer_message: str, labels: str):
81 | return self.generate_answer(customer_message=customer_message, labels=labels)
82 |
83 | def predict(self, text: str, labels: List[str], **kwargs) -> ClassificationResult:
84 | labels = join_labels(labels)
85 | pred = self.forward(customer_message=text, labels=labels)
86 | return ClassificationResult(pred.intent_class, pred.rationale)
87 |
88 |
89 | class OpenAiIntentClassifier(Classifier):
90 | def __init__(self, model_name: str, openai_api_key: str = "", few_shot_examples: List[ClassificationExample] = None):
91 | if not openai_api_key:
92 | load_dotenv()
93 | openai_api_key = os.getenv("OPENAI_API_KEY")
94 | self.client = OpenAI(api_key=openai_api_key)
95 | self.model_name = model_name
96 | self.few_shot_examples = [] if few_shot_examples is None else []
97 |
98 | def predict(self, text: str, labels: List[str], **kwargs) -> ClassificationResult:
99 | joined_labels = join_labels(labels)
100 |
101 | if len(self.few_shot_examples) == 0:
102 | self.few_shot_examples.append(DEFAULT_EXAMPLE)
103 | examples_str = ""
104 | for example in self.few_shot_examples:
105 | examples_str += str(example)
106 |
107 | prompt = OPENAI_PROMPT_TEMPLATE.replace(LABELS_PLACEHOLDER, joined_labels).replace(TEXT_PLACEHOLDER, text).\
108 | replace(EXAMPLES_PLACEHOLDER, examples_str)
109 |
110 | completion = self.client.chat.completions.create(
111 | model=self.model_name,
112 | messages=[
113 | {"role": "user", "content": prompt}
114 | ],
115 | response_format={"type": "json_object"}
116 | )
117 |
118 | response_dict = eval(completion.choices[0].message.content)
119 | return ClassificationResult(response_dict["class_name"], response_dict["reasoning"])
120 |
121 |
122 | class IntentClassifier:
123 | def __init__(self, model_name: str = INTENT_CLASSIFIER_80M_FLAN_T5_SMALL, device: str = None,
124 | verbose: bool = False):
125 | self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
126 | self.tokenizer = T5Tokenizer.from_pretrained(model_name)
127 | self.device = device
128 | if verbose:
129 | logger.setLevel(logging.DEBUG)
130 |
131 | def _build_prompt(self, text: str, labels: List[str], prompt_template: str = PROMPT_TEMPLATE):
132 | prompt_options = join_labels(labels)
133 |
134 | # first replace {labels} this way we know that {text} can't be misused with tokens that can overlap with {text}
135 | prompt = prompt_template.replace(LABELS_PLACEHOLDER, prompt_options)
136 | prompt = prompt.replace(TEXT_PLACEHOLDER, text)
137 |
138 | return prompt
139 |
140 | def predict(self, text: str, labels: List[str], **kwargs) -> ClassificationResult:
141 | prompt = self._build_prompt(text, labels)
142 | max_length = kwargs.get("max_length", DEFAULT_MAX_LENGTH)
143 | truncation = kwargs.get("truncation", DEFAULT_TRUNCATION)
144 | input_ids = self.tokenizer.encode(prompt, return_tensors="pt", max_length=max_length,
145 | truncation=truncation).to(self.device)
146 | # Generate the output
147 | output = self.model.generate(input_ids, **kwargs)
148 | # Decode the output tokens
149 | decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
150 | logger.debug(f"Full prompt: {prompt}")
151 | logger.debug(f"Decoded output: {decoded_output}")
152 | return ClassificationResult(class_name=decoded_output, reasoning="")
153 |
154 |
155 | class SmolLm2Classifier(Classifier):
156 | def __init__(self, model_name: str = SMOLLM2_1_7B, device: str = CUDA, verbose: bool = False,
157 | few_shot_examples: List[ClassificationExample] = None):
158 | tokenizer = AutoTokenizer.from_pretrained(model_name)
159 | model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
160 | self.tokenizer = tokenizer
161 | self.model = model
162 | self.device = device
163 | self.verbose = verbose
164 | self.few_shot_examples = [] if few_shot_examples is None else few_shot_examples
165 |
166 | def predict(self, text: str, labels: List[str], **kwargs) -> ClassificationResult:
167 | labels_str = "%".join(labels)
168 |
169 | if len(self.few_shot_examples) == 0:
170 | self.few_shot_examples.append(DEFAULT_EXAMPLE)
171 | examples_str = ""
172 | for example in self.few_shot_examples:
173 | examples_str += str(example)
174 |
175 | prompt = SMOLLM2_PROMPT_TEMPLATE.replace(TEXT_PLACEHOLDER, text).replace(LABELS_PLACEHOLDER, labels_str).replace(EXAMPLES_PLACEHOLDER, examples_str)
176 | if self.verbose:
177 | logger.info(prompt)
178 |
179 | messages = [{"role": "system",
180 | "content": "You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list"},
181 | {"role": "user", "content": prompt}]
182 |
183 | input_text = self.tokenizer.apply_chat_template(messages, tokenize=False)
184 |
185 | inputs = self.tokenizer.encode(input_text, return_tensors="pt").to(self.device)
186 | outputs = self.model.generate(inputs, max_new_tokens=350)
187 | output = self.tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
188 |
189 | if self.verbose:
190 | logger.info(output)
191 |
192 | last_answer_occurrence = output.rsplit("Answer: ", 2)[-1]
193 | return ClassificationResult(last_answer_occurrence, "")
194 |
195 |
196 |
--------------------------------------------------------------------------------
/notebooks/dspy_tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "da5c9e91-65f3-4c06-91f6-d7863f129848",
6 | "metadata": {},
7 | "source": [
8 | "## Basic Example"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "5f476789-3bb1-4fc9-9521-8650c0aef2d0",
14 | "metadata": {},
15 | "source": [
16 | "### First define the signature"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 21,
22 | "id": "e6d35f00-46b6-4d0e-93d6-28caa10f1b40",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "import dspy \n",
27 | "from open_intent_classifier.utils import join_labels\n",
28 | "\n",
29 | "class Classification(dspy.Signature):\n",
30 | " \"\"\"Classify the customer message into one of the intent labels.\n",
31 | " The output should be only the predicted class as a single intent label.\"\"\"\n",
32 | "\n",
33 | " customer_message = dspy.InputField(desc=\"Customer message during customer service interaction\")\n",
34 | " intent_labels = dspy.InputField(desc=\"Labels that represent customer intent\")\n",
35 | " intent_class = dspy.OutputField(desc=\"a label best matching customer's intent \")"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "id": "09b42d77-5ef4-420b-9ae4-cafa3ca81c28",
41 | "metadata": {},
42 | "source": [
43 | "### Define the languae model "
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 24,
49 | "id": "43e0debd-85b0-4d6e-9e13-3b6319165e25",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "lm_mini = dspy.OpenAI(model='gpt-4o-mini')\n",
54 | "dspy.settings.configure(lm=lm_mini)"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "id": "17c8e00e-42a3-4650-adf6-99b3c096af15",
60 | "metadata": {},
61 | "source": [
62 | "### Predict"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 25,
68 | "id": "86e29e2a-f875-4614-9518-8129c3ff8a4e",
69 | "metadata": {},
70 | "outputs": [
71 | {
72 | "data": {
73 | "text/plain": [
74 | "Prediction(\n",
75 | " intent_class='Cancel subscription'\n",
76 | ")"
77 | ]
78 | },
79 | "execution_count": 25,
80 | "metadata": {},
81 | "output_type": "execute_result"
82 | }
83 | ],
84 | "source": [
85 | "predict = dspy.Predict(Classification)\n",
86 | "labels = join_labels([\"Cancel subscription\", \"Refund request\"])\n",
87 | "labels = \" \".join([\"Cancel subscription\", \"Refund request\"])\n",
88 | "text = \"I want to cancel my subscription\"\n",
89 | "pred = predict(customer_message=text, intent_labels=labels)\n",
90 | "pred"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 26,
96 | "id": "4b8d7f04-911d-4671-89a3-3105649fdd68",
97 | "metadata": {},
98 | "outputs": [
99 | {
100 | "name": "stdout",
101 | "output_type": "stream",
102 | "text": [
103 | "\n",
104 | "\n",
105 | "\n",
106 | "Classify the customer message into one of the intent labels.\n",
107 | "The output should be only the predicted class as a single intent label.\n",
108 | "\n",
109 | "---\n",
110 | "\n",
111 | "Follow the following format.\n",
112 | "\n",
113 | "Customer Message: Customer message during customer service interaction\n",
114 | "Intent Labels: Labels that represent customer intent\n",
115 | "Intent Class: a label best matching customer's intent\n",
116 | "\n",
117 | "---\n",
118 | "\n",
119 | "Customer Message: I want to cancel my subscription\n",
120 | "Intent Labels: Cancel subscription Refund request\n",
121 | "Intent Class:\u001b[32m Cancel subscription\u001b[0m\n",
122 | "\n",
123 | "\n",
124 | "\n"
125 | ]
126 | },
127 | {
128 | "data": {
129 | "text/plain": [
130 | "\"\\n\\n\\nClassify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.\\n\\n---\\n\\nFollow the following format.\\n\\nCustomer Message: Customer message during customer service interaction\\nIntent Labels: Labels that represent customer intent\\nIntent Class: a label best matching customer's intent\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class:\\x1b[32m Cancel subscription\\x1b[0m\\n\\n\\n\""
131 | ]
132 | },
133 | "execution_count": 26,
134 | "metadata": {},
135 | "output_type": "execute_result"
136 | }
137 | ],
138 | "source": [
139 | "lm_mini.inspect_history(n=1)"
140 | ]
141 | },
142 | {
143 | "cell_type": "markdown",
144 | "id": "c7dbd608-7cb3-4d3e-987e-853aa23e188a",
145 | "metadata": {},
146 | "source": [
147 | "## Using Chain Of Thought (CoT)"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 27,
153 | "id": "61c51226-f2a0-4bd8-810e-95be5c43ff40",
154 | "metadata": {},
155 | "outputs": [
156 | {
157 | "name": "stderr",
158 | "output_type": "stream",
159 | "text": [
160 | "httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n"
161 | ]
162 | },
163 | {
164 | "data": {
165 | "text/plain": [
166 | "Prediction(\n",
167 | " rationale='produce the intent_class. We first identify that the customer is expressing a desire to terminate an ongoing service, which directly aligns with the intent of canceling a subscription. There is no mention of seeking a refund or any other request, so the primary focus is on the cancellation itself.',\n",
168 | " intent_class='Cancel subscription'\n",
169 | ")"
170 | ]
171 | },
172 | "execution_count": 27,
173 | "metadata": {},
174 | "output_type": "execute_result"
175 | }
176 | ],
177 | "source": [
178 | "predict_with_chain_of_thought = dspy.ChainOfThought(Classification)\n",
179 | "pred = predict_with_chain_of_thought(customer_message=text, intent_labels=labels)\n",
180 | "pred"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": 28,
186 | "id": "52dac266-0db6-4b5c-9c8b-cf930175d9eb",
187 | "metadata": {},
188 | "outputs": [
189 | {
190 | "data": {
191 | "text/plain": [
192 | "'Cancel subscription'"
193 | ]
194 | },
195 | "execution_count": 28,
196 | "metadata": {},
197 | "output_type": "execute_result"
198 | }
199 | ],
200 | "source": [
201 | "pred.intent_class"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 29,
207 | "id": "3534fb75-fb66-4a4b-887d-6fedb07bdf10",
208 | "metadata": {},
209 | "outputs": [
210 | {
211 | "data": {
212 | "text/plain": [
213 | "'produce the intent_class. We first identify that the customer is expressing a desire to terminate an ongoing service, which directly aligns with the intent of canceling a subscription. There is no mention of seeking a refund or any other request, so the primary focus is on the cancellation itself.'"
214 | ]
215 | },
216 | "execution_count": 29,
217 | "metadata": {},
218 | "output_type": "execute_result"
219 | }
220 | ],
221 | "source": [
222 | "pred.rationale"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": 30,
228 | "id": "73a919eb-8e23-440a-b26c-b66643d4f7ea",
229 | "metadata": {},
230 | "outputs": [
231 | {
232 | "name": "stdout",
233 | "output_type": "stream",
234 | "text": [
235 | "\n",
236 | "\n",
237 | "\n",
238 | "Classify the customer message into one of the intent labels.\n",
239 | "The output should be only the predicted class as a single intent label.\n",
240 | "\n",
241 | "---\n",
242 | "\n",
243 | "Follow the following format.\n",
244 | "\n",
245 | "Customer Message: Customer message during customer service interaction\n",
246 | "\n",
247 | "Intent Labels: Labels that represent customer intent\n",
248 | "\n",
249 | "Reasoning: Let's think step by step in order to ${produce the intent_class}. We ...\n",
250 | "\n",
251 | "Intent Class: a label best matching customer's intent\n",
252 | "\n",
253 | "---\n",
254 | "\n",
255 | "Customer Message: I want to cancel my subscription\n",
256 | "\n",
257 | "Intent Labels: Cancel subscription Refund request\n",
258 | "\n",
259 | "Reasoning: Let's think step by step in order to\u001b[32m produce the intent_class. We first identify that the customer is expressing a desire to terminate an ongoing service, which directly aligns with the intent of canceling a subscription. There is no mention of seeking a refund or any other request, so the primary focus is on the cancellation itself.\n",
260 | "\n",
261 | "Intent Class: Cancel subscription\u001b[0m\n",
262 | "\n",
263 | "\n",
264 | "\n"
265 | ]
266 | },
267 | {
268 | "data": {
269 | "text/plain": [
270 | "\"\\n\\n\\nClassify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.\\n\\n---\\n\\nFollow the following format.\\n\\nCustomer Message: Customer message during customer service interaction\\n\\nIntent Labels: Labels that represent customer intent\\n\\nReasoning: Let's think step by step in order to ${produce the intent_class}. We ...\\n\\nIntent Class: a label best matching customer's intent\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\n\\nIntent Labels: Cancel subscription Refund request\\n\\nReasoning: Let's think step by step in order to\\x1b[32m produce the intent_class. We first identify that the customer is expressing a desire to terminate an ongoing service, which directly aligns with the intent of canceling a subscription. There is no mention of seeking a refund or any other request, so the primary focus is on the cancellation itself.\\n\\nIntent Class: Cancel subscription\\x1b[0m\\n\\n\\n\""
271 | ]
272 | },
273 | "execution_count": 30,
274 | "metadata": {},
275 | "output_type": "execute_result"
276 | }
277 | ],
278 | "source": [
279 | "lm_mini.inspect_history(n=1)"
280 | ]
281 | },
282 | {
283 | "cell_type": "markdown",
284 | "id": "f387a663-b0df-40f8-9ea7-f88d3f03edf8",
285 | "metadata": {},
286 | "source": [
287 | "## Example of a bug - not specifying output properly"
288 | ]
289 | },
290 | {
291 | "cell_type": "markdown",
292 | "id": "92ec67e1-294e-4cf8-a0e7-babcb3299b50",
293 | "metadata": {},
294 | "source": [
295 | "### My instruction: \n",
296 | "\"\"\"Classify the customer message into one of the intent labels.\"\"\" \n",
297 | "### is not specific enough"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": 31,
303 | "id": "1461b8ea-13f1-4c5f-ad4a-e48e33c1678d",
304 | "metadata": {},
305 | "outputs": [],
306 | "source": [
307 | "import dspy \n",
308 | "from open_intent_classifier.utils import join_labels\n",
309 | "\n",
310 | "class Classification(dspy.Signature):\n",
311 | " \"\"\"Classify the customer message into one of the intent labels.\"\"\"\n",
312 | "\n",
313 | " customer_message = dspy.InputField(desc=\"Customer message during customer service interaction\")\n",
314 | " intent_labels = dspy.InputField(desc=\"Labels that represent customer intent\")\n",
315 | " intent_class = dspy.OutputField(desc=\"a label best matching customer's intent \")"
316 | ]
317 | },
318 | {
319 | "cell_type": "markdown",
320 | "id": "55bdc918-e918-48b3-8aaf-22c26f35091f",
321 | "metadata": {},
322 | "source": [
323 | "### Intent class is the whole message, instead of being the class"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 17,
329 | "id": "ec7d46bd-d247-4eac-afa9-92b168f34f3a",
330 | "metadata": {},
331 | "outputs": [
332 | {
333 | "name": "stderr",
334 | "output_type": "stream",
335 | "text": [
336 | "httpx - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n"
337 | ]
338 | },
339 | {
340 | "data": {
341 | "text/plain": [
342 | "Prediction(\n",
343 | " intent_class='Customer Message: I want to cancel my subscription \\nIntent Labels: Cancel subscription \\nIntent Class: Cancel subscription'\n",
344 | ")"
345 | ]
346 | },
347 | "execution_count": 17,
348 | "metadata": {},
349 | "output_type": "execute_result"
350 | }
351 | ],
352 | "source": [
353 | "predict = dspy.Predict(Classification)\n",
354 | "labels = \" \".join([\"Cancel subscription\", \"Refund request\"])\n",
355 | "text = \"I want to cancel my subscription\"\n",
356 | "pred = predict(customer_message=text, intent_labels=labels)\n",
357 | "pred"
358 | ]
359 | }
360 | ],
361 | "metadata": {
362 | "kernelspec": {
363 | "display_name": "Python 3 (ipykernel)",
364 | "language": "python",
365 | "name": "python3"
366 | },
367 | "language_info": {
368 | "codemirror_mode": {
369 | "name": "ipython",
370 | "version": 3
371 | },
372 | "file_extension": ".py",
373 | "mimetype": "text/x-python",
374 | "name": "python",
375 | "nbconvert_exporter": "python",
376 | "pygments_lexer": "ipython3",
377 | "version": "3.10.12"
378 | }
379 | },
380 | "nbformat": 4,
381 | "nbformat_minor": 5
382 | }
383 |
--------------------------------------------------------------------------------
/notebooks/smollm2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "5a8142e2-4f4a-48c9-bd1a-82a596c08fbe",
6 | "metadata": {},
7 | "source": [
8 | "# 1.7B Instruct"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 8,
14 | "id": "299d5cd3-36e9-49f7-87e7-5802e1c6b75a",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "# pip install transformers\n",
19 | "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
20 | "# checkpoint = \"HuggingFaceTB/SmolLM2-360M\"\n",
21 | "# checkpoint = \"HuggingFaceTB/SmolLM2-360M-Instruct\"\n",
22 | "def get_model_tokenizer(checkpoint):\n",
23 | " device = \"cuda\" # for GPU usage or \"cpu\" for CPU usage\n",
24 | " tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
25 | " # for multiple GPUs install accelerate and do `model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map=\"auto\")`\n",
26 | " model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)\n",
27 | " return model, tokenizer\n",
28 | "\n",
29 | "big_model = \"HuggingFaceTB/SmolLM2-1.7B-Instruct\"\n",
30 | "small_model = \"HuggingFaceTB/SmolLM2-360M-Instruct\""
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 2,
36 | "id": "9d8d0e81-7ea0-4629-8fa0-c49cbdae7a65",
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "model, tokenizer = get_model_tokenizer(big_model)"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 5,
46 | "id": "b837c84f-4d1d-4988-bad6-83bc899e0fc6",
47 | "metadata": {},
48 | "outputs": [
49 | {
50 | "name": "stderr",
51 | "output_type": "stream",
52 | "text": [
53 | "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
54 | "/home/serj/dev/open-intent-classifier/venv/lib/python3.10/site-packages/transformers/generation/utils.py:1258: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
55 | " warnings.warn(\n"
56 | ]
57 | },
58 | {
59 | "name": "stdout",
60 | "output_type": "stream",
61 | "text": [
62 | "Gravity is a fundamental force of nature that affects all objects with mass, from the smallest subatomic\n"
63 | ]
64 | }
65 | ],
66 | "source": [
67 | "device = \"cuda\"\n",
68 | "inputs = tokenizer.encode(\"Gravity is\", return_tensors=\"pt\").to(device)\n",
69 | "outputs = model.generate(inputs)\n",
70 | "print(tokenizer.decode(outputs[0]))"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 31,
76 | "id": "ba388418-8b1a-48ef-9beb-18f6f83ac135",
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "# check the chat template\n",
81 | "prompt_template = '''You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list.\n",
82 | " Format, the list of intents is delimeted with %.\n",
83 | " You must choose only one\n",
84 | " Text: {text}\n",
85 | " Intent List: {labels}.\n",
86 | "\n",
87 | " Answer format: {class: }\n",
88 | " Answer: \n",
89 | " '''\n",
90 | "\n",
91 | "# def predict(text, labels, device=\"cuda\"): \n",
92 | "# labels_str = \"%\".join(labels)\n",
93 | "# prompt = prompt_template.replace(\"{text}\", text).replace(\"{labels}\", labels_str)\n",
94 | "# print(prompt)\n",
95 | "# inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
96 | "# outputs = model.generate(inputs, max_new_tokens=200)\n",
97 | "# return tokenizer.decode(outputs[0])\n",
98 | "\n",
99 | "def predict_roles(text, labels, device=\"cuda\"):\n",
100 | " \n",
101 | " labels_str = \"%\".join(labels)\n",
102 | "\n",
103 | " prompt_format = '''Format, the list of intents is delimeted with %.\n",
104 | " You must choose only one\n",
105 | " Text: {text}\n",
106 | " Intent List: {labels}.\n",
107 | " Answer format is json: {class: }\n",
108 | " You must follow this format'''\n",
109 | " prompt = prompt_format.replace(\"{text}\", text).replace(\"{labels}\", labels_str)\n",
110 | " messages = [{\"role\": \"system\", \"content\": \"You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list\"},\n",
111 | " {\"role\": \"user\", \"content\": prompt}]\n",
112 | " \n",
113 | " input_text=tokenizer.apply_chat_template(messages, tokenize=False)\n",
114 | " inputs = tokenizer.encode(input_text, return_tensors=\"pt\").to(device)\n",
115 | " outputs = model.generate(inputs, max_new_tokens=50)\n",
116 | " return tokenizer.decode(outputs[0])\n",
117 | "\n",
118 | "def predict_small_model(text, labels, device=\"cuda\"):\n",
119 | " labels_str = \"%\".join(labels)\n",
120 | "\n",
121 | " prompt_format = '''Format, the list of intents is delimeted with %.\n",
122 | " Text: is the text of the customer.\n",
123 | " Intent List: is the list of possible intents for the text. You need to choose one.\n",
124 | " You must choose only one\n",
125 | " --\n",
126 | " Your answer must include only the name of the class, without any prefix, suffix or inbetween nicities.\n",
127 | " You must follow this format.\n",
128 | " -- \n",
129 | " Here is an example:\n",
130 | " Text: I have an issues with my iPhone\n",
131 | " Intent List: technical issue % cancel subscription\n",
132 | " Class: technical issue\n",
133 | " --\n",
134 | " Text: {text}\n",
135 | " Intent List: {labels}.\n",
136 | " Class: '''\n",
137 | " prompt = prompt_format.replace(\"{text}\", text).replace(\"{labels}\", labels_str)\n",
138 | " messages = [{\"role\": \"system\", \"content\": \"You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list\"},\n",
139 | " {\"role\": \"user\", \"content\": prompt}]\n",
140 | " \n",
141 | " input_text=tokenizer.apply_chat_template(messages, tokenize=False)\n",
142 | " inputs = tokenizer.encode(input_text, return_tensors=\"pt\").to(device)\n",
143 | " outputs = model.generate(inputs, max_new_tokens=100)\n",
144 | " return tokenizer.decode(outputs[0])\n",
145 | "\n",
146 | "\n",
147 | "def predict_small_model_v2(text, labels, device=\"cuda\"):\n",
148 | " labels_str = \"%\".join(labels)\n",
149 | "\n",
150 | " # prompt_format = '''Format, the list of intents is delimeted with %.\n",
151 | " # Text: is the text of the customer.\n",
152 | " # Intent List: is the list of possible intents for the text. You need to choose one.\n",
153 | " # You must choose only one\n",
154 | " # --\n",
155 | " # Your answer must include only the name of the class, without any prefix, suffix or inbetween nicities.\n",
156 | " # You must follow this format.\n",
157 | " # -- \n",
158 | " # Here is an example:\n",
159 | " # Text: I have an issues with my iPhone\n",
160 | " # Intent List: technical issue % cancel subscription\n",
161 | " # Class: technical issue\n",
162 | " # --\n",
163 | " # Text: {text}\n",
164 | " # Intent List: {labels}.\n",
165 | " # Class: '''\n",
166 | " \n",
167 | " \n",
168 | " # prompt = prompt_format.replace(\"{text}\", text).replace(\"{labels}\", labels_str)\n",
169 | " prompt = f\"Customer Text: {text}. Most matching intent from list: \"\n",
170 | " messages = [{\"role\": \"system\", \"content\": f\"You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list: {labels_str}. Do not return any text other than the text of the most matching intent from the list.\"},\n",
171 | " {\"role\": \"user\", \"content\": prompt}]\n",
172 | " \n",
173 | " input_text=tokenizer.apply_chat_template(messages, tokenize=False)\n",
174 | " inputs = tokenizer.encode(input_text, return_tensors=\"pt\").to(device)\n",
175 | " outputs = model.generate(inputs, max_new_tokens=100)\n",
176 | " return tokenizer.decode(outputs[0])\n"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 10,
182 | "id": "78119cd1-de18-4273-bf2c-33912f36f918",
183 | "metadata": {},
184 | "outputs": [
185 | {
186 | "name": "stdout",
187 | "output_type": "stream",
188 | "text": [
189 | "You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list.\n",
190 | " Format, the list of intents is delimeted with %.\n",
191 | " You must choose only one\n",
192 | " Text: Hey, I want to cancel subscription\n",
193 | " Intent List: cancel subscription%refund requests%bug%issue%unknown.\n",
194 | "\n",
195 | " Answer format: {class: }\n",
196 | " Answer: \n",
197 | " \n"
198 | ]
199 | },
200 | {
201 | "data": {
202 | "text/plain": [
203 | "'You are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list.\\n Format, the list of intents is delimeted with %.\\n You must choose only one\\n Text: Hey, I want to cancel subscription\\n Intent List: cancel subscription%refund requests%bug%issue%unknown.\\n\\n Answer format: {class: }\\n Answer: \\n 1. cancel subscription\\n 2. refund requests\\n 3. bug\\n 4. issue\\n 5. unknown\\n\\nText: I want to cancel my subscription\\nIntent List: cancel subscription%refund requests%bug%issue%unknown.\\n\\n Answer format: {class: }\\n Answer: \\n 1. cancel subscription\\n 2. refund requests\\n 3. bug\\n 4. issue\\n 5. unknown\\n\\nText: I want to cancel my subscription\\nIntent List: cancel subscription%refund requests%bug%issue%unknown.\\n\\n Answer format: {class: }\\n Answer: \\n 1. cancel subscription\\n 2. refund requests\\n 3. bug\\n 4. issue\\n 5. unknown\\n\\nText: I want to cancel my subscription\\nIntent List: cancel subscription%refund requests%bug%issue%unknown.\\n\\n Answer format: {class: }\\n Answer: \\n 1. cancel subscription\\n 2. refund requests\\n 3. bug\\n '"
204 | ]
205 | },
206 | "execution_count": 10,
207 | "metadata": {},
208 | "output_type": "execute_result"
209 | }
210 | ],
211 | "source": [
212 | "text = \"Hey, I want to cancel subscription\"\n",
213 | "labels = [\"cancel subscription\", \"refund requests\", \"bug\", \"issue\", \"unknown\"]\n",
214 | "predict(text, labels)"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 14,
220 | "id": "d89e9619-0bbe-40a0-be94-1ac02f1358e1",
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "data": {
225 | "text/plain": [
226 | "'<|im_start|>system\\nYou are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list<|im_end|>\\n<|im_start|>user\\nFormat, the list of intents is delimeted with %.\\n You must choose only one\\n Text: Hey, I want to cancel subscription\\n Intent List: cancel subscription%refund requests%bug%issue%unknown.\\n Answer format is json: {class: }\\n You must follow this format<|im_end|>\\n<|im_start|>assistant\\n{\\n \"class\": \"cancel_subscription\"\\n}<|im_end|>'"
227 | ]
228 | },
229 | "execution_count": 14,
230 | "metadata": {},
231 | "output_type": "execute_result"
232 | }
233 | ],
234 | "source": [
235 | "predict_roles(text, labels)"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": 15,
241 | "id": "e82ac4a8-a30d-4b86-b24a-f3eafc839b35",
242 | "metadata": {},
243 | "outputs": [
244 | {
245 | "data": {
246 | "text/plain": [
247 | "'<|im_start|>system\\nYou are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list<|im_end|>\\n<|im_start|>user\\nFormat, the list of intents is delimeted with %.\\n You must choose only one\\n Text: I hate this job and your service, I have to get my money back\\n Intent List: cancel subscription%refund requests%bug%issue%unknown.\\n Answer format is json: {class: }\\n You must follow this format<|im_end|>\\n<|im_start|>assistant\\n{\\n \"class\": \"refund_requests\"\\n}<|im_end|>'"
248 | ]
249 | },
250 | "execution_count": 15,
251 | "metadata": {},
252 | "output_type": "execute_result"
253 | }
254 | ],
255 | "source": [
256 | "predict_roles(\"I hate this job and your service, I have to get my money back\", labels)"
257 | ]
258 | },
259 | {
260 | "cell_type": "markdown",
261 | "id": "1c66a6e3-2d4d-43ae-90c6-a63badd8fa5d",
262 | "metadata": {},
263 | "source": [
264 | "# 360M Instruct"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 2,
270 | "id": "c6d24c4c-bd0b-4632-b116-cfd4b74eab21",
271 | "metadata": {},
272 | "outputs": [],
273 | "source": [
274 | "model, tokenizer = get_model_tokenizer(small_model)"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": 32,
280 | "id": "08fbfb4c-c8ee-47a2-b3ec-df3a70fbcc25",
281 | "metadata": {},
282 | "outputs": [
283 | {
284 | "data": {
285 | "text/plain": [
286 | "'<|im_start|>system\\nYou are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list: cancel subscription%refund requests%bug%issue%unknown. Do not return any text other than the text of the most matching intent from the list.<|im_end|>\\n<|im_start|>user\\nCustomer Text: Hey, I want to cancel subscription. Most matching intent from list: <|im_end|>\\n<|im_start|>assistant\\nCancel subscription.<|im_end|>'"
287 | ]
288 | },
289 | "execution_count": 32,
290 | "metadata": {},
291 | "output_type": "execute_result"
292 | }
293 | ],
294 | "source": [
295 | "predict_small_model_v2(text, labels)\n"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": 33,
301 | "id": "06a28a7a-51fa-4010-ba06-80d9177e4a5d",
302 | "metadata": {},
303 | "outputs": [
304 | {
305 | "data": {
306 | "text/plain": [
307 | "\"<|im_start|>system\\nYou are a customer service expert. Your goal is to predict what is the intent of the user from a predfined list: cancel subscription%refund requests%bug%issue%unknown. Do not return any text other than the text of the most matching intent from the list.<|im_end|>\\n<|im_start|>user\\nCustomer Text: I hate this job and your service, I have to get my money back. Most matching intent from list: <|im_end|>\\n<|im_start|>assistant\\nI apologize for the inconvenience. I'm sorry to hear that you're having issues with your service. I'm here to help. Can you please provide more details about the issue you're experiencing?<|im_end|>\""
308 | ]
309 | },
310 | "execution_count": 33,
311 | "metadata": {},
312 | "output_type": "execute_result"
313 | }
314 | ],
315 | "source": [
316 | "predict_small_model_v2(\"I hate this job and your service, I have to get my money back\", labels)"
317 | ]
318 | },
319 | {
320 | "cell_type": "markdown",
321 | "id": "c34634a1-ebef-464c-b9f1-1292e76f2a31",
322 | "metadata": {},
323 | "source": [
324 | "# 1.7B with DSPY"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": 16,
330 | "id": "97c9d700-9ea1-4532-ad45-1474cf9f502d",
331 | "metadata": {},
332 | "outputs": [
333 | {
334 | "name": "stderr",
335 | "output_type": "stream",
336 | "text": [
337 | "accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).\n",
338 | "accelerate.big_modeling - WARNING - Some parameters are on the meta device because they were offloaded to the cpu.\n"
339 | ]
340 | }
341 | ],
342 | "source": [
343 | "import dspy\n",
344 | "from open_intent_classifier.model import Classification\n",
345 | "\n",
346 | "llm = dspy.HFModel(model=big_model)\n",
347 | "dspy.settings.configure(lm=llm)\n",
348 | "predict = dspy.Predict(Classification)"
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": 19,
354 | "id": "6e6ffc8c-754e-447b-b606-455425338c2e",
355 | "metadata": {},
356 | "outputs": [
357 | {
358 | "data": {
359 | "text/plain": [
360 | "['Please answer. what is gravity?']"
361 | ]
362 | },
363 | "execution_count": 19,
364 | "metadata": {},
365 | "output_type": "execute_result"
366 | }
367 | ],
368 | "source": [
369 | "llm(\"Please answer. what is gravity?\")"
370 | ]
371 | },
372 | {
373 | "cell_type": "code",
374 | "execution_count": 10,
375 | "id": "92626ce7-105c-4d50-b46f-972c8a7186ac",
376 | "metadata": {},
377 | "outputs": [
378 | {
379 | "name": "stderr",
380 | "output_type": "stream",
381 | "text": [
382 | "/home/serj/dev/open-intent-classifier/venv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
383 | " warnings.warn(\n"
384 | ]
385 | },
386 | {
387 | "data": {
388 | "text/plain": [
389 | "Prediction(\n",
390 | " intent_class=\"Classify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.\\n\\n---\\n\\nFollow the following format.\\n\\nCustomer Message: Customer message during customer service interaction\\nIntent Labels: Labels that represent customer intent\\nIntent Class: a label best matching customer's intent\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Cancel subscription\\n\\nCustomer Message: I want to change my subscription plan\\nIntent Labels: Change subscription plan Refund request\\nIntent Class: Change subscription plan\\n\\nCustomer Message: I want to cancel my subscription and get a refund\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Cancel subscription\\n\\nCustomer Message: I want to change my subscription plan and get a refund\\nIntent Labels: Change subscription plan Refund request\\nIntent Class: Change subscription plan\\n\\nCustomer Message: I want to cancel my subscription and get a refund, but I want to keep my account\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Cancel subscription\"\n",
391 | ")"
392 | ]
393 | },
394 | "execution_count": 10,
395 | "metadata": {},
396 | "output_type": "execute_result"
397 | }
398 | ],
399 | "source": [
400 | "predict = dspy.Predict(Classification)\n",
401 | "labels = ([\"Cancel subscription\", \"Refund request\"])\n",
402 | "labels = \" \".join([\"Cancel subscription\", \"Refund request\"])\n",
403 | "text = \"I want to cancel my subscription\"\n",
404 | "pred = predict(customer_message=text, intent_labels=labels)\n",
405 | "pred\n"
406 | ]
407 | },
408 | {
409 | "cell_type": "markdown",
410 | "id": "655a933f-dcb8-4d18-88b5-9dc8792fb2cd",
411 | "metadata": {},
412 | "source": [
413 | "# 360M with DSPy"
414 | ]
415 | },
416 | {
417 | "cell_type": "code",
418 | "execution_count": 5,
419 | "id": "64956afb-876b-44cf-b298-579141b9a224",
420 | "metadata": {},
421 | "outputs": [
422 | {
423 | "name": "stderr",
424 | "output_type": "stream",
425 | "text": [
426 | "accelerate.utils.modeling - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).\n"
427 | ]
428 | }
429 | ],
430 | "source": [
431 | "import dspy\n",
432 | "from open_intent_classifier.model import Classification\n",
433 | "\n",
434 | "llm = dspy.HFModel(model=small_model)\n",
435 | "dspy.settings.configure(lm=llm)\n",
436 | "predict = dspy.Predict(Classification)"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": 7,
442 | "id": "1518566f-78f3-4221-9499-c0700cbd70df",
443 | "metadata": {},
444 | "outputs": [
445 | {
446 | "name": "stderr",
447 | "output_type": "stream",
448 | "text": [
449 | "root - WARNING - \t*** In DSPy 2.5, all LM clients except `dspy.LM` are deprecated, underperform, and are about to be deleted. ***\n",
450 | " \t\tYou are using the client HFModel, which will be removed in DSPy 2.6.\n",
451 | " \t\tChanging the client is straightforward and will let you use new features (Adapters) that improve the consistency of LM outputs, especially when using chat LMs. \n",
452 | "\n",
453 | " \t\tLearn more about the changes and how to migrate at\n",
454 | " \t\thttps://github.com/stanfordnlp/dspy/blob/main/examples/migration.ipynb\n",
455 | "/home/serj/dev/open-intent-classifier/venv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
456 | " warnings.warn(\n"
457 | ]
458 | },
459 | {
460 | "data": {
461 | "text/plain": [
462 | "Prediction(\n",
463 | " intent_class=\"Classify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.\\n\\n---\\n\\nFollow the following format.\\n\\nCustomer Message: Customer message during customer service interaction\\nIntent Labels: Labels that represent customer intent\\nIntent Class: a label best matching customer's intent\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Refund request\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Refund request\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Refund request\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Refund request\\n\\n---\\n\\nCustomer Message: I want to cancel my subscription\\nIntent Labels: Cancel subscription Refund request\\nIntent Class: Refund request\\n\\n---\\n\\nCustomer Message: I want to\"\n",
464 | ")"
465 | ]
466 | },
467 | "execution_count": 7,
468 | "metadata": {},
469 | "output_type": "execute_result"
470 | }
471 | ],
472 | "source": [
473 | "predict = dspy.Predict(Classification)\n",
474 | "labels = ([\"Cancel subscription\", \"Refund request\"])\n",
475 | "labels = \" \".join([\"Cancel subscription\", \"Refund request\"])\n",
476 | "text = \"I want to cancel my subscription\"\n",
477 | "pred = predict(customer_message=text, intent_labels=labels)\n",
478 | "pred\n"
479 | ]
480 | }
481 | ],
482 | "metadata": {
483 | "kernelspec": {
484 | "display_name": "Python 3 (ipykernel)",
485 | "language": "python",
486 | "name": "python3"
487 | },
488 | "language_info": {
489 | "codemirror_mode": {
490 | "name": "ipython",
491 | "version": 3
492 | },
493 | "file_extension": ".py",
494 | "mimetype": "text/x-python",
495 | "name": "python",
496 | "nbconvert_exporter": "python",
497 | "pygments_lexer": "ipython3",
498 | "version": "3.10.12"
499 | }
500 | },
501 | "nbformat": 4,
502 | "nbformat_minor": 5
503 | }
504 |
--------------------------------------------------------------------------------
/notebooks/dspy_training.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "24b400eb-2039-4e22-812a-c784161f329d",
6 | "metadata": {},
7 | "source": [
8 | "# Content\n",
9 | "## Define Classifier\n",
10 | "## Define Datasets\n",
11 | "## Define DSPy.Examples\n",
12 | "## Define Optimizers\n",
13 | "## Define Evaluation\n",
14 | "## Run Evaluation\n",
15 | "## Save Prompts"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "id": "b94e80f3-3447-43ea-9543-ff73df3c4805",
21 | "metadata": {},
22 | "source": [
23 | "# Define classifier"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 1,
29 | "id": "39778e03-1631-4d90-a7cd-1c82f5b948ac",
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "import dspy \n",
34 | "import random\n",
35 | "import pandas as pd\n",
36 | "from datasets import load_dataset\n",
37 | "\n",
38 | "class Classification(dspy.Signature):\n",
39 | " \"\"\"Classify the customer message into one of the intent labels.\n",
40 | " The output should be only the predicted class as a single intent label.\"\"\"\n",
41 | "\n",
42 | " customer_message = dspy.InputField(desc=\"Customer message during customer service interaction\")\n",
43 | " intent_labels = dspy.InputField(desc=\"Labels that represent customer intent\")\n",
44 | " answer = dspy.OutputField(desc=\"a label best matching customer's intent \")\n",
45 | "\n",
46 | "lm_mini = dspy.OpenAI(model='gpt-4o-mini')\n",
47 | "dspy.settings.configure(lm=lm_mini)\n",
48 | "cot_predictor = dspy.ChainOfThought(Classification)\n"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "id": "cf11469a-598f-4962-97c4-4ff2055d8086",
54 | "metadata": {},
55 | "source": [
56 | "# Parse Atis Dataset"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 2,
62 | "id": "16594105-5855-4c13-801d-bb3055580136",
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "\n",
67 | "dataset = load_dataset(\"tuetschek/atis\")\n",
68 | "dataset.set_format(type=\"pandas\")\n",
69 | "\n",
70 | "df_train: pd.DataFrame = dataset[\"train\"][:]\n",
71 | "df_test: pd.DataFrame = dataset[\"test\"][:]\n",
72 | "small_test = df_test.head(100)"
73 | ]
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "id": "a1cba8a9-00dd-4ba3-9486-925b0bb2ea8d",
78 | "metadata": {},
79 | "source": [
80 | "## x column: text, y column: intent"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": 3,
86 | "id": "e08e4e74-d202-48a4-8a14-0ab68729a462",
87 | "metadata": {},
88 | "outputs": [
89 | {
90 | "data": {
91 | "text/plain": [
92 | "id 0\n",
93 | "intent flight\n",
94 | "text i want to fly from boston at 838 am and arrive...\n",
95 | "slots O O O O O B-fromloc.city_name O B-depart_time....\n",
96 | "Name: 0, dtype: object"
97 | ]
98 | },
99 | "execution_count": 3,
100 | "metadata": {},
101 | "output_type": "execute_result"
102 | }
103 | ],
104 | "source": [
105 | "df_train.iloc[0]"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "id": "78305be5-30a9-4cf5-a2ac-9c66ecca03cc",
111 | "metadata": {},
112 | "source": [
113 | "## prepare labels"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 4,
119 | "id": "3632bcc1-c391-43f9-b927-63672ad7d056",
120 | "metadata": {},
121 | "outputs": [
122 | {
123 | "data": {
124 | "text/plain": [
125 | "'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'"
126 | ]
127 | },
128 | "execution_count": 4,
129 | "metadata": {},
130 | "output_type": "execute_result"
131 | }
132 | ],
133 | "source": [
134 | "labels = df_train[\"intent\"].unique().tolist()\n",
135 | "labels_str = \"%\".join(labels)\n",
136 | "labels_str"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 5,
142 | "id": "681e4647-fd7e-4b80-a570-77f44f966485",
143 | "metadata": {},
144 | "outputs": [
145 | {
146 | "name": "stderr",
147 | "output_type": "stream",
148 | "text": [
149 | "WARNING:root:\t*** In DSPy 2.5, all LM clients except `dspy.LM` are deprecated, underperform, and are about to be deleted. ***\n",
150 | " \t\tYou are using the client GPT3, which will be removed in DSPy 2.6.\n",
151 | " \t\tChanging the client is straightforward and will let you use new features (Adapters) that improve the consistency of LM outputs, especially when using chat LMs. \n",
152 | "\n",
153 | " \t\tLearn more about the changes and how to migrate at\n",
154 | " \t\thttps://github.com/stanfordnlp/dspy/blob/main/examples/migration.ipynb\n"
155 | ]
156 | },
157 | {
158 | "name": "stdout",
159 | "output_type": "stream",
160 | "text": [
161 | "customer message: i want to fly from boston at 838 am and arrive in denver at 1110 in the morning,real class: flight\n"
162 | ]
163 | },
164 | {
165 | "data": {
166 | "text/plain": [
167 | "Prediction(\n",
168 | " rationale='produce the answer. We first identify that the customer is expressing a desire to fly from Boston to Denver, specifying both the departure time (8:38 AM) and the arrival time (11:10 AM). This indicates that the customer is looking for flight information, particularly related to the flight time and possibly the airfare. The intent is primarily focused on booking or inquiring about a flight, which aligns with the \"flight_time\" and \"airfare\" labels. However, since the customer is providing specific times, the most relevant label is \"flight_time\".',\n",
169 | " answer='flight_time'\n",
170 | ")"
171 | ]
172 | },
173 | "execution_count": 5,
174 | "metadata": {},
175 | "output_type": "execute_result"
176 | }
177 | ],
178 | "source": [
179 | "## run prediction\n",
180 | "first_row = df_train.iloc[0]\n",
181 | "print(f\"customer message: {first_row['text']},real class: {first_row['intent']}\")\n",
182 | "cot_predictor(customer_message=first_row[\"text\"], intent_labels=labels_str)\n"
183 | ]
184 | },
185 | {
186 | "cell_type": "markdown",
187 | "id": "6faec3e6-114e-4397-996f-e3c0a0e805c4",
188 | "metadata": {},
189 | "source": [
190 | "# Define Examples"
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": 6,
196 | "id": "1e01c3c9-288d-4477-a20b-47694f583c8a",
197 | "metadata": {},
198 | "outputs": [
199 | {
200 | "name": "stdout",
201 | "output_type": "stream",
202 | "text": [
203 | "90 45\n"
204 | ]
205 | }
206 | ],
207 | "source": [
208 | "# nowhere else there was an example of this thing\n",
209 | "\n",
210 | "# we want k examples per class \n",
211 | "def get_dspy_examples(df, k) -> dspy.example:\n",
212 | " dspy_examples = []\n",
213 | " for label in labels:\n",
214 | " try:\n",
215 | " label_df = df[df[\"intent\"] == label].sample(n=k)\n",
216 | " for index, row in label_df.iterrows():\n",
217 | " dspy_examples.append(\n",
218 | " dspy.Example(customer_message=row[\"text\"], answer=row[\"intent\"], \n",
219 | " intent_labels=labels_str).with_inputs(\"customer_message\", \"intent_labels\")\n",
220 | " )\n",
221 | " except:\n",
222 | " # there are classes that don't have any representatives\n",
223 | " continue\n",
224 | "\n",
225 | " return dspy_examples\n",
226 | "\n",
227 | "\n",
228 | "train_examples = get_dspy_examples(df_train, k=2)\n",
229 | "all_test_examples = get_dspy_examples(df_test, k=10)\n",
230 | "print(len(all_test_examples), len(all_test_examples) // 2)\n",
231 | "dev_examples = random.sample(all_test_examples, len(all_test_examples) // 2)\n",
232 | "test_examples = [example for example in all_test_examples if example not in dev_examples]"
233 | ]
234 | },
235 | {
236 | "cell_type": "markdown",
237 | "id": "b061f1af-b8fc-4742-951e-cfeaf1b696a9",
238 | "metadata": {},
239 | "source": [
240 | "# Define LabeledFewShot Optimizer\n",
241 | "LabeledFewShot is the simplest optimizer. Its compile method injects samples intro the prompt.\n",
242 | "There is not optimization going on.\n"
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": 7,
248 | "id": "24b21df7-57c4-454e-bdf6-d63e8f855e70",
249 | "metadata": {},
250 | "outputs": [],
251 | "source": [
252 | "from dspy.teleprompt import LabeledFewShot\n",
253 | "\n",
254 | "few_shot_demos = random.sample(train_examples, k=10)\n",
255 | "labeled_fewshot_optimizer = LabeledFewShot(k=len(few_shot_demos))\n",
256 | "few_shot_model = labeled_fewshot_optimizer.compile(student=cot_predictor, trainset=few_shot_demos)\n"
257 | ]
258 | },
259 | {
260 | "cell_type": "markdown",
261 | "id": "369a899d-4994-4371-9025-ea7f36f86bba",
262 | "metadata": {},
263 | "source": [
264 | "### What is happenning under the hood?\n",
265 | "### LabeledFewShot randomly selects labels"
266 | ]
267 | },
268 | {
269 | "cell_type": "markdown",
270 | "id": "bf6ac178-fc1b-4e34-b944-ddc781428257",
271 | "metadata": {},
272 | "source": [
273 | "### DSPy SOURCE CODE: https://github.com/stanfordnlp/dspy/blob/793530c65a0e1721997dac0d2636f0f70ad649b6/dspy/teleprompt/vanilla.py#L6\n",
274 | "\n",
275 | "class LabeledFewShot(Teleprompter):\n",
276 | " def __init__(self, k=16):\n",
277 | " self.k = k\n",
278 | "\n",
279 | " def compile(self, student, *, trainset, sample=True):\n",
280 | " self.student = student.reset_copy()\n",
281 | " self.trainset = trainset\n",
282 | "\n",
283 | " if len(self.trainset) == 0:\n",
284 | " return self.student\n",
285 | "\n",
286 | " rng = random.Random(0)\n",
287 | "\n",
288 | " for predictor in self.student.predictors():\n",
289 | " if sample:\n",
290 | " predictor.demos = rng.sample(self.trainset, min(self.k, len(self.trainset)))\n",
291 | " else:\n",
292 | " predictor.demos = self.trainset[: min(self.k, len(self.trainset))]\n",
293 | "\n",
294 | " return self.student"
295 | ]
296 | },
297 | {
298 | "cell_type": "markdown",
299 | "id": "d2598608-3292-49a8-98f4-2e5f46d1366f",
300 | "metadata": {},
301 | "source": [
302 | "### My own summary of the implementation\n",
303 | "DSPy samples randomly a portion of the samples as examples for in-context learning. \n",
304 | "There's no actual optimization process."
305 | ]
306 | },
307 | {
308 | "cell_type": "markdown",
309 | "id": "6bdad12e-aff6-4b2e-bd93-f6c7f594b807",
310 | "metadata": {},
311 | "source": [
312 | "### How does the prompt looks like?"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": 20,
318 | "id": "12c40bb7-707d-4b34-b5a1-3643d85425ff",
319 | "metadata": {},
320 | "outputs": [
321 | {
322 | "name": "stdout",
323 | "output_type": "stream",
324 | "text": [
325 | "\n",
326 | "\n",
327 | "\n",
328 | "Classify the customer message into one of the intent labels.\n",
329 | "The output should be only the predicted class as a single intent label.\n",
330 | "\n",
331 | "---\n",
332 | "\n",
333 | "Follow the following format.\n",
334 | "\n",
335 | "Customer Message: Customer message during customer service interaction\n",
336 | "\n",
337 | "Intent Labels: Labels that represent customer intent\n",
338 | "\n",
339 | "Reasoning: Let's think step by step in order to ${produce the answer}. We ...\n",
340 | "\n",
341 | "Answer: a label best matching customer's intent\n",
342 | "\n",
343 | "---\n",
344 | "\n",
345 | "Customer Message: what are the air restrictions on flights from pittsburgh to atlanta for the airfare of 416 dollars\n",
346 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
347 | "Answer: restriction\n",
348 | "\n",
349 | "---\n",
350 | "\n",
351 | "Customer Message: what is the flight number of the earliest flight between boston and washington dc\n",
352 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
353 | "Answer: flight_no\n",
354 | "\n",
355 | "---\n",
356 | "\n",
357 | "Customer Message: what times does continental depart from boston to san francisco\n",
358 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
359 | "Answer: flight_time\n",
360 | "\n",
361 | "---\n",
362 | "\n",
363 | "Customer Message: what ground transportation is available in san francisco\n",
364 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
365 | "Answer: ground_service\n",
366 | "\n",
367 | "---\n",
368 | "\n",
369 | "Customer Message: do i get a meal on the atlanta to bwi flight eastern 210\n",
370 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
371 | "Answer: meal\n",
372 | "\n",
373 | "---\n",
374 | "\n",
375 | "Customer Message: round trip fares from denver to philadelphia less than 1000 dollars\n",
376 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
377 | "Answer: airfare\n",
378 | "\n",
379 | "---\n",
380 | "\n",
381 | "Customer Message: which airlines fly between boston and pittsburgh\n",
382 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
383 | "Answer: airline\n",
384 | "\n",
385 | "---\n",
386 | "\n",
387 | "Customer Message: houston airports\n",
388 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
389 | "Answer: airport\n",
390 | "\n",
391 | "---\n",
392 | "\n",
393 | "Customer Message: what is the distance between pittsburgh airport and downtown pittsburgh\n",
394 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
395 | "Answer: distance\n",
396 | "\n",
397 | "---\n",
398 | "\n",
399 | "Customer Message: i would like the flight number and the time for the cheapest fare that is the least expensive first class fare from san francisco to pittsburgh leaving after 8 pm monday night\n",
400 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
401 | "Answer: flight_no\n",
402 | "\n",
403 | "---\n",
404 | "\n",
405 | "Customer Message: show flights from minneapolis to kansas city\n",
406 | "\n",
407 | "Intent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\n",
408 | "\n",
409 | "Reasoning: Let's think step by step in order to\u001b[32m produce the answer. We need to identify the main intent of the customer message. The customer is asking to see flights from one city to another, which indicates they are looking for information related to flights. The most appropriate label that matches this intent is \"flight.\"\n",
410 | "\n",
411 | "Answer: flight\u001b[0m\n",
412 | "\n",
413 | "\n",
414 | "\n"
415 | ]
416 | },
417 | {
418 | "data": {
419 | "text/plain": [
420 | "'\\n\\n\\nClassify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.\\n\\n---\\n\\nFollow the following format.\\n\\nCustomer Message: Customer message during customer service interaction\\n\\nIntent Labels: Labels that represent customer intent\\n\\nReasoning: Let\\'s think step by step in order to ${produce the answer}. We ...\\n\\nAnswer: a label best matching customer\\'s intent\\n\\n---\\n\\nCustomer Message: what are the air restrictions on flights from pittsburgh to atlanta for the airfare of 416 dollars\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: restriction\\n\\n---\\n\\nCustomer Message: what is the flight number of the earliest flight between boston and washington dc\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: flight_no\\n\\n---\\n\\nCustomer Message: what times does continental depart from boston to san francisco\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: flight_time\\n\\n---\\n\\nCustomer Message: what ground transportation is available in san francisco\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: ground_service\\n\\n---\\n\\nCustomer Message: do i get a meal on the atlanta to bwi flight eastern 210\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: meal\\n\\n---\\n\\nCustomer Message: round trip fares from denver to philadelphia less than 1000 dollars\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: airfare\\n\\n---\\n\\nCustomer Message: which airlines fly between boston and pittsburgh\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: airline\\n\\n---\\n\\nCustomer Message: houston airports\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: airport\\n\\n---\\n\\nCustomer Message: what is the distance between pittsburgh airport and downtown pittsburgh\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: distance\\n\\n---\\n\\nCustomer Message: i would like the flight number and the time for the cheapest fare that is the least expensive first class fare from san francisco to pittsburgh leaving after 8 pm monday night\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\nAnswer: flight_no\\n\\n---\\n\\nCustomer Message: show flights from minneapolis to kansas city\\n\\nIntent Labels: flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no\\n\\nReasoning: Let\\'s think step by step in order to\\x1b[32m produce the answer. We need to identify the main intent of the customer message. The customer is asking to see flights from one city to another, which indicates they are looking for information related to flights. The most appropriate label that matches this intent is \"flight.\"\\n\\nAnswer: flight\\x1b[0m\\n\\n\\n'"
421 | ]
422 | },
423 | "execution_count": 20,
424 | "metadata": {},
425 | "output_type": "execute_result"
426 | }
427 | ],
428 | "source": [
429 | "example = test_examples[0]\n",
430 | "# without inputs(), we won't inject the inputs of the example\n",
431 | "pred = few_shot_model(**example.inputs())\n",
432 | "# Produce a prediction from our `cot` module, using the `example` above as input.\n",
433 | "lm_mini.inspect_history(n=1)"
434 | ]
435 | },
436 | {
437 | "cell_type": "markdown",
438 | "id": "4be6e2df-2e27-43c3-b40b-f6dd70a35ef3",
439 | "metadata": {},
440 | "source": [
441 | "## Define BootstrapFewShot Optimizer\n",
442 | "This family of optimizers is focused on optimizing the few shot examples. Let's take an example of a Sample pipeline and see how we can use this optimizer to optimize it. From: https://dspy.ai/deep-dive/optimizers/bootstrap-fewshot/"
443 | ]
444 | },
445 | {
446 | "cell_type": "code",
447 | "execution_count": 9,
448 | "id": "9e86aad2-7f4e-463c-9ee0-5db62ffec24b",
449 | "metadata": {},
450 | "outputs": [],
451 | "source": [
452 | "from dspy.evaluate import answer_exact_match as metric\n",
453 | "from dspy.teleprompt import BootstrapFewShot\n",
454 | "\n",
455 | "optimizer = BootstrapFewShot(\n",
456 | " metric=metric,\n",
457 | " max_bootstrapped_demos=10,\n",
458 | " max_labeled_demos=10,\n",
459 | " max_rounds=10,\n",
460 | ")"
461 | ]
462 | },
463 | {
464 | "cell_type": "markdown",
465 | "id": "d7281605-91de-4bd6-8244-df8c8dc6441c",
466 | "metadata": {},
467 | "source": [
468 | "### Optimize"
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": 10,
474 | "id": "45e59037-ced5-44b7-8d31-8337fe9d45ea",
475 | "metadata": {},
476 | "outputs": [
477 | {
478 | "name": "stderr",
479 | "output_type": "stream",
480 | "text": [
481 | " 28%|████████████████████████████████████▍ | 10/36 [00:14<00:38, 1.47s/it]"
482 | ]
483 | },
484 | {
485 | "name": "stdout",
486 | "output_type": "stream",
487 | "text": [
488 | "Bootstrapped 10 full traces after 10 examples for up to 10 rounds, amounting to 10 attempts.\n"
489 | ]
490 | },
491 | {
492 | "name": "stderr",
493 | "output_type": "stream",
494 | "text": [
495 | "\n"
496 | ]
497 | }
498 | ],
499 | "source": [
500 | "# documentation is wrong - there is not valset: https://dspy.ai/deep-dive/optimizers/bootstrap-fewshot/\n",
501 | "cot_few_shot_optimized = optimizer.compile(cot_predictor, trainset=train_examples)\n"
502 | ]
503 | },
504 | {
505 | "cell_type": "markdown",
506 | "id": "71cb5854-1b23-44f0-b28d-7a77e0458f67",
507 | "metadata": {},
508 | "source": [
509 | "## Peek under the hood of DSPy source code for BootStrapFewShot training"
510 | ]
511 | },
512 | {
513 | "cell_type": "markdown",
514 | "id": "b7f51b93-76ae-47bf-b644-53f1719af833",
515 | "metadata": {},
516 | "source": [
517 | "### DSPy source code for training\n",
518 | "class BootstrapFewShot()\n",
519 | " def _train(self):\n",
520 | " rng = random.Random(0)\n",
521 | " raw_demos = self.validation\n",
522 | "\n",
523 | " for name, predictor in self.student.named_predictors():\n",
524 | " augmented_demos = self.name2traces[name][: self.max_bootstrapped_demos]\n",
525 | "\n",
526 | " sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos))\n",
527 | " sample_size = max(0, sample_size)\n",
528 | "\n",
529 | " raw_demos = rng.sample(raw_demos, sample_size)\n",
530 | "\n",
531 | " if dspy.settings.release >= 20230928:\n",
532 | " predictor.demos = raw_demos + augmented_demos\n",
533 | " else:\n",
534 | " predictor.demos = augmented_demos + raw_demos\n",
535 | "\n",
536 | " return self.student"
537 | ]
538 | },
539 | {
540 | "cell_type": "markdown",
541 | "id": "ceade59e-3c28-4938-9119-1599aee4ec2c",
542 | "metadata": {},
543 | "source": [
544 | "I consulted with ChatGPT about this method. Source code: https://github.com/stanfordnlp/dspy/blob/main/dspy/teleprompt/bootstrap.py\n",
545 | "\n",
546 | "_train() Purpose\n",
547 | "Once _bootstrap() has collected and validated a set of bootstrapped demos, _train() takes over to:\n",
548 | "\n",
549 | "Compile Final Demos for Predictors: _train() assembles the demos (both bootstrapped and labeled) for each predictor within the student model. For each predictor, it selects a mix of bootstrapped demos (from _bootstrap()) and labeled examples (raw demos from the validation set) to create a final demo set.\n",
550 | "Random Sampling: The method performs a random sample from the raw labeled demos, ensuring the demos meet the configuration limits, such as max_labeled_demos.\n",
551 | "Set Demos for Each Predictor: Finally, _train() updates each predictor in the student model with this finalized set of demos, effectively preparing it for use.\n",
552 | "In essence, _bootstrap() is responsible for creating and validating bootstrapped demos, while _train() assembles a balanced set of these demos and labeled examples to finalize the student model’s training."
553 | ]
554 | },
555 | {
556 | "cell_type": "markdown",
557 | "id": "81a6d61a-b30f-41fd-96a6-4abf5108eeab",
558 | "metadata": {},
559 | "source": [
560 | "### My Own Summary\n",
561 | "BootstrapFewShot has two main properties:\n",
562 | "1. Enable you to generate additional examples\n",
563 | "2. DSPy tests which predictions pass the validation and keep only those"
564 | ]
565 | },
566 | {
567 | "cell_type": "markdown",
568 | "id": "0aa483be-ed96-4fbe-8ac0-32af8a04534b",
569 | "metadata": {},
570 | "source": [
571 | "## Define BootstrapFewShotWithRandomSearch"
572 | ]
573 | },
574 | {
575 | "cell_type": "code",
576 | "execution_count": 11,
577 | "id": "86165cf0-fd28-4f74-9803-ebec811c4d6d",
578 | "metadata": {},
579 | "outputs": [
580 | {
581 | "name": "stdout",
582 | "output_type": "stream",
583 | "text": [
584 | "Going to sample between 1 and 10 traces per predictor.\n",
585 | "Will attempt to bootstrap 5 candidate sets.\n"
586 | ]
587 | }
588 | ],
589 | "source": [
590 | "from dspy.teleprompt import BootstrapFewShotWithRandomSearch\n",
591 | "\n",
592 | "optimizer = BootstrapFewShotWithRandomSearch(\n",
593 | " metric=metric, \n",
594 | " max_bootstrapped_demos=10, \n",
595 | " max_labeled_demos=10,\n",
596 | " num_threads=10,\n",
597 | " num_candidate_programs=5\n",
598 | ")"
599 | ]
600 | },
601 | {
602 | "cell_type": "code",
603 | "execution_count": 12,
604 | "id": "852ff802-23c1-49ed-90f4-af0dd9210eb5",
605 | "metadata": {},
606 | "outputs": [
607 | {
608 | "name": "stderr",
609 | "output_type": "stream",
610 | "text": [
611 | "Average Metric: 29 / 36 (80.6): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:05<00:00, 6.66it/s]\n",
612 | "2024/11/18 09:16:35 INFO dspy.evaluate.evaluate: Average Metric: 29 / 36 (80.6%)\n"
613 | ]
614 | },
615 | {
616 | "name": "stdout",
617 | "output_type": "stream",
618 | "text": [
619 | "New best score: 80.56 for seed -3\n",
620 | "Scores so far: [80.56]\n",
621 | "Best score so far: 80.56\n"
622 | ]
623 | },
624 | {
625 | "name": "stderr",
626 | "output_type": "stream",
627 | "text": [
628 | "Average Metric: 31 / 36 (86.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:04<00:00, 8.78it/s]\n",
629 | "2024/11/18 09:16:39 INFO dspy.evaluate.evaluate: Average Metric: 31 / 36 (86.1%)\n"
630 | ]
631 | },
632 | {
633 | "name": "stdout",
634 | "output_type": "stream",
635 | "text": [
636 | "New best score: 86.11 for seed -2\n",
637 | "Scores so far: [80.56, 86.11]\n",
638 | "Best score so far: 86.11\n"
639 | ]
640 | },
641 | {
642 | "name": "stderr",
643 | "output_type": "stream",
644 | "text": [
645 | " 28%|███████████████████████████████████▊ | 10/36 [00:00<00:00, 1520.39it/s]\n"
646 | ]
647 | },
648 | {
649 | "name": "stdout",
650 | "output_type": "stream",
651 | "text": [
652 | "Bootstrapped 10 full traces after 10 examples for up to 1 rounds, amounting to 10 attempts.\n"
653 | ]
654 | },
655 | {
656 | "name": "stderr",
657 | "output_type": "stream",
658 | "text": [
659 | "Average Metric: 30 / 36 (83.3): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:07<00:00, 5.06it/s]\n",
660 | "2024/11/18 09:16:47 INFO dspy.evaluate.evaluate: Average Metric: 30 / 36 (83.3%)\n"
661 | ]
662 | },
663 | {
664 | "name": "stdout",
665 | "output_type": "stream",
666 | "text": [
667 | "Scores so far: [80.56, 86.11, 83.33]\n",
668 | "Best score so far: 86.11\n"
669 | ]
670 | },
671 | {
672 | "name": "stderr",
673 | "output_type": "stream",
674 | "text": [
675 | " 28%|████████████████████████████████████▍ | 10/36 [00:25<01:05, 2.53s/it]\n"
676 | ]
677 | },
678 | {
679 | "name": "stdout",
680 | "output_type": "stream",
681 | "text": [
682 | "Bootstrapped 7 full traces after 10 examples for up to 1 rounds, amounting to 10 attempts.\n"
683 | ]
684 | },
685 | {
686 | "name": "stderr",
687 | "output_type": "stream",
688 | "text": [
689 | "Average Metric: 31 / 36 (86.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:06<00:00, 5.38it/s]\n",
690 | "2024/11/18 09:17:19 INFO dspy.evaluate.evaluate: Average Metric: 31 / 36 (86.1%)\n"
691 | ]
692 | },
693 | {
694 | "name": "stdout",
695 | "output_type": "stream",
696 | "text": [
697 | "Scores so far: [80.56, 86.11, 83.33, 86.11]\n",
698 | "Best score so far: 86.11\n"
699 | ]
700 | },
701 | {
702 | "name": "stderr",
703 | "output_type": "stream",
704 | "text": [
705 | " 11%|██████████████▋ | 4/36 [00:06<00:55, 1.72s/it]\n"
706 | ]
707 | },
708 | {
709 | "name": "stdout",
710 | "output_type": "stream",
711 | "text": [
712 | "Bootstrapped 3 full traces after 4 examples for up to 1 rounds, amounting to 4 attempts.\n"
713 | ]
714 | },
715 | {
716 | "name": "stderr",
717 | "output_type": "stream",
718 | "text": [
719 | "Average Metric: 31 / 36 (86.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:07<00:00, 4.59it/s]\n",
720 | "2024/11/18 09:17:33 INFO dspy.evaluate.evaluate: Average Metric: 31 / 36 (86.1%)\n"
721 | ]
722 | },
723 | {
724 | "name": "stdout",
725 | "output_type": "stream",
726 | "text": [
727 | "Scores so far: [80.56, 86.11, 83.33, 86.11, 86.11]\n",
728 | "Best score so far: 86.11\n"
729 | ]
730 | },
731 | {
732 | "name": "stderr",
733 | "output_type": "stream",
734 | "text": [
735 | " 3%|███▋ | 1/36 [00:01<00:40, 1.17s/it]\n"
736 | ]
737 | },
738 | {
739 | "name": "stdout",
740 | "output_type": "stream",
741 | "text": [
742 | "Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.\n"
743 | ]
744 | },
745 | {
746 | "name": "stderr",
747 | "output_type": "stream",
748 | "text": [
749 | "Average Metric: 32 / 36 (88.9): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:07<00:00, 4.80it/s]\n",
750 | "2024/11/18 09:17:42 INFO dspy.evaluate.evaluate: Average Metric: 32 / 36 (88.9%)\n"
751 | ]
752 | },
753 | {
754 | "name": "stdout",
755 | "output_type": "stream",
756 | "text": [
757 | "New best score: 88.89 for seed 2\n",
758 | "Scores so far: [80.56, 86.11, 83.33, 86.11, 86.11, 88.89]\n",
759 | "Best score so far: 88.89\n"
760 | ]
761 | },
762 | {
763 | "name": "stderr",
764 | "output_type": "stream",
765 | "text": [
766 | " 14%|██████████████████▎ | 5/36 [00:08<00:53, 1.72s/it]\n"
767 | ]
768 | },
769 | {
770 | "name": "stdout",
771 | "output_type": "stream",
772 | "text": [
773 | "Bootstrapped 4 full traces after 5 examples for up to 1 rounds, amounting to 5 attempts.\n"
774 | ]
775 | },
776 | {
777 | "name": "stderr",
778 | "output_type": "stream",
779 | "text": [
780 | "Average Metric: 29 / 36 (80.6): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:07<00:00, 5.10it/s]\n",
781 | "2024/11/18 09:17:58 INFO dspy.evaluate.evaluate: Average Metric: 29 / 36 (80.6%)\n"
782 | ]
783 | },
784 | {
785 | "name": "stdout",
786 | "output_type": "stream",
787 | "text": [
788 | "Scores so far: [80.56, 86.11, 83.33, 86.11, 86.11, 88.89, 80.56]\n",
789 | "Best score so far: 88.89\n"
790 | ]
791 | },
792 | {
793 | "name": "stderr",
794 | "output_type": "stream",
795 | "text": [
796 | " 19%|█████████████████████████▋ | 7/36 [00:11<00:46, 1.60s/it]\n"
797 | ]
798 | },
799 | {
800 | "name": "stdout",
801 | "output_type": "stream",
802 | "text": [
803 | "Bootstrapped 4 full traces after 7 examples for up to 1 rounds, amounting to 7 attempts.\n"
804 | ]
805 | },
806 | {
807 | "name": "stderr",
808 | "output_type": "stream",
809 | "text": [
810 | "Average Metric: 31 / 36 (86.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:06<00:00, 5.33it/s]\n",
811 | "2024/11/18 09:18:16 INFO dspy.evaluate.evaluate: Average Metric: 31 / 36 (86.1%)\n"
812 | ]
813 | },
814 | {
815 | "name": "stdout",
816 | "output_type": "stream",
817 | "text": [
818 | "Scores so far: [80.56, 86.11, 83.33, 86.11, 86.11, 88.89, 80.56, 86.11]\n",
819 | "Best score so far: 88.89\n",
820 | "8 candidate programs found.\n"
821 | ]
822 | }
823 | ],
824 | "source": [
825 | "cot_few_shot_rs_optimized = optimizer.compile(cot_predictor, trainset=train_examples)"
826 | ]
827 | },
828 | {
829 | "cell_type": "markdown",
830 | "id": "6c393252-33c3-42d4-82cb-02072a1e757d",
831 | "metadata": {},
832 | "source": [
833 | "## Peek under the hood of the source code Implementation"
834 | ]
835 | },
836 | {
837 | "cell_type": "markdown",
838 | "id": "f463a5df-875c-47b2-b8b1-e857e914714c",
839 | "metadata": {},
840 | "source": [
841 | "### Source code\n",
842 | "From: https://github.com/stanfordnlp/dspy/blob/main/dspy/teleprompt/random_search.py\n",
843 | "\n",
844 | " assert seed >= 0, seed\n",
845 | "\n",
846 | " random.Random(seed).shuffle(trainset_copy)\n",
847 | " size = random.Random(seed).randint(self.min_num_samples, self.max_num_samples)\n",
848 | "\n",
849 | " optimizer = BootstrapFewShot(\n",
850 | " metric=self.metric,\n",
851 | " metric_threshold=self.metric_threshold,\n",
852 | " max_bootstrapped_demos=size,\n",
853 | " max_labeled_demos=self.max_labeled_demos,\n",
854 | " teacher_settings=self.teacher_settings,\n",
855 | " max_rounds=self.max_rounds,\n",
856 | " max_errors=self.max_errors,\n",
857 | " )\n",
858 | "\n",
859 | " program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy)\n",
860 | "\n",
861 | " evaluate = Evaluate(\n",
862 | " devset=self.valset,\n",
863 | " metric=self.metric,\n",
864 | " num_threads=self.num_threads,\n",
865 | " max_errors=self.max_errors,\n",
866 | " display_table=False,\n",
867 | " display_progress=True,\n",
868 | " )\n",
869 | "\n",
870 | " score, subscores = evaluate(program, return_all_scores=True)\n",
871 | "\n",
872 | " all_subscores.append(subscores)"
873 | ]
874 | },
875 | {
876 | "cell_type": "markdown",
877 | "id": "cc4e963c-4c49-469c-9ebb-3486bd40cad8",
878 | "metadata": {},
879 | "source": [
880 | "### My own summary\n",
881 | "Given the number of programs we will generate each time a different seed and run BootStrapFewShot with that "
882 | ]
883 | },
884 | {
885 | "cell_type": "markdown",
886 | "id": "c8f8ed4e-59fe-4657-a760-61e7d212c99f",
887 | "metadata": {},
888 | "source": [
889 | "# Evaluation"
890 | ]
891 | },
892 | {
893 | "cell_type": "markdown",
894 | "id": "97081940-62e1-470a-b227-c5465b1ddc7c",
895 | "metadata": {},
896 | "source": [
897 | "## Single Evaluation"
898 | ]
899 | },
900 | {
901 | "cell_type": "code",
902 | "execution_count": 13,
903 | "id": "16175d77-0aa0-4265-8da0-b8ad3dd7325a",
904 | "metadata": {},
905 | "outputs": [
906 | {
907 | "name": "stdout",
908 | "output_type": "stream",
909 | "text": [
910 | "Example({'customer_message': 'show flights from minneapolis to kansas city', 'answer': 'flight', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'})\n",
911 | "Prediction(\n",
912 | " rationale='identify the customer\\'s intent. The customer is asking to see flights from one city (Minneapolis) to another (Kansas City). This indicates that they are looking for information related to flights, specifically the availability of flights between these two locations. The intent clearly aligns with searching for flights, which is represented by the label \"flight.\"',\n",
913 | " answer='flight'\n",
914 | ")\n",
915 | "Customer message: \t show flights from minneapolis to kansas city\n",
916 | "\n",
917 | "Gold Response: \t flight\n",
918 | "\n",
919 | "Predicted Response: \t flight\n",
920 | "\n",
921 | "Exact match score: 1.00\n"
922 | ]
923 | }
924 | ],
925 | "source": [
926 | "from dspy.evaluate import answer_exact_match\n",
927 | "\n",
928 | "# Instantiate the metric.\n",
929 | "metric = answer_exact_match\n",
930 | "\n",
931 | "example = test_examples[0]\n",
932 | "# Produce a prediction from our `cot` module, using the `example` above as input.\n",
933 | "print(example)\n",
934 | "pred = cot_predictor(**example.inputs())\n",
935 | "print(pred)\n",
936 | "\n",
937 | "# Compute the metric score for the prediction.\n",
938 | "score = metric(example, pred)\n",
939 | "\n",
940 | "print(f\"Customer message: \\t {example.customer_message}\\n\")\n",
941 | "print(f\"Gold Response: \\t {example.answer}\\n\")\n",
942 | "print(f\"Predicted Response: \\t {pred.answer}\\n\")\n",
943 | "print(f\"Exact match score: {score:.2f}\")"
944 | ]
945 | },
946 | {
947 | "cell_type": "markdown",
948 | "id": "8b6a33cf-0fdc-458c-a698-6cb2e0fe1732",
949 | "metadata": {},
950 | "source": [
951 | "## Setup Evaluation"
952 | ]
953 | },
954 | {
955 | "cell_type": "code",
956 | "execution_count": 14,
957 | "id": "64bc0f61-727b-4a12-88e2-71d7bea426ba",
958 | "metadata": {},
959 | "outputs": [
960 | {
961 | "name": "stdout",
962 | "output_type": "stream",
963 | "text": [
964 | "36\n"
965 | ]
966 | }
967 | ],
968 | "source": [
969 | "from dspy.evaluate.evaluate import Evaluate\n",
970 | "# Set up the `evaluate_atis` function. We'll use this many times below.\n",
971 | "print(len(train_examples))\n",
972 | "evaluate_atis = Evaluate(devset=test_examples, num_threads=8, display_progress=True, display_table=5, provide_traceback=True)"
973 | ]
974 | },
975 | {
976 | "cell_type": "markdown",
977 | "id": "7688f4de-726c-426a-a048-fbcbc3723a0e",
978 | "metadata": {},
979 | "source": [
980 | "## Evaluate zero shot CoT "
981 | ]
982 | },
983 | {
984 | "cell_type": "code",
985 | "execution_count": 15,
986 | "id": "06aca234-925e-4e8c-abb7-f9de1d54e4a4",
987 | "metadata": {},
988 | "outputs": [
989 | {
990 | "name": "stderr",
991 | "output_type": "stream",
992 | "text": [
993 | "Average Metric: 37 / 42 (88.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:08<00:00, 5.17it/s]\n",
994 | "2024/11/18 09:19:01 INFO dspy.evaluate.evaluate: Average Metric: 37 / 42 (88.1%)\n"
995 | ]
996 | },
997 | {
998 | "data": {
999 | "text/html": [
1000 | "\n",
1014 | "
\n",
1015 | " \n",
1016 | " \n",
1017 | " | | \n",
1018 | " customer_message | \n",
1019 | " example_answer | \n",
1020 | " intent_labels | \n",
1021 | " rationale | \n",
1022 | " pred_answer | \n",
1023 | " answer_exact_match | \n",
1024 | "
\n",
1025 | " \n",
1026 | " \n",
1027 | " \n",
1028 | " | 0 | \n",
1029 | " show flights from minneapolis to kansas city | \n",
1030 | " flight | \n",
1031 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1032 | " identify the customer's intent. The customer is asking to see flights from one city (Minneapolis) to another (Kansas City). This indicates that they are looking... | \n",
1033 | " flight | \n",
1034 | " ✔️ [True] | \n",
1035 | "
\n",
1036 | " \n",
1037 | " | 1 | \n",
1038 | " which flights leave on wednesday april thirteenth from indianapolis and arrive in montreal in the morning | \n",
1039 | " flight | \n",
1040 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1041 | " produce the answer. We need to identify the main focus of the customer's message. The customer is inquiring about flights, specifically looking for details about... | \n",
1042 | " flight_time | \n",
1043 | " | \n",
1044 | "
\n",
1045 | " \n",
1046 | " | 2 | \n",
1047 | " i would like flight information from phoenix to denver | \n",
1048 | " flight | \n",
1049 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1050 | " identify the customer's intent. The customer is asking for specific information regarding a flight from Phoenix to Denver. This request involves details about the flight,... | \n",
1051 | " flight | \n",
1052 | " ✔️ [True] | \n",
1053 | "
\n",
1054 | " \n",
1055 | " | 3 | \n",
1056 | " how much is coach flight from pittsburgh to atlanta | \n",
1057 | " airfare | \n",
1058 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1059 | " determine the customer's intent. The customer is inquiring about the cost of a coach flight from Pittsburgh to Atlanta, which directly relates to airfare. The... | \n",
1060 | " airfare | \n",
1061 | " ✔️ [True] | \n",
1062 | "
\n",
1063 | " \n",
1064 | " | 4 | \n",
1065 | " list airfares for first class round trip from detroit to st. petersburg | \n",
1066 | " airfare | \n",
1067 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1068 | " identify the customer's intent. The customer is asking for information about airfares specifically for a first-class round trip flight from Detroit to St. Petersburg. This... | \n",
1069 | " airfare | \n",
1070 | " ✔️ [True] | \n",
1071 | "
\n",
1072 | " \n",
1073 | "
\n"
1074 | ],
1075 | "text/plain": [
1076 | ""
1077 | ]
1078 | },
1079 | "metadata": {},
1080 | "output_type": "display_data"
1081 | },
1082 | {
1083 | "data": {
1084 | "text/html": [
1085 | "\n",
1086 | " \n",
1092 | " ... 37 more rows not displayed ...\n",
1093 | "
\n",
1094 | " "
1095 | ],
1096 | "text/plain": [
1097 | ""
1098 | ]
1099 | },
1100 | "metadata": {},
1101 | "output_type": "display_data"
1102 | },
1103 | {
1104 | "data": {
1105 | "text/plain": [
1106 | "88.1"
1107 | ]
1108 | },
1109 | "execution_count": 15,
1110 | "metadata": {},
1111 | "output_type": "execute_result"
1112 | }
1113 | ],
1114 | "source": [
1115 | "# Evaluate the program with the `answer_exact_match` metric.\n",
1116 | "# Launch evaluation.\n",
1117 | "evaluate_atis(cot_predictor, metric=metric)\n"
1118 | ]
1119 | },
1120 | {
1121 | "cell_type": "markdown",
1122 | "id": "eefa60d2-dac0-4646-afed-ee34e6675619",
1123 | "metadata": {},
1124 | "source": [
1125 | "## Evaluate few shot CoT"
1126 | ]
1127 | },
1128 | {
1129 | "cell_type": "code",
1130 | "execution_count": 16,
1131 | "id": "142910a7-ef4a-44c0-b48d-b9e999e74553",
1132 | "metadata": {},
1133 | "outputs": [
1134 | {
1135 | "name": "stderr",
1136 | "output_type": "stream",
1137 | "text": [
1138 | "Average Metric: 33 / 42 (78.6): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:14<00:00, 2.98it/s]\n",
1139 | "2024/11/18 09:19:22 INFO dspy.evaluate.evaluate: Average Metric: 33 / 42 (78.6%)\n"
1140 | ]
1141 | },
1142 | {
1143 | "data": {
1144 | "text/html": [
1145 | "\n",
1159 | "\n",
1160 | " \n",
1161 | " \n",
1162 | " | | \n",
1163 | " customer_message | \n",
1164 | " example_answer | \n",
1165 | " intent_labels | \n",
1166 | " rationale | \n",
1167 | " pred_answer | \n",
1168 | " answer_exact_match | \n",
1169 | "
\n",
1170 | " \n",
1171 | " \n",
1172 | " \n",
1173 | " | 0 | \n",
1174 | " show flights from minneapolis to kansas city | \n",
1175 | " flight | \n",
1176 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1177 | " produce the answer. We need to identify the main intent of the customer message. The customer is asking to see flights from one city to... | \n",
1178 | " flight | \n",
1179 | " ✔️ [True] | \n",
1180 | "
\n",
1181 | " \n",
1182 | " | 1 | \n",
1183 | " which flights leave on wednesday april thirteenth from indianapolis and arrive in montreal in the morning | \n",
1184 | " flight | \n",
1185 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1186 | " determine the intent of the customer's message. The customer is asking for information about flights that leave from Indianapolis and arrive in Montreal on a... | \n",
1187 | " flight_time | \n",
1188 | " | \n",
1189 | "
\n",
1190 | " \n",
1191 | " | 2 | \n",
1192 | " i would like flight information from phoenix to denver | \n",
1193 | " flight | \n",
1194 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1195 | " produce the answer. We need to identify the main focus of the customer's request. The customer is asking for flight information, which typically includes details... | \n",
1196 | " flight | \n",
1197 | " ✔️ [True] | \n",
1198 | "
\n",
1199 | " \n",
1200 | " | 3 | \n",
1201 | " how much is coach flight from pittsburgh to atlanta | \n",
1202 | " airfare | \n",
1203 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1204 | " produce the answer. We need to identify the main focus of the customer's inquiry. The customer is asking about the cost of a coach flight... | \n",
1205 | " airfare | \n",
1206 | " ✔️ [True] | \n",
1207 | "
\n",
1208 | " \n",
1209 | " | 4 | \n",
1210 | " list airfares for first class round trip from detroit to st. petersburg | \n",
1211 | " airfare | \n",
1212 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1213 | " produce the answer. We need to identify the main focus of the customer's request. The customer is asking for airfares specifically for first-class round trips,... | \n",
1214 | " airfare | \n",
1215 | " ✔️ [True] | \n",
1216 | "
\n",
1217 | " \n",
1218 | "
\n"
1219 | ],
1220 | "text/plain": [
1221 | ""
1222 | ]
1223 | },
1224 | "metadata": {},
1225 | "output_type": "display_data"
1226 | },
1227 | {
1228 | "data": {
1229 | "text/html": [
1230 | "\n",
1231 | " \n",
1237 | " ... 37 more rows not displayed ...\n",
1238 | "
\n",
1239 | " "
1240 | ],
1241 | "text/plain": [
1242 | ""
1243 | ]
1244 | },
1245 | "metadata": {},
1246 | "output_type": "display_data"
1247 | },
1248 | {
1249 | "data": {
1250 | "text/plain": [
1251 | "78.57"
1252 | ]
1253 | },
1254 | "execution_count": 16,
1255 | "metadata": {},
1256 | "output_type": "execute_result"
1257 | }
1258 | ],
1259 | "source": [
1260 | "evaluate_atis(few_shot_model, metric=metric)"
1261 | ]
1262 | },
1263 | {
1264 | "cell_type": "markdown",
1265 | "id": "4e91d4f9-a39f-4d5a-8d0a-bec4c0ba6501",
1266 | "metadata": {},
1267 | "source": [
1268 | "## Evaluate BootstrapedFewShot"
1269 | ]
1270 | },
1271 | {
1272 | "cell_type": "code",
1273 | "execution_count": 17,
1274 | "id": "5f20ef32-d986-4fda-875a-deaa690651ac",
1275 | "metadata": {},
1276 | "outputs": [
1277 | {
1278 | "name": "stderr",
1279 | "output_type": "stream",
1280 | "text": [
1281 | "Average Metric: 37 / 42 (88.1): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:10<00:00, 4.06it/s]\n",
1282 | "2024/11/18 09:19:43 INFO dspy.evaluate.evaluate: Average Metric: 37 / 42 (88.1%)\n"
1283 | ]
1284 | },
1285 | {
1286 | "data": {
1287 | "text/html": [
1288 | "\n",
1302 | "\n",
1303 | " \n",
1304 | " \n",
1305 | " | | \n",
1306 | " customer_message | \n",
1307 | " example_answer | \n",
1308 | " intent_labels | \n",
1309 | " rationale | \n",
1310 | " pred_answer | \n",
1311 | " answer_exact_match | \n",
1312 | "
\n",
1313 | " \n",
1314 | " \n",
1315 | " \n",
1316 | " | 0 | \n",
1317 | " show flights from minneapolis to kansas city | \n",
1318 | " flight | \n",
1319 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1320 | " identify the main focus of the customer's request. The customer is asking to see flights from Minneapolis to Kansas City, which indicates a clear interest... | \n",
1321 | " flight | \n",
1322 | " ✔️ [True] | \n",
1323 | "
\n",
1324 | " \n",
1325 | " | 1 | \n",
1326 | " which flights leave on wednesday april thirteenth from indianapolis and arrive in montreal in the morning | \n",
1327 | " flight | \n",
1328 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1329 | " identify the intent of the customer's message. The customer is inquiring about flights that leave on a specific day (Wednesday, April 13th) from Indianapolis to... | \n",
1330 | " flight | \n",
1331 | " ✔️ [True] | \n",
1332 | "
\n",
1333 | " \n",
1334 | " | 2 | \n",
1335 | " i would like flight information from phoenix to denver | \n",
1336 | " flight | \n",
1337 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1338 | " identify the main focus of the customer's request. The customer is explicitly asking for flight information regarding a route from Phoenix to Denver. This indicates... | \n",
1339 | " flight | \n",
1340 | " ✔️ [True] | \n",
1341 | "
\n",
1342 | " \n",
1343 | " | 3 | \n",
1344 | " how much is coach flight from pittsburgh to atlanta | \n",
1345 | " airfare | \n",
1346 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1347 | " identify the main focus of the customer's inquiry. The customer is asking about the cost of a coach flight from Pittsburgh to Atlanta, which indicates... | \n",
1348 | " airfare | \n",
1349 | " ✔️ [True] | \n",
1350 | "
\n",
1351 | " \n",
1352 | " | 4 | \n",
1353 | " list airfares for first class round trip from detroit to st. petersburg | \n",
1354 | " airfare | \n",
1355 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1356 | " identify the main focus of the customer's request. The customer is asking for airfares specifically for first-class round trip flights from Detroit to St. Petersburg.... | \n",
1357 | " airfare | \n",
1358 | " ✔️ [True] | \n",
1359 | "
\n",
1360 | " \n",
1361 | "
\n"
1362 | ],
1363 | "text/plain": [
1364 | ""
1365 | ]
1366 | },
1367 | "metadata": {},
1368 | "output_type": "display_data"
1369 | },
1370 | {
1371 | "data": {
1372 | "text/html": [
1373 | "\n",
1374 | " \n",
1380 | " ... 37 more rows not displayed ...\n",
1381 | "
\n",
1382 | " "
1383 | ],
1384 | "text/plain": [
1385 | ""
1386 | ]
1387 | },
1388 | "metadata": {},
1389 | "output_type": "display_data"
1390 | },
1391 | {
1392 | "data": {
1393 | "text/plain": [
1394 | "88.1"
1395 | ]
1396 | },
1397 | "execution_count": 17,
1398 | "metadata": {},
1399 | "output_type": "execute_result"
1400 | }
1401 | ],
1402 | "source": [
1403 | "evaluate_atis(cot_few_shot_optimized, metric=metric)"
1404 | ]
1405 | },
1406 | {
1407 | "cell_type": "markdown",
1408 | "id": "67f56d55-d44f-4a9a-b5ba-3e93670fc8fd",
1409 | "metadata": {},
1410 | "source": [
1411 | "## Evaluate Boostraped Random Search"
1412 | ]
1413 | },
1414 | {
1415 | "cell_type": "code",
1416 | "execution_count": 18,
1417 | "id": "8a0491f8-dd36-4910-8db5-f3036a00960a",
1418 | "metadata": {},
1419 | "outputs": [
1420 | {
1421 | "name": "stderr",
1422 | "output_type": "stream",
1423 | "text": [
1424 | "Average Metric: 40 / 42 (95.2): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 42/42 [00:09<00:00, 4.22it/s]\n",
1425 | "2024/11/18 09:19:58 INFO dspy.evaluate.evaluate: Average Metric: 40 / 42 (95.2%)\n"
1426 | ]
1427 | },
1428 | {
1429 | "data": {
1430 | "text/html": [
1431 | "\n",
1445 | "\n",
1446 | " \n",
1447 | " \n",
1448 | " | | \n",
1449 | " customer_message | \n",
1450 | " example_answer | \n",
1451 | " intent_labels | \n",
1452 | " rationale | \n",
1453 | " pred_answer | \n",
1454 | " answer_exact_match | \n",
1455 | "
\n",
1456 | " \n",
1457 | " \n",
1458 | " \n",
1459 | " | 0 | \n",
1460 | " show flights from minneapolis to kansas city | \n",
1461 | " flight | \n",
1462 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1463 | " determine the intent of the customer's message. The customer is asking to see flights between two specific cities, which indicates they are looking for information... | \n",
1464 | " flight | \n",
1465 | " ✔️ [True] | \n",
1466 | "
\n",
1467 | " \n",
1468 | " | 1 | \n",
1469 | " which flights leave on wednesday april thirteenth from indianapolis and arrive in montreal in the morning | \n",
1470 | " flight | \n",
1471 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1472 | " determine the intent of the customer's message. The customer is asking for specific flight information regarding flights that leave from a particular city on a... | \n",
1473 | " flight | \n",
1474 | " ✔️ [True] | \n",
1475 | "
\n",
1476 | " \n",
1477 | " | 2 | \n",
1478 | " i would like flight information from phoenix to denver | \n",
1479 | " flight | \n",
1480 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1481 | " determine the intent of the customer's message. The customer is requesting flight information between two specific cities, which indicates they are looking for details related... | \n",
1482 | " flight | \n",
1483 | " ✔️ [True] | \n",
1484 | "
\n",
1485 | " \n",
1486 | " | 3 | \n",
1487 | " how much is coach flight from pittsburgh to atlanta | \n",
1488 | " airfare | \n",
1489 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1490 | " determine the intent of the customer's message. The customer is inquiring about the cost of a specific type of flight (coach) from one city to... | \n",
1491 | " airfare | \n",
1492 | " ✔️ [True] | \n",
1493 | "
\n",
1494 | " \n",
1495 | " | 4 | \n",
1496 | " list airfares for first class round trip from detroit to st. petersburg | \n",
1497 | " airfare | \n",
1498 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1499 | " determine the intent of the customer's message. The customer is asking for information about airfares specifically for first class round trip flights between two cities.... | \n",
1500 | " airfare | \n",
1501 | " ✔️ [True] | \n",
1502 | "
\n",
1503 | " \n",
1504 | "
\n"
1505 | ],
1506 | "text/plain": [
1507 | ""
1508 | ]
1509 | },
1510 | "metadata": {},
1511 | "output_type": "display_data"
1512 | },
1513 | {
1514 | "data": {
1515 | "text/html": [
1516 | "\n",
1517 | " \n",
1523 | " ... 37 more rows not displayed ...\n",
1524 | "
\n",
1525 | " "
1526 | ],
1527 | "text/plain": [
1528 | ""
1529 | ]
1530 | },
1531 | "metadata": {},
1532 | "output_type": "display_data"
1533 | },
1534 | {
1535 | "data": {
1536 | "text/plain": [
1537 | "95.24"
1538 | ]
1539 | },
1540 | "execution_count": 18,
1541 | "metadata": {},
1542 | "output_type": "execute_result"
1543 | }
1544 | ],
1545 | "source": [
1546 | "evaluate_atis(cot_few_shot_rs_optimized, metric=metric)"
1547 | ]
1548 | },
1549 | {
1550 | "cell_type": "markdown",
1551 | "id": "bca34c63-f758-48c3-b7d8-c25b54045125",
1552 | "metadata": {},
1553 | "source": [
1554 | "# Save / Load models"
1555 | ]
1556 | },
1557 | {
1558 | "cell_type": "code",
1559 | "execution_count": 21,
1560 | "id": "b2476c94-4ff1-4467-a9ee-acf8e4121cf1",
1561 | "metadata": {},
1562 | "outputs": [],
1563 | "source": [
1564 | "cot_predictor.save(\"cot_zero_shot.json\")\n",
1565 | "few_shot_model.save(\"cot_few_shot.json\")\n",
1566 | "cot_few_shot_optimized.save(\"cot_boostraped_few_shot.json\")\n",
1567 | "cot_few_shot_rs_optimized.save(\"cot_bootstraped_rs_few_shot.json\")"
1568 | ]
1569 | },
1570 | {
1571 | "cell_type": "code",
1572 | "execution_count": 22,
1573 | "id": "ee875833-b02d-4ef8-a218-3b242dba4292",
1574 | "metadata": {},
1575 | "outputs": [],
1576 | "source": [
1577 | "cot_predictor.load(\"cot_zero_shot.json\")\n",
1578 | "few_shot_model.load(\"cot_few_shot.json\")\n",
1579 | "cot_few_shot_optimized.load(\"cot_boostraped_few_shot.json\")\n",
1580 | "cot_few_shot_rs_optimized.load(\"cot_bootstraped_rs_few_shot.json\")"
1581 | ]
1582 | },
1583 | {
1584 | "cell_type": "code",
1585 | "execution_count": null,
1586 | "id": "63c9f057-c9bb-40d7-812b-8f570dc78d05",
1587 | "metadata": {},
1588 | "outputs": [],
1589 | "source": []
1590 | },
1591 | {
1592 | "cell_type": "code",
1593 | "execution_count": null,
1594 | "id": "7f29f4cd-3511-4e4c-9bd3-c3a6adf687e5",
1595 | "metadata": {},
1596 | "outputs": [],
1597 | "source": []
1598 | },
1599 | {
1600 | "cell_type": "code",
1601 | "execution_count": 27,
1602 | "id": "40ca28e6-e75f-4976-9a1c-2237102b7932",
1603 | "metadata": {},
1604 | "outputs": [
1605 | {
1606 | "data": {
1607 | "text/plain": [
1608 | "[Predict(StringSignature(customer_message, intent_labels -> rationale, answer\n",
1609 | " instructions='Classify the customer message into one of the intent labels.\\nThe output should be only the predicted class as a single intent label.'\n",
1610 | " customer_message = Field(annotation=str required=True json_schema_extra={'desc': 'Customer message during customer service interaction', '__dspy_field_type': 'input', 'prefix': 'Customer Message:'})\n",
1611 | " intent_labels = Field(annotation=str required=True json_schema_extra={'desc': 'Labels that represent customer intent', '__dspy_field_type': 'input', 'prefix': 'Intent Labels:'})\n",
1612 | " rationale = Field(annotation=str required=True json_schema_extra={'prefix': \"Reasoning: Let's think step by step in order to\", 'desc': '${produce the answer}. We ...', '__dspy_field_type': 'output'})\n",
1613 | " answer = Field(annotation=str required=True json_schema_extra={'desc': \"a label best matching customer's intent \", '__dspy_field_type': 'output', 'prefix': 'Answer:'})\n",
1614 | " ))]"
1615 | ]
1616 | },
1617 | "execution_count": 27,
1618 | "metadata": {},
1619 | "output_type": "execute_result"
1620 | }
1621 | ],
1622 | "source": [
1623 | "cot_predictor.parameters()"
1624 | ]
1625 | },
1626 | {
1627 | "cell_type": "code",
1628 | "execution_count": 30,
1629 | "id": "c176b4f0-ce8c-40a0-b905-be11ff0d4332",
1630 | "metadata": {},
1631 | "outputs": [
1632 | {
1633 | "data": {
1634 | "text/plain": [
1635 | "[Example({'customer_message': 'what kind of plane flies from boston to pittsburgh after noon', 'answer': 'aircraft', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1636 | " Example({'customer_message': 'what are the restrictions on the cheapest one way fare between boston and oakland', 'answer': 'restriction', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1637 | " Example({'customer_message': 'fine can you give me information on ground transportation in washington dc to downtown', 'answer': 'ground_service', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1638 | " Example({'customer_message': 'what is the distance from toronto international airport to toronto', 'answer': 'distance', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1639 | " Example({'customer_message': 'list flights from denver to philadelphia', 'answer': 'flight', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1640 | " Example({'customer_message': 'what is the meaning of restriction ap80', 'answer': 'abbreviation', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1641 | " Example({'customer_message': 'what is restriction ap57', 'answer': 'restriction', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1642 | " Example({'customer_message': 'what time zone is denver in', 'answer': 'city', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1643 | " Example({'customer_message': 'airports', 'answer': 'airport', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1644 | " Example({'customer_message': 'list daily flights from boston to oakland using twa', 'answer': 'flight', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'})]"
1645 | ]
1646 | },
1647 | "execution_count": 30,
1648 | "metadata": {},
1649 | "output_type": "execute_result"
1650 | }
1651 | ],
1652 | "source": [
1653 | "few_shot_model.demos"
1654 | ]
1655 | },
1656 | {
1657 | "cell_type": "code",
1658 | "execution_count": 31,
1659 | "id": "a1df2ccc-04aa-48e3-a2c7-daf235662e7a",
1660 | "metadata": {},
1661 | "outputs": [
1662 | {
1663 | "data": {
1664 | "text/plain": [
1665 | "[Example({'augmented': True, 'customer_message': 'show me the flights from los angeles to pittsburgh which arrive at pittsburgh on monday', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'produce the answer. We need to identify the main focus of the customer\\'s request. The customer is asking for flights from Los Angeles to Pittsburgh, specifically mentioning the arrival day (Monday). This indicates that the intent is related to finding flights. The most relevant label that matches this request is \"flight\".', 'answer': 'flight'}) (input_keys=None),\n",
1666 | " Example({'augmented': True, 'customer_message': 'what time does the flight leave denver going to san francisco on continental airlines', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'determine the intent of the customer\\'s message. The customer is asking for the departure time of a specific flight from Denver to San Francisco on Continental Airlines. This indicates that the primary focus of the inquiry is related to the timing of the flight. \\n\\nThe intent labels include various categories, but the most relevant one for this question is \"flight_time,\" as it directly pertains to the time of departure for the flight.', 'answer': 'flight_time'}) (input_keys=None),\n",
1667 | " Example({'augmented': True, 'customer_message': 'show me times for coach flights between boston and baltimore on wednesday', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'produce the answer. We need to identify the main focus of the customer\\'s request. The customer is asking for \"times\" related to \"coach flights\" between two cities on a specific day. This indicates that the customer is interested in the schedule or timing of flights. Therefore, the intent label that best matches this request is related to flight times.', 'answer': 'flight_time'}) (input_keys=None),\n",
1668 | " Example({'augmented': True, 'customer_message': 'how much does it cost to fly on american from dallas to baltimore', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'produce the answer. We need to identify the main focus of the customer\\'s inquiry. The customer is asking about the cost of a flight on American Airlines from Dallas to Baltimore. This indicates that the primary intent is related to the fare or price of the flight. Therefore, the best matching intent label for this inquiry is \"airfare.\"', 'answer': 'airfare'}) (input_keys=None),\n",
1669 | " Example({'augmented': True, 'customer_message': 'please show me fares from denver to san francisco', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"produce the answer. We need to identify the main focus of the customer's request. The customer is asking for fares, which indicates they are interested in the cost of flights from Denver to San Francisco. This aligns with the intent label related to airfare.\", 'answer': 'airfare'}) (input_keys=None),\n",
1670 | " Example({'augmented': True, 'customer_message': \"i 'd like to know what type of plane flies from boston to dallas fort worth after 530\", 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"determine the intent of the customer's message. The customer is asking about the type of plane that operates on a specific route (from Boston to Dallas Fort Worth) and at a specific time (after 5:30). This indicates that the customer is interested in information related to the aircraft used for that flight.\", 'answer': 'aircraft'}) (input_keys=None),\n",
1671 | " Example({'augmented': True, 'customer_message': 'what is the smallest aircraft that flies from pittsburgh to baltimore arriving may seventh', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"determine the intent of the customer's message. The customer is asking about the smallest aircraft that operates on a specific route (from Pittsburgh to Baltimore) and on a specific date (arriving May seventh). This indicates that the customer is interested in information related to the type of aircraft used for that flight.\", 'answer': 'aircraft'}) (input_keys=None),\n",
1672 | " Example({'augmented': True, 'customer_message': 'what types of ground transportation are there to san francisco airport', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'identify the intent of the customer message. The customer is asking about the types of ground transportation available to San Francisco airport. This indicates that the focus is on transportation services rather than flights, fares, or other aviation-related queries. Therefore, the most appropriate intent label that matches this inquiry is related to ground services.', 'answer': 'ground_service'}) (input_keys=None),\n",
1673 | " Example({'augmented': True, 'customer_message': 'is there ground transportation from the milwaukee airport to the downtown area', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'identify the intent of the customer message. The customer is inquiring about ground transportation options from the Milwaukee airport to the downtown area. This indicates a need for information related to ground services rather than flights or fares. Therefore, the most appropriate intent label for this message is related to ground services.', 'answer': 'ground_service'}) (input_keys=None),\n",
1674 | " Example({'augmented': True, 'customer_message': 'houston airports', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'identify the intent of the customer message. The customer is asking about \"houston airports,\" which indicates they are looking for information related to airports in Houston. This aligns with the intent label that represents airport-related inquiries.', 'answer': 'airport'}) (input_keys=None)]"
1675 | ]
1676 | },
1677 | "execution_count": 31,
1678 | "metadata": {},
1679 | "output_type": "execute_result"
1680 | }
1681 | ],
1682 | "source": [
1683 | "cot_few_shot_optimized.demos"
1684 | ]
1685 | },
1686 | {
1687 | "cell_type": "code",
1688 | "execution_count": 78,
1689 | "id": "8956c58a-069d-4da9-92a2-865181e45305",
1690 | "metadata": {},
1691 | "outputs": [
1692 | {
1693 | "data": {
1694 | "text/plain": [
1695 | "[Example({'augmented': True, 'customer_message': 'show me the flights from boston to oakland', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': 'produce the answer. We need to identify the main request in the customer message. The customer is asking for information about flights from Boston to Oakland, which indicates they are looking for flight details. The most relevant intent label that matches this request is \"flight\".', 'answer': 'flight'}) (input_keys=None),\n",
1696 | " Example({'augmented': True, 'customer_message': 'does delta have an early afternoon flight from boston to san francisco', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"identify the main request in the customer's message. The customer is asking about the availability of a specific flight operated by Delta Airlines from Boston to San Francisco during a particular time frame (early afternoon). This indicates an intent to inquire about a flight schedule rather than asking for flights, fares, or other related information. \\n\\nSince the customer is specifically asking about the existence of a flight, the best matching intent label is 'flight'.\", 'answer': 'flight'}) (input_keys=None),\n",
1697 | " Example({'augmented': True, 'customer_message': 'please give me the flight times the morning on united airlines for september twentieth from philadelphia to san francisco', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"produce the answer. We need to identify the main focus of the customer's request. The customer is asking for flight times specifically for a flight on United Airlines from Philadelphia to San Francisco on a specific date, which indicates they are interested in the timing of the flights. Therefore, the intent label that best matches this request is flight_time.\", 'answer': 'flight_time'}) (input_keys=None),\n",
1698 | " Example({'augmented': True, 'customer_message': 'what times does continental depart from boston to san francisco', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no', 'rationale': \"determine the intent of the customer's message. The customer is asking for specific departure times for a flight operated by Continental Airlines from Boston to San Francisco. This request is focused on the timing of the flight, which falls under the category of flight times.\", 'answer': 'flight_time'}) (input_keys=None),\n",
1699 | " Example({'customer_message': 'can you list costs of denver rental cars', 'answer': 'ground_fare', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1700 | " Example({'customer_message': 'all flights and fares from pittsburgh to dallas round trip after 12 pm less than 1100 dollars', 'answer': 'flight+airfare', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1701 | " Example({'customer_message': 'what is the ap 57 restriction', 'answer': 'abbreviation', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1702 | " Example({'customer_message': 'what is the total seating capacity of all aircraft of american airlines', 'answer': 'capacity', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1703 | " Example({'customer_message': 'names of airports', 'answer': 'airport', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1704 | " Example({'customer_message': 'how much is the ground transportation between atlanta and downtown', 'answer': 'ground_fare', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1705 | " Example({'customer_message': 'what ground transportation is there in atlanta', 'answer': 'ground_service', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1706 | " Example({'customer_message': 'what are the seating capacities of planes between pittsburgh and baltimore', 'answer': 'capacity', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1707 | " Example({'customer_message': 'what are all the available meals', 'answer': 'meal', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1708 | " Example({'customer_message': 'show me all meals on flights from atlanta to washington', 'answer': 'meal', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1709 | " Example({'customer_message': 'airline and flight number from columbus to minneapolis', 'answer': 'airline+flight_no', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'}),\n",
1710 | " Example({'customer_message': 'flight numbers from columbus to minneapolis tomorrow', 'answer': 'flight_no', 'intent_labels': 'flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no'}) (input_keys={'customer_message', 'intent_labels'})]"
1711 | ]
1712 | },
1713 | "execution_count": 78,
1714 | "metadata": {},
1715 | "output_type": "execute_result"
1716 | }
1717 | ],
1718 | "source": [
1719 | "cot_few_shot_rs_optimized.demos"
1720 | ]
1721 | },
1722 | {
1723 | "cell_type": "markdown",
1724 | "id": "4a0856d6-bc3b-4c47-8ae3-6934975e0839",
1725 | "metadata": {},
1726 | "source": [
1727 | "# Remove bootstrapping"
1728 | ]
1729 | },
1730 | {
1731 | "cell_type": "code",
1732 | "execution_count": 41,
1733 | "id": "d5f289d7-750a-4d4d-b959-c8ff74fa639a",
1734 | "metadata": {},
1735 | "outputs": [
1736 | {
1737 | "name": "stderr",
1738 | "output_type": "stream",
1739 | "text": [
1740 | " 11%|██████████████▋ | 4/36 [00:01<00:11, 2.85it/s]\n"
1741 | ]
1742 | },
1743 | {
1744 | "name": "stdout",
1745 | "output_type": "stream",
1746 | "text": [
1747 | "Bootstrapped 4 full traces after 4 examples for up to 20 rounds, amounting to 8 attempts.\n"
1748 | ]
1749 | },
1750 | {
1751 | "name": "stderr",
1752 | "output_type": "stream",
1753 | "text": [
1754 | "Average Metric: 32 / 40 (80.0): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:11<00:00, 3.63it/s]\n",
1755 | "2024/11/17 06:42:28 INFO dspy.evaluate.evaluate: Average Metric: 32 / 40 (80.0%)\n"
1756 | ]
1757 | },
1758 | {
1759 | "data": {
1760 | "text/html": [
1761 | "\n",
1775 | "\n",
1776 | " \n",
1777 | " \n",
1778 | " | | \n",
1779 | " customer_message | \n",
1780 | " example_answer | \n",
1781 | " intent_labels | \n",
1782 | " rationale | \n",
1783 | " pred_answer | \n",
1784 | " answer_exact_match | \n",
1785 | "
\n",
1786 | " \n",
1787 | " \n",
1788 | " \n",
1789 | " | 0 | \n",
1790 | " list airports in new york | \n",
1791 | " airport | \n",
1792 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1793 | " determine the intent of the customer's message. The customer is asking for a list of airports located in New York. This request is focused on... | \n",
1794 | " airport | \n",
1795 | " ✔️ [True] | \n",
1796 | "
\n",
1797 | " \n",
1798 | " | 1 | \n",
1799 | " show me flights from montreal to orlando | \n",
1800 | " flight | \n",
1801 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1802 | " produce the answer. The customer is requesting information about flights from Montreal to Orlando. This indicates that the intent is related to finding flights between... | \n",
1803 | " flight | \n",
1804 | " ✔️ [True] | \n",
1805 | "
\n",
1806 | " \n",
1807 | " | 2 | \n",
1808 | " which flights leave chicago next tuesday and arrive in detroit around 6 pm | \n",
1809 | " flight | \n",
1810 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1811 | " determine the intent of the customer's message. The customer is asking for information about flights that leave from Chicago next Tuesday and arrive in Detroit... | \n",
1812 | " flight | \n",
1813 | " ✔️ [True] | \n",
1814 | "
\n",
1815 | " \n",
1816 | " | 3 | \n",
1817 | " list the flights from dallas to baltimore arriving july first | \n",
1818 | " flight | \n",
1819 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1820 | " determine the intent of the customer's message. The customer is asking for a list of flights from Dallas to Baltimore that are arriving on a... | \n",
1821 | " flight | \n",
1822 | " ✔️ [True] | \n",
1823 | "
\n",
1824 | " \n",
1825 | " | 4 | \n",
1826 | " show flights tomorrow evening from milwaukee to st. louis | \n",
1827 | " flight | \n",
1828 | " flight%flight_time%airfare%aircraft%ground_service%airport%airline%distance%abbreviation%ground_fare%quantity%city%flight_no%capacity%flight+airfare%meal%restriction%airline+flight_no%ground_service+ground_fare%airfare+flight_time%cheapest%aircraft+flight+flight_no | \n",
1829 | " produce the answer. The customer is requesting information about flights that are scheduled for tomorrow evening from Milwaukee to St. Louis. This indicates that the... | \n",
1830 | " flight | \n",
1831 | " ✔️ [True] | \n",
1832 | "
\n",
1833 | " \n",
1834 | "
\n"
1835 | ],
1836 | "text/plain": [
1837 | ""
1838 | ]
1839 | },
1840 | "metadata": {},
1841 | "output_type": "display_data"
1842 | },
1843 | {
1844 | "data": {
1845 | "text/html": [
1846 | "\n",
1847 | " \n",
1853 | " ... 35 more rows not displayed ...\n",
1854 | "
\n",
1855 | " "
1856 | ],
1857 | "text/plain": [
1858 | ""
1859 | ]
1860 | },
1861 | "metadata": {},
1862 | "output_type": "display_data"
1863 | },
1864 | {
1865 | "data": {
1866 | "text/plain": [
1867 | "80.0"
1868 | ]
1869 | },
1870 | "execution_count": 41,
1871 | "metadata": {},
1872 | "output_type": "execute_result"
1873 | }
1874 | ],
1875 | "source": [
1876 | "from dspy.evaluate import answer_exact_match as metric\n",
1877 | "from dspy.teleprompt import BootstrapFewShot\n",
1878 | "\n",
1879 | "optimizer = BootstrapFewShot(\n",
1880 | " metric=metric,\n",
1881 | " max_bootstrapped_demos=4, # no need for boostrapped\n",
1882 | " max_labeled_demos=40, # increase examples\n",
1883 | " max_rounds=20,\n",
1884 | ")\n",
1885 | "\n",
1886 | "# documentation is wrong - there is not valset: https://dspy.ai/deep-dive/optimizers/bootstrap-fewshot/\n",
1887 | "cot_few_shot_to_bootstrap = optimizer.compile(cot_predictor, trainset=train_examples)\n",
1888 | "\n",
1889 | "\n",
1890 | "evaluate_atis(cot_few_shot_to_bootstrap, metric=metric)\n",
1891 | "# bootrsapped_demos: generated demos?\n",
1892 | "# max_labeled_demos: input demos"
1893 | ]
1894 | }
1895 | ],
1896 | "metadata": {
1897 | "kernelspec": {
1898 | "display_name": "Python 3 (ipykernel)",
1899 | "language": "python",
1900 | "name": "python3"
1901 | },
1902 | "language_info": {
1903 | "codemirror_mode": {
1904 | "name": "ipython",
1905 | "version": 3
1906 | },
1907 | "file_extension": ".py",
1908 | "mimetype": "text/x-python",
1909 | "name": "python",
1910 | "nbconvert_exporter": "python",
1911 | "pygments_lexer": "ipython3",
1912 | "version": "3.10.12"
1913 | }
1914 | },
1915 | "nbformat": 4,
1916 | "nbformat_minor": 5
1917 | }
1918 |
--------------------------------------------------------------------------------