├── .github
└── workflows
│ ├── pypi-release.yml
│ └── tests.yml
├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── lib.iml
├── misc.xml
├── modules.xml
└── vcs.xml
├── .vscode
└── settings.json
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── litellm
├── .env.sample
├── README.md
├── config.yaml
└── docker-compose.yml
└── py
├── .env.sample
├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── README.md
├── cookbooks
├── bootstrap_few_shot
│ ├── arize.ipynb
│ ├── langfuse.ipynb
│ ├── langsmith.ipynb
│ ├── lunary.ipynb
│ └── parea.ipynb
├── labeled_few_shot
│ ├── arize.ipynb
│ ├── json.ipynb
│ ├── langfuse.ipynb
│ ├── langsmith.ipynb
│ ├── lunary.ipynb
│ └── parea.ipynb
└── predefined
│ ├── single_class_clasifier_trained_with_syn_data.ipynb
│ ├── single_class_classifier.ipynb
│ └── single_class_classifier_syn_data.ipynb
├── pyproject.toml
├── requirements-dev.lock
├── requirements.lock
├── src
└── zenbase
│ ├── __init__.py
│ ├── adaptors
│ ├── __init__.py
│ ├── arize.py
│ ├── arize
│ │ ├── __init__.py
│ │ ├── adaptor.py
│ │ ├── dataset_helper.py
│ │ └── evaluation_helper.py
│ ├── base
│ │ ├── adaptor.py
│ │ ├── dataset_helper.py
│ │ └── evaluation_helper.py
│ ├── braintrust.py
│ ├── json
│ │ ├── adaptor.py
│ │ ├── dataset_helper.py
│ │ └── evaluation_helper.py
│ ├── langchain
│ │ ├── __init__.py
│ │ ├── adaptor.py
│ │ ├── dataset_helper.py
│ │ └── evaluation_helper.py
│ ├── langfuse_helper
│ │ ├── __init__.py
│ │ ├── adaptor.py
│ │ ├── dataset_helper.py
│ │ └── evaluation_helper.py
│ ├── lunary
│ │ ├── __init__.py
│ │ ├── adaptor.py
│ │ ├── dataset_helper.py
│ │ └── evaluation_helper.py
│ └── parea
│ │ ├── __init__.py
│ │ ├── adaptor.py
│ │ ├── dataset_helper.py
│ │ └── evaluation_helper.py
│ ├── core
│ └── managers.py
│ ├── optim
│ ├── base.py
│ └── metric
│ │ ├── bootstrap_few_shot.py
│ │ ├── labeled_few_shot.py
│ │ └── types.py
│ ├── predefined
│ ├── __init__.py
│ ├── base
│ │ ├── function_generator.py
│ │ └── optimizer.py
│ ├── generic_lm_function
│ │ └── optimizer.py
│ ├── single_class_classifier
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ └── function_generator.py
│ └── syntethic_data
│ │ └── single_class_classifier.py
│ ├── settings.py
│ ├── types.py
│ └── utils.py
└── tests
├── adaptors
├── bootstrap_few_shot_optimizer_args.zenbase
├── parea_bootstrap_few_shot.zenbase
├── test_arize.py
├── test_braintrust.py
├── test_langchain.py
├── test_langfuse.py
├── test_lunary.py
└── test_parea.py
├── conftest.py
├── core
└── managers.py
├── optim
└── metric
│ ├── test_bootstrap_few_shot.py
│ └── test_labeled_few_shot.py
├── predefined
├── test_generic_lm_function_optimizer.py
├── test_single_class_classifier.py
└── test_single_class_classifier_syntethic_data.py
├── sciprts
├── clean_up_langsmith.py
└── convert_notebooks.py
└── test_types.py
/.github/workflows/pypi-release.yml:
--------------------------------------------------------------------------------
1 | name: Publish to PyPI
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | strategy:
12 | matrix:
13 | python-version: ["3.10"]
14 |
15 | steps:
16 | - uses: actions/checkout@v4
17 | - uses: eifinger/setup-rye@v3
18 | id: setup-rye
19 | with:
20 | enable-cache: true
21 | working-directory: py
22 | cache-prefix: ${{ matrix.python-version }}
23 | - name: Pin python-version ${{ matrix.python-version }}
24 | working-directory: py
25 | run: rye pin ${{ matrix.python-version }}
26 | - name: Update Rye
27 | run: rye self update
28 | - name: Install dependencies
29 | working-directory: py
30 | run: rye sync --no-lock
31 | - name: Run Tests
32 | working-directory: py
33 | run: rye run pytest -v
34 | - name: Build package
35 | working-directory: py
36 | run: rye build
37 | - name: Publish to PyPI
38 | working-directory: py
39 | env:
40 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
41 | run: rye publish --token $PYPI_TOKEN --yes --skip-existing
42 |
43 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request_target:
8 | branches:
9 | - main
10 | - rc*
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | python-version: ["3.10", "3.11", "3.12"]
18 |
19 | steps:
20 | - uses: actions/checkout@v4
21 | - uses: eifinger/setup-rye@v3
22 | id: setup-rye
23 | with:
24 | enable-cache: true
25 | working-directory: py
26 | cache-prefix: ${{ matrix.python-version }}
27 | - name: Pin python-version ${{ matrix.python-version }}
28 | working-directory: py
29 | run: rye pin ${{ matrix.python-version }}
30 | - name: Install dependencies
31 | working-directory: py
32 | run: rye sync --no-lock
33 | - name: Check py directory contents
34 | run: ls -la py
35 | - name: Run Tests
36 | working-directory: py
37 | run: rye run pytest -v -n auto
38 | - name: Run Integration Tests
39 | working-directory: py
40 | run: rye run pytest -v -m helpers -n auto
41 | env:
42 | LANGCHAIN_API_KEY: ${{ secrets.LANGCHAIN_API_KEY }}
43 | LANGCHAIN_TRACING_V2: ${{ secrets.LANGCHAIN_TRACING_V2 }}
44 | LANGFUSE_HOST: ${{ secrets.LANGFUSE_HOST }}
45 | LANGFUSE_PUBLIC_KEY: ${{ secrets.LANGFUSE_PUBLIC_KEY }}
46 | LANGFUSE_SECRET_KEY: ${{ secrets.LANGFUSE_SECRET_KEY }}
47 | LANGSMITH_TEST_TRACKING: ${{ secrets.LANGSMITH_TEST_TRACKING }}
48 | LUNARY_PUBLIC_KEY: ${{ secrets.LUNARY_PUBLIC_KEY }}
49 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
50 | PAREA_API_KEY: ${{ secrets.PAREA_API_KEY }}
51 | BRAINTRUST_API_KEY: ${{ secrets.BRAINTRUST_API_KEY }}
52 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .env*
2 | !.env.sample
3 | .DS_Store
4 |
5 | .idea/
6 | # User-specific stuff
7 | .idea/**/workspace.xml
8 | .idea/**/tasks.xml
9 | .idea/**/usage.statistics.xml
10 | .idea/**/dictionaries
11 | .idea/**/shelf
12 |
13 | litellm/redis-data
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/lib.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "files.exclude": {
3 | "**/.git": true,
4 | "**/.svn": true,
5 | "**/.hg": true,
6 | "**/CVS": true,
7 | "**/.DS_Store": true,
8 | "**/Thumbs.db": true,
9 | "py/.venv/": true,
10 | "py/**/__pycache__": true,
11 | "py/.pytest_cache/": true
12 | },
13 | "python.defaultInterpreterPath": "py/.venv/bin/python",
14 | "python.analysis.extraPaths": [
15 | "./py/src"
16 | ],
17 | "python.testing.pytestArgs": [
18 | "py"
19 | ],
20 | "python.testing.unittestEnabled": false,
21 | "python.testing.pytestEnabled": true
22 | }
23 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guidelines
2 |
3 | Thank you for your interest in contributing to our project! Any and all contributions are welcome. This document provides some basic guidelines for contributing.
4 |
5 | ## Bug Reports
6 |
7 | If you find a bug in the code, you can help us by submitting an issue to our GitHub Repository. Even better, you can submit a Pull Request with a fix.
8 |
9 | ## Feature Requests
10 |
11 | If you'd like to make a feature request, you can do so by submitting an issue to our GitHub Repository. You can also submit a Pull Request to help implement that feature.
12 |
13 | ## Pull Requests
14 |
15 | Here are some guidelines for submitting a pull request:
16 |
17 | - Fork the repository and clone it locally. Connect your local repository to the original `upstream` repository by adding it as a remote. Pull in changes from `upstream` often so that you stay up to date with the latest code.
18 | - Create a branch for your edits.
19 | - Be clear about what problem is occurring and how someone can reproduce that problem or why your feature will help. Then be equally as clear about the steps you took to make your changes.
20 | - It's always best to consult the existing issues before creating a new one.
21 |
22 | ## Questions?
23 |
24 | If you have any questions, please feel free to contact us.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 zenbase-ai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
8 |
9 | # Zenbase Core — Prompt engineering, automated.
10 |
11 | Zenbase Core is the library for programming — not prompting — AI in production. It is a spin-out of [Stanford NLP's DSPy](https://github.com/stanfordnlp/dspy) project started by key contributors. DSPy is an excellent framework for R&D. Zenbase strives to be its brother for software engineering.
12 |
13 | ## Key Points
14 |
15 | - **DSPy** is optimized for research and development, providing tools specifically designed for deep exploration and optimization of AI systems. You run experiments with AI systems and report the results.
16 | - **Zenbase** focuses on adapting these research advancements to practical software engineering needs. Optimizers can be integrated into existing systems. We're focused on helping you deploy automatic prompt optimization in production.
17 |
18 | ## Quick Start with Python
19 |
20 | ```bash
21 | pip install zenbase
22 | ```
23 |
24 | For more information take a look at the [Zenbase Core Python README](./py/README.md).
25 |
26 | ## License
27 |
28 | This project is licensed under the MIT License. See the [LICENSE](./LICENSE) file for more information.
29 |
30 |
31 | ## Contribution
32 |
33 | Thank you for your interest in contributing to our project! Any and all contributions are welcome. See [CONTRIBUTING](./CONTRIBUTING.md) for more information.
34 |
--------------------------------------------------------------------------------
/litellm/.env.sample:
--------------------------------------------------------------------------------
1 | REDIS_HOST=redis
2 | REDIS_PORT=6379
3 | REDIS_SSL=False
4 | OPENAI_API_KEY="YOUR_API_KEY"
5 |
--------------------------------------------------------------------------------
/litellm/README.md:
--------------------------------------------------------------------------------
1 | ## LiteLLM Cache
2 |
3 | LiteLLM Cache is a proxy server designed to cache your LLM requests, helping to reduce costs and improve efficiency.
4 |
5 | ### Requirements
6 | - Docker Compose
7 | - Docker
8 |
9 | ### Setup Instructions
10 |
11 | 1. **Configure Settings:**
12 | - Navigate to `./config.yaml` and update the configuration as per your requirements. For more information, visit [LiteLLM Documentation](https://litellm.vercel.app/).
13 |
14 | 2. **Prepare Environment Variables:**
15 | - Create a `.env` file from the `.env.sample` file. Adjust the details in `.env` to match your `config.yaml` settings.
16 |
17 | 3. **Start the Docker Container:**
18 | ```bash
19 | docker-compose up -d
20 | ```
21 |
22 | 4. **Update Your LLM Server URL:**
23 | - Change the LLM calling server URL in your application to `http://0.0.0.0:4000`.
24 |
25 | For example, using the OpenAI Python SDK:
26 | ```python
27 | from openai import OpenAI
28 |
29 | llm = OpenAI(
30 | base_url='http://0.0.0.0:4000'
31 | )
32 | ```
33 |
34 | With these steps, your LLM requests will be routed through the LiteLLM Cache proxy server, optimizing performance and reducing costs.
--------------------------------------------------------------------------------
/litellm/config.yaml:
--------------------------------------------------------------------------------
1 | model_list:
2 | - model_name: gpt-3.5-turbo
3 | litellm_params:
4 | model: openai/gpt-3.5-turbo # The `openai/` prefix will call openai.chat.completions.create
5 | api_key: os.environ/OPENAI_API_KEY # The `os.environ/` prefix will call os.environ.get
6 |
7 | router_settings:
8 | redis_host: redis
9 | redis_port: 6379
--------------------------------------------------------------------------------
/litellm/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 |
3 | services:
4 | litellm:
5 | image: ghcr.io/berriai/litellm:main-latest
6 | ports:
7 | - "4000:4000"
8 | volumes:
9 | - ./config.yaml:/app/config.yaml # Mount the local configuration file
10 | command: [ "--config", "/app/config.yaml", "--port", "4000", "--num_workers", "8", "--detailed_debug"]
11 | env_file:
12 | - .env
13 |
14 | redis:
15 | image: redis:alpine
16 | ports:
17 | - "6379:6379"
18 | volumes:
19 | - ./redis-data:/data
20 | command: redis-server --appendonly yes
21 | restart: always
22 |
--------------------------------------------------------------------------------
/py/.env.sample:
--------------------------------------------------------------------------------
1 | LANGCHAIN_API_KEY="YOUR-LANGCHAIN-API-KEY"
2 | LANGCHAIN_TRACING_V2=true
3 | LANGFUSE_HOST="https://us.cloud.langfuse.com"
4 | LANGFUSE_PUBLIC_KEY="YOUR-LANGFUSE-PUBLIC-KEY"
5 | LANGFUSE_SECRET_KEY="YOUR-LANGFUSE-SECRET-KEY"
6 | LANGSMITH_TEST_TRACKING=tests/cache/
7 | LUNARY_PUBLIC_KEY="YOUR-LUNARY-PUBLIC-KEY"
8 | OPENAI_API_KEY="YOUR-OPENAI-API"
9 | PAREA_API_KEY="YOUR-PAREA-API-KEY"
10 | BRAINTRUST_API_KEY="YOUR-BRAINTRUST-API-KEY"
--------------------------------------------------------------------------------
/py/.gitignore:
--------------------------------------------------------------------------------
1 | # python generated files
2 | __pycache__/
3 | *.py[oc]
4 | build/
5 | dist/
6 | wheels/
7 | *.egg-info
8 |
9 | # venv
10 | .venv
11 | .idea/
12 |
--------------------------------------------------------------------------------
/py/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: rye-lint
5 | name: rye-lint
6 | description: "Lint Python via 'rye lint'"
7 | entry: sh -c 'cd py && rye lint --fix'
8 | language: system
9 | types_or: [python, pyi]
10 | files: ^py/
11 | args: []
12 | require_serial: true
13 | additional_dependencies: []
14 | minimum_pre_commit_version: "2.9.2"
15 |
16 | - id: rye-format
17 | name: rye-format
18 | description: "Format Python via 'rye fmt'"
19 | entry: sh -c 'cd py && rye fmt'
20 | language: system
21 | types_or: [python, pyi]
22 | files: ^py/
23 | args: []
24 | require_serial: true
25 | additional_dependencies: []
26 | minimum_pre_commit_version: "2.9.2"
27 |
28 | - id: rye-test
29 | name: rye-test
30 | description: "Test Python via 'rye test'"
31 | entry: sh -c 'cd py && rye test'
32 | language: system
33 | types_or: [python, pyi]
34 | files: ^py/
35 | args: []
36 | pass_filenames: false
37 | require_serial: true
38 | additional_dependencies: []
39 | minimum_pre_commit_version: "2.9.2"
40 |
--------------------------------------------------------------------------------
/py/.python-version:
--------------------------------------------------------------------------------
1 | 3.10.13
2 |
--------------------------------------------------------------------------------
/py/README.md:
--------------------------------------------------------------------------------
1 | # Zenbase Python SDK
2 |
3 | ## Installation
4 |
5 |
6 |
7 | Zenbase requires Python ≥3.10. You can install it using your favorite package manager:
8 |
9 | ```bash
10 | pip install zenbase
11 | poetry add zenbase
12 | rye add zenbase
13 | ```
14 |
15 | ## Usage
16 |
17 | Zenbase is designed to require minimal changes to your existing codebase and integrate seamlessly with your existing eval/observability platforms. It works with any AI SDK (OpenAI, Anthropic, Cohere, Langchain, etc.).
18 |
19 |
20 | ### Labeled Few-Shot Learning Cookbooks:
21 |
22 | LabeledFewShot will be useful for tasks that are just one layer of prompts.
23 |
24 | | Cookbook | Run in Colab |
25 | |---------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
26 | | [without_integration.ipynb](cookbooks/labeled_few_shot/json.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/json.ipynb) |
27 | | [langsmith.ipynb](cookbooks/labeled_few_shot/langsmith.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/langsmith.ipynb) |
28 | | [arize.ipynb](cookbooks/labeled_few_shot/arize.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/arize.ipynb) |
29 | | [langfuse.ipynb](cookbooks/labeled_few_shot/langfuse.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/langfuse.ipynb) |
30 | | [parea.ipynb](cookbooks/labeled_few_shot/parea.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/parea.ipynb) |
31 | | [lunary.ipynb](cookbooks/labeled_few_shot/lunary.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/lunary.ipynb) |
32 |
33 | ### Bootstrap Few-Shot Learning Cookbooks:
34 |
35 | BootstrapFewShot will be useful for tasks that are multiple layers of prompts.
36 |
37 | | Cookbook | Run in Colab |
38 | |-----------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
39 | | [langsmith.ipynb](cookbooks/bootstrap_few_shot/langsmith.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/bootstrap_few_shot/langsmith.ipynb) |
40 | | [arize.ipynb](cookbooks/bootstrap_few_shot/arize.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/bootstrap_few_shot/arize.ipynb) |
41 | | [langfuse.ipynb](cookbooks/bootstrap_few_shot/langfuse.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/bootstrap_few_shot/langfuse.ipynb) |
42 | | [parea.ipynb](cookbooks/bootstrap_few_shot/parea.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/bootstrap_few_shot/parea.ipynb) |
43 | | [lunary.ipynb](cookbooks/bootstrap_few_shot/lunary.ipynb) | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/bootstrap_few_shot/lunary.ipynb) |
44 |
45 |
46 | ### Predefined Cookbooks:
47 |
48 | | Cookbook | Description | Run in Colab |
49 | |-------------------------------------------------------------------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
50 | | [Single Class Classifier](cookbooks/predefined/single_class_classifier.ipynb) | Basic implementation of a single class classifier | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/predefined/single_class_classifier.ipynb) |
51 | | [Synthetic Data Generation](cookbooks/predefined/single_class_classifier_syn_data.ipynb) | Generate synthetic data for single class classification | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/predefined/single_class_classifier_syn_data.ipynb) |
52 | | [Classifier with Synthetic Data](cookbooks/predefined/single_class_clasifier_trained_with_syn_data.ipynb) | Train and test a single class classifier using synthetic data | [
](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/predefined/single_class_clasifier_trained_with_syn_data.ipynb) |
53 |
54 |
55 | ## Development setup
56 |
57 | This repo uses Python 3.10 and [rye](https://rye.astral.sh/) to manage dependencies. Once you've gotten rye installed, you can install dependencies by running:
58 |
59 | ```bash
60 | rye sync
61 | ```
62 |
63 | And activate the virtualenv with:
64 |
65 | ```bash
66 | . .venv/bin/activate
67 | ```
68 |
69 | You can run tests with:
70 |
71 | ```bash
72 | rye test # pytest -sv to see prints and verbose output
73 | rye test -- -m helpers # integration tests with helpers
74 | ```
75 |
--------------------------------------------------------------------------------
/py/cookbooks/predefined/single_class_clasifier_trained_with_syn_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Single Class Classifier with Synthetic Data Generation\n",
8 | "\n",
9 | "This notebook demonstrates how to use the `SingleClassClassifierSyntheticDataGenerator` to create a synthetic dataset, and then use that dataset to train and test a `SingleClassClassifier`."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Setup and Imports"
17 | ]
18 | },
19 | {
20 | "metadata": {},
21 | "cell_type": "markdown",
22 | "source": "### Import the Zenbase Library"
23 | },
24 | {
25 | "metadata": {
26 | "ExecuteTime": {
27 | "end_time": "2024-07-25T01:06:59.016591Z",
28 | "start_time": "2024-07-25T01:06:59.011483Z"
29 | }
30 | },
31 | "cell_type": "code",
32 | "source": [
33 | "import sys\n",
34 | "import subprocess\n",
35 | "\n",
36 | "def install_package(package):\n",
37 | " try:\n",
38 | " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", package])\n",
39 | " except subprocess.CalledProcessError as e:\n",
40 | " print(f\"Failed to install {package}: {e}\")\n",
41 | " raise\n",
42 | "\n",
43 | "def install_packages(packages):\n",
44 | " for package in packages:\n",
45 | " install_package(package)\n",
46 | "\n",
47 | "try:\n",
48 | " # Check if running in Google Colab\n",
49 | " import google.colab\n",
50 | " IN_COLAB = True\n",
51 | "except ImportError:\n",
52 | " IN_COLAB = False\n",
53 | "\n",
54 | "if IN_COLAB:\n",
55 | " # Install the zenbase package if running in Google Colab\n",
56 | " # install_package('zenbase')\n",
57 | " # Install the zenbse package from a GitHub branch if running in Google Colab\n",
58 | " install_package('git+https://github.com/zenbase-ai/lib.git@main#egg=zenbase&subdirectory=py')\n",
59 | "\n",
60 | " # List of other packages to install in Google Colab\n",
61 | " additional_packages = [\n",
62 | " 'python-dotenv',\n",
63 | " 'openai',\n",
64 | " 'langchain',\n",
65 | " 'langchain_openai',\n",
66 | " 'instructor',\n",
67 | " 'datasets'\n",
68 | " ]\n",
69 | " \n",
70 | " # Install additional packages\n",
71 | " install_packages(additional_packages)\n",
72 | "\n",
73 | "# Now import the zenbase library\n",
74 | "try:\n",
75 | " import zenbase\n",
76 | "except ImportError as e:\n",
77 | " print(\"Failed to import zenbase: \", e)\n",
78 | " raise"
79 | ],
80 | "outputs": [],
81 | "execution_count": 8
82 | },
83 | {
84 | "metadata": {},
85 | "cell_type": "markdown",
86 | "source": "### Configure the Environment"
87 | },
88 | {
89 | "metadata": {
90 | "ExecuteTime": {
91 | "end_time": "2024-07-25T01:06:59.027812Z",
92 | "start_time": "2024-07-25T01:06:59.023030Z"
93 | }
94 | },
95 | "cell_type": "code",
96 | "source": [
97 | "from pathlib import Path\n",
98 | "from dotenv import load_dotenv\n",
99 | "\n",
100 | "# import os\n",
101 | "#\n",
102 | "# os.environ[\"OPENAI_API_KEY\"] = \"...\"\n",
103 | "\n",
104 | "load_dotenv(Path(\"../../.env.test\"), override=True)"
105 | ],
106 | "outputs": [
107 | {
108 | "data": {
109 | "text/plain": [
110 | "True"
111 | ]
112 | },
113 | "execution_count": 9,
114 | "metadata": {},
115 | "output_type": "execute_result"
116 | }
117 | ],
118 | "execution_count": 9
119 | },
120 | {
121 | "cell_type": "code",
122 | "metadata": {
123 | "ExecuteTime": {
124 | "end_time": "2024-07-25T01:06:59.051900Z",
125 | "start_time": "2024-07-25T01:06:59.029133Z"
126 | }
127 | },
128 | "source": [
129 | "import sys\n",
130 | "import subprocess\n",
131 | "import instructor\n",
132 | "from openai import OpenAI\n",
133 | "from zenbase.core.managers import ZenbaseTracer\n",
134 | "from zenbase.predefined.single_class_classifier import SingleClassClassifier\n",
135 | "from zenbase.predefined.syntethic_data.single_class_classifier import SingleClassClassifierSyntheticDataGenerator\n",
136 | "\n",
137 | "# Set up OpenAI and Instructor clients\n",
138 | "openai_client = OpenAI()\n",
139 | "instructor_client = instructor.from_openai(openai_client)\n",
140 | "zenbase_tracer = ZenbaseTracer()"
141 | ],
142 | "outputs": [],
143 | "execution_count": 10
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "metadata": {},
148 | "source": [
149 | "## Define Classification Task"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "metadata": {
155 | "ExecuteTime": {
156 | "end_time": "2024-07-25T01:06:59.054916Z",
157 | "start_time": "2024-07-25T01:06:59.052873Z"
158 | }
159 | },
160 | "source": [
161 | "prompt_definition = \"\"\"Your task is to accurately categorize each incoming news article into one of the given categories based on its title and content.\"\"\"\n",
162 | "\n",
163 | "class_dict = {\n",
164 | " \"Automobiles\": \"Discussions and news about automobiles, including car maintenance, driving experiences, and the latest automotive technology.\",\n",
165 | " \"Computers\": \"Topics related to computer hardware, software, graphics, cryptography, and operating systems, including troubleshooting and advancements.\",\n",
166 | " \"Science\": \"News and discussions about scientific topics including space exploration, medicine, and electronics.\",\n",
167 | " \"Politics\": \"Debates and news about political topics, including gun control, Middle Eastern politics, and miscellaneous political discussions.\",\n",
168 | "}"
169 | ],
170 | "outputs": [],
171 | "execution_count": 11
172 | },
173 | {
174 | "cell_type": "markdown",
175 | "metadata": {},
176 | "source": [
177 | "## Generate Synthetic Data"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "metadata": {
183 | "ExecuteTime": {
184 | "end_time": "2024-07-25T01:08:00.370454Z",
185 | "start_time": "2024-07-25T01:06:59.056722Z"
186 | }
187 | },
188 | "source": [
189 | "# Set up the generator\n",
190 | "generator = SingleClassClassifierSyntheticDataGenerator(\n",
191 | " instructor_client=instructor_client,\n",
192 | " prompt=prompt_definition,\n",
193 | " class_dict=class_dict,\n",
194 | " model=\"gpt-4o-mini\"\n",
195 | ")\n",
196 | "\n",
197 | "# Define the number of examples per category for each set\n",
198 | "train_examples_per_category = 10\n",
199 | "val_examples_per_category = 3\n",
200 | "test_examples_per_category = 3\n",
201 | "\n",
202 | "# Generate train set\n",
203 | "train_examples = generator.generate_examples(train_examples_per_category)\n",
204 | "print(f\"Generated {len(train_examples)} examples for the train set.\\n\")\n",
205 | "\n",
206 | "# Generate validation set\n",
207 | "val_examples = generator.generate_examples(val_examples_per_category)\n",
208 | "print(f\"Generated {len(val_examples)} examples for the validation set.\\n\")\n",
209 | "\n",
210 | "# Generate test set\n",
211 | "test_examples = generator.generate_examples(test_examples_per_category)\n",
212 | "print(f\"Generated {len(test_examples)} examples for the test set.\\n\")"
213 | ],
214 | "outputs": [
215 | {
216 | "name": "stdout",
217 | "output_type": "stream",
218 | "text": [
219 | "Generated 40 examples for the train set.\n",
220 | "\n",
221 | "Generated 12 examples for the validation set.\n",
222 | "\n",
223 | "Generated 12 examples for the test set.\n",
224 | "\n"
225 | ]
226 | }
227 | ],
228 | "execution_count": 12
229 | },
230 | {
231 | "cell_type": "markdown",
232 | "metadata": {},
233 | "source": [
234 | "## Create and Train the Classifier"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "metadata": {
240 | "ExecuteTime": {
241 | "end_time": "2024-07-25T01:11:27.692383Z",
242 | "start_time": "2024-07-25T01:08:00.374312Z"
243 | }
244 | },
245 | "source": [
246 | "classifier = SingleClassClassifier(\n",
247 | " instructor_client=instructor_client,\n",
248 | " prompt=prompt_definition,\n",
249 | " class_dict=class_dict,\n",
250 | " model=\"gpt-4o-mini\",\n",
251 | " zenbase_tracer=zenbase_tracer,\n",
252 | " training_set=train_examples,\n",
253 | " validation_set=val_examples,\n",
254 | " test_set=test_examples,\n",
255 | " samples=20,\n",
256 | ")\n",
257 | "\n",
258 | "result = classifier.optimize()"
259 | ],
260 | "outputs": [],
261 | "execution_count": 13
262 | },
263 | {
264 | "cell_type": "markdown",
265 | "metadata": {},
266 | "source": [
267 | "## Analyze Results"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "metadata": {
273 | "ExecuteTime": {
274 | "end_time": "2024-07-25T01:11:27.696612Z",
275 | "start_time": "2024-07-25T01:11:27.693579Z"
276 | }
277 | },
278 | "source": [
279 | "print(\"Base Evaluation Score:\", classifier.base_evaluation.evals['score'])\n",
280 | "print(\"Best Evaluation Score:\", classifier.best_evaluation.evals['score'])\n",
281 | "\n",
282 | "print(\"\\nBest function:\", result.best_function)\n",
283 | "print(\"Number of candidate results:\", len(result.candidate_results))\n",
284 | "print(\"Best candidate result:\", result.best_candidate_result.evals)"
285 | ],
286 | "outputs": [
287 | {
288 | "name": "stdout",
289 | "output_type": "stream",
290 | "text": [
291 | "Base Evaluation Score: 0.9166666666666666\n",
292 | "Best Evaluation Score: 0.9166666666666666\n",
293 | "\n",
294 | "Best function: \n",
295 | "Number of candidate results: 20\n",
296 | "Best candidate result: {'score': 0.9166666666666666}\n",
297 | "Number of traces: 264\n"
298 | ]
299 | }
300 | ],
301 | "execution_count": 14
302 | },
303 | {
304 | "cell_type": "markdown",
305 | "metadata": {},
306 | "source": [
307 | "## Test the Classifier"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "metadata": {
313 | "ExecuteTime": {
314 | "end_time": "2024-07-25T01:11:28.490185Z",
315 | "start_time": "2024-07-25T01:11:27.697387Z"
316 | }
317 | },
318 | "source": [
319 | "new_article = \"\"\"\n",
320 | "title: Revolutionary Quantum Computer Achieves Milestone in Cryptography\n",
321 | "content: Scientists at a leading tech company have announced a breakthrough in quantum computing, \n",
322 | "demonstrating a quantum computer capable of solving complex cryptographic problems in record time. \n",
323 | "This development has significant implications for data security and could revolutionize fields \n",
324 | "ranging from finance to national security. However, experts warn that it also poses potential \n",
325 | "risks to current encryption methods.\n",
326 | "\"\"\"\n",
327 | "\n",
328 | "classification = result.best_function(new_article)\n",
329 | "print(f\"The article is classified as: {classification.class_label.name}\")"
330 | ],
331 | "outputs": [
332 | {
333 | "name": "stdout",
334 | "output_type": "stream",
335 | "text": [
336 | "The article is classified as: Computers\n"
337 | ]
338 | }
339 | ],
340 | "execution_count": 15
341 | },
342 | {
343 | "cell_type": "markdown",
344 | "metadata": {},
345 | "source": [
346 | "## Conclusion\n",
347 | "\n",
348 | "In this notebook, we've demonstrated how to:\n",
349 | "1. Generate synthetic data for a single-class classification task\n",
350 | "2. Prepare the synthetic data for training and testing\n",
351 | "3. Create and train a SingleClassClassifier using the synthetic data\n",
352 | "4. Analyze the results of the classifier\n",
353 | "5. Use the trained classifier to categorize new input\n",
354 | "\n",
355 | "This approach allows for rapid prototyping and testing of classification models, especially in scenarios where real-world labeled data might be scarce or difficult to obtain."
356 | ]
357 | }
358 | ],
359 | "metadata": {
360 | "kernelspec": {
361 | "display_name": "Python 3",
362 | "language": "python",
363 | "name": "python3"
364 | },
365 | "language_info": {
366 | "codemirror_mode": {
367 | "name": "ipython",
368 | "version": 3
369 | },
370 | "file_extension": ".py",
371 | "mimetype": "text/x-python",
372 | "name": "python",
373 | "nbconvert_exporter": "python",
374 | "pygments_lexer": "ipython3",
375 | "version": "3.8.0"
376 | }
377 | },
378 | "nbformat": 4,
379 | "nbformat_minor": 4
380 | }
381 |
--------------------------------------------------------------------------------
/py/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "zenbase"
3 | version = "0.0.22"
4 | description = "LLMs made Zen"
5 | authors = [{ name = "Cyrus Nouroozi", email = "cyrus@zenbase.ai" }]
6 | dependencies = [
7 | "pksuid>=1.1.2",
8 | "faker>=24.2.0",
9 | "anyio>=4.4.0",
10 | "opentelemetry-sdk>=1.25.0",
11 | "opentelemetry-api>=1.25.0",
12 | "structlog>=24.2.0",
13 | "pyee>=11.1.0",
14 | "posthog>=3.5.0",
15 | "cloudpickle>=3.0.0",
16 | "instructor>=1.3.5",
17 | ]
18 | readme = "README.md"
19 | requires-python = ">= 3.10"
20 | lisence = "MIT"
21 |
22 | [build-system]
23 | requires = ["hatchling"]
24 | build-backend = "hatchling.build"
25 |
26 | [tool.mypy]
27 | namespace_packages = true
28 |
29 | [tool.pytest.ini_options]
30 | markers = ["helpers"]
31 | addopts = "-m 'not helpers'"
32 |
33 | [tool.rye]
34 | managed = true
35 | dev-dependencies = [
36 | "pytest>=8.2.1",
37 | "ruff>=0.4.6",
38 | "langsmith[vcr]>=0.1.72",
39 | "pyright>=1.1.365",
40 | "datasets>=2.19.1",
41 | "ipython>=8.24.0",
42 | "ipdb>=0.13.13",
43 | "openai>=1.30.5",
44 | "pytest-recording>=0.13.1",
45 | "python-dotenv>=1.0.1",
46 | "vcrpy>=6.0.1",
47 | "arize-phoenix[evals]>=4.9.0",
48 | "nest-asyncio>=1.6.0",
49 | "langchain-openai>=0.1.8",
50 | "langchain-core>=0.2.3",
51 | "langchain>=0.2.1",
52 | "parea-ai>=0.2.164",
53 | "langfuse>=2.35.0",
54 | "lunary>=1.0.30",
55 | "autoevals>=0.0.68",
56 | "braintrust>=0.0.131",
57 | "pre-commit>=3.7.1",
58 | "pytest-xdist>=3.6.1",
59 | "openai-responses>=0.8.1",
60 | ]
61 |
62 | [tool.hatch.metadata]
63 | allow-direct-references = true
64 |
65 | [tool.hatch.build.targets.wheel]
66 | packages = ["src/zenbase"]
67 |
68 | [tool.ruff]
69 | exclude = [
70 | "venv",
71 | ".git",
72 | "__pycache__",
73 | "build",
74 | "dist",
75 | "venv",
76 | ]
77 | line-length = 120
78 | src = ["src", "tests"]
79 |
80 | [tool.ruff.lint]
81 | ignore = []
82 | select = [
83 | "E",
84 | "F",
85 | "W",
86 | "I001",
87 | ]
88 |
--------------------------------------------------------------------------------
/py/requirements.lock:
--------------------------------------------------------------------------------
1 | # generated by rye
2 | # use `rye lock` or `rye sync` to update this lockfile
3 | #
4 | # last locked with the following flags:
5 | # pre: false
6 | # features: []
7 | # all-features: false
8 | # with-sources: false
9 | # generate-hashes: false
10 |
11 | -e file:.
12 | aiohttp==3.9.5
13 | # via instructor
14 | aiosignal==1.3.1
15 | # via aiohttp
16 | annotated-types==0.7.0
17 | # via pydantic
18 | anyio==4.4.0
19 | # via httpx
20 | # via openai
21 | # via zenbase
22 | async-timeout==4.0.3
23 | # via aiohttp
24 | attrs==23.2.0
25 | # via aiohttp
26 | backoff==2.2.1
27 | # via posthog
28 | certifi==2024.6.2
29 | # via httpcore
30 | # via httpx
31 | # via requests
32 | charset-normalizer==3.3.2
33 | # via requests
34 | click==8.1.7
35 | # via typer
36 | cloudpickle==3.0.0
37 | # via zenbase
38 | deprecated==1.2.14
39 | # via opentelemetry-api
40 | distro==1.9.0
41 | # via openai
42 | docstring-parser==0.16
43 | # via instructor
44 | exceptiongroup==1.2.1
45 | # via anyio
46 | faker==26.0.0
47 | # via zenbase
48 | frozenlist==1.4.1
49 | # via aiohttp
50 | # via aiosignal
51 | h11==0.14.0
52 | # via httpcore
53 | httpcore==1.0.5
54 | # via httpx
55 | httpx==0.27.0
56 | # via openai
57 | idna==3.7
58 | # via anyio
59 | # via httpx
60 | # via requests
61 | # via yarl
62 | importlib-metadata==7.1.0
63 | # via opentelemetry-api
64 | instructor==1.3.5
65 | # via zenbase
66 | jiter==0.4.2
67 | # via instructor
68 | # via openai
69 | markdown-it-py==3.0.0
70 | # via rich
71 | mdurl==0.1.2
72 | # via markdown-it-py
73 | monotonic==1.6
74 | # via posthog
75 | multidict==6.0.5
76 | # via aiohttp
77 | # via yarl
78 | openai==1.43.0
79 | # via instructor
80 | opentelemetry-api==1.25.0
81 | # via opentelemetry-sdk
82 | # via opentelemetry-semantic-conventions
83 | # via zenbase
84 | opentelemetry-sdk==1.25.0
85 | # via zenbase
86 | opentelemetry-semantic-conventions==0.46b0
87 | # via opentelemetry-sdk
88 | pksuid==1.1.2
89 | # via zenbase
90 | posthog==3.5.0
91 | # via zenbase
92 | pybase62==0.4.3
93 | # via pksuid
94 | pydantic==2.8.2
95 | # via instructor
96 | # via openai
97 | pydantic-core==2.20.1
98 | # via instructor
99 | # via pydantic
100 | pyee==11.1.0
101 | # via zenbase
102 | pygments==2.18.0
103 | # via rich
104 | python-dateutil==2.9.0.post0
105 | # via faker
106 | # via posthog
107 | requests==2.32.3
108 | # via posthog
109 | rich==13.7.1
110 | # via instructor
111 | # via typer
112 | shellingham==1.5.4
113 | # via typer
114 | six==1.16.0
115 | # via posthog
116 | # via python-dateutil
117 | sniffio==1.3.1
118 | # via anyio
119 | # via httpx
120 | # via openai
121 | structlog==24.2.0
122 | # via zenbase
123 | tenacity==8.5.0
124 | # via instructor
125 | tqdm==4.66.4
126 | # via openai
127 | typer==0.12.3
128 | # via instructor
129 | typing-extensions==4.12.1
130 | # via anyio
131 | # via openai
132 | # via opentelemetry-sdk
133 | # via pydantic
134 | # via pydantic-core
135 | # via pyee
136 | # via typer
137 | urllib3==2.2.1
138 | # via requests
139 | wrapt==1.16.0
140 | # via deprecated
141 | yarl==1.9.4
142 | # via aiohttp
143 | zipp==3.19.2
144 | # via importlib-metadata
145 |
--------------------------------------------------------------------------------
/py/src/zenbase/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/src/zenbase/__init__.py
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/src/zenbase/adaptors/__init__.py
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/arize.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Callable
2 |
3 | import pandas as pd
4 |
5 | from zenbase.optim.metric.types import CandidateEvalResult, OverallEvalValue
6 | from zenbase.types import LMDemo, LMFunction
7 | from zenbase.utils import amap
8 |
9 | if TYPE_CHECKING:
10 | from phoenix.evals import LLMEvaluator
11 |
12 |
13 | class ZenPhoenix:
14 | MetricEvaluator = Callable[[list["LLMEvaluator"], list[pd.DataFrame]], OverallEvalValue]
15 |
16 | @staticmethod
17 | def df_to_demos(df: pd.DataFrame) -> list[LMDemo]:
18 | raise NotImplementedError()
19 |
20 | @staticmethod
21 | def default_metric(evaluators: list["LLMEvaluator"], eval_dfs: list[pd.DataFrame]) -> OverallEvalValue:
22 | evals = {"score": sum(df.score.mean() for df in eval_dfs)}
23 | evals.update({e.__name__: df.score.mean() for e, df in zip(evaluators, eval_dfs)})
24 | return evals
25 |
26 | @classmethod
27 | def metric_evaluator(
28 | cls,
29 | dataset: pd.DataFrame,
30 | evaluators: list["LLMEvaluator"],
31 | metric_evals: MetricEvaluator = default_metric,
32 | concurrency: int = 20,
33 | *args,
34 | **kwargs,
35 | ):
36 | from phoenix.evals import run_evals
37 |
38 | async def run_experiment(function: LMFunction) -> CandidateEvalResult:
39 | nonlocal dataset
40 | run_df = dataset.copy()
41 | # TODO: Is it typical for there to only be 1 value?
42 | responses = await amap(
43 | function,
44 | run_df["attributes.input.value"].to_list(), # i don't think this works
45 | concurrency=concurrency,
46 | )
47 | run_df["attributes.output.value"] = responses
48 |
49 | eval_dfs = run_evals(run_df, evaluators, *args, concurrency=concurrency, **kwargs)
50 |
51 | return CandidateEvalResult(
52 | function,
53 | evals=metric_evals(evaluators, eval_dfs),
54 | )
55 |
56 | return run_experiment
57 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/arize/__init__.py:
--------------------------------------------------------------------------------
1 | from .adaptor import * # noqa
2 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/arize/adaptor.py:
--------------------------------------------------------------------------------
1 | from zenbase.adaptors.arize.dataset_helper import ArizeDatasetHelper
2 | from zenbase.adaptors.arize.evaluation_helper import ArizeEvaluationHelper
3 |
4 |
5 | class ZenArizeAdaptor(ArizeDatasetHelper, ArizeEvaluationHelper):
6 | def __init__(self, client=None):
7 | ArizeDatasetHelper.__init__(self, client)
8 | ArizeEvaluationHelper.__init__(self, client)
9 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/arize/dataset_helper.py:
--------------------------------------------------------------------------------
1 | from phoenix.experiments.types import Dataset, Example
2 |
3 | from zenbase.adaptors.base.dataset_helper import BaseDatasetHelper
4 | from zenbase.types import LMDemo
5 |
6 |
7 | class ArizeDatasetHelper(BaseDatasetHelper):
8 | def create_dataset(self, dataset_name: str, *args, **kwargs):
9 | raise NotImplementedError(
10 | "create_dataset not implemented / supported for Arize, dataset will be created"
11 | "automatically when adding examples to it."
12 | )
13 |
14 | def add_examples_to_dataset(self, dataset_name: str, inputs: list, outputs: list) -> Dataset:
15 | return self.client.upload_dataset(
16 | dataset_name=dataset_name,
17 | inputs=inputs,
18 | outputs=outputs,
19 | )
20 |
21 | def fetch_dataset_examples(self, dataset_name: str) -> list[Example]:
22 | return list(self.client.get_dataset(name=dataset_name).examples.values())
23 |
24 | def fetch_dataset_demos(self, dataset_name: str) -> list[LMDemo]:
25 | dataset_examples = self.fetch_dataset_examples(dataset_name)
26 | return [
27 | LMDemo(inputs=example.input, outputs=example.output) # noqa
28 | for example in dataset_examples
29 | ]
30 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/arize/evaluation_helper.py:
--------------------------------------------------------------------------------
1 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper
2 | from zenbase.optim.metric.types import CandidateEvalResult, CandidateEvaluator, IndividualEvalValue
3 | from zenbase.types import LMDemo, LMFunction
4 | from zenbase.utils import random_name_generator
5 |
6 |
7 | class ArizeEvaluationHelper(BaseEvaluationHelper):
8 | def get_evaluator(self, data: str):
9 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy()
10 | dataset = self.client.get_dataset(name=data)
11 | evaluator_kwargs_to_pass.update({"dataset": dataset})
12 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass)
13 |
14 | def _metric_evaluator_generator(self, threshold: float = 0.5, **evaluate_kwargs) -> CandidateEvaluator:
15 | from phoenix.experiments import run_experiment
16 |
17 | gen_random_name = random_name_generator(evaluate_kwargs.pop("experiment_prefix", None))
18 |
19 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult:
20 | def arize_adapted_function(input):
21 | return function(input)
22 |
23 | experiment = run_experiment(
24 | evaluate_kwargs["dataset"],
25 | arize_adapted_function,
26 | experiment_name=gen_random_name(),
27 | evaluators=evaluate_kwargs.get("evaluators", None),
28 | )
29 | list_of_individual_evals = []
30 | for individual_eval in experiment.eval_runs:
31 | example_id = experiment.runs[individual_eval.experiment_run_id].dataset_example_id
32 | example = experiment.dataset.examples[example_id]
33 | if individual_eval.result:
34 | list_of_individual_evals.append(
35 | IndividualEvalValue(
36 | passed=individual_eval.result.score >= threshold,
37 | response=experiment.runs[individual_eval.experiment_run_id].output,
38 | demo=LMDemo(
39 | inputs=example.input,
40 | outputs=example.output,
41 | ),
42 | score=individual_eval.result.score,
43 | )
44 | )
45 |
46 | # make average scores of all evaluation metrics
47 | avg_scores = [i.stats["avg_score"][0] for i in experiment.eval_summaries]
48 | avg_score = sum(avg_scores) / len(avg_scores)
49 |
50 | return CandidateEvalResult(function, {"score": avg_score}, individual_evals=list_of_individual_evals)
51 |
52 | return evaluate_candidate
53 |
54 | @classmethod
55 | def metric_evaluator(cls, threshold: float = 0.5, **evaluate_kwargs) -> CandidateEvaluator:
56 | # TODO: Should remove and deprecate
57 | from phoenix.experiments import run_experiment
58 |
59 | gen_random_name = random_name_generator(evaluate_kwargs.pop("experiment_prefix", None))
60 |
61 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult:
62 | def arize_adapted_function(input):
63 | return function(input)
64 |
65 | experiment = run_experiment(
66 | evaluate_kwargs["dataset"],
67 | arize_adapted_function,
68 | experiment_name=gen_random_name(),
69 | evaluators=evaluate_kwargs.get("evaluators", None),
70 | )
71 | list_of_individual_evals = []
72 | for individual_eval in experiment.eval_runs:
73 | example_id = experiment.runs[individual_eval.experiment_run_id].dataset_example_id
74 | example = experiment.dataset.examples[example_id]
75 |
76 | list_of_individual_evals.append(
77 | IndividualEvalValue(
78 | passed=individual_eval.result.score >= threshold,
79 | response=experiment.runs[individual_eval.experiment_run_id].output,
80 | demo=LMDemo(
81 | inputs=example.input,
82 | outputs=example.output,
83 | ),
84 | score=individual_eval.result.score,
85 | )
86 | )
87 |
88 | # make average scores of all evaluation metrics
89 | avg_scores = [i.stats["avg_score"][0] for i in experiment.eval_summaries]
90 | avg_score = sum(avg_scores) / len(avg_scores)
91 |
92 | return CandidateEvalResult(function, {"score": avg_score}, individual_evals=list_of_individual_evals)
93 |
94 | return evaluate_candidate
95 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/base/adaptor.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 |
4 | class ZenAdaptor(ABC):
5 | def __init__(self, client=None):
6 | self.client = client
7 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/base/dataset_helper.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any
3 |
4 | from zenbase.adaptors.base.adaptor import ZenAdaptor
5 |
6 |
7 | class BaseDatasetHelper(ZenAdaptor):
8 | @abstractmethod
9 | def create_dataset(self, dataset_name: str, *args, **kwargs) -> Any: ...
10 |
11 | @abstractmethod
12 | def add_examples_to_dataset(self, dataset_id: Any, inputs: list, outputs: list) -> None: ...
13 |
14 | @abstractmethod
15 | def fetch_dataset_examples(self, dataset_name: str) -> Any: ...
16 |
17 | @abstractmethod
18 | def fetch_dataset_demos(self, dataset: Any) -> Any: ...
19 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/base/evaluation_helper.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Any
3 |
4 | from zenbase.adaptors.base.adaptor import ZenAdaptor
5 |
6 |
7 | class BaseEvaluationHelper(ZenAdaptor):
8 | evaluator_args = tuple()
9 | evaluator_kwargs = dict()
10 |
11 | def set_evaluator_kwargs(self, *args, **kwargs) -> None:
12 | self.evaluator_kwargs = kwargs
13 | self.evaluator_args = args
14 |
15 | @abstractmethod
16 | def get_evaluator(self, data: Any): ...
17 |
18 | @classmethod
19 | @abstractmethod
20 | def metric_evaluator(cls, *args, **evaluate_kwargs): ...
21 |
22 | # TODO: Should remove and deprecate
23 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/braintrust.py:
--------------------------------------------------------------------------------
1 | from asyncio import Task
2 | from dataclasses import asdict
3 | from typing import AsyncIterator, Awaitable, Callable, Iterator
4 |
5 | from braintrust import (
6 | Eval,
7 | EvalCase,
8 | EvalHooks,
9 | EvalScorer,
10 | Input,
11 | Metadata,
12 | Output,
13 | ReporterDef,
14 | )
15 |
16 | from zenbase.optim.metric.types import CandidateEvalResult
17 | from zenbase.types import LMFunction
18 | from zenbase.utils import random_name_generator
19 |
20 |
21 | class ZenBraintrust:
22 | @staticmethod
23 | def metric_evaluator(
24 | name: str,
25 | data: Callable[[], Iterator[EvalCase] | AsyncIterator[EvalCase]],
26 | task: Callable[[Input, EvalHooks], Output | Awaitable[Output]],
27 | scores: list[EvalScorer],
28 | experiment_name: str | None = None,
29 | trial_count: int = 1,
30 | metadata: Metadata | None = None,
31 | is_public: bool = False,
32 | update: bool = False,
33 | reporter: ReporterDef | str | None = None,
34 | ):
35 | gen_random_name = random_name_generator(experiment_name)
36 |
37 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult:
38 | eval_result = Eval(
39 | name=name,
40 | experiment_name=gen_random_name(),
41 | data=data,
42 | task=task,
43 | scores=scores,
44 | trial_count=trial_count,
45 | metadata={
46 | **metadata,
47 | **asdict(function.zenbase),
48 | },
49 | is_public=is_public,
50 | update=update,
51 | reporter=reporter,
52 | )
53 |
54 | if isinstance(eval_result, Task):
55 | eval_result = eval_result.result()
56 |
57 | assert eval_result is not None, "Failed to run Braintrust Eval"
58 |
59 | evals = {s.name: s.score for s in eval_result.summary.scores.values()}
60 | if "score" not in evals:
61 | evals["score"] = sum(evals.values())
62 |
63 | return CandidateEvalResult(function, evals)
64 |
65 | return evaluate_candidate
66 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/json/adaptor.py:
--------------------------------------------------------------------------------
1 | from zenbase.adaptors.json.dataset_helper import JSONDatasetHelper
2 | from zenbase.adaptors.json.evaluation_helper import JSONEvaluationHelper
3 |
4 |
5 | class JSONAdaptor(JSONDatasetHelper, JSONEvaluationHelper):
6 | def __init__(self, client=None):
7 | JSONDatasetHelper.__init__(self, client)
8 | JSONEvaluationHelper.__init__(self, client)
9 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/json/dataset_helper.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from zenbase.adaptors.base.adaptor import ZenAdaptor
4 | from zenbase.types import LMDemo
5 |
6 |
7 | class JSONDatasetHelper(ZenAdaptor):
8 | datasets = {}
9 |
10 | def create_dataset(self, dataset_name: str, *args, **kwargs) -> Any:
11 | self.datasets[dataset_name] = []
12 |
13 | return self.datasets[dataset_name]
14 |
15 | def add_examples_to_dataset(self, dataset_id: Any, inputs: list, outputs: list) -> None:
16 | for input, output in zip(inputs, outputs):
17 | self.datasets[dataset_id].append({"input": input, "output": output})
18 |
19 | def fetch_dataset_examples(self, dataset_name: str) -> Any:
20 | return self.datasets[dataset_name]
21 |
22 | def fetch_dataset_demos(self, dataset: Any) -> Any:
23 | if isinstance(dataset[0], LMDemo):
24 | return dataset
25 | return [LMDemo(input=item, output=item) for item in dataset]
26 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/json/evaluation_helper.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from typing import Any, Callable
3 |
4 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper
5 | from zenbase.optim.metric.types import CandidateEvalResult, CandidateEvaluator, IndividualEvalValue, OverallEvalValue
6 | from zenbase.types import LMDemo, LMFunction
7 | from zenbase.utils import pmap
8 |
9 |
10 | class JSONEvaluationHelper(BaseEvaluationHelper):
11 | MetricEvaluator = Callable[[list[tuple[bool, Any]]], OverallEvalValue]
12 |
13 | @staticmethod
14 | def default_metric(batch_results: list[tuple[bool, Any]]) -> OverallEvalValue:
15 | avg_pass = sum(int(result["passed"]) for result in batch_results) / len(batch_results)
16 | return {"score": avg_pass}
17 |
18 | def get_evaluator(self, data: list[LMDemo]):
19 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy()
20 | evaluator_kwargs_to_pass.update({"data": data})
21 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass)
22 |
23 | @staticmethod
24 | def _metric_evaluator_generator(
25 | *args,
26 | eval_function: Callable,
27 | data: list[LMDemo],
28 | eval_metrics: MetricEvaluator = default_metric,
29 | concurrency: int = 1,
30 | **kwargs,
31 | ) -> CandidateEvaluator:
32 | # TODO: Should remove and deprecate
33 | def evaluate_metric(function: LMFunction) -> CandidateEvalResult:
34 | individual_evals = []
35 |
36 | def run_and_evaluate(demo: LMDemo):
37 | nonlocal individual_evals
38 |
39 | response = function(demo.inputs)
40 |
41 | # Check if eval_function accepts 'input' parameter
42 | eval_params = inspect.signature(eval_function).parameters
43 | if "input" in eval_params:
44 | result = eval_function(
45 | input=demo.inputs,
46 | output=response,
47 | ideal_output=demo.outputs,
48 | *args,
49 | **kwargs,
50 | )
51 | else:
52 | result = eval_function(
53 | output=response,
54 | ideal_output=demo.outputs,
55 | *args,
56 | **kwargs,
57 | )
58 |
59 | individual_evals.append(
60 | IndividualEvalValue(
61 | passed=result["passed"],
62 | response=response,
63 | demo=demo,
64 | )
65 | )
66 |
67 | return result
68 |
69 | eval_results = pmap(
70 | run_and_evaluate,
71 | data,
72 | concurrency=concurrency,
73 | )
74 |
75 | return CandidateEvalResult(function, eval_metrics(eval_results), individual_evals=individual_evals)
76 |
77 | return evaluate_metric
78 |
79 | @classmethod
80 | def metric_evaluator(
81 | cls,
82 | eval_function: Callable,
83 | data: list[LMDemo],
84 | eval_metrics: MetricEvaluator = default_metric,
85 | concurrency: int = 1,
86 | threshold: float = 0.5,
87 | ) -> CandidateEvaluator:
88 | # TODO: Should remove and deprecate
89 | def evaluate_metric(function: LMFunction) -> CandidateEvalResult:
90 | individual_evals = []
91 |
92 | def run_and_evaluate(demo: LMDemo):
93 | nonlocal individual_evals
94 | response = function(demo.inputs)
95 |
96 | # Check if eval_function accepts 'input' parameter
97 | eval_params = inspect.signature(eval_function).parameters
98 | if "input" in eval_params:
99 | result = eval_function(
100 | input=demo.inputs,
101 | output=response,
102 | ideal_output=demo.outputs,
103 | )
104 | else:
105 | result = eval_function(
106 | output=response,
107 | ideal_output=demo.outputs,
108 | )
109 |
110 | individual_evals.append(
111 | IndividualEvalValue(
112 | passed=result["passed"],
113 | response=response,
114 | demo=demo,
115 | details=result,
116 | )
117 | )
118 |
119 | return result
120 |
121 | eval_results = pmap(
122 | run_and_evaluate,
123 | data,
124 | concurrency=concurrency,
125 | )
126 |
127 | return CandidateEvalResult(function, eval_metrics(eval_results), individual_evals=individual_evals)
128 |
129 | return evaluate_metric
130 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/langchain/__init__.py:
--------------------------------------------------------------------------------
1 | from .adaptor import * # noqa
2 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/langchain/adaptor.py:
--------------------------------------------------------------------------------
1 | __all__ = ["ZenLangSmith"]
2 |
3 | from zenbase.adaptors.langchain.dataset_helper import LangsmithDatasetHelper
4 | from zenbase.adaptors.langchain.evaluation_helper import LangsmithEvaluationHelper
5 |
6 |
7 | class ZenLangSmith(LangsmithDatasetHelper, LangsmithEvaluationHelper):
8 | def __init__(self, client=None):
9 | LangsmithDatasetHelper.__init__(self, client)
10 | LangsmithEvaluationHelper.__init__(self, client)
11 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/langchain/dataset_helper.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Iterator
2 |
3 | from zenbase.adaptors.base.dataset_helper import BaseDatasetHelper
4 | from zenbase.types import LMDemo
5 |
6 | if TYPE_CHECKING:
7 | from langsmith import schemas
8 |
9 |
10 | class LangsmithDatasetHelper(BaseDatasetHelper):
11 | def create_dataset(self, dataset_name: str, description: str) -> "schemas.Dataset":
12 | dataset = self.client.create_dataset(dataset_name, description=description)
13 | return dataset
14 |
15 | def add_examples_to_dataset(self, dataset_name: str, inputs: list, outputs: list) -> None:
16 | self.client.create_examples(
17 | inputs=inputs,
18 | outputs=outputs,
19 | dataset_name=dataset_name,
20 | )
21 |
22 | def fetch_dataset_examples(self, dataset_name: str) -> Iterator["schemas.Example"]:
23 | dataset = self.fetch_dataset(dataset_name)
24 | return self.client.list_examples(dataset_id=dataset.id)
25 |
26 | def fetch_dataset(self, dataset_name: str):
27 | datasets = self.client.list_datasets(dataset_name=dataset_name)
28 | if not datasets:
29 | raise ValueError(f"Dataset '{dataset_name}' not found")
30 | dataset = [i for i in datasets][0]
31 | return dataset
32 |
33 | def fetch_dataset_demos(self, dataset: "schemas.Dataset") -> list[LMDemo]:
34 | dataset_examples = self.fetch_dataset_examples(dataset.name)
35 | return self.examples_to_demos(dataset_examples)
36 |
37 | @staticmethod
38 | def examples_to_demos(examples: Iterator["schemas.Example"]) -> list[LMDemo]:
39 | return [LMDemo(inputs=e.inputs, outputs=e.outputs, adaptor_object=e) for e in examples]
40 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/langchain/evaluation_helper.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict
2 | from typing import TYPE_CHECKING
3 |
4 | from langsmith.evaluation._runner import ExperimentResults # noqa
5 |
6 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper
7 | from zenbase.optim.metric.types import CandidateEvalResult, CandidateEvaluator, IndividualEvalValue, OverallEvalValue
8 | from zenbase.types import LMDemo, LMFunction
9 | from zenbase.utils import random_name_generator
10 |
11 | if TYPE_CHECKING:
12 | from langsmith import schemas
13 |
14 |
15 | class LangsmithEvaluationHelper(BaseEvaluationHelper):
16 | def __init__(self, client=None):
17 | super().__init__(client)
18 | self.evaluator_args = None
19 | self.evaluator_kwargs = None
20 |
21 | def set_evaluator_kwargs(self, *args, **kwargs) -> None:
22 | self.evaluator_kwargs = kwargs
23 | self.evaluator_args = args
24 |
25 | def get_evaluator(self, data: "schemas.Dataset") -> CandidateEvaluator:
26 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy()
27 | evaluator_kwargs_to_pass.update({"data": data.name})
28 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass)
29 |
30 | def _metric_evaluator_generator(self, threshold: float = 0.5, **evaluate_kwargs) -> CandidateEvaluator:
31 | from langsmith import evaluate
32 |
33 | metadata = evaluate_kwargs.pop("metadata", {})
34 | gen_random_name = random_name_generator(evaluate_kwargs.pop("experiment_prefix", None))
35 |
36 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult:
37 | experiment_results = evaluate(
38 | function,
39 | experiment_prefix=gen_random_name(),
40 | metadata={
41 | **metadata,
42 | },
43 | **evaluate_kwargs,
44 | )
45 |
46 | individual_evals = self._experiment_results_to_individual_evals(experiment_results, threshold)
47 |
48 | if summary_results := experiment_results._summary_results["results"]: # noqa
49 | evals = self._eval_results_to_evals(summary_results)
50 | else:
51 | evals = self._individual_evals_to_overall_evals(individual_evals)
52 |
53 | return CandidateEvalResult(function, evals, individual_evals)
54 |
55 | return evaluate_candidate
56 |
57 | @classmethod
58 | def metric_evaluator(cls, threshold: float = 0.5, **evaluate_kwargs) -> CandidateEvaluator:
59 | # TODO: Should remove and deprecate
60 | from langsmith import evaluate
61 |
62 | metadata = evaluate_kwargs.pop("metadata", {})
63 | gen_random_name = random_name_generator(evaluate_kwargs.pop("experiment_prefix", None))
64 |
65 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult:
66 | experiment_results = evaluate(
67 | function,
68 | experiment_prefix=gen_random_name(),
69 | metadata={
70 | **metadata,
71 | **asdict(function.zenbase),
72 | },
73 | **evaluate_kwargs,
74 | )
75 |
76 | individual_evals = cls._experiment_results_to_individual_evals(experiment_results, threshold)
77 |
78 | if summary_results := experiment_results._summary_results["results"]: # noqa
79 | evals = cls._eval_results_to_evals(summary_results)
80 | else:
81 | evals = cls._individual_evals_to_overall_evals(individual_evals)
82 |
83 | return CandidateEvalResult(function, evals, individual_evals)
84 |
85 | return evaluate_candidate
86 |
87 | @staticmethod
88 | def _individual_evals_to_overall_evals(individual_evals: list[IndividualEvalValue]) -> OverallEvalValue:
89 | if not individual_evals:
90 | raise ValueError("No evaluation results")
91 |
92 | if individual_evals[0].score is not None:
93 | number_of_scores = sum(1 for e in individual_evals if e.score is not None)
94 | score = sum(e.score for e in individual_evals if e.score is not None) / number_of_scores
95 | else:
96 | number_of_filled_passed = sum(1 for e in individual_evals if e.passed is not None)
97 | number_of_actual_passed = sum(1 for e in individual_evals if e.passed)
98 | score = number_of_actual_passed / number_of_filled_passed
99 |
100 | return {"score": score}
101 |
102 | @staticmethod
103 | def _experiment_results_to_individual_evals(
104 | experiment_results: ExperimentResults, threshold=0.5
105 | ) -> (list)[IndividualEvalValue]:
106 | individual_evals = []
107 | for res in experiment_results._results: # noqa
108 | if not res["evaluation_results"]["results"]:
109 | continue
110 | score = res["evaluation_results"]["results"][0].score
111 | inputs = res["example"].inputs
112 | outputs = res["example"].outputs
113 | individual_evals.append(
114 | IndividualEvalValue(
115 | passed=score >= threshold,
116 | response=outputs,
117 | demo=LMDemo(inputs=inputs, outputs=outputs, adaptor_object=res["example"]),
118 | details=res,
119 | score=score,
120 | )
121 | )
122 | return individual_evals
123 |
124 | @staticmethod
125 | def _eval_results_to_evals(eval_results: list) -> OverallEvalValue:
126 | if not eval_results:
127 | raise ValueError("No evaluation results")
128 |
129 | return {
130 | "score": eval_results[0].score,
131 | **{r.key: r.dict() for r in eval_results},
132 | }
133 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/langfuse_helper/__init__.py:
--------------------------------------------------------------------------------
1 | from .adaptor import * # noqa
2 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/langfuse_helper/adaptor.py:
--------------------------------------------------------------------------------
1 | from zenbase.adaptors.langfuse_helper.dataset_helper import LangfuseDatasetHelper
2 | from zenbase.adaptors.langfuse_helper.evaluation_helper import LangfuseEvaluationHelper
3 |
4 |
5 | class ZenLangfuse(LangfuseDatasetHelper, LangfuseEvaluationHelper):
6 | def __init__(self, client=None):
7 | LangfuseDatasetHelper.__init__(self, client)
8 | LangfuseEvaluationHelper.__init__(self, client)
9 |
10 | def get_evaluator(self, data: str):
11 | data = self.fetch_dataset_demos(data)
12 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy()
13 | evaluator_kwargs_to_pass.update({"data": data})
14 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass)
15 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/langfuse_helper/dataset_helper.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from langfuse.client import DatasetItemClient
4 |
5 | from zenbase.adaptors.base.dataset_helper import BaseDatasetHelper
6 | from zenbase.types import LMDemo
7 |
8 |
9 | class LangfuseDatasetHelper(BaseDatasetHelper):
10 | def create_dataset(self, dataset_name: str, *args, **kwargs) -> Any:
11 | return self.client.create_dataset(dataset_name, *args, **kwargs)
12 |
13 | def add_examples_to_dataset(self, dataset_name: str, inputs: list, outputs: list) -> None:
14 | for the_input, the_output in zip(inputs, outputs):
15 | self.client.create_dataset_item(
16 | dataset_name=dataset_name,
17 | input=the_input,
18 | expected_output=the_output,
19 | )
20 |
21 | def fetch_dataset_examples(self, dataset_name: str) -> list[DatasetItemClient]:
22 | return self.client.get_dataset(dataset_name).items
23 |
24 | def fetch_dataset_demos(self, dataset_name: str) -> list[LMDemo]:
25 | return [
26 | LMDemo(inputs=example.input, outputs=example.expected_output)
27 | for example in self.fetch_dataset_examples(dataset_name)
28 | ]
29 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/langfuse_helper/evaluation_helper.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | from langfuse import Langfuse
4 | from langfuse.client import Dataset
5 |
6 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper
7 | from zenbase.optim.metric.types import (
8 | CandidateEvalResult,
9 | CandidateEvaluator,
10 | IndividualEvalValue,
11 | OverallEvalValue,
12 | )
13 | from zenbase.types import LMDemo, LMFunction, Outputs
14 | from zenbase.utils import pmap
15 |
16 |
17 | class LangfuseEvaluationHelper(BaseEvaluationHelper):
18 | MetricEvaluator = Callable[[list[OverallEvalValue]], OverallEvalValue]
19 |
20 | @staticmethod
21 | def default_candidate_evals(item_evals: list[OverallEvalValue]) -> OverallEvalValue:
22 | keys = item_evals[0].keys()
23 | evals = {k: sum(d[k] for d in item_evals) / len(item_evals) for k in keys}
24 | if not evals["score"]:
25 | evals["score"] = sum(evals.values()) / len(evals)
26 | return evals
27 |
28 | @staticmethod
29 | def dataset_demos(dataset: Dataset) -> list[LMDemo]:
30 | # TODO: Should remove and deprecate
31 | return [LMDemo(inputs=item.input, outputs=item.expected_output) for item in dataset.items]
32 |
33 | @staticmethod
34 | def _metric_evaluator_generator(
35 | data: list[LMDemo],
36 | evaluate: Callable[[Outputs, LMDemo, Langfuse], OverallEvalValue],
37 | candidate_evals: MetricEvaluator = default_candidate_evals,
38 | langfuse: Langfuse | None = None,
39 | concurrency: int = 20,
40 | threshold: float = 0.5,
41 | ) -> CandidateEvaluator:
42 | # TODO: this is not the way to run experiment in the langfuse, we should update with the new beta feature
43 | from langfuse import Langfuse
44 | from langfuse.decorators import observe
45 |
46 | langfuse = langfuse or Langfuse()
47 |
48 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult:
49 | individual_evals = []
50 |
51 | @observe()
52 | def run_and_evaluate(demo: LMDemo):
53 | nonlocal individual_evals
54 |
55 | outputs = function(demo.inputs)
56 | evals = evaluate(outputs, demo, langfuse=langfuse)
57 | individual_evals.append(
58 | IndividualEvalValue(
59 | passed=evals["score"] >= threshold,
60 | response=outputs,
61 | demo=demo,
62 | score=evals["score"],
63 | )
64 | )
65 | return evals
66 |
67 | item_evals = pmap(
68 | run_and_evaluate,
69 | data,
70 | concurrency=concurrency,
71 | )
72 | candidate_eval = candidate_evals(item_evals)
73 |
74 | return CandidateEvalResult(function, candidate_eval, individual_evals=individual_evals)
75 |
76 | return evaluate_candidate
77 |
78 | @classmethod
79 | def metric_evaluator(
80 | cls,
81 | evalset: Dataset,
82 | evaluate: Callable[[Outputs, LMDemo, Langfuse], OverallEvalValue],
83 | candidate_evals: MetricEvaluator = default_candidate_evals,
84 | langfuse: Langfuse | None = None,
85 | concurrency: int = 20,
86 | threshold: float = 0.5,
87 | ) -> CandidateEvaluator:
88 | # TODO: Should remove and deprecate
89 | from langfuse import Langfuse
90 | from langfuse.decorators import observe
91 |
92 | langfuse = langfuse or Langfuse()
93 |
94 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult:
95 | @observe()
96 | def run_and_evaluate(demo: LMDemo):
97 | outputs = function(demo.inputs)
98 | evals = evaluate(outputs, demo, langfuse=langfuse)
99 | return evals
100 |
101 | item_evals = pmap(
102 | run_and_evaluate,
103 | cls.dataset_demos(evalset),
104 | concurrency=concurrency,
105 | )
106 | candidate_eval = candidate_evals(item_evals)
107 |
108 | return CandidateEvalResult(function, candidate_eval)
109 |
110 | return evaluate_candidate
111 |
112 | def get_evaluator(self, data: str):
113 | raise NotImplementedError("This method should be implemented by the parent class as it needs access to data")
114 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/lunary/__init__.py:
--------------------------------------------------------------------------------
1 | from .adaptor import * # noqa
2 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/lunary/adaptor.py:
--------------------------------------------------------------------------------
1 | from zenbase.adaptors.lunary.dataset_helper import LunaryDatasetHelper
2 | from zenbase.adaptors.lunary.evaluation_helper import LunaryEvaluationHelper
3 |
4 |
5 | class ZenLunary(LunaryDatasetHelper, LunaryEvaluationHelper):
6 | def __init__(self, client=None):
7 | LunaryDatasetHelper.__init__(self, client)
8 | LunaryEvaluationHelper.__init__(self, client)
9 |
10 | def get_evaluator(self, data: str):
11 | data = self.fetch_dataset_demos(data)
12 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy()
13 | evaluator_kwargs_to_pass.update({"data": data})
14 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass)
15 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/lunary/dataset_helper.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from zenbase.adaptors.base.dataset_helper import BaseDatasetHelper
4 | from zenbase.types import LMDemo
5 |
6 |
7 | class LunaryDatasetHelper(BaseDatasetHelper):
8 | def create_dataset(self, dataset_name: str, *args, **kwargs) -> Any:
9 | raise NotImplementedError("Lunary doesn't support creating datasets")
10 |
11 | def add_examples_to_dataset(self, dataset_id: Any, inputs: list, outputs: list) -> None:
12 | raise NotImplementedError("Lunary doesn't support adding examples to datasets")
13 |
14 | def fetch_dataset_examples(self, dataset_name: str):
15 | return self.client.get_dataset(dataset_name)
16 |
17 | def fetch_dataset_demos(self, dataset_name: str) -> list[LMDemo]:
18 | return [
19 | LMDemo(inputs=example.input, outputs=example.ideal_output, adaptor_object=example)
20 | for example in self.client.get_dataset(dataset_name)
21 | ]
22 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/lunary/evaluation_helper.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable
2 |
3 | import lunary
4 |
5 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper
6 | from zenbase.optim.metric.types import (
7 | CandidateEvalResult,
8 | CandidateEvaluator,
9 | IndividualEvalValue,
10 | OverallEvalValue,
11 | )
12 | from zenbase.types import LMDemo, LMFunction
13 | from zenbase.utils import pmap
14 |
15 |
16 | class LunaryEvaluationHelper(BaseEvaluationHelper):
17 | MetricEvaluator = Callable[[list[tuple[bool, Any]]], OverallEvalValue]
18 |
19 | @staticmethod
20 | def default_metric(batch_results: list[tuple[bool, Any]]) -> OverallEvalValue:
21 | avg_pass = sum(int(passed) for passed, _ in batch_results) / len(batch_results)
22 | return {"score": avg_pass}
23 |
24 | def get_evaluator(self, data: str):
25 | raise NotImplementedError("This method should be implemented by the parent class as it needs access to data")
26 |
27 | @staticmethod
28 | def _metric_evaluator_generator(
29 | *args,
30 | checklist: str,
31 | data: list[LMDemo],
32 | eval_metrics: MetricEvaluator = default_metric,
33 | concurrency: int = 1,
34 | **kwargs,
35 | ) -> CandidateEvaluator:
36 | # TODO: Should remove and deprecate
37 | def evaluate_metric(function: LMFunction) -> CandidateEvalResult:
38 | individual_evals = []
39 |
40 | def run_and_evaluate(demo: LMDemo):
41 | nonlocal individual_evals
42 |
43 | response = function(demo.inputs)
44 | result = lunary.evaluate(
45 | checklist,
46 | input=demo.inputs,
47 | output=response,
48 | ideal_output=demo.outputs,
49 | *args,
50 | **kwargs,
51 | )
52 |
53 | individual_evals.append(
54 | IndividualEvalValue(
55 | details=result[1][0]["details"],
56 | passed=result[0],
57 | response=response,
58 | demo=demo,
59 | )
60 | )
61 |
62 | return result
63 |
64 | eval_results = pmap(
65 | run_and_evaluate,
66 | data,
67 | concurrency=concurrency,
68 | )
69 |
70 | return CandidateEvalResult(function, eval_metrics(eval_results), individual_evals=individual_evals)
71 |
72 | return evaluate_metric
73 |
74 | @classmethod
75 | def dataset_to_demos(cls, dataset: list[lunary.DatasetItem]) -> list[LMDemo]:
76 | # TODO: Should remove and deprecate
77 | return [LMDemo(inputs=item.input, outputs=item.ideal_output, adaptor_object=item) for item in dataset]
78 |
79 | @classmethod
80 | def metric_evaluator(
81 | cls,
82 | *args,
83 | checklist: str,
84 | evalset: list[lunary.DatasetItem],
85 | eval_metrics: MetricEvaluator = default_metric,
86 | concurrency: int = 20,
87 | **kwargs,
88 | ) -> CandidateEvaluator:
89 | # TODO: Should remove and deprecate
90 | def evaluate_metric(function: LMFunction) -> CandidateEvalResult:
91 | individual_evals = []
92 |
93 | def run_and_evaluate(demo: LMDemo):
94 | nonlocal individual_evals
95 |
96 | item = demo.adaptor_object
97 |
98 | response = function(item.input)
99 | result = lunary.evaluate(
100 | checklist,
101 | input=item.input,
102 | output=response,
103 | ideal_output=item.ideal_output,
104 | *args,
105 | **kwargs,
106 | )
107 |
108 | individual_evals.append(
109 | IndividualEvalValue(
110 | details=result[1][0]["details"],
111 | passed=result[0],
112 | response=response,
113 | demo=demo,
114 | )
115 | )
116 |
117 | return result
118 |
119 | eval_results = pmap(
120 | run_and_evaluate,
121 | cls.dataset_to_demos(evalset),
122 | concurrency=concurrency,
123 | )
124 |
125 | return CandidateEvalResult(function, eval_metrics(eval_results), individual_evals=individual_evals)
126 |
127 | return evaluate_metric
128 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/parea/__init__.py:
--------------------------------------------------------------------------------
1 | from .adaptor import * # noqa
2 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/parea/adaptor.py:
--------------------------------------------------------------------------------
1 | __all__ = ["ZenParea"]
2 |
3 | from zenbase.adaptors.parea.dataset_helper import PareaDatasetHelper
4 | from zenbase.adaptors.parea.evaluation_helper import PareaEvaluationHelper
5 | from zenbase.optim.metric.types import CandidateEvaluator
6 |
7 |
8 | class ZenParea(PareaDatasetHelper, PareaEvaluationHelper):
9 | def __init__(self, client=None):
10 | PareaDatasetHelper.__init__(self, client)
11 | PareaEvaluationHelper.__init__(self, client)
12 |
13 | def get_evaluator(self, data: str) -> CandidateEvaluator:
14 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy()
15 | evaluator_kwargs_to_pass.update({"data": self.fetch_dataset_list_of_dicts(data)})
16 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass)
17 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/parea/dataset_helper.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import ValuesView
3 |
4 | from parea.schemas import TestCase, TestCaseCollection
5 |
6 | from zenbase.adaptors.base.dataset_helper import BaseDatasetHelper
7 | from zenbase.types import LMDemo
8 |
9 |
10 | class PareaDatasetHelper(BaseDatasetHelper):
11 | def create_dataset(self, dataset_name: str):
12 | dataset = self.client.create_test_collection(data=[], name=dataset_name)
13 | return dataset
14 |
15 | def add_examples_to_dataset(self, inputs, outputs, dataset_name: str):
16 | data = [{"inputs": inputs[i], "target": outputs[i]} for i in range(len(inputs))]
17 | self.client.add_test_cases(data, dataset_name)
18 |
19 | def create_dataset_and_add_examples(self, inputs, outputs, dataset_name: str):
20 | data = [{"inputs": inputs[i], "target": outputs[i]} for i in range(len(inputs))]
21 | dataset = self.client.create_test_collection(dataset_name, data)
22 | return dataset
23 |
24 | def fetch_dataset_examples(self, dataset_name: str):
25 | return self.fetch_dataset(dataset_name).test_cases.values()
26 |
27 | def fetch_dataset(self, dataset_name: str) -> TestCaseCollection:
28 | return self.client.get_collection(dataset_name)
29 |
30 | def fetch_dataset_demos(self, dataset_name: str) -> list[LMDemo]:
31 | return self.example_to_demo(self.fetch_dataset_examples(dataset_name))
32 |
33 | def fetch_dataset_list_of_dicts(self, dataset_name: str) -> list[dict]:
34 | return [
35 | {"inputs": json.loads(case.inputs["inputs"]), "target": case.target}
36 | for case in self.fetch_dataset_examples(dataset_name)
37 | ]
38 |
39 | @staticmethod
40 | def example_to_demo(examples: ValuesView[TestCase]) -> list[LMDemo]:
41 | return [LMDemo(inputs=example.inputs, outputs={"target": example.target}) for example in examples]
42 |
--------------------------------------------------------------------------------
/py/src/zenbase/adaptors/parea/evaluation_helper.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from dataclasses import asdict
4 | from functools import partial
5 | from inspect import (
6 | _empty, # noqa
7 | signature,
8 | )
9 | from json import JSONDecodeError
10 | from typing import Callable
11 |
12 | from langsmith.evaluation._runner import ExperimentResults # noqa
13 | from parea import Parea
14 | from parea.schemas import ExperimentStatsSchema, ListExperimentUUIDsFilters
15 | from tenacity import (
16 | before_sleep_log,
17 | retry,
18 | stop_after_attempt,
19 | wait_exponential_jitter,
20 | )
21 |
22 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper
23 | from zenbase.optim.metric.types import (
24 | CandidateEvalResult,
25 | CandidateEvaluator,
26 | IndividualEvalValue,
27 | OverallEvalValue,
28 | )
29 | from zenbase.types import LMDemo, LMFunction
30 | from zenbase.utils import expand_nested_json, get_logger, random_name_generator
31 |
32 | log = get_logger(__name__)
33 |
34 |
35 | class PareaEvaluationHelper(BaseEvaluationHelper):
36 | MetricEvaluator = Callable[[dict[str, float]], OverallEvalValue]
37 |
38 | def __init__(self, client=None):
39 | super().__init__(client)
40 | self.evaluator_args = None
41 | self.evaluator_kwargs = None
42 |
43 | def get_evaluator(self, data: str):
44 | pass
45 |
46 | @staticmethod
47 | def default_candidate_evals(stats: ExperimentStatsSchema) -> OverallEvalValue:
48 | return {**stats.avg_scores, "score": sum(stats.avg_scores.values())}
49 |
50 | def _metric_evaluator_generator(
51 | self,
52 | *args,
53 | p: Parea | None = None,
54 | candidate_evals: MetricEvaluator = default_candidate_evals,
55 | **kwargs,
56 | ) -> CandidateEvaluator:
57 | p = p or Parea()
58 | assert isinstance(p, Parea)
59 |
60 | base_metadata = kwargs.pop("metadata", {})
61 | random_name = random_name_generator(kwargs.pop("name", None))()
62 |
63 | @retry(
64 | stop=stop_after_attempt(3),
65 | wait=wait_exponential_jitter(max=8),
66 | before_sleep=before_sleep_log(log, logging.WARN),
67 | )
68 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult:
69 | # Check if the function has a default value for the 'request' parameter
70 | # TODO: Needs clean up, this is not the way to do it.
71 | if (
72 | "optimized_args_in_fn" in signature(function).parameters
73 | and "request" in signature(function).parameters
74 | and signature(function).parameters["request"].default is _empty
75 | ):
76 | # Create a new function with 'None' as the default value for 'request'
77 | function_with_default = partial(function, request=None)
78 | else:
79 | # If 'request' already has a default value, use the function as is
80 | function_with_default = function
81 |
82 | experiment = p.experiment(
83 | func=function_with_default,
84 | *args,
85 | **kwargs,
86 | name=random_name,
87 | metadata={
88 | **base_metadata,
89 | },
90 | )
91 |
92 | experiment.run()
93 |
94 | if not experiment.experiment_stats:
95 | raise RuntimeError("Failed to run experiment on Parea")
96 |
97 | experiments = p.list_experiments(
98 | ListExperimentUUIDsFilters(experiment_name_filter=experiment.experiment_name)
99 | )
100 | experiment__uuid = experiments[0].uuid
101 | print(f"Num. experiments: {len(experiments)}")
102 | individual_evals = self._experiment_results_to_individual_evals(
103 | experiment.experiment_stats, experiment__uuid=experiment__uuid
104 | )
105 | return CandidateEvalResult(
106 | function,
107 | evals=candidate_evals(experiment.experiment_stats),
108 | individual_evals=individual_evals,
109 | )
110 |
111 | return evaluate_candidate
112 |
113 | def _experiment_results_to_individual_evals(
114 | self,
115 | experiment_stats: ExperimentStatsSchema,
116 | threshold=0.5,
117 | score_name="",
118 | experiment__uuid=None,
119 | ) -> list[IndividualEvalValue]:
120 | if experiment_stats is None or experiment__uuid is None:
121 | raise ValueError("experiment_stats and experiment__uuid must not be None")
122 |
123 | individual_evals = []
124 | # Retrieve the JSON logs for the experiment using its UUID
125 | try:
126 | json_traces = self._get_experiment_logs(experiment__uuid)
127 | except JSONDecodeError:
128 | raise ValueError("Failed to parse experiment logs")
129 |
130 | def find_input_output_with_trace_id(trace_id):
131 | for trace in json_traces:
132 | try:
133 | if trace["trace_id"] == trace_id:
134 | inputs = expand_nested_json(trace["inputs"])
135 | outputs = expand_nested_json(trace["output"])
136 | for k, v in inputs.items():
137 | if isinstance(v, dict) and "zenbase" in v:
138 | return v["inputs"], outputs
139 | except KeyError:
140 | continue
141 | return None, None
142 |
143 | for res in experiment_stats.parent_trace_stats:
144 | # Skip this iteration if there are no scores associated with the current result
145 | if not res.scores:
146 | continue
147 |
148 | # Find the score, prioritizing scores that match the given score name, or defaulting to the first score
149 | score = next((i.score for i in res.scores if score_name and i.name == score_name), res.scores[0].score)
150 |
151 | if not res.trace_id or not json_traces:
152 | raise ValueError("Trace ID or logs not found in experiment results")
153 |
154 | inputs, outputs = find_input_output_with_trace_id(res.trace_id)
155 |
156 | if not inputs or not outputs:
157 | continue
158 |
159 | individual_evals.append(
160 | IndividualEvalValue(
161 | passed=score >= threshold,
162 | response=outputs,
163 | demo=LMDemo(inputs=inputs, outputs=outputs, adaptor_object=res),
164 | score=score,
165 | )
166 | )
167 | return individual_evals
168 |
169 | def _get_experiment_logs(self, experiment__uuid):
170 | from parea.client import GET_EXPERIMENT_LOGS_ENDPOINT
171 |
172 | filter_data = {"filter_field": None, "filter_operator": None, "filter_value": None}
173 | endpoint = GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment__uuid)
174 | response = self.client._client.request("POST", endpoint, data=filter_data) # noqa
175 | return response.json()
176 |
177 | @classmethod
178 | def metric_evaluator(
179 | cls,
180 | *args,
181 | p: Parea | None = None,
182 | candidate_evals: MetricEvaluator = default_candidate_evals,
183 | **kwargs,
184 | ) -> CandidateEvaluator:
185 | # TODO: should
186 | p = p or Parea()
187 | assert isinstance(p, Parea)
188 |
189 | base_metadata = kwargs.pop("metadata", {})
190 | gen_random_name = random_name_generator(kwargs.pop("name", None))
191 |
192 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult:
193 | experiment = p.experiment(
194 | func=function,
195 | *args,
196 | **kwargs,
197 | name=gen_random_name(),
198 | metadata={
199 | **base_metadata,
200 | "zenbase": json.dumps(asdict(function.zenbase)),
201 | },
202 | )
203 |
204 | experiment.run()
205 | assert experiment.experiment_stats, "failed to run experiment"
206 |
207 | return CandidateEvalResult(
208 | function,
209 | evals=candidate_evals(experiment.experiment_stats),
210 | )
211 |
212 | return evaluate_candidate
213 |
--------------------------------------------------------------------------------
/py/src/zenbase/core/managers.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from abc import ABC
3 | from collections import OrderedDict
4 | from contextlib import contextmanager
5 | from typing import Any, Callable, Union
6 |
7 | from zenbase.types import LMFunction, LMZenbase
8 | from zenbase.utils import ksuid
9 |
10 |
11 | class BaseTracer(ABC):
12 | pass
13 |
14 |
15 | class ZenbaseTracer(BaseTracer):
16 | def __init__(self, max_traces=1000):
17 | self.all_traces = OrderedDict()
18 | self.max_traces = max_traces
19 | self.current_trace = None
20 | self.current_key = None
21 | self.optimized_args = {}
22 |
23 | def __call__(self, function: Callable[[Any], Any] = None, zenbase: LMZenbase = None) -> Union[Callable, LMFunction]:
24 | if function is None:
25 | return lambda f: self.trace_function(f, zenbase)
26 | return self.trace_function(function, zenbase)
27 |
28 | def flush(self):
29 | self.all_traces.clear()
30 |
31 | def add_trace(self, run_timestamp: str, func_name: str, trace_data: dict):
32 | is_new_key = run_timestamp not in self.all_traces
33 | if is_new_key and len(self.all_traces) >= self.max_traces:
34 | self.all_traces.popitem(last=False)
35 | traces_for_timestamp = self.all_traces.setdefault(run_timestamp, OrderedDict())
36 | traces_for_timestamp[func_name] = trace_data
37 | if not is_new_key:
38 | self.all_traces.move_to_end(run_timestamp)
39 |
40 | def trace_function(self, function: Callable[[Any], Any] = None, zenbase: LMZenbase = None) -> LMFunction:
41 | def wrapper(request, lm_function, *args, **kwargs):
42 | func_name = function.__name__
43 | run_timestamp = ksuid(func_name)
44 |
45 | if self.current_trace is None:
46 | with self.trace_context(func_name, run_timestamp):
47 | return self._execute_and_trace(function, func_name, request, lm_function, *args, **kwargs)
48 | else:
49 | return self._execute_and_trace(function, func_name, request, lm_function, *args, **kwargs)
50 |
51 | return LMFunction(wrapper, zenbase)
52 |
53 | @contextmanager
54 | def trace_context(self, func_name, run_timestamp, optimized_args=None):
55 | if self.current_trace is None:
56 | self.current_trace = {}
57 | self.current_key = run_timestamp
58 | if optimized_args:
59 | self.optimized_args = optimized_args
60 | try:
61 | yield
62 | finally:
63 | if self.current_key == run_timestamp:
64 | self.add_trace(run_timestamp, func_name, self.current_trace)
65 | self.current_trace = None
66 | self.current_key = None
67 | self.optimized_args = {}
68 |
69 | def _execute_and_trace(self, func, func_name, request, lm_function, *args, **kwargs):
70 | # Get the function signature
71 | sig = inspect.signature(func)
72 |
73 | # Map positional args to their names and combine with kwargs
74 | combined_args = {**kwargs}
75 | arg_names = list(sig.parameters.keys())[: len(args)]
76 | combined_args.update(zip(arg_names, args))
77 |
78 | # Include default values for missing arguments
79 | for param in sig.parameters.values():
80 | if param.name not in combined_args and param.default is not param.empty:
81 | combined_args[param.name] = param.default
82 |
83 | if func_name in self.optimized_args:
84 | optimized_args = self.optimized_args[func_name]["args"]
85 | if "zenbase" in optimized_args:
86 | request.zenbase = optimized_args["zenbase"]
87 | optimized_args.pop("zenbase", None)
88 |
89 | # Replace with optimized arguments if available
90 | if func_name in self.optimized_args:
91 | optimized_args = self.optimized_args[func_name]["args"]
92 | combined_args.update(optimized_args)
93 |
94 | combined_args.update(
95 | {
96 | "request": request,
97 | }
98 | )
99 | # Capture input arguments in trace_info
100 | trace_info = {"args": combined_args, "output": None, "request": request, "lm_function": lm_function}
101 |
102 | # Execute the function and capture its output
103 | output = func(**combined_args)
104 | trace_info["output"] = output
105 |
106 | # Store the trace information in the current_trace dictionary
107 | if self.current_trace is not None:
108 | self.current_trace[func_name] = trace_info
109 |
110 | return output
111 |
--------------------------------------------------------------------------------
/py/src/zenbase/optim/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass, field
3 | from random import Random
4 | from typing import Generic
5 |
6 | from pyee.asyncio import AsyncIOEventEmitter
7 |
8 | from zenbase.types import Inputs, LMFunction, Outputs
9 | from zenbase.utils import random_factory
10 |
11 |
12 | @dataclass(kw_only=True)
13 | class LMOptim(Generic[Inputs, Outputs], ABC):
14 | random: Random = field(default_factory=random_factory)
15 | events: AsyncIOEventEmitter = field(default_factory=AsyncIOEventEmitter)
16 |
17 | @abstractmethod
18 | def perform(
19 | self,
20 | lmfn: LMFunction[Inputs, Outputs],
21 | *args,
22 | **kwargs,
23 | ): ...
24 |
--------------------------------------------------------------------------------
/py/src/zenbase/optim/metric/bootstrap_few_shot.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from dataclasses import dataclass, field
3 | from datetime import datetime
4 | from functools import partial
5 | from typing import Any, Dict, NamedTuple
6 |
7 | import cloudpickle
8 |
9 | from zenbase.adaptors.langchain import ZenLangSmith
10 | from zenbase.core.managers import ZenbaseTracer
11 | from zenbase.optim.base import LMOptim
12 | from zenbase.optim.metric.labeled_few_shot import LabeledFewShot
13 | from zenbase.optim.metric.types import CandidateEvalResult
14 | from zenbase.types import Inputs, LMDemo, LMFunction, LMZenbase, Outputs
15 | from zenbase.utils import get_logger, ot_tracer
16 |
17 | log = get_logger(__name__)
18 |
19 |
20 | @dataclass(kw_only=True)
21 | class BootstrapFewShot(LMOptim[Inputs, Outputs]):
22 | class Result(NamedTuple):
23 | best_function: LMFunction[Inputs, Outputs]
24 | candidate_results: list[CandidateEvalResult] | None = None
25 |
26 | shots: int = field(default=5)
27 | training_set_demos: list[LMDemo[Inputs, Outputs]] | None = None
28 | training_set: Any = None # TODO: it needs to be more generic and pass our Dataset Object here
29 | test_set: Any = None
30 | validation_set: Any = None
31 | base_evaluation = None
32 | best_evaluation = None
33 | optimizer_args: Dict[str, dict[str, dict[str, LMDemo]]] = field(default_factory=dict)
34 | zen_adaptor: Any = None
35 | evaluator_kwargs: dict[str, Any] = field(default_factory=dict)
36 |
37 | def __post_init__(self):
38 | self.training_set_demos = self.zen_adaptor.fetch_dataset_demos(self.training_set)
39 | self.zen_adaptor.set_evaluator_kwargs(**self.evaluator_kwargs)
40 | assert 1 <= self.shots <= len(self.training_set_demos)
41 |
42 | @ot_tracer.start_as_current_span("perform")
43 | def perform(
44 | self,
45 | student_lm: LMFunction[Inputs, Outputs],
46 | teacher_lm: LMFunction[Inputs, Outputs] | None = None,
47 | samples: int = 5,
48 | rounds: int = 1,
49 | trace_manager: ZenbaseTracer = None,
50 | ) -> Result:
51 | """
52 | This function will perform the bootstrap few shot optimization on the given student_lm function.
53 | It will return the best function that is optimized based on the given student_lm function.
54 |
55 |
56 | :param student_lm: The student function that needs to be optimized
57 | :param teacher_lm: The teacher function that will be used to optimize the student function
58 | :param samples: The number of samples to be used for the optimization
59 | :param rounds: The number of rounds to be used for the optimization in the LabeledFewShot
60 | :param trace_manager: The trace manager that will be used to trace the function
61 | :param helper_class: The helper class that will be used to fetch the dataset and evaluator
62 | """
63 | assert trace_manager is not None, "Zenbase is required for this operation"
64 | # Clean up traces
65 | trace_manager.flush()
66 |
67 | test_set_evaluator = self.zen_adaptor.get_evaluator(data=self.test_set)
68 | self.base_evaluation = test_set_evaluator(student_lm)
69 |
70 | if not teacher_lm:
71 | # Create the base LabeledFewShot teacher model
72 | trace_manager.flush()
73 | teacher_lm = self._create_teacher_model(self.zen_adaptor, student_lm, samples, rounds)
74 |
75 | # Evaluate and validate the demo set
76 | validated_training_set_demos = self._validate_demo_set(self.zen_adaptor, teacher_lm)
77 |
78 | # Run each validated demo to fill up the traces
79 | trace_manager.flush()
80 | self._run_validated_demos(teacher_lm, validated_training_set_demos)
81 |
82 | # Consolidate the traces to optimized args
83 | optimized_args = self._consolidate_traces_to_optimized_args(trace_manager)
84 | self.set_optimizer_args(optimized_args)
85 |
86 | # Create the optimized function
87 | optimized_fn = self._create_optimized_function(student_lm, optimized_args, trace_manager)
88 |
89 | # Evaluate the optimized function
90 | self.best_evaluation = test_set_evaluator(optimized_fn)
91 |
92 | trace_manager.flush()
93 | return self.Result(best_function=optimized_fn)
94 |
95 | def _create_teacher_model(
96 | self, zen_adaptor: ZenLangSmith, student_lm: LMFunction, samples: int, rounds: int
97 | ) -> LMFunction:
98 | evaluator = zen_adaptor.get_evaluator(data=self.validation_set)
99 | teacher_lm, _, _ = LabeledFewShot(demoset=self.training_set_demos, shots=self.shots).perform(
100 | student_lm, evaluator=evaluator, samples=samples, rounds=rounds
101 | )
102 | return teacher_lm
103 |
104 | def _validate_demo_set(self, zen_adaptor: ZenLangSmith, teacher_lm: LMFunction) -> list[LMDemo]:
105 | # TODO: here is an issue that we are not removing the actual training set from the task demo
106 | # so it is possible of over fitting but it is not a big issue for now,
107 | # we should remove them in the trace_manager
108 | # def teacher_lm_tweaked(request: LMRequest):
109 | # #check inputs in the task demos
110 | # for demo in request.zenbase.task_demos:
111 | # if request.inputs == demo.inputs:
112 | # request.zenbase.task_demos.pop(demo)
113 | # return teacher_lm(request)
114 |
115 | # get evaluator for the training set
116 | evaluate_demo_set = zen_adaptor.get_evaluator(data=self.training_set)
117 | # run the evaluation and get the result of the evaluation
118 | result = evaluate_demo_set(teacher_lm)
119 | # find the validated training set that has been passed
120 | validated_demo_set = [eval.demo for eval in result.individual_evals if eval.passed]
121 | return validated_demo_set
122 |
123 | @staticmethod
124 | def _run_validated_demos(teacher_lm: LMFunction, validated_demo_set: list[LMDemo]) -> None:
125 | """
126 | Run each of the validated demos to fill up the traces
127 |
128 | :param teacher_lm: The teacher model to run the demos
129 | :param validated_demo_set: The validated demos to run
130 | """
131 | for validated_demo in validated_demo_set:
132 | teacher_lm(validated_demo.inputs)
133 |
134 | def _consolidate_traces_to_optimized_args(
135 | self, trace_manager: ZenbaseTracer
136 | ) -> dict[str, dict[str, dict[str, LMDemo]]]:
137 | """
138 | Consolidate the traces to optimized args that will be used to optimize the student function
139 |
140 | :param trace_manager: The trace manager that contains all the traces
141 | """
142 | all_traces = trace_manager.all_traces
143 | each_function_inputs = {}
144 |
145 | for trace_value in all_traces.values():
146 | for function_trace in trace_value.values():
147 | for inside_functions, inside_functions_traces in function_trace.items():
148 | input_args = inside_functions_traces["args"]["request"].inputs
149 | output_args = inside_functions_traces["output"]
150 |
151 | # Sanitize input and output arguments by replacing curly braces with spaces.
152 | # This prevents conflicts when using these arguments as keys in template rendering within LangChain.
153 | if isinstance(input_args, dict):
154 | input_args = {k: str(v).replace("{", " ").replace("}", " ") for k, v in input_args.items()}
155 | if isinstance(output_args, dict):
156 | output_args = {k: str(v).replace("{", " ").replace("}", " ") for k, v in output_args.items()}
157 |
158 | each_function_inputs.setdefault(inside_functions, []).append(
159 | LMDemo(inputs=input_args, outputs=output_args)
160 | )
161 |
162 | optimized_args = {
163 | function: {"args": {"zenbase": LMZenbase(task_demos=demos)}}
164 | for function, demos in each_function_inputs.items()
165 | }
166 | return optimized_args
167 |
168 | @staticmethod
169 | def _create_optimized_function(
170 | student_lm: LMFunction, optimized_args: dict, trace_manager: ZenbaseTracer
171 | ) -> LMFunction:
172 | """
173 | Create the optimized function that will be used to optimize the student function
174 |
175 | :param student_lm: The student function that needs to be optimized
176 | :param optimized_args: The optimized args that will be used to optimize the student function
177 | :param trace_manager: The trace manager that will be used to trace the function
178 | """
179 |
180 | def optimized_fn_base(request, zenbase, optimized_args_in_fn, trace_manager, *args, **kwargs):
181 | if request is None and "inputs" not in kwargs.keys():
182 | raise ValueError("Request or inputs should be passed")
183 | elif request is None:
184 | request = kwargs["inputs"]
185 | kwargs.pop("inputs")
186 |
187 | new_optimized_args = deepcopy(optimized_args_in_fn)
188 | with trace_manager.trace_context(
189 | "optimized", f"optimized_layer_0_{datetime.now().isoformat()}", new_optimized_args
190 | ):
191 | if request is None:
192 | return student_lm(*args, **kwargs)
193 | return student_lm(request, *args, **kwargs)
194 |
195 | optimized_fn = partial(
196 | optimized_fn_base,
197 | zenbase=LMZenbase(), # it doesn't do anything, it is just for type safety
198 | optimized_args_in_fn=optimized_args,
199 | trace_manager=trace_manager,
200 | )
201 | return optimized_fn
202 |
203 | def set_optimizer_args(self, args: Dict[str, Any]) -> None:
204 | """
205 | Set the optimizer arguments.
206 |
207 | :param args: A dictionary containing the optimizer arguments
208 | """
209 | self.optimizer_args = args
210 |
211 | def get_optimizer_args(self) -> Dict[str, Any]:
212 | """
213 | Get the current optimizer arguments.
214 |
215 | :return: A dictionary containing the current optimizer arguments
216 | """
217 | return self.optimizer_args
218 |
219 | def save_optimizer_args(self, file_path: str) -> None:
220 | """
221 | Save the optimizer arguments to a dill file.
222 |
223 | :param file_path: The path to save the dill file
224 | """
225 | with open(file_path, "wb") as f:
226 | cloudpickle.dump(self.optimizer_args, f)
227 |
228 | @classmethod
229 | def load_optimizer_and_function(
230 | cls, optimizer_args_file: str, student_lm: LMFunction[Inputs, Outputs], trace_manager: ZenbaseTracer
231 | ) -> LMFunction[Inputs, Outputs]:
232 | """
233 | Load optimizer arguments and create an optimized function.
234 |
235 | :param optimizer_args_file: The path to the JSON file containing optimizer arguments
236 | :param student_lm: The student function to be optimized
237 | :param trace_manager: The trace manager to be used
238 | :return: An optimized function
239 | """
240 | optimizer_args = cls._load_optimizer_args(optimizer_args_file)
241 | return cls._create_optimized_function(student_lm, optimizer_args, trace_manager)
242 |
243 | @classmethod
244 | def _load_optimizer_args(cls, file_path: str) -> Dict[str, Any]:
245 | """
246 | Load optimizer arguments from a dill file.
247 |
248 | :param file_path: The path to load the dill file from
249 | :return: A dictionary containing the loaded optimizer arguments
250 | """
251 | with open(file_path, "rb") as f:
252 | return cloudpickle.load(f)
253 |
--------------------------------------------------------------------------------
/py/src/zenbase/optim/metric/labeled_few_shot.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from math import factorial
3 | from typing import NamedTuple
4 |
5 | from zenbase.optim.base import LMOptim
6 | from zenbase.optim.metric.types import CandidateEvalResult, CandidateEvaluator
7 | from zenbase.types import Inputs, LMDemo, LMFunction, LMZenbase, Outputs
8 | from zenbase.utils import asyncify, get_logger, ksuid, ot_tracer, pmap, posthog
9 |
10 | log = get_logger(__name__)
11 |
12 |
13 | @dataclass(kw_only=True)
14 | class LabeledFewShot(LMOptim[Inputs, Outputs]):
15 | class Result(NamedTuple):
16 | best_function: LMFunction[Inputs, Outputs]
17 | candidate_results: list[CandidateEvalResult]
18 | best_candidate_result: CandidateEvalResult | None
19 |
20 | demoset: list[LMDemo[Inputs, Outputs]]
21 | shots: int = field(default=5)
22 |
23 | def __post_init__(self):
24 | assert 1 <= self.shots <= len(self.demoset)
25 |
26 | @ot_tracer.start_as_current_span("perform")
27 | def perform(
28 | self,
29 | lmfn: LMFunction[Inputs, Outputs],
30 | evaluator: CandidateEvaluator[Inputs, Outputs],
31 | samples: int = 0,
32 | rounds: int = 1,
33 | concurrency: int = 1,
34 | ) -> Result:
35 | samples = samples or len(self.demoset)
36 |
37 | best_score = float("-inf")
38 | best_lmfn = lmfn
39 | best_candidate_result = None
40 |
41 | @ot_tracer.start_as_current_span("run_experiment")
42 | def run_candidate_zenbase(zenbase: LMZenbase):
43 | nonlocal best_score, best_lmfn, best_candidate_result
44 |
45 | candidate_fn = lmfn.clean_and_duplicate(zenbase)
46 | try:
47 | candidate_result = evaluator(candidate_fn)
48 | except Exception as e:
49 | log.error("candidate evaluation failed", error=e)
50 | candidate_result = CandidateEvalResult(candidate_fn, {"score": float("-inf")})
51 |
52 | self.events.emit("candidate", candidate_result)
53 |
54 | if candidate_result.evals["score"] > best_score:
55 | best_score = candidate_result.evals["score"]
56 | best_lmfn = candidate_fn
57 | best_candidate_result = candidate_result
58 |
59 | return candidate_result
60 |
61 | candidates: list[CandidateEvalResult] = []
62 | for _ in range(rounds):
63 | candidates += pmap(
64 | run_candidate_zenbase,
65 | self.candidates(best_lmfn, samples),
66 | concurrency=concurrency,
67 | )
68 |
69 | posthog().capture(
70 | distinct_id=ksuid(),
71 | event="optimize_labeled_few_shot",
72 | properties={
73 | "evals": {c.function.id: c.evals for c in candidates},
74 | },
75 | )
76 |
77 | return self.Result(best_lmfn, candidates, best_candidate_result)
78 |
79 | async def aperform(
80 | self,
81 | lmfn: LMFunction[Inputs, Outputs],
82 | evaluator: CandidateEvaluator[Inputs, Outputs],
83 | samples: int = 0,
84 | rounds: int = 1,
85 | concurrency: int = 1,
86 | ) -> Result:
87 | return await asyncify(self.perform)(lmfn, evaluator, samples, rounds, concurrency)
88 |
89 | def candidates(self, _lmfn: LMFunction[Inputs, Outputs], samples: int):
90 | max_samples = factorial(len(self.demoset))
91 | if samples > max_samples:
92 | log.warn(
93 | "samples >= factorial(len(demoset)), using factorial(len(demoset))",
94 | max_samples=max_samples,
95 | samples=samples,
96 | )
97 | samples = max_samples
98 |
99 | for _ in range(samples):
100 | demos = tuple(self.random.sample(self.demoset, k=self.shots))
101 | yield LMZenbase(task_demos=demos)
102 |
--------------------------------------------------------------------------------
/py/src/zenbase/optim/metric/types.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from dataclasses import dataclass, field
3 | from typing import Callable, Generic, TypedDict
4 |
5 | from zenbase.types import Dataclass, Inputs, LMDemo, LMFunction, Outputs
6 |
7 |
8 | class OverallEvalValue(TypedDict):
9 | score: float
10 |
11 |
12 | @dataclasses.dataclass(frozen=True)
13 | class IndividualEvalValue(Dataclass, Generic[Outputs]):
14 | passed: bool
15 | response: Outputs
16 | demo: LMDemo
17 | score: float | None = None
18 | details: dict = field(default_factory=dict)
19 |
20 |
21 | @dataclass
22 | class CandidateEvalResult(Generic[Inputs, Outputs]):
23 | function: LMFunction[Inputs, Outputs]
24 | evals: OverallEvalValue = field(default_factory=dict)
25 | individual_evals: list[IndividualEvalValue] = field(default_factory=list)
26 |
27 |
28 | CandidateEvaluator = Callable[
29 | [LMFunction[Inputs, Outputs]],
30 | CandidateEvalResult[Inputs, Outputs],
31 | ]
32 |
--------------------------------------------------------------------------------
/py/src/zenbase/predefined/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/src/zenbase/predefined/__init__.py
--------------------------------------------------------------------------------
/py/src/zenbase/predefined/base/function_generator.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class BaseLMFunctionGenerator(abc.ABC):
5 | @abc.abstractmethod
6 | def generate(self, *args, **kwargs): ...
7 |
--------------------------------------------------------------------------------
/py/src/zenbase/predefined/base/optimizer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 |
4 | class BasePredefinedOptimizer(ABC):
5 | pass
6 |
--------------------------------------------------------------------------------
/py/src/zenbase/predefined/generic_lm_function/optimizer.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Any, Callable, List, NamedTuple, Type
3 |
4 | from instructor.client import Instructor
5 | from pydantic import BaseModel
6 |
7 | from zenbase.adaptors.json.adaptor import JSONAdaptor
8 | from zenbase.core.managers import ZenbaseTracer
9 | from zenbase.optim.metric.labeled_few_shot import LabeledFewShot
10 | from zenbase.optim.metric.types import CandidateEvalResult
11 | from zenbase.types import LMDemo, LMFunction
12 |
13 |
14 | @dataclass
15 | class GenericLMFunctionOptimizer:
16 | class Result(NamedTuple):
17 | best_function: LMFunction
18 | candidate_results: list[CandidateEvalResult]
19 | best_candidate_result: CandidateEvalResult | None
20 |
21 | instructor_client: Instructor
22 | prompt: str
23 | input_model: Type[BaseModel]
24 | output_model: Type[BaseModel]
25 | model: str
26 | zenbase_tracer: ZenbaseTracer
27 | training_set: List[dict]
28 | validation_set: List[dict]
29 | test_set: List[dict]
30 | custom_evaluator: Callable[[Any, dict], dict] = field(default=None)
31 | shots: int = 5
32 | samples: int = 10
33 | last_result: Result | None = field(default=None)
34 |
35 | lm_function: LMFunction = field(init=False)
36 | training_set_demos: List[LMDemo] = field(init=False)
37 | validation_set_demos: List[LMDemo] = field(init=False)
38 | test_set_demos: List[LMDemo] = field(init=False)
39 | best_evaluation: CandidateEvalResult | None = field(default=None)
40 | base_evaluation: CandidateEvalResult | None = field(default=None)
41 |
42 | def __post_init__(self):
43 | self.lm_function = self._generate_lm_function()
44 | self.training_set_demos = self._convert_dataset_to_demos(self.training_set)
45 | self.validation_set_demos = self._convert_dataset_to_demos(self.validation_set)
46 | self.test_set_demos = self._convert_dataset_to_demos(self.test_set)
47 |
48 | def _generate_lm_function(self) -> LMFunction:
49 | @self.zenbase_tracer.trace_function
50 | def generic_function(request):
51 | system_role = "assistant" if self.model.startswith("o1") else "system"
52 | messages = [
53 | {"role": system_role, "content": self.prompt},
54 | ]
55 |
56 | if request.zenbase.task_demos:
57 | messages.append({"role": system_role, "content": "Here are some examples:"})
58 | for demo in request.zenbase.task_demos:
59 | if demo.inputs == request.inputs:
60 | continue
61 | messages.extend(
62 | [
63 | {"role": "user", "content": str(demo.inputs)},
64 | {"role": "assistant", "content": str(demo.outputs)},
65 | ]
66 | )
67 | messages.append({"role": system_role, "content": "Now, please answer the following question:"})
68 |
69 | messages.append({"role": "user", "content": str(request.inputs)})
70 |
71 | kwargs = {
72 | "model": self.model,
73 | "response_model": self.output_model,
74 | "messages": messages,
75 | "max_retries": 3,
76 | }
77 |
78 | if not self.model.startswith("o1"):
79 | kwargs.update(
80 | {
81 | "logprobs": True,
82 | "top_logprobs": 5,
83 | }
84 | )
85 |
86 | return self.instructor_client.chat.completions.create(**kwargs)
87 |
88 | return generic_function
89 |
90 | def _convert_dataset_to_demos(self, dataset: List[dict]) -> List[LMDemo]:
91 | return [LMDemo(inputs=item["inputs"], outputs=item["outputs"]) for item in dataset]
92 |
93 | def optimize(self) -> Result:
94 | evaluator = self.custom_evaluator or self._create_default_evaluator()
95 | test_evaluator = self._create_test_evaluator(evaluator)
96 |
97 | # Perform base evaluation
98 | self.base_evaluation = self._perform_base_evaluation(test_evaluator)
99 |
100 | optimizer = LabeledFewShot(demoset=self.training_set_demos, shots=self.shots)
101 | optimizer_result = optimizer.perform(
102 | self.lm_function,
103 | evaluator=JSONAdaptor.metric_evaluator(
104 | data=self.validation_set_demos,
105 | eval_function=evaluator,
106 | ),
107 | samples=self.samples,
108 | rounds=1,
109 | )
110 |
111 | # Evaluate best function
112 | self.best_evaluation = self._evaluate_best_function(test_evaluator, optimizer_result)
113 |
114 | self.last_result = self.Result(
115 | best_function=optimizer_result.best_function,
116 | candidate_results=optimizer_result.candidate_results,
117 | best_candidate_result=optimizer_result.best_candidate_result,
118 | )
119 |
120 | return self.last_result
121 |
122 | def _create_default_evaluator(self):
123 | def evaluator(output: BaseModel, ideal_output: dict) -> dict:
124 | return {
125 | "passed": int(output.model_dump(mode="json") == ideal_output),
126 | }
127 |
128 | return evaluator
129 |
130 | def _create_test_evaluator(self, evaluator):
131 | return JSONAdaptor.metric_evaluator(
132 | data=self.test_set_demos,
133 | eval_function=evaluator,
134 | )
135 |
136 | def _perform_base_evaluation(self, test_evaluator):
137 | """Perform the base evaluation of the LM function."""
138 | return test_evaluator(self.lm_function)
139 |
140 | def _evaluate_best_function(self, test_evaluator, optimizer_result):
141 | """Evaluate the best function from the optimization result."""
142 | return test_evaluator(optimizer_result.best_function)
143 |
144 | def create_lm_function_with_demos(self, prompt: str, demos: List[dict]) -> LMFunction:
145 | @self.zenbase_tracer.trace_function
146 | def lm_function_with_demos(request):
147 | system_role = "assistant" if self.model.startswith("o1") else "system"
148 | messages = [
149 | {"role": system_role, "content": prompt},
150 | ]
151 |
152 | # Add demos to the messages
153 | if demos:
154 | messages.append({"role": system_role, "content": "Here are some examples:"})
155 | for demo in demos:
156 | messages.extend(
157 | [
158 | {"role": "user", "content": str(demo["inputs"])},
159 | {"role": "assistant", "content": str(demo["outputs"])},
160 | ]
161 | )
162 | messages.append({"role": system_role, "content": "Now, please answer the following question:"})
163 |
164 | # Add the actual request
165 | messages.append({"role": "user", "content": str(request.inputs)})
166 |
167 | kwargs = {
168 | "model": self.model,
169 | "response_model": self.output_model,
170 | "messages": messages,
171 | "max_retries": 3,
172 | }
173 |
174 | if not self.model.startswith("o1"):
175 | kwargs.update(
176 | {
177 | "logprobs": True,
178 | "top_logprobs": 5,
179 | }
180 | )
181 |
182 | return self.instructor_client.chat.completions.create(**kwargs)
183 |
184 | return lm_function_with_demos
185 |
186 | def generate_csv_report(self):
187 | if not self.last_result:
188 | raise ValueError("No results to generate report from")
189 |
190 | best_candidate_result = self.last_result.best_candidate_result
191 | base_evaluation = self.base_evaluation
192 | best_evaluation = self.best_evaluation
193 |
194 | list_of_rows = [("type", "input", "ideal_output", "output", "passed", "score", "details")]
195 |
196 | for eval_item in best_candidate_result.individual_evals:
197 | list_of_rows.append(
198 | (
199 | "best_candidate_result",
200 | eval_item.demo.inputs,
201 | eval_item.demo.outputs,
202 | eval_item.response,
203 | eval_item.passed,
204 | eval_item.score,
205 | eval_item.details,
206 | )
207 | )
208 |
209 | for eval_item in base_evaluation.individual_evals:
210 | list_of_rows.append(
211 | (
212 | "base_evaluation",
213 | eval_item.demo.inputs,
214 | eval_item.demo.outputs,
215 | eval_item.response,
216 | eval_item.passed,
217 | eval_item.score,
218 | eval_item.details,
219 | )
220 | )
221 |
222 | for eval_item in best_evaluation.individual_evals:
223 | list_of_rows.append(
224 | (
225 | "best_evaluation",
226 | eval_item.demo.inputs,
227 | eval_item.demo.outputs,
228 | eval_item.response,
229 | eval_item.passed,
230 | eval_item.score,
231 | eval_item.details,
232 | )
233 | )
234 |
235 | # save to csv
236 | import csv
237 |
238 | with open("report.csv", "w", newline="") as file:
239 | writer = csv.writer(file)
240 | writer.writerows(list_of_rows)
241 |
--------------------------------------------------------------------------------
/py/src/zenbase/predefined/single_class_classifier/__init__.py:
--------------------------------------------------------------------------------
1 | from .classifier import * # noqa
2 |
--------------------------------------------------------------------------------
/py/src/zenbase/predefined/single_class_classifier/classifier.py:
--------------------------------------------------------------------------------
1 | __all__ = ["SingleClassClassifier"]
2 |
3 | from dataclasses import dataclass, field
4 | from enum import Enum
5 | from typing import Any, Dict, NamedTuple, Type
6 |
7 | import cloudpickle
8 | from instructor.client import AsyncInstructor, Instructor
9 | from pydantic import BaseModel
10 |
11 | from zenbase.adaptors.json.adaptor import JSONAdaptor
12 | from zenbase.core.managers import ZenbaseTracer
13 | from zenbase.optim.metric.labeled_few_shot import LabeledFewShot
14 | from zenbase.optim.metric.types import CandidateEvalResult
15 | from zenbase.predefined.base.optimizer import BasePredefinedOptimizer
16 | from zenbase.predefined.single_class_classifier.function_generator import SingleClassClassifierLMFunctionGenerator
17 | from zenbase.predefined.syntethic_data.single_class_classifier import SingleClassClassifierSyntheticDataExample
18 | from zenbase.types import Inputs, LMDemo, LMFunction, Outputs
19 |
20 |
21 | @dataclass(kw_only=True)
22 | class SingleClassClassifier(BasePredefinedOptimizer):
23 | """
24 | A single-class classifier that optimizes and evaluates language model functions.
25 | """
26 |
27 | class Result(NamedTuple):
28 | best_function: LMFunction[Inputs, Outputs]
29 | candidate_results: list[CandidateEvalResult]
30 | best_candidate_result: CandidateEvalResult | None
31 |
32 | instructor_client: Instructor | AsyncInstructor
33 | prompt: str
34 | class_dict: Dict[str, str] | None = field(default=None)
35 | class_enum: Enum | None = field(default=None)
36 | prediction_class: Type[BaseModel] | None = field(default=None)
37 | model: str
38 | zenbase_tracer: ZenbaseTracer
39 | lm_function: LMFunction | None = field(default=None)
40 | training_set: list
41 | test_set: list
42 | validation_set: list
43 | shots: int = 5
44 | samples: int = 10
45 | best_evaluation: CandidateEvalResult | None = field(default=None)
46 | base_evaluation: CandidateEvalResult | None = field(default=None)
47 | optimizer_result: Result | None = field(default=None)
48 |
49 | def __post_init__(self):
50 | """Initialize the SingleClassClassifier after creation."""
51 | self.lm_function = self._generate_lm_function()
52 | self.training_set_demos = self._convert_dataset_to_demos(self.training_set)
53 | self.test_set_demos = self._convert_dataset_to_demos(self.test_set)
54 | self.validation_set_demos = self._convert_dataset_to_demos(self.validation_set)
55 |
56 | def _generate_lm_function(self) -> LMFunction:
57 | """Generate the language model function."""
58 | return SingleClassClassifierLMFunctionGenerator(
59 | instructor_client=self.instructor_client,
60 | prompt=self.prompt,
61 | class_dict=self.class_dict,
62 | class_enum=self.class_enum,
63 | prediction_class=self.prediction_class,
64 | model=self.model,
65 | zenbase_tracer=self.zenbase_tracer,
66 | ).generate()
67 |
68 | @staticmethod
69 | def _convert_dataset_to_demos(dataset: list) -> list[LMDemo]:
70 | """Convert a dataset to a list of LMDemo objects."""
71 | if dataset:
72 | if isinstance(dataset[0], dict):
73 | return [
74 | LMDemo(inputs={"question": item["inputs"]}, outputs={"answer": item["outputs"]}) for item in dataset
75 | ]
76 | elif isinstance(dataset[0], SingleClassClassifierSyntheticDataExample):
77 | return [LMDemo(inputs={"question": item.inputs}, outputs={"answer": item.outputs}) for item in dataset]
78 |
79 | def load_classifier(self, filename: str):
80 | with open(filename, "rb") as f:
81 | lm_zenbase = cloudpickle.load(f)
82 | return self.lm_function.clean_and_duplicate(lm_zenbase)
83 |
84 | def optimize(self) -> Result:
85 | """
86 | Perform the optimization and evaluation of the language model function.
87 |
88 | Returns:
89 | Result: The optimization result containing the best function and evaluation metrics.
90 | """
91 | # Define the evaluation function
92 | evaluator = self._create_evaluator()
93 |
94 | # Create test evaluator
95 | test_evaluator = self._create_test_evaluator(evaluator)
96 |
97 | # Perform base evaluation
98 | self.base_evaluation = self._perform_base_evaluation(test_evaluator)
99 |
100 | # Create and run optimizer
101 | optimizer_result = self._run_optimization(evaluator)
102 |
103 | # Evaluate best function
104 | self.best_evaluation = self._evaluate_best_function(test_evaluator, optimizer_result)
105 |
106 | # Save last optimizer_result
107 | self.optimizer_result = optimizer_result
108 |
109 | return optimizer_result
110 |
111 | @staticmethod
112 | def _create_evaluator():
113 | """Create the evaluation function."""
114 |
115 | def evaluator(output: Any, ideal_output: Dict[str, Any]) -> Dict[str, int]:
116 | return {
117 | "passed": int(ideal_output["answer"] == output.class_label.name),
118 | }
119 |
120 | return evaluator
121 |
122 | def _create_test_evaluator(self, evaluator):
123 | """Create the test evaluator using JSONAdaptor."""
124 | return JSONAdaptor.metric_evaluator(
125 | data=self.validation_set_demos,
126 | eval_function=evaluator,
127 | )
128 |
129 | def _perform_base_evaluation(self, test_evaluator):
130 | """Perform the base evaluation of the LM function."""
131 | return test_evaluator(self.lm_function)
132 |
133 | def _run_optimization(self, evaluator):
134 | """Run the optimization process."""
135 | optimizer = LabeledFewShot(demoset=self.training_set_demos, shots=self.shots)
136 | return optimizer.perform(
137 | self.lm_function,
138 | evaluator=JSONAdaptor.metric_evaluator(
139 | data=self.validation_set_demos,
140 | eval_function=evaluator,
141 | ),
142 | samples=self.samples,
143 | rounds=1,
144 | )
145 |
146 | def _evaluate_best_function(self, test_evaluator, optimizer_result):
147 | """Evaluate the best function from the optimization result."""
148 | return test_evaluator(optimizer_result.best_function)
149 |
--------------------------------------------------------------------------------
/py/src/zenbase/predefined/single_class_classifier/function_generator.py:
--------------------------------------------------------------------------------
1 | """
2 | This module provides a SingleClassClassifierLMFunctionGenerator for generating language model functions.
3 | """
4 |
5 | import logging
6 | from dataclasses import dataclass, field
7 | from enum import Enum
8 | from typing import Dict, Optional, Type
9 |
10 | from instructor.client import AsyncInstructor, Instructor
11 | from pydantic import BaseModel
12 | from tenacity import (
13 | before_sleep_log,
14 | retry,
15 | stop_after_attempt,
16 | wait_exponential_jitter,
17 | )
18 |
19 | from zenbase.core.managers import ZenbaseTracer
20 | from zenbase.predefined.base.function_generator import BaseLMFunctionGenerator
21 | from zenbase.types import LMFunction, LMRequest
22 |
23 | log = logging.getLogger(__name__)
24 |
25 |
26 | @dataclass(kw_only=True)
27 | class SingleClassClassifierLMFunctionGenerator(BaseLMFunctionGenerator):
28 | """
29 | A generator for creating single-class classifier language model functions.
30 | """
31 |
32 | instructor_client: Instructor | AsyncInstructor
33 | prompt: str
34 | class_dict: Optional[Dict[str, str]] = field(default=None)
35 | class_enum: Optional[Enum] = field(default=None)
36 | prediction_class: Optional[Type[BaseModel]] = field(default=None)
37 | model: str
38 | zenbase_tracer: ZenbaseTracer
39 |
40 | def __post_init__(self):
41 | """Initialize the generator after creation."""
42 | self._initialize_class_enum()
43 | self._initialize_prediction_class()
44 |
45 | def _initialize_class_enum(self):
46 | """Initialize the class enum if not provided."""
47 | if not self.class_enum and self.class_dict:
48 | self.class_enum = self._generate_class_enum()
49 |
50 | def _initialize_prediction_class(self):
51 | """Initialize the prediction class if not provided."""
52 | if not self.prediction_class and self.class_enum:
53 | self.prediction_class = self._generate_prediction_class()
54 |
55 | def generate(self) -> LMFunction:
56 | """Generate the classifier language model function."""
57 | return self._generate_classifier_prompt_lm_function()
58 |
59 | def _generate_class_enum(self) -> Enum:
60 | """Generate the class enum from the class dictionary."""
61 | return Enum("Labels", self.class_dict)
62 |
63 | def _generate_prediction_class(self) -> Type[BaseModel]:
64 | """Generate the prediction class based on the class enum."""
65 | class_enum = self.class_enum
66 |
67 | class SinglePrediction(BaseModel):
68 | reasoning: str
69 | class_label: class_enum
70 |
71 | return SinglePrediction
72 |
73 | def _generate_classifier_prompt_lm_function(self) -> LMFunction:
74 | """Generate the classifier prompt language model function."""
75 |
76 | @retry(
77 | stop=stop_after_attempt(3),
78 | wait=wait_exponential_jitter(max=8),
79 | before_sleep=before_sleep_log(log, logging.WARN),
80 | )
81 | def classifier_function(request: LMRequest):
82 | categories = "\n".join([f"- {key.upper()}: {value}" for key, value in self.class_dict.items()])
83 | messages = [
84 | {
85 | "role": "system",
86 | "content": f"""You are an expert classifier. Your task is to categorize inputs accurately based
87 | on the following instructions: {self.prompt}
88 |
89 | Categories and their descriptions:
90 | {categories}
91 |
92 | Rules:
93 | 1. Analyze the input carefully.
94 | 2. Choose the most appropriate category based on the descriptions provided.
95 | 3. Respond with ONLY the category name in UPPERCASE.
96 | 4. If unsure, choose the category that best fits the input.""",
97 | }
98 | ]
99 |
100 | if request.zenbase.task_demos:
101 | messages.append({"role": "system", "content": "Here are some examples of classifications:"})
102 | for demo in request.zenbase.task_demos:
103 | messages.extend(
104 | [
105 | {"role": "user", "content": demo.inputs["question"]},
106 | {"role": "assistant", "content": demo.outputs["answer"]},
107 | ]
108 | )
109 | messages.append(
110 | {"role": "system", "content": "Now, classify the new input following the same pattern."}
111 | )
112 |
113 | messages.extend(
114 | [
115 | {"role": "system", "content": "Please classify the following input:"},
116 | {"role": "user", "content": str(request.inputs)},
117 | ]
118 | )
119 |
120 | return self.instructor_client.chat.completions.create(
121 | model=self.model, response_model=self.prediction_class, messages=messages
122 | )
123 |
124 | return self.zenbase_tracer.trace_function(classifier_function)
125 |
--------------------------------------------------------------------------------
/py/src/zenbase/predefined/syntethic_data/single_class_classifier.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import io
3 | from typing import Dict, List
4 |
5 | from instructor import Instructor
6 | from pydantic import BaseModel, Field
7 |
8 |
9 | class SingleClassClassifierSyntheticDataExample(BaseModel):
10 | inputs: str = Field(..., description="The input text for single class classification")
11 | outputs: str = Field(..., description="The correct classification category")
12 |
13 |
14 | class SingleClassClassifierSyntheticDataGenerator:
15 | def __init__(
16 | self,
17 | instructor_client: Instructor,
18 | prompt: str,
19 | class_dict: Dict[str, str],
20 | model: str = "gpt-4o-mini",
21 | ):
22 | self.instructor_client = instructor_client
23 | self.prompt = prompt
24 | self.class_dict = class_dict
25 | self.model = model
26 |
27 | def generate_examples_for_category(
28 | self, category: str, description: str, num_examples: int
29 | ) -> List[SingleClassClassifierSyntheticDataExample]:
30 | messages = [
31 | {
32 | "role": "system",
33 | "content": f"""You are an expert in generating synthetic datasets for single class classification
34 | tasks. Your task is to create diverse and realistic examples based on the following instructions:
35 |
36 | {self.prompt}
37 |
38 | You are focusing on generating examples for the following category:
39 | - {category}: {description}
40 |
41 | For each example, generate:
42 | 1. A realistic and diverse input text that should be classified into the given category.
43 | 2. The category name as the output.
44 |
45 | Ensure diversity in the generated examples.""",
46 | },
47 | {"role": "user", "content": f"Generate {num_examples} examples for the category '{category}'."},
48 | ]
49 |
50 | response = self.instructor_client.chat.completions.create(
51 | model=self.model, response_model=List[SingleClassClassifierSyntheticDataExample], messages=messages
52 | )
53 |
54 | return response
55 |
56 | def generate_examples(self, examples_per_category: int) -> List[SingleClassClassifierSyntheticDataExample]:
57 | all_examples = []
58 | for category, description in self.class_dict.items():
59 | category_examples = self.generate_examples_for_category(category, description, examples_per_category)
60 | all_examples.extend(category_examples)
61 | return all_examples
62 |
63 | def generate_csv(self, examples_per_category: int) -> str:
64 | examples = self.generate_examples(examples_per_category)
65 |
66 | output = io.StringIO()
67 | writer = csv.DictWriter(output, fieldnames=["inputs", "outputs"])
68 | writer.writeheader()
69 | for example in examples:
70 | writer.writerow(example.dict())
71 |
72 | return output.getvalue()
73 |
74 | def save_csv(self, filename: str, examples_per_category: int):
75 | csv_content = self.generate_csv(examples_per_category)
76 | with open(filename, "w", newline="", encoding="utf-8") as f:
77 | f.write(csv_content)
78 |
--------------------------------------------------------------------------------
/py/src/zenbase/settings.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | # Define the base directory as the parent of the parent directory of this file
4 | BASE_DIR = Path(__file__).resolve().parent.parent
5 |
6 | # Define the test directory as a subdirectory of the base directory's parent
7 | TEST_DIR = BASE_DIR.parent / "tests"
8 |
--------------------------------------------------------------------------------
/py/src/zenbase/types.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import inspect
3 | import json
4 | from collections import deque
5 | from copy import copy
6 | from functools import partial
7 | from typing import Awaitable, Callable, Generic, TypeVar, Union, get_origin
8 |
9 | from zenbase.utils import asyncify, ksuid_generator
10 |
11 |
12 | class Dataclass:
13 | """
14 | Modified from Braintrust's SerializableDataClass
15 | """
16 |
17 | def copy(self, **changes):
18 | return dataclasses.replace(self, **changes)
19 |
20 | def as_dict(self):
21 | """Serialize the object to a dictionary."""
22 | return dataclasses.asdict(self)
23 |
24 | def as_json(self, **kwargs):
25 | """Serialize the object to JSON."""
26 | return json.dumps(self.as_dict(), **kwargs)
27 |
28 | @classmethod
29 | def from_dict(cls, d: dict):
30 | """Deserialize the object from a dictionary. This method
31 | is shallow and will not call from_dict() on nested objects."""
32 | fields = set(f.name for f in dataclasses.fields(cls))
33 | filtered = {k: v for k, v in d.items() if k in fields}
34 | return cls(**filtered)
35 |
36 | @classmethod
37 | def from_dict_deep(cls, d: dict):
38 | """Deserialize the object from a dictionary. This method
39 | is deep and will call from_dict_deep() on nested objects."""
40 | fields = {f.name: f for f in dataclasses.fields(cls)}
41 | filtered = {}
42 | for k, v in d.items():
43 | if k not in fields:
44 | continue
45 |
46 | if isinstance(v, dict) and isinstance(fields[k].type, type) and issubclass(fields[k].type, Dataclass):
47 | filtered[k] = fields[k].type.from_dict_deep(v)
48 | elif get_origin(fields[k].type) == Union:
49 | for t in fields[k].type.__args__:
50 | if isinstance(t, type) and issubclass(t, Dataclass):
51 | try:
52 | filtered[k] = t.from_dict_deep(v)
53 | break
54 | except TypeError:
55 | pass
56 | else:
57 | filtered[k] = v
58 | elif (
59 | isinstance(v, list)
60 | and get_origin(fields[k].type) == list
61 | and len(fields[k].type.__args__) == 1
62 | and isinstance(fields[k].type.__args__[0], type)
63 | and issubclass(fields[k].type.__args__[0], Dataclass)
64 | ):
65 | filtered[k] = [fields[k].type.__args__[0].from_dict_deep(i) for i in v]
66 | else:
67 | filtered[k] = v
68 | return cls(**filtered)
69 |
70 |
71 | Inputs = TypeVar("Inputs", covariant=True, bound=dict)
72 | Outputs = TypeVar("Outputs", covariant=True, bound=dict)
73 |
74 |
75 | @dataclasses.dataclass(frozen=True)
76 | class LMDemo(Dataclass, Generic[Inputs, Outputs]):
77 | inputs: Inputs
78 | outputs: Outputs
79 | adaptor_object: object | None = None
80 |
81 | def __hash__(self):
82 | def make_hashable(obj):
83 | if isinstance(obj, dict):
84 | return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
85 | elif isinstance(obj, list):
86 | return tuple(make_hashable(x) for x in obj)
87 | elif isinstance(obj, set):
88 | return tuple(sorted(make_hashable(x) for x in obj))
89 | else:
90 | return obj
91 | return hash((make_hashable(self.inputs), make_hashable(self.outputs)))
92 |
93 |
94 | @dataclasses.dataclass(frozen=True)
95 | class LMZenbase(Dataclass, Generic[Inputs, Outputs]):
96 | task_demos: list[LMDemo[Inputs, Outputs]] = dataclasses.field(default_factory=list)
97 | model_params: dict = dataclasses.field(default_factory=dict) # OpenAI-compatible model params
98 |
99 |
100 | @dataclasses.dataclass()
101 | class LMRequest(Dataclass, Generic[Inputs, Outputs]):
102 | zenbase: LMZenbase[Inputs, Outputs]
103 | inputs: Inputs = dataclasses.field(default_factory=dict)
104 | id: str = dataclasses.field(default_factory=ksuid_generator("request"))
105 |
106 |
107 | @dataclasses.dataclass(frozen=True)
108 | class LMResponse(Dataclass, Generic[Outputs]):
109 | outputs: Outputs
110 | attributes: dict = dataclasses.field(default_factory=dict) # token_count, cost, inference_time, etc.
111 | id: str = dataclasses.field(default_factory=ksuid_generator("response"))
112 |
113 |
114 | @dataclasses.dataclass(frozen=True)
115 | class LMCall(Dataclass, Generic[Inputs, Outputs]):
116 | function: "LMFunction[Inputs, Outputs]"
117 | request: LMRequest[Inputs, Outputs]
118 | response: LMResponse[Outputs]
119 | id: str = dataclasses.field(default_factory=ksuid_generator("call"))
120 |
121 |
122 | class LMFunction(Generic[Inputs, Outputs]):
123 | gen_id = staticmethod(ksuid_generator("fn"))
124 |
125 | id: str
126 | fn: Callable[[LMRequest[Inputs, Outputs]], Outputs | Awaitable[Outputs]]
127 | __name__: str
128 | __qualname__: str
129 | __doc__: str
130 | __signature__: inspect.Signature
131 | zenbase: LMZenbase[Inputs, Outputs]
132 | history: deque[LMCall[Inputs, Outputs]]
133 |
134 | def __init__(
135 | self,
136 | fn: Callable[[LMRequest[Inputs, Outputs]], Outputs | Awaitable[Outputs]],
137 | zenbase: LMZenbase | None = None,
138 | maxhistory: int = 100,
139 | ):
140 | self.fn = fn
141 |
142 | if qualname := getattr(fn, "__qualname__", None):
143 | self.id = qualname
144 | self.__qualname__ = qualname
145 | else:
146 | self.id = self.gen_id()
147 | self.__qualname__ = f"zenbase_{self.id}"
148 |
149 | self.__name__ = getattr(fn, "__name__", f"zenbase_{self.id}")
150 |
151 | self.__doc__ = getattr(fn, "__doc__", "")
152 | self.__signature__ = inspect.signature(fn)
153 |
154 | self.zenbase = zenbase or LMZenbase()
155 | self.history = deque([], maxlen=maxhistory)
156 |
157 | def clean_and_duplicate(self, zenbase: LMZenbase | None = None) -> "LMFunction[Inputs, Outputs]":
158 | dup = copy(self)
159 | dup.id = self.gen_id()
160 | dup.zenbase = zenbase or self.zenbase.copy()
161 | dup.history = deque([], maxlen=self.history.maxlen)
162 | return dup
163 |
164 | def prepare_request(self, inputs: Inputs) -> LMRequest[Inputs, Outputs]:
165 | return LMRequest(zenbase=self.zenbase, inputs=inputs)
166 |
167 | def process_response(
168 | self,
169 | request: LMRequest[Inputs, Outputs],
170 | outputs: Outputs,
171 | ) -> Outputs:
172 | self.history.append(LMCall(self, request, LMResponse(outputs)))
173 | return outputs
174 |
175 | def __call__(self, inputs: Inputs, *args, **kwargs) -> Outputs:
176 | request = self.prepare_request(inputs)
177 | kwargs.update({"lm_function": self} if "lm_function" in inspect.signature(self.fn).parameters else {})
178 | response = self.fn(request, *args, **kwargs)
179 | return self.process_response(request, response)
180 |
181 | async def coro(
182 | self,
183 | inputs: Inputs,
184 | *args,
185 | **kwargs,
186 | ) -> Outputs:
187 | request = self.prepare_request(inputs)
188 | response = await asyncify(self.fn)(request, *args, **kwargs)
189 | return self.process_response(request, response)
190 |
191 |
192 | def deflm(
193 | function: (Callable[[LMRequest[Inputs, Outputs]], Outputs | Awaitable[Outputs]] | None) = None,
194 | zenbase: LMZenbase[Inputs, Outputs] | None = None,
195 | ) -> LMFunction[Inputs, Outputs]:
196 | if function is None:
197 | return partial(deflm, zenbase=zenbase)
198 |
199 | if isinstance(function, LMFunction):
200 | return function.clean_and_duplicate(zenbase)
201 |
202 | return LMFunction(function, zenbase)
203 |
--------------------------------------------------------------------------------
/py/src/zenbase/utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import functools
3 | import inspect
4 | import json
5 | import logging
6 | import os
7 | from random import Random
8 | from typing import AsyncIterable, Awaitable, Callable, ParamSpec, TypeVar
9 |
10 | import anyio
11 | from anyio._core._eventloop import threadlocals
12 | from faker import Faker
13 | from opentelemetry import trace
14 | from pksuid import PKSUID
15 | from posthog import Posthog
16 | from structlog import get_logger
17 |
18 | get_logger: Callable[..., logging.Logger] = get_logger
19 | ot_tracer = trace.get_tracer("zenbase")
20 |
21 |
22 | def posthog() -> Posthog:
23 | if project_api_key := os.getenv("ZENBASE_ANALYTICS_KEY"):
24 | client = Posthog(
25 | project_api_key=project_api_key,
26 | host="https://us.i.posthog.com",
27 | )
28 | client.identify(os.environ["ZENBASE_ANALYTICS_ID"])
29 | else:
30 | client = Posthog("")
31 | client.disabled = True
32 | return client
33 |
34 |
35 | def get_seed(seed: int | None = None) -> int:
36 | return seed or int(os.getenv("RANDOM_SEED", 42))
37 |
38 |
39 | def random_factory(seed: int | None = None) -> Random:
40 | return Random(get_seed(seed))
41 |
42 |
43 | def ksuid(prefix: str | None = None) -> str:
44 | return str(PKSUID(prefix))
45 |
46 |
47 | def ksuid_generator(prefix: str) -> Callable[[], str]:
48 | return functools.partial(ksuid, prefix)
49 |
50 |
51 | def random_name_generator(
52 | prefix: str | None = None,
53 | random_name_generator=Faker().catch_phrase,
54 | ) -> Callable[[], str]:
55 | head = f"zenbase-{prefix}" if prefix else "zenbase"
56 |
57 | def gen():
58 | return "-".join([head, *random_name_generator().lower().split(" ")[:2]])
59 |
60 | return gen
61 |
62 |
63 | I_ParamSpec = ParamSpec("I_ParamSpec")
64 | O_Retval = TypeVar("O_Retval")
65 |
66 |
67 | def asyncify(
68 | func: Callable[I_ParamSpec, O_Retval],
69 | *,
70 | cancellable: bool = True,
71 | limiter: anyio.CapacityLimiter | None = None,
72 | ) -> Callable[I_ParamSpec, Awaitable[O_Retval]]:
73 | if inspect.iscoroutinefunction(func):
74 | return func
75 |
76 | @functools.wraps(func)
77 | async def wrapper(*args: I_ParamSpec.args, **kwargs: I_ParamSpec.kwargs) -> O_Retval:
78 | partial_f = functools.partial(func, *args, **kwargs)
79 | return await anyio.to_thread.run_sync(
80 | partial_f,
81 | abandon_on_cancel=cancellable,
82 | limiter=limiter,
83 | )
84 |
85 | return wrapper
86 |
87 |
88 | def syncify(
89 | func: Callable[I_ParamSpec, O_Retval],
90 | ) -> Callable[I_ParamSpec, O_Retval]:
91 | if not inspect.iscoroutinefunction(func):
92 | return func
93 |
94 | @functools.wraps(func)
95 | def wrapper(*args: I_ParamSpec.args, **kwargs: I_ParamSpec.kwargs) -> O_Retval:
96 | partial_f = functools.partial(func, *args, **kwargs)
97 | if not getattr(threadlocals, "current_async_backend", None):
98 | try:
99 | return asyncio.get_running_loop().run_until_complete(partial_f())
100 | except RuntimeError:
101 | return anyio.run(partial_f)
102 | return anyio.from_thread.run(partial_f)
103 |
104 | return wrapper
105 |
106 |
107 | ReturnValue = TypeVar("ReturnValue", covariant=True)
108 |
109 |
110 | async def amap(
111 | func: Callable[..., Awaitable[ReturnValue]],
112 | iterable,
113 | *iterables,
114 | concurrency=10,
115 | ) -> list[ReturnValue]:
116 | assert concurrency >= 1, "Concurrency must be greater than or equal to 1"
117 |
118 | if concurrency == 1:
119 | return [await func(*args) for args in zip(iterable, *iterables)]
120 |
121 | if concurrency == float("inf"):
122 | return await asyncio.gather(*[func(*args) for args in zip(iterable, *iterables)])
123 |
124 | semaphore = asyncio.Semaphore(concurrency)
125 |
126 | @functools.wraps(func)
127 | async def mapper(*args):
128 | async with semaphore:
129 | return await func(*args)
130 |
131 | return await asyncio.gather(*[mapper(*args) for args in zip(iterable, *iterables)])
132 |
133 |
134 | def pmap(
135 | func: Callable[..., ReturnValue],
136 | iterable,
137 | *iterables,
138 | concurrency=10,
139 | ) -> list[ReturnValue]:
140 | # TODO: Should revert.
141 | return [func(*args) for args in zip(iterable, *iterables)]
142 |
143 |
144 | async def alist(aiterable: AsyncIterable[ReturnValue]) -> list[ReturnValue]:
145 | return [x async for x in aiterable]
146 |
147 |
148 | def expand_nested_json(d):
149 | def recursive_expand(value):
150 | if isinstance(value, str):
151 | try:
152 | # Try to parse the string value as JSON
153 | parsed_value = json.loads(value)
154 | # Recursively expand the parsed value in case it contains further nested JSON
155 | return recursive_expand(parsed_value)
156 | except json.JSONDecodeError:
157 | # If parsing fails, return the original string
158 | return value
159 | elif isinstance(value, dict):
160 | # Recursively expand each key-value pair in the dictionary
161 | return {k: recursive_expand(v) for k, v in value.items()}
162 | elif isinstance(value, list):
163 | # Recursively expand each element in the list
164 | return [recursive_expand(elem) for elem in value]
165 | else:
166 | return value
167 |
168 | return recursive_expand(d)
169 |
--------------------------------------------------------------------------------
/py/tests/adaptors/bootstrap_few_shot_optimizer_args.zenbase:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/tests/adaptors/bootstrap_few_shot_optimizer_args.zenbase
--------------------------------------------------------------------------------
/py/tests/adaptors/parea_bootstrap_few_shot.zenbase:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/tests/adaptors/parea_bootstrap_few_shot.zenbase
--------------------------------------------------------------------------------
/py/tests/adaptors/test_braintrust.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/tests/adaptors/test_braintrust.py
--------------------------------------------------------------------------------
/py/tests/adaptors/test_lunary.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import lunary
4 | import pytest
5 | from openai import OpenAI
6 | from tenacity import (
7 | before_sleep_log,
8 | retry,
9 | stop_after_attempt,
10 | wait_exponential_jitter,
11 | )
12 |
13 | from zenbase.adaptors.lunary import ZenLunary
14 | from zenbase.core.managers import ZenbaseTracer
15 | from zenbase.optim.metric.bootstrap_few_shot import BootstrapFewShot
16 | from zenbase.optim.metric.labeled_few_shot import LabeledFewShot
17 | from zenbase.types import LMRequest
18 |
19 | SAMPLES = 2
20 | SHOTS = 3
21 | TESTSET_SIZE = 5
22 |
23 | log = logging.getLogger(__name__)
24 |
25 |
26 | @pytest.fixture
27 | def optim(gsm8k_demoset: list):
28 | return LabeledFewShot(demoset=gsm8k_demoset, shots=SHOTS)
29 |
30 |
31 | @pytest.fixture
32 | def bootstrap_few_shot_optim(gsm8k_demoset: list):
33 | return BootstrapFewShot(training_set_demos=gsm8k_demoset, shots=SHOTS)
34 |
35 |
36 | @pytest.fixture(scope="module")
37 | def openai():
38 | client = OpenAI()
39 | lunary.monitor(client)
40 | return client
41 |
42 |
43 | @pytest.fixture(scope="module")
44 | def evalset():
45 | items = lunary.get_dataset("gsm8k-evalset")
46 | assert any(items)
47 | return items
48 |
49 |
50 | @pytest.mark.helpers
51 | def test_lunary_lcel_labeled_few_shot(optim: LabeledFewShot, evalset: list):
52 | trace_manager = ZenbaseTracer()
53 |
54 | @trace_manager.trace_function
55 | @retry(
56 | stop=stop_after_attempt(3),
57 | wait=wait_exponential_jitter(max=8),
58 | before_sleep=before_sleep_log(log, logging.WARN),
59 | )
60 | def langchain_chain(request: LMRequest):
61 | """
62 | A math solver llm call that can solve any math problem setup with langchain libra.
63 | """
64 |
65 | from langchain_core.output_parsers import StrOutputParser
66 | from langchain_core.prompts import ChatPromptTemplate
67 | from langchain_openai import ChatOpenAI
68 |
69 | messages = [
70 | (
71 | "system",
72 | "You are an expert math solver. Your answer must be just the number with no separators, and nothing else. Follow the format of the examples.", # noqa
73 | # noqa
74 | )
75 | ]
76 | for demo in request.zenbase.task_demos:
77 | messages += [
78 | ("user", demo.inputs["question"]),
79 | ("assistant", demo.outputs["answer"]),
80 | ]
81 |
82 | messages.append(("user", "{question}"))
83 |
84 | chain = ChatPromptTemplate.from_messages(messages) | ChatOpenAI(model="gpt-4o-mini") | StrOutputParser()
85 |
86 | print("Mathing...")
87 | answer = chain.invoke(request.inputs)
88 | return answer.split("#### ")[-1]
89 |
90 | fn, candidates, _ = optim.perform(
91 | langchain_chain,
92 | evaluator=ZenLunary.metric_evaluator(
93 | checklist="exact-match",
94 | evalset=evalset,
95 | concurrency=2,
96 | ),
97 | samples=SAMPLES,
98 | rounds=1,
99 | )
100 |
101 | assert fn is not None
102 | assert any(candidates)
103 | assert next(c for c in candidates if 0.5 <= c.evals["score"] <= 1)
104 |
105 |
106 | @pytest.fixture(scope="module")
107 | def lunary_helper():
108 | return ZenLunary(client=lunary)
109 |
--------------------------------------------------------------------------------
/py/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any, Callable
3 |
4 | import pytest
5 | from datasets import DatasetDict
6 |
7 | from zenbase.types import LMDemo
8 |
9 |
10 | def pytest_configure():
11 | from pathlib import Path
12 |
13 | if not os.getenv("CI"):
14 | from dotenv import load_dotenv
15 |
16 | load_dotenv(str(Path(__file__).parent.parent / ".env.test"))
17 |
18 | import nest_asyncio
19 |
20 | nest_asyncio.apply()
21 |
22 |
23 | @pytest.fixture
24 | def anyio_backend():
25 | return "asyncio"
26 |
27 |
28 | @pytest.fixture(scope="session", autouse=True)
29 | def vcr_config():
30 | def response_processor(exclude_headers: list[str]) -> Callable:
31 | def before_record_response(response: dict | Any) -> dict | Any:
32 | if isinstance(response, dict):
33 | try:
34 | response_str = response.get("body", {}).get("string", b"").decode("utf-8")
35 | if "Rate limit reached for" in response_str:
36 | # don't record rate-limiting responses
37 | return None
38 | except UnicodeDecodeError:
39 | pass # ignore if we can't parse response
40 |
41 | for header in exclude_headers:
42 | if header in response["headers"]:
43 | response["headers"].pop(header)
44 | return response
45 |
46 | return before_record_response
47 |
48 | return {
49 | "filter_headers": [
50 | "User-Agent",
51 | "Accept",
52 | "Accept-Encoding",
53 | "Connection",
54 | "Content-Length",
55 | "Content-Type",
56 | # OpenAI request headers we don't want
57 | "Cookie",
58 | "authorization",
59 | "X-OpenAI-Client-User-Agent",
60 | "OpenAI-Organization",
61 | "x-stainless-lang",
62 | "x-stainless-package-version",
63 | "x-stainless-os",
64 | "x-stainless-arch",
65 | "x-stainless-runtime",
66 | "x-stainless-runtime-version",
67 | "x-api-key",
68 | ],
69 | "filter_query_parameters": ["api_key"],
70 | "cassette_library_dir": "tests/cache/cassettes",
71 | "before_record_response": response_processor(
72 | exclude_headers=[
73 | # OpenAI response headers we don't want
74 | "Set-Cookie",
75 | "Server",
76 | "access-control-allow-origin",
77 | "alt-svc",
78 | "openai-organization",
79 | "openai-version",
80 | "strict-transport-security",
81 | "x-ratelimit-limit-requests",
82 | "x-ratelimit-limit-tokens",
83 | "x-ratelimit-remaining-requests",
84 | "x-ratelimit-remaining-tokens",
85 | "x-ratelimit-reset-requests",
86 | "x-ratelimit-reset-tokens",
87 | "x-request-id",
88 | ]
89 | ),
90 | "match_on": [
91 | "method",
92 | "scheme",
93 | "host",
94 | "port",
95 | "path",
96 | "query",
97 | "body",
98 | "headers",
99 | ],
100 | }
101 |
102 |
103 | @pytest.fixture(scope="session")
104 | def gsm8k_dataset():
105 | import datasets
106 |
107 | return datasets.load_dataset("gsm8k", "main")
108 |
109 |
110 | @pytest.fixture(scope="session")
111 | def arxiv_dataset():
112 | import datasets
113 |
114 | return datasets.load_dataset("dansbecker/arxiv_article_classification")
115 |
116 |
117 | @pytest.fixture(scope="session")
118 | def news_dataset():
119 | import datasets
120 |
121 | return datasets.load_dataset("SetFit/20_newsgroups")
122 |
123 |
124 | @pytest.fixture(scope="session")
125 | def gsm8k_demoset(gsm8k_dataset: DatasetDict) -> list[LMDemo]:
126 | return [
127 | LMDemo(
128 | inputs={"question": r["question"]},
129 | outputs={"answer": r["answer"]},
130 | )
131 | for r in gsm8k_dataset["train"].select(range(5))
132 | ]
133 |
--------------------------------------------------------------------------------
/py/tests/core/managers.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from datetime import datetime
3 | from unittest.mock import patch
4 |
5 | import pytest
6 |
7 | from zenbase.core.managers import ZenbaseTracer
8 | from zenbase.types import LMFunction, LMZenbase
9 |
10 |
11 | @pytest.fixture
12 | def zenbase_manager():
13 | return ZenbaseTracer()
14 |
15 |
16 | @pytest.fixture
17 | def layer_2_1(zenbase_manager):
18 | @zenbase_manager.trace_function
19 | def _layer_2_1(request, instruction="default_instruction_2_1", candidates=[]):
20 | layer_2_1_output = f"layer_2_1_output_{str(request.inputs)}"
21 | return layer_2_1_output
22 |
23 | return _layer_2_1
24 |
25 |
26 | @pytest.fixture
27 | def layer_1_1(zenbase_manager, layer_2_1):
28 | @zenbase_manager.trace_function
29 | def _layer_1_1(request, instruction="default_instruction_1_1", candidates=[]):
30 | layer_2_1(request.inputs, instruction=instruction)
31 | layer_1_1_output = f"layer_1_1_output_{str(request.inputs)}"
32 | return layer_1_1_output
33 |
34 | return _layer_1_1
35 |
36 |
37 | @pytest.fixture
38 | def layer_1_2(zenbase_manager):
39 | @zenbase_manager.trace_function
40 | def _layer_1_2(request, instruction="default_instruction_1_2", candidates=[]):
41 | layer_1_2_output = f"layer_1_2_output_{str(request.inputs)}"
42 | return layer_1_2_output
43 |
44 | return _layer_1_2
45 |
46 |
47 | @pytest.fixture
48 | def layer_0(zenbase_manager, layer_1_1, layer_1_2):
49 | @zenbase_manager.trace_function
50 | def _layer_0(request, instruction="default_instruction_0", candidates=[]):
51 | layer_1_1(inputs=request.inputs)
52 | layer_1_2(inputs=request.inputs)
53 | layer_0_output = f"layer_0_output_{str(request.inputs)}"
54 | return layer_0_output
55 |
56 | return _layer_0
57 |
58 |
59 | @pytest.fixture
60 | def layer_0_2(zenbase_manager, layer_1_1, layer_1_2):
61 | @zenbase_manager.trace_function
62 | def _layer_0_2(request, instruction="default_instruction_0_2", candidates=[]):
63 | layer_1_1(inputs=request.inputs["inputs"])
64 | layer_1_2(inputs=request.inputs["inputs"])
65 | layer_0_output = f"layer_0_2_output_{str(request.inputs['inputs'])}"
66 | return layer_0_output
67 |
68 | return _layer_0_2
69 |
70 |
71 | def test_trace_layer_0(zenbase_manager, layer_0):
72 | inputs = [{"inputs": i} for i in range(5)]
73 |
74 | for inputs in inputs:
75 | layer_0(inputs=inputs)
76 |
77 | assert len(zenbase_manager.all_traces) == 5
78 | for trace in zenbase_manager.all_traces.values():
79 | assert "_layer_0" in trace
80 |
81 |
82 | def test_trace_layer_0_multiple_runs(zenbase_manager, layer_0):
83 | inputs = [{"inputs": i} for i in range(5)]
84 |
85 | for the_input in inputs:
86 | layer_0(inputs=the_input)
87 | for the_input in inputs:
88 | layer_0(inputs=the_input)
89 |
90 | assert len(zenbase_manager.all_traces) == 10
91 |
92 |
93 | def test_trace_layer_0_2(zenbase_manager, layer_0_2):
94 | inputs = [{"inputs": i} for i in range(5)]
95 |
96 | for the_input in inputs:
97 | layer_0_2(inputs=the_input)
98 |
99 | assert len(zenbase_manager.all_traces) == 5
100 | for trace in zenbase_manager.all_traces.values():
101 | assert "_layer_0_2" in trace
102 |
103 |
104 | def test_trace_layer_0_with_optimized_args(zenbase_manager, layer_0):
105 | inputs = [{"inputs": i} for i in range(5)]
106 | optimized_args = {
107 | "layer_2_1": {"args": {"instruction": "optimized_instruction_2_1", "candidates": ["optimized_candidate_2_1"]}},
108 | "layer_1_1": {"args": {"instruction": "optimized_instruction_1_1", "candidates": ["optimized_candidate_1_1"]}},
109 | "layer_1_2": {"args": {"instruction": "optimized_instruction_1_2", "candidates": ["optimized_candidate_1_2"]}},
110 | "layer_0": {"args": {"instruction": "optimized_instruction_0", "candidates": ["optimized_candidate_0"]}},
111 | }
112 |
113 | def optimized_layer_0(*args, **kwargs):
114 | with zenbase_manager.trace_context(
115 | "optimized_layer_0", f"optimized_layer_0_{datetime.now().isoformat()}", optimized_args
116 | ):
117 | return layer_0(*args, **kwargs)
118 |
119 | for the_input in inputs:
120 | optimized_layer_0(inputs=the_input)
121 |
122 | assert len(zenbase_manager.all_traces) == 5
123 | for trace in zenbase_manager.all_traces.values():
124 | assert "optimized_layer_0" in trace
125 |
126 |
127 | def test_trace_layer_functions(zenbase_manager, layer_2_1, layer_1_1, layer_1_2):
128 | inputs = [{"inputs": i} for i in range(5)]
129 |
130 | for inputs in inputs:
131 | layer_2_1(inputs=inputs)
132 | layer_1_1(inputs=inputs)
133 | layer_1_2(inputs=inputs)
134 |
135 | assert len(zenbase_manager.all_traces) == 15
136 | for trace in zenbase_manager.all_traces.values():
137 | assert any(func in trace for func in ["_layer_2_1", "_layer_1_1", "_layer_1_2"])
138 |
139 |
140 | @pytest.fixture
141 | def tracer():
142 | return ZenbaseTracer(max_traces=3)
143 |
144 |
145 | def test_init(tracer):
146 | assert isinstance(tracer.all_traces, OrderedDict)
147 | assert tracer.max_traces == 3
148 | assert tracer.current_trace is None
149 | assert tracer.current_key is None
150 | assert tracer.optimized_args == {}
151 |
152 |
153 | def test_flush(tracer):
154 | tracer.all_traces = OrderedDict({"key1": "value1", "key2": "value2"})
155 | tracer.flush()
156 | assert len(tracer.all_traces) == 0
157 |
158 |
159 | def test_add_trace(tracer):
160 | # Add first trace
161 | tracer.add_trace("timestamp1", "func1", {"data": "trace1"})
162 | assert len(tracer.all_traces) == 1
163 | assert "timestamp1" in tracer.all_traces
164 |
165 | # Add second trace
166 | tracer.add_trace("timestamp2", "func2", {"data": "trace2"})
167 | assert len(tracer.all_traces) == 2
168 |
169 | # Add third trace
170 | tracer.add_trace("timestamp3", "func3", {"data": "trace3"})
171 | assert len(tracer.all_traces) == 3
172 |
173 | # Add fourth trace (should remove oldest)
174 | tracer.add_trace("timestamp4", "func4", {"data": "trace4"})
175 | assert len(tracer.all_traces) == 3
176 | assert "timestamp1" not in tracer.all_traces
177 | assert "timestamp4" in tracer.all_traces
178 |
179 |
180 | @patch("zenbase.utils.ksuid")
181 | def test_trace_function(mock_ksuid, tracer):
182 | mock_ksuid.return_value = "test_timestamp"
183 |
184 | def test_func(request):
185 | return request.inputs[0] + request.inputs[1]
186 |
187 | zenbase = LMZenbase()
188 | traced_func = tracer.trace_function(test_func, zenbase)
189 | assert isinstance(traced_func, LMFunction)
190 |
191 | result = traced_func(inputs=(2, 3))
192 |
193 | assert result == 5
194 | trace = tracer.all_traces[list(tracer.all_traces.keys())[0]]
195 | assert "test_func" in trace["test_func"]
196 | trace_info = trace["test_func"]["test_func"]
197 | assert trace_info["args"]["request"].inputs == (2, 3)
198 | assert trace_info["output"] == 5
199 |
200 |
201 | def test_trace_context(tracer):
202 | with tracer.trace_context("test_func", "test_timestamp"):
203 | assert tracer.current_key == "test_timestamp"
204 | assert isinstance(tracer.current_trace, dict)
205 |
206 | assert tracer.current_trace is None
207 | assert tracer.current_key is None
208 | assert "test_timestamp" in tracer.all_traces
209 |
210 |
211 | def test_max_traces_limit(tracer):
212 | for i in range(5):
213 | tracer.add_trace(f"timestamp{i}", f"func{i}", {"data": f"trace{i}"})
214 |
215 | assert len(tracer.all_traces) == 3
216 | assert "timestamp0" not in tracer.all_traces
217 | assert "timestamp1" not in tracer.all_traces
218 | assert "timestamp2" in tracer.all_traces
219 | assert "timestamp3" in tracer.all_traces
220 | assert "timestamp4" in tracer.all_traces
221 |
222 |
223 | @patch("zenbase.utils.ksuid")
224 | def test_optimized_args(mock_ksuid, tracer):
225 | mock_ksuid.return_value = "test_timestamp"
226 |
227 | def test_func(request, z=3):
228 | x, y = request.inputs
229 | return x + y + z
230 |
231 | tracer.optimized_args = {"test_func": {"args": {"z": 5}}}
232 | zenbase = LMZenbase()
233 | traced_func = tracer.trace_function(test_func, zenbase)
234 |
235 | result = traced_func(inputs=(2, 10))
236 |
237 | assert result == 17 # 2 + 10 + 5
238 | trace = tracer.all_traces[list(tracer.all_traces.keys())[0]]
239 | assert "test_func" in trace["test_func"]
240 | trace_info = trace["test_func"]["test_func"]
241 | assert trace_info["args"]["request"].inputs == (2, 10)
242 | assert trace_info["args"]["z"] == 5
243 | assert trace_info["output"] == 17
244 |
--------------------------------------------------------------------------------
/py/tests/optim/metric/test_bootstrap_few_shot.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock, mock_open, patch
2 |
3 | import pytest
4 |
5 | from zenbase.adaptors.langchain import ZenLangSmith
6 | from zenbase.core.managers import ZenbaseTracer
7 | from zenbase.optim.metric.bootstrap_few_shot import BootstrapFewShot
8 | from zenbase.types import LMDemo, LMFunction, LMZenbase
9 |
10 |
11 | @pytest.fixture
12 | def mock_zen_adaptor():
13 | adaptor = Mock(spec=ZenLangSmith)
14 | adaptor.fetch_dataset_demos.return_value = [
15 | LMDemo(inputs={"input": "test1"}, outputs={"output": "result1"}),
16 | LMDemo(inputs={"input": "test2"}, outputs={"output": "result2"}),
17 | LMDemo(inputs={"input": "test3"}, outputs={"output": "result3"}),
18 | ]
19 | adaptor.get_evaluator.return_value = Mock(
20 | return_value=Mock(
21 | individual_evals=[
22 | Mock(passed=True, demo=LMDemo(inputs={"input": "test1"}, outputs={"output": "result1"})),
23 | Mock(passed=True, demo=LMDemo(inputs={"input": "test2"}, outputs={"output": "result2"})),
24 | Mock(passed=False, demo=LMDemo(inputs={"input": "test3"}, outputs={"output": "result3"})),
25 | ]
26 | )
27 | )
28 | return adaptor
29 |
30 |
31 | @pytest.fixture
32 | def mock_trace_manager():
33 | return Mock(spec=ZenbaseTracer)
34 |
35 |
36 | @pytest.fixture
37 | def bootstrap_few_shot(mock_zen_adaptor):
38 | return BootstrapFewShot(
39 | shots=2,
40 | training_set=Mock(),
41 | test_set=Mock(),
42 | validation_set=Mock(),
43 | zen_adaptor=mock_zen_adaptor,
44 | )
45 |
46 |
47 | def test_init(bootstrap_few_shot):
48 | assert bootstrap_few_shot.shots == 2
49 | assert len(bootstrap_few_shot.training_set_demos) == 3
50 |
51 |
52 | def test_init_invalid_shots():
53 | with pytest.raises(AssertionError):
54 | BootstrapFewShot(shots=0, training_set=Mock(), test_set=Mock(), validation_set=Mock(), zen_adaptor=Mock())
55 |
56 |
57 | def test_create_teacher_model(bootstrap_few_shot, mock_zen_adaptor):
58 | mock_lmfn = Mock(spec=LMFunction)
59 | with patch("zenbase.optim.metric.bootstrap_few_shot.LabeledFewShot") as mock_labeled_few_shot:
60 | mock_labeled_few_shot.return_value.perform.return_value = (Mock(), None, None)
61 | teacher_model = bootstrap_few_shot._create_teacher_model(mock_zen_adaptor, mock_lmfn, 5, 1)
62 | assert teacher_model is not None
63 | mock_labeled_few_shot.assert_called_once()
64 |
65 |
66 | def test_validate_demo_set(bootstrap_few_shot, mock_zen_adaptor):
67 | mock_teacher_lm = Mock(spec=LMFunction)
68 | validated_demos = bootstrap_few_shot._validate_demo_set(mock_zen_adaptor, mock_teacher_lm)
69 | assert len(validated_demos) == 2
70 | assert all(demo.inputs["input"].startswith("test") for demo in validated_demos)
71 |
72 |
73 | def test_run_validated_demos():
74 | mock_teacher_lm = Mock(spec=LMFunction)
75 | validated_demo_set = [
76 | LMDemo(inputs={"input": "test1"}, outputs={"output": "result1"}),
77 | LMDemo(inputs={"input": "test2"}, outputs={"output": "result2"}),
78 | ]
79 | BootstrapFewShot._run_validated_demos(mock_teacher_lm, validated_demo_set)
80 | assert mock_teacher_lm.call_count == 2
81 |
82 |
83 | def test_consolidate_traces_to_optimized_args(mock_trace_manager, bootstrap_few_shot):
84 | mock_trace_manager.all_traces = {
85 | "trace1": {
86 | "func1": {
87 | "inner_func1": {"args": {"request": Mock(inputs={"input": "test1"})}, "output": {"output": "result1"}}
88 | }
89 | }
90 | }
91 | optimized_args = bootstrap_few_shot._consolidate_traces_to_optimized_args(mock_trace_manager)
92 | assert "inner_func1" in optimized_args
93 | assert isinstance(optimized_args["inner_func1"]["args"]["zenbase"], LMZenbase)
94 |
95 |
96 | def test_create_optimized_function():
97 | mock_student_lm = Mock(spec=LMFunction)
98 | mock_trace_manager = Mock(spec=ZenbaseTracer)
99 | optimized_args = {"func1": {"args": {"zenbase": LMZenbase(task_demos=[])}}}
100 |
101 | optimized_fn = BootstrapFewShot._create_optimized_function(mock_student_lm, optimized_args, mock_trace_manager)
102 | assert callable(optimized_fn)
103 |
104 |
105 | @patch("zenbase.optim.metric.bootstrap_few_shot.partial")
106 | def test_perform(mock_partial, bootstrap_few_shot, mock_zen_adaptor, mock_trace_manager):
107 | mock_student_lm = Mock(spec=LMFunction)
108 | mock_teacher_lm = Mock(spec=LMFunction)
109 |
110 | with patch.object(bootstrap_few_shot, "_create_teacher_model", return_value=mock_teacher_lm):
111 | with patch.object(bootstrap_few_shot, "_validate_demo_set"):
112 | with patch.object(bootstrap_few_shot, "_run_validated_demos"):
113 | with patch.object(bootstrap_few_shot, "_consolidate_traces_to_optimized_args"):
114 | with patch.object(bootstrap_few_shot, "_create_optimized_function"):
115 | result = bootstrap_few_shot.perform(
116 | mock_student_lm,
117 | mock_teacher_lm,
118 | samples=5,
119 | rounds=1,
120 | trace_manager=mock_trace_manager,
121 | )
122 |
123 | assert isinstance(result, BootstrapFewShot.Result)
124 | assert result.best_function is not None
125 |
126 |
127 | def test_set_and_get_optimizer_args(bootstrap_few_shot):
128 | test_args = {"test": "args"}
129 | bootstrap_few_shot.set_optimizer_args(test_args)
130 | assert bootstrap_few_shot.get_optimizer_args() == test_args
131 |
132 |
133 | @patch("cloudpickle.dump")
134 | def test_save_optimizer_args(mock_dump, bootstrap_few_shot, tmp_path):
135 | test_args = {"test": "args"}
136 | bootstrap_few_shot.set_optimizer_args(test_args)
137 | file_path = tmp_path / "test_optimizer_args.dill"
138 | bootstrap_few_shot.save_optimizer_args(str(file_path))
139 | mock_dump.assert_called_once()
140 |
141 |
142 | @patch("builtins.open", new_callable=mock_open, read_data="dummy data")
143 | @patch("cloudpickle.load")
144 | def test_load_optimizer_args(mock_load, mock_file):
145 | test_args = {"test": "args"}
146 | mock_load.return_value = test_args
147 | loaded_args = BootstrapFewShot._load_optimizer_args("dummy_path")
148 | mock_file.assert_called_once_with("dummy_path", "rb")
149 | mock_load.assert_called_once()
150 | assert loaded_args == test_args
151 |
152 |
153 | @patch("zenbase.optim.metric.bootstrap_few_shot.BootstrapFewShot._load_optimizer_args")
154 | @patch("zenbase.optim.metric.bootstrap_few_shot.BootstrapFewShot._create_optimized_function")
155 | def test_load_optimizer_and_function(mock_create_optimized_function, mock_load_optimizer_args):
156 | mock_student_lm = Mock(spec=LMFunction)
157 | mock_trace_manager = Mock(spec=ZenbaseTracer)
158 | mock_load_optimizer_args.return_value = {"test": "args"}
159 | mock_create_optimized_function.return_value = Mock(spec=LMFunction)
160 |
161 | result = BootstrapFewShot.load_optimizer_and_function("dummy_path", mock_student_lm, mock_trace_manager)
162 |
163 | mock_load_optimizer_args.assert_called_once_with("dummy_path")
164 | mock_create_optimized_function.assert_called_once()
165 | assert isinstance(result, Mock)
166 | assert isinstance(result, LMFunction)
167 |
--------------------------------------------------------------------------------
/py/tests/optim/metric/test_labeled_few_shot.py:
--------------------------------------------------------------------------------
1 | from random import Random, random
2 |
3 | import pytest
4 |
5 | from zenbase.optim.metric.labeled_few_shot import LabeledFewShot
6 | from zenbase.optim.metric.types import CandidateEvalResult
7 | from zenbase.types import LMDemo, LMFunction, LMRequest, deflm
8 |
9 | lmfn = deflm(lambda x: x)
10 |
11 |
12 | demoset = [
13 | LMDemo(inputs={}, outputs={"output": "a"}),
14 | LMDemo(inputs={}, outputs={"output": "b"}),
15 | LMDemo(inputs={}, outputs={"output": "c"}),
16 | LMDemo(inputs={}, outputs={"output": "d"}),
17 | LMDemo(inputs={}, outputs={"output": "e"}),
18 | LMDemo(inputs={}, outputs={"output": "f"}),
19 | ]
20 |
21 |
22 | def test_invalid_shots():
23 | with pytest.raises(AssertionError):
24 | LabeledFewShot(demoset=demoset, shots=0)
25 | with pytest.raises(AssertionError):
26 | LabeledFewShot(demoset=demoset, shots=len(demoset) + 1)
27 |
28 |
29 | def test_idempotency():
30 | shots = 2
31 | samples = 5
32 |
33 | optim1 = LabeledFewShot(demoset=demoset, shots=shots)
34 | optim2 = LabeledFewShot(demoset=demoset, shots=shots)
35 | optim3 = LabeledFewShot(demoset=demoset, shots=shots, random=Random(41))
36 |
37 | set1 = list(optim1.candidates(lmfn, samples))
38 | set2 = list(optim2.candidates(lmfn, samples))
39 | set3 = list(optim3.candidates(lmfn, samples))
40 |
41 | assert set1 == set2
42 | assert set1 != set3
43 | assert set2 != set3
44 |
45 |
46 | @pytest.fixture
47 | def optim():
48 | return LabeledFewShot(demoset=demoset, shots=2)
49 |
50 |
51 | def test_candidate_generation(optim: LabeledFewShot):
52 | samples = 5
53 |
54 | candidates = list(optim.candidates(lmfn, samples))
55 |
56 | assert all(len(c.task_demos) == optim.shots for c in candidates)
57 | assert len(candidates) == samples
58 |
59 |
60 | @deflm
61 | def dummy_lmfn(_: LMRequest):
62 | return {"answer": 42}
63 |
64 |
65 | def dummy_evalfn(fn: LMFunction):
66 | return CandidateEvalResult(fn, {"score": random()})
67 |
68 |
69 | def test_training(optim: LabeledFewShot):
70 | # Train the dummy function
71 | trained_lmfn, candidates, best_candidate_result = optim.perform(
72 | dummy_lmfn,
73 | dummy_evalfn,
74 | rounds=1,
75 | concurrency=1,
76 | )
77 |
78 | # Check that the best function is returned
79 | best_function = max(candidates, key=lambda r: r.evals["score"]).function
80 | assert trained_lmfn == best_function
81 |
82 | for demo in trained_lmfn.zenbase.task_demos:
83 | assert demo in demoset
84 |
85 |
86 | @pytest.mark.anyio
87 | async def test_async_training(optim: LabeledFewShot):
88 | # Train the dummy function
89 | trained_dummy_lmfn, candidates, best_candidate_result = await optim.aperform(
90 | dummy_lmfn,
91 | dummy_evalfn,
92 | rounds=1,
93 | concurrency=1,
94 | )
95 |
96 | # Check that the best function is returned
97 | best_function = max(candidates, key=lambda r: r.evals["score"]).function
98 | assert trained_dummy_lmfn == best_function
99 |
--------------------------------------------------------------------------------
/py/tests/predefined/test_generic_lm_function_optimizer.py:
--------------------------------------------------------------------------------
1 | import instructor
2 | import pytest
3 | from instructor import Instructor
4 | from openai import OpenAI
5 | from pydantic import BaseModel
6 |
7 | from zenbase.core.managers import ZenbaseTracer
8 | from zenbase.predefined.generic_lm_function.optimizer import GenericLMFunctionOptimizer
9 |
10 |
11 | class InputModel(BaseModel):
12 | question: str
13 |
14 |
15 | class OutputModel(BaseModel):
16 | answer: str
17 |
18 |
19 | @pytest.fixture(scope="module")
20 | def openai_client() -> OpenAI:
21 | return OpenAI()
22 |
23 |
24 | @pytest.fixture(scope="module")
25 | def instructor_client(openai_client: OpenAI) -> Instructor:
26 | return instructor.from_openai(openai_client)
27 |
28 |
29 | @pytest.fixture(scope="module")
30 | def zenbase_tracer() -> ZenbaseTracer:
31 | return ZenbaseTracer()
32 |
33 |
34 | @pytest.fixture
35 | def generic_optimizer(instructor_client, zenbase_tracer):
36 | training_set = [
37 | {"inputs": {"question": "What is the capital of France?"}, "outputs": {"answer": "Paris"}},
38 | {"inputs": {"question": "Who wrote Romeo and Juliet?"}, "outputs": {"answer": "William Shakespeare"}},
39 | {"inputs": {"question": "What is the largest planet in our solar system?"}, "outputs": {"answer": "Jupiter"}},
40 | {"inputs": {"question": "Who painted the Mona Lisa?"}, "outputs": {"answer": "Leonardo da Vinci"}},
41 | {"inputs": {"question": "What is the chemical symbol for gold?"}, "outputs": {"answer": "Au"}},
42 | ]
43 |
44 | return GenericLMFunctionOptimizer(
45 | instructor_client=instructor_client,
46 | prompt="You are a helpful assistant. Answer the user's question concisely.",
47 | input_model=InputModel,
48 | output_model=OutputModel,
49 | model="gpt-4o-mini",
50 | zenbase_tracer=zenbase_tracer,
51 | training_set=training_set,
52 | validation_set=[
53 | {"inputs": {"question": "What is the capital of Italy?"}, "outputs": {"answer": "Rome"}},
54 | {"inputs": {"question": "What is the capital of France?"}, "outputs": {"answer": "Paris"}},
55 | {"inputs": {"question": "What is the capital of Germany?"}, "outputs": {"answer": "Berlin"}},
56 | ],
57 | test_set=[
58 | {"inputs": {"question": "Who invented the telephone?"}, "outputs": {"answer": "Alexander Graham Bell"}},
59 | {"inputs": {"question": "Who is CEO of microsoft?"}, "outputs": {"answer": "Bill Gates"}},
60 | {"inputs": {"question": "Who is founder of Facebook?"}, "outputs": {"answer": "Mark Zuckerberg"}},
61 | ],
62 | shots=len(training_set), # Set shots to the number of training examples
63 | )
64 |
65 |
66 | @pytest.mark.helpers
67 | def test_generic_optimizer_optimize(generic_optimizer):
68 | result = generic_optimizer.optimize()
69 | assert result is not None
70 | assert isinstance(result, GenericLMFunctionOptimizer.Result)
71 | assert result.best_function is not None
72 | assert callable(result.best_function)
73 | assert isinstance(result.candidate_results, list)
74 | assert result.best_candidate_result is not None
75 |
76 | # Check base evaluation
77 | assert generic_optimizer.base_evaluation is not None
78 |
79 | # Check best evaluation
80 | assert generic_optimizer.best_evaluation is not None
81 |
82 | # Test the best function
83 | test_input = InputModel(question="What is the capital of Italy?")
84 | output = result.best_function(test_input)
85 | assert isinstance(output, OutputModel)
86 | assert isinstance(output.answer, str)
87 | assert output.answer.strip().lower() == "rome"
88 |
89 |
90 | @pytest.mark.helpers
91 | def test_generic_optimizer_evaluations(generic_optimizer):
92 | result = generic_optimizer.optimize()
93 |
94 | # Check that base and best evaluations exist
95 | assert generic_optimizer.base_evaluation is not None
96 | assert generic_optimizer.best_evaluation is not None
97 |
98 | # Additional checks to ensure the structure of the result
99 | assert isinstance(result, GenericLMFunctionOptimizer.Result)
100 | assert result.best_function is not None
101 | assert isinstance(result.candidate_results, list)
102 | assert result.best_candidate_result is not None
103 |
104 |
105 | @pytest.mark.helpers
106 | def test_generic_optimizer_custom_evaluator(instructor_client, zenbase_tracer):
107 | def custom_evaluator(output: OutputModel, ideal_output: dict) -> dict:
108 | return {"passed": int(output.answer.lower() == ideal_output["answer"].lower()), "length": len(output.answer)}
109 |
110 | training_set = [
111 | {"inputs": {"question": "What is 2+2?"}, "outputs": {"answer": "4"}},
112 | {"inputs": {"question": "What is the capital of France?"}, "outputs": {"answer": "Paris"}},
113 | {"inputs": {"question": "Who wrote Romeo and Juliet?"}, "outputs": {"answer": "William Shakespeare"}},
114 | {"inputs": {"question": "What is the largest planet in our solar system?"}, "outputs": {"answer": "Jupiter"}},
115 | {"inputs": {"question": "Who painted the Mona Lisa?"}, "outputs": {"answer": "Leonardo da Vinci"}},
116 | ]
117 |
118 | optimizer = GenericLMFunctionOptimizer(
119 | instructor_client=instructor_client,
120 | prompt="You are a helpful assistant. Answer the user's question concisely.",
121 | input_model=InputModel,
122 | output_model=OutputModel,
123 | model="gpt-4o-mini",
124 | zenbase_tracer=zenbase_tracer,
125 | training_set=training_set,
126 | validation_set=[{"inputs": {"question": "What is 3+3?"}, "outputs": {"answer": "6"}}],
127 | test_set=[{"inputs": {"question": "What is 4+4?"}, "outputs": {"answer": "8"}}],
128 | custom_evaluator=custom_evaluator,
129 | shots=len(training_set), # Set shots to the number of training examples
130 | )
131 |
132 | result = optimizer.optimize()
133 | assert result is not None
134 | assert isinstance(result, GenericLMFunctionOptimizer.Result)
135 | assert "length" in optimizer.best_evaluation.individual_evals[0].details
136 |
137 | # Test the custom evaluator
138 | test_input = InputModel(question="What is 5+5?")
139 | output = result.best_function(test_input)
140 | assert isinstance(output, OutputModel)
141 | assert isinstance(output.answer, str)
142 |
143 | # Manually apply the custom evaluator
144 | eval_result = custom_evaluator(output, {"answer": "10"})
145 | assert "passed" in eval_result
146 | assert "length" in eval_result
147 |
148 |
149 | @pytest.mark.helpers
150 | def test_create_lm_function_with_demos(generic_optimizer):
151 | prompt = "You are a helpful assistant. Answer the user's question concisely."
152 | demos = [
153 | {"inputs": {"question": "What is the capital of France?"}, "outputs": {"answer": "Paris"}},
154 | {"inputs": {"question": "Who wrote Romeo and Juliet?"}, "outputs": {"answer": "William Shakespeare"}},
155 | ]
156 |
157 | lm_function = generic_optimizer.create_lm_function_with_demos(prompt, demos)
158 |
159 | # Test that the function is created and can be called
160 | test_input = InputModel(question="What is the capital of Italy?")
161 | result = lm_function(test_input)
162 |
163 | assert isinstance(result, OutputModel)
164 | assert isinstance(result.answer, str)
165 | assert result.answer.strip().lower() == "rome"
166 |
167 | # Test with a question from the demos
168 | test_input_demo = InputModel(question="What is the capital of France?")
169 | result_demo = lm_function(test_input_demo)
170 |
171 | assert isinstance(result_demo, OutputModel)
172 | assert isinstance(result_demo.answer, str)
173 | assert result_demo.answer.strip().lower() == "paris"
174 |
--------------------------------------------------------------------------------
/py/tests/predefined/test_single_class_classifier.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | import instructor
3 | import pandas as pd
4 | import pytest
5 | from instructor.client import Instructor
6 | from openai import OpenAI
7 | from pydantic import BaseModel
8 |
9 | from zenbase.core.managers import ZenbaseTracer
10 | from zenbase.predefined.single_class_classifier import SingleClassClassifier
11 | from zenbase.predefined.single_class_classifier.function_generator import SingleClassClassifierLMFunctionGenerator
12 |
13 | TRAINSET_SIZE = 100
14 | VALIDATIONSET_SIZE = 21
15 | TESTSET_SIZE = 21
16 |
17 |
18 | @pytest.fixture(scope="module")
19 | def prompt_definition() -> str:
20 | return """Your task is to accurately categorize each incoming news article into one of the given categories based
21 | on its title and content."""
22 |
23 |
24 | @pytest.fixture(scope="module")
25 | def class_dict() -> dict[str, str]:
26 | return {
27 | "Automobiles": "Discussions and news about automobiles, including car maintenance, driving experiences, "
28 | "and the latest automotive technology.",
29 | "Computers": "Topics related to computer hardware, software, graphics, cryptography, and operating systems, "
30 | "including troubleshooting and advancements.",
31 | "Science": "News and discussions about scientific topics including space exploration, medicine, and "
32 | "electronics.",
33 | "Politics": "Debates and news about political topics, including gun control, Middle Eastern politics,"
34 | " and miscellaneous political discussions.",
35 | "Religion": "Discussions about various religions, including beliefs, practices, atheism, and religious news.",
36 | "For Sale": "Classified ads for buying and selling miscellaneous items, from electronics to household goods.",
37 | "Sports": "Everything about sports, including discussions, news, player updates, and game analysis.",
38 | }
39 |
40 |
41 | @pytest.fixture(scope="module")
42 | def sample_news_article():
43 | return """title: New Advancements in Electric Vehicle Technology
44 | content: The automotive industry is witnessing a significant shift towards electric vehicles (EVs).
45 | Recent advancements in battery technology have led to increased range and reduced charging times.
46 | Companies like Tesla, Nissan, and BMW are at the forefront of this innovation, aiming to make EVs
47 | more accessible and efficient.
48 | With governments worldwide pushing for greener alternatives, the future of transportation looks
49 | electric."""
50 |
51 |
52 | def create_dataset_with_examples(item_set: list):
53 | return [{"inputs": item["text"], "outputs": convert_to_human_readable(item["label_text"])} for item in item_set]
54 |
55 |
56 | @pytest.fixture(scope="module")
57 | def openai_client() -> OpenAI:
58 | return OpenAI()
59 |
60 |
61 | @pytest.fixture(scope="module")
62 | def instructor_client(openai_client: OpenAI) -> Instructor:
63 | return instructor.from_openai(openai_client)
64 |
65 |
66 | @pytest.fixture(scope="module")
67 | def zenbase_tracer() -> ZenbaseTracer:
68 | return ZenbaseTracer()
69 |
70 |
71 | @pytest.fixture(scope="module")
72 | def single_class_classifier_generator(
73 | instructor_client: Instructor, prompt_definition: str, class_dict: dict[str, str], zenbase_tracer: ZenbaseTracer
74 | ) -> SingleClassClassifierLMFunctionGenerator:
75 | return SingleClassClassifierLMFunctionGenerator(
76 | instructor_client=instructor_client,
77 | prompt=prompt_definition,
78 | class_dict=class_dict,
79 | model="gpt-4o-mini",
80 | zenbase_tracer=zenbase_tracer,
81 | )
82 |
83 |
84 | @pytest.mark.helpers
85 | def test_single_class_classifier_lm_function_generator_initialization(
86 | single_class_classifier_generator: SingleClassClassifierLMFunctionGenerator,
87 | ):
88 | assert single_class_classifier_generator is not None
89 | assert single_class_classifier_generator.instructor_client
90 | None
91 | assert single_class_classifier_generator.class_dict is not None
92 | assert single_class_classifier_generator.model is not None
93 | assert single_class_classifier_generator.zenbase_tracer is not None
94 |
95 | # Check generated class enum and prediction class
96 | assert single_class_classifier_generator.class_enum is not None
97 | assert issubclass(single_class_classifier_generator.prediction_class, BaseModel)
98 |
99 |
100 | @pytest.mark.helpers
101 | def test_single_class_classifier_lm_function_generator_prediction(
102 | single_class_classifier_generator: SingleClassClassifierLMFunctionGenerator, sample_news_article
103 | ):
104 | result = single_class_classifier_generator.generate()(sample_news_article)
105 |
106 | assert result.class_label.name == "Automobiles"
107 | assert single_class_classifier_generator.zenbase_tracer.all_traces is not None
108 |
109 |
110 | @pytest.mark.helpers
111 | def test_single_class_classifier_lm_function_generator_with_missing_data(
112 | single_class_classifier_generator: SingleClassClassifierLMFunctionGenerator,
113 | ):
114 | faulty_email = {"subject": "Meeting Reminder"}
115 | result = single_class_classifier_generator.generate()(faulty_email)
116 | assert result.class_label.name in {
117 | "Automobiles",
118 | "Computers",
119 | "Science",
120 | "Politics",
121 | "Religion",
122 | "For Sale",
123 | "Sports",
124 | "Other",
125 | }
126 |
127 |
128 | def convert_to_human_readable(category: str) -> str:
129 | human_readable_map = {
130 | "rec.autos": "Automobiles",
131 | "comp.sys.mac.hardware": "Computers",
132 | "comp.graphics": "Computers",
133 | "sci.space": "Science",
134 | "talk.politics.guns": "Politics",
135 | "sci.med": "Science",
136 | "comp.sys.ibm.pc.hardware": "Computers",
137 | "comp.os.ms-windows.misc": "Computers",
138 | "rec.motorcycles": "Automobiles",
139 | "talk.religion.misc": "Religion",
140 | "misc.forsale": "For Sale",
141 | "alt.atheism": "Religion",
142 | "sci.electronics": "Computers",
143 | "comp.windows.x": "Computers",
144 | "rec.sport.hockey": "Sports",
145 | "rec.sport.baseball": "Sports",
146 | "soc.religion.christian": "Religion",
147 | "talk.politics.mideast": "Politics",
148 | "talk.politics.misc": "Politics",
149 | "sci.crypt": "Computers",
150 | }
151 | return human_readable_map.get(category)
152 |
153 |
154 | @pytest.fixture(scope="module")
155 | def get_balanced_dataset():
156 | # Load the dataset
157 | split = "train"
158 | train_size = TRAINSET_SIZE
159 | validation_size = VALIDATIONSET_SIZE
160 | test_size = TESTSET_SIZE
161 | dataset = datasets.load_dataset("SetFit/20_newsgroups", split=split)
162 |
163 | # Convert to pandas DataFrame for easier manipulation
164 | df = pd.DataFrame(dataset)
165 |
166 | # Convert text labels to human-readable labels
167 | df["human_readable_label"] = df["label_text"].apply(convert_to_human_readable)
168 |
169 | # Find text labels
170 | text_labels = df["human_readable_label"].unique()
171 |
172 | # Determine the number of labels
173 | num_labels = len(text_labels)
174 |
175 | # Calculate the number of samples per label for each subset
176 | train_samples_per_label = train_size // num_labels
177 | validation_samples_per_label = validation_size // num_labels
178 | test_samples_per_label = test_size // num_labels
179 |
180 | # Create empty DataFrames for train, validation, and test sets
181 | train_set = pd.DataFrame()
182 | validation_set = pd.DataFrame()
183 | test_set = pd.DataFrame()
184 |
185 | # Split sequentially without shuffling
186 | for label in text_labels:
187 | label_df = df[df["human_readable_label"] == label]
188 |
189 | # Ensure there's enough data
190 | if len(label_df) < (train_samples_per_label + validation_samples_per_label + test_samples_per_label):
191 | raise ValueError(f"Not enough data for label {label}")
192 |
193 | # Split the label-specific DataFrame
194 | label_train = label_df.iloc[:train_samples_per_label]
195 | label_validation = label_df.iloc[
196 | train_samples_per_label : train_samples_per_label + validation_samples_per_label
197 | ]
198 | label_test = label_df.iloc[
199 | train_samples_per_label + validation_samples_per_label : train_samples_per_label
200 | + validation_samples_per_label
201 | + test_samples_per_label
202 | ]
203 |
204 | # Append to the respective sets
205 | train_set = pd.concat([train_set, label_train])
206 | validation_set = pd.concat([validation_set, label_validation])
207 | test_set = pd.concat([test_set, label_test])
208 |
209 | # Create dataset with examples
210 | train_set = create_dataset_with_examples(train_set.to_dict("records"))
211 | validation_set = create_dataset_with_examples(validation_set.to_dict("records"))
212 | test_set = create_dataset_with_examples(test_set.to_dict("records"))
213 |
214 | return train_set, validation_set, test_set
215 |
216 |
217 | @pytest.fixture(scope="module")
218 | def single_class_classifier(
219 | instructor_client: Instructor,
220 | prompt_definition: str,
221 | class_dict: dict[str, str],
222 | zenbase_tracer: ZenbaseTracer,
223 | get_balanced_dataset,
224 | ) -> SingleClassClassifier:
225 | train_set, validation_set, test_set = get_balanced_dataset
226 | return SingleClassClassifier(
227 | instructor_client=instructor_client,
228 | prompt=prompt_definition,
229 | class_dict=class_dict,
230 | model="gpt-4o-mini",
231 | zenbase_tracer=zenbase_tracer,
232 | training_set=train_set,
233 | validation_set=validation_set,
234 | test_set=test_set,
235 | )
236 |
237 |
238 | @pytest.mark.helpers
239 | def test_single_class_classifier_perform(single_class_classifier: SingleClassClassifier, sample_news_article):
240 | result = single_class_classifier.optimize()
241 | assert all(
242 | [result.best_function, result.candidate_results, result.best_candidate_result]
243 | ), "Assertions failed for result properties"
244 | traces = single_class_classifier.zenbase_tracer.all_traces
245 | assert traces, "No traces found"
246 | assert result is not None, "Result should not be None"
247 | assert hasattr(result, "best_function"), "Result should have a best_function attribute"
248 | best_fn = result.best_function
249 | assert callable(best_fn), "best_function should be callable"
250 | output = best_fn(sample_news_article)
251 | assert output is not None, "output should not be None"
252 |
--------------------------------------------------------------------------------
/py/tests/predefined/test_single_class_classifier_syntethic_data.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import io
3 | from unittest.mock import Mock, patch
4 |
5 | import instructor
6 | import pytest
7 | from openai import OpenAI
8 |
9 | from zenbase.predefined.syntethic_data.single_class_classifier import (
10 | SingleClassClassifierSyntheticDataExample,
11 | SingleClassClassifierSyntheticDataGenerator,
12 | )
13 |
14 |
15 | @pytest.fixture(scope="module")
16 | def prompt_definition() -> str:
17 | return """Your task is to accurately categorize each incoming news article into one of the given categories based
18 | on its title and content."""
19 |
20 |
21 | @pytest.fixture(scope="module")
22 | def class_dict() -> dict[str, str]:
23 | return {
24 | "Automobiles": "Discussions and news about automobiles, including car maintenance, driving experiences, "
25 | "and the latest automotive technology.",
26 | "Computers": "Topics related to computer hardware, software, graphics, cryptography, and operating systems, "
27 | "including troubleshooting and advancements.",
28 | "Science": "News and discussions about scientific topics including space exploration, medicine, and "
29 | "electronics.",
30 | "Politics": "Debates and news about political topics, including gun control, Middle Eastern politics,"
31 | " and miscellaneous political discussions.",
32 | }
33 |
34 |
35 | @pytest.fixture(scope="module")
36 | def openai_client() -> OpenAI:
37 | return OpenAI()
38 |
39 |
40 | @pytest.fixture(scope="module")
41 | def instructor_client(openai_client: OpenAI) -> instructor.Instructor:
42 | return instructor.from_openai(openai_client)
43 |
44 |
45 | @pytest.fixture(scope="module")
46 | def synthetic_data_generator(
47 | instructor_client: instructor.Instructor,
48 | prompt_definition: str,
49 | class_dict: dict[str, str],
50 | ) -> SingleClassClassifierSyntheticDataGenerator:
51 | return SingleClassClassifierSyntheticDataGenerator(
52 | instructor_client=instructor_client,
53 | prompt=prompt_definition,
54 | class_dict=class_dict,
55 | model="gpt-4o-mini",
56 | )
57 |
58 |
59 | @pytest.mark.helpers
60 | def test_synthetic_data_generator_initialization(
61 | synthetic_data_generator: SingleClassClassifierSyntheticDataGenerator,
62 | ):
63 | assert synthetic_data_generator is not None
64 | assert synthetic_data_generator.instructor_client is not None
65 | assert synthetic_data_generator.prompt is not None
66 | assert synthetic_data_generator.class_dict is not None
67 | assert synthetic_data_generator.model is not None
68 |
69 |
70 | @pytest.mark.helpers
71 | def test_generate_examples_for_category(
72 | synthetic_data_generator: SingleClassClassifierSyntheticDataGenerator,
73 | ):
74 | category = "Automobiles"
75 | description = synthetic_data_generator.class_dict[category]
76 | num_examples = 5
77 | examples = synthetic_data_generator.generate_examples_for_category(category, description, num_examples)
78 |
79 | assert len(examples) == num_examples
80 | for example in examples:
81 | assert example.inputs is not None
82 | assert example.outputs == category
83 |
84 |
85 | @pytest.mark.helpers
86 | def test_generate_examples(
87 | synthetic_data_generator: SingleClassClassifierSyntheticDataGenerator,
88 | ):
89 | examples_per_category = 3
90 | all_examples = synthetic_data_generator.generate_examples(examples_per_category)
91 |
92 | assert len(all_examples) == examples_per_category * len(synthetic_data_generator.class_dict)
93 | for example in all_examples:
94 | assert example.inputs is not None
95 | assert example.outputs in synthetic_data_generator.class_dict.keys()
96 |
97 |
98 | @pytest.mark.helpers
99 | def test_generate_csv(
100 | synthetic_data_generator: SingleClassClassifierSyntheticDataGenerator,
101 | ):
102 | examples_per_category = 2
103 | csv_content = synthetic_data_generator.generate_csv(examples_per_category)
104 |
105 | csv_reader = csv.DictReader(io.StringIO(csv_content))
106 | rows = list(csv_reader)
107 |
108 | assert len(rows) == examples_per_category * len(synthetic_data_generator.class_dict)
109 | for row in rows:
110 | assert "inputs" in row
111 | assert "outputs" in row
112 | assert row["outputs"] in synthetic_data_generator.class_dict.keys()
113 |
114 |
115 | @pytest.mark.helpers
116 | def test_save_csv(
117 | synthetic_data_generator: SingleClassClassifierSyntheticDataGenerator,
118 | tmp_path,
119 | ):
120 | examples_per_category = 2
121 | file_path = tmp_path / "test_synthetic_data.csv"
122 | synthetic_data_generator.save_csv(str(file_path), examples_per_category)
123 |
124 | assert file_path.exists()
125 |
126 | with open(file_path, "r", newline="", encoding="utf-8") as f:
127 | csv_reader = csv.DictReader(f)
128 | rows = list(csv_reader)
129 |
130 | assert len(rows) == examples_per_category * len(synthetic_data_generator.class_dict)
131 | for row in rows:
132 | assert "inputs" in row
133 | assert "outputs" in row
134 | assert row["outputs"] in synthetic_data_generator.class_dict.keys()
135 |
136 |
137 | @pytest.fixture
138 | def mock_openai_client():
139 | mock_client = Mock(spec=OpenAI)
140 | mock_client.chat = Mock()
141 | mock_client.chat.completions = Mock()
142 | mock_client.chat.completions.create = Mock()
143 | return mock_client
144 |
145 |
146 | @pytest.fixture
147 | def mock_instructor_client(mock_openai_client):
148 | return instructor.from_openai(mock_openai_client)
149 |
150 |
151 | @pytest.fixture
152 | def mock_generator(mock_instructor_client, class_dict):
153 | return SingleClassClassifierSyntheticDataGenerator(
154 | instructor_client=mock_instructor_client,
155 | prompt="Classify the given text into one of the categories",
156 | class_dict=class_dict,
157 | model="gpt-4o-mini",
158 | )
159 |
160 |
161 | def mock_generate_examples(
162 | category: str, description: str, num: int
163 | ) -> list[SingleClassClassifierSyntheticDataExample]:
164 | return [
165 | SingleClassClassifierSyntheticDataExample(
166 | inputs=f"Sample text for {category} {i}: {description[:20]}...", outputs=category
167 | )
168 | for i in range(num)
169 | ]
170 |
171 |
172 | def test_generate_csv_mock(mock_generator):
173 | examples_per_category = 2
174 |
175 | with patch.object(mock_generator, "generate_examples_for_category", side_effect=mock_generate_examples):
176 | csv_content = mock_generator.generate_csv(examples_per_category)
177 |
178 | # Parse the CSV content
179 | reader = csv.DictReader(io.StringIO(csv_content))
180 | rows = list(reader)
181 |
182 | assert len(rows) == len(mock_generator.class_dict) * examples_per_category
183 | for row in rows:
184 | assert "inputs" in row
185 | assert "outputs" in row
186 | assert row["outputs"] in mock_generator.class_dict
187 |
188 |
189 | def test_save_csv_mock(mock_generator, tmp_path):
190 | examples_per_category = 2
191 | filename = tmp_path / "test_output.csv"
192 |
193 | with patch.object(mock_generator, "generate_examples_for_category", side_effect=mock_generate_examples):
194 | mock_generator.save_csv(str(filename), examples_per_category)
195 |
196 | assert filename.exists()
197 |
198 | with open(filename, "r", newline="", encoding="utf-8") as f:
199 | reader = csv.DictReader(f)
200 | rows = list(reader)
201 |
202 | assert len(rows) == len(mock_generator.class_dict) * examples_per_category
203 | for row in rows:
204 | assert "inputs" in row
205 | assert "outputs" in row
206 | assert row["outputs"] in mock_generator.class_dict
207 |
208 |
209 | def test_integration_mock(mock_generator):
210 | examples_per_category = 1
211 |
212 | def mock_create(**kwargs):
213 | category = kwargs["messages"][1]["content"].split("'")[1]
214 | description = next(desc for cat, desc in mock_generator.class_dict.items() if cat == category)
215 | return mock_generate_examples(category, description, examples_per_category)
216 |
217 | with patch.object(mock_generator.instructor_client.chat.completions, "create", side_effect=mock_create):
218 | csv_content = mock_generator.generate_csv(examples_per_category)
219 |
220 | reader = csv.DictReader(io.StringIO(csv_content))
221 | rows = list(reader)
222 |
223 | assert len(rows) == len(mock_generator.class_dict) * examples_per_category
224 | for row in rows:
225 | assert row["outputs"] in mock_generator.class_dict
226 | assert row["inputs"].startswith(f"Sample text for {row['outputs']}")
227 |
--------------------------------------------------------------------------------
/py/tests/sciprts/clean_up_langsmith.py:
--------------------------------------------------------------------------------
1 | from dotenv import load_dotenv
2 | from langsmith import Client
3 |
4 |
5 | def remove_all_datasets():
6 | # Initialize LangSmith client
7 | client = Client()
8 |
9 | # Fetch all datasets
10 | datasets = list(client.list_datasets())
11 |
12 | if not datasets:
13 | print("No datasets found in LangSmith.")
14 | return
15 |
16 | # Confirm with the user
17 | confirm = input(f"Are you sure you want to delete all {len(datasets)} datasets? (yes/no): ")
18 | if confirm.lower() != "yes":
19 | print("Operation cancelled.")
20 | return
21 |
22 | # Delete each dataset
23 | for dataset in datasets:
24 | try:
25 | client.delete_dataset(dataset_id=dataset.id)
26 | print(f"Deleted dataset: {dataset.name} (ID: {dataset.id})")
27 | except Exception as e:
28 | print(f"Error deleting dataset {dataset.name} (ID: {dataset.id}): {str(e)}")
29 |
30 | print("All datasets have been deleted.")
31 |
32 |
33 | # Run the function
34 | if __name__ == "__main__":
35 | load_dotenv("../../.env.test")
36 | remove_all_datasets()
37 |
--------------------------------------------------------------------------------
/py/tests/sciprts/convert_notebooks.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import nbformat
4 |
5 |
6 | def fix_notebooks_in_directory(directory):
7 | for root, dirs, files in os.walk(directory):
8 | for file in files:
9 | if file.endswith(".ipynb"):
10 | notebook_path = os.path.join(root, file)
11 | print(f"Fixing notebook: {notebook_path}")
12 |
13 | with open(notebook_path, "r", encoding="utf-8") as f:
14 | notebook = nbformat.read(f, as_version=4)
15 |
16 | for cell in notebook["cells"]:
17 | if cell["cell_type"] == "code" and "execution_count" not in cell:
18 | cell["execution_count"] = None
19 |
20 | with open(notebook_path, "w", encoding="utf-8") as f:
21 | nbformat.write(notebook, f)
22 |
23 | print(f"Fixed notebook: {notebook_path}")
24 |
25 |
26 | fix_notebooks_in_directory("../../cookbooks")
27 |
--------------------------------------------------------------------------------
/py/tests/test_types.py:
--------------------------------------------------------------------------------
1 | from zenbase.types import LMDemo, deflm
2 |
3 |
4 | def test_demo_eq():
5 | demoset = [
6 | LMDemo(inputs={}, outputs={"output": "a"}),
7 | LMDemo(inputs={}, outputs={"output": "b"}),
8 | ]
9 |
10 | # Structural inequality
11 | assert demoset[0] != demoset[1]
12 | # Structural equality
13 | assert demoset[0] == LMDemo(inputs={}, outputs={"output": "a"})
14 |
15 |
16 | def test_lm_function_refine():
17 | fn = deflm(lambda r: r.inputs)
18 | assert fn != fn.clean_and_duplicate()
19 |
--------------------------------------------------------------------------------