├── .gitignore
├── .python-version
├── LICENSE
├── Makefile
├── README.md
├── all_sweeps.sh
├── config-schema.json
├── examples
└── tweet_writer
│ ├── config.json
│ ├── create_dataset.py
│ ├── guided.config.json
│ └── task.py
├── experiments
├── __init__.py
├── cs_tooluse
│ ├── config.json
│ └── evaluators.py
├── elon_email_sweeps.jsonl
├── email_cs
│ ├── config.json
│ └── evaluators.py
├── email_cs10
│ ├── config.json
│ └── evaluators.py
├── email_cs_simple
│ ├── config.json
│ └── evaluators.py
├── email_elon
│ ├── config.json
│ └── evaluators.py
├── extract_code
│ ├── __init__.py
│ ├── backup.config.json
│ ├── config.json
│ ├── evaluators.py
│ └── task.py
├── extract_legal
│ ├── __init__.py
│ ├── backup.config.json
│ ├── config.json
│ ├── evaluators.py
│ └── task.py
├── math_multi
│ ├── config.json
│ └── evaluators.py
├── multiclass_email10
│ ├── __init__.py
│ ├── backup.config.json
│ ├── config.json
│ ├── evaluators.py
│ └── task.py
├── multiclass_email3
│ ├── __init__.py
│ ├── backup.config.json
│ ├── config.json
│ ├── evaluators.py
│ └── task.py
├── multiclass_health10
│ ├── __init__.py
│ ├── backup.config.json
│ ├── config.json
│ ├── evaluators.py
│ └── task.py
├── multiclass_health3
│ ├── __init__.py
│ ├── backup.config.json
│ ├── config.json
│ ├── evaluators.py
│ └── task.py
├── sweeps.jsonl
├── tool_sweeps.jsonl
├── tooluse_ecommerce
│ ├── __init__.py
│ ├── backup.config.json
│ ├── config.json
│ ├── evaluators.py
│ └── task.py
└── tooluse_finance
│ ├── __init__.py
│ ├── backup.config.json
│ ├── config.json
│ ├── evaluators.py
│ └── task.py
├── generate_schema.py
├── pyproject.toml
├── src
└── promptim
│ ├── __init__.py
│ ├── __main__.py
│ ├── _utils.py
│ ├── algorithms
│ ├── __init__.py
│ ├── base.py
│ ├── minibatch.py
│ ├── mipro.py
│ ├── phaseevo
│ │ ├── __init__.py
│ │ ├── algo.py
│ │ └── mutations.py
│ └── tpe_sampler.py
│ ├── config.py
│ ├── optimizers
│ ├── __init__.py
│ ├── base.py
│ ├── debate.py
│ ├── feedback_guided.py
│ ├── fewshot.py
│ └── metaprompt.py
│ ├── py.typed
│ ├── tasks
│ ├── __init__.py
│ ├── metaprompt.py
│ ├── scone.py
│ ├── simpleqa.py
│ ├── ticket_classification.py
│ └── tweet_generator.py
│ ├── trainer.py
│ └── types.py
├── static
└── optimizer.gif
├── test
└── cassettes
│ └── db688c10-764b-42ec-acce-4d62419600ed.yaml
├── tests
└── test_optimizers.py
└── uv.lock
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.pyc
6 | .DS_Store
7 | build/
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113 | .pdm.toml
114 | .pdm-python
115 | .pdm-build/
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | #.idea/
166 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.11
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 William FH
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: lint format
2 |
3 |
4 | lint:
5 | uv run ruff check .
6 | uv run ruff format . --diff
7 |
8 | format:
9 | uv run ruff check --fix .
10 | uv run ruff format .
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Promptim
2 |
3 | Promptim is an experimental **prompt** opt**im**ization library to help you systematically improve your AI systems.
4 |
5 | Promptim automates the process of improving prompts on specific tasks. You provide initial prompt, a dataset, and custom evaluators (and optional human feedback), and `promptim` runs an optimization loop to produce a refined prompt that aims to outperform the original.
6 |
7 | For setup and usage details, see the quick start guide below.
8 |
9 | 
10 |
11 | ## Quick start
12 |
13 | Let's try prompt optimization on a simple tweet generation task.
14 |
15 | ### 1. Install
16 |
17 | First install the CLI.
18 |
19 | ```shell
20 | pip install -U promptim
21 | ```
22 |
23 | And make sure you have a valid [LangSmith API Key](https://smith.langchain.com/) in your environment. For the quick start task, we will use Anthropic's Claude model for our optimizer and for the target system.
24 |
25 | ```shell
26 | LANGSMITH_API_KEY=CHANGEME
27 | ANTHROPIC_API_KEY=CHANGEME
28 | ```
29 |
30 | ### 2. Create task
31 |
32 | Next, create a task to optimize over. Run the following command to generate a template:
33 |
34 | ```shell
35 | promptim create task ./my-tweet-task \
36 | --name my-tweet-task \
37 | --prompt langchain-ai/tweet-generator-example-with-nothing:starter \
38 | --dataset https://smith.langchain.com/public/6ed521df-c0d8-42b7-a0db-48dd73a0c680/d \
39 | --description "Write informative tweets on any subject." \
40 | -y
41 | ```
42 | This command will generate starter code, complete with the task's:
43 | 1. Name: Provide a useful name for the task (like "ticket classifier" or "report generator"). You may use the default here.
44 | 2. Prompt: This is an identifier in the LangSmith prompt hub. Use the following public prompt to start.
45 | 3. Dataset: This is the name (or public URL) for the dataset we are optimizing over. Optionally, it can have train/dev/test splits to report separate metrics throughout the training process.
46 | 4. Description: This is a high-level description of the purpose for this prompt. The optimizer uses this to help focus its improvements.
47 |
48 | Once you've completed the template creation, you should have two files in the `my-tweet-task` directory:
49 |
50 | ```shell
51 | └── my-tweet-task
52 | ├── config.json
53 | └── task.py
54 | ```
55 |
56 | We can ignore the `config.json` file for now (we'll discuss that later). The last thing we need to do before training is create an evaluator.
57 |
58 | ### 3. Define evaluators
59 |
60 | Next we need to quantify prompt performance on our task. What does "good" and "bad" look like? We do this using evaluators.
61 |
62 | Open the evaluator stub written in `my-tweet-task/task.py` and find the line that assigns a score to a prediction:
63 |
64 | ```python
65 | # Implement your evaluation logic here
66 | score = len(str(predicted.content)) < 180 # Replace with actual score
67 | ```
68 |
69 | We are going to make this evaluator penalize outputs with hashtags. Update that line to be:
70 | ```python
71 | score = int("#" not in result)
72 | ```
73 |
74 | Next, update the evaluator name. We do this using the `key` field in the evaluator response.
75 | ```python
76 | "key": "tweet_omits_hashtags",
77 | ```
78 |
79 | To help the optimizer know the ideal behavior, we can add additional instructions in the `comment` field in the response.
80 |
81 | Update the "comment" line to explicitly give pass/fail comments:
82 | ```python
83 | "comment": "Pass: tweet omits hashtags" if score == 1 else "Fail: omit all hashtags from generated tweets",
84 | ```
85 |
86 | And now we're ready to train! The final evaluator should look like:
87 |
88 | ```python
89 | def example_evaluator(run: Run, example: Example) -> dict:
90 | """An example evaluator. Larger numbers are better."""
91 | predicted: AIMessage = run.outputs["output"]
92 |
93 | result = str(predicted.content)
94 | score = int("#" not in result)
95 | return {
96 | "key": "tweet_omits_hashtags",
97 | "score": score,
98 | "comment": "Pass: tweet omits hashtags" if score == 1 else "Fail: omit all hashtags from generated tweets",
99 | }
100 |
101 | ```
102 |
103 | ### 4. Train
104 |
105 | To start optimizing your prompt, run the `train` command:
106 |
107 | ```shell
108 | promptim train --task ./my-tweet-task/config.json
109 | ```
110 |
111 | You will see the progress in your terminal. once it's completed, the training job will print out the final "optimized" prompt in the terminal, as well as a link to the commit in the hub.
112 |
113 | ### Explanation
114 |
115 | Whenever you run `promptim train`, promptim first loads the prompt and dataset specified in your configuration. It then evaluates your prompt on the dev split (if present; full dataset otherwise) using the evaluator(s) configured above. This gives us baseline metrics to compare against throughout the optimization process.
116 |
117 | After computing a baseline, `promptim` begins optimizing the prompt by looping over minibatches of training examples. For each minibatch, `promptim` computes the metrics and then applies a **metaprompt** to suggest changes to the current prompt. It then applies that updated prompt to the next minibatch of training examples and repeats the process. It does this over the entire **train** split (if present; full dataset otherwise).
118 |
119 | After `promptim` has consumed the whole `train` split, it computes metrics again on the `dev` split. If the metrics show improvement (average score is greater), then the updated prompt is retained for the next round. If the metrics are the same or worse than the current best score, the prompt is discarded.
120 |
121 | This process is repeated `--num-epochs` times before the process terminates.
122 |
123 | ## How to:
124 |
125 | ### Add human labels
126 |
127 | To add human labeling using the annotation queue:
128 |
129 | 1. Set up an annotation queue:
130 | When running the `train` command, use the `--annotation-queue` option to specify a queue name:
131 | ```
132 | promptim train --task ./my-tweet-task/config.json --annotation-queue my_queue
133 | ```
134 |
135 | 2. During training, the system will pause after each batch and print out instructions on how to label the results. It will wait for human annotations.
136 |
137 | 3. Access the annotation interface:
138 | - Open the LangSmith UI
139 | - Navigate to the specified queue (e.g., "my_queue")
140 | - Review and label as many examples as you'd like, adding notes and scores
141 |
142 | 4. Resume:
143 | - Type 'c' in the terminal
144 | - The training loop will fetch your annotations and include them in the metaprompt's next optimizatin pass
145 |
146 | This human-in-the-loop approach allows you to guide the prompt optimization process by providing direct feedback on the model's outputs.
147 |
148 | ## Reference
149 |
150 | ### CLI Arguments
151 |
152 | The current CLI arguments are as follows. They are experimental and may change in the future:
153 |
154 | ```shell
155 | Usage: promptim [OPTIONS] COMMAND [ARGS]...
156 |
157 | Optimize prompts for AI tasks using automated evaluation and feedback.
158 |
159 | Promptim helps improve prompts for various AI tasks by running an
160 | optimization loop. You provide an initial prompt, a dataset, and custom
161 | evaluators. Promptim then iteratively refines the prompt to improve
162 | performance on your specific task.
163 |
164 | To get started, create a task configuration or use a pre-defined one, then
165 | run the 'train' command to begin optimization.
166 |
167 | Example: promptim train --task ./my-task/config.json
168 |
169 | Options:
170 | --version Show the version and exit.
171 | --help Show this message and exit.
172 |
173 | Commands:
174 | create Commands for creating new tasks.
175 | train Train and optimize prompts for different tasks.
176 | ```
177 |
178 | #### create
179 |
180 |
181 | ```shell
182 | Usage: promptim create [OPTIONS] COMMAND [ARGS]...
183 |
184 | Commands for creating new tasks and examples.
185 |
186 | Options:
187 | --help Show this message and exit.
188 |
189 | Commands:
190 | example Clone a pre-made tweet generation task
191 | task Walkthrough to create a new task directory from your own prompt and dataset
192 | ```
193 |
194 | `promptim create task`
195 |
196 | ```shell
197 | Usage: promptim create task [OPTIONS] PATH
198 |
199 | Create a new task directory with config.json and task file for a custom
200 | prompt and dataset.
201 |
202 | Options:
203 | --name TEXT Name for the task. If not provided, the directory name
204 | will be used as default. This name will be used in the
205 | config.json file.
206 | --prompt TEXT Name of the prompt in LangSmith to be optimized.
207 | If not provided, you'll be prompted to select or create
208 | one. This will be used as the initial prompt for
209 | optimization.
210 | --description TEXT Description of the task for the optimizer. This helps
211 | guide the optimization process by providing context
212 | about the task's objectives and constraints.
213 | --dataset TEXT Name or public URL of the dataset in LangSmith to be used for
214 | training and evaluation. If not provided, you'll be
215 | prompted to select or create one. This dataset will be
216 | used to test and improve the prompt.
217 | -y, --yes Automatically answer yes to all CLI prompts. Use with
218 | caution as it skips confirmation steps and uses defaults
219 | where applicable.
220 | --help Show this message and exit.
221 | ```
222 |
223 |
224 | #### train
225 |
226 | ```shell
227 | Usage: promptim train [OPTIONS]
228 |
229 | Train and optimize prompts for different tasks.
230 |
231 | Options:
232 | --task TEXT Task to optimize. Specify a pre-defined task name
233 | or path to a custom config file. The task defines
234 | the dataset, evaluators, and initial prompt to
235 | optimize. Example:
236 | 'examples/tweet_writer/config.json' for a custom
237 | task, or 'sentiment_analysis' for a pre-defined
238 | task.
239 | --batch-size INTEGER Number of examples to process in each optimization
240 | iteration. Larger batches may improve stability but
241 | are limited by the metaprompter's maximum context
242 | window size.
243 | --train-size INTEGER Maximum number of training examples to use per
244 | epoch. Useful for limiting optimization time on
245 | large datasets. If smaller than total available
246 | data, a random subset will be used each epoch.
247 | --epochs INTEGER Number of complete passes through the training
248 | data. More epochs may improve results but increase
249 | runtime.
250 | --debug Enable debug mode for verbose logging and
251 | sequential processing.
252 | --annotation-queue TEXT Name of the LangSmith annotation queue for manual
253 | review of optimization results. The queue will be
254 | cleared and updated on each batch.
255 | --no-commit Prevent committing the optimized prompt to the
256 | LangChain Hub. Use this for local experimentation.
257 | --help Show this message and exit.
258 | ```
259 |
260 | ### Configuration
261 |
262 | The schema for your `config.json` file can be found in [config-schema.json](./config-schema.json).
263 |
264 | It contains the following arguments:
265 |
266 | - `name` (string, required): The name of your task.
267 | - `dataset` (string, required): The name of the dataset in LangSmith to be used for training and evaluation.
268 | - `initial_prompt` (object, required): Configuration for the initial prompt to be optimized.
269 | - `identifier` (string, optional): Identifier for a prompt from the hub repository. Mutually exclusive with prompt_str.
270 | - `prompt_str` (string, optional): Raw prompt string to optimize locally. Mutually exclusive with identifier.
271 | - `model_config` (object, optional): Configuration dictionary specifying model parameters for optimization.
272 | - `which` (integer, default: 0): Index of the message to optimize within the prompt.
273 | - `description` (string, optional): A detailed explanation of the task's objectives and constraints.
274 | - `evaluator_descriptions` (object, optional): A mapping of evaluator names to their descriptions.
275 | - `optimizer` (object, optional): Configuration specifying model settings and hyperparameters. If not provided, default configuration will be used.
276 | - `model` (object, required): Model configuration dictionary specifying the model name, parameters, and other settings used by the optimizer.
277 | - `evaluators` (string, required): Import path to evaluator functions in format 'file_path:variable_name'. The functions should evaluate prompt quality. Example: `./task/evaluators.py:evaluators`
278 | - `system` (string, optional): Import path to system configuration in format 'file_path:variable_name'. Defines how prompts are executed. If not provided, a default system with just a prompt and LLM will be constructed. Example: `./task/my_system.py:chain`
279 |
280 | Below is an example `config.json` file:
281 |
282 | ```json
283 | {
284 | "name": "Tweet Generator",
285 | "dataset": "tweet_dataset",
286 | "initial_prompt": {
287 | "prompt_str": "Write a tweet about {topic} in the style of {author}",
288 | "which": 0
289 | },
290 | "description": "Generate engaging tweets on various topics in the style of different authors",
291 | "evaluator_descriptions": {
292 | "engagement_score": "Measures the potential engagement of the tweet",
293 | "style_match": "Evaluates how well the tweet matches the specified author's style"
294 | },
295 | "evaluators": "./tweet_evaluators.py:evaluators",
296 | "optimizer": {
297 | "model": {
298 | "name": "gpt-3.5-turbo",
299 | "temperature": 0.7
300 | }
301 | }
302 | }
303 | ```
304 |
--------------------------------------------------------------------------------
/all_sweeps.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e # Exit on error
3 |
4 | VENV_PATH=".venv/bin/activate"
5 | DEFAULT_SWEEP="experiments/tool_sweeps.jsonl"
6 | SWEEP_FILE=""
7 | EPOCHS=10
8 |
9 | while [[ $# -gt 0 ]]; do
10 | case $1 in
11 | --sweep)
12 | SWEEP_FILE="$2"
13 | shift 2
14 | ;;
15 | --epochs)
16 | EPOCHS="$2"
17 | shift 2
18 | ;;
19 | *)
20 | echo "Unknown argument: $1"
21 | echo "Usage: $0 [--sweep path/to/sweep.jsonl] [--epochs N]"
22 | exit 1
23 | ;;
24 | esac
25 | done
26 |
27 | if [ -z "$SWEEP_FILE" ]; then
28 | SWEEP_FILE="$DEFAULT_SWEEP"
29 | fi
30 |
31 | # Fewshots:
32 | # math_multi -> 4o
33 | #####
34 | # cs10 -> o1, o1-mini, claude, 4o
35 | # tooluse -> o1, o1-mini, claude
36 | # email_cs_simple -> o1, o1-mini, claude, 4o
37 |
38 |
39 | directories=(
40 | "experiments/email_elon"
41 | "experiments/email_cs"
42 | "experiments/email_cs10"
43 | "experiments/math_multi"
44 | "experiments/email_cs_simple"
45 | )
46 |
47 | if [ ! -f "$VENV_PATH" ]; then
48 | echo "Error: Virtual environment not found at $VENV_PATH"
49 | exit 1
50 | fi
51 |
52 | source "$VENV_PATH" || {
53 | echo "Error: Failed to activate virtual environment"
54 | exit 1
55 | }
56 |
57 | # Verify sweep file exists
58 | if [ ! -f "$SWEEP_FILE" ]; then
59 | echo "Error: Sweep file not found at $SWEEP_FILE"
60 | exit 1
61 | fi
62 |
63 | for dir in "${directories[@]}"; do
64 | if [ ! -f "$dir/config.json" ]; then
65 | echo "Warning: Config file not found at $dir/config.json"
66 | continue
67 | fi
68 | echo "Starting sweep for $dir using sweep file: $SWEEP_FILE"
69 | promptim train --task "$dir/config.json" --sweep "$SWEEP_FILE" --epochs $EPOCHS &
70 | done
71 |
72 | wait
73 | echo "All sweeps completed"
74 |
--------------------------------------------------------------------------------
/config-schema.json:
--------------------------------------------------------------------------------
1 | {
2 | "$defs": {
3 | "OptimizerConfig": {
4 | "properties": {
5 | "model": {
6 | "description": "Model configuration dictionary specifying the model name, parameters, and other settings used during optimization.",
7 | "title": "Model",
8 | "type": "object"
9 | }
10 | },
11 | "required": [
12 | "model"
13 | ],
14 | "title": "OptimizerConfig",
15 | "type": "object"
16 | },
17 | "PromptConfig": {
18 | "properties": {
19 | "identifier": {
20 | "anyOf": [
21 | {
22 | "type": "string"
23 | },
24 | {
25 | "type": "null"
26 | }
27 | ],
28 | "default": null,
29 | "description": "Identifier for a prompt from the hub repository. Mutually exclusive with prompt_str.",
30 | "title": "Identifier"
31 | },
32 | "prompt_str": {
33 | "anyOf": [
34 | {
35 | "type": "string"
36 | },
37 | {
38 | "type": "null"
39 | }
40 | ],
41 | "default": null,
42 | "description": "Raw prompt string to optimize locally. Mutually exclusive with identifier.",
43 | "title": "Prompt Str"
44 | },
45 | "model_config": {
46 | "anyOf": [
47 | {
48 | "type": "object"
49 | },
50 | {
51 | "type": "null"
52 | }
53 | ],
54 | "default": null,
55 | "description": "Configuration dictionary specifying model parameters for optimization.",
56 | "title": "Model Config"
57 | },
58 | "which": {
59 | "default": 0,
60 | "description": "Index of the message to optimize within the prompt.",
61 | "title": "Which",
62 | "type": "integer"
63 | }
64 | },
65 | "title": "PromptConfig",
66 | "type": "object"
67 | }
68 | },
69 | "properties": {
70 | "name": {
71 | "title": "Name",
72 | "type": "string"
73 | },
74 | "dataset": {
75 | "title": "Dataset",
76 | "type": "string"
77 | },
78 | "initial_prompt": {
79 | "$ref": "#/$defs/PromptConfig"
80 | },
81 | "description": {
82 | "anyOf": [
83 | {
84 | "type": "string"
85 | },
86 | {
87 | "type": "null"
88 | }
89 | ],
90 | "default": "",
91 | "title": "Description"
92 | },
93 | "evaluator_descriptions": {
94 | "anyOf": [
95 | {
96 | "type": "object"
97 | },
98 | {
99 | "type": "null"
100 | }
101 | ],
102 | "title": "Evaluator Descriptions"
103 | },
104 | "baseline_experiment": {
105 | "anyOf": [
106 | {
107 | "format": "uuid",
108 | "type": "string"
109 | },
110 | {
111 | "type": "null"
112 | }
113 | ],
114 | "default": null,
115 | "title": "Baseline Experiment"
116 | },
117 | "optimizer": {
118 | "anyOf": [
119 | {
120 | "$ref": "#/$defs/OptimizerConfig"
121 | },
122 | {
123 | "type": "null"
124 | }
125 | ],
126 | "default": null,
127 | "description": "Optimization configuration specifying model settings and hyperparameters. If None, default configuration will be used."
128 | },
129 | "evaluators": {
130 | "anyOf": [
131 | {
132 | "type": "string"
133 | },
134 | {
135 | "type": "null"
136 | }
137 | ],
138 | "description": "Import path to evaluator functions in format 'file_path:variable_name'. The functions should evaluate prompt quality.\n\nExample:\n ./task/evaluators.py:evaluators",
139 | "title": "Evaluators"
140 | },
141 | "system": {
142 | "anyOf": [
143 | {
144 | "type": "string"
145 | },
146 | {
147 | "type": "null"
148 | }
149 | ],
150 | "default": null,
151 | "description": "Import path to system configuration in format 'file_path:variable_name'. Defines how prompts are executed.\n\nExample:\n ./task/my_system.py:chain",
152 | "title": "System"
153 | }
154 | },
155 | "required": [
156 | "name",
157 | "dataset",
158 | "initial_prompt",
159 | "evaluators"
160 | ],
161 | "title": "Config",
162 | "type": "object",
163 | "$schema": "http://json-schema.org/draft-07/schema#"
164 | }
--------------------------------------------------------------------------------
/examples/tweet_writer/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tweet Generator",
3 | "dataset": "tweet-optim",
4 | "evaluators": "./task.py:evaluators",
5 | "optimizer": {
6 | "model": {
7 | "model": "openai:gpt-4o"
8 | }
9 | },
10 | "initial_prompt": {
11 | "identifier": "tweet-generator-example:starter"
12 | },
13 | "evaluator_descriptions": {
14 | "under_180_chars": "Checks if the tweet is under 180 characters. 1 if true, 0 if false.",
15 | "no_hashtags": "Checks if the tweet contains no hashtags. 1 if true, 0 if false.",
16 | "multiline": "Fails if the tweet is not multiple lines. 1 if true, 0 if false. 0 is bad."
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/examples/tweet_writer/create_dataset.py:
--------------------------------------------------------------------------------
1 | ## Example of how to create the dataset
2 |
3 | if __name__ == "__main__":
4 | from langsmith import Client
5 |
6 | client = Client()
7 |
8 | topics = [
9 | "NBA",
10 | "NFL",
11 | "Movies",
12 | "Taylor Swift",
13 | "Artificial Intelligence",
14 | "Climate Change",
15 | "Space Exploration",
16 | "Cryptocurrency",
17 | "Healthy Living",
18 | "Travel Destinations",
19 | "Technology Trends",
20 | "Fashion",
21 | "Food and Cooking",
22 | "Music Festivals",
23 | "Entrepreneurship",
24 | "Fitness",
25 | "Gaming",
26 | "Politics",
27 | "Environmental Conservation",
28 | "Social Media Trends",
29 | "Education",
30 | "Mental Health",
31 | "Renewable Energy",
32 | "Virtual Reality",
33 | "Sustainable Fashion",
34 | "Robotics",
35 | "Quantum Computing",
36 | "Genetic Engineering",
37 | "Smart Cities",
38 | "Cybersecurity",
39 | "Augmented Reality",
40 | "Electric Vehicles",
41 | "Blockchain",
42 | "3D Printing",
43 | "Nanotechnology",
44 | "Biotechnology",
45 | "Internet of Things",
46 | "Cloud Computing",
47 | "Big Data",
48 | "Machine Learning",
49 | "Artificial General Intelligence",
50 | "Space Tourism",
51 | "Autonomous Vehicles",
52 | "Drones",
53 | "Wearable Technology",
54 | "Personalized Medicine",
55 | "Telemedicine",
56 | "Remote Work",
57 | "Digital Nomads",
58 | "Gig Economy",
59 | "Circular Economy",
60 | "Vertical Farming",
61 | "Lab-grown Meat",
62 | "Plant-based Diets",
63 | "Mindfulness",
64 | "Yoga",
65 | "Meditation",
66 | "Biohacking",
67 | "Nootropics",
68 | "Intermittent Fasting",
69 | "HIIT Workouts",
70 | "Esports",
71 | "Streaming Services",
72 | "Podcasting",
73 | "True Crime",
74 | "Tiny Houses",
75 | "Minimalism",
76 | "Zero Waste Living",
77 | "Upcycling",
78 | "Eco-tourism",
79 | "Voluntourism",
80 | "Digital Detox",
81 | "Slow Living",
82 | "Hygge",
83 | "Urban Gardening",
84 | "Permaculture",
85 | "Regenerative Agriculture",
86 | "Microplastics",
87 | "Ocean Conservation",
88 | "Rewilding",
89 | "Endangered Species",
90 | "Biodiversity",
91 | "Ethical AI",
92 | "Data Privacy",
93 | "Net Neutrality",
94 | "Deepfakes",
95 | "Fake News",
96 | "Social Media Activism",
97 | "Cancel Culture",
98 | "Meme Culture",
99 | "NFTs",
100 | "Decentralized Finance",
101 | "Universal Basic Income",
102 | "Gender Equality",
103 | "LGBTQ+ Rights",
104 | "Black Lives Matter",
105 | "Indigenous Rights",
106 | "Refugee Crisis",
107 | "Global Poverty",
108 | "Universal Healthcare",
109 | "Drug Decriminalization",
110 | "Prison Reform",
111 | "Gun Control",
112 | "Voting Rights",
113 | "Gerrymandering",
114 | "Campaign Finance Reform",
115 | "Term Limits",
116 | "Ranked Choice Voting",
117 | "Direct Democracy",
118 | "Space Debris",
119 | "Asteroid Mining",
120 | "Mars Colonization",
121 | "Extraterrestrial Life",
122 | "Dark Matter",
123 | "Black Holes",
124 | "Quantum Entanglement",
125 | "Fusion Energy",
126 | "Antimatter",
127 | "Cryonics",
128 | "Life Extension",
129 | "Transhumanism",
130 | "Cyborgs",
131 | "Brain-Computer Interfaces",
132 | "Memory Implants",
133 | "Holographic Displays",
134 | ]
135 |
136 | # Create datasets
137 | ds = client.create_dataset(dataset_name="tweet-optim")
138 |
139 | # Split topics into train, dev, and test sets
140 | train_topics = topics[:80]
141 | dev_topics = topics[80:90]
142 | test_topics = topics[90:]
143 |
144 | # Create examples for each dataset
145 | for split_name, dataset_topics in [
146 | ("train", train_topics),
147 | ("dev", dev_topics),
148 | ("test", test_topics),
149 | ]:
150 | client.create_examples(
151 | inputs=[{"topic": topic} for topic in dataset_topics],
152 | dataset_id=ds.id,
153 | splits=[split_name] * len(dataset_topics),
154 | )
155 |
156 | print("Dataset created successfully!")
157 |
--------------------------------------------------------------------------------
/examples/tweet_writer/guided.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tweet Generator",
3 | "dataset": "tweet-optim",
4 | "evaluators": "./task.py:evaluators",
5 | "optimizer": {
6 | "kind": "feedback_guided",
7 | "model": {
8 | "model": "claude-3-5-sonnet-20241022",
9 | "max_tokens_to_sample": 8192
10 | },
11 | "max_batch_size": 10
12 | },
13 | "initial_prompt": {
14 | "identifier": "tweet-generator-example:starter"
15 | },
16 | "evaluator_descriptions": {
17 | "under_180_chars": "Checks if the tweet is under 180 characters. 1 if true, 0 if false.",
18 | "no_hashtags": "Checks if the tweet contains no hashtags. 1 if true, 0 if false.",
19 | "multiline": "Fails if the tweet is not multiple lines. 1 if true, 0 if false. 0 is bad."
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/examples/tweet_writer/task.py:
--------------------------------------------------------------------------------
1 | def under_180_chars(run, example):
2 | """Evaluate if the tweet is under 180 characters."""
3 | result = run.outputs.get("tweet", "")
4 | score = int(len(result) < 180)
5 | comment = "Pass" if score == 1 else "Fail"
6 | return {
7 | "key": "under_180_chars",
8 | "score": score,
9 | "comment": comment,
10 | }
11 |
12 |
13 | def no_hashtags(run, example):
14 | """Evaluate if the tweet contains no hashtags."""
15 | result = run.outputs.get("tweet", "")
16 | score = int("#" not in result)
17 | comment = "Pass" if score == 1 else "Fail"
18 | return {
19 | "key": "no_hashtags",
20 | "score": score,
21 | "comment": comment,
22 | }
23 |
24 |
25 | def multiple_lines(run, example):
26 | """Evaluate if the tweet contains multiple lines."""
27 | result = run.outputs.get("tweet", "")
28 | score = int("\n" in result)
29 | comment = "Pass" if score == 1 else "Fail"
30 | return {
31 | "key": "multiline",
32 | "score": score,
33 | "comment": comment,
34 | }
35 |
36 |
37 | evaluators = [multiple_lines, no_hashtags, under_180_chars]
38 |
--------------------------------------------------------------------------------
/experiments/__init__.py:
--------------------------------------------------------------------------------
1 | """Task definitions and configurations for prompt optimization."""
2 |
3 | from .multiclass_email3 import EmailClassificationTask3
4 | from .multiclass_email10 import EmailClassificationTask10
5 | from .multiclass_health3 import HealthClassificationTask3
6 | from .multiclass_health10 import HealthClassificationTask10
7 | from .tooluse_finance import FinanceToolUseTask
8 | from .tooluse_ecommerce import EcommerceToolUseTask
9 | from .extract_code import CodeExtractionTask
10 | from .extract_legal import LegalExtractionTask
11 |
12 | __all__ = [
13 | "EmailClassificationTask3",
14 | "EmailClassificationTask10",
15 | "HealthClassificationTask3",
16 | "HealthClassificationTask10",
17 | "FinanceToolUseTask",
18 | "EcommerceToolUseTask",
19 | "CodeExtractionTask",
20 | "LegalExtractionTask",
21 | ]
22 |
--------------------------------------------------------------------------------
/experiments/cs_tooluse/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "cs_tooluse",
3 | "dataset": "CS Simulations",
4 | "description": "Triage emails",
5 | "evaluators": "./evaluators.py:check_output",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the predicted tool calls match the expected behavior"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022"
12 | }
13 | },
14 | "initial_prompt": {
15 | "identifier": "langchain-ai/support-tool-use-demo:d6981321",
16 | "upload_to": "langchain-ai/support-tool-use-demo"
17 | },
18 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
19 | }
20 |
--------------------------------------------------------------------------------
/experiments/cs_tooluse/evaluators.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import AIMessage
2 | import asyncio
3 |
4 |
5 | async def check_tool_call(expected: dict, tool_calls: list) -> dict:
6 | matching = next((tc for tc in tool_calls if tc["name"] == expected["name"]), None)
7 | if matching is None:
8 | return {"comment": "", "score": 0.0}
9 | args: dict = matching["args"]
10 | expected_args: dict = expected["args"]
11 | pf = {}
12 | for arg_name, arg_value in expected_args.items():
13 | if arg_name not in args:
14 | return {"comment": f"missing arg: {arg_name}", "score": 0.0}
15 | if isinstance(arg_value, str) and len(arg_value.split(" ")) > 5:
16 | pf[arg_name] = (
17 | arg_name in args
18 | and isinstance(args[arg_name], str)
19 | and len(args[arg_name].split(" ")) > 5
20 | )
21 | continue
22 | pf[arg_name] = arg_value == args[arg_name]
23 | score = sum(pf.values()) / len(pf)
24 | if score < 1.0:
25 | score_per_arg = ", ".join(
26 | f"{arg_name}: {int(score)}" for arg_name, score in pf.items()
27 | )
28 | comment = f"Score per arg: {score_per_arg}"
29 | else:
30 | comment = "ok"
31 | return {"comment": comment, "score": score}
32 |
33 |
34 | async def check_output(outputs: dict, reference_outputs: dict) -> dict:
35 | expected: dict = reference_outputs["output"]
36 | actual: AIMessage = outputs["output"]
37 | actual_tcs = actual.tool_calls or []
38 |
39 | coros = []
40 | for tc in expected.get("tool_calls", []) or []:
41 | coros.append(check_tool_call(tc, actual_tcs))
42 |
43 | results = await asyncio.gather(*coros)
44 | if not results:
45 | return {
46 | "key": "quality",
47 | "score": 1.0,
48 | "comment": "No tool calls found to evaluate.",
49 | }
50 | score = sum(r["score"] for r in results) / len(results)
51 | passed = score == 1.0
52 | comment = ", ".join(r["comment"] for r in results if r["comment"])
53 | comment = f"{'Passed' if passed else 'Failed'}: {comment}"
54 |
55 | return {"key": "quality", "score": score, "comment": comment}
56 |
--------------------------------------------------------------------------------
/experiments/elon_email_sweeps.jsonl:
--------------------------------------------------------------------------------
1 | // {"optimizer": {"kind": "metaprompt", "model": {"model": "openai:o1"} }}
2 | // {"optimizer": {"kind": "feedback_guided", "model": {"model": "openai:o1"} }}
3 | // {"optimizer": {"model": {"model": "openai:o1"} }, "algorithm": {"kind": "phaseevo"}}
4 | {"optimizer": {"kind": "metaprompt", "model": {"model": "claude-3-5-sonnet-20241022","max_tokens_to_sample": 8192} }}
5 | // {"optimizer": {"kind": "metaprompt", "model": {"model": "openai:gpt-4o"} }}
6 | // {"optimizer": {"kind": "metaprompt", "model": {"model": "openai:o1"} }}
7 | // {"optimizer": {"kind": "feedback_guided", "model": {"model": "claude-3-5-sonnet-20241022","max_tokens_to_sample": 8192} }}
8 | // {"optimizer": {"kind": "feedback_guided", "model": {"model": "openai:gpt-4o"} }}
9 | // {"optimizer": {"model": {"model": "claude-3-5-sonnet-20241022","max_tokens_to_sample": 8192} }, "algorithm": {"kind": "phaseevo"}}
10 | // {"optimizer": {"model": {"model": "openai:gpt-4o"} }, "algorithm": {"kind": "phaseevo"}}
--------------------------------------------------------------------------------
/experiments/email_cs/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "email_cs",
3 | "dataset": "https://smith.langchain.com/public/816b9cfd-e049-4425-b9bf-9835225947a0/d",
4 | "description": "Triage emails",
5 | "evaluators": "./evaluators.py:accuracy_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the predicted category matches the reference class"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022"
12 | }
13 | },
14 | "initial_prompt": {
15 | "identifier": "langchain-ai/cs_email:ec65d706",
16 | "upload_to": "cs_email"
17 | },
18 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
19 | }
--------------------------------------------------------------------------------
/experiments/email_cs/evaluators.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import AIMessage
2 | from langsmith.schemas import Run, Example
3 |
4 |
5 | def accuracy_evaluator(run: Run, example: Example) -> dict:
6 | """Evaluator to check if the predicted emotion class matches the reference."""
7 | reference_outputs = example.outputs
8 | predicted: AIMessage = run.outputs["route_to"]
9 | score = 1 if predicted == reference_outputs["route_to"] else 0
10 | return {
11 | "key": "accuracy",
12 | "score": score,
13 | "comment": "Pass: triage class is correct"
14 | if score == 1
15 | else "Fail: triage class is not correct",
16 | }
17 |
18 |
19 | evaluators = [accuracy_evaluator]
20 |
--------------------------------------------------------------------------------
/experiments/email_cs10/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "email_cs_10",
3 | "dataset": "https://smith.langchain.com/public/45cf669b-4e11-4315-a311-5a59f855df8b/d",
4 | "description": "Triage emails",
5 | "evaluators": "./evaluators.py:accuracy_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the predicted category matches the reference class"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022"
12 | }
13 | },
14 | "initial_prompt": {
15 | "identifier": "langchain-ai/cs_email10:6a110c73",
16 | "upload_to": "cs_email10"
17 | },
18 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
19 | }
--------------------------------------------------------------------------------
/experiments/email_cs10/evaluators.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import AIMessage
2 | from langsmith.schemas import Run, Example
3 |
4 |
5 | def accuracy_evaluator(run: Run, example: Example) -> dict:
6 | """Evaluator to check if the predicted emotion class matches the reference."""
7 | reference_outputs = example.outputs
8 | predicted: AIMessage = run.outputs["route_to"]
9 | score = 1 if predicted == reference_outputs["route_to"] else 0
10 | return {
11 | "key": "accuracy",
12 | "score": score,
13 | "comment": "Pass: triage class is correct"
14 | if score == 1
15 | else "Fail: triage class is not correct",
16 | }
17 |
18 |
19 | evaluators = [accuracy_evaluator]
20 |
--------------------------------------------------------------------------------
/experiments/email_cs_simple/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "email_cs_simple",
3 | "dataset": "https://smith.langchain.com/public/c0af3eb3-6f91-47cd-965c-e3bea3140e09/d",
4 | "description": "Triage emails",
5 | "evaluators": "./evaluators.py:accuracy_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the predicted category matches the reference class"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022"
12 | }
13 | },
14 | "initial_prompt": {
15 | "identifier": "langchain-ai/email_cs_simple:933387ea",
16 | "upload_to": "email_cs_simple"
17 | },
18 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
19 | }
20 |
--------------------------------------------------------------------------------
/experiments/email_cs_simple/evaluators.py:
--------------------------------------------------------------------------------
1 | from langsmith.schemas import Run, Example
2 |
3 |
4 | def accuracy_evaluator(run: Run, example: Example) -> dict:
5 | """Evaluator to check if the predicted emotion class matches the reference."""
6 | score = 1 if run.outputs["action"] == example.outputs["action"] else 0
7 | return {
8 | "key": "accuracy",
9 | "score": score,
10 | "comment": (
11 | "Pass: triage class is correct"
12 | if score == 1
13 | else "Fail: triage class is not correct"
14 | ),
15 | }
16 |
17 |
18 | evaluators = [accuracy_evaluator]
19 |
--------------------------------------------------------------------------------
/experiments/email_elon/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "email_elon",
3 | "dataset": "https://smith.langchain.com/public/0cc60f5b-5c37-4020-8d54-88e403a18a9c/d",
4 | "description": "Triage emails",
5 | "evaluators": "email_elon.evaluators:accuracy_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the predicted category matches the reference class"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022"
12 | }
13 | },
14 | "initial_prompt": {
15 | "identifier": "langchain-ai/elon_email:38f8f365",
16 | "upload_to": "elon_email"
17 | },
18 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
19 | }
--------------------------------------------------------------------------------
/experiments/email_elon/evaluators.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import AIMessage
2 | from langsmith.schemas import Run, Example
3 |
4 |
5 | def accuracy_evaluator(run: Run, example: Example) -> dict:
6 | """Evaluator to check if the predicted emotion class matches the reference."""
7 | reference_outputs = example.outputs
8 | predicted: AIMessage = run.outputs["classification"]
9 | score = 1 if predicted == reference_outputs["classification"] else 0
10 | return {
11 | "key": "accuracy",
12 | "score": score,
13 | "comment": "Pass: triage class is correct"
14 | if score == 1
15 | else "Fail: triage class is not correct",
16 | }
17 |
18 |
19 | evaluators = [accuracy_evaluator]
20 |
--------------------------------------------------------------------------------
/experiments/extract_code/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import CodeExtractionTask
2 |
3 | __all__ = ["CodeExtractionTask"]
4 |
--------------------------------------------------------------------------------
/experiments/extract_code/backup.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "extraction-code",
3 | "dataset": {
4 | "name": "code_extraction_dataset",
5 | "description": "A dataset of code snippets with extracted information",
6 | "url": "https://smith.langchain.com/public/84e18db4-b5f1-4eaa-8896-904532d167db/d"
7 | },
8 | "description": "Extract function name, parameters, and return type from code",
9 | "evaluators": "extract_code.evaluators:accuracy_evaluator",
10 | "evaluator_descriptions": {
11 | "accuracy": "Evaluates if the extracted information matches the reference"
12 | },
13 | "optimizer": {
14 | "model": {
15 | "model": "claude-3-5-sonnet-20241022"
16 | }
17 | },
18 | "initial_prompt": {
19 | "identifier": "emily-sentiment/extraction-code:15288745",
20 | "upload_to": "langchain-ai/extraction-code"
21 | },
22 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
23 | }
--------------------------------------------------------------------------------
/experiments/extract_code/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "extraction-code",
3 | "dataset": "code_extraction_dataset",
4 | "description": "Extract function name, parameters, and return type from code",
5 | "evaluators": "extract_code.evaluators:accuracy_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the extracted information matches the reference"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022"
12 | }
13 | },
14 | "initial_prompt": {
15 | "identifier": "emily-sentiment/extraction-code:15288745",
16 | "upload_to": "langchain-ai/extraction-code"
17 | },
18 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
19 | }
--------------------------------------------------------------------------------
/experiments/extract_code/evaluators.py:
--------------------------------------------------------------------------------
1 | from typing_extensions import TypedDict
2 | from langsmith.schemas import Run, Example
3 | from difflib import SequenceMatcher
4 | import json
5 |
6 |
7 | class Outputs(TypedDict):
8 | function_name: str
9 | parameters: list
10 | return_type: str
11 |
12 |
13 | def semantic_similarity(a, b):
14 | """compare two values with semantic similarity, handling different data types."""
15 | if a is None or b is None:
16 | return 0
17 |
18 | # convert to strings for comparison
19 | if isinstance(a, (dict, list)):
20 | a = json.dumps(a, sort_keys=True)
21 | if isinstance(b, (dict, list)):
22 | b = json.dumps(b, sort_keys=True)
23 |
24 | a = str(a).lower()
25 | b = str(b).lower()
26 | return SequenceMatcher(None, a, b).ratio()
27 |
28 |
29 | def accuracy_evaluator(run: Run, example: Example) -> dict:
30 | """evaluator for partial matching of function details."""
31 | try:
32 | # safely get reference outputs
33 | reference_outputs = example.outputs or {}
34 | extractions_str = reference_outputs.get("extractions", "{}")
35 |
36 | if isinstance(extractions_str, dict):
37 | nested_json = extractions_str
38 | else:
39 | nested_json = json.loads(extractions_str)
40 |
41 | # safely get run outputs
42 | run_outputs = run.outputs or {}
43 |
44 | if isinstance(run_outputs, str):
45 | try:
46 | run_outputs = json.loads(run_outputs)
47 | except json.JSONDecodeError:
48 | run_outputs = {}
49 |
50 | # calculate matches with semantic similarity
51 | matches = {
52 | "function_name": semantic_similarity(
53 | run_outputs.get("function_name"), nested_json.get("function_name")
54 | ),
55 | "parameters": semantic_similarity(
56 | run_outputs.get("parameters", []), nested_json.get("parameters", [])
57 | ),
58 | "return_type": semantic_similarity(
59 | run_outputs.get("return_type"), nested_json.get("return_type")
60 | ),
61 | }
62 |
63 | # calculate overall score
64 | score = sum(matches.values()) / len(matches)
65 |
66 | # generate detailed feedback
67 | if score == 1:
68 | comment = "Pass: Perfect match in function details."
69 | elif score > 0:
70 | comment = f"Partial match (score: {score:.2f})."
71 | else:
72 | comment = "Fail: No match in function details."
73 |
74 | # add specific field comparisons to comment
75 | field_feedback = [f"{k}: {v:.2f}" for k, v in matches.items()]
76 | comment += f"\nField-wise similarity scores:\n{', '.join(field_feedback)}"
77 |
78 | except Exception as e:
79 | # provide informative error feedback
80 | score = 0
81 | comment = (
82 | f"Error in evaluation: {str(e)}. Run outputs: {str(run.outputs)[:200]}..."
83 | )
84 |
85 | return {
86 | "key": "function_extraction_accuracy",
87 | "score": score,
88 | "comment": comment,
89 | }
90 |
91 |
92 | evaluators = [accuracy_evaluator]
93 |
--------------------------------------------------------------------------------
/experiments/extract_code/task.py:
--------------------------------------------------------------------------------
1 | """Task definition for code extraction."""
2 |
3 | from typing import List, Callable
4 | import json
5 | from pathlib import Path
6 |
7 | from langsmith.schemas import Run, Example
8 | from krishpromptim.prompt_types import Task, Dataset
9 | from .evaluators import evaluators as code_extraction_evaluators
10 |
11 |
12 | class CodeExtractionTask(Task):
13 | """Task for extracting function details from code."""
14 |
15 | def __init__(self, custom_evaluators: List[Callable[[Run, Example], dict]] = None):
16 | """Initialize the task with optional custom evaluators.
17 |
18 | Args:
19 | custom_evaluators: Optional list of custom evaluator functions. If provided,
20 | these will replace the default evaluator.
21 | """
22 | evaluators = (
23 | custom_evaluators
24 | if custom_evaluators is not None
25 | else code_extraction_evaluators
26 | )
27 |
28 | # Load config from json
29 | config_path = Path(__file__).parent / "config.json"
30 | with open(config_path) as f:
31 | config = json.load(f)
32 |
33 | super().__init__(
34 | name=config["name"],
35 | description=config["description"],
36 | dataset=Dataset(**config["dataset"]),
37 | evaluators=evaluators,
38 | evaluator_descriptions=config["evaluator_descriptions"],
39 | initial_prompt=config["initial_prompt"],
40 | optimizer=config["optimizer"],
41 | )
42 |
43 |
44 | # Export the task class
45 | __all__ = ["CodeExtractionTask"]
46 |
--------------------------------------------------------------------------------
/experiments/extract_legal/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import LegalExtractionTask
2 |
3 | __all__ = ["LegalExtractionTask"]
4 |
--------------------------------------------------------------------------------
/experiments/extract_legal/backup.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "extraction-legal",
3 | "dataset": {
4 | "name": "legal_extraction_dataset",
5 | "description": "A dataset of legal documents with extracted information",
6 | "url": "https://smith.langchain.com/public/23723ece-9e09-4564-9ffd-e3243360bdcd/d"
7 | },
8 | "description": "Extract key information from legal documents",
9 | "evaluators": "extract_legal.evaluators:accuracy_evaluator",
10 | "evaluator_descriptions": {
11 | "accuracy": "Evaluates if the extracted information matches the reference"
12 | },
13 | "optimizer": {
14 | "model": {
15 | "model": "claude-3-5-sonnet-20241022",
16 | "max_tokens_to_sample": 8192
17 | }
18 | },
19 | "initial_prompt": {
20 | "identifier": "emily-sentiment/extraction-legal:8791ad13",
21 | "upload_to": "langchain-ai/extraction-legal"
22 | }
23 | }
--------------------------------------------------------------------------------
/experiments/extract_legal/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "extraction-legal",
3 | "dataset": "legal_extraction_dataset",
4 | "description": "Extract key information from legal documents",
5 | "evaluators": "extract_legal.evaluators:accuracy_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the extracted information matches the reference"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022",
12 | "max_tokens_to_sample": 8192
13 | }
14 | },
15 | "initial_prompt": {
16 | "identifier": "emily-sentiment/extraction-legal:8791ad13",
17 | "upload_to": "langchain-ai/extraction-legal"
18 | }
19 | }
--------------------------------------------------------------------------------
/experiments/extract_legal/evaluators.py:
--------------------------------------------------------------------------------
1 | from typing_extensions import TypedDict
2 | from langsmith.schemas import Run, Example
3 | from difflib import SequenceMatcher
4 | import json
5 |
6 |
7 | class Outputs(TypedDict):
8 | parties_involved: list
9 | effective_date: str
10 | termination_clauses: list
11 | jurisdiction: str
12 | governing_law: str
13 | payment_terms: dict
14 | liability_clauses: list
15 | confidentiality_terms: dict
16 |
17 |
18 | def semantic_similarity(a, b):
19 | """compare two values with semantic similarity, handling different data types."""
20 | if a is None or b is None:
21 | return 0
22 |
23 | # convert to strings for comparison
24 | if isinstance(a, (dict, list)):
25 | a = json.dumps(a, sort_keys=True)
26 | if isinstance(b, (dict, list)):
27 | b = json.dumps(b, sort_keys=True)
28 |
29 | a = str(a).lower()
30 | b = str(b).lower()
31 | return SequenceMatcher(None, a, b).ratio()
32 |
33 |
34 | def accuracy_evaluator(run: Run, example: Example) -> dict:
35 | """evaluator for partial matching of legal document details."""
36 | try:
37 | # parse reference outputs from nested json structure
38 | reference_outputs = example.outputs or {}
39 | extractions_str = reference_outputs.get("extractions", "{}")
40 |
41 | if isinstance(extractions_str, dict):
42 | reference_data = extractions_str
43 | else:
44 | reference_data = json.loads(extractions_str)
45 |
46 | # parse run outputs
47 | run_outputs = run.outputs
48 | if isinstance(run_outputs, str):
49 | try:
50 | outputs = json.loads(run_outputs)
51 | except json.JSONDecodeError:
52 | outputs = {}
53 | else:
54 | outputs = run_outputs or {}
55 |
56 | # calculate matches with semantic similarity
57 | matches = {
58 | "parties_involved": semantic_similarity(
59 | outputs.get("parties_involved"), reference_data.get("parties_involved")
60 | ),
61 | "effective_date": semantic_similarity(
62 | outputs.get("effective_date"), reference_data.get("effective_date")
63 | ),
64 | "termination_clauses": semantic_similarity(
65 | outputs.get("termination_clauses"),
66 | reference_data.get("termination_clauses"),
67 | ),
68 | "jurisdiction": semantic_similarity(
69 | outputs.get("jurisdiction"), reference_data.get("jurisdiction")
70 | ),
71 | "governing_law": semantic_similarity(
72 | outputs.get("governing_law"), reference_data.get("governing_law")
73 | ),
74 | "payment_terms": semantic_similarity(
75 | outputs.get("payment_terms"), reference_data.get("payment_terms")
76 | ),
77 | "liability_clauses": semantic_similarity(
78 | outputs.get("liability_clauses"),
79 | reference_data.get("liability_clauses"),
80 | ),
81 | "confidentiality_terms": semantic_similarity(
82 | outputs.get("confidentiality_terms"),
83 | reference_data.get("confidentiality_terms"),
84 | ),
85 | }
86 |
87 | # calculate overall score
88 | score = sum(matches.values()) / len(matches)
89 |
90 | # generate detailed feedback
91 | if score > 0.9:
92 | comment = "Pass: Very close match in legal document details."
93 | elif score > 0.7:
94 | comment = "Good: Strong match with minor differences."
95 | elif score > 0.5:
96 | comment = "Fair: Moderate match with some differences."
97 | else:
98 | comment = "Need improvement: Significant differences found."
99 |
100 | # add specific field comparisons to comment
101 | field_feedback = [f"{k}: {v:.2f}" for k, v in matches.items()]
102 | comment += f"\nField-wise similarity scores:\n{', '.join(field_feedback)}"
103 |
104 | except Exception as e:
105 | score = 0
106 | comment = f"Error in evaluation: {str(e)}. Check JSON structure and parsing."
107 |
108 | return {"key": "legal_extraction_accuracy", "score": score, "comment": comment}
109 |
110 |
111 | evaluators = [accuracy_evaluator]
112 |
--------------------------------------------------------------------------------
/experiments/extract_legal/task.py:
--------------------------------------------------------------------------------
1 | """Task definition for legal document extraction."""
2 |
3 | from typing import List, Callable
4 | import json
5 | from pathlib import Path
6 |
7 | from langsmith.schemas import Run, Example
8 | from krishpromptim.prompt_types import Task, Dataset
9 | from .evaluators import evaluators as legal_extraction_evaluators
10 |
11 |
12 | class LegalExtractionTask(Task):
13 | """Task for extracting information from legal documents."""
14 |
15 | def __init__(self, custom_evaluators: List[Callable[[Run, Example], dict]] = None):
16 | """Initialize the task with optional custom evaluators.
17 |
18 | Args:
19 | custom_evaluators: Optional list of custom evaluator functions. If provided,
20 | these will replace the default evaluator.
21 | """
22 | evaluators = (
23 | custom_evaluators
24 | if custom_evaluators is not None
25 | else legal_extraction_evaluators
26 | )
27 |
28 | # Load config from json
29 | config_path = Path(__file__).parent / "config.json"
30 | with open(config_path) as f:
31 | config = json.load(f)
32 |
33 | super().__init__(
34 | name=config["name"],
35 | description=config["description"],
36 | dataset=Dataset(**config["dataset"]),
37 | evaluators=evaluators,
38 | evaluator_descriptions=config["evaluator_descriptions"],
39 | initial_prompt=config["initial_prompt"],
40 | optimizer=config["optimizer"],
41 | )
42 |
43 |
44 | # Export the task class
45 | __all__ = ["LegalExtractionTask"]
46 |
--------------------------------------------------------------------------------
/experiments/math_multi/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "math_word_problems",
3 | "dataset": "https://smith.langchain.com/public/dd06b406-c37e-429e-8274-db4e9bc832ef/d",
4 | "description": "Solve math word problems",
5 | "evaluators": "./evaluators.py:correctness_evaluator",
6 | "evaluator_descriptions": {
7 | "correctness": "Evaluates if the predicted answer equals the reference answer"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022"
12 | }
13 | },
14 | "initial_prompt": {
15 | "identifier": "langchain-ai/math_word_problems:47c8c36d",
16 | "upload_to": "math_word_problems"
17 | },
18 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
19 | }
20 |
--------------------------------------------------------------------------------
/experiments/math_multi/evaluators.py:
--------------------------------------------------------------------------------
1 | from langsmith.schemas import Run, Example
2 | from trustcall import create_extractor
3 | from langchain_openai import ChatOpenAI
4 | from typing import Literal
5 |
6 |
7 | def segment_error(
8 | value_correctness_analysis: str,
9 | value_correctness: bool,
10 | language_correctness_analysis: str,
11 | language_correctness: bool,
12 | error_type: Literal["logic", "language", "syntax"],
13 | ):
14 | """Analyze the failing test case to break down **why** the prompt failed. It could fail either because the value was wrong (logic error), the response language was wrong (language error) or just spelling error (syntax error; so the value is correct and langugae is correct but there was a small spacing or punctuation error)."""
15 | pass
16 |
17 |
18 | grader = create_extractor(
19 | ChatOpenAI(model="gpt-4o-mini"),
20 | tools=[segment_error],
21 | tool_choice="segment_error",
22 | )
23 |
24 |
25 | async def correctness_evaluator(run: Run, example: Example) -> dict:
26 | """Evaluator to check if the predicted answer matches the reference."""
27 | reference_outputs = example.outputs
28 | try:
29 | predicted = run.outputs["answer"]
30 | except KeyError:
31 | predicted = "Failed to generate answer"
32 | score = 1 if predicted.lower() == reference_outputs["answer"].lower() else 0
33 | if not score:
34 | response = await grader.ainvoke(
35 | "Analyze the following test case to break-down why it failed. First analyze the value in a language-agnostic sense, then analyze the language."
36 | " If it is in the correct language and the correct value but just written differently, then it is a syntax error. "
37 | "\n\nExample 1:\n\n"
38 | "Test case: 1 + 2\n"
39 | "Reference answer: 三\n"
40 | "Predicted answer: three\n"
41 | " Judgment: "
42 | " value_correctness_analysis: 三 and three both represent 3, so the value is correct.\n"
43 | " value_correctness: True\n"
44 | "language_correctness_analysis: three is in the incorrect language. To pass, the prompt should have responded in Mandarin.\n"
45 | "language_correctness: false"
46 | "error_type: language"
47 | " Example 2:"
48 | "Test case: 1 + 30\n"
49 | "Reference answer: thirty-three\n"
50 | "Predicted answer: thirty three\n"
51 | " Judgment: "
52 | " value_correctness_analysis: thirty-three and thirty three both represent 33, so the value is correct.\n"
53 | " value_correctness: True\n"
54 | "language_correctness_analysis: thirty three is in the correct language but written differently; the language is correct\n"
55 | "language_correctness: true"
56 | "error_type: syntax"
57 | "\n\n"
58 | "# Test:\n"
59 | f"Test case: {example.inputs['problem']}\n"
60 | f"Reference answer: {reference_outputs['answer']}\n"
61 | f"Predicted answer: {predicted}"
62 | )
63 | result = response["responses"][0]
64 | if result.error_type == "syntax":
65 | return {
66 | "key": "correctness",
67 | "score": 1,
68 | "comment": "Pass: answer is correct, modulo a small syntax error.",
69 | }
70 | return {
71 | "key": "correctness",
72 | "score": 0,
73 | "comment": f"Fail: answer is not correct. Error type: {result.error_type}. "
74 | f"Logical correctness analysis: {result.value_correctness_analysis}. Language correctness analysis: {result.language_correctness_analysis}.",
75 | }
76 | return {
77 | "key": "correctness",
78 | "score": score,
79 | "comment": (
80 | "Pass: answer is correct" if score == 1 else "Fail: answer is not correct"
81 | ),
82 | }
83 |
84 |
85 | evaluators = [correctness_evaluator]
86 |
--------------------------------------------------------------------------------
/experiments/multiclass_email10/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import EmailClassificationTask10
2 |
3 | __all__ = ["EmailClassificationTask10"]
4 |
--------------------------------------------------------------------------------
/experiments/multiclass_email10/backup.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "multiclass_email10",
3 | "dataset": {
4 | "name": "multiclass_email10",
5 | "description": "A dataset of emails with their detailed emotion classifications",
6 | "url": "https://smith.langchain.com/public/18105687-0b24-404a-8679-e13fe6a5e383/d"
7 | },
8 | "description": "Classify emails into 10 sentiment categories",
9 | "evaluators": "multiclass_email10.evaluators:accuracy_evaluator",
10 | "evaluator_descriptions": {
11 | "accuracy": "Evaluates if the predicted category matches the reference class"
12 | },
13 | "optimizer": {
14 | "model": {
15 | "model": "claude-3-5-sonnet-20241022",
16 | "max_tokens_to_sample": 8192
17 | }
18 | },
19 | "initial_prompt": {
20 | "identifier": "emily-sentiment/multiclass-email10:4f6fd82a",
21 | "upload_to": "langchain-ai/multiclass-email10"
22 | },
23 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
24 | }
--------------------------------------------------------------------------------
/experiments/multiclass_email10/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "multiclass_email10",
3 | "dataset": "multiclass_email10",
4 | "description": "Classify emails into 10 sentiment categories",
5 | "evaluators": "multiclass_email10.evaluators:accuracy_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the predicted category matches the reference class"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022"
12 | }
13 | },
14 | "initial_prompt": {
15 | "identifier": "emily-sentiment/multiclass-email10:4f6fd82a",
16 | "upload_to": "langchain-ai/multiclass-email10"
17 | },
18 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
19 | }
20 |
--------------------------------------------------------------------------------
/experiments/multiclass_email10/evaluators.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import AIMessage
2 | from langsmith.schemas import Run, Example
3 |
4 |
5 | def accuracy_evaluator(run: Run, example: Example) -> dict:
6 | """Evaluator to check if the predicted emotion class matches the reference."""
7 | reference_outputs = example.outputs
8 | predicted: AIMessage = run.outputs["output"]
9 | result = str(predicted.content)
10 | score = 1 if result == reference_outputs["ten_class"] else 0
11 | return {
12 | "key": "accuracy",
13 | "score": score,
14 | "comment": "Pass: emotion class is correct"
15 | if score == 1
16 | else "Fail: emotion class is not correct",
17 | }
18 |
19 |
20 | evaluators = [accuracy_evaluator]
21 |
--------------------------------------------------------------------------------
/experiments/multiclass_email10/task.py:
--------------------------------------------------------------------------------
1 | """Task definition for email classification (10-class)."""
2 |
3 | from typing import List, Callable
4 | import json
5 | from pathlib import Path
6 |
7 | from langsmith.schemas import Run, Example
8 | from krishpromptim.prompt_types import Task, Dataset
9 | from .evaluators import evaluators as ten_class_evaluators
10 |
11 |
12 | class EmailClassificationTask10(Task):
13 | """Task for classifying emails into ten categories."""
14 |
15 | def __init__(self, custom_evaluators: List[Callable[[Run, Example], dict]] = None):
16 | """Initialize the task with optional custom evaluators.
17 |
18 | Args:
19 | custom_evaluators: Optional list of custom evaluator functions. If provided,
20 | these will replace the default evaluator.
21 | """
22 | evaluators = (
23 | custom_evaluators if custom_evaluators is not None else ten_class_evaluators
24 | )
25 |
26 | # Load config from json
27 | config_path = Path(__file__).parent / "config.json"
28 | with open(config_path) as f:
29 | config = json.load(f)
30 |
31 | super().__init__(
32 | name=config["name"],
33 | description=config["description"],
34 | dataset=Dataset(**config["dataset"]),
35 | evaluators=evaluators,
36 | evaluator_descriptions=config["evaluator_descriptions"],
37 | initial_prompt=config["initial_prompt"],
38 | optimizer=config["optimizer"],
39 | )
40 |
41 |
42 | # Export the task class
43 | __all__ = ["EmailClassificationTask10"]
44 |
--------------------------------------------------------------------------------
/experiments/multiclass_email3/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import EmailClassificationTask3
2 |
3 | __all__ = ["EmailClassificationTask3"]
4 |
--------------------------------------------------------------------------------
/experiments/multiclass_email3/backup.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "multiclass_email3",
3 | "dataset": {
4 | "name": "multiclass_email3",
5 | "description": "A dataset of emails with their sentiment classifications",
6 | "url": "https://smith.langchain.com/public/846e278a-9d0d-4d19-95c5-09803d722c36/d"
7 | },
8 | "description": "Classify emails into 3 sentiment categories",
9 | "evaluators": "multiclass_email3.evaluators:accuracy_evaluator",
10 | "evaluator_descriptions": {
11 | "accuracy": "Evaluates if the predicted category matches the reference class"
12 | },
13 | "optimizer": {
14 | "model": {
15 | "model": "claude-3-5-sonnet-20241022",
16 | "max_tokens_to_sample": 8192
17 | }
18 | },
19 | "initial_prompt": {
20 | "identifier": "emily-sentiment/multiclass-email3:e9d9beb2",
21 | "upload_to": "langchain-ai/multiclass-email3"
22 | },
23 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
24 | }
--------------------------------------------------------------------------------
/experiments/multiclass_email3/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "multiclass_email3",
3 | "dataset": "multiclass_email3",
4 | "description": "Classify emails into 3 sentiment categories",
5 | "evaluators": "multiclass_email3.evaluators:accuracy_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the predicted category matches the reference class"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022",
12 | "max_tokens_to_sample": 8192
13 | }
14 | },
15 | "initial_prompt": {
16 | "identifier": "emily-sentiment/multiclass-email3:e9d9beb2",
17 | "upload_to": "langchain-ai/multiclass-email3"
18 | },
19 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
20 | }
--------------------------------------------------------------------------------
/experiments/multiclass_email3/evaluators.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import AIMessage
2 | from langsmith.schemas import Run, Example
3 |
4 |
5 | def accuracy_evaluator(run: Run, example: Example) -> dict:
6 | """Evaluator to check if the predicted emotion class matches the reference."""
7 | reference_outputs = example.outputs
8 | predicted: AIMessage = run.outputs["output"]
9 | result = str(predicted.content)
10 | score = 1 if result == reference_outputs["three_class"] else 0
11 | return {
12 | "key": "accuracy",
13 | "score": score,
14 | "comment": "Pass: emotion class is correct"
15 | if score == 1
16 | else "Fail: emotion class is not correct",
17 | }
18 |
19 |
20 | evaluators = [accuracy_evaluator]
21 |
--------------------------------------------------------------------------------
/experiments/multiclass_email3/task.py:
--------------------------------------------------------------------------------
1 | """Task definition for email classification (3-class)."""
2 |
3 | from typing import List, Callable
4 | import json
5 | from pathlib import Path
6 |
7 | from langsmith.schemas import Run, Example
8 | from krishpromptim.prompt_types import Task, Dataset
9 | from .evaluators import evaluators as three_class_evaluators
10 |
11 |
12 | class EmailClassificationTask3(Task):
13 | """Task for classifying emails into three categories."""
14 |
15 | def __init__(self, custom_evaluators: List[Callable[[Run, Example], dict]] = None):
16 | """Initialize the task with optional custom evaluators.
17 |
18 | Args:
19 | custom_evaluators: Optional list of custom evaluator functions. If provided,
20 | these will replace the default evaluator.
21 | """
22 | evaluators = (
23 | custom_evaluators
24 | if custom_evaluators is not None
25 | else three_class_evaluators
26 | )
27 |
28 | # Load config from json
29 | config_path = Path(__file__).parent / "config.json"
30 | with open(config_path) as f:
31 | config = json.load(f)
32 |
33 | super().__init__(
34 | name=config["name"],
35 | description=config["description"],
36 | dataset=Dataset(**config["dataset"]),
37 | evaluators=evaluators,
38 | evaluator_descriptions=config["evaluator_descriptions"],
39 | initial_prompt=config["initial_prompt"],
40 | optimizer=config["optimizer"],
41 | )
42 |
43 |
44 | # Export the task class
45 | __all__ = ["EmailClassificationTask3"]
46 |
--------------------------------------------------------------------------------
/experiments/multiclass_health10/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import HealthClassificationTask10
2 |
3 | __all__ = ["HealthClassificationTask10"]
4 |
--------------------------------------------------------------------------------
/experiments/multiclass_health10/backup.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "multiclass-health10",
3 | "dataset": {
4 | "name": "multiclass-health10",
5 | "description": "A dataset of health-related texts with detailed classifications",
6 | "url": "https://smith.langchain.com/public/0a17ee5a-5f75-490f-835c-1d0ac70f957c/d"
7 | },
8 | "description": "Classify health-related text into 10 categories",
9 | "evaluators": "multiclass_health10.evaluators:accuracy_evaluator",
10 | "evaluator_descriptions": {
11 | "accuracy": "Evaluates if the predicted category matches the reference class"
12 | },
13 | "optimizer": {
14 | "model": {
15 | "model": "claude-3-5-sonnet-20241022",
16 | "max_tokens_to_sample": 8192
17 | }
18 | },
19 | "initial_prompt": {
20 | "identifier": "emily-sentiment/multiclass-health10:72da772b",
21 | "upload_to": "langchain-ai/multiclass-health10"
22 | },
23 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
24 | }
--------------------------------------------------------------------------------
/experiments/multiclass_health10/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "multiclass-health10",
3 | "dataset": "multiclass-health10",
4 | "description": "Classify health-related text into 10 categories",
5 | "evaluators": "multiclass_health10.evaluators:accuracy_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the predicted category matches the reference class"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022"
12 | }
13 | },
14 | "initial_prompt": {
15 | "identifier": "emily-sentiment/multiclass-health10:72da772b",
16 | "upload_to": "langchain-ai/multiclass-health10"
17 | },
18 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
19 | }
--------------------------------------------------------------------------------
/experiments/multiclass_health10/evaluators.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import AIMessage
2 | from langsmith.schemas import Run, Example
3 |
4 |
5 | def accuracy_evaluator(run: Run, example: Example) -> dict:
6 | """Evaluator to check if the predicted emotion class matches the reference."""
7 | reference_outputs = example.outputs
8 | predicted: AIMessage = run.outputs["output"]
9 | result = str(predicted.content)
10 | score = 1 if result == reference_outputs["ten_class"] else 0
11 | return {
12 | "key": "accuracy",
13 | "score": score,
14 | "comment": "Pass: health disease is correct"
15 | if score == 1
16 | else "Fail: health disease is not correct",
17 | }
18 |
19 |
20 | evaluators = [accuracy_evaluator]
21 |
--------------------------------------------------------------------------------
/experiments/multiclass_health10/task.py:
--------------------------------------------------------------------------------
1 | """Task definition for health classification (10-class)."""
2 |
3 | from typing import List, Callable
4 | import json
5 | from pathlib import Path
6 |
7 | from langsmith.schemas import Run, Example
8 | from krishpromptim.prompt_types import Task, Dataset
9 | from .evaluators import evaluators as health_ten_class_evaluators
10 |
11 |
12 | class HealthClassificationTask10(Task):
13 | """Task for classifying health conditions into ten categories."""
14 |
15 | def __init__(self, custom_evaluators: List[Callable[[Run, Example], dict]] = None):
16 | """Initialize the task with optional custom evaluators.
17 |
18 | Args:
19 | custom_evaluators: Optional list of custom evaluator functions. If provided,
20 | these will replace the default evaluator.
21 | """
22 | evaluators = (
23 | custom_evaluators
24 | if custom_evaluators is not None
25 | else health_ten_class_evaluators
26 | )
27 |
28 | # Load config from json
29 | config_path = Path(__file__).parent / "config.json"
30 | with open(config_path) as f:
31 | config = json.load(f)
32 |
33 | super().__init__(
34 | name=config["name"],
35 | description=config["description"],
36 | dataset=Dataset(**config["dataset"]),
37 | evaluators=evaluators,
38 | evaluator_descriptions=config["evaluator_descriptions"],
39 | initial_prompt=config["initial_prompt"],
40 | optimizer=config["optimizer"],
41 | )
42 |
43 |
44 | # Export the task class
45 | __all__ = ["HealthClassificationTask10"]
46 |
--------------------------------------------------------------------------------
/experiments/multiclass_health3/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import HealthClassificationTask3
2 |
3 | __all__ = ["HealthClassificationTask3"]
4 |
--------------------------------------------------------------------------------
/experiments/multiclass_health3/backup.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "multiclass-health3",
3 | "dataset": {
4 | "name": "health_classification_dataset",
5 | "description": "A dataset of health-related texts with sentiment classifications",
6 | "url": "https://smith.langchain.com/public/1cff4d33-c096-4460-9910-d453fd323820/d"
7 | },
8 | "description": "Classify health-related text into 3 categories",
9 | "evaluators": "multiclass_health3.evaluators:accuracy_evaluator",
10 | "evaluator_descriptions": {
11 | "accuracy": "Evaluates if the predicted category matches the reference class"
12 | },
13 | "optimizer": {
14 | "model": {
15 | "model": "claude-3-5-sonnet-20241022",
16 | "max_tokens_to_sample": 8192
17 | }
18 | },
19 | "initial_prompt": {
20 | "identifier": "emily-sentiment/multiclass-health3:6b523633",
21 | "upload_to": "langchain-ai/multiclass-health3"
22 | },
23 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
24 | }
--------------------------------------------------------------------------------
/experiments/multiclass_health3/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "multiclass-health3",
3 | "dataset": "health_classification_dataset",
4 | "description": "Classify health-related text into 3 categories",
5 | "evaluators": "multiclass_health3.evaluators:accuracy_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the predicted category matches the reference class"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022"
12 | }
13 | },
14 | "initial_prompt": {
15 | "identifier": "emily-sentiment/multiclass-health3:6b523633",
16 | "upload_to": "langchain-ai/multiclass-health3"
17 | },
18 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
19 | }
20 |
--------------------------------------------------------------------------------
/experiments/multiclass_health3/evaluators.py:
--------------------------------------------------------------------------------
1 | from langchain_core.messages import AIMessage
2 | from langsmith.schemas import Run, Example
3 |
4 |
5 | def accuracy_evaluator(run: Run, example: Example) -> dict:
6 | """Evaluator to check if the predicted emotion class matches the reference."""
7 | reference_outputs = example.outputs
8 | predicted: AIMessage = run.outputs["output"]
9 | result = str(predicted.content)
10 | score = 1 if result == reference_outputs["three_class"] else 0
11 | return {
12 | "key": "accuracy",
13 | "score": score,
14 | "comment": "Pass: health disease is correct"
15 | if score == 1
16 | else "Fail: health disease is not correct",
17 | }
18 |
19 |
20 | evaluators = [accuracy_evaluator]
21 |
--------------------------------------------------------------------------------
/experiments/multiclass_health3/task.py:
--------------------------------------------------------------------------------
1 | """Task definition for health classification (3-class)."""
2 |
3 | from typing import List, Callable
4 | import json
5 | from pathlib import Path
6 |
7 | from langsmith.schemas import Run, Example
8 | from krishpromptim.prompt_types import Task, Dataset
9 | from .evaluators import evaluators as health_three_class_evaluators
10 |
11 |
12 | class HealthClassificationTask3(Task):
13 | """Task for classifying health conditions into three categories."""
14 |
15 | def __init__(self, custom_evaluators: List[Callable[[Run, Example], dict]] = None):
16 | """Initialize the task with optional custom evaluators.
17 |
18 | Args:
19 | custom_evaluators: Optional list of custom evaluator functions. If provided,
20 | these will replace the default evaluator.
21 | """
22 | evaluators = (
23 | custom_evaluators
24 | if custom_evaluators is not None
25 | else health_three_class_evaluators
26 | )
27 |
28 | # Load config from json
29 | config_path = Path(__file__).parent / "config.json"
30 | with open(config_path) as f:
31 | config = json.load(f)
32 |
33 | super().__init__(
34 | name=config["name"],
35 | description=config["description"],
36 | dataset=Dataset(**config["dataset"]),
37 | evaluators=evaluators,
38 | evaluator_descriptions=config["evaluator_descriptions"],
39 | initial_prompt=config["initial_prompt"],
40 | optimizer=config["optimizer"],
41 | )
42 |
43 |
44 | # Export the task class
45 | __all__ = ["HealthClassificationTask3"]
46 |
--------------------------------------------------------------------------------
/experiments/sweeps.jsonl:
--------------------------------------------------------------------------------
1 | {"optimizer": {"model": {"model": "claude-3-5-sonnet-20241022","max_tokens_to_sample": 8192} }}
2 | {"optimizer": {"model": {"model": "openai:gpt-4o"} }}
3 | {"optimizer": {"model": {"model": "openai:o1-preview"} }}
--------------------------------------------------------------------------------
/experiments/tool_sweeps.jsonl:
--------------------------------------------------------------------------------
1 | {"optimizer": {"kind": "metaprompt", "max_reasoning_steps": 5, "model": {"model": "claude-3-5-sonnet-20241022","max_tokens_to_sample": 8192} }}
2 | // {"algorithm": {"kind": "mipro"}, "optimizer": {"model": {"model": "claude-3-5-sonnet-20241022","max_tokens_to_sample": 8192} }}
3 | {"optimizer": {"kind": "metaprompt", "max_reasoning_steps": 1, "model": {"model": "claude-3-5-sonnet-20241022","max_tokens_to_sample": 8192} }}
4 | {"optimizer": {"kind": "metaprompt", "max_reasoning_steps": 1, "model": {"model": "openai:gpt-4o"} }}
5 | {"optimizer": {"kind": "metaprompt", "max_reasoning_steps": 5, "model": {"model": "openai:gpt-4o"} }}
6 | {"optimizer": {"kind": "feedback_guided", "model": {"model": "claude-3-5-sonnet-20241022","max_tokens_to_sample": 8192} }}
7 | {"optimizer": {"kind": "feedback_guided", "model": {"model": "openai:gpt-4o"} }}
8 | {"optimizer": {"kind": "metaprompt", "max_reasoning_steps": 1, "model": {"model": "openai:o1"} }}
9 | {"optimizer": {"kind": "feedback_guided", "model": {"model": "openai:o1"} }}
10 | {"optimizer": {"kind": "metaprompt", "max_reasoning_steps": 5, "model": {"model": "openai:o1"} }}
11 | {"optimizer": {"model": {"model": "claude-3-5-sonnet-20241022","max_tokens_to_sample": 8192} }, "algorithm": {"kind": "phaseevo"}}
12 | {"optimizer": {"model": {"model": "openai:gpt-4o"} }, "algorithm": {"kind": "phaseevo"}}
13 | {"optimizer": {"model": {"model": "openai:o1"} }, "algorithm": {"kind": "phaseevo"}}
--------------------------------------------------------------------------------
/experiments/tooluse_ecommerce/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import EcommerceToolUseTask
2 |
3 | __all__ = ["EcommerceToolUseTask"]
4 |
--------------------------------------------------------------------------------
/experiments/tooluse_ecommerce/backup.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "tooluse-ecommerce",
3 | "dataset": {
4 | "name": "ecommerce_tooluse_dataset",
5 | "description": "A dataset of e-commerce scenarios requiring tool use",
6 | "url": "https://smith.langchain.com/public/37d23fe7-3ec3-4fbe-a85b-0697eb7a2d7e/d"
7 | },
8 | "description": "Use tools for e-commerce related tasks",
9 | "evaluators": "tooluse_ecommerce.evaluators:tool_use_evaluator",
10 | "evaluator_descriptions": {
11 | "accuracy": "Evaluates if the tool usage matches the expected behavior"
12 | },
13 | "optimizer": {
14 | "model": {
15 | "model": "claude-3-5-sonnet-20241022",
16 | "max_tokens_to_sample": 8192
17 | }
18 | },
19 | "initial_prompt": {
20 | "identifier": "emily-sentiment/tooluse-ecommerce:a34c2697",
21 | "upload_to": "langchain-ai/tooluse-ecommerce"
22 | },
23 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
24 | }
--------------------------------------------------------------------------------
/experiments/tooluse_ecommerce/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "tooluse-ecommerce",
3 | "dataset": "ecommerce_tooluse_dataset",
4 | "description": "Use tools for e-commerce related tasks",
5 | "evaluators": "tooluse_ecommerce.evaluators:tool_use_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the tool usage matches the expected behavior"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022",
12 | "max_tokens_to_sample": 8192
13 | }
14 | },
15 | "initial_prompt": {
16 | "identifier": "emily-sentiment/tooluse-ecommerce:a34c2697",
17 | "upload_to": "langchain-ai/tooluse-ecommerce"
18 | },
19 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
20 | }
--------------------------------------------------------------------------------
/experiments/tooluse_ecommerce/evaluators.py:
--------------------------------------------------------------------------------
1 | """Evaluators to optimize task: tool-use-ecommerce.
2 |
3 | Evaluators compute scores for prompts run over the configured dataset."""
4 |
5 | from langsmith.schemas import Run, Example
6 | import json
7 |
8 |
9 | def extract_tool_call(run_outputs):
10 | """Helper function to extract tool call from nested run output structure."""
11 | try:
12 | if isinstance(run_outputs, dict) and "tool_calls" in run_outputs:
13 | tool_call = run_outputs["tool_calls"][0]
14 | if isinstance(tool_call, dict):
15 | return {"name": tool_call["name"], "args": tool_call["args"]}
16 |
17 | if isinstance(run_outputs, dict) and "output" in run_outputs:
18 | output = run_outputs["output"]
19 |
20 | if hasattr(output, "additional_kwargs"):
21 | additional_kwargs = output.additional_kwargs
22 | if (
23 | isinstance(additional_kwargs, dict)
24 | and "tool_calls" in additional_kwargs
25 | ):
26 | tool_call = additional_kwargs["tool_calls"][0]
27 | if "function" in tool_call:
28 | return {
29 | "name": tool_call["function"]["name"],
30 | "args": json.loads(tool_call["function"]["arguments"])
31 | if isinstance(tool_call["function"]["arguments"], str)
32 | else tool_call["function"]["arguments"],
33 | }
34 |
35 | if isinstance(output, dict):
36 | if "tool_calls" in output:
37 | tool_call = output["tool_calls"][0]
38 | if isinstance(tool_call, dict):
39 | return {"name": tool_call["name"], "args": tool_call["args"]}
40 |
41 | if "additional_kwargs" in output:
42 | additional_kwargs = output["additional_kwargs"]
43 | if "tool_calls" in additional_kwargs:
44 | tool_call = additional_kwargs["tool_calls"][0]
45 | if "function" in tool_call:
46 | return {
47 | "name": tool_call["function"]["name"],
48 | "args": json.loads(tool_call["function"]["arguments"])
49 | if isinstance(tool_call["function"]["arguments"], str)
50 | else tool_call["function"]["arguments"],
51 | }
52 |
53 | except (KeyError, IndexError, json.JSONDecodeError, AttributeError):
54 | return None
55 |
56 | return None
57 |
58 |
59 | def tool_use_evaluator(run: Run, example: Example) -> dict:
60 | """Evaluator for matching the correct tool and its inputs."""
61 | try:
62 | reference_outputs = example.outputs or {}
63 | correct_tool_str = reference_outputs.get("correct_tool", "{}")
64 | if isinstance(correct_tool_str, str):
65 | correct_tool = json.loads(correct_tool_str)
66 | else:
67 | correct_tool = correct_tool_str
68 |
69 | predicted_tool = extract_tool_call(run.outputs)
70 | if not predicted_tool:
71 | return {
72 | "key": "tool_use_accuracy",
73 | "score": 0,
74 | "comment": f"No valid tool calls found in run outputs: {str(run.outputs)[:200]}...",
75 | }
76 |
77 | tool_name_match = predicted_tool.get("name") == correct_tool.get("name")
78 |
79 | correct_inputs = set(correct_tool.get("inputs", {}).items())
80 | predicted_inputs = set(predicted_tool.get("args", {}).items())
81 |
82 | if correct_inputs:
83 | inputs_match_count = len(correct_inputs.intersection(predicted_inputs))
84 | tool_inputs_match = inputs_match_count / len(correct_inputs)
85 | else:
86 | tool_inputs_match = 1 if not predicted_inputs else 0
87 |
88 | score = (tool_name_match + tool_inputs_match) / 2
89 |
90 | if score == 1:
91 | comment = "Pass: Correct tool and inputs matched."
92 | elif score > 0:
93 | comment = (
94 | f"Partial match (score: {score:.2f}). Expected tool '{correct_tool.get('name')}' "
95 | f"with inputs {correct_tool.get('inputs')}, but got tool '{predicted_tool.get('name')}' "
96 | f"with inputs {predicted_tool.get('args')}."
97 | )
98 | else:
99 | comment = (
100 | f"Fail: Expected tool '{correct_tool.get('name')}' with inputs {correct_tool.get('inputs')}, "
101 | f"but got tool '{predicted_tool.get('name')}' with inputs {predicted_tool.get('args')}."
102 | )
103 |
104 | except Exception as e:
105 | score = 0
106 | comment = (
107 | f"Error in evaluation: {str(e)}. Run outputs: {str(run.outputs)[:200]}... "
108 | f"Reference correct tool: {str(correct_tool)}"
109 | )
110 |
111 | result = {
112 | "key": "tool_use_accuracy",
113 | "score": score,
114 | "comment": comment,
115 | }
116 |
117 | return result
118 |
119 |
120 | evaluators = [tool_use_evaluator]
121 |
--------------------------------------------------------------------------------
/experiments/tooluse_ecommerce/task.py:
--------------------------------------------------------------------------------
1 | """Task definition for e-commerce tool use."""
2 |
3 | from typing import List, Callable
4 | import json
5 | from pathlib import Path
6 |
7 | from langsmith.schemas import Run, Example
8 | from krishpromptim.prompt_types import Task, Dataset
9 | from .evaluators import evaluators as ecommerce_tooluse_evaluators
10 |
11 |
12 | class EcommerceToolUseTask(Task):
13 | """Task for evaluating tool use in e-commerce scenarios."""
14 |
15 | def __init__(self, custom_evaluators: List[Callable[[Run, Example], dict]] = None):
16 | """Initialize the task with optional custom evaluators.
17 |
18 | Args:
19 | custom_evaluators: Optional list of custom evaluator functions. If provided,
20 | these will replace the default evaluator.
21 | """
22 | evaluators = (
23 | custom_evaluators
24 | if custom_evaluators is not None
25 | else ecommerce_tooluse_evaluators
26 | )
27 |
28 | # Load config from json
29 | config_path = Path(__file__).parent / "config.json"
30 | with open(config_path) as f:
31 | config = json.load(f)
32 |
33 | super().__init__(
34 | name=config["name"],
35 | description=config["description"],
36 | dataset=Dataset(**config["dataset"]),
37 | evaluators=evaluators,
38 | evaluator_descriptions=config["evaluator_descriptions"],
39 | initial_prompt=config["initial_prompt"],
40 | optimizer=config["optimizer"],
41 | )
42 |
43 |
44 | # Export the task class
45 | __all__ = ["EcommerceToolUseTask"]
46 |
--------------------------------------------------------------------------------
/experiments/tooluse_finance/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import FinanceToolUseTask
2 |
3 | __all__ = ["FinanceToolUseTask"]
4 |
--------------------------------------------------------------------------------
/experiments/tooluse_finance/backup.config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "tooluse-finance",
3 | "dataset": {
4 | "name": "finance_tooluse_dataset",
5 | "description": "A dataset of finance scenarios requiring tool use",
6 | "url": "https://smith.langchain.com/public/4e9220bb-d4f7-43ae-819b-4580811470fb/d"
7 | },
8 | "description": "Use tools for finance related tasks",
9 | "evaluators": "tooluse_finance.evaluators:tool_use_evaluator",
10 | "evaluator_descriptions": {
11 | "accuracy": "Evaluates if the tool usage matches the expected behavior"
12 | },
13 | "optimizer": {
14 | "model": {
15 | "model": "claude-3-5-sonnet-20241022",
16 | "max_tokens_to_sample": 8192
17 | }
18 | },
19 | "initial_prompt": {
20 | "identifier": "emily-sentiment/tooluse-finance:b9e0638f",
21 | "upload_to": "langchain-ai/tooluse-finance"
22 | },
23 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
24 | }
--------------------------------------------------------------------------------
/experiments/tooluse_finance/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "tooluse-finance",
3 | "dataset": "finance_tooluse_dataset",
4 | "description": "Use tools for finance related tasks",
5 | "evaluators": "tooluse_finance.evaluators:tool_use_evaluator",
6 | "evaluator_descriptions": {
7 | "accuracy": "Evaluates if the tool usage matches the expected behavior"
8 | },
9 | "optimizer": {
10 | "model": {
11 | "model": "claude-3-5-sonnet-20241022",
12 | "max_tokens_to_sample": 8192
13 | }
14 | },
15 | "initial_prompt": {
16 | "identifier": "emily-sentiment/tooluse-finance:b9e0638f",
17 | "upload_to": "langchain-ai/tooluse-finance"
18 | },
19 | "$schema": "https://raw.githubusercontent.com/hinthornw/promptimizer/refs/heads/main/config-schema.json"
20 | }
--------------------------------------------------------------------------------
/experiments/tooluse_finance/evaluators.py:
--------------------------------------------------------------------------------
1 | """Evaluators to optimize task: tooluse-finance.
2 |
3 | Evaluators compute scores for prompts run over the configured dataset:
4 | https://smith.langchain.com/o/4b3539a7-f6b9-4950-a199-a27fd5dcbf2f/datasets/10de0a62-603f-429c-b222-e5fb624f7ca6
5 | """
6 |
7 | from langsmith.schemas import Run, Example
8 | import json
9 |
10 |
11 | def extract_tool_call(run_outputs):
12 | """Helper function to extract tool call from nested run output structure."""
13 | try:
14 | if isinstance(run_outputs, dict) and "tool_calls" in run_outputs:
15 | tool_call = run_outputs["tool_calls"][0]
16 | if isinstance(tool_call, dict):
17 | return {"name": tool_call["name"], "args": tool_call["args"]}
18 |
19 | if isinstance(run_outputs, dict) and "output" in run_outputs:
20 | output = run_outputs["output"]
21 |
22 | if hasattr(output, "additional_kwargs"):
23 | additional_kwargs = output.additional_kwargs
24 | if (
25 | isinstance(additional_kwargs, dict)
26 | and "tool_calls" in additional_kwargs
27 | ):
28 | tool_call = additional_kwargs["tool_calls"][0]
29 | if "function" in tool_call:
30 | return {
31 | "name": tool_call["function"]["name"],
32 | "args": json.loads(tool_call["function"]["arguments"])
33 | if isinstance(tool_call["function"]["arguments"], str)
34 | else tool_call["function"]["arguments"],
35 | }
36 |
37 | if isinstance(output, dict):
38 | if "tool_calls" in output:
39 | tool_call = output["tool_calls"][0]
40 | if isinstance(tool_call, dict):
41 | return {"name": tool_call["name"], "args": tool_call["args"]}
42 |
43 | if "additional_kwargs" in output:
44 | additional_kwargs = output["additional_kwargs"]
45 | if "tool_calls" in additional_kwargs:
46 | tool_call = additional_kwargs["tool_calls"][0]
47 | if "function" in tool_call:
48 | return {
49 | "name": tool_call["function"]["name"],
50 | "args": json.loads(tool_call["function"]["arguments"])
51 | if isinstance(tool_call["function"]["arguments"], str)
52 | else tool_call["function"]["arguments"],
53 | }
54 |
55 | except (KeyError, IndexError, json.JSONDecodeError, AttributeError):
56 | return None
57 |
58 | return None
59 |
60 |
61 | def tool_use_evaluator(run: Run, example: Example) -> dict:
62 | """Evaluator for matching the correct tool and its inputs."""
63 | try:
64 | reference_outputs = example.outputs or {}
65 | correct_tool_str = reference_outputs.get("correct_tool", "{}")
66 | if isinstance(correct_tool_str, str):
67 | correct_tool = json.loads(correct_tool_str)
68 | else:
69 | correct_tool = correct_tool_str
70 |
71 | predicted_tool = extract_tool_call(run.outputs)
72 | if not predicted_tool:
73 | return {
74 | "key": "tool_use_accuracy",
75 | "score": 0,
76 | "comment": f"No valid tool calls found in run outputs: {str(run.outputs)[:200]}...",
77 | }
78 |
79 | tool_name_match = predicted_tool.get("name") == correct_tool.get("name")
80 |
81 | correct_inputs = set(correct_tool.get("inputs", {}).items())
82 | predicted_inputs = set(predicted_tool.get("args", {}).items())
83 |
84 | if correct_inputs:
85 | inputs_match_count = len(correct_inputs.intersection(predicted_inputs))
86 | tool_inputs_match = inputs_match_count / len(correct_inputs)
87 | else:
88 | tool_inputs_match = 1 if not predicted_inputs else 0
89 |
90 | score = (tool_name_match + tool_inputs_match) / 2
91 |
92 | if score == 1:
93 | comment = "Pass: Correct tool and inputs matched."
94 | elif score > 0:
95 | comment = (
96 | f"Partial match (score: {score:.2f}). Expected tool '{correct_tool.get('name')}' "
97 | f"with inputs {correct_tool.get('inputs')}, but got tool '{predicted_tool.get('name')}' "
98 | f"with inputs {predicted_tool.get('args')}."
99 | )
100 | else:
101 | comment = (
102 | f"Fail: Expected tool '{correct_tool.get('name')}' with inputs {correct_tool.get('inputs')}, "
103 | f"but got tool '{predicted_tool.get('name')}' with inputs {predicted_tool.get('args')}."
104 | )
105 |
106 | except Exception as e:
107 | score = 0
108 | comment = (
109 | f"Error in evaluation: {str(e)}. Run outputs: {str(run.outputs)[:200]}... "
110 | f"Reference correct tool: {str(correct_tool)}"
111 | )
112 |
113 | result = {
114 | "key": "tool_use_accuracy",
115 | "score": score,
116 | "comment": comment,
117 | }
118 |
119 | return result
120 |
121 |
122 | evaluators = [tool_use_evaluator]
123 |
--------------------------------------------------------------------------------
/experiments/tooluse_finance/task.py:
--------------------------------------------------------------------------------
1 | """Task definition for finance tool use."""
2 |
3 | from typing import List, Callable
4 | import json
5 | from pathlib import Path
6 |
7 | from langsmith.schemas import Run, Example
8 | from krishpromptim.prompt_types import Task, Dataset
9 | from .evaluators import evaluators as finance_tooluse_evaluators
10 |
11 |
12 | class FinanceToolUseTask(Task):
13 | """Task for evaluating tool use in finance scenarios."""
14 |
15 | def __init__(self, custom_evaluators: List[Callable[[Run, Example], dict]] = None):
16 | """Initialize the task with optional custom evaluators.
17 |
18 | Args:
19 | custom_evaluators: Optional list of custom evaluator functions. If provided,
20 | these will replace the default evaluator.
21 | """
22 | evaluators = (
23 | custom_evaluators
24 | if custom_evaluators is not None
25 | else finance_tooluse_evaluators
26 | )
27 |
28 | # Load config from json
29 | config_path = Path(__file__).parent / "config.json"
30 | with open(config_path) as f:
31 | config = json.load(f)
32 |
33 | super().__init__(
34 | name=config["name"],
35 | description=config["description"],
36 | dataset=Dataset(**config["dataset"]),
37 | evaluators=evaluators,
38 | evaluator_descriptions=config["evaluator_descriptions"],
39 | initial_prompt=config["initial_prompt"],
40 | optimizer=config["optimizer"],
41 | )
42 |
43 |
44 | # Export the task class
45 | __all__ = ["FinanceToolUseTask"]
46 |
--------------------------------------------------------------------------------
/generate_schema.py:
--------------------------------------------------------------------------------
1 | import json
2 | from dataclasses import fields
3 | from typing import Any, Optional, Type
4 |
5 | from pydantic import BaseModel, Field, create_model
6 | from pydantic.json_schema import model_json_schema
7 |
8 | from promptim.config import Config
9 |
10 |
11 | def get_schema(cls: Type[Any]) -> dict:
12 | """Create a JSON schema dict from a dataclass or Pydantic model.
13 |
14 | Args:
15 | cls: A dataclass or Pydantic model type.
16 |
17 | Returns:
18 | A dict representing the JSON schema of the input class.
19 |
20 | Raises:
21 | TypeError: If the input is not a dataclass or Pydantic model.
22 | """
23 | if isinstance(cls, type) and issubclass(cls, BaseModel):
24 | return model_json_schema(cls)
25 | elif hasattr(cls, "__dataclass_fields__"):
26 | # Convert dataclass to Pydantic model
27 | fields_dict = {}
28 | for field in fields(cls):
29 | field_info = {}
30 | if field.default is not field.default_factory:
31 | # Field has a default value or default factory
32 | field_info["default"] = field.default
33 | if field.metadata.get("description"):
34 | field_info["description"] = field.metadata["description"]
35 |
36 | if field_info:
37 | fields_dict[field.name] = (Optional[field.type], Field(**field_info))
38 | else:
39 | # Field is required
40 | fields_dict[field.name] = (field.type, ...)
41 | pydantic_model = create_model(cls.__name__, **fields_dict)
42 | return model_json_schema(pydantic_model)
43 | else:
44 | raise TypeError("Input must be a dataclass or Pydantic model")
45 |
46 |
47 | config_schema = get_schema(Config)
48 | config_schema["$schema"] = "http://json-schema.org/draft-07/schema#"
49 |
50 | with open("config-schema.json", "w") as f:
51 | json.dump(config_schema, f, indent=2)
52 |
53 | print("Schema generated and saved to config-schema.json")
54 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "promptim"
3 | version = "0.0.9"
4 | description = "A framework for optimizing prompts through multi-task evaluation and iterative improvement"
5 | authors = [
6 | { name = "William Fu-Hinthorn", email = "13333726+hinthornw@users.noreply.github.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">=3.11"
10 | dependencies = [
11 | "click",
12 | "langchain",
13 | "langchain-anthropic>=0.3",
14 | "langsmith>=0.2.11",
15 | "langchain-openai>=0.3",
16 | "pandas>=0.2.4",
17 | "rich",
18 | "python-dotenv>=1.0.1",
19 | "trustcall>=0.0.28",
20 | ]
21 |
22 | [project.scripts]
23 | promptim = "promptim.__main__:cli"
24 |
25 | [tool.setuptools]
26 | packages = ["promptim"]
27 |
28 | [build-system]
29 | requires = ["hatchling"]
30 | build-backend = "hatchling.build"
31 |
32 | [dependency-groups]
33 | dev = [
34 | "pytest>=8.3.4",
35 | "trustcall>=0.0.26",
36 | "vcrpy>=6.0.2",
37 | ]
38 |
--------------------------------------------------------------------------------
/src/promptim/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hinthornw/promptimizer/2052e812fc52cff19d8539b30b1782385f011c1e/src/promptim/__init__.py
--------------------------------------------------------------------------------
/src/promptim/_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions."""
2 |
3 | from difflib import SequenceMatcher
4 | from rich import print as richprint
5 | from rich.panel import Panel
6 |
7 | import re
8 | import uuid
9 | import langsmith as ls
10 | from collections import deque
11 |
12 |
13 | def _colorize_diff(diff):
14 | for op, i1, i2, j1, j2 in diff.get_opcodes():
15 | if op == "equal":
16 | yield diff.a[i1:i2]
17 | elif op == "insert":
18 | yield f"[green]{diff.b[j1:j2]}[/green]"
19 | elif op == "delete":
20 | yield f"[red]{diff.a[i1:i2]}[/red]"
21 | elif op == "replace":
22 | yield f"[red]{diff.a[i1:i2]}[/red][green]{diff.b[j1:j2]}[/green]"
23 |
24 |
25 | def print_rich_diff(original: str, updated: str, title: str = "") -> None:
26 | diff = SequenceMatcher(None, original, updated)
27 | colorized_diff = "".join(_colorize_diff(diff))
28 | panel = Panel(
29 | colorized_diff, title=title or "Prompt Diff", expand=False, border_style="bold"
30 | )
31 | richprint(panel)
32 |
33 |
34 | def get_var_healer(vars: set[str], all_required: bool = False):
35 | var_to_uuid = {f"{{{v}}}": uuid.uuid4().hex for v in vars}
36 | uuid_to_var = {v: k for k, v in var_to_uuid.items()}
37 |
38 | def escape(input_string: str) -> str:
39 | result = re.sub(r"(?|", re.MULTILINE | re.DOTALL
51 | )
52 |
53 | def assert_all_required(input_string: str) -> str:
54 | if not all_required:
55 | return input_string
56 |
57 | missing = [var for var in vars if f"{{{var}}}" not in input_string]
58 | if missing:
59 | raise ValueError(f"Missing required variable: {', '.join(missing)}")
60 |
61 | return input_string
62 |
63 | def mask(input_string: str) -> str:
64 | return mask_pattern.sub(lambda m: var_to_uuid[m.group(0)], input_string)
65 |
66 | def unmask(input_string: str) -> str:
67 | return unmask_pattern.sub(lambda m: uuid_to_var[m.group(0)], input_string)
68 |
69 | def pipe(input_string: str) -> str:
70 | return unmask(
71 | strip_to_optimize_pattern.sub(
72 | "", escape(mask(assert_all_required(input_string)))
73 | )
74 | )
75 |
76 | return pipe
77 |
78 |
79 | def get_token_usage() -> int | None:
80 | rt = ls.get_current_run_tree()
81 | if not rt:
82 | return
83 | runs = deque([rt])
84 | kept = []
85 | while runs:
86 | run = runs.popleft()
87 | if run.run_type == "llm":
88 | kept.append(run)
89 | runs.extend(run.child_runs)
90 | all_toks = []
91 | for r in kept:
92 | usage = ((r.outputs or {}).get("llm_output") or {}).get("usage")
93 | if not usage:
94 | continue
95 | input_tokens = usage.get("input_tokens")
96 | output_tokens = usage.get("output_tokens")
97 | if input_tokens is None or output_tokens is None:
98 | continue
99 | all_toks.append(output_tokens + input_tokens)
100 |
101 | if all_toks:
102 | return sum(all_toks)
103 |
--------------------------------------------------------------------------------
/src/promptim/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 | from promptim.algorithms.base import BaseAlgorithm
2 | from promptim.algorithms.minibatch import MinibatchAlgorithm
3 | from promptim.algorithms.phaseevo import PhaseEvoAlgorithm
4 | from promptim.algorithms.mipro import MIPROAlgorithm
5 | from langchain_core.language_models import BaseChatModel
6 |
7 | _MAP = {
8 | "minibatch": MinibatchAlgorithm,
9 | "phaseevo": PhaseEvoAlgorithm,
10 | "mipro": MIPROAlgorithm,
11 | }
12 |
13 |
14 | def load_algorithm(config: dict, optimizer_model: BaseChatModel) -> BaseAlgorithm:
15 | """Load an algorithm from its config dictionary."""
16 | config = config.copy()
17 | kind = config.pop("kind", "minibatch")
18 | if kind not in _MAP:
19 | raise ValueError(
20 | f"Unknown algorithm kind: {kind}. Available kinds: {list(_MAP.keys())}"
21 | )
22 |
23 | return _MAP[kind](config, optimizer_model)
24 |
25 |
26 | __all__ = [
27 | "MinibatchAlgorithm",
28 | "PhaseEvoAlgorithm",
29 | "MIPROAlgorithm",
30 | "load_algorithm",
31 | ]
32 |
--------------------------------------------------------------------------------
/src/promptim/algorithms/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Union, Optional, TypeVar, Generic
3 | from dataclasses import dataclass
4 |
5 | from promptim import types as pm_types
6 | from promptim.trainer import PromptTrainer
7 | from langchain_core.language_models import BaseChatModel
8 |
9 |
10 | @dataclass
11 | class AlgorithmConfig:
12 | """Base configuration for training algorithms."""
13 |
14 | train_size: Optional[int] = None
15 | batch_size: int = 40
16 | epochs: int = 5
17 | debug: bool = False
18 | max_score: float = 1.0
19 |
20 |
21 | C = TypeVar("C", bound=AlgorithmConfig)
22 |
23 |
24 | class BaseAlgorithm(Generic[C], ABC):
25 | """
26 | Abstract base that defines the macro-level training loop
27 | or search procedure (epochs, phases, etc.).
28 | """
29 |
30 | config_cls = AlgorithmConfig
31 |
32 | def __init__(
33 | self, config: Optional[Union[dict, AlgorithmConfig]], model: BaseChatModel
34 | ):
35 | self.config = self._resolve_config(config or {})
36 | self.model = model
37 |
38 | def _resolve_config(self, config: Union[dict, AlgorithmConfig]) -> C:
39 | if isinstance(config, dict):
40 | return self.config_cls(**config) # type: ignore
41 | return config # type: ignore
42 |
43 | @abstractmethod
44 | async def run(
45 | self,
46 | trainer: PromptTrainer,
47 | task: pm_types.Task,
48 | initial_population: Union[pm_types.PromptWrapper, List[pm_types.PromptWrapper]],
49 | train_examples: list[pm_types.Example],
50 | dev_examples: list[pm_types.Example],
51 | *,
52 | system_config: Optional[dict] = None,
53 | annotation_queue: Optional[str] = None,
54 | commit_prompts: bool = False,
55 | experiment_name: str = "Prompt Optimization",
56 | baseline_scores: Optional[dict] = None,
57 | baseline_experiment_results: Optional[list] = None,
58 | ) -> tuple[pm_types.PromptWrapper, float, dict]:
59 | """
60 | Execute the training/evolution procedure using the trainer's capabilities.
61 |
62 | Args:
63 | trainer: The PromptTrainer instance providing evaluation and utilities
64 | task: The task to optimize for
65 | initial_population: Single prompt or list of prompts to start with
66 | system_config: Optional system-level configuration
67 | annotation_queue: Optional queue for manual review
68 | commit_prompts: Whether to commit prompts to LangSmith
69 | experiment_name: Optional name for the experiment
70 |
71 | Returns:
72 | Tuple of (best prompt, best score)
73 | """
74 | pass
75 |
--------------------------------------------------------------------------------
/src/promptim/algorithms/minibatch.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Optional
2 | from rich.panel import Panel
3 | from rich.progress import Progress
4 | from rich.table import Table
5 |
6 | from promptim import types as pm_types
7 | from promptim.trainer import PromptTrainer
8 |
9 | from promptim.algorithms.base import BaseAlgorithm, AlgorithmConfig
10 | from promptim import _utils as pm_utils
11 | import langsmith as ls
12 |
13 |
14 | class MinibatchAlgorithm(BaseAlgorithm[AlgorithmConfig]):
15 | """
16 | Classic epoch-based training that processes data in minibatches.
17 | This preserves the original optimize_prompt behavior.
18 | """
19 |
20 | @ls.traceable(name="MinibatchAlgorithm.run")
21 | async def run(
22 | self,
23 | trainer: PromptTrainer,
24 | task: pm_types.Task,
25 | initial_population: Union[pm_types.PromptWrapper, List[pm_types.PromptWrapper]],
26 | train_examples: list[pm_types.Example],
27 | dev_examples: list[pm_types.Example],
28 | *,
29 | system_config: Optional[dict] = None,
30 | annotation_queue: Optional[str] = None,
31 | commit_prompts: bool = False,
32 | experiment_name: str = "Prompt Optimization",
33 | baseline_scores: Optional[dict] = None,
34 | baseline_experiment_results: Optional[list] = None,
35 | ) -> tuple[pm_types.PromptWrapper, float]:
36 | """Implementation of the original optimize_prompt flow."""
37 | if isinstance(initial_population, pm_types.PromptWrapper):
38 | initial_population = [initial_population]
39 | history = [initial_population]
40 | best_score = float("-inf")
41 | best_prompt = initial_population[0]
42 | with Progress() as progress:
43 | main_task = progress.add_task(
44 | "[cyan]Optimizing prompt...", total=self.config.epochs + 2
45 | )
46 | progress.update(
47 | main_task, advance=10, description="[cyan]Getting baseline scores..."
48 | )
49 | baseline_scores = baseline_scores or {}
50 | best_score = (
51 | sum(baseline_scores.values()) / len(baseline_scores)
52 | if baseline_scores
53 | else float("-inf")
54 | )
55 | table = Table(
56 | title="Baseline Scores (Dev Set)",
57 | show_header=True,
58 | header_style="bold magenta",
59 | )
60 | table.add_column("Metric", style="cyan", no_wrap=True)
61 | table.add_column("Score", justify="right", style="green")
62 |
63 | for metric, score in baseline_scores.items():
64 | table.add_row(metric, f"{score:.4f}")
65 |
66 | table.add_row("Average", f"{best_score:.4f}", style="bold")
67 |
68 | progress.console.print(
69 | Panel(
70 | table,
71 | title="[bold]Initial Prompt Evaluation[/bold]",
72 | border_style="cyan",
73 | )
74 | )
75 | progress.console.print("\n[bold cyan]Beginning optimization.[/bold cyan]")
76 | progress.console.print()
77 |
78 | # Step 2: Train
79 | progress.update(
80 | main_task,
81 | advance=1,
82 | description="[cyan]Optimizing prompt on epoch 1...",
83 | )
84 | training_session_fut = trainer._enqueue_experiment(
85 | experiment_name=experiment_name,
86 | examples=train_examples,
87 | split="train",
88 | epoch=0,
89 | )
90 | for epoch in range(self.config.epochs):
91 | training_session = await training_session_fut
92 |
93 | trainer.optimizer.on_epoch_start(epoch, task)
94 | trainer.rng.shuffle(train_examples)
95 | if self.config.train_size:
96 | train_examples = train_examples[: self.config.train_size]
97 |
98 | batches = [
99 | train_examples[i : i + self.config.batch_size]
100 | for i in range(0, len(train_examples), self.config.batch_size)
101 | ]
102 |
103 | batch_task = progress.add_task(
104 | f"[yellow]Epoch {epoch+1} batches", total=len(batches)
105 | )
106 | all_train_scores = []
107 | avg_score = float("-inf")
108 | training_session_fut = trainer._enqueue_experiment(
109 | experiment_name=experiment_name,
110 | examples=train_examples,
111 | split="train",
112 | epoch=epoch + 1,
113 | )
114 | for bix, batch in enumerate(batches):
115 | results = None
116 | if bix == 0 and epoch == 0 and baseline_experiment_results:
117 | bindices = {e.id for e in batch}
118 | results = [
119 | r
120 | for r in baseline_experiment_results
121 | if r["example"].id in bindices
122 | ]
123 | if len(results) != len(batch):
124 | results = None
125 | if results is None:
126 | results = await trainer._evaluate_prompt(
127 | history[-1][-1],
128 | task,
129 | batch,
130 | debug=self.config.debug,
131 | experiment=training_session,
132 | system_config=system_config,
133 | )
134 | next_action = "continue"
135 |
136 | if annotation_queue:
137 | results, next_action = await trainer._wait_for_annotation_queue(
138 | results,
139 | annotation_queue,
140 | task,
141 | progress,
142 | )
143 | if next_action != "continue":
144 | break
145 | train_scores = await trainer.calculate_scores(results)
146 | train_score = (
147 | sum(train_scores.values()) / len(train_scores)
148 | if train_scores
149 | else None
150 | )
151 | all_train_scores.append(train_score)
152 | avg_score = sum(all_train_scores) / len(all_train_scores)
153 | progress.update(
154 | batch_task,
155 | description=f"[yellow]Epoch {epoch+1} (Avg training score: {avg_score:.4f})",
156 | )
157 | # Get improved population
158 | try:
159 | improved = await trainer.optimizer.improve_prompt(
160 | history=history,
161 | results=results,
162 | task=task,
163 | trainer=trainer,
164 | )
165 | history[-1].extend(improved)
166 |
167 | if commit_prompts:
168 | for prompt in improved:
169 | prompt.push_prompt(client=trainer.client)
170 | except Exception as e:
171 | progress.console.print(
172 | f"Failed to improve prompt: {e}", style="red"
173 | )
174 | break
175 |
176 | # Evaluate on dev set
177 | progress.update(main_task, description="[cyan]Evaluating on dev set...")
178 | dev_results = await trainer._evaluate_prompt(
179 | history[-1][-1],
180 | task,
181 | dev_examples,
182 | debug=self.config.debug,
183 | system_config=system_config,
184 | )
185 | dev_scores = await trainer.calculate_scores(dev_results)
186 | dev_score = (
187 | sum(dev_scores.values()) / len(dev_scores) if dev_scores else None
188 | )
189 | progress.update(
190 | batch_task,
191 | description=f'[yellow]Epoch {epoch+1} (Dev: {f"{dev_score:.4f}" if dev_score is not None else "-"}, Train: {f"{avg_score:.4f}" if avg_score is not None else "-"})',
192 | )
193 |
194 | if dev_score is not None and dev_score > best_score:
195 | best_score = dev_score
196 | best_prompt = history[-1][-1]
197 | progress.console.print(
198 | f"New best score: {best_score:.4f} (surpassed previous best)"
199 | )
200 | progress.console.print("Average of:")
201 | for metric, score in dev_scores.items():
202 | progress.console.print(f" {metric}: {score:.4f}")
203 | else:
204 | progress.console.print(
205 | f"Score {dev_score:.4f} did not surpass best score {best_score:.4f}"
206 | )
207 |
208 | trainer.log_metric(
209 | "score",
210 | value=best_score,
211 | x=epoch,
212 | x_label="epoch",
213 | split="dev",
214 | prompt=best_prompt,
215 | )
216 |
217 | tokens_used = pm_utils.get_token_usage()
218 | if tokens_used is not None:
219 | trainer.log_metric(
220 | "score",
221 | value=best_score,
222 | x=tokens_used,
223 | x_label="total tokens",
224 | split="dev",
225 | prompt=best_prompt,
226 | )
227 | history.append([best_prompt])
228 |
229 | return best_prompt, best_score
230 |
--------------------------------------------------------------------------------
/src/promptim/algorithms/phaseevo/__init__.py:
--------------------------------------------------------------------------------
1 | from promptim.algorithms.phaseevo.algo import PhaseEvoAlgorithm
2 |
3 | __all__ = ["PhaseEvoAlgorithm"]
4 |
--------------------------------------------------------------------------------
/src/promptim/algorithms/tpe_sampler.py:
--------------------------------------------------------------------------------
1 | # Adapted from Optuna. All credit go to the authors of that library.
2 | # https://optuna.readthedocs.io/en/stable/_modules/optuna/samplers/_tpe/sampler.html#TPESampler
3 | import math
4 | import random
5 | from typing import List, Dict, Tuple, Callable, Awaitable, Union, Any
6 |
7 | _MIN = -999999999.0
8 |
9 |
10 | class TPESampler:
11 | """Tree-structured parzen estimator; based on Optuna's implementation but without the extra power.
12 |
13 | For each parameter, we store (value, objective) for each completed trial.
14 | We then:
15 | 1) Sort by objective (assume 'lower is better' by default).
16 | 2) Split into 'good' set (best fraction) vs 'bad' set (rest).
17 | 3) Model each set as a mixture of Gaussians (one Gaussian per data point).
18 | 4) Generate multiple candidate points from the mixture of 'good' set,
19 | evaluating ratio l(x)/g(x), and choose the x that maximizes it.
20 | """
21 |
22 | def __init__(self, seed: int = 42):
23 | self.rng = random.Random(seed)
24 | # Data structure to store param -> list of (value, objective)
25 | self.observations: Dict[str, List[Tuple[Union[float, int, str], float]]] = {}
26 | # You can store advanced settings here if desired (bandwidth, etc.)
27 |
28 | def register(
29 | self, param_name: str, value: Union[float, int, str], objective: float
30 | ):
31 | """
32 | Add one completed trial's param value and objective outcome.
33 | """
34 | if param_name not in self.observations:
35 | self.observations[param_name] = []
36 | self.observations[param_name].append((value, objective))
37 |
38 | def suggest_categorical(
39 | self,
40 | param_name: str,
41 | choices: List[Any],
42 | n_candidates: int = 24,
43 | gamma: float = 0.2,
44 | lower_is_better: bool = True,
45 | ) -> Any:
46 | """Return a suggested categorical value for the given param."""
47 | history = self.observations.get(param_name, [])
48 | if len(history) < 2:
49 | return self.rng.choice(choices)
50 |
51 | sorted_history = sorted(
52 | history, key=lambda x: x[1], reverse=(not lower_is_better)
53 | )
54 |
55 | n_good = max(1, int(math.ceil(len(sorted_history) * gamma)))
56 | good = sorted_history[:n_good]
57 | bad = sorted_history[n_good:]
58 |
59 | good_counts = {choice: 0.0 for choice in choices}
60 | bad_counts = {choice: 0.0 for choice in choices}
61 |
62 | pseudocount = 1.0
63 | for choice in choices:
64 | good_counts[choice] = pseudocount
65 | bad_counts[choice] = pseudocount
66 |
67 | for val, _ in good:
68 | good_counts[val] += 1.0
69 | for val, _ in bad:
70 | bad_counts[val] += 1.0
71 |
72 | good_total = sum(good_counts.values())
73 | bad_total = sum(bad_counts.values())
74 |
75 | for choice in choices:
76 | good_counts[choice] /= good_total
77 | bad_counts[choice] /= bad_total
78 |
79 | best_choice = None
80 | best_ratio = _MIN
81 |
82 | for _ in range(n_candidates):
83 | candidate = self.rng.choice(choices)
84 | ratio = math.log(good_counts[candidate]) - math.log(bad_counts[candidate])
85 |
86 | if ratio > best_ratio:
87 | best_ratio = ratio
88 | best_choice = candidate
89 |
90 | return best_choice if best_choice is not None else self.rng.choice(choices)
91 |
92 | def suggest(
93 | self,
94 | param_name: str,
95 | low: float,
96 | high: float,
97 | n_candidates: int = 24,
98 | gamma: float = 0.2,
99 | lower_is_better: bool = True,
100 | bandwidth: float = 0.1,
101 | ) -> float:
102 | """Return a suggested float value for the given param within [low, high].
103 |
104 | Args:
105 | n_candidates: Number of candidate samples from the 'good' mixture
106 | gamma: Fraction of trials to consider 'good' (0.2 => top 20%).
107 | lower_is_better: If True, smaller objective is better. If False, bigger is better.
108 | bandwidth: Kernel width (std dev) for each sample-based Gaussian in the mixture.
109 | """
110 | history = self.observations.get(param_name, [])
111 | if len(history) < 2:
112 | return self.rng.uniform(low, high)
113 |
114 | sorted_history = sorted(
115 | history, key=lambda x: x[1], reverse=(not lower_is_better)
116 | )
117 |
118 | n_good = max(1, int(math.ceil(len(sorted_history) * gamma)))
119 | good = sorted_history[:n_good]
120 | bad = sorted_history[n_good:]
121 |
122 | best_x = None
123 | best_obj = _MIN
124 |
125 | for _ in range(n_candidates):
126 | x_cand = self._sample_from_mixture(good, low, high, bandwidth)
127 | log_l_good = self._log_mixture_pdf(x_cand, good, bandwidth)
128 | log_l_bad = self._log_mixture_pdf(x_cand, bad, bandwidth)
129 | ratio = log_l_good - log_l_bad
130 |
131 | if ratio > best_obj:
132 | best_obj = ratio
133 | best_x = x_cand
134 |
135 | if best_x is None:
136 | return self.rng.uniform(low, high)
137 |
138 | return max(low, min(high, best_x))
139 |
140 | def suggest_int(
141 | self,
142 | param_name: str,
143 | low: int,
144 | high: int,
145 | n_candidates: int = 24,
146 | gamma: float = 0.2,
147 | lower_is_better: bool = True,
148 | ) -> int:
149 | """Return a suggested integer value for the given param within [low, high]."""
150 | float_val = self.suggest(
151 | param_name=param_name,
152 | low=float(low) - 0.4999,
153 | high=float(high) + 0.4999,
154 | n_candidates=n_candidates,
155 | gamma=gamma,
156 | lower_is_better=lower_is_better,
157 | )
158 | return int(round(float_val))
159 |
160 | async def optimize(
161 | self, objective_fn: Callable[[Any], Awaitable[float]], n_trials: int = 30
162 | ) -> "Trial":
163 | """Run optimization for n_trials, returning best trial."""
164 | best_score = float("-inf")
165 | best_trial = None
166 |
167 | for _ in range(n_trials):
168 | trial = Trial(self)
169 | score = await objective_fn(trial)
170 |
171 | if score > best_score:
172 | best_score = score
173 | best_trial = trial
174 |
175 | return best_trial
176 |
177 | def _sample_from_mixture(
178 | self,
179 | dataset: List[Tuple[float, float]],
180 | low: float,
181 | high: float,
182 | bandwidth: float,
183 | ) -> float:
184 | """
185 | Sample one x from the mixture of Gaussians, each centered on a
186 | data point from `dataset`.
187 | """
188 | if not dataset:
189 | return self.rng.uniform(low, high)
190 |
191 | idx = self.rng.randint(0, len(dataset) - 1)
192 | center = dataset[idx][0]
193 |
194 | min_distance = min(center - low, high - center)
195 | adj_bandwidth = min(bandwidth, min_distance / 3)
196 |
197 | return self.rng.gauss(center, adj_bandwidth)
198 |
199 | def _log_mixture_pdf(
200 | self, x: float, dataset: List[Tuple[float, float]], bandwidth: float
201 | ) -> float:
202 | """mixture is average of Normal(center=each data point, sigma=bandwidth)."""
203 | if not dataset:
204 | return _MIN
205 |
206 | log_vals = []
207 | for val, _ in dataset:
208 | log_vals.append(self._log_normal_pdf(x, val, bandwidth))
209 |
210 | max_log = max(log_vals)
211 | sum_exp = 0.0
212 | for log_val in log_vals:
213 | sum_exp += math.exp(log_val - max_log)
214 |
215 | return max_log + math.log(sum_exp) - math.log(len(log_vals))
216 |
217 | def _log_normal_pdf(self, x: float, mu: float, sigma: float) -> float:
218 | if sigma <= 0.0:
219 | return _MIN
220 |
221 | z = (x - mu) / sigma
222 | return -0.5 * z * z - math.log(sigma) - 0.5 * math.log(2 * math.pi)
223 |
224 |
225 | class Trial:
226 | def __init__(self, sampler: TPESampler):
227 | self.sampler = sampler
228 | self.params = {}
229 |
230 | def suggest_categorical(
231 | self,
232 | name: str,
233 | choices: List[Any],
234 | n_candidates: int = 24,
235 | gamma: float = 0.2,
236 | lower_is_better: bool = True,
237 | ) -> Any:
238 | value = self.sampler.suggest_categorical(
239 | name,
240 | choices,
241 | n_candidates=n_candidates,
242 | gamma=gamma,
243 | lower_is_better=lower_is_better,
244 | )
245 | self.params[name] = value
246 | return value
247 |
248 | def suggest_int(
249 | self,
250 | name: str,
251 | low: int,
252 | high: int,
253 | n_candidates: int = 24,
254 | gamma: float = 0.2,
255 | lower_is_better: bool = True,
256 | ) -> int:
257 | value = self.sampler.suggest_int(
258 | name,
259 | low,
260 | high,
261 | n_candidates=n_candidates,
262 | gamma=gamma,
263 | lower_is_better=lower_is_better,
264 | )
265 | self.params[name] = value
266 | return value
267 |
--------------------------------------------------------------------------------
/src/promptim/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional, Union
3 |
4 |
5 | # Import optimizer configs
6 | from promptim.optimizers.metaprompt import Config as MetaPromptConfig
7 | from promptim.optimizers.fewshot import Config as FewShotConfig
8 | from promptim.optimizers.feedback_guided import Config as FeedbackGuidedConfig
9 | from promptim.types import TaskLike
10 |
11 |
12 | OptimizerConfig = Union[MetaPromptConfig, FewShotConfig, FeedbackGuidedConfig]
13 |
14 |
15 | @dataclass(kw_only=True)
16 | class Config(TaskLike):
17 | optimizer: OptimizerConfig | None = field(
18 | default=None,
19 | metadata={
20 | "description": "Optimization configuration specifying model settings and hyperparameters. If None, default configuration will be used."
21 | },
22 | )
23 | evaluators: str = field(
24 | metadata={
25 | "description": (
26 | "Import path to evaluator functions in format 'file_path:variable_name'. The functions should evaluate prompt quality.\n\n"
27 | "Example:\n ./task/evaluators.py:evaluators"
28 | )
29 | }
30 | )
31 | system: Optional[str] = field(
32 | default=None,
33 | metadata={
34 | "description": (
35 | "Import path to system configuration in format 'file_path:variable_name'. Defines how prompts are executed.\n\n"
36 | "Example:\n ./task/my_system.py:chain"
37 | )
38 | },
39 | )
40 |
--------------------------------------------------------------------------------
/src/promptim/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from promptim.optimizers.fewshot import FewShotOptimizer
2 | from promptim.optimizers.metaprompt import MetaPromptOptimizer
3 | from promptim.optimizers.feedback_guided import FeedbackGuidedOptimizer
4 | from promptim.optimizers.base import BaseOptimizer
5 |
6 | # Use the config_cls.kind.default to get the map keys
7 |
8 | _MAP = {
9 | "metaprompt": MetaPromptOptimizer,
10 | "fewshot": FewShotOptimizer,
11 | "feedback_guided": FeedbackGuidedOptimizer,
12 | }
13 |
14 |
15 | def load_optimizer(config: dict) -> BaseOptimizer:
16 | """Load an optimizer from its config dictionary."""
17 | kind = config["kind"]
18 | if kind not in _MAP:
19 | raise ValueError(
20 | f"Unknown optimizer kind: {kind}. Available kinds: {list(_MAP.keys())}"
21 | )
22 |
23 | return _MAP[kind].from_config(config)
24 |
25 |
26 | __all__ = [
27 | "MetaPromptOptimizer",
28 | "FewShotOptimizer",
29 | "FeedbackGuidedOptimizer",
30 | "BaseOptimizer",
31 | "load_optimizer",
32 | ]
33 |
--------------------------------------------------------------------------------
/src/promptim/optimizers/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Type, Sequence
3 | from langsmith.evaluation._arunner import ExperimentResultRow
4 | from promptim import types as pm_types
5 | from dataclasses import dataclass, field, is_dataclass, asdict
6 | from langchain_core.language_models import BaseChatModel
7 | from langchain.chat_models import init_chat_model
8 |
9 | MODEL_TYPE = str | BaseChatModel | dict
10 |
11 |
12 | @dataclass(kw_only=True)
13 | class Config:
14 | kind: str
15 | model: MODEL_TYPE = field(
16 | default_factory=lambda: {
17 | "model": "claude-3-5-sonnet-20241022",
18 | "max_tokens_to_sample": 8192,
19 | }
20 | )
21 |
22 |
23 | class BaseMutator(ABC):
24 | config_cls: Type[Config]
25 |
26 | def __init__(self, *, model: MODEL_TYPE):
27 | self.model = _resolve_model(model)
28 |
29 | @classmethod
30 | def from_config(cls, config: dict | Config):
31 | if is_dataclass(config):
32 | config = asdict(config)
33 | config_ = {k: v for k, v in config.items() if k != "kind"}
34 | return cls(**config_)
35 |
36 |
37 | class BaseOptimizer(BaseMutator):
38 | @abstractmethod
39 | async def improve_prompt(
40 | self,
41 | history: Sequence[Sequence[pm_types.PromptWrapper]],
42 | results: List[ExperimentResultRow],
43 | task: pm_types.Task,
44 | **kwargs,
45 | ) -> list[pm_types.PromptWrapper]:
46 | """Given the current generation of prompts and the latest evaluation results,
47 | propose a new and improved prompt variant."""
48 |
49 | def on_epoch_start(self, epoch: int, task: pm_types.Task):
50 | """Hook for any setup needed at the start of each epoch."""
51 |
52 |
53 | # Private utils
54 |
55 |
56 | def _resolve_model(model: MODEL_TYPE) -> BaseChatModel:
57 | if isinstance(model, dict):
58 | return init_chat_model(**model)
59 | elif isinstance(model, BaseChatModel):
60 | return model
61 | else:
62 | return init_chat_model(model=model)
63 |
--------------------------------------------------------------------------------
/src/promptim/optimizers/feedback_guided.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Literal, Sequence, cast
2 | from langsmith.evaluation._arunner import ExperimentResultRow
3 | from dataclasses import dataclass, field
4 | from promptim import types as pm_types, _utils as pm_utils
5 | from promptim.optimizers import base as optimizers
6 | from pydantic import BaseModel, Field
7 | import langsmith as ls
8 | import random
9 | from promptim.optimizers.metaprompt import DEFAULT_METAPROMPT
10 | from trustcall import create_extractor
11 |
12 | _DEFAULT_RECOMMENDATION_PROMPT = """You are giving feedback on the performance of an AI model.
13 |
14 | Analyze the test case, along with the prompt and any evaluation scores. Based on those results,
15 | develop a theory of why the model failed. Perform a detailed analysis, commensurate to the complexity of the task.
16 | Then provide targeted recommendations for improvements.
17 |
18 | The current prompt is:
19 |
20 |
21 | {prompt}
22 |
23 | Another AI will optimize the above prompt based on your recommendations. Be as clear and specific as possible.
24 | """
25 |
26 |
27 | @dataclass(kw_only=True)
28 | class Config(optimizers.Config):
29 | kind: Literal["feedback_guided"] = field(
30 | default="feedback_guided",
31 | metadata={
32 | "description": "The feedback_guided optimizer that iteratively improves"
33 | " prompts based on feedback from evaluation results, focusing on examples that fall below a specified performance threshold."
34 | },
35 | )
36 | recommendation_prompt: str = field(
37 | default=_DEFAULT_RECOMMENDATION_PROMPT,
38 | )
39 | score_threshold: float = 0.8
40 | max_batch_size: Optional[int] = 20
41 |
42 |
43 | class Advise(BaseModel):
44 | """Think step-by-step, analyzing the task and test results. Provide a clear recommendation on why the prompt failed this
45 | test case, and what it should do to succeed next time for this type of input. Focus on the test metrics and expected output (if available).
46 | """
47 |
48 | analysis: str = Field(
49 | description="First, analyze why the prompt failed for this example. Think of what instructions in the prompt were poorly defined or missing."
50 | )
51 | recommended_changes: str = Field(
52 | description="Second, provide targeted recommendations for improvements."
53 | )
54 |
55 |
56 | class FeedbackGuidedOptimizer(optimizers.BaseOptimizer):
57 | """
58 | A two-phase optimization algorithm that:
59 | 1. Identifies examples with scores below a threshold
60 | 2. Generates targeted recommendations for improvements
61 | 3. Uses these recommendations to guide prompt refinement
62 |
63 | The algorithm is particularly effective when you want to focus
64 | optimization efforts on specific failure cases while maintaining
65 | overall prompt quality.
66 | """
67 |
68 | config_cls = Config
69 |
70 | def __init__(
71 | self,
72 | *,
73 | model: optimizers.MODEL_TYPE | None = None,
74 | score_threshold: float = 0.8,
75 | recommendation_prompt: Optional[str] = None,
76 | meta_prompt: Optional[str] = None,
77 | max_batch_size: Optional[int] = 20,
78 | ):
79 | super().__init__(model=model)
80 | self.score_threshold = score_threshold
81 | self.recommendation_prompt = (
82 | recommendation_prompt or _DEFAULT_RECOMMENDATION_PROMPT
83 | )
84 | self.meta_prompt = meta_prompt or DEFAULT_METAPROMPT
85 | self.max_batch_size = max_batch_size
86 |
87 | def _format_failing_examples(
88 | self, results: list[ExperimentResultRow]
89 | ) -> list[dict]:
90 | """Identify and format examples that fall below the score threshold."""
91 | failing = []
92 | for r in results:
93 | # Consider "failing" if any evaluation score is below threshold
94 | if any(
95 | (
96 | eval_result.score is not None
97 | and eval_result.score < self.score_threshold
98 | )
99 | for eval_result in r["evaluation_results"]["results"]
100 | ):
101 | failing.append(self._format_example(r))
102 | return failing
103 |
104 | def _format_example(self, example: ExperimentResultRow) -> str:
105 | """Format failing examples into a string for analysis."""
106 | outputs = example["example"].outputs
107 |
108 | if outputs:
109 | ref_output = f"But we expected: {outputs}"
110 | else:
111 | ref_output = ""
112 | scores = []
113 | for eval_result in example["evaluation_results"]["results"]:
114 | scores.append(
115 | f"- {eval_result.key}: {eval_result.score}"
116 | f"{f' (Comment: {eval_result.comment})' if eval_result.comment else ''}"
117 | )
118 |
119 | scores = "\n".join(scores)
120 | if scores:
121 | scores = f"\n\nTest results:\n{scores}"
122 |
123 | return f"""Failing Example:
124 | For input: {example['example'].inputs}
125 | The prompt predicted: {example['run'].outputs}
126 | {ref_output}
127 | {scores}
128 | """
129 |
130 | async def improve_prompt(
131 | self,
132 | history: Sequence[Sequence[pm_types.PromptWrapper]],
133 | results: list[ExperimentResultRow],
134 | task: pm_types.Task,
135 | **kwargs,
136 | ) -> list[pm_types.PromptWrapper]:
137 | """Improve prompt using feedback from failing examples.
138 |
139 | 1. Select failing examples
140 | 2. If no failing examples, return current prompt
141 | 3. Batch advisor over failing examples
142 | 4. Format advisor responses into a string
143 | 5. Run metaprompt over formatted advice
144 | """
145 | current_prompt = history[-1][-1]
146 | other_attempts = [
147 | p for prompts in history for p in prompts if p is not current_prompt
148 | ]
149 | # 1. Identify failing examples
150 | failing_examples = self._format_failing_examples(results)
151 |
152 | # 2. If no failing examples, return current prompt unchanged
153 | if not failing_examples:
154 | return list(history[-1])
155 | if self.max_batch_size and len(failing_examples) > self.max_batch_size:
156 | random.shuffle(failing_examples)
157 | failing_examples = failing_examples[: self.max_batch_size]
158 | # 3. Generate targeted recommendations for each failing example
159 | advisor = create_extractor(self.model, tools=[Advise])
160 | advisor_inputs = [
161 | [
162 | (
163 | "system",
164 | self.recommendation_prompt.format(
165 | prompt=current_prompt.get_prompt_str_in_context()
166 | ),
167 | ),
168 | ("user", example),
169 | ]
170 | for example in failing_examples
171 | ]
172 | with ls.trace(
173 | name="Analyze examples", inputs={"num_examples": len(failing_examples)}
174 | ):
175 | results_ = await advisor.abatch(advisor_inputs)
176 | recommendations = cast(list[Advise], [r["responses"][0] for r in results_])
177 |
178 | # 4. Format recommendations into a consolidated string
179 | formatted_recommendations = []
180 | for i, (example, rec) in enumerate(zip(failing_examples, recommendations)):
181 | formatted_recommendations.append("Recommended changes for example {i+1}:")
182 | formatted_recommendations.append(rec.recommended_changes)
183 | formatted_recommendations.append("-" * 40 + "\n")
184 |
185 | all_recommendations = "\n".join(formatted_recommendations)
186 |
187 | # 5. Use consolidated recommendations to guide final prompt improvement
188 | chain = create_extractor(
189 | self.model,
190 | tools=[pm_types.prompt_schema(current_prompt)],
191 | tool_choice="OptimizedPromptOutput",
192 | )
193 | inputs = {
194 | "current_hypo": "",
195 | "current_prompt": current_prompt.get_prompt_str_in_context(),
196 | "task_description": task.describe(),
197 | "annotated_results": all_recommendations,
198 | "other_attempts": (
199 | "\n\n---".join([p.get_prompt_str() for p in other_attempts])
200 | if other_attempts
201 | else "N/A"
202 | ),
203 | }
204 | with ls.trace("Apply Recommendations", inputs=inputs) as rt:
205 | prompt_output = await chain.ainvoke(self.meta_prompt.format(**inputs))
206 | prompt_output = cast(
207 | pm_types.OptimizedPromptOutput, prompt_output["responses"][0]
208 | )
209 | rt.add_outputs({"prompt_output": prompt_output})
210 |
211 | candidate = pm_types.PromptWrapper.from_prior(
212 | current_prompt, prompt_output.improved_prompt
213 | )
214 |
215 | pm_utils.print_rich_diff(
216 | current_prompt.get_prompt_str_in_context(),
217 | candidate.get_prompt_str_in_context(),
218 | "Updated Prompt with Targeted Improvements",
219 | )
220 | return [candidate]
221 |
--------------------------------------------------------------------------------
/src/promptim/optimizers/fewshot.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from dataclasses import dataclass, field
3 | from typing_extensions import Literal
4 | import random
5 |
6 | import langsmith as ls
7 | from promptim.optimizers import base as optimizers
8 | from promptim import types as pm_types
9 | from langsmith.evaluation._arunner import ExperimentResultRow
10 | from promptim import _utils as pm_utils
11 |
12 |
13 | @dataclass(kw_only=True)
14 | class Config(optimizers.Config):
15 | kind: Literal["fewshot"] = field(
16 | default="fewshot",
17 | metadata={
18 | "description": "The few-shot optimizer that uses TPE to select optimal example combinations"
19 | },
20 | )
21 | max_examples: int = field(
22 | default=50,
23 | metadata={"description": "Maximum number of few-shot examples in the pool"},
24 | )
25 | n_trials: int = field(
26 | default=5, metadata={"description": "Number of TPE optimization trials"}
27 | )
28 | minibatch_size: int = field(
29 | default=10,
30 | metadata={"description": "Number of few-shot examples per minibatch"},
31 | )
32 |
33 |
34 | class FewShotOptimizer(optimizers.BaseOptimizer):
35 | config_cls = Config
36 |
37 | def __init__(
38 | self,
39 | *,
40 | model: optimizers.MODEL_TYPE | None = None,
41 | max_examples: int = 50,
42 | minibatch_size: int = 10,
43 | n_trials: int = 5,
44 | ):
45 | super().__init__(model=model)
46 | self.max_examples = max_examples
47 | self.n_trials = n_trials
48 | self.minibatch_size = minibatch_size
49 | self._rng = random.Random(42) # Just for any extra randomization you might do
50 | from promptim.algorithms.tpe_sampler import TPESampler
51 |
52 | self.sampler = TPESampler(seed=42)
53 |
54 | @ls.traceable
55 | async def improve_prompt(
56 | self,
57 | history: List[List[pm_types.PromptWrapper]],
58 | results: List[ExperimentResultRow],
59 | task: pm_types.Task,
60 | trainer: "PromptTrainer" = None,
61 | **kwargs,
62 | ) -> List[pm_types.PromptWrapper]:
63 | """
64 | Improve the prompt by picking an optimal subset of few-shot examples
65 | that yields the highest average evaluation score.
66 | """
67 | try:
68 | url_ = ls.get_current_run_tree().get_url()
69 | print(f"See optimization run: {url_}")
70 | except Exception:
71 | pass
72 | from promptim.algorithms.tpe_sampler import Trial
73 |
74 | if not results:
75 | # No data to optimize with
76 | return list(history[-1])
77 |
78 | current_prompt = history[-1][-1]
79 | train_examples = [r["example"] for r in results]
80 | best_score = float("-inf")
81 | best_prompt = current_prompt
82 | n_examples = len(train_examples)
83 |
84 | async with ls.trace("FewShotOptimization") as run_tree:
85 |
86 | # This objective is called once per trial, with a new "Trial" each time
87 | @ls.traceable
88 | async def objective(trial: Trial) -> float:
89 | nonlocal best_prompt, best_score
90 |
91 | example_mask = []
92 | for i in range(n_examples):
93 | # We want to MAXIMIZE, so set lower_is_better=False
94 | include_flag = trial.suggest_categorical(
95 | f"include_example_{i}",
96 | choices=[0, 1],
97 | n_candidates=10,
98 | gamma=0.2,
99 | lower_is_better=False,
100 | )
101 | example_mask.append(bool(include_flag))
102 |
103 | while sum(example_mask) > self.max_examples:
104 | chosen_to_remove = self._rng.choice(
105 | [k for k, inc in enumerate(example_mask) if inc]
106 | )
107 | example_mask[chosen_to_remove] = False
108 |
109 | # Construct new prompt with selected few-shot examples
110 | selected_examples = [
111 | ex for ex, inc in zip(train_examples, example_mask) if inc
112 | ]
113 |
114 | if not selected_examples:
115 | score = float("-inf")
116 | else:
117 | shuffled = self._rng.sample(
118 | selected_examples, len(selected_examples)
119 | )
120 | candidate = self._create_prompt_with_examples(
121 | current_prompt, shuffled
122 | )
123 | other_examples = [
124 | ex for ex in train_examples if ex not in selected_examples
125 | ][: self.minibatch_size]
126 | if other_examples:
127 | results = await trainer._evaluate_prompt(
128 | candidate, task, other_examples, upload_results=False
129 | )
130 | score = self._calculate_score(results)
131 | # Keep track of best
132 | if score > best_score:
133 | best_score = score
134 | best_prompt = candidate
135 | # For logging
136 | rt = ls.get_current_run_tree()
137 | rt.metadata["best_score"] = score
138 | else:
139 | if best_score > float("-inf"):
140 | score = best_score
141 | else:
142 | best_prompt = candidate
143 | score = -99999.0
144 | best_score = score
145 |
146 | # Manually register each param w.r.t. outcome
147 | # We pass objective=score for each parameter we suggested
148 | for i, inc_val in enumerate(example_mask):
149 | self.sampler.register(
150 | param_name=f"include_example_{i}",
151 | value=int(inc_val),
152 | objective=score,
153 | )
154 |
155 | return score
156 |
157 | # Actually run TPE optimization for the given number of trials
158 | try:
159 | best_trial = await self.sampler.optimize(
160 | objective, n_trials=self.n_trials
161 | )
162 | run_tree.add_outputs(
163 | {
164 | "best_score": best_score,
165 | "n_trials": self.n_trials,
166 | "best_params": best_trial.params if best_trial else {},
167 | }
168 | )
169 | except Exception as e:
170 | print(f"TPE optimization failed: {e}")
171 | # If it fails, just fall back to last prompt
172 | return list(history[-1])
173 |
174 | # Print a side-by-side difference of the improved prompt
175 | pm_utils.print_rich_diff(
176 | current_prompt.get_prompt_str_in_context(),
177 | best_prompt.get_prompt_str_in_context(),
178 | "Updated Prompt with Optimized Few-Shot Examples",
179 | )
180 |
181 | return [best_prompt]
182 |
183 | def _calculate_score(self, results: List[ExperimentResultRow]) -> float:
184 | # We average over all "score" values that exist.
185 | # If none exist, returns -∞ to discourage such combos.
186 | scores = []
187 | for result in results:
188 | for eval_result in result["evaluation_results"]["results"]:
189 | if eval_result.score is not None:
190 | scores.append(eval_result.score)
191 | if not scores:
192 | return float("-inf")
193 | return sum(scores) / len(scores)
194 |
195 | def _create_prompt_with_examples(
196 | self, base_prompt: pm_types.PromptWrapper, examples: List[pm_types.Example]
197 | ) -> pm_types.PromptWrapper:
198 | """Create a new prompt with the given few-shot examples."""
199 |
200 | def sanitize(s: str) -> str:
201 | # Guard braces for typical format-string safety
202 | return str(s).replace("{", "{{").replace("}", "}}")
203 |
204 | few_shot_text = "\n\n# Few-Shot Examples:\n"
205 | for ex in examples:
206 | outputs = ex.outputs
207 | if isinstance(outputs, dict) and len(outputs) == 1 and "output" in outputs:
208 | outputs = outputs["output"]
209 |
210 | few_shot_text += (
211 | f"Input: {sanitize(ex.inputs)}\n"
212 | f"Output: {sanitize(outputs)}\n"
213 | "---\n"
214 | )
215 |
216 | new_prompt_text = base_prompt.get_prompt_str() + few_shot_text
217 | return pm_types.PromptWrapper.from_prior(
218 | base_prompt,
219 | new_prompt_text,
220 | extra_info={
221 | "num_fewshot": len(examples),
222 | "fewshot_indices": [ex.id for ex in examples],
223 | "optimizer": "fewshot",
224 | },
225 | )
226 |
--------------------------------------------------------------------------------
/src/promptim/optimizers/metaprompt.py:
--------------------------------------------------------------------------------
1 | from typing import List, Sequence
2 | from langsmith.evaluation._arunner import ExperimentResultRow
3 | from langchain_core.messages import AIMessage
4 | from dataclasses import dataclass, field
5 | from promptim import types as pm_types
6 | from promptim import _utils as pm_utils
7 | from promptim.optimizers import base as optimizers
8 | from typing_extensions import Literal
9 | from trustcall import create_extractor
10 | import langsmith as ls
11 | import html
12 |
13 |
14 | DEFAULT_METAPROMPT = """Diagnose and optimize the quality of the prompt over the target task. Understand the underlying model's behavior patterns, and the underlying data generating process
15 | so you know how to make the right improvements. Understand the prompt only has the individual input context. Use the aggregate results for deeper understanding.
16 |
17 | ## Current prompt
18 |
19 | The following is the current best-performing prompt:
20 | {current_hypo}
21 |
22 | {current_prompt}
23 |
24 |
25 | Your generations will replace the content within the tags. The rest is fixed context over which you have no control. The TO_OPTIMIZE and CONTEXT\
26 | tags are provided here to help you disambiguateand not present in the prompt itself.
27 |
28 | ## Previous Prompt Attempts
29 |
30 | You previously attempted to use the following prompts, but they earned worse scores than the current one:
31 |
32 | {other_attempts}
33 |
34 |
35 | Think about what hypotheses you were testing in these previous attempts. None of them were optimal. Think through why to explore better options and better understand the underlying domain.
36 |
37 | ## Annotated results:
38 | The prompt sees the input variables. It produces the outputs.
39 | The reference is hidden to the prompt and represents the expectations of the task.
40 |
41 | {annotated_results}
42 |
43 |
44 | ## Task description:
45 |
46 | {task_description}
47 |
48 |
49 | Unless otherwise specified, higher scores are better (try to maximize scores). Aim for perfect scores across all examples.
50 |
51 | In your head, search through all edits, planning the optimization step-by-step:
52 | 1. Analyze the current results and where they fall short
53 | 2. Identify patterns in the underlying data generating process for the dataset.
54 | 3. Identify patterns in successful vs unsuccessful cases.
55 | 4. Generate hypotheses about what would help fix the shortcomings of the existing prompt.Propose specific improvements to address the shortcomings. Improvements cannot mention reference outputs, which are unavailable to the model.
56 | 5. Generate an improved prompt based on the most promising hypothesis.
57 |
58 | The improved prompt must keep all original input variables.
59 | Focus on targeted improvements rather than re-hashing what the prompt already handles well.
60 |
61 | Use prompting strategies as appropriate for the task. For logic and math, consider encourage more chain-of-thought reasoning,
62 | or include reasoning trajectories to induce better performance. For creative tasks, consider adding style guidelines.
63 | Or consider including synthetic exemplars. Take all the time you need, but DO NOT REPEAT HYPOTHESES FROM PREVIOUS ATTEMPTS. Update your priors by thinking why they're disproven, then try something new."""
64 |
65 |
66 | @dataclass(kw_only=True)
67 | class Config(optimizers.Config):
68 | """Configuration for the metaprompt optimization algorithm."""
69 |
70 | kind: Literal["metaprompt"] = field(
71 | default="metaprompt",
72 | metadata={
73 | "description": "The meta-prompt optimizer that uses an LLM to analyze and improve prompts."
74 | },
75 | )
76 | meta_prompt: str = field(
77 | default=DEFAULT_METAPROMPT,
78 | metadata={
79 | "description": "The meta-prompt to use for analyzing and improving prompts."
80 | },
81 | )
82 |
83 |
84 | class MetaPromptOptimizer(optimizers.BaseOptimizer):
85 | """
86 | This is the original style meta-prompt algorithm:
87 | It takes the current results and uses the meta-prompt to propose a new prompt.
88 | """
89 |
90 | config_cls = Config
91 |
92 | def __init__(
93 | self,
94 | *,
95 | model: optimizers.MODEL_TYPE | None = None,
96 | max_reasoning_steps: int = 5,
97 | meta_prompt: str | None = None,
98 | ):
99 | super().__init__(model=model)
100 | self.meta_prompt = meta_prompt or DEFAULT_METAPROMPT
101 | self.max_reasoning_steps = max_reasoning_steps
102 |
103 | @ls.traceable(run_type="prompt", name="meta_prompt")
104 | def format(self, **kwargs):
105 | return self.meta_prompt.format(**kwargs)
106 |
107 | def _format_results(self, results: List[ExperimentResultRow]) -> str:
108 | formatted = []
109 | for i, r in enumerate(results):
110 | formatted.append(f"Example {i+1}:")
111 | formatted.append(f'Input: {r["run"].inputs}')
112 | if r["example"].outputs:
113 | formatted.append(
114 | f'Reference (hidden from prompt): {r["example"].outputs}'
115 | )
116 | formatted.append(f'Prompt output: {r["run"].outputs}')
117 | formatted.append("Evaluations:")
118 | for eval_result in r["evaluation_results"]["results"]:
119 | formatted.append(f"- {eval_result.key}: {eval_result.score}")
120 | if eval_result.comment:
121 | formatted.append(f" Comment: {eval_result.comment}")
122 | formatted.append("---")
123 | return "\n".join(formatted)
124 |
125 | @ls.traceable
126 | async def improve_prompt(
127 | self,
128 | history: Sequence[Sequence[pm_types.PromptWrapper]],
129 | results: List[ExperimentResultRow],
130 | task: pm_types.Task,
131 | **kwargs,
132 | ) -> list[pm_types.PromptWrapper]:
133 | current_prompt = history[-1][-1]
134 | other_attempts = list(
135 | {
136 | html.escape(p.get_prompt_str()): (
137 | p,
138 | (p.extra.get("hypothesis") or "") if p.extra else "",
139 | )
140 | for ps in history
141 | for p in ps
142 | if p.get_prompt_str() != current_prompt.get_prompt_str()
143 | }.values()
144 | )[-5:]
145 |
146 | annotated_results = self._format_results(results)
147 | async with ls.trace("Optimize") as rt:
148 | print(f"Optimizing with url {rt.get_url()}", flush=True)
149 | formatted = current_prompt.get_prompt_str_in_context()
150 | hypo = (
151 | current_prompt.extra.get("hypothesis") if current_prompt.extra else None
152 | )
153 | if hypo:
154 | hypo = "Hypothesis for this prompt: " + hypo
155 | inputs = self.format(
156 | current_prompt=formatted,
157 | current_hypo=hypo or "",
158 | annotated_results=annotated_results,
159 | task_description=task.describe(),
160 | other_attempts=(
161 | "\n\n---".join(
162 | [
163 | f"{hypo}"
164 | f"\n{p.get_prompt_str()}\n"
165 | for i, (p, hypo) in enumerate(other_attempts)
166 | ]
167 | )
168 | if other_attempts
169 | else "N/A"
170 | ),
171 | )
172 | prompt_output = await self.react_agent(inputs, current_prompt)
173 | rt.add_outputs({"output": prompt_output})
174 | candidate = pm_types.PromptWrapper.from_prior(
175 | current_prompt,
176 | prompt_output.improved_prompt,
177 | extra_info={"hypothesis": prompt_output.hypothesis},
178 | )
179 |
180 | pm_utils.print_rich_diff(
181 | current_prompt.get_prompt_str_in_context(),
182 | candidate.get_prompt_str_in_context(),
183 | "Updated Prompt",
184 | )
185 |
186 | return [candidate]
187 |
188 | @ls.traceable
189 | async def react_agent(
190 | self, inputs: str, current_prompt, n=5
191 | ) -> pm_types.OptimizedPromptOutput:
192 | messages = [
193 | {"role": "user", "content": inputs},
194 | ]
195 | tooly = pm_types.prompt_schema(current_prompt)
196 | just_think = create_extractor(
197 | self.model,
198 | tools=[think, critique],
199 | tool_choice="any",
200 | )
201 | any_chain = create_extractor(
202 | self.model,
203 | tools=[think, critique, tooly],
204 | tool_choice="any",
205 | )
206 | final_chain = create_extractor(
207 | self.model,
208 | tools=[tooly],
209 | tool_choice="OptimizedPromptOutput",
210 | )
211 | for ix in range(n):
212 | if ix == n - 1:
213 | chain = final_chain
214 | elif ix == 0:
215 | chain = just_think
216 | else:
217 | chain = any_chain
218 | response = await chain.ainvoke(messages)
219 | final_response = next(
220 | (
221 | r
222 | for r in response["responses"]
223 | if r.__repr_name__() == "OptimizedPromptOutput"
224 | ),
225 | None,
226 | )
227 | if final_response:
228 | return final_response
229 | msg: AIMessage = response["messages"][-1]
230 | messages.append(msg)
231 | ids = [tc["id"] for tc in (msg.tool_calls or [])]
232 | for id_ in ids:
233 | messages.append({"role": "tool", "content": "", "tool_call_id": id_})
234 |
235 | raise ValueError(f"Failed to generate response after {n} attempts")
236 |
237 |
238 | def think(thought: str):
239 | """First call this to reason over complicated domains, uncover hidden input/output patterns, theorize why previous hypotheses failed, and creatively conduct error analyses (e.g., deep diagnostics/recursively analyzing "why" something failed). List characteristics of the data generating process you failed to notice before. Hypothesize fixes, prioritize, critique, and repeat calling this tool until you are confident in your next solution."""
240 | return "Take as much time as you need! If you're stuck, take a step back and try something new."
241 |
242 |
243 | def critique(criticism: str):
244 | """Then, critique your thoughts and hypotheses. Identify flaws in your previous hypotheses and current thinking. Forecast why the hypotheses won't work. Get to the bottom of what is really driving the problem. This tool returns no new information but gives you more time to plan."""
245 | return "Take as much time as you need. It's important to think through different strategies."
246 |
--------------------------------------------------------------------------------
/src/promptim/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hinthornw/promptimizer/2052e812fc52cff19d8539b30b1782385f011c1e/src/promptim/py.typed
--------------------------------------------------------------------------------
/src/promptim/tasks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hinthornw/promptimizer/2052e812fc52cff19d8539b30b1782385f011c1e/src/promptim/tasks/__init__.py
--------------------------------------------------------------------------------
/src/promptim/tasks/metaprompt.py:
--------------------------------------------------------------------------------
1 | import langsmith as ls
2 | from langchain_anthropic import ChatAnthropic
3 | from langchain_core.prompts import ChatPromptTemplate
4 |
5 | from promptim.tasks.scone import scone_task
6 | from promptim.tasks.simpleqa import simpleqa_task
7 | from promptim.tasks.ticket_classification import ticket_classification_task
8 | from promptim.tasks.tweet_generator import tweet_task
9 | from promptim.optimizers import metaprompt as metaprompt_optimizer
10 | from promptim import types as pm_types, trainer as pm_trainer
11 |
12 | DEFAULT_METAMETAPROMPT = """You are an expert in prompt optimization systems. Your task is to improve the effectiveness of prompt optimization prompts - the prompts used to guide the improvement of task-specific prompts.
13 |
14 | Current Optimization Prompt:
15 | {current_prompt}
16 |
17 | Performance Data:
18 | Shows how this optimization prompt performed in improving various task-specific prompts
19 | {annotated_results}
20 |
21 | Objective:
22 | Improve the optimization prompt to better guide the enhancement of task-specific prompts across:
23 | - Systematic analysis of prompt performance
24 | - Identification of improvement opportunities
25 | - Generation of enhanced prompts
26 | - Validation of improvements
27 |
28 | Analysis Steps:
29 | 1. Optimization Effectiveness
30 | - How well did this optimization prompt guide improvements?
31 | - Which aspects of prompt optimization were handled well/poorly?
32 | - What patterns emerge in successful vs unsuccessful optimization attempts?
33 |
34 | 2. Structural Assessment
35 | - How clearly does it guide the optimization process?
36 | - How well does it maintain prompt constraints?
37 | - What components are missing or ineffective?
38 |
39 | 3. Improvement Strategy
40 | - Which elements of the optimization process need enhancement?
41 | - How can we make the optimization guidance more effective?
42 | - What additional checks or validations would help?
43 |
44 | Output Format:
45 |
46 | Analysis of how well this optimization prompt guides improvements
47 |
48 |
49 |
50 | Specific changes to enhance optimization capabilities
51 |
52 |
53 |
54 | The enhanced prompt for optimizing task-specific prompts
55 | """
56 |
57 |
58 | class MetaPromptSystem:
59 | """System for running the metaprompt optimization task."""
60 |
61 | def __init__(
62 | self, task_map: dict[str, pm_types.Task], meta_prompt: pm_types.PromptWrapper
63 | ):
64 | from langchain.chat_models import init_chat_model
65 |
66 | self.task_map = task_map
67 | try:
68 | self.model = ChatAnthropic(
69 | model="claude-3-5-sonnet-20241022", max_tokens_to_sample=8192
70 | )
71 | except Exception:
72 | self.model = init_chat_model()
73 |
74 | self.trainer = pm_trainer.PromptOptimizer(
75 | self.model, meta_prompt.get_prompt_str()
76 | )
77 |
78 | async def __call__(self, prompt: ChatPromptTemplate, inputs: dict) -> dict:
79 | task = self.task_map[inputs["task"]]
80 | task.initial_prompt.load()
81 |
82 | # Run initial prompt on batch using aevaluate
83 | async def predict(example_inputs: dict):
84 | return await task.system_safe(task.initial_prompt.load(), example_inputs)
85 |
86 | train_batch = list(
87 | self.trainer.client.list_examples(example_ids=inputs["train_batch"])
88 | )
89 | dev_batch = list(
90 | self.trainer.client.list_examples(example_ids=inputs["dev_batch"])
91 | )
92 | with ls.tracing_context(parent={"langsmith-trace": ""}):
93 | initial_results = [
94 | r
95 | async for r in (
96 | await ls.aevaluate(
97 | predict,
98 | data=train_batch,
99 | evaluators=task.evaluators,
100 | )
101 | )
102 | ]
103 | task.initial_prompt.get_prompt_str()
104 |
105 | # Generate new downstream task prompt
106 | extracted = await self.trainer.apply_metaprompt(
107 | current_prompt=task.initial_prompt,
108 | meta_prompt=prompt.messages[0].prompt.template, # type: ignore
109 | task=task,
110 | results=initial_results,
111 | )
112 |
113 | # Now we actually evaluate based on how well the updated prompt's "improvements"
114 | # translate to a dev batch
115 | with ls.tracing_context(parent={"langsmith-trace": ""}):
116 | initial_dev_results = [
117 | r
118 | async for r in (
119 | await ls.aevaluate(
120 | predict,
121 | data=dev_batch,
122 | evaluators=task.evaluators,
123 | )
124 | )
125 | ]
126 | initial_dev_scores = await self.trainer.calculate_scores(initial_dev_results)
127 |
128 | async def predict_new(example_inputs: dict):
129 | return await task.system_safe(extracted._cached, example_inputs)
130 |
131 | with ls.tracing_context(parent={"langsmith-trace": ""}):
132 | new_results = [
133 | r
134 | async for r in (
135 | await ls.aevaluate(
136 | predict_new,
137 | data=dev_batch,
138 | evaluators=task.evaluators,
139 | )
140 | )
141 | ]
142 | new_scores = await self.trainer.calculate_scores(new_results)
143 | return {
144 | "original_prompt": task.initial_prompt,
145 | "new_prompt": extracted.get_prompt_str(),
146 | # "reasoning_for_changes": extracted.analysis,
147 | "initial_scores": initial_dev_scores,
148 | "new_scores": new_scores,
149 | }
150 |
151 |
152 | def metaprompt_evaluator(run, example):
153 | """Evaluate the performance of the metaprompt."""
154 | original_score = sum(run.outputs["initial_scores"].values()) / len(
155 | run.outputs["initial_scores"]
156 | )
157 | new_score = sum(run.outputs["new_scores"].values()) / len(run.outputs["new_scores"])
158 | # Map the difference in scores to a 0 to 1 scale
159 | score_diff = new_score - original_score
160 | normalized_score = max(0, min(1, (score_diff + 1) / 2))
161 | if normalized_score > 0.5:
162 | comment = "The average scores improved after making the changes suggested by the prompt optimizer."
163 | elif normalized_score < 0.5:
164 | comment = "The average score dropped after making the changes suggested by the prompt optimizer."
165 | else:
166 | comment = "The average score remained the same after making the changes suggested by the prompt optimizer."
167 |
168 | return {
169 | "key": "metaprompt_improvement",
170 | "score": normalized_score,
171 | "comment": comment,
172 | }
173 |
174 |
175 | prompt_config = pm_types.PromptWrapper(
176 | prompt_str=metaprompt_optimizer.DEFAULT_METAPROMPT
177 | )
178 | metaprompt_task = pm_types.Task(
179 | name="MetaPrompt Optimizer",
180 | description="A meta-optimization task that aims to improve the prompt used for optimizing task-specific prompts. This task evaluates and enhances the effectiveness of the prompt optimization process itself, leading to better performance across various language tasks.",
181 | dataset="metaprompt-optim",
182 | initial_prompt=prompt_config,
183 | evaluators=[metaprompt_evaluator],
184 | evaluator_descriptions={
185 | "metaprompt_improvement": "Checks if the new prompt leads to improved scores. 1 if better, 0.5 if same, 0 if worse."
186 | },
187 | system=MetaPromptSystem(
188 | {
189 | "scone": scone_task,
190 | "tweet": tweet_task,
191 | "simpleqa": simpleqa_task,
192 | "ticket_classification_task": ticket_classification_task,
193 | },
194 | meta_prompt=prompt_config,
195 | ),
196 | )
197 |
198 |
199 | if __name__ == "__main__":
200 | import argparse
201 | import random
202 |
203 | from langsmith import Client
204 |
205 | random.seed(42)
206 |
207 | def create_datasets(client, tasks, batchsize, overwrite=False):
208 | datasets = {}
209 | for split_name in ["train", "dev", "test"]:
210 | dataset_name = "metaprompt-optim"
211 | if overwrite:
212 | if client.has_dataset(dataset_name=dataset_name):
213 | client.delete_dataset(dataset_name=dataset_name)
214 | datasets[split_name] = client.create_dataset(dataset_name=dataset_name)
215 |
216 | for task_name in tasks:
217 | task_datasets = {
218 | "train": client.list_examples(dataset_name=f"{task_name}-train"),
219 | "dev": client.list_examples(dataset_name=f"{task_name}-dev"),
220 | "test": client.list_examples(dataset_name=f"{task_name}-test"),
221 | }
222 |
223 | for split_name in ["train", "dev", "test"]:
224 | examples = []
225 | task_split_examples = list(task_datasets[split_name])
226 | random.shuffle(task_split_examples)
227 | for i in range(len(task_split_examples) // (batchsize * 2)):
228 | batch = [
229 | str(example.id)
230 | for example in task_split_examples[i : i + batchsize * 2]
231 | ]
232 |
233 | examples.append(
234 | {
235 | "task": task_name,
236 | "train_batch": batch[0 : len(batch) // 2],
237 | "dev_batch": batch[len(batch) // 2 :],
238 | }
239 | )
240 |
241 | client.create_examples(
242 | inputs=examples,
243 | dataset_id=datasets[split_name].id,
244 | splits=[split_name] * len(examples),
245 | )
246 |
247 | return datasets
248 |
249 | parser = argparse.ArgumentParser(
250 | description="Generate datasets for metaprompt optimization"
251 | )
252 | parser.add_argument(
253 | "--batchsize", type=int, default=5, help="Number of examples in each batch."
254 | )
255 | parser.add_argument(
256 | "--overwrite",
257 | action="store_true",
258 | help="Overwrite existing datasets if they exist.",
259 | )
260 | args = parser.parse_args()
261 |
262 | client = Client()
263 | tasks = ["scone", "tweet", "simpleqa", "ticket_classification_task"]
264 |
265 | datasets = create_datasets(client, tasks, args.batchsize, args.overwrite)
266 |
267 | print("Datasets created successfully!")
268 | for name, dataset in datasets.items():
269 | print(f"{name.capitalize()} dataset ID: {dataset.id}")
270 |
--------------------------------------------------------------------------------
/src/promptim/tasks/scone.py:
--------------------------------------------------------------------------------
1 | from promptim.types import PromptWrapper, Task
2 |
3 |
4 | def exact_match(run, example):
5 | """Evaluate the exact match correctness of the NLI result."""
6 | try:
7 | predicted = run.outputs["is_entailed"]
8 | expected = example.outputs["answer"]
9 | score = expected.lower() == predicted.lower()
10 | except Exception:
11 | try:
12 | expected = example.outputs["answer"]
13 | expected_bool = {"no": False, "yes": True}.get(expected.strip().lower())
14 | score = run.outputs["output"].is_entailed == expected_bool
15 | except Exception:
16 | score = 0
17 | return {
18 | "key": "exact_match",
19 | "score": int(score),
20 | }
21 |
22 |
23 | scone_task = Task(
24 | name="Scone (NLI)",
25 | dataset="scone-optim",
26 | initial_prompt=PromptWrapper(identifier="langchain-ai/scone-example:d49910d6"),
27 | evaluators=[exact_match],
28 | evaluator_descriptions={
29 | "exact_match": "Directly compares the expected against the predicted outputs. 1 if correct, 0 if incorrect."
30 | },
31 | )
32 |
--------------------------------------------------------------------------------
/src/promptim/tasks/simpleqa.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from langchain_openai import ChatOpenAI
4 | from pydantic import BaseModel, Field
5 |
6 | from promptim.types import PromptWrapper, Task
7 |
8 | GRADER_TEMPLATE = """Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"].
9 | First, I will give examples of each grade, and then you will grade a new example.
10 |
11 | The following are examples of CORRECT predicted answers.
12 | ```
13 | Question: What are the names of Barack Obama's children?
14 | Gold target: Malia Obama and Sasha Obama
15 | Predicted answer 1: sasha and malia obama
16 | Predicted answer 2: most people would say Malia and Sasha, but I'm not sure and would have to double check
17 | Predicted answer 3: Barack Obama has two daughters. Their names are Malia Ann and Natasha Marian, but they are commonly referred to as Malia Obama and Sasha Obama. Malia was born on July 4, 1998, and Sasha was born on June 10, 2001.
18 | ```
19 | These predicted answers are all CORRECT because:
20 | - They fully contain the important information in the gold target.
21 | - They do not contain any information that contradicts the gold target.
22 | - Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter.
23 | - Hedging and guessing are permissible, provided that the gold target is fully included and the response contains no incorrect information or contradictions.
24 |
25 | The following are examples of INCORRECT predicted answers.
26 | ```
27 | Question: What are the names of Barack Obama's children?
28 | Gold target: Malia and Sasha
29 | Predicted answer 1: Malia.
30 | Predicted answer 2: Malia, Sasha, and Susan.
31 | Predicted answer 3: Barack Obama does not have any children.
32 | Predicted answer 4: I think it's either Malia and Sasha. Or it could be Malia and Jackie. Or it could be Joey and Malia.
33 | Predicted answer 4: While I don't know their exact names, I can tell you that Barack Obama has three children.
34 | Predicted answer 5: It's possible you may mean Betsy and Olivia. However, you should clarify further details with updated references if necessary. Is that the correct answer?
35 | Predicted answer 6: It may be the case that Obama's child is named James. However, it's recommended to confirm the most accurate and updated information since this could change over time. This model may not always reflect the most current information.
36 | ```
37 | These predicted answers are all INCORRECT because:
38 | - A factual statement in the answer contradicts the gold target. Incorrect statements that have some hedging (e.g., "it is possible that", "although i'm not sure, i think") are also considered incorrect.
39 |
40 | The following are examples of NOT_ATTEMPTED predicted answers.
41 | ```
42 | Question: What are the names of Barack Obama's children?
43 | Gold target: Malia and Sasha
44 | Predicted answer 1: I don't know.
45 | Predicted answer 2: I need more context about which Obama you are talking about.
46 | Predicted answer 3: Without researching the web, I cannot answer this question. However, I can tell you that Barack Obama has two children.
47 | Predicted answer 4: Barack Obama has two children. I know that one of them is Malia, but I'm not sure about the other one.
48 | ```
49 | These predicted answers are all NOT_ATTEMPTED because:
50 | - The important information in the gold target is not included in the answer.
51 | - No statements in the answer contradict the gold target.
52 |
53 | Also note the following things:
54 | - For grading questions where the gold target is a number, the predicted answer needs to be correct to the last significant figure in the gold answer. For example, consider a question "How many citations does the Transformer Paper have?" with gold target "120k".
55 | - Predicted answers "120k", "124k", and 115k" are all CORRECT.
56 | - Predicted answers "100k" and "113k" are INCORRECT.
57 | - Predicted answers "around 100k" and "more than 50k" are considered NOT_ATTEMPTED because they neither confirm nor contradict the gold target.
58 | - The gold target may contain more information than the question. In such cases, the predicted answer only needs to contain the information that is in the question.
59 | - For example, consider the question "What episode did Derek and Meredith get legally married in Grey's Anatomy?" with gold target "Season 7, Episode 20: White Wedding". Either "Season 7, Episode 20" or "White Wedding" would be considered a CORRECT answer.
60 | - Do not punish predicted answers if they omit information that would be clearly inferred from the question.
61 | - For example, consider the question "What city is OpenAI headquartered in?" and the gold target "San Francisco, California". The predicted answer "San Francisco" would be considered CORRECT, even though it does not include "California".
62 | - Consider the question "What award did A pretrainer's guide to training data: Measuring the effects of data age, domain coverage, quality, & toxicity win at NAACL '24?", the gold target is "Outstanding Paper Award". The predicted answer "Outstanding Paper" would be considered CORRECT, because "award" is presumed in the question.
63 | - For the question "What is the height of Jason Wei in meters?", the gold target is "1.73 m". The predicted answer "1.75" would be considered CORRECT, because meters is specified in the question.
64 | - For the question "What is the name of Barack Obama's wife?", the gold target is "Michelle Obama". The predicted answer "Michelle" would be considered CORRECT, because the last name can be presumed.
65 | - Do not punish for typos in people's name if it's clearly the same name.
66 | - For example, if the gold target is "Hyung Won Chung", you can consider the following predicted answers as correct: "Hyoong Won Choong", "Hyungwon Chung", or "Hyun Won Chung".
67 |
68 | Here is a new example. Simply reply with either CORRECT, INCORRECT, NOT_ATTEMPTED. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
69 | ```
70 | Question: {question}
71 | Gold target: {target}
72 | Predicted answer: {predicted_answer}
73 | ```
74 |
75 | Grade the predicted answer of this new question as one of:
76 | - CORRECT
77 | - INCORRECT
78 | - NOT_ATTEMPTED
79 | """
80 |
81 |
82 | class GraderOutput(BaseModel):
83 | """Submit the grade."""
84 |
85 | grade: Literal["CORRECT", "INCORRECT", "NOT_ATTEMPTED"] = Field(
86 | description="The grade for the predicted answer"
87 | )
88 |
89 |
90 | grader = ChatOpenAI(model="gpt-4o-mini").with_structured_output(GraderOutput)
91 |
92 |
93 | async def simpleqa_evaluator(run, example):
94 | """Evaluate the correctness of the SimpleQA answer."""
95 | grader_prompt = GRADER_TEMPLATE.format(
96 | question=example.inputs["problem"],
97 | target=example.outputs["answer"],
98 | predicted_answer=run.outputs["output"],
99 | )
100 |
101 | grading_response = await grader.ainvoke(grader_prompt)
102 |
103 | grade = grading_response.grade
104 |
105 | score = 1 if grade == "CORRECT" else 0
106 | comment = grade.lower()
107 |
108 | return {"key": "simpleqa_score", "score": score, "comment": comment}
109 |
110 |
111 | simpleqa_task = Task(
112 | name="SimpleQA",
113 | description="A task to measure short-form factuality in large language models",
114 | dataset="simpleqa-optim",
115 | initial_prompt=PromptWrapper(identifier="langchain-ai/simpleqa-example:43349b82"),
116 | evaluators=[simpleqa_evaluator],
117 | evaluator_descriptions={
118 | "simpleqa_score": "Evaluates the correctness of the answer. 1 if correct, 0 if incorrect or not attempted."
119 | },
120 | )
121 |
122 | if __name__ == "__main__":
123 | import random
124 |
125 | import langsmith as ls
126 |
127 | c = ls.Client()
128 | examples = list(c.list_examples(dataset_name="Simple QA Full"))
129 |
130 | random.shuffle(examples)
131 | full = examples.copy()
132 | train, dev, test = [], [], []
133 | dataset = c.create_dataset(dataset_name="simpleqa-optim")
134 | for ds, size, name in zip(
135 | [train, dev, test], [200, 100, 100], ["train", "dev", "test"]
136 | ):
137 | for i in range(size):
138 | ds.append(full.pop())
139 |
140 | c.create_examples(
141 | inputs=[{"problem": e.inputs["problem"]} for e in ds],
142 | outputs=[e.outputs for e in ds],
143 | dataset_id=dataset.id,
144 | metadata=[e.inputs["metadata"] for e in ds],
145 | splits=[name] * len(ds),
146 | )
147 |
--------------------------------------------------------------------------------
/src/promptim/tasks/ticket_classification.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 |
4 | from langchain_openai import ChatOpenAI
5 | from pydantic import BaseModel, Field
6 |
7 | from promptim.types import PromptWrapper, Task
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class Grade(BaseModel):
13 | """Call to submit your grade."""
14 |
15 | reasoning: str = Field(
16 | description="First, explain your thought process on why you are giving the provided grade."
17 | )
18 | score: int = Field(
19 | ge=0, le=5, description="Then, submit your score on a scale from 0 to 5."
20 | )
21 |
22 | @property
23 | def normalized(self):
24 | return self.score / 5
25 |
26 |
27 | @functools.lru_cache
28 | def _get_judge():
29 | from trustcall import create_extractor
30 |
31 | return create_extractor(
32 | ChatOpenAI(model="gpt-4o-mini"), tools=[Grade], tool_choice=Grade.__name__
33 | )
34 |
35 |
36 | utemplate = """Grade the following:
37 | Predicted: {predicted}
38 | Reference example: {reference}"""
39 |
40 |
41 | async def summary_quality(run, example):
42 | predicted = run.outputs.get("summary")
43 | rubric = """Grade the quality of summary. If it fails any criteria, give a 0. If it's perfect, give a 5.
44 | Criteria:
45 | - Must not include idle words like "the email is about X"
46 | - Preferred format is from needs/wants X
47 | """
48 | reference = example.outputs["summary"]
49 | result = await _get_judge().ainvoke(
50 | [
51 | ("system", rubric),
52 | ("user", utemplate.format(predicted=predicted, reference=reference)),
53 | ]
54 | )
55 | grade: Grade = result["responses"][0]
56 | pf = "Pass" if grade.score >= 4 else "Fail"
57 | return {"score": grade.normalized, "comment": f"{pf}: {grade.reasoning}"}
58 |
59 |
60 | def accuracy_check(run, example, key: str):
61 | predicted = run.outputs.get(key)
62 | reference = example.outputs.get(key)
63 | if reference is None:
64 | return {
65 | "key": f"{key}-correctness",
66 | "comment": "Skipping - reference label not found.",
67 | }
68 | score = (
69 | predicted == reference
70 | if not isinstance(reference, list)
71 | else predicted in reference
72 | )
73 | pf = "Pass" if score else "Fail"
74 | return {
75 | "key": f"{key}-correctness",
76 | "score": score,
77 | "comment": f"{pf}",
78 | } #: Expected {reference}. Got: {predicted}. Why did you get this wrong? Think deeply and update associations."}
79 |
80 |
81 | classifiers = [
82 | functools.partial(accuracy_check, key=key)
83 | for key in [
84 | "category",
85 | "support_category",
86 | "ticket_status",
87 | "requires_response",
88 | "non_support_category",
89 | ]
90 | ]
91 | evaluators = [summary_quality, *classifiers]
92 |
93 |
94 | ticket_classification_task = Task(
95 | name="Ticket Classification",
96 | description="A task to classify customer support tickets",
97 | dataset="ticket-classification-optim",
98 | initial_prompt=PromptWrapper(
99 | identifier="langchain-ai/ticket-classifier-example:376ab5e4",
100 | which=1,
101 | ),
102 | evaluators=evaluators,
103 | evaluator_descriptions={
104 | "summary_quality": "Evaluates the quality of the summary",
105 | "category-correctness": "Checks if the category is correct",
106 | "support_category-correctness": "Checks if the support category is correct",
107 | "ticket_status-correctness": "Checks if the ticket status is correct",
108 | "requires_response-correctness": "Checks if the requires_response field is correct",
109 | "non_support_category-correctness": "Checks if the non-support category is correct",
110 | },
111 | )
112 |
113 | if __name__ == "__main__":
114 | import random
115 |
116 | import langsmith as ls
117 |
118 | c = ls.Client()
119 | examples = list(
120 | c.list_examples(dataset_name="customer-support-bot.test_extraction")
121 | )
122 |
123 | random.shuffle(examples)
124 | full = examples.copy()
125 | train, dev, test = [], [], []
126 | dname = "ticket-classification-optim"
127 | try:
128 | dataset = c.create_dataset(dataset_name=dname)
129 | except Exception:
130 | c.delete_dataset(dataset_name=dname)
131 | dataset = c.create_dataset(dataset_name=dname)
132 | for ds, size, name in zip(
133 | [train, dev, test], [41, 20, 20], ["train", "dev", "test"]
134 | ):
135 | for i in range(size):
136 | ds.append(full.pop())
137 |
138 | outputs = [e.outputs for e in ds]
139 | for o in outputs:
140 | for k, v in o.pop("outputs", {}).items():
141 | o[k] = v
142 | c.create_examples(
143 | inputs=[e.inputs for e in ds],
144 | outputs=outputs,
145 | dataset_id=dataset.id,
146 | splits=[name] * len(ds),
147 | )
148 |
--------------------------------------------------------------------------------
/src/promptim/tasks/tweet_generator.py:
--------------------------------------------------------------------------------
1 | from promptim.types import PromptWrapper, Task
2 |
3 |
4 | def under_180_chars(run, example):
5 | """Evaluate if the tweet is under 180 characters."""
6 | result = run.outputs.get("tweet", "")
7 | score = int(len(result) < 180)
8 | comment = "Pass" if score == 1 else "Fail"
9 | return {
10 | "key": "under_180_chars",
11 | "score": score,
12 | "comment": comment,
13 | }
14 |
15 |
16 | def no_hashtags(run, example):
17 | """Evaluate if the tweet contains no hashtags."""
18 | result = run.outputs.get("tweet", "")
19 | score = int("#" not in result)
20 | comment = "Pass" if score == 1 else "Fail"
21 | return {
22 | "key": "no_hashtags",
23 | "score": score,
24 | "comment": comment,
25 | }
26 |
27 |
28 | def multiple_lines(run, example):
29 | """Evaluate if the tweet contains multiple lines."""
30 | result = run.outputs.get("tweet", "")
31 | score = int("\n" in result)
32 | comment = "Pass" if score == 1 else "Fail"
33 | return {
34 | "key": "multiline",
35 | "score": score,
36 | "comment": comment,
37 | }
38 |
39 |
40 | tweet_task = Task(
41 | name="Tweet Generator",
42 | dataset="tweet-optim",
43 | initial_prompt=PromptWrapper(identifier="tweet-generator-example:c39837bd"),
44 | evaluators=[under_180_chars, no_hashtags, multiple_lines],
45 | evaluator_descriptions={
46 | "under_180_chars": "Checks if the tweet is under 180 characters. 1 if true, 0 if false.",
47 | "no_hashtags": "Checks if the tweet contains no hashtags. 1 if true, 0 if false.",
48 | "multiline": "Fails if the tweet is not multiple lines. 1 if true, 0 if false. 0 is bad.",
49 | },
50 | )
51 |
52 |
53 | ## Example of how to create the dataset
54 |
55 | if __name__ == "__main__":
56 | from langsmith import Client
57 |
58 | client = Client()
59 |
60 | topics = [
61 | "NBA",
62 | "NFL",
63 | "Movies",
64 | "Taylor Swift",
65 | "Artificial Intelligence",
66 | "Climate Change",
67 | "Space Exploration",
68 | "Cryptocurrency",
69 | "Healthy Living",
70 | "Travel Destinations",
71 | "Technology Trends",
72 | "Fashion",
73 | "Food and Cooking",
74 | "Music Festivals",
75 | "Entrepreneurship",
76 | "Fitness",
77 | "Gaming",
78 | "Politics",
79 | "Environmental Conservation",
80 | "Social Media Trends",
81 | "Education",
82 | "Mental Health",
83 | "Renewable Energy",
84 | "Virtual Reality",
85 | "Sustainable Fashion",
86 | "Robotics",
87 | "Quantum Computing",
88 | "Genetic Engineering",
89 | "Smart Cities",
90 | "Cybersecurity",
91 | "Augmented Reality",
92 | "Electric Vehicles",
93 | "Blockchain",
94 | "3D Printing",
95 | "Nanotechnology",
96 | "Biotechnology",
97 | "Internet of Things",
98 | "Cloud Computing",
99 | "Big Data",
100 | "Machine Learning",
101 | "Artificial General Intelligence",
102 | "Space Tourism",
103 | "Autonomous Vehicles",
104 | "Drones",
105 | "Wearable Technology",
106 | "Personalized Medicine",
107 | "Telemedicine",
108 | "Remote Work",
109 | "Digital Nomads",
110 | "Gig Economy",
111 | "Circular Economy",
112 | "Vertical Farming",
113 | "Lab-grown Meat",
114 | "Plant-based Diets",
115 | "Mindfulness",
116 | "Yoga",
117 | "Meditation",
118 | "Biohacking",
119 | "Nootropics",
120 | "Intermittent Fasting",
121 | "HIIT Workouts",
122 | "Esports",
123 | "Streaming Services",
124 | "Podcasting",
125 | "True Crime",
126 | "Tiny Houses",
127 | "Minimalism",
128 | "Zero Waste Living",
129 | "Upcycling",
130 | "Eco-tourism",
131 | "Voluntourism",
132 | "Digital Detox",
133 | "Slow Living",
134 | "Hygge",
135 | "Urban Gardening",
136 | "Permaculture",
137 | "Regenerative Agriculture",
138 | "Microplastics",
139 | "Ocean Conservation",
140 | "Rewilding",
141 | "Endangered Species",
142 | "Biodiversity",
143 | "Ethical AI",
144 | "Data Privacy",
145 | "Net Neutrality",
146 | "Deepfakes",
147 | "Fake News",
148 | "Social Media Activism",
149 | "Cancel Culture",
150 | "Meme Culture",
151 | "NFTs",
152 | "Decentralized Finance",
153 | "Universal Basic Income",
154 | "Gender Equality",
155 | "LGBTQ+ Rights",
156 | "Black Lives Matter",
157 | "Indigenous Rights",
158 | "Refugee Crisis",
159 | "Global Poverty",
160 | "Universal Healthcare",
161 | "Drug Decriminalization",
162 | "Prison Reform",
163 | "Gun Control",
164 | "Voting Rights",
165 | "Gerrymandering",
166 | "Campaign Finance Reform",
167 | "Term Limits",
168 | "Ranked Choice Voting",
169 | "Direct Democracy",
170 | "Space Debris",
171 | "Asteroid Mining",
172 | "Mars Colonization",
173 | "Extraterrestrial Life",
174 | "Dark Matter",
175 | "Black Holes",
176 | "Quantum Entanglement",
177 | "Fusion Energy",
178 | "Antimatter",
179 | "Cryonics",
180 | "Life Extension",
181 | "Transhumanism",
182 | "Cyborgs",
183 | "Brain-Computer Interfaces",
184 | "Memory Implants",
185 | "Holographic Displays",
186 | ]
187 |
188 | # Create datasets
189 | ds = client.create_dataset(dataset_name="tweet-optim")
190 |
191 | # Split topics into train, dev, and test sets
192 | train_topics = topics[:80]
193 | dev_topics = topics[80:90]
194 | test_topics = topics[90:]
195 |
196 | # Create examples for each dataset
197 | for split_name, dataset_topics in [
198 | ("train", train_topics),
199 | ("dev", dev_topics),
200 | ("test", test_topics),
201 | ]:
202 | client.create_examples(
203 | inputs=[{"topic": topic} for topic in dataset_topics],
204 | dataset_id=ds.id,
205 | splits=[split_name] * len(dataset_topics),
206 | )
207 |
208 | print("Dataset created successfully!")
209 |
--------------------------------------------------------------------------------
/src/promptim/types.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | from dataclasses import dataclass, field, fields
4 | from typing import Callable, Optional, Any, Protocol
5 | from uuid import UUID
6 |
7 | import langsmith as ls
8 | from langchain.chat_models import init_chat_model
9 | from langchain_core.language_models import BaseChatModel
10 | from langchain_core.load import dumps
11 | from langchain_core.prompts import ChatPromptTemplate
12 | from langchain_core.prompts.structured import StructuredPrompt
13 | from langchain_core.prompts import MessagesPlaceholder
14 | from langchain_core.runnables import RunnableBinding, RunnableSequence
15 | from langsmith.schemas import Example, Run
16 | from langsmith.utils import LangSmithConflictError
17 | from pydantic import BaseModel, Field, model_validator
18 | from promptim._utils import get_var_healer
19 | import logging
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 | DEFAULT_PROMPT_MODEL_CONFIG = {"model": "claude-3-5-haiku-20241022"}
24 | DEFAULT_OPTIMIZER_MODEL_CONFIG = {
25 | "model": "claude-3-5-sonnet-20241022",
26 | "max_tokens_to_sample": 8192,
27 | }
28 |
29 |
30 | SystemType = Callable[[ChatPromptTemplate, dict], dict]
31 | """Takes the current prompt and the example inputs and returns the results."""
32 |
33 |
34 | @dataclass(kw_only=True)
35 | class PromptConfig:
36 | identifier: str | None = field(
37 | default=None,
38 | metadata={
39 | "description": "Identifier for a prompt from the hub repository. Mutually exclusive with prompt_str."
40 | },
41 | )
42 | prompt_str: str | None = field(
43 | default=None,
44 | metadata={
45 | "description": "Raw prompt string to optimize locally. Mutually exclusive with identifier."
46 | },
47 | )
48 | model_config: dict | None = field(
49 | default=None,
50 | metadata={
51 | "description": "Configuration dictionary specifying model parameters for optimization."
52 | },
53 | )
54 | which: int = field(
55 | default=0,
56 | metadata={"description": "Index of the message to optimize within the prompt."},
57 | )
58 | upload_to: str | None = field(
59 | default=None,
60 | metadata={
61 | "description": "Upload the prompt to the hub repository. Mutually exclusive with identifier."
62 | },
63 | )
64 |
65 | def __post_init__(self):
66 | if self.identifier and self.prompt_str:
67 | raise ValueError(
68 | "Cannot provide both identifier and prompt_str. Choose one."
69 | )
70 | elif not self.identifier and not self.prompt_str:
71 | raise ValueError("Must provide either identifier or prompt_str.")
72 | if self.identifier and not self.upload_to:
73 | self.upload_to = self.identifier
74 |
75 |
76 | @dataclass(kw_only=True)
77 | class PromptWrapper(PromptConfig):
78 | _cached: ChatPromptTemplate | None = None
79 | _postlude: RunnableBinding | BaseChatModel | None = None
80 | lineage: list["PromptWrapper"] | None = None
81 | extra: dict | None = None
82 |
83 | @classmethod
84 | def from_config(cls, config: PromptConfig | dict):
85 | if isinstance(config, dict):
86 | config = PromptConfig(**config)
87 | return cls(
88 | identifier=config.identifier,
89 | prompt_str=config.prompt_str,
90 | model_config=config.model_config,
91 | which=config.which,
92 | )
93 |
94 | def load(self, client: ls.Client | None = None) -> ChatPromptTemplate:
95 | if self._cached is None:
96 | if self.prompt_str:
97 | self._cached = ChatPromptTemplate.from_messages(
98 | [("user", self.prompt_str)]
99 | )
100 | self._postlude = init_chat_model(
101 | **(self.model_config or DEFAULT_PROMPT_MODEL_CONFIG)
102 | )
103 | else:
104 | client = client or ls.Client()
105 | postlude = None
106 | prompt = client.pull_prompt(self.identifier, include_model=True)
107 | if isinstance(prompt, RunnableSequence):
108 | prompt, bound_llm = prompt.first, prompt.steps[1]
109 | new_model = None
110 |
111 | if isinstance(bound_llm, RunnableBinding):
112 | if tools := bound_llm.kwargs.get("tools"):
113 | bound_llm.kwargs["tools"] = _ensure_stricty(tools)
114 | if new_model:
115 | bound_llm = new_model.bind(
116 | **{
117 | k: v
118 | for k, v in bound_llm.kwargs.items()
119 | if k not in ("model", "model_name")
120 | }
121 | )
122 | else:
123 | if new_model:
124 | bound_llm = new_model
125 | if isinstance(prompt, StructuredPrompt) and isinstance(
126 | bound_llm, RunnableBinding
127 | ):
128 | seq: RunnableSequence = prompt | bound_llm.bound
129 |
130 | rebound_llm = seq.steps[1]
131 | if tools := rebound_llm.kwargs.get("tools"):
132 | rebound_llm.kwargs["tools"] = _ensure_stricty(tools)
133 | parser = seq.steps[2]
134 | postlude = RunnableSequence(
135 | rebound_llm.bind(
136 | **{
137 | k: v
138 | for k, v in (
139 | dict((bound_llm.kwargs or {}))
140 | | (self.model_config or {})
141 | ).items()
142 | if k not in rebound_llm.kwargs
143 | and k not in ("model", "model_name")
144 | }
145 | ),
146 | parser,
147 | )
148 | else:
149 | postlude = bound_llm
150 | else:
151 | postlude = init_chat_model(
152 | **(self.model_config or DEFAULT_PROMPT_MODEL_CONFIG)
153 | )
154 | if isinstance(prompt, StructuredPrompt):
155 | postlude = RunnableSequence(*(prompt | postlude).steps[1:])
156 | self._cached = prompt
157 | self._postlude = postlude
158 | return self._cached
159 |
160 | def get_prompt_str(self, client: ls.Client | None = None) -> str:
161 | tmpl = self.load(client)
162 | msg = tmpl.messages[self.which]
163 | try:
164 | return (
165 | "{{messages}}"
166 | if isinstance(msg, MessagesPlaceholder)
167 | else msg.prompt.template
168 | )
169 | except Exception as e:
170 | raise NotImplementedError(
171 | f"Unsupported message template format. {msg}"
172 | ) from e
173 |
174 | def required_variables(self) -> set[str]:
175 | tmpl = self.load()
176 | return set(tmpl.messages[self.which].input_variables)
177 |
178 | def get_prompt_str_in_context(self, client: ls.Client | None = None) -> str:
179 | tmpl = self.load(client)
180 | formatted = []
181 | for i, msg in enumerate(tmpl.messages):
182 | kind = msg.__class__.__name__.replace("MessagePromptTemplate", "").replace(
183 | "Human", "User"
184 | )
185 | if i == self.which:
186 | tmpl = "{{messages}}" if isinstance(msg, MessagesPlaceholder) else msg.prompt.template # type: ignore
187 | formatted.append(
188 | f"""
189 | {tmpl}
190 | """
191 | )
192 | else:
193 | tmpl = "{{messages}}" if isinstance(msg, MessagesPlaceholder) else msg.prompt.template # type: ignore
194 | formatted.append(
195 | f"""
196 | {tmpl}
197 |
198 | """
199 | )
200 | return "\n".join(formatted)
201 |
202 | @classmethod
203 | def from_prior(
204 | cls, prior: "PromptWrapper", output: str, extra_info: dict | None = None
205 | ):
206 | copied = prior._cached
207 | if not copied:
208 | raise ValueError("Cannot load from unloaded prior.")
209 | extra_info = extra_info or {}
210 | copied = copy.deepcopy(copied)
211 | tmpl = copied.messages[prior.which]
212 | tmpl.prompt.template = output # type: ignore
213 | lineage = prior.lineage.copy() if prior.lineage else []
214 | lineage.append(prior)
215 | return cls(
216 | identifier=prior.identifier,
217 | prompt_str=prior.prompt_str,
218 | which=prior.which,
219 | _cached=copied,
220 | _postlude=prior._postlude,
221 | lineage=lineage,
222 | extra=extra_info,
223 | upload_to=prior.upload_to,
224 | )
225 |
226 | def push_prompt(
227 | self,
228 | *,
229 | include_model_info: bool = True,
230 | client: ls.Client | None = None,
231 | ) -> str:
232 | if not self.upload_to:
233 | raise ValueError("Cannot push prompt without an upload target.")
234 | client = client or ls.Client()
235 | prompt = self.load(client)
236 | identifier = self.upload_to.rsplit(":", maxsplit=1)[0]
237 | try:
238 | if not include_model_info or not self._postlude:
239 | new_id = client.push_prompt(identifier, object=prompt)
240 | else:
241 | seq = self._get_seq(client)
242 | return self._push_seq(client, seq, identifier)
243 |
244 | except LangSmithConflictError:
245 | return identifier
246 |
247 | return ":".join(
248 | new_id
249 | # Remove the https:// prefix
250 | .split("/prompts/", maxsplit=1)[1]
251 | # Rm query string
252 | .split("?")[0]
253 | # Split the repo from the commit hash
254 | .rsplit("/", maxsplit=1)
255 | )
256 |
257 | def _get_seq(self, client: ls.Client | None = None):
258 | prompt = self.load(client)
259 | second = (
260 | self._postlude.first
261 | if isinstance(self._postlude, RunnableSequence)
262 | else self._postlude
263 | )
264 | if second:
265 | return RunnableSequence(prompt, second)
266 | return prompt
267 |
268 | @staticmethod
269 | def _push_seq(client: ls.Client, seq: RunnableSequence, identifier: str):
270 | manifest = json.loads(dumps(seq))
271 | manifest["id"] = ("langsmith", "playground", "PromptPlayground")
272 | return client.push_prompt(identifier, object=manifest)
273 |
274 | def dumps(self, push: bool = False) -> str:
275 | if push:
276 | identifier = self.push_prompt(include_model_info=False)
277 | else:
278 | identifier = self.identifier
279 | d = {
280 | "identifier": identifier,
281 | "prompt_str": (
282 | self.prompt_str if self.prompt_str else self.get_prompt_str_in_context()
283 | ),
284 | "model_config": self.model_config,
285 | "which": self.which,
286 | "manifest": self._get_seq(client=None),
287 | }
288 | return dumps(d)
289 |
290 |
291 | @dataclass(kw_only=True)
292 | class TaskLike:
293 | """Represents a specific task for prompt optimization."""
294 |
295 | name: str
296 | """The identifier for the task, used for logging and referencing."""
297 | dataset: str
298 | """The name of the dataset in LangSmith to be used for training and evaluation."""
299 | initial_prompt: PromptConfig
300 | """The starting prompt configuration, which will be optimized during the process."""
301 | description: str = ""
302 | """A detailed explanation of the task's objectives and constraints."""
303 | evaluator_descriptions: dict = field(default_factory=dict)
304 | """A mapping of evaluator names to their descriptions, used to guide the optimization process."""
305 | baseline_experiment: Optional[UUID] = None
306 | """The UUID of a previous experiment to use as a baseline for comparison, if available."""
307 |
308 |
309 | @dataclass(kw_only=True)
310 | class Task(TaskLike):
311 | """Represents a specific task for prompt optimization with additional execution details."""
312 |
313 | evaluators: list[Callable[[Run, Example], dict]]
314 | """A list of functions that assess the quality of model outputs, each returning a score and optional feedback."""
315 | system: Optional[SystemType] = None
316 | """A custom system configuration for executing the prompt, allowing for task-specific processing."""
317 |
318 | @classmethod
319 | def from_dict(cls, d: dict):
320 | d_ = d.copy()
321 | kwargs = {"initial_prompt": PromptWrapper(**d_.pop("initial_prompt")), **d_}
322 |
323 | field_names = {f.name for f in fields(cls)}
324 | kwargs = {k: v for k, v in kwargs.items() if k in field_names}
325 | return cls(**kwargs)
326 |
327 | def describe(self):
328 | descript = self.description if self.description else self.name
329 | evaluator_desc = "\n".join(
330 | [f"- {key}: {value}" for key, value in self.evaluator_descriptions.items()]
331 | )
332 | return f"{descript}\n\nDescription of scores:\n{evaluator_desc}"
333 |
334 | @staticmethod
335 | def get_prompt_system(prompt_wrapper: PromptWrapper):
336 | async def prompt_system(prompt: ChatPromptTemplate, inputs: dict):
337 | formatted = prompt.invoke(inputs)
338 | return await prompt_wrapper._postlude.ainvoke(formatted)
339 |
340 | return prompt_system
341 |
342 | @property
343 | def system_safe(self) -> SystemType:
344 | if self.system:
345 | return self.system
346 |
347 | prompt = PromptWrapper.from_config(self.initial_prompt)
348 | return self.get_prompt_system(prompt)
349 |
350 |
351 | class OptimizedPromptOutput(Protocol):
352 | analysis: str
353 | hypothesis: str
354 | improved_prompt: str
355 |
356 |
357 | def prompt_schema(
358 | og_prompt: PromptWrapper,
359 | schema: type[OptimizedPromptOutput] = OptimizedPromptOutput,
360 | ) -> type[OptimizedPromptOutput]:
361 | required_variables = og_prompt.required_variables()
362 | if required_variables:
363 | variables_str = ", ".join(f"{{{var}}}" for var in required_variables)
364 | prompt_description = (
365 | f" The prompt section being optimized contains the following f-string variables to be templated in: {variables_str}."
366 | " You must retain all of these variables in your improved prompt. No other input variables are allowed."
367 | )
368 | else:
369 | prompt_description = (
370 | " The prompt section being optimized contains no input f-string variables."
371 | " Any brackets {{ foo }} you emit will be escaped and not used."
372 | )
373 |
374 | pipeline = get_var_healer(set(required_variables), all_required=True)
375 |
376 | class OptimizedPromptOutput(BaseModel):
377 | """Schema for the optimized prompt output."""
378 |
379 | analysis: str = Field(
380 | description="First, analyze the current results and plan improvements to reconcile them."
381 | )
382 | hypothesis: str = Field(
383 | description="Second, write your hypothesis on what prompt intervention you are making to fix the prompt's errors."
384 | )
385 | improved_prompt: str = Field(
386 | description="The improved prompt text to replace the text contained within the"
387 | f" and tags, in f-string format. Do not includde in your response. {prompt_description}"
388 | )
389 |
390 | @model_validator(mode="before")
391 | @classmethod
392 | def validate_input_variables(cls, data: Any) -> Any:
393 | assert "improved_prompt" in data
394 | data["improved_prompt"] = pipeline(data["improved_prompt"])
395 | return data
396 |
397 | return OptimizedPromptOutput
398 |
399 |
400 | def _ensure_stricty(tools: list) -> list:
401 | result = []
402 | for tool in tools:
403 | if isinstance(tool, dict):
404 | strict = None
405 | if func := tool.get("function"):
406 | if parameters := func.get("parameters"):
407 | if "strict" in parameters:
408 | strict = parameters["strict"]
409 | if strict is not None:
410 | tool = copy.deepcopy(tool)
411 | tool["function"]["strict"] = strict
412 | result.append(tool)
413 | return result
414 |
--------------------------------------------------------------------------------
/static/optimizer.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hinthornw/promptimizer/2052e812fc52cff19d8539b30b1782385f011c1e/static/optimizer.gif
--------------------------------------------------------------------------------
/tests/test_optimizers.py:
--------------------------------------------------------------------------------
1 | from promptim.optimizers import (
2 | load_optimizer,
3 | FeedbackGuidedOptimizer,
4 | MetaPromptOptimizer,
5 | FewShotOptimizer,
6 | )
7 |
8 |
9 | def test_config_kind():
10 | optimizers = [FewShotOptimizer, MetaPromptOptimizer, FeedbackGuidedOptimizer]
11 | _MAP = {OptimizerCls.config_cls.kind: OptimizerCls for OptimizerCls in optimizers}
12 | assert len(_MAP) == len(optimizers)
13 |
14 | for kind in _MAP:
15 | loaded = load_optimizer({"kind": kind})
16 | assert isinstance(loaded, _MAP[kind])
17 |
--------------------------------------------------------------------------------