├── .github └── workflows │ ├── pypi-release.yml │ └── tests.yml ├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── lib.iml ├── misc.xml ├── modules.xml └── vcs.xml ├── .vscode └── settings.json ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── litellm ├── .env.sample ├── README.md ├── config.yaml └── docker-compose.yml └── py ├── .env.sample ├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── README.md ├── cookbooks ├── bootstrap_few_shot │ ├── arize.ipynb │ ├── langfuse.ipynb │ ├── langsmith.ipynb │ ├── lunary.ipynb │ └── parea.ipynb ├── labeled_few_shot │ ├── arize.ipynb │ ├── json.ipynb │ ├── langfuse.ipynb │ ├── langsmith.ipynb │ ├── lunary.ipynb │ └── parea.ipynb └── predefined │ ├── single_class_clasifier_trained_with_syn_data.ipynb │ ├── single_class_classifier.ipynb │ └── single_class_classifier_syn_data.ipynb ├── pyproject.toml ├── requirements-dev.lock ├── requirements.lock ├── src └── zenbase │ ├── __init__.py │ ├── adaptors │ ├── __init__.py │ ├── arize.py │ ├── arize │ │ ├── __init__.py │ │ ├── adaptor.py │ │ ├── dataset_helper.py │ │ └── evaluation_helper.py │ ├── base │ │ ├── adaptor.py │ │ ├── dataset_helper.py │ │ └── evaluation_helper.py │ ├── braintrust.py │ ├── json │ │ ├── adaptor.py │ │ ├── dataset_helper.py │ │ └── evaluation_helper.py │ ├── langchain │ │ ├── __init__.py │ │ ├── adaptor.py │ │ ├── dataset_helper.py │ │ └── evaluation_helper.py │ ├── langfuse_helper │ │ ├── __init__.py │ │ ├── adaptor.py │ │ ├── dataset_helper.py │ │ └── evaluation_helper.py │ ├── lunary │ │ ├── __init__.py │ │ ├── adaptor.py │ │ ├── dataset_helper.py │ │ └── evaluation_helper.py │ └── parea │ │ ├── __init__.py │ │ ├── adaptor.py │ │ ├── dataset_helper.py │ │ └── evaluation_helper.py │ ├── core │ └── managers.py │ ├── optim │ ├── base.py │ └── metric │ │ ├── bootstrap_few_shot.py │ │ ├── labeled_few_shot.py │ │ └── types.py │ ├── predefined │ ├── __init__.py │ ├── base │ │ ├── function_generator.py │ │ └── optimizer.py │ ├── generic_lm_function │ │ └── optimizer.py │ ├── single_class_classifier │ │ ├── __init__.py │ │ ├── classifier.py │ │ └── function_generator.py │ └── syntethic_data │ │ └── single_class_classifier.py │ ├── settings.py │ ├── types.py │ └── utils.py └── tests ├── adaptors ├── bootstrap_few_shot_optimizer_args.zenbase ├── parea_bootstrap_few_shot.zenbase ├── test_arize.py ├── test_braintrust.py ├── test_langchain.py ├── test_langfuse.py ├── test_lunary.py └── test_parea.py ├── conftest.py ├── core └── managers.py ├── optim └── metric │ ├── test_bootstrap_few_shot.py │ └── test_labeled_few_shot.py ├── predefined ├── test_generic_lm_function_optimizer.py ├── test_single_class_classifier.py └── test_single_class_classifier_syntethic_data.py ├── sciprts ├── clean_up_langsmith.py └── convert_notebooks.py └── test_types.py /.github/workflows/pypi-release.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.10"] 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: eifinger/setup-rye@v3 18 | id: setup-rye 19 | with: 20 | enable-cache: true 21 | working-directory: py 22 | cache-prefix: ${{ matrix.python-version }} 23 | - name: Pin python-version ${{ matrix.python-version }} 24 | working-directory: py 25 | run: rye pin ${{ matrix.python-version }} 26 | - name: Update Rye 27 | run: rye self update 28 | - name: Install dependencies 29 | working-directory: py 30 | run: rye sync --no-lock 31 | - name: Run Tests 32 | working-directory: py 33 | run: rye run pytest -v 34 | - name: Build package 35 | working-directory: py 36 | run: rye build 37 | - name: Publish to PyPI 38 | working-directory: py 39 | env: 40 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 41 | run: rye publish --token $PYPI_TOKEN --yes --skip-existing 42 | 43 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request_target: 8 | branches: 9 | - main 10 | - rc* 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.10", "3.11", "3.12"] 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - uses: eifinger/setup-rye@v3 22 | id: setup-rye 23 | with: 24 | enable-cache: true 25 | working-directory: py 26 | cache-prefix: ${{ matrix.python-version }} 27 | - name: Pin python-version ${{ matrix.python-version }} 28 | working-directory: py 29 | run: rye pin ${{ matrix.python-version }} 30 | - name: Install dependencies 31 | working-directory: py 32 | run: rye sync --no-lock 33 | - name: Check py directory contents 34 | run: ls -la py 35 | - name: Run Tests 36 | working-directory: py 37 | run: rye run pytest -v -n auto 38 | - name: Run Integration Tests 39 | working-directory: py 40 | run: rye run pytest -v -m helpers -n auto 41 | env: 42 | LANGCHAIN_API_KEY: ${{ secrets.LANGCHAIN_API_KEY }} 43 | LANGCHAIN_TRACING_V2: ${{ secrets.LANGCHAIN_TRACING_V2 }} 44 | LANGFUSE_HOST: ${{ secrets.LANGFUSE_HOST }} 45 | LANGFUSE_PUBLIC_KEY: ${{ secrets.LANGFUSE_PUBLIC_KEY }} 46 | LANGFUSE_SECRET_KEY: ${{ secrets.LANGFUSE_SECRET_KEY }} 47 | LANGSMITH_TEST_TRACKING: ${{ secrets.LANGSMITH_TEST_TRACKING }} 48 | LUNARY_PUBLIC_KEY: ${{ secrets.LUNARY_PUBLIC_KEY }} 49 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 50 | PAREA_API_KEY: ${{ secrets.PAREA_API_KEY }} 51 | BRAINTRUST_API_KEY: ${{ secrets.BRAINTRUST_API_KEY }} 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env* 2 | !.env.sample 3 | .DS_Store 4 | 5 | .idea/ 6 | # User-specific stuff 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/**/usage.statistics.xml 10 | .idea/**/dictionaries 11 | .idea/**/shelf 12 | 13 | litellm/redis-data -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/lib.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 15 | 16 | 18 | 19 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.exclude": { 3 | "**/.git": true, 4 | "**/.svn": true, 5 | "**/.hg": true, 6 | "**/CVS": true, 7 | "**/.DS_Store": true, 8 | "**/Thumbs.db": true, 9 | "py/.venv/": true, 10 | "py/**/__pycache__": true, 11 | "py/.pytest_cache/": true 12 | }, 13 | "python.defaultInterpreterPath": "py/.venv/bin/python", 14 | "python.analysis.extraPaths": [ 15 | "./py/src" 16 | ], 17 | "python.testing.pytestArgs": [ 18 | "py" 19 | ], 20 | "python.testing.unittestEnabled": false, 21 | "python.testing.pytestEnabled": true 22 | } 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guidelines 2 | 3 | Thank you for your interest in contributing to our project! Any and all contributions are welcome. This document provides some basic guidelines for contributing. 4 | 5 | ## Bug Reports 6 | 7 | If you find a bug in the code, you can help us by submitting an issue to our GitHub Repository. Even better, you can submit a Pull Request with a fix. 8 | 9 | ## Feature Requests 10 | 11 | If you'd like to make a feature request, you can do so by submitting an issue to our GitHub Repository. You can also submit a Pull Request to help implement that feature. 12 | 13 | ## Pull Requests 14 | 15 | Here are some guidelines for submitting a pull request: 16 | 17 | - Fork the repository and clone it locally. Connect your local repository to the original `upstream` repository by adding it as a remote. Pull in changes from `upstream` often so that you stay up to date with the latest code. 18 | - Create a branch for your edits. 19 | - Be clear about what problem is occurring and how someone can reproduce that problem or why your feature will help. Then be equally as clear about the steps you took to make your changes. 20 | - It's always best to consult the existing issues before creating a new one. 21 | 22 | ## Questions? 23 | 24 | If you have any questions, please feel free to contact us. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 zenbase-ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 | MIT License 4 | Y Combinator S24 5 | zenbase Python package on PyPi 6 |
7 |
8 | 9 | # Zenbase Core —  Prompt engineering, automated. 10 | 11 | Zenbase Core is the library for programming — not prompting — AI in production. It is a spin-out of [Stanford NLP's DSPy](https://github.com/stanfordnlp/dspy) project started by key contributors. DSPy is an excellent framework for R&D. Zenbase strives to be its brother for software engineering. 12 | 13 | ## Key Points 14 | 15 | - **DSPy** is optimized for research and development, providing tools specifically designed for deep exploration and optimization of AI systems. You run experiments with AI systems and report the results. 16 | - **Zenbase** focuses on adapting these research advancements to practical software engineering needs. Optimizers can be integrated into existing systems. We're focused on helping you deploy automatic prompt optimization in production. 17 | 18 | ## Quick Start with Python 19 | 20 | ```bash 21 | pip install zenbase 22 | ``` 23 | 24 | For more information take a look at the [Zenbase Core Python README](./py/README.md). 25 | 26 | ## License 27 | 28 | This project is licensed under the MIT License. See the [LICENSE](./LICENSE) file for more information. 29 | 30 | 31 | ## Contribution 32 | 33 | Thank you for your interest in contributing to our project! Any and all contributions are welcome. See [CONTRIBUTING](./CONTRIBUTING.md) for more information. 34 | -------------------------------------------------------------------------------- /litellm/.env.sample: -------------------------------------------------------------------------------- 1 | REDIS_HOST=redis 2 | REDIS_PORT=6379 3 | REDIS_SSL=False 4 | OPENAI_API_KEY="YOUR_API_KEY" 5 | -------------------------------------------------------------------------------- /litellm/README.md: -------------------------------------------------------------------------------- 1 | ## LiteLLM Cache 2 | 3 | LiteLLM Cache is a proxy server designed to cache your LLM requests, helping to reduce costs and improve efficiency. 4 | 5 | ### Requirements 6 | - Docker Compose 7 | - Docker 8 | 9 | ### Setup Instructions 10 | 11 | 1. **Configure Settings:** 12 | - Navigate to `./config.yaml` and update the configuration as per your requirements. For more information, visit [LiteLLM Documentation](https://litellm.vercel.app/). 13 | 14 | 2. **Prepare Environment Variables:** 15 | - Create a `.env` file from the `.env.sample` file. Adjust the details in `.env` to match your `config.yaml` settings. 16 | 17 | 3. **Start the Docker Container:** 18 | ```bash 19 | docker-compose up -d 20 | ``` 21 | 22 | 4. **Update Your LLM Server URL:** 23 | - Change the LLM calling server URL in your application to `http://0.0.0.0:4000`. 24 | 25 | For example, using the OpenAI Python SDK: 26 | ```python 27 | from openai import OpenAI 28 | 29 | llm = OpenAI( 30 | base_url='http://0.0.0.0:4000' 31 | ) 32 | ``` 33 | 34 | With these steps, your LLM requests will be routed through the LiteLLM Cache proxy server, optimizing performance and reducing costs. -------------------------------------------------------------------------------- /litellm/config.yaml: -------------------------------------------------------------------------------- 1 | model_list: 2 | - model_name: gpt-3.5-turbo 3 | litellm_params: 4 | model: openai/gpt-3.5-turbo # The `openai/` prefix will call openai.chat.completions.create 5 | api_key: os.environ/OPENAI_API_KEY # The `os.environ/` prefix will call os.environ.get 6 | 7 | router_settings: 8 | redis_host: redis 9 | redis_port: 6379 -------------------------------------------------------------------------------- /litellm/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | litellm: 5 | image: ghcr.io/berriai/litellm:main-latest 6 | ports: 7 | - "4000:4000" 8 | volumes: 9 | - ./config.yaml:/app/config.yaml # Mount the local configuration file 10 | command: [ "--config", "/app/config.yaml", "--port", "4000", "--num_workers", "8", "--detailed_debug"] 11 | env_file: 12 | - .env 13 | 14 | redis: 15 | image: redis:alpine 16 | ports: 17 | - "6379:6379" 18 | volumes: 19 | - ./redis-data:/data 20 | command: redis-server --appendonly yes 21 | restart: always 22 | -------------------------------------------------------------------------------- /py/.env.sample: -------------------------------------------------------------------------------- 1 | LANGCHAIN_API_KEY="YOUR-LANGCHAIN-API-KEY" 2 | LANGCHAIN_TRACING_V2=true 3 | LANGFUSE_HOST="https://us.cloud.langfuse.com" 4 | LANGFUSE_PUBLIC_KEY="YOUR-LANGFUSE-PUBLIC-KEY" 5 | LANGFUSE_SECRET_KEY="YOUR-LANGFUSE-SECRET-KEY" 6 | LANGSMITH_TEST_TRACKING=tests/cache/ 7 | LUNARY_PUBLIC_KEY="YOUR-LUNARY-PUBLIC-KEY" 8 | OPENAI_API_KEY="YOUR-OPENAI-API" 9 | PAREA_API_KEY="YOUR-PAREA-API-KEY" 10 | BRAINTRUST_API_KEY="YOUR-BRAINTRUST-API-KEY" -------------------------------------------------------------------------------- /py/.gitignore: -------------------------------------------------------------------------------- 1 | # python generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | wheels/ 7 | *.egg-info 8 | 9 | # venv 10 | .venv 11 | .idea/ 12 | -------------------------------------------------------------------------------- /py/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: rye-lint 5 | name: rye-lint 6 | description: "Lint Python via 'rye lint'" 7 | entry: sh -c 'cd py && rye lint --fix' 8 | language: system 9 | types_or: [python, pyi] 10 | files: ^py/ 11 | args: [] 12 | require_serial: true 13 | additional_dependencies: [] 14 | minimum_pre_commit_version: "2.9.2" 15 | 16 | - id: rye-format 17 | name: rye-format 18 | description: "Format Python via 'rye fmt'" 19 | entry: sh -c 'cd py && rye fmt' 20 | language: system 21 | types_or: [python, pyi] 22 | files: ^py/ 23 | args: [] 24 | require_serial: true 25 | additional_dependencies: [] 26 | minimum_pre_commit_version: "2.9.2" 27 | 28 | - id: rye-test 29 | name: rye-test 30 | description: "Test Python via 'rye test'" 31 | entry: sh -c 'cd py && rye test' 32 | language: system 33 | types_or: [python, pyi] 34 | files: ^py/ 35 | args: [] 36 | pass_filenames: false 37 | require_serial: true 38 | additional_dependencies: [] 39 | minimum_pre_commit_version: "2.9.2" 40 | -------------------------------------------------------------------------------- /py/.python-version: -------------------------------------------------------------------------------- 1 | 3.10.13 2 | -------------------------------------------------------------------------------- /py/README.md: -------------------------------------------------------------------------------- 1 | # Zenbase Python SDK 2 | 3 | ## Installation 4 | 5 | zenbase Python package on PyPi 6 | 7 | Zenbase requires Python ≥3.10. You can install it using your favorite package manager: 8 | 9 | ```bash 10 | pip install zenbase 11 | poetry add zenbase 12 | rye add zenbase 13 | ``` 14 | 15 | ## Usage 16 | 17 | Zenbase is designed to require minimal changes to your existing codebase and integrate seamlessly with your existing eval/observability platforms. It works with any AI SDK (OpenAI, Anthropic, Cohere, Langchain, etc.). 18 | 19 | 20 | ### Labeled Few-Shot Learning Cookbooks: 21 | 22 | LabeledFewShot will be useful for tasks that are just one layer of prompts. 23 | 24 | | Cookbook | Run in Colab | 25 | |---------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 26 | | [without_integration.ipynb](cookbooks/labeled_few_shot/json.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/json.ipynb) | 27 | | [langsmith.ipynb](cookbooks/labeled_few_shot/langsmith.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/langsmith.ipynb) | 28 | | [arize.ipynb](cookbooks/labeled_few_shot/arize.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/arize.ipynb) | 29 | | [langfuse.ipynb](cookbooks/labeled_few_shot/langfuse.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/langfuse.ipynb) | 30 | | [parea.ipynb](cookbooks/labeled_few_shot/parea.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/parea.ipynb) | 31 | | [lunary.ipynb](cookbooks/labeled_few_shot/lunary.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/labeled_few_shot/lunary.ipynb) | 32 | 33 | ### Bootstrap Few-Shot Learning Cookbooks: 34 | 35 | BootstrapFewShot will be useful for tasks that are multiple layers of prompts. 36 | 37 | | Cookbook | Run in Colab | 38 | |-----------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 39 | | [langsmith.ipynb](cookbooks/bootstrap_few_shot/langsmith.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/bootstrap_few_shot/langsmith.ipynb) | 40 | | [arize.ipynb](cookbooks/bootstrap_few_shot/arize.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/bootstrap_few_shot/arize.ipynb) | 41 | | [langfuse.ipynb](cookbooks/bootstrap_few_shot/langfuse.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/bootstrap_few_shot/langfuse.ipynb) | 42 | | [parea.ipynb](cookbooks/bootstrap_few_shot/parea.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/bootstrap_few_shot/parea.ipynb) | 43 | | [lunary.ipynb](cookbooks/bootstrap_few_shot/lunary.ipynb) | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/bootstrap_few_shot/lunary.ipynb) | 44 | 45 | 46 | ### Predefined Cookbooks: 47 | 48 | | Cookbook | Description | Run in Colab | 49 | |-------------------------------------------------------------------------------------|---------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 50 | | [Single Class Classifier](cookbooks/predefined/single_class_classifier.ipynb) | Basic implementation of a single class classifier | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/predefined/single_class_classifier.ipynb) | 51 | | [Synthetic Data Generation](cookbooks/predefined/single_class_classifier_syn_data.ipynb) | Generate synthetic data for single class classification | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/predefined/single_class_classifier_syn_data.ipynb) | 52 | | [Classifier with Synthetic Data](cookbooks/predefined/single_class_clasifier_trained_with_syn_data.ipynb) | Train and test a single class classifier using synthetic data | [](https://colab.research.google.com/github/zenbase-ai/core/blob/main/py/cookbooks/predefined/single_class_clasifier_trained_with_syn_data.ipynb) | 53 | 54 | 55 | ## Development setup 56 | 57 | This repo uses Python 3.10 and [rye](https://rye.astral.sh/) to manage dependencies. Once you've gotten rye installed, you can install dependencies by running: 58 | 59 | ```bash 60 | rye sync 61 | ``` 62 | 63 | And activate the virtualenv with: 64 | 65 | ```bash 66 | . .venv/bin/activate 67 | ``` 68 | 69 | You can run tests with: 70 | 71 | ```bash 72 | rye test # pytest -sv to see prints and verbose output 73 | rye test -- -m helpers # integration tests with helpers 74 | ``` 75 | -------------------------------------------------------------------------------- /py/cookbooks/predefined/single_class_clasifier_trained_with_syn_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Single Class Classifier with Synthetic Data Generation\n", 8 | "\n", 9 | "This notebook demonstrates how to use the `SingleClassClassifierSyntheticDataGenerator` to create a synthetic dataset, and then use that dataset to train and test a `SingleClassClassifier`." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Setup and Imports" 17 | ] 18 | }, 19 | { 20 | "metadata": {}, 21 | "cell_type": "markdown", 22 | "source": "### Import the Zenbase Library" 23 | }, 24 | { 25 | "metadata": { 26 | "ExecuteTime": { 27 | "end_time": "2024-07-25T01:06:59.016591Z", 28 | "start_time": "2024-07-25T01:06:59.011483Z" 29 | } 30 | }, 31 | "cell_type": "code", 32 | "source": [ 33 | "import sys\n", 34 | "import subprocess\n", 35 | "\n", 36 | "def install_package(package):\n", 37 | " try:\n", 38 | " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", package])\n", 39 | " except subprocess.CalledProcessError as e:\n", 40 | " print(f\"Failed to install {package}: {e}\")\n", 41 | " raise\n", 42 | "\n", 43 | "def install_packages(packages):\n", 44 | " for package in packages:\n", 45 | " install_package(package)\n", 46 | "\n", 47 | "try:\n", 48 | " # Check if running in Google Colab\n", 49 | " import google.colab\n", 50 | " IN_COLAB = True\n", 51 | "except ImportError:\n", 52 | " IN_COLAB = False\n", 53 | "\n", 54 | "if IN_COLAB:\n", 55 | " # Install the zenbase package if running in Google Colab\n", 56 | " # install_package('zenbase')\n", 57 | " # Install the zenbse package from a GitHub branch if running in Google Colab\n", 58 | " install_package('git+https://github.com/zenbase-ai/lib.git@main#egg=zenbase&subdirectory=py')\n", 59 | "\n", 60 | " # List of other packages to install in Google Colab\n", 61 | " additional_packages = [\n", 62 | " 'python-dotenv',\n", 63 | " 'openai',\n", 64 | " 'langchain',\n", 65 | " 'langchain_openai',\n", 66 | " 'instructor',\n", 67 | " 'datasets'\n", 68 | " ]\n", 69 | " \n", 70 | " # Install additional packages\n", 71 | " install_packages(additional_packages)\n", 72 | "\n", 73 | "# Now import the zenbase library\n", 74 | "try:\n", 75 | " import zenbase\n", 76 | "except ImportError as e:\n", 77 | " print(\"Failed to import zenbase: \", e)\n", 78 | " raise" 79 | ], 80 | "outputs": [], 81 | "execution_count": 8 82 | }, 83 | { 84 | "metadata": {}, 85 | "cell_type": "markdown", 86 | "source": "### Configure the Environment" 87 | }, 88 | { 89 | "metadata": { 90 | "ExecuteTime": { 91 | "end_time": "2024-07-25T01:06:59.027812Z", 92 | "start_time": "2024-07-25T01:06:59.023030Z" 93 | } 94 | }, 95 | "cell_type": "code", 96 | "source": [ 97 | "from pathlib import Path\n", 98 | "from dotenv import load_dotenv\n", 99 | "\n", 100 | "# import os\n", 101 | "#\n", 102 | "# os.environ[\"OPENAI_API_KEY\"] = \"...\"\n", 103 | "\n", 104 | "load_dotenv(Path(\"../../.env.test\"), override=True)" 105 | ], 106 | "outputs": [ 107 | { 108 | "data": { 109 | "text/plain": [ 110 | "True" 111 | ] 112 | }, 113 | "execution_count": 9, 114 | "metadata": {}, 115 | "output_type": "execute_result" 116 | } 117 | ], 118 | "execution_count": 9 119 | }, 120 | { 121 | "cell_type": "code", 122 | "metadata": { 123 | "ExecuteTime": { 124 | "end_time": "2024-07-25T01:06:59.051900Z", 125 | "start_time": "2024-07-25T01:06:59.029133Z" 126 | } 127 | }, 128 | "source": [ 129 | "import sys\n", 130 | "import subprocess\n", 131 | "import instructor\n", 132 | "from openai import OpenAI\n", 133 | "from zenbase.core.managers import ZenbaseTracer\n", 134 | "from zenbase.predefined.single_class_classifier import SingleClassClassifier\n", 135 | "from zenbase.predefined.syntethic_data.single_class_classifier import SingleClassClassifierSyntheticDataGenerator\n", 136 | "\n", 137 | "# Set up OpenAI and Instructor clients\n", 138 | "openai_client = OpenAI()\n", 139 | "instructor_client = instructor.from_openai(openai_client)\n", 140 | "zenbase_tracer = ZenbaseTracer()" 141 | ], 142 | "outputs": [], 143 | "execution_count": 10 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "## Define Classification Task" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "metadata": { 155 | "ExecuteTime": { 156 | "end_time": "2024-07-25T01:06:59.054916Z", 157 | "start_time": "2024-07-25T01:06:59.052873Z" 158 | } 159 | }, 160 | "source": [ 161 | "prompt_definition = \"\"\"Your task is to accurately categorize each incoming news article into one of the given categories based on its title and content.\"\"\"\n", 162 | "\n", 163 | "class_dict = {\n", 164 | " \"Automobiles\": \"Discussions and news about automobiles, including car maintenance, driving experiences, and the latest automotive technology.\",\n", 165 | " \"Computers\": \"Topics related to computer hardware, software, graphics, cryptography, and operating systems, including troubleshooting and advancements.\",\n", 166 | " \"Science\": \"News and discussions about scientific topics including space exploration, medicine, and electronics.\",\n", 167 | " \"Politics\": \"Debates and news about political topics, including gun control, Middle Eastern politics, and miscellaneous political discussions.\",\n", 168 | "}" 169 | ], 170 | "outputs": [], 171 | "execution_count": 11 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "## Generate Synthetic Data" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "metadata": { 183 | "ExecuteTime": { 184 | "end_time": "2024-07-25T01:08:00.370454Z", 185 | "start_time": "2024-07-25T01:06:59.056722Z" 186 | } 187 | }, 188 | "source": [ 189 | "# Set up the generator\n", 190 | "generator = SingleClassClassifierSyntheticDataGenerator(\n", 191 | " instructor_client=instructor_client,\n", 192 | " prompt=prompt_definition,\n", 193 | " class_dict=class_dict,\n", 194 | " model=\"gpt-4o-mini\"\n", 195 | ")\n", 196 | "\n", 197 | "# Define the number of examples per category for each set\n", 198 | "train_examples_per_category = 10\n", 199 | "val_examples_per_category = 3\n", 200 | "test_examples_per_category = 3\n", 201 | "\n", 202 | "# Generate train set\n", 203 | "train_examples = generator.generate_examples(train_examples_per_category)\n", 204 | "print(f\"Generated {len(train_examples)} examples for the train set.\\n\")\n", 205 | "\n", 206 | "# Generate validation set\n", 207 | "val_examples = generator.generate_examples(val_examples_per_category)\n", 208 | "print(f\"Generated {len(val_examples)} examples for the validation set.\\n\")\n", 209 | "\n", 210 | "# Generate test set\n", 211 | "test_examples = generator.generate_examples(test_examples_per_category)\n", 212 | "print(f\"Generated {len(test_examples)} examples for the test set.\\n\")" 213 | ], 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "Generated 40 examples for the train set.\n", 220 | "\n", 221 | "Generated 12 examples for the validation set.\n", 222 | "\n", 223 | "Generated 12 examples for the test set.\n", 224 | "\n" 225 | ] 226 | } 227 | ], 228 | "execution_count": 12 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": {}, 233 | "source": [ 234 | "## Create and Train the Classifier" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "metadata": { 240 | "ExecuteTime": { 241 | "end_time": "2024-07-25T01:11:27.692383Z", 242 | "start_time": "2024-07-25T01:08:00.374312Z" 243 | } 244 | }, 245 | "source": [ 246 | "classifier = SingleClassClassifier(\n", 247 | " instructor_client=instructor_client,\n", 248 | " prompt=prompt_definition,\n", 249 | " class_dict=class_dict,\n", 250 | " model=\"gpt-4o-mini\",\n", 251 | " zenbase_tracer=zenbase_tracer,\n", 252 | " training_set=train_examples,\n", 253 | " validation_set=val_examples,\n", 254 | " test_set=test_examples,\n", 255 | " samples=20,\n", 256 | ")\n", 257 | "\n", 258 | "result = classifier.optimize()" 259 | ], 260 | "outputs": [], 261 | "execution_count": 13 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "## Analyze Results" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "metadata": { 273 | "ExecuteTime": { 274 | "end_time": "2024-07-25T01:11:27.696612Z", 275 | "start_time": "2024-07-25T01:11:27.693579Z" 276 | } 277 | }, 278 | "source": [ 279 | "print(\"Base Evaluation Score:\", classifier.base_evaluation.evals['score'])\n", 280 | "print(\"Best Evaluation Score:\", classifier.best_evaluation.evals['score'])\n", 281 | "\n", 282 | "print(\"\\nBest function:\", result.best_function)\n", 283 | "print(\"Number of candidate results:\", len(result.candidate_results))\n", 284 | "print(\"Best candidate result:\", result.best_candidate_result.evals)" 285 | ], 286 | "outputs": [ 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | "Base Evaluation Score: 0.9166666666666666\n", 292 | "Best Evaluation Score: 0.9166666666666666\n", 293 | "\n", 294 | "Best function: \n", 295 | "Number of candidate results: 20\n", 296 | "Best candidate result: {'score': 0.9166666666666666}\n", 297 | "Number of traces: 264\n" 298 | ] 299 | } 300 | ], 301 | "execution_count": 14 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": {}, 306 | "source": [ 307 | "## Test the Classifier" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "metadata": { 313 | "ExecuteTime": { 314 | "end_time": "2024-07-25T01:11:28.490185Z", 315 | "start_time": "2024-07-25T01:11:27.697387Z" 316 | } 317 | }, 318 | "source": [ 319 | "new_article = \"\"\"\n", 320 | "title: Revolutionary Quantum Computer Achieves Milestone in Cryptography\n", 321 | "content: Scientists at a leading tech company have announced a breakthrough in quantum computing, \n", 322 | "demonstrating a quantum computer capable of solving complex cryptographic problems in record time. \n", 323 | "This development has significant implications for data security and could revolutionize fields \n", 324 | "ranging from finance to national security. However, experts warn that it also poses potential \n", 325 | "risks to current encryption methods.\n", 326 | "\"\"\"\n", 327 | "\n", 328 | "classification = result.best_function(new_article)\n", 329 | "print(f\"The article is classified as: {classification.class_label.name}\")" 330 | ], 331 | "outputs": [ 332 | { 333 | "name": "stdout", 334 | "output_type": "stream", 335 | "text": [ 336 | "The article is classified as: Computers\n" 337 | ] 338 | } 339 | ], 340 | "execution_count": 15 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "## Conclusion\n", 347 | "\n", 348 | "In this notebook, we've demonstrated how to:\n", 349 | "1. Generate synthetic data for a single-class classification task\n", 350 | "2. Prepare the synthetic data for training and testing\n", 351 | "3. Create and train a SingleClassClassifier using the synthetic data\n", 352 | "4. Analyze the results of the classifier\n", 353 | "5. Use the trained classifier to categorize new input\n", 354 | "\n", 355 | "This approach allows for rapid prototyping and testing of classification models, especially in scenarios where real-world labeled data might be scarce or difficult to obtain." 356 | ] 357 | } 358 | ], 359 | "metadata": { 360 | "kernelspec": { 361 | "display_name": "Python 3", 362 | "language": "python", 363 | "name": "python3" 364 | }, 365 | "language_info": { 366 | "codemirror_mode": { 367 | "name": "ipython", 368 | "version": 3 369 | }, 370 | "file_extension": ".py", 371 | "mimetype": "text/x-python", 372 | "name": "python", 373 | "nbconvert_exporter": "python", 374 | "pygments_lexer": "ipython3", 375 | "version": "3.8.0" 376 | } 377 | }, 378 | "nbformat": 4, 379 | "nbformat_minor": 4 380 | } 381 | -------------------------------------------------------------------------------- /py/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "zenbase" 3 | version = "0.0.22" 4 | description = "LLMs made Zen" 5 | authors = [{ name = "Cyrus Nouroozi", email = "cyrus@zenbase.ai" }] 6 | dependencies = [ 7 | "pksuid>=1.1.2", 8 | "faker>=24.2.0", 9 | "anyio>=4.4.0", 10 | "opentelemetry-sdk>=1.25.0", 11 | "opentelemetry-api>=1.25.0", 12 | "structlog>=24.2.0", 13 | "pyee>=11.1.0", 14 | "posthog>=3.5.0", 15 | "cloudpickle>=3.0.0", 16 | "instructor>=1.3.5", 17 | ] 18 | readme = "README.md" 19 | requires-python = ">= 3.10" 20 | lisence = "MIT" 21 | 22 | [build-system] 23 | requires = ["hatchling"] 24 | build-backend = "hatchling.build" 25 | 26 | [tool.mypy] 27 | namespace_packages = true 28 | 29 | [tool.pytest.ini_options] 30 | markers = ["helpers"] 31 | addopts = "-m 'not helpers'" 32 | 33 | [tool.rye] 34 | managed = true 35 | dev-dependencies = [ 36 | "pytest>=8.2.1", 37 | "ruff>=0.4.6", 38 | "langsmith[vcr]>=0.1.72", 39 | "pyright>=1.1.365", 40 | "datasets>=2.19.1", 41 | "ipython>=8.24.0", 42 | "ipdb>=0.13.13", 43 | "openai>=1.30.5", 44 | "pytest-recording>=0.13.1", 45 | "python-dotenv>=1.0.1", 46 | "vcrpy>=6.0.1", 47 | "arize-phoenix[evals]>=4.9.0", 48 | "nest-asyncio>=1.6.0", 49 | "langchain-openai>=0.1.8", 50 | "langchain-core>=0.2.3", 51 | "langchain>=0.2.1", 52 | "parea-ai>=0.2.164", 53 | "langfuse>=2.35.0", 54 | "lunary>=1.0.30", 55 | "autoevals>=0.0.68", 56 | "braintrust>=0.0.131", 57 | "pre-commit>=3.7.1", 58 | "pytest-xdist>=3.6.1", 59 | "openai-responses>=0.8.1", 60 | ] 61 | 62 | [tool.hatch.metadata] 63 | allow-direct-references = true 64 | 65 | [tool.hatch.build.targets.wheel] 66 | packages = ["src/zenbase"] 67 | 68 | [tool.ruff] 69 | exclude = [ 70 | "venv", 71 | ".git", 72 | "__pycache__", 73 | "build", 74 | "dist", 75 | "venv", 76 | ] 77 | line-length = 120 78 | src = ["src", "tests"] 79 | 80 | [tool.ruff.lint] 81 | ignore = [] 82 | select = [ 83 | "E", 84 | "F", 85 | "W", 86 | "I001", 87 | ] 88 | -------------------------------------------------------------------------------- /py/requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | # with-sources: false 9 | # generate-hashes: false 10 | 11 | -e file:. 12 | aiohttp==3.9.5 13 | # via instructor 14 | aiosignal==1.3.1 15 | # via aiohttp 16 | annotated-types==0.7.0 17 | # via pydantic 18 | anyio==4.4.0 19 | # via httpx 20 | # via openai 21 | # via zenbase 22 | async-timeout==4.0.3 23 | # via aiohttp 24 | attrs==23.2.0 25 | # via aiohttp 26 | backoff==2.2.1 27 | # via posthog 28 | certifi==2024.6.2 29 | # via httpcore 30 | # via httpx 31 | # via requests 32 | charset-normalizer==3.3.2 33 | # via requests 34 | click==8.1.7 35 | # via typer 36 | cloudpickle==3.0.0 37 | # via zenbase 38 | deprecated==1.2.14 39 | # via opentelemetry-api 40 | distro==1.9.0 41 | # via openai 42 | docstring-parser==0.16 43 | # via instructor 44 | exceptiongroup==1.2.1 45 | # via anyio 46 | faker==26.0.0 47 | # via zenbase 48 | frozenlist==1.4.1 49 | # via aiohttp 50 | # via aiosignal 51 | h11==0.14.0 52 | # via httpcore 53 | httpcore==1.0.5 54 | # via httpx 55 | httpx==0.27.0 56 | # via openai 57 | idna==3.7 58 | # via anyio 59 | # via httpx 60 | # via requests 61 | # via yarl 62 | importlib-metadata==7.1.0 63 | # via opentelemetry-api 64 | instructor==1.3.5 65 | # via zenbase 66 | jiter==0.4.2 67 | # via instructor 68 | # via openai 69 | markdown-it-py==3.0.0 70 | # via rich 71 | mdurl==0.1.2 72 | # via markdown-it-py 73 | monotonic==1.6 74 | # via posthog 75 | multidict==6.0.5 76 | # via aiohttp 77 | # via yarl 78 | openai==1.43.0 79 | # via instructor 80 | opentelemetry-api==1.25.0 81 | # via opentelemetry-sdk 82 | # via opentelemetry-semantic-conventions 83 | # via zenbase 84 | opentelemetry-sdk==1.25.0 85 | # via zenbase 86 | opentelemetry-semantic-conventions==0.46b0 87 | # via opentelemetry-sdk 88 | pksuid==1.1.2 89 | # via zenbase 90 | posthog==3.5.0 91 | # via zenbase 92 | pybase62==0.4.3 93 | # via pksuid 94 | pydantic==2.8.2 95 | # via instructor 96 | # via openai 97 | pydantic-core==2.20.1 98 | # via instructor 99 | # via pydantic 100 | pyee==11.1.0 101 | # via zenbase 102 | pygments==2.18.0 103 | # via rich 104 | python-dateutil==2.9.0.post0 105 | # via faker 106 | # via posthog 107 | requests==2.32.3 108 | # via posthog 109 | rich==13.7.1 110 | # via instructor 111 | # via typer 112 | shellingham==1.5.4 113 | # via typer 114 | six==1.16.0 115 | # via posthog 116 | # via python-dateutil 117 | sniffio==1.3.1 118 | # via anyio 119 | # via httpx 120 | # via openai 121 | structlog==24.2.0 122 | # via zenbase 123 | tenacity==8.5.0 124 | # via instructor 125 | tqdm==4.66.4 126 | # via openai 127 | typer==0.12.3 128 | # via instructor 129 | typing-extensions==4.12.1 130 | # via anyio 131 | # via openai 132 | # via opentelemetry-sdk 133 | # via pydantic 134 | # via pydantic-core 135 | # via pyee 136 | # via typer 137 | urllib3==2.2.1 138 | # via requests 139 | wrapt==1.16.0 140 | # via deprecated 141 | yarl==1.9.4 142 | # via aiohttp 143 | zipp==3.19.2 144 | # via importlib-metadata 145 | -------------------------------------------------------------------------------- /py/src/zenbase/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/src/zenbase/__init__.py -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/src/zenbase/adaptors/__init__.py -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/arize.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Callable 2 | 3 | import pandas as pd 4 | 5 | from zenbase.optim.metric.types import CandidateEvalResult, OverallEvalValue 6 | from zenbase.types import LMDemo, LMFunction 7 | from zenbase.utils import amap 8 | 9 | if TYPE_CHECKING: 10 | from phoenix.evals import LLMEvaluator 11 | 12 | 13 | class ZenPhoenix: 14 | MetricEvaluator = Callable[[list["LLMEvaluator"], list[pd.DataFrame]], OverallEvalValue] 15 | 16 | @staticmethod 17 | def df_to_demos(df: pd.DataFrame) -> list[LMDemo]: 18 | raise NotImplementedError() 19 | 20 | @staticmethod 21 | def default_metric(evaluators: list["LLMEvaluator"], eval_dfs: list[pd.DataFrame]) -> OverallEvalValue: 22 | evals = {"score": sum(df.score.mean() for df in eval_dfs)} 23 | evals.update({e.__name__: df.score.mean() for e, df in zip(evaluators, eval_dfs)}) 24 | return evals 25 | 26 | @classmethod 27 | def metric_evaluator( 28 | cls, 29 | dataset: pd.DataFrame, 30 | evaluators: list["LLMEvaluator"], 31 | metric_evals: MetricEvaluator = default_metric, 32 | concurrency: int = 20, 33 | *args, 34 | **kwargs, 35 | ): 36 | from phoenix.evals import run_evals 37 | 38 | async def run_experiment(function: LMFunction) -> CandidateEvalResult: 39 | nonlocal dataset 40 | run_df = dataset.copy() 41 | # TODO: Is it typical for there to only be 1 value? 42 | responses = await amap( 43 | function, 44 | run_df["attributes.input.value"].to_list(), # i don't think this works 45 | concurrency=concurrency, 46 | ) 47 | run_df["attributes.output.value"] = responses 48 | 49 | eval_dfs = run_evals(run_df, evaluators, *args, concurrency=concurrency, **kwargs) 50 | 51 | return CandidateEvalResult( 52 | function, 53 | evals=metric_evals(evaluators, eval_dfs), 54 | ) 55 | 56 | return run_experiment 57 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/arize/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptor import * # noqa 2 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/arize/adaptor.py: -------------------------------------------------------------------------------- 1 | from zenbase.adaptors.arize.dataset_helper import ArizeDatasetHelper 2 | from zenbase.adaptors.arize.evaluation_helper import ArizeEvaluationHelper 3 | 4 | 5 | class ZenArizeAdaptor(ArizeDatasetHelper, ArizeEvaluationHelper): 6 | def __init__(self, client=None): 7 | ArizeDatasetHelper.__init__(self, client) 8 | ArizeEvaluationHelper.__init__(self, client) 9 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/arize/dataset_helper.py: -------------------------------------------------------------------------------- 1 | from phoenix.experiments.types import Dataset, Example 2 | 3 | from zenbase.adaptors.base.dataset_helper import BaseDatasetHelper 4 | from zenbase.types import LMDemo 5 | 6 | 7 | class ArizeDatasetHelper(BaseDatasetHelper): 8 | def create_dataset(self, dataset_name: str, *args, **kwargs): 9 | raise NotImplementedError( 10 | "create_dataset not implemented / supported for Arize, dataset will be created" 11 | "automatically when adding examples to it." 12 | ) 13 | 14 | def add_examples_to_dataset(self, dataset_name: str, inputs: list, outputs: list) -> Dataset: 15 | return self.client.upload_dataset( 16 | dataset_name=dataset_name, 17 | inputs=inputs, 18 | outputs=outputs, 19 | ) 20 | 21 | def fetch_dataset_examples(self, dataset_name: str) -> list[Example]: 22 | return list(self.client.get_dataset(name=dataset_name).examples.values()) 23 | 24 | def fetch_dataset_demos(self, dataset_name: str) -> list[LMDemo]: 25 | dataset_examples = self.fetch_dataset_examples(dataset_name) 26 | return [ 27 | LMDemo(inputs=example.input, outputs=example.output) # noqa 28 | for example in dataset_examples 29 | ] 30 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/arize/evaluation_helper.py: -------------------------------------------------------------------------------- 1 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper 2 | from zenbase.optim.metric.types import CandidateEvalResult, CandidateEvaluator, IndividualEvalValue 3 | from zenbase.types import LMDemo, LMFunction 4 | from zenbase.utils import random_name_generator 5 | 6 | 7 | class ArizeEvaluationHelper(BaseEvaluationHelper): 8 | def get_evaluator(self, data: str): 9 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy() 10 | dataset = self.client.get_dataset(name=data) 11 | evaluator_kwargs_to_pass.update({"dataset": dataset}) 12 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass) 13 | 14 | def _metric_evaluator_generator(self, threshold: float = 0.5, **evaluate_kwargs) -> CandidateEvaluator: 15 | from phoenix.experiments import run_experiment 16 | 17 | gen_random_name = random_name_generator(evaluate_kwargs.pop("experiment_prefix", None)) 18 | 19 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult: 20 | def arize_adapted_function(input): 21 | return function(input) 22 | 23 | experiment = run_experiment( 24 | evaluate_kwargs["dataset"], 25 | arize_adapted_function, 26 | experiment_name=gen_random_name(), 27 | evaluators=evaluate_kwargs.get("evaluators", None), 28 | ) 29 | list_of_individual_evals = [] 30 | for individual_eval in experiment.eval_runs: 31 | example_id = experiment.runs[individual_eval.experiment_run_id].dataset_example_id 32 | example = experiment.dataset.examples[example_id] 33 | if individual_eval.result: 34 | list_of_individual_evals.append( 35 | IndividualEvalValue( 36 | passed=individual_eval.result.score >= threshold, 37 | response=experiment.runs[individual_eval.experiment_run_id].output, 38 | demo=LMDemo( 39 | inputs=example.input, 40 | outputs=example.output, 41 | ), 42 | score=individual_eval.result.score, 43 | ) 44 | ) 45 | 46 | # make average scores of all evaluation metrics 47 | avg_scores = [i.stats["avg_score"][0] for i in experiment.eval_summaries] 48 | avg_score = sum(avg_scores) / len(avg_scores) 49 | 50 | return CandidateEvalResult(function, {"score": avg_score}, individual_evals=list_of_individual_evals) 51 | 52 | return evaluate_candidate 53 | 54 | @classmethod 55 | def metric_evaluator(cls, threshold: float = 0.5, **evaluate_kwargs) -> CandidateEvaluator: 56 | # TODO: Should remove and deprecate 57 | from phoenix.experiments import run_experiment 58 | 59 | gen_random_name = random_name_generator(evaluate_kwargs.pop("experiment_prefix", None)) 60 | 61 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult: 62 | def arize_adapted_function(input): 63 | return function(input) 64 | 65 | experiment = run_experiment( 66 | evaluate_kwargs["dataset"], 67 | arize_adapted_function, 68 | experiment_name=gen_random_name(), 69 | evaluators=evaluate_kwargs.get("evaluators", None), 70 | ) 71 | list_of_individual_evals = [] 72 | for individual_eval in experiment.eval_runs: 73 | example_id = experiment.runs[individual_eval.experiment_run_id].dataset_example_id 74 | example = experiment.dataset.examples[example_id] 75 | 76 | list_of_individual_evals.append( 77 | IndividualEvalValue( 78 | passed=individual_eval.result.score >= threshold, 79 | response=experiment.runs[individual_eval.experiment_run_id].output, 80 | demo=LMDemo( 81 | inputs=example.input, 82 | outputs=example.output, 83 | ), 84 | score=individual_eval.result.score, 85 | ) 86 | ) 87 | 88 | # make average scores of all evaluation metrics 89 | avg_scores = [i.stats["avg_score"][0] for i in experiment.eval_summaries] 90 | avg_score = sum(avg_scores) / len(avg_scores) 91 | 92 | return CandidateEvalResult(function, {"score": avg_score}, individual_evals=list_of_individual_evals) 93 | 94 | return evaluate_candidate 95 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/base/adaptor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class ZenAdaptor(ABC): 5 | def __init__(self, client=None): 6 | self.client = client 7 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/base/dataset_helper.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any 3 | 4 | from zenbase.adaptors.base.adaptor import ZenAdaptor 5 | 6 | 7 | class BaseDatasetHelper(ZenAdaptor): 8 | @abstractmethod 9 | def create_dataset(self, dataset_name: str, *args, **kwargs) -> Any: ... 10 | 11 | @abstractmethod 12 | def add_examples_to_dataset(self, dataset_id: Any, inputs: list, outputs: list) -> None: ... 13 | 14 | @abstractmethod 15 | def fetch_dataset_examples(self, dataset_name: str) -> Any: ... 16 | 17 | @abstractmethod 18 | def fetch_dataset_demos(self, dataset: Any) -> Any: ... 19 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/base/evaluation_helper.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any 3 | 4 | from zenbase.adaptors.base.adaptor import ZenAdaptor 5 | 6 | 7 | class BaseEvaluationHelper(ZenAdaptor): 8 | evaluator_args = tuple() 9 | evaluator_kwargs = dict() 10 | 11 | def set_evaluator_kwargs(self, *args, **kwargs) -> None: 12 | self.evaluator_kwargs = kwargs 13 | self.evaluator_args = args 14 | 15 | @abstractmethod 16 | def get_evaluator(self, data: Any): ... 17 | 18 | @classmethod 19 | @abstractmethod 20 | def metric_evaluator(cls, *args, **evaluate_kwargs): ... 21 | 22 | # TODO: Should remove and deprecate 23 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/braintrust.py: -------------------------------------------------------------------------------- 1 | from asyncio import Task 2 | from dataclasses import asdict 3 | from typing import AsyncIterator, Awaitable, Callable, Iterator 4 | 5 | from braintrust import ( 6 | Eval, 7 | EvalCase, 8 | EvalHooks, 9 | EvalScorer, 10 | Input, 11 | Metadata, 12 | Output, 13 | ReporterDef, 14 | ) 15 | 16 | from zenbase.optim.metric.types import CandidateEvalResult 17 | from zenbase.types import LMFunction 18 | from zenbase.utils import random_name_generator 19 | 20 | 21 | class ZenBraintrust: 22 | @staticmethod 23 | def metric_evaluator( 24 | name: str, 25 | data: Callable[[], Iterator[EvalCase] | AsyncIterator[EvalCase]], 26 | task: Callable[[Input, EvalHooks], Output | Awaitable[Output]], 27 | scores: list[EvalScorer], 28 | experiment_name: str | None = None, 29 | trial_count: int = 1, 30 | metadata: Metadata | None = None, 31 | is_public: bool = False, 32 | update: bool = False, 33 | reporter: ReporterDef | str | None = None, 34 | ): 35 | gen_random_name = random_name_generator(experiment_name) 36 | 37 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult: 38 | eval_result = Eval( 39 | name=name, 40 | experiment_name=gen_random_name(), 41 | data=data, 42 | task=task, 43 | scores=scores, 44 | trial_count=trial_count, 45 | metadata={ 46 | **metadata, 47 | **asdict(function.zenbase), 48 | }, 49 | is_public=is_public, 50 | update=update, 51 | reporter=reporter, 52 | ) 53 | 54 | if isinstance(eval_result, Task): 55 | eval_result = eval_result.result() 56 | 57 | assert eval_result is not None, "Failed to run Braintrust Eval" 58 | 59 | evals = {s.name: s.score for s in eval_result.summary.scores.values()} 60 | if "score" not in evals: 61 | evals["score"] = sum(evals.values()) 62 | 63 | return CandidateEvalResult(function, evals) 64 | 65 | return evaluate_candidate 66 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/json/adaptor.py: -------------------------------------------------------------------------------- 1 | from zenbase.adaptors.json.dataset_helper import JSONDatasetHelper 2 | from zenbase.adaptors.json.evaluation_helper import JSONEvaluationHelper 3 | 4 | 5 | class JSONAdaptor(JSONDatasetHelper, JSONEvaluationHelper): 6 | def __init__(self, client=None): 7 | JSONDatasetHelper.__init__(self, client) 8 | JSONEvaluationHelper.__init__(self, client) 9 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/json/dataset_helper.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from zenbase.adaptors.base.adaptor import ZenAdaptor 4 | from zenbase.types import LMDemo 5 | 6 | 7 | class JSONDatasetHelper(ZenAdaptor): 8 | datasets = {} 9 | 10 | def create_dataset(self, dataset_name: str, *args, **kwargs) -> Any: 11 | self.datasets[dataset_name] = [] 12 | 13 | return self.datasets[dataset_name] 14 | 15 | def add_examples_to_dataset(self, dataset_id: Any, inputs: list, outputs: list) -> None: 16 | for input, output in zip(inputs, outputs): 17 | self.datasets[dataset_id].append({"input": input, "output": output}) 18 | 19 | def fetch_dataset_examples(self, dataset_name: str) -> Any: 20 | return self.datasets[dataset_name] 21 | 22 | def fetch_dataset_demos(self, dataset: Any) -> Any: 23 | if isinstance(dataset[0], LMDemo): 24 | return dataset 25 | return [LMDemo(input=item, output=item) for item in dataset] 26 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/json/evaluation_helper.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable 3 | 4 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper 5 | from zenbase.optim.metric.types import CandidateEvalResult, CandidateEvaluator, IndividualEvalValue, OverallEvalValue 6 | from zenbase.types import LMDemo, LMFunction 7 | from zenbase.utils import pmap 8 | 9 | 10 | class JSONEvaluationHelper(BaseEvaluationHelper): 11 | MetricEvaluator = Callable[[list[tuple[bool, Any]]], OverallEvalValue] 12 | 13 | @staticmethod 14 | def default_metric(batch_results: list[tuple[bool, Any]]) -> OverallEvalValue: 15 | avg_pass = sum(int(result["passed"]) for result in batch_results) / len(batch_results) 16 | return {"score": avg_pass} 17 | 18 | def get_evaluator(self, data: list[LMDemo]): 19 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy() 20 | evaluator_kwargs_to_pass.update({"data": data}) 21 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass) 22 | 23 | @staticmethod 24 | def _metric_evaluator_generator( 25 | *args, 26 | eval_function: Callable, 27 | data: list[LMDemo], 28 | eval_metrics: MetricEvaluator = default_metric, 29 | concurrency: int = 1, 30 | **kwargs, 31 | ) -> CandidateEvaluator: 32 | # TODO: Should remove and deprecate 33 | def evaluate_metric(function: LMFunction) -> CandidateEvalResult: 34 | individual_evals = [] 35 | 36 | def run_and_evaluate(demo: LMDemo): 37 | nonlocal individual_evals 38 | 39 | response = function(demo.inputs) 40 | 41 | # Check if eval_function accepts 'input' parameter 42 | eval_params = inspect.signature(eval_function).parameters 43 | if "input" in eval_params: 44 | result = eval_function( 45 | input=demo.inputs, 46 | output=response, 47 | ideal_output=demo.outputs, 48 | *args, 49 | **kwargs, 50 | ) 51 | else: 52 | result = eval_function( 53 | output=response, 54 | ideal_output=demo.outputs, 55 | *args, 56 | **kwargs, 57 | ) 58 | 59 | individual_evals.append( 60 | IndividualEvalValue( 61 | passed=result["passed"], 62 | response=response, 63 | demo=demo, 64 | ) 65 | ) 66 | 67 | return result 68 | 69 | eval_results = pmap( 70 | run_and_evaluate, 71 | data, 72 | concurrency=concurrency, 73 | ) 74 | 75 | return CandidateEvalResult(function, eval_metrics(eval_results), individual_evals=individual_evals) 76 | 77 | return evaluate_metric 78 | 79 | @classmethod 80 | def metric_evaluator( 81 | cls, 82 | eval_function: Callable, 83 | data: list[LMDemo], 84 | eval_metrics: MetricEvaluator = default_metric, 85 | concurrency: int = 1, 86 | threshold: float = 0.5, 87 | ) -> CandidateEvaluator: 88 | # TODO: Should remove and deprecate 89 | def evaluate_metric(function: LMFunction) -> CandidateEvalResult: 90 | individual_evals = [] 91 | 92 | def run_and_evaluate(demo: LMDemo): 93 | nonlocal individual_evals 94 | response = function(demo.inputs) 95 | 96 | # Check if eval_function accepts 'input' parameter 97 | eval_params = inspect.signature(eval_function).parameters 98 | if "input" in eval_params: 99 | result = eval_function( 100 | input=demo.inputs, 101 | output=response, 102 | ideal_output=demo.outputs, 103 | ) 104 | else: 105 | result = eval_function( 106 | output=response, 107 | ideal_output=demo.outputs, 108 | ) 109 | 110 | individual_evals.append( 111 | IndividualEvalValue( 112 | passed=result["passed"], 113 | response=response, 114 | demo=demo, 115 | details=result, 116 | ) 117 | ) 118 | 119 | return result 120 | 121 | eval_results = pmap( 122 | run_and_evaluate, 123 | data, 124 | concurrency=concurrency, 125 | ) 126 | 127 | return CandidateEvalResult(function, eval_metrics(eval_results), individual_evals=individual_evals) 128 | 129 | return evaluate_metric 130 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/langchain/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptor import * # noqa 2 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/langchain/adaptor.py: -------------------------------------------------------------------------------- 1 | __all__ = ["ZenLangSmith"] 2 | 3 | from zenbase.adaptors.langchain.dataset_helper import LangsmithDatasetHelper 4 | from zenbase.adaptors.langchain.evaluation_helper import LangsmithEvaluationHelper 5 | 6 | 7 | class ZenLangSmith(LangsmithDatasetHelper, LangsmithEvaluationHelper): 8 | def __init__(self, client=None): 9 | LangsmithDatasetHelper.__init__(self, client) 10 | LangsmithEvaluationHelper.__init__(self, client) 11 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/langchain/dataset_helper.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Iterator 2 | 3 | from zenbase.adaptors.base.dataset_helper import BaseDatasetHelper 4 | from zenbase.types import LMDemo 5 | 6 | if TYPE_CHECKING: 7 | from langsmith import schemas 8 | 9 | 10 | class LangsmithDatasetHelper(BaseDatasetHelper): 11 | def create_dataset(self, dataset_name: str, description: str) -> "schemas.Dataset": 12 | dataset = self.client.create_dataset(dataset_name, description=description) 13 | return dataset 14 | 15 | def add_examples_to_dataset(self, dataset_name: str, inputs: list, outputs: list) -> None: 16 | self.client.create_examples( 17 | inputs=inputs, 18 | outputs=outputs, 19 | dataset_name=dataset_name, 20 | ) 21 | 22 | def fetch_dataset_examples(self, dataset_name: str) -> Iterator["schemas.Example"]: 23 | dataset = self.fetch_dataset(dataset_name) 24 | return self.client.list_examples(dataset_id=dataset.id) 25 | 26 | def fetch_dataset(self, dataset_name: str): 27 | datasets = self.client.list_datasets(dataset_name=dataset_name) 28 | if not datasets: 29 | raise ValueError(f"Dataset '{dataset_name}' not found") 30 | dataset = [i for i in datasets][0] 31 | return dataset 32 | 33 | def fetch_dataset_demos(self, dataset: "schemas.Dataset") -> list[LMDemo]: 34 | dataset_examples = self.fetch_dataset_examples(dataset.name) 35 | return self.examples_to_demos(dataset_examples) 36 | 37 | @staticmethod 38 | def examples_to_demos(examples: Iterator["schemas.Example"]) -> list[LMDemo]: 39 | return [LMDemo(inputs=e.inputs, outputs=e.outputs, adaptor_object=e) for e in examples] 40 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/langchain/evaluation_helper.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | from typing import TYPE_CHECKING 3 | 4 | from langsmith.evaluation._runner import ExperimentResults # noqa 5 | 6 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper 7 | from zenbase.optim.metric.types import CandidateEvalResult, CandidateEvaluator, IndividualEvalValue, OverallEvalValue 8 | from zenbase.types import LMDemo, LMFunction 9 | from zenbase.utils import random_name_generator 10 | 11 | if TYPE_CHECKING: 12 | from langsmith import schemas 13 | 14 | 15 | class LangsmithEvaluationHelper(BaseEvaluationHelper): 16 | def __init__(self, client=None): 17 | super().__init__(client) 18 | self.evaluator_args = None 19 | self.evaluator_kwargs = None 20 | 21 | def set_evaluator_kwargs(self, *args, **kwargs) -> None: 22 | self.evaluator_kwargs = kwargs 23 | self.evaluator_args = args 24 | 25 | def get_evaluator(self, data: "schemas.Dataset") -> CandidateEvaluator: 26 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy() 27 | evaluator_kwargs_to_pass.update({"data": data.name}) 28 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass) 29 | 30 | def _metric_evaluator_generator(self, threshold: float = 0.5, **evaluate_kwargs) -> CandidateEvaluator: 31 | from langsmith import evaluate 32 | 33 | metadata = evaluate_kwargs.pop("metadata", {}) 34 | gen_random_name = random_name_generator(evaluate_kwargs.pop("experiment_prefix", None)) 35 | 36 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult: 37 | experiment_results = evaluate( 38 | function, 39 | experiment_prefix=gen_random_name(), 40 | metadata={ 41 | **metadata, 42 | }, 43 | **evaluate_kwargs, 44 | ) 45 | 46 | individual_evals = self._experiment_results_to_individual_evals(experiment_results, threshold) 47 | 48 | if summary_results := experiment_results._summary_results["results"]: # noqa 49 | evals = self._eval_results_to_evals(summary_results) 50 | else: 51 | evals = self._individual_evals_to_overall_evals(individual_evals) 52 | 53 | return CandidateEvalResult(function, evals, individual_evals) 54 | 55 | return evaluate_candidate 56 | 57 | @classmethod 58 | def metric_evaluator(cls, threshold: float = 0.5, **evaluate_kwargs) -> CandidateEvaluator: 59 | # TODO: Should remove and deprecate 60 | from langsmith import evaluate 61 | 62 | metadata = evaluate_kwargs.pop("metadata", {}) 63 | gen_random_name = random_name_generator(evaluate_kwargs.pop("experiment_prefix", None)) 64 | 65 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult: 66 | experiment_results = evaluate( 67 | function, 68 | experiment_prefix=gen_random_name(), 69 | metadata={ 70 | **metadata, 71 | **asdict(function.zenbase), 72 | }, 73 | **evaluate_kwargs, 74 | ) 75 | 76 | individual_evals = cls._experiment_results_to_individual_evals(experiment_results, threshold) 77 | 78 | if summary_results := experiment_results._summary_results["results"]: # noqa 79 | evals = cls._eval_results_to_evals(summary_results) 80 | else: 81 | evals = cls._individual_evals_to_overall_evals(individual_evals) 82 | 83 | return CandidateEvalResult(function, evals, individual_evals) 84 | 85 | return evaluate_candidate 86 | 87 | @staticmethod 88 | def _individual_evals_to_overall_evals(individual_evals: list[IndividualEvalValue]) -> OverallEvalValue: 89 | if not individual_evals: 90 | raise ValueError("No evaluation results") 91 | 92 | if individual_evals[0].score is not None: 93 | number_of_scores = sum(1 for e in individual_evals if e.score is not None) 94 | score = sum(e.score for e in individual_evals if e.score is not None) / number_of_scores 95 | else: 96 | number_of_filled_passed = sum(1 for e in individual_evals if e.passed is not None) 97 | number_of_actual_passed = sum(1 for e in individual_evals if e.passed) 98 | score = number_of_actual_passed / number_of_filled_passed 99 | 100 | return {"score": score} 101 | 102 | @staticmethod 103 | def _experiment_results_to_individual_evals( 104 | experiment_results: ExperimentResults, threshold=0.5 105 | ) -> (list)[IndividualEvalValue]: 106 | individual_evals = [] 107 | for res in experiment_results._results: # noqa 108 | if not res["evaluation_results"]["results"]: 109 | continue 110 | score = res["evaluation_results"]["results"][0].score 111 | inputs = res["example"].inputs 112 | outputs = res["example"].outputs 113 | individual_evals.append( 114 | IndividualEvalValue( 115 | passed=score >= threshold, 116 | response=outputs, 117 | demo=LMDemo(inputs=inputs, outputs=outputs, adaptor_object=res["example"]), 118 | details=res, 119 | score=score, 120 | ) 121 | ) 122 | return individual_evals 123 | 124 | @staticmethod 125 | def _eval_results_to_evals(eval_results: list) -> OverallEvalValue: 126 | if not eval_results: 127 | raise ValueError("No evaluation results") 128 | 129 | return { 130 | "score": eval_results[0].score, 131 | **{r.key: r.dict() for r in eval_results}, 132 | } 133 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/langfuse_helper/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptor import * # noqa 2 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/langfuse_helper/adaptor.py: -------------------------------------------------------------------------------- 1 | from zenbase.adaptors.langfuse_helper.dataset_helper import LangfuseDatasetHelper 2 | from zenbase.adaptors.langfuse_helper.evaluation_helper import LangfuseEvaluationHelper 3 | 4 | 5 | class ZenLangfuse(LangfuseDatasetHelper, LangfuseEvaluationHelper): 6 | def __init__(self, client=None): 7 | LangfuseDatasetHelper.__init__(self, client) 8 | LangfuseEvaluationHelper.__init__(self, client) 9 | 10 | def get_evaluator(self, data: str): 11 | data = self.fetch_dataset_demos(data) 12 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy() 13 | evaluator_kwargs_to_pass.update({"data": data}) 14 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass) 15 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/langfuse_helper/dataset_helper.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from langfuse.client import DatasetItemClient 4 | 5 | from zenbase.adaptors.base.dataset_helper import BaseDatasetHelper 6 | from zenbase.types import LMDemo 7 | 8 | 9 | class LangfuseDatasetHelper(BaseDatasetHelper): 10 | def create_dataset(self, dataset_name: str, *args, **kwargs) -> Any: 11 | return self.client.create_dataset(dataset_name, *args, **kwargs) 12 | 13 | def add_examples_to_dataset(self, dataset_name: str, inputs: list, outputs: list) -> None: 14 | for the_input, the_output in zip(inputs, outputs): 15 | self.client.create_dataset_item( 16 | dataset_name=dataset_name, 17 | input=the_input, 18 | expected_output=the_output, 19 | ) 20 | 21 | def fetch_dataset_examples(self, dataset_name: str) -> list[DatasetItemClient]: 22 | return self.client.get_dataset(dataset_name).items 23 | 24 | def fetch_dataset_demos(self, dataset_name: str) -> list[LMDemo]: 25 | return [ 26 | LMDemo(inputs=example.input, outputs=example.expected_output) 27 | for example in self.fetch_dataset_examples(dataset_name) 28 | ] 29 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/langfuse_helper/evaluation_helper.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from langfuse import Langfuse 4 | from langfuse.client import Dataset 5 | 6 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper 7 | from zenbase.optim.metric.types import ( 8 | CandidateEvalResult, 9 | CandidateEvaluator, 10 | IndividualEvalValue, 11 | OverallEvalValue, 12 | ) 13 | from zenbase.types import LMDemo, LMFunction, Outputs 14 | from zenbase.utils import pmap 15 | 16 | 17 | class LangfuseEvaluationHelper(BaseEvaluationHelper): 18 | MetricEvaluator = Callable[[list[OverallEvalValue]], OverallEvalValue] 19 | 20 | @staticmethod 21 | def default_candidate_evals(item_evals: list[OverallEvalValue]) -> OverallEvalValue: 22 | keys = item_evals[0].keys() 23 | evals = {k: sum(d[k] for d in item_evals) / len(item_evals) for k in keys} 24 | if not evals["score"]: 25 | evals["score"] = sum(evals.values()) / len(evals) 26 | return evals 27 | 28 | @staticmethod 29 | def dataset_demos(dataset: Dataset) -> list[LMDemo]: 30 | # TODO: Should remove and deprecate 31 | return [LMDemo(inputs=item.input, outputs=item.expected_output) for item in dataset.items] 32 | 33 | @staticmethod 34 | def _metric_evaluator_generator( 35 | data: list[LMDemo], 36 | evaluate: Callable[[Outputs, LMDemo, Langfuse], OverallEvalValue], 37 | candidate_evals: MetricEvaluator = default_candidate_evals, 38 | langfuse: Langfuse | None = None, 39 | concurrency: int = 20, 40 | threshold: float = 0.5, 41 | ) -> CandidateEvaluator: 42 | # TODO: this is not the way to run experiment in the langfuse, we should update with the new beta feature 43 | from langfuse import Langfuse 44 | from langfuse.decorators import observe 45 | 46 | langfuse = langfuse or Langfuse() 47 | 48 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult: 49 | individual_evals = [] 50 | 51 | @observe() 52 | def run_and_evaluate(demo: LMDemo): 53 | nonlocal individual_evals 54 | 55 | outputs = function(demo.inputs) 56 | evals = evaluate(outputs, demo, langfuse=langfuse) 57 | individual_evals.append( 58 | IndividualEvalValue( 59 | passed=evals["score"] >= threshold, 60 | response=outputs, 61 | demo=demo, 62 | score=evals["score"], 63 | ) 64 | ) 65 | return evals 66 | 67 | item_evals = pmap( 68 | run_and_evaluate, 69 | data, 70 | concurrency=concurrency, 71 | ) 72 | candidate_eval = candidate_evals(item_evals) 73 | 74 | return CandidateEvalResult(function, candidate_eval, individual_evals=individual_evals) 75 | 76 | return evaluate_candidate 77 | 78 | @classmethod 79 | def metric_evaluator( 80 | cls, 81 | evalset: Dataset, 82 | evaluate: Callable[[Outputs, LMDemo, Langfuse], OverallEvalValue], 83 | candidate_evals: MetricEvaluator = default_candidate_evals, 84 | langfuse: Langfuse | None = None, 85 | concurrency: int = 20, 86 | threshold: float = 0.5, 87 | ) -> CandidateEvaluator: 88 | # TODO: Should remove and deprecate 89 | from langfuse import Langfuse 90 | from langfuse.decorators import observe 91 | 92 | langfuse = langfuse or Langfuse() 93 | 94 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult: 95 | @observe() 96 | def run_and_evaluate(demo: LMDemo): 97 | outputs = function(demo.inputs) 98 | evals = evaluate(outputs, demo, langfuse=langfuse) 99 | return evals 100 | 101 | item_evals = pmap( 102 | run_and_evaluate, 103 | cls.dataset_demos(evalset), 104 | concurrency=concurrency, 105 | ) 106 | candidate_eval = candidate_evals(item_evals) 107 | 108 | return CandidateEvalResult(function, candidate_eval) 109 | 110 | return evaluate_candidate 111 | 112 | def get_evaluator(self, data: str): 113 | raise NotImplementedError("This method should be implemented by the parent class as it needs access to data") 114 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/lunary/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptor import * # noqa 2 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/lunary/adaptor.py: -------------------------------------------------------------------------------- 1 | from zenbase.adaptors.lunary.dataset_helper import LunaryDatasetHelper 2 | from zenbase.adaptors.lunary.evaluation_helper import LunaryEvaluationHelper 3 | 4 | 5 | class ZenLunary(LunaryDatasetHelper, LunaryEvaluationHelper): 6 | def __init__(self, client=None): 7 | LunaryDatasetHelper.__init__(self, client) 8 | LunaryEvaluationHelper.__init__(self, client) 9 | 10 | def get_evaluator(self, data: str): 11 | data = self.fetch_dataset_demos(data) 12 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy() 13 | evaluator_kwargs_to_pass.update({"data": data}) 14 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass) 15 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/lunary/dataset_helper.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from zenbase.adaptors.base.dataset_helper import BaseDatasetHelper 4 | from zenbase.types import LMDemo 5 | 6 | 7 | class LunaryDatasetHelper(BaseDatasetHelper): 8 | def create_dataset(self, dataset_name: str, *args, **kwargs) -> Any: 9 | raise NotImplementedError("Lunary doesn't support creating datasets") 10 | 11 | def add_examples_to_dataset(self, dataset_id: Any, inputs: list, outputs: list) -> None: 12 | raise NotImplementedError("Lunary doesn't support adding examples to datasets") 13 | 14 | def fetch_dataset_examples(self, dataset_name: str): 15 | return self.client.get_dataset(dataset_name) 16 | 17 | def fetch_dataset_demos(self, dataset_name: str) -> list[LMDemo]: 18 | return [ 19 | LMDemo(inputs=example.input, outputs=example.ideal_output, adaptor_object=example) 20 | for example in self.client.get_dataset(dataset_name) 21 | ] 22 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/lunary/evaluation_helper.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import lunary 4 | 5 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper 6 | from zenbase.optim.metric.types import ( 7 | CandidateEvalResult, 8 | CandidateEvaluator, 9 | IndividualEvalValue, 10 | OverallEvalValue, 11 | ) 12 | from zenbase.types import LMDemo, LMFunction 13 | from zenbase.utils import pmap 14 | 15 | 16 | class LunaryEvaluationHelper(BaseEvaluationHelper): 17 | MetricEvaluator = Callable[[list[tuple[bool, Any]]], OverallEvalValue] 18 | 19 | @staticmethod 20 | def default_metric(batch_results: list[tuple[bool, Any]]) -> OverallEvalValue: 21 | avg_pass = sum(int(passed) for passed, _ in batch_results) / len(batch_results) 22 | return {"score": avg_pass} 23 | 24 | def get_evaluator(self, data: str): 25 | raise NotImplementedError("This method should be implemented by the parent class as it needs access to data") 26 | 27 | @staticmethod 28 | def _metric_evaluator_generator( 29 | *args, 30 | checklist: str, 31 | data: list[LMDemo], 32 | eval_metrics: MetricEvaluator = default_metric, 33 | concurrency: int = 1, 34 | **kwargs, 35 | ) -> CandidateEvaluator: 36 | # TODO: Should remove and deprecate 37 | def evaluate_metric(function: LMFunction) -> CandidateEvalResult: 38 | individual_evals = [] 39 | 40 | def run_and_evaluate(demo: LMDemo): 41 | nonlocal individual_evals 42 | 43 | response = function(demo.inputs) 44 | result = lunary.evaluate( 45 | checklist, 46 | input=demo.inputs, 47 | output=response, 48 | ideal_output=demo.outputs, 49 | *args, 50 | **kwargs, 51 | ) 52 | 53 | individual_evals.append( 54 | IndividualEvalValue( 55 | details=result[1][0]["details"], 56 | passed=result[0], 57 | response=response, 58 | demo=demo, 59 | ) 60 | ) 61 | 62 | return result 63 | 64 | eval_results = pmap( 65 | run_and_evaluate, 66 | data, 67 | concurrency=concurrency, 68 | ) 69 | 70 | return CandidateEvalResult(function, eval_metrics(eval_results), individual_evals=individual_evals) 71 | 72 | return evaluate_metric 73 | 74 | @classmethod 75 | def dataset_to_demos(cls, dataset: list[lunary.DatasetItem]) -> list[LMDemo]: 76 | # TODO: Should remove and deprecate 77 | return [LMDemo(inputs=item.input, outputs=item.ideal_output, adaptor_object=item) for item in dataset] 78 | 79 | @classmethod 80 | def metric_evaluator( 81 | cls, 82 | *args, 83 | checklist: str, 84 | evalset: list[lunary.DatasetItem], 85 | eval_metrics: MetricEvaluator = default_metric, 86 | concurrency: int = 20, 87 | **kwargs, 88 | ) -> CandidateEvaluator: 89 | # TODO: Should remove and deprecate 90 | def evaluate_metric(function: LMFunction) -> CandidateEvalResult: 91 | individual_evals = [] 92 | 93 | def run_and_evaluate(demo: LMDemo): 94 | nonlocal individual_evals 95 | 96 | item = demo.adaptor_object 97 | 98 | response = function(item.input) 99 | result = lunary.evaluate( 100 | checklist, 101 | input=item.input, 102 | output=response, 103 | ideal_output=item.ideal_output, 104 | *args, 105 | **kwargs, 106 | ) 107 | 108 | individual_evals.append( 109 | IndividualEvalValue( 110 | details=result[1][0]["details"], 111 | passed=result[0], 112 | response=response, 113 | demo=demo, 114 | ) 115 | ) 116 | 117 | return result 118 | 119 | eval_results = pmap( 120 | run_and_evaluate, 121 | cls.dataset_to_demos(evalset), 122 | concurrency=concurrency, 123 | ) 124 | 125 | return CandidateEvalResult(function, eval_metrics(eval_results), individual_evals=individual_evals) 126 | 127 | return evaluate_metric 128 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/parea/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptor import * # noqa 2 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/parea/adaptor.py: -------------------------------------------------------------------------------- 1 | __all__ = ["ZenParea"] 2 | 3 | from zenbase.adaptors.parea.dataset_helper import PareaDatasetHelper 4 | from zenbase.adaptors.parea.evaluation_helper import PareaEvaluationHelper 5 | from zenbase.optim.metric.types import CandidateEvaluator 6 | 7 | 8 | class ZenParea(PareaDatasetHelper, PareaEvaluationHelper): 9 | def __init__(self, client=None): 10 | PareaDatasetHelper.__init__(self, client) 11 | PareaEvaluationHelper.__init__(self, client) 12 | 13 | def get_evaluator(self, data: str) -> CandidateEvaluator: 14 | evaluator_kwargs_to_pass = self.evaluator_kwargs.copy() 15 | evaluator_kwargs_to_pass.update({"data": self.fetch_dataset_list_of_dicts(data)}) 16 | return self._metric_evaluator_generator(**evaluator_kwargs_to_pass) 17 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/parea/dataset_helper.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import ValuesView 3 | 4 | from parea.schemas import TestCase, TestCaseCollection 5 | 6 | from zenbase.adaptors.base.dataset_helper import BaseDatasetHelper 7 | from zenbase.types import LMDemo 8 | 9 | 10 | class PareaDatasetHelper(BaseDatasetHelper): 11 | def create_dataset(self, dataset_name: str): 12 | dataset = self.client.create_test_collection(data=[], name=dataset_name) 13 | return dataset 14 | 15 | def add_examples_to_dataset(self, inputs, outputs, dataset_name: str): 16 | data = [{"inputs": inputs[i], "target": outputs[i]} for i in range(len(inputs))] 17 | self.client.add_test_cases(data, dataset_name) 18 | 19 | def create_dataset_and_add_examples(self, inputs, outputs, dataset_name: str): 20 | data = [{"inputs": inputs[i], "target": outputs[i]} for i in range(len(inputs))] 21 | dataset = self.client.create_test_collection(dataset_name, data) 22 | return dataset 23 | 24 | def fetch_dataset_examples(self, dataset_name: str): 25 | return self.fetch_dataset(dataset_name).test_cases.values() 26 | 27 | def fetch_dataset(self, dataset_name: str) -> TestCaseCollection: 28 | return self.client.get_collection(dataset_name) 29 | 30 | def fetch_dataset_demos(self, dataset_name: str) -> list[LMDemo]: 31 | return self.example_to_demo(self.fetch_dataset_examples(dataset_name)) 32 | 33 | def fetch_dataset_list_of_dicts(self, dataset_name: str) -> list[dict]: 34 | return [ 35 | {"inputs": json.loads(case.inputs["inputs"]), "target": case.target} 36 | for case in self.fetch_dataset_examples(dataset_name) 37 | ] 38 | 39 | @staticmethod 40 | def example_to_demo(examples: ValuesView[TestCase]) -> list[LMDemo]: 41 | return [LMDemo(inputs=example.inputs, outputs={"target": example.target}) for example in examples] 42 | -------------------------------------------------------------------------------- /py/src/zenbase/adaptors/parea/evaluation_helper.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from dataclasses import asdict 4 | from functools import partial 5 | from inspect import ( 6 | _empty, # noqa 7 | signature, 8 | ) 9 | from json import JSONDecodeError 10 | from typing import Callable 11 | 12 | from langsmith.evaluation._runner import ExperimentResults # noqa 13 | from parea import Parea 14 | from parea.schemas import ExperimentStatsSchema, ListExperimentUUIDsFilters 15 | from tenacity import ( 16 | before_sleep_log, 17 | retry, 18 | stop_after_attempt, 19 | wait_exponential_jitter, 20 | ) 21 | 22 | from zenbase.adaptors.base.evaluation_helper import BaseEvaluationHelper 23 | from zenbase.optim.metric.types import ( 24 | CandidateEvalResult, 25 | CandidateEvaluator, 26 | IndividualEvalValue, 27 | OverallEvalValue, 28 | ) 29 | from zenbase.types import LMDemo, LMFunction 30 | from zenbase.utils import expand_nested_json, get_logger, random_name_generator 31 | 32 | log = get_logger(__name__) 33 | 34 | 35 | class PareaEvaluationHelper(BaseEvaluationHelper): 36 | MetricEvaluator = Callable[[dict[str, float]], OverallEvalValue] 37 | 38 | def __init__(self, client=None): 39 | super().__init__(client) 40 | self.evaluator_args = None 41 | self.evaluator_kwargs = None 42 | 43 | def get_evaluator(self, data: str): 44 | pass 45 | 46 | @staticmethod 47 | def default_candidate_evals(stats: ExperimentStatsSchema) -> OverallEvalValue: 48 | return {**stats.avg_scores, "score": sum(stats.avg_scores.values())} 49 | 50 | def _metric_evaluator_generator( 51 | self, 52 | *args, 53 | p: Parea | None = None, 54 | candidate_evals: MetricEvaluator = default_candidate_evals, 55 | **kwargs, 56 | ) -> CandidateEvaluator: 57 | p = p or Parea() 58 | assert isinstance(p, Parea) 59 | 60 | base_metadata = kwargs.pop("metadata", {}) 61 | random_name = random_name_generator(kwargs.pop("name", None))() 62 | 63 | @retry( 64 | stop=stop_after_attempt(3), 65 | wait=wait_exponential_jitter(max=8), 66 | before_sleep=before_sleep_log(log, logging.WARN), 67 | ) 68 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult: 69 | # Check if the function has a default value for the 'request' parameter 70 | # TODO: Needs clean up, this is not the way to do it. 71 | if ( 72 | "optimized_args_in_fn" in signature(function).parameters 73 | and "request" in signature(function).parameters 74 | and signature(function).parameters["request"].default is _empty 75 | ): 76 | # Create a new function with 'None' as the default value for 'request' 77 | function_with_default = partial(function, request=None) 78 | else: 79 | # If 'request' already has a default value, use the function as is 80 | function_with_default = function 81 | 82 | experiment = p.experiment( 83 | func=function_with_default, 84 | *args, 85 | **kwargs, 86 | name=random_name, 87 | metadata={ 88 | **base_metadata, 89 | }, 90 | ) 91 | 92 | experiment.run() 93 | 94 | if not experiment.experiment_stats: 95 | raise RuntimeError("Failed to run experiment on Parea") 96 | 97 | experiments = p.list_experiments( 98 | ListExperimentUUIDsFilters(experiment_name_filter=experiment.experiment_name) 99 | ) 100 | experiment__uuid = experiments[0].uuid 101 | print(f"Num. experiments: {len(experiments)}") 102 | individual_evals = self._experiment_results_to_individual_evals( 103 | experiment.experiment_stats, experiment__uuid=experiment__uuid 104 | ) 105 | return CandidateEvalResult( 106 | function, 107 | evals=candidate_evals(experiment.experiment_stats), 108 | individual_evals=individual_evals, 109 | ) 110 | 111 | return evaluate_candidate 112 | 113 | def _experiment_results_to_individual_evals( 114 | self, 115 | experiment_stats: ExperimentStatsSchema, 116 | threshold=0.5, 117 | score_name="", 118 | experiment__uuid=None, 119 | ) -> list[IndividualEvalValue]: 120 | if experiment_stats is None or experiment__uuid is None: 121 | raise ValueError("experiment_stats and experiment__uuid must not be None") 122 | 123 | individual_evals = [] 124 | # Retrieve the JSON logs for the experiment using its UUID 125 | try: 126 | json_traces = self._get_experiment_logs(experiment__uuid) 127 | except JSONDecodeError: 128 | raise ValueError("Failed to parse experiment logs") 129 | 130 | def find_input_output_with_trace_id(trace_id): 131 | for trace in json_traces: 132 | try: 133 | if trace["trace_id"] == trace_id: 134 | inputs = expand_nested_json(trace["inputs"]) 135 | outputs = expand_nested_json(trace["output"]) 136 | for k, v in inputs.items(): 137 | if isinstance(v, dict) and "zenbase" in v: 138 | return v["inputs"], outputs 139 | except KeyError: 140 | continue 141 | return None, None 142 | 143 | for res in experiment_stats.parent_trace_stats: 144 | # Skip this iteration if there are no scores associated with the current result 145 | if not res.scores: 146 | continue 147 | 148 | # Find the score, prioritizing scores that match the given score name, or defaulting to the first score 149 | score = next((i.score for i in res.scores if score_name and i.name == score_name), res.scores[0].score) 150 | 151 | if not res.trace_id or not json_traces: 152 | raise ValueError("Trace ID or logs not found in experiment results") 153 | 154 | inputs, outputs = find_input_output_with_trace_id(res.trace_id) 155 | 156 | if not inputs or not outputs: 157 | continue 158 | 159 | individual_evals.append( 160 | IndividualEvalValue( 161 | passed=score >= threshold, 162 | response=outputs, 163 | demo=LMDemo(inputs=inputs, outputs=outputs, adaptor_object=res), 164 | score=score, 165 | ) 166 | ) 167 | return individual_evals 168 | 169 | def _get_experiment_logs(self, experiment__uuid): 170 | from parea.client import GET_EXPERIMENT_LOGS_ENDPOINT 171 | 172 | filter_data = {"filter_field": None, "filter_operator": None, "filter_value": None} 173 | endpoint = GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment__uuid) 174 | response = self.client._client.request("POST", endpoint, data=filter_data) # noqa 175 | return response.json() 176 | 177 | @classmethod 178 | def metric_evaluator( 179 | cls, 180 | *args, 181 | p: Parea | None = None, 182 | candidate_evals: MetricEvaluator = default_candidate_evals, 183 | **kwargs, 184 | ) -> CandidateEvaluator: 185 | # TODO: should 186 | p = p or Parea() 187 | assert isinstance(p, Parea) 188 | 189 | base_metadata = kwargs.pop("metadata", {}) 190 | gen_random_name = random_name_generator(kwargs.pop("name", None)) 191 | 192 | def evaluate_candidate(function: LMFunction) -> CandidateEvalResult: 193 | experiment = p.experiment( 194 | func=function, 195 | *args, 196 | **kwargs, 197 | name=gen_random_name(), 198 | metadata={ 199 | **base_metadata, 200 | "zenbase": json.dumps(asdict(function.zenbase)), 201 | }, 202 | ) 203 | 204 | experiment.run() 205 | assert experiment.experiment_stats, "failed to run experiment" 206 | 207 | return CandidateEvalResult( 208 | function, 209 | evals=candidate_evals(experiment.experiment_stats), 210 | ) 211 | 212 | return evaluate_candidate 213 | -------------------------------------------------------------------------------- /py/src/zenbase/core/managers.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from abc import ABC 3 | from collections import OrderedDict 4 | from contextlib import contextmanager 5 | from typing import Any, Callable, Union 6 | 7 | from zenbase.types import LMFunction, LMZenbase 8 | from zenbase.utils import ksuid 9 | 10 | 11 | class BaseTracer(ABC): 12 | pass 13 | 14 | 15 | class ZenbaseTracer(BaseTracer): 16 | def __init__(self, max_traces=1000): 17 | self.all_traces = OrderedDict() 18 | self.max_traces = max_traces 19 | self.current_trace = None 20 | self.current_key = None 21 | self.optimized_args = {} 22 | 23 | def __call__(self, function: Callable[[Any], Any] = None, zenbase: LMZenbase = None) -> Union[Callable, LMFunction]: 24 | if function is None: 25 | return lambda f: self.trace_function(f, zenbase) 26 | return self.trace_function(function, zenbase) 27 | 28 | def flush(self): 29 | self.all_traces.clear() 30 | 31 | def add_trace(self, run_timestamp: str, func_name: str, trace_data: dict): 32 | is_new_key = run_timestamp not in self.all_traces 33 | if is_new_key and len(self.all_traces) >= self.max_traces: 34 | self.all_traces.popitem(last=False) 35 | traces_for_timestamp = self.all_traces.setdefault(run_timestamp, OrderedDict()) 36 | traces_for_timestamp[func_name] = trace_data 37 | if not is_new_key: 38 | self.all_traces.move_to_end(run_timestamp) 39 | 40 | def trace_function(self, function: Callable[[Any], Any] = None, zenbase: LMZenbase = None) -> LMFunction: 41 | def wrapper(request, lm_function, *args, **kwargs): 42 | func_name = function.__name__ 43 | run_timestamp = ksuid(func_name) 44 | 45 | if self.current_trace is None: 46 | with self.trace_context(func_name, run_timestamp): 47 | return self._execute_and_trace(function, func_name, request, lm_function, *args, **kwargs) 48 | else: 49 | return self._execute_and_trace(function, func_name, request, lm_function, *args, **kwargs) 50 | 51 | return LMFunction(wrapper, zenbase) 52 | 53 | @contextmanager 54 | def trace_context(self, func_name, run_timestamp, optimized_args=None): 55 | if self.current_trace is None: 56 | self.current_trace = {} 57 | self.current_key = run_timestamp 58 | if optimized_args: 59 | self.optimized_args = optimized_args 60 | try: 61 | yield 62 | finally: 63 | if self.current_key == run_timestamp: 64 | self.add_trace(run_timestamp, func_name, self.current_trace) 65 | self.current_trace = None 66 | self.current_key = None 67 | self.optimized_args = {} 68 | 69 | def _execute_and_trace(self, func, func_name, request, lm_function, *args, **kwargs): 70 | # Get the function signature 71 | sig = inspect.signature(func) 72 | 73 | # Map positional args to their names and combine with kwargs 74 | combined_args = {**kwargs} 75 | arg_names = list(sig.parameters.keys())[: len(args)] 76 | combined_args.update(zip(arg_names, args)) 77 | 78 | # Include default values for missing arguments 79 | for param in sig.parameters.values(): 80 | if param.name not in combined_args and param.default is not param.empty: 81 | combined_args[param.name] = param.default 82 | 83 | if func_name in self.optimized_args: 84 | optimized_args = self.optimized_args[func_name]["args"] 85 | if "zenbase" in optimized_args: 86 | request.zenbase = optimized_args["zenbase"] 87 | optimized_args.pop("zenbase", None) 88 | 89 | # Replace with optimized arguments if available 90 | if func_name in self.optimized_args: 91 | optimized_args = self.optimized_args[func_name]["args"] 92 | combined_args.update(optimized_args) 93 | 94 | combined_args.update( 95 | { 96 | "request": request, 97 | } 98 | ) 99 | # Capture input arguments in trace_info 100 | trace_info = {"args": combined_args, "output": None, "request": request, "lm_function": lm_function} 101 | 102 | # Execute the function and capture its output 103 | output = func(**combined_args) 104 | trace_info["output"] = output 105 | 106 | # Store the trace information in the current_trace dictionary 107 | if self.current_trace is not None: 108 | self.current_trace[func_name] = trace_info 109 | 110 | return output 111 | -------------------------------------------------------------------------------- /py/src/zenbase/optim/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass, field 3 | from random import Random 4 | from typing import Generic 5 | 6 | from pyee.asyncio import AsyncIOEventEmitter 7 | 8 | from zenbase.types import Inputs, LMFunction, Outputs 9 | from zenbase.utils import random_factory 10 | 11 | 12 | @dataclass(kw_only=True) 13 | class LMOptim(Generic[Inputs, Outputs], ABC): 14 | random: Random = field(default_factory=random_factory) 15 | events: AsyncIOEventEmitter = field(default_factory=AsyncIOEventEmitter) 16 | 17 | @abstractmethod 18 | def perform( 19 | self, 20 | lmfn: LMFunction[Inputs, Outputs], 21 | *args, 22 | **kwargs, 23 | ): ... 24 | -------------------------------------------------------------------------------- /py/src/zenbase/optim/metric/bootstrap_few_shot.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from dataclasses import dataclass, field 3 | from datetime import datetime 4 | from functools import partial 5 | from typing import Any, Dict, NamedTuple 6 | 7 | import cloudpickle 8 | 9 | from zenbase.adaptors.langchain import ZenLangSmith 10 | from zenbase.core.managers import ZenbaseTracer 11 | from zenbase.optim.base import LMOptim 12 | from zenbase.optim.metric.labeled_few_shot import LabeledFewShot 13 | from zenbase.optim.metric.types import CandidateEvalResult 14 | from zenbase.types import Inputs, LMDemo, LMFunction, LMZenbase, Outputs 15 | from zenbase.utils import get_logger, ot_tracer 16 | 17 | log = get_logger(__name__) 18 | 19 | 20 | @dataclass(kw_only=True) 21 | class BootstrapFewShot(LMOptim[Inputs, Outputs]): 22 | class Result(NamedTuple): 23 | best_function: LMFunction[Inputs, Outputs] 24 | candidate_results: list[CandidateEvalResult] | None = None 25 | 26 | shots: int = field(default=5) 27 | training_set_demos: list[LMDemo[Inputs, Outputs]] | None = None 28 | training_set: Any = None # TODO: it needs to be more generic and pass our Dataset Object here 29 | test_set: Any = None 30 | validation_set: Any = None 31 | base_evaluation = None 32 | best_evaluation = None 33 | optimizer_args: Dict[str, dict[str, dict[str, LMDemo]]] = field(default_factory=dict) 34 | zen_adaptor: Any = None 35 | evaluator_kwargs: dict[str, Any] = field(default_factory=dict) 36 | 37 | def __post_init__(self): 38 | self.training_set_demos = self.zen_adaptor.fetch_dataset_demos(self.training_set) 39 | self.zen_adaptor.set_evaluator_kwargs(**self.evaluator_kwargs) 40 | assert 1 <= self.shots <= len(self.training_set_demos) 41 | 42 | @ot_tracer.start_as_current_span("perform") 43 | def perform( 44 | self, 45 | student_lm: LMFunction[Inputs, Outputs], 46 | teacher_lm: LMFunction[Inputs, Outputs] | None = None, 47 | samples: int = 5, 48 | rounds: int = 1, 49 | trace_manager: ZenbaseTracer = None, 50 | ) -> Result: 51 | """ 52 | This function will perform the bootstrap few shot optimization on the given student_lm function. 53 | It will return the best function that is optimized based on the given student_lm function. 54 | 55 | 56 | :param student_lm: The student function that needs to be optimized 57 | :param teacher_lm: The teacher function that will be used to optimize the student function 58 | :param samples: The number of samples to be used for the optimization 59 | :param rounds: The number of rounds to be used for the optimization in the LabeledFewShot 60 | :param trace_manager: The trace manager that will be used to trace the function 61 | :param helper_class: The helper class that will be used to fetch the dataset and evaluator 62 | """ 63 | assert trace_manager is not None, "Zenbase is required for this operation" 64 | # Clean up traces 65 | trace_manager.flush() 66 | 67 | test_set_evaluator = self.zen_adaptor.get_evaluator(data=self.test_set) 68 | self.base_evaluation = test_set_evaluator(student_lm) 69 | 70 | if not teacher_lm: 71 | # Create the base LabeledFewShot teacher model 72 | trace_manager.flush() 73 | teacher_lm = self._create_teacher_model(self.zen_adaptor, student_lm, samples, rounds) 74 | 75 | # Evaluate and validate the demo set 76 | validated_training_set_demos = self._validate_demo_set(self.zen_adaptor, teacher_lm) 77 | 78 | # Run each validated demo to fill up the traces 79 | trace_manager.flush() 80 | self._run_validated_demos(teacher_lm, validated_training_set_demos) 81 | 82 | # Consolidate the traces to optimized args 83 | optimized_args = self._consolidate_traces_to_optimized_args(trace_manager) 84 | self.set_optimizer_args(optimized_args) 85 | 86 | # Create the optimized function 87 | optimized_fn = self._create_optimized_function(student_lm, optimized_args, trace_manager) 88 | 89 | # Evaluate the optimized function 90 | self.best_evaluation = test_set_evaluator(optimized_fn) 91 | 92 | trace_manager.flush() 93 | return self.Result(best_function=optimized_fn) 94 | 95 | def _create_teacher_model( 96 | self, zen_adaptor: ZenLangSmith, student_lm: LMFunction, samples: int, rounds: int 97 | ) -> LMFunction: 98 | evaluator = zen_adaptor.get_evaluator(data=self.validation_set) 99 | teacher_lm, _, _ = LabeledFewShot(demoset=self.training_set_demos, shots=self.shots).perform( 100 | student_lm, evaluator=evaluator, samples=samples, rounds=rounds 101 | ) 102 | return teacher_lm 103 | 104 | def _validate_demo_set(self, zen_adaptor: ZenLangSmith, teacher_lm: LMFunction) -> list[LMDemo]: 105 | # TODO: here is an issue that we are not removing the actual training set from the task demo 106 | # so it is possible of over fitting but it is not a big issue for now, 107 | # we should remove them in the trace_manager 108 | # def teacher_lm_tweaked(request: LMRequest): 109 | # #check inputs in the task demos 110 | # for demo in request.zenbase.task_demos: 111 | # if request.inputs == demo.inputs: 112 | # request.zenbase.task_demos.pop(demo) 113 | # return teacher_lm(request) 114 | 115 | # get evaluator for the training set 116 | evaluate_demo_set = zen_adaptor.get_evaluator(data=self.training_set) 117 | # run the evaluation and get the result of the evaluation 118 | result = evaluate_demo_set(teacher_lm) 119 | # find the validated training set that has been passed 120 | validated_demo_set = [eval.demo for eval in result.individual_evals if eval.passed] 121 | return validated_demo_set 122 | 123 | @staticmethod 124 | def _run_validated_demos(teacher_lm: LMFunction, validated_demo_set: list[LMDemo]) -> None: 125 | """ 126 | Run each of the validated demos to fill up the traces 127 | 128 | :param teacher_lm: The teacher model to run the demos 129 | :param validated_demo_set: The validated demos to run 130 | """ 131 | for validated_demo in validated_demo_set: 132 | teacher_lm(validated_demo.inputs) 133 | 134 | def _consolidate_traces_to_optimized_args( 135 | self, trace_manager: ZenbaseTracer 136 | ) -> dict[str, dict[str, dict[str, LMDemo]]]: 137 | """ 138 | Consolidate the traces to optimized args that will be used to optimize the student function 139 | 140 | :param trace_manager: The trace manager that contains all the traces 141 | """ 142 | all_traces = trace_manager.all_traces 143 | each_function_inputs = {} 144 | 145 | for trace_value in all_traces.values(): 146 | for function_trace in trace_value.values(): 147 | for inside_functions, inside_functions_traces in function_trace.items(): 148 | input_args = inside_functions_traces["args"]["request"].inputs 149 | output_args = inside_functions_traces["output"] 150 | 151 | # Sanitize input and output arguments by replacing curly braces with spaces. 152 | # This prevents conflicts when using these arguments as keys in template rendering within LangChain. 153 | if isinstance(input_args, dict): 154 | input_args = {k: str(v).replace("{", " ").replace("}", " ") for k, v in input_args.items()} 155 | if isinstance(output_args, dict): 156 | output_args = {k: str(v).replace("{", " ").replace("}", " ") for k, v in output_args.items()} 157 | 158 | each_function_inputs.setdefault(inside_functions, []).append( 159 | LMDemo(inputs=input_args, outputs=output_args) 160 | ) 161 | 162 | optimized_args = { 163 | function: {"args": {"zenbase": LMZenbase(task_demos=demos)}} 164 | for function, demos in each_function_inputs.items() 165 | } 166 | return optimized_args 167 | 168 | @staticmethod 169 | def _create_optimized_function( 170 | student_lm: LMFunction, optimized_args: dict, trace_manager: ZenbaseTracer 171 | ) -> LMFunction: 172 | """ 173 | Create the optimized function that will be used to optimize the student function 174 | 175 | :param student_lm: The student function that needs to be optimized 176 | :param optimized_args: The optimized args that will be used to optimize the student function 177 | :param trace_manager: The trace manager that will be used to trace the function 178 | """ 179 | 180 | def optimized_fn_base(request, zenbase, optimized_args_in_fn, trace_manager, *args, **kwargs): 181 | if request is None and "inputs" not in kwargs.keys(): 182 | raise ValueError("Request or inputs should be passed") 183 | elif request is None: 184 | request = kwargs["inputs"] 185 | kwargs.pop("inputs") 186 | 187 | new_optimized_args = deepcopy(optimized_args_in_fn) 188 | with trace_manager.trace_context( 189 | "optimized", f"optimized_layer_0_{datetime.now().isoformat()}", new_optimized_args 190 | ): 191 | if request is None: 192 | return student_lm(*args, **kwargs) 193 | return student_lm(request, *args, **kwargs) 194 | 195 | optimized_fn = partial( 196 | optimized_fn_base, 197 | zenbase=LMZenbase(), # it doesn't do anything, it is just for type safety 198 | optimized_args_in_fn=optimized_args, 199 | trace_manager=trace_manager, 200 | ) 201 | return optimized_fn 202 | 203 | def set_optimizer_args(self, args: Dict[str, Any]) -> None: 204 | """ 205 | Set the optimizer arguments. 206 | 207 | :param args: A dictionary containing the optimizer arguments 208 | """ 209 | self.optimizer_args = args 210 | 211 | def get_optimizer_args(self) -> Dict[str, Any]: 212 | """ 213 | Get the current optimizer arguments. 214 | 215 | :return: A dictionary containing the current optimizer arguments 216 | """ 217 | return self.optimizer_args 218 | 219 | def save_optimizer_args(self, file_path: str) -> None: 220 | """ 221 | Save the optimizer arguments to a dill file. 222 | 223 | :param file_path: The path to save the dill file 224 | """ 225 | with open(file_path, "wb") as f: 226 | cloudpickle.dump(self.optimizer_args, f) 227 | 228 | @classmethod 229 | def load_optimizer_and_function( 230 | cls, optimizer_args_file: str, student_lm: LMFunction[Inputs, Outputs], trace_manager: ZenbaseTracer 231 | ) -> LMFunction[Inputs, Outputs]: 232 | """ 233 | Load optimizer arguments and create an optimized function. 234 | 235 | :param optimizer_args_file: The path to the JSON file containing optimizer arguments 236 | :param student_lm: The student function to be optimized 237 | :param trace_manager: The trace manager to be used 238 | :return: An optimized function 239 | """ 240 | optimizer_args = cls._load_optimizer_args(optimizer_args_file) 241 | return cls._create_optimized_function(student_lm, optimizer_args, trace_manager) 242 | 243 | @classmethod 244 | def _load_optimizer_args(cls, file_path: str) -> Dict[str, Any]: 245 | """ 246 | Load optimizer arguments from a dill file. 247 | 248 | :param file_path: The path to load the dill file from 249 | :return: A dictionary containing the loaded optimizer arguments 250 | """ 251 | with open(file_path, "rb") as f: 252 | return cloudpickle.load(f) 253 | -------------------------------------------------------------------------------- /py/src/zenbase/optim/metric/labeled_few_shot.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from math import factorial 3 | from typing import NamedTuple 4 | 5 | from zenbase.optim.base import LMOptim 6 | from zenbase.optim.metric.types import CandidateEvalResult, CandidateEvaluator 7 | from zenbase.types import Inputs, LMDemo, LMFunction, LMZenbase, Outputs 8 | from zenbase.utils import asyncify, get_logger, ksuid, ot_tracer, pmap, posthog 9 | 10 | log = get_logger(__name__) 11 | 12 | 13 | @dataclass(kw_only=True) 14 | class LabeledFewShot(LMOptim[Inputs, Outputs]): 15 | class Result(NamedTuple): 16 | best_function: LMFunction[Inputs, Outputs] 17 | candidate_results: list[CandidateEvalResult] 18 | best_candidate_result: CandidateEvalResult | None 19 | 20 | demoset: list[LMDemo[Inputs, Outputs]] 21 | shots: int = field(default=5) 22 | 23 | def __post_init__(self): 24 | assert 1 <= self.shots <= len(self.demoset) 25 | 26 | @ot_tracer.start_as_current_span("perform") 27 | def perform( 28 | self, 29 | lmfn: LMFunction[Inputs, Outputs], 30 | evaluator: CandidateEvaluator[Inputs, Outputs], 31 | samples: int = 0, 32 | rounds: int = 1, 33 | concurrency: int = 1, 34 | ) -> Result: 35 | samples = samples or len(self.demoset) 36 | 37 | best_score = float("-inf") 38 | best_lmfn = lmfn 39 | best_candidate_result = None 40 | 41 | @ot_tracer.start_as_current_span("run_experiment") 42 | def run_candidate_zenbase(zenbase: LMZenbase): 43 | nonlocal best_score, best_lmfn, best_candidate_result 44 | 45 | candidate_fn = lmfn.clean_and_duplicate(zenbase) 46 | try: 47 | candidate_result = evaluator(candidate_fn) 48 | except Exception as e: 49 | log.error("candidate evaluation failed", error=e) 50 | candidate_result = CandidateEvalResult(candidate_fn, {"score": float("-inf")}) 51 | 52 | self.events.emit("candidate", candidate_result) 53 | 54 | if candidate_result.evals["score"] > best_score: 55 | best_score = candidate_result.evals["score"] 56 | best_lmfn = candidate_fn 57 | best_candidate_result = candidate_result 58 | 59 | return candidate_result 60 | 61 | candidates: list[CandidateEvalResult] = [] 62 | for _ in range(rounds): 63 | candidates += pmap( 64 | run_candidate_zenbase, 65 | self.candidates(best_lmfn, samples), 66 | concurrency=concurrency, 67 | ) 68 | 69 | posthog().capture( 70 | distinct_id=ksuid(), 71 | event="optimize_labeled_few_shot", 72 | properties={ 73 | "evals": {c.function.id: c.evals for c in candidates}, 74 | }, 75 | ) 76 | 77 | return self.Result(best_lmfn, candidates, best_candidate_result) 78 | 79 | async def aperform( 80 | self, 81 | lmfn: LMFunction[Inputs, Outputs], 82 | evaluator: CandidateEvaluator[Inputs, Outputs], 83 | samples: int = 0, 84 | rounds: int = 1, 85 | concurrency: int = 1, 86 | ) -> Result: 87 | return await asyncify(self.perform)(lmfn, evaluator, samples, rounds, concurrency) 88 | 89 | def candidates(self, _lmfn: LMFunction[Inputs, Outputs], samples: int): 90 | max_samples = factorial(len(self.demoset)) 91 | if samples > max_samples: 92 | log.warn( 93 | "samples >= factorial(len(demoset)), using factorial(len(demoset))", 94 | max_samples=max_samples, 95 | samples=samples, 96 | ) 97 | samples = max_samples 98 | 99 | for _ in range(samples): 100 | demos = tuple(self.random.sample(self.demoset, k=self.shots)) 101 | yield LMZenbase(task_demos=demos) 102 | -------------------------------------------------------------------------------- /py/src/zenbase/optim/metric/types.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from dataclasses import dataclass, field 3 | from typing import Callable, Generic, TypedDict 4 | 5 | from zenbase.types import Dataclass, Inputs, LMDemo, LMFunction, Outputs 6 | 7 | 8 | class OverallEvalValue(TypedDict): 9 | score: float 10 | 11 | 12 | @dataclasses.dataclass(frozen=True) 13 | class IndividualEvalValue(Dataclass, Generic[Outputs]): 14 | passed: bool 15 | response: Outputs 16 | demo: LMDemo 17 | score: float | None = None 18 | details: dict = field(default_factory=dict) 19 | 20 | 21 | @dataclass 22 | class CandidateEvalResult(Generic[Inputs, Outputs]): 23 | function: LMFunction[Inputs, Outputs] 24 | evals: OverallEvalValue = field(default_factory=dict) 25 | individual_evals: list[IndividualEvalValue] = field(default_factory=list) 26 | 27 | 28 | CandidateEvaluator = Callable[ 29 | [LMFunction[Inputs, Outputs]], 30 | CandidateEvalResult[Inputs, Outputs], 31 | ] 32 | -------------------------------------------------------------------------------- /py/src/zenbase/predefined/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/src/zenbase/predefined/__init__.py -------------------------------------------------------------------------------- /py/src/zenbase/predefined/base/function_generator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class BaseLMFunctionGenerator(abc.ABC): 5 | @abc.abstractmethod 6 | def generate(self, *args, **kwargs): ... 7 | -------------------------------------------------------------------------------- /py/src/zenbase/predefined/base/optimizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class BasePredefinedOptimizer(ABC): 5 | pass 6 | -------------------------------------------------------------------------------- /py/src/zenbase/predefined/generic_lm_function/optimizer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, Callable, List, NamedTuple, Type 3 | 4 | from instructor.client import Instructor 5 | from pydantic import BaseModel 6 | 7 | from zenbase.adaptors.json.adaptor import JSONAdaptor 8 | from zenbase.core.managers import ZenbaseTracer 9 | from zenbase.optim.metric.labeled_few_shot import LabeledFewShot 10 | from zenbase.optim.metric.types import CandidateEvalResult 11 | from zenbase.types import LMDemo, LMFunction 12 | 13 | 14 | @dataclass 15 | class GenericLMFunctionOptimizer: 16 | class Result(NamedTuple): 17 | best_function: LMFunction 18 | candidate_results: list[CandidateEvalResult] 19 | best_candidate_result: CandidateEvalResult | None 20 | 21 | instructor_client: Instructor 22 | prompt: str 23 | input_model: Type[BaseModel] 24 | output_model: Type[BaseModel] 25 | model: str 26 | zenbase_tracer: ZenbaseTracer 27 | training_set: List[dict] 28 | validation_set: List[dict] 29 | test_set: List[dict] 30 | custom_evaluator: Callable[[Any, dict], dict] = field(default=None) 31 | shots: int = 5 32 | samples: int = 10 33 | last_result: Result | None = field(default=None) 34 | 35 | lm_function: LMFunction = field(init=False) 36 | training_set_demos: List[LMDemo] = field(init=False) 37 | validation_set_demos: List[LMDemo] = field(init=False) 38 | test_set_demos: List[LMDemo] = field(init=False) 39 | best_evaluation: CandidateEvalResult | None = field(default=None) 40 | base_evaluation: CandidateEvalResult | None = field(default=None) 41 | 42 | def __post_init__(self): 43 | self.lm_function = self._generate_lm_function() 44 | self.training_set_demos = self._convert_dataset_to_demos(self.training_set) 45 | self.validation_set_demos = self._convert_dataset_to_demos(self.validation_set) 46 | self.test_set_demos = self._convert_dataset_to_demos(self.test_set) 47 | 48 | def _generate_lm_function(self) -> LMFunction: 49 | @self.zenbase_tracer.trace_function 50 | def generic_function(request): 51 | system_role = "assistant" if self.model.startswith("o1") else "system" 52 | messages = [ 53 | {"role": system_role, "content": self.prompt}, 54 | ] 55 | 56 | if request.zenbase.task_demos: 57 | messages.append({"role": system_role, "content": "Here are some examples:"}) 58 | for demo in request.zenbase.task_demos: 59 | if demo.inputs == request.inputs: 60 | continue 61 | messages.extend( 62 | [ 63 | {"role": "user", "content": str(demo.inputs)}, 64 | {"role": "assistant", "content": str(demo.outputs)}, 65 | ] 66 | ) 67 | messages.append({"role": system_role, "content": "Now, please answer the following question:"}) 68 | 69 | messages.append({"role": "user", "content": str(request.inputs)}) 70 | 71 | kwargs = { 72 | "model": self.model, 73 | "response_model": self.output_model, 74 | "messages": messages, 75 | "max_retries": 3, 76 | } 77 | 78 | if not self.model.startswith("o1"): 79 | kwargs.update( 80 | { 81 | "logprobs": True, 82 | "top_logprobs": 5, 83 | } 84 | ) 85 | 86 | return self.instructor_client.chat.completions.create(**kwargs) 87 | 88 | return generic_function 89 | 90 | def _convert_dataset_to_demos(self, dataset: List[dict]) -> List[LMDemo]: 91 | return [LMDemo(inputs=item["inputs"], outputs=item["outputs"]) for item in dataset] 92 | 93 | def optimize(self) -> Result: 94 | evaluator = self.custom_evaluator or self._create_default_evaluator() 95 | test_evaluator = self._create_test_evaluator(evaluator) 96 | 97 | # Perform base evaluation 98 | self.base_evaluation = self._perform_base_evaluation(test_evaluator) 99 | 100 | optimizer = LabeledFewShot(demoset=self.training_set_demos, shots=self.shots) 101 | optimizer_result = optimizer.perform( 102 | self.lm_function, 103 | evaluator=JSONAdaptor.metric_evaluator( 104 | data=self.validation_set_demos, 105 | eval_function=evaluator, 106 | ), 107 | samples=self.samples, 108 | rounds=1, 109 | ) 110 | 111 | # Evaluate best function 112 | self.best_evaluation = self._evaluate_best_function(test_evaluator, optimizer_result) 113 | 114 | self.last_result = self.Result( 115 | best_function=optimizer_result.best_function, 116 | candidate_results=optimizer_result.candidate_results, 117 | best_candidate_result=optimizer_result.best_candidate_result, 118 | ) 119 | 120 | return self.last_result 121 | 122 | def _create_default_evaluator(self): 123 | def evaluator(output: BaseModel, ideal_output: dict) -> dict: 124 | return { 125 | "passed": int(output.model_dump(mode="json") == ideal_output), 126 | } 127 | 128 | return evaluator 129 | 130 | def _create_test_evaluator(self, evaluator): 131 | return JSONAdaptor.metric_evaluator( 132 | data=self.test_set_demos, 133 | eval_function=evaluator, 134 | ) 135 | 136 | def _perform_base_evaluation(self, test_evaluator): 137 | """Perform the base evaluation of the LM function.""" 138 | return test_evaluator(self.lm_function) 139 | 140 | def _evaluate_best_function(self, test_evaluator, optimizer_result): 141 | """Evaluate the best function from the optimization result.""" 142 | return test_evaluator(optimizer_result.best_function) 143 | 144 | def create_lm_function_with_demos(self, prompt: str, demos: List[dict]) -> LMFunction: 145 | @self.zenbase_tracer.trace_function 146 | def lm_function_with_demos(request): 147 | system_role = "assistant" if self.model.startswith("o1") else "system" 148 | messages = [ 149 | {"role": system_role, "content": prompt}, 150 | ] 151 | 152 | # Add demos to the messages 153 | if demos: 154 | messages.append({"role": system_role, "content": "Here are some examples:"}) 155 | for demo in demos: 156 | messages.extend( 157 | [ 158 | {"role": "user", "content": str(demo["inputs"])}, 159 | {"role": "assistant", "content": str(demo["outputs"])}, 160 | ] 161 | ) 162 | messages.append({"role": system_role, "content": "Now, please answer the following question:"}) 163 | 164 | # Add the actual request 165 | messages.append({"role": "user", "content": str(request.inputs)}) 166 | 167 | kwargs = { 168 | "model": self.model, 169 | "response_model": self.output_model, 170 | "messages": messages, 171 | "max_retries": 3, 172 | } 173 | 174 | if not self.model.startswith("o1"): 175 | kwargs.update( 176 | { 177 | "logprobs": True, 178 | "top_logprobs": 5, 179 | } 180 | ) 181 | 182 | return self.instructor_client.chat.completions.create(**kwargs) 183 | 184 | return lm_function_with_demos 185 | 186 | def generate_csv_report(self): 187 | if not self.last_result: 188 | raise ValueError("No results to generate report from") 189 | 190 | best_candidate_result = self.last_result.best_candidate_result 191 | base_evaluation = self.base_evaluation 192 | best_evaluation = self.best_evaluation 193 | 194 | list_of_rows = [("type", "input", "ideal_output", "output", "passed", "score", "details")] 195 | 196 | for eval_item in best_candidate_result.individual_evals: 197 | list_of_rows.append( 198 | ( 199 | "best_candidate_result", 200 | eval_item.demo.inputs, 201 | eval_item.demo.outputs, 202 | eval_item.response, 203 | eval_item.passed, 204 | eval_item.score, 205 | eval_item.details, 206 | ) 207 | ) 208 | 209 | for eval_item in base_evaluation.individual_evals: 210 | list_of_rows.append( 211 | ( 212 | "base_evaluation", 213 | eval_item.demo.inputs, 214 | eval_item.demo.outputs, 215 | eval_item.response, 216 | eval_item.passed, 217 | eval_item.score, 218 | eval_item.details, 219 | ) 220 | ) 221 | 222 | for eval_item in best_evaluation.individual_evals: 223 | list_of_rows.append( 224 | ( 225 | "best_evaluation", 226 | eval_item.demo.inputs, 227 | eval_item.demo.outputs, 228 | eval_item.response, 229 | eval_item.passed, 230 | eval_item.score, 231 | eval_item.details, 232 | ) 233 | ) 234 | 235 | # save to csv 236 | import csv 237 | 238 | with open("report.csv", "w", newline="") as file: 239 | writer = csv.writer(file) 240 | writer.writerows(list_of_rows) 241 | -------------------------------------------------------------------------------- /py/src/zenbase/predefined/single_class_classifier/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import * # noqa 2 | -------------------------------------------------------------------------------- /py/src/zenbase/predefined/single_class_classifier/classifier.py: -------------------------------------------------------------------------------- 1 | __all__ = ["SingleClassClassifier"] 2 | 3 | from dataclasses import dataclass, field 4 | from enum import Enum 5 | from typing import Any, Dict, NamedTuple, Type 6 | 7 | import cloudpickle 8 | from instructor.client import AsyncInstructor, Instructor 9 | from pydantic import BaseModel 10 | 11 | from zenbase.adaptors.json.adaptor import JSONAdaptor 12 | from zenbase.core.managers import ZenbaseTracer 13 | from zenbase.optim.metric.labeled_few_shot import LabeledFewShot 14 | from zenbase.optim.metric.types import CandidateEvalResult 15 | from zenbase.predefined.base.optimizer import BasePredefinedOptimizer 16 | from zenbase.predefined.single_class_classifier.function_generator import SingleClassClassifierLMFunctionGenerator 17 | from zenbase.predefined.syntethic_data.single_class_classifier import SingleClassClassifierSyntheticDataExample 18 | from zenbase.types import Inputs, LMDemo, LMFunction, Outputs 19 | 20 | 21 | @dataclass(kw_only=True) 22 | class SingleClassClassifier(BasePredefinedOptimizer): 23 | """ 24 | A single-class classifier that optimizes and evaluates language model functions. 25 | """ 26 | 27 | class Result(NamedTuple): 28 | best_function: LMFunction[Inputs, Outputs] 29 | candidate_results: list[CandidateEvalResult] 30 | best_candidate_result: CandidateEvalResult | None 31 | 32 | instructor_client: Instructor | AsyncInstructor 33 | prompt: str 34 | class_dict: Dict[str, str] | None = field(default=None) 35 | class_enum: Enum | None = field(default=None) 36 | prediction_class: Type[BaseModel] | None = field(default=None) 37 | model: str 38 | zenbase_tracer: ZenbaseTracer 39 | lm_function: LMFunction | None = field(default=None) 40 | training_set: list 41 | test_set: list 42 | validation_set: list 43 | shots: int = 5 44 | samples: int = 10 45 | best_evaluation: CandidateEvalResult | None = field(default=None) 46 | base_evaluation: CandidateEvalResult | None = field(default=None) 47 | optimizer_result: Result | None = field(default=None) 48 | 49 | def __post_init__(self): 50 | """Initialize the SingleClassClassifier after creation.""" 51 | self.lm_function = self._generate_lm_function() 52 | self.training_set_demos = self._convert_dataset_to_demos(self.training_set) 53 | self.test_set_demos = self._convert_dataset_to_demos(self.test_set) 54 | self.validation_set_demos = self._convert_dataset_to_demos(self.validation_set) 55 | 56 | def _generate_lm_function(self) -> LMFunction: 57 | """Generate the language model function.""" 58 | return SingleClassClassifierLMFunctionGenerator( 59 | instructor_client=self.instructor_client, 60 | prompt=self.prompt, 61 | class_dict=self.class_dict, 62 | class_enum=self.class_enum, 63 | prediction_class=self.prediction_class, 64 | model=self.model, 65 | zenbase_tracer=self.zenbase_tracer, 66 | ).generate() 67 | 68 | @staticmethod 69 | def _convert_dataset_to_demos(dataset: list) -> list[LMDemo]: 70 | """Convert a dataset to a list of LMDemo objects.""" 71 | if dataset: 72 | if isinstance(dataset[0], dict): 73 | return [ 74 | LMDemo(inputs={"question": item["inputs"]}, outputs={"answer": item["outputs"]}) for item in dataset 75 | ] 76 | elif isinstance(dataset[0], SingleClassClassifierSyntheticDataExample): 77 | return [LMDemo(inputs={"question": item.inputs}, outputs={"answer": item.outputs}) for item in dataset] 78 | 79 | def load_classifier(self, filename: str): 80 | with open(filename, "rb") as f: 81 | lm_zenbase = cloudpickle.load(f) 82 | return self.lm_function.clean_and_duplicate(lm_zenbase) 83 | 84 | def optimize(self) -> Result: 85 | """ 86 | Perform the optimization and evaluation of the language model function. 87 | 88 | Returns: 89 | Result: The optimization result containing the best function and evaluation metrics. 90 | """ 91 | # Define the evaluation function 92 | evaluator = self._create_evaluator() 93 | 94 | # Create test evaluator 95 | test_evaluator = self._create_test_evaluator(evaluator) 96 | 97 | # Perform base evaluation 98 | self.base_evaluation = self._perform_base_evaluation(test_evaluator) 99 | 100 | # Create and run optimizer 101 | optimizer_result = self._run_optimization(evaluator) 102 | 103 | # Evaluate best function 104 | self.best_evaluation = self._evaluate_best_function(test_evaluator, optimizer_result) 105 | 106 | # Save last optimizer_result 107 | self.optimizer_result = optimizer_result 108 | 109 | return optimizer_result 110 | 111 | @staticmethod 112 | def _create_evaluator(): 113 | """Create the evaluation function.""" 114 | 115 | def evaluator(output: Any, ideal_output: Dict[str, Any]) -> Dict[str, int]: 116 | return { 117 | "passed": int(ideal_output["answer"] == output.class_label.name), 118 | } 119 | 120 | return evaluator 121 | 122 | def _create_test_evaluator(self, evaluator): 123 | """Create the test evaluator using JSONAdaptor.""" 124 | return JSONAdaptor.metric_evaluator( 125 | data=self.validation_set_demos, 126 | eval_function=evaluator, 127 | ) 128 | 129 | def _perform_base_evaluation(self, test_evaluator): 130 | """Perform the base evaluation of the LM function.""" 131 | return test_evaluator(self.lm_function) 132 | 133 | def _run_optimization(self, evaluator): 134 | """Run the optimization process.""" 135 | optimizer = LabeledFewShot(demoset=self.training_set_demos, shots=self.shots) 136 | return optimizer.perform( 137 | self.lm_function, 138 | evaluator=JSONAdaptor.metric_evaluator( 139 | data=self.validation_set_demos, 140 | eval_function=evaluator, 141 | ), 142 | samples=self.samples, 143 | rounds=1, 144 | ) 145 | 146 | def _evaluate_best_function(self, test_evaluator, optimizer_result): 147 | """Evaluate the best function from the optimization result.""" 148 | return test_evaluator(optimizer_result.best_function) 149 | -------------------------------------------------------------------------------- /py/src/zenbase/predefined/single_class_classifier/function_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides a SingleClassClassifierLMFunctionGenerator for generating language model functions. 3 | """ 4 | 5 | import logging 6 | from dataclasses import dataclass, field 7 | from enum import Enum 8 | from typing import Dict, Optional, Type 9 | 10 | from instructor.client import AsyncInstructor, Instructor 11 | from pydantic import BaseModel 12 | from tenacity import ( 13 | before_sleep_log, 14 | retry, 15 | stop_after_attempt, 16 | wait_exponential_jitter, 17 | ) 18 | 19 | from zenbase.core.managers import ZenbaseTracer 20 | from zenbase.predefined.base.function_generator import BaseLMFunctionGenerator 21 | from zenbase.types import LMFunction, LMRequest 22 | 23 | log = logging.getLogger(__name__) 24 | 25 | 26 | @dataclass(kw_only=True) 27 | class SingleClassClassifierLMFunctionGenerator(BaseLMFunctionGenerator): 28 | """ 29 | A generator for creating single-class classifier language model functions. 30 | """ 31 | 32 | instructor_client: Instructor | AsyncInstructor 33 | prompt: str 34 | class_dict: Optional[Dict[str, str]] = field(default=None) 35 | class_enum: Optional[Enum] = field(default=None) 36 | prediction_class: Optional[Type[BaseModel]] = field(default=None) 37 | model: str 38 | zenbase_tracer: ZenbaseTracer 39 | 40 | def __post_init__(self): 41 | """Initialize the generator after creation.""" 42 | self._initialize_class_enum() 43 | self._initialize_prediction_class() 44 | 45 | def _initialize_class_enum(self): 46 | """Initialize the class enum if not provided.""" 47 | if not self.class_enum and self.class_dict: 48 | self.class_enum = self._generate_class_enum() 49 | 50 | def _initialize_prediction_class(self): 51 | """Initialize the prediction class if not provided.""" 52 | if not self.prediction_class and self.class_enum: 53 | self.prediction_class = self._generate_prediction_class() 54 | 55 | def generate(self) -> LMFunction: 56 | """Generate the classifier language model function.""" 57 | return self._generate_classifier_prompt_lm_function() 58 | 59 | def _generate_class_enum(self) -> Enum: 60 | """Generate the class enum from the class dictionary.""" 61 | return Enum("Labels", self.class_dict) 62 | 63 | def _generate_prediction_class(self) -> Type[BaseModel]: 64 | """Generate the prediction class based on the class enum.""" 65 | class_enum = self.class_enum 66 | 67 | class SinglePrediction(BaseModel): 68 | reasoning: str 69 | class_label: class_enum 70 | 71 | return SinglePrediction 72 | 73 | def _generate_classifier_prompt_lm_function(self) -> LMFunction: 74 | """Generate the classifier prompt language model function.""" 75 | 76 | @retry( 77 | stop=stop_after_attempt(3), 78 | wait=wait_exponential_jitter(max=8), 79 | before_sleep=before_sleep_log(log, logging.WARN), 80 | ) 81 | def classifier_function(request: LMRequest): 82 | categories = "\n".join([f"- {key.upper()}: {value}" for key, value in self.class_dict.items()]) 83 | messages = [ 84 | { 85 | "role": "system", 86 | "content": f"""You are an expert classifier. Your task is to categorize inputs accurately based 87 | on the following instructions: {self.prompt} 88 | 89 | Categories and their descriptions: 90 | {categories} 91 | 92 | Rules: 93 | 1. Analyze the input carefully. 94 | 2. Choose the most appropriate category based on the descriptions provided. 95 | 3. Respond with ONLY the category name in UPPERCASE. 96 | 4. If unsure, choose the category that best fits the input.""", 97 | } 98 | ] 99 | 100 | if request.zenbase.task_demos: 101 | messages.append({"role": "system", "content": "Here are some examples of classifications:"}) 102 | for demo in request.zenbase.task_demos: 103 | messages.extend( 104 | [ 105 | {"role": "user", "content": demo.inputs["question"]}, 106 | {"role": "assistant", "content": demo.outputs["answer"]}, 107 | ] 108 | ) 109 | messages.append( 110 | {"role": "system", "content": "Now, classify the new input following the same pattern."} 111 | ) 112 | 113 | messages.extend( 114 | [ 115 | {"role": "system", "content": "Please classify the following input:"}, 116 | {"role": "user", "content": str(request.inputs)}, 117 | ] 118 | ) 119 | 120 | return self.instructor_client.chat.completions.create( 121 | model=self.model, response_model=self.prediction_class, messages=messages 122 | ) 123 | 124 | return self.zenbase_tracer.trace_function(classifier_function) 125 | -------------------------------------------------------------------------------- /py/src/zenbase/predefined/syntethic_data/single_class_classifier.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import io 3 | from typing import Dict, List 4 | 5 | from instructor import Instructor 6 | from pydantic import BaseModel, Field 7 | 8 | 9 | class SingleClassClassifierSyntheticDataExample(BaseModel): 10 | inputs: str = Field(..., description="The input text for single class classification") 11 | outputs: str = Field(..., description="The correct classification category") 12 | 13 | 14 | class SingleClassClassifierSyntheticDataGenerator: 15 | def __init__( 16 | self, 17 | instructor_client: Instructor, 18 | prompt: str, 19 | class_dict: Dict[str, str], 20 | model: str = "gpt-4o-mini", 21 | ): 22 | self.instructor_client = instructor_client 23 | self.prompt = prompt 24 | self.class_dict = class_dict 25 | self.model = model 26 | 27 | def generate_examples_for_category( 28 | self, category: str, description: str, num_examples: int 29 | ) -> List[SingleClassClassifierSyntheticDataExample]: 30 | messages = [ 31 | { 32 | "role": "system", 33 | "content": f"""You are an expert in generating synthetic datasets for single class classification 34 | tasks. Your task is to create diverse and realistic examples based on the following instructions: 35 | 36 | {self.prompt} 37 | 38 | You are focusing on generating examples for the following category: 39 | - {category}: {description} 40 | 41 | For each example, generate: 42 | 1. A realistic and diverse input text that should be classified into the given category. 43 | 2. The category name as the output. 44 | 45 | Ensure diversity in the generated examples.""", 46 | }, 47 | {"role": "user", "content": f"Generate {num_examples} examples for the category '{category}'."}, 48 | ] 49 | 50 | response = self.instructor_client.chat.completions.create( 51 | model=self.model, response_model=List[SingleClassClassifierSyntheticDataExample], messages=messages 52 | ) 53 | 54 | return response 55 | 56 | def generate_examples(self, examples_per_category: int) -> List[SingleClassClassifierSyntheticDataExample]: 57 | all_examples = [] 58 | for category, description in self.class_dict.items(): 59 | category_examples = self.generate_examples_for_category(category, description, examples_per_category) 60 | all_examples.extend(category_examples) 61 | return all_examples 62 | 63 | def generate_csv(self, examples_per_category: int) -> str: 64 | examples = self.generate_examples(examples_per_category) 65 | 66 | output = io.StringIO() 67 | writer = csv.DictWriter(output, fieldnames=["inputs", "outputs"]) 68 | writer.writeheader() 69 | for example in examples: 70 | writer.writerow(example.dict()) 71 | 72 | return output.getvalue() 73 | 74 | def save_csv(self, filename: str, examples_per_category: int): 75 | csv_content = self.generate_csv(examples_per_category) 76 | with open(filename, "w", newline="", encoding="utf-8") as f: 77 | f.write(csv_content) 78 | -------------------------------------------------------------------------------- /py/src/zenbase/settings.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | # Define the base directory as the parent of the parent directory of this file 4 | BASE_DIR = Path(__file__).resolve().parent.parent 5 | 6 | # Define the test directory as a subdirectory of the base directory's parent 7 | TEST_DIR = BASE_DIR.parent / "tests" 8 | -------------------------------------------------------------------------------- /py/src/zenbase/types.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import inspect 3 | import json 4 | from collections import deque 5 | from copy import copy 6 | from functools import partial 7 | from typing import Awaitable, Callable, Generic, TypeVar, Union, get_origin 8 | 9 | from zenbase.utils import asyncify, ksuid_generator 10 | 11 | 12 | class Dataclass: 13 | """ 14 | Modified from Braintrust's SerializableDataClass 15 | """ 16 | 17 | def copy(self, **changes): 18 | return dataclasses.replace(self, **changes) 19 | 20 | def as_dict(self): 21 | """Serialize the object to a dictionary.""" 22 | return dataclasses.asdict(self) 23 | 24 | def as_json(self, **kwargs): 25 | """Serialize the object to JSON.""" 26 | return json.dumps(self.as_dict(), **kwargs) 27 | 28 | @classmethod 29 | def from_dict(cls, d: dict): 30 | """Deserialize the object from a dictionary. This method 31 | is shallow and will not call from_dict() on nested objects.""" 32 | fields = set(f.name for f in dataclasses.fields(cls)) 33 | filtered = {k: v for k, v in d.items() if k in fields} 34 | return cls(**filtered) 35 | 36 | @classmethod 37 | def from_dict_deep(cls, d: dict): 38 | """Deserialize the object from a dictionary. This method 39 | is deep and will call from_dict_deep() on nested objects.""" 40 | fields = {f.name: f for f in dataclasses.fields(cls)} 41 | filtered = {} 42 | for k, v in d.items(): 43 | if k not in fields: 44 | continue 45 | 46 | if isinstance(v, dict) and isinstance(fields[k].type, type) and issubclass(fields[k].type, Dataclass): 47 | filtered[k] = fields[k].type.from_dict_deep(v) 48 | elif get_origin(fields[k].type) == Union: 49 | for t in fields[k].type.__args__: 50 | if isinstance(t, type) and issubclass(t, Dataclass): 51 | try: 52 | filtered[k] = t.from_dict_deep(v) 53 | break 54 | except TypeError: 55 | pass 56 | else: 57 | filtered[k] = v 58 | elif ( 59 | isinstance(v, list) 60 | and get_origin(fields[k].type) == list 61 | and len(fields[k].type.__args__) == 1 62 | and isinstance(fields[k].type.__args__[0], type) 63 | and issubclass(fields[k].type.__args__[0], Dataclass) 64 | ): 65 | filtered[k] = [fields[k].type.__args__[0].from_dict_deep(i) for i in v] 66 | else: 67 | filtered[k] = v 68 | return cls(**filtered) 69 | 70 | 71 | Inputs = TypeVar("Inputs", covariant=True, bound=dict) 72 | Outputs = TypeVar("Outputs", covariant=True, bound=dict) 73 | 74 | 75 | @dataclasses.dataclass(frozen=True) 76 | class LMDemo(Dataclass, Generic[Inputs, Outputs]): 77 | inputs: Inputs 78 | outputs: Outputs 79 | adaptor_object: object | None = None 80 | 81 | def __hash__(self): 82 | def make_hashable(obj): 83 | if isinstance(obj, dict): 84 | return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) 85 | elif isinstance(obj, list): 86 | return tuple(make_hashable(x) for x in obj) 87 | elif isinstance(obj, set): 88 | return tuple(sorted(make_hashable(x) for x in obj)) 89 | else: 90 | return obj 91 | return hash((make_hashable(self.inputs), make_hashable(self.outputs))) 92 | 93 | 94 | @dataclasses.dataclass(frozen=True) 95 | class LMZenbase(Dataclass, Generic[Inputs, Outputs]): 96 | task_demos: list[LMDemo[Inputs, Outputs]] = dataclasses.field(default_factory=list) 97 | model_params: dict = dataclasses.field(default_factory=dict) # OpenAI-compatible model params 98 | 99 | 100 | @dataclasses.dataclass() 101 | class LMRequest(Dataclass, Generic[Inputs, Outputs]): 102 | zenbase: LMZenbase[Inputs, Outputs] 103 | inputs: Inputs = dataclasses.field(default_factory=dict) 104 | id: str = dataclasses.field(default_factory=ksuid_generator("request")) 105 | 106 | 107 | @dataclasses.dataclass(frozen=True) 108 | class LMResponse(Dataclass, Generic[Outputs]): 109 | outputs: Outputs 110 | attributes: dict = dataclasses.field(default_factory=dict) # token_count, cost, inference_time, etc. 111 | id: str = dataclasses.field(default_factory=ksuid_generator("response")) 112 | 113 | 114 | @dataclasses.dataclass(frozen=True) 115 | class LMCall(Dataclass, Generic[Inputs, Outputs]): 116 | function: "LMFunction[Inputs, Outputs]" 117 | request: LMRequest[Inputs, Outputs] 118 | response: LMResponse[Outputs] 119 | id: str = dataclasses.field(default_factory=ksuid_generator("call")) 120 | 121 | 122 | class LMFunction(Generic[Inputs, Outputs]): 123 | gen_id = staticmethod(ksuid_generator("fn")) 124 | 125 | id: str 126 | fn: Callable[[LMRequest[Inputs, Outputs]], Outputs | Awaitable[Outputs]] 127 | __name__: str 128 | __qualname__: str 129 | __doc__: str 130 | __signature__: inspect.Signature 131 | zenbase: LMZenbase[Inputs, Outputs] 132 | history: deque[LMCall[Inputs, Outputs]] 133 | 134 | def __init__( 135 | self, 136 | fn: Callable[[LMRequest[Inputs, Outputs]], Outputs | Awaitable[Outputs]], 137 | zenbase: LMZenbase | None = None, 138 | maxhistory: int = 100, 139 | ): 140 | self.fn = fn 141 | 142 | if qualname := getattr(fn, "__qualname__", None): 143 | self.id = qualname 144 | self.__qualname__ = qualname 145 | else: 146 | self.id = self.gen_id() 147 | self.__qualname__ = f"zenbase_{self.id}" 148 | 149 | self.__name__ = getattr(fn, "__name__", f"zenbase_{self.id}") 150 | 151 | self.__doc__ = getattr(fn, "__doc__", "") 152 | self.__signature__ = inspect.signature(fn) 153 | 154 | self.zenbase = zenbase or LMZenbase() 155 | self.history = deque([], maxlen=maxhistory) 156 | 157 | def clean_and_duplicate(self, zenbase: LMZenbase | None = None) -> "LMFunction[Inputs, Outputs]": 158 | dup = copy(self) 159 | dup.id = self.gen_id() 160 | dup.zenbase = zenbase or self.zenbase.copy() 161 | dup.history = deque([], maxlen=self.history.maxlen) 162 | return dup 163 | 164 | def prepare_request(self, inputs: Inputs) -> LMRequest[Inputs, Outputs]: 165 | return LMRequest(zenbase=self.zenbase, inputs=inputs) 166 | 167 | def process_response( 168 | self, 169 | request: LMRequest[Inputs, Outputs], 170 | outputs: Outputs, 171 | ) -> Outputs: 172 | self.history.append(LMCall(self, request, LMResponse(outputs))) 173 | return outputs 174 | 175 | def __call__(self, inputs: Inputs, *args, **kwargs) -> Outputs: 176 | request = self.prepare_request(inputs) 177 | kwargs.update({"lm_function": self} if "lm_function" in inspect.signature(self.fn).parameters else {}) 178 | response = self.fn(request, *args, **kwargs) 179 | return self.process_response(request, response) 180 | 181 | async def coro( 182 | self, 183 | inputs: Inputs, 184 | *args, 185 | **kwargs, 186 | ) -> Outputs: 187 | request = self.prepare_request(inputs) 188 | response = await asyncify(self.fn)(request, *args, **kwargs) 189 | return self.process_response(request, response) 190 | 191 | 192 | def deflm( 193 | function: (Callable[[LMRequest[Inputs, Outputs]], Outputs | Awaitable[Outputs]] | None) = None, 194 | zenbase: LMZenbase[Inputs, Outputs] | None = None, 195 | ) -> LMFunction[Inputs, Outputs]: 196 | if function is None: 197 | return partial(deflm, zenbase=zenbase) 198 | 199 | if isinstance(function, LMFunction): 200 | return function.clean_and_duplicate(zenbase) 201 | 202 | return LMFunction(function, zenbase) 203 | -------------------------------------------------------------------------------- /py/src/zenbase/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | import inspect 4 | import json 5 | import logging 6 | import os 7 | from random import Random 8 | from typing import AsyncIterable, Awaitable, Callable, ParamSpec, TypeVar 9 | 10 | import anyio 11 | from anyio._core._eventloop import threadlocals 12 | from faker import Faker 13 | from opentelemetry import trace 14 | from pksuid import PKSUID 15 | from posthog import Posthog 16 | from structlog import get_logger 17 | 18 | get_logger: Callable[..., logging.Logger] = get_logger 19 | ot_tracer = trace.get_tracer("zenbase") 20 | 21 | 22 | def posthog() -> Posthog: 23 | if project_api_key := os.getenv("ZENBASE_ANALYTICS_KEY"): 24 | client = Posthog( 25 | project_api_key=project_api_key, 26 | host="https://us.i.posthog.com", 27 | ) 28 | client.identify(os.environ["ZENBASE_ANALYTICS_ID"]) 29 | else: 30 | client = Posthog("") 31 | client.disabled = True 32 | return client 33 | 34 | 35 | def get_seed(seed: int | None = None) -> int: 36 | return seed or int(os.getenv("RANDOM_SEED", 42)) 37 | 38 | 39 | def random_factory(seed: int | None = None) -> Random: 40 | return Random(get_seed(seed)) 41 | 42 | 43 | def ksuid(prefix: str | None = None) -> str: 44 | return str(PKSUID(prefix)) 45 | 46 | 47 | def ksuid_generator(prefix: str) -> Callable[[], str]: 48 | return functools.partial(ksuid, prefix) 49 | 50 | 51 | def random_name_generator( 52 | prefix: str | None = None, 53 | random_name_generator=Faker().catch_phrase, 54 | ) -> Callable[[], str]: 55 | head = f"zenbase-{prefix}" if prefix else "zenbase" 56 | 57 | def gen(): 58 | return "-".join([head, *random_name_generator().lower().split(" ")[:2]]) 59 | 60 | return gen 61 | 62 | 63 | I_ParamSpec = ParamSpec("I_ParamSpec") 64 | O_Retval = TypeVar("O_Retval") 65 | 66 | 67 | def asyncify( 68 | func: Callable[I_ParamSpec, O_Retval], 69 | *, 70 | cancellable: bool = True, 71 | limiter: anyio.CapacityLimiter | None = None, 72 | ) -> Callable[I_ParamSpec, Awaitable[O_Retval]]: 73 | if inspect.iscoroutinefunction(func): 74 | return func 75 | 76 | @functools.wraps(func) 77 | async def wrapper(*args: I_ParamSpec.args, **kwargs: I_ParamSpec.kwargs) -> O_Retval: 78 | partial_f = functools.partial(func, *args, **kwargs) 79 | return await anyio.to_thread.run_sync( 80 | partial_f, 81 | abandon_on_cancel=cancellable, 82 | limiter=limiter, 83 | ) 84 | 85 | return wrapper 86 | 87 | 88 | def syncify( 89 | func: Callable[I_ParamSpec, O_Retval], 90 | ) -> Callable[I_ParamSpec, O_Retval]: 91 | if not inspect.iscoroutinefunction(func): 92 | return func 93 | 94 | @functools.wraps(func) 95 | def wrapper(*args: I_ParamSpec.args, **kwargs: I_ParamSpec.kwargs) -> O_Retval: 96 | partial_f = functools.partial(func, *args, **kwargs) 97 | if not getattr(threadlocals, "current_async_backend", None): 98 | try: 99 | return asyncio.get_running_loop().run_until_complete(partial_f()) 100 | except RuntimeError: 101 | return anyio.run(partial_f) 102 | return anyio.from_thread.run(partial_f) 103 | 104 | return wrapper 105 | 106 | 107 | ReturnValue = TypeVar("ReturnValue", covariant=True) 108 | 109 | 110 | async def amap( 111 | func: Callable[..., Awaitable[ReturnValue]], 112 | iterable, 113 | *iterables, 114 | concurrency=10, 115 | ) -> list[ReturnValue]: 116 | assert concurrency >= 1, "Concurrency must be greater than or equal to 1" 117 | 118 | if concurrency == 1: 119 | return [await func(*args) for args in zip(iterable, *iterables)] 120 | 121 | if concurrency == float("inf"): 122 | return await asyncio.gather(*[func(*args) for args in zip(iterable, *iterables)]) 123 | 124 | semaphore = asyncio.Semaphore(concurrency) 125 | 126 | @functools.wraps(func) 127 | async def mapper(*args): 128 | async with semaphore: 129 | return await func(*args) 130 | 131 | return await asyncio.gather(*[mapper(*args) for args in zip(iterable, *iterables)]) 132 | 133 | 134 | def pmap( 135 | func: Callable[..., ReturnValue], 136 | iterable, 137 | *iterables, 138 | concurrency=10, 139 | ) -> list[ReturnValue]: 140 | # TODO: Should revert. 141 | return [func(*args) for args in zip(iterable, *iterables)] 142 | 143 | 144 | async def alist(aiterable: AsyncIterable[ReturnValue]) -> list[ReturnValue]: 145 | return [x async for x in aiterable] 146 | 147 | 148 | def expand_nested_json(d): 149 | def recursive_expand(value): 150 | if isinstance(value, str): 151 | try: 152 | # Try to parse the string value as JSON 153 | parsed_value = json.loads(value) 154 | # Recursively expand the parsed value in case it contains further nested JSON 155 | return recursive_expand(parsed_value) 156 | except json.JSONDecodeError: 157 | # If parsing fails, return the original string 158 | return value 159 | elif isinstance(value, dict): 160 | # Recursively expand each key-value pair in the dictionary 161 | return {k: recursive_expand(v) for k, v in value.items()} 162 | elif isinstance(value, list): 163 | # Recursively expand each element in the list 164 | return [recursive_expand(elem) for elem in value] 165 | else: 166 | return value 167 | 168 | return recursive_expand(d) 169 | -------------------------------------------------------------------------------- /py/tests/adaptors/bootstrap_few_shot_optimizer_args.zenbase: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/tests/adaptors/bootstrap_few_shot_optimizer_args.zenbase -------------------------------------------------------------------------------- /py/tests/adaptors/parea_bootstrap_few_shot.zenbase: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/tests/adaptors/parea_bootstrap_few_shot.zenbase -------------------------------------------------------------------------------- /py/tests/adaptors/test_braintrust.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zenbase-ai/core/59971868e85784c54eddb6980cf791b975e0befb/py/tests/adaptors/test_braintrust.py -------------------------------------------------------------------------------- /py/tests/adaptors/test_lunary.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import lunary 4 | import pytest 5 | from openai import OpenAI 6 | from tenacity import ( 7 | before_sleep_log, 8 | retry, 9 | stop_after_attempt, 10 | wait_exponential_jitter, 11 | ) 12 | 13 | from zenbase.adaptors.lunary import ZenLunary 14 | from zenbase.core.managers import ZenbaseTracer 15 | from zenbase.optim.metric.bootstrap_few_shot import BootstrapFewShot 16 | from zenbase.optim.metric.labeled_few_shot import LabeledFewShot 17 | from zenbase.types import LMRequest 18 | 19 | SAMPLES = 2 20 | SHOTS = 3 21 | TESTSET_SIZE = 5 22 | 23 | log = logging.getLogger(__name__) 24 | 25 | 26 | @pytest.fixture 27 | def optim(gsm8k_demoset: list): 28 | return LabeledFewShot(demoset=gsm8k_demoset, shots=SHOTS) 29 | 30 | 31 | @pytest.fixture 32 | def bootstrap_few_shot_optim(gsm8k_demoset: list): 33 | return BootstrapFewShot(training_set_demos=gsm8k_demoset, shots=SHOTS) 34 | 35 | 36 | @pytest.fixture(scope="module") 37 | def openai(): 38 | client = OpenAI() 39 | lunary.monitor(client) 40 | return client 41 | 42 | 43 | @pytest.fixture(scope="module") 44 | def evalset(): 45 | items = lunary.get_dataset("gsm8k-evalset") 46 | assert any(items) 47 | return items 48 | 49 | 50 | @pytest.mark.helpers 51 | def test_lunary_lcel_labeled_few_shot(optim: LabeledFewShot, evalset: list): 52 | trace_manager = ZenbaseTracer() 53 | 54 | @trace_manager.trace_function 55 | @retry( 56 | stop=stop_after_attempt(3), 57 | wait=wait_exponential_jitter(max=8), 58 | before_sleep=before_sleep_log(log, logging.WARN), 59 | ) 60 | def langchain_chain(request: LMRequest): 61 | """ 62 | A math solver llm call that can solve any math problem setup with langchain libra. 63 | """ 64 | 65 | from langchain_core.output_parsers import StrOutputParser 66 | from langchain_core.prompts import ChatPromptTemplate 67 | from langchain_openai import ChatOpenAI 68 | 69 | messages = [ 70 | ( 71 | "system", 72 | "You are an expert math solver. Your answer must be just the number with no separators, and nothing else. Follow the format of the examples.", # noqa 73 | # noqa 74 | ) 75 | ] 76 | for demo in request.zenbase.task_demos: 77 | messages += [ 78 | ("user", demo.inputs["question"]), 79 | ("assistant", demo.outputs["answer"]), 80 | ] 81 | 82 | messages.append(("user", "{question}")) 83 | 84 | chain = ChatPromptTemplate.from_messages(messages) | ChatOpenAI(model="gpt-4o-mini") | StrOutputParser() 85 | 86 | print("Mathing...") 87 | answer = chain.invoke(request.inputs) 88 | return answer.split("#### ")[-1] 89 | 90 | fn, candidates, _ = optim.perform( 91 | langchain_chain, 92 | evaluator=ZenLunary.metric_evaluator( 93 | checklist="exact-match", 94 | evalset=evalset, 95 | concurrency=2, 96 | ), 97 | samples=SAMPLES, 98 | rounds=1, 99 | ) 100 | 101 | assert fn is not None 102 | assert any(candidates) 103 | assert next(c for c in candidates if 0.5 <= c.evals["score"] <= 1) 104 | 105 | 106 | @pytest.fixture(scope="module") 107 | def lunary_helper(): 108 | return ZenLunary(client=lunary) 109 | -------------------------------------------------------------------------------- /py/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable 3 | 4 | import pytest 5 | from datasets import DatasetDict 6 | 7 | from zenbase.types import LMDemo 8 | 9 | 10 | def pytest_configure(): 11 | from pathlib import Path 12 | 13 | if not os.getenv("CI"): 14 | from dotenv import load_dotenv 15 | 16 | load_dotenv(str(Path(__file__).parent.parent / ".env.test")) 17 | 18 | import nest_asyncio 19 | 20 | nest_asyncio.apply() 21 | 22 | 23 | @pytest.fixture 24 | def anyio_backend(): 25 | return "asyncio" 26 | 27 | 28 | @pytest.fixture(scope="session", autouse=True) 29 | def vcr_config(): 30 | def response_processor(exclude_headers: list[str]) -> Callable: 31 | def before_record_response(response: dict | Any) -> dict | Any: 32 | if isinstance(response, dict): 33 | try: 34 | response_str = response.get("body", {}).get("string", b"").decode("utf-8") 35 | if "Rate limit reached for" in response_str: 36 | # don't record rate-limiting responses 37 | return None 38 | except UnicodeDecodeError: 39 | pass # ignore if we can't parse response 40 | 41 | for header in exclude_headers: 42 | if header in response["headers"]: 43 | response["headers"].pop(header) 44 | return response 45 | 46 | return before_record_response 47 | 48 | return { 49 | "filter_headers": [ 50 | "User-Agent", 51 | "Accept", 52 | "Accept-Encoding", 53 | "Connection", 54 | "Content-Length", 55 | "Content-Type", 56 | # OpenAI request headers we don't want 57 | "Cookie", 58 | "authorization", 59 | "X-OpenAI-Client-User-Agent", 60 | "OpenAI-Organization", 61 | "x-stainless-lang", 62 | "x-stainless-package-version", 63 | "x-stainless-os", 64 | "x-stainless-arch", 65 | "x-stainless-runtime", 66 | "x-stainless-runtime-version", 67 | "x-api-key", 68 | ], 69 | "filter_query_parameters": ["api_key"], 70 | "cassette_library_dir": "tests/cache/cassettes", 71 | "before_record_response": response_processor( 72 | exclude_headers=[ 73 | # OpenAI response headers we don't want 74 | "Set-Cookie", 75 | "Server", 76 | "access-control-allow-origin", 77 | "alt-svc", 78 | "openai-organization", 79 | "openai-version", 80 | "strict-transport-security", 81 | "x-ratelimit-limit-requests", 82 | "x-ratelimit-limit-tokens", 83 | "x-ratelimit-remaining-requests", 84 | "x-ratelimit-remaining-tokens", 85 | "x-ratelimit-reset-requests", 86 | "x-ratelimit-reset-tokens", 87 | "x-request-id", 88 | ] 89 | ), 90 | "match_on": [ 91 | "method", 92 | "scheme", 93 | "host", 94 | "port", 95 | "path", 96 | "query", 97 | "body", 98 | "headers", 99 | ], 100 | } 101 | 102 | 103 | @pytest.fixture(scope="session") 104 | def gsm8k_dataset(): 105 | import datasets 106 | 107 | return datasets.load_dataset("gsm8k", "main") 108 | 109 | 110 | @pytest.fixture(scope="session") 111 | def arxiv_dataset(): 112 | import datasets 113 | 114 | return datasets.load_dataset("dansbecker/arxiv_article_classification") 115 | 116 | 117 | @pytest.fixture(scope="session") 118 | def news_dataset(): 119 | import datasets 120 | 121 | return datasets.load_dataset("SetFit/20_newsgroups") 122 | 123 | 124 | @pytest.fixture(scope="session") 125 | def gsm8k_demoset(gsm8k_dataset: DatasetDict) -> list[LMDemo]: 126 | return [ 127 | LMDemo( 128 | inputs={"question": r["question"]}, 129 | outputs={"answer": r["answer"]}, 130 | ) 131 | for r in gsm8k_dataset["train"].select(range(5)) 132 | ] 133 | -------------------------------------------------------------------------------- /py/tests/core/managers.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from datetime import datetime 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | 7 | from zenbase.core.managers import ZenbaseTracer 8 | from zenbase.types import LMFunction, LMZenbase 9 | 10 | 11 | @pytest.fixture 12 | def zenbase_manager(): 13 | return ZenbaseTracer() 14 | 15 | 16 | @pytest.fixture 17 | def layer_2_1(zenbase_manager): 18 | @zenbase_manager.trace_function 19 | def _layer_2_1(request, instruction="default_instruction_2_1", candidates=[]): 20 | layer_2_1_output = f"layer_2_1_output_{str(request.inputs)}" 21 | return layer_2_1_output 22 | 23 | return _layer_2_1 24 | 25 | 26 | @pytest.fixture 27 | def layer_1_1(zenbase_manager, layer_2_1): 28 | @zenbase_manager.trace_function 29 | def _layer_1_1(request, instruction="default_instruction_1_1", candidates=[]): 30 | layer_2_1(request.inputs, instruction=instruction) 31 | layer_1_1_output = f"layer_1_1_output_{str(request.inputs)}" 32 | return layer_1_1_output 33 | 34 | return _layer_1_1 35 | 36 | 37 | @pytest.fixture 38 | def layer_1_2(zenbase_manager): 39 | @zenbase_manager.trace_function 40 | def _layer_1_2(request, instruction="default_instruction_1_2", candidates=[]): 41 | layer_1_2_output = f"layer_1_2_output_{str(request.inputs)}" 42 | return layer_1_2_output 43 | 44 | return _layer_1_2 45 | 46 | 47 | @pytest.fixture 48 | def layer_0(zenbase_manager, layer_1_1, layer_1_2): 49 | @zenbase_manager.trace_function 50 | def _layer_0(request, instruction="default_instruction_0", candidates=[]): 51 | layer_1_1(inputs=request.inputs) 52 | layer_1_2(inputs=request.inputs) 53 | layer_0_output = f"layer_0_output_{str(request.inputs)}" 54 | return layer_0_output 55 | 56 | return _layer_0 57 | 58 | 59 | @pytest.fixture 60 | def layer_0_2(zenbase_manager, layer_1_1, layer_1_2): 61 | @zenbase_manager.trace_function 62 | def _layer_0_2(request, instruction="default_instruction_0_2", candidates=[]): 63 | layer_1_1(inputs=request.inputs["inputs"]) 64 | layer_1_2(inputs=request.inputs["inputs"]) 65 | layer_0_output = f"layer_0_2_output_{str(request.inputs['inputs'])}" 66 | return layer_0_output 67 | 68 | return _layer_0_2 69 | 70 | 71 | def test_trace_layer_0(zenbase_manager, layer_0): 72 | inputs = [{"inputs": i} for i in range(5)] 73 | 74 | for inputs in inputs: 75 | layer_0(inputs=inputs) 76 | 77 | assert len(zenbase_manager.all_traces) == 5 78 | for trace in zenbase_manager.all_traces.values(): 79 | assert "_layer_0" in trace 80 | 81 | 82 | def test_trace_layer_0_multiple_runs(zenbase_manager, layer_0): 83 | inputs = [{"inputs": i} for i in range(5)] 84 | 85 | for the_input in inputs: 86 | layer_0(inputs=the_input) 87 | for the_input in inputs: 88 | layer_0(inputs=the_input) 89 | 90 | assert len(zenbase_manager.all_traces) == 10 91 | 92 | 93 | def test_trace_layer_0_2(zenbase_manager, layer_0_2): 94 | inputs = [{"inputs": i} for i in range(5)] 95 | 96 | for the_input in inputs: 97 | layer_0_2(inputs=the_input) 98 | 99 | assert len(zenbase_manager.all_traces) == 5 100 | for trace in zenbase_manager.all_traces.values(): 101 | assert "_layer_0_2" in trace 102 | 103 | 104 | def test_trace_layer_0_with_optimized_args(zenbase_manager, layer_0): 105 | inputs = [{"inputs": i} for i in range(5)] 106 | optimized_args = { 107 | "layer_2_1": {"args": {"instruction": "optimized_instruction_2_1", "candidates": ["optimized_candidate_2_1"]}}, 108 | "layer_1_1": {"args": {"instruction": "optimized_instruction_1_1", "candidates": ["optimized_candidate_1_1"]}}, 109 | "layer_1_2": {"args": {"instruction": "optimized_instruction_1_2", "candidates": ["optimized_candidate_1_2"]}}, 110 | "layer_0": {"args": {"instruction": "optimized_instruction_0", "candidates": ["optimized_candidate_0"]}}, 111 | } 112 | 113 | def optimized_layer_0(*args, **kwargs): 114 | with zenbase_manager.trace_context( 115 | "optimized_layer_0", f"optimized_layer_0_{datetime.now().isoformat()}", optimized_args 116 | ): 117 | return layer_0(*args, **kwargs) 118 | 119 | for the_input in inputs: 120 | optimized_layer_0(inputs=the_input) 121 | 122 | assert len(zenbase_manager.all_traces) == 5 123 | for trace in zenbase_manager.all_traces.values(): 124 | assert "optimized_layer_0" in trace 125 | 126 | 127 | def test_trace_layer_functions(zenbase_manager, layer_2_1, layer_1_1, layer_1_2): 128 | inputs = [{"inputs": i} for i in range(5)] 129 | 130 | for inputs in inputs: 131 | layer_2_1(inputs=inputs) 132 | layer_1_1(inputs=inputs) 133 | layer_1_2(inputs=inputs) 134 | 135 | assert len(zenbase_manager.all_traces) == 15 136 | for trace in zenbase_manager.all_traces.values(): 137 | assert any(func in trace for func in ["_layer_2_1", "_layer_1_1", "_layer_1_2"]) 138 | 139 | 140 | @pytest.fixture 141 | def tracer(): 142 | return ZenbaseTracer(max_traces=3) 143 | 144 | 145 | def test_init(tracer): 146 | assert isinstance(tracer.all_traces, OrderedDict) 147 | assert tracer.max_traces == 3 148 | assert tracer.current_trace is None 149 | assert tracer.current_key is None 150 | assert tracer.optimized_args == {} 151 | 152 | 153 | def test_flush(tracer): 154 | tracer.all_traces = OrderedDict({"key1": "value1", "key2": "value2"}) 155 | tracer.flush() 156 | assert len(tracer.all_traces) == 0 157 | 158 | 159 | def test_add_trace(tracer): 160 | # Add first trace 161 | tracer.add_trace("timestamp1", "func1", {"data": "trace1"}) 162 | assert len(tracer.all_traces) == 1 163 | assert "timestamp1" in tracer.all_traces 164 | 165 | # Add second trace 166 | tracer.add_trace("timestamp2", "func2", {"data": "trace2"}) 167 | assert len(tracer.all_traces) == 2 168 | 169 | # Add third trace 170 | tracer.add_trace("timestamp3", "func3", {"data": "trace3"}) 171 | assert len(tracer.all_traces) == 3 172 | 173 | # Add fourth trace (should remove oldest) 174 | tracer.add_trace("timestamp4", "func4", {"data": "trace4"}) 175 | assert len(tracer.all_traces) == 3 176 | assert "timestamp1" not in tracer.all_traces 177 | assert "timestamp4" in tracer.all_traces 178 | 179 | 180 | @patch("zenbase.utils.ksuid") 181 | def test_trace_function(mock_ksuid, tracer): 182 | mock_ksuid.return_value = "test_timestamp" 183 | 184 | def test_func(request): 185 | return request.inputs[0] + request.inputs[1] 186 | 187 | zenbase = LMZenbase() 188 | traced_func = tracer.trace_function(test_func, zenbase) 189 | assert isinstance(traced_func, LMFunction) 190 | 191 | result = traced_func(inputs=(2, 3)) 192 | 193 | assert result == 5 194 | trace = tracer.all_traces[list(tracer.all_traces.keys())[0]] 195 | assert "test_func" in trace["test_func"] 196 | trace_info = trace["test_func"]["test_func"] 197 | assert trace_info["args"]["request"].inputs == (2, 3) 198 | assert trace_info["output"] == 5 199 | 200 | 201 | def test_trace_context(tracer): 202 | with tracer.trace_context("test_func", "test_timestamp"): 203 | assert tracer.current_key == "test_timestamp" 204 | assert isinstance(tracer.current_trace, dict) 205 | 206 | assert tracer.current_trace is None 207 | assert tracer.current_key is None 208 | assert "test_timestamp" in tracer.all_traces 209 | 210 | 211 | def test_max_traces_limit(tracer): 212 | for i in range(5): 213 | tracer.add_trace(f"timestamp{i}", f"func{i}", {"data": f"trace{i}"}) 214 | 215 | assert len(tracer.all_traces) == 3 216 | assert "timestamp0" not in tracer.all_traces 217 | assert "timestamp1" not in tracer.all_traces 218 | assert "timestamp2" in tracer.all_traces 219 | assert "timestamp3" in tracer.all_traces 220 | assert "timestamp4" in tracer.all_traces 221 | 222 | 223 | @patch("zenbase.utils.ksuid") 224 | def test_optimized_args(mock_ksuid, tracer): 225 | mock_ksuid.return_value = "test_timestamp" 226 | 227 | def test_func(request, z=3): 228 | x, y = request.inputs 229 | return x + y + z 230 | 231 | tracer.optimized_args = {"test_func": {"args": {"z": 5}}} 232 | zenbase = LMZenbase() 233 | traced_func = tracer.trace_function(test_func, zenbase) 234 | 235 | result = traced_func(inputs=(2, 10)) 236 | 237 | assert result == 17 # 2 + 10 + 5 238 | trace = tracer.all_traces[list(tracer.all_traces.keys())[0]] 239 | assert "test_func" in trace["test_func"] 240 | trace_info = trace["test_func"]["test_func"] 241 | assert trace_info["args"]["request"].inputs == (2, 10) 242 | assert trace_info["args"]["z"] == 5 243 | assert trace_info["output"] == 17 244 | -------------------------------------------------------------------------------- /py/tests/optim/metric/test_bootstrap_few_shot.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, mock_open, patch 2 | 3 | import pytest 4 | 5 | from zenbase.adaptors.langchain import ZenLangSmith 6 | from zenbase.core.managers import ZenbaseTracer 7 | from zenbase.optim.metric.bootstrap_few_shot import BootstrapFewShot 8 | from zenbase.types import LMDemo, LMFunction, LMZenbase 9 | 10 | 11 | @pytest.fixture 12 | def mock_zen_adaptor(): 13 | adaptor = Mock(spec=ZenLangSmith) 14 | adaptor.fetch_dataset_demos.return_value = [ 15 | LMDemo(inputs={"input": "test1"}, outputs={"output": "result1"}), 16 | LMDemo(inputs={"input": "test2"}, outputs={"output": "result2"}), 17 | LMDemo(inputs={"input": "test3"}, outputs={"output": "result3"}), 18 | ] 19 | adaptor.get_evaluator.return_value = Mock( 20 | return_value=Mock( 21 | individual_evals=[ 22 | Mock(passed=True, demo=LMDemo(inputs={"input": "test1"}, outputs={"output": "result1"})), 23 | Mock(passed=True, demo=LMDemo(inputs={"input": "test2"}, outputs={"output": "result2"})), 24 | Mock(passed=False, demo=LMDemo(inputs={"input": "test3"}, outputs={"output": "result3"})), 25 | ] 26 | ) 27 | ) 28 | return adaptor 29 | 30 | 31 | @pytest.fixture 32 | def mock_trace_manager(): 33 | return Mock(spec=ZenbaseTracer) 34 | 35 | 36 | @pytest.fixture 37 | def bootstrap_few_shot(mock_zen_adaptor): 38 | return BootstrapFewShot( 39 | shots=2, 40 | training_set=Mock(), 41 | test_set=Mock(), 42 | validation_set=Mock(), 43 | zen_adaptor=mock_zen_adaptor, 44 | ) 45 | 46 | 47 | def test_init(bootstrap_few_shot): 48 | assert bootstrap_few_shot.shots == 2 49 | assert len(bootstrap_few_shot.training_set_demos) == 3 50 | 51 | 52 | def test_init_invalid_shots(): 53 | with pytest.raises(AssertionError): 54 | BootstrapFewShot(shots=0, training_set=Mock(), test_set=Mock(), validation_set=Mock(), zen_adaptor=Mock()) 55 | 56 | 57 | def test_create_teacher_model(bootstrap_few_shot, mock_zen_adaptor): 58 | mock_lmfn = Mock(spec=LMFunction) 59 | with patch("zenbase.optim.metric.bootstrap_few_shot.LabeledFewShot") as mock_labeled_few_shot: 60 | mock_labeled_few_shot.return_value.perform.return_value = (Mock(), None, None) 61 | teacher_model = bootstrap_few_shot._create_teacher_model(mock_zen_adaptor, mock_lmfn, 5, 1) 62 | assert teacher_model is not None 63 | mock_labeled_few_shot.assert_called_once() 64 | 65 | 66 | def test_validate_demo_set(bootstrap_few_shot, mock_zen_adaptor): 67 | mock_teacher_lm = Mock(spec=LMFunction) 68 | validated_demos = bootstrap_few_shot._validate_demo_set(mock_zen_adaptor, mock_teacher_lm) 69 | assert len(validated_demos) == 2 70 | assert all(demo.inputs["input"].startswith("test") for demo in validated_demos) 71 | 72 | 73 | def test_run_validated_demos(): 74 | mock_teacher_lm = Mock(spec=LMFunction) 75 | validated_demo_set = [ 76 | LMDemo(inputs={"input": "test1"}, outputs={"output": "result1"}), 77 | LMDemo(inputs={"input": "test2"}, outputs={"output": "result2"}), 78 | ] 79 | BootstrapFewShot._run_validated_demos(mock_teacher_lm, validated_demo_set) 80 | assert mock_teacher_lm.call_count == 2 81 | 82 | 83 | def test_consolidate_traces_to_optimized_args(mock_trace_manager, bootstrap_few_shot): 84 | mock_trace_manager.all_traces = { 85 | "trace1": { 86 | "func1": { 87 | "inner_func1": {"args": {"request": Mock(inputs={"input": "test1"})}, "output": {"output": "result1"}} 88 | } 89 | } 90 | } 91 | optimized_args = bootstrap_few_shot._consolidate_traces_to_optimized_args(mock_trace_manager) 92 | assert "inner_func1" in optimized_args 93 | assert isinstance(optimized_args["inner_func1"]["args"]["zenbase"], LMZenbase) 94 | 95 | 96 | def test_create_optimized_function(): 97 | mock_student_lm = Mock(spec=LMFunction) 98 | mock_trace_manager = Mock(spec=ZenbaseTracer) 99 | optimized_args = {"func1": {"args": {"zenbase": LMZenbase(task_demos=[])}}} 100 | 101 | optimized_fn = BootstrapFewShot._create_optimized_function(mock_student_lm, optimized_args, mock_trace_manager) 102 | assert callable(optimized_fn) 103 | 104 | 105 | @patch("zenbase.optim.metric.bootstrap_few_shot.partial") 106 | def test_perform(mock_partial, bootstrap_few_shot, mock_zen_adaptor, mock_trace_manager): 107 | mock_student_lm = Mock(spec=LMFunction) 108 | mock_teacher_lm = Mock(spec=LMFunction) 109 | 110 | with patch.object(bootstrap_few_shot, "_create_teacher_model", return_value=mock_teacher_lm): 111 | with patch.object(bootstrap_few_shot, "_validate_demo_set"): 112 | with patch.object(bootstrap_few_shot, "_run_validated_demos"): 113 | with patch.object(bootstrap_few_shot, "_consolidate_traces_to_optimized_args"): 114 | with patch.object(bootstrap_few_shot, "_create_optimized_function"): 115 | result = bootstrap_few_shot.perform( 116 | mock_student_lm, 117 | mock_teacher_lm, 118 | samples=5, 119 | rounds=1, 120 | trace_manager=mock_trace_manager, 121 | ) 122 | 123 | assert isinstance(result, BootstrapFewShot.Result) 124 | assert result.best_function is not None 125 | 126 | 127 | def test_set_and_get_optimizer_args(bootstrap_few_shot): 128 | test_args = {"test": "args"} 129 | bootstrap_few_shot.set_optimizer_args(test_args) 130 | assert bootstrap_few_shot.get_optimizer_args() == test_args 131 | 132 | 133 | @patch("cloudpickle.dump") 134 | def test_save_optimizer_args(mock_dump, bootstrap_few_shot, tmp_path): 135 | test_args = {"test": "args"} 136 | bootstrap_few_shot.set_optimizer_args(test_args) 137 | file_path = tmp_path / "test_optimizer_args.dill" 138 | bootstrap_few_shot.save_optimizer_args(str(file_path)) 139 | mock_dump.assert_called_once() 140 | 141 | 142 | @patch("builtins.open", new_callable=mock_open, read_data="dummy data") 143 | @patch("cloudpickle.load") 144 | def test_load_optimizer_args(mock_load, mock_file): 145 | test_args = {"test": "args"} 146 | mock_load.return_value = test_args 147 | loaded_args = BootstrapFewShot._load_optimizer_args("dummy_path") 148 | mock_file.assert_called_once_with("dummy_path", "rb") 149 | mock_load.assert_called_once() 150 | assert loaded_args == test_args 151 | 152 | 153 | @patch("zenbase.optim.metric.bootstrap_few_shot.BootstrapFewShot._load_optimizer_args") 154 | @patch("zenbase.optim.metric.bootstrap_few_shot.BootstrapFewShot._create_optimized_function") 155 | def test_load_optimizer_and_function(mock_create_optimized_function, mock_load_optimizer_args): 156 | mock_student_lm = Mock(spec=LMFunction) 157 | mock_trace_manager = Mock(spec=ZenbaseTracer) 158 | mock_load_optimizer_args.return_value = {"test": "args"} 159 | mock_create_optimized_function.return_value = Mock(spec=LMFunction) 160 | 161 | result = BootstrapFewShot.load_optimizer_and_function("dummy_path", mock_student_lm, mock_trace_manager) 162 | 163 | mock_load_optimizer_args.assert_called_once_with("dummy_path") 164 | mock_create_optimized_function.assert_called_once() 165 | assert isinstance(result, Mock) 166 | assert isinstance(result, LMFunction) 167 | -------------------------------------------------------------------------------- /py/tests/optim/metric/test_labeled_few_shot.py: -------------------------------------------------------------------------------- 1 | from random import Random, random 2 | 3 | import pytest 4 | 5 | from zenbase.optim.metric.labeled_few_shot import LabeledFewShot 6 | from zenbase.optim.metric.types import CandidateEvalResult 7 | from zenbase.types import LMDemo, LMFunction, LMRequest, deflm 8 | 9 | lmfn = deflm(lambda x: x) 10 | 11 | 12 | demoset = [ 13 | LMDemo(inputs={}, outputs={"output": "a"}), 14 | LMDemo(inputs={}, outputs={"output": "b"}), 15 | LMDemo(inputs={}, outputs={"output": "c"}), 16 | LMDemo(inputs={}, outputs={"output": "d"}), 17 | LMDemo(inputs={}, outputs={"output": "e"}), 18 | LMDemo(inputs={}, outputs={"output": "f"}), 19 | ] 20 | 21 | 22 | def test_invalid_shots(): 23 | with pytest.raises(AssertionError): 24 | LabeledFewShot(demoset=demoset, shots=0) 25 | with pytest.raises(AssertionError): 26 | LabeledFewShot(demoset=demoset, shots=len(demoset) + 1) 27 | 28 | 29 | def test_idempotency(): 30 | shots = 2 31 | samples = 5 32 | 33 | optim1 = LabeledFewShot(demoset=demoset, shots=shots) 34 | optim2 = LabeledFewShot(demoset=demoset, shots=shots) 35 | optim3 = LabeledFewShot(demoset=demoset, shots=shots, random=Random(41)) 36 | 37 | set1 = list(optim1.candidates(lmfn, samples)) 38 | set2 = list(optim2.candidates(lmfn, samples)) 39 | set3 = list(optim3.candidates(lmfn, samples)) 40 | 41 | assert set1 == set2 42 | assert set1 != set3 43 | assert set2 != set3 44 | 45 | 46 | @pytest.fixture 47 | def optim(): 48 | return LabeledFewShot(demoset=demoset, shots=2) 49 | 50 | 51 | def test_candidate_generation(optim: LabeledFewShot): 52 | samples = 5 53 | 54 | candidates = list(optim.candidates(lmfn, samples)) 55 | 56 | assert all(len(c.task_demos) == optim.shots for c in candidates) 57 | assert len(candidates) == samples 58 | 59 | 60 | @deflm 61 | def dummy_lmfn(_: LMRequest): 62 | return {"answer": 42} 63 | 64 | 65 | def dummy_evalfn(fn: LMFunction): 66 | return CandidateEvalResult(fn, {"score": random()}) 67 | 68 | 69 | def test_training(optim: LabeledFewShot): 70 | # Train the dummy function 71 | trained_lmfn, candidates, best_candidate_result = optim.perform( 72 | dummy_lmfn, 73 | dummy_evalfn, 74 | rounds=1, 75 | concurrency=1, 76 | ) 77 | 78 | # Check that the best function is returned 79 | best_function = max(candidates, key=lambda r: r.evals["score"]).function 80 | assert trained_lmfn == best_function 81 | 82 | for demo in trained_lmfn.zenbase.task_demos: 83 | assert demo in demoset 84 | 85 | 86 | @pytest.mark.anyio 87 | async def test_async_training(optim: LabeledFewShot): 88 | # Train the dummy function 89 | trained_dummy_lmfn, candidates, best_candidate_result = await optim.aperform( 90 | dummy_lmfn, 91 | dummy_evalfn, 92 | rounds=1, 93 | concurrency=1, 94 | ) 95 | 96 | # Check that the best function is returned 97 | best_function = max(candidates, key=lambda r: r.evals["score"]).function 98 | assert trained_dummy_lmfn == best_function 99 | -------------------------------------------------------------------------------- /py/tests/predefined/test_generic_lm_function_optimizer.py: -------------------------------------------------------------------------------- 1 | import instructor 2 | import pytest 3 | from instructor import Instructor 4 | from openai import OpenAI 5 | from pydantic import BaseModel 6 | 7 | from zenbase.core.managers import ZenbaseTracer 8 | from zenbase.predefined.generic_lm_function.optimizer import GenericLMFunctionOptimizer 9 | 10 | 11 | class InputModel(BaseModel): 12 | question: str 13 | 14 | 15 | class OutputModel(BaseModel): 16 | answer: str 17 | 18 | 19 | @pytest.fixture(scope="module") 20 | def openai_client() -> OpenAI: 21 | return OpenAI() 22 | 23 | 24 | @pytest.fixture(scope="module") 25 | def instructor_client(openai_client: OpenAI) -> Instructor: 26 | return instructor.from_openai(openai_client) 27 | 28 | 29 | @pytest.fixture(scope="module") 30 | def zenbase_tracer() -> ZenbaseTracer: 31 | return ZenbaseTracer() 32 | 33 | 34 | @pytest.fixture 35 | def generic_optimizer(instructor_client, zenbase_tracer): 36 | training_set = [ 37 | {"inputs": {"question": "What is the capital of France?"}, "outputs": {"answer": "Paris"}}, 38 | {"inputs": {"question": "Who wrote Romeo and Juliet?"}, "outputs": {"answer": "William Shakespeare"}}, 39 | {"inputs": {"question": "What is the largest planet in our solar system?"}, "outputs": {"answer": "Jupiter"}}, 40 | {"inputs": {"question": "Who painted the Mona Lisa?"}, "outputs": {"answer": "Leonardo da Vinci"}}, 41 | {"inputs": {"question": "What is the chemical symbol for gold?"}, "outputs": {"answer": "Au"}}, 42 | ] 43 | 44 | return GenericLMFunctionOptimizer( 45 | instructor_client=instructor_client, 46 | prompt="You are a helpful assistant. Answer the user's question concisely.", 47 | input_model=InputModel, 48 | output_model=OutputModel, 49 | model="gpt-4o-mini", 50 | zenbase_tracer=zenbase_tracer, 51 | training_set=training_set, 52 | validation_set=[ 53 | {"inputs": {"question": "What is the capital of Italy?"}, "outputs": {"answer": "Rome"}}, 54 | {"inputs": {"question": "What is the capital of France?"}, "outputs": {"answer": "Paris"}}, 55 | {"inputs": {"question": "What is the capital of Germany?"}, "outputs": {"answer": "Berlin"}}, 56 | ], 57 | test_set=[ 58 | {"inputs": {"question": "Who invented the telephone?"}, "outputs": {"answer": "Alexander Graham Bell"}}, 59 | {"inputs": {"question": "Who is CEO of microsoft?"}, "outputs": {"answer": "Bill Gates"}}, 60 | {"inputs": {"question": "Who is founder of Facebook?"}, "outputs": {"answer": "Mark Zuckerberg"}}, 61 | ], 62 | shots=len(training_set), # Set shots to the number of training examples 63 | ) 64 | 65 | 66 | @pytest.mark.helpers 67 | def test_generic_optimizer_optimize(generic_optimizer): 68 | result = generic_optimizer.optimize() 69 | assert result is not None 70 | assert isinstance(result, GenericLMFunctionOptimizer.Result) 71 | assert result.best_function is not None 72 | assert callable(result.best_function) 73 | assert isinstance(result.candidate_results, list) 74 | assert result.best_candidate_result is not None 75 | 76 | # Check base evaluation 77 | assert generic_optimizer.base_evaluation is not None 78 | 79 | # Check best evaluation 80 | assert generic_optimizer.best_evaluation is not None 81 | 82 | # Test the best function 83 | test_input = InputModel(question="What is the capital of Italy?") 84 | output = result.best_function(test_input) 85 | assert isinstance(output, OutputModel) 86 | assert isinstance(output.answer, str) 87 | assert output.answer.strip().lower() == "rome" 88 | 89 | 90 | @pytest.mark.helpers 91 | def test_generic_optimizer_evaluations(generic_optimizer): 92 | result = generic_optimizer.optimize() 93 | 94 | # Check that base and best evaluations exist 95 | assert generic_optimizer.base_evaluation is not None 96 | assert generic_optimizer.best_evaluation is not None 97 | 98 | # Additional checks to ensure the structure of the result 99 | assert isinstance(result, GenericLMFunctionOptimizer.Result) 100 | assert result.best_function is not None 101 | assert isinstance(result.candidate_results, list) 102 | assert result.best_candidate_result is not None 103 | 104 | 105 | @pytest.mark.helpers 106 | def test_generic_optimizer_custom_evaluator(instructor_client, zenbase_tracer): 107 | def custom_evaluator(output: OutputModel, ideal_output: dict) -> dict: 108 | return {"passed": int(output.answer.lower() == ideal_output["answer"].lower()), "length": len(output.answer)} 109 | 110 | training_set = [ 111 | {"inputs": {"question": "What is 2+2?"}, "outputs": {"answer": "4"}}, 112 | {"inputs": {"question": "What is the capital of France?"}, "outputs": {"answer": "Paris"}}, 113 | {"inputs": {"question": "Who wrote Romeo and Juliet?"}, "outputs": {"answer": "William Shakespeare"}}, 114 | {"inputs": {"question": "What is the largest planet in our solar system?"}, "outputs": {"answer": "Jupiter"}}, 115 | {"inputs": {"question": "Who painted the Mona Lisa?"}, "outputs": {"answer": "Leonardo da Vinci"}}, 116 | ] 117 | 118 | optimizer = GenericLMFunctionOptimizer( 119 | instructor_client=instructor_client, 120 | prompt="You are a helpful assistant. Answer the user's question concisely.", 121 | input_model=InputModel, 122 | output_model=OutputModel, 123 | model="gpt-4o-mini", 124 | zenbase_tracer=zenbase_tracer, 125 | training_set=training_set, 126 | validation_set=[{"inputs": {"question": "What is 3+3?"}, "outputs": {"answer": "6"}}], 127 | test_set=[{"inputs": {"question": "What is 4+4?"}, "outputs": {"answer": "8"}}], 128 | custom_evaluator=custom_evaluator, 129 | shots=len(training_set), # Set shots to the number of training examples 130 | ) 131 | 132 | result = optimizer.optimize() 133 | assert result is not None 134 | assert isinstance(result, GenericLMFunctionOptimizer.Result) 135 | assert "length" in optimizer.best_evaluation.individual_evals[0].details 136 | 137 | # Test the custom evaluator 138 | test_input = InputModel(question="What is 5+5?") 139 | output = result.best_function(test_input) 140 | assert isinstance(output, OutputModel) 141 | assert isinstance(output.answer, str) 142 | 143 | # Manually apply the custom evaluator 144 | eval_result = custom_evaluator(output, {"answer": "10"}) 145 | assert "passed" in eval_result 146 | assert "length" in eval_result 147 | 148 | 149 | @pytest.mark.helpers 150 | def test_create_lm_function_with_demos(generic_optimizer): 151 | prompt = "You are a helpful assistant. Answer the user's question concisely." 152 | demos = [ 153 | {"inputs": {"question": "What is the capital of France?"}, "outputs": {"answer": "Paris"}}, 154 | {"inputs": {"question": "Who wrote Romeo and Juliet?"}, "outputs": {"answer": "William Shakespeare"}}, 155 | ] 156 | 157 | lm_function = generic_optimizer.create_lm_function_with_demos(prompt, demos) 158 | 159 | # Test that the function is created and can be called 160 | test_input = InputModel(question="What is the capital of Italy?") 161 | result = lm_function(test_input) 162 | 163 | assert isinstance(result, OutputModel) 164 | assert isinstance(result.answer, str) 165 | assert result.answer.strip().lower() == "rome" 166 | 167 | # Test with a question from the demos 168 | test_input_demo = InputModel(question="What is the capital of France?") 169 | result_demo = lm_function(test_input_demo) 170 | 171 | assert isinstance(result_demo, OutputModel) 172 | assert isinstance(result_demo.answer, str) 173 | assert result_demo.answer.strip().lower() == "paris" 174 | -------------------------------------------------------------------------------- /py/tests/predefined/test_single_class_classifier.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import instructor 3 | import pandas as pd 4 | import pytest 5 | from instructor.client import Instructor 6 | from openai import OpenAI 7 | from pydantic import BaseModel 8 | 9 | from zenbase.core.managers import ZenbaseTracer 10 | from zenbase.predefined.single_class_classifier import SingleClassClassifier 11 | from zenbase.predefined.single_class_classifier.function_generator import SingleClassClassifierLMFunctionGenerator 12 | 13 | TRAINSET_SIZE = 100 14 | VALIDATIONSET_SIZE = 21 15 | TESTSET_SIZE = 21 16 | 17 | 18 | @pytest.fixture(scope="module") 19 | def prompt_definition() -> str: 20 | return """Your task is to accurately categorize each incoming news article into one of the given categories based 21 | on its title and content.""" 22 | 23 | 24 | @pytest.fixture(scope="module") 25 | def class_dict() -> dict[str, str]: 26 | return { 27 | "Automobiles": "Discussions and news about automobiles, including car maintenance, driving experiences, " 28 | "and the latest automotive technology.", 29 | "Computers": "Topics related to computer hardware, software, graphics, cryptography, and operating systems, " 30 | "including troubleshooting and advancements.", 31 | "Science": "News and discussions about scientific topics including space exploration, medicine, and " 32 | "electronics.", 33 | "Politics": "Debates and news about political topics, including gun control, Middle Eastern politics," 34 | " and miscellaneous political discussions.", 35 | "Religion": "Discussions about various religions, including beliefs, practices, atheism, and religious news.", 36 | "For Sale": "Classified ads for buying and selling miscellaneous items, from electronics to household goods.", 37 | "Sports": "Everything about sports, including discussions, news, player updates, and game analysis.", 38 | } 39 | 40 | 41 | @pytest.fixture(scope="module") 42 | def sample_news_article(): 43 | return """title: New Advancements in Electric Vehicle Technology 44 | content: The automotive industry is witnessing a significant shift towards electric vehicles (EVs). 45 | Recent advancements in battery technology have led to increased range and reduced charging times. 46 | Companies like Tesla, Nissan, and BMW are at the forefront of this innovation, aiming to make EVs 47 | more accessible and efficient. 48 | With governments worldwide pushing for greener alternatives, the future of transportation looks 49 | electric.""" 50 | 51 | 52 | def create_dataset_with_examples(item_set: list): 53 | return [{"inputs": item["text"], "outputs": convert_to_human_readable(item["label_text"])} for item in item_set] 54 | 55 | 56 | @pytest.fixture(scope="module") 57 | def openai_client() -> OpenAI: 58 | return OpenAI() 59 | 60 | 61 | @pytest.fixture(scope="module") 62 | def instructor_client(openai_client: OpenAI) -> Instructor: 63 | return instructor.from_openai(openai_client) 64 | 65 | 66 | @pytest.fixture(scope="module") 67 | def zenbase_tracer() -> ZenbaseTracer: 68 | return ZenbaseTracer() 69 | 70 | 71 | @pytest.fixture(scope="module") 72 | def single_class_classifier_generator( 73 | instructor_client: Instructor, prompt_definition: str, class_dict: dict[str, str], zenbase_tracer: ZenbaseTracer 74 | ) -> SingleClassClassifierLMFunctionGenerator: 75 | return SingleClassClassifierLMFunctionGenerator( 76 | instructor_client=instructor_client, 77 | prompt=prompt_definition, 78 | class_dict=class_dict, 79 | model="gpt-4o-mini", 80 | zenbase_tracer=zenbase_tracer, 81 | ) 82 | 83 | 84 | @pytest.mark.helpers 85 | def test_single_class_classifier_lm_function_generator_initialization( 86 | single_class_classifier_generator: SingleClassClassifierLMFunctionGenerator, 87 | ): 88 | assert single_class_classifier_generator is not None 89 | assert single_class_classifier_generator.instructor_client 90 | None 91 | assert single_class_classifier_generator.class_dict is not None 92 | assert single_class_classifier_generator.model is not None 93 | assert single_class_classifier_generator.zenbase_tracer is not None 94 | 95 | # Check generated class enum and prediction class 96 | assert single_class_classifier_generator.class_enum is not None 97 | assert issubclass(single_class_classifier_generator.prediction_class, BaseModel) 98 | 99 | 100 | @pytest.mark.helpers 101 | def test_single_class_classifier_lm_function_generator_prediction( 102 | single_class_classifier_generator: SingleClassClassifierLMFunctionGenerator, sample_news_article 103 | ): 104 | result = single_class_classifier_generator.generate()(sample_news_article) 105 | 106 | assert result.class_label.name == "Automobiles" 107 | assert single_class_classifier_generator.zenbase_tracer.all_traces is not None 108 | 109 | 110 | @pytest.mark.helpers 111 | def test_single_class_classifier_lm_function_generator_with_missing_data( 112 | single_class_classifier_generator: SingleClassClassifierLMFunctionGenerator, 113 | ): 114 | faulty_email = {"subject": "Meeting Reminder"} 115 | result = single_class_classifier_generator.generate()(faulty_email) 116 | assert result.class_label.name in { 117 | "Automobiles", 118 | "Computers", 119 | "Science", 120 | "Politics", 121 | "Religion", 122 | "For Sale", 123 | "Sports", 124 | "Other", 125 | } 126 | 127 | 128 | def convert_to_human_readable(category: str) -> str: 129 | human_readable_map = { 130 | "rec.autos": "Automobiles", 131 | "comp.sys.mac.hardware": "Computers", 132 | "comp.graphics": "Computers", 133 | "sci.space": "Science", 134 | "talk.politics.guns": "Politics", 135 | "sci.med": "Science", 136 | "comp.sys.ibm.pc.hardware": "Computers", 137 | "comp.os.ms-windows.misc": "Computers", 138 | "rec.motorcycles": "Automobiles", 139 | "talk.religion.misc": "Religion", 140 | "misc.forsale": "For Sale", 141 | "alt.atheism": "Religion", 142 | "sci.electronics": "Computers", 143 | "comp.windows.x": "Computers", 144 | "rec.sport.hockey": "Sports", 145 | "rec.sport.baseball": "Sports", 146 | "soc.religion.christian": "Religion", 147 | "talk.politics.mideast": "Politics", 148 | "talk.politics.misc": "Politics", 149 | "sci.crypt": "Computers", 150 | } 151 | return human_readable_map.get(category) 152 | 153 | 154 | @pytest.fixture(scope="module") 155 | def get_balanced_dataset(): 156 | # Load the dataset 157 | split = "train" 158 | train_size = TRAINSET_SIZE 159 | validation_size = VALIDATIONSET_SIZE 160 | test_size = TESTSET_SIZE 161 | dataset = datasets.load_dataset("SetFit/20_newsgroups", split=split) 162 | 163 | # Convert to pandas DataFrame for easier manipulation 164 | df = pd.DataFrame(dataset) 165 | 166 | # Convert text labels to human-readable labels 167 | df["human_readable_label"] = df["label_text"].apply(convert_to_human_readable) 168 | 169 | # Find text labels 170 | text_labels = df["human_readable_label"].unique() 171 | 172 | # Determine the number of labels 173 | num_labels = len(text_labels) 174 | 175 | # Calculate the number of samples per label for each subset 176 | train_samples_per_label = train_size // num_labels 177 | validation_samples_per_label = validation_size // num_labels 178 | test_samples_per_label = test_size // num_labels 179 | 180 | # Create empty DataFrames for train, validation, and test sets 181 | train_set = pd.DataFrame() 182 | validation_set = pd.DataFrame() 183 | test_set = pd.DataFrame() 184 | 185 | # Split sequentially without shuffling 186 | for label in text_labels: 187 | label_df = df[df["human_readable_label"] == label] 188 | 189 | # Ensure there's enough data 190 | if len(label_df) < (train_samples_per_label + validation_samples_per_label + test_samples_per_label): 191 | raise ValueError(f"Not enough data for label {label}") 192 | 193 | # Split the label-specific DataFrame 194 | label_train = label_df.iloc[:train_samples_per_label] 195 | label_validation = label_df.iloc[ 196 | train_samples_per_label : train_samples_per_label + validation_samples_per_label 197 | ] 198 | label_test = label_df.iloc[ 199 | train_samples_per_label + validation_samples_per_label : train_samples_per_label 200 | + validation_samples_per_label 201 | + test_samples_per_label 202 | ] 203 | 204 | # Append to the respective sets 205 | train_set = pd.concat([train_set, label_train]) 206 | validation_set = pd.concat([validation_set, label_validation]) 207 | test_set = pd.concat([test_set, label_test]) 208 | 209 | # Create dataset with examples 210 | train_set = create_dataset_with_examples(train_set.to_dict("records")) 211 | validation_set = create_dataset_with_examples(validation_set.to_dict("records")) 212 | test_set = create_dataset_with_examples(test_set.to_dict("records")) 213 | 214 | return train_set, validation_set, test_set 215 | 216 | 217 | @pytest.fixture(scope="module") 218 | def single_class_classifier( 219 | instructor_client: Instructor, 220 | prompt_definition: str, 221 | class_dict: dict[str, str], 222 | zenbase_tracer: ZenbaseTracer, 223 | get_balanced_dataset, 224 | ) -> SingleClassClassifier: 225 | train_set, validation_set, test_set = get_balanced_dataset 226 | return SingleClassClassifier( 227 | instructor_client=instructor_client, 228 | prompt=prompt_definition, 229 | class_dict=class_dict, 230 | model="gpt-4o-mini", 231 | zenbase_tracer=zenbase_tracer, 232 | training_set=train_set, 233 | validation_set=validation_set, 234 | test_set=test_set, 235 | ) 236 | 237 | 238 | @pytest.mark.helpers 239 | def test_single_class_classifier_perform(single_class_classifier: SingleClassClassifier, sample_news_article): 240 | result = single_class_classifier.optimize() 241 | assert all( 242 | [result.best_function, result.candidate_results, result.best_candidate_result] 243 | ), "Assertions failed for result properties" 244 | traces = single_class_classifier.zenbase_tracer.all_traces 245 | assert traces, "No traces found" 246 | assert result is not None, "Result should not be None" 247 | assert hasattr(result, "best_function"), "Result should have a best_function attribute" 248 | best_fn = result.best_function 249 | assert callable(best_fn), "best_function should be callable" 250 | output = best_fn(sample_news_article) 251 | assert output is not None, "output should not be None" 252 | -------------------------------------------------------------------------------- /py/tests/predefined/test_single_class_classifier_syntethic_data.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import io 3 | from unittest.mock import Mock, patch 4 | 5 | import instructor 6 | import pytest 7 | from openai import OpenAI 8 | 9 | from zenbase.predefined.syntethic_data.single_class_classifier import ( 10 | SingleClassClassifierSyntheticDataExample, 11 | SingleClassClassifierSyntheticDataGenerator, 12 | ) 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def prompt_definition() -> str: 17 | return """Your task is to accurately categorize each incoming news article into one of the given categories based 18 | on its title and content.""" 19 | 20 | 21 | @pytest.fixture(scope="module") 22 | def class_dict() -> dict[str, str]: 23 | return { 24 | "Automobiles": "Discussions and news about automobiles, including car maintenance, driving experiences, " 25 | "and the latest automotive technology.", 26 | "Computers": "Topics related to computer hardware, software, graphics, cryptography, and operating systems, " 27 | "including troubleshooting and advancements.", 28 | "Science": "News and discussions about scientific topics including space exploration, medicine, and " 29 | "electronics.", 30 | "Politics": "Debates and news about political topics, including gun control, Middle Eastern politics," 31 | " and miscellaneous political discussions.", 32 | } 33 | 34 | 35 | @pytest.fixture(scope="module") 36 | def openai_client() -> OpenAI: 37 | return OpenAI() 38 | 39 | 40 | @pytest.fixture(scope="module") 41 | def instructor_client(openai_client: OpenAI) -> instructor.Instructor: 42 | return instructor.from_openai(openai_client) 43 | 44 | 45 | @pytest.fixture(scope="module") 46 | def synthetic_data_generator( 47 | instructor_client: instructor.Instructor, 48 | prompt_definition: str, 49 | class_dict: dict[str, str], 50 | ) -> SingleClassClassifierSyntheticDataGenerator: 51 | return SingleClassClassifierSyntheticDataGenerator( 52 | instructor_client=instructor_client, 53 | prompt=prompt_definition, 54 | class_dict=class_dict, 55 | model="gpt-4o-mini", 56 | ) 57 | 58 | 59 | @pytest.mark.helpers 60 | def test_synthetic_data_generator_initialization( 61 | synthetic_data_generator: SingleClassClassifierSyntheticDataGenerator, 62 | ): 63 | assert synthetic_data_generator is not None 64 | assert synthetic_data_generator.instructor_client is not None 65 | assert synthetic_data_generator.prompt is not None 66 | assert synthetic_data_generator.class_dict is not None 67 | assert synthetic_data_generator.model is not None 68 | 69 | 70 | @pytest.mark.helpers 71 | def test_generate_examples_for_category( 72 | synthetic_data_generator: SingleClassClassifierSyntheticDataGenerator, 73 | ): 74 | category = "Automobiles" 75 | description = synthetic_data_generator.class_dict[category] 76 | num_examples = 5 77 | examples = synthetic_data_generator.generate_examples_for_category(category, description, num_examples) 78 | 79 | assert len(examples) == num_examples 80 | for example in examples: 81 | assert example.inputs is not None 82 | assert example.outputs == category 83 | 84 | 85 | @pytest.mark.helpers 86 | def test_generate_examples( 87 | synthetic_data_generator: SingleClassClassifierSyntheticDataGenerator, 88 | ): 89 | examples_per_category = 3 90 | all_examples = synthetic_data_generator.generate_examples(examples_per_category) 91 | 92 | assert len(all_examples) == examples_per_category * len(synthetic_data_generator.class_dict) 93 | for example in all_examples: 94 | assert example.inputs is not None 95 | assert example.outputs in synthetic_data_generator.class_dict.keys() 96 | 97 | 98 | @pytest.mark.helpers 99 | def test_generate_csv( 100 | synthetic_data_generator: SingleClassClassifierSyntheticDataGenerator, 101 | ): 102 | examples_per_category = 2 103 | csv_content = synthetic_data_generator.generate_csv(examples_per_category) 104 | 105 | csv_reader = csv.DictReader(io.StringIO(csv_content)) 106 | rows = list(csv_reader) 107 | 108 | assert len(rows) == examples_per_category * len(synthetic_data_generator.class_dict) 109 | for row in rows: 110 | assert "inputs" in row 111 | assert "outputs" in row 112 | assert row["outputs"] in synthetic_data_generator.class_dict.keys() 113 | 114 | 115 | @pytest.mark.helpers 116 | def test_save_csv( 117 | synthetic_data_generator: SingleClassClassifierSyntheticDataGenerator, 118 | tmp_path, 119 | ): 120 | examples_per_category = 2 121 | file_path = tmp_path / "test_synthetic_data.csv" 122 | synthetic_data_generator.save_csv(str(file_path), examples_per_category) 123 | 124 | assert file_path.exists() 125 | 126 | with open(file_path, "r", newline="", encoding="utf-8") as f: 127 | csv_reader = csv.DictReader(f) 128 | rows = list(csv_reader) 129 | 130 | assert len(rows) == examples_per_category * len(synthetic_data_generator.class_dict) 131 | for row in rows: 132 | assert "inputs" in row 133 | assert "outputs" in row 134 | assert row["outputs"] in synthetic_data_generator.class_dict.keys() 135 | 136 | 137 | @pytest.fixture 138 | def mock_openai_client(): 139 | mock_client = Mock(spec=OpenAI) 140 | mock_client.chat = Mock() 141 | mock_client.chat.completions = Mock() 142 | mock_client.chat.completions.create = Mock() 143 | return mock_client 144 | 145 | 146 | @pytest.fixture 147 | def mock_instructor_client(mock_openai_client): 148 | return instructor.from_openai(mock_openai_client) 149 | 150 | 151 | @pytest.fixture 152 | def mock_generator(mock_instructor_client, class_dict): 153 | return SingleClassClassifierSyntheticDataGenerator( 154 | instructor_client=mock_instructor_client, 155 | prompt="Classify the given text into one of the categories", 156 | class_dict=class_dict, 157 | model="gpt-4o-mini", 158 | ) 159 | 160 | 161 | def mock_generate_examples( 162 | category: str, description: str, num: int 163 | ) -> list[SingleClassClassifierSyntheticDataExample]: 164 | return [ 165 | SingleClassClassifierSyntheticDataExample( 166 | inputs=f"Sample text for {category} {i}: {description[:20]}...", outputs=category 167 | ) 168 | for i in range(num) 169 | ] 170 | 171 | 172 | def test_generate_csv_mock(mock_generator): 173 | examples_per_category = 2 174 | 175 | with patch.object(mock_generator, "generate_examples_for_category", side_effect=mock_generate_examples): 176 | csv_content = mock_generator.generate_csv(examples_per_category) 177 | 178 | # Parse the CSV content 179 | reader = csv.DictReader(io.StringIO(csv_content)) 180 | rows = list(reader) 181 | 182 | assert len(rows) == len(mock_generator.class_dict) * examples_per_category 183 | for row in rows: 184 | assert "inputs" in row 185 | assert "outputs" in row 186 | assert row["outputs"] in mock_generator.class_dict 187 | 188 | 189 | def test_save_csv_mock(mock_generator, tmp_path): 190 | examples_per_category = 2 191 | filename = tmp_path / "test_output.csv" 192 | 193 | with patch.object(mock_generator, "generate_examples_for_category", side_effect=mock_generate_examples): 194 | mock_generator.save_csv(str(filename), examples_per_category) 195 | 196 | assert filename.exists() 197 | 198 | with open(filename, "r", newline="", encoding="utf-8") as f: 199 | reader = csv.DictReader(f) 200 | rows = list(reader) 201 | 202 | assert len(rows) == len(mock_generator.class_dict) * examples_per_category 203 | for row in rows: 204 | assert "inputs" in row 205 | assert "outputs" in row 206 | assert row["outputs"] in mock_generator.class_dict 207 | 208 | 209 | def test_integration_mock(mock_generator): 210 | examples_per_category = 1 211 | 212 | def mock_create(**kwargs): 213 | category = kwargs["messages"][1]["content"].split("'")[1] 214 | description = next(desc for cat, desc in mock_generator.class_dict.items() if cat == category) 215 | return mock_generate_examples(category, description, examples_per_category) 216 | 217 | with patch.object(mock_generator.instructor_client.chat.completions, "create", side_effect=mock_create): 218 | csv_content = mock_generator.generate_csv(examples_per_category) 219 | 220 | reader = csv.DictReader(io.StringIO(csv_content)) 221 | rows = list(reader) 222 | 223 | assert len(rows) == len(mock_generator.class_dict) * examples_per_category 224 | for row in rows: 225 | assert row["outputs"] in mock_generator.class_dict 226 | assert row["inputs"].startswith(f"Sample text for {row['outputs']}") 227 | -------------------------------------------------------------------------------- /py/tests/sciprts/clean_up_langsmith.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from langsmith import Client 3 | 4 | 5 | def remove_all_datasets(): 6 | # Initialize LangSmith client 7 | client = Client() 8 | 9 | # Fetch all datasets 10 | datasets = list(client.list_datasets()) 11 | 12 | if not datasets: 13 | print("No datasets found in LangSmith.") 14 | return 15 | 16 | # Confirm with the user 17 | confirm = input(f"Are you sure you want to delete all {len(datasets)} datasets? (yes/no): ") 18 | if confirm.lower() != "yes": 19 | print("Operation cancelled.") 20 | return 21 | 22 | # Delete each dataset 23 | for dataset in datasets: 24 | try: 25 | client.delete_dataset(dataset_id=dataset.id) 26 | print(f"Deleted dataset: {dataset.name} (ID: {dataset.id})") 27 | except Exception as e: 28 | print(f"Error deleting dataset {dataset.name} (ID: {dataset.id}): {str(e)}") 29 | 30 | print("All datasets have been deleted.") 31 | 32 | 33 | # Run the function 34 | if __name__ == "__main__": 35 | load_dotenv("../../.env.test") 36 | remove_all_datasets() 37 | -------------------------------------------------------------------------------- /py/tests/sciprts/convert_notebooks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import nbformat 4 | 5 | 6 | def fix_notebooks_in_directory(directory): 7 | for root, dirs, files in os.walk(directory): 8 | for file in files: 9 | if file.endswith(".ipynb"): 10 | notebook_path = os.path.join(root, file) 11 | print(f"Fixing notebook: {notebook_path}") 12 | 13 | with open(notebook_path, "r", encoding="utf-8") as f: 14 | notebook = nbformat.read(f, as_version=4) 15 | 16 | for cell in notebook["cells"]: 17 | if cell["cell_type"] == "code" and "execution_count" not in cell: 18 | cell["execution_count"] = None 19 | 20 | with open(notebook_path, "w", encoding="utf-8") as f: 21 | nbformat.write(notebook, f) 22 | 23 | print(f"Fixed notebook: {notebook_path}") 24 | 25 | 26 | fix_notebooks_in_directory("../../cookbooks") 27 | -------------------------------------------------------------------------------- /py/tests/test_types.py: -------------------------------------------------------------------------------- 1 | from zenbase.types import LMDemo, deflm 2 | 3 | 4 | def test_demo_eq(): 5 | demoset = [ 6 | LMDemo(inputs={}, outputs={"output": "a"}), 7 | LMDemo(inputs={}, outputs={"output": "b"}), 8 | ] 9 | 10 | # Structural inequality 11 | assert demoset[0] != demoset[1] 12 | # Structural equality 13 | assert demoset[0] == LMDemo(inputs={}, outputs={"output": "a"}) 14 | 15 | 16 | def test_lm_function_refine(): 17 | fn = deflm(lambda r: r.inputs) 18 | assert fn != fn.clean_and_duplicate() 19 | --------------------------------------------------------------------------------