├── .cache.json
├── .flake8
├── .github
├── bump_version.py
└── workflows
│ ├── pypi.yml
│ ├── test-demos.yml
│ ├── tests.yml
│ └── update-docs.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .run
└── pycharm
│ ├── Python tests for test_interfaces.run.xml
│ ├── Python tests for test_routing.run.xml
│ ├── Python tests for test_universal_api.run.xml
│ ├── Python tests for test_utils.run.xml
│ ├── Template Python tests.run.xml
│ ├── Template Python.run.xml
│ └── generate_docs.run.xml
├── LICENSE
├── README.md
├── generate_docs.py
├── poetry.lock
├── pydoc-markdown.yml
├── pyproject.toml
├── tests
├── __init__.py
├── conftest.py
├── test_logging
│ ├── __init__.py
│ ├── helpers.py
│ ├── test_dataset.py
│ ├── test_evals.py
│ ├── test_logs.py
│ ├── test_projects.py
│ ├── test_tracing.py
│ └── test_utils
│ │ ├── __init__.py
│ │ ├── test_artifacts.py
│ │ ├── test_async_logger.py
│ │ ├── test_contexts.py
│ │ ├── test_datasets.py
│ │ ├── test_logs.py
│ │ └── test_projects.py
├── test_routing
│ ├── __init__.py
│ ├── test_fallbacks.py
│ └── test_routing_syntax.py
├── test_universal_api
│ ├── __init__.py
│ ├── test_basics.py
│ ├── test_chatbot.py
│ ├── test_json_mode.py
│ ├── test_multi_llm.py
│ ├── test_stateful.py
│ ├── test_usage.py
│ ├── test_user_input.py
│ └── test_utils
│ │ ├── __init__.py
│ │ ├── test_credits.py
│ │ ├── test_custom_api_keys.py
│ │ ├── test_custom_endpoints.py
│ │ ├── test_endpoint_metrics.py
│ │ ├── test_supported_endpoints.py
│ │ └── test_usage.py
└── test_utils
│ ├── __init__.py
│ ├── helpers.py
│ ├── test_caching.py
│ └── test_map.py
└── unify
├── __init__.py
├── logging
├── __init__.py
├── dataset.py
├── logs.py
└── utils
│ ├── __init__.py
│ ├── artifacts.py
│ ├── async_logger.py
│ ├── compositions.py
│ ├── contexts.py
│ ├── datasets.py
│ ├── logs.py
│ ├── projects.py
│ └── tracing.py
├── universal_api
├── __init__.py
├── casting.py
├── chatbot.py
├── clients
│ ├── __init__.py
│ ├── base.py
│ ├── helpers.py
│ ├── multi_llm.py
│ └── uni_llm.py
├── types
│ ├── __init__.py
│ └── prompt.py
├── usage.py
└── utils
│ ├── __init__.py
│ ├── credits.py
│ ├── custom_api_keys.py
│ ├── custom_endpoints.py
│ ├── endpoint_metrics.py
│ ├── queries.py
│ └── supported_endpoints.py
└── utils
├── __init__.py
├── _caching.py
├── _requests.py
├── helpers.py
└── map.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-complexity = 6
3 | inline-quotes = double
4 | max-line-length = 88
5 | extend-ignore = E203
6 | docstring_style = sphinx
7 |
8 | ignore =
9 | ; Found `f` string
10 | WPS305,
11 | ; Missing docstring in public module
12 | D100,
13 | ; Missing docstring in magic method
14 | D105,
15 | ; Missing docstring in __init__
16 | D107,
17 | ; Found `__init__.py` module with logic
18 | WPS412,
19 | ; Found class without a base class
20 | WPS306,
21 | ; Missing docstring in public nested class
22 | D106,
23 | ; First line should be in imperative mood
24 | D401,
25 | ; Found wrong variable name
26 | WPS110,
27 | ; Found `__init__.py` module with logic
28 | WPS326,
29 | ; Found string constant over-use
30 | WPS226,
31 | ; Found upper-case constant in a class
32 | WPS115,
33 | ; Found nested function
34 | WPS602,
35 | ; Found method without arguments
36 | WPS605,
37 | ; Found overused expression
38 | WPS204,
39 | ; Found too many module members
40 | WPS202,
41 | ; Found too high module cognitive complexity
42 | WPS232,
43 | ; line break before binary operator
44 | W503,
45 | ; Found module with too many imports
46 | WPS201,
47 | ; Inline strong start-string without end-string.
48 | RST210,
49 | ; Found nested class
50 | WPS431,
51 | ; Found wrong module name
52 | WPS100,
53 | ; Found too many methods
54 | WPS214,
55 | ; Found too long ``try`` body
56 | WPS229,
57 | ; Found unpythonic getter or setter
58 | WPS615,
59 | ; Found a line that starts with a dot
60 | WPS348,
61 | ; Found complex default value (for dependency injection)
62 | WPS404,
63 | ; not perform function calls in argument defaults (for dependency injection)
64 | B008,
65 | ; Model should define verbose_name in its Meta inner class
66 | DJ10,
67 | ; Model should define verbose_name_plural in its Meta inner class
68 | DJ11,
69 | ; Found mutable module constant.
70 | WPS407,
71 | ; Found too many empty lines in `def`
72 | WPS473,
73 | ; too many no-cover comments.
74 | WPS403,
75 | ; Found `noqa` comments overuse
76 | WPS402,
77 | ; Found protected attribute usage
78 | WPS437,
79 | ; Found too short name
80 | WPS111,
81 | ; Found error raising from itself
82 | WPS469,
83 | ; Found a too complex `f` string
84 | WPS237,
85 | ; Imports not being at top-level
86 | E402
87 | ; complex function names
88 | C901
89 | ; missing docstring in public package
90 | D104
91 |
92 | per-file-ignores =
93 | ; all tests
94 | test_*.py,tests.py,tests_*.py,*/tests/*,conftest.py:
95 | ; Use of assert detected
96 | S101,
97 | ; Found outer scope names shadowing
98 | WPS442,
99 | ; Found too many local variables
100 | WPS210,
101 | ; Found magic number
102 | WPS432,
103 | ; Missing parameter(s) in Docstring
104 | DAR101,
105 | ; Found too many arguments
106 | WPS211,
107 | ; Missing docstring in public class
108 | D101,
109 | ; Missing docstring in public method
110 | D102,
111 | ; Found too long name
112 | WPS118,
113 |
114 | ; all init files
115 | __init__.py:
116 | ; ignore not used imports
117 | F401,
118 | ; ignore import with wildcard
119 | F403,
120 | ; Found wrong metadata variable
121 | WPS410,
122 |
123 | ; exceptions file
124 | unify/exceptions.py:
125 | ; Found wrong keyword
126 | WPS420,
127 | ; Found incorrect node inside `class` body
128 | WPS604,
129 |
130 | exclude =
131 | ./.cache,
132 | ./.git,
133 | ./.idea,
134 | ./.mypy_cache,
135 | ./.pytest_cache,
136 | ./.venv,
137 | ./venv,
138 | ./env,
139 | ./cached_venv,
140 | ./docs,
141 | ./deploy,
142 | ./var,
143 | ./.vscode,
144 | ./generate_docs.py
145 |
--------------------------------------------------------------------------------
/.github/bump_version.py:
--------------------------------------------------------------------------------
1 | import re
2 | import subprocess
3 |
4 | import requests
5 |
6 | pypi_rss_url = "https://pypi.org/rss/project/unifyai/releases.xml"
7 |
8 | if __name__ == "__main__":
9 | # Get the latest version from the RSS feed
10 | response = requests.get(pypi_rss_url)
11 | version = re.findall(r"\d+\.\d+\.\d+", response.text)[0]
12 | if not version:
13 | raise Exception("Failed parsing version")
14 |
15 | # Bump patch version
16 | subprocess.run(["poetry", "version", version])
17 | subprocess.run(["poetry", "version", "patch"])
18 |
--------------------------------------------------------------------------------
/.github/workflows/pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish release to PyPI
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | paths:
8 | - 'unify/**'
9 | - 'LICENSE'
10 | - 'README.md'
11 | - 'pyproject.toml'
12 | - 'poetry.lock'
13 |
14 | jobs:
15 | publish_pypi_release:
16 | runs-on: ubuntu-latest
17 | environment: pypi
18 | continue-on-error: false
19 |
20 | steps:
21 | - name: Checkout repository
22 | uses: actions/checkout@v4
23 |
24 | - name: Install Python
25 | uses: actions/setup-python@v5
26 | with:
27 | python-version: '3.9'
28 |
29 | - name: Install Poetry
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install poetry
33 |
34 | - name: Bump package version
35 | run: |
36 | python .github/bump_version.py
37 |
38 | - name: Build package
39 | run: poetry build
40 |
41 | - name: Publish to TestPyPI
42 | run: |
43 | poetry config repositories.testpypi https://test.pypi.org/legacy/
44 | poetry publish --skip-existing -r testpypi -u __token__ -p ${{ secrets.TEST_PYPI_API_KEY }}
45 |
46 | - name: Publish to PyPI
47 | run: |
48 | poetry publish -u __token__ -p ${{ secrets.PYPI_API_KEY }}
49 |
--------------------------------------------------------------------------------
/.github/workflows/test-demos.yml:
--------------------------------------------------------------------------------
1 | name: Test demos
2 | on:
3 | workflow_dispatch:
4 | # push:
5 | # branches:
6 | # - main
7 | # permissions:
8 | # contents: write
9 | # actions: read
10 | # id-token: write
11 |
12 | jobs:
13 | test-demos:
14 | name: Test demos
15 | uses: unifyai/workflows/.github/workflows/test-demos.yml@main
16 | secrets:
17 | GITHUB_TOKEN: ${{ secrets.CONSOLE_TOKEN }}
18 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Testing unify
2 |
3 | on: push
4 |
5 | jobs:
6 | black:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v2
10 | - name: Set up Python
11 | uses: actions/setup-python@v2
12 | with:
13 | python-version: '3.9'
14 | - name: Install deps
15 | uses: knowsuchagency/poetry-install@v1
16 | env:
17 | POETRY_VIRTUALENVS_CREATE: false
18 | - name: Run black check
19 | run: poetry run black --check .
20 |
21 | pytest:
22 | runs-on: ubuntu-latest
23 | environment: unify-testing
24 | timeout-minutes: 120
25 | steps:
26 | - uses: actions/checkout@v2
27 | - name: Set up Python
28 | uses: actions/setup-python@v2
29 | with:
30 | python-version: '3.9'
31 | - name: Install deps
32 | run: |
33 | python -m pip install --upgrade pip
34 | pip install poetry
35 | poetry install --with dev
36 | - name: Run unit tests
37 | run: poetry run pytest --timeout=120 -p no:warnings -vv .
38 | env:
39 | UNIFY_KEY: ${{ secrets.USER_API_KEY }}
40 |
--------------------------------------------------------------------------------
/.github/workflows/update-docs.yml:
--------------------------------------------------------------------------------
1 | name: Publish docs
2 | on:
3 | push:
4 | branches:
5 | - main
6 | workflow_dispatch:
7 | permissions:
8 | contents: write
9 |
10 | jobs:
11 | publish-docs:
12 | name: Update docs
13 | uses: unifyai/workflows/.github/workflows/publish-docs.yml@main
14 | secrets: inherit
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Python template
2 |
3 | .idea/
4 | .vscode/
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | *.sqlite3
66 | *.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .env.tmp
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | # Pyre type checker
137 | .pyre/
138 |
139 | # pytype static type analyzer
140 | .pytype/
141 |
142 | # Cython debug symbols
143 | cython_debug/
144 |
145 | # Cloud SDKs
146 | google-cloud-sdk/
147 | google-cloud-cli-*.tar.gz
148 |
149 | # SQL Dump files
150 | *.dump.sql
151 |
152 | # Binaries
153 | bin/
154 |
155 | # Tests
156 | new_module.py
157 | test.py
158 | implementations.py
159 | .test_cache.json
160 | *.cache.json
161 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: check-ast
8 | - id: trailing-whitespace
9 | - id: check-toml
10 | - id: end-of-file-fixer
11 |
12 | - repo: https://github.com/asottile/add-trailing-comma
13 | rev: v3.1.0
14 | hooks:
15 | - id: add-trailing-comma
16 |
17 | - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
18 | rev: v2.14.0
19 | hooks:
20 | - id: pretty-format-yaml
21 | args:
22 | - --autofix
23 | - --preserve-quotes
24 | - --indent=2
25 |
26 | - repo: local
27 | hooks:
28 | - id: autoflake
29 | name: autoflake
30 | entry: poetry run autoflake
31 | language: system
32 | types: [python]
33 | args: [--in-place, --remove-all-unused-imports, --remove-duplicate-keys]
34 | exclude: |
35 | (?x)^(
36 | .*/__init__\.py # Exclude all __init__.py files
37 | )$
38 |
39 | - id: black
40 | name: Format with Black
41 | entry: poetry run black
42 | language: system
43 | types: [python]
44 |
45 | - id: isort
46 | name: isort
47 | entry: poetry run isort
48 | language: system
49 | types: [python]
50 | exclude: |
51 | (?x)^(
52 | .*/__init__\.py # Exclude all __init__.py files
53 | )$
54 |
--------------------------------------------------------------------------------
/.run/pycharm/Python tests for test_interfaces.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/.run/pycharm/Python tests for test_routing.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/.run/pycharm/Python tests for test_universal_api.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/.run/pycharm/Python tests for test_utils.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/.run/pycharm/Template Python tests.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/.run/pycharm/Template Python.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/.run/pycharm/generate_docs.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | ----
10 |
11 | 
12 | 
13 | 
14 |
15 |
20 |
21 | **Fully hackable** LLMOps. Build *custom* interfaces for: logging, evals, guardrails, labelling, tracing, agents, human-in-the-loop, hyperparam sweeps, and anything else you can think of ✨
22 |
23 | Just `unify.log` your data, and add an interface using the four building blocks:
24 |
25 | 1. **tables** 🔢
26 | 2. **views** 🔍
27 | 3. **plots** 📊
28 | 4. **editor** 🕹️ (coming soon)
29 |
30 | Every LLM product has **unique** and **changing** requirements, as do the **users**. Your infra should reflect this!
31 |
32 | We've tried to make Unify as **(a) simple**, **(b) modular** and **(c) hackable** as possible, so you can quickly probe, analyze, and iterate on the data that's important for **you**, your **product** and your **users** ⚡
33 |
34 | ## Quickstart
35 |
36 | [Sign up](https://console.unify.ai/), `pip install unifyai`, run your first eval ⬇️, and then check out the logs in your first [interface](https://console.unify.ai) 📊
37 |
38 | ```python
39 | import unify
40 | from random import randint, choice
41 |
42 | # initialize project
43 | unify.activate("Maths Assistant")
44 |
45 | # build agent
46 | client = unify.Unify("o3-mini@openai", traced=True)
47 | client.set_system_message(
48 | "You are a helpful maths assistant, "
49 | "tasked with adding and subtracting integers."
50 | )
51 |
52 | # add test cases
53 | qs = [
54 | f"{randint(0, 100)} {choice(['+', '-'])} {randint(0, 100)}"
55 | for i in range(10)
56 | ]
57 |
58 | # define evaluator
59 | @unify.traced
60 | def evaluate_response(question: str, response: str) -> float:
61 | correct_answer = eval(question)
62 | try:
63 | response_int = int(
64 | "".join(
65 | [
66 | c for c in response.split(" ")[-1]
67 | if c.isdigit()
68 | ]
69 | ),
70 | )
71 | return float(correct_answer == response_int)
72 | except ValueError:
73 | return 0.
74 |
75 | # define evaluation
76 | @unify.traced
77 | def evaluate(q: str):
78 | response = client.copy().generate(q)
79 | score = evaluate_response(q, response)
80 | unify.log(
81 | question=q,
82 | response=response,
83 | score=score
84 | )
85 |
86 | # execute + log your evaluation
87 | with unify.Experiment():
88 | unify.map(evaluate, qs)
89 | ```
90 |
91 | Check out our [Quickstart Video](https://youtu.be/fl9SzsoCegw?si=MhQZDfNS6U-ZsVYc) for a guided walkthrough.
92 |
93 | ## Focus on your *product*, not the *LLM* 🎯
94 |
95 | Despite all of the hype, abstractions, and jargon, the *process* for building quality LLM apps is pretty simple.
96 |
97 | ```
98 | create simplest possible agent 🤖
99 | while True:
100 | create/expand unit tests (evals) 🗂️
101 | while run(tests) failing: 🧪
102 | Analyze failures, understand the root cause 🔍
103 | Vary system prompt, in-context examples, tools etc. to rectify 🔀
104 | Beta test with users, find more failures 🚦
105 | ```
106 |
107 | We've tried to strip away all of the excessive LLM jargon, so you can focus on your *product*, your *users*, and the *data* you care about, and *nothing else* 📈
108 |
109 | Unify takes inspiration from:
110 | - [PostHog](https://posthog.com/) / [Grafana](https://grafana.com/) / [LogFire](https://pydantic.dev/logfire) for powerful observability 🔬
111 | - [LangSmith](https://www.langchain.com/langsmith) / [BrainTrust](https://www.braintrust.dev/) / [Weave](https://wandb.ai/site/weave/) for LLM abstractions 🤖
112 | - [Notion](https://www.notion.com/) / [Airtable](https://www.airtable.com/) for composability and versatility 🧱
113 |
114 | Whether you're technical or non-technical, we hope Unify can help you to rapidly build top-notch LLM apps, and to remain fully focused on your *product* (not the *LLM*).
115 |
116 | ## Learn More
117 |
118 | Check out our [docs](https://docs.unify.ai/), and if you have any questions feel free to reach out to us on [discord](https://discord.com/invite/sXyFF8tDtm) 👾
119 |
120 | Unify is under active development 🚧, feedback in all shapes/sizes is also very welcome! 🙏
121 |
122 | Happy prompting! 🧑💻
123 |
--------------------------------------------------------------------------------
/pydoc-markdown.yml:
--------------------------------------------------------------------------------
1 | loader:
2 | - type: package
3 | packages:
4 | - unify # Replace with your package name
5 |
6 | renderer:
7 | type: markdown
8 | options:
9 | output: docs # Output directory for generated Markdown files
10 | template: |
11 | # {{ module_name }}
12 |
13 | {{ module_doc }}
14 |
15 | ## Classes
16 |
17 | {{ classes }}
18 |
19 | ## Functions
20 |
21 | {{ functions }}
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "unifyai"
3 | packages = [{include = "unify"}]
4 | version = "0.9.10"
5 | readme = "README.md"
6 | description = "A Python package for interacting with the Unify API"
7 | authors = ["Unify "]
8 | repository = "https://github.com/unifyai/unify"
9 |
10 | [tool.poetry.dependencies]
11 | python = "^3.9"
12 | requests = "^2.31.0"
13 | requests-toolbelt = "^1.0.0"
14 | openai = "^1.47.0"
15 | jsonlines = "^4.0.0"
16 | rich = "^13.8.1"
17 | pytest = "^8.3.3"
18 | pytest-timeout = "^2.3.1"
19 | pytest-asyncio = "^0.24.0"
20 | termcolor ="2.5.0"
21 | aiohttp = "^3.11.12"
22 |
23 | [tool.poetry.group.dev.dependencies]
24 | types-requests = "*"
25 | flake8 = "~4.0.1"
26 | mypy = "^1.1.1"
27 | isort = "^5.11.4"
28 | pre-commit = "^3.0.1"
29 | wemake-python-styleguide = "^0.17.0"
30 | black = "^24.3.0"
31 | autoflake = "^1.6.1"
32 | pydoc-markdown = "^4.0.0"
33 |
34 | [tool.isort]
35 | profile = "black"
36 | multi_line_output = 3
37 | src_paths = ["orchestra",]
38 |
39 | [tool.mypy]
40 | strict = true
41 | ignore_missing_imports = true
42 | allow_subclassing_any = true
43 | allow_untyped_calls = true
44 | pretty = true
45 | show_error_codes = true
46 | implicit_reexport = true
47 | allow_untyped_decorators = true
48 | warn_unused_ignores = false
49 | warn_return_any = false
50 | namespace_packages = true
51 |
52 | [build-system]
53 | requires = ["poetry-core>=1.0.0"]
54 | build-backend = "poetry.core.masonry.api"
55 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | pytest_plugins = ("pytest_asyncio",)
2 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import unify
4 |
5 |
6 | def pytest_sessionstart(session):
7 | if os.environ.get("CI"):
8 | unify.delete_logs()
9 |
--------------------------------------------------------------------------------
/tests/test_logging/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_logging/__init__.py
--------------------------------------------------------------------------------
/tests/test_logging/helpers.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import functools
3 | import sys
4 | import traceback
5 |
6 | import unify
7 |
8 |
9 | def _handle_project(test_fn):
10 | # noinspection PyBroadException
11 | @functools.wraps(test_fn)
12 | def wrapper(*args, **kwargs):
13 | project = test_fn.__name__
14 | if project in unify.list_projects():
15 | unify.delete_project(project)
16 | try:
17 | with unify.Project(project):
18 | test_fn(*args, **kwargs)
19 | unify.delete_project(project)
20 | except:
21 | unify.delete_project(project)
22 | exc_type, exc_value, exc_tb = sys.exc_info()
23 | tb_string = "".join(traceback.format_exception(exc_type, exc_value, exc_tb))
24 | raise Exception(f"{tb_string}")
25 |
26 | @functools.wraps(test_fn)
27 | async def async_wrapper(*args, **kwargs):
28 | project = test_fn.__name__
29 | if project in unify.list_projects():
30 | unify.delete_project(project)
31 | try:
32 | with unify.Project(project):
33 | await test_fn(*args, **kwargs)
34 | unify.delete_project(project)
35 | except:
36 | unify.delete_project(project)
37 | exc_type, exc_value, exc_tb = sys.exc_info()
38 | tb_string = "".join(traceback.format_exception(exc_type, exc_value, exc_tb))
39 | raise Exception(f"{tb_string}")
40 |
41 | return async_wrapper if asyncio.iscoroutinefunction(test_fn) else wrapper
42 |
--------------------------------------------------------------------------------
/tests/test_logging/test_projects.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import unify
4 |
5 |
6 | def test_set_project():
7 | unify.deactivate()
8 | assert unify.active_project() is None
9 | unify.activate("my_project")
10 | assert unify.active_project() == "my_project"
11 | unify.deactivate()
12 |
13 |
14 | def test_unset_project():
15 | unify.deactivate()
16 | assert unify.active_project() is None
17 | unify.activate("my_project")
18 | assert unify.active_project() == "my_project"
19 | unify.deactivate()
20 | assert unify.active_project() is None
21 |
22 |
23 | def test_with_project():
24 | unify.deactivate()
25 | assert unify.active_project() is None
26 | with unify.Project("my_project"):
27 | assert unify.active_project() == "my_project"
28 | assert unify.active_project() is None
29 |
30 |
31 | def test_set_project_then_log():
32 | unify.deactivate()
33 | assert unify.active_project() is None
34 | unify.activate("test_set_project_then_log")
35 | assert unify.active_project() == "test_set_project_then_log"
36 | unify.log(key=1.0)
37 | unify.deactivate()
38 | assert unify.active_project() is None
39 | unify.delete_project("test_set_project_then_log")
40 |
41 |
42 | def test_with_project_then_log():
43 | unify.deactivate()
44 | assert unify.active_project() is None
45 | with unify.Project("test_with_project_then_log"):
46 | assert unify.active_project() == "test_with_project_then_log"
47 | unify.log(key=1.0)
48 | assert unify.active_project() is None
49 | unify.delete_project("test_with_project_then_log")
50 |
51 |
52 | def test_project_env_var():
53 | unify.deactivate()
54 | assert unify.active_project() is None
55 | os.environ["UNIFY_PROJECT"] = "test_project_env_var"
56 | assert unify.active_project() == "test_project_env_var"
57 | unify.delete_logs(project="test_project_env_var")
58 | unify.log(x=0, y=1, z=2)
59 | del os.environ["UNIFY_PROJECT"]
60 | assert unify.active_project() is None
61 | try:
62 | logs = unify.get_logs(project="test_project_env_var")
63 | assert len(logs) == 1
64 | assert logs[0].entries == {"x": 0, "y": 1, "z": 2}
65 | finally:
66 | unify.delete_project("test_project_env_var")
67 |
68 |
69 | if __name__ == "__main__":
70 | pass
71 |
--------------------------------------------------------------------------------
/tests/test_logging/test_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_logging/test_utils/__init__.py
--------------------------------------------------------------------------------
/tests/test_logging/test_utils/test_artifacts.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 |
4 | def test_artifacts():
5 | project = "my_project"
6 | if project in unify.list_projects():
7 | unify.delete_project(project)
8 | unify.create_project(project)
9 | artifacts = {"dataset": "my_dataset", "description": "this is my dataset"}
10 | assert len(unify.get_project_artifacts(project=project)) == 0
11 | unify.add_project_artifacts(project=project, **artifacts)
12 | assert "dataset" in unify.get_project_artifacts(project=project)
13 | unify.delete_project_artifact(
14 | "dataset",
15 | project=project,
16 | )
17 | assert "dataset" not in unify.get_project_artifacts(project=project)
18 | assert "description" in unify.get_project_artifacts(project=project)
19 |
20 |
21 | if __name__ == "__main__":
22 | pass
23 |
--------------------------------------------------------------------------------
/tests/test_logging/test_utils/test_async_logger.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 | from ..helpers import _handle_project
4 |
5 |
6 | @_handle_project
7 | def test_async_logger():
8 | try:
9 | logs_sync = [unify.log(x=i, y=i * 2, z=i * 3) for i in range(10)]
10 |
11 | unify.initialize_async_logger()
12 | logs_async = [unify.log(x=i, y=i * 2, z=i * 3) for i in range(10)]
13 | unify.shutdown_async_logger()
14 |
15 | assert len(logs_async) == len(logs_sync)
16 | for log_async, log_sync in zip(
17 | sorted(logs_async, key=lambda x: x.entries["x"]),
18 | sorted(logs_sync, key=lambda x: x.entries["x"]),
19 | ):
20 | assert log_async.entries == log_sync.entries
21 | assert unify.ASYNC_LOGGING == False
22 | except Exception as e:
23 | unify.shutdown_async_logger()
24 | raise e
25 |
26 |
27 | if __name__ == "__main__":
28 | pass
29 |
--------------------------------------------------------------------------------
/tests/test_logging/test_utils/test_contexts.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 | from ..helpers import _handle_project
4 |
5 |
6 | @_handle_project
7 | def test_create_context():
8 | assert len(unify.get_contexts()) == 0
9 | unify.create_context("my_context")
10 | assert len(unify.get_contexts()) == 1
11 | assert "my_context" in unify.get_contexts()
12 |
13 |
14 | @_handle_project
15 | def test_get_contexts():
16 | assert len(unify.get_contexts()) == 0
17 | unify.log(x=0, context="a/b")
18 | unify.log(x=1, context="a/b")
19 | unify.log(x=0, context="b/c")
20 | unify.log(x=1, context="b/c")
21 | contexts = unify.get_contexts()
22 | assert len(contexts) == 2
23 | assert "a/b" in contexts
24 | assert "b/c" in contexts
25 | contexts = unify.get_contexts(prefix="a")
26 | assert len(contexts) == 1
27 | assert "a/b" in contexts
28 | assert "a/c" not in contexts
29 | contexts = unify.get_contexts(prefix="b")
30 | assert len(contexts) == 1
31 | assert "b/c" in contexts
32 | assert "a/b" not in contexts
33 |
34 |
35 | @_handle_project
36 | def test_delete_context():
37 | unify.log(x=0, context="a/b")
38 | contexts = unify.get_contexts()
39 | assert len(contexts) == 1
40 | assert "a/b" in contexts
41 | unify.delete_context("a/b")
42 | assert "a/b" not in unify.get_contexts()
43 | assert len(unify.get_logs()) == 0
44 |
45 |
46 | @_handle_project
47 | def test_add_logs_to_context():
48 | l0 = unify.log(x=0, context="a/b")
49 | l1 = unify.log(x=1, context="a/b")
50 | l2 = unify.log(x=2, context="b/c")
51 | l3 = unify.log(x=3, context="b/c")
52 | unify.add_logs_to_context(log_ids=[l0.id, l1.id], context="b/c")
53 | assert len(unify.get_logs(context="a/b")) == 2
54 | assert unify.get_logs(context="a/b", return_ids_only=True) == [l1.id, l0.id]
55 | assert len(unify.get_logs(context="b/c")) == 4
56 | assert unify.get_logs(context="b/c", return_ids_only=True) == [
57 | l3.id,
58 | l2.id,
59 | l1.id,
60 | l0.id,
61 | ]
62 |
63 |
64 | @_handle_project
65 | def test_rename_context():
66 | unify.log(x=0, context="a/b")
67 | unify.rename_context("a/b", "a/c")
68 | contexts = unify.get_contexts()
69 | assert "a/b" not in contexts
70 | assert "a/c" in contexts
71 | logs = unify.get_logs(context="a/c")
72 | assert len(logs) == 1
73 | assert logs[0].context == "a/c"
74 |
75 |
76 | @_handle_project
77 | def test_get_context():
78 | name = "foo"
79 | desc = "my_description"
80 | is_versioned = True
81 | allow_duplicates = True
82 | unify.create_context(
83 | name,
84 | description=desc,
85 | is_versioned=is_versioned,
86 | allow_duplicates=allow_duplicates,
87 | )
88 |
89 | context = unify.get_context(name)
90 | assert context["name"] == name
91 | assert context["description"] == desc
92 | assert context["is_versioned"] is is_versioned
93 | assert context["allow_duplicates"] is allow_duplicates
94 |
95 |
96 | if __name__ == "__main__":
97 | pass
98 |
--------------------------------------------------------------------------------
/tests/test_logging/test_utils/test_datasets.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 | from ..helpers import _handle_project
4 |
5 |
6 | @_handle_project
7 | def test_list_datasets():
8 | assert len(unify.list_datasets()) == 0
9 | unify.log(x=0, context="Datasets/Prod/TestSet")
10 | unify.log(x=1, context="Datasets/Prod/TestSet")
11 | unify.log(x=0, context="Datasets/Eval/ValidationSet")
12 | unify.log(x=1, context="Datasets/Eval/ValidationSet")
13 | datasets = unify.list_datasets()
14 | assert len(datasets) == 2
15 | assert "Prod/TestSet" in datasets
16 | assert "Eval/ValidationSet" in datasets
17 | datasets = unify.list_datasets(prefix="Prod")
18 | assert len(datasets) == 1
19 | assert "Prod/TestSet" in datasets
20 | assert "Eval/ValidationSet" not in datasets
21 | datasets = unify.list_datasets(prefix="Eval")
22 | assert len(datasets) == 1
23 | assert "Eval/ValidationSet" in datasets
24 | assert "Prod/TestSet" not in datasets
25 |
26 |
27 | @_handle_project
28 | def test_upload_dataset():
29 | dataset = [
30 | {
31 | "name": "Dan",
32 | "age": 31,
33 | "gender": "male",
34 | },
35 | {
36 | "name": "Jane",
37 | "age": 25,
38 | "gender": "female",
39 | },
40 | {
41 | "name": "John",
42 | "age": 35,
43 | "gender": "male",
44 | },
45 | ]
46 | data = unify.upload_dataset("staff", dataset)
47 | assert len(data) == 3
48 |
49 |
50 | @_handle_project
51 | def test_add_dataset_entries():
52 | dataset = [
53 | {
54 | "name": "Dan",
55 | "age": 31,
56 | "gender": "male",
57 | },
58 | {
59 | "name": "Jane",
60 | "age": 25,
61 | "gender": "female",
62 | },
63 | {
64 | "name": "John",
65 | "age": 35,
66 | "gender": "male",
67 | },
68 | ]
69 | ids = unify.upload_dataset("staff", dataset)
70 | assert len(ids) == 3
71 | dataset = unify.download_dataset("staff")
72 | assert len(dataset) == 3
73 | assert dataset[0].entries["name"] == "Dan"
74 | assert dataset[1].entries["name"] == "Jane"
75 | assert dataset[2].entries["name"] == "John"
76 | new_entries = [
77 | {
78 | "name": "Chloe",
79 | "age": 28,
80 | "gender": "female",
81 | },
82 | {
83 | "name": "Tom",
84 | "age": 32,
85 | "gender": "male",
86 | },
87 | ]
88 | ids = unify.add_dataset_entries("staff", new_entries)
89 | assert len(ids) == 2
90 | dataset = unify.download_dataset("staff")
91 | assert len(dataset) == 5
92 | assert dataset[0].entries["name"] == "Dan"
93 | assert dataset[1].entries["name"] == "Jane"
94 | assert dataset[2].entries["name"] == "John"
95 | assert dataset[3].entries["name"] == "Chloe"
96 | assert dataset[4].entries["name"] == "Tom"
97 |
98 |
99 | @_handle_project
100 | def test_download_dataset():
101 | dataset = [
102 | {
103 | "name": "Dan",
104 | "age": 31,
105 | "gender": "male",
106 | },
107 | {
108 | "name": "Jane",
109 | "age": 25,
110 | "gender": "female",
111 | },
112 | {
113 | "name": "John",
114 | "age": 35,
115 | "gender": "male",
116 | },
117 | ]
118 | unify.upload_dataset("staff", dataset)
119 | data = unify.download_dataset("staff")
120 | assert len(data) == 3
121 |
122 |
123 | @_handle_project
124 | def test_delete_dataset():
125 | dataset = [
126 | {
127 | "name": "Dan",
128 | "age": 31,
129 | "gender": "male",
130 | },
131 | {
132 | "name": "Jane",
133 | "age": 25,
134 | "gender": "female",
135 | },
136 | {
137 | "name": "John",
138 | "age": 35,
139 | "gender": "male",
140 | },
141 | ]
142 | unify.upload_dataset("staff", dataset)
143 | unify.delete_dataset("staff")
144 | assert "staff" not in unify.list_datasets()
145 |
146 |
147 | if __name__ == "__main__":
148 | pass
149 |
--------------------------------------------------------------------------------
/tests/test_logging/test_utils/test_projects.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 | from ..helpers import _handle_project
4 |
5 |
6 | def test_project():
7 | name = "my_project"
8 | if name in unify.list_projects():
9 | unify.delete_project(name)
10 | assert name not in unify.list_projects()
11 | unify.create_project(name)
12 | assert name in unify.list_projects()
13 | new_name = "my_project1"
14 | unify.rename_project(name, new_name)
15 | assert new_name in unify.list_projects()
16 | unify.delete_project(new_name)
17 | assert new_name not in unify.list_projects()
18 |
19 |
20 | def test_project_thread_lock():
21 | # all 10 threads would try to create the project at the same time without
22 | # thread locking, but only one should acquire the lock, and this should pass
23 | unify.map(
24 | unify.log,
25 | project="test_project",
26 | a=[1] * 10,
27 | b=[2] * 10,
28 | c=[3] * 10,
29 | from_args=True,
30 | )
31 | unify.delete_project("test_project")
32 |
33 |
34 | @_handle_project
35 | def test_delete_project_logs():
36 | [unify.log(x=i) for i in range(10)]
37 | assert len(unify.get_logs()) == 10
38 | unify.delete_project_logs("test_delete_project_logs")
39 | assert len(unify.get_logs()) == 0
40 | assert "test_delete_project_logs" in unify.list_projects()
41 |
42 |
43 | @_handle_project
44 | def test_delete_project_contexts():
45 | unify.create_context("foo")
46 | unify.create_context("bar")
47 |
48 | assert len(unify.get_contexts()) == 2
49 | unify.delete_project_contexts("test_delete_project_contexts")
50 |
51 | assert len(unify.get_contexts()) == 0
52 | assert "test_delete_project_contexts" in unify.list_projects()
53 |
54 |
55 | if __name__ == "__main__":
56 | pass
57 |
--------------------------------------------------------------------------------
/tests/test_routing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_routing/__init__.py
--------------------------------------------------------------------------------
/tests/test_routing/test_fallbacks.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 |
4 | def test_provider_fallback():
5 | unify.Unify("claude-3-opus@anthropic->aws-bedrock").generate("Hello.")
6 |
7 |
8 | def test_model_fallback():
9 | unify.Unify("gemini-1.5-pro->gemini-1.5-flash@vertex-ai").generate("Hello.")
10 |
11 |
12 | def test_endpoint_fallback():
13 | unify.Unify(
14 | "llama-3.1-405b-chat@together-ai->gpt-4o@openai",
15 | ).generate("Hello.")
16 |
17 |
18 | if __name__ == "__main__":
19 | pass
20 |
--------------------------------------------------------------------------------
/tests/test_routing/test_routing_syntax.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 | # Meta Providers #
4 | # ---------------#
5 |
6 |
7 | def test_ttft():
8 | for pre in ("", "lowest-"):
9 | unify.Unify(f"claude-3-opus@{pre}time-to-first-token").generate("Hello.")
10 | unify.Unify(f"gpt-4o@{pre}ttft").generate("Hello.")
11 | unify.Unify(f"mixtral-8x22b-instruct-v0.1@{pre}t").generate("Hello.")
12 |
13 |
14 | def test_itl():
15 | for pre in ("", "lowest-"):
16 | unify.Unify(f"claude-3-opus@{pre}inter-token-latency").generate("Hello.")
17 | unify.Unify(f"gpt-4o@{pre}itl").generate("Hello.")
18 | unify.Unify(f"mixtral-8x22b-instruct-v0.1@{pre}i").generate("Hello.")
19 |
20 |
21 | def test_cost():
22 | for pre in ("", "lowest-"):
23 | unify.Unify(f"claude-3-opus@{pre}cost").generate("Hello.")
24 | unify.Unify(f"mixtral-8x22b-instruct-v0.1@{pre}c").generate("Hello.")
25 |
26 |
27 | def test_input_cost():
28 | for pre in ("", "lowest-"):
29 | unify.Unify(f"claude-3-opus@{pre}input-cost").generate("Hello.")
30 | unify.Unify(f"gpt-4o@{pre}ic").generate("Hello.")
31 | unify.Unify(f"mixtral-8x22b-instruct-v0.1@{pre}i").generate("Hello.")
32 |
33 |
34 | def test_output_cost():
35 | for pre in ("", "lowest-"):
36 | unify.Unify(f"claude-3-opus@{pre}output-cost").generate("Hello.")
37 | unify.Unify(f"mixtral-8x22b-instruct-v0.1@{pre}oc").generate("Hello.")
38 |
39 |
40 | # Thresholds #
41 | # -----------#
42 |
43 |
44 | def test_thresholds():
45 | unify.Unify("llama-3.1-405b-chat@inter-token-latency|c<5").generate("Hello.")
46 |
47 |
48 | # Search Space #
49 | # -------------#
50 |
51 |
52 | def test_routing_w_providers():
53 | unify.Unify(
54 | "llama-3.1-405b-chat@itl|providers:groq,fireworks-ai,together-ai",
55 | ).generate("Hello.")
56 |
57 |
58 | def test_routing_skip_providers():
59 | unify.Unify(
60 | "llama-3.1-405b-chat@itl|skip_providers:vertex-ai,aws-bedrock",
61 | ).generate("Hello.")
62 |
63 |
64 | if __name__ == "__main__":
65 | pass
66 |
--------------------------------------------------------------------------------
/tests/test_universal_api/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_universal_api/__init__.py
--------------------------------------------------------------------------------
/tests/test_universal_api/test_basics.py:
--------------------------------------------------------------------------------
1 | import json
2 | from types import AsyncGeneratorType, GeneratorType
3 |
4 | import pytest
5 | from openai.types.chat import ParsedChatCompletion
6 | from pydantic import BaseModel
7 | from unify import AsyncUnify, Unify
8 |
9 |
10 | class Response(BaseModel):
11 | number: int
12 |
13 |
14 | class TestUnifyBasics:
15 | def test_invalid_api_key_raises_authentication_error(self) -> None:
16 | with pytest.raises(Exception):
17 | client = Unify(
18 | api_key="invalid_api_key",
19 | endpoint="gpt-4o@openai",
20 | )
21 | client.generate(user_message="hello")
22 |
23 | def test_incorrect_model_name_raises_internal_server_error(self) -> None:
24 | with pytest.raises(Exception):
25 | Unify(model="wong-model-name")
26 |
27 | def test_generate_returns_string_when_stream_false(self) -> None:
28 | client = Unify(
29 | endpoint="gpt-4o@openai",
30 | )
31 | result = client.generate(user_message="hello", stream=False)
32 | assert isinstance(result, str)
33 |
34 | def test_traced_and_cached(self) -> None:
35 | client = Unify(
36 | endpoint="gpt-4o@openai",
37 | traced=True,
38 | cache=True,
39 | )
40 | client.generate("hello")
41 |
42 | def test_copy_client(self) -> None:
43 | client = Unify(
44 | endpoint="gpt-4o@openai",
45 | )
46 | client.set_system_message("you are a helpful agent")
47 | clone = client.copy()
48 | assert clone.endpoint == "gpt-4o@openai"
49 | assert clone.system_message == "you are a helpful agent"
50 | clone.set_system_message("you are not helpful")
51 | assert clone.system_message == "you are not helpful"
52 | assert client.system_message == "you are a helpful agent"
53 | clone.set_endpoint("o1@openai")
54 | assert clone.endpoint == "o1@openai"
55 | assert client.endpoint == "gpt-4o@openai"
56 |
57 | def test_structured_output(self) -> None:
58 | client = Unify(
59 | endpoint="gpt-4o@openai",
60 | response_format=Response,
61 | )
62 |
63 | result = client.generate(
64 | user_message="what is 1 + 1?",
65 | return_full_completion=True,
66 | )
67 | assert isinstance(result, ParsedChatCompletion)
68 | assert isinstance(result.choices[0].message.content, str)
69 | result = json.loads(result.choices[0].message.content)
70 | assert isinstance(result, dict)
71 | assert result == {"number": 2}
72 |
73 | result = client.generate(
74 | user_message="what is 1 + 1?",
75 | )
76 | assert isinstance(result, str)
77 | result = json.loads(result)
78 | assert isinstance(result, dict)
79 | assert result == {"number": 2}
80 |
81 | def test_structured_output_w_caching(self) -> None:
82 | client = Unify(
83 | endpoint="gpt-4o@openai",
84 | response_format=Response,
85 | cache=True,
86 | )
87 | assert json.loads(client.generate(user_message="what is 1 + 1?"))["number"] == 2
88 | assert json.loads(client.generate(user_message="what is 1 + 1?"))["number"] == 2
89 |
90 | def test_generate_returns_generator_when_stream_true(self) -> None:
91 | client = Unify(
92 | endpoint="gpt-4o@openai",
93 | )
94 | result = client.generate(user_message="hello", stream=True)
95 | assert isinstance(result, GeneratorType)
96 |
97 | def test_default_params_handled_correctly(self) -> None:
98 | client = Unify(
99 | endpoint="gpt-4o@openai",
100 | n=2,
101 | return_full_completion=True,
102 | )
103 | result = client.generate(user_message="hello")
104 | assert len(result.choices) == 2
105 |
106 | def test_setter_chaining(self):
107 | client = Unify("gpt-4o@openai")
108 | client.set_temperature(0.5).set_n(2)
109 | assert client.temperature == 0.5
110 | assert client.n == 2
111 |
112 | def test_stateful(self):
113 |
114 | # via generate
115 | client = Unify("gpt-4o@openai", stateful=True)
116 | client.set_system_message("you are a good mathematician.")
117 | client.generate("What is 1 + 1?")
118 | client.generate("How do you know?")
119 | assert len(client.messages) == 5
120 | assert client.messages[0]["role"] == "system"
121 | assert client.messages[1]["role"] == "user"
122 | assert client.messages[2]["role"] == "assistant"
123 | assert client.messages[3]["role"] == "user"
124 | assert client.messages[4]["role"] == "assistant"
125 |
126 | # via append
127 | client = Unify("gpt-4o@openai", return_full_completion=True)
128 | client.set_stateful(True)
129 | client.set_system_message("You are an expert.")
130 | client.append_messages(
131 | [
132 | {
133 | "role": "user",
134 | "content": [
135 | {
136 | "type": "text",
137 | "text": "Hello",
138 | },
139 | ],
140 | },
141 | ],
142 | )
143 | assert len(client.messages) == 2
144 | assert client.messages[0]["role"] == "system"
145 | assert client.messages[1]["role"] == "user"
146 |
147 | def test_json_structor(self):
148 | client = Unify(
149 | endpoint="gpt-4o@openai",
150 | temperature=0.5,
151 | )
152 | serialized = client.json()
153 | assert serialized
154 | assert serialized["endpoint"] == "gpt-4o@openai"
155 |
156 |
157 | @pytest.mark.asyncio
158 | class TestAsyncUnifyBasics:
159 | async def test_invalid_api_key_raises_authentication_error(self) -> None:
160 | with pytest.raises(Exception):
161 | async_client = AsyncUnify(
162 | api_key="invalid_api_key",
163 | endpoint="gpt-4o@openai",
164 | )
165 | await async_client.generate(user_message="hello")
166 |
167 | async def test_incorrect_model_name_raises_internal_server_error(self) -> None:
168 | with pytest.raises(Exception):
169 | AsyncUnify(model="wong-model-name")
170 |
171 | @pytest.mark.skip()
172 | async def test_generate_returns_string_when_stream_false(self) -> None:
173 | async_client = AsyncUnify(
174 | endpoint="gpt-4o@openai",
175 | )
176 | result = await async_client.generate(user_message="hello", stream=False)
177 | assert isinstance(result, str)
178 |
179 | async def test_generate_returns_generator_when_stream_true(self) -> None:
180 | async_client = AsyncUnify(
181 | endpoint="gpt-4o@openai",
182 | )
183 | result = await async_client.generate(user_message="hello", stream=True)
184 | assert isinstance(result, AsyncGeneratorType)
185 |
186 | @pytest.mark.skip()
187 | async def test_default_params_handled_correctly(self) -> None:
188 | async_client = AsyncUnify(
189 | endpoint="gpt-4o@openai",
190 | n=2,
191 | return_full_completion=True,
192 | )
193 | result = await async_client.generate(user_message="hello")
194 | assert len(result.choices) == 2
195 |
196 |
197 | if __name__ == "__main__":
198 | pass
199 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_chatbot.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import traceback
3 |
4 | import pytest
5 | from unify import ChatBot, MultiUnify, Unify
6 |
7 |
8 | class SimulateInput:
9 | def __init__(self):
10 | self._messages = [
11 | "What is the capital of Spain? Be succinct.",
12 | "Who is their most famous sports player? Be succinct.",
13 | "quit",
14 | ]
15 | self._count = 0
16 | self._true_input = None
17 |
18 | def _new_input(self):
19 | message = self._messages[self._count]
20 | self._count += 1
21 | print(message)
22 | return message
23 |
24 | def __enter__(self):
25 | self._true_input = builtins.__dict__["input"]
26 | builtins.__dict__["input"] = self._new_input
27 |
28 | def __exit__(self, exc_type, exc_value, tb):
29 | builtins.__dict__["input"] = self._true_input
30 | self._count = 0
31 | if exc_type is not None:
32 | traceback.print_exception(exc_type, exc_value, tb)
33 | return False
34 | return True
35 |
36 |
37 | class TestChatbotUniLLM:
38 | def test_constructor(self) -> None:
39 | client = Unify(
40 | endpoint="gpt-4o@openai",
41 | cache=True,
42 | )
43 | ChatBot(client)
44 |
45 | def test_simple_non_stream_chat_n_quit(self):
46 | client = Unify(
47 | endpoint="gpt-4o@openai",
48 | cache=True,
49 | )
50 | chatbot = ChatBot(client)
51 | with SimulateInput():
52 | chatbot.run()
53 |
54 | @pytest.mark.skip()
55 | def test_simple_stream_chat_n_quit(self):
56 | client = Unify(
57 | endpoint="gpt-4o@openai",
58 | cache=True,
59 | stream=True,
60 | )
61 | chatbot = ChatBot(client)
62 | with SimulateInput():
63 | chatbot.run()
64 |
65 |
66 | class TestChatbotMultiUnify:
67 | def test_constructor(self) -> None:
68 | client = MultiUnify(
69 | endpoints=["gpt-4@openai", "gpt-4o@openai"],
70 | cache=True,
71 | )
72 | ChatBot(client)
73 |
74 | def test_simple_non_stream_chat_n_quit(self):
75 | client = MultiUnify(
76 | endpoints=["gpt-4@openai", "gpt-4o@openai"],
77 | cache=True,
78 | )
79 | chatbot = ChatBot(client)
80 | with SimulateInput():
81 | chatbot.run()
82 |
83 |
84 | if __name__ == "__main__":
85 | pass
86 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_json_mode.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from unify import Unify
4 |
5 |
6 | def test_openai_json_mode() -> None:
7 | client = Unify(endpoint="gpt-4o@openai")
8 | result = client.generate(
9 | system_message="You are a helpful assistant designed to output JSON.",
10 | user_message="Who won the world series in 2020?",
11 | response_format={"type": "json_object"},
12 | )
13 | assert isinstance(result, str)
14 | result = json.loads(result)
15 | assert isinstance(result, dict)
16 |
17 |
18 | def test_anthropic_json_mode() -> None:
19 | client = Unify(endpoint="claude-3-opus@anthropic")
20 | result = client.generate(
21 | system_message="You are a helpful assistant designed to output JSON.",
22 | user_message="Who won the world series in 2020?",
23 | )
24 | assert isinstance(result, str)
25 | result = json.loads(result)
26 | assert isinstance(result, dict)
27 |
28 |
29 | if __name__ == "__main__":
30 | pass
31 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_multi_llm.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unify import AsyncMultiUnify, MultiUnify
3 |
4 |
5 | class TestMultiUnify:
6 | def test_constructor(self) -> None:
7 | MultiUnify(
8 | endpoints=["llama-3-8b-chat@together-ai", "gpt-4o@openai"],
9 | cache=True,
10 | )
11 |
12 | def test_add_endpoints(self):
13 | endpoints = ("llama-3-8b-chat@together-ai", "gpt-4o@openai")
14 | client = MultiUnify(endpoints=endpoints, cache=True)
15 | assert client.endpoints == endpoints
16 | assert tuple(client.clients.keys()) == endpoints
17 | client.add_endpoints("claude-3.5-sonnet@anthropic")
18 | endpoints = (
19 | "llama-3-8b-chat@together-ai",
20 | "gpt-4o@openai",
21 | "claude-3.5-sonnet@anthropic",
22 | )
23 | assert client.endpoints == endpoints
24 | assert tuple(client.clients.keys()) == endpoints
25 | client.add_endpoints("claude-3.5-sonnet@anthropic")
26 | assert client.endpoints == endpoints
27 | assert tuple(client.clients.keys()) == endpoints
28 | with pytest.raises(Exception):
29 | client.add_endpoints("claude-3.5-sonnet@anthropic", ignore_duplicates=False)
30 |
31 | def test_remove_endpoints(self):
32 | endpoints = (
33 | "llama-3-8b-chat@together-ai",
34 | "gpt-4o@openai",
35 | "claude-3.5-sonnet@anthropic",
36 | )
37 | client = MultiUnify(endpoints=endpoints, cache=True)
38 | assert client.endpoints == endpoints
39 | assert tuple(client.clients.keys()) == endpoints
40 | client.remove_endpoints("claude-3.5-sonnet@anthropic")
41 | endpoints = ("llama-3-8b-chat@together-ai", "gpt-4o@openai")
42 | assert client.endpoints == endpoints
43 | assert tuple(client.clients.keys()) == endpoints
44 | client.remove_endpoints("claude-3.5-sonnet@anthropic")
45 | assert client.endpoints == endpoints
46 | assert tuple(client.clients.keys()) == endpoints
47 | with pytest.raises(Exception):
48 | client.remove_endpoints("claude-3.5-sonnet@anthropic", ignore_missing=False)
49 |
50 | def test_generate(self):
51 | endpoints = (
52 | "gpt-4o@openai",
53 | "claude-3.5-sonnet@anthropic",
54 | )
55 | client = MultiUnify(endpoints=endpoints, cache=True)
56 | responses = client.generate("Hello, how it is going?")
57 | for endpoint, (response_endpoint, response) in zip(
58 | endpoints,
59 | responses.items(),
60 | ):
61 | assert endpoint == response_endpoint
62 | assert isinstance(response, str)
63 | assert len(response) > 0
64 |
65 | def test_multi_message_histories(self):
66 | endpoints = ("claude-3.5-sonnet@anthropic", "gpt-4o@openai")
67 | messages = {
68 | "claude-3.5-sonnet@anthropic": [
69 | {"role": "assistant", "content": "Let's talk about cats"},
70 | ],
71 | "gpt-4o@openai": [
72 | {"role": "assistant", "content": "Let's talk about dogs"},
73 | ],
74 | }
75 | animals = {"claude-3.5-sonnet@anthropic": "cat", "gpt-4o@openai": "dog"}
76 | client = MultiUnify(
77 | endpoints=endpoints,
78 | messages=messages,
79 | cache=True,
80 | )
81 | responses = client.generate("What animal did you want to talk about?")
82 | for endpoint, (response_endpoint, response) in zip(
83 | endpoints,
84 | responses.items(),
85 | ):
86 | assert endpoint == response_endpoint
87 | assert isinstance(response, str)
88 | assert len(response) > 0
89 | assert animals[endpoint] in response.lower()
90 |
91 | def test_setter_chaining(self):
92 | endpoints = (
93 | "llama-3-8b-chat@together-ai",
94 | "gpt-4o@openai",
95 | "claude-3.5-sonnet@anthropic",
96 | )
97 | client = MultiUnify(endpoints=endpoints, cache=True)
98 | client.add_endpoints(["gpt-4@openai", "gpt-4-turbo@openai"]).remove_endpoints(
99 | "claude-3.5-sonnet@anthropic",
100 | )
101 | assert set(client.endpoints) == {
102 | "llama-3-8b-chat@together-ai",
103 | "gpt-4o@openai",
104 | "gpt-4@openai",
105 | "gpt-4-turbo@openai",
106 | }
107 |
108 |
109 | @pytest.mark.asyncio
110 | class TestAsyncMultiUnify:
111 | async def test_async_generate(self):
112 | endpoints = (
113 | "gpt-4o@openai",
114 | "claude-3.5-sonnet@anthropic",
115 | )
116 | client = AsyncMultiUnify(endpoints=endpoints, cache=True)
117 | responses = await client.generate("Hello, how it is going?")
118 | for endpoint, (response_endpoint, response) in zip(
119 | endpoints,
120 | responses.items(),
121 | ):
122 | assert endpoint == response_endpoint
123 | assert isinstance(response, str)
124 | assert len(response) > 0
125 |
126 |
127 | if __name__ == "__main__":
128 | pass
129 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_stateful.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from unify import AsyncUnify, Unify
3 |
4 |
5 | class TestStateful:
6 | # ──────────────────────────────── SYNC ──────────────────────────────── #
7 | def test_stateful_sync_non_stream(self):
8 | client = Unify(endpoint="gpt-4o@openai", cache=True, stateful=True)
9 |
10 | client.generate(user_message="hi") # non-stream
11 | assert len(client.messages) == 2 # user + assistant
12 | assert client.messages[-1]["role"] == "assistant"
13 | assert isinstance(client.messages[-1]["content"], str)
14 |
15 | def test_stateful_sync_stream(self):
16 | client = Unify(endpoint="gpt-4o@openai", cache=True, stateful=True)
17 |
18 | chunks = list(client.generate(user_message="hello", stream=True))
19 | assert all([isinstance(c, str) for c in chunks])
20 | assert len(client.messages) == 2 # user + assistant
21 | assert isinstance(client.messages[-1]["content"], str)
22 |
23 | def test_stateless_sync_stream_clears_history(self):
24 | client = Unify(endpoint="gpt-4o@openai", stateful=False)
25 |
26 | list(client.generate(user_message="hello", stream=True))
27 | assert client.messages == [] # history wiped
28 |
29 | # ─────────────────────────────── ASYNC ──────────────────────────────── #
30 | @pytest.mark.asyncio
31 | async def test_stateful_async_non_stream(self):
32 | client = AsyncUnify(endpoint="gpt-4o@openai", cache=True, stateful=True)
33 |
34 | await client.generate(user_message="hi") # non-stream
35 | assert len(client.messages) == 2 # user + assistant
36 | assert isinstance(client.messages[-1]["content"], str)
37 |
38 | @pytest.mark.asyncio
39 | async def test_stateful_async_stream(self):
40 | client = AsyncUnify(endpoint="gpt-4o@openai", cache=True, stateful=True)
41 |
42 | stream = await client.generate(user_message="hello", stream=True)
43 | assert all([isinstance(c, str) async for c in stream])
44 | assert len(client.messages) == 2 # user + assistant
45 | assert isinstance(client.messages[-1]["content"], str)
46 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_usage.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import unify
4 |
5 | dir_path = os.path.dirname(os.path.realpath(__file__))
6 |
7 |
8 | def test_with_logging() -> None:
9 | model_fn = lambda msg: "This is my response."
10 | model_fn = unify.with_logging(model_fn, endpoint="my_model")
11 | model_fn(msg="Hello?")
12 |
13 |
14 | if __name__ == "__main__":
15 | pass
16 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_user_input.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import traceback
3 |
4 | import unify
5 | from openai.types.chat import ChatCompletion
6 |
7 |
8 | class SimulateInput:
9 | def __init__(self):
10 | self._message = "Hi, how can I help you?"
11 | self._true_input = None
12 |
13 | def _new_input(self, user_instructions):
14 | if user_instructions is not None:
15 | print(user_instructions)
16 | print(self._message)
17 | return self._message
18 |
19 | def __enter__(self):
20 | self._true_input = builtins.__dict__["input"]
21 | builtins.__dict__["input"] = self._new_input
22 |
23 | def __exit__(self, exc_type, exc_value, tb):
24 | builtins.__dict__["input"] = self._true_input
25 | self._count = 0
26 | if exc_type is not None:
27 | traceback.print_exception(exc_type, exc_value, tb)
28 | return False
29 | return True
30 |
31 |
32 | def test_user_input_client():
33 | client = unify.Unify("user-input")
34 | with SimulateInput():
35 | response = client.generate("hello")
36 | assert isinstance(response, str)
37 | response = client.generate("hello", return_full_completion=True)
38 | assert isinstance(response, ChatCompletion)
39 |
40 |
41 | if __name__ == "__main__":
42 | pass
43 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_universal_api/test_utils/__init__.py
--------------------------------------------------------------------------------
/tests/test_universal_api/test_utils/test_credits.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 |
4 | def test_get_credits() -> None:
5 | creds = unify.get_credits()
6 | assert isinstance(creds, float)
7 |
8 |
9 | if __name__ == "__main__":
10 | pass
11 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_utils/test_custom_api_keys.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 |
4 | # noinspection PyBroadException
5 | class CustomAPIKeyHandler:
6 | def __init__(self, ky_name, ky_value, nw_name):
7 | self._key_name = ky_name
8 | self._key_value = ky_value
9 | self._new_name = nw_name
10 |
11 | def _handle(self):
12 | # should work even if list_custom_api_keys does not
13 | for name in (self._key_name, self._new_name):
14 | try:
15 | unify.delete_custom_api_key(name)
16 | except:
17 | pass
18 | # should if other keys have wrongly been created
19 | try:
20 | custom_keys = unify.list_custom_api_keys()
21 | for dct in custom_keys:
22 | unify.delete_custom_api_key(dct["name"])
23 | except:
24 | pass
25 |
26 | def __enter__(self):
27 | self._handle()
28 |
29 | def __exit__(self, exc_type, exc_val, exc_tb):
30 | self._handle()
31 |
32 |
33 | def _find_key(key_to_find, list_of_keys):
34 | for key in list_of_keys:
35 | if key["name"] == key_to_find:
36 | return True
37 | return False
38 |
39 |
40 | key_name = "my_test_key2"
41 | key_value = "1234"
42 | new_name = "new_test_key"
43 | handler = CustomAPIKeyHandler(
44 | key_name,
45 | key_value,
46 | new_name,
47 | )
48 |
49 |
50 | def test_create_custom_api_key():
51 | with handler:
52 | response = unify.create_custom_api_key(key_name, key_value)
53 | assert response == {"info": "API key created successfully!"}
54 |
55 |
56 | def test_list_custom_api_keys():
57 | with handler:
58 | custom_keys = unify.list_custom_api_keys()
59 | assert isinstance(custom_keys, list)
60 | assert len(custom_keys) == 0
61 | unify.create_custom_api_key(key_name, key_value)
62 | custom_keys = unify.list_custom_api_keys()
63 | assert isinstance(custom_keys, list)
64 | assert len(custom_keys) == 1
65 | assert custom_keys[0]["name"] == key_name
66 | assert custom_keys[0]["value"] == "*" * 4 + key_value
67 |
68 |
69 | def test_get_custom_api_key():
70 | with handler:
71 | unify.create_custom_api_key(key_name, key_value)
72 | retrieved_key = unify.get_custom_api_key(key_name)
73 | assert isinstance(retrieved_key, dict)
74 | assert retrieved_key["name"] == key_name
75 | assert retrieved_key["value"] == "*" * 4 + key_value
76 |
77 |
78 | def test_rename_custom_api_key():
79 | with handler:
80 | unify.create_custom_api_key(key_name, key_value)
81 | custom_keys = unify.list_custom_api_keys()
82 | assert isinstance(custom_keys, list)
83 | assert len(custom_keys) == 1
84 | assert custom_keys[0]["name"] == key_name
85 | unify.rename_custom_api_key(key_name, new_name)
86 | custom_keys = unify.list_custom_api_keys()
87 | assert isinstance(custom_keys, list)
88 | assert len(custom_keys) == 1
89 | assert custom_keys[0]["name"] == new_name
90 |
91 |
92 | def test_delete_custom_api_key():
93 | with handler:
94 | unify.create_custom_api_key(key_name, key_value)
95 | custom_keys = unify.list_custom_api_keys()
96 | assert isinstance(custom_keys, list)
97 | assert len(custom_keys) == 1
98 | assert custom_keys[0]["name"] == key_name
99 | unify.delete_custom_api_key(key_name)
100 | custom_keys = unify.list_custom_api_keys()
101 | assert isinstance(custom_keys, list)
102 | assert len(custom_keys) == 0
103 |
104 |
105 | if __name__ == "__main__":
106 | pass
107 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_utils/test_custom_endpoints.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 |
4 | # noinspection PyBroadException
5 | class CustomEndpointHandler:
6 | def __init__(self, ky_name, ky_value, endpoint_names):
7 | self._key_name = ky_name
8 | self._key_value = ky_value
9 | self._endpoint_names = endpoint_names
10 |
11 | def _handle(self):
12 | try:
13 | unify.delete_custom_api_key(self._key_name)
14 | except:
15 | pass
16 | for endpoint_nm in self._endpoint_names:
17 | try:
18 | unify.delete_custom_endpoint(endpoint_nm)
19 | except:
20 | pass
21 | unify.create_custom_api_key(self._key_name, self._key_value)
22 |
23 | def __enter__(self):
24 | self._handle()
25 |
26 | def __exit__(self, exc_type, exc_val, exc_tb):
27 | self._handle()
28 |
29 |
30 | endpoint_name = "my_endpoint@custom"
31 | new_endpoint_name = "renamed@custom"
32 | endpoint_url = "test.com"
33 | key_name = "test_key"
34 | key_value = "4321"
35 | unify.create_custom_api_key(key_name, key_value)
36 | custom_endpoint_handler = CustomEndpointHandler(
37 | key_name,
38 | key_value,
39 | [endpoint_name, new_endpoint_name],
40 | )
41 |
42 |
43 | def test_create_custom_endpoint():
44 | with custom_endpoint_handler:
45 | unify.create_custom_endpoint(
46 | name=endpoint_name,
47 | url=endpoint_url,
48 | key_name=key_name,
49 | )
50 |
51 |
52 | def test_list_custom_endpoints():
53 | with custom_endpoint_handler:
54 | unify.create_custom_endpoint(
55 | name=endpoint_name,
56 | url=endpoint_url,
57 | key_name=key_name,
58 | )
59 | custom_endpoints = unify.list_custom_endpoints()
60 | assert len(custom_endpoints) == 1
61 | assert endpoint_name == custom_endpoints[0]["name"]
62 |
63 |
64 | def test_rename_custom_endpoint():
65 | with custom_endpoint_handler:
66 | unify.create_custom_endpoint(
67 | name=endpoint_name,
68 | url=endpoint_url,
69 | key_name=key_name,
70 | )
71 | custom_endpoints = unify.list_custom_endpoints()
72 | assert len(custom_endpoints) == 1
73 | assert endpoint_name == custom_endpoints[0]["name"]
74 | unify.rename_custom_endpoint(
75 | endpoint_name,
76 | new_endpoint_name,
77 | )
78 | custom_endpoints = unify.list_custom_endpoints()
79 | assert len(custom_endpoints) == 1
80 | assert new_endpoint_name == custom_endpoints[0]["name"]
81 |
82 |
83 | def test_delete_custom_endpoints():
84 | with custom_endpoint_handler:
85 | unify.create_custom_endpoint(
86 | name=endpoint_name,
87 | url=endpoint_url,
88 | key_name=key_name,
89 | )
90 | custom_endpoints = unify.list_custom_endpoints()
91 | assert len(custom_endpoints) == 1
92 | assert endpoint_name == custom_endpoints[0]["name"]
93 | unify.delete_custom_endpoint(endpoint_name)
94 | custom_endpoints = unify.list_custom_endpoints()
95 | assert len(custom_endpoints) == 0
96 |
97 |
98 | if __name__ == "__main__":
99 | pass
100 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_utils/test_endpoint_metrics.py:
--------------------------------------------------------------------------------
1 | import time
2 | from datetime import datetime, timezone
3 |
4 | import unify
5 | from unify import Metrics
6 |
7 |
8 | # noinspection PyBroadException
9 | class CustomEndpointHandler:
10 | def __init__(self, ep_name, ep_url, ky_name, ky_value):
11 | self.endpoint_name = ep_name
12 | self._endpoint_url = ep_url
13 | self._key_name = ky_name
14 | self._key_value = ky_value
15 |
16 | def _cleanup(self):
17 | try:
18 | unify.delete_endpoint_metrics(self.endpoint_name)
19 | except:
20 | pass
21 | try:
22 | unify.delete_custom_endpoint(self.endpoint_name)
23 | except:
24 | pass
25 | try:
26 | unify.delete_custom_api_key(self._key_name)
27 | except:
28 | pass
29 |
30 | def __enter__(self):
31 | self._cleanup()
32 | unify.create_custom_api_key(self._key_name, self._key_value)
33 | unify.create_custom_endpoint(
34 | name=self.endpoint_name,
35 | url=self._endpoint_url,
36 | key_name=self._key_name,
37 | )
38 |
39 | def __exit__(self, exc_type, exc_val, exc_tb):
40 | self._cleanup()
41 |
42 |
43 | endpoint_name = "my_endpoint@custom"
44 | endpoint_url = "test.com"
45 | key_name = "test_key"
46 | key_value = "4321"
47 | handler = CustomEndpointHandler(
48 | endpoint_name,
49 | endpoint_url,
50 | key_name,
51 | key_value,
52 | )
53 |
54 |
55 | def test_get_public_endpoint_metrics():
56 | metrics = unify.get_endpoint_metrics("gpt-4o@openai")
57 | assert isinstance(metrics, list)
58 | assert len(metrics) == 1
59 | metrics = metrics[0]
60 | assert isinstance(metrics, Metrics)
61 | assert hasattr(metrics, "ttft")
62 | assert isinstance(metrics.ttft, float)
63 | assert hasattr(metrics, "itl")
64 | assert isinstance(metrics.itl, float)
65 | assert hasattr(metrics, "input_cost")
66 | assert isinstance(metrics.input_cost, float)
67 | assert hasattr(metrics, "output_cost")
68 | assert isinstance(metrics.output_cost, float)
69 | assert hasattr(metrics, "measured_at")
70 | assert isinstance(metrics.measured_at, str)
71 |
72 |
73 | def test_client_metric_properties():
74 | client = unify.Unify("gpt-4o@openai", cache=True)
75 | assert isinstance(client.input_cost, float)
76 | assert isinstance(client.output_cost, float)
77 | assert isinstance(client.ttft, float)
78 | assert isinstance(client.itl, float)
79 | client = unify.MultiUnify(
80 | ["gpt-4o@openai", "claude-3-haiku@anthropic"],
81 | cache=True,
82 | )
83 | assert isinstance(client.input_cost, dict)
84 | assert isinstance(client.input_cost["gpt-4o@openai"], float)
85 | assert isinstance(client.input_cost["claude-3-haiku@anthropic"], float)
86 | assert isinstance(client.output_cost, dict)
87 | assert isinstance(client.output_cost["gpt-4o@openai"], float)
88 | assert isinstance(client.output_cost["claude-3-haiku@anthropic"], float)
89 | assert isinstance(client.ttft, dict)
90 | assert isinstance(client.ttft["gpt-4o@openai"], float)
91 | assert isinstance(client.ttft["claude-3-haiku@anthropic"], float)
92 | assert isinstance(client.itl, dict)
93 | assert isinstance(client.itl["gpt-4o@openai"], float)
94 | assert isinstance(client.itl["claude-3-haiku@anthropic"], float)
95 |
96 |
97 | def test_log_endpoint_metric():
98 | with handler:
99 | unify.log_endpoint_metric(
100 | endpoint_name,
101 | metric_name="itl",
102 | value=1.23,
103 | )
104 |
105 |
106 | def test_log_and_get_endpoint_metric():
107 | with handler:
108 | now = datetime.now(timezone.utc)
109 | unify.log_endpoint_metric(
110 | endpoint_name,
111 | metric_name="itl",
112 | value=1.23,
113 | )
114 | metrics = unify.get_endpoint_metrics(endpoint_name, start_time=now)
115 | assert isinstance(metrics, list)
116 | assert len(metrics) == 1
117 | metrics = metrics[0]
118 | assert hasattr(metrics, "itl")
119 | assert isinstance(metrics.itl, float)
120 | assert metrics.itl == 1.23
121 |
122 |
123 | def test_log_and_get_endpoint_metric_with_time_windows():
124 | with handler:
125 | t0 = datetime.now(timezone.utc)
126 | unify.log_endpoint_metric(
127 | endpoint_name,
128 | metric_name="itl",
129 | value=1.23,
130 | )
131 | unify.log_endpoint_metric(
132 | endpoint_name,
133 | metric_name="ttft",
134 | value=4.56,
135 | )
136 | time.sleep(0.5)
137 | t1 = datetime.now(timezone.utc)
138 | unify.log_endpoint_metric(
139 | endpoint_name,
140 | metric_name="itl",
141 | value=7.89,
142 | )
143 | all_metrics = unify.get_endpoint_metrics(endpoint_name, start_time=t0)
144 | # two log events detected, due to double itl logging
145 | assert len(all_metrics) == 2
146 | # Data all accumulates at the latest entry (top of the stack)
147 | assert isinstance(all_metrics[0].itl, float)
148 | assert all_metrics[0].ttft is None
149 | assert isinstance(all_metrics[1].itl, float)
150 | assert isinstance(all_metrics[1].ttft, float)
151 | assert all_metrics[0].itl == 1.23
152 | assert all_metrics[1].ttft == 4.56
153 | assert all_metrics[1].itl == 7.89
154 | # The original two logs are not retrieved
155 | limited_metrics = unify.get_endpoint_metrics(
156 | endpoint_name,
157 | start_time=t1,
158 | )
159 | assert len(limited_metrics) == 1
160 | assert limited_metrics[0].ttft is None
161 | assert isinstance(limited_metrics[0].itl, float)
162 | assert limited_metrics[0].itl == 7.89
163 | # The ttft is now retrieved due to 'latest' mode
164 | latest_metrics = unify.get_endpoint_metrics(endpoint_name)
165 | assert len(latest_metrics) == 1
166 | assert isinstance(latest_metrics[0].ttft, float)
167 | assert isinstance(latest_metrics[0].itl, float)
168 | assert latest_metrics[0].ttft == 4.56
169 | assert latest_metrics[0].itl == 7.89
170 |
171 |
172 | def test_delete_all_metrics_for_endpoint():
173 | with handler:
174 | # log metric
175 | unify.log_endpoint_metric(
176 | endpoint_name,
177 | metric_name="itl",
178 | value=1.23,
179 | )
180 | # verify it exists
181 | metrics = unify.get_endpoint_metrics(endpoint_name)
182 | assert isinstance(metrics, list)
183 | assert len(metrics) == 1
184 | # delete it
185 | unify.delete_endpoint_metrics(endpoint_name)
186 | # verify it no longer exists
187 | metrics = unify.get_endpoint_metrics(endpoint_name)
188 | assert isinstance(metrics, list)
189 | assert len(metrics) == 0
190 |
191 |
192 | def test_delete_some_metrics_for_endpoint():
193 | with handler:
194 | # log metrics at t0
195 | t0 = datetime.now(timezone.utc)
196 | unify.log_endpoint_metric(
197 | endpoint_name,
198 | metric_name="itl",
199 | value=1.23,
200 | )
201 | unify.log_endpoint_metric(
202 | endpoint_name,
203 | metric_name="ttft",
204 | value=4.56,
205 | )
206 | time.sleep(0.5)
207 | # log metric at t1
208 | unify.log_endpoint_metric(
209 | endpoint_name,
210 | metric_name="itl",
211 | value=7.89,
212 | )
213 | # verify both exist
214 | metrics = unify.get_endpoint_metrics(endpoint_name, start_time=t0)
215 | assert len(metrics) == 2
216 | # delete the first itl entry
217 | unify.delete_endpoint_metrics(
218 | endpoint_name,
219 | timestamps=metrics[0].measured_at["itl"],
220 | )
221 | # verify only the latest entry exists, with both itl and ttft
222 | metrics = unify.get_endpoint_metrics(endpoint_name, start_time=t0)
223 | assert len(metrics) == 1
224 | assert isinstance(metrics[0].itl, float)
225 | assert isinstance(metrics[0].ttft, float)
226 | # delete the ttft entry
227 | unify.delete_endpoint_metrics(
228 | endpoint_name,
229 | timestamps=metrics[0].measured_at["ttft"],
230 | )
231 | # verify only the latest entry exists, with only the itl
232 | metrics = unify.get_endpoint_metrics(endpoint_name, start_time=t0)
233 | assert len(metrics) == 1
234 | assert isinstance(metrics[0].itl, float)
235 | assert metrics[0].ttft is None
236 | # delete the final itl entry
237 | unify.delete_endpoint_metrics(
238 | endpoint_name,
239 | timestamps=metrics[0].measured_at["itl"],
240 | )
241 | # verify no metrics exist
242 | metrics = unify.get_endpoint_metrics(endpoint_name, start_time=t0)
243 | assert len(metrics) == 0
244 |
245 |
246 | if __name__ == "__main__":
247 | pass
248 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_utils/test_supported_endpoints.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import unify
3 |
4 |
5 | class TestSupportedModels:
6 | def test_list_models(self) -> None:
7 | models = unify.list_models()
8 | assert isinstance(models, list), "return type was not a list: {}".format(
9 | models,
10 | ) # is list
11 | assert models, "returned list was empty: {}".format(models) # not empty
12 | assert len(models) == len(set(models)), "duplication detected: {}".format(
13 | models,
14 | ) # no duplication
15 |
16 | def test_list_models_w_provider(self):
17 | models = unify.list_models("openai")
18 | assert isinstance(models, list), "return type was not a list: {}".format(
19 | models,
20 | ) # is list
21 | assert models, "returned list was empty: {}".format(models) # not empty
22 | assert len(models) == len(set(models)), "duplication detected: {}".format(
23 | models,
24 | ) # no duplication
25 | assert len(models) == len(
26 | [e for e in unify.list_endpoints() if "openai" in e],
27 | ), (
28 | "number of models for the provider did not match the number of endpoints "
29 | "with the provider in the string"
30 | )
31 |
32 |
33 | class TestSupportedProviders:
34 | def test_list_providers(self) -> None:
35 | providers = unify.list_providers()
36 | assert isinstance(providers, list), "return type was not a list: {}".format(
37 | providers,
38 | ) # is list
39 | assert providers, "returned list was empty: {}".format(providers) # not empty
40 | assert len(providers) == len(set(providers)), "duplication detected: {}".format(
41 | providers,
42 | ) # no duplication
43 |
44 | def test_list_providers_w_model(self):
45 | providers = unify.list_providers("llama-3.2-90b-chat")
46 | assert isinstance(providers, list), "return type was not a list: {}".format(
47 | providers,
48 | ) # is list
49 | assert providers, "returned list was empty: {}".format(providers) # not empty
50 | assert len(providers) == len(set(providers)), "duplication detected: {}".format(
51 | providers,
52 | ) # no duplication
53 | assert len(providers) == len(
54 | [e for e in unify.list_endpoints() if "llama-3.2-90b-chat" in e],
55 | ), (
56 | "number of providers for the model did not match the number of endpoints "
57 | "with the model in the string"
58 | )
59 |
60 |
61 | class TestSupportedEndpoints:
62 | def test_list_endpoints(self) -> None:
63 | endpoints = unify.list_endpoints()
64 | assert isinstance(endpoints, list), "return type was not a list: {}".format(
65 | endpoints,
66 | ) # is list
67 | assert endpoints, "returned list was empty: {}".format(endpoints) # not empty
68 | assert len(endpoints) == len(set(endpoints)), "duplication detected: {}".format(
69 | endpoints,
70 | ) # no duplication
71 |
72 | def test_list_endpoints_w_model(self) -> None:
73 | endpoints = unify.list_endpoints(model="llama-3.2-90b-chat")
74 | assert isinstance(endpoints, list), "return type was not a list: {}".format(
75 | endpoints,
76 | ) # is list
77 | assert endpoints, "returned list was empty: {}".format(endpoints) # not empty
78 | assert len(endpoints) == len(set(endpoints)), "duplication detected: {}".format(
79 | endpoints,
80 | ) # no duplication
81 | assert len(endpoints) == len(unify.list_providers("llama-3.2-90b-chat")), (
82 | "number of endpoints for the model did not match the number of providers "
83 | "for the model"
84 | )
85 |
86 | def test_list_endpoints_w_provider(self) -> None:
87 | endpoints = unify.list_endpoints(provider="openai")
88 | assert isinstance(endpoints, list), "return type was not a list: {}".format(
89 | endpoints,
90 | ) # is list
91 | assert endpoints, "returned list was empty: {}".format(endpoints) # not empty
92 | assert len(endpoints) == len(set(endpoints)), "duplication detected: {}".format(
93 | endpoints,
94 | ) # no duplication
95 | assert len(endpoints) == len(unify.list_models("openai")), (
96 | "number of endpoints for the provider did not match the number of models "
97 | "for the provider"
98 | )
99 |
100 | def test_list_endpoints_w_model_w_provider(self) -> None:
101 | with pytest.raises(Exception):
102 | unify.list_endpoints("gpt-4o", "openai")
103 |
104 |
105 | if __name__ == "__main__":
106 | pass
107 |
--------------------------------------------------------------------------------
/tests/test_universal_api/test_utils/test_usage.py:
--------------------------------------------------------------------------------
1 | import threading
2 | from datetime import datetime, timedelta, timezone
3 | from typing import Callable
4 |
5 | import pytest
6 | import unify
7 |
8 | THREAD_LOCK = threading.Lock()
9 |
10 | tag = "test_tag"
11 | data = {
12 | "endpoint": "local_model_test@external",
13 | "query_body": {
14 | "messages": [
15 | {"role": "system", "content": "You are an useful assistant"},
16 | {"role": "user", "content": "Explain who Newton was."},
17 | ],
18 | "model": "llama-3-8b-chat@aws-bedrock",
19 | "max_tokens": 100,
20 | "temperature": 0.5,
21 | },
22 | "response_body": {
23 | "model": "meta.llama3-8b-instruct-v1:0",
24 | "created": 1725396241,
25 | "id": "chatcmpl-92d3b36e-7b64-4ae8-8102-9b7e3f5dd30f",
26 | "object": "chat.completion",
27 | "usage": {
28 | "completion_tokens": 100,
29 | "prompt_tokens": 44,
30 | "total_tokens": 144,
31 | },
32 | "choices": [
33 | {
34 | "finish_reason": "stop",
35 | "index": 0,
36 | "message": {
37 | "content": "Sir Isaac Newton was an English mathematician, "
38 | "physicist, and astronomer who lived from 1643 "
39 | "to 1727.\\n\\nHe is widely recognized as one "
40 | "of the most influential scientists in history, "
41 | "and his work laid the foundation for the "
42 | "Scientific Revolution of the 17th century."
43 | "\\n\\nNewton's most famous achievement is his "
44 | "theory of universal gravitation, which he "
45 | "presented in his groundbreaking book "
46 | '"Philosophi\\u00e6 Naturalis Principia '
47 | 'Mathematica" in 1687.',
48 | "role": "assistant",
49 | },
50 | },
51 | ],
52 | },
53 | "timestamp": str(datetime.now(timezone.utc) + timedelta(seconds=0.01)),
54 | "tags": [tag],
55 | }
56 |
57 |
58 | def _thread_locked(fn: Callable) -> Callable:
59 | # noinspection PyBroadException
60 | def wrapped(*args, **kwargs):
61 | THREAD_LOCK.acquire()
62 | try:
63 | ret = fn(*args, **kwargs)
64 | THREAD_LOCK.release()
65 | return ret
66 | except:
67 | THREAD_LOCK.release()
68 |
69 | return wrapped
70 |
71 |
72 | @_thread_locked
73 | def test_log_query_manually():
74 | result = unify.log_query(**data)
75 | assert isinstance(result, dict)
76 | assert "info" in result
77 | assert result["info"] == "Query logged successfully"
78 |
79 |
80 | @_thread_locked
81 | def test_log_query_via_chat_completion():
82 | client = unify.Unify("gpt-4o@openai")
83 | response = client.generate(
84 | "hello",
85 | log_query_body=True,
86 | log_response_body=True,
87 | )
88 | assert isinstance(response, str)
89 |
90 |
91 | @_thread_locked
92 | def test_get_queries_from_manual():
93 | start_time = datetime.now(timezone.utc)
94 | unify.log_query(**data)
95 | history = unify.get_queries(
96 | endpoints="local_model_test@external",
97 | start_time=start_time,
98 | )
99 | assert len(history) == 1
100 | history = unify.get_queries(
101 | endpoints="local_model_test@external",
102 | start_time=datetime.now(timezone.utc) + timedelta(seconds=1),
103 | )
104 | assert len(history) == 0
105 |
106 |
107 | @_thread_locked
108 | def test_get_queries_from_chat_completion():
109 | start_time = datetime.now(timezone.utc)
110 | unify.Unify("gpt-4o@openai").generate(
111 | "hello",
112 | log_query_body=True,
113 | log_response_body=True,
114 | )
115 | history = unify.get_queries(
116 | endpoints="gpt-4o@openai",
117 | start_time=start_time,
118 | )
119 | assert len(history) == 1
120 | history = unify.get_queries(
121 | endpoints="gpt-4o@openai",
122 | start_time=datetime.now(timezone.utc) + timedelta(seconds=1),
123 | )
124 | assert len(history) == 0
125 |
126 |
127 | @_thread_locked
128 | def test_get_query_failures():
129 | start_time = datetime.now(timezone.utc)
130 | client = unify.Unify("gpt-4o@openai")
131 | client.generate(
132 | "hello",
133 | log_query_body=True,
134 | log_response_body=True,
135 | )
136 | with pytest.raises(Exception):
137 | client.generate(
138 | "hello",
139 | log_query_body=True,
140 | log_response_body=True,
141 | drop_params=False,
142 | invalid_arg="invalid_value",
143 | )
144 |
145 | # inside logged timeframe
146 | history_w_both = unify.get_queries(
147 | endpoints="gpt-4o@openai",
148 | start_time=start_time,
149 | failures=True,
150 | )
151 | assert len(history_w_both) == 2
152 | history_only_failures = unify.get_queries(
153 | endpoints="gpt-4o@openai",
154 | start_time=start_time,
155 | failures="only",
156 | )
157 | assert len(history_only_failures) == 1
158 | history_only_success = unify.get_queries(
159 | endpoints="gpt-4o@openai",
160 | start_time=start_time,
161 | failures=False,
162 | )
163 | assert len(history_only_success) == 1
164 |
165 | # Outside logged timeframe
166 | history_w_both = unify.get_queries(
167 | endpoints="gpt-4o@openai",
168 | start_time=datetime.now(timezone.utc) + timedelta(seconds=1),
169 | failures=True,
170 | )
171 | assert len(history_w_both) == 0
172 | history_only_failures = unify.get_queries(
173 | endpoints="gpt-4o@openai",
174 | start_time=datetime.now(timezone.utc) + timedelta(seconds=1),
175 | failures="only",
176 | )
177 | assert len(history_only_failures) == 0
178 | history_only_success = unify.get_queries(
179 | endpoints="gpt-4o@openai",
180 | start_time=datetime.now(timezone.utc) + timedelta(seconds=1),
181 | failures=False,
182 | )
183 | assert len(history_only_success) == 0
184 |
185 |
186 | @_thread_locked
187 | def test_get_query_tags():
188 | unify.log_query(**data)
189 | tags = unify.get_query_tags()
190 | assert isinstance(tags, list)
191 | assert tag in tags
192 |
193 |
194 | if __name__ == "__main__":
195 | pass
196 |
--------------------------------------------------------------------------------
/tests/test_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/tests/test_utils/__init__.py
--------------------------------------------------------------------------------
/tests/test_utils/helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import unify
4 | from unify.utils._caching import _get_caching_fpath
5 |
6 |
7 | class _CacheHandler:
8 | def __init__(self, fname=".test_cache.json"):
9 | self._old_cache_fpath = _get_caching_fpath()
10 | self._fname = fname
11 | self.test_path = ""
12 |
13 | def __enter__(self):
14 | unify.set_caching_fname(self._fname)
15 | self.test_path = _get_caching_fpath()
16 | if os.path.exists(self.test_path):
17 | os.remove(self.test_path)
18 | return self
19 |
20 | def __exit__(self, exc_type, exc_value, traceback):
21 | if os.path.exists(self.test_path):
22 | os.remove(self.test_path)
23 | unify.set_caching_fname(self._old_cache_fpath)
24 |
--------------------------------------------------------------------------------
/tests/test_utils/test_map.py:
--------------------------------------------------------------------------------
1 | import random
2 | import time
3 |
4 | import pytest
5 | import unify
6 | from tests.test_utils.helpers import _CacheHandler
7 | from unify.utils._caching import _get_cache, _write_to_cache
8 |
9 | from ..test_logging.helpers import _handle_project
10 |
11 | # Helpers #
12 | # --------#
13 |
14 |
15 | class ProjectHandling:
16 | def __enter__(self):
17 | if "test_project" in unify.list_projects():
18 | unify.delete_project("test_project")
19 |
20 | def __exit__(self, exc_type, exc_value, tb):
21 | if "test_project" in unify.list_projects():
22 | unify.delete_project("test_project")
23 |
24 |
25 | client = unify.Unify("gpt-4o@openai", cache=True)
26 | async_client = unify.AsyncUnify("gpt-4o@openai", cache=True)
27 | qs = ["3 - 1", "4 + 7", "6 + 2", "9 - 3", "7 + 9"]
28 |
29 |
30 | def evaluate_response(question: str, response: str) -> float:
31 | correct_answer = eval(question)
32 | try:
33 | response_int = int(
34 | "".join([c for c in response.split(" ")[-1] if c.isdigit()]),
35 | )
36 | return float(correct_answer == response_int)
37 | except ValueError:
38 | return 0.0
39 |
40 |
41 | def evaluate(q: str):
42 | response = client.generate(q)
43 | evaluate_response(q, response)
44 |
45 |
46 | def evaluate_w_log(q: str):
47 | response = client.generate(q)
48 | score = evaluate_response(q, response)
49 | return unify.log(
50 | question=q,
51 | response=response,
52 | score=score,
53 | skip_duplicates=False,
54 | )
55 |
56 |
57 | @pytest.mark.asyncio
58 | async def async_evaluate(q: str):
59 | response = await async_client.generate(q)
60 | return evaluate_response(q, response)
61 |
62 |
63 | # Tests #
64 | # ------#
65 |
66 |
67 | def test_threaded_map() -> None:
68 | with ProjectHandling():
69 | with unify.Project("test_project"):
70 | unify.map(evaluate_w_log, qs)
71 | for q in qs:
72 | evaluate_w_log(q)
73 |
74 |
75 | def test_map_mode() -> None:
76 | unify.set_map_mode("threading")
77 | assert unify.get_map_mode() == "threading"
78 | unify.map(evaluate_w_log, qs)
79 | unify.set_map_mode("asyncio")
80 | assert unify.get_map_mode() == "asyncio"
81 | unify.map(evaluate_w_log, qs)
82 | unify.set_map_mode("loop")
83 | assert unify.get_map_mode() == "loop"
84 | unify.map(evaluate_w_log, qs)
85 |
86 |
87 | @_handle_project
88 | def test_map_w_cache() -> None:
89 | with _CacheHandler():
90 |
91 | @unify.traced(name="gen{x}")
92 | def gen(x, cache):
93 | ret = None
94 | if cache in [True, "read", "read-only"]:
95 | ret = _get_cache(
96 | fn_name="gen",
97 | kw={"x": x},
98 | raise_on_empty=(cache == "read-only"),
99 | )
100 | if ret is None:
101 | ret = random.randint(1, 5) + x
102 | if cache in [True, "write"]:
103 | _write_to_cache(
104 | fn_name="gen",
105 | kw={"x": x},
106 | response=ret,
107 | )
108 | return ret
109 |
110 | @unify.traced
111 | def fn(cache):
112 | x = gen(0, cache)
113 | time.sleep(random.uniform(0, 0.1))
114 | y = gen(x, cache)
115 | time.sleep(random.uniform(0, 0.1))
116 | z = gen(y, cache)
117 |
118 | @unify.traced
119 | def cache_is_true():
120 | unify.map(fn, [True] * 10)
121 |
122 | @unify.traced
123 | def cache_is_read_only():
124 | unify.map(fn, ["read-only"] * 10)
125 |
126 | cache_is_true()
127 | cache_is_read_only()
128 |
129 |
130 | def test_threaded_map_from_args() -> None:
131 | with ProjectHandling():
132 | with unify.Project("test_project"):
133 | unify.map(evaluate_w_log, qs, from_args=True)
134 | for q in qs:
135 | evaluate_w_log(q)
136 |
137 |
138 | def test_threaded_map_with_context() -> None:
139 | with ProjectHandling():
140 | with unify.Project("test_project"):
141 |
142 | def contextual_func(a, b, c=3):
143 | with unify.Entries(a=a, b=b, c=c):
144 | unify.log(test="some random value")
145 | return a + b + c
146 |
147 | results = unify.map(
148 | contextual_func,
149 | [
150 | ((1, 3), {"c": 2}),
151 | ((2, 4), {"c": 4}),
152 | ],
153 | )
154 | assert results == [1 + 3 + 2, 2 + 4 + 4]
155 | results = unify.map(
156 | contextual_func,
157 | [
158 | ((1,), {"b": 2, "c": 2}),
159 | ((3,), {"b": 4, "c": 4}),
160 | ],
161 | )
162 | assert results == [1 + 2 + 2, 3 + 4 + 4]
163 |
164 |
165 | def test_threaded_map_with_context_from_args() -> None:
166 | with ProjectHandling():
167 | with unify.Project("test_project"):
168 |
169 | def contextual_func(a, b, c=3):
170 | with unify.Entries(a=a, b=b, c=c):
171 | unify.log(test="some random value")
172 | return a + b + c
173 |
174 | results = unify.map(
175 | contextual_func,
176 | (1, 2),
177 | (3, 4),
178 | c=(2, 4),
179 | from_args=True,
180 | )
181 | assert results == [1 + 3 + 2, 2 + 4 + 4]
182 | results = unify.map(
183 | contextual_func,
184 | (1, 3),
185 | b=(2, 4),
186 | c=(2, 4),
187 | from_args=True,
188 | )
189 | assert results == [1 + 2 + 2, 3 + 4 + 4]
190 |
191 |
192 | def test_asyncio_map() -> None:
193 | unify.map(async_evaluate, qs, mode="asyncio")
194 | for q in qs:
195 | evaluate(q)
196 |
197 |
198 | def test_asyncio_map_from_args() -> None:
199 | unify.map(async_evaluate, qs, mode="asyncio", from_args=True)
200 | for q in qs:
201 | evaluate(q)
202 |
203 |
204 | def test_loop_map() -> None:
205 | unify.map(evaluate_w_log, qs, mode="loop")
206 |
207 |
208 | def test_loop_map_from_args() -> None:
209 | unify.map(evaluate_w_log, qs, mode="loop", from_args=True)
210 |
211 |
212 | @pytest.mark.asyncio
213 | def test_asyncio_map_with_context() -> None:
214 | with ProjectHandling():
215 | with unify.Project("test_project"):
216 |
217 | def contextual_func(a, b, c=3):
218 | with unify.Entries(a=a, b=b, c=c):
219 | time.sleep(0.1)
220 | unify.log(test="some random value")
221 | return a + b + c
222 |
223 | results = unify.map(
224 | contextual_func,
225 | [
226 | ((1, 3), {"c": 2}),
227 | ((2, 4), {"c": 4}),
228 | ],
229 | mode="asyncio",
230 | )
231 | assert results == [1 + 3 + 2, 2 + 4 + 4]
232 | results = unify.map(
233 | contextual_func,
234 | [
235 | ((1,), {"b": 2, "c": 2}),
236 | ((3,), {"b": 4, "c": 4}),
237 | ],
238 | mode="asyncio",
239 | )
240 | assert results == [1 + 2 + 2, 3 + 4 + 4]
241 |
242 |
243 | @pytest.mark.asyncio
244 | def test_asyncio_map_with_context_from_args() -> None:
245 | with ProjectHandling():
246 | with unify.Project("test_project"):
247 |
248 | def contextual_func(a, b, c=3):
249 | with unify.Entries(a=a, b=b, c=c):
250 | time.sleep(0.1)
251 | unify.log(test="some random value")
252 | return a + b + c
253 |
254 | results = unify.map(
255 | contextual_func,
256 | (1, 2),
257 | (3, 4),
258 | c=2,
259 | mode="asyncio",
260 | from_args=True,
261 | )
262 | assert results == [1 + 3 + 2, 2 + 4 + 2]
263 | results = unify.map(
264 | contextual_func,
265 | (1, 2),
266 | (3, 4),
267 | c=[2, 4],
268 | mode="asyncio",
269 | from_args=True,
270 | )
271 | assert results == [1 + 3 + 2, 2 + 4 + 4]
272 | results = unify.map(
273 | contextual_func,
274 | (1, 3),
275 | b=[2, 4],
276 | c=[2, 4],
277 | mode="asyncio",
278 | from_args=True,
279 | )
280 | assert results == [1 + 2 + 2, 3 + 4 + 4]
281 |
282 |
283 | if __name__ == "__main__":
284 | pass
285 |
--------------------------------------------------------------------------------
/unify/__init__.py:
--------------------------------------------------------------------------------
1 | """Unify python module."""
2 |
3 | import os
4 | from typing import Callable, Optional
5 |
6 |
7 | if "UNIFY_BASE_URL" in os.environ.keys():
8 | BASE_URL = os.environ["UNIFY_BASE_URL"]
9 | else:
10 | BASE_URL = "https://api.unify.ai/v0"
11 |
12 |
13 | CLIENT_LOGGING = False
14 | LOCAL_MODELS = dict()
15 | SEED = None
16 | UNIFY_DIR = os.path.dirname(__file__)
17 |
18 |
19 | def set_seed(seed: int) -> None:
20 | global SEED
21 | SEED = seed
22 |
23 |
24 | def get_seed() -> Optional[int]:
25 | return SEED
26 |
27 |
28 | def register_local_model(model_name: str, fn: Callable):
29 | if "@local" not in model_name:
30 | model_name += "@local"
31 | LOCAL_MODELS[model_name] = fn
32 |
33 |
34 | from .universal_api.utils import (
35 | credits,
36 | custom_api_keys,
37 | custom_endpoints,
38 | endpoint_metrics,
39 | queries,
40 | supported_endpoints,
41 | )
42 | from .universal_api.utils.credits import *
43 | from .universal_api.utils.custom_api_keys import *
44 | from .universal_api.utils.custom_endpoints import *
45 | from .universal_api.utils.endpoint_metrics import *
46 | from .universal_api.utils.queries import *
47 | from .universal_api.utils.supported_endpoints import *
48 |
49 | from .logging.utils import artifacts
50 | from .logging.utils import compositions
51 | from .logging.utils import contexts
52 | from .logging.utils import datasets
53 | from .logging.utils import logs
54 | from .logging.utils import projects
55 |
56 | from .logging.utils.artifacts import *
57 | from .logging.utils.compositions import *
58 | from .logging.utils.contexts import *
59 | from .logging.utils.datasets import *
60 | from .logging.utils.logs import *
61 | from .logging.utils.projects import *
62 | from .logging.utils.tracing import install_tracing_hook, disable_tracing_hook
63 |
64 | from .utils import helpers, map, get_map_mode, set_map_mode, _caching
65 | from .utils._caching import (
66 | set_caching,
67 | set_caching_fname,
68 | cache_file_union,
69 | cache_file_intersection,
70 | subtract_cache_files,
71 | cached,
72 | )
73 |
74 | from .universal_api import chatbot, clients, usage
75 | from .universal_api.clients import multi_llm
76 | from .universal_api.chatbot import *
77 | from unify.universal_api.clients.uni_llm import *
78 | from unify.universal_api.clients.multi_llm import *
79 |
80 | from .universal_api import casting, types
81 | from .logging import dataset, logs
82 |
83 | from .universal_api.casting import *
84 | from .universal_api.usage import *
85 | from .universal_api.types import *
86 |
87 | from .logging.dataset import *
88 | from .logging.logs import *
89 |
90 |
91 | # Project #
92 | # --------#
93 |
94 | PROJECT: Optional[str] = None
95 |
96 |
97 | # noinspection PyShadowingNames
98 | def activate(project: str, overwrite: bool = False, api_key: str = None) -> None:
99 | if project not in list_projects(api_key=api_key):
100 | create_project(project, api_key=api_key)
101 | elif overwrite:
102 | create_project(project, api_key=api_key, overwrite=True)
103 | global PROJECT
104 | PROJECT = project
105 |
106 |
107 | def deactivate() -> None:
108 | global PROJECT
109 | PROJECT = None
110 |
111 |
112 | def active_project() -> str:
113 | global PROJECT
114 | if PROJECT is None:
115 | return os.environ.get("UNIFY_PROJECT")
116 | return PROJECT
117 |
118 |
119 | class Project:
120 |
121 | # noinspection PyShadowingNames
122 | def __init__(
123 | self,
124 | project: str,
125 | overwrite: bool = False,
126 | api_key: Optional[str] = None,
127 | ) -> None:
128 | self._project = project
129 | self._overwrite = overwrite
130 | # noinspection PyProtectedMember
131 | self._api_key = helpers._validate_api_key(api_key)
132 | self._entered = False
133 |
134 | def create(self) -> None:
135 | create_project(self._project, overwrite=self._overwrite, api_key=self._api_key)
136 |
137 | def delete(self):
138 | delete_project(self._project, api_key=self._api_key)
139 |
140 | def rename(self, new_name: str):
141 | rename_project(self._project, new_name, api_key=self._api_key)
142 | self._project = new_name
143 | if self._entered:
144 | activate(self._project)
145 |
146 | def __enter__(self):
147 | activate(self._project)
148 | if self._project not in list_projects(api_key=self._api_key) or self._overwrite:
149 | self.create()
150 | self._entered = True
151 |
152 | def __exit__(self, exc_type, exc_val, exc_tb):
153 | deactivate()
154 | self._entered = False
155 |
--------------------------------------------------------------------------------
/unify/logging/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/unify/logging/__init__.py
--------------------------------------------------------------------------------
/unify/logging/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 |
4 | class RequestError(Exception):
5 | def __init__(self, response: requests.Response):
6 | req = response.request
7 | message = (
8 | f"{req.method} {req.url} failed with status code {response.status_code}. "
9 | f"Request body: {req.body}, Response: {response.text}"
10 | )
11 | super().__init__(message)
12 | self.response = response
13 |
14 |
15 | def _check_response(response: requests.Response):
16 | if not response.ok:
17 | raise RequestError(response)
18 |
--------------------------------------------------------------------------------
/unify/logging/utils/artifacts.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 |
3 | from unify import BASE_URL
4 | from unify.utils import _requests
5 |
6 | from ...utils.helpers import (
7 | _check_response,
8 | _get_and_maybe_create_project,
9 | _validate_api_key,
10 | )
11 |
12 | # Artifacts #
13 | # ----------#
14 |
15 |
16 | def add_project_artifacts(
17 | *,
18 | project: Optional[str] = None,
19 | api_key: Optional[str] = None,
20 | **kwargs,
21 | ) -> Dict[str, str]:
22 | """
23 | Creates one or more artifacts associated to a project. Artifacts are project-level
24 | metadata that don’t depend on other variables.
25 |
26 | Args:
27 | project: Name of the project the artifacts belong to.
28 |
29 | api_key: If specified, unify API key to be used. Defaults to the value in the
30 | `UNIFY_KEY` environment variable.
31 |
32 | kwargs: Dictionary containing one or more key:value pairs that will be stored
33 | as artifacts.
34 |
35 | Returns:
36 | A message indicating whether the artifacts were successfully added.
37 | """
38 | api_key = _validate_api_key(api_key)
39 | headers = {
40 | "accept": "application/json",
41 | "Authorization": f"Bearer {api_key}",
42 | }
43 | body = {"artifacts": kwargs}
44 | project = _get_and_maybe_create_project(project, api_key=api_key)
45 | response = _requests.post(
46 | BASE_URL + f"/project/{project}/artifacts",
47 | headers=headers,
48 | json=body,
49 | )
50 | _check_response(response)
51 | return response.json()
52 |
53 |
54 | def delete_project_artifact(
55 | key: str,
56 | *,
57 | project: Optional[str] = None,
58 | api_key: Optional[str] = None,
59 | ) -> str:
60 | """
61 | Deletes an artifact from a project.
62 |
63 | Args:
64 | project: Name of the project to delete an artifact from.
65 |
66 | key: Key of the artifact to delete.
67 |
68 | api_key: If specified, unify API key to be used. Defaults to the value in the
69 | `UNIFY_KEY` environment variable.
70 |
71 | Returns:
72 | Whether the artifact was successfully deleted.
73 | """
74 | api_key = _validate_api_key(api_key)
75 | headers = {
76 | "accept": "application/json",
77 | "Authorization": f"Bearer {api_key}",
78 | }
79 | project = _get_and_maybe_create_project(project, api_key=api_key)
80 | response = _requests.delete(
81 | BASE_URL + f"/project/{project}/artifacts/{key}",
82 | headers=headers,
83 | )
84 | _check_response(response)
85 | return response.json()
86 |
87 |
88 | def get_project_artifacts(
89 | *,
90 | project: Optional[str] = None,
91 | api_key: Optional[str] = None,
92 | ) -> Dict[str, Any]:
93 | """
94 | Returns the key-value pairs for all artifacts in a project.
95 |
96 | Args:
97 | project: Name of the project to delete an artifact from.
98 |
99 | api_key: If specified, unify API key to be used. Defaults to the value in the
100 | `UNIFY_KEY` environment variable.
101 |
102 | Returns:
103 | A dictionary of all artifacts associated with the project, with keys for
104 | artifact names and values for the artifacts themselves.
105 | """
106 | api_key = _validate_api_key(api_key)
107 | headers = {
108 | "accept": "application/json",
109 | "Authorization": f"Bearer {api_key}",
110 | }
111 | project = _get_and_maybe_create_project(project, api_key=api_key)
112 | response = _requests.get(
113 | BASE_URL + f"/project/{project}/artifacts",
114 | headers=headers,
115 | )
116 | _check_response(response)
117 | return response.json()
118 |
--------------------------------------------------------------------------------
/unify/logging/utils/async_logger.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import os
4 | import threading
5 | from concurrent.futures import TimeoutError
6 | from typing import List
7 |
8 | import aiohttp
9 | from unify import BASE_URL
10 |
11 | # Configure logging based on environment variable
12 | ASYNC_LOGGER_DEBUG = os.getenv("UNIFY_ASYNC_LOGGER_DEBUG", "false").lower() in (
13 | "true",
14 | "1",
15 | )
16 | logger = logging.getLogger("async_logger")
17 | logger.setLevel(logging.DEBUG if ASYNC_LOGGER_DEBUG else logging.WARNING)
18 |
19 |
20 | class AsyncLoggerManager:
21 | def __init__(
22 | self,
23 | *,
24 | base_url: str = BASE_URL,
25 | api_key: str = os.getenv("UNIFY_KEY"),
26 | num_consumers: int = 256,
27 | max_queue_size: int = 10000,
28 | ):
29 |
30 | self.loop = asyncio.new_event_loop()
31 | self.queue = None
32 | self.consumers: List[asyncio.Task] = []
33 | self.num_consumers = num_consumers
34 | self.start_flag = threading.Event()
35 | self.shutting_down = False
36 | self.max_queue_size = max_queue_size
37 |
38 | headers = {
39 | "Authorization": f"Bearer {api_key}",
40 | "Content-Type": "application/json",
41 | "accept": "application/json",
42 | }
43 | url = base_url + "/"
44 | connector = aiohttp.TCPConnector(limit=num_consumers // 2, loop=self.loop)
45 | self.session = aiohttp.ClientSession(
46 | url,
47 | headers=headers,
48 | loop=self.loop,
49 | connector=connector,
50 | )
51 |
52 | self.thread = threading.Thread(target=self._run_loop, daemon=True)
53 | self.thread.start()
54 | self.start_flag.wait()
55 | self.callbacks = []
56 |
57 | def register_callback(self, fn):
58 | self.callbacks.append(fn)
59 |
60 | def clear_callbacks(self):
61 | self.callbacks = []
62 |
63 | def _notify_callbacks(self):
64 | for fn in self.callbacks:
65 | fn()
66 |
67 | async def _join(self):
68 | await self.queue.join()
69 |
70 | def join(self):
71 | try:
72 | future = asyncio.run_coroutine_threadsafe(self._join(), self.loop)
73 | while True:
74 | try:
75 | future.result(timeout=0.5)
76 | break
77 | except (asyncio.TimeoutError, TimeoutError):
78 | continue
79 | except Exception as e:
80 | logger.error(f"Error in join: {e}")
81 | raise e
82 |
83 | async def _main_loop(self):
84 | self.start_flag.set()
85 | await asyncio.gather(*self.consumers, return_exceptions=True)
86 |
87 | def _run_loop(self):
88 | asyncio.set_event_loop(self.loop)
89 | self.queue = asyncio.Queue(maxsize=self.max_queue_size)
90 |
91 | for _ in range(self.num_consumers):
92 | self.consumers.append(self._log_consumer())
93 |
94 | try:
95 | self.loop.run_until_complete(self._main_loop())
96 | except Exception as e:
97 | logger.error(f"Event loop error: {e}")
98 | raise e
99 | finally:
100 | self.loop.close()
101 |
102 | async def _consume_create(self, body, future, idx):
103 | async with self.session.post("logs", json=body) as res:
104 | if res.status != 200:
105 | txt = await res.text()
106 | logger.error(f"Error in consume_create {idx}: {txt}")
107 | return
108 | res_json = await res.json()
109 | logger.debug(f"Created {idx} with response {res.status}: {res_json}")
110 | future.set_result(res_json["log_event_ids"][0])
111 |
112 | async def _consume_update(self, body, future, idx):
113 | if not future.done():
114 | await future
115 | body["logs"] = [future.result()]
116 | async with self.session.put("logs", json=body) as res:
117 | if res.status != 200:
118 | txt = await res.text()
119 | logger.error(f"Error in consume_update {idx}: {txt}")
120 | return
121 | res_json = await res.json()
122 | logger.debug(f"Updated {idx} with response {res.status}: {res_json}")
123 |
124 | async def _log_consumer(self):
125 | while True:
126 | try:
127 | event = await self.queue.get()
128 | idx = self.queue.qsize() + 1
129 | logger.debug(f"Processing event {event['type']}: {idx}")
130 | if event["type"] == "create":
131 | await self._consume_create(event["_data"], event["future"], idx)
132 | elif event["type"] == "update":
133 | await self._consume_update(event["_data"], event["future"], idx)
134 | else:
135 | raise Exception(f"Unknown event type: {event['type']}")
136 | except Exception as e:
137 | event["future"].set_exception(e)
138 | logger.error(f"Error in consumer: {e}")
139 | raise e
140 | finally:
141 | self.queue.task_done()
142 | self._notify_callbacks()
143 |
144 | def log_create(
145 | self,
146 | project: str,
147 | context: str,
148 | params: dict,
149 | entries: dict,
150 | ) -> asyncio.Future:
151 | fut = self.loop.create_future()
152 | event = {
153 | "_data": {
154 | "project": project,
155 | "context": context,
156 | "params": params,
157 | "entries": entries,
158 | },
159 | "type": "create",
160 | "future": fut,
161 | }
162 | asyncio.run_coroutine_threadsafe(self.queue.put(event), self.loop).result()
163 | return fut
164 |
165 | def log_update(
166 | self,
167 | project: str,
168 | context: str,
169 | future: asyncio.Future,
170 | mode: str,
171 | overwrite: bool,
172 | data: dict,
173 | ) -> None:
174 | event = {
175 | "_data": {
176 | mode: data,
177 | "project": project,
178 | "context": context,
179 | "overwrite": overwrite,
180 | },
181 | "type": "update",
182 | "future": future,
183 | }
184 | asyncio.run_coroutine_threadsafe(self.queue.put(event), self.loop).result()
185 |
186 | def stop_sync(self, immediate=False):
187 | if self.shutting_down:
188 | return
189 |
190 | self.shutting_down = True
191 | if immediate:
192 | logger.debug("Stopping async logger immediately")
193 | self.loop.stop()
194 | else:
195 | self.join()
196 |
--------------------------------------------------------------------------------
/unify/logging/utils/compositions.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 |
5 | from ...utils.helpers import _validate_api_key
6 | from .logs import *
7 |
8 | # Parameters #
9 | # -----------#
10 |
11 |
12 | def get_param_by_version(
13 | field: str,
14 | version: Union[str, int],
15 | api_key: Optional[str] = None,
16 | ) -> Any:
17 | """
18 | Gets the parameter by version.
19 |
20 | Args:
21 | field: The field of the parameter to get.
22 |
23 | version: The version of the parameter to get.
24 |
25 | api_key: If specified, unify API key to be used. Defaults to the value in the
26 | `UNIFY_KEY` environment variable.
27 |
28 | Returns:
29 | The parameter by version.
30 | """
31 | api_key = _validate_api_key(api_key)
32 | version = str(version)
33 | filter_exp = f"version({field}) == {version}"
34 | return get_logs(filter=filter_exp, limit=1, api_key=api_key)[0].params[field][1]
35 |
36 |
37 | def get_param_by_value(
38 | field: str,
39 | value: Any,
40 | api_key: Optional[str] = None,
41 | ) -> Any:
42 | """
43 | Gets the parameter by value.
44 |
45 | Args:
46 | field: The field of the parameter to get.
47 |
48 | value: The value of the parameter to get.
49 |
50 | api_key: If specified, unify API key to be used. Defaults to the value in the
51 | `UNIFY_KEY` environment variable.
52 |
53 | Returns:
54 | The parameter by version.
55 | """
56 | api_key = _validate_api_key(api_key)
57 | filter_exp = f"{field} == {json.dumps(value)}"
58 | return get_logs(filter=filter_exp, limit=1, api_key=api_key)[0].params[field][0]
59 |
60 |
61 | def get_source() -> str:
62 | """
63 | Extracts the source code for the file from where this function was called.
64 |
65 | Returns:
66 | The source code for the file, as a string.
67 | """
68 | frame = inspect.getouterframes(inspect.currentframe())[1]
69 | with open(frame.filename, "r") as file:
70 | source = file.read()
71 | return f"```python\n{source}\n```"
72 |
73 |
74 | # Experiments #
75 | # ------------#
76 |
77 |
78 | def get_experiment_name(version: int, api_key: Optional[str] = None) -> str:
79 | """
80 | Gets the experiment name (by version).
81 |
82 | Args:
83 | version: The version of the experiment to get.
84 |
85 | api_key: If specified, unify API key to be used. Defaults to the value in the
86 | `UNIFY_KEY` environment variable.
87 |
88 | Returns:
89 | The experiment name with said version.
90 | """
91 | experiments = get_groups(key="experiment", api_key=api_key)
92 | if not experiments:
93 | return None
94 | elif version < 0:
95 | version = len(experiments) + version
96 | if str(version) not in experiments:
97 | return None
98 | return experiments[str(version)]
99 |
100 |
101 | def get_experiment_version(name: str, api_key: Optional[str] = None) -> int:
102 | """
103 | Gets the experiment version (by name).
104 |
105 | Args:
106 | name: The name of the experiment to get.
107 |
108 | api_key: If specified, unify API key to be used. Defaults to the value in the
109 | `UNIFY_KEY` environment variable.
110 |
111 | Returns:
112 | The experiment version with said name.
113 | """
114 | experiments = get_groups(key="experiment", api_key=api_key)
115 | if not experiments:
116 | return None
117 | experiments = {v: k for k, v in experiments.items()}
118 | if name not in experiments:
119 | return None
120 | return int(experiments[name])
121 |
--------------------------------------------------------------------------------
/unify/logging/utils/contexts.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 |
3 | from unify import BASE_URL
4 | from unify.utils import _requests
5 |
6 | from ...utils.helpers import (
7 | _check_response,
8 | _get_and_maybe_create_project,
9 | _validate_api_key,
10 | )
11 | from .logs import CONTEXT_WRITE
12 |
13 | # Contexts #
14 | # ---------#
15 |
16 |
17 | def create_context(
18 | name: str,
19 | description: str = None,
20 | is_versioned: bool = False,
21 | allow_duplicates: bool = True,
22 | *,
23 | project: Optional[str] = None,
24 | api_key: Optional[str] = None,
25 | ) -> None:
26 | """
27 | Create a context.
28 |
29 | Args:
30 | name: Name of the context to create.
31 |
32 | description: Description of the context to create.
33 |
34 | is_versioned: Whether the context is versioned.
35 |
36 | allow_duplicates: Whether to allow duplicates in the context.
37 |
38 | project: Name of the project the context belongs to.
39 |
40 | api_key: If specified, unify API key to be used. Defaults to the value in the
41 | `UNIFY_KEY` environment variable.
42 |
43 | Returns:
44 | A message indicating whether the context was successfully created.
45 | """
46 | api_key = _validate_api_key(api_key)
47 | project = _get_and_maybe_create_project(
48 | project,
49 | api_key=api_key,
50 | create_if_missing=False,
51 | )
52 | headers = {
53 | "accept": "application/json",
54 | "Authorization": f"Bearer {api_key}",
55 | }
56 | body = {
57 | "name": name,
58 | "description": description,
59 | "is_versioned": is_versioned,
60 | "allow_duplicates": allow_duplicates,
61 | }
62 | response = _requests.post(
63 | BASE_URL + f"/project/{project}/contexts",
64 | headers=headers,
65 | json=body,
66 | )
67 | _check_response(response)
68 | return response.json()
69 |
70 |
71 | def rename_context(
72 | name: str,
73 | new_name: str,
74 | *,
75 | project: Optional[str] = None,
76 | api_key: Optional[str] = None,
77 | ) -> None:
78 | """
79 | Rename a context.
80 |
81 | Args:
82 | name: Name of the context to rename.
83 |
84 | new_name: New name of the context.
85 |
86 | project: Name of the project the context belongs to.
87 |
88 | api_key: If specified, unify API key to be used. Defaults to the value in the
89 | `UNIFY_KEY` environment variable.
90 | """
91 | api_key = _validate_api_key(api_key)
92 | project = _get_and_maybe_create_project(
93 | project,
94 | api_key=api_key,
95 | create_if_missing=False,
96 | )
97 | headers = {
98 | "accept": "application/json",
99 | "Authorization": f"Bearer {api_key}",
100 | }
101 | response = _requests.patch(
102 | BASE_URL + f"/project/{project}/contexts/{name}/rename",
103 | headers=headers,
104 | json={"name": new_name},
105 | )
106 | _check_response(response)
107 | return response.json()
108 |
109 |
110 | def get_context(
111 | name: str,
112 | *,
113 | project: Optional[str] = None,
114 | api_key: Optional[str] = None,
115 | ) -> Dict[str, str]:
116 | """
117 | Get information about a specific context including its versioning status and current version.
118 |
119 | Args:
120 | name: Name of the context to get.
121 |
122 | project: Name of the project the context belongs to.
123 |
124 | api_key: If specified, unify API key to be used. Defaults to the value in the
125 | `UNIFY_KEY` environment variable.
126 | """
127 | api_key = _validate_api_key(api_key)
128 | project = _get_and_maybe_create_project(
129 | project,
130 | api_key=api_key,
131 | create_if_missing=False,
132 | )
133 | headers = {
134 | "accept": "application/json",
135 | "Authorization": f"Bearer {api_key}",
136 | }
137 | response = _requests.get(
138 | BASE_URL + f"/project/{project}/contexts/{name}",
139 | headers=headers,
140 | )
141 | _check_response(response)
142 | return response.json()
143 |
144 |
145 | def get_contexts(
146 | project: Optional[str] = None,
147 | *,
148 | prefix: Optional[str] = None,
149 | api_key: Optional[str] = None,
150 | ) -> Dict[str, str]:
151 | """
152 | Gets all contexts associated with a project, with the corresponding prefix.
153 |
154 | Args:
155 | prefix: Prefix of the contexts to get.
156 |
157 | project: Name of the project the artifacts belong to.
158 |
159 | api_key: If specified, unify API key to be used. Defaults to the value in the
160 | `UNIFY_KEY` environment variable.
161 |
162 | kwargs: Dictionary containing one or more key:value pairs that will be stored
163 | as artifacts.
164 |
165 | Returns:
166 | A message indicating whether the artifacts were successfully added.
167 | """
168 | api_key = _validate_api_key(api_key)
169 | headers = {
170 | "accept": "application/json",
171 | "Authorization": f"Bearer {api_key}",
172 | }
173 | project = _get_and_maybe_create_project(
174 | project,
175 | api_key=api_key,
176 | create_if_missing=False,
177 | )
178 | response = _requests.get(
179 | BASE_URL + f"/project/{project}/contexts",
180 | headers=headers,
181 | )
182 | _check_response(response)
183 | contexts = response.json()
184 | contexts = {context["name"]: context["description"] for context in contexts}
185 | if prefix:
186 | contexts = {
187 | context: description
188 | for context, description in contexts.items()
189 | if context.startswith(prefix)
190 | }
191 | return contexts
192 |
193 |
194 | def delete_context(
195 | name: str,
196 | *,
197 | delete_children: bool = True,
198 | project: Optional[str] = None,
199 | api_key: Optional[str] = None,
200 | ) -> None:
201 | """
202 | Delete a context from the server.
203 |
204 | Args:
205 | name: Name of the context to delete.
206 |
207 | delete_children: Whether to delete child contexts (which share the same "/" separated prefix).
208 |
209 | project: Name of the project the context belongs to.
210 |
211 | api_key: If specified, unify API key to be used. Defaults to the value in the
212 | `UNIFY_KEY` environment variable.
213 | """
214 | api_key = _validate_api_key(api_key)
215 | project = _get_and_maybe_create_project(
216 | project,
217 | api_key=api_key,
218 | create_if_missing=False,
219 | )
220 | headers = {
221 | "accept": "application/json",
222 | "Authorization": f"Bearer {api_key}",
223 | }
224 |
225 | # ToDo: remove this hack once this task [https://app.clickup.com/t/86c3kuch6] is done
226 | all_contexts = get_contexts(project, prefix=name)
227 | for ctx in all_contexts:
228 | response = _requests.delete(
229 | BASE_URL + f"/project/{project}/contexts/{ctx}",
230 | headers=headers,
231 | )
232 | _check_response(response)
233 | if all_contexts:
234 | return response.json()
235 |
236 |
237 | def add_logs_to_context(
238 | log_ids: List[int],
239 | *,
240 | context: Optional[str] = None,
241 | project: Optional[str] = None,
242 | api_key: Optional[str] = None,
243 | ) -> None:
244 | """
245 | Add logs to a context.
246 |
247 | Args:
248 | log_ids: List of log ids to add to the context.
249 |
250 | context: Name of the context to add the logs to.
251 |
252 | project: Name of the project the logs belong to.
253 |
254 | api_key: If specified, unify API key to be used. Defaults to the value in the
255 | `UNIFY_KEY` environment variable.
256 |
257 | Returns:
258 | A message indicating whether the logs were successfully added to the context.
259 | """
260 | api_key = _validate_api_key(api_key)
261 | context = context if context else CONTEXT_WRITE.get()
262 | project = _get_and_maybe_create_project(
263 | project,
264 | api_key=api_key,
265 | create_if_missing=False,
266 | )
267 | headers = {
268 | "accept": "application/json",
269 | "Authorization": f"Bearer {api_key}",
270 | }
271 | body = {
272 | "context_name": context,
273 | "log_ids": log_ids,
274 | }
275 | response = _requests.post(
276 | BASE_URL + f"/project/{project}/contexts/add_logs",
277 | headers=headers,
278 | json=body,
279 | )
280 | _check_response(response)
281 | return response.json()
282 |
--------------------------------------------------------------------------------
/unify/logging/utils/datasets.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional
2 |
3 | from ...utils.helpers import _get_and_maybe_create_project, _validate_api_key
4 | from ..logs import Log
5 | from .contexts import *
6 | from .logs import *
7 |
8 | # Datasets #
9 | # ---------#
10 |
11 |
12 | def list_datasets(
13 | *,
14 | project: Optional[str] = None,
15 | prefix: str = "",
16 | api_key: Optional[str] = None,
17 | ) -> Dict[str, str]:
18 | """
19 | List all datasets associated with a project and context.
20 |
21 | Args:
22 | project: Name of the project the datasets belong to.
23 |
24 | prefix: Prefix of the datasets to get.
25 |
26 | api_key: If specified, unify API key to be used. Defaults to the value in the
27 | `UNIFY_KEY` environment variable.
28 |
29 | Returns:
30 | A list of datasets.
31 | """
32 | api_key = _validate_api_key(api_key)
33 | contexts = get_contexts(
34 | prefix=f"Datasets/{prefix}",
35 | project=project,
36 | api_key=api_key,
37 | )
38 | return {
39 | "/".join(name.split("/")[1:]): description
40 | for name, description in contexts.items()
41 | }
42 |
43 |
44 | def upload_dataset(
45 | name: str,
46 | data: List[Any],
47 | *,
48 | overwrite: bool = False,
49 | allow_duplicates: bool = False,
50 | project: Optional[str] = None,
51 | api_key: Optional[str] = None,
52 | ) -> List[int]:
53 | """
54 | Upload a dataset to the server.
55 |
56 | Args:
57 | name: Name of the dataset.
58 |
59 | data: Contents of the dataset.
60 |
61 | overwrite: Whether to overwrite the dataset if it already exists.
62 |
63 | allow_duplicates: Whether to allow duplicates in the dataset.
64 |
65 | project: Name of the project the dataset belongs to.
66 |
67 | api_key: If specified, unify API key to be used. Defaults to the value in the
68 | `UNIFY_KEY` environment variable.
69 | Returns:
70 | A list all log ids in the dataset.
71 | """
72 | api_key = _validate_api_key(api_key)
73 | project = _get_and_maybe_create_project(project, api_key=api_key)
74 | context = f"Datasets/{name}"
75 | log_instances = [isinstance(item, unify.Log) for item in data]
76 | are_logs = False
77 | if not allow_duplicates and not overwrite:
78 | # ToDo: remove this verbose logic once ignore_duplicates is implemented
79 | if name in unify.list_datasets():
80 | upstream_dataset = unify.Dataset(
81 | unify.download_dataset(name, project=project, api_key=api_key),
82 | )
83 | else:
84 | upstream_dataset = unify.Dataset([])
85 | if any(log_instances):
86 | assert all(log_instances), "If any items are logs, all items must be logs"
87 | are_logs = True
88 | # ToDo: remove this verbose logic once ignore_duplicates is implemented
89 | if not allow_duplicates and not overwrite:
90 | data = [l for l in data if l not in upstream_dataset]
91 | elif not all(isinstance(item, dict) for item in data):
92 | # ToDo: remove this verbose logic once ignore_duplicates is implemented
93 | if not allow_duplicates and not overwrite:
94 | data = [item for item in data if item not in upstream_dataset]
95 | data = [{"data": item} for item in data]
96 | if name in unify.list_datasets():
97 | upstream_ids = get_logs(
98 | project=project,
99 | context=context,
100 | return_ids_only=True,
101 | )
102 | else:
103 | upstream_ids = []
104 | if not are_logs:
105 | return upstream_ids + create_logs(
106 | project=project,
107 | context=context,
108 | entries=data,
109 | mutable=True,
110 | batched=True,
111 | # ToDo: uncomment once ignore_duplicates is implemented
112 | # ignore_duplicates=not allow_duplicates,
113 | )
114 | local_ids = [l.id for l in data]
115 | matching_ids = [id for id in upstream_ids if id in local_ids]
116 | matching_data = [l.entries for l in data if l.id in matching_ids]
117 | assert len(matching_data) == len(
118 | matching_ids,
119 | ), "matching data and ids must be the same length"
120 | if matching_data:
121 | update_logs(
122 | logs=matching_ids,
123 | api_key=api_key,
124 | entries=matching_data,
125 | overwrite=True,
126 | )
127 | if overwrite:
128 | upstream_only_ids = [id for id in upstream_ids if id not in local_ids]
129 | if upstream_only_ids:
130 | delete_logs(
131 | logs=upstream_only_ids,
132 | context=context,
133 | project=project,
134 | api_key=api_key,
135 | )
136 | upstream_ids = [id for id in upstream_ids if id not in upstream_only_ids]
137 | ids_not_in_dataset = [
138 | id for id in local_ids if id not in matching_ids and id is not None
139 | ]
140 | if ids_not_in_dataset:
141 | if context not in unify.get_contexts():
142 | unify.create_context(
143 | context,
144 | project=project,
145 | api_key=api_key,
146 | )
147 | unify.add_logs_to_context(
148 | log_ids=ids_not_in_dataset,
149 | context=context,
150 | project=project,
151 | api_key=api_key,
152 | )
153 | local_only_data = [l.entries for l in data if l.id is None]
154 | if local_only_data:
155 | return upstream_ids + create_logs(
156 | project=project,
157 | context=context,
158 | entries=local_only_data,
159 | mutable=True,
160 | batched=True,
161 | )
162 | return upstream_ids + ids_not_in_dataset
163 |
164 |
165 | def download_dataset(
166 | name: str,
167 | *,
168 | project: Optional[str] = None,
169 | api_key: Optional[str] = None,
170 | ) -> List[Log]:
171 | """
172 | Download a dataset from the server.
173 |
174 | Args:
175 | name: Name of the dataset.
176 |
177 | project: Name of the project the dataset belongs to.
178 |
179 | api_key: If specified, unify API key to be used. Defaults to the value in the
180 | `UNIFY_KEY` environment variable.
181 | """
182 | api_key = _validate_api_key(api_key)
183 | project = _get_and_maybe_create_project(project, api_key=api_key)
184 | logs = get_logs(
185 | project=project,
186 | context=f"Datasets/{name}",
187 | )
188 | return list(reversed(logs))
189 |
190 |
191 | def delete_dataset(
192 | name: str,
193 | *,
194 | project: Optional[str] = None,
195 | api_key: Optional[str] = None,
196 | ) -> None:
197 | """
198 | Delete a dataset from the server.
199 |
200 | Args:
201 | name: Name of the dataset.
202 |
203 | project: Name of the project the dataset belongs to.
204 |
205 | api_key: If specified, unify API key to be used. Defaults to the value in the
206 | `UNIFY_KEY` environment variable.
207 | """
208 | api_key = _validate_api_key(api_key)
209 | project = _get_and_maybe_create_project(project, api_key=api_key)
210 | delete_context(f"Datasets/{name}", project=project, api_key=api_key)
211 |
212 |
213 | def add_dataset_entries(
214 | name: str,
215 | data: List[Any],
216 | *,
217 | project: Optional[str] = None,
218 | api_key: Optional[str] = None,
219 | ) -> List[int]:
220 | """
221 | Adds entries to an existing dataset in the server.
222 |
223 | Args:
224 | name: Name of the dataset.
225 |
226 | contents: Contents to add to the dataset.
227 |
228 | project: Name of the project the dataset belongs to.
229 |
230 | api_key: If specified, unify API key to be used. Defaults to the value in the
231 | `UNIFY_KEY` environment variable.
232 | Returns:
233 | A list of the newly added dataset logs.
234 | """
235 | api_key = _validate_api_key(api_key)
236 | project = _get_and_maybe_create_project(
237 | project,
238 | api_key=api_key,
239 | create_if_missing=False,
240 | )
241 | if not all(isinstance(item, dict) for item in data):
242 | data = [{"data": item} for item in data]
243 | logs = create_logs(
244 | project=project,
245 | context=f"Datasets/{name}",
246 | entries=data,
247 | mutable=True,
248 | batched=True,
249 | )
250 | return logs
251 |
--------------------------------------------------------------------------------
/unify/logging/utils/projects.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 |
3 | from unify import BASE_URL
4 | from unify.utils import _requests
5 |
6 | from ...utils.helpers import _check_response, _validate_api_key
7 |
8 | # Projects #
9 | # ---------#
10 |
11 |
12 | def create_project(
13 | name: str,
14 | *,
15 | overwrite: bool = False,
16 | api_key: Optional[str] = None,
17 | ) -> Dict[str, str]:
18 | """
19 | Creates a logging project and adds this to your account. This project will have
20 | a set of logs associated with it.
21 |
22 | Args:
23 | name: A unique, user-defined name used when referencing the project.
24 |
25 | overwrite: Whether to overwrite an existing project if is already exists.
26 |
27 | api_key: If specified, unify API key to be used. Defaults to the value in the
28 | `UNIFY_KEY` environment variable.
29 |
30 | Returns:
31 | A message indicating whether the project was created successfully.
32 | """
33 | api_key = _validate_api_key(api_key)
34 | headers = {
35 | "accept": "application/json",
36 | "Authorization": f"Bearer {api_key}",
37 | }
38 | body = {"name": name}
39 | if overwrite:
40 | if name in list_projects(api_key=api_key):
41 | delete_project(name=name, api_key=api_key)
42 | response = _requests.post(BASE_URL + "/project", headers=headers, json=body)
43 | _check_response(response)
44 | return response.json()
45 |
46 |
47 | def rename_project(
48 | name: str,
49 | new_name: str,
50 | *,
51 | api_key: Optional[str] = None,
52 | ) -> Dict[str, str]:
53 | """
54 | Renames a project from `name` to `new_name` in your account.
55 |
56 | Args:
57 | name: Name of the project to rename.
58 |
59 | new_name: A unique, user-defined name used when referencing the project.
60 |
61 | api_key: If specified, unify API key to be used. Defaults to the value in the
62 | `UNIFY_KEY` environment variable.
63 |
64 | Returns:
65 | A message indicating whether the project was successfully renamed.
66 | """
67 | api_key = _validate_api_key(api_key)
68 | headers = {
69 | "accept": "application/json",
70 | "Authorization": f"Bearer {api_key}",
71 | }
72 | body = {"name": new_name}
73 | response = _requests.patch(
74 | BASE_URL + f"/project/{name}",
75 | headers=headers,
76 | json=body,
77 | )
78 | _check_response(response)
79 | return response.json()
80 |
81 |
82 | def delete_project(
83 | name: str,
84 | *,
85 | api_key: Optional[str] = None,
86 | ) -> str:
87 | """
88 | Deletes a project from your account.
89 |
90 | Args:
91 | name: Name of the project to delete.
92 |
93 | api_key: If specified, unify API key to be used. Defaults to the value in the
94 | `UNIFY_KEY` environment variable.
95 |
96 | Returns:
97 | Whether the project was successfully deleted.
98 | """
99 | api_key = _validate_api_key(api_key)
100 | headers = {
101 | "accept": "application/json",
102 | "Authorization": f"Bearer {api_key}",
103 | }
104 | response = _requests.delete(BASE_URL + f"/project/{name}", headers=headers)
105 | _check_response(response)
106 | return response.json()
107 |
108 |
109 | def delete_project_logs(
110 | name: str,
111 | *,
112 | api_key: Optional[str] = None,
113 | ) -> None:
114 | """
115 | Deletes all logs from a project.
116 |
117 | Args:
118 | name: Name of the project to delete logs from.
119 |
120 | api_key: If specified, unify API key to be used. Defaults to the value in the
121 | `UNIFY_KEY` environment variable.
122 | """
123 | api_key = _validate_api_key(api_key)
124 | headers = {
125 | "accept": "application/json",
126 | "Authorization": f"Bearer {api_key}",
127 | }
128 | response = _requests.delete(BASE_URL + f"/project/{name}/logs", headers=headers)
129 | _check_response(response)
130 | return response.json()
131 |
132 |
133 | def delete_project_contexts(
134 | name: str,
135 | *,
136 | api_key: Optional[str] = None,
137 | ) -> None:
138 | """
139 | Deletes all contexts and their associated logs from a project
140 |
141 | Args:
142 | name: Name of the project to delete contexts from.
143 |
144 | api_key: If specified, unify API key to be used. Defaults to the value in the
145 | `UNIFY_KEY` environment variable.
146 | """
147 | api_key = _validate_api_key(api_key)
148 | headers = {
149 | "accept": "application/json",
150 | "Authorization": f"Bearer {api_key}",
151 | }
152 | response = _requests.delete(BASE_URL + f"/project/{name}/contexts", headers=headers)
153 | _check_response(response)
154 | return response.json()
155 |
156 |
157 | def list_projects(
158 | *,
159 | api_key: Optional[str] = None,
160 | ) -> List[str]:
161 | """
162 | Returns the names of all projects stored in your account.
163 |
164 | Args:
165 | api_key: If specified, unify API key to be used. Defaults to the value in the
166 | `UNIFY_KEY` environment variable.
167 |
168 | Returns:
169 | List of all project names.
170 | """
171 | api_key = _validate_api_key(api_key)
172 | headers = {
173 | "accept": "application/json",
174 | "Authorization": f"Bearer {api_key}",
175 | }
176 | response = _requests.get(BASE_URL + "/projects", headers=headers)
177 | _check_response(response)
178 | return response.json()
179 |
--------------------------------------------------------------------------------
/unify/logging/utils/tracing.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 | from typing import Callable, List
4 |
5 | import unify
6 |
7 |
8 | class TraceLoader(importlib.abc.Loader):
9 | def __init__(self, original_loader, filter: Callable = None):
10 | self._original_loader = original_loader
11 | self.filter = filter
12 |
13 | def create_module(self, spec):
14 | return self._original_loader.create_module(spec)
15 |
16 | def exec_module(self, module):
17 | self._original_loader.exec_module(module)
18 | unify.traced(module, filter=self.filter)
19 |
20 |
21 | class TraceFinder(importlib.abc.MetaPathFinder):
22 | def __init__(self, targets: List[str], filter: Callable = None):
23 | self.targets = targets
24 | self.filter = filter
25 |
26 | def find_spec(self, fullname, path, target=None):
27 | for target_module in self.targets:
28 | if not fullname.startswith(target_module):
29 | return None
30 |
31 | original_sys_meta_path = sys.meta_path[:]
32 | sys.meta_path = [
33 | finder for finder in sys.meta_path if not isinstance(finder, TraceFinder)
34 | ]
35 | try:
36 | spec = importlib.util.find_spec(fullname, path)
37 | if spec is None:
38 | return None
39 | finally:
40 | sys.meta_path = original_sys_meta_path
41 |
42 | if spec.origin is None or not spec.origin.endswith(".py"):
43 | return None
44 |
45 | spec.loader = TraceLoader(spec.loader, filter=self.filter)
46 | return spec
47 |
48 |
49 | def install_tracing_hook(targets: List[str], filter: Callable = None):
50 | """Install an import hook that wraps imported modules with the traced decorator.
51 |
52 | This function adds a TraceFinder to sys.meta_path that will intercept module imports
53 | and wrap them with the traced decorator. The hook will only be installed if one
54 | doesn't already exist.
55 |
56 | Args:
57 | targets: List of module name prefixes to target for tracing. Only modules
58 | whose names start with these prefixes will be wrapped.
59 |
60 | filter: A filter function that is passed to the traced decorator.
61 |
62 | """
63 | if not any(isinstance(finder, TraceFinder) for finder in sys.meta_path):
64 | sys.meta_path.insert(0, TraceFinder(targets, filter))
65 |
66 |
67 | def disable_tracing_hook():
68 | """Remove the tracing import hook from sys.meta_path.
69 |
70 | This function removes any TraceFinder instances from sys.meta_path, effectively
71 | disabling the tracing functionality for subsequent module imports.
72 |
73 | """
74 | for finder in sys.meta_path:
75 | if isinstance(finder, TraceFinder):
76 | sys.meta_path.remove(finder)
77 |
--------------------------------------------------------------------------------
/unify/universal_api/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/unify/universal_api/__init__.py
--------------------------------------------------------------------------------
/unify/universal_api/casting.py:
--------------------------------------------------------------------------------
1 | from typing import List, Type, Union
2 |
3 | from openai.types.chat import ChatCompletion
4 | from unify.universal_api.types import Prompt
5 |
6 | # Upcasting
7 |
8 |
9 | def _usr_msg_to_prompt(user_message: str) -> Prompt:
10 | return Prompt(user_message)
11 |
12 |
13 | def _bool_to_float(boolean: bool) -> float:
14 | return float(boolean)
15 |
16 |
17 | # Downcasting
18 |
19 |
20 | def _prompt_to_usr_msg(prompt: Prompt) -> str:
21 | return prompt.messages[-1]["content"]
22 |
23 |
24 | def _chat_completion_to_assis_msg(chat_completion: ChatCompletion) -> str:
25 | return chat_completion.choices[0].message.content
26 |
27 |
28 | def _float_to_bool(float_in: float) -> bool:
29 | return bool(float_in)
30 |
31 |
32 | # Cast Dict
33 |
34 | _CAST_DICT = {
35 | str: {Prompt: _usr_msg_to_prompt},
36 | Prompt: {
37 | str: _prompt_to_usr_msg,
38 | },
39 | ChatCompletion: {str: _chat_completion_to_assis_msg},
40 | bool: {
41 | float: _bool_to_float,
42 | },
43 | float: {
44 | bool: _float_to_bool,
45 | },
46 | }
47 |
48 |
49 | def _cast_from_selection(
50 | inp: Union[str, bool, float, Prompt, ChatCompletion],
51 | targets: List[Union[float, Prompt, ChatCompletion]],
52 | ) -> Union[str, bool, float, Prompt, ChatCompletion]:
53 | """
54 | Upcasts the input if possible, based on the permitted upcasting targets provided.
55 |
56 | Args:
57 | inp: The input to cast.
58 |
59 | targets: The set of permitted upcasting targets.
60 |
61 | Returns:
62 | The input after casting to the new type, if it was possible.
63 | """
64 | input_type = type(inp)
65 | assert input_type in _CAST_DICT, (
66 | "Cannot upcast input {} of type {}, because this type is not in the "
67 | "_CAST_DICT, meaning there are no functions for casting this type."
68 | )
69 | cast_fns = _CAST_DICT[input_type]
70 | targets = [target for target in targets if target in cast_fns]
71 | assert len(targets) == 1, "There must be exactly one valid casting target."
72 | to_type = targets[0]
73 | return cast_fns[to_type](inp)
74 |
75 |
76 | # Public function
77 |
78 |
79 | def cast(
80 | inp: Union[str, bool, float, Prompt, ChatCompletion],
81 | to_type: Union[
82 | Type[Union[str, bool, float, Prompt, ChatCompletion]],
83 | List[Type[Union[str, bool, float, Prompt, ChatCompletion]]],
84 | ],
85 | ) -> Union[str, bool, float, Prompt, ChatCompletion]:
86 | """
87 | Cast the input to the specified type.
88 |
89 | Args:
90 | inp: The input to cast.
91 |
92 | to_type: The type to cast the input to.
93 |
94 | Returns:
95 | The input after casting to the new type.
96 | """
97 | if isinstance(to_type, list):
98 | return _cast_from_selection(inp, to_type)
99 | input_type = type(inp)
100 | if input_type is to_type:
101 | return inp
102 | return _CAST_DICT[input_type][to_type](inp)
103 |
104 |
105 | def try_cast(
106 | inp: Union[str, bool, float, Prompt, ChatCompletion],
107 | to_type: Union[
108 | Type[Union[str, bool, float, Prompt, ChatCompletion]],
109 | List[Type[Union[str, bool, float, Prompt, ChatCompletion]]],
110 | ],
111 | ) -> Union[str, bool, float, Prompt, ChatCompletion]:
112 | # noinspection PyBroadException
113 | try:
114 | return cast(inp, to_type)
115 | except:
116 | return inp
117 |
--------------------------------------------------------------------------------
/unify/universal_api/chatbot.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import sys
3 | from typing import Dict, Union
4 |
5 | import unify
6 | from unify.universal_api.clients import _Client, _MultiClient, _UniClient
7 |
8 |
9 | class ChatBot: # noqa: WPS338
10 | """Agent class represents an LLM chat agent."""
11 |
12 | def __init__(
13 | self,
14 | client: _Client,
15 | ) -> None:
16 | """
17 | Initializes the ChatBot object, wrapped around a client.
18 |
19 | Args:
20 | client: The Client instance to wrap the chatbot logic around.
21 | """
22 | self._paused = False
23 | assert not client.return_full_completion, (
24 | "ChatBot currently only supports clients which only generate the message "
25 | "content in the return"
26 | )
27 | self._client = client
28 | self.clear_chat_history()
29 |
30 | @property
31 | def client(self) -> _Client:
32 | """
33 | Get the client object. # noqa: DAR201.
34 |
35 | Returns:
36 | The client.
37 | """
38 | return self._client
39 |
40 | def set_client(self, value: client) -> None:
41 | """
42 | Set the client. # noqa: DAR101.
43 |
44 | Args:
45 | value: The unify client.
46 | """
47 | if isinstance(value, _Client):
48 | self._client = value
49 | else:
50 | raise Exception("Invalid client!")
51 |
52 | def _get_credits(self) -> float:
53 | """
54 | Retrieves the current credit balance from associated with the UNIFY account.
55 |
56 | Returns:
57 | Current credit balance.
58 | """
59 | return self._client.get_credit_balance()
60 |
61 | def _update_message_history(
62 | self,
63 | role: str,
64 | content: Union[str, Dict[str, str]],
65 | ) -> None:
66 | """
67 | Updates message history with user input.
68 |
69 | Args:
70 | role: Either "assistant" or "user".
71 | content: User input message.
72 | """
73 | if isinstance(self._client, _UniClient):
74 | self._client.messages.append(
75 | {
76 | "role": role,
77 | "content": content,
78 | },
79 | )
80 | elif isinstance(self._client, _MultiClient):
81 | if isinstance(content, str):
82 | content = {endpoint: content for endpoint in self._client.endpoints}
83 | for endpoint, cont in content.items():
84 | self._client.messages[endpoint].append(
85 | {
86 | "role": role,
87 | "content": cont,
88 | },
89 | )
90 | else:
91 | raise Exception(
92 | "client must either be a UniClient or MultiClient instance.",
93 | )
94 |
95 | def clear_chat_history(self) -> None:
96 | """Clears the chat history."""
97 | if isinstance(self._client, _UniClient):
98 | self._client.set_messages([])
99 | elif isinstance(self._client, _MultiClient):
100 | self._client.set_messages(
101 | {endpoint: [] for endpoint in self._client.endpoints},
102 | )
103 | else:
104 | raise Exception(
105 | "client must either be a UniClient or MultiClient instance.",
106 | )
107 |
108 | @staticmethod
109 | def _stream_response(response) -> str:
110 | words = ""
111 | for chunk in response:
112 | words += chunk
113 | sys.stdout.write(chunk)
114 | sys.stdout.flush()
115 | sys.stdout.write("\n")
116 | return words
117 |
118 | def _handle_uni_llm_response(
119 | self,
120 | response: str,
121 | endpoint: Union[bool, str],
122 | ) -> str:
123 | if endpoint:
124 | endpoint = self._client.endpoint if endpoint is True else endpoint
125 | sys.stdout.write(endpoint + ":\n")
126 | if self._client.stream:
127 | words = self._stream_response(response)
128 | else:
129 | words = response
130 | sys.stdout.write(words)
131 | sys.stdout.write("\n\n")
132 | return words
133 |
134 | def _handle_multi_llm_response(self, response: Dict[str, str]) -> Dict[str, str]:
135 | for endpoint, resp in response.items():
136 | self._handle_uni_llm_response(resp, endpoint)
137 | return response
138 |
139 | def _handle_response(
140 | self,
141 | response: Union[str, Dict[str, str]],
142 | show_endpoint: bool,
143 | ) -> None:
144 | if isinstance(self._client, _UniClient):
145 | response = self._handle_uni_llm_response(response, show_endpoint)
146 | elif isinstance(self._client, _MultiClient):
147 | response = self._handle_multi_llm_response(response)
148 | else:
149 | raise Exception(
150 | "client must either be a UniClient or MultiClient instance.",
151 | )
152 | self._update_message_history(
153 | role="assistant",
154 | content=response,
155 | )
156 |
157 | def run(self, show_credits: bool = False, show_endpoint: bool = False) -> None:
158 | """
159 | Starts the chat interaction loop.
160 |
161 | Args:
162 | show_credits: Whether to show credit consumption. Defaults to False.
163 | show_endpoint: Whether to show the endpoint used. Defaults to False.
164 | """
165 | if not self._paused:
166 | sys.stdout.write(
167 | "Let's have a chat. (Enter `pause` to pause and `quit` to exit)\n",
168 | )
169 | self.clear_chat_history()
170 | else:
171 | sys.stdout.write(
172 | "Welcome back! (Remember, enter `pause` to pause and `quit` to exit)\n",
173 | )
174 | self._paused = False
175 | while True:
176 | sys.stdout.write("> ")
177 | inp = input()
178 | if inp == "quit":
179 | self.clear_chat_history()
180 | break
181 | elif inp == "pause":
182 | self._paused = True
183 | break
184 | self._update_message_history(role="user", content=inp)
185 | initial_credit_balance = self._get_credits()
186 | if isinstance(self._client, unify.AsyncUnify):
187 | response = asyncio.run(self._client.generate())
188 | else:
189 | response = self._client.generate()
190 | self._handle_response(response, show_endpoint)
191 | final_credit_balance = self._get_credits()
192 | if show_credits:
193 | sys.stdout.write(
194 | "\n(spent {:.6f} credits)".format(
195 | initial_credit_balance - final_credit_balance,
196 | ),
197 | )
198 |
--------------------------------------------------------------------------------
/unify/universal_api/clients/__init__.py:
--------------------------------------------------------------------------------
1 | from . import base
2 | from .base import _Client
3 | from . import uni_llm
4 | from .uni_llm import _UniClient, Unify, AsyncUnify
5 | from . import multi_llm
6 | from .multi_llm import _MultiClient, MultiUnify, AsyncMultiUnify
7 |
--------------------------------------------------------------------------------
/unify/universal_api/clients/helpers.py:
--------------------------------------------------------------------------------
1 | import unify
2 |
3 | # Helpers
4 |
5 |
6 | def _is_custom_endpoint(endpoint: str):
7 | _, provider = endpoint.split("@")
8 | return "custom" in provider
9 |
10 |
11 | def _is_local_endpoint(endpoint: str):
12 | _, provider = endpoint.split("@")
13 | return provider == "local"
14 |
15 |
16 | def _is_fallback_provider(provider: str, api_key: str = None):
17 | public_providers = unify.list_providers(api_key=api_key)
18 | return all(p in public_providers for p in provider.split("->"))
19 |
20 |
21 | def _is_fallback_model(model: str, api_key: str = None):
22 | public_models = unify.list_models(api_key=api_key)
23 | return all(p in public_models for p in model.split("->"))
24 |
25 |
26 | def _is_fallback_endpoint(endpoint: str, api_key: str = None):
27 | public_endpoints = unify.list_endpoints(api_key=api_key)
28 | return all(e in public_endpoints for e in endpoint.split("->"))
29 |
30 |
31 | def _is_meta_provider(provider: str, api_key: str = None):
32 | public_providers = unify.list_providers(api_key=api_key)
33 | if "skip_providers:" in provider:
34 | skip_provs = provider.split("skip_providers:")[-1].split("|")[0]
35 | for prov in skip_provs.split(","):
36 | if prov.strip() not in public_providers:
37 | return False
38 | chnk0, chnk1 = provider.split("skip_providers:")
39 | chnk2 = "|".join(chnk1.split("|")[1:])
40 | provider = "".join([chnk0, chnk2])
41 | if "providers:" in provider:
42 | provs = provider.split("providers:")[-1].split("|")[0]
43 | for prov in provs.split(","):
44 | if prov.strip() not in public_providers:
45 | return False
46 | chnk0, chnk1 = provider.split("providers:")
47 | chnk2 = "|".join(chnk1.split("|")[1:])
48 | provider = "".join([chnk0, chnk2])
49 | if provider[-1] == "|":
50 | provider = provider[:-1]
51 | public_models = unify.list_models(api_key=api_key)
52 | if "skip_models:" in provider:
53 | skip_mods = provider.split("skip_models:")[-1].split("|")[0]
54 | for md in skip_mods.split(","):
55 | if md.strip() not in public_models:
56 | return False
57 | chnk0, chnk1 = provider.split("skip_models:")
58 | chnk2 = "|".join(chnk1.split("|")[1:])
59 | provider = "".join([chnk0, chnk2])
60 | if "models:" in provider:
61 | mods = provider.split("models:")[-1].split("|")[0]
62 | for md in mods.split(","):
63 | if md.strip() not in public_models:
64 | return False
65 | chnk0, chnk1 = provider.split("models:")
66 | chnk2 = "|".join(chnk1.split("|")[1:])
67 | provider = "".join([chnk0, chnk2])
68 | meta_providers = (
69 | (
70 | "highest-quality",
71 | "lowest-time-to-first-token",
72 | "lowest-inter-token-latency",
73 | "lowest-input-cost",
74 | "lowest-output-cost",
75 | "lowest-cost",
76 | "lowest-ttft",
77 | "lowest-itl",
78 | "lowest-ic",
79 | "lowest-oc",
80 | "highest-q",
81 | "lowest-t",
82 | "lowest-i",
83 | "lowest-c",
84 | )
85 | + (
86 | "quality",
87 | "time-to-first-token",
88 | "inter-token-latency",
89 | "input-cost",
90 | "output-cost",
91 | "cost",
92 | )
93 | + (
94 | "q",
95 | "ttft",
96 | "itl",
97 | "ic",
98 | "oc",
99 | "t",
100 | "i",
101 | "c",
102 | )
103 | )
104 | operators = ("<", ">", "=", "|", ".", ":")
105 | for s in meta_providers + operators:
106 | provider = provider.replace(s, "")
107 | return all(c.isnumeric() for c in provider)
108 |
109 |
110 | # Checks
111 |
112 |
113 | def _is_valid_endpoint(endpoint: str, api_key: str = None):
114 | if endpoint == "user-input":
115 | return True
116 | if _is_fallback_endpoint(endpoint, api_key):
117 | return True
118 | model, provider = endpoint.split("@")
119 | if _is_valid_provider(provider) and _is_valid_model(model):
120 | return True
121 | if endpoint in unify.list_endpoints(api_key=api_key):
122 | return True
123 | if _is_custom_endpoint(endpoint) or _is_local_endpoint(endpoint):
124 | return True
125 | return False
126 |
127 |
128 | def _is_valid_provider(provider: str, api_key: str = None):
129 | if _is_meta_provider(provider):
130 | return True
131 | if provider in unify.list_providers(api_key=api_key):
132 | return True
133 | if _is_fallback_provider(provider):
134 | return True
135 | if provider == "local" or "custom" in provider:
136 | return True
137 | return False
138 |
139 |
140 | def _is_valid_model(model: str, custom_or_local: bool = False, api_key: str = None):
141 | if custom_or_local:
142 | return True
143 | if model in unify.list_models(api_key=api_key):
144 | return True
145 | if _is_fallback_model(model):
146 | return True
147 | if model == "router":
148 | return True
149 | return False
150 |
151 |
152 | # Assertions
153 |
154 |
155 | def _assert_is_valid_endpoint(endpoint: str, api_key: str = None):
156 | assert _is_valid_endpoint(endpoint, api_key), f"{endpoint} is not a valid endpoint"
157 |
158 |
159 | def _assert_is_valid_provider(provider: str, api_key: str = None):
160 | assert _is_valid_provider(provider, api_key), f"{provider} is not a valid provider"
161 |
162 |
163 | def _assert_is_valid_model(
164 | model: str,
165 | custom_or_local: bool = False,
166 | api_key: str = None,
167 | ):
168 | assert _is_valid_model(
169 | model,
170 | custom_or_local,
171 | api_key,
172 | ), f"{model} is not a valid model"
173 |
--------------------------------------------------------------------------------
/unify/universal_api/types/__init__.py:
--------------------------------------------------------------------------------
1 | from .prompt import *
2 |
--------------------------------------------------------------------------------
/unify/universal_api/types/prompt.py:
--------------------------------------------------------------------------------
1 | class Prompt:
2 | def __init__(
3 | self,
4 | **components,
5 | ):
6 | """
7 | Create Prompt instance.
8 |
9 | Args:
10 | components: All components of the prompt.
11 |
12 | Returns:
13 | The Prompt instance.
14 | """
15 | self.components = components
16 |
--------------------------------------------------------------------------------
/unify/universal_api/usage.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from typing import List, Optional
3 |
4 | import unify
5 |
6 | from ..utils.helpers import _validate_api_key
7 |
8 |
9 | def with_logging(
10 | model_fn: Optional[callable] = None,
11 | *,
12 | endpoint: str,
13 | tags: Optional[List[str]] = None,
14 | timestamp: Optional[datetime.datetime] = None,
15 | log_query_body: bool = True,
16 | log_response_body: bool = True,
17 | api_key: Optional[str] = None,
18 | ):
19 | """
20 | Wrap a local model callable with logging of the queries.
21 |
22 | Args:
23 | model_fn: The model callable to wrap logging around.
24 | endpoint: The endpoint name to give to this local callable.
25 | tags: Tags for later filtering.
26 | timestamp: A timestamp (if not set, will be the time of sending).
27 | log_query_body: Whether or not to log the query body.
28 | log_response_body: Whether or not to log the response body.
29 | api_key: If specified, unify API key to be used. Defaults to the value in the `UNIFY_KEY` environment variable.
30 |
31 | Returns:
32 | A new callable, but with logging added every time the function is called.
33 |
34 | Raises:
35 | requests.HTTPError: If the API request fails.
36 | """
37 | _tags = tags
38 | _timestamp = timestamp
39 | _log_query_body = log_query_body
40 | _log_response_body = log_response_body
41 | api_key = _validate_api_key(api_key)
42 |
43 | # noinspection PyShadowingNames
44 | def model_fn_w_logging(
45 | *args,
46 | tags: Optional[List[str]] = None,
47 | timestamp: Optional[datetime.datetime] = None,
48 | log_query_body: bool = True,
49 | log_response_body: bool = True,
50 | **kwargs,
51 | ):
52 | if len(args) != 0:
53 | raise Exception(
54 | "When logging queries for a local model, all arguments to "
55 | "the model callable must be provided as keyword arguments. "
56 | "Positional arguments are not supported. This is so the "
57 | "query body dict can be fully populated with keys for each "
58 | "entry.",
59 | )
60 | query_body = kwargs
61 | response = model_fn(**query_body)
62 | if not isinstance(response, dict):
63 | response = {"response": response}
64 | kw = dict(
65 | endpoint=endpoint,
66 | query_body=query_body,
67 | response_body=response,
68 | tags=tags,
69 | timestamp=timestamp,
70 | api_key=api_key,
71 | )
72 | if log_query_body:
73 | if not log_response_body:
74 | del kw["response_body"]
75 | unify.log_query(**kw)
76 | return response
77 |
78 | return model_fn_w_logging
79 |
--------------------------------------------------------------------------------
/unify/universal_api/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/unifyai/unify/fbb548073fb0dc9f0cc0f67e28220011cbf9ec5b/unify/universal_api/utils/__init__.py
--------------------------------------------------------------------------------
/unify/universal_api/utils/credits.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from unify import BASE_URL
4 | from unify.utils import _requests
5 |
6 | from ...utils.helpers import _res_to_list, _validate_api_key
7 |
8 |
9 | def get_credits(*, api_key: Optional[str] = None) -> float:
10 | """
11 | Returns the credits remaining in the user account, in USD.
12 |
13 | Args:
14 | api_key: If specified, unify API key to be used. Defaults to the value in the
15 | `UNIFY_KEY` environment variable.
16 |
17 | Returns:
18 | The credits remaining in USD.
19 | Raises:
20 | ValueError: If there was an HTTP error.
21 | """
22 | api_key = _validate_api_key(api_key)
23 | headers = {
24 | "accept": "application/json",
25 | "Authorization": f"Bearer {api_key}",
26 | }
27 | # Send GET request to the /get_credits endpoint
28 | response = _requests.get(BASE_URL + "/credits", headers=headers)
29 | if response.status_code != 200:
30 | raise Exception(response.json())
31 | return _res_to_list(response)["credits"]
32 |
--------------------------------------------------------------------------------
/unify/universal_api/utils/custom_api_keys.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional
2 |
3 | from unify import BASE_URL
4 | from unify.utils import _requests
5 |
6 | from ...utils.helpers import _validate_api_key
7 |
8 |
9 | def create_custom_api_key(
10 | name: str,
11 | value: str,
12 | *,
13 | api_key: Optional[str] = None,
14 | ) -> Dict[str, str]:
15 | """
16 | Create a custom API key.
17 |
18 | Args:
19 | name: Name of the API key.
20 | value: Value of the API key.
21 | api_key: If specified, unify API key to be used. Defaults
22 | to the value in the `UNIFY_KEY` environment variable.
23 |
24 | Returns:
25 | A dictionary containing the response information.
26 |
27 | """
28 | api_key = _validate_api_key(api_key)
29 | headers = {
30 | "accept": "application/json",
31 | "Authorization": f"Bearer {api_key}",
32 | }
33 | url = f"{BASE_URL}/custom_api_key"
34 |
35 | params = {"name": name, "value": value}
36 |
37 | response = _requests.post(url, headers=headers, params=params)
38 | if response.status_code != 200:
39 | raise Exception(response.json())
40 |
41 | return response.json()
42 |
43 |
44 | def get_custom_api_key(
45 | name: str,
46 | *,
47 | api_key: Optional[str] = None,
48 | ) -> Dict[str, Any]:
49 | """
50 | Get the value of a custom API key.
51 |
52 | Args:
53 | name: Name of the API key to get the value for.
54 | api_key: If specified, unify API key to be used. Defaults
55 | to the value in the `UNIFY_KEY` environment variable.
56 |
57 | Returns:
58 | A dictionary containing the custom API key information.
59 |
60 | Raises:
61 | requests.HTTPError: If the request fails.
62 | """
63 | api_key = _validate_api_key(api_key)
64 | headers = {
65 | "accept": "application/json",
66 | "Authorization": f"Bearer {api_key}",
67 | }
68 | url = f"{BASE_URL}/custom_api_key"
69 | params = {"name": name}
70 |
71 | response = _requests.get(url, headers=headers, params=params)
72 | if response.status_code != 200:
73 | raise Exception(response.json())
74 |
75 | return response.json()
76 |
77 |
78 | def delete_custom_api_key(
79 | name: str,
80 | *,
81 | api_key: Optional[str] = None,
82 | ) -> Dict[str, str]:
83 | """
84 | Delete a custom API key.
85 |
86 | Args:
87 | name: Name of the custom API key to delete.
88 | api_key: If specified, unify API key to be used. Defaults
89 | to the value in the `UNIFY_KEY` environment variable.
90 |
91 | Returns:
92 | A dictionary containing the response message if successful.
93 |
94 | Raises:
95 | requests.HTTPError: If the API request fails.
96 | KeyError: If the API key is not found.
97 | """
98 | api_key = _validate_api_key(api_key)
99 | headers = {
100 | "accept": "application/json",
101 | "Authorization": f"Bearer {api_key}",
102 | }
103 | url = f"{BASE_URL}/custom_api_key"
104 |
105 | params = {"name": name}
106 |
107 | response = _requests.delete(url, headers=headers, params=params)
108 |
109 | if response.status_code == 200:
110 | return response.json()
111 | elif response.status_code == 404:
112 | raise KeyError("API key not found.")
113 | else:
114 | if response.status_code != 200:
115 | raise Exception(response.json())
116 |
117 |
118 | def rename_custom_api_key(
119 | name: str,
120 | new_name: str,
121 | *,
122 | api_key: Optional[str] = None,
123 | ) -> Dict[str, Any]:
124 | """
125 | Rename a custom API key.
126 |
127 | Args:
128 | name: Name of the custom API key to be updated.
129 | new_name: New name for the custom API key.
130 | api_key: If specified, unify API key to be used. Defaults
131 | to the value in the `UNIFY_KEY` environment variable.
132 |
133 | Returns:
134 | A dictionary containing the response information.
135 |
136 | Raises:
137 | requests.HTTPError: If the API request fails.
138 | KeyError: If the API key is not provided or found in environment variables.
139 | """
140 | api_key = _validate_api_key(api_key)
141 | headers = {
142 | "accept": "application/json",
143 | "Authorization": f"Bearer {api_key}",
144 | }
145 | url = f"{BASE_URL}/custom_api_key/rename"
146 |
147 | params = {"name": name, "new_name": new_name}
148 |
149 | response = _requests.post(url, headers=headers, params=params)
150 | if response.status_code != 200:
151 | raise Exception(response.json())
152 |
153 | return response.json()
154 |
155 |
156 | def list_custom_api_keys(
157 | *,
158 | api_key: Optional[str] = None,
159 | ) -> List[Dict[str, str]]:
160 | """
161 | Get a list of custom API keys associated with the user's account.
162 |
163 | Args:
164 | api_key: If specified, unify API key to be used. Defaults
165 | to the value in the `UNIFY_KEY` environment variable.
166 |
167 | Returns:
168 | A list of dictionaries containing custom API key information.
169 | Each dictionary has 'name' and 'value' keys.
170 |
171 | """
172 | api_key = _validate_api_key(api_key)
173 | headers = {
174 | "accept": "application/json",
175 | "Authorization": f"Bearer {api_key}",
176 | }
177 | url = f"{BASE_URL}/custom_api_key/list"
178 |
179 | response = _requests.get(url, headers=headers)
180 | if response.status_code != 200:
181 | raise Exception(response.json())
182 |
183 | return response.json()
184 |
--------------------------------------------------------------------------------
/unify/universal_api/utils/custom_endpoints.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional
2 |
3 | from unify import BASE_URL
4 | from unify.utils import _requests
5 |
6 | from ...utils.helpers import _validate_api_key
7 |
8 |
9 | def create_custom_endpoint(
10 | *,
11 | name: str,
12 | url: str,
13 | key_name: str,
14 | model_name: Optional[str] = None,
15 | provider: Optional[str] = None,
16 | api_key: Optional[str] = None,
17 | ) -> Dict[str, Any]:
18 | """
19 | Create a custom endpoint for API calls.
20 |
21 | Args:
22 | name: Alias for the custom endpoint. This will be the name used to call the endpoint.
23 | url: Base URL of the endpoint being called. Must support the OpenAI format.
24 | key_name: Name of the API key that will be passed as part of the query.
25 | model_name: Name passed to the custom endpoint as model name. If not specified, it will default to the endpoint alias.
26 | provider: If the custom endpoint is for a fine-tuned model which is hosted directly via one of the supported providers,
27 | then this argument should be specified as the provider used.
28 | api_key: If specified, unify API key to be used. Defaults to the value in the `UNIFY_KEY` environment variable.
29 |
30 | Returns:
31 | A dictionary containing the response from the API.
32 |
33 | Raises:
34 | requests.HTTPError: If the API request fails.
35 | KeyError: If the UNIFY_KEY is not set and no api_key is provided.
36 | """
37 | api_key = _validate_api_key(api_key)
38 | headers = {
39 | "accept": "application/json",
40 | "Authorization": f"Bearer {api_key}",
41 | }
42 |
43 | params = {
44 | "name": name,
45 | "url": url,
46 | "key_name": key_name,
47 | }
48 |
49 | if model_name:
50 | params["model_name"] = model_name
51 | if provider:
52 | params["provider"] = provider
53 |
54 | response = _requests.post(
55 | f"{BASE_URL}/custom_endpoint",
56 | headers=headers,
57 | params=params,
58 | )
59 | if response.status_code != 200:
60 | raise Exception(response.json())
61 |
62 | return response.json()
63 |
64 |
65 | def delete_custom_endpoint(
66 | name: str,
67 | *,
68 | api_key: Optional[str] = None,
69 | ) -> Dict[str, str]:
70 | """
71 | Delete a custom endpoint.
72 |
73 | Args:
74 | name: Name of the custom endpoint to delete.
75 | api_key: If specified, unify API key to be used. Defaults
76 | to the value in the `UNIFY_KEY` environment variable.
77 |
78 | Returns:
79 | A dictionary containing the response message.
80 |
81 | Raises:
82 | requests.HTTPError: If the API request fails.
83 | """
84 | api_key = _validate_api_key(api_key)
85 | headers = {
86 | "accept": "application/json",
87 | "Authorization": f"Bearer {api_key}",
88 | }
89 | url = f"{BASE_URL}/custom_endpoint"
90 |
91 | params = {"name": name}
92 |
93 | response = _requests.delete(url, headers=headers, params=params)
94 | if response.status_code != 200:
95 | raise Exception(response.json())
96 |
97 | return response.json()
98 |
99 |
100 | def rename_custom_endpoint(
101 | name: str,
102 | new_name: str,
103 | *,
104 | api_key: Optional[str] = None,
105 | ) -> Dict[str, Any]:
106 | """
107 | Rename a custom endpoint.
108 |
109 | Args:
110 | name: Name of the custom endpoint to be updated.
111 | new_name: New name for the custom endpoint.
112 | api_key: If specified, unify API key to be used. Defaults
113 | to the value in the `UNIFY_KEY` environment variable.
114 |
115 | Returns:
116 | A dictionary containing the response information.
117 |
118 | Raises:
119 | requests.HTTPError: If the API request fails.
120 | """
121 | api_key = _validate_api_key(api_key)
122 | headers = {
123 | "accept": "application/json",
124 | "Authorization": f"Bearer {api_key}",
125 | }
126 | url = f"{BASE_URL}/custom_endpoint/rename"
127 |
128 | params = {"name": name, "new_name": new_name}
129 |
130 | response = _requests.post(url, headers=headers, params=params)
131 | if response.status_code != 200:
132 | raise Exception(response.json())
133 |
134 | return response.json()
135 |
136 |
137 | def list_custom_endpoints(
138 | *,
139 | api_key: Optional[str] = None,
140 | ) -> List[Dict[str, str]]:
141 | """
142 | Get a list of custom endpoints for the authenticated user.
143 |
144 | Args:
145 | api_key: If specified, unify API key to be used. Defaults
146 | to the value in the `UNIFY_KEY` environment variable.
147 |
148 | Returns:
149 | A list of dictionaries containing information about custom endpoints.
150 | Each dictionary has keys: 'name', 'mdl_name', 'url', and 'key'.
151 |
152 | Raises:
153 | requests.exceptions.RequestException: If the API request fails.
154 | """
155 | api_key = _validate_api_key(api_key)
156 | headers = {
157 | "accept": "application/json",
158 | "Authorization": f"Bearer {api_key}",
159 | }
160 | url = f"{BASE_URL}/custom_endpoint/list"
161 |
162 | response = _requests.get(url, headers=headers)
163 | if response.status_code != 200:
164 | raise Exception(response.json())
165 |
166 | return response.json()
167 |
--------------------------------------------------------------------------------
/unify/universal_api/utils/endpoint_metrics.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from typing import Dict, List, Optional, Union
3 |
4 | from pydantic import BaseModel
5 | from unify import BASE_URL
6 | from unify.utils import _requests
7 |
8 | from ...utils.helpers import _validate_api_key
9 |
10 |
11 | class Metrics(BaseModel, extra="allow"):
12 | ttft: Optional[float]
13 | itl: Optional[float]
14 | input_cost: Optional[float]
15 | output_cost: Optional[float]
16 | measured_at: Union[datetime.datetime, str, Dict[str, Union[datetime.datetime, str]]]
17 |
18 |
19 | def get_endpoint_metrics(
20 | endpoint: str,
21 | *,
22 | start_time: Optional[Union[datetime.datetime, str]] = None,
23 | end_time: Optional[Union[datetime.datetime, str]] = None,
24 | api_key: Optional[str] = None,
25 | ) -> List[Metrics]:
26 | """
27 | Retrieve the set of cost and speed metrics for the specified endpoint.
28 |
29 | Args:
30 | endpoint: The endpoint to retrieve the metrics for, in model@provider format
31 |
32 | start_time: Window start time. Only returns the latest benchmark if unspecified.
33 |
34 | end_time: Window end time. Assumed to be the current time if this is unspecified
35 | and start_time is specified. Only the latest benchmark is returned if both are
36 | unspecified.
37 |
38 | api_key: If specified, unify API key to be used. Defaults to the value in the
39 | `UNIFY_KEY` environment variable.
40 |
41 | Returns:
42 | The set of metrics for the specified endpoint.
43 | """
44 | api_key = _validate_api_key(api_key)
45 | headers = {
46 | "accept": "application/json",
47 | "Authorization": f"Bearer {api_key}",
48 | }
49 | params = {
50 | "model": endpoint.split("@")[0],
51 | "provider": endpoint.split("@")[1],
52 | "start_time": start_time,
53 | "end_time": end_time,
54 | }
55 | response = _requests.get(
56 | BASE_URL + "/endpoint-metrics",
57 | headers=headers,
58 | params=params,
59 | )
60 | if response.status_code != 200:
61 | raise Exception(response.json())
62 | return [
63 | Metrics(
64 | ttft=metrics_dct["ttft"],
65 | itl=metrics_dct["itl"],
66 | input_cost=metrics_dct["input_cost"],
67 | output_cost=metrics_dct["output_cost"],
68 | measured_at=metrics_dct["measured_at"],
69 | )
70 | for metrics_dct in response.json()
71 | ]
72 |
73 |
74 | def log_endpoint_metric(
75 | endpoint_name: str,
76 | *,
77 | metric_name: str,
78 | value: float,
79 | measured_at: Optional[Union[str, datetime.datetime]] = None,
80 | api_key: Optional[str] = None,
81 | ) -> Dict[str, str]:
82 | """
83 | Append speed or cost data to the standardized time-series benchmarks for a custom
84 | endpoint (only custom endpoints are publishable by end users).
85 |
86 | Args:
87 | endpoint_name: Name of the custom endpoint to append benchmark data for.
88 |
89 | metric_name: Name of the metric to submit. Allowed metrics are: “input_cost”,
90 | “output_cost”, “ttft”, “itl”.
91 |
92 | value: Value of the metric to submit.
93 |
94 | measured_at: The timestamp to associate with the submission. Defaults to current
95 | time if unspecified.
96 |
97 | api_key: If specified, unify API key to be used. Defaults to the value in the
98 | `UNIFY_KEY` environment variable.
99 | """
100 | api_key = _validate_api_key(api_key)
101 | headers = {
102 | "accept": "application/json",
103 | "Authorization": f"Bearer {api_key}",
104 | }
105 | params = {
106 | "endpoint_name": endpoint_name,
107 | "metric_name": metric_name,
108 | "value": value,
109 | "measured_at": measured_at,
110 | }
111 | response = _requests.post(
112 | BASE_URL + "/endpoint-metrics",
113 | headers=headers,
114 | params=params,
115 | )
116 | if response.status_code != 200:
117 | raise Exception(response.json())
118 | return response.json()
119 |
120 |
121 | def delete_endpoint_metrics(
122 | endpoint_name: str,
123 | *,
124 | timestamps: Optional[Union[datetime.datetime, List[datetime.datetime]]] = None,
125 | api_key: Optional[str] = None,
126 | ) -> Dict[str, str]:
127 | api_key = _validate_api_key(api_key)
128 | headers = {
129 | "accept": "application/json",
130 | "Authorization": f"Bearer {api_key}",
131 | }
132 | params = {
133 | "endpoint_name": endpoint_name,
134 | "timestamps": timestamps,
135 | }
136 | response = _requests.delete(
137 | BASE_URL + "/endpoint-metrics",
138 | headers=headers,
139 | params=params,
140 | )
141 | if response.status_code != 200:
142 | raise Exception(response.json())
143 | return response.json()
144 |
--------------------------------------------------------------------------------
/unify/universal_api/utils/queries.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from typing import Any, Dict, List, Optional, Union
3 |
4 | from unify import BASE_URL
5 | from unify.utils import _requests
6 |
7 | from ...utils.helpers import _validate_api_key
8 |
9 |
10 | def get_query_tags(
11 | *,
12 | api_key: Optional[str] = None,
13 | ) -> List[str]:
14 | """
15 | Get a list of available query tags.
16 |
17 | Args:
18 | api_key: If specified, unify API key to be used. Defaults
19 | to the value in the `UNIFY_KEY` environment variable.
20 |
21 | Returns:
22 | A list of available query tags if successful, otherwise an empty list.
23 | """
24 | api_key = _validate_api_key(api_key)
25 | headers = {
26 | "accept": "application/json",
27 | "Authorization": f"Bearer {api_key}",
28 | }
29 | url = f"{BASE_URL}/tags"
30 | response = _requests.get(url, headers=headers)
31 | if response.status_code != 200:
32 | raise Exception(response.json())
33 |
34 | return response.json()
35 |
36 |
37 | def get_queries(
38 | *,
39 | tags: Optional[Union[str, List[str]]] = None,
40 | endpoints: Optional[Union[str, List[str]]] = None,
41 | start_time: Optional[Union[datetime.datetime, str]] = None,
42 | end_time: Optional[Union[datetime.datetime, str]] = None,
43 | page_number: Optional[int] = None,
44 | failures: Optional[Union[bool, str]] = None,
45 | api_key: Optional[str] = None,
46 | ) -> Dict[str, Any]:
47 | """
48 | Get query history based on specified filters.
49 |
50 | Args:
51 | tags: Tags to filter for queries that are marked with these tags.
52 |
53 | endpoints: Optionally specify an endpoint, or a list of endpoints to filter for.
54 |
55 | start_time: Timestamp of the earliest query to aggregate.
56 | Format is `YYYY-MM-DD hh:mm:ss`.
57 |
58 | end_time: Timestamp of the latest query to aggregate.
59 | Format is `YYYY-MM-DD hh:mm:ss`.
60 |
61 | page_number: The query history is returned in pages, with up to 100 prompts per
62 | page. Increase the page number to see older prompts. Default is 1.
63 |
64 | failures: indicates whether to includes failures in the return
65 | (when set as True), or whether to return failures exclusively
66 | (when set as ‘only’). Default is False.
67 |
68 | api_key: If specified, unify API key to be used.
69 | Defaults to the value in the `UNIFY_KEY` environment variable.
70 |
71 | Returns:
72 | A dictionary containing the query history data.
73 | """
74 | api_key = _validate_api_key(api_key)
75 | headers = {
76 | "accept": "application/json",
77 | "Authorization": f"Bearer {api_key}",
78 | }
79 |
80 | params = {}
81 | if tags:
82 | params["tags"] = tags
83 | if endpoints:
84 | params["endpoints"] = endpoints
85 | if start_time:
86 | params["start_time"] = start_time
87 | if end_time:
88 | params["end_time"] = end_time
89 | if page_number:
90 | params["page_number"] = page_number
91 | if failures:
92 | params["failures"] = failures
93 |
94 | url = f"{BASE_URL}/queries"
95 | response = _requests.get(url, headers=headers, params=params)
96 | if response.status_code != 200:
97 | raise Exception(response.json())
98 |
99 | return response.json()
100 |
101 |
102 | def log_query(
103 | *,
104 | endpoint: str,
105 | query_body: Dict,
106 | response_body: Optional[Dict] = None,
107 | tags: Optional[List[str]] = None,
108 | timestamp: Optional[Union[datetime.datetime, str]] = None,
109 | api_key: Optional[str] = None,
110 | ):
111 | """
112 | Log a query (and optionally response) for a locally deployed (non-Unify-registered)
113 | model, with tagging (default None) and timestamp (default datetime.now() also
114 | optionally writeable.
115 |
116 | Args:
117 | endpoint: Endpoint to log query for.
118 | query_body: A dict containing the body of the request.
119 | response_body: An optional dict containing the response to the request.
120 | tags: Custom tags for later filtering.
121 | timestamp: A timestamp (if not set, will be the time of sending).
122 | api_key: If specified, unify API key to be used. Defaults to the value in the `UNIFY_KEY` environment variable.
123 |
124 | Returns:
125 | A dictionary containing the response message if successful.
126 |
127 | Raises:
128 | requests.HTTPError: If the API request fails.
129 | """
130 | api_key = _validate_api_key(api_key)
131 | headers = {
132 | "accept": "application/json",
133 | "Authorization": f"Bearer {api_key}",
134 | }
135 |
136 | data = {
137 | "endpoint": endpoint,
138 | "query_body": query_body,
139 | "response_body": response_body,
140 | "tags": tags,
141 | "timestamp": timestamp,
142 | }
143 |
144 | # Remove None values from params
145 | data = {k: v for k, v in data.items() if v is not None}
146 |
147 | url = f"{BASE_URL}/queries"
148 |
149 | response = _requests.post(url, headers=headers, json=data)
150 | if response.status_code != 200:
151 | raise Exception(response.json())
152 |
153 | return response.json()
154 |
155 |
156 | def get_query_metrics(
157 | *,
158 | start_time: Optional[Union[datetime.datetime, str]] = None,
159 | end_time: Optional[Union[datetime.datetime, str]] = None,
160 | models: Optional[str] = None,
161 | providers: Optional[str] = None,
162 | interval: int = 300,
163 | secondary_user_id: Optional[str] = None,
164 | api_key: Optional[str] = None,
165 | ) -> Dict[str, Any]:
166 | """
167 | Get query metrics for specified parameters.
168 |
169 | Args:
170 | start_time: Timestamp of the earliest query to aggregate. Format is `YYYY-MM-DD hh:mm:ss`.
171 | end_time: Timestamp of the latest query to aggregate. Format is `YYYY-MM-DD hh:mm:ss`.
172 | models: Models to fetch metrics from. Comma-separated string of model names.
173 | providers: Providers to fetch metrics from. Comma-separated string of provider names.
174 | interval: Number of seconds in the aggregation interval. Default is 300.
175 | secondary_user_id: Secondary user id to match the `user` attribute from `/chat/completions`.
176 | api_key: If specified, unify API key to be used. Defaults to the value in the `UNIFY_KEY` environment variable.
177 |
178 | Returns:
179 | A dictionary containing the query metrics.
180 | """
181 | api_key = _validate_api_key(api_key)
182 | headers = {
183 | "accept": "application/json",
184 | "Authorization": f"Bearer {api_key}",
185 | }
186 |
187 | params = {
188 | "start_time": start_time,
189 | "end_time": end_time,
190 | "models": models,
191 | "providers": providers,
192 | "interval": interval,
193 | "secondary_user_id": secondary_user_id,
194 | }
195 |
196 | # Remove None values from params
197 | params = {k: v for k, v in params.items() if v is not None}
198 |
199 | url = f"{BASE_URL}/metrics"
200 |
201 | response = _requests.get(url, headers=headers, params=params)
202 | if response.status_code != 200:
203 | raise Exception(response.json())
204 |
205 | return response.json()
206 |
--------------------------------------------------------------------------------
/unify/universal_api/utils/supported_endpoints.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from unify import BASE_URL
4 | from unify.utils import _requests
5 |
6 | from ...utils.helpers import _res_to_list, _validate_api_key
7 |
8 |
9 | def list_providers(
10 | model: Optional[str] = None,
11 | *,
12 | api_key: Optional[str] = None,
13 | ) -> List[str]:
14 | """
15 | Get a list of available providers, either in total or for a specific model.
16 |
17 | Args:
18 | model: If specified, returns the list of providers supporting this model.
19 | api_key: If specified, unify API key to be used. Defaults
20 | to the value in the `UNIFY_KEY` environment variable.
21 |
22 | Returns:
23 | A list of provider names associated with the model if successful, otherwise an
24 | empty list.
25 | Raises:
26 | BadRequestError: If there was an HTTP error.
27 | ValueError: If there was an error parsing the JSON response.
28 | """
29 | api_key = _validate_api_key(api_key)
30 | headers = {
31 | "accept": "application/json",
32 | "Authorization": f"Bearer {api_key}",
33 | }
34 | url = f"{BASE_URL}/providers"
35 | if model:
36 | kw = dict(headers=headers, params={"model": model})
37 | else:
38 | kw = dict(headers=headers)
39 | response = _requests.get(url, **kw)
40 | if response.status_code != 200:
41 | raise Exception(response.json())
42 | return _res_to_list(response)
43 |
44 |
45 | def list_models(
46 | provider: Optional[str] = None,
47 | *,
48 | api_key: Optional[str] = None,
49 | ) -> List[str]:
50 | """
51 | Get a list of available models, either in total or for a specific provider.
52 |
53 | Args:
54 | provider: If specified, returns the list of models supporting this provider.
55 | api_key: If specified, unify API key to be used. Defaults
56 | to the value in the `UNIFY_KEY` environment variable.
57 |
58 | Returns:
59 | A list of available model names if successful, otherwise an empty list.
60 | Raises:
61 | BadRequestError: If there was an HTTP error.
62 | ValueError: If there was an error parsing the JSON response.
63 | """
64 | api_key = _validate_api_key(api_key)
65 | headers = {
66 | "accept": "application/json",
67 | "Authorization": f"Bearer {api_key}",
68 | }
69 | url = f"{BASE_URL}/models"
70 | if provider:
71 | kw = dict(headers=headers, params={"provider": provider})
72 | else:
73 | kw = dict(headers=headers)
74 | response = _requests.get(url, **kw)
75 | if response.status_code != 200:
76 | raise Exception(response.json())
77 | return _res_to_list(response)
78 |
79 |
80 | def list_endpoints(
81 | model: Optional[str] = None,
82 | provider: Optional[str] = None,
83 | *,
84 | api_key: Optional[str] = None,
85 | ) -> List[str]:
86 | """
87 | Get a list of available endpoint, either in total or for a specific model or
88 | provider.
89 |
90 | Args:
91 | model: If specified, returns the list of endpoint supporting this model.
92 | provider: If specified, returns the list of endpoint supporting this provider.
93 |
94 | api_key: If specified, unify API key to be used. Defaults to the value in the
95 | `UNIFY_KEY` environment variable.
96 |
97 | Returns:
98 | A list of endpoint names if successful, otherwise an empty list.
99 | Raises:
100 | BadRequestError: If there was an HTTP error.
101 | ValueError: If there was an error parsing the JSON response.
102 | """
103 | api_key = _validate_api_key(api_key)
104 | headers = {
105 | "accept": "application/json",
106 | "Authorization": f"Bearer {api_key}",
107 | }
108 | url = f"{BASE_URL}/endpoints"
109 | if model and provider:
110 | raise ValueError("Please specify either model OR provider, not both.")
111 | elif model:
112 | kw = dict(headers=headers, params={"model": model})
113 | return _res_to_list(
114 | _requests.get(url, headers=headers, params={"model": model}),
115 | )
116 | elif provider:
117 | kw = dict(headers=headers, params={"provider": provider})
118 | else:
119 | kw = dict(headers=headers)
120 | response = _requests.get(url, **kw)
121 | if response.status_code != 200:
122 | raise Exception(response.json())
123 | return _res_to_list(response)
124 |
--------------------------------------------------------------------------------
/unify/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from . import helpers
2 | from .map import *
3 |
--------------------------------------------------------------------------------
/unify/utils/_requests.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 |
5 | import requests
6 |
7 | _logger = logging.getLogger("unify_requests")
8 | _log_enabled = os.getenv("UNIFY_REQUESTS_DEBUG", "false").lower() in ("true", "1")
9 | _logger.setLevel(logging.DEBUG if _log_enabled else logging.WARNING)
10 |
11 |
12 | class ResponseDecodeError(Exception):
13 | def __init__(self, response: requests.Response):
14 | self.response = response
15 | super().__init__(f"Request failed to parse response: {response.text}")
16 |
17 |
18 | def _log(type: str, url: str, mask_key: bool = True, /, **kwargs):
19 | if not _log_enabled:
20 | return
21 | _kwargs_str = ""
22 | if mask_key and "headers" in kwargs:
23 | key = kwargs["headers"]["Authorization"]
24 | kwargs["headers"]["Authorization"] = "***"
25 |
26 | for k, v in kwargs.items():
27 | if isinstance(v, dict):
28 | _kwargs_str += f"{k:}:{json.dumps(v, indent=2)},\n"
29 | else:
30 | _kwargs_str += f"{k}:{v},\n"
31 |
32 | if mask_key and "headers" in kwargs:
33 | kwargs["headers"]["Authorization"] = key
34 |
35 | log_msg = f"""
36 | ====== {type} =======
37 | url:{url}
38 | {_kwargs_str}
39 | """
40 | _logger.debug(log_msg)
41 |
42 |
43 | def request(method, url, **kwargs):
44 | _log(f"request:{method}", url, True, **kwargs)
45 | res = requests.request(method, url, **kwargs)
46 | try:
47 | _log(f"request:{method} response:{res.status_code}", url, response=res.json())
48 | except requests.exceptions.JSONDecodeError as e:
49 | raise ResponseDecodeError(res)
50 | return res
51 |
52 |
53 | def get(url, params=None, **kwargs):
54 | _log("GET", url, True, params=params, **kwargs)
55 | res = requests.get(url, params=params, **kwargs)
56 | try:
57 | _log(f"GET response:{res.status_code}", url, response=res.json())
58 | except requests.exceptions.JSONDecodeError as e:
59 | raise ResponseDecodeError(res)
60 | return res
61 |
62 |
63 | def options(url, **kwargs):
64 | _log("OPTIONS", url, True, **kwargs)
65 | res = requests.options(url, **kwargs)
66 | try:
67 | _log(f"OPTIONS response:{res.status_code}", url, response=res.json())
68 | except requests.exceptions.JSONDecodeError as e:
69 | raise ResponseDecodeError(res)
70 | return res
71 |
72 |
73 | def head(url, **kwargs):
74 | _log("HEAD", url, True, **kwargs)
75 | res = requests.head(url, **kwargs)
76 | try:
77 | _log(f"HEAD response:{res.status_code}", url, response=res.json())
78 | except requests.exceptions.JSONDecodeError as e:
79 | raise ResponseDecodeError(res)
80 | return res
81 |
82 |
83 | def post(url, data=None, json=None, **kwargs):
84 | _log("POST", url, True, data=data, json=json, **kwargs)
85 | res = requests.post(url, data=data, json=json, **kwargs)
86 | try:
87 | _log(f"POST response:{res.status_code}", url, response=res.json())
88 | except requests.exceptions.JSONDecodeError as e:
89 | raise ResponseDecodeError(res)
90 | return res
91 |
92 |
93 | def put(url, data=None, **kwargs):
94 | _log("PUT", url, True, data=data, **kwargs)
95 | res = requests.put(url, data=data, **kwargs)
96 | try:
97 | _log(f"PUT response:{res.status_code}", url, response=res.json())
98 | except requests.exceptions.JSONDecodeError as e:
99 | raise ResponseDecodeError(res)
100 | return res
101 |
102 |
103 | def patch(url, data=None, **kwargs):
104 | _log("PATCH", url, True, data=data, **kwargs)
105 | res = requests.patch(url, data=data, **kwargs)
106 | try:
107 | _log(f"PATCH response:{res.status_code}", url, response=res.json())
108 | except requests.exceptions.JSONDecodeError as e:
109 | raise ResponseDecodeError(res)
110 | return res
111 |
112 |
113 | def delete(url, **kwargs):
114 | _log("DELETE", url, True, **kwargs)
115 | res = requests.delete(url, **kwargs)
116 | try:
117 | _log(f"DELETE response:{res.status_code}", url, response=res.json())
118 | except requests.exceptions.JSONDecodeError as e:
119 | raise ResponseDecodeError(res)
120 | return res
121 |
--------------------------------------------------------------------------------
/unify/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import json
3 | import os
4 | import threading
5 | from typing import Any, Dict, List, Optional, Tuple, Union
6 |
7 | import requests
8 | import unify
9 | from pydantic import BaseModel, ValidationError
10 |
11 | PROJECT_LOCK = threading.Lock()
12 |
13 |
14 | class RequestError(Exception):
15 | def __init__(self, response: requests.Response):
16 | req = response.request
17 | message = (
18 | f"{req.method} {req.url} failed with status code {response.status_code}. "
19 | f"Request body: {req.body}, Response: {response.text}"
20 | )
21 | super().__init__(message)
22 | self.response = response
23 |
24 |
25 | def _check_response(response: requests.Response):
26 | if not response.ok:
27 | raise RequestError(response)
28 |
29 |
30 | def _res_to_list(response: requests.Response) -> Union[List, Dict]:
31 | return json.loads(response.text)
32 |
33 |
34 | def _validate_api_key(api_key: Optional[str]) -> str:
35 | if api_key is None:
36 | api_key = os.environ.get("UNIFY_KEY")
37 | if api_key is None:
38 | raise KeyError(
39 | "UNIFY_KEY is missing. Please make sure it is set correctly!",
40 | )
41 | return api_key
42 |
43 |
44 | def _default(value: Any, default_value: Any) -> Any:
45 | return value if value is not None else default_value
46 |
47 |
48 | def _dict_aligns_with_pydantic(dict_in: Dict, pydantic_cls: type(BaseModel)) -> bool:
49 | try:
50 | pydantic_cls.model_validate(dict_in)
51 | return True
52 | except ValidationError:
53 | return False
54 |
55 |
56 | def _make_json_serializable(
57 | item: Any,
58 | ) -> Union[Dict, List, Tuple]:
59 | # Add a recursion guard using getattr to avoid infinite recursion
60 | if hasattr(item, "_being_serialized") and getattr(item, "_being_serialized", False):
61 | return ""
62 |
63 | try:
64 | # For objects that might cause recursion, set a flag
65 | if hasattr(item, "__dict__") and not isinstance(
66 | item,
67 | (dict, list, tuple, BaseModel),
68 | ):
69 | setattr(item, "_being_serialized", True)
70 |
71 | if isinstance(item, list):
72 | result = [_make_json_serializable(i) for i in item]
73 | elif isinstance(item, dict):
74 | result = {k: _make_json_serializable(v) for k, v in item.items()}
75 | elif isinstance(item, tuple):
76 | result = tuple(_make_json_serializable(i) for i in item)
77 | elif inspect.isclass(item) and issubclass(item, BaseModel):
78 | result = item.model_json_schema()
79 | elif isinstance(item, BaseModel):
80 | result = item.model_dump()
81 | elif hasattr(item, "json") and callable(item.json):
82 | result = _make_json_serializable(item.json())
83 | # Handle threading objects specifically
84 | elif "threading" in type(item).__module__:
85 | result = f"<{type(item).__name__} at {id(item)}>"
86 | elif isinstance(item, (int, float, bool, str, type(None))):
87 | result = item
88 | else:
89 | try:
90 | result = json.dumps(item)
91 | except Exception:
92 | try:
93 | result = str(item)
94 | except Exception:
95 | result = f"<{type(item).__name__} at {id(item)}>"
96 |
97 | return result
98 | finally:
99 | # Clean up the recursion guard flag
100 | if hasattr(item, "__dict__") and not isinstance(
101 | item,
102 | (dict, list, tuple, BaseModel),
103 | ):
104 | try:
105 | delattr(item, "_being_serialized")
106 | except (AttributeError, TypeError):
107 | pass
108 |
109 |
110 | def _get_and_maybe_create_project(
111 | project: Optional[str] = None,
112 | required: bool = True,
113 | api_key: Optional[str] = None,
114 | create_if_missing: bool = True,
115 | ) -> Optional[str]:
116 | # noinspection PyUnresolvedReferences
117 | from unify.logging.utils.logs import ASYNC_LOGGING
118 |
119 | api_key = _validate_api_key(api_key)
120 | if project is None:
121 | project = unify.active_project()
122 | if project is None:
123 | if required:
124 | project = "_"
125 | else:
126 | return None
127 | if not create_if_missing:
128 | return project
129 | if ASYNC_LOGGING:
130 | # acquiring the project lock here will block the async logger
131 | # so we skip the lock if we are in async mode
132 | return project
133 | with PROJECT_LOCK:
134 | if project not in unify.list_projects(api_key=api_key):
135 | unify.create_project(project, api_key=api_key)
136 | return project
137 |
138 |
139 | def _prune_dict(val):
140 | def keep(v):
141 | if v in (None, "NOT_GIVEN"):
142 | return False
143 | else:
144 | ret = _prune_dict(v)
145 | if isinstance(ret, dict) or isinstance(ret, list) or isinstance(ret, tuple):
146 | return bool(ret)
147 | return True
148 |
149 | if (
150 | not isinstance(val, dict)
151 | and not isinstance(val, list)
152 | and not isinstance(val, tuple)
153 | ):
154 | return val
155 | elif isinstance(val, dict):
156 | return {k: _prune_dict(v) for k, v in val.items() if keep(v)}
157 | elif isinstance(val, list):
158 | return [_prune_dict(v) for i, v in enumerate(val) if keep(v)]
159 | else:
160 | return tuple(_prune_dict(v) for i, v in enumerate(val) if keep(v))
161 |
162 |
163 | import copy
164 | from typing import Any, Dict, List, Set, Tuple, Union
165 |
166 | __all__ = ["flexible_deepcopy"]
167 |
168 |
169 | # Internal sentinel: return this to signal "skip me".
170 | class _SkipType:
171 | pass
172 |
173 |
174 | _SKIP = _SkipType()
175 |
176 | Container = Union[Dict[Any, Any], List[Any], Tuple[Any, ...], Set[Any]]
177 |
178 |
179 | def flexible_deepcopy(
180 | obj: Any,
181 | on_fail: str = "raise",
182 | _memo: Optional[Dict[int, Any]] = None,
183 | ) -> Any:
184 | """
185 | Perform a deepcopy that tolerates un‑copyable elements.
186 |
187 | Parameters
188 | ----------
189 | obj : Any
190 | The object you wish to copy.
191 | on_fail : {'raise', 'skip', 'shallow'}, default 'raise'
192 | • 'raise' – re‑raise copy error (standard behaviour).
193 | • 'skip' – drop the offending element from the result.
194 | • 'shallow' – insert the original element unchanged.
195 | _memo : dict or None (internal)
196 | Memoisation dict to preserve identity & avoid infinite recursion.
197 |
198 | Returns
199 | -------
200 | Any
201 | A deep‑copied version of *obj*, modified per *on_fail* strategy.
202 |
203 | Raises
204 | ------
205 | ValueError
206 | If *on_fail* is not one of the accepted values.
207 | Exception
208 | Re‑raises whatever copy error occurred when *on_fail* == 'raise'.
209 | """
210 | if _memo is None:
211 | _memo = {}
212 |
213 | obj_id = id(obj)
214 | if obj_id in _memo: # Handle circular references.
215 | return _memo[obj_id]
216 |
217 | def _attempt(value: Any) -> Union[Any, _SkipType]:
218 | """Try to deepcopy *value*; fall back per on_fail."""
219 | try:
220 | return flexible_deepcopy(value, on_fail, _memo)
221 | except Exception:
222 | if on_fail == "raise":
223 | raise
224 | if on_fail == "shallow":
225 | return value
226 | if on_fail == "skip":
227 | return _SKIP
228 | raise ValueError(f"Invalid on_fail option: {on_fail!r}")
229 |
230 | # --- Handle built‑in containers explicitly ---------------------------
231 | if isinstance(obj, dict):
232 | result: Dict[Any, Any] = {}
233 | _memo[obj_id] = result # Early memoisation for cycles
234 | for k, v in obj.items():
235 | nk = _attempt(k)
236 | nv = _attempt(v)
237 | if _SKIP in (nk, nv): # Skip entry if key or value failed
238 | continue
239 | result[nk] = nv
240 | return result
241 |
242 | if isinstance(obj, list):
243 | result: List[Any] = []
244 | _memo[obj_id] = result
245 | for item in obj:
246 | nitem = _attempt(item)
247 | if nitem is not _SKIP:
248 | result.append(nitem)
249 | return result
250 |
251 | if isinstance(obj, tuple):
252 | items = []
253 | _memo[obj_id] = None # Placeholder for circular refs
254 | for item in obj:
255 | nitem = _attempt(item)
256 | if nitem is not _SKIP:
257 | items.append(nitem)
258 | result = tuple(items)
259 | _memo[obj_id] = result
260 | return result
261 |
262 | if isinstance(obj, set):
263 | result: Set[Any] = set()
264 | _memo[obj_id] = result
265 | for item in obj:
266 | nitem = _attempt(item)
267 | if nitem is not _SKIP:
268 | result.add(nitem)
269 | return result
270 |
271 | # --- Non‑container: fall back to standard deepcopy -------------------
272 | try:
273 | result = copy.deepcopy(obj, _memo)
274 | _memo[obj_id] = result
275 | return result
276 | except Exception:
277 | if on_fail == "raise":
278 | raise
279 | if on_fail == "shallow":
280 | _memo[obj_id] = obj
281 | return obj
282 | if on_fail == "skip":
283 | return _SKIP
284 | raise ValueError(f"Invalid on_fail option: {on_fail!r}")
285 |
--------------------------------------------------------------------------------
/unify/utils/map.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import contextvars
3 | import threading
4 | from typing import Any, List
5 |
6 | from tqdm import tqdm
7 | from tqdm.asyncio import tqdm_asyncio
8 |
9 | MAP_MODE = "threading"
10 |
11 |
12 | def set_map_mode(mode: str):
13 | global MAP_MODE
14 | MAP_MODE = mode
15 |
16 |
17 | def get_map_mode() -> str:
18 | return MAP_MODE
19 |
20 |
21 | def _is_iterable(item: Any) -> bool:
22 | try:
23 | iter(item)
24 | return True
25 | except TypeError:
26 | return False
27 |
28 |
29 | # noinspection PyShadowingBuiltins
30 | def map(
31 | fn: callable,
32 | *args,
33 | mode=None,
34 | name="",
35 | from_args=False,
36 | raise_exceptions=True,
37 | **kwargs,
38 | ) -> Any:
39 |
40 | if name:
41 | name = (
42 | " ".join(substr[0].upper() + substr[1:] for substr in name.split("_")) + " "
43 | )
44 |
45 | if mode is None:
46 | mode = get_map_mode()
47 |
48 | assert mode in (
49 | "threading",
50 | "asyncio",
51 | "loop",
52 | ), "map mode must be one of threading, asyncio or loop."
53 |
54 | def fn_w_exception_handling(*a, **kw):
55 | try:
56 | return fn(*a, **kw)
57 | except Exception as e:
58 | if raise_exceptions:
59 | raise e
60 |
61 | if from_args:
62 | args = list(args)
63 | for i, a in enumerate(args):
64 | if _is_iterable(a):
65 | args[i] = list(a)
66 |
67 | if args:
68 | num_calls = len(args[0])
69 | else:
70 | for v in kwargs.values():
71 | if isinstance(v, list):
72 | num_calls = len(v)
73 | break
74 | else:
75 | raise Exception(
76 | "At least one of the args or kwargs must be a list, "
77 | "which is to be mapped across the threads",
78 | )
79 | args_n_kwargs = [
80 | (
81 | tuple(a[i] for a in args),
82 | {
83 | k: v[i] if (isinstance(v, list) or isinstance(v, tuple)) else v
84 | for k, v in kwargs.items()
85 | },
86 | )
87 | for i in range(num_calls)
88 | ]
89 | else:
90 | args_n_kwargs = args[0]
91 | if not isinstance(args_n_kwargs[0], tuple):
92 | if isinstance(args_n_kwargs[0], dict):
93 | args_n_kwargs = [((), item) for item in args_n_kwargs]
94 | else:
95 | args_n_kwargs = [((item,), {}) for item in args_n_kwargs]
96 | elif (
97 | not isinstance(args_n_kwargs[0][0], tuple)
98 | or len(args_n_kwargs[0]) < 2
99 | or not isinstance(args_n_kwargs[0][1], dict)
100 | ):
101 | args_n_kwargs = [(item, {}) for item in args_n_kwargs]
102 | num_calls = len(args_n_kwargs)
103 |
104 | if mode == "loop":
105 |
106 | pbar = tqdm(total=num_calls)
107 | pbar.set_description(f"{name}Iterations")
108 |
109 | returns = list()
110 | for a, kw in args_n_kwargs:
111 | ret = fn_w_exception_handling(*a, **kw)
112 | returns.append(ret)
113 | pbar.update(1)
114 | pbar.close()
115 | return returns
116 |
117 | elif mode == "threading":
118 |
119 | pbar = tqdm(total=num_calls)
120 | pbar.set_description(f"{name}Threads")
121 |
122 | def fn_w_indexing(rets: List[None], thread_idx: int, *a, **kw):
123 | for var, value in kw["context"].items():
124 | var.set(value)
125 | del kw["context"]
126 | ret = fn_w_exception_handling(*a, **kw)
127 | pbar.update(1)
128 | rets[thread_idx] = ret
129 |
130 | threads = list()
131 | returns = [None] * num_calls
132 | for i, a_n_kw in enumerate(args_n_kwargs):
133 | a, kw = a_n_kw
134 | kw["context"] = contextvars.copy_context()
135 | thread = threading.Thread(
136 | target=fn_w_indexing,
137 | args=(returns, i, *a),
138 | kwargs=kw,
139 | )
140 | thread.start()
141 | threads.append(thread)
142 | [thread.join() for thread in threads]
143 | pbar.close()
144 | return returns
145 |
146 | def _run_asyncio_in_thread(ret):
147 | asyncio.set_event_loop(asyncio.new_event_loop())
148 | MAX_WORKERS = 100
149 | semaphore = asyncio.Semaphore(MAX_WORKERS)
150 | fns = []
151 |
152 | async def fn_wrapper(*args, **kwargs):
153 | async with semaphore:
154 | return await asyncio.to_thread(fn_w_exception_handling, *args, **kwargs)
155 |
156 | for _, a_n_kw in enumerate(args_n_kwargs):
157 | a, kw = a_n_kw
158 | fns.append(fn_wrapper(*a, **kw))
159 |
160 | async def main(fns):
161 | return await tqdm_asyncio.gather(*fns, desc=f"{name}Coroutines")
162 |
163 | ret += asyncio.run(main(fns))
164 |
165 | ret = []
166 | thread = threading.Thread(target=_run_asyncio_in_thread, args=(ret,))
167 | thread.start()
168 | thread.join()
169 | return ret
170 |
--------------------------------------------------------------------------------