├── .cache.json ├── .flake8 ├── .github ├── bump_version.py └── workflows │ ├── pypi.yml │ ├── test-demos.yml │ ├── tests.yml │ └── update-docs.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .run └── pycharm │ ├── Python tests for test_interfaces.run.xml │ ├── Python tests for test_routing.run.xml │ ├── Python tests for test_universal_api.run.xml │ ├── Python tests for test_utils.run.xml │ ├── Template Python tests.run.xml │ ├── Template Python.run.xml │ └── generate_docs.run.xml ├── LICENSE ├── README.md ├── generate_docs.py ├── poetry.lock ├── pydoc-markdown.yml ├── pyproject.toml ├── tests ├── __init__.py ├── conftest.py ├── test_logging │ ├── __init__.py │ ├── helpers.py │ ├── test_dataset.py │ ├── test_evals.py │ ├── test_logs.py │ ├── test_projects.py │ ├── test_tracing.py │ └── test_utils │ │ ├── __init__.py │ │ ├── test_artifacts.py │ │ ├── test_async_logger.py │ │ ├── test_contexts.py │ │ ├── test_datasets.py │ │ ├── test_logs.py │ │ └── test_projects.py ├── test_routing │ ├── __init__.py │ ├── test_fallbacks.py │ └── test_routing_syntax.py ├── test_universal_api │ ├── __init__.py │ ├── test_basics.py │ ├── test_chatbot.py │ ├── test_json_mode.py │ ├── test_multi_llm.py │ ├── test_stateful.py │ ├── test_usage.py │ ├── test_user_input.py │ └── test_utils │ │ ├── __init__.py │ │ ├── test_credits.py │ │ ├── test_custom_api_keys.py │ │ ├── test_custom_endpoints.py │ │ ├── test_endpoint_metrics.py │ │ ├── test_supported_endpoints.py │ │ └── test_usage.py └── test_utils │ ├── __init__.py │ ├── helpers.py │ ├── test_caching.py │ └── test_map.py └── unify ├── __init__.py ├── logging ├── __init__.py ├── dataset.py ├── logs.py └── utils │ ├── __init__.py │ ├── artifacts.py │ ├── async_logger.py │ ├── compositions.py │ ├── contexts.py │ ├── datasets.py │ ├── logs.py │ ├── projects.py │ └── tracing.py ├── universal_api ├── __init__.py ├── casting.py ├── chatbot.py ├── clients │ ├── __init__.py │ ├── base.py │ ├── helpers.py │ ├── multi_llm.py │ └── uni_llm.py ├── types │ ├── __init__.py │ └── prompt.py ├── usage.py └── utils │ ├── __init__.py │ ├── credits.py │ ├── custom_api_keys.py │ ├── custom_endpoints.py │ ├── endpoint_metrics.py │ ├── queries.py │ └── supported_endpoints.py └── utils ├── __init__.py ├── _caching.py ├── _requests.py ├── helpers.py └── map.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-complexity = 6 3 | inline-quotes = double 4 | max-line-length = 88 5 | extend-ignore = E203 6 | docstring_style = sphinx 7 | 8 | ignore = 9 | ; Found `f` string 10 | WPS305, 11 | ; Missing docstring in public module 12 | D100, 13 | ; Missing docstring in magic method 14 | D105, 15 | ; Missing docstring in __init__ 16 | D107, 17 | ; Found `__init__.py` module with logic 18 | WPS412, 19 | ; Found class without a base class 20 | WPS306, 21 | ; Missing docstring in public nested class 22 | D106, 23 | ; First line should be in imperative mood 24 | D401, 25 | ; Found wrong variable name 26 | WPS110, 27 | ; Found `__init__.py` module with logic 28 | WPS326, 29 | ; Found string constant over-use 30 | WPS226, 31 | ; Found upper-case constant in a class 32 | WPS115, 33 | ; Found nested function 34 | WPS602, 35 | ; Found method without arguments 36 | WPS605, 37 | ; Found overused expression 38 | WPS204, 39 | ; Found too many module members 40 | WPS202, 41 | ; Found too high module cognitive complexity 42 | WPS232, 43 | ; line break before binary operator 44 | W503, 45 | ; Found module with too many imports 46 | WPS201, 47 | ; Inline strong start-string without end-string. 48 | RST210, 49 | ; Found nested class 50 | WPS431, 51 | ; Found wrong module name 52 | WPS100, 53 | ; Found too many methods 54 | WPS214, 55 | ; Found too long ``try`` body 56 | WPS229, 57 | ; Found unpythonic getter or setter 58 | WPS615, 59 | ; Found a line that starts with a dot 60 | WPS348, 61 | ; Found complex default value (for dependency injection) 62 | WPS404, 63 | ; not perform function calls in argument defaults (for dependency injection) 64 | B008, 65 | ; Model should define verbose_name in its Meta inner class 66 | DJ10, 67 | ; Model should define verbose_name_plural in its Meta inner class 68 | DJ11, 69 | ; Found mutable module constant. 70 | WPS407, 71 | ; Found too many empty lines in `def` 72 | WPS473, 73 | ; too many no-cover comments. 74 | WPS403, 75 | ; Found `noqa` comments overuse 76 | WPS402, 77 | ; Found protected attribute usage 78 | WPS437, 79 | ; Found too short name 80 | WPS111, 81 | ; Found error raising from itself 82 | WPS469, 83 | ; Found a too complex `f` string 84 | WPS237, 85 | ; Imports not being at top-level 86 | E402 87 | ; complex function names 88 | C901 89 | ; missing docstring in public package 90 | D104 91 | 92 | per-file-ignores = 93 | ; all tests 94 | test_*.py,tests.py,tests_*.py,*/tests/*,conftest.py: 95 | ; Use of assert detected 96 | S101, 97 | ; Found outer scope names shadowing 98 | WPS442, 99 | ; Found too many local variables 100 | WPS210, 101 | ; Found magic number 102 | WPS432, 103 | ; Missing parameter(s) in Docstring 104 | DAR101, 105 | ; Found too many arguments 106 | WPS211, 107 | ; Missing docstring in public class 108 | D101, 109 | ; Missing docstring in public method 110 | D102, 111 | ; Found too long name 112 | WPS118, 113 | 114 | ; all init files 115 | __init__.py: 116 | ; ignore not used imports 117 | F401, 118 | ; ignore import with wildcard 119 | F403, 120 | ; Found wrong metadata variable 121 | WPS410, 122 | 123 | ; exceptions file 124 | unify/exceptions.py: 125 | ; Found wrong keyword 126 | WPS420, 127 | ; Found incorrect node inside `class` body 128 | WPS604, 129 | 130 | exclude = 131 | ./.cache, 132 | ./.git, 133 | ./.idea, 134 | ./.mypy_cache, 135 | ./.pytest_cache, 136 | ./.venv, 137 | ./venv, 138 | ./env, 139 | ./cached_venv, 140 | ./docs, 141 | ./deploy, 142 | ./var, 143 | ./.vscode, 144 | ./generate_docs.py 145 | -------------------------------------------------------------------------------- /.github/bump_version.py: -------------------------------------------------------------------------------- 1 | import re 2 | import subprocess 3 | 4 | import requests 5 | 6 | pypi_rss_url = "https://pypi.org/rss/project/unifyai/releases.xml" 7 | 8 | if __name__ == "__main__": 9 | # Get the latest version from the RSS feed 10 | response = requests.get(pypi_rss_url) 11 | version = re.findall(r"\d+\.\d+\.\d+", response.text)[0] 12 | if not version: 13 | raise Exception("Failed parsing version") 14 | 15 | # Bump patch version 16 | subprocess.run(["poetry", "version", version]) 17 | subprocess.run(["poetry", "version", "patch"]) 18 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish release to PyPI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - 'unify/**' 9 | - 'LICENSE' 10 | - 'README.md' 11 | - 'pyproject.toml' 12 | - 'poetry.lock' 13 | 14 | jobs: 15 | publish_pypi_release: 16 | runs-on: ubuntu-latest 17 | environment: pypi 18 | continue-on-error: false 19 | 20 | steps: 21 | - name: Checkout repository 22 | uses: actions/checkout@v4 23 | 24 | - name: Install Python 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: '3.9' 28 | 29 | - name: Install Poetry 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install poetry 33 | 34 | - name: Bump package version 35 | run: | 36 | python .github/bump_version.py 37 | 38 | - name: Build package 39 | run: poetry build 40 | 41 | - name: Publish to TestPyPI 42 | run: | 43 | poetry config repositories.testpypi https://test.pypi.org/legacy/ 44 | poetry publish --skip-existing -r testpypi -u __token__ -p ${{ secrets.TEST_PYPI_API_KEY }} 45 | 46 | - name: Publish to PyPI 47 | run: | 48 | poetry publish -u __token__ -p ${{ secrets.PYPI_API_KEY }} 49 | -------------------------------------------------------------------------------- /.github/workflows/test-demos.yml: -------------------------------------------------------------------------------- 1 | name: Test demos 2 | on: 3 | workflow_dispatch: 4 | # push: 5 | # branches: 6 | # - main 7 | # permissions: 8 | # contents: write 9 | # actions: read 10 | # id-token: write 11 | 12 | jobs: 13 | test-demos: 14 | name: Test demos 15 | uses: unifyai/workflows/.github/workflows/test-demos.yml@main 16 | secrets: 17 | GITHUB_TOKEN: ${{ secrets.CONSOLE_TOKEN }} 18 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Testing unify 2 | 3 | on: push 4 | 5 | jobs: 6 | black: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Python 11 | uses: actions/setup-python@v2 12 | with: 13 | python-version: '3.9' 14 | - name: Install deps 15 | uses: knowsuchagency/poetry-install@v1 16 | env: 17 | POETRY_VIRTUALENVS_CREATE: false 18 | - name: Run black check 19 | run: poetry run black --check . 20 | 21 | pytest: 22 | runs-on: ubuntu-latest 23 | environment: unify-testing 24 | timeout-minutes: 120 25 | steps: 26 | - uses: actions/checkout@v2 27 | - name: Set up Python 28 | uses: actions/setup-python@v2 29 | with: 30 | python-version: '3.9' 31 | - name: Install deps 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install poetry 35 | poetry install --with dev 36 | - name: Run unit tests 37 | run: poetry run pytest --timeout=120 -p no:warnings -vv . 38 | env: 39 | UNIFY_KEY: ${{ secrets.USER_API_KEY }} 40 | -------------------------------------------------------------------------------- /.github/workflows/update-docs.yml: -------------------------------------------------------------------------------- 1 | name: Publish docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | workflow_dispatch: 7 | permissions: 8 | contents: write 9 | 10 | jobs: 11 | publish-docs: 12 | name: Update docs 13 | uses: unifyai/workflows/.github/workflows/publish-docs.yml@main 14 | secrets: inherit 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | 3 | .idea/ 4 | .vscode/ 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | *.sqlite3 66 | *.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .env.tmp 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # Cython debug symbols 143 | cython_debug/ 144 | 145 | # Cloud SDKs 146 | google-cloud-sdk/ 147 | google-cloud-cli-*.tar.gz 148 | 149 | # SQL Dump files 150 | *.dump.sql 151 | 152 | # Binaries 153 | bin/ 154 | 155 | # Tests 156 | new_module.py 157 | test.py 158 | implementations.py 159 | .test_cache.json 160 | *.cache.json 161 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: check-ast 8 | - id: trailing-whitespace 9 | - id: check-toml 10 | - id: end-of-file-fixer 11 | 12 | - repo: https://github.com/asottile/add-trailing-comma 13 | rev: v3.1.0 14 | hooks: 15 | - id: add-trailing-comma 16 | 17 | - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks 18 | rev: v2.14.0 19 | hooks: 20 | - id: pretty-format-yaml 21 | args: 22 | - --autofix 23 | - --preserve-quotes 24 | - --indent=2 25 | 26 | - repo: local 27 | hooks: 28 | - id: autoflake 29 | name: autoflake 30 | entry: poetry run autoflake 31 | language: system 32 | types: [python] 33 | args: [--in-place, --remove-all-unused-imports, --remove-duplicate-keys] 34 | exclude: | 35 | (?x)^( 36 | .*/__init__\.py # Exclude all __init__.py files 37 | )$ 38 | 39 | - id: black 40 | name: Format with Black 41 | entry: poetry run black 42 | language: system 43 | types: [python] 44 | 45 | - id: isort 46 | name: isort 47 | entry: poetry run isort 48 | language: system 49 | types: [python] 50 | exclude: | 51 | (?x)^( 52 | .*/__init__\.py # Exclude all __init__.py files 53 | )$ 54 | -------------------------------------------------------------------------------- /.run/pycharm/Python tests for test_interfaces.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 29 | 30 | -------------------------------------------------------------------------------- /.run/pycharm/Python tests for test_routing.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 29 | 30 | -------------------------------------------------------------------------------- /.run/pycharm/Python tests for test_universal_api.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 29 | 30 | -------------------------------------------------------------------------------- /.run/pycharm/Python tests for test_utils.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 29 | 30 | -------------------------------------------------------------------------------- /.run/pycharm/Template Python tests.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 30 | 31 | 32 | 60 | 61 | -------------------------------------------------------------------------------- /.run/pycharm/Template Python.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 37 | 38 | -------------------------------------------------------------------------------- /.run/pycharm/generate_docs.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 35 | 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ---- 10 | 11 | ![Static Badge](https://img.shields.io/badge/Y%20Combinator-W23-orange) 12 | ![X (formerly Twitter) Follow](https://img.shields.io/twitter/follow/letsunifyai) 13 | ![Static Badge](https://img.shields.io/badge/Join_Discord-464646?&logo=discord&logoColor=5865F2) 14 | 15 |
16 | 17 | 18 | 19 |
20 | 21 | **Fully hackable** LLMOps. Build *custom* interfaces for: logging, evals, guardrails, labelling, tracing, agents, human-in-the-loop, hyperparam sweeps, and anything else you can think of ✨ 22 | 23 | Just `unify.log` your data, and add an interface using the four building blocks: 24 | 25 | 1. **tables** 🔢 26 | 2. **views** 🔍 27 | 3. **plots** 📊 28 | 4. **editor** 🕹️ (coming soon) 29 | 30 | Every LLM product has **unique** and **changing** requirements, as do the **users**. Your infra should reflect this! 31 | 32 | We've tried to make Unify as **(a) simple**, **(b) modular** and **(c) hackable** as possible, so you can quickly probe, analyze, and iterate on the data that's important for **you**, your **product** and your **users** ⚡ 33 | 34 | ## Quickstart 35 | 36 | [Sign up](https://console.unify.ai/), `pip install unifyai`, run your first eval ⬇️, and then check out the logs in your first [interface](https://console.unify.ai) 📊 37 | 38 | ```python 39 | import unify 40 | from random import randint, choice 41 | 42 | # initialize project 43 | unify.activate("Maths Assistant") 44 | 45 | # build agent 46 | client = unify.Unify("o3-mini@openai", traced=True) 47 | client.set_system_message( 48 | "You are a helpful maths assistant, " 49 | "tasked with adding and subtracting integers." 50 | ) 51 | 52 | # add test cases 53 | qs = [ 54 | f"{randint(0, 100)} {choice(['+', '-'])} {randint(0, 100)}" 55 | for i in range(10) 56 | ] 57 | 58 | # define evaluator 59 | @unify.traced 60 | def evaluate_response(question: str, response: str) -> float: 61 | correct_answer = eval(question) 62 | try: 63 | response_int = int( 64 | "".join( 65 | [ 66 | c for c in response.split(" ")[-1] 67 | if c.isdigit() 68 | ] 69 | ), 70 | ) 71 | return float(correct_answer == response_int) 72 | except ValueError: 73 | return 0. 74 | 75 | # define evaluation 76 | @unify.traced 77 | def evaluate(q: str): 78 | response = client.copy().generate(q) 79 | score = evaluate_response(q, response) 80 | unify.log( 81 | question=q, 82 | response=response, 83 | score=score 84 | ) 85 | 86 | # execute + log your evaluation 87 | with unify.Experiment(): 88 | unify.map(evaluate, qs) 89 | ``` 90 | 91 | Check out our [Quickstart Video](https://youtu.be/fl9SzsoCegw?si=MhQZDfNS6U-ZsVYc) for a guided walkthrough. 92 | 93 | ## Focus on your *product*, not the *LLM* 🎯 94 | 95 | Despite all of the hype, abstractions, and jargon, the *process* for building quality LLM apps is pretty simple. 96 | 97 | ``` 98 | create simplest possible agent 🤖 99 | while True: 100 | create/expand unit tests (evals) 🗂️ 101 | while run(tests) failing: 🧪 102 | Analyze failures, understand the root cause 🔍 103 | Vary system prompt, in-context examples, tools etc. to rectify 🔀 104 | Beta test with users, find more failures 🚦 105 | ``` 106 | 107 | We've tried to strip away all of the excessive LLM jargon, so you can focus on your *product*, your *users*, and the *data* you care about, and *nothing else* 📈 108 | 109 | Unify takes inspiration from: 110 | - [PostHog](https://posthog.com/) / [Grafana](https://grafana.com/) / [LogFire](https://pydantic.dev/logfire) for powerful observability 🔬 111 | - [LangSmith](https://www.langchain.com/langsmith) / [BrainTrust](https://www.braintrust.dev/) / [Weave](https://wandb.ai/site/weave/) for LLM abstractions 🤖 112 | - [Notion](https://www.notion.com/) / [Airtable](https://www.airtable.com/) for composability and versatility 🧱 113 | 114 | Whether you're technical or non-technical, we hope Unify can help you to rapidly build top-notch LLM apps, and to remain fully focused on your *product* (not the *LLM*). 115 | 116 | ## Learn More 117 | 118 | Check out our [docs](https://docs.unify.ai/), and if you have any questions feel free to reach out to us on [discord](https://discord.com/invite/sXyFF8tDtm) 👾 119 | 120 | Unify is under active development 🚧, feedback in all shapes/sizes is also very welcome! 🙏 121 | 122 | Happy prompting! 🧑‍💻 123 | -------------------------------------------------------------------------------- /pydoc-markdown.yml: -------------------------------------------------------------------------------- 1 | loader: 2 | - type: package 3 | packages: 4 | - unify # Replace with your package name 5 | 6 | renderer: 7 | type: markdown 8 | options: 9 | output: docs # Output directory for generated Markdown files 10 | template: | 11 | # {{ module_name }} 12 | 13 | {{ module_doc }} 14 | 15 | ## Classes 16 | 17 | {{ classes }} 18 | 19 | ## Functions 20 | 21 | {{ functions }} 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "unifyai" 3 | packages = [{include = "unify"}] 4 | version = "0.9.10" 5 | readme = "README.md" 6 | description = "A Python package for interacting with the Unify API" 7 | authors = ["Unify "] 8 | repository = "https://github.com/unifyai/unify" 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.9" 12 | requests = "^2.31.0" 13 | requests-toolbelt = "^1.0.0" 14 | openai = "^1.47.0" 15 | jsonlines = "^4.0.0" 16 | rich = "^13.8.1" 17 | pytest = "^8.3.3" 18 | pytest-timeout = "^2.3.1" 19 | pytest-asyncio = "^0.24.0" 20 | termcolor ="2.5.0" 21 | aiohttp = "^3.11.12" 22 | 23 | [tool.poetry.group.dev.dependencies] 24 | types-requests = "*" 25 | flake8 = "~4.0.1" 26 | mypy = "^1.1.1" 27 | isort = "^5.11.4" 28 | pre-commit = "^3.0.1" 29 | wemake-python-styleguide = "^0.17.0" 30 | black = "^24.3.0" 31 | autoflake = "^1.6.1" 32 | pydoc-markdown = "^4.0.0" 33 | 34 | [tool.isort] 35 | profile = "black" 36 | multi_line_output = 3 37 | src_paths = ["orchestra",] 38 | 39 | [tool.mypy] 40 | strict = true 41 | ignore_missing_imports = true 42 | allow_subclassing_any = true 43 | allow_untyped_calls = true 44 | pretty = true 45 | show_error_codes = true 46 | implicit_reexport = true 47 | allow_untyped_decorators = true 48 | warn_unused_ignores = false 49 | warn_return_any = false 50 | namespace_packages = true 51 | 52 | [build-system] 53 | requires = ["poetry-core>=1.0.0"] 54 | build-backend = "poetry.core.masonry.api" 55 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | pytest_plugins = ("pytest_asyncio",) 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import unify 4 | 5 | 6 | def pytest_sessionstart(session): 7 | if os.environ.get("CI"): 8 | unify.delete_logs() 9 | -------------------------------------------------------------------------------- /tests/test_logging/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_logging/__init__.py -------------------------------------------------------------------------------- /tests/test_logging/helpers.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | import sys 4 | import traceback 5 | 6 | import unify 7 | 8 | 9 | def _handle_project(test_fn): 10 | # noinspection PyBroadException 11 | @functools.wraps(test_fn) 12 | def wrapper(*args, **kwargs): 13 | project = test_fn.__name__ 14 | if project in unify.list_projects(): 15 | unify.delete_project(project) 16 | try: 17 | with unify.Project(project): 18 | test_fn(*args, **kwargs) 19 | unify.delete_project(project) 20 | except: 21 | unify.delete_project(project) 22 | exc_type, exc_value, exc_tb = sys.exc_info() 23 | tb_string = "".join(traceback.format_exception(exc_type, exc_value, exc_tb)) 24 | raise Exception(f"{tb_string}") 25 | 26 | @functools.wraps(test_fn) 27 | async def async_wrapper(*args, **kwargs): 28 | project = test_fn.__name__ 29 | if project in unify.list_projects(): 30 | unify.delete_project(project) 31 | try: 32 | with unify.Project(project): 33 | await test_fn(*args, **kwargs) 34 | unify.delete_project(project) 35 | except: 36 | unify.delete_project(project) 37 | exc_type, exc_value, exc_tb = sys.exc_info() 38 | tb_string = "".join(traceback.format_exception(exc_type, exc_value, exc_tb)) 39 | raise Exception(f"{tb_string}") 40 | 41 | return async_wrapper if asyncio.iscoroutinefunction(test_fn) else wrapper 42 | -------------------------------------------------------------------------------- /tests/test_logging/test_projects.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import unify 4 | 5 | 6 | def test_set_project(): 7 | unify.deactivate() 8 | assert unify.active_project() is None 9 | unify.activate("my_project") 10 | assert unify.active_project() == "my_project" 11 | unify.deactivate() 12 | 13 | 14 | def test_unset_project(): 15 | unify.deactivate() 16 | assert unify.active_project() is None 17 | unify.activate("my_project") 18 | assert unify.active_project() == "my_project" 19 | unify.deactivate() 20 | assert unify.active_project() is None 21 | 22 | 23 | def test_with_project(): 24 | unify.deactivate() 25 | assert unify.active_project() is None 26 | with unify.Project("my_project"): 27 | assert unify.active_project() == "my_project" 28 | assert unify.active_project() is None 29 | 30 | 31 | def test_set_project_then_log(): 32 | unify.deactivate() 33 | assert unify.active_project() is None 34 | unify.activate("test_set_project_then_log") 35 | assert unify.active_project() == "test_set_project_then_log" 36 | unify.log(key=1.0) 37 | unify.deactivate() 38 | assert unify.active_project() is None 39 | unify.delete_project("test_set_project_then_log") 40 | 41 | 42 | def test_with_project_then_log(): 43 | unify.deactivate() 44 | assert unify.active_project() is None 45 | with unify.Project("test_with_project_then_log"): 46 | assert unify.active_project() == "test_with_project_then_log" 47 | unify.log(key=1.0) 48 | assert unify.active_project() is None 49 | unify.delete_project("test_with_project_then_log") 50 | 51 | 52 | def test_project_env_var(): 53 | unify.deactivate() 54 | assert unify.active_project() is None 55 | os.environ["UNIFY_PROJECT"] = "test_project_env_var" 56 | assert unify.active_project() == "test_project_env_var" 57 | unify.delete_logs(project="test_project_env_var") 58 | unify.log(x=0, y=1, z=2) 59 | del os.environ["UNIFY_PROJECT"] 60 | assert unify.active_project() is None 61 | try: 62 | logs = unify.get_logs(project="test_project_env_var") 63 | assert len(logs) == 1 64 | assert logs[0].entries == {"x": 0, "y": 1, "z": 2} 65 | finally: 66 | unify.delete_project("test_project_env_var") 67 | 68 | 69 | if __name__ == "__main__": 70 | pass 71 | -------------------------------------------------------------------------------- /tests/test_logging/test_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_logging/test_utils/__init__.py -------------------------------------------------------------------------------- /tests/test_logging/test_utils/test_artifacts.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | 4 | def test_artifacts(): 5 | project = "my_project" 6 | if project in unify.list_projects(): 7 | unify.delete_project(project) 8 | unify.create_project(project) 9 | artifacts = {"dataset": "my_dataset", "description": "this is my dataset"} 10 | assert len(unify.get_project_artifacts(project=project)) == 0 11 | unify.add_project_artifacts(project=project, **artifacts) 12 | assert "dataset" in unify.get_project_artifacts(project=project) 13 | unify.delete_project_artifact( 14 | "dataset", 15 | project=project, 16 | ) 17 | assert "dataset" not in unify.get_project_artifacts(project=project) 18 | assert "description" in unify.get_project_artifacts(project=project) 19 | 20 | 21 | if __name__ == "__main__": 22 | pass 23 | -------------------------------------------------------------------------------- /tests/test_logging/test_utils/test_async_logger.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | from ..helpers import _handle_project 4 | 5 | 6 | @_handle_project 7 | def test_async_logger(): 8 | try: 9 | logs_sync = [unify.log(x=i, y=i * 2, z=i * 3) for i in range(10)] 10 | 11 | unify.initialize_async_logger() 12 | logs_async = [unify.log(x=i, y=i * 2, z=i * 3) for i in range(10)] 13 | unify.shutdown_async_logger() 14 | 15 | assert len(logs_async) == len(logs_sync) 16 | for log_async, log_sync in zip( 17 | sorted(logs_async, key=lambda x: x.entries["x"]), 18 | sorted(logs_sync, key=lambda x: x.entries["x"]), 19 | ): 20 | assert log_async.entries == log_sync.entries 21 | assert unify.ASYNC_LOGGING == False 22 | except Exception as e: 23 | unify.shutdown_async_logger() 24 | raise e 25 | 26 | 27 | if __name__ == "__main__": 28 | pass 29 | -------------------------------------------------------------------------------- /tests/test_logging/test_utils/test_contexts.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | from ..helpers import _handle_project 4 | 5 | 6 | @_handle_project 7 | def test_create_context(): 8 | assert len(unify.get_contexts()) == 0 9 | unify.create_context("my_context") 10 | assert len(unify.get_contexts()) == 1 11 | assert "my_context" in unify.get_contexts() 12 | 13 | 14 | @_handle_project 15 | def test_get_contexts(): 16 | assert len(unify.get_contexts()) == 0 17 | unify.log(x=0, context="a/b") 18 | unify.log(x=1, context="a/b") 19 | unify.log(x=0, context="b/c") 20 | unify.log(x=1, context="b/c") 21 | contexts = unify.get_contexts() 22 | assert len(contexts) == 2 23 | assert "a/b" in contexts 24 | assert "b/c" in contexts 25 | contexts = unify.get_contexts(prefix="a") 26 | assert len(contexts) == 1 27 | assert "a/b" in contexts 28 | assert "a/c" not in contexts 29 | contexts = unify.get_contexts(prefix="b") 30 | assert len(contexts) == 1 31 | assert "b/c" in contexts 32 | assert "a/b" not in contexts 33 | 34 | 35 | @_handle_project 36 | def test_delete_context(): 37 | unify.log(x=0, context="a/b") 38 | contexts = unify.get_contexts() 39 | assert len(contexts) == 1 40 | assert "a/b" in contexts 41 | unify.delete_context("a/b") 42 | assert "a/b" not in unify.get_contexts() 43 | assert len(unify.get_logs()) == 0 44 | 45 | 46 | @_handle_project 47 | def test_add_logs_to_context(): 48 | l0 = unify.log(x=0, context="a/b") 49 | l1 = unify.log(x=1, context="a/b") 50 | l2 = unify.log(x=2, context="b/c") 51 | l3 = unify.log(x=3, context="b/c") 52 | unify.add_logs_to_context(log_ids=[l0.id, l1.id], context="b/c") 53 | assert len(unify.get_logs(context="a/b")) == 2 54 | assert unify.get_logs(context="a/b", return_ids_only=True) == [l1.id, l0.id] 55 | assert len(unify.get_logs(context="b/c")) == 4 56 | assert unify.get_logs(context="b/c", return_ids_only=True) == [ 57 | l3.id, 58 | l2.id, 59 | l1.id, 60 | l0.id, 61 | ] 62 | 63 | 64 | @_handle_project 65 | def test_rename_context(): 66 | unify.log(x=0, context="a/b") 67 | unify.rename_context("a/b", "a/c") 68 | contexts = unify.get_contexts() 69 | assert "a/b" not in contexts 70 | assert "a/c" in contexts 71 | logs = unify.get_logs(context="a/c") 72 | assert len(logs) == 1 73 | assert logs[0].context == "a/c" 74 | 75 | 76 | @_handle_project 77 | def test_get_context(): 78 | name = "foo" 79 | desc = "my_description" 80 | is_versioned = True 81 | allow_duplicates = True 82 | unify.create_context( 83 | name, 84 | description=desc, 85 | is_versioned=is_versioned, 86 | allow_duplicates=allow_duplicates, 87 | ) 88 | 89 | context = unify.get_context(name) 90 | assert context["name"] == name 91 | assert context["description"] == desc 92 | assert context["is_versioned"] is is_versioned 93 | assert context["allow_duplicates"] is allow_duplicates 94 | 95 | 96 | if __name__ == "__main__": 97 | pass 98 | -------------------------------------------------------------------------------- /tests/test_logging/test_utils/test_datasets.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | from ..helpers import _handle_project 4 | 5 | 6 | @_handle_project 7 | def test_list_datasets(): 8 | assert len(unify.list_datasets()) == 0 9 | unify.log(x=0, context="Datasets/Prod/TestSet") 10 | unify.log(x=1, context="Datasets/Prod/TestSet") 11 | unify.log(x=0, context="Datasets/Eval/ValidationSet") 12 | unify.log(x=1, context="Datasets/Eval/ValidationSet") 13 | datasets = unify.list_datasets() 14 | assert len(datasets) == 2 15 | assert "Prod/TestSet" in datasets 16 | assert "Eval/ValidationSet" in datasets 17 | datasets = unify.list_datasets(prefix="Prod") 18 | assert len(datasets) == 1 19 | assert "Prod/TestSet" in datasets 20 | assert "Eval/ValidationSet" not in datasets 21 | datasets = unify.list_datasets(prefix="Eval") 22 | assert len(datasets) == 1 23 | assert "Eval/ValidationSet" in datasets 24 | assert "Prod/TestSet" not in datasets 25 | 26 | 27 | @_handle_project 28 | def test_upload_dataset(): 29 | dataset = [ 30 | { 31 | "name": "Dan", 32 | "age": 31, 33 | "gender": "male", 34 | }, 35 | { 36 | "name": "Jane", 37 | "age": 25, 38 | "gender": "female", 39 | }, 40 | { 41 | "name": "John", 42 | "age": 35, 43 | "gender": "male", 44 | }, 45 | ] 46 | data = unify.upload_dataset("staff", dataset) 47 | assert len(data) == 3 48 | 49 | 50 | @_handle_project 51 | def test_add_dataset_entries(): 52 | dataset = [ 53 | { 54 | "name": "Dan", 55 | "age": 31, 56 | "gender": "male", 57 | }, 58 | { 59 | "name": "Jane", 60 | "age": 25, 61 | "gender": "female", 62 | }, 63 | { 64 | "name": "John", 65 | "age": 35, 66 | "gender": "male", 67 | }, 68 | ] 69 | ids = unify.upload_dataset("staff", dataset) 70 | assert len(ids) == 3 71 | dataset = unify.download_dataset("staff") 72 | assert len(dataset) == 3 73 | assert dataset[0].entries["name"] == "Dan" 74 | assert dataset[1].entries["name"] == "Jane" 75 | assert dataset[2].entries["name"] == "John" 76 | new_entries = [ 77 | { 78 | "name": "Chloe", 79 | "age": 28, 80 | "gender": "female", 81 | }, 82 | { 83 | "name": "Tom", 84 | "age": 32, 85 | "gender": "male", 86 | }, 87 | ] 88 | ids = unify.add_dataset_entries("staff", new_entries) 89 | assert len(ids) == 2 90 | dataset = unify.download_dataset("staff") 91 | assert len(dataset) == 5 92 | assert dataset[0].entries["name"] == "Dan" 93 | assert dataset[1].entries["name"] == "Jane" 94 | assert dataset[2].entries["name"] == "John" 95 | assert dataset[3].entries["name"] == "Chloe" 96 | assert dataset[4].entries["name"] == "Tom" 97 | 98 | 99 | @_handle_project 100 | def test_download_dataset(): 101 | dataset = [ 102 | { 103 | "name": "Dan", 104 | "age": 31, 105 | "gender": "male", 106 | }, 107 | { 108 | "name": "Jane", 109 | "age": 25, 110 | "gender": "female", 111 | }, 112 | { 113 | "name": "John", 114 | "age": 35, 115 | "gender": "male", 116 | }, 117 | ] 118 | unify.upload_dataset("staff", dataset) 119 | data = unify.download_dataset("staff") 120 | assert len(data) == 3 121 | 122 | 123 | @_handle_project 124 | def test_delete_dataset(): 125 | dataset = [ 126 | { 127 | "name": "Dan", 128 | "age": 31, 129 | "gender": "male", 130 | }, 131 | { 132 | "name": "Jane", 133 | "age": 25, 134 | "gender": "female", 135 | }, 136 | { 137 | "name": "John", 138 | "age": 35, 139 | "gender": "male", 140 | }, 141 | ] 142 | unify.upload_dataset("staff", dataset) 143 | unify.delete_dataset("staff") 144 | assert "staff" not in unify.list_datasets() 145 | 146 | 147 | if __name__ == "__main__": 148 | pass 149 | -------------------------------------------------------------------------------- /tests/test_logging/test_utils/test_projects.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | from ..helpers import _handle_project 4 | 5 | 6 | def test_project(): 7 | name = "my_project" 8 | if name in unify.list_projects(): 9 | unify.delete_project(name) 10 | assert name not in unify.list_projects() 11 | unify.create_project(name) 12 | assert name in unify.list_projects() 13 | new_name = "my_project1" 14 | unify.rename_project(name, new_name) 15 | assert new_name in unify.list_projects() 16 | unify.delete_project(new_name) 17 | assert new_name not in unify.list_projects() 18 | 19 | 20 | def test_project_thread_lock(): 21 | # all 10 threads would try to create the project at the same time without 22 | # thread locking, but only one should acquire the lock, and this should pass 23 | unify.map( 24 | unify.log, 25 | project="test_project", 26 | a=[1] * 10, 27 | b=[2] * 10, 28 | c=[3] * 10, 29 | from_args=True, 30 | ) 31 | unify.delete_project("test_project") 32 | 33 | 34 | @_handle_project 35 | def test_delete_project_logs(): 36 | [unify.log(x=i) for i in range(10)] 37 | assert len(unify.get_logs()) == 10 38 | unify.delete_project_logs("test_delete_project_logs") 39 | assert len(unify.get_logs()) == 0 40 | assert "test_delete_project_logs" in unify.list_projects() 41 | 42 | 43 | @_handle_project 44 | def test_delete_project_contexts(): 45 | unify.create_context("foo") 46 | unify.create_context("bar") 47 | 48 | assert len(unify.get_contexts()) == 2 49 | unify.delete_project_contexts("test_delete_project_contexts") 50 | 51 | assert len(unify.get_contexts()) == 0 52 | assert "test_delete_project_contexts" in unify.list_projects() 53 | 54 | 55 | if __name__ == "__main__": 56 | pass 57 | -------------------------------------------------------------------------------- /tests/test_routing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_routing/__init__.py -------------------------------------------------------------------------------- /tests/test_routing/test_fallbacks.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | 4 | def test_provider_fallback(): 5 | unify.Unify("claude-3-opus@anthropic->aws-bedrock").generate("Hello.") 6 | 7 | 8 | def test_model_fallback(): 9 | unify.Unify("gemini-1.5-pro->gemini-1.5-flash@vertex-ai").generate("Hello.") 10 | 11 | 12 | def test_endpoint_fallback(): 13 | unify.Unify( 14 | "llama-3.1-405b-chat@together-ai->gpt-4o@openai", 15 | ).generate("Hello.") 16 | 17 | 18 | if __name__ == "__main__": 19 | pass 20 | -------------------------------------------------------------------------------- /tests/test_routing/test_routing_syntax.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | # Meta Providers # 4 | # ---------------# 5 | 6 | 7 | def test_ttft(): 8 | for pre in ("", "lowest-"): 9 | unify.Unify(f"claude-3-opus@{pre}time-to-first-token").generate("Hello.") 10 | unify.Unify(f"gpt-4o@{pre}ttft").generate("Hello.") 11 | unify.Unify(f"mixtral-8x22b-instruct-v0.1@{pre}t").generate("Hello.") 12 | 13 | 14 | def test_itl(): 15 | for pre in ("", "lowest-"): 16 | unify.Unify(f"claude-3-opus@{pre}inter-token-latency").generate("Hello.") 17 | unify.Unify(f"gpt-4o@{pre}itl").generate("Hello.") 18 | unify.Unify(f"mixtral-8x22b-instruct-v0.1@{pre}i").generate("Hello.") 19 | 20 | 21 | def test_cost(): 22 | for pre in ("", "lowest-"): 23 | unify.Unify(f"claude-3-opus@{pre}cost").generate("Hello.") 24 | unify.Unify(f"mixtral-8x22b-instruct-v0.1@{pre}c").generate("Hello.") 25 | 26 | 27 | def test_input_cost(): 28 | for pre in ("", "lowest-"): 29 | unify.Unify(f"claude-3-opus@{pre}input-cost").generate("Hello.") 30 | unify.Unify(f"gpt-4o@{pre}ic").generate("Hello.") 31 | unify.Unify(f"mixtral-8x22b-instruct-v0.1@{pre}i").generate("Hello.") 32 | 33 | 34 | def test_output_cost(): 35 | for pre in ("", "lowest-"): 36 | unify.Unify(f"claude-3-opus@{pre}output-cost").generate("Hello.") 37 | unify.Unify(f"mixtral-8x22b-instruct-v0.1@{pre}oc").generate("Hello.") 38 | 39 | 40 | # Thresholds # 41 | # -----------# 42 | 43 | 44 | def test_thresholds(): 45 | unify.Unify("llama-3.1-405b-chat@inter-token-latency|c<5").generate("Hello.") 46 | 47 | 48 | # Search Space # 49 | # -------------# 50 | 51 | 52 | def test_routing_w_providers(): 53 | unify.Unify( 54 | "llama-3.1-405b-chat@itl|providers:groq,fireworks-ai,together-ai", 55 | ).generate("Hello.") 56 | 57 | 58 | def test_routing_skip_providers(): 59 | unify.Unify( 60 | "llama-3.1-405b-chat@itl|skip_providers:vertex-ai,aws-bedrock", 61 | ).generate("Hello.") 62 | 63 | 64 | if __name__ == "__main__": 65 | pass 66 | -------------------------------------------------------------------------------- /tests/test_universal_api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_universal_api/__init__.py -------------------------------------------------------------------------------- /tests/test_universal_api/test_basics.py: -------------------------------------------------------------------------------- 1 | import json 2 | from types import AsyncGeneratorType, GeneratorType 3 | 4 | import pytest 5 | from openai.types.chat import ParsedChatCompletion 6 | from pydantic import BaseModel 7 | from unify import AsyncUnify, Unify 8 | 9 | 10 | class Response(BaseModel): 11 | number: int 12 | 13 | 14 | class TestUnifyBasics: 15 | def test_invalid_api_key_raises_authentication_error(self) -> None: 16 | with pytest.raises(Exception): 17 | client = Unify( 18 | api_key="invalid_api_key", 19 | endpoint="gpt-4o@openai", 20 | ) 21 | client.generate(user_message="hello") 22 | 23 | def test_incorrect_model_name_raises_internal_server_error(self) -> None: 24 | with pytest.raises(Exception): 25 | Unify(model="wong-model-name") 26 | 27 | def test_generate_returns_string_when_stream_false(self) -> None: 28 | client = Unify( 29 | endpoint="gpt-4o@openai", 30 | ) 31 | result = client.generate(user_message="hello", stream=False) 32 | assert isinstance(result, str) 33 | 34 | def test_traced_and_cached(self) -> None: 35 | client = Unify( 36 | endpoint="gpt-4o@openai", 37 | traced=True, 38 | cache=True, 39 | ) 40 | client.generate("hello") 41 | 42 | def test_copy_client(self) -> None: 43 | client = Unify( 44 | endpoint="gpt-4o@openai", 45 | ) 46 | client.set_system_message("you are a helpful agent") 47 | clone = client.copy() 48 | assert clone.endpoint == "gpt-4o@openai" 49 | assert clone.system_message == "you are a helpful agent" 50 | clone.set_system_message("you are not helpful") 51 | assert clone.system_message == "you are not helpful" 52 | assert client.system_message == "you are a helpful agent" 53 | clone.set_endpoint("o1@openai") 54 | assert clone.endpoint == "o1@openai" 55 | assert client.endpoint == "gpt-4o@openai" 56 | 57 | def test_structured_output(self) -> None: 58 | client = Unify( 59 | endpoint="gpt-4o@openai", 60 | response_format=Response, 61 | ) 62 | 63 | result = client.generate( 64 | user_message="what is 1 + 1?", 65 | return_full_completion=True, 66 | ) 67 | assert isinstance(result, ParsedChatCompletion) 68 | assert isinstance(result.choices[0].message.content, str) 69 | result = json.loads(result.choices[0].message.content) 70 | assert isinstance(result, dict) 71 | assert result == {"number": 2} 72 | 73 | result = client.generate( 74 | user_message="what is 1 + 1?", 75 | ) 76 | assert isinstance(result, str) 77 | result = json.loads(result) 78 | assert isinstance(result, dict) 79 | assert result == {"number": 2} 80 | 81 | def test_structured_output_w_caching(self) -> None: 82 | client = Unify( 83 | endpoint="gpt-4o@openai", 84 | response_format=Response, 85 | cache=True, 86 | ) 87 | assert json.loads(client.generate(user_message="what is 1 + 1?"))["number"] == 2 88 | assert json.loads(client.generate(user_message="what is 1 + 1?"))["number"] == 2 89 | 90 | def test_generate_returns_generator_when_stream_true(self) -> None: 91 | client = Unify( 92 | endpoint="gpt-4o@openai", 93 | ) 94 | result = client.generate(user_message="hello", stream=True) 95 | assert isinstance(result, GeneratorType) 96 | 97 | def test_default_params_handled_correctly(self) -> None: 98 | client = Unify( 99 | endpoint="gpt-4o@openai", 100 | n=2, 101 | return_full_completion=True, 102 | ) 103 | result = client.generate(user_message="hello") 104 | assert len(result.choices) == 2 105 | 106 | def test_setter_chaining(self): 107 | client = Unify("gpt-4o@openai") 108 | client.set_temperature(0.5).set_n(2) 109 | assert client.temperature == 0.5 110 | assert client.n == 2 111 | 112 | def test_stateful(self): 113 | 114 | # via generate 115 | client = Unify("gpt-4o@openai", stateful=True) 116 | client.set_system_message("you are a good mathematician.") 117 | client.generate("What is 1 + 1?") 118 | client.generate("How do you know?") 119 | assert len(client.messages) == 5 120 | assert client.messages[0]["role"] == "system" 121 | assert client.messages[1]["role"] == "user" 122 | assert client.messages[2]["role"] == "assistant" 123 | assert client.messages[3]["role"] == "user" 124 | assert client.messages[4]["role"] == "assistant" 125 | 126 | # via append 127 | client = Unify("gpt-4o@openai", return_full_completion=True) 128 | client.set_stateful(True) 129 | client.set_system_message("You are an expert.") 130 | client.append_messages( 131 | [ 132 | { 133 | "role": "user", 134 | "content": [ 135 | { 136 | "type": "text", 137 | "text": "Hello", 138 | }, 139 | ], 140 | }, 141 | ], 142 | ) 143 | assert len(client.messages) == 2 144 | assert client.messages[0]["role"] == "system" 145 | assert client.messages[1]["role"] == "user" 146 | 147 | def test_json_structor(self): 148 | client = Unify( 149 | endpoint="gpt-4o@openai", 150 | temperature=0.5, 151 | ) 152 | serialized = client.json() 153 | assert serialized 154 | assert serialized["endpoint"] == "gpt-4o@openai" 155 | 156 | 157 | @pytest.mark.asyncio 158 | class TestAsyncUnifyBasics: 159 | async def test_invalid_api_key_raises_authentication_error(self) -> None: 160 | with pytest.raises(Exception): 161 | async_client = AsyncUnify( 162 | api_key="invalid_api_key", 163 | endpoint="gpt-4o@openai", 164 | ) 165 | await async_client.generate(user_message="hello") 166 | 167 | async def test_incorrect_model_name_raises_internal_server_error(self) -> None: 168 | with pytest.raises(Exception): 169 | AsyncUnify(model="wong-model-name") 170 | 171 | @pytest.mark.skip() 172 | async def test_generate_returns_string_when_stream_false(self) -> None: 173 | async_client = AsyncUnify( 174 | endpoint="gpt-4o@openai", 175 | ) 176 | result = await async_client.generate(user_message="hello", stream=False) 177 | assert isinstance(result, str) 178 | 179 | async def test_generate_returns_generator_when_stream_true(self) -> None: 180 | async_client = AsyncUnify( 181 | endpoint="gpt-4o@openai", 182 | ) 183 | result = await async_client.generate(user_message="hello", stream=True) 184 | assert isinstance(result, AsyncGeneratorType) 185 | 186 | @pytest.mark.skip() 187 | async def test_default_params_handled_correctly(self) -> None: 188 | async_client = AsyncUnify( 189 | endpoint="gpt-4o@openai", 190 | n=2, 191 | return_full_completion=True, 192 | ) 193 | result = await async_client.generate(user_message="hello") 194 | assert len(result.choices) == 2 195 | 196 | 197 | if __name__ == "__main__": 198 | pass 199 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_chatbot.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import traceback 3 | 4 | import pytest 5 | from unify import ChatBot, MultiUnify, Unify 6 | 7 | 8 | class SimulateInput: 9 | def __init__(self): 10 | self._messages = [ 11 | "What is the capital of Spain? Be succinct.", 12 | "Who is their most famous sports player? Be succinct.", 13 | "quit", 14 | ] 15 | self._count = 0 16 | self._true_input = None 17 | 18 | def _new_input(self): 19 | message = self._messages[self._count] 20 | self._count += 1 21 | print(message) 22 | return message 23 | 24 | def __enter__(self): 25 | self._true_input = builtins.__dict__["input"] 26 | builtins.__dict__["input"] = self._new_input 27 | 28 | def __exit__(self, exc_type, exc_value, tb): 29 | builtins.__dict__["input"] = self._true_input 30 | self._count = 0 31 | if exc_type is not None: 32 | traceback.print_exception(exc_type, exc_value, tb) 33 | return False 34 | return True 35 | 36 | 37 | class TestChatbotUniLLM: 38 | def test_constructor(self) -> None: 39 | client = Unify( 40 | endpoint="gpt-4o@openai", 41 | cache=True, 42 | ) 43 | ChatBot(client) 44 | 45 | def test_simple_non_stream_chat_n_quit(self): 46 | client = Unify( 47 | endpoint="gpt-4o@openai", 48 | cache=True, 49 | ) 50 | chatbot = ChatBot(client) 51 | with SimulateInput(): 52 | chatbot.run() 53 | 54 | @pytest.mark.skip() 55 | def test_simple_stream_chat_n_quit(self): 56 | client = Unify( 57 | endpoint="gpt-4o@openai", 58 | cache=True, 59 | stream=True, 60 | ) 61 | chatbot = ChatBot(client) 62 | with SimulateInput(): 63 | chatbot.run() 64 | 65 | 66 | class TestChatbotMultiUnify: 67 | def test_constructor(self) -> None: 68 | client = MultiUnify( 69 | endpoints=["gpt-4@openai", "gpt-4o@openai"], 70 | cache=True, 71 | ) 72 | ChatBot(client) 73 | 74 | def test_simple_non_stream_chat_n_quit(self): 75 | client = MultiUnify( 76 | endpoints=["gpt-4@openai", "gpt-4o@openai"], 77 | cache=True, 78 | ) 79 | chatbot = ChatBot(client) 80 | with SimulateInput(): 81 | chatbot.run() 82 | 83 | 84 | if __name__ == "__main__": 85 | pass 86 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_json_mode.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from unify import Unify 4 | 5 | 6 | def test_openai_json_mode() -> None: 7 | client = Unify(endpoint="gpt-4o@openai") 8 | result = client.generate( 9 | system_message="You are a helpful assistant designed to output JSON.", 10 | user_message="Who won the world series in 2020?", 11 | response_format={"type": "json_object"}, 12 | ) 13 | assert isinstance(result, str) 14 | result = json.loads(result) 15 | assert isinstance(result, dict) 16 | 17 | 18 | def test_anthropic_json_mode() -> None: 19 | client = Unify(endpoint="claude-3-opus@anthropic") 20 | result = client.generate( 21 | system_message="You are a helpful assistant designed to output JSON.", 22 | user_message="Who won the world series in 2020?", 23 | ) 24 | assert isinstance(result, str) 25 | result = json.loads(result) 26 | assert isinstance(result, dict) 27 | 28 | 29 | if __name__ == "__main__": 30 | pass 31 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_multi_llm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unify import AsyncMultiUnify, MultiUnify 3 | 4 | 5 | class TestMultiUnify: 6 | def test_constructor(self) -> None: 7 | MultiUnify( 8 | endpoints=["llama-3-8b-chat@together-ai", "gpt-4o@openai"], 9 | cache=True, 10 | ) 11 | 12 | def test_add_endpoints(self): 13 | endpoints = ("llama-3-8b-chat@together-ai", "gpt-4o@openai") 14 | client = MultiUnify(endpoints=endpoints, cache=True) 15 | assert client.endpoints == endpoints 16 | assert tuple(client.clients.keys()) == endpoints 17 | client.add_endpoints("claude-3.5-sonnet@anthropic") 18 | endpoints = ( 19 | "llama-3-8b-chat@together-ai", 20 | "gpt-4o@openai", 21 | "claude-3.5-sonnet@anthropic", 22 | ) 23 | assert client.endpoints == endpoints 24 | assert tuple(client.clients.keys()) == endpoints 25 | client.add_endpoints("claude-3.5-sonnet@anthropic") 26 | assert client.endpoints == endpoints 27 | assert tuple(client.clients.keys()) == endpoints 28 | with pytest.raises(Exception): 29 | client.add_endpoints("claude-3.5-sonnet@anthropic", ignore_duplicates=False) 30 | 31 | def test_remove_endpoints(self): 32 | endpoints = ( 33 | "llama-3-8b-chat@together-ai", 34 | "gpt-4o@openai", 35 | "claude-3.5-sonnet@anthropic", 36 | ) 37 | client = MultiUnify(endpoints=endpoints, cache=True) 38 | assert client.endpoints == endpoints 39 | assert tuple(client.clients.keys()) == endpoints 40 | client.remove_endpoints("claude-3.5-sonnet@anthropic") 41 | endpoints = ("llama-3-8b-chat@together-ai", "gpt-4o@openai") 42 | assert client.endpoints == endpoints 43 | assert tuple(client.clients.keys()) == endpoints 44 | client.remove_endpoints("claude-3.5-sonnet@anthropic") 45 | assert client.endpoints == endpoints 46 | assert tuple(client.clients.keys()) == endpoints 47 | with pytest.raises(Exception): 48 | client.remove_endpoints("claude-3.5-sonnet@anthropic", ignore_missing=False) 49 | 50 | def test_generate(self): 51 | endpoints = ( 52 | "gpt-4o@openai", 53 | "claude-3.5-sonnet@anthropic", 54 | ) 55 | client = MultiUnify(endpoints=endpoints, cache=True) 56 | responses = client.generate("Hello, how it is going?") 57 | for endpoint, (response_endpoint, response) in zip( 58 | endpoints, 59 | responses.items(), 60 | ): 61 | assert endpoint == response_endpoint 62 | assert isinstance(response, str) 63 | assert len(response) > 0 64 | 65 | def test_multi_message_histories(self): 66 | endpoints = ("claude-3.5-sonnet@anthropic", "gpt-4o@openai") 67 | messages = { 68 | "claude-3.5-sonnet@anthropic": [ 69 | {"role": "assistant", "content": "Let's talk about cats"}, 70 | ], 71 | "gpt-4o@openai": [ 72 | {"role": "assistant", "content": "Let's talk about dogs"}, 73 | ], 74 | } 75 | animals = {"claude-3.5-sonnet@anthropic": "cat", "gpt-4o@openai": "dog"} 76 | client = MultiUnify( 77 | endpoints=endpoints, 78 | messages=messages, 79 | cache=True, 80 | ) 81 | responses = client.generate("What animal did you want to talk about?") 82 | for endpoint, (response_endpoint, response) in zip( 83 | endpoints, 84 | responses.items(), 85 | ): 86 | assert endpoint == response_endpoint 87 | assert isinstance(response, str) 88 | assert len(response) > 0 89 | assert animals[endpoint] in response.lower() 90 | 91 | def test_setter_chaining(self): 92 | endpoints = ( 93 | "llama-3-8b-chat@together-ai", 94 | "gpt-4o@openai", 95 | "claude-3.5-sonnet@anthropic", 96 | ) 97 | client = MultiUnify(endpoints=endpoints, cache=True) 98 | client.add_endpoints(["gpt-4@openai", "gpt-4-turbo@openai"]).remove_endpoints( 99 | "claude-3.5-sonnet@anthropic", 100 | ) 101 | assert set(client.endpoints) == { 102 | "llama-3-8b-chat@together-ai", 103 | "gpt-4o@openai", 104 | "gpt-4@openai", 105 | "gpt-4-turbo@openai", 106 | } 107 | 108 | 109 | @pytest.mark.asyncio 110 | class TestAsyncMultiUnify: 111 | async def test_async_generate(self): 112 | endpoints = ( 113 | "gpt-4o@openai", 114 | "claude-3.5-sonnet@anthropic", 115 | ) 116 | client = AsyncMultiUnify(endpoints=endpoints, cache=True) 117 | responses = await client.generate("Hello, how it is going?") 118 | for endpoint, (response_endpoint, response) in zip( 119 | endpoints, 120 | responses.items(), 121 | ): 122 | assert endpoint == response_endpoint 123 | assert isinstance(response, str) 124 | assert len(response) > 0 125 | 126 | 127 | if __name__ == "__main__": 128 | pass 129 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_stateful.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unify import AsyncUnify, Unify 3 | 4 | 5 | class TestStateful: 6 | # ──────────────────────────────── SYNC ──────────────────────────────── # 7 | def test_stateful_sync_non_stream(self): 8 | client = Unify(endpoint="gpt-4o@openai", cache=True, stateful=True) 9 | 10 | client.generate(user_message="hi") # non-stream 11 | assert len(client.messages) == 2 # user + assistant 12 | assert client.messages[-1]["role"] == "assistant" 13 | assert isinstance(client.messages[-1]["content"], str) 14 | 15 | def test_stateful_sync_stream(self): 16 | client = Unify(endpoint="gpt-4o@openai", cache=True, stateful=True) 17 | 18 | chunks = list(client.generate(user_message="hello", stream=True)) 19 | assert all([isinstance(c, str) for c in chunks]) 20 | assert len(client.messages) == 2 # user + assistant 21 | assert isinstance(client.messages[-1]["content"], str) 22 | 23 | def test_stateless_sync_stream_clears_history(self): 24 | client = Unify(endpoint="gpt-4o@openai", stateful=False) 25 | 26 | list(client.generate(user_message="hello", stream=True)) 27 | assert client.messages == [] # history wiped 28 | 29 | # ─────────────────────────────── ASYNC ──────────────────────────────── # 30 | @pytest.mark.asyncio 31 | async def test_stateful_async_non_stream(self): 32 | client = AsyncUnify(endpoint="gpt-4o@openai", cache=True, stateful=True) 33 | 34 | await client.generate(user_message="hi") # non-stream 35 | assert len(client.messages) == 2 # user + assistant 36 | assert isinstance(client.messages[-1]["content"], str) 37 | 38 | @pytest.mark.asyncio 39 | async def test_stateful_async_stream(self): 40 | client = AsyncUnify(endpoint="gpt-4o@openai", cache=True, stateful=True) 41 | 42 | stream = await client.generate(user_message="hello", stream=True) 43 | assert all([isinstance(c, str) async for c in stream]) 44 | assert len(client.messages) == 2 # user + assistant 45 | assert isinstance(client.messages[-1]["content"], str) 46 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_usage.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import unify 4 | 5 | dir_path = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | 8 | def test_with_logging() -> None: 9 | model_fn = lambda msg: "This is my response." 10 | model_fn = unify.with_logging(model_fn, endpoint="my_model") 11 | model_fn(msg="Hello?") 12 | 13 | 14 | if __name__ == "__main__": 15 | pass 16 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_user_input.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import traceback 3 | 4 | import unify 5 | from openai.types.chat import ChatCompletion 6 | 7 | 8 | class SimulateInput: 9 | def __init__(self): 10 | self._message = "Hi, how can I help you?" 11 | self._true_input = None 12 | 13 | def _new_input(self, user_instructions): 14 | if user_instructions is not None: 15 | print(user_instructions) 16 | print(self._message) 17 | return self._message 18 | 19 | def __enter__(self): 20 | self._true_input = builtins.__dict__["input"] 21 | builtins.__dict__["input"] = self._new_input 22 | 23 | def __exit__(self, exc_type, exc_value, tb): 24 | builtins.__dict__["input"] = self._true_input 25 | self._count = 0 26 | if exc_type is not None: 27 | traceback.print_exception(exc_type, exc_value, tb) 28 | return False 29 | return True 30 | 31 | 32 | def test_user_input_client(): 33 | client = unify.Unify("user-input") 34 | with SimulateInput(): 35 | response = client.generate("hello") 36 | assert isinstance(response, str) 37 | response = client.generate("hello", return_full_completion=True) 38 | assert isinstance(response, ChatCompletion) 39 | 40 | 41 | if __name__ == "__main__": 42 | pass 43 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_universal_api/test_utils/__init__.py -------------------------------------------------------------------------------- /tests/test_universal_api/test_utils/test_credits.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | 4 | def test_get_credits() -> None: 5 | creds = unify.get_credits() 6 | assert isinstance(creds, float) 7 | 8 | 9 | if __name__ == "__main__": 10 | pass 11 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_utils/test_custom_api_keys.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | 4 | # noinspection PyBroadException 5 | class CustomAPIKeyHandler: 6 | def __init__(self, ky_name, ky_value, nw_name): 7 | self._key_name = ky_name 8 | self._key_value = ky_value 9 | self._new_name = nw_name 10 | 11 | def _handle(self): 12 | # should work even if list_custom_api_keys does not 13 | for name in (self._key_name, self._new_name): 14 | try: 15 | unify.delete_custom_api_key(name) 16 | except: 17 | pass 18 | # should if other keys have wrongly been created 19 | try: 20 | custom_keys = unify.list_custom_api_keys() 21 | for dct in custom_keys: 22 | unify.delete_custom_api_key(dct["name"]) 23 | except: 24 | pass 25 | 26 | def __enter__(self): 27 | self._handle() 28 | 29 | def __exit__(self, exc_type, exc_val, exc_tb): 30 | self._handle() 31 | 32 | 33 | def _find_key(key_to_find, list_of_keys): 34 | for key in list_of_keys: 35 | if key["name"] == key_to_find: 36 | return True 37 | return False 38 | 39 | 40 | key_name = "my_test_key2" 41 | key_value = "1234" 42 | new_name = "new_test_key" 43 | handler = CustomAPIKeyHandler( 44 | key_name, 45 | key_value, 46 | new_name, 47 | ) 48 | 49 | 50 | def test_create_custom_api_key(): 51 | with handler: 52 | response = unify.create_custom_api_key(key_name, key_value) 53 | assert response == {"info": "API key created successfully!"} 54 | 55 | 56 | def test_list_custom_api_keys(): 57 | with handler: 58 | custom_keys = unify.list_custom_api_keys() 59 | assert isinstance(custom_keys, list) 60 | assert len(custom_keys) == 0 61 | unify.create_custom_api_key(key_name, key_value) 62 | custom_keys = unify.list_custom_api_keys() 63 | assert isinstance(custom_keys, list) 64 | assert len(custom_keys) == 1 65 | assert custom_keys[0]["name"] == key_name 66 | assert custom_keys[0]["value"] == "*" * 4 + key_value 67 | 68 | 69 | def test_get_custom_api_key(): 70 | with handler: 71 | unify.create_custom_api_key(key_name, key_value) 72 | retrieved_key = unify.get_custom_api_key(key_name) 73 | assert isinstance(retrieved_key, dict) 74 | assert retrieved_key["name"] == key_name 75 | assert retrieved_key["value"] == "*" * 4 + key_value 76 | 77 | 78 | def test_rename_custom_api_key(): 79 | with handler: 80 | unify.create_custom_api_key(key_name, key_value) 81 | custom_keys = unify.list_custom_api_keys() 82 | assert isinstance(custom_keys, list) 83 | assert len(custom_keys) == 1 84 | assert custom_keys[0]["name"] == key_name 85 | unify.rename_custom_api_key(key_name, new_name) 86 | custom_keys = unify.list_custom_api_keys() 87 | assert isinstance(custom_keys, list) 88 | assert len(custom_keys) == 1 89 | assert custom_keys[0]["name"] == new_name 90 | 91 | 92 | def test_delete_custom_api_key(): 93 | with handler: 94 | unify.create_custom_api_key(key_name, key_value) 95 | custom_keys = unify.list_custom_api_keys() 96 | assert isinstance(custom_keys, list) 97 | assert len(custom_keys) == 1 98 | assert custom_keys[0]["name"] == key_name 99 | unify.delete_custom_api_key(key_name) 100 | custom_keys = unify.list_custom_api_keys() 101 | assert isinstance(custom_keys, list) 102 | assert len(custom_keys) == 0 103 | 104 | 105 | if __name__ == "__main__": 106 | pass 107 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_utils/test_custom_endpoints.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | 4 | # noinspection PyBroadException 5 | class CustomEndpointHandler: 6 | def __init__(self, ky_name, ky_value, endpoint_names): 7 | self._key_name = ky_name 8 | self._key_value = ky_value 9 | self._endpoint_names = endpoint_names 10 | 11 | def _handle(self): 12 | try: 13 | unify.delete_custom_api_key(self._key_name) 14 | except: 15 | pass 16 | for endpoint_nm in self._endpoint_names: 17 | try: 18 | unify.delete_custom_endpoint(endpoint_nm) 19 | except: 20 | pass 21 | unify.create_custom_api_key(self._key_name, self._key_value) 22 | 23 | def __enter__(self): 24 | self._handle() 25 | 26 | def __exit__(self, exc_type, exc_val, exc_tb): 27 | self._handle() 28 | 29 | 30 | endpoint_name = "my_endpoint@custom" 31 | new_endpoint_name = "renamed@custom" 32 | endpoint_url = "test.com" 33 | key_name = "test_key" 34 | key_value = "4321" 35 | unify.create_custom_api_key(key_name, key_value) 36 | custom_endpoint_handler = CustomEndpointHandler( 37 | key_name, 38 | key_value, 39 | [endpoint_name, new_endpoint_name], 40 | ) 41 | 42 | 43 | def test_create_custom_endpoint(): 44 | with custom_endpoint_handler: 45 | unify.create_custom_endpoint( 46 | name=endpoint_name, 47 | url=endpoint_url, 48 | key_name=key_name, 49 | ) 50 | 51 | 52 | def test_list_custom_endpoints(): 53 | with custom_endpoint_handler: 54 | unify.create_custom_endpoint( 55 | name=endpoint_name, 56 | url=endpoint_url, 57 | key_name=key_name, 58 | ) 59 | custom_endpoints = unify.list_custom_endpoints() 60 | assert len(custom_endpoints) == 1 61 | assert endpoint_name == custom_endpoints[0]["name"] 62 | 63 | 64 | def test_rename_custom_endpoint(): 65 | with custom_endpoint_handler: 66 | unify.create_custom_endpoint( 67 | name=endpoint_name, 68 | url=endpoint_url, 69 | key_name=key_name, 70 | ) 71 | custom_endpoints = unify.list_custom_endpoints() 72 | assert len(custom_endpoints) == 1 73 | assert endpoint_name == custom_endpoints[0]["name"] 74 | unify.rename_custom_endpoint( 75 | endpoint_name, 76 | new_endpoint_name, 77 | ) 78 | custom_endpoints = unify.list_custom_endpoints() 79 | assert len(custom_endpoints) == 1 80 | assert new_endpoint_name == custom_endpoints[0]["name"] 81 | 82 | 83 | def test_delete_custom_endpoints(): 84 | with custom_endpoint_handler: 85 | unify.create_custom_endpoint( 86 | name=endpoint_name, 87 | url=endpoint_url, 88 | key_name=key_name, 89 | ) 90 | custom_endpoints = unify.list_custom_endpoints() 91 | assert len(custom_endpoints) == 1 92 | assert endpoint_name == custom_endpoints[0]["name"] 93 | unify.delete_custom_endpoint(endpoint_name) 94 | custom_endpoints = unify.list_custom_endpoints() 95 | assert len(custom_endpoints) == 0 96 | 97 | 98 | if __name__ == "__main__": 99 | pass 100 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_utils/test_endpoint_metrics.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime, timezone 3 | 4 | import unify 5 | from unify import Metrics 6 | 7 | 8 | # noinspection PyBroadException 9 | class CustomEndpointHandler: 10 | def __init__(self, ep_name, ep_url, ky_name, ky_value): 11 | self.endpoint_name = ep_name 12 | self._endpoint_url = ep_url 13 | self._key_name = ky_name 14 | self._key_value = ky_value 15 | 16 | def _cleanup(self): 17 | try: 18 | unify.delete_endpoint_metrics(self.endpoint_name) 19 | except: 20 | pass 21 | try: 22 | unify.delete_custom_endpoint(self.endpoint_name) 23 | except: 24 | pass 25 | try: 26 | unify.delete_custom_api_key(self._key_name) 27 | except: 28 | pass 29 | 30 | def __enter__(self): 31 | self._cleanup() 32 | unify.create_custom_api_key(self._key_name, self._key_value) 33 | unify.create_custom_endpoint( 34 | name=self.endpoint_name, 35 | url=self._endpoint_url, 36 | key_name=self._key_name, 37 | ) 38 | 39 | def __exit__(self, exc_type, exc_val, exc_tb): 40 | self._cleanup() 41 | 42 | 43 | endpoint_name = "my_endpoint@custom" 44 | endpoint_url = "test.com" 45 | key_name = "test_key" 46 | key_value = "4321" 47 | handler = CustomEndpointHandler( 48 | endpoint_name, 49 | endpoint_url, 50 | key_name, 51 | key_value, 52 | ) 53 | 54 | 55 | def test_get_public_endpoint_metrics(): 56 | metrics = unify.get_endpoint_metrics("gpt-4o@openai") 57 | assert isinstance(metrics, list) 58 | assert len(metrics) == 1 59 | metrics = metrics[0] 60 | assert isinstance(metrics, Metrics) 61 | assert hasattr(metrics, "ttft") 62 | assert isinstance(metrics.ttft, float) 63 | assert hasattr(metrics, "itl") 64 | assert isinstance(metrics.itl, float) 65 | assert hasattr(metrics, "input_cost") 66 | assert isinstance(metrics.input_cost, float) 67 | assert hasattr(metrics, "output_cost") 68 | assert isinstance(metrics.output_cost, float) 69 | assert hasattr(metrics, "measured_at") 70 | assert isinstance(metrics.measured_at, str) 71 | 72 | 73 | def test_client_metric_properties(): 74 | client = unify.Unify("gpt-4o@openai", cache=True) 75 | assert isinstance(client.input_cost, float) 76 | assert isinstance(client.output_cost, float) 77 | assert isinstance(client.ttft, float) 78 | assert isinstance(client.itl, float) 79 | client = unify.MultiUnify( 80 | ["gpt-4o@openai", "claude-3-haiku@anthropic"], 81 | cache=True, 82 | ) 83 | assert isinstance(client.input_cost, dict) 84 | assert isinstance(client.input_cost["gpt-4o@openai"], float) 85 | assert isinstance(client.input_cost["claude-3-haiku@anthropic"], float) 86 | assert isinstance(client.output_cost, dict) 87 | assert isinstance(client.output_cost["gpt-4o@openai"], float) 88 | assert isinstance(client.output_cost["claude-3-haiku@anthropic"], float) 89 | assert isinstance(client.ttft, dict) 90 | assert isinstance(client.ttft["gpt-4o@openai"], float) 91 | assert isinstance(client.ttft["claude-3-haiku@anthropic"], float) 92 | assert isinstance(client.itl, dict) 93 | assert isinstance(client.itl["gpt-4o@openai"], float) 94 | assert isinstance(client.itl["claude-3-haiku@anthropic"], float) 95 | 96 | 97 | def test_log_endpoint_metric(): 98 | with handler: 99 | unify.log_endpoint_metric( 100 | endpoint_name, 101 | metric_name="itl", 102 | value=1.23, 103 | ) 104 | 105 | 106 | def test_log_and_get_endpoint_metric(): 107 | with handler: 108 | now = datetime.now(timezone.utc) 109 | unify.log_endpoint_metric( 110 | endpoint_name, 111 | metric_name="itl", 112 | value=1.23, 113 | ) 114 | metrics = unify.get_endpoint_metrics(endpoint_name, start_time=now) 115 | assert isinstance(metrics, list) 116 | assert len(metrics) == 1 117 | metrics = metrics[0] 118 | assert hasattr(metrics, "itl") 119 | assert isinstance(metrics.itl, float) 120 | assert metrics.itl == 1.23 121 | 122 | 123 | def test_log_and_get_endpoint_metric_with_time_windows(): 124 | with handler: 125 | t0 = datetime.now(timezone.utc) 126 | unify.log_endpoint_metric( 127 | endpoint_name, 128 | metric_name="itl", 129 | value=1.23, 130 | ) 131 | unify.log_endpoint_metric( 132 | endpoint_name, 133 | metric_name="ttft", 134 | value=4.56, 135 | ) 136 | time.sleep(0.5) 137 | t1 = datetime.now(timezone.utc) 138 | unify.log_endpoint_metric( 139 | endpoint_name, 140 | metric_name="itl", 141 | value=7.89, 142 | ) 143 | all_metrics = unify.get_endpoint_metrics(endpoint_name, start_time=t0) 144 | # two log events detected, due to double itl logging 145 | assert len(all_metrics) == 2 146 | # Data all accumulates at the latest entry (top of the stack) 147 | assert isinstance(all_metrics[0].itl, float) 148 | assert all_metrics[0].ttft is None 149 | assert isinstance(all_metrics[1].itl, float) 150 | assert isinstance(all_metrics[1].ttft, float) 151 | assert all_metrics[0].itl == 1.23 152 | assert all_metrics[1].ttft == 4.56 153 | assert all_metrics[1].itl == 7.89 154 | # The original two logs are not retrieved 155 | limited_metrics = unify.get_endpoint_metrics( 156 | endpoint_name, 157 | start_time=t1, 158 | ) 159 | assert len(limited_metrics) == 1 160 | assert limited_metrics[0].ttft is None 161 | assert isinstance(limited_metrics[0].itl, float) 162 | assert limited_metrics[0].itl == 7.89 163 | # The ttft is now retrieved due to 'latest' mode 164 | latest_metrics = unify.get_endpoint_metrics(endpoint_name) 165 | assert len(latest_metrics) == 1 166 | assert isinstance(latest_metrics[0].ttft, float) 167 | assert isinstance(latest_metrics[0].itl, float) 168 | assert latest_metrics[0].ttft == 4.56 169 | assert latest_metrics[0].itl == 7.89 170 | 171 | 172 | def test_delete_all_metrics_for_endpoint(): 173 | with handler: 174 | # log metric 175 | unify.log_endpoint_metric( 176 | endpoint_name, 177 | metric_name="itl", 178 | value=1.23, 179 | ) 180 | # verify it exists 181 | metrics = unify.get_endpoint_metrics(endpoint_name) 182 | assert isinstance(metrics, list) 183 | assert len(metrics) == 1 184 | # delete it 185 | unify.delete_endpoint_metrics(endpoint_name) 186 | # verify it no longer exists 187 | metrics = unify.get_endpoint_metrics(endpoint_name) 188 | assert isinstance(metrics, list) 189 | assert len(metrics) == 0 190 | 191 | 192 | def test_delete_some_metrics_for_endpoint(): 193 | with handler: 194 | # log metrics at t0 195 | t0 = datetime.now(timezone.utc) 196 | unify.log_endpoint_metric( 197 | endpoint_name, 198 | metric_name="itl", 199 | value=1.23, 200 | ) 201 | unify.log_endpoint_metric( 202 | endpoint_name, 203 | metric_name="ttft", 204 | value=4.56, 205 | ) 206 | time.sleep(0.5) 207 | # log metric at t1 208 | unify.log_endpoint_metric( 209 | endpoint_name, 210 | metric_name="itl", 211 | value=7.89, 212 | ) 213 | # verify both exist 214 | metrics = unify.get_endpoint_metrics(endpoint_name, start_time=t0) 215 | assert len(metrics) == 2 216 | # delete the first itl entry 217 | unify.delete_endpoint_metrics( 218 | endpoint_name, 219 | timestamps=metrics[0].measured_at["itl"], 220 | ) 221 | # verify only the latest entry exists, with both itl and ttft 222 | metrics = unify.get_endpoint_metrics(endpoint_name, start_time=t0) 223 | assert len(metrics) == 1 224 | assert isinstance(metrics[0].itl, float) 225 | assert isinstance(metrics[0].ttft, float) 226 | # delete the ttft entry 227 | unify.delete_endpoint_metrics( 228 | endpoint_name, 229 | timestamps=metrics[0].measured_at["ttft"], 230 | ) 231 | # verify only the latest entry exists, with only the itl 232 | metrics = unify.get_endpoint_metrics(endpoint_name, start_time=t0) 233 | assert len(metrics) == 1 234 | assert isinstance(metrics[0].itl, float) 235 | assert metrics[0].ttft is None 236 | # delete the final itl entry 237 | unify.delete_endpoint_metrics( 238 | endpoint_name, 239 | timestamps=metrics[0].measured_at["itl"], 240 | ) 241 | # verify no metrics exist 242 | metrics = unify.get_endpoint_metrics(endpoint_name, start_time=t0) 243 | assert len(metrics) == 0 244 | 245 | 246 | if __name__ == "__main__": 247 | pass 248 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_utils/test_supported_endpoints.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import unify 3 | 4 | 5 | class TestSupportedModels: 6 | def test_list_models(self) -> None: 7 | models = unify.list_models() 8 | assert isinstance(models, list), "return type was not a list: {}".format( 9 | models, 10 | ) # is list 11 | assert models, "returned list was empty: {}".format(models) # not empty 12 | assert len(models) == len(set(models)), "duplication detected: {}".format( 13 | models, 14 | ) # no duplication 15 | 16 | def test_list_models_w_provider(self): 17 | models = unify.list_models("openai") 18 | assert isinstance(models, list), "return type was not a list: {}".format( 19 | models, 20 | ) # is list 21 | assert models, "returned list was empty: {}".format(models) # not empty 22 | assert len(models) == len(set(models)), "duplication detected: {}".format( 23 | models, 24 | ) # no duplication 25 | assert len(models) == len( 26 | [e for e in unify.list_endpoints() if "openai" in e], 27 | ), ( 28 | "number of models for the provider did not match the number of endpoints " 29 | "with the provider in the string" 30 | ) 31 | 32 | 33 | class TestSupportedProviders: 34 | def test_list_providers(self) -> None: 35 | providers = unify.list_providers() 36 | assert isinstance(providers, list), "return type was not a list: {}".format( 37 | providers, 38 | ) # is list 39 | assert providers, "returned list was empty: {}".format(providers) # not empty 40 | assert len(providers) == len(set(providers)), "duplication detected: {}".format( 41 | providers, 42 | ) # no duplication 43 | 44 | def test_list_providers_w_model(self): 45 | providers = unify.list_providers("llama-3.2-90b-chat") 46 | assert isinstance(providers, list), "return type was not a list: {}".format( 47 | providers, 48 | ) # is list 49 | assert providers, "returned list was empty: {}".format(providers) # not empty 50 | assert len(providers) == len(set(providers)), "duplication detected: {}".format( 51 | providers, 52 | ) # no duplication 53 | assert len(providers) == len( 54 | [e for e in unify.list_endpoints() if "llama-3.2-90b-chat" in e], 55 | ), ( 56 | "number of providers for the model did not match the number of endpoints " 57 | "with the model in the string" 58 | ) 59 | 60 | 61 | class TestSupportedEndpoints: 62 | def test_list_endpoints(self) -> None: 63 | endpoints = unify.list_endpoints() 64 | assert isinstance(endpoints, list), "return type was not a list: {}".format( 65 | endpoints, 66 | ) # is list 67 | assert endpoints, "returned list was empty: {}".format(endpoints) # not empty 68 | assert len(endpoints) == len(set(endpoints)), "duplication detected: {}".format( 69 | endpoints, 70 | ) # no duplication 71 | 72 | def test_list_endpoints_w_model(self) -> None: 73 | endpoints = unify.list_endpoints(model="llama-3.2-90b-chat") 74 | assert isinstance(endpoints, list), "return type was not a list: {}".format( 75 | endpoints, 76 | ) # is list 77 | assert endpoints, "returned list was empty: {}".format(endpoints) # not empty 78 | assert len(endpoints) == len(set(endpoints)), "duplication detected: {}".format( 79 | endpoints, 80 | ) # no duplication 81 | assert len(endpoints) == len(unify.list_providers("llama-3.2-90b-chat")), ( 82 | "number of endpoints for the model did not match the number of providers " 83 | "for the model" 84 | ) 85 | 86 | def test_list_endpoints_w_provider(self) -> None: 87 | endpoints = unify.list_endpoints(provider="openai") 88 | assert isinstance(endpoints, list), "return type was not a list: {}".format( 89 | endpoints, 90 | ) # is list 91 | assert endpoints, "returned list was empty: {}".format(endpoints) # not empty 92 | assert len(endpoints) == len(set(endpoints)), "duplication detected: {}".format( 93 | endpoints, 94 | ) # no duplication 95 | assert len(endpoints) == len(unify.list_models("openai")), ( 96 | "number of endpoints for the provider did not match the number of models " 97 | "for the provider" 98 | ) 99 | 100 | def test_list_endpoints_w_model_w_provider(self) -> None: 101 | with pytest.raises(Exception): 102 | unify.list_endpoints("gpt-4o", "openai") 103 | 104 | 105 | if __name__ == "__main__": 106 | pass 107 | -------------------------------------------------------------------------------- /tests/test_universal_api/test_utils/test_usage.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from datetime import datetime, timedelta, timezone 3 | from typing import Callable 4 | 5 | import pytest 6 | import unify 7 | 8 | THREAD_LOCK = threading.Lock() 9 | 10 | tag = "test_tag" 11 | data = { 12 | "endpoint": "local_model_test@external", 13 | "query_body": { 14 | "messages": [ 15 | {"role": "system", "content": "You are an useful assistant"}, 16 | {"role": "user", "content": "Explain who Newton was."}, 17 | ], 18 | "model": "llama-3-8b-chat@aws-bedrock", 19 | "max_tokens": 100, 20 | "temperature": 0.5, 21 | }, 22 | "response_body": { 23 | "model": "meta.llama3-8b-instruct-v1:0", 24 | "created": 1725396241, 25 | "id": "chatcmpl-92d3b36e-7b64-4ae8-8102-9b7e3f5dd30f", 26 | "object": "chat.completion", 27 | "usage": { 28 | "completion_tokens": 100, 29 | "prompt_tokens": 44, 30 | "total_tokens": 144, 31 | }, 32 | "choices": [ 33 | { 34 | "finish_reason": "stop", 35 | "index": 0, 36 | "message": { 37 | "content": "Sir Isaac Newton was an English mathematician, " 38 | "physicist, and astronomer who lived from 1643 " 39 | "to 1727.\\n\\nHe is widely recognized as one " 40 | "of the most influential scientists in history, " 41 | "and his work laid the foundation for the " 42 | "Scientific Revolution of the 17th century." 43 | "\\n\\nNewton's most famous achievement is his " 44 | "theory of universal gravitation, which he " 45 | "presented in his groundbreaking book " 46 | '"Philosophi\\u00e6 Naturalis Principia ' 47 | 'Mathematica" in 1687.', 48 | "role": "assistant", 49 | }, 50 | }, 51 | ], 52 | }, 53 | "timestamp": str(datetime.now(timezone.utc) + timedelta(seconds=0.01)), 54 | "tags": [tag], 55 | } 56 | 57 | 58 | def _thread_locked(fn: Callable) -> Callable: 59 | # noinspection PyBroadException 60 | def wrapped(*args, **kwargs): 61 | THREAD_LOCK.acquire() 62 | try: 63 | ret = fn(*args, **kwargs) 64 | THREAD_LOCK.release() 65 | return ret 66 | except: 67 | THREAD_LOCK.release() 68 | 69 | return wrapped 70 | 71 | 72 | @_thread_locked 73 | def test_log_query_manually(): 74 | result = unify.log_query(**data) 75 | assert isinstance(result, dict) 76 | assert "info" in result 77 | assert result["info"] == "Query logged successfully" 78 | 79 | 80 | @_thread_locked 81 | def test_log_query_via_chat_completion(): 82 | client = unify.Unify("gpt-4o@openai") 83 | response = client.generate( 84 | "hello", 85 | log_query_body=True, 86 | log_response_body=True, 87 | ) 88 | assert isinstance(response, str) 89 | 90 | 91 | @_thread_locked 92 | def test_get_queries_from_manual(): 93 | start_time = datetime.now(timezone.utc) 94 | unify.log_query(**data) 95 | history = unify.get_queries( 96 | endpoints="local_model_test@external", 97 | start_time=start_time, 98 | ) 99 | assert len(history) == 1 100 | history = unify.get_queries( 101 | endpoints="local_model_test@external", 102 | start_time=datetime.now(timezone.utc) + timedelta(seconds=1), 103 | ) 104 | assert len(history) == 0 105 | 106 | 107 | @_thread_locked 108 | def test_get_queries_from_chat_completion(): 109 | start_time = datetime.now(timezone.utc) 110 | unify.Unify("gpt-4o@openai").generate( 111 | "hello", 112 | log_query_body=True, 113 | log_response_body=True, 114 | ) 115 | history = unify.get_queries( 116 | endpoints="gpt-4o@openai", 117 | start_time=start_time, 118 | ) 119 | assert len(history) == 1 120 | history = unify.get_queries( 121 | endpoints="gpt-4o@openai", 122 | start_time=datetime.now(timezone.utc) + timedelta(seconds=1), 123 | ) 124 | assert len(history) == 0 125 | 126 | 127 | @_thread_locked 128 | def test_get_query_failures(): 129 | start_time = datetime.now(timezone.utc) 130 | client = unify.Unify("gpt-4o@openai") 131 | client.generate( 132 | "hello", 133 | log_query_body=True, 134 | log_response_body=True, 135 | ) 136 | with pytest.raises(Exception): 137 | client.generate( 138 | "hello", 139 | log_query_body=True, 140 | log_response_body=True, 141 | drop_params=False, 142 | invalid_arg="invalid_value", 143 | ) 144 | 145 | # inside logged timeframe 146 | history_w_both = unify.get_queries( 147 | endpoints="gpt-4o@openai", 148 | start_time=start_time, 149 | failures=True, 150 | ) 151 | assert len(history_w_both) == 2 152 | history_only_failures = unify.get_queries( 153 | endpoints="gpt-4o@openai", 154 | start_time=start_time, 155 | failures="only", 156 | ) 157 | assert len(history_only_failures) == 1 158 | history_only_success = unify.get_queries( 159 | endpoints="gpt-4o@openai", 160 | start_time=start_time, 161 | failures=False, 162 | ) 163 | assert len(history_only_success) == 1 164 | 165 | # Outside logged timeframe 166 | history_w_both = unify.get_queries( 167 | endpoints="gpt-4o@openai", 168 | start_time=datetime.now(timezone.utc) + timedelta(seconds=1), 169 | failures=True, 170 | ) 171 | assert len(history_w_both) == 0 172 | history_only_failures = unify.get_queries( 173 | endpoints="gpt-4o@openai", 174 | start_time=datetime.now(timezone.utc) + timedelta(seconds=1), 175 | failures="only", 176 | ) 177 | assert len(history_only_failures) == 0 178 | history_only_success = unify.get_queries( 179 | endpoints="gpt-4o@openai", 180 | start_time=datetime.now(timezone.utc) + timedelta(seconds=1), 181 | failures=False, 182 | ) 183 | assert len(history_only_success) == 0 184 | 185 | 186 | @_thread_locked 187 | def test_get_query_tags(): 188 | unify.log_query(**data) 189 | tags = unify.get_query_tags() 190 | assert isinstance(tags, list) 191 | assert tag in tags 192 | 193 | 194 | if __name__ == "__main__": 195 | pass 196 | -------------------------------------------------------------------------------- /tests/test_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_utils/__init__.py -------------------------------------------------------------------------------- /tests/test_utils/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import unify 4 | from unify.utils._caching import _get_caching_fpath 5 | 6 | 7 | class _CacheHandler: 8 | def __init__(self, fname=".test_cache.json"): 9 | self._old_cache_fpath = _get_caching_fpath() 10 | self._fname = fname 11 | self.test_path = "" 12 | 13 | def __enter__(self): 14 | unify.set_caching_fname(self._fname) 15 | self.test_path = _get_caching_fpath() 16 | if os.path.exists(self.test_path): 17 | os.remove(self.test_path) 18 | return self 19 | 20 | def __exit__(self, exc_type, exc_value, traceback): 21 | if os.path.exists(self.test_path): 22 | os.remove(self.test_path) 23 | unify.set_caching_fname(self._old_cache_fpath) 24 | -------------------------------------------------------------------------------- /tests/test_utils/test_map.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import pytest 5 | import unify 6 | from tests.test_utils.helpers import _CacheHandler 7 | from unify.utils._caching import _get_cache, _write_to_cache 8 | 9 | from ..test_logging.helpers import _handle_project 10 | 11 | # Helpers # 12 | # --------# 13 | 14 | 15 | class ProjectHandling: 16 | def __enter__(self): 17 | if "test_project" in unify.list_projects(): 18 | unify.delete_project("test_project") 19 | 20 | def __exit__(self, exc_type, exc_value, tb): 21 | if "test_project" in unify.list_projects(): 22 | unify.delete_project("test_project") 23 | 24 | 25 | client = unify.Unify("gpt-4o@openai", cache=True) 26 | async_client = unify.AsyncUnify("gpt-4o@openai", cache=True) 27 | qs = ["3 - 1", "4 + 7", "6 + 2", "9 - 3", "7 + 9"] 28 | 29 | 30 | def evaluate_response(question: str, response: str) -> float: 31 | correct_answer = eval(question) 32 | try: 33 | response_int = int( 34 | "".join([c for c in response.split(" ")[-1] if c.isdigit()]), 35 | ) 36 | return float(correct_answer == response_int) 37 | except ValueError: 38 | return 0.0 39 | 40 | 41 | def evaluate(q: str): 42 | response = client.generate(q) 43 | evaluate_response(q, response) 44 | 45 | 46 | def evaluate_w_log(q: str): 47 | response = client.generate(q) 48 | score = evaluate_response(q, response) 49 | return unify.log( 50 | question=q, 51 | response=response, 52 | score=score, 53 | skip_duplicates=False, 54 | ) 55 | 56 | 57 | @pytest.mark.asyncio 58 | async def async_evaluate(q: str): 59 | response = await async_client.generate(q) 60 | return evaluate_response(q, response) 61 | 62 | 63 | # Tests # 64 | # ------# 65 | 66 | 67 | def test_threaded_map() -> None: 68 | with ProjectHandling(): 69 | with unify.Project("test_project"): 70 | unify.map(evaluate_w_log, qs) 71 | for q in qs: 72 | evaluate_w_log(q) 73 | 74 | 75 | def test_map_mode() -> None: 76 | unify.set_map_mode("threading") 77 | assert unify.get_map_mode() == "threading" 78 | unify.map(evaluate_w_log, qs) 79 | unify.set_map_mode("asyncio") 80 | assert unify.get_map_mode() == "asyncio" 81 | unify.map(evaluate_w_log, qs) 82 | unify.set_map_mode("loop") 83 | assert unify.get_map_mode() == "loop" 84 | unify.map(evaluate_w_log, qs) 85 | 86 | 87 | @_handle_project 88 | def test_map_w_cache() -> None: 89 | with _CacheHandler(): 90 | 91 | @unify.traced(name="gen{x}") 92 | def gen(x, cache): 93 | ret = None 94 | if cache in [True, "read", "read-only"]: 95 | ret = _get_cache( 96 | fn_name="gen", 97 | kw={"x": x}, 98 | raise_on_empty=(cache == "read-only"), 99 | ) 100 | if ret is None: 101 | ret = random.randint(1, 5) + x 102 | if cache in [True, "write"]: 103 | _write_to_cache( 104 | fn_name="gen", 105 | kw={"x": x}, 106 | response=ret, 107 | ) 108 | return ret 109 | 110 | @unify.traced 111 | def fn(cache): 112 | x = gen(0, cache) 113 | time.sleep(random.uniform(0, 0.1)) 114 | y = gen(x, cache) 115 | time.sleep(random.uniform(0, 0.1)) 116 | z = gen(y, cache) 117 | 118 | @unify.traced 119 | def cache_is_true(): 120 | unify.map(fn, [True] * 10) 121 | 122 | @unify.traced 123 | def cache_is_read_only(): 124 | unify.map(fn, ["read-only"] * 10) 125 | 126 | cache_is_true() 127 | cache_is_read_only() 128 | 129 | 130 | def test_threaded_map_from_args() -> None: 131 | with ProjectHandling(): 132 | with unify.Project("test_project"): 133 | unify.map(evaluate_w_log, qs, from_args=True) 134 | for q in qs: 135 | evaluate_w_log(q) 136 | 137 | 138 | def test_threaded_map_with_context() -> None: 139 | with ProjectHandling(): 140 | with unify.Project("test_project"): 141 | 142 | def contextual_func(a, b, c=3): 143 | with unify.Entries(a=a, b=b, c=c): 144 | unify.log(test="some random value") 145 | return a + b + c 146 | 147 | results = unify.map( 148 | contextual_func, 149 | [ 150 | ((1, 3), {"c": 2}), 151 | ((2, 4), {"c": 4}), 152 | ], 153 | ) 154 | assert results == [1 + 3 + 2, 2 + 4 + 4] 155 | results = unify.map( 156 | contextual_func, 157 | [ 158 | ((1,), {"b": 2, "c": 2}), 159 | ((3,), {"b": 4, "c": 4}), 160 | ], 161 | ) 162 | assert results == [1 + 2 + 2, 3 + 4 + 4] 163 | 164 | 165 | def test_threaded_map_with_context_from_args() -> None: 166 | with ProjectHandling(): 167 | with unify.Project("test_project"): 168 | 169 | def contextual_func(a, b, c=3): 170 | with unify.Entries(a=a, b=b, c=c): 171 | unify.log(test="some random value") 172 | return a + b + c 173 | 174 | results = unify.map( 175 | contextual_func, 176 | (1, 2), 177 | (3, 4), 178 | c=(2, 4), 179 | from_args=True, 180 | ) 181 | assert results == [1 + 3 + 2, 2 + 4 + 4] 182 | results = unify.map( 183 | contextual_func, 184 | (1, 3), 185 | b=(2, 4), 186 | c=(2, 4), 187 | from_args=True, 188 | ) 189 | assert results == [1 + 2 + 2, 3 + 4 + 4] 190 | 191 | 192 | def test_asyncio_map() -> None: 193 | unify.map(async_evaluate, qs, mode="asyncio") 194 | for q in qs: 195 | evaluate(q) 196 | 197 | 198 | def test_asyncio_map_from_args() -> None: 199 | unify.map(async_evaluate, qs, mode="asyncio", from_args=True) 200 | for q in qs: 201 | evaluate(q) 202 | 203 | 204 | def test_loop_map() -> None: 205 | unify.map(evaluate_w_log, qs, mode="loop") 206 | 207 | 208 | def test_loop_map_from_args() -> None: 209 | unify.map(evaluate_w_log, qs, mode="loop", from_args=True) 210 | 211 | 212 | @pytest.mark.asyncio 213 | def test_asyncio_map_with_context() -> None: 214 | with ProjectHandling(): 215 | with unify.Project("test_project"): 216 | 217 | def contextual_func(a, b, c=3): 218 | with unify.Entries(a=a, b=b, c=c): 219 | time.sleep(0.1) 220 | unify.log(test="some random value") 221 | return a + b + c 222 | 223 | results = unify.map( 224 | contextual_func, 225 | [ 226 | ((1, 3), {"c": 2}), 227 | ((2, 4), {"c": 4}), 228 | ], 229 | mode="asyncio", 230 | ) 231 | assert results == [1 + 3 + 2, 2 + 4 + 4] 232 | results = unify.map( 233 | contextual_func, 234 | [ 235 | ((1,), {"b": 2, "c": 2}), 236 | ((3,), {"b": 4, "c": 4}), 237 | ], 238 | mode="asyncio", 239 | ) 240 | assert results == [1 + 2 + 2, 3 + 4 + 4] 241 | 242 | 243 | @pytest.mark.asyncio 244 | def test_asyncio_map_with_context_from_args() -> None: 245 | with ProjectHandling(): 246 | with unify.Project("test_project"): 247 | 248 | def contextual_func(a, b, c=3): 249 | with unify.Entries(a=a, b=b, c=c): 250 | time.sleep(0.1) 251 | unify.log(test="some random value") 252 | return a + b + c 253 | 254 | results = unify.map( 255 | contextual_func, 256 | (1, 2), 257 | (3, 4), 258 | c=2, 259 | mode="asyncio", 260 | from_args=True, 261 | ) 262 | assert results == [1 + 3 + 2, 2 + 4 + 2] 263 | results = unify.map( 264 | contextual_func, 265 | (1, 2), 266 | (3, 4), 267 | c=[2, 4], 268 | mode="asyncio", 269 | from_args=True, 270 | ) 271 | assert results == [1 + 3 + 2, 2 + 4 + 4] 272 | results = unify.map( 273 | contextual_func, 274 | (1, 3), 275 | b=[2, 4], 276 | c=[2, 4], 277 | mode="asyncio", 278 | from_args=True, 279 | ) 280 | assert results == [1 + 2 + 2, 3 + 4 + 4] 281 | 282 | 283 | if __name__ == "__main__": 284 | pass 285 | -------------------------------------------------------------------------------- /unify/__init__.py: -------------------------------------------------------------------------------- 1 | """Unify python module.""" 2 | 3 | import os 4 | from typing import Callable, Optional 5 | 6 | 7 | if "UNIFY_BASE_URL" in os.environ.keys(): 8 | BASE_URL = os.environ["UNIFY_BASE_URL"] 9 | else: 10 | BASE_URL = "https://api.unify.ai/v0" 11 | 12 | 13 | CLIENT_LOGGING = False 14 | LOCAL_MODELS = dict() 15 | SEED = None 16 | UNIFY_DIR = os.path.dirname(__file__) 17 | 18 | 19 | def set_seed(seed: int) -> None: 20 | global SEED 21 | SEED = seed 22 | 23 | 24 | def get_seed() -> Optional[int]: 25 | return SEED 26 | 27 | 28 | def register_local_model(model_name: str, fn: Callable): 29 | if "@local" not in model_name: 30 | model_name += "@local" 31 | LOCAL_MODELS[model_name] = fn 32 | 33 | 34 | from .universal_api.utils import ( 35 | credits, 36 | custom_api_keys, 37 | custom_endpoints, 38 | endpoint_metrics, 39 | queries, 40 | supported_endpoints, 41 | ) 42 | from .universal_api.utils.credits import * 43 | from .universal_api.utils.custom_api_keys import * 44 | from .universal_api.utils.custom_endpoints import * 45 | from .universal_api.utils.endpoint_metrics import * 46 | from .universal_api.utils.queries import * 47 | from .universal_api.utils.supported_endpoints import * 48 | 49 | from .logging.utils import artifacts 50 | from .logging.utils import compositions 51 | from .logging.utils import contexts 52 | from .logging.utils import datasets 53 | from .logging.utils import logs 54 | from .logging.utils import projects 55 | 56 | from .logging.utils.artifacts import * 57 | from .logging.utils.compositions import * 58 | from .logging.utils.contexts import * 59 | from .logging.utils.datasets import * 60 | from .logging.utils.logs import * 61 | from .logging.utils.projects import * 62 | from .logging.utils.tracing import install_tracing_hook, disable_tracing_hook 63 | 64 | from .utils import helpers, map, get_map_mode, set_map_mode, _caching 65 | from .utils._caching import ( 66 | set_caching, 67 | set_caching_fname, 68 | cache_file_union, 69 | cache_file_intersection, 70 | subtract_cache_files, 71 | cached, 72 | ) 73 | 74 | from .universal_api import chatbot, clients, usage 75 | from .universal_api.clients import multi_llm 76 | from .universal_api.chatbot import * 77 | from unify.universal_api.clients.uni_llm import * 78 | from unify.universal_api.clients.multi_llm import * 79 | 80 | from .universal_api import casting, types 81 | from .logging import dataset, logs 82 | 83 | from .universal_api.casting import * 84 | from .universal_api.usage import * 85 | from .universal_api.types import * 86 | 87 | from .logging.dataset import * 88 | from .logging.logs import * 89 | 90 | 91 | # Project # 92 | # --------# 93 | 94 | PROJECT: Optional[str] = None 95 | 96 | 97 | # noinspection PyShadowingNames 98 | def activate(project: str, overwrite: bool = False, api_key: str = None) -> None: 99 | if project not in list_projects(api_key=api_key): 100 | create_project(project, api_key=api_key) 101 | elif overwrite: 102 | create_project(project, api_key=api_key, overwrite=True) 103 | global PROJECT 104 | PROJECT = project 105 | 106 | 107 | def deactivate() -> None: 108 | global PROJECT 109 | PROJECT = None 110 | 111 | 112 | def active_project() -> str: 113 | global PROJECT 114 | if PROJECT is None: 115 | return os.environ.get("UNIFY_PROJECT") 116 | return PROJECT 117 | 118 | 119 | class Project: 120 | 121 | # noinspection PyShadowingNames 122 | def __init__( 123 | self, 124 | project: str, 125 | overwrite: bool = False, 126 | api_key: Optional[str] = None, 127 | ) -> None: 128 | self._project = project 129 | self._overwrite = overwrite 130 | # noinspection PyProtectedMember 131 | self._api_key = helpers._validate_api_key(api_key) 132 | self._entered = False 133 | 134 | def create(self) -> None: 135 | create_project(self._project, overwrite=self._overwrite, api_key=self._api_key) 136 | 137 | def delete(self): 138 | delete_project(self._project, api_key=self._api_key) 139 | 140 | def rename(self, new_name: str): 141 | rename_project(self._project, new_name, api_key=self._api_key) 142 | self._project = new_name 143 | if self._entered: 144 | activate(self._project) 145 | 146 | def __enter__(self): 147 | activate(self._project) 148 | if self._project not in list_projects(api_key=self._api_key) or self._overwrite: 149 | self.create() 150 | self._entered = True 151 | 152 | def __exit__(self, exc_type, exc_val, exc_tb): 153 | deactivate() 154 | self._entered = False 155 | -------------------------------------------------------------------------------- /unify/logging/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/unify/logging/__init__.py -------------------------------------------------------------------------------- /unify/logging/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | class RequestError(Exception): 5 | def __init__(self, response: requests.Response): 6 | req = response.request 7 | message = ( 8 | f"{req.method} {req.url} failed with status code {response.status_code}. " 9 | f"Request body: {req.body}, Response: {response.text}" 10 | ) 11 | super().__init__(message) 12 | self.response = response 13 | 14 | 15 | def _check_response(response: requests.Response): 16 | if not response.ok: 17 | raise RequestError(response) 18 | -------------------------------------------------------------------------------- /unify/logging/utils/artifacts.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | from unify import BASE_URL 4 | from unify.utils import _requests 5 | 6 | from ...utils.helpers import ( 7 | _check_response, 8 | _get_and_maybe_create_project, 9 | _validate_api_key, 10 | ) 11 | 12 | # Artifacts # 13 | # ----------# 14 | 15 | 16 | def add_project_artifacts( 17 | *, 18 | project: Optional[str] = None, 19 | api_key: Optional[str] = None, 20 | **kwargs, 21 | ) -> Dict[str, str]: 22 | """ 23 | Creates one or more artifacts associated to a project. Artifacts are project-level 24 | metadata that don’t depend on other variables. 25 | 26 | Args: 27 | project: Name of the project the artifacts belong to. 28 | 29 | api_key: If specified, unify API key to be used. Defaults to the value in the 30 | `UNIFY_KEY` environment variable. 31 | 32 | kwargs: Dictionary containing one or more key:value pairs that will be stored 33 | as artifacts. 34 | 35 | Returns: 36 | A message indicating whether the artifacts were successfully added. 37 | """ 38 | api_key = _validate_api_key(api_key) 39 | headers = { 40 | "accept": "application/json", 41 | "Authorization": f"Bearer {api_key}", 42 | } 43 | body = {"artifacts": kwargs} 44 | project = _get_and_maybe_create_project(project, api_key=api_key) 45 | response = _requests.post( 46 | BASE_URL + f"/project/{project}/artifacts", 47 | headers=headers, 48 | json=body, 49 | ) 50 | _check_response(response) 51 | return response.json() 52 | 53 | 54 | def delete_project_artifact( 55 | key: str, 56 | *, 57 | project: Optional[str] = None, 58 | api_key: Optional[str] = None, 59 | ) -> str: 60 | """ 61 | Deletes an artifact from a project. 62 | 63 | Args: 64 | project: Name of the project to delete an artifact from. 65 | 66 | key: Key of the artifact to delete. 67 | 68 | api_key: If specified, unify API key to be used. Defaults to the value in the 69 | `UNIFY_KEY` environment variable. 70 | 71 | Returns: 72 | Whether the artifact was successfully deleted. 73 | """ 74 | api_key = _validate_api_key(api_key) 75 | headers = { 76 | "accept": "application/json", 77 | "Authorization": f"Bearer {api_key}", 78 | } 79 | project = _get_and_maybe_create_project(project, api_key=api_key) 80 | response = _requests.delete( 81 | BASE_URL + f"/project/{project}/artifacts/{key}", 82 | headers=headers, 83 | ) 84 | _check_response(response) 85 | return response.json() 86 | 87 | 88 | def get_project_artifacts( 89 | *, 90 | project: Optional[str] = None, 91 | api_key: Optional[str] = None, 92 | ) -> Dict[str, Any]: 93 | """ 94 | Returns the key-value pairs for all artifacts in a project. 95 | 96 | Args: 97 | project: Name of the project to delete an artifact from. 98 | 99 | api_key: If specified, unify API key to be used. Defaults to the value in the 100 | `UNIFY_KEY` environment variable. 101 | 102 | Returns: 103 | A dictionary of all artifacts associated with the project, with keys for 104 | artifact names and values for the artifacts themselves. 105 | """ 106 | api_key = _validate_api_key(api_key) 107 | headers = { 108 | "accept": "application/json", 109 | "Authorization": f"Bearer {api_key}", 110 | } 111 | project = _get_and_maybe_create_project(project, api_key=api_key) 112 | response = _requests.get( 113 | BASE_URL + f"/project/{project}/artifacts", 114 | headers=headers, 115 | ) 116 | _check_response(response) 117 | return response.json() 118 | -------------------------------------------------------------------------------- /unify/logging/utils/async_logger.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | import threading 5 | from concurrent.futures import TimeoutError 6 | from typing import List 7 | 8 | import aiohttp 9 | from unify import BASE_URL 10 | 11 | # Configure logging based on environment variable 12 | ASYNC_LOGGER_DEBUG = os.getenv("UNIFY_ASYNC_LOGGER_DEBUG", "false").lower() in ( 13 | "true", 14 | "1", 15 | ) 16 | logger = logging.getLogger("async_logger") 17 | logger.setLevel(logging.DEBUG if ASYNC_LOGGER_DEBUG else logging.WARNING) 18 | 19 | 20 | class AsyncLoggerManager: 21 | def __init__( 22 | self, 23 | *, 24 | base_url: str = BASE_URL, 25 | api_key: str = os.getenv("UNIFY_KEY"), 26 | num_consumers: int = 256, 27 | max_queue_size: int = 10000, 28 | ): 29 | 30 | self.loop = asyncio.new_event_loop() 31 | self.queue = None 32 | self.consumers: List[asyncio.Task] = [] 33 | self.num_consumers = num_consumers 34 | self.start_flag = threading.Event() 35 | self.shutting_down = False 36 | self.max_queue_size = max_queue_size 37 | 38 | headers = { 39 | "Authorization": f"Bearer {api_key}", 40 | "Content-Type": "application/json", 41 | "accept": "application/json", 42 | } 43 | url = base_url + "/" 44 | connector = aiohttp.TCPConnector(limit=num_consumers // 2, loop=self.loop) 45 | self.session = aiohttp.ClientSession( 46 | url, 47 | headers=headers, 48 | loop=self.loop, 49 | connector=connector, 50 | ) 51 | 52 | self.thread = threading.Thread(target=self._run_loop, daemon=True) 53 | self.thread.start() 54 | self.start_flag.wait() 55 | self.callbacks = [] 56 | 57 | def register_callback(self, fn): 58 | self.callbacks.append(fn) 59 | 60 | def clear_callbacks(self): 61 | self.callbacks = [] 62 | 63 | def _notify_callbacks(self): 64 | for fn in self.callbacks: 65 | fn() 66 | 67 | async def _join(self): 68 | await self.queue.join() 69 | 70 | def join(self): 71 | try: 72 | future = asyncio.run_coroutine_threadsafe(self._join(), self.loop) 73 | while True: 74 | try: 75 | future.result(timeout=0.5) 76 | break 77 | except (asyncio.TimeoutError, TimeoutError): 78 | continue 79 | except Exception as e: 80 | logger.error(f"Error in join: {e}") 81 | raise e 82 | 83 | async def _main_loop(self): 84 | self.start_flag.set() 85 | await asyncio.gather(*self.consumers, return_exceptions=True) 86 | 87 | def _run_loop(self): 88 | asyncio.set_event_loop(self.loop) 89 | self.queue = asyncio.Queue(maxsize=self.max_queue_size) 90 | 91 | for _ in range(self.num_consumers): 92 | self.consumers.append(self._log_consumer()) 93 | 94 | try: 95 | self.loop.run_until_complete(self._main_loop()) 96 | except Exception as e: 97 | logger.error(f"Event loop error: {e}") 98 | raise e 99 | finally: 100 | self.loop.close() 101 | 102 | async def _consume_create(self, body, future, idx): 103 | async with self.session.post("logs", json=body) as res: 104 | if res.status != 200: 105 | txt = await res.text() 106 | logger.error(f"Error in consume_create {idx}: {txt}") 107 | return 108 | res_json = await res.json() 109 | logger.debug(f"Created {idx} with response {res.status}: {res_json}") 110 | future.set_result(res_json["log_event_ids"][0]) 111 | 112 | async def _consume_update(self, body, future, idx): 113 | if not future.done(): 114 | await future 115 | body["logs"] = [future.result()] 116 | async with self.session.put("logs", json=body) as res: 117 | if res.status != 200: 118 | txt = await res.text() 119 | logger.error(f"Error in consume_update {idx}: {txt}") 120 | return 121 | res_json = await res.json() 122 | logger.debug(f"Updated {idx} with response {res.status}: {res_json}") 123 | 124 | async def _log_consumer(self): 125 | while True: 126 | try: 127 | event = await self.queue.get() 128 | idx = self.queue.qsize() + 1 129 | logger.debug(f"Processing event {event['type']}: {idx}") 130 | if event["type"] == "create": 131 | await self._consume_create(event["_data"], event["future"], idx) 132 | elif event["type"] == "update": 133 | await self._consume_update(event["_data"], event["future"], idx) 134 | else: 135 | raise Exception(f"Unknown event type: {event['type']}") 136 | except Exception as e: 137 | event["future"].set_exception(e) 138 | logger.error(f"Error in consumer: {e}") 139 | raise e 140 | finally: 141 | self.queue.task_done() 142 | self._notify_callbacks() 143 | 144 | def log_create( 145 | self, 146 | project: str, 147 | context: str, 148 | params: dict, 149 | entries: dict, 150 | ) -> asyncio.Future: 151 | fut = self.loop.create_future() 152 | event = { 153 | "_data": { 154 | "project": project, 155 | "context": context, 156 | "params": params, 157 | "entries": entries, 158 | }, 159 | "type": "create", 160 | "future": fut, 161 | } 162 | asyncio.run_coroutine_threadsafe(self.queue.put(event), self.loop).result() 163 | return fut 164 | 165 | def log_update( 166 | self, 167 | project: str, 168 | context: str, 169 | future: asyncio.Future, 170 | mode: str, 171 | overwrite: bool, 172 | data: dict, 173 | ) -> None: 174 | event = { 175 | "_data": { 176 | mode: data, 177 | "project": project, 178 | "context": context, 179 | "overwrite": overwrite, 180 | }, 181 | "type": "update", 182 | "future": future, 183 | } 184 | asyncio.run_coroutine_threadsafe(self.queue.put(event), self.loop).result() 185 | 186 | def stop_sync(self, immediate=False): 187 | if self.shutting_down: 188 | return 189 | 190 | self.shutting_down = True 191 | if immediate: 192 | logger.debug("Stopping async logger immediately") 193 | self.loop.stop() 194 | else: 195 | self.join() 196 | -------------------------------------------------------------------------------- /unify/logging/utils/compositions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | 5 | from ...utils.helpers import _validate_api_key 6 | from .logs import * 7 | 8 | # Parameters # 9 | # -----------# 10 | 11 | 12 | def get_param_by_version( 13 | field: str, 14 | version: Union[str, int], 15 | api_key: Optional[str] = None, 16 | ) -> Any: 17 | """ 18 | Gets the parameter by version. 19 | 20 | Args: 21 | field: The field of the parameter to get. 22 | 23 | version: The version of the parameter to get. 24 | 25 | api_key: If specified, unify API key to be used. Defaults to the value in the 26 | `UNIFY_KEY` environment variable. 27 | 28 | Returns: 29 | The parameter by version. 30 | """ 31 | api_key = _validate_api_key(api_key) 32 | version = str(version) 33 | filter_exp = f"version({field}) == {version}" 34 | return get_logs(filter=filter_exp, limit=1, api_key=api_key)[0].params[field][1] 35 | 36 | 37 | def get_param_by_value( 38 | field: str, 39 | value: Any, 40 | api_key: Optional[str] = None, 41 | ) -> Any: 42 | """ 43 | Gets the parameter by value. 44 | 45 | Args: 46 | field: The field of the parameter to get. 47 | 48 | value: The value of the parameter to get. 49 | 50 | api_key: If specified, unify API key to be used. Defaults to the value in the 51 | `UNIFY_KEY` environment variable. 52 | 53 | Returns: 54 | The parameter by version. 55 | """ 56 | api_key = _validate_api_key(api_key) 57 | filter_exp = f"{field} == {json.dumps(value)}" 58 | return get_logs(filter=filter_exp, limit=1, api_key=api_key)[0].params[field][0] 59 | 60 | 61 | def get_source() -> str: 62 | """ 63 | Extracts the source code for the file from where this function was called. 64 | 65 | Returns: 66 | The source code for the file, as a string. 67 | """ 68 | frame = inspect.getouterframes(inspect.currentframe())[1] 69 | with open(frame.filename, "r") as file: 70 | source = file.read() 71 | return f"```python\n{source}\n```" 72 | 73 | 74 | # Experiments # 75 | # ------------# 76 | 77 | 78 | def get_experiment_name(version: int, api_key: Optional[str] = None) -> str: 79 | """ 80 | Gets the experiment name (by version). 81 | 82 | Args: 83 | version: The version of the experiment to get. 84 | 85 | api_key: If specified, unify API key to be used. Defaults to the value in the 86 | `UNIFY_KEY` environment variable. 87 | 88 | Returns: 89 | The experiment name with said version. 90 | """ 91 | experiments = get_groups(key="experiment", api_key=api_key) 92 | if not experiments: 93 | return None 94 | elif version < 0: 95 | version = len(experiments) + version 96 | if str(version) not in experiments: 97 | return None 98 | return experiments[str(version)] 99 | 100 | 101 | def get_experiment_version(name: str, api_key: Optional[str] = None) -> int: 102 | """ 103 | Gets the experiment version (by name). 104 | 105 | Args: 106 | name: The name of the experiment to get. 107 | 108 | api_key: If specified, unify API key to be used. Defaults to the value in the 109 | `UNIFY_KEY` environment variable. 110 | 111 | Returns: 112 | The experiment version with said name. 113 | """ 114 | experiments = get_groups(key="experiment", api_key=api_key) 115 | if not experiments: 116 | return None 117 | experiments = {v: k for k, v in experiments.items()} 118 | if name not in experiments: 119 | return None 120 | return int(experiments[name]) 121 | -------------------------------------------------------------------------------- /unify/logging/utils/contexts.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | from unify import BASE_URL 4 | from unify.utils import _requests 5 | 6 | from ...utils.helpers import ( 7 | _check_response, 8 | _get_and_maybe_create_project, 9 | _validate_api_key, 10 | ) 11 | from .logs import CONTEXT_WRITE 12 | 13 | # Contexts # 14 | # ---------# 15 | 16 | 17 | def create_context( 18 | name: str, 19 | description: str = None, 20 | is_versioned: bool = False, 21 | allow_duplicates: bool = True, 22 | *, 23 | project: Optional[str] = None, 24 | api_key: Optional[str] = None, 25 | ) -> None: 26 | """ 27 | Create a context. 28 | 29 | Args: 30 | name: Name of the context to create. 31 | 32 | description: Description of the context to create. 33 | 34 | is_versioned: Whether the context is versioned. 35 | 36 | allow_duplicates: Whether to allow duplicates in the context. 37 | 38 | project: Name of the project the context belongs to. 39 | 40 | api_key: If specified, unify API key to be used. Defaults to the value in the 41 | `UNIFY_KEY` environment variable. 42 | 43 | Returns: 44 | A message indicating whether the context was successfully created. 45 | """ 46 | api_key = _validate_api_key(api_key) 47 | project = _get_and_maybe_create_project( 48 | project, 49 | api_key=api_key, 50 | create_if_missing=False, 51 | ) 52 | headers = { 53 | "accept": "application/json", 54 | "Authorization": f"Bearer {api_key}", 55 | } 56 | body = { 57 | "name": name, 58 | "description": description, 59 | "is_versioned": is_versioned, 60 | "allow_duplicates": allow_duplicates, 61 | } 62 | response = _requests.post( 63 | BASE_URL + f"/project/{project}/contexts", 64 | headers=headers, 65 | json=body, 66 | ) 67 | _check_response(response) 68 | return response.json() 69 | 70 | 71 | def rename_context( 72 | name: str, 73 | new_name: str, 74 | *, 75 | project: Optional[str] = None, 76 | api_key: Optional[str] = None, 77 | ) -> None: 78 | """ 79 | Rename a context. 80 | 81 | Args: 82 | name: Name of the context to rename. 83 | 84 | new_name: New name of the context. 85 | 86 | project: Name of the project the context belongs to. 87 | 88 | api_key: If specified, unify API key to be used. Defaults to the value in the 89 | `UNIFY_KEY` environment variable. 90 | """ 91 | api_key = _validate_api_key(api_key) 92 | project = _get_and_maybe_create_project( 93 | project, 94 | api_key=api_key, 95 | create_if_missing=False, 96 | ) 97 | headers = { 98 | "accept": "application/json", 99 | "Authorization": f"Bearer {api_key}", 100 | } 101 | response = _requests.patch( 102 | BASE_URL + f"/project/{project}/contexts/{name}/rename", 103 | headers=headers, 104 | json={"name": new_name}, 105 | ) 106 | _check_response(response) 107 | return response.json() 108 | 109 | 110 | def get_context( 111 | name: str, 112 | *, 113 | project: Optional[str] = None, 114 | api_key: Optional[str] = None, 115 | ) -> Dict[str, str]: 116 | """ 117 | Get information about a specific context including its versioning status and current version. 118 | 119 | Args: 120 | name: Name of the context to get. 121 | 122 | project: Name of the project the context belongs to. 123 | 124 | api_key: If specified, unify API key to be used. Defaults to the value in the 125 | `UNIFY_KEY` environment variable. 126 | """ 127 | api_key = _validate_api_key(api_key) 128 | project = _get_and_maybe_create_project( 129 | project, 130 | api_key=api_key, 131 | create_if_missing=False, 132 | ) 133 | headers = { 134 | "accept": "application/json", 135 | "Authorization": f"Bearer {api_key}", 136 | } 137 | response = _requests.get( 138 | BASE_URL + f"/project/{project}/contexts/{name}", 139 | headers=headers, 140 | ) 141 | _check_response(response) 142 | return response.json() 143 | 144 | 145 | def get_contexts( 146 | project: Optional[str] = None, 147 | *, 148 | prefix: Optional[str] = None, 149 | api_key: Optional[str] = None, 150 | ) -> Dict[str, str]: 151 | """ 152 | Gets all contexts associated with a project, with the corresponding prefix. 153 | 154 | Args: 155 | prefix: Prefix of the contexts to get. 156 | 157 | project: Name of the project the artifacts belong to. 158 | 159 | api_key: If specified, unify API key to be used. Defaults to the value in the 160 | `UNIFY_KEY` environment variable. 161 | 162 | kwargs: Dictionary containing one or more key:value pairs that will be stored 163 | as artifacts. 164 | 165 | Returns: 166 | A message indicating whether the artifacts were successfully added. 167 | """ 168 | api_key = _validate_api_key(api_key) 169 | headers = { 170 | "accept": "application/json", 171 | "Authorization": f"Bearer {api_key}", 172 | } 173 | project = _get_and_maybe_create_project( 174 | project, 175 | api_key=api_key, 176 | create_if_missing=False, 177 | ) 178 | response = _requests.get( 179 | BASE_URL + f"/project/{project}/contexts", 180 | headers=headers, 181 | ) 182 | _check_response(response) 183 | contexts = response.json() 184 | contexts = {context["name"]: context["description"] for context in contexts} 185 | if prefix: 186 | contexts = { 187 | context: description 188 | for context, description in contexts.items() 189 | if context.startswith(prefix) 190 | } 191 | return contexts 192 | 193 | 194 | def delete_context( 195 | name: str, 196 | *, 197 | delete_children: bool = True, 198 | project: Optional[str] = None, 199 | api_key: Optional[str] = None, 200 | ) -> None: 201 | """ 202 | Delete a context from the server. 203 | 204 | Args: 205 | name: Name of the context to delete. 206 | 207 | delete_children: Whether to delete child contexts (which share the same "/" separated prefix). 208 | 209 | project: Name of the project the context belongs to. 210 | 211 | api_key: If specified, unify API key to be used. Defaults to the value in the 212 | `UNIFY_KEY` environment variable. 213 | """ 214 | api_key = _validate_api_key(api_key) 215 | project = _get_and_maybe_create_project( 216 | project, 217 | api_key=api_key, 218 | create_if_missing=False, 219 | ) 220 | headers = { 221 | "accept": "application/json", 222 | "Authorization": f"Bearer {api_key}", 223 | } 224 | 225 | # ToDo: remove this hack once this task [https://app.clickup.com/t/86c3kuch6] is done 226 | all_contexts = get_contexts(project, prefix=name) 227 | for ctx in all_contexts: 228 | response = _requests.delete( 229 | BASE_URL + f"/project/{project}/contexts/{ctx}", 230 | headers=headers, 231 | ) 232 | _check_response(response) 233 | if all_contexts: 234 | return response.json() 235 | 236 | 237 | def add_logs_to_context( 238 | log_ids: List[int], 239 | *, 240 | context: Optional[str] = None, 241 | project: Optional[str] = None, 242 | api_key: Optional[str] = None, 243 | ) -> None: 244 | """ 245 | Add logs to a context. 246 | 247 | Args: 248 | log_ids: List of log ids to add to the context. 249 | 250 | context: Name of the context to add the logs to. 251 | 252 | project: Name of the project the logs belong to. 253 | 254 | api_key: If specified, unify API key to be used. Defaults to the value in the 255 | `UNIFY_KEY` environment variable. 256 | 257 | Returns: 258 | A message indicating whether the logs were successfully added to the context. 259 | """ 260 | api_key = _validate_api_key(api_key) 261 | context = context if context else CONTEXT_WRITE.get() 262 | project = _get_and_maybe_create_project( 263 | project, 264 | api_key=api_key, 265 | create_if_missing=False, 266 | ) 267 | headers = { 268 | "accept": "application/json", 269 | "Authorization": f"Bearer {api_key}", 270 | } 271 | body = { 272 | "context_name": context, 273 | "log_ids": log_ids, 274 | } 275 | response = _requests.post( 276 | BASE_URL + f"/project/{project}/contexts/add_logs", 277 | headers=headers, 278 | json=body, 279 | ) 280 | _check_response(response) 281 | return response.json() 282 | -------------------------------------------------------------------------------- /unify/logging/utils/datasets.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from ...utils.helpers import _get_and_maybe_create_project, _validate_api_key 4 | from ..logs import Log 5 | from .contexts import * 6 | from .logs import * 7 | 8 | # Datasets # 9 | # ---------# 10 | 11 | 12 | def list_datasets( 13 | *, 14 | project: Optional[str] = None, 15 | prefix: str = "", 16 | api_key: Optional[str] = None, 17 | ) -> Dict[str, str]: 18 | """ 19 | List all datasets associated with a project and context. 20 | 21 | Args: 22 | project: Name of the project the datasets belong to. 23 | 24 | prefix: Prefix of the datasets to get. 25 | 26 | api_key: If specified, unify API key to be used. Defaults to the value in the 27 | `UNIFY_KEY` environment variable. 28 | 29 | Returns: 30 | A list of datasets. 31 | """ 32 | api_key = _validate_api_key(api_key) 33 | contexts = get_contexts( 34 | prefix=f"Datasets/{prefix}", 35 | project=project, 36 | api_key=api_key, 37 | ) 38 | return { 39 | "/".join(name.split("/")[1:]): description 40 | for name, description in contexts.items() 41 | } 42 | 43 | 44 | def upload_dataset( 45 | name: str, 46 | data: List[Any], 47 | *, 48 | overwrite: bool = False, 49 | allow_duplicates: bool = False, 50 | project: Optional[str] = None, 51 | api_key: Optional[str] = None, 52 | ) -> List[int]: 53 | """ 54 | Upload a dataset to the server. 55 | 56 | Args: 57 | name: Name of the dataset. 58 | 59 | data: Contents of the dataset. 60 | 61 | overwrite: Whether to overwrite the dataset if it already exists. 62 | 63 | allow_duplicates: Whether to allow duplicates in the dataset. 64 | 65 | project: Name of the project the dataset belongs to. 66 | 67 | api_key: If specified, unify API key to be used. Defaults to the value in the 68 | `UNIFY_KEY` environment variable. 69 | Returns: 70 | A list all log ids in the dataset. 71 | """ 72 | api_key = _validate_api_key(api_key) 73 | project = _get_and_maybe_create_project(project, api_key=api_key) 74 | context = f"Datasets/{name}" 75 | log_instances = [isinstance(item, unify.Log) for item in data] 76 | are_logs = False 77 | if not allow_duplicates and not overwrite: 78 | # ToDo: remove this verbose logic once ignore_duplicates is implemented 79 | if name in unify.list_datasets(): 80 | upstream_dataset = unify.Dataset( 81 | unify.download_dataset(name, project=project, api_key=api_key), 82 | ) 83 | else: 84 | upstream_dataset = unify.Dataset([]) 85 | if any(log_instances): 86 | assert all(log_instances), "If any items are logs, all items must be logs" 87 | are_logs = True 88 | # ToDo: remove this verbose logic once ignore_duplicates is implemented 89 | if not allow_duplicates and not overwrite: 90 | data = [l for l in data if l not in upstream_dataset] 91 | elif not all(isinstance(item, dict) for item in data): 92 | # ToDo: remove this verbose logic once ignore_duplicates is implemented 93 | if not allow_duplicates and not overwrite: 94 | data = [item for item in data if item not in upstream_dataset] 95 | data = [{"data": item} for item in data] 96 | if name in unify.list_datasets(): 97 | upstream_ids = get_logs( 98 | project=project, 99 | context=context, 100 | return_ids_only=True, 101 | ) 102 | else: 103 | upstream_ids = [] 104 | if not are_logs: 105 | return upstream_ids + create_logs( 106 | project=project, 107 | context=context, 108 | entries=data, 109 | mutable=True, 110 | batched=True, 111 | # ToDo: uncomment once ignore_duplicates is implemented 112 | # ignore_duplicates=not allow_duplicates, 113 | ) 114 | local_ids = [l.id for l in data] 115 | matching_ids = [id for id in upstream_ids if id in local_ids] 116 | matching_data = [l.entries for l in data if l.id in matching_ids] 117 | assert len(matching_data) == len( 118 | matching_ids, 119 | ), "matching data and ids must be the same length" 120 | if matching_data: 121 | update_logs( 122 | logs=matching_ids, 123 | api_key=api_key, 124 | entries=matching_data, 125 | overwrite=True, 126 | ) 127 | if overwrite: 128 | upstream_only_ids = [id for id in upstream_ids if id not in local_ids] 129 | if upstream_only_ids: 130 | delete_logs( 131 | logs=upstream_only_ids, 132 | context=context, 133 | project=project, 134 | api_key=api_key, 135 | ) 136 | upstream_ids = [id for id in upstream_ids if id not in upstream_only_ids] 137 | ids_not_in_dataset = [ 138 | id for id in local_ids if id not in matching_ids and id is not None 139 | ] 140 | if ids_not_in_dataset: 141 | if context not in unify.get_contexts(): 142 | unify.create_context( 143 | context, 144 | project=project, 145 | api_key=api_key, 146 | ) 147 | unify.add_logs_to_context( 148 | log_ids=ids_not_in_dataset, 149 | context=context, 150 | project=project, 151 | api_key=api_key, 152 | ) 153 | local_only_data = [l.entries for l in data if l.id is None] 154 | if local_only_data: 155 | return upstream_ids + create_logs( 156 | project=project, 157 | context=context, 158 | entries=local_only_data, 159 | mutable=True, 160 | batched=True, 161 | ) 162 | return upstream_ids + ids_not_in_dataset 163 | 164 | 165 | def download_dataset( 166 | name: str, 167 | *, 168 | project: Optional[str] = None, 169 | api_key: Optional[str] = None, 170 | ) -> List[Log]: 171 | """ 172 | Download a dataset from the server. 173 | 174 | Args: 175 | name: Name of the dataset. 176 | 177 | project: Name of the project the dataset belongs to. 178 | 179 | api_key: If specified, unify API key to be used. Defaults to the value in the 180 | `UNIFY_KEY` environment variable. 181 | """ 182 | api_key = _validate_api_key(api_key) 183 | project = _get_and_maybe_create_project(project, api_key=api_key) 184 | logs = get_logs( 185 | project=project, 186 | context=f"Datasets/{name}", 187 | ) 188 | return list(reversed(logs)) 189 | 190 | 191 | def delete_dataset( 192 | name: str, 193 | *, 194 | project: Optional[str] = None, 195 | api_key: Optional[str] = None, 196 | ) -> None: 197 | """ 198 | Delete a dataset from the server. 199 | 200 | Args: 201 | name: Name of the dataset. 202 | 203 | project: Name of the project the dataset belongs to. 204 | 205 | api_key: If specified, unify API key to be used. Defaults to the value in the 206 | `UNIFY_KEY` environment variable. 207 | """ 208 | api_key = _validate_api_key(api_key) 209 | project = _get_and_maybe_create_project(project, api_key=api_key) 210 | delete_context(f"Datasets/{name}", project=project, api_key=api_key) 211 | 212 | 213 | def add_dataset_entries( 214 | name: str, 215 | data: List[Any], 216 | *, 217 | project: Optional[str] = None, 218 | api_key: Optional[str] = None, 219 | ) -> List[int]: 220 | """ 221 | Adds entries to an existing dataset in the server. 222 | 223 | Args: 224 | name: Name of the dataset. 225 | 226 | contents: Contents to add to the dataset. 227 | 228 | project: Name of the project the dataset belongs to. 229 | 230 | api_key: If specified, unify API key to be used. Defaults to the value in the 231 | `UNIFY_KEY` environment variable. 232 | Returns: 233 | A list of the newly added dataset logs. 234 | """ 235 | api_key = _validate_api_key(api_key) 236 | project = _get_and_maybe_create_project( 237 | project, 238 | api_key=api_key, 239 | create_if_missing=False, 240 | ) 241 | if not all(isinstance(item, dict) for item in data): 242 | data = [{"data": item} for item in data] 243 | logs = create_logs( 244 | project=project, 245 | context=f"Datasets/{name}", 246 | entries=data, 247 | mutable=True, 248 | batched=True, 249 | ) 250 | return logs 251 | -------------------------------------------------------------------------------- /unify/logging/utils/projects.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | from unify import BASE_URL 4 | from unify.utils import _requests 5 | 6 | from ...utils.helpers import _check_response, _validate_api_key 7 | 8 | # Projects # 9 | # ---------# 10 | 11 | 12 | def create_project( 13 | name: str, 14 | *, 15 | overwrite: bool = False, 16 | api_key: Optional[str] = None, 17 | ) -> Dict[str, str]: 18 | """ 19 | Creates a logging project and adds this to your account. This project will have 20 | a set of logs associated with it. 21 | 22 | Args: 23 | name: A unique, user-defined name used when referencing the project. 24 | 25 | overwrite: Whether to overwrite an existing project if is already exists. 26 | 27 | api_key: If specified, unify API key to be used. Defaults to the value in the 28 | `UNIFY_KEY` environment variable. 29 | 30 | Returns: 31 | A message indicating whether the project was created successfully. 32 | """ 33 | api_key = _validate_api_key(api_key) 34 | headers = { 35 | "accept": "application/json", 36 | "Authorization": f"Bearer {api_key}", 37 | } 38 | body = {"name": name} 39 | if overwrite: 40 | if name in list_projects(api_key=api_key): 41 | delete_project(name=name, api_key=api_key) 42 | response = _requests.post(BASE_URL + "/project", headers=headers, json=body) 43 | _check_response(response) 44 | return response.json() 45 | 46 | 47 | def rename_project( 48 | name: str, 49 | new_name: str, 50 | *, 51 | api_key: Optional[str] = None, 52 | ) -> Dict[str, str]: 53 | """ 54 | Renames a project from `name` to `new_name` in your account. 55 | 56 | Args: 57 | name: Name of the project to rename. 58 | 59 | new_name: A unique, user-defined name used when referencing the project. 60 | 61 | api_key: If specified, unify API key to be used. Defaults to the value in the 62 | `UNIFY_KEY` environment variable. 63 | 64 | Returns: 65 | A message indicating whether the project was successfully renamed. 66 | """ 67 | api_key = _validate_api_key(api_key) 68 | headers = { 69 | "accept": "application/json", 70 | "Authorization": f"Bearer {api_key}", 71 | } 72 | body = {"name": new_name} 73 | response = _requests.patch( 74 | BASE_URL + f"/project/{name}", 75 | headers=headers, 76 | json=body, 77 | ) 78 | _check_response(response) 79 | return response.json() 80 | 81 | 82 | def delete_project( 83 | name: str, 84 | *, 85 | api_key: Optional[str] = None, 86 | ) -> str: 87 | """ 88 | Deletes a project from your account. 89 | 90 | Args: 91 | name: Name of the project to delete. 92 | 93 | api_key: If specified, unify API key to be used. Defaults to the value in the 94 | `UNIFY_KEY` environment variable. 95 | 96 | Returns: 97 | Whether the project was successfully deleted. 98 | """ 99 | api_key = _validate_api_key(api_key) 100 | headers = { 101 | "accept": "application/json", 102 | "Authorization": f"Bearer {api_key}", 103 | } 104 | response = _requests.delete(BASE_URL + f"/project/{name}", headers=headers) 105 | _check_response(response) 106 | return response.json() 107 | 108 | 109 | def delete_project_logs( 110 | name: str, 111 | *, 112 | api_key: Optional[str] = None, 113 | ) -> None: 114 | """ 115 | Deletes all logs from a project. 116 | 117 | Args: 118 | name: Name of the project to delete logs from. 119 | 120 | api_key: If specified, unify API key to be used. Defaults to the value in the 121 | `UNIFY_KEY` environment variable. 122 | """ 123 | api_key = _validate_api_key(api_key) 124 | headers = { 125 | "accept": "application/json", 126 | "Authorization": f"Bearer {api_key}", 127 | } 128 | response = _requests.delete(BASE_URL + f"/project/{name}/logs", headers=headers) 129 | _check_response(response) 130 | return response.json() 131 | 132 | 133 | def delete_project_contexts( 134 | name: str, 135 | *, 136 | api_key: Optional[str] = None, 137 | ) -> None: 138 | """ 139 | Deletes all contexts and their associated logs from a project 140 | 141 | Args: 142 | name: Name of the project to delete contexts from. 143 | 144 | api_key: If specified, unify API key to be used. Defaults to the value in the 145 | `UNIFY_KEY` environment variable. 146 | """ 147 | api_key = _validate_api_key(api_key) 148 | headers = { 149 | "accept": "application/json", 150 | "Authorization": f"Bearer {api_key}", 151 | } 152 | response = _requests.delete(BASE_URL + f"/project/{name}/contexts", headers=headers) 153 | _check_response(response) 154 | return response.json() 155 | 156 | 157 | def list_projects( 158 | *, 159 | api_key: Optional[str] = None, 160 | ) -> List[str]: 161 | """ 162 | Returns the names of all projects stored in your account. 163 | 164 | Args: 165 | api_key: If specified, unify API key to be used. Defaults to the value in the 166 | `UNIFY_KEY` environment variable. 167 | 168 | Returns: 169 | List of all project names. 170 | """ 171 | api_key = _validate_api_key(api_key) 172 | headers = { 173 | "accept": "application/json", 174 | "Authorization": f"Bearer {api_key}", 175 | } 176 | response = _requests.get(BASE_URL + "/projects", headers=headers) 177 | _check_response(response) 178 | return response.json() 179 | -------------------------------------------------------------------------------- /unify/logging/utils/tracing.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | from typing import Callable, List 4 | 5 | import unify 6 | 7 | 8 | class TraceLoader(importlib.abc.Loader): 9 | def __init__(self, original_loader, filter: Callable = None): 10 | self._original_loader = original_loader 11 | self.filter = filter 12 | 13 | def create_module(self, spec): 14 | return self._original_loader.create_module(spec) 15 | 16 | def exec_module(self, module): 17 | self._original_loader.exec_module(module) 18 | unify.traced(module, filter=self.filter) 19 | 20 | 21 | class TraceFinder(importlib.abc.MetaPathFinder): 22 | def __init__(self, targets: List[str], filter: Callable = None): 23 | self.targets = targets 24 | self.filter = filter 25 | 26 | def find_spec(self, fullname, path, target=None): 27 | for target_module in self.targets: 28 | if not fullname.startswith(target_module): 29 | return None 30 | 31 | original_sys_meta_path = sys.meta_path[:] 32 | sys.meta_path = [ 33 | finder for finder in sys.meta_path if not isinstance(finder, TraceFinder) 34 | ] 35 | try: 36 | spec = importlib.util.find_spec(fullname, path) 37 | if spec is None: 38 | return None 39 | finally: 40 | sys.meta_path = original_sys_meta_path 41 | 42 | if spec.origin is None or not spec.origin.endswith(".py"): 43 | return None 44 | 45 | spec.loader = TraceLoader(spec.loader, filter=self.filter) 46 | return spec 47 | 48 | 49 | def install_tracing_hook(targets: List[str], filter: Callable = None): 50 | """Install an import hook that wraps imported modules with the traced decorator. 51 | 52 | This function adds a TraceFinder to sys.meta_path that will intercept module imports 53 | and wrap them with the traced decorator. The hook will only be installed if one 54 | doesn't already exist. 55 | 56 | Args: 57 | targets: List of module name prefixes to target for tracing. Only modules 58 | whose names start with these prefixes will be wrapped. 59 | 60 | filter: A filter function that is passed to the traced decorator. 61 | 62 | """ 63 | if not any(isinstance(finder, TraceFinder) for finder in sys.meta_path): 64 | sys.meta_path.insert(0, TraceFinder(targets, filter)) 65 | 66 | 67 | def disable_tracing_hook(): 68 | """Remove the tracing import hook from sys.meta_path. 69 | 70 | This function removes any TraceFinder instances from sys.meta_path, effectively 71 | disabling the tracing functionality for subsequent module imports. 72 | 73 | """ 74 | for finder in sys.meta_path: 75 | if isinstance(finder, TraceFinder): 76 | sys.meta_path.remove(finder) 77 | -------------------------------------------------------------------------------- /unify/universal_api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/unify/universal_api/__init__.py -------------------------------------------------------------------------------- /unify/universal_api/casting.py: -------------------------------------------------------------------------------- 1 | from typing import List, Type, Union 2 | 3 | from openai.types.chat import ChatCompletion 4 | from unify.universal_api.types import Prompt 5 | 6 | # Upcasting 7 | 8 | 9 | def _usr_msg_to_prompt(user_message: str) -> Prompt: 10 | return Prompt(user_message) 11 | 12 | 13 | def _bool_to_float(boolean: bool) -> float: 14 | return float(boolean) 15 | 16 | 17 | # Downcasting 18 | 19 | 20 | def _prompt_to_usr_msg(prompt: Prompt) -> str: 21 | return prompt.messages[-1]["content"] 22 | 23 | 24 | def _chat_completion_to_assis_msg(chat_completion: ChatCompletion) -> str: 25 | return chat_completion.choices[0].message.content 26 | 27 | 28 | def _float_to_bool(float_in: float) -> bool: 29 | return bool(float_in) 30 | 31 | 32 | # Cast Dict 33 | 34 | _CAST_DICT = { 35 | str: {Prompt: _usr_msg_to_prompt}, 36 | Prompt: { 37 | str: _prompt_to_usr_msg, 38 | }, 39 | ChatCompletion: {str: _chat_completion_to_assis_msg}, 40 | bool: { 41 | float: _bool_to_float, 42 | }, 43 | float: { 44 | bool: _float_to_bool, 45 | }, 46 | } 47 | 48 | 49 | def _cast_from_selection( 50 | inp: Union[str, bool, float, Prompt, ChatCompletion], 51 | targets: List[Union[float, Prompt, ChatCompletion]], 52 | ) -> Union[str, bool, float, Prompt, ChatCompletion]: 53 | """ 54 | Upcasts the input if possible, based on the permitted upcasting targets provided. 55 | 56 | Args: 57 | inp: The input to cast. 58 | 59 | targets: The set of permitted upcasting targets. 60 | 61 | Returns: 62 | The input after casting to the new type, if it was possible. 63 | """ 64 | input_type = type(inp) 65 | assert input_type in _CAST_DICT, ( 66 | "Cannot upcast input {} of type {}, because this type is not in the " 67 | "_CAST_DICT, meaning there are no functions for casting this type." 68 | ) 69 | cast_fns = _CAST_DICT[input_type] 70 | targets = [target for target in targets if target in cast_fns] 71 | assert len(targets) == 1, "There must be exactly one valid casting target." 72 | to_type = targets[0] 73 | return cast_fns[to_type](inp) 74 | 75 | 76 | # Public function 77 | 78 | 79 | def cast( 80 | inp: Union[str, bool, float, Prompt, ChatCompletion], 81 | to_type: Union[ 82 | Type[Union[str, bool, float, Prompt, ChatCompletion]], 83 | List[Type[Union[str, bool, float, Prompt, ChatCompletion]]], 84 | ], 85 | ) -> Union[str, bool, float, Prompt, ChatCompletion]: 86 | """ 87 | Cast the input to the specified type. 88 | 89 | Args: 90 | inp: The input to cast. 91 | 92 | to_type: The type to cast the input to. 93 | 94 | Returns: 95 | The input after casting to the new type. 96 | """ 97 | if isinstance(to_type, list): 98 | return _cast_from_selection(inp, to_type) 99 | input_type = type(inp) 100 | if input_type is to_type: 101 | return inp 102 | return _CAST_DICT[input_type][to_type](inp) 103 | 104 | 105 | def try_cast( 106 | inp: Union[str, bool, float, Prompt, ChatCompletion], 107 | to_type: Union[ 108 | Type[Union[str, bool, float, Prompt, ChatCompletion]], 109 | List[Type[Union[str, bool, float, Prompt, ChatCompletion]]], 110 | ], 111 | ) -> Union[str, bool, float, Prompt, ChatCompletion]: 112 | # noinspection PyBroadException 113 | try: 114 | return cast(inp, to_type) 115 | except: 116 | return inp 117 | -------------------------------------------------------------------------------- /unify/universal_api/chatbot.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | from typing import Dict, Union 4 | 5 | import unify 6 | from unify.universal_api.clients import _Client, _MultiClient, _UniClient 7 | 8 | 9 | class ChatBot: # noqa: WPS338 10 | """Agent class represents an LLM chat agent.""" 11 | 12 | def __init__( 13 | self, 14 | client: _Client, 15 | ) -> None: 16 | """ 17 | Initializes the ChatBot object, wrapped around a client. 18 | 19 | Args: 20 | client: The Client instance to wrap the chatbot logic around. 21 | """ 22 | self._paused = False 23 | assert not client.return_full_completion, ( 24 | "ChatBot currently only supports clients which only generate the message " 25 | "content in the return" 26 | ) 27 | self._client = client 28 | self.clear_chat_history() 29 | 30 | @property 31 | def client(self) -> _Client: 32 | """ 33 | Get the client object. # noqa: DAR201. 34 | 35 | Returns: 36 | The client. 37 | """ 38 | return self._client 39 | 40 | def set_client(self, value: client) -> None: 41 | """ 42 | Set the client. # noqa: DAR101. 43 | 44 | Args: 45 | value: The unify client. 46 | """ 47 | if isinstance(value, _Client): 48 | self._client = value 49 | else: 50 | raise Exception("Invalid client!") 51 | 52 | def _get_credits(self) -> float: 53 | """ 54 | Retrieves the current credit balance from associated with the UNIFY account. 55 | 56 | Returns: 57 | Current credit balance. 58 | """ 59 | return self._client.get_credit_balance() 60 | 61 | def _update_message_history( 62 | self, 63 | role: str, 64 | content: Union[str, Dict[str, str]], 65 | ) -> None: 66 | """ 67 | Updates message history with user input. 68 | 69 | Args: 70 | role: Either "assistant" or "user". 71 | content: User input message. 72 | """ 73 | if isinstance(self._client, _UniClient): 74 | self._client.messages.append( 75 | { 76 | "role": role, 77 | "content": content, 78 | }, 79 | ) 80 | elif isinstance(self._client, _MultiClient): 81 | if isinstance(content, str): 82 | content = {endpoint: content for endpoint in self._client.endpoints} 83 | for endpoint, cont in content.items(): 84 | self._client.messages[endpoint].append( 85 | { 86 | "role": role, 87 | "content": cont, 88 | }, 89 | ) 90 | else: 91 | raise Exception( 92 | "client must either be a UniClient or MultiClient instance.", 93 | ) 94 | 95 | def clear_chat_history(self) -> None: 96 | """Clears the chat history.""" 97 | if isinstance(self._client, _UniClient): 98 | self._client.set_messages([]) 99 | elif isinstance(self._client, _MultiClient): 100 | self._client.set_messages( 101 | {endpoint: [] for endpoint in self._client.endpoints}, 102 | ) 103 | else: 104 | raise Exception( 105 | "client must either be a UniClient or MultiClient instance.", 106 | ) 107 | 108 | @staticmethod 109 | def _stream_response(response) -> str: 110 | words = "" 111 | for chunk in response: 112 | words += chunk 113 | sys.stdout.write(chunk) 114 | sys.stdout.flush() 115 | sys.stdout.write("\n") 116 | return words 117 | 118 | def _handle_uni_llm_response( 119 | self, 120 | response: str, 121 | endpoint: Union[bool, str], 122 | ) -> str: 123 | if endpoint: 124 | endpoint = self._client.endpoint if endpoint is True else endpoint 125 | sys.stdout.write(endpoint + ":\n") 126 | if self._client.stream: 127 | words = self._stream_response(response) 128 | else: 129 | words = response 130 | sys.stdout.write(words) 131 | sys.stdout.write("\n\n") 132 | return words 133 | 134 | def _handle_multi_llm_response(self, response: Dict[str, str]) -> Dict[str, str]: 135 | for endpoint, resp in response.items(): 136 | self._handle_uni_llm_response(resp, endpoint) 137 | return response 138 | 139 | def _handle_response( 140 | self, 141 | response: Union[str, Dict[str, str]], 142 | show_endpoint: bool, 143 | ) -> None: 144 | if isinstance(self._client, _UniClient): 145 | response = self._handle_uni_llm_response(response, show_endpoint) 146 | elif isinstance(self._client, _MultiClient): 147 | response = self._handle_multi_llm_response(response) 148 | else: 149 | raise Exception( 150 | "client must either be a UniClient or MultiClient instance.", 151 | ) 152 | self._update_message_history( 153 | role="assistant", 154 | content=response, 155 | ) 156 | 157 | def run(self, show_credits: bool = False, show_endpoint: bool = False) -> None: 158 | """ 159 | Starts the chat interaction loop. 160 | 161 | Args: 162 | show_credits: Whether to show credit consumption. Defaults to False. 163 | show_endpoint: Whether to show the endpoint used. Defaults to False. 164 | """ 165 | if not self._paused: 166 | sys.stdout.write( 167 | "Let's have a chat. (Enter `pause` to pause and `quit` to exit)\n", 168 | ) 169 | self.clear_chat_history() 170 | else: 171 | sys.stdout.write( 172 | "Welcome back! (Remember, enter `pause` to pause and `quit` to exit)\n", 173 | ) 174 | self._paused = False 175 | while True: 176 | sys.stdout.write("> ") 177 | inp = input() 178 | if inp == "quit": 179 | self.clear_chat_history() 180 | break 181 | elif inp == "pause": 182 | self._paused = True 183 | break 184 | self._update_message_history(role="user", content=inp) 185 | initial_credit_balance = self._get_credits() 186 | if isinstance(self._client, unify.AsyncUnify): 187 | response = asyncio.run(self._client.generate()) 188 | else: 189 | response = self._client.generate() 190 | self._handle_response(response, show_endpoint) 191 | final_credit_balance = self._get_credits() 192 | if show_credits: 193 | sys.stdout.write( 194 | "\n(spent {:.6f} credits)".format( 195 | initial_credit_balance - final_credit_balance, 196 | ), 197 | ) 198 | -------------------------------------------------------------------------------- /unify/universal_api/clients/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | from .base import _Client 3 | from . import uni_llm 4 | from .uni_llm import _UniClient, Unify, AsyncUnify 5 | from . import multi_llm 6 | from .multi_llm import _MultiClient, MultiUnify, AsyncMultiUnify 7 | -------------------------------------------------------------------------------- /unify/universal_api/clients/helpers.py: -------------------------------------------------------------------------------- 1 | import unify 2 | 3 | # Helpers 4 | 5 | 6 | def _is_custom_endpoint(endpoint: str): 7 | _, provider = endpoint.split("@") 8 | return "custom" in provider 9 | 10 | 11 | def _is_local_endpoint(endpoint: str): 12 | _, provider = endpoint.split("@") 13 | return provider == "local" 14 | 15 | 16 | def _is_fallback_provider(provider: str, api_key: str = None): 17 | public_providers = unify.list_providers(api_key=api_key) 18 | return all(p in public_providers for p in provider.split("->")) 19 | 20 | 21 | def _is_fallback_model(model: str, api_key: str = None): 22 | public_models = unify.list_models(api_key=api_key) 23 | return all(p in public_models for p in model.split("->")) 24 | 25 | 26 | def _is_fallback_endpoint(endpoint: str, api_key: str = None): 27 | public_endpoints = unify.list_endpoints(api_key=api_key) 28 | return all(e in public_endpoints for e in endpoint.split("->")) 29 | 30 | 31 | def _is_meta_provider(provider: str, api_key: str = None): 32 | public_providers = unify.list_providers(api_key=api_key) 33 | if "skip_providers:" in provider: 34 | skip_provs = provider.split("skip_providers:")[-1].split("|")[0] 35 | for prov in skip_provs.split(","): 36 | if prov.strip() not in public_providers: 37 | return False 38 | chnk0, chnk1 = provider.split("skip_providers:") 39 | chnk2 = "|".join(chnk1.split("|")[1:]) 40 | provider = "".join([chnk0, chnk2]) 41 | if "providers:" in provider: 42 | provs = provider.split("providers:")[-1].split("|")[0] 43 | for prov in provs.split(","): 44 | if prov.strip() not in public_providers: 45 | return False 46 | chnk0, chnk1 = provider.split("providers:") 47 | chnk2 = "|".join(chnk1.split("|")[1:]) 48 | provider = "".join([chnk0, chnk2]) 49 | if provider[-1] == "|": 50 | provider = provider[:-1] 51 | public_models = unify.list_models(api_key=api_key) 52 | if "skip_models:" in provider: 53 | skip_mods = provider.split("skip_models:")[-1].split("|")[0] 54 | for md in skip_mods.split(","): 55 | if md.strip() not in public_models: 56 | return False 57 | chnk0, chnk1 = provider.split("skip_models:") 58 | chnk2 = "|".join(chnk1.split("|")[1:]) 59 | provider = "".join([chnk0, chnk2]) 60 | if "models:" in provider: 61 | mods = provider.split("models:")[-1].split("|")[0] 62 | for md in mods.split(","): 63 | if md.strip() not in public_models: 64 | return False 65 | chnk0, chnk1 = provider.split("models:") 66 | chnk2 = "|".join(chnk1.split("|")[1:]) 67 | provider = "".join([chnk0, chnk2]) 68 | meta_providers = ( 69 | ( 70 | "highest-quality", 71 | "lowest-time-to-first-token", 72 | "lowest-inter-token-latency", 73 | "lowest-input-cost", 74 | "lowest-output-cost", 75 | "lowest-cost", 76 | "lowest-ttft", 77 | "lowest-itl", 78 | "lowest-ic", 79 | "lowest-oc", 80 | "highest-q", 81 | "lowest-t", 82 | "lowest-i", 83 | "lowest-c", 84 | ) 85 | + ( 86 | "quality", 87 | "time-to-first-token", 88 | "inter-token-latency", 89 | "input-cost", 90 | "output-cost", 91 | "cost", 92 | ) 93 | + ( 94 | "q", 95 | "ttft", 96 | "itl", 97 | "ic", 98 | "oc", 99 | "t", 100 | "i", 101 | "c", 102 | ) 103 | ) 104 | operators = ("<", ">", "=", "|", ".", ":") 105 | for s in meta_providers + operators: 106 | provider = provider.replace(s, "") 107 | return all(c.isnumeric() for c in provider) 108 | 109 | 110 | # Checks 111 | 112 | 113 | def _is_valid_endpoint(endpoint: str, api_key: str = None): 114 | if endpoint == "user-input": 115 | return True 116 | if _is_fallback_endpoint(endpoint, api_key): 117 | return True 118 | model, provider = endpoint.split("@") 119 | if _is_valid_provider(provider) and _is_valid_model(model): 120 | return True 121 | if endpoint in unify.list_endpoints(api_key=api_key): 122 | return True 123 | if _is_custom_endpoint(endpoint) or _is_local_endpoint(endpoint): 124 | return True 125 | return False 126 | 127 | 128 | def _is_valid_provider(provider: str, api_key: str = None): 129 | if _is_meta_provider(provider): 130 | return True 131 | if provider in unify.list_providers(api_key=api_key): 132 | return True 133 | if _is_fallback_provider(provider): 134 | return True 135 | if provider == "local" or "custom" in provider: 136 | return True 137 | return False 138 | 139 | 140 | def _is_valid_model(model: str, custom_or_local: bool = False, api_key: str = None): 141 | if custom_or_local: 142 | return True 143 | if model in unify.list_models(api_key=api_key): 144 | return True 145 | if _is_fallback_model(model): 146 | return True 147 | if model == "router": 148 | return True 149 | return False 150 | 151 | 152 | # Assertions 153 | 154 | 155 | def _assert_is_valid_endpoint(endpoint: str, api_key: str = None): 156 | assert _is_valid_endpoint(endpoint, api_key), f"{endpoint} is not a valid endpoint" 157 | 158 | 159 | def _assert_is_valid_provider(provider: str, api_key: str = None): 160 | assert _is_valid_provider(provider, api_key), f"{provider} is not a valid provider" 161 | 162 | 163 | def _assert_is_valid_model( 164 | model: str, 165 | custom_or_local: bool = False, 166 | api_key: str = None, 167 | ): 168 | assert _is_valid_model( 169 | model, 170 | custom_or_local, 171 | api_key, 172 | ), f"{model} is not a valid model" 173 | -------------------------------------------------------------------------------- /unify/universal_api/types/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt import * 2 | -------------------------------------------------------------------------------- /unify/universal_api/types/prompt.py: -------------------------------------------------------------------------------- 1 | class Prompt: 2 | def __init__( 3 | self, 4 | **components, 5 | ): 6 | """ 7 | Create Prompt instance. 8 | 9 | Args: 10 | components: All components of the prompt. 11 | 12 | Returns: 13 | The Prompt instance. 14 | """ 15 | self.components = components 16 | -------------------------------------------------------------------------------- /unify/universal_api/usage.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import List, Optional 3 | 4 | import unify 5 | 6 | from ..utils.helpers import _validate_api_key 7 | 8 | 9 | def with_logging( 10 | model_fn: Optional[callable] = None, 11 | *, 12 | endpoint: str, 13 | tags: Optional[List[str]] = None, 14 | timestamp: Optional[datetime.datetime] = None, 15 | log_query_body: bool = True, 16 | log_response_body: bool = True, 17 | api_key: Optional[str] = None, 18 | ): 19 | """ 20 | Wrap a local model callable with logging of the queries. 21 | 22 | Args: 23 | model_fn: The model callable to wrap logging around. 24 | endpoint: The endpoint name to give to this local callable. 25 | tags: Tags for later filtering. 26 | timestamp: A timestamp (if not set, will be the time of sending). 27 | log_query_body: Whether or not to log the query body. 28 | log_response_body: Whether or not to log the response body. 29 | api_key: If specified, unify API key to be used. Defaults to the value in the `UNIFY_KEY` environment variable. 30 | 31 | Returns: 32 | A new callable, but with logging added every time the function is called. 33 | 34 | Raises: 35 | requests.HTTPError: If the API request fails. 36 | """ 37 | _tags = tags 38 | _timestamp = timestamp 39 | _log_query_body = log_query_body 40 | _log_response_body = log_response_body 41 | api_key = _validate_api_key(api_key) 42 | 43 | # noinspection PyShadowingNames 44 | def model_fn_w_logging( 45 | *args, 46 | tags: Optional[List[str]] = None, 47 | timestamp: Optional[datetime.datetime] = None, 48 | log_query_body: bool = True, 49 | log_response_body: bool = True, 50 | **kwargs, 51 | ): 52 | if len(args) != 0: 53 | raise Exception( 54 | "When logging queries for a local model, all arguments to " 55 | "the model callable must be provided as keyword arguments. " 56 | "Positional arguments are not supported. This is so the " 57 | "query body dict can be fully populated with keys for each " 58 | "entry.", 59 | ) 60 | query_body = kwargs 61 | response = model_fn(**query_body) 62 | if not isinstance(response, dict): 63 | response = {"response": response} 64 | kw = dict( 65 | endpoint=endpoint, 66 | query_body=query_body, 67 | response_body=response, 68 | tags=tags, 69 | timestamp=timestamp, 70 | api_key=api_key, 71 | ) 72 | if log_query_body: 73 | if not log_response_body: 74 | del kw["response_body"] 75 | unify.log_query(**kw) 76 | return response 77 | 78 | return model_fn_w_logging 79 | -------------------------------------------------------------------------------- /unify/universal_api/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/unify/universal_api/utils/__init__.py -------------------------------------------------------------------------------- /unify/universal_api/utils/credits.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from unify import BASE_URL 4 | from unify.utils import _requests 5 | 6 | from ...utils.helpers import _res_to_list, _validate_api_key 7 | 8 | 9 | def get_credits(*, api_key: Optional[str] = None) -> float: 10 | """ 11 | Returns the credits remaining in the user account, in USD. 12 | 13 | Args: 14 | api_key: If specified, unify API key to be used. Defaults to the value in the 15 | `UNIFY_KEY` environment variable. 16 | 17 | Returns: 18 | The credits remaining in USD. 19 | Raises: 20 | ValueError: If there was an HTTP error. 21 | """ 22 | api_key = _validate_api_key(api_key) 23 | headers = { 24 | "accept": "application/json", 25 | "Authorization": f"Bearer {api_key}", 26 | } 27 | # Send GET request to the /get_credits endpoint 28 | response = _requests.get(BASE_URL + "/credits", headers=headers) 29 | if response.status_code != 200: 30 | raise Exception(response.json()) 31 | return _res_to_list(response)["credits"] 32 | -------------------------------------------------------------------------------- /unify/universal_api/utils/custom_api_keys.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from unify import BASE_URL 4 | from unify.utils import _requests 5 | 6 | from ...utils.helpers import _validate_api_key 7 | 8 | 9 | def create_custom_api_key( 10 | name: str, 11 | value: str, 12 | *, 13 | api_key: Optional[str] = None, 14 | ) -> Dict[str, str]: 15 | """ 16 | Create a custom API key. 17 | 18 | Args: 19 | name: Name of the API key. 20 | value: Value of the API key. 21 | api_key: If specified, unify API key to be used. Defaults 22 | to the value in the `UNIFY_KEY` environment variable. 23 | 24 | Returns: 25 | A dictionary containing the response information. 26 | 27 | """ 28 | api_key = _validate_api_key(api_key) 29 | headers = { 30 | "accept": "application/json", 31 | "Authorization": f"Bearer {api_key}", 32 | } 33 | url = f"{BASE_URL}/custom_api_key" 34 | 35 | params = {"name": name, "value": value} 36 | 37 | response = _requests.post(url, headers=headers, params=params) 38 | if response.status_code != 200: 39 | raise Exception(response.json()) 40 | 41 | return response.json() 42 | 43 | 44 | def get_custom_api_key( 45 | name: str, 46 | *, 47 | api_key: Optional[str] = None, 48 | ) -> Dict[str, Any]: 49 | """ 50 | Get the value of a custom API key. 51 | 52 | Args: 53 | name: Name of the API key to get the value for. 54 | api_key: If specified, unify API key to be used. Defaults 55 | to the value in the `UNIFY_KEY` environment variable. 56 | 57 | Returns: 58 | A dictionary containing the custom API key information. 59 | 60 | Raises: 61 | requests.HTTPError: If the request fails. 62 | """ 63 | api_key = _validate_api_key(api_key) 64 | headers = { 65 | "accept": "application/json", 66 | "Authorization": f"Bearer {api_key}", 67 | } 68 | url = f"{BASE_URL}/custom_api_key" 69 | params = {"name": name} 70 | 71 | response = _requests.get(url, headers=headers, params=params) 72 | if response.status_code != 200: 73 | raise Exception(response.json()) 74 | 75 | return response.json() 76 | 77 | 78 | def delete_custom_api_key( 79 | name: str, 80 | *, 81 | api_key: Optional[str] = None, 82 | ) -> Dict[str, str]: 83 | """ 84 | Delete a custom API key. 85 | 86 | Args: 87 | name: Name of the custom API key to delete. 88 | api_key: If specified, unify API key to be used. Defaults 89 | to the value in the `UNIFY_KEY` environment variable. 90 | 91 | Returns: 92 | A dictionary containing the response message if successful. 93 | 94 | Raises: 95 | requests.HTTPError: If the API request fails. 96 | KeyError: If the API key is not found. 97 | """ 98 | api_key = _validate_api_key(api_key) 99 | headers = { 100 | "accept": "application/json", 101 | "Authorization": f"Bearer {api_key}", 102 | } 103 | url = f"{BASE_URL}/custom_api_key" 104 | 105 | params = {"name": name} 106 | 107 | response = _requests.delete(url, headers=headers, params=params) 108 | 109 | if response.status_code == 200: 110 | return response.json() 111 | elif response.status_code == 404: 112 | raise KeyError("API key not found.") 113 | else: 114 | if response.status_code != 200: 115 | raise Exception(response.json()) 116 | 117 | 118 | def rename_custom_api_key( 119 | name: str, 120 | new_name: str, 121 | *, 122 | api_key: Optional[str] = None, 123 | ) -> Dict[str, Any]: 124 | """ 125 | Rename a custom API key. 126 | 127 | Args: 128 | name: Name of the custom API key to be updated. 129 | new_name: New name for the custom API key. 130 | api_key: If specified, unify API key to be used. Defaults 131 | to the value in the `UNIFY_KEY` environment variable. 132 | 133 | Returns: 134 | A dictionary containing the response information. 135 | 136 | Raises: 137 | requests.HTTPError: If the API request fails. 138 | KeyError: If the API key is not provided or found in environment variables. 139 | """ 140 | api_key = _validate_api_key(api_key) 141 | headers = { 142 | "accept": "application/json", 143 | "Authorization": f"Bearer {api_key}", 144 | } 145 | url = f"{BASE_URL}/custom_api_key/rename" 146 | 147 | params = {"name": name, "new_name": new_name} 148 | 149 | response = _requests.post(url, headers=headers, params=params) 150 | if response.status_code != 200: 151 | raise Exception(response.json()) 152 | 153 | return response.json() 154 | 155 | 156 | def list_custom_api_keys( 157 | *, 158 | api_key: Optional[str] = None, 159 | ) -> List[Dict[str, str]]: 160 | """ 161 | Get a list of custom API keys associated with the user's account. 162 | 163 | Args: 164 | api_key: If specified, unify API key to be used. Defaults 165 | to the value in the `UNIFY_KEY` environment variable. 166 | 167 | Returns: 168 | A list of dictionaries containing custom API key information. 169 | Each dictionary has 'name' and 'value' keys. 170 | 171 | """ 172 | api_key = _validate_api_key(api_key) 173 | headers = { 174 | "accept": "application/json", 175 | "Authorization": f"Bearer {api_key}", 176 | } 177 | url = f"{BASE_URL}/custom_api_key/list" 178 | 179 | response = _requests.get(url, headers=headers) 180 | if response.status_code != 200: 181 | raise Exception(response.json()) 182 | 183 | return response.json() 184 | -------------------------------------------------------------------------------- /unify/universal_api/utils/custom_endpoints.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from unify import BASE_URL 4 | from unify.utils import _requests 5 | 6 | from ...utils.helpers import _validate_api_key 7 | 8 | 9 | def create_custom_endpoint( 10 | *, 11 | name: str, 12 | url: str, 13 | key_name: str, 14 | model_name: Optional[str] = None, 15 | provider: Optional[str] = None, 16 | api_key: Optional[str] = None, 17 | ) -> Dict[str, Any]: 18 | """ 19 | Create a custom endpoint for API calls. 20 | 21 | Args: 22 | name: Alias for the custom endpoint. This will be the name used to call the endpoint. 23 | url: Base URL of the endpoint being called. Must support the OpenAI format. 24 | key_name: Name of the API key that will be passed as part of the query. 25 | model_name: Name passed to the custom endpoint as model name. If not specified, it will default to the endpoint alias. 26 | provider: If the custom endpoint is for a fine-tuned model which is hosted directly via one of the supported providers, 27 | then this argument should be specified as the provider used. 28 | api_key: If specified, unify API key to be used. Defaults to the value in the `UNIFY_KEY` environment variable. 29 | 30 | Returns: 31 | A dictionary containing the response from the API. 32 | 33 | Raises: 34 | requests.HTTPError: If the API request fails. 35 | KeyError: If the UNIFY_KEY is not set and no api_key is provided. 36 | """ 37 | api_key = _validate_api_key(api_key) 38 | headers = { 39 | "accept": "application/json", 40 | "Authorization": f"Bearer {api_key}", 41 | } 42 | 43 | params = { 44 | "name": name, 45 | "url": url, 46 | "key_name": key_name, 47 | } 48 | 49 | if model_name: 50 | params["model_name"] = model_name 51 | if provider: 52 | params["provider"] = provider 53 | 54 | response = _requests.post( 55 | f"{BASE_URL}/custom_endpoint", 56 | headers=headers, 57 | params=params, 58 | ) 59 | if response.status_code != 200: 60 | raise Exception(response.json()) 61 | 62 | return response.json() 63 | 64 | 65 | def delete_custom_endpoint( 66 | name: str, 67 | *, 68 | api_key: Optional[str] = None, 69 | ) -> Dict[str, str]: 70 | """ 71 | Delete a custom endpoint. 72 | 73 | Args: 74 | name: Name of the custom endpoint to delete. 75 | api_key: If specified, unify API key to be used. Defaults 76 | to the value in the `UNIFY_KEY` environment variable. 77 | 78 | Returns: 79 | A dictionary containing the response message. 80 | 81 | Raises: 82 | requests.HTTPError: If the API request fails. 83 | """ 84 | api_key = _validate_api_key(api_key) 85 | headers = { 86 | "accept": "application/json", 87 | "Authorization": f"Bearer {api_key}", 88 | } 89 | url = f"{BASE_URL}/custom_endpoint" 90 | 91 | params = {"name": name} 92 | 93 | response = _requests.delete(url, headers=headers, params=params) 94 | if response.status_code != 200: 95 | raise Exception(response.json()) 96 | 97 | return response.json() 98 | 99 | 100 | def rename_custom_endpoint( 101 | name: str, 102 | new_name: str, 103 | *, 104 | api_key: Optional[str] = None, 105 | ) -> Dict[str, Any]: 106 | """ 107 | Rename a custom endpoint. 108 | 109 | Args: 110 | name: Name of the custom endpoint to be updated. 111 | new_name: New name for the custom endpoint. 112 | api_key: If specified, unify API key to be used. Defaults 113 | to the value in the `UNIFY_KEY` environment variable. 114 | 115 | Returns: 116 | A dictionary containing the response information. 117 | 118 | Raises: 119 | requests.HTTPError: If the API request fails. 120 | """ 121 | api_key = _validate_api_key(api_key) 122 | headers = { 123 | "accept": "application/json", 124 | "Authorization": f"Bearer {api_key}", 125 | } 126 | url = f"{BASE_URL}/custom_endpoint/rename" 127 | 128 | params = {"name": name, "new_name": new_name} 129 | 130 | response = _requests.post(url, headers=headers, params=params) 131 | if response.status_code != 200: 132 | raise Exception(response.json()) 133 | 134 | return response.json() 135 | 136 | 137 | def list_custom_endpoints( 138 | *, 139 | api_key: Optional[str] = None, 140 | ) -> List[Dict[str, str]]: 141 | """ 142 | Get a list of custom endpoints for the authenticated user. 143 | 144 | Args: 145 | api_key: If specified, unify API key to be used. Defaults 146 | to the value in the `UNIFY_KEY` environment variable. 147 | 148 | Returns: 149 | A list of dictionaries containing information about custom endpoints. 150 | Each dictionary has keys: 'name', 'mdl_name', 'url', and 'key'. 151 | 152 | Raises: 153 | requests.exceptions.RequestException: If the API request fails. 154 | """ 155 | api_key = _validate_api_key(api_key) 156 | headers = { 157 | "accept": "application/json", 158 | "Authorization": f"Bearer {api_key}", 159 | } 160 | url = f"{BASE_URL}/custom_endpoint/list" 161 | 162 | response = _requests.get(url, headers=headers) 163 | if response.status_code != 200: 164 | raise Exception(response.json()) 165 | 166 | return response.json() 167 | -------------------------------------------------------------------------------- /unify/universal_api/utils/endpoint_metrics.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Dict, List, Optional, Union 3 | 4 | from pydantic import BaseModel 5 | from unify import BASE_URL 6 | from unify.utils import _requests 7 | 8 | from ...utils.helpers import _validate_api_key 9 | 10 | 11 | class Metrics(BaseModel, extra="allow"): 12 | ttft: Optional[float] 13 | itl: Optional[float] 14 | input_cost: Optional[float] 15 | output_cost: Optional[float] 16 | measured_at: Union[datetime.datetime, str, Dict[str, Union[datetime.datetime, str]]] 17 | 18 | 19 | def get_endpoint_metrics( 20 | endpoint: str, 21 | *, 22 | start_time: Optional[Union[datetime.datetime, str]] = None, 23 | end_time: Optional[Union[datetime.datetime, str]] = None, 24 | api_key: Optional[str] = None, 25 | ) -> List[Metrics]: 26 | """ 27 | Retrieve the set of cost and speed metrics for the specified endpoint. 28 | 29 | Args: 30 | endpoint: The endpoint to retrieve the metrics for, in model@provider format 31 | 32 | start_time: Window start time. Only returns the latest benchmark if unspecified. 33 | 34 | end_time: Window end time. Assumed to be the current time if this is unspecified 35 | and start_time is specified. Only the latest benchmark is returned if both are 36 | unspecified. 37 | 38 | api_key: If specified, unify API key to be used. Defaults to the value in the 39 | `UNIFY_KEY` environment variable. 40 | 41 | Returns: 42 | The set of metrics for the specified endpoint. 43 | """ 44 | api_key = _validate_api_key(api_key) 45 | headers = { 46 | "accept": "application/json", 47 | "Authorization": f"Bearer {api_key}", 48 | } 49 | params = { 50 | "model": endpoint.split("@")[0], 51 | "provider": endpoint.split("@")[1], 52 | "start_time": start_time, 53 | "end_time": end_time, 54 | } 55 | response = _requests.get( 56 | BASE_URL + "/endpoint-metrics", 57 | headers=headers, 58 | params=params, 59 | ) 60 | if response.status_code != 200: 61 | raise Exception(response.json()) 62 | return [ 63 | Metrics( 64 | ttft=metrics_dct["ttft"], 65 | itl=metrics_dct["itl"], 66 | input_cost=metrics_dct["input_cost"], 67 | output_cost=metrics_dct["output_cost"], 68 | measured_at=metrics_dct["measured_at"], 69 | ) 70 | for metrics_dct in response.json() 71 | ] 72 | 73 | 74 | def log_endpoint_metric( 75 | endpoint_name: str, 76 | *, 77 | metric_name: str, 78 | value: float, 79 | measured_at: Optional[Union[str, datetime.datetime]] = None, 80 | api_key: Optional[str] = None, 81 | ) -> Dict[str, str]: 82 | """ 83 | Append speed or cost data to the standardized time-series benchmarks for a custom 84 | endpoint (only custom endpoints are publishable by end users). 85 | 86 | Args: 87 | endpoint_name: Name of the custom endpoint to append benchmark data for. 88 | 89 | metric_name: Name of the metric to submit. Allowed metrics are: “input_cost”, 90 | “output_cost”, “ttft”, “itl”. 91 | 92 | value: Value of the metric to submit. 93 | 94 | measured_at: The timestamp to associate with the submission. Defaults to current 95 | time if unspecified. 96 | 97 | api_key: If specified, unify API key to be used. Defaults to the value in the 98 | `UNIFY_KEY` environment variable. 99 | """ 100 | api_key = _validate_api_key(api_key) 101 | headers = { 102 | "accept": "application/json", 103 | "Authorization": f"Bearer {api_key}", 104 | } 105 | params = { 106 | "endpoint_name": endpoint_name, 107 | "metric_name": metric_name, 108 | "value": value, 109 | "measured_at": measured_at, 110 | } 111 | response = _requests.post( 112 | BASE_URL + "/endpoint-metrics", 113 | headers=headers, 114 | params=params, 115 | ) 116 | if response.status_code != 200: 117 | raise Exception(response.json()) 118 | return response.json() 119 | 120 | 121 | def delete_endpoint_metrics( 122 | endpoint_name: str, 123 | *, 124 | timestamps: Optional[Union[datetime.datetime, List[datetime.datetime]]] = None, 125 | api_key: Optional[str] = None, 126 | ) -> Dict[str, str]: 127 | api_key = _validate_api_key(api_key) 128 | headers = { 129 | "accept": "application/json", 130 | "Authorization": f"Bearer {api_key}", 131 | } 132 | params = { 133 | "endpoint_name": endpoint_name, 134 | "timestamps": timestamps, 135 | } 136 | response = _requests.delete( 137 | BASE_URL + "/endpoint-metrics", 138 | headers=headers, 139 | params=params, 140 | ) 141 | if response.status_code != 200: 142 | raise Exception(response.json()) 143 | return response.json() 144 | -------------------------------------------------------------------------------- /unify/universal_api/utils/queries.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | from unify import BASE_URL 5 | from unify.utils import _requests 6 | 7 | from ...utils.helpers import _validate_api_key 8 | 9 | 10 | def get_query_tags( 11 | *, 12 | api_key: Optional[str] = None, 13 | ) -> List[str]: 14 | """ 15 | Get a list of available query tags. 16 | 17 | Args: 18 | api_key: If specified, unify API key to be used. Defaults 19 | to the value in the `UNIFY_KEY` environment variable. 20 | 21 | Returns: 22 | A list of available query tags if successful, otherwise an empty list. 23 | """ 24 | api_key = _validate_api_key(api_key) 25 | headers = { 26 | "accept": "application/json", 27 | "Authorization": f"Bearer {api_key}", 28 | } 29 | url = f"{BASE_URL}/tags" 30 | response = _requests.get(url, headers=headers) 31 | if response.status_code != 200: 32 | raise Exception(response.json()) 33 | 34 | return response.json() 35 | 36 | 37 | def get_queries( 38 | *, 39 | tags: Optional[Union[str, List[str]]] = None, 40 | endpoints: Optional[Union[str, List[str]]] = None, 41 | start_time: Optional[Union[datetime.datetime, str]] = None, 42 | end_time: Optional[Union[datetime.datetime, str]] = None, 43 | page_number: Optional[int] = None, 44 | failures: Optional[Union[bool, str]] = None, 45 | api_key: Optional[str] = None, 46 | ) -> Dict[str, Any]: 47 | """ 48 | Get query history based on specified filters. 49 | 50 | Args: 51 | tags: Tags to filter for queries that are marked with these tags. 52 | 53 | endpoints: Optionally specify an endpoint, or a list of endpoints to filter for. 54 | 55 | start_time: Timestamp of the earliest query to aggregate. 56 | Format is `YYYY-MM-DD hh:mm:ss`. 57 | 58 | end_time: Timestamp of the latest query to aggregate. 59 | Format is `YYYY-MM-DD hh:mm:ss`. 60 | 61 | page_number: The query history is returned in pages, with up to 100 prompts per 62 | page. Increase the page number to see older prompts. Default is 1. 63 | 64 | failures: indicates whether to includes failures in the return 65 | (when set as True), or whether to return failures exclusively 66 | (when set as ‘only’). Default is False. 67 | 68 | api_key: If specified, unify API key to be used. 69 | Defaults to the value in the `UNIFY_KEY` environment variable. 70 | 71 | Returns: 72 | A dictionary containing the query history data. 73 | """ 74 | api_key = _validate_api_key(api_key) 75 | headers = { 76 | "accept": "application/json", 77 | "Authorization": f"Bearer {api_key}", 78 | } 79 | 80 | params = {} 81 | if tags: 82 | params["tags"] = tags 83 | if endpoints: 84 | params["endpoints"] = endpoints 85 | if start_time: 86 | params["start_time"] = start_time 87 | if end_time: 88 | params["end_time"] = end_time 89 | if page_number: 90 | params["page_number"] = page_number 91 | if failures: 92 | params["failures"] = failures 93 | 94 | url = f"{BASE_URL}/queries" 95 | response = _requests.get(url, headers=headers, params=params) 96 | if response.status_code != 200: 97 | raise Exception(response.json()) 98 | 99 | return response.json() 100 | 101 | 102 | def log_query( 103 | *, 104 | endpoint: str, 105 | query_body: Dict, 106 | response_body: Optional[Dict] = None, 107 | tags: Optional[List[str]] = None, 108 | timestamp: Optional[Union[datetime.datetime, str]] = None, 109 | api_key: Optional[str] = None, 110 | ): 111 | """ 112 | Log a query (and optionally response) for a locally deployed (non-Unify-registered) 113 | model, with tagging (default None) and timestamp (default datetime.now() also 114 | optionally writeable. 115 | 116 | Args: 117 | endpoint: Endpoint to log query for. 118 | query_body: A dict containing the body of the request. 119 | response_body: An optional dict containing the response to the request. 120 | tags: Custom tags for later filtering. 121 | timestamp: A timestamp (if not set, will be the time of sending). 122 | api_key: If specified, unify API key to be used. Defaults to the value in the `UNIFY_KEY` environment variable. 123 | 124 | Returns: 125 | A dictionary containing the response message if successful. 126 | 127 | Raises: 128 | requests.HTTPError: If the API request fails. 129 | """ 130 | api_key = _validate_api_key(api_key) 131 | headers = { 132 | "accept": "application/json", 133 | "Authorization": f"Bearer {api_key}", 134 | } 135 | 136 | data = { 137 | "endpoint": endpoint, 138 | "query_body": query_body, 139 | "response_body": response_body, 140 | "tags": tags, 141 | "timestamp": timestamp, 142 | } 143 | 144 | # Remove None values from params 145 | data = {k: v for k, v in data.items() if v is not None} 146 | 147 | url = f"{BASE_URL}/queries" 148 | 149 | response = _requests.post(url, headers=headers, json=data) 150 | if response.status_code != 200: 151 | raise Exception(response.json()) 152 | 153 | return response.json() 154 | 155 | 156 | def get_query_metrics( 157 | *, 158 | start_time: Optional[Union[datetime.datetime, str]] = None, 159 | end_time: Optional[Union[datetime.datetime, str]] = None, 160 | models: Optional[str] = None, 161 | providers: Optional[str] = None, 162 | interval: int = 300, 163 | secondary_user_id: Optional[str] = None, 164 | api_key: Optional[str] = None, 165 | ) -> Dict[str, Any]: 166 | """ 167 | Get query metrics for specified parameters. 168 | 169 | Args: 170 | start_time: Timestamp of the earliest query to aggregate. Format is `YYYY-MM-DD hh:mm:ss`. 171 | end_time: Timestamp of the latest query to aggregate. Format is `YYYY-MM-DD hh:mm:ss`. 172 | models: Models to fetch metrics from. Comma-separated string of model names. 173 | providers: Providers to fetch metrics from. Comma-separated string of provider names. 174 | interval: Number of seconds in the aggregation interval. Default is 300. 175 | secondary_user_id: Secondary user id to match the `user` attribute from `/chat/completions`. 176 | api_key: If specified, unify API key to be used. Defaults to the value in the `UNIFY_KEY` environment variable. 177 | 178 | Returns: 179 | A dictionary containing the query metrics. 180 | """ 181 | api_key = _validate_api_key(api_key) 182 | headers = { 183 | "accept": "application/json", 184 | "Authorization": f"Bearer {api_key}", 185 | } 186 | 187 | params = { 188 | "start_time": start_time, 189 | "end_time": end_time, 190 | "models": models, 191 | "providers": providers, 192 | "interval": interval, 193 | "secondary_user_id": secondary_user_id, 194 | } 195 | 196 | # Remove None values from params 197 | params = {k: v for k, v in params.items() if v is not None} 198 | 199 | url = f"{BASE_URL}/metrics" 200 | 201 | response = _requests.get(url, headers=headers, params=params) 202 | if response.status_code != 200: 203 | raise Exception(response.json()) 204 | 205 | return response.json() 206 | -------------------------------------------------------------------------------- /unify/universal_api/utils/supported_endpoints.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from unify import BASE_URL 4 | from unify.utils import _requests 5 | 6 | from ...utils.helpers import _res_to_list, _validate_api_key 7 | 8 | 9 | def list_providers( 10 | model: Optional[str] = None, 11 | *, 12 | api_key: Optional[str] = None, 13 | ) -> List[str]: 14 | """ 15 | Get a list of available providers, either in total or for a specific model. 16 | 17 | Args: 18 | model: If specified, returns the list of providers supporting this model. 19 | api_key: If specified, unify API key to be used. Defaults 20 | to the value in the `UNIFY_KEY` environment variable. 21 | 22 | Returns: 23 | A list of provider names associated with the model if successful, otherwise an 24 | empty list. 25 | Raises: 26 | BadRequestError: If there was an HTTP error. 27 | ValueError: If there was an error parsing the JSON response. 28 | """ 29 | api_key = _validate_api_key(api_key) 30 | headers = { 31 | "accept": "application/json", 32 | "Authorization": f"Bearer {api_key}", 33 | } 34 | url = f"{BASE_URL}/providers" 35 | if model: 36 | kw = dict(headers=headers, params={"model": model}) 37 | else: 38 | kw = dict(headers=headers) 39 | response = _requests.get(url, **kw) 40 | if response.status_code != 200: 41 | raise Exception(response.json()) 42 | return _res_to_list(response) 43 | 44 | 45 | def list_models( 46 | provider: Optional[str] = None, 47 | *, 48 | api_key: Optional[str] = None, 49 | ) -> List[str]: 50 | """ 51 | Get a list of available models, either in total or for a specific provider. 52 | 53 | Args: 54 | provider: If specified, returns the list of models supporting this provider. 55 | api_key: If specified, unify API key to be used. Defaults 56 | to the value in the `UNIFY_KEY` environment variable. 57 | 58 | Returns: 59 | A list of available model names if successful, otherwise an empty list. 60 | Raises: 61 | BadRequestError: If there was an HTTP error. 62 | ValueError: If there was an error parsing the JSON response. 63 | """ 64 | api_key = _validate_api_key(api_key) 65 | headers = { 66 | "accept": "application/json", 67 | "Authorization": f"Bearer {api_key}", 68 | } 69 | url = f"{BASE_URL}/models" 70 | if provider: 71 | kw = dict(headers=headers, params={"provider": provider}) 72 | else: 73 | kw = dict(headers=headers) 74 | response = _requests.get(url, **kw) 75 | if response.status_code != 200: 76 | raise Exception(response.json()) 77 | return _res_to_list(response) 78 | 79 | 80 | def list_endpoints( 81 | model: Optional[str] = None, 82 | provider: Optional[str] = None, 83 | *, 84 | api_key: Optional[str] = None, 85 | ) -> List[str]: 86 | """ 87 | Get a list of available endpoint, either in total or for a specific model or 88 | provider. 89 | 90 | Args: 91 | model: If specified, returns the list of endpoint supporting this model. 92 | provider: If specified, returns the list of endpoint supporting this provider. 93 | 94 | api_key: If specified, unify API key to be used. Defaults to the value in the 95 | `UNIFY_KEY` environment variable. 96 | 97 | Returns: 98 | A list of endpoint names if successful, otherwise an empty list. 99 | Raises: 100 | BadRequestError: If there was an HTTP error. 101 | ValueError: If there was an error parsing the JSON response. 102 | """ 103 | api_key = _validate_api_key(api_key) 104 | headers = { 105 | "accept": "application/json", 106 | "Authorization": f"Bearer {api_key}", 107 | } 108 | url = f"{BASE_URL}/endpoints" 109 | if model and provider: 110 | raise ValueError("Please specify either model OR provider, not both.") 111 | elif model: 112 | kw = dict(headers=headers, params={"model": model}) 113 | return _res_to_list( 114 | _requests.get(url, headers=headers, params={"model": model}), 115 | ) 116 | elif provider: 117 | kw = dict(headers=headers, params={"provider": provider}) 118 | else: 119 | kw = dict(headers=headers) 120 | response = _requests.get(url, **kw) 121 | if response.status_code != 200: 122 | raise Exception(response.json()) 123 | return _res_to_list(response) 124 | -------------------------------------------------------------------------------- /unify/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import helpers 2 | from .map import * 3 | -------------------------------------------------------------------------------- /unify/utils/_requests.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import requests 6 | 7 | _logger = logging.getLogger("unify_requests") 8 | _log_enabled = os.getenv("UNIFY_REQUESTS_DEBUG", "false").lower() in ("true", "1") 9 | _logger.setLevel(logging.DEBUG if _log_enabled else logging.WARNING) 10 | 11 | 12 | class ResponseDecodeError(Exception): 13 | def __init__(self, response: requests.Response): 14 | self.response = response 15 | super().__init__(f"Request failed to parse response: {response.text}") 16 | 17 | 18 | def _log(type: str, url: str, mask_key: bool = True, /, **kwargs): 19 | if not _log_enabled: 20 | return 21 | _kwargs_str = "" 22 | if mask_key and "headers" in kwargs: 23 | key = kwargs["headers"]["Authorization"] 24 | kwargs["headers"]["Authorization"] = "***" 25 | 26 | for k, v in kwargs.items(): 27 | if isinstance(v, dict): 28 | _kwargs_str += f"{k:}:{json.dumps(v, indent=2)},\n" 29 | else: 30 | _kwargs_str += f"{k}:{v},\n" 31 | 32 | if mask_key and "headers" in kwargs: 33 | kwargs["headers"]["Authorization"] = key 34 | 35 | log_msg = f""" 36 | ====== {type} ======= 37 | url:{url} 38 | {_kwargs_str} 39 | """ 40 | _logger.debug(log_msg) 41 | 42 | 43 | def request(method, url, **kwargs): 44 | _log(f"request:{method}", url, True, **kwargs) 45 | res = requests.request(method, url, **kwargs) 46 | try: 47 | _log(f"request:{method} response:{res.status_code}", url, response=res.json()) 48 | except requests.exceptions.JSONDecodeError as e: 49 | raise ResponseDecodeError(res) 50 | return res 51 | 52 | 53 | def get(url, params=None, **kwargs): 54 | _log("GET", url, True, params=params, **kwargs) 55 | res = requests.get(url, params=params, **kwargs) 56 | try: 57 | _log(f"GET response:{res.status_code}", url, response=res.json()) 58 | except requests.exceptions.JSONDecodeError as e: 59 | raise ResponseDecodeError(res) 60 | return res 61 | 62 | 63 | def options(url, **kwargs): 64 | _log("OPTIONS", url, True, **kwargs) 65 | res = requests.options(url, **kwargs) 66 | try: 67 | _log(f"OPTIONS response:{res.status_code}", url, response=res.json()) 68 | except requests.exceptions.JSONDecodeError as e: 69 | raise ResponseDecodeError(res) 70 | return res 71 | 72 | 73 | def head(url, **kwargs): 74 | _log("HEAD", url, True, **kwargs) 75 | res = requests.head(url, **kwargs) 76 | try: 77 | _log(f"HEAD response:{res.status_code}", url, response=res.json()) 78 | except requests.exceptions.JSONDecodeError as e: 79 | raise ResponseDecodeError(res) 80 | return res 81 | 82 | 83 | def post(url, data=None, json=None, **kwargs): 84 | _log("POST", url, True, data=data, json=json, **kwargs) 85 | res = requests.post(url, data=data, json=json, **kwargs) 86 | try: 87 | _log(f"POST response:{res.status_code}", url, response=res.json()) 88 | except requests.exceptions.JSONDecodeError as e: 89 | raise ResponseDecodeError(res) 90 | return res 91 | 92 | 93 | def put(url, data=None, **kwargs): 94 | _log("PUT", url, True, data=data, **kwargs) 95 | res = requests.put(url, data=data, **kwargs) 96 | try: 97 | _log(f"PUT response:{res.status_code}", url, response=res.json()) 98 | except requests.exceptions.JSONDecodeError as e: 99 | raise ResponseDecodeError(res) 100 | return res 101 | 102 | 103 | def patch(url, data=None, **kwargs): 104 | _log("PATCH", url, True, data=data, **kwargs) 105 | res = requests.patch(url, data=data, **kwargs) 106 | try: 107 | _log(f"PATCH response:{res.status_code}", url, response=res.json()) 108 | except requests.exceptions.JSONDecodeError as e: 109 | raise ResponseDecodeError(res) 110 | return res 111 | 112 | 113 | def delete(url, **kwargs): 114 | _log("DELETE", url, True, **kwargs) 115 | res = requests.delete(url, **kwargs) 116 | try: 117 | _log(f"DELETE response:{res.status_code}", url, response=res.json()) 118 | except requests.exceptions.JSONDecodeError as e: 119 | raise ResponseDecodeError(res) 120 | return res 121 | -------------------------------------------------------------------------------- /unify/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import json 3 | import os 4 | import threading 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | 7 | import requests 8 | import unify 9 | from pydantic import BaseModel, ValidationError 10 | 11 | PROJECT_LOCK = threading.Lock() 12 | 13 | 14 | class RequestError(Exception): 15 | def __init__(self, response: requests.Response): 16 | req = response.request 17 | message = ( 18 | f"{req.method} {req.url} failed with status code {response.status_code}. " 19 | f"Request body: {req.body}, Response: {response.text}" 20 | ) 21 | super().__init__(message) 22 | self.response = response 23 | 24 | 25 | def _check_response(response: requests.Response): 26 | if not response.ok: 27 | raise RequestError(response) 28 | 29 | 30 | def _res_to_list(response: requests.Response) -> Union[List, Dict]: 31 | return json.loads(response.text) 32 | 33 | 34 | def _validate_api_key(api_key: Optional[str]) -> str: 35 | if api_key is None: 36 | api_key = os.environ.get("UNIFY_KEY") 37 | if api_key is None: 38 | raise KeyError( 39 | "UNIFY_KEY is missing. Please make sure it is set correctly!", 40 | ) 41 | return api_key 42 | 43 | 44 | def _default(value: Any, default_value: Any) -> Any: 45 | return value if value is not None else default_value 46 | 47 | 48 | def _dict_aligns_with_pydantic(dict_in: Dict, pydantic_cls: type(BaseModel)) -> bool: 49 | try: 50 | pydantic_cls.model_validate(dict_in) 51 | return True 52 | except ValidationError: 53 | return False 54 | 55 | 56 | def _make_json_serializable( 57 | item: Any, 58 | ) -> Union[Dict, List, Tuple]: 59 | # Add a recursion guard using getattr to avoid infinite recursion 60 | if hasattr(item, "_being_serialized") and getattr(item, "_being_serialized", False): 61 | return "" 62 | 63 | try: 64 | # For objects that might cause recursion, set a flag 65 | if hasattr(item, "__dict__") and not isinstance( 66 | item, 67 | (dict, list, tuple, BaseModel), 68 | ): 69 | setattr(item, "_being_serialized", True) 70 | 71 | if isinstance(item, list): 72 | result = [_make_json_serializable(i) for i in item] 73 | elif isinstance(item, dict): 74 | result = {k: _make_json_serializable(v) for k, v in item.items()} 75 | elif isinstance(item, tuple): 76 | result = tuple(_make_json_serializable(i) for i in item) 77 | elif inspect.isclass(item) and issubclass(item, BaseModel): 78 | result = item.model_json_schema() 79 | elif isinstance(item, BaseModel): 80 | result = item.model_dump() 81 | elif hasattr(item, "json") and callable(item.json): 82 | result = _make_json_serializable(item.json()) 83 | # Handle threading objects specifically 84 | elif "threading" in type(item).__module__: 85 | result = f"<{type(item).__name__} at {id(item)}>" 86 | elif isinstance(item, (int, float, bool, str, type(None))): 87 | result = item 88 | else: 89 | try: 90 | result = json.dumps(item) 91 | except Exception: 92 | try: 93 | result = str(item) 94 | except Exception: 95 | result = f"<{type(item).__name__} at {id(item)}>" 96 | 97 | return result 98 | finally: 99 | # Clean up the recursion guard flag 100 | if hasattr(item, "__dict__") and not isinstance( 101 | item, 102 | (dict, list, tuple, BaseModel), 103 | ): 104 | try: 105 | delattr(item, "_being_serialized") 106 | except (AttributeError, TypeError): 107 | pass 108 | 109 | 110 | def _get_and_maybe_create_project( 111 | project: Optional[str] = None, 112 | required: bool = True, 113 | api_key: Optional[str] = None, 114 | create_if_missing: bool = True, 115 | ) -> Optional[str]: 116 | # noinspection PyUnresolvedReferences 117 | from unify.logging.utils.logs import ASYNC_LOGGING 118 | 119 | api_key = _validate_api_key(api_key) 120 | if project is None: 121 | project = unify.active_project() 122 | if project is None: 123 | if required: 124 | project = "_" 125 | else: 126 | return None 127 | if not create_if_missing: 128 | return project 129 | if ASYNC_LOGGING: 130 | # acquiring the project lock here will block the async logger 131 | # so we skip the lock if we are in async mode 132 | return project 133 | with PROJECT_LOCK: 134 | if project not in unify.list_projects(api_key=api_key): 135 | unify.create_project(project, api_key=api_key) 136 | return project 137 | 138 | 139 | def _prune_dict(val): 140 | def keep(v): 141 | if v in (None, "NOT_GIVEN"): 142 | return False 143 | else: 144 | ret = _prune_dict(v) 145 | if isinstance(ret, dict) or isinstance(ret, list) or isinstance(ret, tuple): 146 | return bool(ret) 147 | return True 148 | 149 | if ( 150 | not isinstance(val, dict) 151 | and not isinstance(val, list) 152 | and not isinstance(val, tuple) 153 | ): 154 | return val 155 | elif isinstance(val, dict): 156 | return {k: _prune_dict(v) for k, v in val.items() if keep(v)} 157 | elif isinstance(val, list): 158 | return [_prune_dict(v) for i, v in enumerate(val) if keep(v)] 159 | else: 160 | return tuple(_prune_dict(v) for i, v in enumerate(val) if keep(v)) 161 | 162 | 163 | import copy 164 | from typing import Any, Dict, List, Set, Tuple, Union 165 | 166 | __all__ = ["flexible_deepcopy"] 167 | 168 | 169 | # Internal sentinel: return this to signal "skip me". 170 | class _SkipType: 171 | pass 172 | 173 | 174 | _SKIP = _SkipType() 175 | 176 | Container = Union[Dict[Any, Any], List[Any], Tuple[Any, ...], Set[Any]] 177 | 178 | 179 | def flexible_deepcopy( 180 | obj: Any, 181 | on_fail: str = "raise", 182 | _memo: Optional[Dict[int, Any]] = None, 183 | ) -> Any: 184 | """ 185 | Perform a deepcopy that tolerates un‑copyable elements. 186 | 187 | Parameters 188 | ---------- 189 | obj : Any 190 | The object you wish to copy. 191 | on_fail : {'raise', 'skip', 'shallow'}, default 'raise' 192 | • 'raise' – re‑raise copy error (standard behaviour). 193 | • 'skip' – drop the offending element from the result. 194 | • 'shallow' – insert the original element unchanged. 195 | _memo : dict or None (internal) 196 | Memoisation dict to preserve identity & avoid infinite recursion. 197 | 198 | Returns 199 | ------- 200 | Any 201 | A deep‑copied version of *obj*, modified per *on_fail* strategy. 202 | 203 | Raises 204 | ------ 205 | ValueError 206 | If *on_fail* is not one of the accepted values. 207 | Exception 208 | Re‑raises whatever copy error occurred when *on_fail* == 'raise'. 209 | """ 210 | if _memo is None: 211 | _memo = {} 212 | 213 | obj_id = id(obj) 214 | if obj_id in _memo: # Handle circular references. 215 | return _memo[obj_id] 216 | 217 | def _attempt(value: Any) -> Union[Any, _SkipType]: 218 | """Try to deepcopy *value*; fall back per on_fail.""" 219 | try: 220 | return flexible_deepcopy(value, on_fail, _memo) 221 | except Exception: 222 | if on_fail == "raise": 223 | raise 224 | if on_fail == "shallow": 225 | return value 226 | if on_fail == "skip": 227 | return _SKIP 228 | raise ValueError(f"Invalid on_fail option: {on_fail!r}") 229 | 230 | # --- Handle built‑in containers explicitly --------------------------- 231 | if isinstance(obj, dict): 232 | result: Dict[Any, Any] = {} 233 | _memo[obj_id] = result # Early memoisation for cycles 234 | for k, v in obj.items(): 235 | nk = _attempt(k) 236 | nv = _attempt(v) 237 | if _SKIP in (nk, nv): # Skip entry if key or value failed 238 | continue 239 | result[nk] = nv 240 | return result 241 | 242 | if isinstance(obj, list): 243 | result: List[Any] = [] 244 | _memo[obj_id] = result 245 | for item in obj: 246 | nitem = _attempt(item) 247 | if nitem is not _SKIP: 248 | result.append(nitem) 249 | return result 250 | 251 | if isinstance(obj, tuple): 252 | items = [] 253 | _memo[obj_id] = None # Placeholder for circular refs 254 | for item in obj: 255 | nitem = _attempt(item) 256 | if nitem is not _SKIP: 257 | items.append(nitem) 258 | result = tuple(items) 259 | _memo[obj_id] = result 260 | return result 261 | 262 | if isinstance(obj, set): 263 | result: Set[Any] = set() 264 | _memo[obj_id] = result 265 | for item in obj: 266 | nitem = _attempt(item) 267 | if nitem is not _SKIP: 268 | result.add(nitem) 269 | return result 270 | 271 | # --- Non‑container: fall back to standard deepcopy ------------------- 272 | try: 273 | result = copy.deepcopy(obj, _memo) 274 | _memo[obj_id] = result 275 | return result 276 | except Exception: 277 | if on_fail == "raise": 278 | raise 279 | if on_fail == "shallow": 280 | _memo[obj_id] = obj 281 | return obj 282 | if on_fail == "skip": 283 | return _SKIP 284 | raise ValueError(f"Invalid on_fail option: {on_fail!r}") 285 | -------------------------------------------------------------------------------- /unify/utils/map.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import contextvars 3 | import threading 4 | from typing import Any, List 5 | 6 | from tqdm import tqdm 7 | from tqdm.asyncio import tqdm_asyncio 8 | 9 | MAP_MODE = "threading" 10 | 11 | 12 | def set_map_mode(mode: str): 13 | global MAP_MODE 14 | MAP_MODE = mode 15 | 16 | 17 | def get_map_mode() -> str: 18 | return MAP_MODE 19 | 20 | 21 | def _is_iterable(item: Any) -> bool: 22 | try: 23 | iter(item) 24 | return True 25 | except TypeError: 26 | return False 27 | 28 | 29 | # noinspection PyShadowingBuiltins 30 | def map( 31 | fn: callable, 32 | *args, 33 | mode=None, 34 | name="", 35 | from_args=False, 36 | raise_exceptions=True, 37 | **kwargs, 38 | ) -> Any: 39 | 40 | if name: 41 | name = ( 42 | " ".join(substr[0].upper() + substr[1:] for substr in name.split("_")) + " " 43 | ) 44 | 45 | if mode is None: 46 | mode = get_map_mode() 47 | 48 | assert mode in ( 49 | "threading", 50 | "asyncio", 51 | "loop", 52 | ), "map mode must be one of threading, asyncio or loop." 53 | 54 | def fn_w_exception_handling(*a, **kw): 55 | try: 56 | return fn(*a, **kw) 57 | except Exception as e: 58 | if raise_exceptions: 59 | raise e 60 | 61 | if from_args: 62 | args = list(args) 63 | for i, a in enumerate(args): 64 | if _is_iterable(a): 65 | args[i] = list(a) 66 | 67 | if args: 68 | num_calls = len(args[0]) 69 | else: 70 | for v in kwargs.values(): 71 | if isinstance(v, list): 72 | num_calls = len(v) 73 | break 74 | else: 75 | raise Exception( 76 | "At least one of the args or kwargs must be a list, " 77 | "which is to be mapped across the threads", 78 | ) 79 | args_n_kwargs = [ 80 | ( 81 | tuple(a[i] for a in args), 82 | { 83 | k: v[i] if (isinstance(v, list) or isinstance(v, tuple)) else v 84 | for k, v in kwargs.items() 85 | }, 86 | ) 87 | for i in range(num_calls) 88 | ] 89 | else: 90 | args_n_kwargs = args[0] 91 | if not isinstance(args_n_kwargs[0], tuple): 92 | if isinstance(args_n_kwargs[0], dict): 93 | args_n_kwargs = [((), item) for item in args_n_kwargs] 94 | else: 95 | args_n_kwargs = [((item,), {}) for item in args_n_kwargs] 96 | elif ( 97 | not isinstance(args_n_kwargs[0][0], tuple) 98 | or len(args_n_kwargs[0]) < 2 99 | or not isinstance(args_n_kwargs[0][1], dict) 100 | ): 101 | args_n_kwargs = [(item, {}) for item in args_n_kwargs] 102 | num_calls = len(args_n_kwargs) 103 | 104 | if mode == "loop": 105 | 106 | pbar = tqdm(total=num_calls) 107 | pbar.set_description(f"{name}Iterations") 108 | 109 | returns = list() 110 | for a, kw in args_n_kwargs: 111 | ret = fn_w_exception_handling(*a, **kw) 112 | returns.append(ret) 113 | pbar.update(1) 114 | pbar.close() 115 | return returns 116 | 117 | elif mode == "threading": 118 | 119 | pbar = tqdm(total=num_calls) 120 | pbar.set_description(f"{name}Threads") 121 | 122 | def fn_w_indexing(rets: List[None], thread_idx: int, *a, **kw): 123 | for var, value in kw["context"].items(): 124 | var.set(value) 125 | del kw["context"] 126 | ret = fn_w_exception_handling(*a, **kw) 127 | pbar.update(1) 128 | rets[thread_idx] = ret 129 | 130 | threads = list() 131 | returns = [None] * num_calls 132 | for i, a_n_kw in enumerate(args_n_kwargs): 133 | a, kw = a_n_kw 134 | kw["context"] = contextvars.copy_context() 135 | thread = threading.Thread( 136 | target=fn_w_indexing, 137 | args=(returns, i, *a), 138 | kwargs=kw, 139 | ) 140 | thread.start() 141 | threads.append(thread) 142 | [thread.join() for thread in threads] 143 | pbar.close() 144 | return returns 145 | 146 | def _run_asyncio_in_thread(ret): 147 | asyncio.set_event_loop(asyncio.new_event_loop()) 148 | MAX_WORKERS = 100 149 | semaphore = asyncio.Semaphore(MAX_WORKERS) 150 | fns = [] 151 | 152 | async def fn_wrapper(*args, **kwargs): 153 | async with semaphore: 154 | return await asyncio.to_thread(fn_w_exception_handling, *args, **kwargs) 155 | 156 | for _, a_n_kw in enumerate(args_n_kwargs): 157 | a, kw = a_n_kw 158 | fns.append(fn_wrapper(*a, **kw)) 159 | 160 | async def main(fns): 161 | return await tqdm_asyncio.gather(*fns, desc=f"{name}Coroutines") 162 | 163 | ret += asyncio.run(main(fns)) 164 | 165 | ret = [] 166 | thread = threading.Thread(target=_run_asyncio_in_thread, args=(ret,)) 167 | thread.start() 168 | thread.join() 169 | return ret 170 | --------------------------------------------------------------------------------